In [34]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.transforms import Resize
from transformers import ViTForImageClassification
from datasets import load_metric
import medmnist
from medmnist import INFO, Evaluator

In [35]:
data_flag = 'octmnist'
download = True

NUM_EPOCHS = 3
BATCH_SIZE = 128
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

In [36]:
# preprocessing
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download)
eval_dataset = DataClass(split='val', transform=data_transform, download=download)
test_dataset = DataClass(split='test', transform=data_transform, download=download)

pil_dataset = DataClass(split='train', download=download)

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)

Using downloaded and verified file: /home/zhl038/.medmnist/octmnist.npz
Using downloaded and verified file: /home/zhl038/.medmnist/octmnist.npz
Using downloaded and verified file: /home/zhl038/.medmnist/octmnist.npz
Using downloaded and verified file: /home/zhl038/.medmnist/octmnist.npz


In [37]:
from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

loading feature extractor configuration file https://huggingface.co/google/vit-base-patch16-224-in21k/resolve/main/preprocessor_config.json from cache at /home/zhl038/.cache/huggingface/transformers/7c7f3e780b30eeeacd3962294e5154788caa6d9aa555ed6d5c2f0d2c485eba18.c322cbf30b69973d5aae6c0866f5cba198b5fe51a2fe259d2a506827ec6274bc
Feature extractor ViTFeatureExtractor {
  "do_normalize": true,
  "do_resize": true,
  "feature_extractor_type": "ViTFeatureExtractor",
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "size": 224
}



In [38]:
def process_example(example):
    inputs = feature_extractor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs

In [39]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([torch.squeeze(torch.stack((Resize(224)(x[0]),) * 3, axis = 0)) for x in batch]),
        'labels': torch.tensor([x[1] for x in batch])
    }

In [40]:

def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [41]:


labels = train_dataset.info['label']
metric = load_metric("accuracy")
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

loading configuration file https://huggingface.co/google/vit-base-patch16-224-in21k/resolve/main/config.json from cache at /home/zhl038/.cache/huggingface/transformers/7bba26dd36a6ff9f6a9b19436dec361727bea03ec70fbfa82b70628109163eaa.92995a56e2eabab0c686015c4ad8275b4f9cbd858ed228f6a08936f2c31667e7
Model config ViTConfig {
  "_name_or_path": "google/vit-base-patch16-224-in21k",
  "architectures": [
    "ViTModel"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "id2label": {
    "0": "0",
    "1": "1",
    "2": "2",
    "3": "3"
  },
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "0": "0",
    "1": "1",
    "2": "2",
    "3": "3"
  },
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.18.0"
}

loadi

In [42]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base-beans",
  per_device_train_batch_size=32,
  per_device_eval_batch_size=32,
  evaluation_strategy="steps",
  num_train_epochs=2,
  fp16=True,
  save_steps=500,
  eval_steps=500,
  logging_steps=100,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
#   report_to='tensorboard',
  load_best_model_at_end=True,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [43]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=feature_extractor,
)

Using amp half precision backend


In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

***** Running training *****
  Num examples = 97477
  Num Epochs = 2
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 3048


Step,Training Loss,Validation Loss,Accuracy
500,0.3406,0.407747,0.867891
1000,0.2631,0.254798,0.917836
1500,0.2362,0.255391,0.91562
2000,0.1972,0.221159,0.928545
2500,0.1756,0.203841,0.933807


***** Running Evaluation *****
  Num examples = 10832
  Batch size = 32
Saving model checkpoint to ./vit-base-beans/checkpoint-500
Configuration saved in ./vit-base-beans/checkpoint-500/config.json
Model weights saved in ./vit-base-beans/checkpoint-500/pytorch_model.bin
Feature extractor saved in ./vit-base-beans/checkpoint-500/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 10832
  Batch size = 32
Saving model checkpoint to ./vit-base-beans/checkpoint-1000
Configuration saved in ./vit-base-beans/checkpoint-1000/config.json
Model weights saved in ./vit-base-beans/checkpoint-1000/pytorch_model.bin
Feature extractor saved in ./vit-base-beans/checkpoint-1000/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 10832
  Batch size = 32
Saving model checkpoint to ./vit-base-beans/checkpoint-1500
Configuration saved in ./vit-base-beans/checkpoint-1500/config.json
Model weights saved in ./vit-base-beans/checkpoint-1500/pytorch_model.bin
Feature extr

In [None]:
metrics = trainer.evaluate(test_dataset['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)