## Setup Device & Environment

In [1]:
from datasets import load_dataset
import numpy as np

import torch
from transformers import AutoImageProcessor
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

import evaluate

KeyboardInterrupt: 

In [None]:
device = torch.device('cuda')

device

## Data Preparation

### Load Dataset

In [None]:
dataset = load_dataset("imagefolder", data_dir="./datasets/chest_xray")

In [None]:
print(dataset)

### Setup Labels

In [None]:
labels = labels = dataset["train"].features["label"].names
print(labels)

In [None]:
label2id, id2label = dict(), dict()

for i, label in enumerate(labels):
  label2id[i] = label
  id2label[label] = i

In [None]:
print(label2id)
print(id2label)

### Transforming Data

In [None]:
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

In [None]:
size = (image_processor.size["height"], image_processor.size["width"])
resizer = RandomResizedCrop(size)
normalize = Normalize(image_processor.image_mean, image_processor.image_std)

In [None]:
_transforms = Compose([resizer, ToTensor(), normalize])

In [None]:
def transforms(examples):
  examples["image"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
  # del examples["image"]
  return examples

In [None]:
dataset = dataset.with_transform(transforms)

In [None]:
print(dataset)

### Preparing metrics for the model

In [None]:
accuracy = evaluate.load("accuracy")

In [None]:
def compute_metrics(eval_pred):
  predictions = np.argmax(eval_pred.predictions, axis=1)
  return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)

## Setting Up Model

In [None]:
from torch import nn
from transformers import ViTForImageClassification

class MyCompositeModel(nn.Module):
  def __init__(self):
    super(MyCompositeModel, self).__init__()
    
    self.pretrained = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
    self.my_new_layers = nn.Sequential(
      nn.Linear(1000, 100),
      nn.ReLU(),
      nn.Linear(100, 2)
    )
  
  def forward(self, x):
    x = self.pretrained(x).logits
    x = self.my_new_layers(x)
    return x
  
model = MyCompositeModel()
model = model.to(device)

## Training Setup

### Data Loaders

In [None]:
training_loader = torch.utils.data.DataLoader(dataset['train'], batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset['test'], batch_size=16, shuffle=True)
validation_loader = torch.utils.data.DataLoader(dataset['test'], batch_size=16, shuffle=True)

### Loss Function

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()

### Optimizer

In [None]:
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.my_new_layers.parameters(), lr=0.001, momentum=0.9)

In [None]:
from torchmetrics import Accuracy

accuracy_metric = Accuracy(task='multiclass', num_classes=2).to(device)

### Single Epoch Training Function

In [None]:
def train_one_epoch(epoch_index, tb_writer):
  running_loss = 0.
  running_accuracy = 0.
  last_loss = 0.

  # Here, we use enumerate(training_loader) instead of
  # iter(training_loader) so that we can track the batch
  # index and do some intra-epoch reporting
  for i, data in enumerate(training_loader):
  
    # Every data instance is an input + label pair
    inputs = data['image'].to(device)
    labels = data['label'].to(device)
  
    # Zero your gradients for every batch!
    optimizer.zero_grad()

    # Make predictions for this batch
    outputs = model(inputs)

    # Compute the loss and its gradients
    loss = loss_fn(outputs, labels)
    loss.backward()
    
    training_accuracy = accuracy_metric(outputs, labels)

    # Adjust learning weights
    optimizer.step()

    # Gather data and report
    running_loss += loss.item()
    running_accuracy += training_accuracy
    
    # print('batch {}', i)
    
    if i % 10 == 0:
      last_loss = running_loss / 10 # loss per batch
      last_accuracy = running_accuracy / 10 # accuracy per batch
      print('  batch {} loss: {} training_accuracy: {}'.format(i + 1, last_loss, last_accuracy))
      tb_x = epoch_index * len(training_loader) + i + 1
      tb_writer.add_scalar('Loss/train', last_loss, tb_x)
      running_loss = 0.
      running_accuracy = 0.
  
  return last_loss

### Many Epochs Training Function

In [None]:
def train_many_epochs(epochs, writer):
  best_vloss = 1_000_000.

  for epoch_number in range(epochs):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    running_vloss = 0.0
    running_vacc = 0.0
    
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()
    
    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
      for i, vdata in enumerate(validation_loader):
        vinputs = vdata['image'].to(device)
        vlabels = vdata['label'].to(device)
        voutputs = model(vinputs)
        
        vloss = loss_fn(voutputs, vlabels)
        running_vloss += vloss
        
        vacc = accuracy_metric(voutputs, vlabels)
        running_vacc += vacc

    avg_vloss = running_vloss / (i + 1)
    avg_vacc = running_vacc / (i + 1)
    print('LOSS train {} valid {} ACCURACY validation {}'.format(avg_loss, avg_vloss, avg_vacc))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars(
      'Training vs. Validation Loss',
      { 'Training' : avg_loss, 'Validation' : avg_vloss },
      epoch_number + 1
    )
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
      best_vloss = avg_vloss
      model_path = 'model_{}_{}'.format(timestamp, epoch_number)
      torch.save(model.state_dict(), model_path)

    epoch_number += 1

## Training Model

In [None]:
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
sum_writer = SummaryWriter('runs/chest_trainer_{}'.format(timestamp))

In [None]:
train_many_epochs(2, sum_writer)

###