In [2]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.transforms import Resize
from transformers import ViTForImageClassification
from datasets import load_metric
import medmnist
from medmnist import INFO, Evaluator

In [3]:
data_flag = 'octmnist'
download = True

NUM_EPOCHS = 3
BATCH_SIZE = 128
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

In [4]:
# preprocessing
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download)
eval_dataset = DataClass(split='val', transform=data_transform, download=download)
test_dataset = DataClass(split='test', transform=data_transform, download=download)

pil_dataset = DataClass(split='train', download=download)

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)

Using downloaded and verified file: /home/zhl038/.medmnist/octmnist.npz
Using downloaded and verified file: /home/zhl038/.medmnist/octmnist.npz
Using downloaded and verified file: /home/zhl038/.medmnist/octmnist.npz
Using downloaded and verified file: /home/zhl038/.medmnist/octmnist.npz


In [5]:
from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

In [6]:
def process_example(example):
    inputs = feature_extractor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs

In [7]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([torch.squeeze(torch.stack((Resize(224)(x[0]),) * 3, axis = 0)) for x in batch]),
        'labels': torch.tensor([x[1] for x in batch])
    }

In [8]:

def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [9]:


labels = train_dataset.info['label']
metric = load_metric("accuracy")
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base-beans",
  per_device_train_batch_size=32,
  per_device_eval_batch_size=32,
  evaluation_strategy="steps",
  num_train_epochs=NUM_EPOCHS,
  fp16=True,
  save_steps=500,
  eval_steps=500,
  logging_steps=100,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
#   report_to='tensorboard',
  load_best_model_at_end=True,
)

In [11]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=feature_extractor,
)

Using amp half precision backend


In [12]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

***** Running training *****
  Num examples = 97477
  Num Epochs = 3
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 4572
  'labels': torch.tensor([x[1] for x in batch])


Step,Training Loss,Validation Loss,Accuracy
500,0.3362,0.338959,0.885524
1000,0.2657,0.261041,0.916728
1500,0.2509,0.276622,0.907866
2000,0.2086,0.22109,0.927806
2500,0.1984,0.20441,0.933623
3000,0.1865,0.197979,0.936115
3500,0.1336,0.199917,0.937131
4000,0.1241,0.177992,0.941562
4500,0.103,0.172229,0.945716


***** Running Evaluation *****
  Num examples = 10832
  Batch size = 32
Saving model checkpoint to ./vit-base-beans/checkpoint-500
Configuration saved in ./vit-base-beans/checkpoint-500/config.json
Model weights saved in ./vit-base-beans/checkpoint-500/pytorch_model.bin
Feature extractor saved in ./vit-base-beans/checkpoint-500/preprocessor_config.json
Deleting older checkpoint [vit-base-beans/checkpoint-2500] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 10832
  Batch size = 32
Saving model checkpoint to ./vit-base-beans/checkpoint-1000
Configuration saved in ./vit-base-beans/checkpoint-1000/config.json
Model weights saved in ./vit-base-beans/checkpoint-1000/pytorch_model.bin
Feature extractor saved in ./vit-base-beans/checkpoint-1000/preprocessor_config.json
Deleting older checkpoint [vit-base-beans/checkpoint-3000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 10832
  Batch size = 32
Saving model checkpoint to ./vit-base-

***** train metrics *****
  epoch                    =           3.0
  total_flos               = 21105135067GF
  train_loss               =         0.214
  train_runtime            =    0:47:34.07
  train_samples_per_second =       102.461
  train_steps_per_second   =         1.602


In [13]:
metrics = trainer.evaluate(test_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** Running Evaluation *****
  Num examples = 1000
  Batch size = 32


***** eval metrics *****
  epoch                   =        3.0
  eval_accuracy           =      0.804
  eval_loss               =     0.6555
  eval_runtime            = 0:00:04.41
  eval_samples_per_second =    226.341
  eval_steps_per_second   =      3.621
