In [6]:
import os
from PIL import Image
import numpy as np
from tqdm import tqdm

def get_image_paths(folder):
    img_dir = os.path.join(folder, "image")
    return [os.path.join(img_dir, fname) for fname in os.listdir(img_dir) if fname.endswith('.png')]

all_image_paths = []
for split in ["train", "val", "test"]:
    all_image_paths.extend(get_image_paths(split))

means = []
stds = []

for img_path in tqdm(all_image_paths, desc="Calculating mean/std"):
    img = Image.open(img_path).convert("RGB")
    img = np.array(img) / 255.0  # scale to [0,1]
    means.append(np.mean(img, axis=(0,1)))
    stds.append(np.std(img, axis=(0,1)))

mean = np.mean(means, axis=0)
std = np.mean(stds, axis=0)

print(f"Dataset mean: {mean}")
print(f"Dataset std: {std}")

Calculating mean/std: 100%|██████████| 300/300 [00:01<00:00, 267.74it/s]

Dataset mean: [0.11803523 0.12175034 0.12110489]
Dataset std: [0.16351671 0.16600904 0.16600859]





U-Net

In [21]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

# -----------------------
# Dataset
# -----------------------
class SegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.images = os.listdir(img_dir)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])  # assumes same filename

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # grayscale

        if self.transform is not None:
            image = self.transform(image)
            # For mask: only resize and ToTensor, no Normalize!
            mask_transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor()
            ])
            mask = mask_transform(mask)

        mask = (mask > 0).float()  # binarize
        return image, mask

# -----------------------
# U-Net
# -----------------------
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()
        self.enc1 = DoubleConv(in_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)

        self.pool = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(512, 1024)

        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(128, 64)

        self.conv_final = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.upconv4(b)
        d4 = torch.cat((d4, e4), dim=1)
        d4 = self.dec4(d4)

        d3 = self.upconv3(d4)
        d3 = torch.cat((d3, e3), dim=1)
        d3 = self.dec3(d3)

        d2 = self.upconv2(d3)
        d2 = torch.cat((d2, e2), dim=1)
        d2 = self.dec2(d2)

        d1 = self.upconv1(d2)
        d1 = torch.cat((d1, e1), dim=1)
        d1 = self.dec1(d1)

        return torch.sigmoid(self.conv_final(d1))

# -----------------------
# Metrics
# -----------------------
def dice_score(pred, target, smooth=1e-6):
    pred = (pred > 0.5).float()
    target = target.float()
    intersection = (pred * target).sum()
    return ((2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)).item()

def iou_score(pred, target, smooth=1e-6):
    pred = (pred > 0.5).float()
    target = target.float()
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    return ((intersection + smooth) / (union + smooth)).item()

def pixel_accuracy(pred, target):
    pred = (pred > 0.5).float()
    correct = (pred == target).float().sum()
    return (correct / target.numel()).item()

def precision_score(pred, target, smooth=1e-6):
    pred = (pred > 0.5).float()
    target = target.float()
    tp = (pred * target).sum()
    fp = (pred * (1 - target)).sum()
    return (tp / (tp + fp + smooth)).item()

def recall_score(pred, target, smooth=1e-6):
    pred = (pred > 0.5).float()
    target = target.float()
    tp = (pred * target).sum()
    fn = ((1 - pred) * target).sum()
    return (tp / (tp + fn + smooth)).item()

# -----------------------
# Data Setup
# -----------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.11803523, 0.12175034, 0.12110489], std=[0.16351671, 0.16600904, 0.16600859])
])

train_dataset = SegmentationDataset(
    img_dir="train/image",
    mask_dir="train/mask",
    transform=transform
)
val_dataset = SegmentationDataset(
    img_dir="val/image",
    mask_dir="val/mask",
    transform=transform
)
test_dataset = SegmentationDataset(
    img_dir="test/image",
    mask_dir="test/mask",
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)
test_loader = DataLoader(test_dataset, batch_size=8)

# -----------------------
# U-Net Training
# -----------------------
model = UNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
bce = nn.BCELoss()

num_epochs = 20
u_net_best_test_loss = float("inf")
patience = 5
epochs_no_improve = 0

for epoch in range(num_epochs):
    # --- Training ---
    model.train()
    u_net_train_loss, u_net_train_dice, u_net_train_iou, u_net_train_acc = 0, 0, 0, 0
    for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        imgs, masks = imgs.to(device), masks.to(device)
        preds = model(imgs)
        loss = 0.5 * bce(preds, masks) + 0.5 * (1 - dice_score(preds, masks))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        u_net_train_loss += loss.item()
        u_net_train_dice += dice_score(preds, masks)
        u_net_train_iou += iou_score(preds, masks)
        u_net_train_acc += pixel_accuracy(preds, masks)

    # --- Validation ---
    model.eval()
    u_net_val_loss, u_net_val_dice, u_net_val_iou, u_net_val_acc, u_net_val_prec, u_net_val_rec = 0, 0, 0, 0, 0, 0
    with torch.no_grad():
        for imgs, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            imgs, masks = imgs.to(device), masks.to(device)
            preds = model(imgs)
            loss = 0.5 * bce(preds, masks) + 0.5 * (1 - dice_score(preds, masks))
            u_net_val_loss += loss.item()
            u_net_val_dice += dice_score(preds, masks)
            u_net_val_iou += iou_score(preds, masks)
            u_net_val_acc += pixel_accuracy(preds, masks)
            u_net_val_prec += precision_score(preds, masks)
            u_net_val_rec += recall_score(preds, masks)

    # --- Test ---
    u_net_test_loss, u_net_test_dice, u_net_test_iou, u_net_test_acc, u_net_test_prec, u_net_test_rec = 0, 0, 0, 0, 0, 0
    with torch.no_grad():
        for imgs, masks in tqdm(test_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Test"):
            imgs, masks = imgs.to(device), masks.to(device)
            preds = model(imgs)
            loss = 0.5 * bce(preds, masks) + 0.5 * (1 - dice_score(preds, masks))
            u_net_test_loss += loss.item()
            u_net_test_dice += dice_score(preds, masks)
            u_net_test_iou += iou_score(preds, masks)
            u_net_test_acc += pixel_accuracy(preds, masks)
            u_net_test_prec += precision_score(preds, masks)
            u_net_test_rec += recall_score(preds, masks)

    # --- Print metrics ---
    print(f"\nEpoch {epoch+1}:")
    print(f" Train Loss={u_net_train_loss/len(train_loader):.4f} | "
          f"Dice={u_net_train_dice/len(train_loader):.4f} | IoU={u_net_train_iou/len(train_loader):.4f} | Acc={u_net_train_acc/len(train_loader):.4f}")
    print(f" Val   Loss={u_net_val_loss/len(val_loader):.4f} | "
          f"Dice={u_net_val_dice/len(val_loader):.4f} | IoU={u_net_val_iou/len(val_loader):.4f} | Acc={u_net_val_acc/len(val_loader):.4f} | "
          f"Prec={u_net_val_prec/len(val_loader):.4f} | Recall={u_net_val_rec/len(val_loader):.4f}")
    print(f" Test  Loss={u_net_test_loss/len(test_loader):.4f} | "
          f"Dice={u_net_test_dice/len(test_loader):.4f} | IoU={u_net_test_iou/len(test_loader):.4f} | Acc={u_net_test_acc/len(test_loader):.4f} | "
          f"Prec={u_net_test_prec/len(test_loader):.4f} | Recall={u_net_test_rec/len(test_loader):.4f}")

    # --- Early stopping ---
    if u_net_test_loss < u_net_best_test_loss:
        u_net_best_test_loss = u_net_test_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), "best_unet.pth")
        print("✅ Model saved!")
    else:
        epochs_no_improve += 1
        print(f"EarlyStopping counter: {epochs_no_improve} of {patience}")
        if epochs_no_improve >= patience:
            print("⏹️ Early stopping triggered!")
            break


Epoch 1/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.11it/s]
Epoch 1/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 11.67it/s]
Epoch 1/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 12.78it/s]



Epoch 1:
 Train Loss=0.8455 | Dice=0.0116 | IoU=0.0058 | Acc=0.5501
 Val   Loss=0.8042 | Dice=0.0267 | IoU=0.0135 | Acc=0.7578 | Prec=0.0136 | Recall=0.7947
 Test  Loss=0.8015 | Dice=0.0311 | IoU=0.0158 | Acc=0.7628 | Prec=0.0159 | Recall=0.8274
✅ Model saved!


Epoch 2/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.18it/s]
Epoch 2/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 12.50it/s]
Epoch 2/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 13.27it/s]



Epoch 2:
 Train Loss=0.7431 | Dice=0.0012 | IoU=0.0006 | Acc=0.9891
 Val   Loss=0.7396 | Dice=0.0010 | IoU=0.0005 | Acc=0.9930 | Prec=0.0015 | Recall=0.0007
 Test  Loss=0.7390 | Dice=0.0018 | IoU=0.0009 | Acc=0.9926 | Prec=0.0026 | Recall=0.0013
✅ Model saved!


Epoch 3/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.22it/s]
Epoch 3/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 12.06it/s]
Epoch 3/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 12.98it/s]



Epoch 3:
 Train Loss=0.7182 | Dice=0.0000 | IoU=0.0000 | Acc=0.9942
 Val   Loss=0.7162 | Dice=0.0000 | IoU=0.0000 | Acc=0.9946 | Prec=0.0000 | Recall=0.0000
 Test  Loss=0.7160 | Dice=0.0000 | IoU=0.0000 | Acc=0.9942 | Prec=0.0000 | Recall=0.0000
✅ Model saved!


Epoch 4/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.05it/s]
Epoch 4/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 12.28it/s]
Epoch 4/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 11.79it/s]



Epoch 4:
 Train Loss=0.7050 | Dice=0.0000 | IoU=0.0000 | Acc=0.9948
 Val   Loss=0.7039 | Dice=0.0000 | IoU=0.0000 | Acc=0.9950 | Prec=0.0000 | Recall=0.0000
 Test  Loss=0.7037 | Dice=0.0000 | IoU=0.0000 | Acc=0.9946 | Prec=0.0000 | Recall=0.0000
✅ Model saved!


Epoch 5/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  4.94it/s]
Epoch 5/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 12.42it/s]
Epoch 5/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 12.93it/s]



Epoch 5:
 Train Loss=0.6944 | Dice=0.0000 | IoU=0.0000 | Acc=0.9951
 Val   Loss=0.6959 | Dice=0.0000 | IoU=0.0000 | Acc=0.9952 | Prec=0.0000 | Recall=0.0000
 Test  Loss=0.6957 | Dice=0.0000 | IoU=0.0000 | Acc=0.9947 | Prec=0.0000 | Recall=0.0000
✅ Model saved!


Epoch 6/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.11it/s]
Epoch 6/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 11.56it/s]
Epoch 6/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 13.16it/s]



Epoch 6:
 Train Loss=0.6052 | Dice=0.1609 | IoU=0.0981 | Acc=0.9956
 Val   Loss=0.4268 | Dice=0.5251 | IoU=0.3626 | Acc=0.9958 | Prec=0.4880 | Recall=0.5846
 Test  Loss=0.4266 | Dice=0.5252 | IoU=0.3724 | Acc=0.9955 | Prec=0.4970 | Recall=0.5878
✅ Model saved!


Epoch 7/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.16it/s]
Epoch 7/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 12.44it/s]
Epoch 7/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 13.23it/s]



Epoch 7:
 Train Loss=0.3792 | Dice=0.5965 | IoU=0.4306 | Acc=0.9969
 Val   Loss=0.3827 | Dice=0.5899 | IoU=0.4257 | Acc=0.9972 | Prec=0.7136 | Recall=0.5112
 Test  Loss=0.4089 | Dice=0.5378 | IoU=0.3857 | Acc=0.9968 | Prec=0.6804 | Recall=0.4741
✅ Model saved!


Epoch 8/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.01it/s]
Epoch 8/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 11.88it/s]
Epoch 8/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 13.14it/s]



Epoch 8:
 Train Loss=0.3128 | Dice=0.7135 | IoU=0.5570 | Acc=0.9976
 Val   Loss=0.3107 | Dice=0.7175 | IoU=0.5624 | Acc=0.9972 | Prec=0.6168 | Recall=0.8734
 Test  Loss=0.3211 | Dice=0.6966 | IoU=0.5397 | Acc=0.9969 | Prec=0.6109 | Recall=0.8321
✅ Model saved!


Epoch 9/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.12it/s]
Epoch 9/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 10.73it/s]
Epoch 9/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 12.85it/s]



Epoch 9:
 Train Loss=0.2946 | Dice=0.7337 | IoU=0.5822 | Acc=0.9976
 Val   Loss=0.2894 | Dice=0.7422 | IoU=0.5913 | Acc=0.9977 | Prec=0.6921 | Recall=0.8039
 Test  Loss=0.3009 | Dice=0.7195 | IoU=0.5678 | Acc=0.9974 | Prec=0.6839 | Recall=0.7769
✅ Model saved!


Epoch 10/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  4.95it/s]
Epoch 10/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 12.36it/s]
Epoch 10/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 12.88it/s]



Epoch 10:
 Train Loss=0.2662 | Dice=0.7766 | IoU=0.6357 | Acc=0.9979
 Val   Loss=0.2647 | Dice=0.7753 | IoU=0.6360 | Acc=0.9981 | Prec=0.7350 | Recall=0.8240
 Test  Loss=0.2807 | Dice=0.7436 | IoU=0.5952 | Acc=0.9976 | Prec=0.7145 | Recall=0.7922
✅ Model saved!


Epoch 11/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.17it/s]
Epoch 11/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 11.18it/s]
Epoch 11/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 12.94it/s]



Epoch 11:
 Train Loss=0.2514 | Dice=0.7933 | IoU=0.6582 | Acc=0.9981
 Val   Loss=0.2478 | Dice=0.7978 | IoU=0.6657 | Acc=0.9983 | Prec=0.7575 | Recall=0.8479
 Test  Loss=0.2697 | Dice=0.7545 | IoU=0.6096 | Acc=0.9977 | Prec=0.7209 | Recall=0.8057
✅ Model saved!


Epoch 12/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.17it/s]
Epoch 12/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 12.20it/s]
Epoch 12/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 13.20it/s]



Epoch 12:
 Train Loss=0.2406 | Dice=0.8035 | IoU=0.6726 | Acc=0.9982
 Val   Loss=0.2526 | Dice=0.7779 | IoU=0.6374 | Acc=0.9983 | Prec=0.8112 | Recall=0.7506
 Test  Loss=0.2764 | Dice=0.7310 | IoU=0.5812 | Acc=0.9977 | Prec=0.7679 | Recall=0.7050
EarlyStopping counter: 1 of 5


Epoch 13/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.24it/s]
Epoch 13/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 12.27it/s]
Epoch 13/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 13.14it/s]



Epoch 13:
 Train Loss=0.2298 | Dice=0.8132 | IoU=0.6860 | Acc=0.9983
 Val   Loss=0.2264 | Dice=0.8168 | IoU=0.6912 | Acc=0.9985 | Prec=0.8279 | Recall=0.8069
 Test  Loss=0.2448 | Dice=0.7804 | IoU=0.6439 | Acc=0.9980 | Prec=0.7908 | Recall=0.7800
✅ Model saved!


Epoch 14/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.04it/s]
Epoch 14/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 12.35it/s]
Epoch 14/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 13.03it/s]



Epoch 14:
 Train Loss=0.2213 | Dice=0.8199 | IoU=0.6960 | Acc=0.9984
 Val   Loss=0.2209 | Dice=0.8165 | IoU=0.6905 | Acc=0.9985 | Prec=0.8289 | Recall=0.8073
 Test  Loss=0.2450 | Dice=0.7687 | IoU=0.6301 | Acc=0.9979 | Prec=0.7685 | Recall=0.7831
EarlyStopping counter: 1 of 5


Epoch 15/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.26it/s]
Epoch 15/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 12.23it/s]
Epoch 15/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 12.96it/s]



Epoch 15:
 Train Loss=0.2086 | Dice=0.8356 | IoU=0.7185 | Acc=0.9985
 Val   Loss=0.2373 | Dice=0.7751 | IoU=0.6350 | Acc=0.9979 | Prec=0.6758 | Recall=0.9163
 Test  Loss=0.2579 | Dice=0.7346 | IoU=0.5858 | Acc=0.9972 | Prec=0.6350 | Recall=0.8836
EarlyStopping counter: 2 of 5


Epoch 16/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.24it/s]
Epoch 16/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 11.69it/s]
Epoch 16/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 13.09it/s]



Epoch 16:
 Train Loss=0.1999 | Dice=0.8438 | IoU=0.7308 | Acc=0.9986
 Val   Loss=0.2089 | Dice=0.8200 | IoU=0.6956 | Acc=0.9986 | Prec=0.8623 | Recall=0.7844
 Test  Loss=0.2242 | Dice=0.7898 | IoU=0.6569 | Acc=0.9982 | Prec=0.8342 | Recall=0.7620
✅ Model saved!


Epoch 17/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.27it/s]
Epoch 17/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 12.20it/s]
Epoch 17/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 13.33it/s]



Epoch 17:
 Train Loss=0.1910 | Dice=0.8530 | IoU=0.7445 | Acc=0.9987
 Val   Loss=0.2013 | Dice=0.8272 | IoU=0.7063 | Acc=0.9985 | Prec=0.8099 | Recall=0.8471
 Test  Loss=0.2193 | Dice=0.7917 | IoU=0.6586 | Acc=0.9981 | Prec=0.7740 | Recall=0.8254
✅ Model saved!


Epoch 18/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.32it/s]
Epoch 18/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 12.28it/s]
Epoch 18/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 13.29it/s]



Epoch 18:
 Train Loss=0.1839 | Dice=0.8587 | IoU=0.7533 | Acc=0.9988
 Val   Loss=0.1898 | Dice=0.8444 | IoU=0.7311 | Acc=0.9987 | Prec=0.8340 | Recall=0.8559
 Test  Loss=0.2084 | Dice=0.8078 | IoU=0.6812 | Acc=0.9982 | Prec=0.7970 | Recall=0.8294
✅ Model saved!


Epoch 19/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  5.32it/s]
Epoch 19/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 12.39it/s]
Epoch 19/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 12.78it/s]



Epoch 19:
 Train Loss=0.1728 | Dice=0.8730 | IoU=0.7750 | Acc=0.9989
 Val   Loss=0.1865 | Dice=0.8411 | IoU=0.7265 | Acc=0.9987 | Prec=0.8250 | Recall=0.8605
 Test  Loss=0.2051 | Dice=0.8043 | IoU=0.6760 | Acc=0.9982 | Prec=0.7931 | Recall=0.8321
✅ Model saved!


Epoch 20/20 - Training: 100%|██████████| 24/24 [00:04<00:00,  4.98it/s]
Epoch 20/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 11.98it/s]
Epoch 20/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 12.10it/s]



Epoch 20:
 Train Loss=0.1679 | Dice=0.8752 | IoU=0.7784 | Acc=0.9989
 Val   Loss=0.1847 | Dice=0.8379 | IoU=0.7218 | Acc=0.9987 | Prec=0.8572 | Recall=0.8214
 Test  Loss=0.2020 | Dice=0.8038 | IoU=0.6760 | Acc=0.9983 | Prec=0.8283 | Recall=0.7907
✅ Model saved!


In [None]:
print(f"Best Test Loss: {u_net_test_loss/len(test_loader):.4f} at epoch {epoch + 1 - epochs_no_improve}")
print(f"Best Test Dice: {u_net_test_dice/len(test_loader):.4f}")
print(f"Best Test IoU: {u_net_test_iou/len(test_loader):.4f}")
print(f"Best Test Accuracy: {u_net_test_acc/len(test_loader):.4f}")
print(f"Best Test Precision: {u_net_test_prec/len(test_loader):.4f}")
print(f"Best Test Recall: {u_net_test_rec/len(test_loader):.4f}")

Best Test Loss: 0.2020 at epoch 17
Best Test Dice: 0.8038
Best Test IoU: 0.6760
Best Test Accuracy: 0.9983
Best Test Precision: 0.8283
Best Test Recall: 0.7907


In [57]:
u_net_loss = u_net_test_loss/len(test_loader)
u_net_dice = u_net_test_dice/len(test_loader)
u_net_iou = u_net_test_iou/len(test_loader)

In [25]:
import torchvision
from torchvision.models.segmentation import deeplabv3_resnet50

# -----------------------
# DeepLabV3 Model
# -----------------------
class DeepLabV3(nn.Module):
    def __init__(self, num_classes=1):
        super(DeepLabV3, self).__init__()
        self.model = deeplabv3_resnet50(weights=None, num_classes=1)
        # If you want to use pretrained backbone:
        self.model = deeplabv3_resnet50(weights="DEFAULT")
        self.model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)

    def forward(self, x):
        out = self.model(x)['out']
        return torch.sigmoid(out)

# -----------------------
# Training DeepLabV3
# -----------------------
deeplab_model = DeepLabV3().to(device)
optimizer = optim.Adam(deeplab_model.parameters(), lr=1e-4)
bce = nn.BCELoss()
num_epochs = 20
deep_lab_best_test_loss = float("inf")
patience = 5
epochs_no_improve = 0

for epoch in range(num_epochs):
    # --- Training ---
    deeplab_model.train()
    deep_lab_train_loss, deep_lab_train_dice, deep_lab_train_iou, deep_lab_train_acc = 0, 0, 0, 0
    for imgs, masks in tqdm(train_loader, desc=f"DeepLabV3 Epoch {epoch+1}/{num_epochs} - Training"):
        imgs, masks = imgs.to(device), masks.to(device)
        preds = deeplab_model(imgs)
        loss = 0.5 * bce(preds, masks) + 0.5 * (1 - dice_score(preds, masks))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        deep_lab_train_loss += loss.item()
        deep_lab_train_dice += dice_score(preds, masks)
        deep_lab_train_iou += iou_score(preds, masks)
        deep_lab_train_acc += pixel_accuracy(preds, masks)

    # --- Validation ---
    deeplab_model.eval()
    deep_lab_val_loss, deep_lab_val_dice, deep_lab_val_iou, deep_lab_val_acc, deep_lab_val_prec, deep_lab_val_rec = 0, 0, 0, 0, 0, 0
    with torch.no_grad():
        for imgs, masks in tqdm(val_loader, desc=f"DeepLabV3 Epoch {epoch+1}/{num_epochs} - Validation"):
            imgs, masks = imgs.to(device), masks.to(device)
            preds = deeplab_model(imgs)
            loss = 0.5 * bce(preds, masks) + 0.5 * (1 - dice_score(preds, masks))
            deep_lab_val_loss += loss.item()
            deep_lab_val_dice += dice_score(preds, masks)
            deep_lab_val_iou += iou_score(preds, masks)
            deep_lab_val_acc += pixel_accuracy(preds, masks)
            deep_lab_val_prec += precision_score(preds, masks)
            deep_lab_val_rec += recall_score(preds, masks)

    # --- Test ---
    deep_lab_test_loss, deep_lab_test_dice, deep_lab_test_iou, deep_lab_test_acc, deep_lab_test_prec, deep_lab_test_rec = 0, 0, 0, 0, 0, 0
    with torch.no_grad():
        for imgs, masks in tqdm(test_loader, desc=f"DeepLabV3 Epoch {epoch+1}/{num_epochs} - Test"):
            imgs, masks = imgs.to(device), masks.to(device)
            preds = deeplab_model(imgs)
            loss = 0.5 * bce(preds, masks) + 0.5 * (1 - dice_score(preds, masks))
            deep_lab_test_loss += loss.item()
            deep_lab_test_dice += dice_score(preds, masks)
            deep_lab_test_iou += iou_score(preds, masks)
            deep_lab_test_acc += pixel_accuracy(preds, masks)
            deep_lab_test_prec += precision_score(preds, masks)
            deep_lab_test_rec += recall_score(preds, masks)

    # --- Print metrics ---
    print(f"\nDeepLabV3 Epoch {epoch+1}:")
    print(f" Train Loss={deep_lab_train_loss/len(train_loader):.4f} | "
          f"Dice={deep_lab_train_dice/len(train_loader):.4f} | IoU={deep_lab_train_iou/len(train_loader):.4f} | Acc={deep_lab_train_acc/len(train_loader):.4f}")
    print(f" Val   Loss={deep_lab_val_loss/len(val_loader):.4f} | "
          f"Dice={deep_lab_val_dice/len(val_loader):.4f} | IoU={deep_lab_val_iou/len(val_loader):.4f} | Acc={deep_lab_val_acc/len(val_loader):.4f} | "
          f"Prec={deep_lab_val_prec/len(val_loader):.4f} | Recall={deep_lab_val_rec/len(val_loader):.4f}")
    print(f" Test  Loss={deep_lab_test_loss/len(test_loader):.4f} | "
          f"Dice={deep_lab_test_dice/len(test_loader):.4f} | IoU={deep_lab_test_iou/len(test_loader):.4f} | Acc={deep_lab_test_acc/len(test_loader):.4f} | "
          f"Prec={deep_lab_test_prec/len(test_loader):.4f} | Recall={deep_lab_test_rec/len(test_loader):.4f}")

    # --- Early stopping ---
    if deep_lab_test_loss < deep_lab_best_test_loss:
        deep_lab_best_test_loss = deep_lab_test_loss
        epochs_no_improve = 0
        torch.save(deeplab_model.state_dict(), "best_deeplabv3.pth")
        print("✅ DeepLabV3 Model saved!")
    else:
        epochs_no_improve += 1
        print(f"EarlyStopping counter: {epochs_no_improve} of {patience}")
        if epochs_no_improve >= patience:
            print("⏹️ Early stopping triggered!")
            break

DeepLabV3 Epoch 1/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.36it/s]
DeepLabV3 Epoch 1/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 15.25it/s]
DeepLabV3 Epoch 1/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 16.07it/s]



DeepLabV3 Epoch 1:
 Train Loss=0.8097 | Dice=0.0137 | IoU=0.0070 | Acc=0.7282
 Val   Loss=0.9446 | Dice=0.0000 | IoU=0.0000 | Acc=0.9692 | Prec=0.0000 | Recall=0.0000
 Test  Loss=0.9199 | Dice=0.0000 | IoU=0.0000 | Acc=0.9721 | Prec=0.0000 | Recall=0.0000
✅ DeepLabV3 Model saved!


DeepLabV3 Epoch 2/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.64it/s]
DeepLabV3 Epoch 2/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 15.21it/s]
DeepLabV3 Epoch 2/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 15.68it/s]



DeepLabV3 Epoch 2:
 Train Loss=0.7111 | Dice=0.0000 | IoU=0.0000 | Acc=0.9921
 Val   Loss=0.7000 | Dice=0.0000 | IoU=0.0000 | Acc=0.9901 | Prec=0.0000 | Recall=0.0000
 Test  Loss=0.6968 | Dice=0.0000 | IoU=0.0000 | Acc=0.9902 | Prec=0.0000 | Recall=0.0000
✅ DeepLabV3 Model saved!


DeepLabV3 Epoch 3/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.54it/s]
DeepLabV3 Epoch 3/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 15.49it/s]
DeepLabV3 Epoch 3/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 16.26it/s]



DeepLabV3 Epoch 3:
 Train Loss=0.6694 | Dice=0.0000 | IoU=0.0000 | Acc=0.9942
 Val   Loss=0.6884 | Dice=0.0000 | IoU=0.0000 | Acc=0.9924 | Prec=0.0000 | Recall=0.0000
 Test  Loss=0.7018 | Dice=0.0000 | IoU=0.0000 | Acc=0.9919 | Prec=0.0000 | Recall=0.0000
EarlyStopping counter: 1 of 5


DeepLabV3 Epoch 4/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.49it/s]
DeepLabV3 Epoch 4/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 14.60it/s]
DeepLabV3 Epoch 4/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 15.41it/s]



DeepLabV3 Epoch 4:
 Train Loss=0.6484 | Dice=0.0000 | IoU=0.0000 | Acc=0.9948
 Val   Loss=0.6421 | Dice=0.0000 | IoU=0.0000 | Acc=0.9943 | Prec=0.0000 | Recall=0.0000
 Test  Loss=0.6444 | Dice=0.0000 | IoU=0.0000 | Acc=0.9935 | Prec=0.0000 | Recall=0.0000
✅ DeepLabV3 Model saved!


DeepLabV3 Epoch 5/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.43it/s]
DeepLabV3 Epoch 5/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 14.57it/s]
DeepLabV3 Epoch 5/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 15.64it/s]



DeepLabV3 Epoch 5:
 Train Loss=0.6325 | Dice=0.0000 | IoU=0.0000 | Acc=0.9950
 Val   Loss=0.6312 | Dice=0.0000 | IoU=0.0000 | Acc=0.9952 | Prec=0.0000 | Recall=0.0000
 Test  Loss=0.6315 | Dice=0.0000 | IoU=0.0000 | Acc=0.9945 | Prec=0.0000 | Recall=0.0000
✅ DeepLabV3 Model saved!


DeepLabV3 Epoch 6/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.62it/s]
DeepLabV3 Epoch 6/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 15.19it/s]
DeepLabV3 Epoch 6/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 15.87it/s]



DeepLabV3 Epoch 6:
 Train Loss=0.5976 | Dice=0.0399 | IoU=0.0224 | Acc=0.9953
 Val   Loss=0.5641 | Dice=0.1158 | IoU=0.0670 | Acc=0.9921 | Prec=0.1093 | Recall=0.1334
 Test  Loss=0.5826 | Dice=0.0890 | IoU=0.0542 | Acc=0.9912 | Prec=0.1001 | Recall=0.1112
✅ DeepLabV3 Model saved!


DeepLabV3 Epoch 7/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.44it/s]
DeepLabV3 Epoch 7/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 14.83it/s]
DeepLabV3 Epoch 7/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 15.68it/s]



DeepLabV3 Epoch 7:
 Train Loss=0.4555 | Dice=0.3015 | IoU=0.1897 | Acc=0.9957
 Val   Loss=0.4848 | Dice=0.2574 | IoU=0.1638 | Acc=0.9940 | Prec=0.3133 | Recall=0.2516
 Test  Loss=0.5050 | Dice=0.2311 | IoU=0.1506 | Acc=0.9937 | Prec=0.3838 | Recall=0.2206
✅ DeepLabV3 Model saved!


DeepLabV3 Epoch 8/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.47it/s]
DeepLabV3 Epoch 8/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 15.01it/s]
DeepLabV3 Epoch 8/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 15.65it/s]



DeepLabV3 Epoch 8:
 Train Loss=0.3667 | Dice=0.4572 | IoU=0.3096 | Acc=0.9961
 Val   Loss=0.3528 | Dice=0.4747 | IoU=0.3281 | Acc=0.9966 | Prec=0.6331 | Recall=0.4115
 Test  Loss=0.3869 | Dice=0.4086 | IoU=0.2847 | Acc=0.9961 | Prec=0.6585 | Recall=0.3574
✅ DeepLabV3 Model saved!


DeepLabV3 Epoch 9/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.42it/s]
DeepLabV3 Epoch 9/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 14.48it/s]
DeepLabV3 Epoch 9/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 15.32it/s]



DeepLabV3 Epoch 9:
 Train Loss=0.3029 | Dice=0.5667 | IoU=0.4019 | Acc=0.9965
 Val   Loss=0.3534 | Dice=0.4608 | IoU=0.3213 | Acc=0.9964 | Prec=0.7018 | Recall=0.4072
 Test  Loss=0.3658 | Dice=0.4380 | IoU=0.3097 | Acc=0.9957 | Prec=0.6134 | Recall=0.3974
✅ DeepLabV3 Model saved!


DeepLabV3 Epoch 10/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.43it/s]
DeepLabV3 Epoch 10/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 14.90it/s]
DeepLabV3 Epoch 10/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 15.63it/s]



DeepLabV3 Epoch 10:
 Train Loss=0.2617 | Dice=0.6307 | IoU=0.4650 | Acc=0.9969
 Val   Loss=0.2606 | Dice=0.6217 | IoU=0.4639 | Acc=0.9970 | Prec=0.6314 | Recall=0.6252
 Test  Loss=0.2873 | Dice=0.5699 | IoU=0.4164 | Acc=0.9962 | Prec=0.5892 | Recall=0.5928
✅ DeepLabV3 Model saved!


DeepLabV3 Epoch 11/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.66it/s]
DeepLabV3 Epoch 11/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 14.34it/s]
DeepLabV3 Epoch 11/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 16.09it/s]



DeepLabV3 Epoch 11:
 Train Loss=0.2265 | Dice=0.6865 | IoU=0.5256 | Acc=0.9973
 Val   Loss=0.2615 | Dice=0.6055 | IoU=0.4553 | Acc=0.9972 | Prec=0.6906 | Recall=0.5581
 Test  Loss=0.2965 | Dice=0.5372 | IoU=0.3934 | Acc=0.9966 | Prec=0.6711 | Recall=0.4943
EarlyStopping counter: 1 of 5


DeepLabV3 Epoch 12/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.58it/s]
DeepLabV3 Epoch 12/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 15.20it/s]
DeepLabV3 Epoch 12/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 16.16it/s]



DeepLabV3 Epoch 12:
 Train Loss=0.2090 | Dice=0.7093 | IoU=0.5518 | Acc=0.9975
 Val   Loss=0.2306 | Dice=0.6623 | IoU=0.5051 | Acc=0.9969 | Prec=0.5870 | Recall=0.7686
 Test  Loss=0.2550 | Dice=0.6127 | IoU=0.4587 | Acc=0.9964 | Prec=0.5704 | Recall=0.6958
✅ DeepLabV3 Model saved!


DeepLabV3 Epoch 13/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.55it/s]
DeepLabV3 Epoch 13/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 15.33it/s]
DeepLabV3 Epoch 13/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 16.21it/s]



DeepLabV3 Epoch 13:
 Train Loss=0.1959 | Dice=0.7243 | IoU=0.5688 | Acc=0.9976
 Val   Loss=0.2352 | Dice=0.6434 | IoU=0.4942 | Acc=0.9972 | Prec=0.6450 | Recall=0.6613
 Test  Loss=0.2672 | Dice=0.5817 | IoU=0.4352 | Acc=0.9963 | Prec=0.6086 | Recall=0.6111
EarlyStopping counter: 1 of 5


DeepLabV3 Epoch 14/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.41it/s]
DeepLabV3 Epoch 14/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 14.84it/s]
DeepLabV3 Epoch 14/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 15.49it/s]



DeepLabV3 Epoch 14:
 Train Loss=0.1783 | Dice=0.7493 | IoU=0.6012 | Acc=0.9978
 Val   Loss=0.2281 | Dice=0.6469 | IoU=0.4931 | Acc=0.9972 | Prec=0.6983 | Recall=0.6129
 Test  Loss=0.2626 | Dice=0.5809 | IoU=0.4376 | Acc=0.9965 | Prec=0.6800 | Recall=0.5494
EarlyStopping counter: 2 of 5


DeepLabV3 Epoch 15/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.47it/s]
DeepLabV3 Epoch 15/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 14.91it/s]
DeepLabV3 Epoch 15/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 16.04it/s]



DeepLabV3 Epoch 15:
 Train Loss=0.1702 | Dice=0.7568 | IoU=0.6112 | Acc=0.9980
 Val   Loss=0.1764 | Dice=0.7391 | IoU=0.5919 | Acc=0.9976 | Prec=0.6590 | Recall=0.8463
 Test  Loss=0.2000 | Dice=0.6927 | IoU=0.5397 | Acc=0.9970 | Prec=0.6337 | Recall=0.7820
✅ DeepLabV3 Model saved!


DeepLabV3 Epoch 16/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.53it/s]
DeepLabV3 Epoch 16/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 15.08it/s]
DeepLabV3 Epoch 16/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 15.93it/s]



DeepLabV3 Epoch 16:
 Train Loss=0.1501 | Dice=0.7895 | IoU=0.6538 | Acc=0.9982
 Val   Loss=0.2030 | Dice=0.6814 | IoU=0.5284 | Acc=0.9976 | Prec=0.7485 | Recall=0.6359
 Test  Loss=0.2348 | Dice=0.6195 | IoU=0.4771 | Acc=0.9969 | Prec=0.6867 | Recall=0.5981
EarlyStopping counter: 1 of 5


DeepLabV3 Epoch 17/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.49it/s]
DeepLabV3 Epoch 17/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 14.24it/s]
DeepLabV3 Epoch 17/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 15.20it/s]



DeepLabV3 Epoch 17:
 Train Loss=0.1434 | Dice=0.7958 | IoU=0.6617 | Acc=0.9983
 Val   Loss=0.1868 | Dice=0.7052 | IoU=0.5558 | Acc=0.9978 | Prec=0.7604 | Recall=0.6624
 Test  Loss=0.2293 | Dice=0.6212 | IoU=0.4790 | Acc=0.9970 | Prec=0.6972 | Recall=0.5950
EarlyStopping counter: 2 of 5


DeepLabV3 Epoch 18/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.31it/s]
DeepLabV3 Epoch 18/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 14.67it/s]
DeepLabV3 Epoch 18/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 16.04it/s]



DeepLabV3 Epoch 18:
 Train Loss=0.1403 | Dice=0.7963 | IoU=0.6626 | Acc=0.9983
 Val   Loss=0.1980 | Dice=0.6825 | IoU=0.5345 | Acc=0.9974 | Prec=0.6919 | Recall=0.6827
 Test  Loss=0.2250 | Dice=0.6322 | IoU=0.4905 | Acc=0.9967 | Prec=0.6693 | Recall=0.6423
EarlyStopping counter: 3 of 5


DeepLabV3 Epoch 19/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.60it/s]
DeepLabV3 Epoch 19/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 14.53it/s]
DeepLabV3 Epoch 19/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 15.91it/s]



DeepLabV3 Epoch 19:
 Train Loss=0.1366 | Dice=0.7983 | IoU=0.6660 | Acc=0.9983
 Val   Loss=0.1558 | Dice=0.7560 | IoU=0.6137 | Acc=0.9979 | Prec=0.7152 | Recall=0.8070
 Test  Loss=0.1818 | Dice=0.7058 | IoU=0.5563 | Acc=0.9973 | Prec=0.6817 | Recall=0.7500
✅ DeepLabV3 Model saved!


DeepLabV3 Epoch 20/20 - Training: 100%|██████████| 24/24 [00:03<00:00,  6.61it/s]
DeepLabV3 Epoch 20/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 15.26it/s]
DeepLabV3 Epoch 20/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 15.85it/s]


DeepLabV3 Epoch 20:
 Train Loss=0.1326 | Dice=0.8019 | IoU=0.6702 | Acc=0.9983
 Val   Loss=0.1730 | Dice=0.7204 | IoU=0.5691 | Acc=0.9977 | Prec=0.7072 | Recall=0.7386
 Test  Loss=0.2017 | Dice=0.6656 | IoU=0.5171 | Acc=0.9970 | Prec=0.6669 | Recall=0.6928
EarlyStopping counter: 1 of 5





In [26]:
print(f"Best Test Loss: {deep_lab_test_loss/len(test_loader):.4f} at epoch {epoch + 1 - epochs_no_improve}")
print(f"Best Test Dice: {deep_lab_test_dice/len(test_loader):.4f}")
print(f"Best Test IoU: {deep_lab_test_iou/len(test_loader):.4f}")
print(f"Best Test Accuracy: {deep_lab_test_acc/len(test_loader):.4f}")
print(f"Best Test Precision: {deep_lab_test_prec/len(test_loader):.4f}")
print(f"Best Test Recall: {deep_lab_test_rec/len(test_loader):.4f}")

Best Test Loss: 0.2017 at epoch 19
Best Test Dice: 0.6656
Best Test IoU: 0.5171
Best Test Accuracy: 0.9970
Best Test Precision: 0.6669
Best Test Recall: 0.6928


In [58]:
deep_lab_loss = deep_lab_test_loss/len(test_loader)
deep_lab_dice = deep_lab_test_dice/len(test_loader)
deep_lab_iou = deep_lab_test_iou/len(test_loader)

In [31]:
import torchvision
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

# -----------------------
# FPN Model
# -----------------------
class FPNHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(FPNHead, self).__init__()
        self.conv = nn.Conv2d(in_channels, 128, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.out = nn.Conv2d(128, num_classes, 1)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.out(x)
        return x

class FPNNet(nn.Module):
    def __init__(self, num_classes=1):
        super(FPNNet, self).__init__()
        # Use a pre-trained ResNet50 FPN backbone
        self.backbone = resnet_fpn_backbone('resnet50', pretrained=True)
        self.head = FPNHead(256, num_classes)  # 256 is the FPN out_channels

    def forward(self, x):
        features = self.backbone(x)
        # Use the highest resolution FPN output ('0')
        x = features['0']
        x = self.head(x)
        x = nn.functional.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
        return torch.sigmoid(x)

# -----------------------
# Training FPN
# -----------------------
fpn_model = FPNNet().to(device)
optimizer = optim.Adam(fpn_model.parameters(), lr=1e-4)
bce = nn.BCELoss()
num_epochs = 20
best_test_loss = float("inf")
patience = 5
epochs_no_improve = 0

for epoch in range(num_epochs):
    # --- Training ---
    fpn_model.train()
    fpn_train_loss, fpn_train_dice, fpn_train_iou, fpn_train_acc = 0, 0, 0, 0
    for imgs, masks in tqdm(train_loader, desc=f"FPN Epoch {epoch+1}/{num_epochs} - Training"):
        imgs, masks = imgs.to(device), masks.to(device)
        preds = fpn_model(imgs)
        loss = 0.5 * bce(preds, masks) + 0.5 * (1 - dice_score(preds, masks))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        fpn_train_loss += loss.item()
        fpn_train_dice += dice_score(preds, masks)
        fpn_train_iou += iou_score(preds, masks)
        fpn_train_acc += pixel_accuracy(preds, masks)

    # --- Validation ---
    fpn_model.eval()
    fpn_val_loss, fpn_val_dice, fpn_val_iou, fpn_val_acc, fpn_val_prec, fpn_val_rec = 0, 0, 0, 0, 0, 0
    with torch.no_grad():
        for imgs, masks in tqdm(val_loader, desc=f"FPN Epoch {epoch+1}/{num_epochs} - Validation"):
            imgs, masks = imgs.to(device), masks.to(device)
            preds = fpn_model(imgs)
            loss = 0.5 * bce(preds, masks) + 0.5 * (1 - dice_score(preds, masks))
            fpn_val_loss += loss.item()
            fpn_val_dice += dice_score(preds, masks)
            fpn_val_iou += iou_score(preds, masks)
            fpn_val_acc += pixel_accuracy(preds, masks)
            fpn_val_prec += precision_score(preds, masks)
            fpn_val_rec += recall_score(preds, masks)

    # --- Test ---
    fpn_test_loss, fpn_test_dice, fpn_test_iou, fpn_test_acc, fpn_test_prec, fpn_test_rec = 0, 0, 0, 0, 0, 0
    with torch.no_grad():
        for imgs, masks in tqdm(test_loader, desc=f"FPN Epoch {epoch+1}/{num_epochs} - Test"):
            imgs, masks = imgs.to(device), masks.to(device)
            preds = fpn_model(imgs)
            loss = 0.5 * bce(preds, masks) + 0.5 * (1 - dice_score(preds, masks))
            fpn_test_loss += loss.item()
            fpn_test_dice += dice_score(preds, masks)
            fpn_test_iou += iou_score(preds, masks)
            fpn_test_acc += pixel_accuracy(preds, masks)
            fpn_test_prec += precision_score(preds, masks)
            fpn_test_rec += recall_score(preds, masks)

    # --- Print metrics ---
    print(f"\nFPN Epoch {epoch+1}:")
    print(f" Train Loss={fpn_train_loss/len(train_loader):.4f} | "
          f"Dice={fpn_train_dice/len(train_loader):.4f} | IoU={fpn_train_iou/len(train_loader):.4f} | Acc={fpn_train_acc/len(train_loader):.4f}")
    print(f" Val   Loss={fpn_val_loss/len(val_loader):.4f} | "
          f"Dice={fpn_val_dice/len(val_loader):.4f} | IoU={fpn_val_iou/len(val_loader):.4f} | Acc={fpn_val_acc/len(val_loader):.4f} | "
          f"Prec={fpn_val_prec/len(val_loader):.4f} | Recall={fpn_val_rec/len(val_loader):.4f}")
    print(f" Test  Loss={fpn_test_loss/len(test_loader):.4f} | "
          f"Dice={fpn_test_dice/len(test_loader):.4f} | IoU={fpn_test_iou/len(test_loader):.4f} | Acc={fpn_test_acc/len(test_loader):.4f} | "
          f"Prec={fpn_test_prec/len(test_loader):.4f} | Recall={fpn_test_rec/len(test_loader):.4f}")

    # --- Early stopping ---
    if fpn_test_loss < best_test_loss:
        best_test_loss = fpn_test_loss
        epochs_no_improve = 0
        torch.save(fpn_model.state_dict(), "best_fpn.pth")
        print("✅ FPN Model saved!")
    else:
        epochs_no_improve += 1
        print(f"EarlyStopping counter: {epochs_no_improve} of {patience}")
        if epochs_no_improve >= patience:
            print("⏹️ Early stopping triggered!")
            break

FPN Epoch 1/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 14.59it/s]
FPN Epoch 1/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 22.70it/s]
FPN Epoch 1/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 25.01it/s]



FPN Epoch 1:
 Train Loss=0.5302 | Dice=0.0001 | IoU=0.0000 | Acc=0.9739
 Val   Loss=0.5070 | Dice=0.0000 | IoU=0.0000 | Acc=0.9958 | Prec=0.0000 | Recall=0.0000
 Test  Loss=0.5074 | Dice=0.0000 | IoU=0.0000 | Acc=0.9954 | Prec=0.0000 | Recall=0.0000
✅ FPN Model saved!


FPN Epoch 2/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 14.66it/s]
FPN Epoch 2/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.85it/s]
FPN Epoch 2/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 24.32it/s]



FPN Epoch 2:
 Train Loss=0.4533 | Dice=0.1060 | IoU=0.0623 | Acc=0.9959
 Val   Loss=0.3656 | Dice=0.2788 | IoU=0.1679 | Acc=0.9964 | Prec=0.8312 | Recall=0.1757
 Test  Loss=0.3983 | Dice=0.2144 | IoU=0.1307 | Acc=0.9960 | Prec=0.7160 | Recall=0.1398
✅ FPN Model saved!


FPN Epoch 3/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 13.84it/s]
FPN Epoch 3/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.85it/s]
FPN Epoch 3/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 18.67it/s]



FPN Epoch 3:
 Train Loss=0.3305 | Dice=0.3482 | IoU=0.2198 | Acc=0.9965
 Val   Loss=0.2799 | Dice=0.4483 | IoU=0.2926 | Acc=0.9968 | Prec=0.8008 | Recall=0.3224
 Test  Loss=0.3134 | Dice=0.3826 | IoU=0.2526 | Acc=0.9963 | Prec=0.7183 | Recall=0.2832
✅ FPN Model saved!


FPN Epoch 4/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 14.29it/s]
FPN Epoch 4/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 14.29it/s]
FPN Epoch 4/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.99it/s]
FPN Epoch 4/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.99it/s]
FPN Epoch 4/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 24.99it/s]
FPN Epoch 4/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 24.99it/s]



FPN Epoch 4:
 Train Loss=0.2522 | Dice=0.5031 | IoU=0.3462 | Acc=0.9969
 Val   Loss=0.3960 | Dice=0.2163 | IoU=0.1274 | Acc=0.9964 | Prec=0.8820 | Recall=0.1294
 Test  Loss=0.3605 | Dice=0.2891 | IoU=0.1877 | Acc=0.9962 | Prec=0.8136 | Recall=0.1953
EarlyStopping counter: 1 of 5


FPN Epoch 5/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 14.92it/s]
FPN Epoch 5/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 14.92it/s]
FPN Epoch 5/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 22.74it/s]
FPN Epoch 5/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 22.74it/s]
FPN Epoch 5/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 24.32it/s]




FPN Epoch 5:
 Train Loss=0.2152 | Dice=0.5763 | IoU=0.4173 | Acc=0.9972
 Val   Loss=0.2340 | Dice=0.5380 | IoU=0.3726 | Acc=0.9973 | Prec=0.8755 | Recall=0.3967
 Test  Loss=0.2468 | Dice=0.5140 | IoU=0.3619 | Acc=0.9969 | Prec=0.8144 | Recall=0.3944
✅ FPN Model saved!


FPN Epoch 6/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.47it/s]
FPN Epoch 6/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.47it/s]
FPN Epoch 6/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 21.15it/s]
FPN Epoch 6/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 21.15it/s]
FPN Epoch 6/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 24.79it/s]




FPN Epoch 6:
 Train Loss=0.1620 | Dice=0.6816 | IoU=0.5214 | Acc=0.9977
 Val   Loss=0.1751 | Dice=0.6551 | IoU=0.4923 | Acc=0.9977 | Prec=0.8161 | Recall=0.5614
 Test  Loss=0.1887 | Dice=0.6297 | IoU=0.4719 | Acc=0.9972 | Prec=0.7466 | Recall=0.5668
✅ FPN Model saved!


FPN Epoch 7/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.06it/s]
FPN Epoch 7/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.06it/s]
FPN Epoch 7/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.33it/s]
FPN Epoch 7/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.33it/s]
FPN Epoch 7/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 23.78it/s]
FPN Epoch 7/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 23.78it/s]



FPN Epoch 7:
 Train Loss=0.1391 | Dice=0.7267 | IoU=0.5746 | Acc=0.9979
 Val   Loss=0.1923 | Dice=0.6206 | IoU=0.4551 | Acc=0.9977 | Prec=0.9181 | Recall=0.4756
 Test  Loss=0.2085 | Dice=0.5901 | IoU=0.4335 | Acc=0.9973 | Prec=0.8748 | Recall=0.4693
EarlyStopping counter: 1 of 5


FPN Epoch 8/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.36it/s]
FPN Epoch 8/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.36it/s]
FPN Epoch 8/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.93it/s]
FPN Epoch 8/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.93it/s]
FPN Epoch 8/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 25.39it/s]




FPN Epoch 8:
 Train Loss=0.1176 | Dice=0.7691 | IoU=0.6261 | Acc=0.9982
 Val   Loss=0.1358 | Dice=0.7329 | IoU=0.5827 | Acc=0.9981 | Prec=0.8533 | Recall=0.6484
 Test  Loss=0.1645 | Dice=0.6774 | IoU=0.5226 | Acc=0.9976 | Prec=0.8124 | Recall=0.5983
✅ FPN Model saved!


FPN Epoch 9/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.62it/s]
FPN Epoch 9/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.62it/s]
FPN Epoch 9/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 24.82it/s]
FPN Epoch 9/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 24.82it/s]
FPN Epoch 9/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 25.68it/s]




FPN Epoch 9:
 Train Loss=0.1157 | Dice=0.7727 | IoU=0.6333 | Acc=0.9983
 Val   Loss=0.1130 | Dice=0.7783 | IoU=0.6396 | Acc=0.9983 | Prec=0.8077 | Recall=0.7560
 Test  Loss=0.1398 | Dice=0.7267 | IoU=0.5765 | Acc=0.9977 | Prec=0.7660 | Recall=0.7136
✅ FPN Model saved!


FPN Epoch 10/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.76it/s]
FPN Epoch 10/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.76it/s]
FPN Epoch 10/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 24.58it/s]
FPN Epoch 10/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 24.58it/s]
FPN Epoch 10/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 23.24it/s]
FPN Epoch 10/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 23.24it/s]



FPN Epoch 10:
 Train Loss=0.0961 | Dice=0.8114 | IoU=0.6834 | Acc=0.9985
 Val   Loss=0.1328 | Dice=0.7387 | IoU=0.5904 | Acc=0.9982 | Prec=0.8982 | Recall=0.6351
 Test  Loss=0.1565 | Dice=0.6935 | IoU=0.5395 | Acc=0.9977 | Prec=0.8534 | Recall=0.6030
EarlyStopping counter: 1 of 5


FPN Epoch 11/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.19it/s]
FPN Epoch 11/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.19it/s]
FPN Epoch 11/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.54it/s]
FPN Epoch 11/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.54it/s]
FPN Epoch 11/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 25.02it/s]




FPN Epoch 11:
 Train Loss=0.0952 | Dice=0.8131 | IoU=0.6863 | Acc=0.9985
 Val   Loss=0.1059 | Dice=0.7920 | IoU=0.6568 | Acc=0.9984 | Prec=0.8326 | Recall=0.7565
 Test  Loss=0.1290 | Dice=0.7480 | IoU=0.6012 | Acc=0.9979 | Prec=0.8001 | Recall=0.7131
✅ FPN Model saved!


FPN Epoch 12/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.15it/s]
FPN Epoch 12/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.15it/s]
FPN Epoch 12/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.32it/s]
FPN Epoch 12/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.32it/s]
FPN Epoch 12/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 25.27it/s]
FPN Epoch 12/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 25.27it/s]



FPN Epoch 12:
 Train Loss=0.0877 | Dice=0.8279 | IoU=0.7070 | Acc=0.9986
 Val   Loss=0.1171 | Dice=0.7698 | IoU=0.6279 | Acc=0.9984 | Prec=0.8898 | Recall=0.6817
 Test  Loss=0.1370 | Dice=0.7320 | IoU=0.5832 | Acc=0.9980 | Prec=0.8633 | Recall=0.6494
EarlyStopping counter: 1 of 5


FPN Epoch 13/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.24it/s]
FPN Epoch 13/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.24it/s]
FPN Epoch 13/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 21.50it/s]
FPN Epoch 13/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 21.50it/s]
FPN Epoch 13/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 24.84it/s]




FPN Epoch 13:
 Train Loss=0.0824 | Dice=0.8383 | IoU=0.7220 | Acc=0.9987
 Val   Loss=0.0987 | Dice=0.8062 | IoU=0.6781 | Acc=0.9986 | Prec=0.8592 | Recall=0.7606
 Test  Loss=0.1254 | Dice=0.7553 | IoU=0.6112 | Acc=0.9980 | Prec=0.8204 | Recall=0.7102
✅ FPN Model saved!


FPN Epoch 14/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.40it/s]
FPN Epoch 14/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.40it/s]
FPN Epoch 14/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 24.17it/s]
FPN Epoch 14/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 24.17it/s]
FPN Epoch 14/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 22.38it/s]
FPN Epoch 14/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 22.38it/s]



FPN Epoch 14:
 Train Loss=0.0779 | Dice=0.8472 | IoU=0.7352 | Acc=0.9987
 Val   Loss=0.1067 | Dice=0.7904 | IoU=0.6555 | Acc=0.9984 | Prec=0.8612 | Recall=0.7341
 Test  Loss=0.1258 | Dice=0.7545 | IoU=0.6098 | Acc=0.9980 | Prec=0.8345 | Recall=0.6935
EarlyStopping counter: 1 of 5


FPN Epoch 15/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.28it/s]
FPN Epoch 15/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.28it/s]
FPN Epoch 15/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.76it/s]
FPN Epoch 15/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.76it/s]
FPN Epoch 15/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 24.42it/s]
FPN Epoch 15/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 24.42it/s]



FPN Epoch 15:
 Train Loss=0.0753 | Dice=0.8522 | IoU=0.7428 | Acc=0.9988
 Val   Loss=0.1074 | Dice=0.7892 | IoU=0.6536 | Acc=0.9985 | Prec=0.9014 | Recall=0.7030
 Test  Loss=0.1329 | Dice=0.7408 | IoU=0.5939 | Acc=0.9980 | Prec=0.8792 | Recall=0.6494
EarlyStopping counter: 2 of 5


FPN Epoch 16/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.08it/s]
FPN Epoch 16/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.08it/s]
FPN Epoch 16/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.87it/s]
FPN Epoch 16/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.87it/s]
FPN Epoch 16/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 24.77it/s]




FPN Epoch 16:
 Train Loss=0.0743 | Dice=0.8542 | IoU=0.7461 | Acc=0.9988
 Val   Loss=0.1025 | Dice=0.7987 | IoU=0.6675 | Acc=0.9985 | Prec=0.8759 | Recall=0.7357
 Test  Loss=0.1249 | Dice=0.7569 | IoU=0.6138 | Acc=0.9980 | Prec=0.8396 | Recall=0.6999
✅ FPN Model saved!


FPN Epoch 17/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 14.92it/s]
FPN Epoch 17/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 14.92it/s]
FPN Epoch 17/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 20.70it/s]
FPN Epoch 17/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 20.70it/s]
FPN Epoch 17/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 25.26it/s]




FPN Epoch 17:
 Train Loss=0.0755 | Dice=0.8518 | IoU=0.7423 | Acc=0.9988
 Val   Loss=0.0953 | Dice=0.8131 | IoU=0.6881 | Acc=0.9985 | Prec=0.7803 | Recall=0.8516
 Test  Loss=0.1120 | Dice=0.7818 | IoU=0.6452 | Acc=0.9980 | Prec=0.7623 | Recall=0.8106
✅ FPN Model saved!


FPN Epoch 18/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.37it/s]
FPN Epoch 18/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.37it/s]
FPN Epoch 18/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 24.30it/s]
FPN Epoch 18/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 24.30it/s]
FPN Epoch 18/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 25.45it/s]
FPN Epoch 18/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 25.45it/s]



FPN Epoch 18:
 Train Loss=0.0784 | Dice=0.8462 | IoU=0.7341 | Acc=0.9987
 Val   Loss=0.0933 | Dice=0.8170 | IoU=0.6932 | Acc=0.9985 | Prec=0.7874 | Recall=0.8507
 Test  Loss=0.1160 | Dice=0.7741 | IoU=0.6357 | Acc=0.9980 | Prec=0.7675 | Recall=0.7972
EarlyStopping counter: 1 of 5


FPN Epoch 19/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.55it/s]
FPN Epoch 19/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.55it/s]
FPN Epoch 19/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.74it/s]
FPN Epoch 19/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 23.74it/s]
FPN Epoch 19/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 21.53it/s]
FPN Epoch 19/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 21.53it/s]



FPN Epoch 19:
 Train Loss=0.0705 | Dice=0.8618 | IoU=0.7574 | Acc=0.9989
 Val   Loss=0.1002 | Dice=0.8032 | IoU=0.6728 | Acc=0.9986 | Prec=0.9065 | Recall=0.7224
 Test  Loss=0.1290 | Dice=0.7482 | IoU=0.6031 | Acc=0.9981 | Prec=0.8856 | Recall=0.6607
EarlyStopping counter: 2 of 5


FPN Epoch 20/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.37it/s]
FPN Epoch 20/20 - Training: 100%|██████████| 24/24 [00:01<00:00, 15.37it/s]
FPN Epoch 20/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 22.85it/s]
FPN Epoch 20/20 - Validation: 100%|██████████| 6/6 [00:00<00:00, 22.85it/s]
FPN Epoch 20/20 - Test: 100%|██████████| 8/8 [00:00<00:00, 24.10it/s]


FPN Epoch 20:
 Train Loss=0.0699 | Dice=0.8627 | IoU=0.7599 | Acc=0.9989
 Val   Loss=0.0970 | Dice=0.8097 | IoU=0.6824 | Acc=0.9986 | Prec=0.8946 | Recall=0.7418
 Test  Loss=0.1249 | Dice=0.7570 | IoU=0.6149 | Acc=0.9981 | Prec=0.8767 | Recall=0.6792
EarlyStopping counter: 3 of 5





In [32]:
print(f"Best Test Loss: {fpn_test_loss/len(test_loader):.4f} at epoch {epoch + 1 - epochs_no_improve}")
print(f"Best Test Dice: {fpn_test_dice/len(test_loader):.4f}")
print(f"Best Test IoU: {fpn_test_iou/len(test_loader):.4f}")
print(f"Best Test Accuracy: {fpn_test_acc/len(test_loader):.4f}")
print(f"Best Test Precision: {fpn_test_prec/len(test_loader):.4f}")
print(f"Best Test Recall: {fpn_test_rec/len(test_loader):.4f}")

Best Test Loss: 0.1249 at epoch 17
Best Test Dice: 0.7570
Best Test IoU: 0.6149
Best Test Accuracy: 0.9981
Best Test Precision: 0.8767
Best Test Recall: 0.6792


In [59]:
fpn_loss = fpn_test_loss/len(test_loader)
fpn_dice = fpn_test_dice/len(test_loader)
fpn_iou = fpn_test_iou/len(test_loader)

In [60]:
# Model ranking based on test_dice (higher is better) and test_iou (higher is better)
# Make sure to divide by the number of test samples to get values in [0, 1]
num_test_samples = len(test_loader.dataset)
model_results = [
    {
        "name": "U-Net",
        "test_loss": u_net_loss,
        "test_dice": u_net_dice,
        "test_iou": u_net_iou
    },
    {
        "name": "DeepLabV3",
        "test_loss": deep_lab_loss,
        "test_dice": deep_lab_dice,
        "test_iou": deep_lab_iou
    },
    {
        "name": "FPN",
        "test_loss": fpn_loss,
        "test_dice": fpn_dice,
        "test_iou": fpn_iou
    }
]

# Rank by Dice (higher is better)
dice_ranking = sorted(model_results, key=lambda x: x["test_dice"], reverse=True)
print("Ranking by Dice Score:")
for rank, result in enumerate(dice_ranking, 1):
    print(f"Rank {rank}: {result['name']} (Dice: {result['test_dice']:.4f})")

print("\nRanking by IoU Score:")
# Rank by IoU (higher is better)
iou_ranking = sorted(model_results, key=lambda x: x["test_iou"], reverse=True)
for rank, result in enumerate(iou_ranking, 1):
    print(f"Rank {rank}: {result['name']} (IoU: {result['test_iou']:.4f})")

Ranking by Dice Score:
Rank 1: U-Net (Dice: 0.8038)
Rank 2: FPN (Dice: 0.7570)
Rank 3: DeepLabV3 (Dice: 0.6656)

Ranking by IoU Score:
Rank 1: U-Net (IoU: 0.6760)
Rank 2: FPN (IoU: 0.6149)
Rank 3: DeepLabV3 (IoU: 0.5171)
