## Data Preparation

In [25]:
from datasets import load_dataset
import numpy as np

import evaluate

### Load Dataset

In [26]:
dataset = load_dataset("imagefolder", data_dir="./datasets/chest_xray")

Resolving data files: 100%|██████████| 5216/5216 [00:00<00:00, 16098.77it/s]
Resolving data files: 100%|██████████| 624/624 [00:00<00:00, 312097.03it/s]


In [27]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 5216
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 16
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 624
    })
})


### Setup Labels

In [28]:
labels = labels = dataset["train"].features["label"].names
print(labels)

['NORMAL', 'PNEUMONIA']


In [29]:
label2id, id2label = dict(), dict()

for i, label in enumerate(labels):
  label2id[i] = label
  id2label[label] = i

In [30]:
print(label2id)
print(id2label)

{0: 'NORMAL', 1: 'PNEUMONIA'}
{'NORMAL': 0, 'PNEUMONIA': 1}


### Transforming Data

In [31]:
from transformers import AutoImageProcessor
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

In [32]:
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

In [33]:
size = (image_processor.size["height"], image_processor.size["width"])
resizer = RandomResizedCrop(size)
normalize = Normalize(image_processor.image_mean, image_processor.image_std)

In [34]:
_transforms = Compose([resizer, ToTensor(), normalize])

In [35]:
def transforms(examples):
  examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
  del examples["image"]
  return examples

In [36]:
dataset  = dataset.with_transform(transforms)

In [37]:
print(dataset['train'])

Dataset({
    features: ['image', 'label'],
    num_rows: 5216
})


### Preparing metrics for the model

In [38]:
accuracy = evaluate.load("accuracy")

In [39]:
def compute_metrics(eval_pred):
  predictions = np.argmax(eval_pred.predictions, axis=1)
  return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)

### Setting Up Model

In [40]:
from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained(
  "google/vit-base-patch16-224-in21k",
  num_labels=len(labels),
  id2label=id2label,
  label2id=label2id
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [41]:
model = model.to("cuda")

In [42]:
model.device

device(type='cuda', index=0)

### Training The Model

In [43]:
from transformers import TrainingArguments
from transformers import Trainer
from transformers import DefaultDataCollator

In [44]:
training_args = TrainingArguments(
  output_dir = "pneumonia_model",
  evaluation_strategy="epoch",
  save_strategy="epoch",
  learning_rate=5e-5,
  per_device_train_batch_size=12,
  per_device_eval_batch_size=12,
  num_train_epochs=1,
  load_best_model_at_end=True,
  metric_for_best_model="accuracy",
  remove_unused_columns=False,
)

In [45]:
trainer = Trainer(
  model=model,
  args=training_args,
  data_collator=DefaultDataCollator(),
  train_dataset=dataset["train"],
  eval_dataset=dataset["test"],
  tokenizer=image_processor,
  compute_metrics=compute_metrics
)

In [47]:
trainer.train()

                                                 
100%|██████████| 435/435 [03:28<00:00,  2.76it/s]

{'eval_loss': 0.39118561148643494, 'eval_accuracy': 0.8846153846153846, 'eval_runtime': 12.8126, 'eval_samples_per_second': 48.702, 'eval_steps_per_second': 4.058, 'epoch': 1.0}


100%|██████████| 435/435 [03:29<00:00,  2.08it/s]

{'train_runtime': 209.3945, 'train_samples_per_second': 24.91, 'train_steps_per_second': 2.077, 'train_loss': 0.2167022705078125, 'epoch': 1.0}





TrainOutput(global_step=435, training_loss=0.2167022705078125, metrics={'train_runtime': 209.3945, 'train_samples_per_second': 24.91, 'train_steps_per_second': 2.077, 'train_loss': 0.2167022705078125, 'epoch': 1.0})

In [49]:
trainer.evaluate()

100%|██████████| 52/52 [00:12<00:00,  4.29it/s]


{'eval_loss': 0.4036100506782532,
 'eval_accuracy': 0.8733974358974359,
 'eval_runtime': 13.0143,
 'eval_samples_per_second': 47.947,
 'eval_steps_per_second': 3.996,
 'epoch': 1.0}

In [50]:
dataset_val = dataset['validation'][:]

dataset_val['label']

[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]

In [None]:
# image = dataset_val["image"]

# dataset_val['pixel_values']

In [51]:
model.to('cpu')

for i in range(0, 16):  
  image = dataset_val["pixel_values"][i]
  
  pred = model(image[None, ...])
  
  logits = pred.logits.detach().numpy()[0]
  pred_class = np.argmax(logits)
  
  print(logits, pred_class)

[-1.6577818  1.7506535] 1
[-0.6380501   0.66631216] 1
[-2.0885642  2.1848955] 1
[-1.3128109  1.3182287] 1
[-2.1820817  2.316295 ] 1
[-2.0649233  2.1608455] 1
[ 1.012402   -0.93442696] 0
[ 1.6163981 -1.7511051] 0
[-2.1520638  2.2770422] 1
[-2.1702387  2.3037913] 1
[-2.1235359  2.2995427] 1
[-2.1038806  2.2296224] 1
[-2.1452243  2.2427146] 1
[-2.1580727  2.3234913] 1
[-2.1790838  2.2776232] 1
[-2.0695882  2.1638727] 1


In [74]:
dataset_test_part = dataset['test'].shuffle(seed=1)[:16]
dataset_test_part['label']

[0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1]

In [75]:
model.to('cpu')

for i in range(0, 16):  
  image = dataset_test_part["pixel_values"][i]
  
  pred = model(image[None, ...])
  
  logits = pred.logits.detach().numpy()[0]
  pred_class = np.argmax(logits)
  
  print(logits, pred_class)

[ 1.2149278 -1.4033091] 0
[-2.1667926  2.2876368] 1
[-2.167377   2.2972054] 1
[-1.926069   2.1151662] 1
[-2.1040063  2.204447 ] 1
[ 1.6650624 -1.7860425] 0
[-1.0361713  1.1633366] 1
[ 0.21347035 -0.13928215] 0
[ 1.5281473 -1.6495514] 0
[-2.0556026  2.25637  ] 1
[-1.5510341  1.6027886] 1
[-2.1792178  2.293909 ] 1
[-2.176153   2.3182933] 1
[-1.9805977  2.1227136] 1
[-1.0566965  1.0367076] 1
[-1.5449934  1.7151619] 1


###