<a href="https://colab.research.google.com/github/Luck1e23/STA160-Team-11-Project/blob/laiq/Resnet50.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install torchxrayvision
!pip install iterative-stratification
!pip install validators matplotlib

Collecting torchxrayvision
  Downloading torchxrayvision-1.4.0-py3-none-any.whl.metadata (18 kB)
Downloading torchxrayvision-1.4.0-py3-none-any.whl (29.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.0/29.0 MB[0m [31m41.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchxrayvision
Successfully installed torchxrayvision-1.4.0
Collecting iterative-stratification
  Downloading iterative_stratification-0.1.9-py3-none-any.whl.metadata (1.3 kB)
Downloading iterative_stratification-0.1.9-py3-none-any.whl (8.5 kB)
Installing collected packages: iterative-stratification
Successfully installed iterative-stratification-0.1.9
Collecting validators
  Downloading validators-0.35.0-py3-none-any.whl.metadata (3.9 kB)
Downloading validators-0.35.0-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.7/44.7 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: validators
Successfully install

In [3]:
from PIL import Image
import os
import pandas as pd
import numpy as np
import json
import requests
import matplotlib.pyplot as plt
import warnings
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim import Adam
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from pathlib import Path

# Visualization tools
import torchvision
import torchvision.transforms.v2 as transforms
import torchvision.transforms.functional as tvF
import torch.nn.functional as F
import matplotlib.pyplot as plt


# Pre-trained Model: torchxrayvision
import torchxrayvision as xrv
import skimage
from torchvision.models import resnet50, ResNet50_Weights

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [5]:
file_path = '/content/drive/Shareddrives/STA_160/dataset/Data_Entry_2017.csv'
train_val = '/content/drive/Shareddrives/STA_160/dataset/train_val_list.txt'
test = '/content/drive/Shareddrives/STA_160/dataset/test_list.txt'
resized_root = '/content/dataset_resized'

In [6]:
!unzip -q /content/drive/Shareddrives/STA_160/NIH_resized.zip -d /content/dataset

!mkdir -p /content/dataset_resized
# Finds all image files inside subfolders
!find /content/dataset/content/dataset_resized/ -type f -exec mv -t /content/dataset_resized/ {} +
!rm -rf /content/dataset/content

In [7]:
### Dataset

class NIHXrays(Dataset):
    def __init__(self, file_path, dataset_root, list_file=None, transform=None):
        self.data = pd.read_csv(file_path)
        self.dataset_root = dataset_root
        self.transform = transform

        # Optional filtering by train_val_list or test_list
        if list_file:
            with open(list_file, 'r') as f:
                image_list = {line.strip() for line in f.readlines()}
            self.data = self.data[self.data['Image Index'].isin(image_list)].reset_index(drop=True)

        # Create label map
        all_labels = set()
        for labels in self.data['Finding Labels']:
            for l in labels.split('|'):
                all_labels.add(l.strip())

        self.all_labels = sorted(all_labels)
        self.label_map = {label: i for i, label in enumerate(self.all_labels)}

        # Build multi hot label vectors
        self.finding_labels = []
        for labels in self.data['Finding Labels']:
            vec = torch.zeros(len(self.all_labels))
            for l in labels.split('|'):
                if l.strip() in self.label_map:
                    vec[self.label_map[l.strip()]] = 1.0
            self.finding_labels.append(vec)

        self.finding_labels = torch.stack(self.finding_labels).float()

        # Recursively map image filenames to full paths
        self.image_map = {}
        for img_path in Path(dataset_root).rglob("*.png"):
            self.image_map[img_path.name] = str(img_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['Image Index']

        if img_name not in self.image_map:
            raise FileNotFoundError(f"Image {img_name} not found.")

        # Load image as grayscale first
        img = Image.open(self.image_map[img_name]).convert("L")

        # Apply augmentation on the PIL image
        if self.transform is not None:
            img = self.transform(img)

        # Convert to NumPy and scale to [0, 1]
        img = np.array(img).astype(np.float32) / 255.0  # [H, W]

        # To tensor with shape [1, H, W]
        img = torch.from_numpy(img).unsqueeze(0)

        # Repeat to make 3 channels for ResNet50: [3, H, W]
        img = img.repeat(3, 1, 1)

        # ImageNet normalization for ResNet50
        img = tvF.normalize(
            img,
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

        label = self.finding_labels[idx].float()

        return img, label, img_name


In [8]:
# Data augmentation for training
rand_transforms = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip()
])

train_val_data = NIHXrays(file_path, resized_root, list_file=train_val, transform=rand_transforms)
test_data      = NIHXrays(file_path, resized_root, list_file=test, transform=None)

print("Number of train+val images:", len(train_val_data))
print("Number of test images:", len(test_data))
print("Label set:", train_val_data.all_labels)


Number of train+val images: 86524
Number of test images: 25596
Label set: ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion', 'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass', 'No Finding', 'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax']


In [9]:
### Multi label stratified split

y = train_val_data.finding_labels.numpy()   # [N, C]
X = np.arange(len(train_val_data))

msss = MultilabelStratifiedShuffleSplit(
    n_splits=1,
    test_size=0.2,
    random_state=42
)

train_idx, valid_idx = next(msss.split(X, y))

train_data = Subset(train_val_data, train_idx)
valid_data = Subset(train_val_data, valid_idx)

print("Train samples:", len(train_data))
print("Valid samples:", len(valid_data))

### DataLoaders

batch_size = 32

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True,
                          num_workers=2, pin_memory=True)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False,
                          num_workers=2, pin_memory=True)


Train samples: 69200
Valid samples: 17324


In [10]:
### ResNet50 model

# Number of output classes from dataset
N_CLASSES = len(train_val_data.all_labels)
print("Number of classes:", N_CLASSES)

# Pretrained ResNet50 from torchvision (ImageNet weights)
weights = ResNet50_Weights.IMAGENET1K_V2
resnet_model = resnet50(weights=weights)

# Replace final fully connected layer
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, N_CLASSES)

# First stage: freeze all layers except the new head
for param in resnet_model.parameters():
    param.requires_grad = False

for param in resnet_model.fc.parameters():
    param.requires_grad = True

resnet_model = resnet_model.to(device)

print("Using torchvision ResNet50 backbone")
print("Trainable params in fc:", sum(p.numel() for p in resnet_model.fc.parameters()))


Number of classes: 15
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 173MB/s]


Using torchvision ResNet50 backbone
Trainable params in fc: 30735


In [11]:
### Loss function and optimizer

def compute_pos_weights(subset):
    full_dataset = subset.dataset
    labels = full_dataset.finding_labels[subset.indices].numpy()

    pos_counts = labels.sum(axis=0)
    neg_counts = (labels == 0).sum(axis=0)
    pos_weight = neg_counts / (pos_counts + 1e-6)

    return torch.tensor(pos_weight, dtype=torch.float32)

pos_weight = compute_pos_weights(train_data).to(device)
pos_weight = torch.clamp(pos_weight, max=30)

class Focal_Loss(nn.Module):
    def __init__(self, gamma=2.0, pos_weight=None):
        super().__init__()
        self.gamma = gamma
        self.pos_weight = pos_weight

    def forward(self, logits, targets):
        bce = F.binary_cross_entropy_with_logits(
            logits,
            targets,
            pos_weight=self.pos_weight,
            reduction='none'
        )

        probs = torch.sigmoid(logits)
        p_t = probs * targets + (1 - probs) * (1 - targets)

        focal_factor = (1 - p_t) ** self.gamma
        loss = focal_factor * bce
        return loss.mean()

loss_function = Focal_Loss(gamma=2.0, pos_weight=pos_weight)

optimizer = Adam(resnet_model.fc.parameters(), lr=1e-3)


In [12]:
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve

def pr_curve_thresholds(y_true, y_prob, min_thresh=0.1):
    num_labels = y_true.shape[1]
    thresholds = np.zeros(num_labels)

    for i in range(num_labels):
        p, r, t = precision_recall_curve(y_true[:, i], y_prob[:, i])
        f1 = 2 * p * r / (p + r + 1e-9)

        if len(t) == 0:
            thresholds[i] = min_thresh
        else:
            thresholds[i] = t[np.argmax(f1)]

    return thresholds

def compute_f1(y_true, y_prob, thresholds):
    y_true = y_true.cpu()
    y_prob = y_prob.cpu()

    thresholds = torch.tensor(thresholds, dtype=y_prob.dtype)
    y_pred = (y_prob >= thresholds.unsqueeze(0)).int()

    return f1_score(y_true, y_pred, average="macro", zero_division=0)

def compute_auc(y_true, y_prob):
    y_true = y_true.detach().cpu().numpy()
    y_prob = y_prob.detach().cpu().numpy()
    return roc_auc_score(y_true, y_prob, average="macro")


In [13]:
def train_epoch(model, train_loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0.0
    all_probs = []
    all_labels = []

    for imgs, labels, _ in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(imgs)
        batch_loss = loss_fn(outputs, labels)
        batch_loss.backward()
        optimizer.step()

        total_loss += batch_loss.item()

        probs = torch.sigmoid(outputs)
        all_probs.append(probs.detach().cpu())
        all_labels.append(labels.detach().cpu())

    all_probs = torch.cat(all_probs)
    all_labels = torch.cat(all_labels)

    auc = compute_auc(all_labels, all_probs)
    epoch_loss = total_loss / len(train_loader)

    print(f"Train - Loss: {epoch_loss:.4f}, AUC: {auc:.4f}")
    return epoch_loss, auc


def validate_epoch(model, valid_loader, loss_fn, device):
    model.eval()
    total_loss = 0.0
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels, _ in valid_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            outputs = model(imgs)
            batch_loss = loss_fn(outputs, labels)

            total_loss += batch_loss.item()

            probs = torch.sigmoid(outputs)
            all_probs.append(probs.detach().cpu())
            all_labels.append(labels.detach().cpu())

    all_probs = torch.cat(all_probs)
    all_labels = torch.cat(all_labels)

    epoch_loss = total_loss / len(valid_loader)
    thresholds = pr_curve_thresholds(all_labels.numpy(), all_probs.numpy())
    valid_f1 = compute_f1(all_labels, all_probs, thresholds)
    valid_auc = compute_auc(all_labels, all_probs)

    print(f"Valid - Loss: {epoch_loss:.4f}, F1: {valid_f1:.4f}, AUC: {valid_auc:.4f}")
    return epoch_loss, valid_f1, valid_auc, thresholds


In [14]:
EPOCHS = 5
best_val_auc = 0.0
best_model_path_resnet = "/content/drive/Shareddrives/STA_160/nih_resnet50_finetuned_best_head.pth"

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")

    train_loss, train_auc = train_epoch(
        model=resnet_model,
        train_loader=train_loader,
        optimizer=optimizer,
        loss_fn=loss_function,
        device=device
    )

    valid_loss, valid_f1, valid_auc, thresholds = validate_epoch(
        model=resnet_model,
        valid_loader=valid_loader,
        loss_fn=loss_function,
        device=device
    )

    if valid_auc > best_val_auc:
        best_val_auc = valid_auc
        torch.save(resnet_model.state_dict(), best_model_path_resnet)
        print(f"New best ResNet50 model saved with AUC {best_val_auc:.4f}")



Epoch 1/5
Train - Loss: 0.2478, AUC: 0.6624
Valid - Loss: 0.2440, F1: 0.1987, AUC: 0.7078
New best ResNet50 model saved with AUC 0.7078

Epoch 2/5
Train - Loss: 0.2424, AUC: 0.6962
Valid - Loss: 0.2434, F1: 0.2030, AUC: 0.7150
New best ResNet50 model saved with AUC 0.7150

Epoch 3/5
Train - Loss: 0.2411, AUC: 0.7066
Valid - Loss: 0.2406, F1: 0.2064, AUC: 0.7185
New best ResNet50 model saved with AUC 0.7185

Epoch 4/5
Train - Loss: 0.2404, AUC: 0.7108
Valid - Loss: 0.2441, F1: 0.2017, AUC: 0.7132

Epoch 5/5
Train - Loss: 0.2404, AUC: 0.7133
Valid - Loss: 0.2424, F1: 0.2117, AUC: 0.7181


In [15]:
for name, param in resnet_model.named_parameters():
    if name.startswith("layer4") or name.startswith("fc"):
        param.requires_grad = True
    else:
        param.requires_grad = False

# New optimizer with smaller LR for backbone, slightly higher for head
optimizer = Adam([
    {
        "params": [p for n, p in resnet_model.named_parameters()
                   if n.startswith("layer4") and p.requires_grad],
        "lr": 1e-5
    },
    {
        "params": resnet_model.fc.parameters(),
        "lr": 1e-4
    }
])

In [None]:
EPOCHS_FINE = 10
best_val_auc_fine = 0.0
best_model_path_resnet_fine = "/content/drive/Shareddrives/STA_160/nih_resnet50_finetuned_best_full.pth"

for epoch in range(EPOCHS_FINE):
    print(f"\nFine tune Epoch {epoch + 1}/{EPOCHS_FINE}")

    train_loss, train_auc = train_epoch(
        model=resnet_model,
        train_loader=train_loader,
        optimizer=optimizer,
        loss_fn=loss_function,
        device=device
    )

    valid_loss, valid_f1, valid_auc, thresholds = validate_epoch(
        model=resnet_model,
        valid_loader=valid_loader,
        loss_fn=loss_function,
        device=device
    )

    if valid_auc > best_val_auc_fine:
        best_val_auc_fine = valid_auc
        torch.save(resnet_model.state_dict(), best_model_path_resnet_fine)
        print(f"New best fine tuned ResNet50 saved with AUC {best_val_auc_fine:.4f}")


Fine tune Epoch 1/10
