In [1]:
import torch
print("PyTorch Version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())
print("Number of GPUs:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
print("Done")

PyTorch Version: 2.6.0+cu118
CUDA Available: True
Number of GPUs: 1
GPU Name: NVIDIA GeForce RTX 4090
Done


In [7]:
import os
import sys
import numpy as np
import cv2
from glob import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score, confusion_matrix
import matplotlib.pyplot as plt
import time

# Data directories
train_image_dir = "E:/Code/ClassPlusSeg/Final_Dataset/Isic_2018_training_Images_Cleaned_v1"
train_mask_dir = "E:/Code/ClassPlusSeg/Final_Dataset/Isic_2018_training_ground_truth_V1"
val_image_dir = "E:/Code/ClassPlusSeg/Final_Dataset/Isic_2018_validation_Images_Cleaned_v1"
val_mask_dir = "E:/Code/ClassPlusSeg/Final_Dataset/Isic_2018_validation_ground_truth_V1"
test_image_dir = "E:/Code/ClassPlusSeg/Final_Dataset/Isic_2018_test_Images_Cleaned_v1"
test_mask_dir = "E:/Code/ClassPlusSeg/Final_Dataset/Isic_2018_test_ground_truth_V1"

# Global parameters
H = 256
W = 256
batch_size = 16
lr = 1e-4
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# SE Block Definition
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

# BatchNormReLU
class BatchNormReLU(nn.Module):
    def __init__(self, num_features):
        super(BatchNormReLU, self).__init__()
        self.bn = nn.BatchNorm2d(num_features)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.bn(x)
        x = self.relu(x)
        return x

# Residual Block with SE
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, strides=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=strides, padding=1, bias=False)
        self.bn_relu1 = BatchNormReLU(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_relu2 = BatchNormReLU(out_channels)
        self.se = SEBlock(out_channels)

        self.shortcut = nn.Sequential()
        if strides != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=strides, padding=0, bias=False)
            )

    def forward(self, x):
        identity = self.shortcut(x)
        x = self.conv1(x)
        x = self.bn_relu1(x)
        x = self.conv2(x)
        x = self.bn_relu2(x)
        x = self.se(x)
        x = x + identity
        return x

# Decoder Block with SE
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.concat_channels = in_channels + skip_channels
        self.res_block = ResidualBlock(self.concat_channels, out_channels, strides=1)
        self.se = SEBlock(out_channels)

    def forward(self, x, skip):
        x = self.upsample(x)
        x = torch.cat([x, skip], dim=1)
        x = self.res_block(x)
        x = self.se(x)
        return x

# Dense PPM Bridge
class DensePPMBridge(nn.Module):
    def __init__(self, in_channels=256, out_channels=512):
        super(DensePPMBridge, self).__init__()
        growth_rate = 128

        self.dense1 = nn.Sequential(
            nn.Conv2d(in_channels, growth_rate, kernel_size=3, stride=2, padding=1, bias=False),
            BatchNormReLU(growth_rate)
        )
        self.dense2 = nn.Sequential(
            nn.Conv2d(in_channels + growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False),
            BatchNormReLU(growth_rate)
        )

        self.ppm_pool1 = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels + 2 * growth_rate, 128, 1, bias=False), BatchNormReLU(128))
        self.ppm_pool2 = nn.Sequential(nn.AdaptiveAvgPool2d(2), nn.Conv2d(in_channels + 2 * growth_rate, 128, 1, bias=False), BatchNormReLU(128))
        self.ppm_pool3 = nn.Sequential(nn.AdaptiveAvgPool2d(3), nn.Conv2d(in_channels + 2 * growth_rate, 128, 1, bias=False), BatchNormReLU(128))
        self.ppm_pool6 = nn.Sequential(nn.AdaptiveAvgPool2d(6), nn.Conv2d(in_channels + 2 * growth_rate, 128, 1, bias=False), BatchNormReLU(128))
        
        self.final_conv = nn.Conv2d(in_channels + 2 * growth_rate + 4 * 128, out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        d1 = self.dense1(x)
        d1_cat = torch.cat([F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True), d1], dim=1)
        d2 = self.dense2(d1_cat)
        dense_out = torch.cat([F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True), d1, d2], dim=1)

        ppm1 = self.ppm_pool1(dense_out)
        ppm1 = F.interpolate(ppm1, size=dense_out.size()[2:], mode='bilinear', align_corners=True)
        ppm2 = self.ppm_pool2(dense_out)
        ppm2 = F.interpolate(ppm2, size=dense_out.size()[2:], mode='bilinear', align_corners=True)
        ppm3 = self.ppm_pool3(dense_out)
        ppm3 = F.interpolate(ppm3, size=dense_out.size()[2:], mode='bilinear', align_corners=True)
        ppm6 = self.ppm_pool6(dense_out)
        ppm6 = F.interpolate(ppm6, size=dense_out.size()[2:], mode='bilinear', align_corners=True)

        ppm_out = torch.cat([dense_out, ppm1, ppm2, ppm3, ppm6], dim=1)
        out = self.final_conv(ppm_out)
        return out

# SE-ResUNet Model
class SEResUNet(nn.Module):
    def __init__(self, input_shape=(256, 256, 3)):
        super(SEResUNet, self).__init__()
        self.input_shape = input_shape

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_relu1 = BatchNormReLU(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.se1 = SEBlock(64)
        self.shortcut1 = nn.Conv2d(3, 64, kernel_size=1, stride=1, padding=0, bias=False)

        self.encoder2 = ResidualBlock(64, 128, strides=2)
        self.encoder3 = ResidualBlock(128, 256, strides=2)
        self.bridge = DensePPMBridge(256, 512)

        self.decoder1 = DecoderBlock(512, 256, 256)
        self.decoder2 = DecoderBlock(256, 128, 128)
        self.decoder3 = DecoderBlock(128, 64, 64)

        self.output = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        s1 = self.conv1(x)
        s1 = self.bn_relu1(s1)
        s1 = self.conv2(s1)
        s1 = self.se1(s1)
        shortcut = self.shortcut1(x)
        s1 = s1 + shortcut

        s2 = self.encoder2(s1)
        s3 = self.encoder3(s2)
        b = self.bridge(s3)

        d1 = self.decoder1(b, s3)
        d2 = self.decoder2(d1, s2)
        d3 = self.decoder3(d2, s1)

        out = self.output(d3)
        out = self.sigmoid(out)
        return out

# Metrics and Utility Functions
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-15):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):
        y_pred = y_pred.view(-1)
        y_true = y_true.view(-1)
        intersection = (y_pred * y_true).sum()
        dice = (2. * intersection + self.smooth) / (y_pred.sum() + y_true.sum() + self.smooth)
        return 1.0 - dice

def dice_coef(y_pred, y_true, smooth=1e-15):
    y_pred = y_pred.view(-1)
    y_true = y_true.view(-1)
    intersection = (y_pred * y_true).sum()
    return (2. * intersection + smooth) / (y_pred.sum() + y_true.sum() + smooth)

def iou(y_pred, y_true, smooth=1e-15):
    y_pred = y_pred.view(-1)
    y_true = y_true.view(-1)
    intersection = (y_pred * y_true).sum()
    union = y_true.sum() + y_pred.sum() - intersection
    return (intersection + smooth) / (union + smooth)

def accuracy(y_pred, y_true):
    y_pred = y_pred.view(-1)
    y_true = y_true.view(-1)
    correct = (y_pred == y_true).float().sum()
    total = y_true.numel()
    return correct / total

def create_dir(path):
    try:
        if not os.path.exists(path):
            os.makedirs(path)
            print(f"Created directory: {path}")
        else:
            print(f"Directory already exists: {path}")
    except OSError as e:
        print(f"Error creating directory {path}: {e}")
        raise

def load_data():
    train_x = sorted(glob(os.path.join(train_image_dir, "*.jpg")))
    train_y = sorted(glob(os.path.join(train_mask_dir, "*.png")))
    valid_x = sorted(glob(os.path.join(val_image_dir, "*.jpg")))
    valid_y = sorted(glob(os.path.join(val_mask_dir, "*.png")))
    test_x = sorted(glob(os.path.join(test_image_dir, "*.jpg")))
    test_y = sorted(glob(os.path.join(test_mask_dir, "*.png")))
    return (train_x, train_y), (valid_x, valid_y), (test_x, test_y)

class ISICDataset(Dataset):
    def __init__(self, images, masks):
        self.images = images
        self.masks = masks

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = cv2.resize(img, (W, H))
        img = img / 255.0
        img = img.astype(np.float32)
        img = np.transpose(img, (2, 0, 1))

        mask_path = self.masks[idx]
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, (W, H))
        mask = mask / 255.0
        mask = mask.astype(np.float32)
        mask = np.expand_dims(mask, axis=0)
        return torch.tensor(img, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)

# Main Execution
np.random.seed(42)
torch.manual_seed(42)

create_dir("Segmentation_Results")
create_dir("files")

(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_data()

print(f"Train: {len(train_x)} - {len(train_y)}")
print(f"Valid: {len(valid_x)} - {len(valid_y)}")
print(f"Test: {len(test_x)} - {len(test_y)}")

train_dataset = ISICDataset(train_x, train_y)
valid_dataset = ISICDataset(valid_x, valid_y)
test_dataset = ISICDataset(test_x, test_y)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Model, loss, optimizer
model = SEResUNet(input_shape=(H, W, 3)).to(device)
criterion = DiceLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
best_val_loss = float('inf')
model_path = os.path.join("files", "segModel.pth")
history = {
    'epoch': [],
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_dice': [],
    'val_iou': [],
    'val_acc': [],
    'train_time': [],
    'val_time': []
}

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_acc = 0.0
    train_start_time = time.time()
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False)
    for images, masks in train_bar:
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
        outputs_bin = (outputs > 0.5).float()
        train_acc += accuracy(outputs_bin, masks).item() * images.size(0)
        train_bar.set_postfix({"Train Loss": loss.item(), "Train Acc": train_acc / (train_loss + 1e-10)})
    train_loss /= len(train_loader.dataset)
    train_acc /= len(train_loader.dataset)
    train_time = (time.time() - train_start_time) * 1000 / len(train_loader)  # ms per batch

    model.eval()
    val_loss = 0.0
    val_dice = 0.0
    val_iou = 0.0
    val_acc = 0.0
    val_start_time = time.time()
    val_bar = tqdm(valid_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]", leave=False)
    with torch.no_grad():
        for images, masks in val_bar:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item() * images.size(0)
            outputs_bin = (outputs > 0.5).float()
            val_dice += dice_coef(outputs_bin, masks).item() * images.size(0)
            val_iou += iou(outputs_bin, masks).item() * images.size(0)
            val_acc += accuracy(outputs_bin, masks).item() * images.size(0)
            val_bar.set_postfix({"Val Loss": loss.item()})
    
    val_loss /= len(valid_loader.dataset)
    val_dice /= len(valid_loader.dataset)
    val_iou /= len(valid_loader.dataset)
    val_acc /= len(valid_loader.dataset)
    val_time = (time.time() - val_start_time) * 1000 / len(valid_loader)  # ms per batch

    # Save history
    history['epoch'].append(epoch + 1)
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_dice'].append(val_dice)
    history['val_iou'].append(val_iou)
    history['val_acc'].append(val_acc)
    history['train_time'].append(train_time)
    history['val_time'].append(val_time)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}, Val IoU: {val_iou:.4f}, Val Acc: {val_acc:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), model_path)
        print(f"Saved best model at epoch {epoch+1}")


Created directory: Segmentation_Results
Created directory: files
Train: 2594 - 2594
Valid: 100 - 100
Test: 1000 - 1000


                                                                                                                       

Epoch 1/100, Train Loss: 0.3531, Train Acc: 0.8731, Val Loss: 0.2809, Val Dice: 0.7273, Val IoU: 0.5805, Val Acc: 0.8627
Saved best model at epoch 1


                                                                                                                       

Epoch 2/100, Train Loss: 0.2166, Train Acc: 0.9127, Val Loss: 0.2130, Val Dice: 0.7894, Val IoU: 0.6642, Val Acc: 0.8954
Saved best model at epoch 2


                                                                                                                       

Epoch 3/100, Train Loss: 0.1906, Train Acc: 0.9227, Val Loss: 0.2087, Val Dice: 0.7923, Val IoU: 0.6703, Val Acc: 0.9025
Saved best model at epoch 3


                                                                                                                       

Epoch 4/100, Train Loss: 0.1754, Train Acc: 0.9277, Val Loss: 0.1910, Val Dice: 0.8099, Val IoU: 0.6885, Val Acc: 0.9047
Saved best model at epoch 4


                                                                                                                       

Epoch 5/100, Train Loss: 0.1649, Train Acc: 0.9322, Val Loss: 0.1904, Val Dice: 0.8103, Val IoU: 0.6941, Val Acc: 0.9083
Saved best model at epoch 5


                                                                                                                       

Epoch 6/100, Train Loss: 0.1641, Train Acc: 0.9328, Val Loss: 0.1784, Val Dice: 0.8223, Val IoU: 0.7069, Val Acc: 0.9141
Saved best model at epoch 6


                                                                                                                       

Epoch 7/100, Train Loss: 0.1571, Train Acc: 0.9352, Val Loss: 0.1929, Val Dice: 0.8076, Val IoU: 0.6834, Val Acc: 0.9037


                                                                                                                       

Epoch 8/100, Train Loss: 0.1499, Train Acc: 0.9381, Val Loss: 0.1895, Val Dice: 0.8108, Val IoU: 0.6914, Val Acc: 0.9077


                                                                                                                       

Epoch 9/100, Train Loss: 0.1488, Train Acc: 0.9393, Val Loss: 0.1712, Val Dice: 0.8291, Val IoU: 0.7199, Val Acc: 0.9196
Saved best model at epoch 9


                                                                                                                       

Epoch 10/100, Train Loss: 0.1440, Train Acc: 0.9408, Val Loss: 0.1655, Val Dice: 0.8352, Val IoU: 0.7236, Val Acc: 0.9148
Saved best model at epoch 10


                                                                                                                       

Epoch 11/100, Train Loss: 0.1422, Train Acc: 0.9410, Val Loss: 0.1500, Val Dice: 0.8505, Val IoU: 0.7428, Val Acc: 0.9185
Saved best model at epoch 11


                                                                                                                       

Epoch 12/100, Train Loss: 0.1366, Train Acc: 0.9434, Val Loss: 0.1629, Val Dice: 0.8377, Val IoU: 0.7233, Val Acc: 0.9138


                                                                                                                       

Epoch 13/100, Train Loss: 0.1365, Train Acc: 0.9436, Val Loss: 0.1987, Val Dice: 0.8016, Val IoU: 0.6832, Val Acc: 0.9112


                                                                                                                       

Epoch 14/100, Train Loss: 0.1343, Train Acc: 0.9439, Val Loss: 0.1947, Val Dice: 0.8054, Val IoU: 0.6814, Val Acc: 0.9073


                                                                                                                       

Epoch 15/100, Train Loss: 0.1365, Train Acc: 0.9434, Val Loss: 0.1429, Val Dice: 0.8575, Val IoU: 0.7548, Val Acc: 0.9240
Saved best model at epoch 15


                                                                                                                       

Epoch 16/100, Train Loss: 0.1279, Train Acc: 0.9472, Val Loss: 0.1712, Val Dice: 0.8291, Val IoU: 0.7132, Val Acc: 0.9134


                                                                                                                       

Epoch 17/100, Train Loss: 0.1291, Train Acc: 0.9468, Val Loss: 0.1684, Val Dice: 0.8317, Val IoU: 0.7228, Val Acc: 0.9200


                                                                                                                       

Epoch 18/100, Train Loss: 0.1316, Train Acc: 0.9449, Val Loss: 0.1626, Val Dice: 0.8377, Val IoU: 0.7293, Val Acc: 0.9202


                                                                                                                       

Epoch 19/100, Train Loss: 0.1333, Train Acc: 0.9449, Val Loss: 0.1667, Val Dice: 0.8336, Val IoU: 0.7203, Val Acc: 0.9133


                                                                                                                       

Epoch 20/100, Train Loss: 0.1231, Train Acc: 0.9490, Val Loss: 0.1276, Val Dice: 0.8727, Val IoU: 0.7765, Val Acc: 0.9318
Saved best model at epoch 20


                                                                                                                       

Epoch 21/100, Train Loss: 0.1217, Train Acc: 0.9494, Val Loss: 0.1607, Val Dice: 0.8395, Val IoU: 0.7296, Val Acc: 0.9153


                                                                                                                       

Epoch 22/100, Train Loss: 0.1236, Train Acc: 0.9492, Val Loss: 0.1480, Val Dice: 0.8522, Val IoU: 0.7538, Val Acc: 0.9297


                                                                                                                       

Epoch 23/100, Train Loss: 0.1174, Train Acc: 0.9512, Val Loss: 0.1716, Val Dice: 0.8286, Val IoU: 0.7163, Val Acc: 0.9159


                                                                                                                       

Epoch 24/100, Train Loss: 0.1194, Train Acc: 0.9509, Val Loss: 0.1267, Val Dice: 0.8735, Val IoU: 0.7772, Val Acc: 0.9306
Saved best model at epoch 24


                                                                                                                       

Epoch 25/100, Train Loss: 0.1178, Train Acc: 0.9504, Val Loss: 0.1449, Val Dice: 0.8553, Val IoU: 0.7561, Val Acc: 0.9293


                                                                                                                       

Epoch 26/100, Train Loss: 0.1148, Train Acc: 0.9520, Val Loss: 0.1238, Val Dice: 0.8764, Val IoU: 0.7838, Val Acc: 0.9357
Saved best model at epoch 26


                                                                                                                       

Epoch 27/100, Train Loss: 0.1136, Train Acc: 0.9531, Val Loss: 0.1331, Val Dice: 0.8670, Val IoU: 0.7669, Val Acc: 0.9264


                                                                                                                       

Epoch 28/100, Train Loss: 0.1126, Train Acc: 0.9534, Val Loss: 0.1339, Val Dice: 0.8663, Val IoU: 0.7706, Val Acc: 0.9331


                                                                                                                       

Epoch 29/100, Train Loss: 0.1093, Train Acc: 0.9546, Val Loss: 0.1201, Val Dice: 0.8800, Val IoU: 0.7884, Val Acc: 0.9362
Saved best model at epoch 29


                                                                                                                       

Epoch 30/100, Train Loss: 0.1119, Train Acc: 0.9532, Val Loss: 0.1311, Val Dice: 0.8691, Val IoU: 0.7738, Val Acc: 0.9328


                                                                                                                       

Epoch 31/100, Train Loss: 0.1108, Train Acc: 0.9533, Val Loss: 0.1546, Val Dice: 0.8455, Val IoU: 0.7421, Val Acc: 0.9261


                                                                                                                       

Epoch 32/100, Train Loss: 0.1151, Train Acc: 0.9517, Val Loss: 0.1431, Val Dice: 0.8570, Val IoU: 0.7600, Val Acc: 0.9308


                                                                                                                       

Epoch 33/100, Train Loss: 0.1092, Train Acc: 0.9543, Val Loss: 0.1361, Val Dice: 0.8641, Val IoU: 0.7661, Val Acc: 0.9308


                                                                                                                       

Epoch 34/100, Train Loss: 0.1147, Train Acc: 0.9517, Val Loss: 0.1163, Val Dice: 0.8839, Val IoU: 0.7957, Val Acc: 0.9392
Saved best model at epoch 34


                                                                                                                       

Epoch 35/100, Train Loss: 0.1136, Train Acc: 0.9525, Val Loss: 0.1652, Val Dice: 0.8349, Val IoU: 0.7223, Val Acc: 0.9177


                                                                                                                       

Epoch 36/100, Train Loss: 0.1161, Train Acc: 0.9516, Val Loss: 0.1371, Val Dice: 0.8631, Val IoU: 0.7632, Val Acc: 0.9280


                                                                                                                       

Epoch 37/100, Train Loss: 0.1096, Train Acc: 0.9544, Val Loss: 0.1391, Val Dice: 0.8610, Val IoU: 0.7595, Val Acc: 0.9290


                                                                                                                       

Epoch 38/100, Train Loss: 0.1076, Train Acc: 0.9553, Val Loss: 0.1269, Val Dice: 0.8732, Val IoU: 0.7812, Val Acc: 0.9366


                                                                                                                       

Epoch 39/100, Train Loss: 0.1063, Train Acc: 0.9559, Val Loss: 0.1217, Val Dice: 0.8785, Val IoU: 0.7869, Val Acc: 0.9365


                                                                                                                       

Epoch 40/100, Train Loss: 0.1058, Train Acc: 0.9560, Val Loss: 0.1299, Val Dice: 0.8702, Val IoU: 0.7777, Val Acc: 0.9364


                                                                                                                       

Epoch 41/100, Train Loss: 0.1055, Train Acc: 0.9562, Val Loss: 0.1263, Val Dice: 0.8738, Val IoU: 0.7794, Val Acc: 0.9326


                                                                                                                       

Epoch 42/100, Train Loss: 0.1009, Train Acc: 0.9574, Val Loss: 0.1243, Val Dice: 0.8758, Val IoU: 0.7815, Val Acc: 0.9339


                                                                                                                       

Epoch 43/100, Train Loss: 0.1022, Train Acc: 0.9573, Val Loss: 0.1202, Val Dice: 0.8799, Val IoU: 0.7874, Val Acc: 0.9347


                                                                                                                       

Epoch 44/100, Train Loss: 0.1038, Train Acc: 0.9565, Val Loss: 0.1179, Val Dice: 0.8822, Val IoU: 0.7944, Val Acc: 0.9400


                                                                                                                       

Epoch 45/100, Train Loss: 0.1007, Train Acc: 0.9577, Val Loss: 0.1636, Val Dice: 0.8365, Val IoU: 0.7309, Val Acc: 0.9245


                                                                                                                       

Epoch 46/100, Train Loss: 0.1050, Train Acc: 0.9558, Val Loss: 0.1226, Val Dice: 0.8775, Val IoU: 0.7881, Val Acc: 0.9385


                                                                                                                       

Epoch 47/100, Train Loss: 0.1015, Train Acc: 0.9576, Val Loss: 0.1248, Val Dice: 0.8753, Val IoU: 0.7838, Val Acc: 0.9363


                                                                                                                       

Epoch 48/100, Train Loss: 0.0977, Train Acc: 0.9586, Val Loss: 0.1105, Val Dice: 0.8895, Val IoU: 0.8037, Val Acc: 0.9417
Saved best model at epoch 48


                                                                                                                       

Epoch 49/100, Train Loss: 0.0985, Train Acc: 0.9585, Val Loss: 0.1337, Val Dice: 0.8664, Val IoU: 0.7663, Val Acc: 0.9256


                                                                                                                       

Epoch 50/100, Train Loss: 0.0988, Train Acc: 0.9586, Val Loss: 0.1288, Val Dice: 0.8713, Val IoU: 0.7768, Val Acc: 0.9338


                                                                                                                       

Epoch 51/100, Train Loss: 0.1010, Train Acc: 0.9577, Val Loss: 0.1242, Val Dice: 0.8759, Val IoU: 0.7850, Val Acc: 0.9374


                                                                                                                       

Epoch 52/100, Train Loss: 0.0976, Train Acc: 0.9590, Val Loss: 0.1119, Val Dice: 0.8882, Val IoU: 0.8011, Val Acc: 0.9407


                                                                                                                       

Epoch 53/100, Train Loss: 0.0972, Train Acc: 0.9590, Val Loss: 0.1099, Val Dice: 0.8903, Val IoU: 0.8046, Val Acc: 0.9422
Saved best model at epoch 53


                                                                                                                       

Epoch 54/100, Train Loss: 0.0950, Train Acc: 0.9599, Val Loss: 0.1058, Val Dice: 0.8942, Val IoU: 0.8102, Val Acc: 0.9434
Saved best model at epoch 54


                                                                                                                       

Epoch 55/100, Train Loss: 0.0950, Train Acc: 0.9597, Val Loss: 0.1166, Val Dice: 0.8835, Val IoU: 0.7933, Val Acc: 0.9373


                                                                                                                       

Epoch 56/100, Train Loss: 0.0914, Train Acc: 0.9612, Val Loss: 0.1096, Val Dice: 0.8904, Val IoU: 0.8052, Val Acc: 0.9410


                                                                                                                       

Epoch 57/100, Train Loss: 0.0987, Train Acc: 0.9589, Val Loss: 0.1666, Val Dice: 0.8335, Val IoU: 0.7163, Val Acc: 0.8978


                                                                                                                       

Epoch 58/100, Train Loss: 0.1011, Train Acc: 0.9573, Val Loss: 0.1214, Val Dice: 0.8787, Val IoU: 0.7862, Val Acc: 0.9331


                                                                                                                       

Epoch 59/100, Train Loss: 0.0958, Train Acc: 0.9601, Val Loss: 0.1214, Val Dice: 0.8787, Val IoU: 0.7860, Val Acc: 0.9353


                                                                                                                       

Epoch 60/100, Train Loss: 0.0919, Train Acc: 0.9616, Val Loss: 0.1187, Val Dice: 0.8814, Val IoU: 0.7909, Val Acc: 0.9354


                                                                                                                       

Epoch 61/100, Train Loss: 0.0886, Train Acc: 0.9627, Val Loss: 0.1232, Val Dice: 0.8768, Val IoU: 0.7874, Val Acc: 0.9387


                                                                                                                       

Epoch 62/100, Train Loss: 0.0949, Train Acc: 0.9604, Val Loss: 0.1160, Val Dice: 0.8841, Val IoU: 0.7948, Val Acc: 0.9375


                                                                                                                       

Epoch 63/100, Train Loss: 0.0910, Train Acc: 0.9619, Val Loss: 0.1174, Val Dice: 0.8827, Val IoU: 0.7937, Val Acc: 0.9376


                                                                                                                       

Epoch 64/100, Train Loss: 0.0918, Train Acc: 0.9611, Val Loss: 0.1224, Val Dice: 0.8777, Val IoU: 0.7849, Val Acc: 0.9339


                                                                                                                       

Epoch 65/100, Train Loss: 0.0934, Train Acc: 0.9607, Val Loss: 0.1114, Val Dice: 0.8887, Val IoU: 0.8034, Val Acc: 0.9429


                                                                                                                       

Epoch 66/100, Train Loss: 0.0937, Train Acc: 0.9604, Val Loss: 0.1896, Val Dice: 0.8104, Val IoU: 0.6933, Val Acc: 0.9150


                                                                                                                       

Epoch 67/100, Train Loss: 0.0980, Train Acc: 0.9588, Val Loss: 0.1533, Val Dice: 0.8467, Val IoU: 0.7416, Val Acc: 0.9264


                                                                                                                       

Epoch 68/100, Train Loss: 0.0991, Train Acc: 0.9594, Val Loss: 0.1097, Val Dice: 0.8904, Val IoU: 0.8063, Val Acc: 0.9422


                                                                                                                       

Epoch 69/100, Train Loss: 0.0880, Train Acc: 0.9629, Val Loss: 0.1086, Val Dice: 0.8915, Val IoU: 0.8074, Val Acc: 0.9438


                                                                                                                       

Epoch 70/100, Train Loss: 0.0839, Train Acc: 0.9643, Val Loss: 0.1174, Val Dice: 0.8827, Val IoU: 0.7948, Val Acc: 0.9388


                                                                                                                       

Epoch 71/100, Train Loss: 0.0845, Train Acc: 0.9643, Val Loss: 0.1177, Val Dice: 0.8824, Val IoU: 0.7943, Val Acc: 0.9383


                                                                                                                       

Epoch 72/100, Train Loss: 0.0878, Train Acc: 0.9637, Val Loss: 0.1229, Val Dice: 0.8771, Val IoU: 0.7861, Val Acc: 0.9361


                                                                                                                       

Epoch 73/100, Train Loss: 0.0915, Train Acc: 0.9617, Val Loss: 0.1110, Val Dice: 0.8890, Val IoU: 0.8024, Val Acc: 0.9404


                                                                                                                       

Epoch 74/100, Train Loss: 0.0913, Train Acc: 0.9613, Val Loss: 0.1186, Val Dice: 0.8814, Val IoU: 0.7921, Val Acc: 0.9370


                                                                                                                       

Epoch 75/100, Train Loss: 0.0922, Train Acc: 0.9610, Val Loss: 0.1235, Val Dice: 0.8766, Val IoU: 0.7825, Val Acc: 0.9331


                                                                                                                       

Epoch 76/100, Train Loss: 0.0933, Train Acc: 0.9609, Val Loss: 0.1107, Val Dice: 0.8893, Val IoU: 0.8041, Val Acc: 0.9417


                                                                                                                       

Epoch 77/100, Train Loss: 0.0831, Train Acc: 0.9646, Val Loss: 0.1261, Val Dice: 0.8739, Val IoU: 0.7779, Val Acc: 0.9304


                                                                                                                       

Epoch 78/100, Train Loss: 0.0867, Train Acc: 0.9636, Val Loss: 0.1177, Val Dice: 0.8824, Val IoU: 0.7955, Val Acc: 0.9397


                                                                                                                       

Epoch 79/100, Train Loss: 0.0829, Train Acc: 0.9651, Val Loss: 0.1083, Val Dice: 0.8918, Val IoU: 0.8087, Val Acc: 0.9440


                                                                                                                       

Epoch 80/100, Train Loss: 0.0835, Train Acc: 0.9645, Val Loss: 0.1504, Val Dice: 0.8497, Val IoU: 0.7462, Val Acc: 0.9284


                                                                                                                       

Epoch 81/100, Train Loss: 0.0878, Train Acc: 0.9632, Val Loss: 0.1146, Val Dice: 0.8855, Val IoU: 0.7976, Val Acc: 0.9390


                                                                                                                       

Epoch 82/100, Train Loss: 0.0869, Train Acc: 0.9631, Val Loss: 0.1122, Val Dice: 0.8878, Val IoU: 0.8030, Val Acc: 0.9409


                                                                                                                       

Epoch 83/100, Train Loss: 0.0832, Train Acc: 0.9648, Val Loss: 0.1171, Val Dice: 0.8830, Val IoU: 0.7948, Val Acc: 0.9392


                                                                                                                       

Epoch 84/100, Train Loss: 0.0834, Train Acc: 0.9647, Val Loss: 0.1315, Val Dice: 0.8685, Val IoU: 0.7713, Val Acc: 0.9338


                                                                                                                       

Epoch 85/100, Train Loss: 0.0881, Train Acc: 0.9628, Val Loss: 0.1211, Val Dice: 0.8790, Val IoU: 0.7898, Val Acc: 0.9390


                                                                                                                       

Epoch 86/100, Train Loss: 0.0819, Train Acc: 0.9655, Val Loss: 0.1260, Val Dice: 0.8740, Val IoU: 0.7790, Val Acc: 0.9320


                                                                                                                       

Epoch 87/100, Train Loss: 0.0807, Train Acc: 0.9654, Val Loss: 0.1156, Val Dice: 0.8844, Val IoU: 0.7967, Val Acc: 0.9396


                                                                                                                       

Epoch 88/100, Train Loss: 0.0818, Train Acc: 0.9654, Val Loss: 0.1075, Val Dice: 0.8925, Val IoU: 0.8097, Val Acc: 0.9425


                                                                                                                       

Epoch 89/100, Train Loss: 0.0804, Train Acc: 0.9662, Val Loss: 0.1607, Val Dice: 0.8394, Val IoU: 0.7366, Val Acc: 0.9262


                                                                                                                       

Epoch 90/100, Train Loss: 0.0884, Train Acc: 0.9629, Val Loss: 0.1181, Val Dice: 0.8819, Val IoU: 0.7917, Val Acc: 0.9369


                                                                                                                       

Epoch 91/100, Train Loss: 0.0824, Train Acc: 0.9651, Val Loss: 0.1308, Val Dice: 0.8693, Val IoU: 0.7710, Val Acc: 0.9251


                                                                                                                       

Epoch 92/100, Train Loss: 0.0879, Train Acc: 0.9632, Val Loss: 0.1094, Val Dice: 0.8907, Val IoU: 0.8073, Val Acc: 0.9417


                                                                                                                       

Epoch 93/100, Train Loss: 0.0803, Train Acc: 0.9659, Val Loss: 0.1143, Val Dice: 0.8857, Val IoU: 0.7982, Val Acc: 0.9387


                                                                                                                       

Epoch 94/100, Train Loss: 0.0774, Train Acc: 0.9673, Val Loss: 0.1217, Val Dice: 0.8784, Val IoU: 0.7870, Val Acc: 0.9360


                                                                                                                       

Epoch 95/100, Train Loss: 0.1065, Train Acc: 0.9552, Val Loss: 0.1365, Val Dice: 0.8635, Val IoU: 0.7648, Val Acc: 0.9299


                                                                                                                       

Epoch 96/100, Train Loss: 0.0897, Train Acc: 0.9623, Val Loss: 0.1275, Val Dice: 0.8725, Val IoU: 0.7773, Val Acc: 0.9332


                                                                                                                       

Epoch 97/100, Train Loss: 0.0856, Train Acc: 0.9640, Val Loss: 0.1186, Val Dice: 0.8814, Val IoU: 0.7923, Val Acc: 0.9390


                                                                                                                       

Epoch 98/100, Train Loss: 0.0820, Train Acc: 0.9656, Val Loss: 0.1066, Val Dice: 0.8934, Val IoU: 0.8094, Val Acc: 0.9433


                                                                                                                       

Epoch 99/100, Train Loss: 0.0788, Train Acc: 0.9667, Val Loss: 0.1201, Val Dice: 0.8800, Val IoU: 0.7907, Val Acc: 0.9398


                                                                                                                       

Epoch 100/100, Train Loss: 0.0806, Train Acc: 0.9663, Val Loss: 0.1109, Val Dice: 0.8891, Val IoU: 0.8047, Val Acc: 0.9420


                                                                                                                       

ValueError: Classification metrics can't handle a mix of continuous and binary targets

In [9]:
import os
import numpy as np
import torch
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import pandas as pd

# Evaluate on test dataset
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
test_loss = 0.0
test_dice = 0.0
test_iou = 0.0
test_acc = 0.0
test_predictions = []
test_labels = []
test_start_time = time.time()

with torch.no_grad():
    test_bar = tqdm(test_loader, desc="Testing", leave=False)
    for images, masks in test_bar:
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        loss = criterion(outputs, masks)
        test_loss += loss.item() * images.size(0)
        outputs_bin = (outputs > 0.5).float()  # Binarize predictions
        test_dice += dice_coef(outputs_bin, masks).item() * images.size(0)
        test_iou += iou(outputs_bin, masks).item() * images.size(0)
        test_acc += accuracy(outputs_bin, masks).item() * images.size(0)
        test_predictions.append(outputs_bin.cpu().numpy().flatten().astype(np.int32))  # Convert to binary integers
        test_labels.append(masks.cpu().numpy().flatten().astype(np.int32))  # Convert to binary integers
        test_bar.set_postfix({"Test Loss": loss.item()})

test_loss /= len(test_loader.dataset)
test_dice /= len(test_loader.dataset)
test_iou /= len(test_loader.dataset)
test_acc /= len(test_loader.dataset)
test_time = (time.time() - test_start_time) * 1000 / len(test_loader)  # ms per batch

# Compute confusion matrix
test_predictions = np.concatenate(test_predictions)
test_labels = np.concatenate(test_labels)
cm = confusion_matrix(test_labels, test_predictions, labels=[0, 1])
tn, fp, fn, tp = cm.ravel()

print("\nTest Metrics:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Dice: {test_dice:.4f}")
print(f"Test IoU: {test_iou:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print("\nConfusion Matrix:")
print(f"True Negatives: {tn}")
print(f"False Positives: {fp}")
print(f"False Negatives: {fn}")
print(f"True Positives: {tp}")

# Generate plots
# Plot 1: Training and Validation Loss
plt.figure(figsize=(10, 6))
plt.plot(history['epoch'], history['train_loss'], label='Training Loss')
plt.plot(history['epoch'], history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join("Segmentation_Results", "loss_plot.png"))
plt.close()

# Plot 2: Validation Dice and IoU
plt.figure(figsize=(10, 6))
plt.plot(history['epoch'], history['val_dice'], label='Validation Dice')
plt.plot(history['epoch'], history['val_iou'], label='Validation IoU')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.title('Validation Dice and IoU Over Epochs')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join("Segmentation_Results", "dice_iou_plot.png"))
plt.close()

# Plot 3: Training and Validation Accuracy
plt.figure(figsize=(10, 6))
plt.plot(history['epoch'], history['train_acc'], label='Training Accuracy')
plt.plot(history['epoch'], history['val_acc'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy Over Epochs')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join("Segmentation_Results", "accuracy_plot.png"))
plt.close()

# Plot 4: Sample Test Predictions
num_samples = 3
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
axes[0, 0].set_title("Input Image")
axes[0, 1].set_title("Ground Truth")
axes[0, 2].set_title("Predicted Mask")

with torch.no_grad():
    for i, (images, masks) in enumerate(test_loader):
        if i >= num_samples:
            break
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        outputs_bin = (outputs > 0.5).float()
        
        img = images[0].cpu().numpy().transpose(1, 2, 0) * 255.0
        mask = masks[0][0].cpu().numpy() * 255.0
        pred = outputs_bin[0][0].cpu().numpy() * 255.0

        axes[i, 0].imshow(img.astype(np.uint8))
        axes[i, 0].axis('off')
        axes[i, 1].imshow(mask, cmap='gray')
        axes[i, 1].axis('off')
        axes[i, 2].imshow(pred, cmap='gray')
        axes[i, 2].axis('off')

plt.tight_layout()
plt.savefig(os.path.join("Segmentation_Results", "test_predictions.png"))
plt.close()

# Generate metrics table
metrics_table = []
for epoch in range(num_epochs):
    test_acc_epoch = test_acc  # Use final test accuracy for all epochs
    val_acc_epoch = history['val_acc'][epoch]
    train_acc_epoch = history['train_acc'][epoch]
    abs_variance = abs(test_acc_epoch - val_acc_epoch)
    time_per_step = (history['train_time'][epoch] + history['val_time'][epoch]) / 2  # Average of train and val
    metrics_table.append({
        'Epoch': epoch + 1,
        'Test Accuracy (%)': test_acc_epoch * 100,
        'Validation Accuracy (%)': val_acc_epoch * 100,
        'Training Accuracy (%)': train_acc_epoch * 100,
        'Absolute Variance': abs_variance * 100,
        'Time per Step (ms/step)': time_per_step
    })

metrics_df = pd.DataFrame(metrics_table)
metrics_df.to_csv(os.path.join("Segmentation_Results", "metrics_table.csv"), index=False)
print("\nMetrics table saved to Segmentation_Results/metrics_table.csv")

                                                                                                                       


Test Metrics:
Test Loss: 0.1408
Test Dice: 0.8593
Test IoU: 0.7589
Test Accuracy: 0.9220

Confusion Matrix:
True Negatives: 44956796
False Positives: 2292415
False Negatives: 2772513
True Positives: 15514276

Metrics table saved to Segmentation_Results/metrics_table.csv


In [11]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from glob import glob
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F

# Data directories
train_image_dir = "E:/Code/ClassPlusSeg/Final_Dataset/ISIC_2019_Training_Input_Cleaned_Images"
test_image_dir = "E:/Code/ClassPlusSeg/Final_Dataset/ISIC_2019_Test_Input_Cleaned_Images"
train_pseudo_mask_dir = "E:/Code/ClassPlusSeg/Final_Dataset/ISIC_2019_Training_Pseudo_Masks"
test_pseudo_mask_dir = "E:/Code/ClassPlusSeg/Final_Dataset/ISIC_2019_Test_Pseudo_Masks"
model_path = "files/segModel.pth"

# Global parameters
H = 256
W = 256
batch_size = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# SE Block Definition
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

# BatchNormReLU
class BatchNormReLU(nn.Module):
    def __init__(self, num_features):
        super(BatchNormReLU, self).__init__()
        self.bn = nn.BatchNorm2d(num_features)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.bn(x)
        x = self.relu(x)
        return x

# Residual Block with SE
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, strides=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=strides, padding=1, bias=False)
        self.bn_relu1 = BatchNormReLU(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_relu2 = BatchNormReLU(out_channels)
        self.se = SEBlock(out_channels)

        self.shortcut = nn.Sequential()
        if strides != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=strides, padding=0, bias=False)
            )

    def forward(self, x):
        identity = self.shortcut(x)
        x = self.conv1(x)
        x = self.bn_relu1(x)
        x = self.conv2(x)
        x = self.bn_relu2(x)
        x = self.se(x)
        x = x + identity
        return x

# Decoder Block with SE
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.concat_channels = in_channels + skip_channels
        self.res_block = ResidualBlock(self.concat_channels, out_channels, strides=1)
        self.se = SEBlock(out_channels)

    def forward(self, x, skip):
        x = self.upsample(x)
        x = torch.cat([x, skip], dim=1)
        x = self.res_block(x)
        x = self.se(x)
        return x

# Dense PPM Bridge
class DensePPMBridge(nn.Module):
    def __init__(self, in_channels=256, out_channels=512):
        super(DensePPMBridge, self).__init__()
        growth_rate = 128

        self.dense1 = nn.Sequential(
            nn.Conv2d(in_channels, growth_rate, kernel_size=3, stride=2, padding=1, bias=False),
            BatchNormReLU(growth_rate)
        )
        self.dense2 = nn.Sequential(
            nn.Conv2d(in_channels + growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False),
            BatchNormReLU(growth_rate)
        )

        self.ppm_pool1 = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels + 2 * growth_rate, 128, 1, bias=False), BatchNormReLU(128))
        self.ppm_pool2 = nn.Sequential(nn.AdaptiveAvgPool2d(2), nn.Conv2d(in_channels + 2 * growth_rate, 128, 1, bias=False), BatchNormReLU(128))
        self.ppm_pool3 = nn.Sequential(nn.AdaptiveAvgPool2d(3), nn.Conv2d(in_channels + 2 * growth_rate, 128, 1, bias=False), BatchNormReLU(128))
        self.ppm_pool6 = nn.Sequential(nn.AdaptiveAvgPool2d(6), nn.Conv2d(in_channels + 2 * growth_rate, 128, 1, bias=False), BatchNormReLU(128))
        
        self.final_conv = nn.Conv2d(in_channels + 2 * growth_rate + 4 * 128, out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        d1 = self.dense1(x)
        d1_cat = torch.cat([F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True), d1], dim=1)
        d2 = self.dense2(d1_cat)
        dense_out = torch.cat([F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True), d1, d2], dim=1)

        ppm1 = self.ppm_pool1(dense_out)
        ppm1 = F.interpolate(ppm1, size=dense_out.size()[2:], mode='bilinear', align_corners=True)
        ppm2 = self.ppm_pool2(dense_out)
        ppm2 = F.interpolate(ppm2, size=dense_out.size()[2:], mode='bilinear', align_corners=True)
        ppm3 = self.ppm_pool3(dense_out)
        ppm3 = F.interpolate(ppm3, size=dense_out.size()[2:], mode='bilinear', align_corners=True)
        ppm6 = self.ppm_pool6(dense_out)
        ppm6 = F.interpolate(ppm6, size=dense_out.size()[2:], mode='bilinear', align_corners=True)

        ppm_out = torch.cat([dense_out, ppm1, ppm2, ppm3, ppm6], dim=1)
        out = self.final_conv(ppm_out)
        return out

# SE-ResUNet Model
class SEResUNet(nn.Module):
    def __init__(self, input_shape=(256, 256, 3)):
        super(SEResUNet, self).__init__()
        self.input_shape = input_shape

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_relu1 = BatchNormReLU(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.se1 = SEBlock(64)
        self.shortcut1 = nn.Conv2d(3, 64, kernel_size=1, stride=1, padding=0, bias=False)

        self.encoder2 = ResidualBlock(64, 128, strides=2)
        self.encoder3 = ResidualBlock(128, 256, strides=2)
        self.bridge = DensePPMBridge(256, 512)

        self.decoder1 = DecoderBlock(512, 256, 256)
        self.decoder2 = DecoderBlock(256, 128, 128)
        self.decoder3 = DecoderBlock(128, 64, 64)

        self.output = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        s1 = self.conv1(x)
        s1 = self.bn_relu1(s1)
        s1 = self.conv2(s1)
        s1 = self.se1(s1)
        shortcut = self.shortcut1(x)
        s1 = s1 + shortcut

        s2 = self.encoder2(s1)
        s3 = self.encoder3(s2)
        b = self.bridge(s3)

        d1 = self.decoder1(b, s3)
        d2 = self.decoder2(d1, s2)
        d3 = self.decoder3(d2, s1)

        out = self.output(d3)
        out = self.sigmoid(out)
        return out

# Dataset for loading images (no masks needed for pseudo mask generation)
class ISICImageDataset(Dataset):
    def __init__(self, images):
        self.images = images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = cv2.resize(img, (W, H))
        img = img / 255.0
        img = img.astype(np.float32)
        img = np.transpose(img, (2, 0, 1))
        return torch.tensor(img, dtype=torch.float32), img_path

# Utility function to create directories
def create_dir(path):
    try:
        if not os.path.exists(path):
            os.makedirs(path)
            print(f"Created directory: {path}")
        else:
            print(f"Directory already exists: {path}")
    except OSError as e:
        print(f"Error creating directory {path}: {e}")
        raise

# Load image paths
def load_image_paths():
    train_images = sorted(glob(os.path.join(train_image_dir, "*.jpg")))
    test_images = sorted(glob(os.path.join(test_image_dir, "*.jpg")))
    return train_images, test_images

# Generate and save pseudo masks
def generate_pseudo_masks(model, image_loader, output_dir):
    model.eval()
    with torch.no_grad():
        for images, img_paths in tqdm(image_loader, desc=f"Generating pseudo masks for {output_dir}"):
            images = images.to(device)
            outputs = model(images)
            outputs_bin = (outputs > 0.5).float()  # Binarize predictions

            for i in range(images.size(0)):
                pred_mask = outputs_bin[i, 0].cpu().numpy() * 255.0  # Scale to 0-255 for saving as image
                pred_mask = pred_mask.astype(np.uint8)
                img_name = os.path.basename(img_paths[i]).replace('.jpg', '.png')
                output_path = os.path.join(output_dir, img_name)
                cv2.imwrite(output_path, pred_mask)

# Main execution
if __name__ == "__main__":
    # Set random seed for reproducibility
    np.random.seed(42)
    torch.manual_seed(42)

    # Create output directories
    create_dir(train_pseudo_mask_dir)
    create_dir(test_pseudo_mask_dir)

    # Load image paths
    train_images, test_images = load_image_paths()
    print(f"Training images: {len(train_images)}")
    print(f"Test images: {len(test_images)}")

    # Create datasets and dataloaders
    train_dataset = ISICImageDataset(train_images)
    test_dataset = ISICImageDataset(test_images)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Load model
    model = SEResUNet(input_shape=(H, W, 3)).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    print(f"Loaded model from {model_path}")

    # Generate pseudo masks
    print("Generating pseudo masks for training set...")
    generate_pseudo_masks(model, train_loader, train_pseudo_mask_dir)
    print("Generating pseudo masks for test set...")
    generate_pseudo_masks(model, test_loader, test_pseudo_mask_dir)
    print("Pseudo mask generation completed.")

Created directory: E:/Code/ClassPlusSeg/Final_Dataset/ISIC_2019_Training_Pseudo_Masks
Created directory: E:/Code/ClassPlusSeg/Final_Dataset/ISIC_2019_Test_Pseudo_Masks
Training images: 25331
Test images: 8238
Loaded model from files/segModel.pth
Generating pseudo masks for training set...


Generating pseudo masks for E:/Code/ClassPlusSeg/Final_Dataset/ISIC_2019_Training_Pseudo_Masks: 100%|█| 1584/1584 [07:3


Generating pseudo masks for test set...


Generating pseudo masks for E:/Code/ClassPlusSeg/Final_Dataset/ISIC_2019_Test_Pseudo_Masks: 100%|█| 515/515 [02:36<00:0

Pseudo mask generation completed.





In [31]:
import pandas as pd
import numpy as np
from imblearn.over_sampling import SMOTE
from sklearn.preprocessing import LabelEncoder
from PIL import Image
import os
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm  # Import tqdm for progress bar

# Paths
train_image_dir = "E:/Code/ClassPlusSeg/Final_Dataset/ISIC_2019_Training_Input_Cleaned_Images"
train_pseudo_mask_dir = "E:/Code/ClassPlusSeg/Final_Dataset/ISIC_2019_Training_Pseudo_Masks"
groundtruth = "E:/Code/ClassPlusSeg/Final_Dataset/ISIC_2019_Training_GroundTruth_Transformed.csv"

# Image size
IMG_SIZE = 224

# Label mapping
label_mapping = {
    'NV': 0, 'MEL': 1, 'BCC': 2, 'BKL': 3,
    'AK': 4, 'SCC': 5, 'VASC': 6, 'DF': 7
}

# Verify paths
for path in [groundtruth, train_image_dir, train_pseudo_mask_dir]:
    if not os.path.exists(path):
        raise FileNotFoundError(f"The path {path} does not exist. Please check and update the path.")

# Custom Dataset for SMOTE
class ISIC2019FeatureDataset(Dataset):
    def __init__(self, dataframe, image_dir, pseudo_mask_dir):
        self.dataframe = dataframe
        self.image_dir = image_dir
        self.pseudo_mask_dir = pseudo_mask_dir
        self.transform = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = self.dataframe.iloc[idx]["image"]
        img_path = os.path.join(self.image_dir, f"{img_name}.jpg")
        mask_path = os.path.join(self.pseudo_mask_dir, f"{img_name}.png")
        label_str = self.dataframe.iloc[idx]["diagnosis"]
        label = label_mapping[label_str]

        # Verify image and mask exist
        if not os.path.exists(img_path):
            print(f"Warning: Image {img_path} not found. Skipping.")
            return None
        if not os.path.exists(mask_path):
            print(f"Warning: Mask {mask_path} not found. Skipping.")
            return None

        # Load image and mask
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        # Apply resize transform
        image = self.transform(image)
        mask = transforms.Resize((IMG_SIZE, IMG_SIZE))(mask)
        mask = np.array(mask, dtype=np.float32) / 255.0

        # Flatten image features (RGB + mask)
        image = image.numpy().flatten()
        mask = mask.flatten()
        features = np.concatenate([image, mask])

        return features, label

# Load ground truth
try:
    df = pd.read_csv(groundtruth)
    print("Ground truth loaded successfully. Columns:", df.columns.tolist())
except Exception as e:
    raise Exception(f"Failed to load CSV at {groundtruth}: {str(e)}")

# Verify required columns
required_columns = ["image", "diagnosis"]
if not all(col in df.columns for col in required_columns):
    raise ValueError(f"CSV must contain columns: {required_columns}. Found: {df.columns.tolist()}")

# Create dataset for feature extraction
feature_dataset = ISIC2019FeatureDataset(df, train_image_dir, train_pseudo_mask_dir)
features = []
labels = []

# Extract features and labels with tqdm progress bar
for item in tqdm(feature_dataset, desc="Extracting features"):
    if item is None:  # Skip invalid entries
        continue
    feature, label = item
    features.append(feature)
    labels.append(label)

features = np.array(features)
labels = np.array(labels)

# Apply SMOTE
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(features, labels)

# Create new DataFrame for resampled data
resampled_data = []
for i, (feature, label) in enumerate(zip(X_resampled, y_resampled)):
    img_name = f"synthetic_{i}"
    label_str = [k for k, v in label_mapping.items() if v == label][0]
    resampled_data.append({"image": img_name, "diagnosis": label_str})

resampled_df = pd.DataFrame(resampled_data)

# Save resampled DataFrame
output_path = "E:/Code/ClassPlusSeg/Final_Dataset/ISIC_2019_Training_GroundTruth_SMOTE.csv"
resampled_df.to_csv(output_path, index=False)
print(f"SMOTE-balanced CSV saved to: {output_path}")

# Print class distribution after SMOTE
print("Class distribution after SMOTE:")
print(resampled_df['diagnosis'].value_counts())

Ground truth loaded successfully. Columns: ['image', 'diagnosis']


Extracting features: 100%|███████████████████████████████████████████████████████| 25331/25331 [08:25<00:00, 50.07it/s]


SMOTE-balanced CSV saved to: E:/Code/ClassPlusSeg/Final_Dataset/ISIC_2019_Training_GroundTruth_SMOTE.csv
Class distribution after SMOTE:
diagnosis
NV      12875
MEL     12875
BKL     12875
DF      12875
SCC     12875
BCC     12875
VASC    12875
AK      12875
Name: count, dtype: int64


In [None]:
import torch
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from imblearn.over_sampling import SMOTE
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights, densenet169, DenseNet169_Weights
from tqdm import tqdm
import os
import time

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define SE (Squeeze-and-Excitation) Block
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

# Define CBAM (Convolutional Block Attention Module) + SE Block
class CBAM_SE(nn.Module):
    def __init__(self, channels, reduction=16):
        super(CBAM_SE, self).__init__()
        # Channel Attention (CBAM)
        self.channel_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.channel_max_pool = nn.AdaptiveMaxPool2d(1)
        self.channel_mlp = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False)
        )
        self.channel_sigmoid = nn.Sigmoid()

        # Spatial Attention (CBAM)
        self.spatial_conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
        self.spatial_sigmoid = nn.Sigmoid()

        # SE Block
        self.se = SEBlock(channels, reduction)

    def forward(self, x):
        # Channel Attention (CBAM)
        avg_pool = self.channel_avg_pool(x).view(x.size(0), x.size(1))
        max_pool = self.channel_max_pool(x).view(x.size(0), x.size(1))
        avg_out = self.channel_mlp(avg_pool)
        max_out = self.channel_mlp(max_pool)
        channel_att = self.channel_sigmoid(avg_out + max_out).view(x.size(0), x.size(1), 1, 1)
        x = x * channel_att

        # Spatial Attention (CBAM)
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        max_pool, _ = torch.max(x, dim=1, keepdim=True)
        spatial_input = torch.cat([avg_pool, max_pool], dim=1)
        spatial_att = self.spatial_sigmoid(self.spatial_conv(spatial_input))
        x = x * spatial_att

        # Apply SE Block
        x = self.se(x)
        return x

# Define the Ensemble Model
class EnsembleModel(nn.Module):
    def __init__(self, num_classes=8):
        super(EnsembleModel, self).__init__()

        # Load Pretrained models
        self.efficientnet = efficientnet_b4(weights=EfficientNet_B4_Weights.IMAGENET1K_V1)
        self.densenet = densenet169(weights=DenseNet169_Weights.IMAGENET1K_V1)

        # Extract feature extractor part
        self.efficientnet = self.efficientnet.features
        self.densenet = self.densenet.features

        # Pooling & Flattening
        self.pool = nn.AdaptiveMaxPool2d(1)
        self.flatten = nn.Flatten()

        # CBAM + SE Attention after concatenation
        self.cbam_se = CBAM_SE(1792 + 1664)

        # Fully Connected layers
        self.fc = nn.Sequential(
            nn.Linear(1792 + 1664, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        # Feature extraction
        feat_a = self.efficientnet(x)
        feat_b = self.densenet(x)

        # Concatenate features
        combined = torch.cat((feat_a, feat_b), dim=1)

        # Apply CBAM + SE attention
        combined = self.cbam_se(combined)

        # Max Pooling and Flatten
        combined = self.pool(combined)
        combined = self.flatten(combined)

        # Fully connected layers
        out = self.fc(combined)
        return out

# Feature Extractor for SMOTE
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.efficientnet = efficientnet_b4(weights=EfficientNet_B4_Weights.IMAGENET1K_V1).features
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten()

    def forward(self, x):
        features = self.efficientnet(x)
        features = self.pool(features)
        features = self.flatten(features)
        return features

# Label mapping for multiclass classification
label_mapping = {
    "MEL": 0,  # Melanoma
    "NV": 1,   # Melanocytic Nevus
    "BCC": 2,  # Basal Cell Carcinoma
    "AK": 3,   # Actinic Keratosis
    "BKL": 4,  # Benign Keratosis
    "DF": 5,   # Dermatofibroma
    "VASC": 6, # Vascular Lesion
    "SCC": 7   # Squamous Cell Carcinoma
}

# Custom Dataset for SMOTE Feature Extraction
class ISIC2019FeatureDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        self.dataframe = dataframe
        self.image_dir = image_dir
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = self.dataframe.iloc[idx]["image"]
        img_path = os.path.join(self.image_dir, f"{img_name}.jpg")
        label_str = self.dataframe.iloc[idx]["diagnosis"]
        label = label_mapping[label_str]

        if not os.path.exists(img_path):
            print(f"Warning: Image {img_path} not found. Skipping.")
            return None

        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        return image, label

# Dataset class for training/validation/test
class SkinLesionDataset(Dataset):
    def __init__(self, dataframe, image_folder, synthetic_image_folder=None, transform=None):
        self.dataframe = dataframe
        self.image_folder = image_folder
        self.synthetic_image_folder = synthetic_image_folder
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = self.dataframe.iloc[idx]["image"]
        label_str = self.dataframe.iloc[idx]["diagnosis"]
        label = label_mapping[label_str]

        # Determine image path (synthetic or original)
        if self.synthetic_image_folder and "synthetic" in img_name:
            img_path = os.path.join(self.synthetic_image_folder, f"{img_name}.jpg")
        else:
            img_path = os.path.join(self.image_folder, f"{img_name}.jpg")

        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image {img_path} not found.")

        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)

# Load training dataset
train_csv_path = r"E:\Code\ClassPlusSeg\Final_Dataset\ISIC_2019_Training_GroundTruth_Transformed.csv"
train_df = pd.read_csv(train_csv_path)
print("Training dataset loaded. Columns:", train_df.columns.tolist())

# Split training data into train (80%) and validation (20%)
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42, stratify=train_df["diagnosis"])
print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}")

# Load test dataset
test_csv_path = r"E:\Code\ClassPlusSeg\Final_Dataset\ISIC_2019_Testing_GroundTruth_Transformed_Simplified.csv"
test_df = pd.read_csv(test_csv_path)
print(f"Test dataset loaded. Columns: {test_df.columns.tolist()}, Test size: {len(test_df)}")

# Extract features for SMOTE using EfficientNet-B4
image_dir = r"E:\Code\ClassPlusSeg\Final_Dataset\ISIC_2019_Training_Input_Cleaned_Images"
feature_dataset = ISIC2019FeatureDataset(train_df, image_dir)
feature_loader = DataLoader(feature_dataset, batch_size=32, shuffle=False, num_workers=4)

feature_extractor = FeatureExtractor().to(device)
feature_extractor.eval()

features = []
labels = []

with torch.no_grad():
    for images, lbls in tqdm(feature_loader, desc="Extracting features for SMOTE"):
        if images is None:
            continue
        images = images.to(device)
        feats = feature_extractor(images)
        features.append(feats.cpu().numpy())
        labels.append(lbls.numpy())

features = np.vstack(features)
labels = np.concatenate(labels)

# Apply SMOTE
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(features, labels)

# Create synthetic DataFrame (map to original images for simplicity)
resampled_data = []
original_images = train_df["image"].values
original_labels = train_df["diagnosis"].values
label_to_images = {lbl: [] for lbl in label_mapping.keys()}
for img, lbl in zip(original_images, original_labels):
    label_to_images[lbl].append(img)

for i, label in enumerate(tqdm(y_resampled, desc="Creating synthetic DataFrame")):
    label_str = [k for k, v in label_mapping.items() if v == label][0]
    # Randomly select an original image from the same class
    img_name = np.random.choice(label_to_images[label_str])
    resampled_data.append({"image": img_name, "diagnosis": label_str})

train_smote_df = pd.DataFrame(resampled_data)
print("SMOTE applied to training data. New train size:", len(train_smote_df))
print("Training class distribution after SMOTE:")
print(train_smote_df['diagnosis'].value_counts())

# Image Augmentation for training
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test Transform (no augmentation)
val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_image_folder = r"E:\Code\ClassPlusSeg\Final_Dataset\ISIC_2019_Training_Input_Cleaned_Images"
test_image_folder = r"E:\Code\ClassPlusSeg\Final_Dataset\ISIC_2019_Testing_Input_Cleaned_Images"
synthetic_image_folder = None  # Not using synthetic images
train_dataset = SkinLesionDataset(train_smote_df, train_image_folder, synthetic_image_folder, train_transform)
val_dataset = SkinLesionDataset(val_df, train_image_folder, transform=val_test_transform)
test_dataset = SkinLesionDataset(test_df, test_image_folder, transform=val_test_transform)

# Create DataLoaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
print(f"Train DataLoader size: {len(train_loader.dataset)}, Validation DataLoader size: {len(val_loader.dataset)}, Test DataLoader size: {len(test_loader.dataset)}")

# Initialize model
num_classes = len(label_mapping)
model = EnsembleModel(num_classes=num_classes).to(device)
print("Model initialized and moved to GPU.")

# Loss, Optimizer, and Scheduler
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

# Evaluation Function
def evaluate(model, loader, criterion, device, desc="Evaluating"):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    all_preds, all_labels = [], []
    start_time = time.time()

    with torch.no_grad():
        for images, labels in tqdm(loader, desc=desc):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    elapsed_time = time.time() - start_time
    ms_per_step = (elapsed_time / len(loader)) * 1000  # ms/step
    return (total_loss / len(loader), 100 * correct / total, ms_per_step, np/array(all_preds), np.array(all_labels))

# Training Function
def train_model(model, train_loader, val_loader, test_loader, criterion, optimizer, scheduler, device, epochs=50, patience=5, save_path='best_ensemble_model.pth'):
    best_val_acc = 0
    epochs_no_improve = 0
    best_model_path = save_path

    for epoch in range(epochs):
        model.train()
        running_loss, correct, total = 0, 0, 0
        start_time = time.time()
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})

        train_time = time.time() - start_time
        train_ms_per_step = (train_time / len(train_loader)) * 1000  # ms/step
        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct / total

        # Evaluate on validation and test sets
        val_loss, val_acc, val_ms_per_step, _, _ = evaluate(model, val_loader, criterion, device, desc="Validating")
        test_loss, test_acc, test_ms_per_step, test_preds, test_labels = evaluate(model, test_loader, criterion, device, desc="Testing")

        # Compute absolute variance
        abs_variance = abs(test_acc - val_acc)

        # Update scheduler
        scheduler.step(val_loss)

        # Print metrics
        print(f"\nEpoch {epoch+1}/{epochs}:")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Validation Loss: {val_loss:.4f}, Validation Acc: {val_acc:.2f}%")
        print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
        print(f"Absolute Variance (Test - Validation Acc): {abs_variance:.2f}%")
        print(f"Time per Step (ms/step) - Train: {train_ms_per_step:.2f}, Val: {val_ms_per_step:.2f}, Test: {test_ms_per_step:.2f}")
        print(f"Learning Rate: {scheduler.get_last_lr()[0]:.6f}")

        # Confusion Matrix
        cm = confusion_matrix(test_labels, test_preds)
        print("Test Confusion Matrix:")
        print(cm)

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            epochs_no_improve = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, best_model_path)
            print(f"New best model saved with Val Acc: {val_acc:.2f}%")
        else:
            epochs_no_improve += 1

        # Early stopping
        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

    # Load and evaluate best model
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"\nLoaded best model with Val Acc: {checkpoint['val_acc']:.2f}%")

    # Final evaluation
    final_train_loss, final_train_acc, final_train_ms, _, _ = evaluate(model, train_loader, criterion, device, desc="Final Train Eval")
    final_val_loss, final_val_acc, final_val_ms, _, _ = evaluate(model, val_loader, criterion, device, desc="Final Val Eval")
    final_test_loss, final_test_acc, final_test_ms, final_test_preds, final_test_labels = evaluate(model, test_loader, criterion, device, desc="Final Test Eval")
    final_abs_variance = abs(final_test_acc - final_val_acc)
    final_cm = confusion_matrix(final_test_labels, final_test_preds)

    print("\nFinal Metrics (Best Model):")
    print(f"Train Loss: {final_train_loss:.4f}, Train Acc: {final_train_acc:.2f}%")
    print(f"Validation Loss: {final_val_loss:.4f}, Validation Acc: {final_val_acc:.2f}%")
    print(f"Test Loss: {final_test_loss:.4f}, Test Acc: {final_test_acc:.2f}%")
    print(f"Absolute Variance (Test - Validation Acc): {final_abs_variance:.2f}%")
    print(f"Time per Step (ms/step) - Train: {final_train_ms:.2f}, Val: {final_val_ms:.2f}, Test: {final_test_ms:.2f}")
    print("Final Test Confusion Matrix:")
    print(final_cm)

    return model

# Run training
def main():
    model = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        epochs=50,
        patience=5,
        save_path='best_ensemble_model_low_memory.pth'
    )
    return model

if __name__ == "__main__":
    main()