In [22]:
import torch
from model import UNet
from dataset import RetinaDataset
from utils import load_data
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv
import json
import segmentation_models_pytorch as smp

In [10]:
train_images, train_masks, test_images, test_masks = load_data("/scratch/y.aboelwafa/Retina_Blood_Vessel_Segmentation/dataset")
train_dataset = RetinaDataset(train_images, train_masks, augment=True)
test_dataset = RetinaDataset(test_images, test_masks)
X, y = train_dataset[0]
print(X.shape, y.shape)

torch.Size([3, 512, 512]) torch.Size([1, 512, 512])


In [15]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)

first_batch_images, first_batch_labels = next(iter(train_dataloader))
print(f"Shape of images in the first batch: {first_batch_images.shape}")
print(f"Shape of labels in the first batch: {first_batch_labels.shape}")

Shape of images in the first batch: torch.Size([4, 3, 512, 512])
Shape of labels in the first batch: torch.Size([4, 1, 512, 512])


In [None]:
BATCH_SIZE = 4
EPOCHS = 5
LR = 1e-4
IN_CHANNELS = 3
OUT_CHANNELS = 1
CHECKPOINT_PATH = "/scratch/y.aboelwafa/Retina_Blood_Vessel_Segmentation/checkpoints/checkpoint.pth"
METRICS_PATH = "/scratch/y.aboelwafa/Retina_Blood_Vessel_Segmentation/metrics/metrics.json"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = smp.losses.DiceLoss(mode='binary')

metrics = {
    'train_loss': [], 'val_loss': [],
    'train_iou_score': [], 'val_iou_score': [],
    'train_dice_score': [], 'val_dice_score': []
}

for epoch in range(EPOCHS):
    model.train()
    for batch_idx, (image, mask) in enumerate(train_dataloader):
        image = image.to(device=device)
        mask = mask.to(device=device)
        optimizer.zero_grad()
        pred = model(image)
        loss = criterion(pred, mask)
        loss.backward()
        optimizer.step()
        if batch_idx % 5 == 0:
            print(f"Epoch {epoch}, batch {batch_idx}, loss: {loss.item()}")

    model.eval()
    with torch.no_grad():
        for image, mask in test_dataloader:
            image = image.to(device=device)
            mask = mask.float().to(device=device)
            pred = model(image)
            loss = criterion(pred, mask)
            print(f"Validation loss: {loss.item()}")
            
            
with open(METRICS_PATH, 'w') as f:
    json.dump(metrics, f)

In [None]:
# num_images = 5
# fig, axs = plt.subplots(num_images, 2, figsize=(10, 5*num_images))

# for i in range(num_images):
#     image_tensor, mask_tensor = train_dataset[i]
    
#     # Convert the image tensor from CxHxW to HxWxC for plotting
#     image = image_tensor.numpy().transpose(1, 2, 0)
#     mask = mask_tensor.numpy().squeeze()  # Remove the channel dimension from the mask
    
#     axs[i, 0].imshow(image)
#     axs[i, 0].axis('off')
#     axs[i, 0].set_title(f'Image {i+1}')
    
#     axs[i, 1].imshow(mask, cmap='gray')
#     axs[i, 1].axis('off')
#     axs[i, 1].set_title(f'Mask {i+1}')

# plt.tight_layout()
# plt.show()