In [12]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from scipy.io import loadmat

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


In [19]:
from scipy.io import loadmat
mat = loadmat('ORIGA/ORIGA/Semi-automatic-annotations/544.mat')
print(mat.keys())

dict_keys(['__header__', '__version__', '__globals__', 'mask'])


In [20]:
class ORIGAMatMaskDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform_img=None, transform_mask=None, mask_key="'mask'"):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted([f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.png'))])
        self.mats = sorted([f for f in os.listdir(mask_dir) if f.lower().endswith('.mat')])
        self.transform_img = transform_img
        self.transform_mask = transform_mask
        self.mask_key = mask_key

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        # Match .mat file by filename (assumes same base name)
        base = os.path.splitext(img_name)[0]
        mat_candidates = [f for f in self.mats if base in f]
        if len(mat_candidates) == 0:
            raise FileNotFoundError(f"No .mat file found for image {img_name}")
        mat_name = mat_candidates[0]
        img_path = os.path.join(self.image_dir, img_name)
        mat_path = os.path.join(self.mask_dir, mat_name)

        image = Image.open(img_path).convert('RGB')
        mat_data = loadmat(mat_path)
        if self.mask_key not in mat_data:
            raise KeyError(f"Key '{self.mask_key}' not found in {mat_name}. Available keys: {mat_data.keys()}")
        mask = mat_data[self.mask_key]
        mask = np.array(mask).astype(np.uint8)
        if mask.ndim == 3:
            mask = mask.squeeze()
        mask_img = Image.fromarray(mask)

        if self.transform_img:
            image = self.transform_img(image)
        if self.transform_mask:
            mask_img = self.transform_mask(mask_img)
            mask_img = torch.squeeze(mask_img)
        return image, mask_img.long()

In [5]:
class ORIGAMatMaskDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform_img=None, transform_mask=None, mask_key="label", mask_size=(256, 256)):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted([f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.png'))])
        self.mats = sorted([f for f in os.listdir(mask_dir) if f.lower().endswith('.mat')])
        self.transform_img = transform_img
        self.transform_mask = transform_mask
        self.mask_key = mask_key
        self.mask_size = mask_size

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        mat_name = self.mats[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mat_path = os.path.join(self.mask_dir, mat_name)

        image = Image.open(img_path).convert('RGB')
        mat_data = loadmat(mat_path)
        mask = mat_data[self.mask_key]
        mask = np.array(mask).astype(np.uint8)
        if mask.ndim == 3:
            mask = mask.squeeze()
        mask_img = Image.fromarray(mask)

        if self.transform_img:
            image = self.transform_img(image)
        if self.transform_mask:
            mask_img = self.transform_mask(mask_img)
            mask_img = torch.squeeze(mask_img)
        return image, mask_img.long()

In [14]:
img_transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
])

mask_transform = transforms.Compose([
    transforms.Resize((256,256), interpolation=Image.NEAREST),
    transforms.ToTensor(),  # float32 [0,1,2], convert to long in dataset
])

In [28]:
train_dataset = ORIGAMatMaskDataset(
    image_dir='ORIGA/ORIGA/roi_images', 
    mask_dir='ORIGA/ORIGA/Semi-automatic-annotations',  
    transform_img=img_transform,
    transform_mask=mask_transform,
    mask_key='mask'  
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)


In [29]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):  # 3 for background, OD, OC
        super(UNet, self).__init__()
        def CBR(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1), nn.BatchNorm2d(out_c), nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1), nn.BatchNorm2d(out_c), nn.ReLU(inplace=True)
            )
        self.enc1 = CBR(in_channels, 64)
        self.enc2 = CBR(64, 128)
        self.enc3 = CBR(128, 256)
        self.enc4 = CBR(256, 512)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = CBR(512, 1024)
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = CBR(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = CBR(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = CBR(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = CBR(128, 64)
        self.conv_last = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b = self.bottleneck(self.pool(e4))
        d4 = self.up4(b)
        d4 = torch.cat((d4, e4), dim=1)
        d4 = self.dec4(d4)
        d3 = self.up3(d4)
        d3 = torch.cat((d3, e3), dim=1)
        d3 = self.dec3(d3)
        d2 = self.up2(d3)
        d2 = torch.cat((d2, e2), dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat((d1, e1), dim=1)
        d1 = self.dec1(d1)
        out = self.conv_last(d1)
        return out

In [30]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)
num_classes = 3 
model = UNet(in_channels=3, out_channels=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

Using device: cuda


In [26]:
model=UNet()
model.to(device)

UNet(
  (enc1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (enc2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (enc3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_r

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0
    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)  # [B, H, W] with class labels

        outputs = model(images)  # [B, num_classes, H, W]
        loss = criterion(outputs, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")