<a href="https://colab.research.google.com/github/Luck1e23/STA160-Team-11-Project/blob/Tina/EfficientNet_B4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install torchxrayvision
!pip install iterative-stratification
!pip install efficientnet_pytorch

Collecting torchxrayvision
  Downloading torchxrayvision-1.4.0-py3-none-any.whl.metadata (18 kB)
Downloading torchxrayvision-1.4.0-py3-none-any.whl (29.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.0/29.0 MB[0m [31m36.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchxrayvision
Successfully installed torchxrayvision-1.4.0
Collecting iterative-stratification
  Downloading iterative_stratification-0.1.9-py3-none-any.whl.metadata (1.3 kB)
Downloading iterative_stratification-0.1.9-py3-none-any.whl (8.5 kB)
Installing collected packages: iterative-stratification
Successfully installed iterative-stratification-0.1.9
Collecting efficientnet_pytorch
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: efficientnet_pytorch
  Building wheel for efficientnet_pytorch (setup.py) ... [?25l[?25hdone
  Created wheel for efficientnet_pytorch: filename=eff

In [3]:
from PIL import Image
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim import Adam
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from pathlib import Path

# Visualization tools
import torchvision
import torchvision.transforms.v2 as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Pre-trained Model: torchxrayvision
import torchxrayvision as xrv
import skimage

# Pre-trained Modedl: EfficientNet
from efficientnet_pytorch import EfficientNet


In [4]:
# Paths
file_path = '/content/drive/Shareddrives/STA_160/dataset/Data_Entry_2017.csv'
train_val = '/content/drive/Shareddrives/STA_160/dataset/train_val_list.txt'
test = '/content/drive/Shareddrives/STA_160/dataset/test_list.txt'
resized_root = '/content/dataset_resized'   # Where resized images will be saved

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
# Unzipping the resized dataset from shared drives
!unzip -q /content/drive/Shareddrives/STA_160/NIH_resized.zip -d /content/dataset

!mkdir -p /content/dataset_resized
# Finds all image files inside subfolders
!find /content/dataset/content/dataset_resized/ -type f -exec mv -t /content/dataset_resized/ {} +
!rm -rf /content/dataset/content #remove folder

In [7]:
# Dataset Class
class NIHXrays(Dataset):
    def __init__(self, file_path, dataset_root, list_file=None, transform=None):
        self.data = pd.read_csv(file_path)
        self.dataset_root = dataset_root
        self.transform = transform

        # Optional filtering
        if list_file:
            with open(list_file, 'r') as f:
                image_list = {line.strip() for line in f.readlines()}
            self.data = self.data[self.data['Image Index'].isin(image_list)].reset_index(drop=True)

        # Create label map
        all_labels = set()
        for labels in self.data['Finding Labels']:
            for l in labels.split('|'):
                all_labels.add(l.strip())

        self.all_labels = sorted(all_labels)
        self.label_map = {label: i for i, label in enumerate(self.all_labels)}

        # Build multi-hot label vectors
        self.finding_labels = []
        for labels in self.data['Finding Labels']:
            vec = torch.zeros(len(self.all_labels))
            for l in labels.split('|'):
                if l.strip() in self.label_map:
                    vec[self.label_map[l.strip()]] = 1.0
            self.finding_labels.append(vec)

        self.finding_labels = torch.stack(self.finding_labels).float()

        # Recursively map image filenames to full paths
        self.image_map = {}
        for img_path in Path(dataset_root).rglob("*.png"):
            self.image_map[img_path.name] = str(img_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['Image Index']

        if img_name not in self.image_map:
            raise FileNotFoundError(f"Image {img_name} not found.")

        # Load Image
        img = Image.open(self.image_map[img_name]).convert("L")

        if self.transform is not None:
          img = self.transform(img)

        img = np.array(img).astype(np.float32) # PIL image --> NumPy

        # XRV normalizaiton
        img = xrv.datasets.normalize(img, maxval=255)

        #   NumPy --> Tensor
        img = torch.from_numpy(img).unsqueeze(0)

        label = self.finding_labels[idx].float()

        return img, label, img_name

In [8]:
# Data Augmentation
rand_transforms = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip()
])

train_val_data = NIHXrays(file_path, resized_root, list_file=train_val, transform=rand_transforms)
test_data      = NIHXrays(file_path, resized_root, list_file=test, transform=None)

In [9]:
# Convert labels to NumPy for the stratifier
y = train_val_data.finding_labels.numpy()   # shape: [N, C]
X = np.arange(len(train_val_data))          # dummy feature array

#Stratified Splitting
msss = MultilabelStratifiedShuffleSplit(
    n_splits=1,
    test_size=0.2,
    random_state=42
)
#only want the first split & get train and validation indices
train_idx, valid_idx = next(msss.split(X, y))

# Get train and validation datasets based on the indices
train_data = Subset(train_val_data, train_idx)
valid_data = Subset(train_val_data, valid_idx)

In [10]:
# Create DataLoaders for training and validation
n = 32
train_loader = DataLoader(train_data, batch_size=n, shuffle=True, num_workers=2, pin_memory=True)
train_N = len(train_loader.dataset)
valid_loader = DataLoader(valid_data, batch_size=n, num_workers=2, pin_memory=True)
valid_N = len(valid_loader.dataset)

In [11]:
# Finding class weights
def compute_pos_weights(subset):

    # Get all labels that are in the subset
    full_dataset = subset.dataset
    labels = full_dataset.finding_labels[subset.indices].numpy()

    pos_counts = labels.sum(axis = 0) #counts how many of each disease present in the dataset
    neg_counts = (labels == 0).sum(axis = 0) # counts how many times each disease was NOT present in the dataset
    pos_weight = neg_counts / (pos_counts + 1e-6) # Ratio of negatives to positives. 1e-6 to prevent dividing by 0.

    return torch.tensor(pos_weight, dtype = torch.float32) #Convert to tensor

# Using Focal Loss as loss function
class Focal_Loss(nn.Module):
    def __init__(self, gamma=2.0, pos_weight=None):
        super().__init__()
        self.gamma = gamma
        self .pos_weight = pos_weight

    def forward(self, logits, targets):
        # logits: (B, C)
        # targets: (B, C)

        # 1. Compute BCE with logits, elementwise (NO reduction)
        bce = F.binary_cross_entropy_with_logits(
            logits,
            targets,
            pos_weight=self.pos_weight,
            reduction='none'
        )

        # 2. Compute p_t = sigmoid(logits) for focal term
        probs = torch.sigmoid(logits)
        p_t = probs * targets + (1 - probs) * (1 - targets)

        # 3. Apply focal modulation
        focal_factor = (1 - p_t) ** self.gamma

        # 4. Combine
        loss = focal_factor * bce

        # 5. Reduce mean
        return loss.mean()

In [12]:
# Using the same 224 x 224 image size for EfficientNet-B4

class EfficientNetB4_Classifier(nn.Module):
    def __init__(self, num_classes=15):
        super().__init__()
        self.backbone = EfficientNet.from_pretrained("efficientnet-b4")
        self.feature_dim = 1792
        self.classifier = nn.Linear(self.feature_dim, num_classes)

    def forward(self, x):
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        feats = self.backbone.extract_features(x)
        feats = F.adaptive_avg_pool2d(feats, 1).view(x.size(0), -1)
        return self.classifier(feats)

effnet_model = EfficientNetB4_Classifier(num_classes=15).to(device)

# Freeze backbone; Classifier unfrozen
for param in effnet_model.backbone.parameters():
    param.requires_grad = False
print("EfficientNet frozen")


pos_weight = compute_pos_weights(train_data).to(device) # Addressing class imbalance: add more weight to rare diseases
pos_weight = torch.clamp(pos_weight, max=30) # Because one of the values was > 600
loss_function = Focal_Loss(gamma=2.0, pos_weight=pos_weight)
optimizer = Adam(effnet_model.classifier.parameters(), lr=1e-3) #increase classifiers learning rate
effnet_model = effnet_model.to(device)



Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b4-6ed6700e.pth


100%|██████████| 74.4M/74.4M [00:00<00:00, 112MB/s]


Loaded pretrained weights for efficientnet-b4
EfficientNet frozen


In [13]:
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve
import torch

def pr_curve_thresholds(y_true, y_prob, min_thresh=0.1):
    num_labels = y_true.shape[1]
    thresholds = np.zeros(num_labels)

    for i in range(num_labels):
        p, r, t = precision_recall_curve(y_true[:, i], y_prob[:, i])
        f1 = 2 * p * r / (p + r + 1e-9)

        if len(t) == 0:
            thresholds[i] = min_thresh  # Ensure thresholds are not all zero
        else:
            thresholds[i] = t[np.argmax(f1)]

    return thresholds

def compute_f1(y_true, y_prob, thresholds):

    y_true = y_true.cpu()
    y_prob = y_prob.cpu()

    # Converts thresholds into tensors and match the shape of y_prob
    thresholds = torch.tensor(thresholds, dtype = y_prob.dtype)
    y_pred = (y_prob >= thresholds.unsqueeze(0)).int()

    return f1_score(y_true, y_pred, average="macro", zero_division=0)

def compute_auc(y_true, y_prob):

    y_true = y_true.detach().cpu().numpy() #binary labels
    y_prob = y_prob.detach().cpu().numpy() #sigmoid probabilities

    return roc_auc_score(y_true, y_prob, average = "macro")

In [14]:
def train_epoch(model, train_loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0.0
    all_probs = []
    all_labels = []

    for imgs, labels, _ in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(imgs)
        batch_loss = loss_fn(outputs, labels)
        batch_loss.backward()
        optimizer.step()

        total_loss += batch_loss.item()

        probs = torch.sigmoid(outputs)
        all_probs.append(probs.detach().cpu())
        all_labels.append(labels.detach().cpu())

    all_probs = torch.cat(all_probs)
    all_labels = torch.cat(all_labels)

    auc = compute_auc(all_labels, all_probs)
    epoch_loss = total_loss / len(train_loader)

    print(f"Train - Loss: {epoch_loss:.4f}, AUC: {auc:.4f}")
    return epoch_loss, auc

def validate_epoch(model, valid_loader, loss_fn, device):
    model.eval()
    total_loss = 0.0
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels, _ in valid_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            outputs = model(imgs)
            batch_loss = loss_fn(outputs, labels)

            total_loss += batch_loss.item()

            probs = torch.sigmoid(outputs)
            all_probs.append(probs.detach().cpu())
            all_labels.append(labels.detach().cpu())

    all_probs = torch.cat(all_probs)
    all_labels = torch.cat(all_labels)

    epoch_loss = total_loss / len(valid_loader)
    thresholds = pr_curve_thresholds(all_labels.numpy(), all_probs.numpy())
    valid_f1 = compute_f1(all_labels, all_probs, thresholds)
    valid_auc = compute_auc(all_labels, all_probs)

    print(f"Valid - Loss: {epoch_loss:.4f}, F1: {valid_f1:.4f}, AUC: {valid_auc:.4f}")
    return epoch_loss, valid_f1, valid_auc, thresholds

In [15]:
EPOCHS = 5
best_val_auc = 0.0
best_model_path_effnet = "/content/drive/Shareddrives/STA_160/nih_efficientnetb4_finetuned_best_head.pth"

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")

    train_loss, train_auc = train_epoch(
        model=effnet_model,
        train_loader=train_loader,
        optimizer=optimizer,
        loss_fn=loss_function,
        device=device
    )

    valid_loss, valid_f1, valid_auc, thresholds = validate_epoch(
        model=effnet_model,
        valid_loader=valid_loader,
        loss_fn=loss_function,
        device=device
    )

    if valid_auc > best_val_auc:
        best_val_auc = valid_auc
        torch.save(effnet_model.state_dict(), best_model_path_effnet)
        print(f"New best EfficientNet-B4 model saved with AUC {best_val_auc:.4f}")


Epoch 1/5


KeyboardInterrupt: 

In [None]:
# Unfreeze last layer of EfficientNet
for param in effnet_model.backbone._blocks[-1].parameters():
    param.requires_grad = True

# Update learning rates for optimizer
optimizer = Adam([
    {"params": effnet_model.classifier.parameters(), "lr": 1e-3},
    {"params": effnet_model.backbone._blocks[-1].parameters(), "lr": 1e-4},
])

In [None]:
EPOCHS_FINE = 20
best_val_auc_fine = 0.0
best_model_path_effnet_fine = "/content/drive/Shareddrives/STA_160/nih_efficientnetb4_finetuned_best_full.pth"

for epoch in range(EPOCHS_FINE):
    print(f"\nFine tune Epoch {epoch + 1}/{EPOCHS_FINE}")

    train_loss, train_auc = train_epoch(
        model=effnet_model,
        train_loader=train_loader,
        optimizer=optimizer,
        loss_fn=loss_function,
        device=device
    )

    valid_loss, valid_f1, valid_auc, thresholds = validate_epoch(
        model=effnet_model,
        valid_loader=valid_loader,
        loss_fn=loss_function,
        device=device
    )

    if valid_auc > best_val_auc_fine:
        best_val_auc_fine = valid_auc
        torch.save(effnet_model.state_dict(), best_model_path_effnet_fine)
        print(f"New best fine tuned EfficientNet-B4 saved with AUC {best_val_auc_fine:.4f}")