## Setup Device & Environment

In [30]:
from datasets import load_dataset
import numpy as np

import torch
from transformers import AutoImageProcessor
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

import evaluate

In [31]:
device = torch.device('cuda')

device

device(type='cuda')

## Data Preparation

### Load Dataset

In [32]:
raw_dataset = load_dataset("imagefolder", data_dir="./datasets/chest_xray")

Resolving data files: 100%|██████████| 5216/5216 [00:00<00:00, 23723.41it/s]
Resolving data files: 100%|██████████| 624/624 [00:00<00:00, 312357.76it/s]


In [33]:
print(raw_dataset)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 5216
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 16
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 624
    })
})


### Setup Labels

In [34]:
labels = labels = raw_dataset["train"].features["label"].names
print(labels)

['NORMAL', 'PNEUMONIA']


In [35]:
label2id, id2label = dict(), dict()

for i, label in enumerate(labels):
  label2id[i] = label
  id2label[label] = i

In [36]:
print(label2id)
print(id2label)

{0: 'NORMAL', 1: 'PNEUMONIA'}
{'NORMAL': 0, 'PNEUMONIA': 1}


### Transforming Data

In [37]:
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

In [38]:
size = (image_processor.size["height"], image_processor.size["width"])
resizer = RandomResizedCrop(size)
normalize = Normalize(image_processor.image_mean, image_processor.image_std)

In [39]:
_transforms = Compose([resizer, ToTensor(), normalize])

In [40]:
def transforms(examples):
  examples["image"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
  # del examples["image"]
  return examples

In [41]:
dataset = raw_dataset.with_transform(transforms)

In [13]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 5216
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 16
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 624
    })
})


### Preparing metrics for the model

In [14]:
accuracy = evaluate.load("accuracy")

In [15]:
def compute_metrics(eval_pred):
  predictions = np.argmax(eval_pred.predictions, axis=1)
  return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)

## Setting Up Model

In [16]:
from torch import nn
from transformers import ViTForImageClassification

class MyCompositeModel(nn.Module):
  def __init__(self):
    super(MyCompositeModel, self).__init__()
    
    self.pretrained = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
    self.my_new_layers = nn.Sequential(
      nn.Linear(1000, 100),
      nn.ReLU(),
      nn.Linear(100, 2)
    )
  
  def forward(self, x):
    x = self.pretrained(x).logits
    x = self.my_new_layers(x)
    return x
  
model = MyCompositeModel()
model = model.to(device)

## Training Setup

### Data Loaders

In [17]:
training_loader = torch.utils.data.DataLoader(dataset['train'], batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset['test'], batch_size=16, shuffle=True)
validation_loader = torch.utils.data.DataLoader(dataset['test'], batch_size=16, shuffle=True)

### Loss Function

In [18]:
loss_fn = torch.nn.CrossEntropyLoss()

### Optimizer

In [19]:
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.my_new_layers.parameters(), lr=0.001, momentum=0.9)

In [20]:
from torchmetrics import Accuracy

accuracy_metric = Accuracy(task='multiclass', num_classes=2).to(device)

### Single Epoch Training Function

In [21]:
def train_one_epoch(epoch_index, tb_writer):
  running_loss = 0.
  running_accuracy = 0.
  last_loss = 0.

  # Here, we use enumerate(training_loader) instead of
  # iter(training_loader) so that we can track the batch
  # index and do some intra-epoch reporting
  for i, data in enumerate(training_loader):
  
    # Every data instance is an input + label pair
    inputs = data['image'].to(device)
    labels = data['label'].to(device)
  
    # Zero your gradients for every batch!
    optimizer.zero_grad()

    # Make predictions for this batch
    outputs = model(inputs)

    # Compute the loss and its gradients
    loss = loss_fn(outputs, labels)
    loss.backward()
    
    training_accuracy = accuracy_metric(outputs, labels)

    # Adjust learning weights
    optimizer.step()

    # Gather data and report
    running_loss += loss.item()
    running_accuracy += training_accuracy
    
    # print('batch {}', i)
    
    if i % 10 == 0:
      last_loss = running_loss / 10 # loss per batch
      last_accuracy = running_accuracy / 10 # accuracy per batch
      print('  batch {} loss: {} training_accuracy: {}'.format(i + 1, last_loss, last_accuracy))
      tb_x = epoch_index * len(training_loader) + i + 1
      tb_writer.add_scalar('Loss/train', last_loss, tb_x)
      running_loss = 0.
      running_accuracy = 0.
  
  return last_loss

### Many Epochs Training Function

In [22]:
def train_many_epochs(epochs, writer):
  best_vloss = 1_000_000.

  for epoch_number in range(epochs):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    running_vloss = 0.0
    running_vacc = 0.0
    
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()
    
    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
      for i, vdata in enumerate(validation_loader):
        vinputs = vdata['image'].to(device)
        vlabels = vdata['label'].to(device)
        voutputs = model(vinputs)
        
        vloss = loss_fn(voutputs, vlabels)
        running_vloss += vloss
        
        vacc = accuracy_metric(voutputs, vlabels)
        running_vacc += vacc

    avg_vloss = running_vloss / (i + 1)
    avg_vacc = running_vacc / (i + 1)
    print('LOSS train {} valid {} ACCURACY validation {}'.format(avg_loss, avg_vloss, avg_vacc))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars(
      'Training vs. Validation Loss',
      { 'Training' : avg_loss, 'Validation' : avg_vloss },
      epoch_number + 1
    )
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
      best_vloss = avg_vloss
      model_path = 'model_{}_{}'.format(timestamp, epoch_number)
      torch.save(model.state_dict(), model_path)

    epoch_number += 1

## Training Model

In [23]:
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
sum_writer = SummaryWriter('runs/chest_trainer_{}'.format(timestamp))

In [24]:
train_many_epochs(10, sum_writer)

EPOCH 1:
  batch 1 loss: 0.07893961071968078 training_accuracy: 0.01875000074505806
  batch 11 loss: 0.6197352081537246 training_accuracy: 0.637499988079071
  batch 21 loss: 0.5389519900083541 training_accuracy: 0.762499988079071
  batch 31 loss: 0.5545953154563904 training_accuracy: 0.75
  batch 41 loss: 0.5299925386905671 training_accuracy: 0.762499988079071
  batch 51 loss: 0.5513634532690048 training_accuracy: 0.7250000238418579
  batch 61 loss: 0.541490113735199 training_accuracy: 0.737500011920929
  batch 71 loss: 0.5416361808776855 training_accuracy: 0.7125000357627869
  batch 81 loss: 0.5312318652868271 training_accuracy: 0.7250000238418579
  batch 91 loss: 0.4416340708732605 training_accuracy: 0.8062500357627869
  batch 101 loss: 0.5152628064155579 training_accuracy: 0.71875
  batch 111 loss: 0.5291282564401627 training_accuracy: 0.7250000238418579
  batch 121 loss: 0.45343744456768037 training_accuracy: 0.8187500238418579
  batch 131 loss: 0.5031083509325981 training_accuracy

In [106]:
model = model.to('cpu')

for i, image in enumerate(raw_dataset['validation']):
  input = dataset['validation'][i]['image']
  label_true = dataset['validation'][i]['label']
  logits = model(input[None, ...]).detach().numpy()
  label_pred = np.argmax(logits)
  print(i, " - " ,label_pred, label_true)

0  -  1 0
1  -  1 0
2  -  1 0
3  -  1 0
4  -  1 0
5  -  1 0
6  -  0 0
7  -  0 0
8  -  1 1
9  -  1 1
10  -  1 1
11  -  1 1
12  -  1 1
13  -  1 1
14  -  1 1
15  -  1 1


###