In [1]:
import os
from datasets import load_dataset

scratch_path = '/scratch/zdiao7/OLIVES_Dataset'
os.makedirs(scratch_path, exist_ok=True)

olives = load_dataset('gOLIVES/OLIVES_Dataset', 'biomarker_detection', split='train',cache_dir=scratch_path)

print(f"Dataset downloaded to: {scratch_path}")
print(f"Number of samples in the dataset: {len(olives)}")

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/32 [00:00<?, ?files/s]

Generating train split:   0%|          | 0/78822 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3871 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/38 [00:00<?, ?it/s]

Dataset downloaded to: /scratch/zdiao7/OLIVES_Dataset
Number of samples in the dataset: 78822


In [2]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torchvision.models as models
olives = olives.with_format("torch")

In [3]:
olives = olives.filter(
    lambda x: not torch.isnan(x['B1']) 
    and not torch.isnan(x['B3']) 
    and not torch.isnan(x['B4']) 
    and not torch.isnan(x['B5'])
)
print(f"Filtered dataset contains {len(olives)} samples.")

Filter:   0%|          | 0/78822 [00:00<?, ? examples/s]

Filtered dataset contains 17591 samples.


In [4]:
def is_valid_sample(data_point):
    return (
        data_point['Image'] is not None
        and not torch.isnan(torch.tensor(data_point['BCVA']))
        and not torch.isnan(torch.tensor(data_point['CST']))
        and not any(torch.isnan(torch.tensor(data_point[key])) for key in ['B1', 'B3', 'B4', 'B5'])
    )

olives = olives.filter(is_valid_sample)
print(f"Filtered dataset contains {len(olives)} samples.")

Filter:   0%|          | 0/17591 [00:00<?, ? examples/s]

  and not torch.isnan(torch.tensor(data_point['BCVA']))
  and not torch.isnan(torch.tensor(data_point['CST']))
  and not any(torch.isnan(torch.tensor(data_point[key])) for key in ['B1', 'B3', 'B4', 'B5'])


Filtered dataset contains 17444 samples.


In [8]:
class OlivesDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __getitem__(self, idx):
        img = self.data[idx]['Image'].float()
        if len(img.shape) == 3 and img.shape[0] == 3:
            img = img.mean(dim=0, keepdim=True)
        elif len(img.shape) == 2:
            img = img.unsqueeze(0)

        if self.transform:
            img = self.transform(img)

        clinical = torch.tensor([self.data[idx]['BCVA'], self.data[idx]['CST']], dtype=torch.float32)
        label = torch.tensor(
            [self.data[idx]['B1'], self.data[idx]['B3'], self.data[idx]['B4'], self.data[idx]['B5']],
            dtype=torch.float32
        )
        return {'Image': img, 'Clinical': clinical, 'Label': label}

    def __len__(self):
        return len(self.data)

In [9]:
transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.Normalize(mean=[0.5], std=[0.5])
])
train_dataset = OlivesDataset(olives, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [11]:
import torch.nn.functional as F
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g = F.interpolate(g, size=x.shape[2:], mode='bilinear', align_corners=True)
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class AttentionUNet(nn.Module):
    def __init__(self, input_channels=1, output_channels=256):
        super(AttentionUNet, self).__init__()
        self.enc1 = self.conv_block(input_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        self.middle = self.conv_block(512, 1024)

        self.att4 = AttentionBlock(1024, 512, 256)
        self.att3 = AttentionBlock(512, 256, 128)
        self.att2 = AttentionBlock(256, 128, 64)

        self.dec4 = self.conv_block(1024 + 512, 512)
        self.dec3 = self.conv_block(512 + 256, 256)
        self.dec2 = self.conv_block(256 + 128, 128)
        self.dec1 = self.conv_block(128 + 64, 64)

        self.output_conv = nn.Conv2d(64, output_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))
        enc3 = self.enc3(F.max_pool2d(enc2, 2))
        enc4 = self.enc4(F.max_pool2d(enc3, 2))
        middle = self.middle(F.max_pool2d(enc4, 2))

        att4 = self.att4(middle, enc4)
        dec4 = self.dec4(torch.cat([F.interpolate(middle, size=enc4.shape[2:], mode='bilinear', align_corners=True), att4], dim=1))

        att3 = self.att3(dec4, enc3)
        dec3 = self.dec3(torch.cat([F.interpolate(dec4, size=enc3.shape[2:], mode='bilinear', align_corners=True), att3], dim=1))

        att2 = self.att2(dec3, enc2)
        dec2 = self.dec2(torch.cat([F.interpolate(dec3, size=enc2.shape[2:], mode='bilinear', align_corners=True), att2], dim=1))

        dec1 = self.dec1(torch.cat([F.interpolate(dec2, size=enc1.shape[2:], mode='bilinear', align_corners=True), enc1], dim=1))
        outputs = self.output_conv(dec1)
        return outputs

In [17]:
class MultiAttentionUNet(nn.Module):
    def __init__(self):
        super(MultiAttentionUNet, self).__init__()
        self.image_model = AttentionUNet(input_channels=1, output_channels=256)
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.clinical_fc = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Linear(16, 16)
        )
        self.fusion_fc = nn.Sequential(
            nn.Linear(256 + 16, 128),
            nn.ReLU(),
            nn.Linear(128, 4)
        )

    def forward(self, images, clinical_features):
        image_features = self.image_model(images)
        image_features = self.global_pool(image_features).view(image_features.size(0), -1)
        clinical_features = self.clinical_fc(clinical_features)
        combined_features = torch.cat((image_features, clinical_features), dim=1)
        outputs = self.fusion_fc(combined_features)
        return outputs

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MultiAttentionUNet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.4)

def train_model(model, train_loader, criterion, optimizer, scheduler, epochs=10, save_path="trained_attention_unet.pth"):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for batch in train_loader:
            images = batch['Image'].to(device)
            clinical = batch['Clinical'].to(device)
            labels = batch['Label'].to(device)
            optimizer.zero_grad()
            outputs = model(images, clinical)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}")

    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

In [11]:
train_model(model, train_loader, criterion, optimizer, scheduler, epochs=40, save_path="trained_attention_unet.pth")

Epoch 1/40, Loss: 0.5099
Epoch 2/40, Loss: 0.3966
Epoch 3/40, Loss: 0.3446
Epoch 4/40, Loss: 0.3169
Epoch 5/40, Loss: 0.2790
Epoch 6/40, Loss: 0.2445
Epoch 7/40, Loss: 0.1936
Epoch 8/40, Loss: 0.1663
Epoch 9/40, Loss: 0.1436
Epoch 10/40, Loss: 0.1130
Epoch 11/40, Loss: 0.0878
Epoch 12/40, Loss: 0.0629
Epoch 13/40, Loss: 0.0316
Epoch 14/40, Loss: 0.0239
Epoch 15/40, Loss: 0.0192
Epoch 16/40, Loss: 0.0147
Epoch 17/40, Loss: 0.0136
Epoch 18/40, Loss: 0.0122
Epoch 19/40, Loss: 0.0070
Epoch 20/40, Loss: 0.0059
Epoch 21/40, Loss: 0.0060
Epoch 22/40, Loss: 0.0062
Epoch 23/40, Loss: 0.0044
Epoch 24/40, Loss: 0.0042
Epoch 25/40, Loss: 0.0037
Epoch 26/40, Loss: 0.0032
Epoch 27/40, Loss: 0.0028
Epoch 28/40, Loss: 0.0030
Epoch 29/40, Loss: 0.0026
Epoch 30/40, Loss: 0.0025
Epoch 31/40, Loss: 0.0023
Epoch 32/40, Loss: 0.0025
Epoch 33/40, Loss: 0.0022
Epoch 34/40, Loss: 0.0021
Epoch 35/40, Loss: 0.0021
Epoch 36/40, Loss: 0.0028
Epoch 37/40, Loss: 0.0020
Epoch 38/40, Loss: 0.0021
Epoch 39/40, Loss: 0.

In [15]:
test_olives = load_dataset('gOLIVES/OLIVES_Dataset', 'biomarker_detection', split='test',cache_dir=scratch_path)
from sklearn.metrics import f1_score, roc_auc_score
import numpy as np

test_olives = test_olives.with_format("torch")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MultiAttentionUNet().to(device)
model.load_state_dict(torch.load("trained_attention_unet.pth"))
model.eval()

test_dataset = OlivesDataset(test_olives, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

def evaluate_f1_and_roc_auc(model, data_loader, threshold=0.85):
    model.eval()

    all_labels = []
    all_predictions = []
    all_probabilities = []

    with torch.no_grad():
        for batch in data_loader:
            images = batch['Image'].to(device)
            clinical = batch['Clinical'].to(device)
            labels = batch['Label'].to(device)

            outputs = model(images, clinical)
            probabilities = torch.sigmoid(outputs).cpu().numpy()

            predictions = (probabilities > threshold).astype(int)

            all_labels.append(labels.cpu().numpy())
            all_predictions.append(predictions)
            all_probabilities.append(probabilities)

    all_labels = np.vstack(all_labels)
    all_predictions = np.vstack(all_predictions)
    all_probabilities = np.vstack(all_probabilities)

    f1_micro = f1_score(all_labels, all_predictions, average='micro')
    f1_macro = f1_score(all_labels, all_predictions, average='macro')
    per_biomarker_f1 = [
        f1_score(all_labels[:, i], all_predictions[:, i], average='binary')
        for i in range(all_labels.shape[1])
    ]

    roc_auc_micro = roc_auc_score(all_labels, all_probabilities, average='micro')
    roc_auc_macro = roc_auc_score(all_labels, all_probabilities, average='macro')
    per_biomarker_roc_auc = [
        roc_auc_score(all_labels[:, i], all_probabilities[:, i])
        for i in range(all_labels.shape[1])
    ]

    print(f"F1 Score (Micro): {f1_micro:.4f}")
    print(f"F1 Score (Macro): {f1_macro:.4f}")
    for i, f1 in enumerate(per_biomarker_f1):
        print(f"F1 Score (Biomarker {i + 1}): {f1:.4f}")

    print(f"ROC-AUC (Micro): {roc_auc_micro:.4f}")
    print(f"ROC-AUC (Macro): {roc_auc_macro:.4f}")
    for i, auc_score in enumerate(per_biomarker_roc_auc):
        print(f"ROC-AUC (Biomarker {i + 1}): {auc_score:.4f}")

    return f1_micro, f1_macro, per_biomarker_f1, roc_auc_micro, roc_auc_macro, per_biomarker_roc_auc

f1_micro, f1_macro, per_biomarker_f1, roc_auc_micro, roc_auc_macro, per_biomarker_roc_auc = evaluate_f1_and_roc_auc(
    model, test_loader)

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]



F1 Score (Micro): 0.7143
F1 Score (Macro): 0.6514
F1 Score (Biomarker 1): 0.6579
F1 Score (Biomarker 2): 0.8105
F1 Score (Biomarker 3): 0.5657
F1 Score (Biomarker 4): 0.5714
ROC-AUC (Micro): 0.8905
ROC-AUC (Macro): 0.8572
ROC-AUC (Biomarker 1): 0.8312
ROC-AUC (Biomarker 2): 0.8088
ROC-AUC (Biomarker 3): 0.8436
ROC-AUC (Biomarker 4): 0.9455


In [16]:
def evaluate_accuracy(model, test_dataset, threshold=0.85):

    model.eval()
    biomarker_correct_counts = np.zeros(4) 
    biomarker_total_counts = np.zeros(4)
    
    with torch.no_grad():
        for idx in range(len(test_dataset)):
            sample = test_dataset[idx]
            img = sample['Image'].unsqueeze(0).to(device)
            clinical = sample['Clinical'].unsqueeze(0).to(device)
            label = sample['Label'].numpy().astype(int)

            output = model(img, clinical)
            prediction = torch.sigmoid(output).squeeze(0).cpu().numpy()

            binary_prediction = (prediction > threshold).astype(int)

            for i in range(4):
                biomarker_total_counts[i] += 1
                if binary_prediction[i] == label[i]:
                    biomarker_correct_counts[i] += 1

    per_biomarker_accuracy = biomarker_correct_counts / biomarker_total_counts

    for i, accuracy in enumerate(per_biomarker_accuracy):
        print(f"Biomarker {i + 1} Accuracy: {accuracy:.2%}")
    
    return per_biomarker_accuracy

per_biomarker_accuracy = evaluate_accuracy(model, test_dataset, threshold=0.85)




Biomarker 1 Accuracy: 76.49%
Biomarker 2 Accuracy: 78.38%
Biomarker 3 Accuracy: 82.51%
Biomarker 4 Accuracy: 96.67%
