In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
from tqdm import tqdm
import segmentation_models_pytorch as smp

# ------------------------
# Dataset
# ------------------------
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])
        
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        # Normalize mask to binary
        mask = (mask > 0).float()
        
        return image, mask

# ------------------------
# Transform
# ------------------------
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# ------------------------
# DataLoader
# ------------------------
train_dataset = SegmentationDataset("Kvasir-SEG/images", "Kvasir-SEG/masks", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# ------------------------
# ResUNet
# ------------------------
class ResUNet(nn.Module):
    def __init__(self, in_channels=3, out_classes=1):
        super(ResUNet, self).__init__()
        self.model = smp.Unet(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=in_channels,
            classes=out_classes
        )

    def forward(self, x):
        return self.model(x)

model = ResUNet().to("cuda" if torch.cuda.is_available() else "cpu")

# ------------------------
# Metrics
# ------------------------
def iou_score(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred) > 0.5
    target = target > 0.5
    intersection = (pred & target).sum().float()
    union = (pred | target).sum().float()
    return (intersection + smooth) / (union + smooth)

def dice_score(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred) > 0.5
    target = target > 0.5
    intersection = (pred & target).sum().float()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

# ------------------------
# Training Setup
# ------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 10
best_iou = 0.0

# ------------------------
# Training Loop
# ------------------------
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    epoch_iou = 0
    epoch_dice = 0
    
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for images, masks in loop:
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)
        loss = criterion(outputs, masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Metrics
        iou = iou_score(outputs, masks)
        dice = dice_score(outputs, masks)

        epoch_loss += loss.item()
        epoch_iou += iou.item()
        epoch_dice += dice.item()
        loop.set_postfix(loss=loss.item(), iou=iou.item(), dice=dice.item())

    avg_iou = epoch_iou / len(train_loader)
    avg_dice = epoch_dice / len(train_loader)
    print(f"\nEpoch {epoch+1}: Avg Loss = {epoch_loss/len(train_loader):.4f}, IoU = {avg_iou:.4f}, Dice = {avg_dice:.4f}")

    # Save best model
    if avg_iou > best_iou:
        best_iou = avg_iou
        torch.save(model.state_dict(), "best_resunet.pth")
        print(f"✅ Saved Best Model (IoU = {best_iou:.4f})")



poch 1/10: 100%|██████████████████████████████████| 125/125 [03:47<00:00,  1.82s/it, dice=0.78, iou=0.639, loss=0.385]


Epoch 1: Avg Loss = 0.5461, IoU = 0.4548, Dice = 0.6060
✅ Saved Best Model (IoU = 0.4548)



poch 2/10: 100%|█████████████████████████████████| 125/125 [03:44<00:00,  1.80s/it, dice=0.896, iou=0.812, loss=0.256]


Epoch 2: Avg Loss = 0.3272, IoU = 0.7076, Dice = 0.8251
✅ Saved Best Model (IoU = 0.7076)



poch 3/10: 100%|█████████████████████████████████| 125/125 [03:48<00:00,  1.83s/it, dice=0.674, iou=0.508, loss=0.328]


Epoch 3: Avg Loss = 0.2246, IoU = 0.7953, Dice = 0.8843
✅ Saved Best Model (IoU = 0.7953)



poch 4/10: 100%|█████████████████████████████████| 125/125 [03:49<00:00,  1.84s/it, dice=0.863, iou=0.758, loss=0.191]


Epoch 4: Avg Loss = 0.1626, IoU = 0.8427, Dice = 0.9134
✅ Saved Best Model (IoU = 0.8427)



poch 5/10: 100%|██████████████████████████████████| 125/125 [04:01<00:00,  1.93s/it, dice=0.94, iou=0.886, loss=0.112]


Epoch 5: Avg Loss = 0.1326, IoU = 0.8493, Dice = 0.9170
✅ Saved Best Model (IoU = 0.8493)



poch 6/10: 100%|████████████████████████████████| 125/125 [03:48<00:00,  1.83s/it, dice=0.953, iou=0.911, loss=0.0872]


Epoch 6: Avg Loss = 0.1070, IoU = 0.8714, Dice = 0.9304
✅ Saved Best Model (IoU = 0.8714)



poch 7/10: 100%|████████████████████████████████| 125/125 [03:52<00:00,  1.86s/it, dice=0.911, iou=0.836, loss=0.0876]


Epoch 7: Avg Loss = 0.0885, IoU = 0.8876, Dice = 0.9398
✅ Saved Best Model (IoU = 0.8876)



poch 8/10: 100%|██████████████████████████████████| 125/125 [03:51<00:00,  1.86s/it, dice=0.95, iou=0.905, loss=0.066]


Epoch 8: Avg Loss = 0.0739, IoU = 0.9007, Dice = 0.9474
✅ Saved Best Model (IoU = 0.9007)



poch 9/10: 100%|████████████████████████████████| 125/125 [03:43<00:00,  1.79s/it, dice=0.937, iou=0.881, loss=0.0789]


Epoch 9: Avg Loss = 0.0616, IoU = 0.9152, Dice = 0.9555
✅ Saved Best Model (IoU = 0.9152)


Epoch 10/10: 100%|███████████████████████████████| 125/125 [03:42<00:00,  1.78s/it, dice=0.962, iou=0.928, loss=0.0491]


Epoch 10: Avg Loss = 0.0580, IoU = 0.9132, Dice = 0.9544



