In [17]:
# Imports
import os
import pandas as pd
import numpy as np
from PIL import Image
from sklearn.metrics import roc_auc_score

from fastai.vision.all import *
import timm
import torch
import time
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GroupShuffleSplit

In [None]:
csv_path = "/student/csc490_project/shared/labels.csv"
image_dir = "/student/csc490_project/shared/preprocessed_images/preprocessed_images"

df = pd.read_csv(csv_path)

# Keep only rows where 'Image Index' is in preprocessed images
preprocessed_images = set(os.listdir(image_dir))
df = df[df["Image Index"].isin(preprocessed_images)]
print(f"Filtered dataset contains {len(df)} images.")

# Convert 'Finding Labels' into a list of diseases
df["Finding Labels"] = df["Finding Labels"].apply(lambda x: x.split('|'))

# Convert labels into a binary multi-label format
mlb = MultiLabelBinarizer()
labels = pd.DataFrame(mlb.fit_transform(df["Finding Labels"]), columns=mlb.classes_)

# Merge binary labels with the dataset and keep 'Image Index', 'Patient ID', and label columns
df = df.join(labels)
df = df[['Image Index', 'Patient ID'] + list(mlb.classes_)]
print(f"Initial dataset shape: {df.shape}")

# Remove rare diseases with fewer than 2 samples (note: labels start at column index 2 now)
class_counts = df.iloc[:, 2:].sum()
rare_classes = class_counts[class_counts < 2].index.tolist()
print(f"Rare classes with <2 samples: {rare_classes}")
df = df.drop(columns=rare_classes)

# Remove rows where all labels are 0 (images that only had rare labels)
df = df[df.iloc[:, 2:].sum(axis=1) > 0]
print(f"Updated dataset shape after removing rare classes: {df.shape}")

# Split the dataset into training and test sets ensuring no patient appears in both sets
gss = GroupShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
train_idx, test_idx = next(gss.split(df, groups=df["Patient ID"]))
print(f"Training set contains {len(train_idx)} images.")
print(f"Test set contains {len(test_idx)} images.")
train_df = df.iloc[train_idx]
test_df = df.iloc[test_idx]
# Verify that there is no overlap in patient IDs between training and test sets
print(f"Training set contains {len(train_df)} images from {train_df['Patient ID'].nunique()} unique patients.")
print(f"Test set contains {len(test_df)} images from {test_df['Patient ID'].nunique()} unique patients.")

common_patients = set(train_df["Patient ID"]).intersection(set(test_df["Patient ID"]))
assert len(common_patients) == 0, "OVERLAP OF PATIENTS BETWEEN TRAINING AND TEST SETS"

# Define disease labels (all columns except 'Image Index' and 'Patient ID')
disease_labels = list(train_df.columns[2:])

Filtered dataset contains 112120 images.
Initial dataset shape: (112120, 17)
Rare classes with <2 samples: []
Updated dataset shape after removing rare classes: (112120, 17)
Training set contains 78566 images.
Test set contains 33554 images.
Training set contains 78566 images from 21563 unique patients.
Test set contains 33554 images from 9242 unique patients.


In [19]:
def get_x(row): 
    return os.path.join(image_dir, row['Image Index'])

def get_y(row):
    return [label for label, value in zip(disease_labels, row[disease_labels]) if value == 1]

In [20]:
# Custom Transform for Gamma Correction
class GammaCorrection(Transform):
    def __init__(self, gamma:float=1.0):
        self.gamma = gamma
    def encodes(self, img:PILImage):
        # Expecting a grayscale PIL image.
        img_np = np.array(img).astype(np.float32) / 255.0
        corrected = np.power(img_np, self.gamma)
        corrected = np.clip(corrected * 255, 0, 255).astype(np.uint8)
        return PILImage.create(corrected)

# Transform to convert image to 3 channels
def to_3channel(img:PILImage):
    return img.convert("RGB")

In [21]:
# Create a DataBlock for multi-label classification.
dblock = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock(vocab=disease_labels)),
    get_x=get_x,
    get_y=get_y,
    splitter=IndexSplitter(test_idx),
    item_tfms=[GammaCorrection(gamma=0.8), to_3channel, Resize(224)],
    batch_tfms=[*aug_transforms(flip_vert=False, max_rotate=15, max_zoom=1.0, max_warp=0.),
                Normalize.from_stats([0.485,0.456,0.406],[0.229,0.224,0.225])]
)

# Create DataLoaders from the full DataFrame
dls = dblock.dataloaders(df, bs=16)

In [22]:
model_name = 'coatnet_2_rw_224.sw_in12k_ft_in1k'
num_classes = len(disease_labels)  # This should match the number of labels
model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

loss_func = BCEWithLogitsLossFlat()

In [None]:
learn = Learner(
    dls,
    model,
    loss_func=loss_func,
    metrics=[
        AccumMetric(accuracy_multi, name='acc_multi'),  
        F1ScoreMulti(),                                 
        RocAucMulti()                                   
    ]
)

learn.fit_one_cycle(
    20,
    lr_max=1e-4,
    cbs=[
        SaveModelCallback(monitor='roc_auc_score', fname='noverlap_coatnet'),
        EarlyStoppingCallback(monitor='roc_auc_score', patience=3)
    ]
)


epoch,train_loss,valid_loss,acc_multi,f1_score,roc_auc_score,time
0,0.181696,0.18112,0.933613,0.115935,0.804658,12:21
1,0.178057,0.182279,0.933498,0.12012,0.799045,12:22
2,0.179023,0.180526,0.933874,0.080466,0.808925,12:22
3,0.178042,0.176067,0.934456,0.142038,0.821646,12:22
4,0.173035,0.177227,0.934053,0.134575,0.824316,12:22
5,0.171583,0.175443,0.934863,0.140871,0.829179,12:22
6,0.170287,0.174386,0.934392,0.17909,0.83656,12:22
7,0.170414,0.173157,0.934919,0.20953,0.835959,12:22
8,0.163702,0.173252,0.934563,0.202233,0.838041,12:22
9,0.156812,0.173462,0.934796,0.18779,0.840254,12:21


Better model found at epoch 0 with roc_auc_score value: 0.8046582124371116.
Better model found at epoch 2 with roc_auc_score value: 0.8089246767003087.
Better model found at epoch 3 with roc_auc_score value: 0.8216459896328643.
Better model found at epoch 4 with roc_auc_score value: 0.8243162961552135.
Better model found at epoch 5 with roc_auc_score value: 0.8291791225467902.
Better model found at epoch 6 with roc_auc_score value: 0.8365604055161971.
Better model found at epoch 8 with roc_auc_score value: 0.8380413845179437.
Better model found at epoch 9 with roc_auc_score value: 0.8402543577962615.
No improvement since epoch 9: early stopping


  state = torch.load(file, map_location=device, **torch_load_kwargs)


: 