In [None]:
%pip install ipywidgets torch torchvision setuptools tensorflow kagglehub opencv-python numpy pandas scipy scikit-learn pillow matplotlib

In [None]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

import kagglehub
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from scipy.ndimage import label
from tensorflow import keras
from keras._tf_keras.keras.preprocessing.image import array_to_img

In [None]:
class GlaucomaDataset(Dataset):
    def __init__(self, root_dir, split='train', output_size=(256, 256)):
        self.output_size = output_size
        self.root_dir = root_dir
        self.split = split
        self.images = []
        self.segs = []

        for direct in self.root_dir:
            self.image_filenames = []
            for path in os.listdir(os.path.join(direct, "Images_Square")):
                if not path.startswith('.'):
                    self.image_filenames.append(path)

            for k in range(len(self.image_filenames)):
                try:
                    print(f'Loading {split} image {k}/{len(self.image_filenames)}...', end='\r')
                    img_name = os.path.join(direct, "Images_Square", self.image_filenames[k])
                    img = np.array(Image.open(img_name).convert('RGB'))
                    img = transforms.functional.to_tensor(img)
                    img = transforms.functional.resize(img, output_size, interpolation=Image.BILINEAR)
                    self.images.append(img)
                except Exception as e:
                    print(f"\nError loading image {self.image_filenames[k]}: {e}")
                    continue 

            if split != 'test':
                for k in range(len(self.image_filenames)):
                    try:
                        print(f'Loading {split} segmentation {k}/{len(self.image_filenames)}...', end='\r')
                        seg_name = os.path.join(direct, "Masks_Square", self.image_filenames[k][:-3] + "png")
                        mask = np.array(Image.open(seg_name, mode='r'))
                        od = (mask == 1.).astype(np.float32)
                        oc = (mask == 2.).astype(np.float32)
                        od = torch.from_numpy(od[None, :, :])
                        oc = torch.from_numpy(oc[None, :, :])
                        od = transforms.functional.resize(od, output_size, interpolation=Image.NEAREST)
                        oc = transforms.functional.resize(oc, output_size, interpolation=Image.NEAREST)
                        self.segs.append(torch.cat([od, oc], dim=0))
                    except Exception as e:
                        print(f"\nError loading segmentation for {self.image_filenames[k]}: {e}")
                        continue 

            print(f'Successfully loaded {split} dataset.', ' ' * 50)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        if self.split == 'test':
            return img
        else:
            seg = self.segs[idx]
            return img, seg


class TestDataset(Dataset):
    def __init__(self, image_dir, json_path, output_size=(256, 256)):
        self.output_size = output_size
        self.image_dir = image_dir

        # Load images
        self.image_filenames = []
        for path in os.listdir(image_dir):
            if not path.startswith('.'):
                self.image_filenames.append(path)

        # Load ground truth glaucoma labels from JSON file
        with open(json_path, 'r') as f:
            self.ground_truth = json.load(f)

        # Create a dictionary mapping from image filename to label
        self.filename_to_label = {}
        for item in self.ground_truth.values():
            filename = item["ImgName"]
            label = item["Label"]
            self.filename_to_label[filename] = label

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        try:
            img_name = os.path.join(self.image_dir, self.image_filenames[idx])
            img = np.array(Image.open(img_name).convert('RGB'))
            img = transforms.functional.to_tensor(img)
            img = transforms.functional.resize(img, self.output_size, interpolation=Image.BILINEAR)
            file_id = self.image_filenames[idx]
            has_glaucoma = self.filename_to_label.get(file_id, None)

            if has_glaucoma is None:
                raise KeyError(f"Glaucoma label for '{file_id}' not found in ground_truth.")

            return img, int(has_glaucoma)
        except Exception as e:
            print(f"\nError loading file {self.image_filenames[idx]}: {e}")
            return None, None


In [None]:
EPS = 1e-7

def compute_dice_coef(input, target):
    batch_size = input.shape[0]
    return sum([dice_coef_sample(input[k, :, :], target[k, :, :]) for k in range(batch_size)]) / batch_size

def dice_coef_sample(input, target):
    iflat = input.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()
    return (2. * intersection) / (iflat.sum() + tflat.sum())

def vertical_diameter(binary_segmentation):
    vertical_axis_diameter = np.sum(binary_segmentation, axis=1)
    diameter = np.max(vertical_axis_diameter, axis=1)
    return diameter

def vertical_cup_to_disc_ratio(od, oc):
    cup_diameter = vertical_diameter(oc)
    disc_diameter = vertical_diameter(od)
    return cup_diameter / (disc_diameter + EPS)

def compute_vCDR_error(pred_od, pred_oc, gt_od, gt_oc):
    pred_vCDR = vertical_cup_to_disc_ratio(pred_od, pred_oc)
    gt_vCDR = vertical_cup_to_disc_ratio(gt_od, gt_oc)
    vCDR_err = np.mean(np.abs(gt_vCDR - pred_vCDR))
    return vCDR_err, pred_vCDR, gt_vCDR

def refine_seg(pred):
    np_pred = pred.numpy()
    largest_ccs = []
    for i in range(np_pred.shape[0]):
        labeled, ncomponents = label(np_pred[i, :, :])
        bincounts = np.bincount(labeled.flat)[1:]
        largest_cc = labeled == np.argmax(bincounts) + 1 if len(bincounts) != 0 else labeled == 0
        largest_ccs.append(torch.tensor(largest_cc, dtype=torch.float32))
    return torch.stack(largest_ccs)


In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=2):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.epoch = 0
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.down5 = Down(1024, 2048)
        factor = 2
        self.down6 = Down(2048, 4096 // factor)
        self.up1 = Up(4096, 2048 // factor)
        self.up2 = Up(2048, 1024 // factor)
        self.up3 = Up(1024, 512 // factor)
        self.up4 = Up(512, 256 // factor)
        self.up5 = Up(256, 128 // factor)
        self.up6 = Up(128, 64)
        self.output_layer = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x6 = self.down5(x5)
        x7 = self.down6(x6)
        out = self.up1(x7, x6)
        out = self.up2(out, x5)
        out = self.up3(out, x4)
        out = self.up4(out, x3)
        out = self.up5(out, x2)
        out = self.up6(out, x1)
        return torch.sigmoid(self.output_layer(out))

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY, diffX = x2.size()[2] - x1.size()[2], x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        return self.conv(torch.cat([x2, x1], dim=1))

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
path = kagglehub.dataset_download("arnavjain1/glaucoma-datasets")
train_dir = os.path.join(path, 'G1020')
train_dir2 = os.path.join(path, 'ORIGA')
train_dirs = [train_dir, train_dir2]
val_dir = [os.path.join(path, 'REFUGE')]
train_set = GlaucomaDataset(train_dirs, split='train')
val_set = GlaucomaDataset(val_dir, split='val')

train_loader = DataLoader(train_set, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)

Successfully loaded train dataset.                                                   
Successfully loaded train dataset.                                                   
Successfully loaded val dataset.                                                   


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model, Loss, Optimizer
model = UNet(n_channels=3, n_classes=2).to(device)
seg_loss = torch.nn.BCELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=1e-4)
checkpoint = torch.load('unet_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
# Training and Validation Loop
nb_train_batches = len(train_loader)
nb_val_batches = len(val_loader)
nb_iter = 0
best_val_auc = 0.
iters = list(range(1, 10))
val_losses = []
train_losses = []
train_accuracy = []
val_accuracy = []
num_epochs = 25

  checkpoint = torch.load('unet_model.pth', map_location=device)


In [None]:
while model.epoch < num_epochs:
    # Accumulators
    train_vCDRs, val_vCDRs = [], []
    train_loss, val_loss = 0., 0.
    train_dsc_od, val_dsc_od = 0., 0.
    train_dsc_oc, val_dsc_oc = 0., 0.
    train_vCDR_error, val_vCDR_error = 0., 0.

    model.train()
    train_data = iter(train_loader)
    for k in range(nb_train_batches):
        imgs, seg_gts = next(train_data)
        imgs, seg_gts = imgs.to(device), seg_gts.to(device)

        # Forward pass
        logits = model(imgs)
        loss = seg_loss(logits, seg_gts)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() / nb_train_batches
        train_losses.append(train_loss)

        with torch.no_grad():
            # Compute segmentation metric
            pred_od = refine_seg((logits[:, 0, :, :] >= 0.5).type(torch.int8).cpu()).to(device)
            pred_oc = refine_seg((logits[:, 1, :, :] >= 0.5).type(torch.int8).cpu()).to(device)
            gt_od = seg_gts[:, 0, :, :].type(torch.int8)
            gt_oc = seg_gts[:, 1, :, :].type(torch.int8)
            dsc_od = compute_dice_coef(pred_od, gt_od)
            dsc_oc = compute_dice_coef(pred_oc, gt_oc)
            train_dsc_od += dsc_od.item() / nb_train_batches
            train_dsc_oc += dsc_oc.item() / nb_train_batches
            vCDR_error, pred_vCDR, gt_vCDR = compute_vCDR_error(pred_od.cpu().numpy(), pred_oc.cpu().numpy(), gt_od.cpu().numpy(), gt_oc.cpu().numpy())
            train_vCDRs += pred_vCDR.tolist()
            train_vCDR_error += vCDR_error / nb_train_batches
        nb_iter += 1
        print('Epoch {}, iter {}/{}, loss {:.6f}'.format(model.epoch + 1, k + 1, nb_train_batches, loss.item()) + ' ' * 20,
              end='\r')

    model.eval()
    with torch.no_grad():
        for k, (imgs, seg_gts) in enumerate(val_loader):
            imgs, seg_gts = imgs.to(device), seg_gts.to(device)
            logits = model(imgs)
            val_loss += seg_loss(logits, seg_gts).item() / nb_val_batches
            val_losses.append(val_loss)
            print('Validation iter {}/{}'.format(k + 1, nb_val_batches) + ' ' * 50,
                  end='\r')
            pred_od = refine_seg((logits[:, 0, :, :] >= 0.5).type(torch.int8).cpu()).to(device)
            pred_oc = refine_seg((logits[:, 1, :, :] >= 0.5).type(torch.int8).cpu()).to(device)
            gt_od = seg_gts[:, 0, :, :].type(torch.int8)
            gt_oc = seg_gts[:, 1, :, :].type(torch.int8)
            dsc_od = compute_dice_coef(pred_od, gt_od)
            dsc_oc = compute_dice_coef(pred_oc, gt_oc)
            val_dsc_od += dsc_od.item() / nb_val_batches
            val_dsc_oc += dsc_oc.item() / nb_val_batches

            vCDR_error, pred_vCDR, gt_vCDR = compute_vCDR_error(pred_od.cpu().numpy(), pred_oc.cpu().numpy(), gt_od.cpu().numpy(), gt_oc.cpu().numpy())
            val_vCDRs += pred_vCDR.tolist()
            val_vCDR_error += vCDR_error / nb_val_batches

    print('Epoch {}'.format(model.epoch + 1) + ' ' * 50)
    print('Loss: {:.4f} (train), {:.4f} (val)'.format(train_loss, val_loss))
    print('Dice Score - OD segmentation: {:.4f} (train), {:.4f} (val)'.format(train_dsc_od, val_dsc_od))
    print('Dice Score - OC segmentation: {:.4f} (train), {:.4f} (val)'.format(train_dsc_oc, val_dsc_oc))
    print('vCDR error: {:.4f} (train), {:.4f} (val)'.format(train_vCDR_error, val_vCDR_error))
    # End of epoch
    model.epoch += 1

save_path = '/content/unet_model.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, save_path)

  checkpoint = torch.load('unet_model.pth', map_location=device)


Epoch 1                                                  
Loss: 0.0050 (train), 0.0092 (val)
Dice Score - OD segmentation: 0.8919 (train), 0.7881 (val)
Dice Score - OC segmentation: 0.7498 (train), 0.8014 (val)
vCDR error: 1.0212 (train), 0.1764 (val)
Epoch 2                                                  
Loss: 0.0043 (train), 0.0080 (val)
Dice Score - OD segmentation: 0.8987 (train), 0.7908 (val)
Dice Score - OC segmentation: 0.7553 (train), 0.8407 (val)
vCDR error: 1.0104 (train), 0.1537 (val)
Epoch 3                                                  
Loss: 0.0037 (train), 0.0077 (val)
Dice Score - OD segmentation: 0.9069 (train), 0.8185 (val)
Dice Score - OC segmentation: 0.7617 (train), 0.8307 (val)
vCDR error: 1.0112 (train), 0.1717 (val)
Epoch 4                                                  
Loss: 0.0034 (train), 0.0086 (val)
Dice Score - OD segmentation: 0.9112 (train), 0.7827 (val)
Dice Score - OC segmentation: 0.7712 (train), 0.8339 (val)
vCDR error: 0.9739 (train), 0.142

In [None]:
test_dir = os.path.join(path, 'REFUGE', 'train', 'Images')
json_path = os.path.join(path, 'REFUGE', 'train', 'index.json')

# Threshold for vCDR classification (e.g., images with vCDR > 0.6 are labeled as glaucoma)
vCDR_threshold = 0.6
test_set = TestDataset(test_dir, json_path)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False)
checkpoint = torch.load('unet_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
tp, tn, fp, fn = 0, 0, 0, 0
output_file_path = "glaucoma_predictions.txt" 

# Perform predictions and collect data
with torch.no_grad():
    predictions, ground_truth_labels = [], []
    with open(output_file_path, "w") as f: 
        for (img, ground_truth_label), filename in zip(test_loader, test_set.image_filenames):
            img = img.to(device)
            logits = model(img)

            # Get segmentation predictions for OD and OC
            pred_od = refine_seg((logits[:, 0, :, :] >= 0.5).type(torch.int8).cpu()).to(device)
            pred_oc = refine_seg((logits[:, 1, :, :] >= 0.5).type(torch.int8).cpu()).to(device)
            pred_vCDR = vertical_cup_to_disc_ratio(pred_od.cpu().numpy(), pred_oc.cpu().numpy())[0]
            predicted_label = int(pred_vCDR > vCDR_threshold)
            ground_truth_label = int(ground_truth_label.item())
            predictions.append(predicted_label)
            ground_truth_labels.append(ground_truth_label)
            f.write(
                f"Image: {filename}, vCDR: {pred_vCDR:.2f}, Prediction: {predicted_label}, Ground Truth: {ground_truth_label}\n"
            )

            # Update TP, TN, FP, FN counts
            if predicted_label == 1 and ground_truth_label == 1:
                tp += 1
            elif predicted_label == 0 and ground_truth_label == 0:
                tn += 1
            elif predicted_label == 1 and ground_truth_label == 0:
                fp += 1
            elif predicted_label == 0 and ground_truth_label == 1:
                fn += 1

# Calculate overall accuracy
accuracy = np.mean([pred == gt for pred, gt in zip(predictions, ground_truth_labels)])

with open(output_file_path, "r+") as f:
    content = f.read()
    f.seek(0, 0)
    f.write(f"Overall Accuracy: {accuracy * 100:.2f}%\n")
    f.write(content)

print(f"Test Accuracy for Glaucoma Classification: {accuracy * 100:.2f}%")
print(f"True Positives (TP): {tp}")
print(f"True Negatives (TN): {tn}")
print(f"False Positives (FP): {fp}")
print(f"False Negatives (FN): {fn}")
print(f"Predictions have been saved to {output_file_path}")


  checkpoint = torch.load('unet_model.pth', map_location=device)


Test Accuracy for Glaucoma Classification: 89.50%
True Positives (TP): 33
True Negatives (TN): 325
False Positives (FP): 35
False Negatives (FN): 7
Predictions have been saved to glaucoma_predictions.txt
