## Vision Transformer (Pre - Trained)

In [None]:
import os
import glob

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import torch
import timm

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split

import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


In [None]:
class ChestXRayDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]
        return image, label

In [None]:
data_path = '/Users/ananyajain/Desktop/CSC413/CSC413-Final-Project/archive/sample'
images_dir = 'sample/images'
labels_csv = 'sample_labels.csv'

In [None]:
labels_df = pd.read_csv(os.path.join(data_path, labels_csv))

image_path = {os.path.basename(x): x for x in glob.glob(os.path.join(data_path, images_dir, '*.png'))}

labels_df = labels_df[labels_df['Image Index'].map(os.path.basename).isin(image_path)]

print('Total Images:', len(image_path), ', Total Input Rows:', labels_df.shape[0])

new_labels_df = pd.DataFrame()
new_labels_df['Id'] = labels_df['Image Index'].copy()
new_labels_df['labels'] = labels_df['Finding Labels'].apply(lambda val: val.split('|'))

mlb = MultiLabelBinarizer()
labels = mlb.fit_transform(new_labels_df['labels'])
labels = np.array(labels, dtype=float)

image_paths = [image_path[os.path.basename(x)] for x in new_labels_df['Id']]

In [None]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
train_paths, val_test_paths, train_labels, val_test_labels = train_test_split(
    image_paths, labels, test_size=0.2, random_state=42)
val_paths, test_paths, val_labels, test_labels = train_test_split(
    val_test_paths, val_test_labels, test_size=0.5, random_state=42)

train_dataset = ChestXRayDataset(train_paths, train_labels, transform)
val_dataset = ChestXRayDataset(val_paths, val_labels, transform)
test_dataset = ChestXRayDataset(test_paths, test_labels, transform)

batch_size = 16
loader_train = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
loader_test = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

In [None]:
model = timm.create_model('vit_base_patch16_224', pretrained=True)
num_classes = 15
model.head = nn.Linear(model.head.in_features, num_classes)
model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = Adam(model.parameters(), lr=1e-4)

In [None]:
def train_model(model, criterion, optimizer, loader_train, loader_val, num_epochs=10):
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        running_loss = 0.0
        running_corrects = 0
        total_samples = 0

        for inputs, labels in loader_train:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            preds = outputs.sigmoid() > 0.5

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += (preds == labels.byte()).sum().item()
            total_samples += labels.numel()

        epoch_loss = running_loss / len(loader_train.dataset)
        epoch_acc = running_corrects / total_samples * 100
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.2f}%')

        val_loss, val_acc = validate_model(model, loader_val, criterion)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
    
    return train_losses, val_losses, train_accuracies, val_accuracies


In [None]:
def validate_model(model, loader_val, criterion, threshold=0.5):
    model.eval()
    total_samples = 0
    total_correct = 0
    running_loss = 0.0
    
    with torch.no_grad():
        for inputs, labels in loader_val:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            predicted = outputs.sigmoid() > threshold
            
            running_loss += loss.item() * inputs.size(0)
            total_correct += (predicted == labels.byte()).sum().item()
            total_samples += labels.numel()

    val_loss = running_loss / len(loader_val.dataset)
    accuracy = total_correct / total_samples * 100
    print(f'Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%')
    return val_loss, accuracy

In [None]:
train_losses, val_losses, train_accuracies, val_accuracies = train_model(model, criterion, optimizer, loader_train, loader_val, num_epochs=10)

In [None]:
test_accuracy = validate_model(model, loader_test, criterion)
print(f'Test Accuracy: {test_accuracy:.2f}%')

In [None]:
plt.figure(figsize=(12, 10))
plt.subplot(2, 2, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.legend()

plt.subplot(2, 2, 2)
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Validation Loss Over Epochs')
plt.legend()

plt.subplot(2, 2, 3)
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training Accuracy Over Epochs')
plt.legend()

plt.subplot(2, 2, 4)
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Validation Accuracy Over Epochs')
plt.legend()

plt.tight_layout()
plt.savefig('training_and_validation_metrics.png')
plt.show()