## Data Preparation

In [1]:
from datasets import load_dataset
import numpy as np

import evaluate

  from .autonotebook import tqdm as notebook_tqdm


### Load Dataset

In [2]:
dataset = load_dataset("imagefolder", data_dir="./datasets/chest_xray")

Resolving data files: 100%|██████████| 5216/5216 [00:00<00:00, 17045.74it/s]
Resolving data files: 100%|██████████| 624/624 [00:00<00:00, 312320.49it/s]


In [3]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 5216
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 16
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 624
    })
})


### Setup Labels

In [4]:
labels = labels = dataset["train"].features["label"].names
print(labels)

['NORMAL', 'PNEUMONIA']


In [5]:
label2id, id2label = dict(), dict()

for i, label in enumerate(labels):
  label2id[i] = label
  id2label[label] = i

In [6]:
print(label2id)
print(id2label)

{0: 'NORMAL', 1: 'PNEUMONIA'}
{'NORMAL': 0, 'PNEUMONIA': 1}


### Transforming Data

In [7]:
from transformers import AutoImageProcessor
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

In [8]:
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

In [9]:
size = (image_processor.size["height"], image_processor.size["width"])
resizer = RandomResizedCrop(size)
normalize = Normalize(image_processor.image_mean, image_processor.image_std)

In [10]:
_transforms = Compose([resizer, ToTensor(), normalize])

In [11]:
def transforms(examples):
  examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
  del examples["image"]
  return examples

In [12]:
dataset  = dataset.with_transform(transforms)

In [13]:
print(dataset['train'])

Dataset({
    features: ['image', 'label'],
    num_rows: 5216
})


### Preparing metrics for the model

In [14]:
accuracy = evaluate.load("accuracy")

In [15]:
def compute_metrics(eval_pred):
  predictions = np.argmax(eval_pred.predictions, axis=1)
  return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)

### Setting Up Model

In [16]:
from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained(
  "google/vit-base-patch16-224-in21k",
  num_labels=len(labels),
  id2label=id2label,
  label2id=label2id
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
model = model.to("cuda")

In [18]:
model.device

device(type='cuda', index=0)

### Training The Model

In [19]:
from transformers import TrainingArguments
from transformers import Trainer
from transformers import DefaultDataCollator

In [20]:
training_args = TrainingArguments(
  output_dir = "pneumonia_model",
  evaluation_strategy="epoch",
  save_strategy="epoch",
  learning_rate=5e-5,
  per_device_train_batch_size=12,
  per_device_eval_batch_size=12,
  num_train_epochs=10,
  load_best_model_at_end=True,
  metric_for_best_model="accuracy",
  remove_unused_columns=False,
)

In [21]:
trainer = Trainer(
  model=model,
  args=training_args,
  data_collator=DefaultDataCollator(),
  train_dataset=dataset["train"],
  eval_dataset=dataset["test"],
  tokenizer=image_processor,
  compute_metrics=compute_metrics
)

In [22]:
trainer.train()

                                                  
 10%|█         | 435/4350 [03:13<23:10,  2.82it/s]

{'eval_loss': 0.37469497323036194, 'eval_accuracy': 0.8798076923076923, 'eval_runtime': 14.3549, 'eval_samples_per_second': 43.469, 'eval_steps_per_second': 3.622, 'epoch': 1.0}


 11%|█▏        | 500/4350 [03:40<26:16,  2.44it/s]  

{'loss': 0.248, 'learning_rate': 4.4252873563218394e-05, 'epoch': 1.15}


                                                  
 20%|██        | 870/4350 [06:20<19:38,  2.95it/s]

{'eval_loss': 0.31853964924812317, 'eval_accuracy': 0.9150641025641025, 'eval_runtime': 11.5778, 'eval_samples_per_second': 53.896, 'eval_steps_per_second': 4.491, 'epoch': 2.0}


 23%|██▎       | 1000/4350 [07:15<21:51,  2.55it/s] 

{'loss': 0.1716, 'learning_rate': 3.850574712643678e-05, 'epoch': 2.3}


                                                   
 30%|███       | 1305/4350 [09:32<16:52,  3.01it/s]

{'eval_loss': 0.31958839297294617, 'eval_accuracy': 0.8990384615384616, 'eval_runtime': 11.9235, 'eval_samples_per_second': 52.334, 'eval_steps_per_second': 4.361, 'epoch': 3.0}


 34%|███▍      | 1500/4350 [10:48<18:50,  2.52it/s]  

{'loss': 0.1357, 'learning_rate': 3.275862068965517e-05, 'epoch': 3.45}


                                                   
 40%|████      | 1740/4350 [12:34<14:34,  2.99it/s]

{'eval_loss': 0.25209927558898926, 'eval_accuracy': 0.9246794871794872, 'eval_runtime': 11.3581, 'eval_samples_per_second': 54.939, 'eval_steps_per_second': 4.578, 'epoch': 4.0}


 46%|████▌     | 2000/4350 [14:16<14:42,  2.66it/s]  

{'loss': 0.124, 'learning_rate': 2.7011494252873566e-05, 'epoch': 4.6}


                                                   
 50%|█████     | 2175/4350 [15:36<11:25,  3.17it/s]

{'eval_loss': 0.32331418991088867, 'eval_accuracy': 0.9038461538461539, 'eval_runtime': 11.463, 'eval_samples_per_second': 54.436, 'eval_steps_per_second': 4.536, 'epoch': 5.0}


 57%|█████▋    | 2500/4350 [17:44<12:04,  2.55it/s]  

{'loss': 0.1144, 'learning_rate': 2.1264367816091954e-05, 'epoch': 5.75}


                                                   
 60%|██████    | 2610/4350 [18:38<08:54,  3.26it/s]

{'eval_loss': 0.27962639927864075, 'eval_accuracy': 0.8974358974358975, 'eval_runtime': 11.4941, 'eval_samples_per_second': 54.289, 'eval_steps_per_second': 4.524, 'epoch': 6.0}


 69%|██████▉   | 3000/4350 [21:11<09:06,  2.47it/s]  

{'loss': 0.0987, 'learning_rate': 1.5517241379310346e-05, 'epoch': 6.9}


                                                   
 70%|███████   | 3045/4350 [21:39<06:39,  3.26it/s]

{'eval_loss': 0.214422345161438, 'eval_accuracy': 0.9342948717948718, 'eval_runtime': 11.2741, 'eval_samples_per_second': 55.348, 'eval_steps_per_second': 4.612, 'epoch': 7.0}


                                                     
 80%|████████  | 3480/4350 [24:41<04:37,  3.14it/s]

{'eval_loss': 0.2941562533378601, 'eval_accuracy': 0.9118589743589743, 'eval_runtime': 11.402, 'eval_samples_per_second': 54.727, 'eval_steps_per_second': 4.561, 'epoch': 8.0}


 80%|████████  | 3500/4350 [24:49<05:23,  2.62it/s]

{'loss': 0.0963, 'learning_rate': 9.770114942528738e-06, 'epoch': 8.05}


                                                   
 90%|█████████ | 3915/4350 [27:42<02:18,  3.13it/s]

{'eval_loss': 0.34147024154663086, 'eval_accuracy': 0.907051282051282, 'eval_runtime': 11.3741, 'eval_samples_per_second': 54.861, 'eval_steps_per_second': 4.572, 'epoch': 9.0}


 92%|█████████▏| 4000/4350 [28:17<02:14,  2.60it/s]

{'loss': 0.0817, 'learning_rate': 4.022988505747127e-06, 'epoch': 9.2}


                                                      
100%|██████████| 4350/4350 [45:06<00:00,  2.70it/s]

{'eval_loss': 0.2847822308540344, 'eval_accuracy': 0.9118589743589743, 'eval_runtime': 13.1616, 'eval_samples_per_second': 47.411, 'eval_steps_per_second': 3.951, 'epoch': 10.0}


100%|██████████| 4350/4350 [45:07<00:00,  1.61it/s]

{'train_runtime': 2707.5956, 'train_samples_per_second': 19.264, 'train_steps_per_second': 1.607, 'train_loss': 0.1291817969837408, 'epoch': 10.0}





TrainOutput(global_step=4350, training_loss=0.1291817969837408, metrics={'train_runtime': 2707.5956, 'train_samples_per_second': 19.264, 'train_steps_per_second': 1.607, 'train_loss': 0.1291817969837408, 'epoch': 10.0})

In [23]:
dataset_val = dataset['validation'][:]

dataset_val['label']

[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]

In [24]:
# image = dataset_val["image"]

# dataset_val['pixel_values']

In [25]:
model.to('cpu')

for i in range(0, 16):  
  image = dataset_val["pixel_values"][i]
  
  pred = model(image[None, ...])
  
  logits = pred.logits.detach().numpy()[0]
  pred_class = np.argmax(logits)
  
  print(logits, pred_class)

[ 1.0386438 -0.986711 ] 0
[-0.13800992  0.19367892] 1
[ 0.20198247 -0.15564975] 0
[-0.95488036  0.848078  ] 1
[-1.2200586  1.0992448] 1
[-0.5879604  0.5493035] 1
[ 1.4296321 -1.0805426] 0
[ 2.228238  -2.0731947] 0
[-2.9751143  2.8376467] 1
[-2.969958  2.867535] 1
[-1.8585757  1.7708261] 1
[-2.1129858  2.0069702] 1
[-2.2598677  2.1570873] 1
[-2.9136753  2.78434  ] 1
[-2.8000977  2.695844 ] 1
[-2.7378416  2.681501 ] 1


###