In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
import os
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F


class FundusVesselDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.filenames = sorted(os.listdir(image_dir))
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.filenames[idx])
        label_path = os.path.join(self.label_dir, self.filenames[idx])

        image = cv2.imread(img_path).astype(np.float32)
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE).astype(np.float32)

        # Z-score normalize
        mean, std = image.mean(), image.std()
        image = (image - mean) / (std + 1e-8)

        image = np.transpose(image, (2, 0, 1))  # HWC → CHW
        label = label / 255.0  # binarize [0, 1]
        label = np.expand_dims(label, axis=0)

        return torch.tensor(image), torch.tensor(label)


In [None]:
from torch.utils.data import DataLoader

# === Paths ===
train_image_dir = "/content/drive/MyDrive/EECE 490 Project/Segmentation Set/Preprocessed_All_Train_Images"
train_label_dir = "/content/drive/MyDrive/EECE 490 Project/Segmentation Set/Preprocessed_All_Train_Labels"

# === Dataset + DataLoader ===
train_dataset = FundusVesselDataset(train_image_dir, train_label_dir)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)


In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F

# === DropBlock2D ===
class DropBlock2D(nn.Module):
    def __init__(self, block_size, drop_prob):
        super(DropBlock2D, self).__init__()
        self.block_size = block_size
        self.drop_prob = drop_prob

    def forward(self, x):
        if not self.training or self.drop_prob == 0.0:
            return x

        gamma = self._compute_gamma(x)
        B, C, H, W = x.shape
        mask = (torch.rand(B, 1, H, W, device=x.device) < gamma).float()
        block_mask = F.max_pool2d(mask, self.block_size, stride=1, padding=self.block_size // 2)
        block_mask = 1 - block_mask
        out = x * block_mask
        out = out * (block_mask.numel() / block_mask.sum().clamp(min=1.0))
        return out

    def _compute_gamma(self, x):
        _, _, h, w = x.size()
        return self.drop_prob / (self.block_size ** 2) * (h * w) / ((h - self.block_size + 1) * (w - self.block_size + 1))

# === Res2Block ===
class Res2Block(nn.Module):
    def __init__(self, in_channels, out_channels, scale=4):
        super(Res2Block, self).__init__()
        self.scale = scale
        width = out_channels // scale
        self.conv1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.convs = nn.ModuleList([
            nn.Conv2d(width, width, 3, padding=1, bias=False) for _ in range(scale - 1)
        ])
        self.bns = nn.ModuleList([
            nn.BatchNorm2d(width) for _ in range(scale - 1)
        ])

        self.conv3 = nn.Conv2d(out_channels, out_channels, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        spx = list(torch.chunk(out, self.scale, 1))
        for i in range(1, self.scale):
            spx[i] = self.relu(self.bns[i - 1](self.convs[i - 1](spx[i] + spx[i - 1])))
        out = torch.cat(spx, 1)
        out = self.bn3(self.conv3(out))
        return self.relu(out)

# === Spatial Attention ===
class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention, self).__init__()

    def forward(self, x):
        A = torch.sum(x, dim=1, keepdim=True)
        flat = A.view(A.size(0), -1)
        softmax = F.softmax(flat, dim=-1)
        P = softmax.view_as(A)
        return x * P

# === PAM ===
class PAM(nn.Module):
    def __init__(self, in_channels):
        super(PAM, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, H, W = x.size()
        proj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(B, -1, H * W)
        energy = torch.bmm(proj_query, proj_key)
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(B, -1, H * W)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(B, C, H, W)
        return self.gamma * out + x

# === CAM ===
class CAM(nn.Module):
    def __init__(self, in_channels):
        super(CAM, self).__init__()
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, H, W = x.size()
        proj_query = x.view(B, C, -1)
        proj_key = x.view(B, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        attention = F.softmax(energy, dim=-1)
        proj_value = x.view(B, C, -1)
        out = torch.bmm(attention, proj_value).view(B, C, H, W)
        return self.gamma * out + x

# === DA-Res2UNet ===
class DARes2UNet(nn.Module):
    def __init__(self):
        super(DARes2UNet, self).__init__()
        self.enc1 = nn.Sequential(Res2Block(3, 16), DropBlock2D(7, 0.1))
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = nn.Sequential(Res2Block(16, 32), DropBlock2D(7, 0.1))
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = nn.Sequential(Res2Block(32, 64), DropBlock2D(7, 0.1))
        self.pool3 = nn.MaxPool2d(2)

        self.enc4 = nn.Sequential(Res2Block(64, 128), DropBlock2D(7, 0.1))

        self.pam = PAM(128)
        self.cam = CAM(128)
        self.sa = SpatialAttention()

        self.up3 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec3 = Res2Block(128, 64)

        self.up2 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.dec2 = Res2Block(64, 32)

        self.up1 = nn.ConvTranspose2d(32, 16, 2, stride=2)
        self.dec1 = Res2Block(32, 16)

        self.final = nn.Conv2d(16, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))

        att = self.sa(self.pam(self.cam(e4)))

        d3 = self.up3(att)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))

        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))

        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.sigmoid(self.final(d1))

In [None]:
from torch import nn, optim
from torch.optim.lr_scheduler import CosineAnnealingLR

# === Initialize model, optimizer, scheduler ===
model = DARes2UNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=40)

# === Load checkpoint (handle both old and new format) ===
#checkpoint_path = "/content/drive/MyDrive/EECE 490 Project/SegTraining3/best_model.pth"
#checkpoint = torch.load(checkpoint_path)

#if 'model_state_dict' in checkpoint:  # New format (full checkpoint)
#    print("🔄 Loading model, optimizer, scheduler states...")
#    start_epoch = checkpoint.get('epoch', 0)
#    model.load_state_dict(checkpoint['model_state_dict'])
#    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
#else:  # Old format (model weights only)
#    print("🔄 Loading model weights only...")
#    model.load_state_dict(checkpoint)

#model = model.to(device)


In [None]:
import os
import cv2
import time
import torch

# === Ensure Save Directory Exists ===
log_dir = "/content/drive/MyDrive/EECE 490 Project/SegTraining"
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(log_dir, "training_log.txt")
model_path = os.path.join(log_dir, "best_model.pth")

# === Setup ===
EPOCHS = 40
best_loss = float('inf')

def dice_coef(preds, targets, threshold=0.5, eps=1e-6):
    preds = (preds > threshold).float()
    intersection = (preds * targets).sum()
    return (2. * intersection + eps) / (preds.sum() + targets.sum() + eps)

# === Training Loop ===
for epoch in range(EPOCHS):
    start_time = time.time()
    model.train()
    epoch_loss = 0
    dice_total = 0

    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        dice_total += dice_coef(outputs, masks).item()

    scheduler.step()
    avg_dice = dice_total / len(train_loader)
    epoch_time = time.time() - start_time

    # === Print ===
    print("="*60)
    print(f"🧠 EPOCH {epoch+1}/{EPOCHS}")
    print(f"📉 Loss: {epoch_loss:.4f} | 🎯 Dice: {avg_dice:.4f} | ⏱ Time: {epoch_time:.2f}s")
    print("="*60)

    # === Save metrics to log file ===
    with open(log_path, "a") as f:
        f.write(f"Epoch {epoch+1}/{EPOCHS} - Loss: {epoch_loss:.4f}, Dice: {avg_dice:.4f}\n")

    # === Save best model checkpoint ===
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), model_path)
        print("✅ Best model saved to Google Drive!")

    # === Save snapshot every 5 epochs ===
    if (epoch + 1) % 5 == 0:
        model.eval()
        with torch.no_grad():
            img, mask = next(iter(train_loader))
            pred = model(img.to(device)).cpu().squeeze().numpy()

            pred_img = (pred * 255).astype(np.uint8)
            snapshot_path = os.path.join(log_dir, f"pred_epoch_{epoch+1}.png")
            cv2.imwrite(snapshot_path, pred_img)
            print(f"📸 Snapshot saved: pred_epoch_{epoch+1}.png")


In [None]:
import os
import cv2
import time
import torch

# === Ensure Save Directory Exists ===
log_dir = "/content/drive/MyDrive/EECE 490 Project/SegTraining2"
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(log_dir, "training_log.txt")
model_path = os.path.join(log_dir, "best_model.pth")

# === Setup ===
EPOCHS = 100
best_loss = 28.4986  # Resume from last best val loss

def dice_coef(preds, targets, threshold=0.5, eps=1e-6):
    preds = (preds > threshold).float()
    intersection = (preds * targets).sum()
    return (2. * intersection + eps) / (preds.sum() + targets.sum() + eps)

# === Training Loop ===
start_epoch = 37
for _ in range(start_epoch):
    scheduler.step()
for epoch in range(start_epoch, EPOCHS):
    start_time = time.time()
    model.train()
    epoch_loss = 0
    dice_total = 0

    # === TRAINING LOOP ===
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        dice_total += dice_coef(outputs, masks).item()

    scheduler.step()
    avg_train_loss = epoch_loss / len(train_loader)
    avg_train_dice = dice_total / len(train_loader)
    epoch_time = time.time() - start_time


    # === PRINT METRICS ===
    print("="*60)
    print(f"🧠 EPOCH {epoch+1}/{EPOCHS}")
    print(f"📉 Train Loss: {avg_train_loss:.4f} | 🎯 Train Dice: {avg_train_dice:.4f}")
    print(f"⏱ Time: {epoch_time:.2f}s")
    print("="*60)

    # === SAVE LOG ===
    with open(log_path, "a") as f:
        f.write(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {avg_train_loss:.4f}, Train Dice: {avg_train_dice:.4f}")

    # === SAVE BEST MODEL ===
    if avg_train_loss < best_loss:
        best_loss = avg_train_loss
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
        }, model_path)
        print("✅ Best model + optimizer + scheduler saved to Google Drive!")

    # === SNAPSHOT EVERY 5 EPOCHS ===
    if (epoch + 1) % 5 == 0:
        model.eval()
        with torch.no_grad():
            img, mask = next(iter(train_loader))
            pred = model(img.to(device)).cpu().numpy()
            pred_img = pred[0, 0]  # First image, first channel (H x W)
            pred_img = (pred_img * 255).astype(np.uint8)

            snapshot_path = os.path.join(log_dir, f"pred_epoch_{epoch+1}.png")
            cv2.imwrite(snapshot_path, pred_img)
            print(f"📸 Snapshot saved: pred_epoch_{epoch+1}.png")


In [None]:
import os
import cv2
import time
import torch

# === Ensure Save Directory Exists ===
log_dir = "/content/drive/MyDrive/EECE 490 Project/SegTraining3"
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(log_dir, "training_log.txt")
model_path = os.path.join(log_dir, "best_model.pth")

# === Setup ===
EPOCHS = 100
best_loss = float('inf')

def dice_coef(preds, targets, threshold=0.5, eps=1e-6):
    preds = (preds > threshold).float()
    intersection = (preds * targets).sum()
    return (2. * intersection + eps) / (preds.sum() + targets.sum() + eps)

# === Training Loop ===
for epoch in range(EPOCHS):
    start_time = time.time()
    model.train()
    epoch_loss = 0
    dice_total = 0

    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        dice_total += dice_coef(outputs, masks).item()

    scheduler.step()
    avg_dice = dice_total / len(train_loader)
    epoch_time = time.time() - start_time

    # === Print ===
    print("="*60)
    print(f"🧠 EPOCH {epoch+1}/{EPOCHS}")
    print(f"📉 Loss: {epoch_loss:.4f} | 🎯 Dice: {avg_dice:.4f} | ⏱ Time: {epoch_time:.2f}s")
    print("="*60)

    # === Save metrics to log file ===
    with open(log_path, "a") as f:
        f.write(f"Epoch {epoch+1}/{EPOCHS} - Loss: {epoch_loss:.4f}, Dice: {avg_dice:.4f}\n")

    # === Save best model checkpoint ===
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), model_path)
        print("✅ Best model saved to Google Drive!")

    # === Save snapshot every 5 epochs ===
    if (epoch + 1) % 5 == 0:
        model.eval()
        with torch.no_grad():
            img, mask = next(iter(train_loader))
            pred = model(img.to(device)).cpu().numpy()

            # Save only the first image in the batch
            pred_img = pred[0, 0]  # shape: [256, 256]
            pred_img = (pred_img * 255).clip(0, 255).astype(np.uint8)

            snapshot_path = os.path.join(log_dir, f"pred_epoch_{epoch+1}.png")
            cv2.imwrite(snapshot_path, pred_img)
            print(f"📸 Snapshot saved: pred_epoch_{epoch+1}.png")


In [None]:
# === Reload model for inference ===
model = DARes2UNet()
model.load_state_dict(torch.load("/content/drive/MyDrive/EECE 490 Project/SegTraining3/best_model.pth"))
model = model.to(device)
model.eval()  # Set to evaluation mode


In [None]:
# === Paths ===
test_image_dir = "/content/drive/MyDrive/EECE 490 Project/Segmentation Set/Preprocessed_All_Test_Images"
test_label_dir = "/content/drive/MyDrive/EECE 490 Project/Segmentation Set/Preprocessed_All_Test_Labels"

# === Dataset + DataLoader ===
test_dataset = FundusVesselDataset(test_image_dir, test_label_dir)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)


In [None]:
import matplotlib.pyplot as plt
import os
import cv2
import torch

# === Inference Save Directory ===
save_dir = "/content/drive/MyDrive/EECE 490 Project/Segmentation Set/Test_Predictions"
os.makedirs(save_dir, exist_ok=True)

# === Inference Loop ===
for idx, (img, mask) in enumerate(test_loader):
    img = img.to(device)

    # 🔥 Get exact original filename 🔥
    filename = test_loader.dataset.filenames[idx]  # EXACT input image name

    with torch.no_grad():
        pred = model(img)
        pred = pred.cpu().numpy()[0, 0]  # [batch, channel, H, W] → [H, W]

    # === Save prediction ===
    pred_img = (pred * 255).clip(0, 255).astype(np.uint8)
    save_path = os.path.join(save_dir, filename)  # SAME name, SAME extension!
    cv2.imwrite(save_path, pred_img)

    # Optional: Show first image
    if idx == 0:
        plt.figure(figsize=(12,4))
        plt.subplot(1,3,1); plt.imshow(img.cpu().numpy()[0].transpose(1,2,0)); plt.title("Input Image")
        plt.subplot(1,3,2); plt.imshow(mask.cpu().numpy()[0,0], cmap='gray'); plt.title("Ground Truth")
        plt.subplot(1,3,3); plt.imshow(pred, cmap='gray'); plt.title("Prediction")
        plt.show()


In [None]:
import os
import cv2
import numpy as np

# === Directories ===
pred_dir = "/content/drive/MyDrive/EECE 490 Project/Segmentation Set/Test_Predictions"
gt_dir   = "/content/drive/MyDrive/EECE 490 Project/Segmentation Set/Preprocessed_All_Test_Labels"

# === Initialize metrics ===
dice_scores = []
accuracies  = []
sensitivities = []
specificities = []

# === Metric Calculation Function ===
def compute_metrics(pred, gt):
    pred = (pred > 0.5).astype(np.uint8)
    gt   = (gt > 0.5).astype(np.uint8)

    TP = np.sum((pred == 1) & (gt == 1))
    TN = np.sum((pred == 0) & (gt == 0))
    FP = np.sum((pred == 1) & (gt == 0))
    FN = np.sum((pred == 0) & (gt == 1))

    dice = (2 * TP) / (2 * TP + FP + FN + 1e-8)
    acc  = (TP + TN) / (TP + TN + FP + FN + 1e-8)
    sen  = TP / (TP + FN + 1e-8)
    spe  = TN / (TN + FP + 1e-8)

    return dice, acc, sen, spe

# === Loop over predictions ===
for filename in os.listdir(pred_dir):
    pred_path = os.path.join(pred_dir, filename)
    gt_path   = os.path.join(gt_dir, filename)

    # === Ensure both files exist ===
    if not os.path.exists(gt_path):
        print(f"🚨 Ground truth missing for: {filename}")
        continue

    # === Load images ===
    pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
    gt   = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)

    # === Skip unreadable images ===
    if pred is None or gt is None:
        print(f"🚨 Problem loading {filename}")
        continue

    # === Normalize to [0, 1] ===
    pred = pred / 255.0
    gt   = gt / 255.0

    # === Compute metrics ===
    dice, acc, sen, spe = compute_metrics(pred, gt)
    dice_scores.append(dice)
    accuracies.append(acc)
    sensitivities.append(sen)
    specificities.append(spe)

# === Report average metrics ===
print(f"🎯 Average Dice Coefficient: {np.mean(dice_scores):.4f}")
print(f"✅ Average Accuracy: {np.mean(accuracies):.4f}")
print(f"🔥 Average Sensitivity (Recall): {np.mean(sensitivities):.4f}")
print(f"🛡️ Average Specificity: {np.mean(specificities):.4f}")


In [None]:
import os

# === Directories ===
test_image_dir = "/content/drive/MyDrive/EECE 490 Project/Segmentation Set/Preprocessed_All_Test_Images"
test_label_dir = "/content/drive/MyDrive/EECE 490 Project/Segmentation Set/Preprocessed_All_Test_Labels"

# === Filter valid image files ===
image_extensions = {".png", ".jpg", ".jpeg", ".tif"}

image_files = sorted([
    f for f in os.listdir(test_image_dir)
    if os.path.splitext(f)[1].lower() in image_extensions
])

label_files = sorted([
    f for f in os.listdir(test_label_dir)
    if os.path.splitext(f)[1].lower() in image_extensions
])

# === Counts ===
print(f"🖼️ Total Test Images: {len(image_files)}")
print(f"🎯 Total Test Labels: {len(label_files)}")

# === Check for mismatches ===
missing_labels = set(image_files) - set(label_files)
missing_images = set(label_files) - set(image_files)

if not missing_labels and not missing_images:
    print("✅ All image and label filenames MATCH perfectly!")
else:
    if missing_labels:
        print(f"🚨 Images without labels: {missing_labels}")
    if missing_images:
        print(f"🚨 Labels without images: {missing_images}")


In [None]:
# === Load your custom image ===
img_path = "/content/drive/MyDrive/EECE 490 Project/Generated_Images_DR/seed0002.png"
image = cv2.imread(img_path).astype(np.float32)

# === Preprocess (match training) ===
mean, std = image.mean(), image.std()
image = (image - mean) / (std + 1e-8)
image = np.transpose(image, (2, 0, 1))  # HWC → CHW
image = torch.tensor(image).unsqueeze(0).to(device)  # Add batch dim

# === Inference ===
with torch.no_grad():
    pred = model(image)
    pred = pred.cpu().numpy()[0, 0]  # Remove batch & channel dims

# === Save or display prediction ===
pred_img = (pred * 255).clip(0, 255).astype(np.uint8)
cv2.imwrite("/content/drive/MyDrive/pred_custom.png", pred_img)
