In [None]:
# Notes: This was ran on google cloud in a collaborative notebook
# I have moved the notebook into the GitHub, updated the data code accordingly, and added the data to the repository so that the model can be built locally
# Also, note this currently runs on cpu not gpu. You can set it to gpu by specifying a device (your gpu) for the dataloader/model

In [None]:
import numpy as np
import pandas as pd
import argparse
import os
import random
import shutil
import time
import warnings
import PIL

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

from pytorch_pretrained_vit import ViT, load_pretrained_weights

import torchxrayvision as xrv

from sklearn import model_selection

In [None]:
random.seed(1234)
torch.manual_seed(1234)
batch_size = 32
num_workers = 8 # change this depending on your hardware specs

In [None]:
# This downloads and imports the pretrained ViT model
model_name = 'B_16'

model = ViT(model_name, pretrained=True)

model.fc = torch.nn.Linear(in_features=768, out_features=2, bias=True)
image_size = model.image_size[0]


In [None]:

import pandas as pd
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder

t = transforms.Compose([
    transforms.Resize(255)
    ,transforms.CenterCrop(224)
    ,transforms.ToTensor()
    ,transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))
])

root_path = './chest-xray/'
data = ImageFolder(root= root_path, transform = t)

index = list(range(0,1684)) # all samples

targets = pd.DataFrame(data.targets, columns=['target'])

targets = targets.iloc[index]
data = torch.utils.data.Subset(data,index)

train_size = int(0.8 * len(data))
test_size = len(data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])

print('--- Number of Samples ---')
print('--- 0: No COVID, 1: COVID ---')
print(targets.groupby('target').size())

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    num_workers=num_workers, 
    pin_memory=False)

test_loader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    shuffle=False,
    num_workers=num_workers, 
    pin_memory=False)

print("batches:", len(train_loader), len(test_loader))

In [None]:
# Optional cell, this shows an example image with the transformations applied
import matplotlib.pyplot as plt

im, lb = next(iter(train_loader))
print (lb)
im = im[0].numpy()
plt.figure(figsize=(15, 7))
plt.axis('off')
plt.imshow(np.transpose(im, (1, 2, 0)))
plt.title("test")
plt.show()

In [None]:
# This sets the class imbalance weights because our data is roughly 7:1 negative to positive
classes = targets.groupby('target').size()
class_imbalance_weights = torch.tensor((max(classes) / classes[0], max(classes) / classes[1]), dtype=torch.float)
criterion = nn.CrossEntropyLoss(weight = class_imbalance_weights)

# Set the optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)


In [None]:
# FREEZE BLOCKS METHOD

def set_trainable(model, block, trainable=True):
    for name, param in model.named_parameters():
        if f"blocks.{block}" in name:
            param.requires_grad = trainable


In [None]:
from sklearn.metrics import *

#input: Y_pred,Y_true
#output: accuracy, auc, precision, recall, f1-score
def classification_metrics(Y_pred, Y_true):
    acc = accuracy_score(Y_true, Y_pred)
    auc = roc_auc_score(Y_true, Y_pred)
    precision = precision_score(Y_true, Y_pred)
    recall = recall_score(Y_true, Y_pred)
    f1 = f1_score(Y_true, Y_pred)
    return acc, auc, precision, recall, f1


def evaluate(model, loader):
    model.eval()
    all_y_true = torch.LongTensor()
    all_y_pred = torch.LongTensor()
    for x, y in test_loader:
        y_hat = model(x)
        y_pred = torch.max(y_hat, dim = 1).indices.detach()
        all_y_true = torch.cat((all_y_true, y.to('cpu').long()), dim=0)
        all_y_pred = torch.cat((all_y_pred,  y_pred), dim=0)
        
    acc, auc, precision, recall, f1 = classification_metrics(all_y_pred, all_y_true)
    print(f"acc: {acc:.3f}, auc: {auc:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, f1: {f1:.3f}")
    return (acc, auc, precision, recall, f1)

In [None]:
# TRAINING LOOP
train_loss_arr = []

# Freeze all 12 blocks
for block in range(12):
    set_trainable(model, block, trainable=False)


epochs = 12
for epoch in range(epochs):
    model.train()

    #unfreeze one block per epoch
    set_trainable(model, 11 - epoch, trainable=True)
    
    train_loss = 0
    for x, y in train_loader:
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
    train_loss = train_loss / len(train_loader)
    train_loss_arr.append(np.mean(train_loss))
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch+1, train_loss))
    evaluate(model, test_loader)