<a href="https://colab.research.google.com/github/Luck1e23/STA160-Team-11-Project/blob/Tina/STA160_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install torchxrayvision
!pip install iterative-stratification



### Packages

In [3]:
from PIL import Image
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim import Adam
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from pathlib import Path

# Visualization tools
import torchvision
import torchvision.transforms.v2 as transforms
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt


# Pre-trained Model: torchxrayvision
import torchxrayvision as xrv
import skimage


### Pre-Trained Model

In [4]:
# XRV Pathology Classifiers :: NIH chest X-ray8
xrv_model = xrv.models.DenseNet(weights="densenet121-res224-nih")

### File Paths

In [5]:
# Paths
file_path = '/content/drive/Shareddrives/STA_160/dataset/Data_Entry_2017.csv'
train_val = '/content/drive/Shareddrives/STA_160/dataset/train_val_list.txt'
test = '/content/drive/Shareddrives/STA_160/dataset/test_list.txt'
dataset_root = '/content/dataset_raw'           # Original dataset (copied from Drive)
resized_root = '/content/dataset_resized'   # Where resized images will be saved

### Creating a 224 x 224 Resized Dataset

In [None]:
# Old Code for resizing the images before making a dataset class
# Purpose was to speed up the running of the code
'''
# Target Size
IMG_SIZE = (224, 224)

for root, dirs, files in os.walk(dataset_root):
    rel_path = os.path.relpath(root, dataset_root)
    save_dir = os.path.join(resized_root, rel_path)
    os.makedirs(save_dir, exist_ok=True)

    for f in files:
        if f.lower().endswith(('.png', '.jpg', '.jpeg')):
            img_path = os.path.join(root, f)

            # Load and convert to grayscale (preserve pixel values)
            img = Image.open(img_path).convert("L")
            arr = np.array(img).astype(np.uint8)

            # Resize using bilinear interpolation
            img_resized = Image.fromarray(arr).resize(IMG_SIZE, Image.BILINEAR)

            # Build output filename (always PNG)
            base = os.path.splitext(f)[0]
            out_path = os.path.join(save_dir, base + ".png")

            # Save as PNG to avoid JPEG compression artifacts
            img_resized.save(out_path, format="PNG")

# Zip the resized dataset
!zip -r -q /content/NIH_resized.zip /content/dataset_resized

# Copy and upload to the shared drive
!cp /content/NIH_resized.zip /content/drive/Shareddrives/STA_160/
'''


### Dataset

In [None]:
# Dataset Class
class NIHXrays(Dataset):
    def __init__(self, file_path, dataset_root, list_file=None):
        self.data = pd.read_csv(file_path)
        self.dataset_root = dataset_root

        # Optional filtering
        if list_file:
            with open(list_file, 'r') as f:
                image_list = {line.strip() for line in f.readlines()}
            self.data = self.data[self.data['Image Index'].isin(image_list)].reset_index(drop=True)

        # Create label map
        all_labels = set()
        for labels in self.data['Finding Labels']:
            for l in labels.split('|'):
                all_labels.add(l.strip())

        self.all_labels = sorted(all_labels)
        self.label_map = {label: i for i, label in enumerate(self.all_labels)}

        # Build multi-hot label vectors
        self.finding_labels = []
        for labels in self.data['Finding Labels']:
            vec = torch.zeros(len(self.all_labels))
            for l in labels.split('|'):
                if l.strip() in self.label_map:
                    vec[self.label_map[l.strip()]] = 1.0
            self.finding_labels.append(vec)

        self.finding_labels = torch.stack(self.finding_labels).float()

        # Recursively map image filenames to full paths
        self.image_map = {}
        for img_path in Path(dataset_root).rglob("*.png"):
            self.image_map[img_path.name] = str(img_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['Image Index']

        if img_name not in self.image_map:
            raise FileNotFoundError(f"Image {img_name} not found.")

        # Load Image
        img = Image.open(self.image_map[img_name]).convert("L")
        img = np.array(img).astype(np.float32) # PIL image --> NumPy [224,224]

        # XRV normalizaiton
        img = xrv.datasets.normalize(img, maxval=255)

        #   NumPy --> Tensor
        img = torch.from_numpy(img).unsqueeze(0) # [1, 224, 224]

        label = self.finding_labels[idx].float()

        return img, label, img_name

In [None]:
train_val_data = NIHXrays(file_path, resized_root, list_file = train_val)
test_data = NIHXrays(file_path, resized_root, list_file = test)

### Multi-Label Stratified Split for Training and Validation

In [None]:
# Convert labels to NumPy for the stratifier
y = train_val_data.finding_labels.numpy()   # shape: [N, C]
X = np.arange(len(train_val_data))          # dummy feature array

#Stratified Splitting
msss = MultilabelStratifiedShuffleSplit(
    n_splits=1,
    test_size=0.2,
    random_state=42
)
#only want the first split & get train and validation indices
train_idx, valid_idx = next(msss.split(X, y))

# Get train and validation datasets based on the indices
train_data = Subset(train_val_data, train_idx)
valid_data = Subset(train_val_data, valid_idx)

### Data Augmentation and Data Loader

In [None]:
# Create DataLoaders for training and validation
n = 32
train_loader = DataLoader(train_data, batch_size=n, shuffle=True, num_workers=2, pin_memory=True)
train_N = len(train_loader.dataset)
valid_loader = DataLoader(valid_data, batch_size=n, num_workers=2, pin_memory=True)
valid_N = len(valid_loader.dataset)

# Data Augmentation
IMG_WIDTH, IMG_HEIGHT = (224, 224)
rand_transforms = transforms.Compose([
    transforms.RandomRotation(25),
    transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale = (0.8, 1), ratio = (1, 1)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2)
])

### Loss Function & Freezing Base Layers of Pre-Trained Model

In [None]:
# XRV pathologies
xrv_labels = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax',
              'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia',
              'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia',
              '', '', '', '']

# Your 15 dataset labels (example)
dataset_labels = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax',
                  'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia',
                  'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'No findings']

# Create mapping from dataset_labels → xrv_labels index
xrv_index_map = [xrv_labels.index(lbl) if lbl in xrv_labels else None for lbl in dataset_labels]


class XRV_Finetune(nn.Module):
    def __init__(self, base_model, num_classes, xrv_index_map):
        super().__init__()
        self.base = base_model
        self.xrv_index_map = xrv_index_map
        self.in_features = len(xrv_index_map)  # 15, including "No findings"

        # Freeze backbone
        for p in self.base.parameters():
            p.requires_grad = False

        # New classification head
        self.classifier = nn.Linear(self.in_features, num_classes)

    def forward(self, x):
        out = self.base(x)  # [B, 18] from DenseNet-121
        xrv_features = []

        for idx in self.xrv_index_map:
            if idx is not None:
                xrv_features.append(out[:, idx].unsqueeze(1)) # Selects entire batch of specific disease and goes from [Batch size] --> [Batch size, 1]
            else:
                # placeholder zero for "No findings", classifier can learn weights
                xrv_features.append(torch.zeros(out.size(0), 1, device=out.device))

        xrv_features = torch.cat(xrv_features, dim=1)  # [Batch size, 15]
        out = self.classifier(xrv_features)            # [Batch size, 15]

        return out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Pre-trained Model
xrv_model = xrv.models.DenseNet(weights="densenet121-res224-nih")

# Freeze
xrv_model.requires_grad_(False)
print("XRV frozen")


N_CLASSES = 15

my_model = XRV_Finetune(xrv_model, N_CLASSES, xrv_index_map = xrv_index_map).to(device)

# Finding class weights
def compute_pos_weights(subset):

    # Get all labels that are in the subset
    full_dataset = subset.dataset
    labels = full_dataset.finding_labels[subset.indices].numpy()

    pos_counts = labels.sum(axis = 0) #counts how many of each disease present in the dataset
    neg_counts = (labels == 0).sum(axis = 0) # counts how many times each disease was NOT present in the dataset
    pos_weight = neg_counts / (pos_counts + 1e-6) # Ratio of negatives to positives. 1e-6 to prevent dividing by 0.

    return torch.tensor(pos_weight, dtype = torch.float32) #Convert to tensor

# Choosing loss function
pos_weight = compute_pos_weights(train_data).to(device) # Addressing class imbalance: add more weight to rare diseases
pos_weight = torch.clamp(pos_weight, max=30) # Because one of the values was > 600

loss_function = nn.BCEWithLogitsLoss(pos_weight = pos_weight) #Since data is multi-label binary classification
optimizer = Adam(my_model.classifier.parameters(), lr=1e-2) #increase classifiers learning rate
my_model = my_model.to(device)

XRV frozen


### Model Performance Metrics

In [None]:
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve
import torch

def pr_curve_thresholds(y_true, y_prob):
    num_labels = y_true.shape[1]
    thresholds = np.zeros(num_labels)

    for i in range(num_labels):
        p, r, t = precision_recall_curve(y_true[:, i], y_prob[:, i])
        f1 = 2 * p * r / (p + r + 1e-9)
        thresholds[i] = t[np.argmax(f1)]

    return thresholds

def compute_f1(y_true, y_prob, thresholds):

    y_true = y_true.cpu()
    y_prob = y_prob.cpu()

    # Converts thresholds into tensors and match the shape of y_prob
    thresholds = torch.tensor(thresholds, dtype = y_prob.dtype)
    y_pred = (y_prob >= thresholds).int()

    return f1_score(y_true, y_pred, average="macro", zero_division=0)

def compute_auc(y_true, y_prob):

    y_true = y_true.detach().cpu().numpy() #binary labels
    y_prob = y_prob.detach().cpu().numpy() #sigmoid probabilities

    return roc_auc_score(y_true, y_prob, average = "macro")

### Training and Validation Functions

In [None]:
def train(model, train_loader, optimizer, loss_fn, device, check_grad=False):
    model.train()

    total_loss = 0
    all_probs = []
    all_labels = []

    for imgs, labels, _ in train_loader:

        # Applying Data Augmentation
        augmented_imgs = []

        for img in imgs:  # iterate batch
            pil_img = transforms.ToPILImage()(img)       # convert tensor → PIL
            pil_img = rand_transforms(pil_img)           # apply augmentation
            aug_img = transforms.ToTensor()(pil_img)     # back to tensor
            augmented_imgs.append(aug_img)

        imgs = torch.stack(augmented_imgs).to(device)    # [B,1,224,224]
        labels = labels.to(device)

        # Forward
        outputs = model(imgs)

        # Compute loss
        batch_loss = loss_fn(outputs, labels)

        # Backprop
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

        total_loss += batch_loss.item()

        # Probability for AUC
        probs = torch.sigmoid(outputs)

        all_probs.append(probs.detach().cpu())
        all_labels.append(labels.detach().cpu())

    if check_grad:
        print("Last Gradient:")
        for p in model.parameters():
            if p.grad is not None:
                print(p.grad)

    # Compute AUROC at end of epoch
    all_probs = torch.cat(all_probs)
    all_labels = torch.cat(all_labels)

    auc = compute_auc(all_labels, all_probs)
    epoch_loss = total_loss / train_N

    print("Train - Loss: {:.4f}, AUC: {:.4f}".format(epoch_loss, auc))

    return epoch_loss, auc

def validate(model, valid_loader, loss_fn, device):
    model.eval()

    total_loss = 0
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels, _ in valid_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            # Compute loss
            outputs = model(imgs)
            batch_loss = loss_fn(outputs, labels)

            total_loss += batch_loss.item()

            # Get probabilities for F1 and AUC
            probs = torch.sigmoid(outputs)

            all_probs.append(probs.detach().cpu())
            all_labels.append(labels.detach().cpu())

    #Store all probabilities and labels
    all_probs = torch.cat(all_probs)
    all_labels = torch.cat(all_labels)

    # Compute loss for every epoch
    epoch_loss = total_loss / valid_N

    return epoch_loss, all_probs, all_labels


### Fit Model and Evaluations

In [None]:
epochs = 10


for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}")

    train_loss, train_auc = train(
        model=my_model,
        train_loader=train_loader,
        optimizer=optimizer,
        loss_fn=loss_function,
        device=device,
        check_grad=False
    )

    valid_loss, valid_probs, valid_labels = validate(
        model=my_model,
        valid_loader=valid_loader,
        loss_fn=loss_function,
        device=device
    )

    #Find thresholds per label from validation
    thresholds = pr_curve_thresholds(valid_labels, valid_probs)

    # Find F1 and AUC for validation
    valid_f1 = compute_f1(valid_labels, valid_probs, thresholds)
    valid_auc = compute_auc(valid_labels, valid_probs)

    print("Valid - Loss: {:.4f}, F1: {:.4f}, AUC: {:.4f}".format(valid_loss, valid_f1, valid_auc))


Epoch 1
Train - Loss: 0.0323, AUC: 0.5026
Valid - Loss: 0.0324, F1: 0.1258, AUC: 0.4996

Epoch 2
Train - Loss: 0.0323, AUC: 0.5034
Valid - Loss: 0.0326, F1: 0.1258, AUC: 0.5001

Epoch 3
Train - Loss: 0.0322, AUC: 0.5137
Valid - Loss: 0.0331, F1: 0.1258, AUC: 0.5003

Epoch 4
Train - Loss: 0.0322, AUC: 0.5125
Valid - Loss: 0.0332, F1: 0.1258, AUC: 0.5002

Epoch 5
Train - Loss: 0.0322, AUC: 0.5085
Valid - Loss: 0.0337, F1: 0.1258, AUC: 0.5003

Epoch 6
Train - Loss: 0.0323, AUC: 0.5110
Valid - Loss: 0.0338, F1: 0.1258, AUC: 0.5007

Epoch 7
Train - Loss: 0.0322, AUC: 0.5121
Valid - Loss: 0.0343, F1: 0.1258, AUC: 0.5003

Epoch 8
Train - Loss: 0.0322, AUC: 0.5147
Valid - Loss: 0.0345, F1: 0.1258, AUC: 0.5004

Epoch 9
Train - Loss: 0.0323, AUC: 0.5119
Valid - Loss: 0.0343, F1: 0.1258, AUC: 0.5002

Epoch 10
Train - Loss: 0.0322, AUC: 0.5117
Valid - Loss: 0.0347, F1: 0.1258, AUC: 0.5003


### Fine-Tuning Model

In [None]:
# Fine Tuning the model

# Unfreeze the last block model
unfreeze_blocks = ["denseblock4", "transition3", "norm5"]
for name, param in my_model.base.features.named_parameters():
    if any(block in name for block in unfreeze_blocks):
        param.requires_grad = True

#Assign different learning rates
optimizer = Adam([
    {'params': [p for p in my_model.base.parameters() if p.requires_grad], 'lr': 1e-6},
    {'params': my_model.classifier.parameters(), 'lr': 1e-3}
])

In [None]:
epochs = 8


for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}")

    train_loss, train_auc = train(
        model=my_model,
        train_loader=train_loader,
        optimizer=optimizer,
        loss_fn=loss_function,
        device=device,
        check_grad=False
    )

    valid_loss, valid_probs, valid_labels = validate(
        model=my_model,
        valid_loader=valid_loader,
        loss_fn=loss_function,
        device=device
    )

    #Find thresholds per label from validation
    thresholds = pr_curve_thresholds(valid_labels, valid_probs)

    # Find F1 and AUC for validation
    valid_f1 = compute_f1(valid_labels, valid_probs, thresholds)
    valid_auc = compute_auc(valid_labels, valid_probs)

    print("Valid - Loss: {:.4f}, F1: {:.4f}, AUC: {:.4f}".format(valid_loss, valid_f1, valid_auc))


Epoch 1


KeyboardInterrupt: 

In [None]:
print(xrv_model.pathologies)
print(len(xrv_model.pathologies))


['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', '', '', '', '']
18


In [None]:
print(pos_weight)

tensor([  9.4831,  49.7850,  29.4665,  62.6799,   9.0377,  59.2962,  68.7782,
        658.2381,   5.2956,  20.7195,   0.7123,  17.3705,  38.0632,  96.2191,
         31.6509], device='cuda:0')


In [None]:
print(my_model)


XRV_Finetune(
  (base): XRV-DenseNet121-densenet121-res224-nih
  (classifier): Linear(in_features=14, out_features=15, bias=True)
)
