<a href="https://colab.research.google.com/github/Luck1e23/STA160-Team-11-Project/blob/Tina/STA160_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
'''
#Unzip to local SSD
!unzip -q /content/drive/Shareddrives/STA_160/archive.zip -d /content/dataset_raw
'''

In [None]:
!pip install torchxrayvision

Collecting torchxrayvision
  Downloading torchxrayvision-1.4.0-py3-none-any.whl.metadata (18 kB)
Downloading torchxrayvision-1.4.0-py3-none-any.whl (29.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.0/29.0 MB[0m [31m62.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchxrayvision
Successfully installed torchxrayvision-1.4.0


In [None]:
from PIL import Image
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import Adam

# Visualization tools
import torchvision
import torchvision.transforms.v2 as transforms
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt


# Pre-trained Model: torchxrayvision
import torchxrayvision as xrv
import skimage


In [None]:
# XRV Pathology Classifiers :: NIH chest X-ray8
xrv_model = xrv.models.DenseNet(weights="densenet121-res224-nih")

Downloading weights...
If this fails you can run `wget https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt -O /root/.torchxrayvision/models_data/nih-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt`
[██████████████████████████████████████████████████]


In [None]:

# Paths
file_path = '/content/drive/Shareddrives/STA_160/dataset/Data_Entry_2017.csv'
train_val = '/content/drive/Shareddrives/STA_160/dataset/train_val_list.txt'
test = '/content/drive/Shareddrives/STA_160/dataset/test_list.txt'
dataset_root = '/content/dataset'           # Original dataset (copied from Drive)
resized_root = '/content/dataset_resized'   # Where resized images will be saved

In [None]:
# Old Code for resizing the images before making a dataset class
# Purpose was to speed up the running of the code
'''
# Target Size
IMG_SIZE = (224, 224)

# Create folder for resized images
os.makedirs(resized_root, exist_ok=True)

# Loop through all images
for root, dirs, files in os.walk(dataset_root):
    # Keep folder structure
    rel_path = os.path.relpath(root, dataset_root)
    save_dir = os.path.join(resized_root, rel_path)
    os.makedirs(save_dir, exist_ok=True)

    for f in files:
        if f.lower().endswith(('.png', '.jpg', '.jpeg')):
            img_path = os.path.join(root, f)
            img = Image.open(img_path).convert("L")  # grayscale
            img = img.resize(IMG_SIZE, Image.BILINEAR)
            img.save(os.path.join(save_dir, f))

# Zip the resized dataset
!zip -r -q /content/NIH_resized.zip /content/dataset_resized

# Copy and upload to the shared drive
!cp /content/NIH_resized.zip /content/drive/Shareddrives/STA_160/
'''

In [None]:
# Unzipping the resized dataset from shared drives
!unzip -q /content/drive/Shareddrives/STA_160/NIH_resized.zip -d /content/dataset

# Make sure the destination folder exists
!mkdir -p /content/dataset_resized

# Move contents only if source exists
!if [ -d /content/dataset/content/dataset_resized ]; then mv /content/dataset/content/dataset_resized/* /content/dataset_resized/; fi

# Remove the empty nested folder safely
!rm -rf /content/dataset/content


In [None]:
# Dataset Class
class NIHXrays(Dataset):
    def __init__(self, file_path, dataset_root, list_file=None):
        self.data = pd.read_csv(file_path)
        self.dataset_root = dataset_root

        # Optional filtering
        if list_file:
            with open(list_file, 'r') as f:
                image_list = {line.strip() for line in f.readlines()}
            self.data = self.data[self.data['Image Index'].isin(image_list)].reset_index(drop=True)

        # Create label map
        all_labels = set()
        for labels in self.data['Finding Labels']:
            for l in labels.split('|'):
                all_labels.add(l.strip())

        self.all_labels = sorted(all_labels)
        self.label_map = {label: i for i, label in enumerate(self.all_labels)}

        # Build multi-hot label vectors
        self.finding_labels = []
        for labels in self.data['Finding Labels']:
            vec = torch.zeros(len(self.all_labels))
            for l in labels.split('|'):
                if l.strip() in self.label_map:
                    vec[self.label_map[l.strip()]] = 1.0
            self.finding_labels.append(vec)

        self.finding_labels = torch.stack(self.finding_labels)

        # Map image filenames to paths
        self.image_map = {}
        for root, dirs, files in os.walk(dataset_root):
            for f in files:
                if f.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_map[f] = os.path.join(root, f)

        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['Image Index']

        if img_name not in self.image_map:
            raise FileNotFoundError(f"Image {img_name} not found.")

        img = Image.open(self.image_map[img_name]).convert("L")
        img = self.transform(img) # tensor [1,224,224]

        # Convert to numpy for XRV normalization
        img = img[0].numpy() # [224,224]
        img = xrv.datasets.normalize(img, maxval=1.0)

        img = torch.from_numpy(img).unsqueeze(0)

        label = self.finding_labels[idx]
        return img, label, img_name

In [None]:
train_val_data = NIHXrays(file_path, resized_root, list_file = train_val)
test_data = NIHXrays(file_path, resized_root, list_file = test)

In [None]:
# Split the train_val_data into training and validation datasets
train_data, valid_data = torch.utils.data.random_split(train_val_data, [0.8, 0.2])

# Create DataLoaders for training and validation
n = 32
train_loader = DataLoader(train_data, batch_size=n, shuffle=True, num_workers=2, pin_memory=True)
train_N = len(train_loader.dataset)
valid_loader = DataLoader(valid_data, batch_size=n, num_workers=2, pin_memory=True)
valid_N = len(valid_loader.dataset)


# Data Augmentation
IMG_WIDTH, IMG_HEIGHT = (224, 224)
rand_transforms = transforms.Compose([
    transforms.RandomRotation(25),
    transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale = (0.8, 1), ratio = (1, 1)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2)
])

In [None]:

class XRV_Finetune(nn.Module):
    def __init__(self, base_model, num_classes):
        super().__init__()
        self.base = base_model

        # Freeze backbone
        for p in self.base.parameters():
            p.requires_grad = False

        # Number of outputs from XRV model
        in_features = len(base_model.pathologies)

        # New classification head
        self.classifier = nn.Linear(in_features, num_classes)

    def forward(self, x):
        out = self.base(x)            # shape [B, in_features]
        out = self.classifier(out)    # shape [B, num_classes]
        return out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Pre-trained Model
xrv_model = xrv.models.DenseNet(weights="densenet121-res224-nih")

# Freeze
xrv_model.requires_grad_(False)
print("XRV frozen")


N_CLASSES = 15

my_model = XRV_Finetune(xrv_model, N_CLASSES).to(device)

# Finding class weights
def compute_pos_weights(subset):

    # Get all labels that are in the subset
    full_dataset = subset.dataset
    labels = full_dataset.finding_labels[subset.indices].numpy()

    pos_counts = labels.sum(axis = 0) #counts how many of each disease present in the dataset
    neg_counts = (labels == 0).sum(axis = 0) # counts how many times each disease was NOT present in the dataset
    pos_weight = neg_counts / (pos_counts + 1e-6) # Ratio of negatives to positives. 1e-6 to prevent dividing by 0.

    return torch.tensor(pos_weight, dtype = torch.float32) #Convert to tensor

# Choosing loss function
pos_weight = compute_pos_weights(train_data).to(device) # Addressing class imbalance: add more weight to rare diseases
loss_function = nn.BCEWithLogitsLoss(pos_weight = pos_weight) #Since data is multi-label binary classification
optimizer = Adam(my_model.parameters())
my_model = my_model.to(device)

XRV frozen


In [19]:
from sklearn.metrics import f1_score, roc_auc_score
import torch

def compute_f1(y_true, y_pred):

    y_true = y_true.cpu()
    y_pred = y_pred.cpu()

    return f1_score(y_true, y_pred, average="macro", zero_division=0)

def compute_auc(y_true, y_prob):

    y_true = y_true.detach().cpu().numpy() #binary labels
    y_prob = y_prob.detach().cpu().numpy() #sigmoid probabilities

    return roc_auc_score(y_true, y_prob, average = "macro")

def train(model, train_loader, optimizer, loss_fn, device, check_grad=False):
    model.train()

    total_loss = 0
    all_probs = []
    all_preds = []
    all_labels = []

    for imgs, labels, _ in train_loader:

        augmented_imgs = []

        for img in imgs:  # iterate batch
            pil_img = transforms.ToPILImage()(img)       # convert tensor → PIL
            pil_img = rand_transforms(pil_img)           # apply augmentation
            aug_img = transforms.ToTensor()(pil_img)     # back to tensor
            augmented_imgs.append(aug_img)

        imgs = torch.stack(augmented_imgs).to(device)    # [B,1,224,224]


        imgs = imgs.to(device)
        labels = labels.to(device)

        # Forward
        outputs = model(imgs)

        # Compute loss
        batch_loss = loss_fn(outputs, labels)

        # Backprop
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

        total_loss += batch_loss.item()

        # Convert logits → predictions
        probs = torch.sigmoid(outputs)
        preds = (probs > 0.5).float()

        all_probs.append(probs.detach().cpu())
        all_preds.append(preds.cpu())
        all_labels.append(labels.detach().cpu())

    if check_grad:
        print("Last Gradient:")
        for p in model.parameters():
            if p.grad is not None:
                print(p.grad)

    # Compute F1 and AUROC at end of epoch
    all_probs = torch.cat(all_probs)
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0)
    auc = compute_auc(all_labels, all_probs)
    epoch_loss = total_loss / train_N

    print("Train - Loss: {:.4f}, F1: {:.4f}, AUC: {:.4f}".format(epoch_loss, f1, auc))

    return epoch_loss, f1, auc

def validate(model, valid_loader, loss_fn, device):
    model.eval()

    total_loss = 0
    all_probs = []
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels, _ in valid_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            # Compute loss
            outputs = model(imgs)
            batch_loss = loss_fn(outputs, labels)

            total_loss += batch_loss.item()

            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()

            all_probs.append(probs.detach().cpu())
            all_preds.append(preds.cpu())
            all_labels.append(labels.detach().cpu())

    # Compute F1 and AUROC at end of epoch
    all_probs = torch.cat(all_probs)
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0)
    auc = compute_auc(all_labels, all_probs)
    epoch_loss = total_loss / valid_N

    print("Valid  - Loss: {:.4f}, F1: {:.4f}, AUC: {:.4f}".format(epoch_loss, f1, auc))

    return epoch_loss, f1, auc


In [22]:
epochs = 3

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}")

    train_loss, train_f1, train_auc = train(
        model=my_model,
        train_loader=train_loader,
        optimizer=optimizer,
        loss_fn=loss_function,
        device=device,
        check_grad=False
    )

    valid_loss, valid_f1, valid_auc = validate(
        model=my_model,
        valid_loader=valid_loader,
        loss_fn=loss_function,
        device=device
    )


Epoch 1
Train - Loss: 0.0399, F1: 0.1104, AUC: 0.5009
Valid  - Loss: 0.0404, F1: 0.0921, AUC: 0.5001

Epoch 2
Train - Loss: 0.0399, F1: 0.1095, AUC: 0.5051
Valid  - Loss: 0.0404, F1: 0.0923, AUC: 0.4999

Epoch 3


KeyboardInterrupt: 

In [None]:
# Fine Tuning the model


# Unfreeze the base model
xrv_model.requires_grad_(True)
optimizer = Adam(my_model.parameters(), lr=.000001)