In [17]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [18]:
import os
import torch
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
import torchvision.utils as utils
from torch import nn
from tqdm import tqdm
from timeit import default_timer as timer
import albumentations as A
from albumentations import ToTensorV2

In [19]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
random.seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [20]:
class DRIVE_dataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = [img for img in os.listdir(img_dir) if img.endswith('.tif')]
        self.valid_pairs = []
        for img in self.images:
            if 'training' in img or 'test' in img:
                base_name = img.split('_')[0]
                mask_name = f"{base_name}_manual1.gif"
                mask_path = os.path.join(self.mask_dir, mask_name)
                if os.path.exists(mask_path):
                    self.valid_pairs.append((img, mask_name))
        if not self.valid_pairs:
            raise ValueError("No valid image-mask pairs")
        print(f"Found {len(self.valid_pairs)} valid pairs in {img_dir}")

    def __len__(self):
        return len(self.valid_pairs)

    def __getitem__(self, index):
        image_name, mask_name = self.valid_pairs[index]
        image_path = os.path.join(self.img_dir, image_name)
        mask_path = os.path.join(self.mask_dir, mask_name)

        image = np.array(Image.open(image_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"))

        # Binarize mask: make sure it's 0 and 1
        mask = (mask > 0).astype(np.uint8)

        # If needed, expand mask dims to HWC
        if mask.ndim == 2:
            mask = np.expand_dims(mask, axis=-1)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            mask = mask.float()
            mask = mask.permute(2,0,1)

        return image, mask

In [21]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class Unet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(Unet, self).__init__()
        self.enc1 = DoubleConv(in_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)
        self.bottle_neck = DoubleConv(512, 1024)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(128, 64)
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(self.pool(x1))
        x3 = self.enc3(self.pool(x2))
        x4 = self.enc4(self.pool(x3))
        x = self.bottle_neck(self.pool(x4))
        x = self.up1(x)
        x = torch.cat([x, x4], dim=1)
        x = self.dec1(x)
        x = self.up2(x)
        x = torch.cat([x, x3], dim=1)
        x = self.dec2(x)
        x = self.up3(x)
        x = torch.cat([x, x2], dim=1)
        x = self.dec3(x)
        x = self.up4(x)
        x = torch.cat([x, x1], dim=1)
        x = self.dec4(x)
        return self.out_conv(x)

In [22]:
def iou_score(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    pred = (pred > 0.5).float()
    target = target.float()

    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection

    if union == 0:
        return 0.0
    return (intersection + smooth) / (union + smooth)

def dice_score(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    pred = (pred > 0.5).float()
    target = target.float()
    pred_sum = pred.sum().item()
    target_sum = target.sum().item()
    #print(f"Pred sum: {pred_sum}, Target sum: {target_sum}")
    intersection = (pred * target).sum()
    denominator = pred.sum() + target.sum() + smooth
    if denominator == smooth:
        return 0.0
    return (2. * intersection + smooth) / denominator

def dice_loss(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    pred = (pred > 0.5).float()
    target = target.float()
    intersection = (pred * target).sum()
    denominator = pred.sum() + target.sum() + smooth
    if denominator == smooth:
        return 1.0
    return 1 - ((2. * intersection + smooth) / denominator)

def combined_loss(y_pred, y_true):
    bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([20.0]).to(device))(y_pred, y_true)
    dice = dice_loss(y_pred, y_true)
    return 0.4 * bce + 0.6 * dice

In [23]:

def train_step(model, dataloader, criterion, optimizer):
    model.train()
    train_loss, train_dice,train_iou = 0, 0,0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        y_pred = model(X)
        loss = criterion(y_pred, y)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_dice += dice_score(y_pred, y)
        train_iou += iou_score(y_pred,y)
    train_loss /= len(dataloader)
    train_dice /= len(dataloader)
    train_iou /= len(dataloader)
    return train_loss, train_dice,train_iou

In [24]:

def test_step(model, dataloader, criterion, epoch):
    model.eval()
    test_loss, test_dice,test_iou = 0, 0, 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            y_pred = model(X)
            loss = criterion(y_pred, y)
            test_loss += loss.item()
            test_dice += dice_score(y_pred, y)
            test_iou += iou_score(y_pred,y)
            #if batch == 8 :
                #utils.save_image(pred, f"pred_epoch_{epoch+1}.png")
                #utils.save_image(y, f"target_epoch_{epoch+1}.png")
            if batch == 0:
                pred = torch.sigmoid(y_pred)
                pred = (pred > 0.5).float()

    test_loss /= len(dataloader)
    test_dice /= len(dataloader)
    test_iou /= len(dataloader)
    return test_loss, test_dice,test_iou


In [25]:
def train(model, train_dataloader, test_dataloader, optimizer, criterion, epochs=100):

    results = {"train_loss": [], "train_dice": [],"train_iou": [], "test_loss": [], "test_dice": [],"test_iou":[]}
    #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    for epoch in tqdm(range(epochs)):
        train_loss, train_dice, train_iou = train_step(model, train_dataloader, criterion, optimizer)
        test_loss, test_dice,test_iou = test_step(model, test_dataloader, criterion, epoch)
        print(f"Epoch: {epoch+1} | train_loss: {train_loss:.4f} | train_dice: {train_dice:.4f} | train_iou : {train_iou:.4f} |"
              f"test_loss: {test_loss:.4f} | test_dice: {test_dice:.4f} | test_iou : {test_iou:.4f}")
        #scheduler.step(test_loss)
        results["train_loss"].append(train_loss)
        results["train_dice"].append(train_dice)
        results["train_iou"].append(train_iou)
        results["test_loss"].append(test_loss)
        results["test_dice"].append(test_dice)
        results["test_iou"].append(test_iou)
    return results


In [26]:
if __name__ == "__main__":
    transform = A.Compose([
    A.Resize(512, 512),
    #A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    #A.RandomBrightnessContrast(p=0.3),
    #A.ShiftScaleRotate(shift_limit=0.02, scale_limit=0.05, rotate_limit=10, p=0.5),
    A.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0),
    ToTensorV2(),
    ])
    image_transform = T.Compose([
        transform,
        #T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    train_data = DRIVE_dataset(
        img_dir="/content/drive/MyDrive/retina-vessel-segmentation/data/DRIVE/training/images",
        mask_dir="/content/drive/MyDrive/retina-vessel-segmentation/data/DRIVE/training/1st_manual",
        transform=transform
    )
    test_data = DRIVE_dataset(
        img_dir="/content/drive/MyDrive/retina-vessel-segmentation/data/DRIVE/test/images",
        mask_dir="/content/drive/MyDrive/retina-vessel-segmentation/data/DRIVE/test/1st_manual",
        transform=transform
    )
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, num_workers=0)
    test_dataloader = DataLoader(test_data, batch_size=2, shuffle=False, num_workers=0)
    model = Unet(in_channels=3, out_channels=1).to(device)
    criterion = combined_loss
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

Found 20 valid pairs in /content/drive/MyDrive/retina-vessel-segmentation/data/DRIVE/training/images
Found 20 valid pairs in /content/drive/MyDrive/retina-vessel-segmentation/data/DRIVE/test/images


In [27]:
start_time = timer()
results = train(model, train_dataloader, test_dataloader, optimizer, criterion, epochs=50)
end_time = timer()
print(f"Total training time: {end_time - start_time:.3f} seconds")

  2%|▏         | 1/50 [00:06<05:36,  6.87s/it]

Epoch: 1 | train_loss: 1.0258 | train_dice: 0.2541 | train_iou : 0.1460 |test_loss: 1.3538 | test_dice: 0.0000 | test_iou : 0.0000


  4%|▍         | 2/50 [00:13<05:33,  6.95s/it]

Epoch: 2 | train_loss: 0.8229 | train_dice: 0.3053 | train_iou : 0.1808 |test_loss: 1.3418 | test_dice: 0.0002 | test_iou : 0.0001


  6%|▌         | 3/50 [00:20<05:25,  6.92s/it]

Epoch: 3 | train_loss: 0.7289 | train_dice: 0.3673 | train_iou : 0.2261 |test_loss: 1.2784 | test_dice: 0.1244 | test_iou : 0.0669


  8%|▊         | 4/50 [00:27<05:19,  6.95s/it]

Epoch: 4 | train_loss: 0.6600 | train_dice: 0.4291 | train_iou : 0.2739 |test_loss: 0.7630 | test_dice: 0.6054 | test_iou : 0.4357


 10%|█         | 5/50 [00:34<05:12,  6.95s/it]

Epoch: 5 | train_loss: 0.6226 | train_dice: 0.4575 | train_iou : 0.2982 |test_loss: 0.6010 | test_dice: 0.6095 | test_iou : 0.4391


 12%|█▏        | 6/50 [00:41<05:10,  7.06s/it]

Epoch: 6 | train_loss: 0.5936 | train_dice: 0.4940 | train_iou : 0.3302 |test_loss: 0.5463 | test_dice: 0.6003 | test_iou : 0.4315


 14%|█▍        | 7/50 [00:49<05:05,  7.10s/it]

Epoch: 7 | train_loss: 0.5856 | train_dice: 0.4877 | train_iou : 0.3238 |test_loss: 0.5991 | test_dice: 0.4920 | test_iou : 0.3281


 16%|█▌        | 8/50 [00:56<04:58,  7.11s/it]

Epoch: 8 | train_loss: 0.5717 | train_dice: 0.5025 | train_iou : 0.3366 |test_loss: 0.5385 | test_dice: 0.5815 | test_iou : 0.4137


 18%|█▊        | 9/50 [01:03<04:57,  7.26s/it]

Epoch: 9 | train_loss: 0.5537 | train_dice: 0.5257 | train_iou : 0.3571 |test_loss: 0.5493 | test_dice: 0.5345 | test_iou : 0.3655


 20%|██        | 10/50 [01:11<04:49,  7.24s/it]

Epoch: 10 | train_loss: 0.5470 | train_dice: 0.5229 | train_iou : 0.3555 |test_loss: 0.5009 | test_dice: 0.6129 | test_iou : 0.4427


 22%|██▏       | 11/50 [01:18<04:43,  7.26s/it]

Epoch: 11 | train_loss: 0.5469 | train_dice: 0.5194 | train_iou : 0.3518 |test_loss: 0.5176 | test_dice: 0.5699 | test_iou : 0.3994


 24%|██▍       | 12/50 [01:25<04:35,  7.26s/it]

Epoch: 12 | train_loss: 0.5213 | train_dice: 0.5483 | train_iou : 0.3793 |test_loss: 0.5183 | test_dice: 0.5746 | test_iou : 0.4053


 26%|██▌       | 13/50 [01:32<04:29,  7.27s/it]

Epoch: 13 | train_loss: 0.5084 | train_dice: 0.5610 | train_iou : 0.3912 |test_loss: 0.4922 | test_dice: 0.5943 | test_iou : 0.4233


 28%|██▊       | 14/50 [01:40<04:19,  7.22s/it]

Epoch: 14 | train_loss: 0.5110 | train_dice: 0.5517 | train_iou : 0.3831 |test_loss: 0.4837 | test_dice: 0.6103 | test_iou : 0.4397


 30%|███       | 15/50 [01:47<04:12,  7.20s/it]

Epoch: 15 | train_loss: 0.5076 | train_dice: 0.5528 | train_iou : 0.3832 |test_loss: 0.4967 | test_dice: 0.5852 | test_iou : 0.4147


 32%|███▏      | 16/50 [01:54<04:04,  7.18s/it]

Epoch: 16 | train_loss: 0.4910 | train_dice: 0.5780 | train_iou : 0.4076 |test_loss: 0.4721 | test_dice: 0.6187 | test_iou : 0.4486


 34%|███▍      | 17/50 [02:01<03:56,  7.18s/it]

Epoch: 17 | train_loss: 0.4925 | train_dice: 0.5668 | train_iou : 0.3970 |test_loss: 0.4702 | test_dice: 0.6164 | test_iou : 0.4463


 36%|███▌      | 18/50 [02:08<03:48,  7.15s/it]

Epoch: 18 | train_loss: 0.4907 | train_dice: 0.5640 | train_iou : 0.3948 |test_loss: 0.4736 | test_dice: 0.6077 | test_iou : 0.4368


 38%|███▊      | 19/50 [02:15<03:41,  7.15s/it]

Epoch: 19 | train_loss: 0.4696 | train_dice: 0.5938 | train_iou : 0.4234 |test_loss: 0.4701 | test_dice: 0.6026 | test_iou : 0.4319


 40%|████      | 20/50 [02:22<03:35,  7.17s/it]

Epoch: 20 | train_loss: 0.4791 | train_dice: 0.5719 | train_iou : 0.4014 |test_loss: 0.4633 | test_dice: 0.6120 | test_iou : 0.4414


 42%|████▏     | 21/50 [02:30<03:27,  7.17s/it]

Epoch: 21 | train_loss: 0.4628 | train_dice: 0.5931 | train_iou : 0.4224 |test_loss: 0.4892 | test_dice: 0.5842 | test_iou : 0.4134


 44%|████▍     | 22/50 [02:37<03:21,  7.19s/it]

Epoch: 22 | train_loss: 0.4509 | train_dice: 0.6101 | train_iou : 0.4403 |test_loss: 0.4689 | test_dice: 0.5950 | test_iou : 0.4242


 46%|████▌     | 23/50 [02:44<03:13,  7.16s/it]

Epoch: 23 | train_loss: 0.4640 | train_dice: 0.5799 | train_iou : 0.4096 |test_loss: 0.4559 | test_dice: 0.6155 | test_iou : 0.4452


 48%|████▊     | 24/50 [02:51<03:06,  7.19s/it]

Epoch: 24 | train_loss: 0.4378 | train_dice: 0.6180 | train_iou : 0.4484 |test_loss: 0.4459 | test_dice: 0.6263 | test_iou : 0.4564


 50%|█████     | 25/50 [02:58<02:59,  7.17s/it]

Epoch: 25 | train_loss: 0.4442 | train_dice: 0.6008 | train_iou : 0.4304 |test_loss: 0.4515 | test_dice: 0.6231 | test_iou : 0.4531


 52%|█████▏    | 26/50 [03:06<02:52,  7.20s/it]

Epoch: 26 | train_loss: 0.4308 | train_dice: 0.6146 | train_iou : 0.4443 |test_loss: 0.4442 | test_dice: 0.6350 | test_iou : 0.4657


 54%|█████▍    | 27/50 [03:13<02:45,  7.18s/it]

Epoch: 27 | train_loss: 0.4203 | train_dice: 0.6268 | train_iou : 0.4572 |test_loss: 0.4443 | test_dice: 0.6263 | test_iou : 0.4563


 56%|█████▌    | 28/50 [03:20<02:38,  7.19s/it]

Epoch: 28 | train_loss: 0.4247 | train_dice: 0.6180 | train_iou : 0.4486 |test_loss: 0.4335 | test_dice: 0.6537 | test_iou : 0.4859


 58%|█████▊    | 29/50 [03:27<02:30,  7.18s/it]

Epoch: 29 | train_loss: 0.4169 | train_dice: 0.6215 | train_iou : 0.4519 |test_loss: 0.4627 | test_dice: 0.5991 | test_iou : 0.4281


 60%|██████    | 30/50 [03:34<02:23,  7.17s/it]

Epoch: 30 | train_loss: 0.4057 | train_dice: 0.6320 | train_iou : 0.4626 |test_loss: 0.4461 | test_dice: 0.6301 | test_iou : 0.4603


 62%|██████▏   | 31/50 [03:42<02:16,  7.21s/it]

Epoch: 31 | train_loss: 0.3957 | train_dice: 0.6401 | train_iou : 0.4718 |test_loss: 0.4281 | test_dice: 0.6560 | test_iou : 0.4884


 64%|██████▍   | 32/50 [03:49<02:09,  7.19s/it]

Epoch: 32 | train_loss: 0.3801 | train_dice: 0.6554 | train_iou : 0.4887 |test_loss: 0.4317 | test_dice: 0.6509 | test_iou : 0.4829


 66%|██████▌   | 33/50 [03:56<02:02,  7.23s/it]

Epoch: 33 | train_loss: 0.3790 | train_dice: 0.6522 | train_iou : 0.4848 |test_loss: 0.4426 | test_dice: 0.6333 | test_iou : 0.4638


 68%|██████▊   | 34/50 [04:03<01:55,  7.20s/it]

Epoch: 34 | train_loss: 0.3663 | train_dice: 0.6640 | train_iou : 0.4981 |test_loss: 0.4545 | test_dice: 0.6176 | test_iou : 0.4473


 70%|███████   | 35/50 [04:10<01:47,  7.20s/it]

Epoch: 35 | train_loss: 0.3585 | train_dice: 0.6699 | train_iou : 0.5046 |test_loss: 0.4325 | test_dice: 0.6450 | test_iou : 0.4765


 72%|███████▏  | 36/50 [04:17<01:40,  7.16s/it]

Epoch: 36 | train_loss: 0.3535 | train_dice: 0.6773 | train_iou : 0.5128 |test_loss: 0.4587 | test_dice: 0.6166 | test_iou : 0.4461


 74%|███████▍  | 37/50 [04:25<01:33,  7.17s/it]

Epoch: 37 | train_loss: 0.3635 | train_dice: 0.6599 | train_iou : 0.4933 |test_loss: 0.4199 | test_dice: 0.6863 | test_iou : 0.5228


 76%|███████▌  | 38/50 [04:32<01:25,  7.16s/it]

Epoch: 38 | train_loss: 0.3415 | train_dice: 0.6814 | train_iou : 0.5178 |test_loss: 0.4234 | test_dice: 0.6686 | test_iou : 0.5026


 78%|███████▊  | 39/50 [04:39<01:18,  7.17s/it]

Epoch: 39 | train_loss: 0.3319 | train_dice: 0.6905 | train_iou : 0.5286 |test_loss: 0.4267 | test_dice: 0.6688 | test_iou : 0.5027


 80%|████████  | 40/50 [04:46<01:11,  7.17s/it]

Epoch: 40 | train_loss: 0.3250 | train_dice: 0.6971 | train_iou : 0.5356 |test_loss: 0.4358 | test_dice: 0.6558 | test_iou : 0.4882


 82%|████████▏ | 41/50 [04:53<01:04,  7.17s/it]

Epoch: 41 | train_loss: 0.3267 | train_dice: 0.6982 | train_iou : 0.5373 |test_loss: 0.4214 | test_dice: 0.6843 | test_iou : 0.5205


 84%|████████▍ | 42/50 [05:01<00:57,  7.21s/it]

Epoch: 42 | train_loss: 0.3377 | train_dice: 0.6787 | train_iou : 0.5145 |test_loss: 0.4157 | test_dice: 0.6884 | test_iou : 0.5253


 86%|████████▌ | 43/50 [05:08<00:50,  7.17s/it]

Epoch: 43 | train_loss: 0.3179 | train_dice: 0.7003 | train_iou : 0.5394 |test_loss: 0.4239 | test_dice: 0.6701 | test_iou : 0.5043


 88%|████████▊ | 44/50 [05:15<00:43,  7.20s/it]

Epoch: 44 | train_loss: 0.3066 | train_dice: 0.7114 | train_iou : 0.5529 |test_loss: 0.4168 | test_dice: 0.6897 | test_iou : 0.5268


 90%|█████████ | 45/50 [05:22<00:35,  7.17s/it]

Epoch: 45 | train_loss: 0.2960 | train_dice: 0.7217 | train_iou : 0.5651 |test_loss: 0.4206 | test_dice: 0.6987 | test_iou : 0.5372


 92%|█████████▏| 46/50 [05:29<00:28,  7.18s/it]

Epoch: 46 | train_loss: 0.2933 | train_dice: 0.7256 | train_iou : 0.5701 |test_loss: 0.4249 | test_dice: 0.7003 | test_iou : 0.5390


 94%|█████████▍| 47/50 [05:36<00:21,  7.16s/it]

Epoch: 47 | train_loss: 0.2939 | train_dice: 0.7226 | train_iou : 0.5665 |test_loss: 0.4155 | test_dice: 0.6954 | test_iou : 0.5333


 96%|█████████▌| 48/50 [05:44<00:14,  7.17s/it]

Epoch: 48 | train_loss: 0.2936 | train_dice: 0.7152 | train_iou : 0.5573 |test_loss: 0.4222 | test_dice: 0.7248 | test_iou : 0.5688


 98%|█████████▊| 49/50 [05:51<00:07,  7.17s/it]

Epoch: 49 | train_loss: 0.2737 | train_dice: 0.7424 | train_iou : 0.5912 |test_loss: 0.4128 | test_dice: 0.6966 | test_iou : 0.5348


100%|██████████| 50/50 [05:58<00:00,  7.17s/it]

Epoch: 50 | train_loss: 0.2683 | train_dice: 0.7459 | train_iou : 0.5952 |test_loss: 0.4148 | test_dice: 0.7152 | test_iou : 0.5569
Total training time: 358.384 seconds





In [28]:
torch.save(model.state_dict(), 'unet_weights.pth')
