In [12]:
import os
os.chdir('..')
import sys
import pandas as pd
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, cohen_kappa_score

from tqdm import tqdm
import csv

In [2]:
# ========================
# region Data Sets/Loaders
# ========================
class RetinaMultiLabelDataset(Dataset):
    def __init__(self, csv_file, image_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = os.path.join(self.image_dir, row.iloc[0])
        img = Image.open(img_path).convert("RGB")
        labels = torch.tensor(row[1:].values.astype("float32"))
        if self.transform:
            img = self.transform(img)
        return img, labels

class RetinaMultiLabelDataset_WithoutLabels(Dataset):
    def __init__(self, image_dir, transform=None):
        self.images = os.listdir(image_dir)
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, img_name


def get_dataloaders(dataset_path: str, img_size=256, batch_size=32):
    
    # transforms
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])

    # paths
    train_csv = os.path.join( dataset_path, "labels/train.csv" )
    val_csv   = os.path.join( dataset_path, "labels/val.csv" )
    test_csv  = os.path.join( dataset_path, "labels/offsite_test.csv" )
    
    train_image_dir = os.path.join( dataset_path, "images/train" )
    val_image_dir =   os.path.join( dataset_path, "images/val" )
    test_image_dir =  os.path.join( dataset_path, "images/offsite_test" )
    onsite_test_image_dir =  os.path.join( dataset_path, "images/onsite_test" )

    # dataset & dataloader
    train_ds =       RetinaMultiLabelDataset(train_csv, train_image_dir, transform)
    val_ds   =       RetinaMultiLabelDataset(val_csv, val_image_dir, transform)
    test_ds  =       RetinaMultiLabelDataset(test_csv, test_image_dir, transform)
    onsite_test_ds = RetinaMultiLabelDataset_WithoutLabels(onsite_test_image_dir, transform)

    train_loader =        DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader   =        DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader  =        DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4)
    onsite_test_loader  = DataLoader(onsite_test_ds, batch_size=1, shuffle=False, num_workers=4)
    
    return train_loader, val_loader, test_loader, onsite_test_loader

In [3]:
# ========================
# region BUILD MODEL
# ========================

def build_model(backbone="resnet18", num_classes=3):

    if backbone == "resnet18":
        model = models.resnet18(weights=None) #"IMAGENET1K_V1") # TODO: should this be None??
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif backbone == "efficientnet":
        model = models.efficientnet_b0(weights=None) #"IMAGENET1K_V1")
        layer_fc: nn.Linear = model.classifier[1] # type: ignore[assignment]
        model.classifier[1] = nn.Linear(layer_fc.in_features, num_classes)
    else:
        raise ValueError(f"Unsupported backbone: {backbone}")
    
    return model


def get_model(backbone="resnet18", pretrained_params=None, freeze_backbone=False, num_classes=3):
    
    model = build_model(backbone, num_classes)
    
    # parameters freezing
    if freeze_backbone:
        print('FREEZING: freezing model backbone (non-Linear layers)')
        model = freeze_non_linear_layers(model)
    
    else:
        print('FREEZING: Unfreezing all layers')
        for p in model.parameters():
            p.requires_grad = True
    
    # pretrained params
    if pretrained_params is not None:
        state_dict = torch.load(pretrained_params, map_location="cpu")
        try:
            model.load_state_dict(state_dict)
        except:
            print(f"ERROR: Incompatible backbone ({backbone}) and params file ({pretrained_params})\nexiting ...")
            sys.exit(2)
    
    # print param amounts
    all_params, trainable_params = get_parameter_count(model)
    print('=====================')
    print('    LOADED MODEL')
    print('---------------------')
    print('backbone:', backbone)
    print('pretrained params:', pretrained_params)
    print('parameter count:  {:_d}'.format(all_params))
    print('trainable params: {:_d}'.format(trainable_params))
    print('frozen params:    {:_d}'.format(all_params-trainable_params))
    print('=====================')
    
    return model

In [4]:
# ========================
# region HELPERS
# ========================

def freeze_non_linear_layers(model):
    """
    Freeze backbone and leave classifier (linear layers) unfrozen. 
    """
    for p in model.parameters():
        p.requires_grad = False
    # Unfreeze only Linear layers
    for m in model.modules():
        if isinstance(m, nn.Linear):
            for p in m.parameters():
                p.requires_grad = True
    return model

def get_parameter_count(model): 
    all_params =       sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return all_params, trainable_params

def ensure_parent_exists(file: str):
    os.makedirs(os.path.dirname(file), exist_ok=True)

# ========================
# region CUSTOM LOSS
# ========================

def FocalLoss(x):
    """
    Focal Loss: A loss function designed to address class imbalance by downweighting easy examples and focusing
    training on hard, misclassified ones.
    """
    raise NotImplementedError()

def ClassBalancedLoss(x):
    """
    Class-Balanced Loss: Re-weight the BCE loss according to class frequency. This is a common method for handling
    class imbalance.
    """
    raise NotImplementedError()

In [5]:
# ========================
# region predict
# ========================
def predict(
        model: nn.Module,
        loader: DataLoader,
        csv_path="onsite_test_submission.csv",
    ):
    
    model.eval()
    data = []
    print(f'generating predictions for {len(loader.dataset)} images') # type: ignore
    with torch.no_grad():
        for img, img_name in tqdm(loader):
            img_name = img_name[0]
            img = img.to(DEVICE)
            output = model(img)[0]
            probs = torch.sigmoid(output).cpu().numpy()
            preds = (probs > 0.5).astype(int)
            data_line = [img_name]
            data_line.extend(preds)
            data.append(data_line)
    
    # write to csv
    if not csv_path.endswith(".csv"): csv_path += ".csv"
    ensure_parent_exists(csv_path)
    print(f'writing predictions to {csv_path}')
    with open(csv_path, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["id","D","G","A"])
        writer.writerows(data)

In [6]:
# ========================
# region test
# ========================
def test(
        model: nn.Module,
        loader: DataLoader,
    ):

    print(f'Testing model on {len(loader.dataset)} images') # type: ignore

    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for imgs, labels in tqdm(loader, colour="magenta"):
            imgs = imgs.to(DEVICE)
            outputs = model(imgs)
            probs = torch.sigmoid(outputs).cpu().numpy()
            preds = (probs > 0.5).astype(int)
            y_true.extend(labels.numpy())
            y_pred.extend(preds)

    y_true = np.array(y_true) #torch.tensor(y_true).numpy()
    y_pred = np.array(y_pred) #torch.tensor(y_pred).numpy()

    # compute metrics
    disease_names = ["DR", "Glaucoma", "AMD"]
    results_data = []
    disease_counts = []
    
    for i, disease in enumerate(disease_names):  # compute metrics for every disease
        y_t = y_true[:, i]
        y_p = y_pred[:, i]

        disease_counts.append(y_t.sum())
        acc =       accuracy_score(y_t, y_p)
        precision = precision_score(y_t, y_p, average="macro", zero_division=0)
        recall =    recall_score(y_t, y_p, average="macro", zero_division=0)
        f1 =        f1_score(y_t, y_p, average="macro", zero_division=0)
        kappa =     cohen_kappa_score(y_t, y_p)

        results_data.append([disease, acc, precision, recall, f1, kappa])

    disease_count = np.array(disease_counts)
    results = pd.DataFrame(
        data=results_data, 
        columns=["Disease", "Accuracy", "Precision", "Recall", "F1-score", "Kappa"],
    ).set_index("Disease")

    results = results.T
    print(results.values)
    results["SUM"] = results.values.sum(axis=1)
    results["MEAN"] = results.values.sum(axis=1) / 3
    results["WEIGHTEDMEAN"] = results.values.sum(axis=1) * disease_count
    print("========================")
    print("DISEASE SPECIFIC METRICS:\n")
    print(results.T)
    
    return results

In [7]:
# ========================
# region train
# ========================
def train(
        model,
        train_loader,
        val_loader,
        optimizer,
        loss_fn,
        epochs=10,
        save_name="checkpoints/best.pt",
    ):

    ensure_parent_exists(save_name)
    print('loss function:', loss_fn)
    
    # iterates
    best_val_loss = float("inf")
    for epoch in range(epochs):
        # print('training ...')
        model.train()
        train_loss = 0
        for imgs, labels in tqdm(train_loader, desc=f"epoch {epoch+1}/{epochs}", colour="purple"):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * imgs.size(0)
        train_loss /= len(train_loader.dataset) # type: ignore

        # validation
        # print('evaluating ...')
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
                outputs = model(imgs)
                loss = loss_fn(outputs, labels)
                val_loss += loss.item() * imgs.size(0)
        val_loss /= len(val_loader.dataset) # type: ignore

        print(f"Train Loss: {train_loss:.4f} Val Loss: {val_loss:.4f}")

        # save best
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            print(f"    ..saving best checkpoint to: {save_name}")
            torch.save(model.state_dict(), save_name)

    return save_name

In [None]:
# ========================
# region MAIN
# ========================
def main(
        mode = "train",
        backbone = "resnet18",
        dataset_path = "dataset",
        pretrained_params: str|None = None,
        save_name: str|None = None,
        freeze_backbone = True, # freeze non-linear layers
        loss_fn = nn.BCEWithLogitsLoss,
        attention = None,
        predict_csv = "onsite_test_submission.csv",
        # save_dir="checkpoints",
        epochs=10, batch_size=32, lr=1e-4, img_size=256,
    ):
    
    # MODE
    match mode:
        case "train": # - train -------
            optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

            print(f'Training {backbone} for {epochs} epochs')
            train(
                model,
                train_loader,
                val_loader,
                optimizer,
                loss_fn,
                epochs=epochs,
                save_name=save_name,
            )
            test(
                model,
                test_loader,
            )
        
        case "test":  # - test ---------
            test(
                model,
                test_loader,
            )
        
        case "predict": # - predict ---
            print("Predicting ...")
            predict(
                model,
                onsite_test_loader,
                csv_path=predict_csv,
            )
        
        case "none": # - none ---------
            print("Just a test, nothing to see here")
        
        case _:
            print("oh no, no mode is matched??")

In [9]:
# ========================
# region CLI
# ========================

LOSS_FUNCS = {
    "bce": nn.BCEWithLogitsLoss(),
    "focal": FocalLoss,
    "class_balanced": ClassBalancedLoss,
}

DEVICE: torch.device

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('using device:', DEVICE)

using device: cpu


In [13]:
train_loader, val_loader, test_loader, onsite_test_loader = get_dataloaders("ODIR_dataset", 256, 32)

print('Building model')
model = get_model(
    backbone = "resnet18",
    pretrained_params = "checkpoints/best_resnet18.pt",
    freeze_backbone = True,
).to(DEVICE)

Building model
FREEZING: freezing model backbone (non-Linear layers)
    LOADED MODEL
---------------------
backbone: resnet18
pretrained params: checkpoints/best_resnet18.pt
parameter count:  11_178_051
trainable params: 1_539
frozen params:    11_176_512


In [14]:
# get metrics
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for imgs, labels in tqdm(test_loader, colour="magenta"):
        imgs = imgs.to(DEVICE)
        outputs = model(imgs)
        probs = torch.sigmoid(outputs).cpu().numpy()
        preds = (probs > 0.5).astype(int)
        y_true.extend(labels.numpy())
        y_pred.extend(preds)

y_true = np.array(y_true) #torch.tensor(y_true).numpy()
y_pred = np.array(y_pred) #torch.tensor(y_pred).numpy()

# compute metrics
disease_names = ["DR", "Glaucoma", "AMD"]
results_data = []
disease_counts = []

for i, disease in enumerate(disease_names):  # compute metrics for every disease
    y_t = y_true[:, i]
    y_p = y_pred[:, i]

    disease_counts.append(y_t.sum())
    acc =       accuracy_score(y_t, y_p)
    precision = precision_score(y_t, y_p, average="macro", zero_division=0)
    recall =    recall_score(y_t, y_p, average="macro", zero_division=0)
    f1 =        f1_score(y_t, y_p, average="macro", zero_division=0)
    kappa =     cohen_kappa_score(y_t, y_p)

    results_data.append([disease, acc, precision, recall, f1, kappa])

disease_count = np.array(disease_counts)

100%|[35m██████████[0m| 7/7 [00:17<00:00,  2.52s/it]


In [40]:
results = pd.DataFrame(
    data=results_data, 
    columns=["Disease", "Accuracy", "Precision", "Recall", "F1-score", "Kappa"],
).set_index("Disease")
results = results.T

In [45]:
weights = disease_count / disease_count.sum()

In [47]:
results = pd.DataFrame(
    data=results_data, 
    columns=["Disease", "Accuracy", "Precision", "Recall", "F1-score", "Kappa"],
).set_index("Disease")
results = results.T

values = results.values
results["SUM"] = values.sum(axis=1)
results["MEAN"] = values.mean(axis=1)
results["WEIGHTEDMEAN"] = (values * weights).sum(axis=1)
results["AVE"] = np.average(values, axis=1, weights=weights)
results

Disease,DR,Glaucoma,AMD,SUM,MEAN,WEIGHTEDMEAN,AVE
Accuracy,0.71,0.845,0.85,2.405,0.801667,0.755948,0.755948
Precision,0.691919,0.819196,0.670054,2.18117,0.727057,0.719197,0.719197
Recall,0.72619,0.731923,0.756384,2.214498,0.738166,0.73067,0.73067
F1-score,0.690667,0.760053,0.69752,2.148239,0.71608,0.707495,0.707495
Kappa,0.395833,0.525413,0.400958,1.322205,0.440735,0.42646,0.42646
