<a href="https://colab.research.google.com/github/Luck1e23/STA160-Team-11-Project/blob/laiq/STA160_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [6]:
!pip install torchxrayvision
!pip install iterative-stratification



### Packages

In [7]:
from PIL import Image
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim import Adam
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from pathlib import Path

# Visualization tools
import torchvision
import torchvision.transforms.v2 as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt


# Pre-trained Model: torchxrayvision
import torchxrayvision as xrv
import skimage


### Pre-Trained Model

In [8]:
# XRV Pathology Classifiers :: NIH chest X-ray8
xrv_model = xrv.models.DenseNet(weights="densenet121-res224-nih")

Downloading weights...
If this fails you can run `wget https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt -O /root/.torchxrayvision/models_data/nih-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt`
[██████████████████████████████████████████████████]


### File Paths

In [9]:
# Paths
file_path = '/content/drive/Shareddrives/STA_160/dataset/Data_Entry_2017.csv'
train_val = '/content/drive/Shareddrives/STA_160/dataset/train_val_list.txt'
test = '/content/drive/Shareddrives/STA_160/dataset/test_list.txt'
resized_root = '/content/dataset_resized'   # Where resized images will be saved

### Creating a 224 x 224 Resized Dataset

In [10]:
# Old Code for resizing the images before making a dataset class
# Purpose was to speed up the running of the code
'''
# Target Size
IMG_SIZE = (224, 224)

for root, dirs, files in os.walk(dataset_root):
    rel_path = os.path.relpath(root, dataset_root)
    save_dir = os.path.join(resized_root, rel_path)
    os.makedirs(save_dir, exist_ok=True)

    for f in files:
        if f.lower().endswith(('.png', '.jpg', '.jpeg')):
            img_path = os.path.join(root, f)

            # Load and convert to grayscale (preserve pixel values)
            img = Image.open(img_path).convert("L")
            arr = np.array(img).astype(np.uint8)

            # Resize using bilinear interpolation
            img_resized = Image.fromarray(arr).resize(IMG_SIZE, Image.BILINEAR)

            # Build output filename (always PNG)
            base = os.path.splitext(f)[0]
            out_path = os.path.join(save_dir, base + ".png")

            # Save as PNG to avoid JPEG compression artifacts
            img_resized.save(out_path, format="PNG")

# Zip the resized dataset
!zip -r -q /content/NIH_resized.zip /content/dataset_resized

# Copy and upload to the shared drive
!cp /content/NIH_resized.zip /content/drive/Shareddrives/STA_160/
'''


'\n# Target Size\nIMG_SIZE = (224, 224)\n\nfor root, dirs, files in os.walk(dataset_root):\n    rel_path = os.path.relpath(root, dataset_root)\n    save_dir = os.path.join(resized_root, rel_path)\n    os.makedirs(save_dir, exist_ok=True)\n\n    for f in files:\n        if f.lower().endswith((\'.png\', \'.jpg\', \'.jpeg\')):\n            img_path = os.path.join(root, f)\n\n            # Load and convert to grayscale (preserve pixel values)\n            img = Image.open(img_path).convert("L")\n            arr = np.array(img).astype(np.uint8)\n\n            # Resize using bilinear interpolation\n            img_resized = Image.fromarray(arr).resize(IMG_SIZE, Image.BILINEAR)\n\n            # Build output filename (always PNG)\n            base = os.path.splitext(f)[0]\n            out_path = os.path.join(save_dir, base + ".png")\n\n            # Save as PNG to avoid JPEG compression artifacts\n            img_resized.save(out_path, format="PNG")\n\n# Zip the resized dataset\n!zip -r -q /co

In [11]:
# Unzipping the resized dataset from shared drives
!unzip -q /content/drive/Shareddrives/STA_160/NIH_resized.zip -d /content/dataset

!mkdir -p /content/dataset_resized
# Finds all image files inside subfolders
!find /content/dataset/content/dataset_resized/ -type f -exec mv -t /content/dataset_resized/ {} +
!rm -rf /content/dataset/content #remove folder

### Dataset

In [12]:
class NIHXrays(Dataset):
    def __init__(self, file_path, dataset_root, list_file=None, transform=None):
        self.data = pd.read_csv(file_path)
        self.dataset_root = dataset_root
        self.transform = transform

        # Optional filtering
        if list_file:
            with open(list_file, 'r') as f:
                image_list = {line.strip() for line in f.readlines()}
            self.data = self.data[self.data['Image Index'].isin(image_list)].reset_index(drop=True)

        # Create label map
        all_labels = set()
        for labels in self.data['Finding Labels']:
            for l in labels.split('|'):
                all_labels.add(l.strip())

        self.all_labels = sorted(all_labels)
        self.label_map = {label: i for i, label in enumerate(self.all_labels)}

        # Build multi-hot label vectors
        self.finding_labels = []
        for labels in self.data['Finding Labels']:
            vec = torch.zeros(len(self.all_labels))
            for l in labels.split('|'):
                if l.strip() in self.label_map:
                    vec[self.label_map[l.strip()]] = 1.0
            self.finding_labels.append(vec)

        self.finding_labels = torch.stack(self.finding_labels).float()

        # Recursively map image filenames to full paths
        self.image_map = {}
        for img_path in Path(dataset_root).rglob("*.png"):
            self.image_map[img_path.name] = str(img_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['Image Index']

        if img_name not in self.image_map:
            raise FileNotFoundError(f"Image {img_name} not found.")

        # Load Image
        img = Image.open(self.image_map[img_name]).convert("L")

        if self.transform is not None:
          img = self.transform(img)

        img = np.array(img).astype(np.float32) # PIL image --> NumPy [224,224]

        # XRV normalizaiton
        img = xrv.datasets.normalize(img, maxval=255)

        #   NumPy --> Tensor
        img = torch.from_numpy(img).unsqueeze(0) # [1, 224, 224]

        label = self.finding_labels[idx].float()

        return img, label, img_name

In [13]:
rand_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip()
])

train_val_data = NIHXrays(file_path, resized_root, list_file=train_val, transform=rand_transforms)
test_data      = NIHXrays(file_path, resized_root, list_file=test, transform=rand_transforms)

dataset_labels = train_val_data.all_labels
print("Using label order:", dataset_labels)

N_CLASSES = len(dataset_labels)

Using label order: ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion', 'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass', 'No Finding', 'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax']


### Multi-Label Stratified Split for Training and Validation

In [14]:
# Convert labels to NumPy for the stratifier
y = train_val_data.finding_labels.numpy()   # shape: [N, C]
X = np.arange(len(train_val_data))          # dummy feature array

#Stratified Splitting
msss = MultilabelStratifiedShuffleSplit(
    n_splits=1,
    test_size=0.2,
    random_state=42
)
#only want the first split & get train and validation indices
train_idx, valid_idx = next(msss.split(X, y))

# Get train and validation datasets based on the indices
train_data = Subset(train_val_data, train_idx)
valid_data = Subset(train_val_data, valid_idx)

### Data Augmentation and Data Loader

In [15]:
# Create DataLoaders for training and validation
n = 32
train_loader = DataLoader(train_data, batch_size=n, shuffle=True, num_workers=2, pin_memory=True)
train_N = len(train_loader.dataset)
valid_loader = DataLoader(valid_data, batch_size=n, num_workers=2, pin_memory=True)
valid_N = len(valid_loader.dataset)

### Freezing Base Layers of Pre-Trained Model

In [16]:
xrv_labels = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax',
              'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia',
              'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia',
              '', '', '', '']  # padding because XRV uses 18 logits

# Build mapping from dataset_labels → positions in xrv logits
xrv_index_map = [
    xrv_labels.index(lbl) if lbl in xrv_labels else None
    for lbl in dataset_labels
]


class XRV_Finetune(nn.Module):
    def __init__(self, base_model, num_classes, xrv_index_map, no_finding_index):
        super().__init__()
        self.base = base_model
        self.xrv_index_map = xrv_index_map
        self.no_finding_index = no_finding_index

        # DenseNet outputs 18 XRV logits → project them to 15 classes
        self.feature_dim = len(xrv_index_map)

        # Add a learnable embedding vector for "No Finding"
        self.no_finding_emb = nn.Parameter(torch.zeros(1))

        # Classification head
        self.classifier = nn.Linear(self.feature_dim, num_classes)

    def forward(self, x):
        out = self.base(x)  # [B, 18]

        features = []
        for idx in self.xrv_index_map:
            if idx is not None:
                features.append(out[:, idx].unsqueeze(1))
            else:
                # Use learnable embedding
                emb = self.no_finding_emb.expand(out.size(0), 1)
                features.append(emb)

        features = torch.cat(features, dim=1)
        out = self.classifier(features)
        return out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Pre-trained Model
xrv_model = xrv.models.DenseNet(weights="densenet121-res224-nih")

# Freeze
xrv_model.requires_grad_(False)
print("XRV frozen")


N_CLASSES = 15
no_finding_index = 14

my_model = XRV_Finetune(xrv_model, N_CLASSES, xrv_index_map = xrv_index_map, no_finding_index = no_finding_index).to(device)

XRV frozen


### Loss Function and Optimizer

In [17]:
# Finding class weights
def compute_pos_weights(subset):

    # Get all labels that are in the subset
    full_dataset = subset.dataset
    labels = full_dataset.finding_labels[subset.indices].numpy()

    pos_counts = labels.sum(axis = 0) #counts how many of each disease present in the dataset
    neg_counts = (labels == 0).sum(axis = 0) # counts how many times each disease was NOT present in the dataset
    pos_weight = neg_counts / (pos_counts + 1e-6) # Ratio of negatives to positives. 1e-6 to prevent dividing by 0.

    return torch.tensor(pos_weight, dtype = torch.float32) #Convert to tensor

pos_weight = compute_pos_weights(train_data).to(device) # Addressing class imbalance: add more weight to rare diseases
pos_weight = torch.clamp(pos_weight, max=30) # Because one of the values was > 600

# Using Focal Loss as loss function
class Focal_Loss(nn.Module):
    def __init__(self, gamma=2.0, pos_weight=None):
        super().__init__()
        self.gamma = gamma
        self .pos_weight = pos_weight

    def forward(self, logits, targets):
        # logits: (B, C)
        # targets: (B, C)

        # 1. Compute BCE with logits, elementwise (NO reduction)
        bce = F.binary_cross_entropy_with_logits(
            logits,
            targets,
            pos_weight=self.pos_weight,
            reduction='none'
        )

        # 2. Compute p_t = sigmoid(logits) for focal term
        probs = torch.sigmoid(logits)
        p_t = probs * targets + (1 - probs) * (1 - targets)

        # 3. Apply focal modulation
        focal_factor = (1 - p_t) ** self.gamma

        # 4. Combine
        loss = focal_factor * bce

        # 5. Reduce mean
        return loss.mean()

loss_function = Focal_Loss(gamma=2.0, pos_weight=pos_weight)

optimizer = Adam(my_model.classifier.parameters(), lr=1e-3) #increase classifiers learning rate
my_model = my_model.to(device)

### Model Performance Metrics

In [18]:
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve
import torch

def pr_curve_thresholds(y_true, y_prob, min_thresh=0.1):
    num_labels = y_true.shape[1]
    thresholds = np.zeros(num_labels)

    for i in range(num_labels):
        p, r, t = precision_recall_curve(y_true[:, i], y_prob[:, i])
        f1 = 2 * p * r / (p + r + 1e-9)

        if len(t) == 0:
            thresholds[i] = min_thresh  # Ensure thresholds are not all zero
        else:
            thresholds[i] = t[np.argmax(f1)]

    return thresholds

def compute_f1(y_true, y_prob, thresholds):

    y_true = y_true.cpu()
    y_prob = y_prob.cpu()

    # Converts thresholds into tensors and match the shape of y_prob
    thresholds = torch.tensor(thresholds, dtype = y_prob.dtype)
    y_pred = (y_prob >= thresholds.unsqueeze(0)).int()

    return f1_score(y_true, y_pred, average="macro", zero_division=0)

def compute_auc(y_true, y_prob):

    y_true = y_true.detach().cpu().numpy() #binary labels
    y_prob = y_prob.detach().cpu().numpy() #sigmoid probabilities

    return roc_auc_score(y_true, y_prob, average = "macro")

### Training and Validation Functions

In [19]:
def train(model, train_loader, optimizer, loss_fn, device, check_grad=False):
    model.train()

    total_loss = 0
    all_probs = []
    all_labels = []

    for imgs, labels, _ in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        # Forward
        outputs = model(imgs)

        # Compute loss
        batch_loss = loss_fn(outputs, labels)

        # Backprop
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

        total_loss += batch_loss.item()

        # Probability for AUC
        probs = torch.sigmoid(outputs)

        all_probs.append(probs.detach().cpu())
        all_labels.append(labels.detach().cpu())

    if check_grad:
        print("Last Gradient:")
        for p in model.parameters():
            if p.grad is not None:
                print(p.grad)

    # Compute AUROC at end of epoch
    all_probs = torch.cat(all_probs)
    all_labels = torch.cat(all_labels)

    auc = compute_auc(all_labels, all_probs)
    epoch_loss = total_loss / len(train_loader)

    print("Train - Loss: {:.4f}, AUC: {:.4f}".format(epoch_loss, auc))

    return epoch_loss, auc

def validate(model, valid_loader, loss_fn, device):
    model.eval()

    total_loss = 0
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels, _ in valid_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            # Compute loss
            outputs = model(imgs)
            batch_loss = loss_fn(outputs, labels)

            total_loss += batch_loss.item()

            # Get probabilities for F1 and AUC
            probs = torch.sigmoid(outputs)

            all_probs.append(probs.detach().cpu())
            all_labels.append(labels.detach().cpu())

    #Store all probabilities and labels
    all_probs = torch.cat(all_probs)
    all_labels = torch.cat(all_labels)

    # Compute loss for every epoch
    epoch_loss = total_loss / len(valid_loader)

    return epoch_loss, all_probs, all_labels


### Fit Model and Evaluations

In [20]:
epochs = 5

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}")

    train_loss, train_auc = train(
        model=my_model,
        train_loader=train_loader,
        optimizer=optimizer,
        loss_fn=loss_function,
        device=device,
        check_grad=False
    )

    valid_loss, valid_probs, valid_labels = validate(
        model=my_model,
        valid_loader=valid_loader,
        loss_fn=loss_function,
        device=device
    )

    #Find thresholds per label from validation
    thresholds = pr_curve_thresholds(valid_labels, valid_probs)

    # Find F1 and AUC for validation
    valid_f1 = compute_f1(valid_labels, valid_probs, thresholds)
    valid_auc = compute_auc(valid_labels, valid_probs)

    print("Valid - Loss: {:.4f}, F1: {:.4f}, AUC: {:.4f}".format(valid_loss, valid_f1, valid_auc))


Epoch 1
Train - Loss: 0.2503, AUC: 0.6075
Valid - Loss: 0.2405, F1: 0.1937, AUC: 0.6661

Epoch 2
Train - Loss: 0.2391, AUC: 0.6677
Valid - Loss: 0.2361, F1: 0.1974, AUC: 0.6854

Epoch 3
Train - Loss: 0.2363, AUC: 0.6819
Valid - Loss: 0.2346, F1: 0.2025, AUC: 0.7043

Epoch 4
Train - Loss: 0.2351, AUC: 0.6943
Valid - Loss: 0.2332, F1: 0.2028, AUC: 0.7075

Epoch 5
Train - Loss: 0.2345, AUC: 0.6982
Valid - Loss: 0.2324, F1: 0.2070, AUC: 0.7153


### Fine-Tuning Model

In [21]:
# Fine Tuning the model

# Unfreeze the last block model
for name, param in my_model.base.features.named_parameters():
    if "denseblock4" in name or "transition3" in name or "norm5" in name:
        param.requires_grad = True

#Assign different learning rates
optimizer = Adam([
    {'params': [p for p in my_model.base.parameters() if p.requires_grad], 'lr': 1e-6},
    {'params': my_model.classifier.parameters(), 'lr': 1e-3}
])

In [22]:
epochs = 10


for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}")

    train_loss, train_auc = train(
        model=my_model,
        train_loader=train_loader,
        optimizer=optimizer,
        loss_fn=loss_function,
        device=device,
        check_grad=False
    )

    valid_loss, valid_probs, valid_labels = validate(
        model=my_model,
        valid_loader=valid_loader,
        loss_fn=loss_function,
        device=device
    )

    #Find thresholds per label from validation
    thresholds = pr_curve_thresholds(valid_labels, valid_probs)

    # Find F1 and AUC for validation
    valid_f1 = compute_f1(valid_labels, valid_probs, thresholds)
    valid_auc = compute_auc(valid_labels, valid_probs)

    print("Valid - Loss: {:.4f}, F1: {:.4f}, AUC: {:.4f}".format(valid_loss, valid_f1, valid_auc))


Epoch 1
Train - Loss: 0.2335, AUC: 0.7063
Valid - Loss: 0.2318, F1: 0.2068, AUC: 0.7163

Epoch 2
Train - Loss: 0.2329, AUC: 0.7072
Valid - Loss: 0.2311, F1: 0.2110, AUC: 0.7213

Epoch 3
Train - Loss: 0.2321, AUC: 0.7131
Valid - Loss: 0.2308, F1: 0.2089, AUC: 0.7209

Epoch 4
Train - Loss: 0.2315, AUC: 0.7158
Valid - Loss: 0.2295, F1: 0.2084, AUC: 0.7246

Epoch 5
Train - Loss: 0.2311, AUC: 0.7172
Valid - Loss: 0.2293, F1: 0.2121, AUC: 0.7252

Epoch 6
Train - Loss: 0.2307, AUC: 0.7178
Valid - Loss: 0.2296, F1: 0.2120, AUC: 0.7266

Epoch 7
Train - Loss: 0.2303, AUC: 0.7191
Valid - Loss: 0.2287, F1: 0.2118, AUC: 0.7287

Epoch 8
Train - Loss: 0.2304, AUC: 0.7188
Valid - Loss: 0.2289, F1: 0.2123, AUC: 0.7266

Epoch 9
Train - Loss: 0.2296, AUC: 0.7227
Valid - Loss: 0.2290, F1: 0.2122, AUC: 0.7263

Epoch 10
Train - Loss: 0.2294, AUC: 0.7236
Valid - Loss: 0.2280, F1: 0.2123, AUC: 0.7299
