<a href="https://colab.research.google.com/github/Luck1e23/STA160-Team-11-Project/blob/Tina/STA160_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install torchxrayvision

Collecting torchxrayvision
  Downloading torchxrayvision-1.4.0-py3-none-any.whl.metadata (18 kB)
Downloading torchxrayvision-1.4.0-py3-none-any.whl (29.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.0/29.0 MB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchxrayvision
Successfully installed torchxrayvision-1.4.0


In [3]:
from PIL import Image
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import Adam

# Visualization tools
import torchvision
import torchvision.transforms.v2 as transforms
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt


# Pre-trained Model: torchxrayvision
import torchxrayvision as xrv
import skimage


In [4]:
# XRV Pathology Classifiers :: NIH chest X-ray8
xrv_model = xrv.models.DenseNet(weights="densenet121-res224-nih")

Downloading weights...
If this fails you can run `wget https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt -O /root/.torchxrayvision/models_data/nih-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt`
[██████████████████████████████████████████████████]


In [None]:
'''
# Create the Dataset class, NIHXrays, which contains the normalized pixel values of all the images

class NIHXrays(Dataset):
    def __init__(self, file_path, dataset_root, list_file = None):
        self.data = pd.read_csv(file_path)
        self.dataset_root = dataset_root

        if list_file:
            # Open the file and convert into a list of image names
            with open(list_file, 'r') as f:
                image_list = [line.strip() for line in f.readlines()]
            # Keep only rows that match the listed image names
            self.data = self.data[self.data['Image Index'].isin(image_list)].reset_index(drop = True)

        # Label map
        all_labels = set()
        for labels in self.data['Finding Labels']:
            for l in labels.split('|'):
                all_labels.add(l.strip())
        self.all_labels = sorted(list(all_labels))
        self.label_map = {label: idx for idx, label in enumerate(self.all_labels)}

        self.finding_labels = []
        for labels in self.data['Finding Labels']:
            vec = torch.zeros(len(self.all_labels))
            # Obtaining all the unique disease names
            for l in labels.split('|'):
                l = l.strip()
                if l in self.label_map:
                    vec[self.label_map[l]] = 1.0 # Mark as 1 if that is the disease found
            self.finding_labels.append(vec)

        # Image map
        self.image_map = {}
        # Go through all subfolders
        for root, dirs, files in os.walk(dataset_root):
            # Only keep 'images' folders
            if os.path.basename(root) == "images":
                for f in files:
                    self.image_map[f] = os.path.join(root, f)

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor() # Converts image to tensor and normalizes pixel values
            ])

    def __len__(self):
        # Get total number of samples
        return len(self.data)

    def __getitem__(self, idx):

        row = self.data.iloc[idx]
        img_name = row['Image Index']
        finding_label = self.finding_labels[idx]

        # Load and transform the image
        path = self.image_map.get(img_name)
        if not path:
            raise FileNotFoundError(f"Image {img_name} not found in dataset root.")
        img = self.transform(Image.open(path).convert('L')) # Convert to grayscale, resize, and a tensor

        return img, finding_label, img_name
'''

In [5]:
### NEW DATASET CLASS


class NIHXrays(Dataset):
    def __init__(self, file_path, dataset_root, list_file=None):
        self.data = pd.read_csv(file_path)
        self.dataset_root = dataset_root

        # Optional filtering
        if list_file:
            with open(list_file, 'r') as f:
                image_list = {line.strip() for line in f.readlines()}
            self.data = self.data[self.data['Image Index'].isin(image_list)].reset_index(drop=True)

        # Create label map
        all_labels = set()
        for labels in self.data['Finding Labels']:
            for l in labels.split('|'):
                all_labels.add(l.strip())

        self.all_labels = sorted(all_labels)
        self.label_map = {label: i for i, label in enumerate(self.all_labels)}

        # Build multi-hot label vectors
        self.finding_labels = []
        for labels in self.data['Finding Labels']:
            vec = torch.zeros(len(self.all_labels))
            for l in labels.split('|'):
                if l.strip() in self.label_map:
                    vec[self.label_map[l.strip()]] = 1.0
            self.finding_labels.append(vec)

        self.finding_labels = torch.stack(self.finding_labels)

        # Map image filenames to paths
        self.image_map = {}
        for root, dirs, files in os.walk(dataset_root):
            for f in files:
                if f.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_map[f] = os.path.join(root, f)

        self.transform = transforms.Compose([
            transforms.Resize((224, 224))
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['Image Index']

        if img_name not in self.image_map:
            raise FileNotFoundError(f"Image {img_name} not found.")

        img = Image.open(self.image_map[img_name]).convert("L")
        img = self.transform(img)
        img = np.array(img).astype(np.float32)
        img = xrv.datasets.normalize(img, maxval=255.0)
        img = torch.from_numpy(img).unsqueeze(0)

        label = self.finding_labels[idx]
        return img, label, img_name

In [6]:
file_path = '/content/drive/Shareddrives/STA_160/dataset/Data_Entry_2017.csv'
dataset_root = '/content/drive/Shareddrives/STA_160/dataset'
train_val = '/content/drive/Shareddrives/STA_160/dataset/train_val_list.txt'
test = '/content/drive/Shareddrives/STA_160/dataset/test_list.txt'


train_val_data = NIHXrays(file_path, dataset_root, list_file = train_val)
test_data = NIHXrays(file_path, dataset_root, list_file = test)

In [7]:
# Split the train_val_data into training and validation datasets
train_data, valid_data = torch.utils.data.random_split(train_val_data, [0.8, 0.2])

# Create DataLoaders for training and validation
n = 32
train_loader = DataLoader(train_data, batch_size=n, shuffle=True)
train_N = len(train_loader.dataset)
valid_loader = DataLoader(valid_data, batch_size=n)
valid_N = len(valid_loader.dataset)


# Data Augmentation
IMG_WIDTH, IMG_HEIGHT = (224, 224)
rand_transforms = transforms.Compose([
    transforms.RandomRotation(25),
    transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale = (0.8, 1), ratio = (1, 1)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2)
])

In [8]:

class XRV_Finetune(nn.Module):
    def __init__(self, base_model, num_classes):
        super().__init__()
        self.base = base_model

        # Freeze backbone
        for p in self.base.parameters():
            p.requires_grad = False

        # Number of outputs from XRV model
        in_features = len(base_model.pathologies)

        # New classification head
        self.classifier = nn.Linear(in_features, num_classes)

    def forward(self, x):
        out = self.base(x)            # shape [B, in_features]
        out = self.classifier(out)    # shape [B, num_classes]
        return out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Pre-trained Model
xrv_model = xrv.models.DenseNet(weights="densenet121-res224-nih")

# Freeze
xrv_model.requires_grad_(False)
print("XRV frozen")


N_CLASSES = 15

my_model = XRV_Finetune(xrv_model, N_CLASSES).to(device)


# Choosing loss function
loss_function = nn.BCEWithLogitsLoss()
optimizer = Adam(my_model.parameters())
my_model = my_model.to(device)

XRV frozen


In [11]:
from sklearn.metrics import f1_score
import torch

def compute_f1(y_true, y_pred):
    y_true = y_true.cpu()
    y_pred = y_pred.cpu()

    return f1_score(y_true, y_pred, average="macro", zero_division=0)

def train(model, train_loader, optimizer, loss_fn, device, check_grad=False):
    model.train()

    total_loss = 0
    all_preds = []
    all_labels = []

    for imgs, labels, _ in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        # Forward
        outputs = model(imgs)

        # Compute loss
        batch_loss = loss_fn(outputs, labels)

        # Backprop
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

        total_loss += batch_loss.item()

        # Convert logits → predictions
        preds = (torch.sigmoid(outputs) > 0.5).float()
        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())

    if check_grad:
        print("Last Gradient:")
        for p in model.parameters():
            if p.grad is not None:
                print(p.grad)

    # Compute F1 at end of epoch
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0)
    epoch_loss = total_loss / train_N

    print("Train - Loss: {:.4f}, F1: {:.4f}".format(epoch_loss, f1))

    return total_loss, f1

def validate(model, valid_loader, loss_fn, device):
    model.eval()

    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels, _ in valid_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            outputs = model(imgs)
            batch_loss = loss_fn(outputs, labels)

            total_loss += batch_loss.item()

            preds = (torch.sigmoid(outputs) > 0.5).float()
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0)
    epoch_loss = total_loss / valid_N

    print("Valid  - Loss: {:.4f}, F1: {:.4f}".format(epoch_loss, f1))

    return total_loss, f1


In [10]:
epochs = 1

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}")

    train_loss, train_f1 = train(
        model=my_model,
        train_loader=train_loader,
        optimizer=optimizer,
        loss_fn=loss_function,
        device=device,
        check_grad=False
    )

    val_loss, val_f1 = validate(
        model=my_model,
        valid_loader=valid_loader,
        loss_fn=loss_function,
        device=device
    )


Epoch 1


KeyboardInterrupt: 