## Setup Device & Environment

In [1]:
from datasets import load_dataset
import numpy as np

import torch
from transformers import AutoImageProcessor
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda')

device

device(type='cuda')

## Data Preparation

### Load Dataset

In [3]:
raw_dataset = load_dataset("imagefolder", data_dir="./datasets/chest_xray")

Resolving data files: 100%|██████████| 5216/5216 [00:00<00:00, 15691.73it/s]
Resolving data files: 100%|██████████| 624/624 [00:00<00:00, 208230.22it/s]


In [4]:
print(raw_dataset)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 5216
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 16
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 624
    })
})


### Setup Labels

In [5]:
labels = labels = raw_dataset["train"].features["label"].names
print(labels)

['NORMAL', 'PNEUMONIA']


In [6]:
label2id, id2label = dict(), dict()

for i, label in enumerate(labels):
  label2id[i] = label
  id2label[label] = i

In [7]:
print(label2id)
print(id2label)

{0: 'NORMAL', 1: 'PNEUMONIA'}
{'NORMAL': 0, 'PNEUMONIA': 1}


### Transforming Data

In [8]:
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

In [9]:
size = (image_processor.size["height"], image_processor.size["width"])
resizer = RandomResizedCrop(size)
normalize = Normalize(image_processor.image_mean, image_processor.image_std)

In [10]:
_transforms = Compose([resizer, ToTensor(), normalize])

In [11]:
def transforms(examples):
  examples["image"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
  # del examples["image"]
  return examples

In [12]:
dataset = raw_dataset.with_transform(transforms)

In [13]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 5216
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 16
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 624
    })
})


### Preparing metrics for the model

In [14]:
accuracy = evaluate.load("accuracy")

In [15]:
def compute_metrics(eval_pred):
  predictions = np.argmax(eval_pred.predictions, axis=1)
  return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)

## Setting Up Model

In [16]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

In [17]:
from torch import nn
from transformers import ViTForImageClassification

class MyCompositeModel(nn.Module):
  def __init__(self):
    super(MyCompositeModel, self).__init__()
    
    self.pretrained = ViTForImageClassification.from_pretrained(
      "google/vit-base-patch16-224",
      num_labels=1000
    )
    self.my_new_layers = nn.Sequential(
      nn.LayerNorm(1000),
      nn.Linear(1000, 64),
      nn.ReLU(),
      nn.Linear(64, 128),
      nn.ReLU(),
      nn.Linear(128, 64),
      nn.ReLU(),
      nn.Linear(64, 2)
    )
  
  def forward(self, x):
    x = self.pretrained(x).logits
    x = self.my_new_layers(x)
    return x
  
model = MyCompositeModel()
model = model.to(device)

get_n_params(model.my_new_layers)

82770

## Training Setup

### Data Loaders

In [18]:
training_loader = torch.utils.data.DataLoader(dataset['train'], batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset['test'], batch_size=16, shuffle=True)
validation_loader = torch.utils.data.DataLoader(dataset['test'], batch_size=16, shuffle=True)

### Loss Function

In [19]:
loss_fn = torch.nn.CrossEntropyLoss()

### Optimizer

In [20]:
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.my_new_layers.parameters(), lr=0.001, momentum=0.9)

In [21]:
from torchmetrics import Accuracy

accuracy_metric = Accuracy(task='multiclass', num_classes=2).to(device)

### Single Epoch Training Function

In [22]:
def train_one_epoch(epoch_index, tb_writer, logging_frequency):
  running_loss = 0.
  running_accuracy = 0.
  last_loss = 0.

  # Here, we use enumerate(training_loader) instead of
  # iter(training_loader) so that we can track the batch
  # index and do some intra-epoch reporting
  for i, data in enumerate(training_loader):
  
    # Every data instance is an input + label pair
    inputs = data['image'].to(device)
    labels = data['label'].to(device)
  
    # Zero your gradients for every batch!
    optimizer.zero_grad()

    # Make predictions for this batch
    outputs = model(inputs)

    # Compute the loss and its gradients
    loss = loss_fn(outputs, labels)
    loss.backward()
    
    training_accuracy = accuracy_metric(outputs, labels)

    # Adjust learning weights
    optimizer.step()

    # Gather data and report
    running_loss += loss.item()
    running_accuracy += training_accuracy
    
    # print('batch {}', i)
    
    if (i+1) % logging_frequency == 0:
      last_loss = running_loss / logging_frequency # loss per batch
      last_accuracy = running_accuracy / logging_frequency # accuracy per batch
      print('  batch {} loss: {} training_accuracy: {}'.format(i + 1, last_loss, last_accuracy))
      tb_x = epoch_index * len(training_loader) + i + 1
      tb_writer.add_scalar('Loss/train', last_loss, tb_x)
      running_loss = 0.
      running_accuracy = 0.
  
  return last_loss

### Many Epochs Training Function

In [23]:
def train_many_epochs(epochs, writer, logging_frequency):
  best_vloss = 1_000_000.

  for epoch_number in range(epochs):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer, logging_frequency)

    running_vloss = 0.0
    running_vacc = 0.0
    
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()
    
    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
      for i, vdata in enumerate(validation_loader):
        vinputs = vdata['image'].to(device)
        vlabels = vdata['label'].to(device)
        voutputs = model(vinputs)
        
        vloss = loss_fn(voutputs, vlabels)
        running_vloss += vloss
        
        vacc = accuracy_metric(voutputs, vlabels)
        running_vacc += vacc

    avg_vloss = running_vloss / (i + 1)
    avg_vacc = running_vacc / (i + 1)
    print('LOSS train {} valid {} ACCURACY validation {}'.format(avg_loss, avg_vloss, avg_vacc))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars(
      'Training vs. Validation Loss',
      { 'Training' : avg_loss, 'Validation' : avg_vloss },
      epoch_number + 1
    )
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
      best_vloss = avg_vloss
      model_path = 'model_{}_{}'.format(timestamp, epoch_number)
      torch.save(model.state_dict(), model_path)

    epoch_number += 1

## Training Model

In [24]:
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
sum_writer = SummaryWriter('runs/chest_trainer_{}'.format(timestamp))

In [25]:
train_many_epochs(2, sum_writer, 50)

EPOCH 1:
  batch 50 loss: 0.6328732705116272 training_accuracy: 0.7425000071525574
  batch 100 loss: 0.5719836354255676 training_accuracy: 0.7425000071525574
  batch 150 loss: 0.5658124047517776 training_accuracy: 0.7387499809265137
  batch 200 loss: 0.5604046785831451 training_accuracy: 0.731249988079071
  batch 250 loss: 0.538584772348404 training_accuracy: 0.7400000095367432
  batch 300 loss: 0.5009221410751343 training_accuracy: 0.7674999833106995
LOSS train 0.5009221410751343 valid 0.6239355802536011 ACCURACY validation 0.625
EPOCH 2:
  batch 50 loss: 0.5129212892055511 training_accuracy: 0.7137500047683716
  batch 100 loss: 0.4810864895582199 training_accuracy: 0.7299999594688416
  batch 150 loss: 0.44143265306949614 training_accuracy: 0.8025000095367432
  batch 200 loss: 0.38980439096689223 training_accuracy: 0.8462499976158142
  batch 250 loss: 0.38685330003499985 training_accuracy: 0.8424999713897705
  batch 300 loss: 0.34139414593577383 training_accuracy: 0.85999995470047
LOS

In [26]:
model = model.to('cpu')
model.eval()

for i, image in enumerate(raw_dataset['validation']):
  input = dataset['validation'][i]['image']
  label_true = dataset['validation'][i]['label']
  logits = model(input[None, ...]).detach().numpy()
  label_pred = np.argmax(logits)
  print(i, " - " ,label_pred, label_true)

0  -  1 0
1  -  1 0
2  -  1 0
3  -  1 0
4  -  0 0
5  -  1 0
6  -  1 0
7  -  1 0
8  -  1 1
9  -  1 1
10  -  1 1
11  -  1 1
12  -  1 1
13  -  1 1
14  -  1 1
15  -  1 1


In [28]:
model = model.to(device)
train_many_epochs(4, sum_writer, 25)

EPOCH 1:
  batch 25 loss: 0.27799043387174605 training_accuracy: 0.8774999976158142
  batch 50 loss: 0.28583341389894484 training_accuracy: 0.8924999833106995
  batch 75 loss: 0.3371810042858124 training_accuracy: 0.8424999713897705
  batch 100 loss: 0.25420293360948565 training_accuracy: 0.8949999809265137
  batch 125 loss: 0.267878472507 training_accuracy: 0.8924999833106995
  batch 150 loss: 0.27907154023647307 training_accuracy: 0.8849999904632568
  batch 175 loss: 0.23940888077020644 training_accuracy: 0.8949999809265137
  batch 200 loss: 0.3184510940313339 training_accuracy: 0.8624999523162842
  batch 225 loss: 0.32735574573278425 training_accuracy: 0.8549999594688416
  batch 250 loss: 0.24683617442846298 training_accuracy: 0.8849999904632568
  batch 275 loss: 0.24874511897563933 training_accuracy: 0.8999999761581421
  batch 300 loss: 0.275392587184906 training_accuracy: 0.8725000023841858
  batch 325 loss: 0.2588343775272369 training_accuracy: 0.9074999690055847
LOSS train 0.258

In [73]:
model = model.to('cpu')
model.eval()

cum_error = 0

testing_fragment = dataset['test'].shuffle(seed=1)[:50]
inputs = testing_fragment['image']
labels_true = testing_fragment['label']

for i in range(len(inputs)):
  input = inputs[i]
  label_true = labels_true[i]
  logits = model(input[None, ...]).detach().numpy()
  label_pred = np.argmax(logits)
  
  cum_error += abs(label_pred - label_true)
  print('index {}: true/predicted: {}/{}'.format(i, label_true, label_pred))
  
error = cum_error / 50

accuracy = 1 - error

print('testing accuracy: {}'.format(accuracy))

index 0: true/predicted: 0/0
index 1: true/predicted: 1/1
index 2: true/predicted: 1/1
index 3: true/predicted: 1/1
index 4: true/predicted: 1/1
index 5: true/predicted: 0/0
index 6: true/predicted: 0/0
index 7: true/predicted: 0/1
index 8: true/predicted: 0/0
index 9: true/predicted: 1/1
index 10: true/predicted: 0/0
index 11: true/predicted: 1/1
index 12: true/predicted: 1/1
index 13: true/predicted: 1/1
index 14: true/predicted: 0/0
index 15: true/predicted: 1/1
index 16: true/predicted: 1/1
index 17: true/predicted: 1/1
index 18: true/predicted: 1/1
index 19: true/predicted: 0/1
index 20: true/predicted: 1/1
index 21: true/predicted: 1/1
index 22: true/predicted: 0/1
index 23: true/predicted: 1/1
index 24: true/predicted: 1/1
index 25: true/predicted: 1/1
index 26: true/predicted: 0/0
index 27: true/predicted: 0/0
index 28: true/predicted: 0/0
index 29: true/predicted: 0/1
index 30: true/predicted: 1/1
index 31: true/predicted: 1/1
index 32: true/predicted: 1/1
index 33: true/predi

###