Brain MRI Binary Classificatin PyTorch ViT

In [None]:
# download dataset from
!pip install --upgrade gdown==v4.6.3
!gdown --fuzzy 18RfTvv5NBKuUgMDJjxXb7BLYz31sT_aH --output brain.zip
# unzip
!unzip -q brain.zip

In [None]:
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from glob import glob
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torchvision
from sklearn.model_selection import train_test_split

In [None]:
IMG_SIZE = 256
class_map = {
    'no': 0,
    'yes': 1
}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
# Read single image
img_paths = sorted(glob('./brain/*/*.jpg') + glob('./brain/*/*.JPG') + glob('./brain/*/*.jpeg'))

In [None]:
# number of images
len(img_paths)

In [None]:
# show image
path = img_paths[-9]
img = Image.open(path).convert("RGB").resize((IMG_SIZE, IMG_SIZE))
print(path, img.size)
plt.imshow(img)

In [None]:
# Extract img class
img_path = img_paths[-9]
print(img_path)
# read label
cls = img_path.split('/')[-2]
print(cls)
# cls idx
print(class_map[cls])

#### Dataset, Dataloader

In [None]:
class BrainDataset(torch.utils.data.Dataset):
    def __init__(self, paths, transform):
        self.paths = paths
        self.transform = transform
        self.class_map = {
            'no': 0,
            'yes': 1
        }
    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        img = Image.open(path).convert("RGB")
        img = self.transform(img)

        cls = path.split('/')[-2]
        label = self.class_map[cls]
        label = torch.tensor(label, dtype=torch.long)
        return img, label

In [None]:
train_paths, val_paths = train_test_split(
    img_paths,
    test_size=0.2,
    random_state=5566
)

In [None]:
transforms = torchvision.models.ViT_B_16_Weights.DEFAULT.transforms()

train_ds = BrainDataset(train_paths, transforms)
val_ds = BrainDataset(train_paths, transforms)

In [None]:
print(transforms)

In [None]:
img, label = train_ds[0]
img.shape, label

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(img.permute(1, 2, 0))
img_raw = img.numpy().transpose(1, 2, 0) # (3, 256, 256) -> (256, 256, 3)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img_raw = std * img_raw + mean
img_raw = np.clip(img_raw, 0, 1)
print(img_raw.shape)
plt.subplot(1, 2, 2)
plt.imshow(img_raw)
plt.show()

In [None]:
BS = 32
train_loader = torch.utils.data.DataLoader(train_ds, BS, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_ds, BS)

Model

In [None]:
model = torchvision.models.vit_b_16(
    weights=torchvision.models.ViT_B_16_Weights.DEFAULT
)

# freeze encoder
for p in model.parameters():
    p.requires_grad = False

model.heads = nn.Sequential(
    nn.Linear(in_features=768, out_features=2)
)

In [None]:
inputs = torch.randn(1, 3, 224, 224)
outputs = model(inputs)
outputs.shape

#### Training

In [None]:
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset) # number of samples
    num_batches = len(dataloader) # batches per epoch

    model.train() # to training mode.
    epoch_loss, epoch_correct = 0, 0
    for batch_i, (x, y) in enumerate(tqdm(dataloader, leave=False)):
        x, y = x.to(device), y.to(device) # move data to GPU

        # zero the parameter gradients
        optimizer.zero_grad()

        # Compute prediction loss
        pred = model(x)
        loss = loss_fn(pred, y)

        # Optimization by gradients
        loss.backward() # backpropagation to compute gradients
        optimizer.step() # update model params

        # write to logs
        epoch_loss += loss.item() # tensor -> python value
        # (N, Class)
        epoch_correct += (pred.argmax(dim=1) == y).sum().item()

    # return avg loss of epoch, acc of epoch
    return epoch_loss/num_batches, epoch_correct/size


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset) # number of samples
    num_batches = len(dataloader) # batches per epoch

    model.eval() # model to test mode.
    epoch_loss, epoch_correct = 0, 0

    # No gradient for test data
    with torch.no_grad():
        for batch_i, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)

            # Compute prediction loss
            pred = model(x)
            loss = loss_fn(pred, y)

            # write to logs
            epoch_loss += loss.item()
            epoch_correct += (pred.argmax(1) == y).sum().item()

    return epoch_loss/num_batches, epoch_correct/size

EPOCHS = 100
logs = {
    'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []
}
# Earlystopping
patience = 5
counter = 0
best_loss = np.inf

for epoch in tqdm(range(EPOCHS)):
    train_loss, train_acc = train(train_loader, model, loss_fn, optimizer)
    val_loss, val_acc = test(val_loader, model, loss_fn)

    print(f'EPOCH: {epoch:04d} \
    train_loss: {train_loss:.4f}, train_acc: {train_acc:.3f} \
    val_loss: {val_loss:.4f}, val_acc: {val_acc:.3f} ')

    logs['train_loss'].append(train_loss)
    logs['train_acc'].append(train_acc)
    logs['val_loss'].append(val_loss)
    logs['val_acc'].append(val_acc)


    torch.save(model.state_dict(), "last.pth")
    # chcek improvement
    if val_loss < best_loss:
        counter = 0
        best_loss = val_loss
        torch.save(model.state_dict(), "best.pth")
    else:
        counter += 1
    if counter >= patience:
        print("Earlystop!")
        break

#### Evaluation

In [None]:
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay,
    recall_score, # sensitivity
    precision_score,
    f1_score,
    roc_curve,
    auc,
)
import pandas as pd

In [None]:
model.load_state_dict(torch.load('best.pth'))
_ = model.eval().to(device)

In [None]:
# Inference
y_pred = []
y_pred_raw = []
y_true = []

with torch.no_grad():
    for x, y in tqdm(val_loader):
        x = x.to(device)
        pred = model(x) # logits
        pred = nn.functional.softmax(pred, dim=1) # apply softmax to logits
        y_pred_raw.append(pred[:, 1]) # probability of class 1
        y_pred.append(pred.argmax(dim=1))
        y_true.append(y)

y_pred = torch.cat(y_pred, dim=0).cpu().numpy()
y_pred_raw = torch.cat(y_pred_raw, dim=0).cpu().numpy()
y_true = torch.cat(y_true, dim=0).cpu().numpy()

In [None]:
y_pred.shape, y_pred_raw.shape, y_true.shape

In [None]:
y_pred[:3], y_pred_raw[:3], y_true[:3]

In [None]:
# classification_report
print(classification_report(y_true, y_pred,
                            target_names=["NO", "YES"],
                            digits=3))

In [None]:
print("Sensitivity:", recall_score(y_true, y_pred))
print("Precision:  ", precision_score(y_true, y_pred))
print("F1 score:   ", f1_score(y_true, y_pred))

In [None]:
# Confusion matrix:
#   row: Ground truth
#   column: predict

cm = confusion_matrix(y_true, y_pred)
print(cm)

In [None]:
disp = ConfusionMatrixDisplay(
    confusion_matrix=cm,
    display_labels=["No", "Yes"]
)
disp.plot()
plt.show()

In [None]:
# y_pred_raw: 0~1
fp_rate, tp_rate, threshold = roc_curve(
    y_true, y_pred_raw
)

df = pd.DataFrame({
    'FPR': fp_rate,
    'TPR': tp_rate,
    'Threshold': threshold
})
df

In [None]:
# AUC score
auc_score = auc(fp_rate, tp_rate)
print(f'AUC: {auc_score:.4f}')

In [None]:
# ROC curve
plt.xlabel('False Positive Rate (FPR) 1-Specificity')
plt.ylabel('True Positive Rate (TPR) Sensitivity')
plt.plot(fp_rate, tp_rate, marker="^")
plt.title('ROC curve')
plt.show()