In [None]:
from google.colab import drive
# Mount the Google Drive at /content/drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# !pip install -r /content/drive/MyDrive/ML/requirement.txt

In [None]:
# !pip install py-spy

In [None]:
# !python /content/drive/MyDrive/ML/scripts/segmentation_train.py --data_name ISIC --data_dir /content/drive/MyDrive/ML/dataset --out_dir /content/drive/MyDrive/ML/output/Test --image_size 256 --num_channels 128 --class_cond False --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16 --diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False --lr 1e-4 --batch_size 4 --lr_anneal_steps 1

Training

In [None]:
# !python /content/drive/MyDrive/ML/scripts/segmentation_train.py --data_name ISIC --data_dir /content/drive/MyDrive/ML/dataset --out_dir /content/drive/MyDrive/ML/output/Test --image_size 256 --num_channels 128 --class_cond False --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16 --diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False --lr 1e-4 --batch_size 4 --lr_anneal_steps 1

segmentation

In [None]:
# !python /content/drive/MyDrive/ML/scripts/segmentation_sample.py --data_name ISIC --data_dir /content/drive/MyDrive/ML/dataset --out_dir /content/drive/MyDrive/ML/output/Test/segmentation --model_path /content/drive/MyDrive/ML/output/Test/emasavedmodel_0.9999_000002.pt --image_size 256 --num_channels 128 --class_cond False --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16 --diffusion_steps 100 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False --num_ensemble 5

In [None]:
# !python /content/drive/MyDrive/ML/scripts/segmentation_env.py --inp_pth /content/drive/MyDrive/ML/output/Test/segmentation --out_pth /content/drive/MyDrive/ML/dataset/ISIC2018_Task1_Validation_GroundTruth


In [None]:
# ============================== Imports and Dependencies ==============================

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from zipfile import ZipFile

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve, confusion_matrix

# ================================ Separable Convolution =================================

class SeparableConv2d(nn.Module):
    """
    Implements a separable convolution layer using depthwise and pointwise convolutions.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, bias=True):
        super(SeparableConv2d, self).__init__()
        # Depthwise convolution (groups=in_channels)
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
                                   padding=padding, groups=in_channels, bias=bias)
        # Pointwise convolution
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1,
                                   padding=0, bias=bias)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

# ================================== ConvLSTM2D ========================================

class ConvLSTMCell(nn.Module):
    """
    Implements a ConvLSTM cell.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):
        super(ConvLSTMCell, self).__init__()

        padding = kernel_size // 2  # To maintain spatial dimensions
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.conv = nn.Conv2d(in_channels=input_channels + hidden_channels,
                              out_channels=4 * hidden_channels,
                              kernel_size=kernel_size,
                              padding=padding,
                              bias=bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        # Concatenate input and hidden state
        combined = torch.cat([input_tensor, h_cur], dim=1)  # along channel axis

        # Compute all gates at once
        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, self.hidden_channels, dim=1)

        i = torch.sigmoid(cc_i)   # input gate
        f = torch.sigmoid(cc_f)   # forget gate
        o = torch.sigmoid(cc_o)   # output gate
        g = torch.tanh(cc_g)      # gate gate

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, spatial_size, device):
        height, width = spatial_size
        return (torch.zeros(batch_size, self.hidden_channels, height, width, device=device),
                torch.zeros(batch_size, self.hidden_channels, height, width, device=device))

class ConvLSTM2D(nn.Module):
    """
    Implements a ConvLSTM2D layer that processes a sequence of inputs.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size=3, bias=True, num_layers=1):
        super(ConvLSTM2D, self).__init__()
        self.num_layers = num_layers
        self.hidden_channels = hidden_channels

        layers = []
        for i in range(num_layers):
            input_c = input_channels if i == 0 else hidden_channels
            layers.append(ConvLSTMCell(input_c, hidden_channels, kernel_size, bias))
        self.layers = nn.ModuleList(layers)

    def forward(self, input_tensor):
        # input_tensor shape: (batch, seq_len, channels, height, width)
        batch_size, seq_len, channels, height, width = input_tensor.size()
        device = input_tensor.device

        # Initialize hidden and cell states for all layers
        hidden_state = []
        cell_state = []
        for i in range(self.num_layers):
            h, c = self.layers[i].init_hidden(batch_size, (height, width), device)
            hidden_state.append(h)
            cell_state.append(c)

        # Iterate over time steps
        for t in range(seq_len):
            x = input_tensor[:, t, :, :, :]  # (batch, channels, height, width)
            for i, layer in enumerate(self.layers):
                h, c = layer(x, (hidden_state[i], cell_state[i]))
                hidden_state[i] = h
                cell_state[i] = c
                x = h  # input to next layer
        return x  # Return the hidden state of the last layer

# ============================== Swin Transformer Blocks ================================

class WindowAttention(nn.Module):
    """
    Window based multi-head self attention (W-MSA) module with relative position bias.
    """
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
        """
        Args:
            dim (int): Number of input channels.
            window_size (tuple): Height and width of the window.
            num_heads (int): Number of attention heads.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            attn_drop (float): Dropout ratio of attention weights.
            proj_drop (float): Dropout ratio of output.
        """
        super(WindowAttention, self).__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # Define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
        )  # 2*Wh-1 * 2*Ww-1, nH

        # Get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1)
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # Query, Key, Value
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # Initialize relative position bias table
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, Wh*Ww, C)
            mask: (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # 3, B_, nH, N, C//nH
        q, k, v = qkv[0], qkv[1], qkv[2]  # each has shape (B_, nH, N, C//nH)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # (B_, nH, N, N)

        # Add relative position bias
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
        )  # Wh*Ww, Wh*Ww, nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)  # (B_, nH, N, N)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = F.softmax(attn, dim=-1)
        else:
            attn = F.softmax(attn, dim=-1)

        attn = self.attn_drop(attn)

        out = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # (B_, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class SwinTransformerBlock(nn.Module):
    """
    Swin Transformer Block with W-MSA and SW-MSA.
    """
    def __init__(self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True,
                 attn_drop=0., proj_drop=0.):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size  # W
        self.shift_size = shift_size    # S
        self.mlp_ratio = mlp_ratio

        assert 0 <= self.shift_size < self.window_size, "shift_size must be in [0, window_size)"

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, (window_size, window_size), num_heads, qkv_bias, attn_drop, proj_drop)

        self.drop_path = nn.Identity()  # Can implement stochastic depth if desired
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(proj_drop)
        )

        if self.shift_size > 0:
            # Shift the window by shift_size
            self.shift_partition = True
        else:
            self.shift_partition = False

    def forward(self, x):
        """
        Args:
            x: input features with shape (B, H*W, C)
        """
        H = W = int(np.sqrt(x.shape[1]))
        B, L, C = x.shape
        assert L == H * W, "Input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # Cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # Partition windows
        window_size = self.window_size
        # Pad H and W to be multiples of window_size
        pad_b = (window_size - H % window_size) % window_size
        pad_r = (window_size - W % window_size) % window_size
        shifted_x = F.pad(shifted_x, (0, 0, 0, pad_r, 0, pad_b))  # pad H and W
        _, Hp, Wp, _ = shifted_x.shape

        # Window partition
        x_windows = shifted_x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
        x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)  # (num_windows*B, window_size*window_size, C)

        # Attention
        attn_windows = self.attn(x_windows)  # (num_windows*B, window_size*window_size, C)

        # Merge windows
        shifted_x = attn_windows.view(-1, window_size, window_size, C)
        shifted_x = shifted_x.view(B, Hp // window_size, Wp // window_size, window_size, window_size, C)
        shifted_x = shifted_x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, C)

        # Reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        # Remove padding
        x = x[:, :H, :W, :].contiguous().view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(attn_windows.view(B, H * W, C))
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

# =============================== Dice Loss Function ====================================

class DiceLoss(nn.Module):
    """
    Dice Loss function to maximize the Dice coefficient.
    Suitable for binary segmentation tasks.
    """
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred (torch.Tensor): Predicted mask probabilities with shape (B, 1, H, W)
            y_true (torch.Tensor): Ground truth masks with shape (B, 1, H, W)
        Returns:
            torch.Tensor: Dice loss
        """
        y_pred = y_pred.view(-1)
        y_true = y_true.view(-1)

        intersection = (y_pred * y_true).sum()
        dice = (2. * intersection + self.smooth) / (y_pred.sum() + y_true.sum() + self.smooth)

        return 1 - dice

# ================================ Main Model ============================================

import torch
import torch.nn as nn

import torch
import torch.nn as nn

class SwinUNet(nn.Module):
    """
    Swin U-Net architecture for image segmentation.
    """
    def __init__(self, input_channels=3, output_channels=1,
                 embed_dim=32, num_heads=[4, 8], window_size=4,
                 mlp_ratio=4., depth=2):
        super(SwinUNet, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels

        # Initial convolutional layers
        self.conv1 = SeparableConv2d(input_channels, 24, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(24)
        self.conv2 = SeparableConv2d(24, 24, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(24)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 256x256 -> 128x128

        # First Swin Transformer Block
        self.swin_unet_E1 = SwinTransformerBlock(
            dim=24,  # Changed from embed_dim=32 to 24
            num_heads=num_heads[0],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )

        # Second convolutional block
        self.conv3 = SeparableConv2d(24, 48, kernel_size=3, padding=1)  # Changed input from embed_dim=32 to 24
        self.bn3 = nn.BatchNorm2d(48)
        self.conv4 = SeparableConv2d(48, 48, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(48)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 128x128 -> 64x64

        # Second Swin Transformer Block
        self.swin_unet_E2 = SwinTransformerBlock(
            dim=48,  # Changed from embed_dim=32 to 48
            num_heads=num_heads[1],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )

        # Third convolutional block (Bottleneck)
        self.conv5 = SeparableConv2d(48, 96, kernel_size=3, padding=1)  # Changed input from embed_dim=32 to 48
        self.bn5 = nn.BatchNorm2d(96)
        self.conv6 = SeparableConv2d(96, 96, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(96)
        self.drop5 = nn.Dropout(0.5)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)  # 64x64 -> 32x32

        # Bottleneck convolutions with dense connections
        self.conv7 = SeparableConv2d(96, 192, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(192)
        self.conv8 = SeparableConv2d(192, 192, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(192)
        self.drop6_1 = nn.Dropout(0.5)

        self.conv9 = SeparableConv2d(192, 192, kernel_size=3, padding=1)
        self.bn9 = nn.BatchNorm2d(192)
        self.conv10 = SeparableConv2d(192, 192, kernel_size=3, padding=1)
        self.bn10 = nn.BatchNorm2d(192)
        self.drop6_2 = nn.Dropout(0.5)

        self.concat1 = nn.Sequential(
            SeparableConv2d(384, 192, kernel_size=3, padding=1),
            SeparableConv2d(192, 192, kernel_size=3, padding=1)
        )
        self.drop6_3 = nn.Dropout(0.5)

        # First Upsampling Block
        self.up1 = nn.ConvTranspose2d(192, 96, kernel_size=2, stride=2)  # 32x32 -> 64x64
        self.bn_up1 = nn.BatchNorm2d(96)
        self.relu_up1 = nn.ReLU(inplace=True)
        self.convLSTM1 = ConvLSTM2D(input_channels=96, hidden_channels=384, kernel_size=3, num_layers=1)
        self.swin_unet_D1 = SwinTransformerBlock(
            dim=384,  # Changed from embed_dim=32 to 384
            num_heads=num_heads[0],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )
        self.conv11 = SeparableConv2d(384, 48, kernel_size=3, padding=1)
        self.conv12 = SeparableConv2d(48, 48, kernel_size=3, padding=1)

        # Second Upsampling Block
        self.up2 = nn.ConvTranspose2d(48, 48, kernel_size=2, stride=2)  # 64x64 -> 128x128
        self.bn_up2 = nn.BatchNorm2d(48)
        self.relu_up2 = nn.ReLU(inplace=True)
        self.convLSTM2 = ConvLSTM2D(input_channels=48, hidden_channels=96, kernel_size=3, num_layers=1)
        self.swin_unet_D2 = SwinTransformerBlock(
            dim=96,  # Changed from embed_dim=32 to 96
            num_heads=num_heads[1],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )
        self.conv13 = SeparableConv2d(96, 24, kernel_size=3, padding=1)
        self.conv14 = SeparableConv2d(24, 24, kernel_size=3, padding=1)

        # Third Upsampling Block
        self.up3 = nn.ConvTranspose2d(24, 24, kernel_size=2, stride=2)  # 128x128 -> 256x256
        self.bn_up3 = nn.BatchNorm2d(24)
        self.relu_up3 = nn.ReLU(inplace=True)
        self.convLSTM3 = ConvLSTM2D(input_channels=24, hidden_channels=48, kernel_size=3, num_layers=1)
        self.swin_unet_D3 = SwinTransformerBlock(
            dim=48,  # Changed from embed_dim=32 to 48
            num_heads=num_heads[1],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )
        self.conv15 = SeparableConv2d(48, 24, kernel_size=3, padding=1)
        self.conv16 = SeparableConv2d(24, 24, kernel_size=3, padding=1)

        # Output Layer
        self.final_conv1 = nn.Conv2d(24, 2, kernel_size=3, padding=1)
        self.final_relu = nn.ReLU(inplace=True)
        self.final_conv2 = nn.Conv2d(2, 1, kernel_size=1, padding=0)
        self.final_sigmoid = nn.Sigmoid()

    def forward(self, x):
        """
        Forward pass of the Swin U-Net model.
        Args:
            x: Input tensor with shape (B, 3, 256, 256)
        Returns:
            torch.Tensor: Output segmentation mask with shape (B, 1, 256, 256)
        """
        # Initial Convolutions
        x1 = self.conv1(x)          # (B, 24, 256, 256)
        x1 = self.bn1(x1)
        x1 = self.conv2(x1)         # (B, 24, 256, 256)
        x1 = self.bn2(x1)
        p1 = self.pool1(x1)         # (B, 24, 128, 128)

        # First Swin Transformer Block
        p1_flat = p1.flatten(2).transpose(1, 2)  # (B, 128*128, 24)
        swin_E1 = self.swin_unet_E1(p1_flat)     # (B, 128*128, 24)
        swin_E1 = swin_E1.transpose(1, 2).view(-1, 24, 128, 128)  # Reshape for Conv2d

        # Second Convolutional Block
        x2 = self.conv3(swin_E1)    # (B, 48, 128, 128)
        x2 = self.bn3(x2)
        x2 = self.conv4(x2)          # (B, 48, 128, 128)
        x2 = self.bn4(x2)
        p2 = self.pool2(x2)          # (B, 48, 64, 64)

        # Second Swin Transformer Block
        p2_flat = p2.flatten(2).transpose(1, 2)  # (B, 64*64, 48)
        swin_E2 = self.swin_unet_E2(p2_flat)     # (B, 64*64, 48)
        swin_E2 = swin_E2.transpose(1, 2).view(-1, 48, 64, 64)  # Reshape for Conv2d

        # Third Convolutional Block (Bottleneck)
        x3 = self.conv5(swin_E2)    # (B, 96, 64, 64)
        x3 = self.bn5(x3)
        x3 = self.conv6(x3)          # (B, 96, 64, 64)
        x3 = self.bn6(x3)
        x3 = self.drop5(x3)
        p3 = self.pool3(x3)          # (B, 96, 32, 32)

        # Bottleneck Convolutions with Dense Connections
        x4 = self.conv7(p3)          # (B, 192, 32, 32)
        x4 = self.bn7(x4)
        x4 = self.conv8(x4)          # (B, 192, 32, 32)
        x4 = self.bn8(x4)
        x4 = self.drop6_1(x4)

        x5 = self.conv9(x4)          # (B, 192, 32, 32)
        x5 = self.bn9(x5)
        x5 = self.conv10(x5)         # (B, 192, 32, 32)
        x5 = self.bn10(x5)
        x5 = self.drop6_2(x5)

        concat = torch.cat([x5, x4], dim=1)  # (B, 384, 32, 32)
        concat = self.concat1(concat)         # (B, 192, 32, 32)
        concat = self.drop6_3(concat)         # (B, 192, 32, 32)

        # First Upsampling Block
        up1 = self.up1(concat)                 # (B, 96, 64, 64)
        up1 = self.bn_up1(up1)
        up1 = self.relu_up1(up1)

        # Prepare for ConvLSTM2D
        # ConvLSTM2D expects input of shape (B, seq_len, C, H, W)
        up1_seq = up1.unsqueeze(1)             # (B, 1, 96, 64, 64)
        x3_seq = x3.unsqueeze(1)               # (B, 1, 96, 64, 64)
        merge1 = torch.cat([x3_seq, up1_seq], dim=1)  # (B, 2, 96, 64, 64)

        # Apply ConvLSTM2D
        convLSTM1_out = self.convLSTM1(merge1)       # (B, 384, 64, 64)

        # Swin Transformer Block in Decoder
        convLSTM1_flat = convLSTM1_out.flatten(2).transpose(1, 2)  # (B, 64*64, 384)
        swin_D1 = self.swin_unet_D1(convLSTM1_flat)               # (B, 64*64, 384)
        swin_D1 = swin_D1.transpose(1, 2).view(-1, 384, 64, 64)    # Reshape for Conv2d

        # Further Convolutions
        conv6 = self.conv11(swin_D1)        # (B, 48, 64, 64)
        conv6 = self.conv12(conv6)          # (B, 48, 64, 64)

        # Second Upsampling Block
        up2 = self.up2(conv6)               # (B, 48, 128, 128)
        up2 = self.bn_up2(up2)
        up2 = self.relu_up2(up2)

        # Prepare for ConvLSTM2D
        up2_seq = up2.unsqueeze(1)           # (B, 1, 48, 128, 128)
        x2_seq = x2.unsqueeze(1)             # (B, 1, 48, 128, 128)
        merge2 = torch.cat([x2_seq, up2_seq], dim=1)  # (B, 2, 48, 128, 128)

        # Apply ConvLSTM2D
        convLSTM2_out = self.convLSTM2(merge2)       # (B, 96, 128, 128)

        # Swin Transformer Block in Decoder
        convLSTM2_flat = convLSTM2_out.flatten(2).transpose(1, 2)  # (B, 128*128, 96)
        swin_D2 = self.swin_unet_D2(convLSTM2_flat)               # (B, 128*128, 96)
        swin_D2 = swin_D2.transpose(1, 2).view(-1, 96, 128, 128)    # Reshape for Conv2d

        # Further Convolutions
        conv7 = self.conv13(swin_D2)        # (B, 24, 128, 128)
        conv7 = self.conv14(conv7)          # (B, 24, 128, 128)

        # Third Upsampling Block
        up3 = self.up3(conv7)               # (B, 24, 256, 256)
        up3 = self.bn_up3(up3)
        up3 = self.relu_up3(up3)

        # Prepare for ConvLSTM2D
        up3_seq = up3.unsqueeze(1)           # (B, 1, 24, 256, 256)
        x1_seq = x1.unsqueeze(1)             # (B, 1, 24, 256, 256)
        merge3 = torch.cat([x1_seq, up3_seq], dim=1)  # (B, 2, 24, 256, 256)

        # Apply ConvLSTM2D
        convLSTM3_out = self.convLSTM3(merge3)       # (B, 48, 256, 256)

        # Swin Transformer Block in Decoder
        convLSTM3_flat = convLSTM3_out.flatten(2).transpose(1, 2)  # (B, 256*256, 48)
        swin_D3 = self.swin_unet_D3(convLSTM3_flat)               # (B, 256*256, 48)
        swin_D3 = swin_D3.transpose(1, 2).view(-1, 48, 256, 256)    # Reshape for Conv2d

        # Further Convolutions
        conv8 = self.conv15(swin_D3)        # (B, 24, 256, 256)
        conv8 = self.conv16(conv8)          # (B, 24, 256, 256)

        # Final Output Convolutions
        final = self.final_conv1(conv8)      # (B, 2, 256, 256)
        final = self.final_relu(final)
        final = self.final_conv2(final)      # (B, 1, 256, 256)
        final = self.final_sigmoid(final)    # (B, 1, 256, 256)

        return final

# ================================== Dataset Class ======================================

class SegmentationDataset(Dataset):
    """
    Custom Dataset for image segmentation tasks.
    Expects images in 'x' folder and masks in 'y' folder.
    """
    def __init__(self, images_dir, masks_dir, transform=None):
        super(SegmentationDataset, self).__init__()
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform

        self.images = sorted(os.listdir(images_dir))
        self.masks = sorted(os.listdir(masks_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.images_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')  # Ensure RGB

        # Load mask
        mask_path = os.path.join(self.masks_dir, self.masks[idx])
        mask = Image.open(mask_path).convert('L')    # Grayscale

        # Apply transformations
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

# =============================== Data Loading and Preprocessing ========================

# Define image dimensions
im_height = 256
im_width = 256

# Define transformations
transform = transforms.Compose([
    transforms.Resize((im_height, im_width)),
    transforms.ToTensor(),  # Converts to [0,1]
])

# Paths to the dataset (update these paths as per your directory structure)
train_images_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1-2_Training_Input'
train_masks_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1_Training_GroundTruth'
test_images_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1-2_Test_Input'
test_masks_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1_Test_GroundTruth'

# Create datasets
train_dataset = SegmentationDataset(train_images_dir, train_masks_dir, transform=transform)
test_dataset = SegmentationDataset(test_images_dir, test_masks_dir, transform=transform)


# Split training data into training and validation sets (80-20 split)
train_size = int(0.8 * len(train_dataset))
valid_size = len(train_dataset) - train_size
train_subset, valid_subset = torch.utils.data.random_split(train_dataset, [train_size, valid_size])

# Create DataLoaders
batch_size = 5

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_subset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# =============================== Training Setup ==========================================

# Instantiate the model
model = SwinUNet(input_channels=3, output_channels=1, embed_dim=32, num_heads=[4, 8], window_size=4, mlp_ratio=4., depth=2)
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')  # Move to GPU if available

# Define loss function and optimizer
criterion = DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Define learning rate scheduler and early stopping parameters
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.25, patience=5, verbose=True, min_lr=1e-9)
early_stopping_patience = 9
best_val_loss = np.inf
epochs_no_improve = 0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [None]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [None]:
from torchinfo import summary

# Define the input size based on your model's expected input.
# For example, if your model expects images with 3 channels and 256x256 dimensions:
input_size = (1, 3, 256, 256)  # (batch_size, channels, height, width)

# Generate and print the model summary
summary(model, input_size=input_size, device='cuda' if torch.cuda.is_available() else 'cpu')

Layer (type:depth-idx)                   Output Shape              Param #
SwinUNet                                 [1, 1, 256, 256]          --
├─SeparableConv2d: 1-1                   [1, 24, 256, 256]         --
│    └─Conv2d: 2-1                       [1, 3, 256, 256]          30
│    └─Conv2d: 2-2                       [1, 24, 256, 256]         96
├─BatchNorm2d: 1-2                       [1, 24, 256, 256]         48
├─SeparableConv2d: 1-3                   [1, 24, 256, 256]         --
│    └─Conv2d: 2-3                       [1, 24, 256, 256]         240
│    └─Conv2d: 2-4                       [1, 24, 256, 256]         600
├─BatchNorm2d: 1-4                       [1, 24, 256, 256]         48
├─MaxPool2d: 1-5                         [1, 24, 128, 128]         --
├─SwinTransformerBlock: 1-6              [1, 16384, 24]            --
│    └─LayerNorm: 2-5                    [1, 16384, 24]            48
│    └─WindowAttention: 2-6              [1024, 16, 24]            196
│    │    └─

In [None]:
# =============================== Training Loop ===========================================

num_epochs = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'

for epoch in range(num_epochs):
    print("Epoch", epoch)
    model.train()
    running_loss = 0.0
    for images, masks in train_loader:
        images = images.to(device)  # (B, 3, 256, 256)
        masks = masks.to(device)    # (B, 1, 256, 256)

        optimizer.zero_grad()
        outputs = model(images)      # (B, 1, 256, 256)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in valid_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)

            val_loss += loss.item() * images.size(0)

    val_loss /= len(valid_loader.dataset)

    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Validation Loss: {val_loss:.4f}')

    # Scheduler step
    scheduler.step(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        # Save the best model
        torch.save(model.state_dict(), r'/content/drive/MyDrive/model/modelWeights_Swin_Trans_Weights_Swin_Trans_Leather.pth')
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= early_stopping_patience:
            print('Early stopping!')
            break


Epoch 0
Epoch 1/1, Training Loss: 0.5978, Validation Loss: 0.5450


In [None]:

# ================================== Prediction ==========================================

# Load the best model weights
model.load_state_dict(torch.load(r'/content/drive/MyDrive/model/modelWeights_Swin_Trans_Weights_Swin_Trans_Leather.pth'))
model.eval()

# Function to save predictions and ground truth
def save_predictions(model, dataloader, save_dir_pred, save_dir_gt, device):
    """
    Saves the predicted masks and ground truth masks.
    Args:
        model (nn.Module): Trained model.
        dataloader (DataLoader): DataLoader for test data.
        save_dir_pred (str): Directory to save predicted masks.
        save_dir_gt (str): Directory to save ground truth masks.
        device (str): Device to run the model on.
    """
    os.makedirs(save_dir_pred, exist_ok=True)
    os.makedirs(save_dir_gt, exist_ok=True)

    with torch.no_grad():
        for i, (images, masks) in enumerate(dataloader):
            if (i % 100 == 0):
              print(f"{i}th Test Image") # There are total 1000 test images.

            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = outputs.cpu().numpy()
            gts = masks.cpu().numpy()

            for j in range(preds.shape[0]):
                pred_mask = preds[j, 0, :, :]
                gt_mask = gts[j, 0, :, :]

                # Save predicted mask
                pred_img = Image.fromarray((pred_mask * 255).astype(np.uint8))
                pred_img.save(os.path.join(save_dir_pred, f"{i * dataloader.batch_size + j + 1}.png"))

                # Save ground truth mask
                gt_img = Image.fromarray((gt_mask * 255).astype(np.uint8))
                gt_img.save(os.path.join(save_dir_gt, f"{i * dataloader.batch_size + j + 1}.tiff"))

# Define directories to save predictions and ground truth
save_dir_pred = r'/content/drive/MyDrive/output/segmented predicted images'
save_dir_gt = r'/content/drive/MyDrive/output/segmented ground truth'

# Save predictions
save_predictions(model, test_loader, save_dir_pred, save_dir_gt, device)

  model.load_state_dict(torch.load(r'/content/drive/MyDrive/model/modelWeights_Swin_Trans_Weights_Swin_Trans_Leather.pth'))


0th Test Image
100th Test Image


TRY THE BELOW TWO EVALUATION JUPYTER CELLS SEPERATELY .. I Forgot the difference between each

In [None]:
# # =================================== Evaluation =========================================

# def evaluate_metrics_pytorch(model, dataloader, device):
#     """
#     Evaluates various metrics for segmentation performance.
#     Args:
#         model (nn.Module): Trained model.
#         dataloader (DataLoader): DataLoader for test data.
#         device (str): Device to run the model on.
#     Returns:
#         dict: Dictionary containing average metrics.
#     """
#     model.eval()
#     all_accuracy = []
#     all_dice = []
#     all_jaccard = []
#     all_sensitivity = []
#     all_specificity = []

#     with torch.no_grad():
#         for images, masks in dataloader:
#             images = images.to(device)
#             masks = masks.to(device)

#             outputs = model(images)
#             preds = outputs > 0.5  # Binary mask

#             preds = preds.cpu().numpy().astype(np.uint8)
#             masks = masks.cpu().numpy().astype(np.uint8)

#             for pred, mask in zip(preds, masks):
#                 pred_flat = pred.flatten()
#                 mask_flat = mask.flatten()

#                 # Precision-Recall Curve to find optimal threshold
#                 precisions, recalls, thresholds = precision_recall_curve(mask_flat, pred.flatten())
#                 f1 = 2 * (precisions * recalls) / (precisions + recalls + 1e-8)
#                 max_idx = np.argmax(f1)
#                 optimal_thresh = thresholds[max_idx] if max_idx < len(thresholds) else 0.5

#                 # Apply optimal threshold
#                 pred_opt = (pred_flat >= optimal_thresh).astype(np.uint8)

#                 # Confusion matrix
#                 tn, fp, fn, tp = confusion_matrix(mask_flat, pred_opt).ravel()

#                 # Calculate metrics
#                 accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-8)
#                 iou = tp / (tp + fp + fn + 1e-8)
#                 dice = (2 * tp) / (2 * tp + fp + fn + 1e-8)
#                 specificity = tn / (tn + fp + 1e-8)
#                 sensitivity = tp / (tp + fn + 1e-8)

#                 all_accuracy.append(accuracy)
#                 all_jaccard.append(iou)
#                 all_dice.append(dice)
#                 all_specificity.append(specificity)
#                 all_sensitivity.append(sensitivity)

#     # Compute average metrics
#     metrics = {
#         'Accuracy': np.mean(all_accuracy),
#         'Dice': np.mean(all_dice),
#         'Jaccard': np.mean(all_jaccard),
#         'Sensitivity': np.mean(all_sensitivity),
#         'Specificity': np.mean(all_specificity)
#     }

#     print(f"Accuracy: {metrics['Accuracy']:.4f}, Dice: {metrics['Dice']:.4f}, Jaccard: {metrics['Jaccard']:.4f}, "
#           f"Sensitivity: {metrics['Sensitivity']:.4f}, Specificity: {metrics['Specificity']:.4f}")

#     return metrics

# # Evaluate the model
# metrics = evaluate_metrics_pytorch(model, test_loader, device)

In [None]:
# =================================== Evaluation =========================================

from sklearn.metrics import confusion_matrix, precision_recall_curve
import numpy as np
import torch

def evaluate_metrics_pytorch(model, dataloader, device):
    """
    Evaluates various metrics for segmentation performance.
    Args:
        model (nn.Module): Trained model.
        dataloader (DataLoader): DataLoader for test data.
        device (str): Device to run the model on.
    Returns:
        dict: Dictionary containing average metrics.
    """
    model.eval()
    all_accuracy = []
    all_dice = []
    all_jaccard = []
    all_sensitivity = []
    all_specificity = []

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = outputs > 0.5  # Binary mask

            preds = preds.cpu().numpy().astype(np.uint8)
            masks = masks.cpu().numpy().astype(np.uint8)
            masks = (masks > 0).astype(np.uint8)  # Convert 255 to 1

            for pred, mask in zip(preds, masks):
                pred_flat = pred.flatten()
                mask_flat = mask.flatten()

                # Precision-Recall Curve to find optimal threshold
                precisions, recalls, thresholds = precision_recall_curve(mask_flat, pred_flat)
                f1 = 2 * (precisions * recalls) / (precisions + recalls + 1e-8)
                max_idx = np.argmax(f1)
                optimal_thresh = thresholds[max_idx] if max_idx < len(thresholds) else 0.5

                # Apply optimal threshold
                pred_opt = (pred_flat >= optimal_thresh).astype(np.uint8)

                # Confusion matrix with specified labels
                cm = confusion_matrix(mask_flat, pred_opt, labels=[0,1])

                # Unpack confusion matrix
                tn, fp, fn, tp = cm.ravel()

                # Calculate metrics
                accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-8)
                iou = tp / (tp + fp + fn + 1e-8)
                dice = (2 * tp) / (2 * tp + fp + fn + 1e-8)
                specificity = tn / (tn + fp + 1e-8)
                sensitivity = tp / (tp + fn + 1e-8)

                all_accuracy.append(accuracy)
                all_jaccard.append(iou)
                all_dice.append(dice)
                all_specificity.append(specificity)
                all_sensitivity.append(sensitivity)

    # Compute average metrics
    metrics = {
        'Accuracy': np.mean(all_accuracy),
        'Dice': np.mean(all_dice),
        'Jaccard': np.mean(all_jaccard),
        'Sensitivity': np.mean(all_sensitivity),
        'Specificity': np.mean(all_specificity)
    }

    print(f"Accuracy: {metrics['Accuracy']:.4f}, Dice: {metrics['Dice']:.4f}, Jaccard: {metrics['Jaccard']:.4f}, "
          f"Sensitivity: {metrics['Sensitivity']:.4f}, Specificity: {metrics['Specificity']:.4f}")

    return metrics

# Evaluate the model
metrics = evaluate_metrics_pytorch(model, test_loader, device)

In [None]:
# model.summary()

In [None]:
1+2

In [None]:
!pip install torchinfo

BIDIRECTIONAL CONVLSTM

In [None]:
# ============================== Imports and Dependencies ==============================

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from zipfile import ZipFile

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve, confusion_matrix

# ================================ Separable Convolution =================================

class SeparableConv2d(nn.Module):
    """
    Implements a separable convolution layer using depthwise and pointwise convolutions.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, bias=True):
        super(SeparableConv2d, self).__init__()
        # Depthwise convolution (groups=in_channels)
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
                                   padding=padding, groups=in_channels, bias=bias)
        # Pointwise convolution
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1,
                                   padding=0, bias=bias)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

# ================================== ConvLSTM2D ========================================

class ConvLSTMCell(nn.Module):
    """
    Implements a ConvLSTM cell.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):
        super(ConvLSTMCell, self).__init__()

        padding = kernel_size // 2  # To maintain spatial dimensions
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.conv = nn.Conv2d(in_channels=input_channels + hidden_channels,
                              out_channels=4 * hidden_channels,
                              kernel_size=kernel_size,
                              padding=padding,
                              bias=bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        # Concatenate input and hidden state
        combined = torch.cat([input_tensor, h_cur], dim=1)  # along channel axis

        # Compute all gates at once
        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, self.hidden_channels, dim=1)

        i = torch.sigmoid(cc_i)   # input gate
        f = torch.sigmoid(cc_f)   # forget gate
        o = torch.sigmoid(cc_o)   # output gate
        g = torch.tanh(cc_g)      # gate gate

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, spatial_size, device):
        height, width = spatial_size
        return (torch.zeros(batch_size, self.hidden_channels, height, width, device=device),
                torch.zeros(batch_size, self.hidden_channels, height, width, device=device))

class ConvLSTM2D(nn.Module):
    """
    Implements a ConvLSTM2D layer that processes a sequence of inputs.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size=3, bias=True, num_layers=1):
        super(ConvLSTM2D, self).__init__()
        self.num_layers = num_layers
        self.hidden_channels = hidden_channels

        layers = []
        for i in range(num_layers):
            input_c = input_channels if i == 0 else hidden_channels
            layers.append(ConvLSTMCell(input_c, hidden_channels, kernel_size, bias))
        self.layers = nn.ModuleList(layers)

    def forward(self, input_tensor, reverse=False):
        # input_tensor shape: (batch, seq_len, channels, height, width)
        batch_size, seq_len, channels, height, width = input_tensor.size()
        device = input_tensor.device

        # Initialize hidden and cell states for all layers
        hidden_state = []
        cell_state = []
        for i in range(self.num_layers):
            h, c = self.layers[i].init_hidden(batch_size, (height, width), device)
            hidden_state.append(h)
            cell_state.append(c)

        # Iterate over time steps
        if reverse:
            time_steps = reversed(range(seq_len))
        else:
            time_steps = range(seq_len)

        outputs = []
        for t in time_steps:
            x = input_tensor[:, t, :, :, :]  # (batch, channels, height, width)
            for i, layer in enumerate(self.layers):
                h, c = layer(x, (hidden_state[i], cell_state[i]))
                hidden_state[i] = h
                cell_state[i] = c
                x = h  # input to next layer
            outputs.append(x)

        outputs = torch.stack(outputs, dim=1)  # (batch, seq_len, channels, height, width)
        if reverse:
            outputs = outputs.flip(dims=[1])  # Reverse back to original order
        return outputs  # Return the sequence of outputs

class BidirectionalConvLSTM2D(nn.Module):
    """
    Implements a Bidirectional ConvLSTM2D layer.
    Processes the input sequence in both forward and backward directions and concatenates the outputs.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size=3, num_layers=1, bias=True):
        super(BidirectionalConvLSTM2D, self).__init__()
        self.forward_conv_lstm = ConvLSTM2D(input_channels, hidden_channels, kernel_size, bias=bias, num_layers=num_layers)
        self.backward_conv_lstm = ConvLSTM2D(input_channels, hidden_channels, kernel_size, bias=bias, num_layers=num_layers)

    def forward(self, input_tensor):
        # input_tensor shape: (batch, seq_len, channels, height, width)
        # Forward direction
        forward_output = self.forward_conv_lstm(input_tensor, reverse=False)  # (batch, seq_len, hidden_channels, H, W)
        # Backward direction
        backward_output = self.backward_conv_lstm(input_tensor, reverse=True)  # (batch, seq_len, hidden_channels, H, W)
        # Concatenate outputs along the channel dimension
        output = torch.cat([forward_output, backward_output], dim=2)  # (batch, seq_len, hidden_channels*2, H, W)
        # Since seq_len=1 in our case after merging, we can squeeze the seq_len dimension
        output = output[:, -1, :, :, :]  # Take the last output (batch, hidden_channels*2, H, W)
        return output

# ============================== Swin Transformer Blocks ================================

class WindowAttention(nn.Module):
    """
    Window based multi-head self attention (W-MSA) module with relative position bias.
    """
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
        """
        Args:
            dim (int): Number of input channels.
            window_size (tuple): Height and width of the window.
            num_heads (int): Number of attention heads.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            attn_drop (float): Dropout ratio of attention weights.
            proj_drop (float): Dropout ratio of output.
        """
        super(WindowAttention, self).__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # Define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
        )  # 2*Wh-1 * 2*Ww-1, nH

        # Get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij'))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1)
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # Query, Key, Value
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # Initialize relative position bias table
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, Wh*Ww, C)
            mask: (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # 3, B_, nH, N, C//nH
        q, k, v = qkv[0], qkv[1], qkv[2]  # each has shape (B_, nH, N, C//nH)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # (B_, nH, N, N)

        # Add relative position bias
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
        )  # Wh*Ww, Wh*Ww, nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)  # (B_, nH, N, N)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = F.softmax(attn, dim=-1)
        else:
            attn = F.softmax(attn, dim=-1)

        attn = self.attn_drop(attn)

        out = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # (B_, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class SwinTransformerBlock(nn.Module):
    """
    Swin Transformer Block with W-MSA and SW-MSA.
    """
    def __init__(self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True,
                 attn_drop=0., proj_drop=0.):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size  # W
        self.shift_size = shift_size    # S
        self.mlp_ratio = mlp_ratio

        assert 0 <= self.shift_size < self.window_size, "shift_size must be in [0, window_size)"

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, (window_size, window_size), num_heads, qkv_bias, attn_drop, proj_drop)

        self.drop_path = nn.Identity()  # Can implement stochastic depth if desired
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(proj_drop)
        )

    def forward(self, x):
        """
        Args:
            x: input features with shape (B, H*W, C)
        """
        H = W = int(np.sqrt(x.shape[1]))
        B, L, C = x.shape
        assert L == H * W, "Input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # Cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # Partition windows
        window_size = self.window_size
        # Pad H and W to be multiples of window_size
        pad_b = (window_size - H % window_size) % window_size
        pad_r = (window_size - W % window_size) % window_size
        shifted_x = F.pad(shifted_x, (0, 0, 0, pad_r, 0, pad_b))  # pad H and W
        _, Hp, Wp, _ = shifted_x.shape

        # Window partition
        x_windows = shifted_x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
        x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)  # (num_windows*B, window_size*window_size, C)

        # Attention
        attn_windows = self.attn(x_windows)  # (num_windows*B, window_size*window_size, C)

        # Merge windows
        shifted_x = attn_windows.view(-1, window_size, window_size, C)
        shifted_x = shifted_x.view(B, Hp // window_size, Wp // window_size, window_size, window_size, C)
        shifted_x = shifted_x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, C)

        # Reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        # Remove padding
        x = x[:, :H, :W, :].contiguous().view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

# =============================== Dice Loss Function ====================================

class DiceLoss(nn.Module):
    """
    Dice Loss function to maximize the Dice coefficient.
    Suitable for binary segmentation tasks.
    """
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred (torch.Tensor): Predicted mask probabilities with shape (B, 1, H, W)
            y_true (torch.Tensor): Ground truth masks with shape (B, 1, H, W)
        Returns:
            torch.Tensor: Dice loss
        """
        y_pred = y_pred.view(-1)
        y_true = y_true.view(-1)

        intersection = (y_pred * y_true).sum()
        dice = (2. * intersection + self.smooth) / (y_pred.sum() + y_true.sum() + self.smooth)

        return 1 - dice

# ================================ Main Model ============================================

class SwinUNet(nn.Module):
    """
    Swin U-Net architecture for image segmentation with bidirectional ConvLSTM layers.
    """
    def __init__(self, input_channels=3, output_channels=1,
                 embed_dim=32, num_heads=[4, 8], window_size=4,
                 mlp_ratio=4., depth=2):
        super(SwinUNet, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels

        # Initial convolutional layers
        self.conv1 = SeparableConv2d(input_channels, 24, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(24)
        self.conv2 = SeparableConv2d(24, 24, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(24)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 256x256 -> 128x128

        # First Swin Transformer Block
        self.swin_unet_E1 = SwinTransformerBlock(
            dim=24,  # Changed from embed_dim=32 to 24
            num_heads=num_heads[0],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )

        # Second convolutional block
        self.conv3 = SeparableConv2d(24, 48, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(48)
        self.conv4 = SeparableConv2d(48, 48, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(48)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 128x128 -> 64x64

        # Second Swin Transformer Block
        self.swin_unet_E2 = SwinTransformerBlock(
            dim=48,
            num_heads=num_heads[1],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )

        # Third convolutional block (Bottleneck)
        self.conv5 = SeparableConv2d(48, 96, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(96)
        self.conv6 = SeparableConv2d(96, 96, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(96)
        self.drop5 = nn.Dropout(0.5)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)  # 64x64 -> 32x32

        # Bottleneck convolutions with dense connections
        self.conv7 = SeparableConv2d(96, 192, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(192)
        self.conv8 = SeparableConv2d(192, 192, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(192)
        self.drop6_1 = nn.Dropout(0.5)

        self.conv9 = SeparableConv2d(192, 192, kernel_size=3, padding=1)
        self.bn9 = nn.BatchNorm2d(192)
        self.conv10 = SeparableConv2d(192, 192, kernel_size=3, padding=1)
        self.bn10 = nn.BatchNorm2d(192)
        self.drop6_2 = nn.Dropout(0.5)

        self.concat1 = nn.Sequential(
            SeparableConv2d(384, 192, kernel_size=3, padding=1),
            SeparableConv2d(192, 192, kernel_size=3, padding=1)
        )
        self.drop6_3 = nn.Dropout(0.5)

        # First Upsampling Block
        self.up1 = nn.ConvTranspose2d(192, 96, kernel_size=2, stride=2)  # 32x32 -> 64x64
        self.bn_up1 = nn.BatchNorm2d(96)
        self.relu_up1 = nn.ReLU(inplace=True)
        self.bidirectional_convLSTM1 = BidirectionalConvLSTM2D(input_channels=96, hidden_channels=192, kernel_size=3, num_layers=1)
        self.swin_unet_D1 = SwinTransformerBlock(
            dim=192 * 2,  # Adjusted for bidirectional output
            num_heads=num_heads[0],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )
        self.conv11 = SeparableConv2d(192 * 2, 48, kernel_size=3, padding=1)
        self.conv12 = SeparableConv2d(48, 48, kernel_size=3, padding=1)

        # Second Upsampling Block
        self.up2 = nn.ConvTranspose2d(48, 48, kernel_size=2, stride=2)  # 64x64 -> 128x128
        self.bn_up2 = nn.BatchNorm2d(48)
        self.relu_up2 = nn.ReLU(inplace=True)
        self.bidirectional_convLSTM2 = BidirectionalConvLSTM2D(input_channels=48, hidden_channels=96, kernel_size=3, num_layers=1)
        self.swin_unet_D2 = SwinTransformerBlock(
            dim=96 * 2,
            num_heads=num_heads[1],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )
        self.conv13 = SeparableConv2d(96 * 2, 24, kernel_size=3, padding=1)
        self.conv14 = SeparableConv2d(24, 24, kernel_size=3, padding=1)

        # Third Upsampling Block
        self.up3 = nn.ConvTranspose2d(24, 24, kernel_size=2, stride=2)  # 128x128 -> 256x256
        self.bn_up3 = nn.BatchNorm2d(24)
        self.relu_up3 = nn.ReLU(inplace=True)
        self.bidirectional_convLSTM3 = BidirectionalConvLSTM2D(input_channels=24, hidden_channels=48, kernel_size=3, num_layers=1)
        self.swin_unet_D3 = SwinTransformerBlock(
            dim=48 * 2,
            num_heads=num_heads[1],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )
        self.conv15 = SeparableConv2d(48 * 2, 24, kernel_size=3, padding=1)
        self.conv16 = SeparableConv2d(24, 24, kernel_size=3, padding=1)

        # Output Layer
        self.final_conv1 = nn.Conv2d(24, 2, kernel_size=3, padding=1)
        self.final_relu = nn.ReLU(inplace=True)
        self.final_conv2 = nn.Conv2d(2, 1, kernel_size=1, padding=0)
        self.final_sigmoid = nn.Sigmoid()

    def forward(self, x):
        """
        Forward pass of the Swin U-Net model.
        Args:
            x: Input tensor with shape (B, 3, 256, 256)
        Returns:
            torch.Tensor: Output segmentation mask with shape (B, 1, 256, 256)
        """
        # Initial Convolutions
        x1 = self.conv1(x)          # (B, 24, 256, 256)
        x1 = self.bn1(x1)
        x1 = self.conv2(x1)         # (B, 24, 256, 256)
        x1 = self.bn2(x1)
        p1 = self.pool1(x1)         # (B, 24, 128, 128)

        # First Swin Transformer Block
        p1_flat = p1.flatten(2).transpose(1, 2)  # (B, 128*128, 24)
        swin_E1 = self.swin_unet_E1(p1_flat)     # (B, 128*128, 24)
        swin_E1 = swin_E1.transpose(1, 2).view(-1, 24, 128, 128)  # Reshape for Conv2d

        # Second Convolutional Block
        x2 = self.conv3(swin_E1)    # (B, 48, 128, 128)
        x2 = self.bn3(x2)
        x2 = self.conv4(x2)          # (B, 48, 128, 128)
        x2 = self.bn4(x2)
        p2 = self.pool2(x2)          # (B, 48, 64, 64)

        # Second Swin Transformer Block
        p2_flat = p2.flatten(2).transpose(1, 2)  # (B, 64*64, 48)
        swin_E2 = self.swin_unet_E2(p2_flat)     # (B, 64*64, 48)
        swin_E2 = swin_E2.transpose(1, 2).view(-1, 48, 64, 64)  # Reshape for Conv2d

        # Third Convolutional Block (Bottleneck)
        x3 = self.conv5(swin_E2)    # (B, 96, 64, 64)
        x3 = self.bn5(x3)
        x3 = self.conv6(x3)          # (B, 96, 64, 64)
        x3 = self.bn6(x3)
        x3 = self.drop5(x3)
        p3 = self.pool3(x3)          # (B, 96, 32, 32)

        # Bottleneck Convolutions with Dense Connections
        x4 = self.conv7(p3)          # (B, 192, 32, 32)
        x4 = self.bn7(x4)
        x4 = self.conv8(x4)          # (B, 192, 32, 32)
        x4 = self.bn8(x4)
        x4 = self.drop6_1(x4)

        x5 = self.conv9(x4)          # (B, 192, 32, 32)
        x5 = self.bn9(x5)
        x5 = self.conv10(x5)         # (B, 192, 32, 32)
        x5 = self.bn10(x5)
        x5 = self.drop6_2(x5)

        concat = torch.cat([x5, x4], dim=1)  # (B, 384, 32, 32)
        concat = self.concat1(concat)         # (B, 192, 32, 32)
        concat = self.drop6_3(concat)         # (B, 192, 32, 32)

        # First Upsampling Block
        up1 = self.up1(concat)                 # (B, 96, 64, 64)
        up1 = self.bn_up1(up1)
        up1 = self.relu_up1(up1)

        # Prepare for BidirectionalConvLSTM2D
        up1_seq = torch.stack([x3, up1], dim=1)  # (B, 2, 96, 64, 64)
        bidir_convLSTM1_out = self.bidirectional_convLSTM1(up1_seq)  # (B, 192*2, 64, 64)

        # Swin Transformer Block in Decoder
        bidir_convLSTM1_flat = bidir_convLSTM1_out.flatten(2).transpose(1, 2)  # (B, 64*64, 192*2)
        swin_D1 = self.swin_unet_D1(bidir_convLSTM1_flat)               # (B, 64*64, 192*2)
        swin_D1 = swin_D1.transpose(1, 2).view(-1, 192*2, 64, 64)    # Reshape for Conv2d

        # Further Convolutions
        conv6 = self.conv11(swin_D1)        # (B, 48, 64, 64)
        conv6 = self.conv12(conv6)          # (B, 48, 64, 64)

        # Second Upsampling Block
        up2 = self.up2(conv6)               # (B, 48, 128, 128)
        up2 = self.bn_up2(up2)
        up2 = self.relu_up2(up2)

        # Prepare for BidirectionalConvLSTM2D
        up2_seq = torch.stack([x2, up2], dim=1)  # (B, 2, 48, 128, 128)
        bidir_convLSTM2_out = self.bidirectional_convLSTM2(up2_seq)  # (B, 96*2, 128, 128)

        # Swin Transformer Block in Decoder
        bidir_convLSTM2_flat = bidir_convLSTM2_out.flatten(2).transpose(1, 2)  # (B, 128*128, 96*2)
        swin_D2 = self.swin_unet_D2(bidir_convLSTM2_flat)               # (B, 128*128, 96*2)
        swin_D2 = swin_D2.transpose(1, 2).view(-1, 96*2, 128, 128)    # Reshape for Conv2d

        # Further Convolutions
        conv7 = self.conv13(swin_D2)        # (B, 24, 128, 128)
        conv7 = self.conv14(conv7)          # (B, 24, 128, 128)

        # Third Upsampling Block
        up3 = self.up3(conv7)               # (B, 24, 256, 256)
        up3 = self.bn_up3(up3)
        up3 = self.relu_up3(up3)

        # Prepare for BidirectionalConvLSTM2D
        up3_seq = torch.stack([x1, up3], dim=1)  # (B, 2, 24, 256, 256)
        bidir_convLSTM3_out = self.bidirectional_convLSTM3(up3_seq)  # (B, 48*2, 256, 256)

        # Swin Transformer Block in Decoder
        bidir_convLSTM3_flat = bidir_convLSTM3_out.flatten(2).transpose(1, 2)  # (B, 256*256, 48*2)
        swin_D3 = self.swin_unet_D3(bidir_convLSTM3_flat)               # (B, 256*256, 48*2)
        swin_D3 = swin_D3.transpose(1, 2).view(-1, 48*2, 256, 256)    # Reshape for Conv2d

        # Further Convolutions
        conv8 = self.conv15(swin_D3)        # (B, 24, 256, 256)
        conv8 = self.conv16(conv8)          # (B, 24, 256, 256)

        # Final Output Convolutions
        final = self.final_conv1(conv8)      # (B, 2, 256, 256)
        final = self.final_relu(final)
        final = self.final_conv2(final)      # (B, 1, 256, 256)
        final = self.final_sigmoid(final)    # (B, 1, 256, 256)

        return final

# ================================== Dataset Class ======================================

class SegmentationDataset(Dataset):
    """
    Custom Dataset for image segmentation tasks.
    Expects images in 'x' folder and masks in 'y' folder.
    """
    def __init__(self, images_dir, masks_dir, transform=None):
        super(SegmentationDataset, self).__init__()
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform

        self.images = sorted(os.listdir(images_dir))
        self.masks = sorted(os.listdir(masks_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.images_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')  # Ensure RGB

        # Load mask
        mask_path = os.path.join(self.masks_dir, self.masks[idx])
        mask = Image.open(mask_path).convert('L')    # Grayscale

        # Apply transformations
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

# =============================== Data Loading and Preprocessing ========================

# Define image dimensions
im_height = 256
im_width = 256

# Define transformations
transform = transforms.Compose([
    transforms.Resize((im_height, im_width)),
    transforms.ToTensor(),  # Converts to [0,1]
])

# Paths to the dataset (update these paths as per your directory structure)
train_images_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1-2_Training_Input'
train_masks_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1_Training_GroundTruth'
test_images_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1-2_Test_Input'
test_masks_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1_Test_GroundTruth'

# Create datasets
train_dataset = SegmentationDataset(train_images_dir, train_masks_dir, transform=transform)
test_dataset = SegmentationDataset(test_images_dir, test_masks_dir, transform=transform)

# Split training data into training and validation sets (80-20 split)
train_size = int(0.8 * len(train_dataset))
valid_size = len(train_dataset) - train_size
train_subset, valid_subset = torch.utils.data.random_split(train_dataset, [train_size, valid_size])

# Create DataLoaders
batch_size = 5

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_subset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# =============================== Training Setup ==========================================

# Instantiate the model
model = SwinUNet(input_channels=3, output_channels=1, embed_dim=32, num_heads=[4, 8], window_size=4, mlp_ratio=4., depth=2)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)  # Move to GPU if available

# Initialize weights using Kaiming Normal initialization
def initialize_weights(module):
    if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Conv3d):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)

model.apply(initialize_weights)

# Define loss function and optimizer
criterion = DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Define learning rate scheduler and early stopping parameters
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.25, patience=5, verbose=True, min_lr=1e-9)
early_stopping_patience = 9
best_val_loss = np.inf
epochs_no_improve = 0

# =============================== Training Loop ===========================================

num_epochs = 1  # You can adjust the number of epochs
device = 'cuda' if torch.cuda.is_available() else 'cpu'

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    model.train()
    running_loss = 0.0
    for images, masks in train_loader:
        images = images.to(device)  # (B, 3, 256, 256)
        masks = masks.to(device)    # (B, 1, 256, 256)

        optimizer.zero_grad()
        outputs = model(images)      # (B, 1, 256, 256)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in valid_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)

            val_loss += loss.item() * images.size(0)

    val_loss /= len(valid_loader.dataset)

    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Validation Loss: {val_loss:.4f}')

    # Scheduler step
    scheduler.step(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        # Save the best model
        torch.save(model.state_dict(), r'/content/drive/MyDrive/model/modelWeights_Swin_Trans_Weights_Swin_Trans_Leather.pth')
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= early_stopping_patience:
            print('Early stopping!')
            break

# ================================== Prediction ==========================================

# Load the best model weights
model.load_state_dict(torch.load(r'/content/drive/MyDrive/model/modelWeights_Swin_Trans_Weights_Swin_Trans_Leather.pth'))
model.eval()

# Function to save predictions and ground truth
def save_predictions(model, dataloader, save_dir_pred, save_dir_gt, device):
    """
    Saves the predicted masks and ground truth masks.
    Args:
        model (nn.Module): Trained model.
        dataloader (DataLoader): DataLoader for test data.
        save_dir_pred (str): Directory to save predicted masks.
        save_dir_gt (str): Directory to save ground truth masks.
        device (str): Device to run the model on.
    """
    os.makedirs(save_dir_pred, exist_ok=True)
    os.makedirs(save_dir_gt, exist_ok=True)

    with torch.no_grad():
        for i, (images, masks) in enumerate(dataloader):
            if (i % 100 == 0):
                print(f"{i}th Test Image")  # There are total 1000 test images.

            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = outputs.cpu().numpy()
            gts = masks.cpu().numpy()

            for j in range(preds.shape[0]):
                pred_mask = preds[j, 0, :, :]
                gt_mask = gts[j, 0, :, :]

                # Save predicted mask
                pred_img = Image.fromarray((pred_mask * 255).astype(np.uint8))
                pred_img.save(os.path.join(save_dir_pred, f"{i * dataloader.batch_size + j + 1}.png"))

                # Save ground truth mask
                gt_img = Image.fromarray((gt_mask * 255).astype(np.uint8))
                gt_img.save(os.path.join(save_dir_gt, f"{i * dataloader.batch_size + j + 1}.tiff"))

# Define directories to save predictions and ground truth
save_dir_pred = r'/content/drive/MyDrive/output/segmented predicted images'
save_dir_gt = r'/content/drive/MyDrive/output/segmented ground truth'

# Save predictions
save_predictions(model, test_loader, save_dir_pred, save_dir_gt, device)

# =================================== Evaluation =========================================

from sklearn.metrics import confusion_matrix, precision_recall_curve
import numpy as np
import torch

def evaluate_metrics_pytorch(model, dataloader, device):
    """
    Evaluates various metrics for segmentation performance.
    Args:
        model (nn.Module): Trained model.
        dataloader (DataLoader): DataLoader for test data.
        device (str): Device to run the model on.
    Returns:
        dict: Dictionary containing average metrics.
    """
    model.eval()
    all_accuracy = []
    all_dice = []
    all_jaccard = []
    all_sensitivity = []
    all_specificity = []

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = outputs > 0.5  # Binary mask

            preds = preds.cpu().numpy().astype(np.uint8)
            masks = masks.cpu().numpy().astype(np.uint8)
            masks = (masks > 0).astype(np.uint8)  # Convert 255 to 1

            for pred, mask in zip(preds, masks):
                pred_flat = pred.flatten()
                mask_flat = mask.flatten()

                # Precision-Recall Curve to find optimal threshold
                precisions, recalls, thresholds = precision_recall_curve(mask_flat, pred_flat)
                f1 = 2 * (precisions * recalls) / (precisions + recalls + 1e-8)
                max_idx = np.argmax(f1)
                optimal_thresh = thresholds[max_idx] if max_idx < len(thresholds) else 0.5

                # Apply optimal threshold
                pred_opt = (pred_flat >= optimal_thresh).astype(np.uint8)

                # Confusion matrix with specified labels
                cm = confusion_matrix(mask_flat, pred_opt, labels=[0,1])

                # Unpack confusion matrix
                tn, fp, fn, tp = cm.ravel()

                # Calculate metrics
                accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-8)
                iou = tp / (tp + fp + fn + 1e-8)
                dice = (2 * tp) / (2 * tp + fp + fn + 1e-8)
                specificity = tn / (tn + fp + 1e-8)
                sensitivity = tp / (tp + fn + 1e-8)

                all_accuracy.append(accuracy)
                all_jaccard.append(iou)
                all_dice.append(dice)
                all_specificity.append(specificity)
                all_sensitivity.append(sensitivity)

    # Compute average metrics
    metrics = {
        'Accuracy': np.mean(all_accuracy),
        'Dice': np.mean(all_dice),
        'Jaccard': np.mean(all_jaccard),
        'Sensitivity': np.mean(all_sensitivity),
        'Specificity': np.mean(all_specificity)
    }

    print(f"Accuracy: {metrics['Accuracy']:.4f}, Dice: {metrics['Dice']:.4f}, Jaccard: {metrics['Jaccard']:.4f}, "
          f"Sensitivity: {metrics['Sensitivity']:.4f}, Specificity: {metrics['Specificity']:.4f}")

    return metrics

# Evaluate the model
metrics = evaluate_metrics_pytorch(model, test_loader, device)

BIDIRECTIONAL CONVLSTM WITH MORE EVALUATION METRICS

In [None]:
# ============================== Imports and Dependencies ==============================

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    precision_recall_curve, confusion_matrix,
    jaccard_score, f1_score, precision_score, recall_score
)

# ================================ Separable Convolution =================================

class SeparableConv2d(nn.Module):
    """
    Implements a separable convolution layer using depthwise and pointwise convolutions.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, bias=True):
        super(SeparableConv2d, self).__init__()
        # Depthwise convolution (groups=in_channels)
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
                                   padding=padding, groups=in_channels, bias=bias)
        # Pointwise convolution
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1,
                                   padding=0, bias=bias)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

# ================================== ConvLSTM2D ========================================

class ConvLSTMCell(nn.Module):
    """
    Implements a ConvLSTM cell.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):
        super(ConvLSTMCell, self).__init__()

        padding = kernel_size // 2  # To maintain spatial dimensions
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.conv = nn.Conv2d(in_channels=input_channels + hidden_channels,
                              out_channels=4 * hidden_channels,
                              kernel_size=kernel_size,
                              padding=padding,
                              bias=bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        # Concatenate input and hidden state
        combined = torch.cat([input_tensor, h_cur], dim=1)  # along channel axis

        # Compute all gates at once
        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, self.hidden_channels, dim=1)

        i = torch.sigmoid(cc_i)   # input gate
        f = torch.sigmoid(cc_f)   # forget gate
        o = torch.sigmoid(cc_o)   # output gate
        g = torch.tanh(cc_g)      # gate gate

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, spatial_size, device):
        height, width = spatial_size
        return (torch.zeros(batch_size, self.hidden_channels, height, width, device=device),
                torch.zeros(batch_size, self.hidden_channels, height, width, device=device))

class ConvLSTM2D(nn.Module):
    """
    Implements a ConvLSTM2D layer that processes a sequence of inputs.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size=3, bias=True, num_layers=1):
        super(ConvLSTM2D, self).__init__()
        self.num_layers = num_layers
        self.hidden_channels = hidden_channels

        layers = []
        for i in range(num_layers):
            input_c = input_channels if i == 0 else hidden_channels
            layers.append(ConvLSTMCell(input_c, hidden_channels, kernel_size, bias))
        self.layers = nn.ModuleList(layers)

    def forward(self, input_tensor, reverse=False):
        # input_tensor shape: (batch, seq_len, channels, height, width)
        batch_size, seq_len, channels, height, width = input_tensor.size()
        device = input_tensor.device

        # Initialize hidden and cell states for all layers
        hidden_state = []
        cell_state = []
        for i in range(self.num_layers):
            h, c = self.layers[i].init_hidden(batch_size, (height, width), device)
            hidden_state.append(h)
            cell_state.append(c)

        # Iterate over time steps
        if reverse:
            time_steps = reversed(range(seq_len))
        else:
            time_steps = range(seq_len)

        outputs = []
        for t in time_steps:
            x = input_tensor[:, t, :, :, :]  # (batch, channels, height, width)
            for i, layer in enumerate(self.layers):
                h, c = layer(x, (hidden_state[i], cell_state[i]))
                hidden_state[i] = h
                cell_state[i] = c
                x = h  # input to next layer
            outputs.append(x)

        outputs = torch.stack(outputs, dim=1)  # (batch, seq_len, channels, height, width)
        if reverse:
            outputs = outputs.flip(dims=[1])  # Reverse back to original order
        return outputs  # Return the sequence of outputs

class BidirectionalConvLSTM2D(nn.Module):
    """
    Implements a Bidirectional ConvLSTM2D layer.
    Processes the input sequence in both forward and backward directions and concatenates the outputs.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size=3, num_layers=1, bias=True):
        super(BidirectionalConvLSTM2D, self).__init__()
        self.forward_conv_lstm = ConvLSTM2D(input_channels, hidden_channels, kernel_size, bias=bias, num_layers=num_layers)
        self.backward_conv_lstm = ConvLSTM2D(input_channels, hidden_channels, kernel_size, bias=bias, num_layers=num_layers)

    def forward(self, input_tensor):
        # input_tensor shape: (batch, seq_len, channels, height, width)
        # Forward direction
        forward_output = self.forward_conv_lstm(input_tensor, reverse=False)  # (batch, seq_len, hidden_channels, H, W)
        # Backward direction
        backward_output = self.backward_conv_lstm(input_tensor, reverse=True)  # (batch, seq_len, hidden_channels, H, W)
        # Concatenate outputs along the channel dimension
        output = torch.cat([forward_output, backward_output], dim=2)  # (batch, seq_len, hidden_channels*2, H, W)
        # Since seq_len=2, we can take the last output
        output = output[:, -1, :, :, :]  # Take the last output (batch, hidden_channels*2, H, W)
        return output

# ============================== Swin Transformer Blocks ================================

class WindowAttention(nn.Module):
    """
    Window based multi-head self attention (W-MSA) module with relative position bias.
    """
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
        """
        Args:
            dim (int): Number of input channels.
            window_size (tuple): Height and width of the window.
            num_heads (int): Number of attention heads.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            attn_drop (float): Dropout ratio of attention weights.
            proj_drop (float): Dropout ratio of output.
        """
        super(WindowAttention, self).__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # Define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
        )  # 2*Wh-1 * 2*Ww-1, nH

        # Get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij'))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1)
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # Query, Key, Value
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # Initialize relative position bias table
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, Wh*Ww, C)
            mask: (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # 3, B_, nH, N, C//nH
        q, k, v = qkv[0], qkv[1], qkv[2]  # each has shape (B_, nH, N, C//nH)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # (B_, nH, N, N)

        # Add relative position bias
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
        )  # Wh*Ww, Wh*Ww, nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)  # (B_, nH, N, N)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = F.softmax(attn, dim=-1)
        else:
            attn = F.softmax(attn, dim=-1)

        attn = self.attn_drop(attn)

        out = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # (B_, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class SwinTransformerBlock(nn.Module):
    """
    Swin Transformer Block with W-MSA and SW-MSA.
    """
    def __init__(self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True,
                 attn_drop=0., proj_drop=0.):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size  # W
        self.shift_size = shift_size    # S
        self.mlp_ratio = mlp_ratio

        assert 0 <= self.shift_size < self.window_size, "shift_size must be in [0, window_size)"

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, (window_size, window_size), num_heads, qkv_bias, attn_drop, proj_drop)

        self.drop_path = nn.Identity()  # Can implement stochastic depth if desired
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(proj_drop)
        )

    def forward(self, x):
        """
        Args:
            x: input features with shape (B, H*W, C)
        """
        H = W = int(np.sqrt(x.shape[1]))
        B, L, C = x.shape
        assert L == H * W, "Input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # Cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # Partition windows
        window_size = self.window_size
        # Pad H and W to be multiples of window_size
        pad_b = (window_size - H % window_size) % window_size
        pad_r = (window_size - W % window_size) % window_size
        shifted_x = F.pad(shifted_x, (0, 0, 0, pad_r, 0, pad_b))  # pad H and W
        _, Hp, Wp, _ = shifted_x.shape

        # Window partition
        x_windows = shifted_x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
        x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)  # (num_windows*B, window_size*window_size, C)

        # Attention
        attn_windows = self.attn(x_windows)  # (num_windows*B, window_size*window_size, C)

        # Merge windows
        shifted_x = attn_windows.view(-1, window_size, window_size, C)
        shifted_x = shifted_x.view(B, Hp // window_size, Wp // window_size, window_size, window_size, C)
        shifted_x = shifted_x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, C)

        # Reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        # Remove padding
        x = x[:, :H, :W, :].contiguous().view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

# =============================== Dice Loss Function ====================================

class DiceLoss(nn.Module):
    """
    Dice Loss function to maximize the Dice coefficient.
    Suitable for binary segmentation tasks.
    """
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred (torch.Tensor): Predicted mask probabilities with shape (B, 1, H, W)
            y_true (torch.Tensor): Ground truth masks with shape (B, 1, H, W)
        Returns:
            torch.Tensor: Dice loss
        """
        y_pred = y_pred.view(-1)
        y_true = y_true.view(-1)

        intersection = (y_pred * y_true).sum()
        dice = (2. * intersection + self.smooth) / (y_pred.sum() + y_true.sum() + self.smooth)

        return 1 - dice

# ================================ Main Model ============================================

class SwinUNet(nn.Module):
    """
    Swin U-Net architecture for image segmentation with bidirectional ConvLSTM layers.
    """
    def __init__(self, input_channels=3, output_channels=1,
                 embed_dim=32, num_heads=[4, 8], window_size=4,
                 mlp_ratio=4., depth=2):
        super(SwinUNet, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels

        # Initial convolutional layers
        self.conv1 = SeparableConv2d(input_channels, 24, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(24)
        self.conv2 = SeparableConv2d(24, 24, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(24)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 256x256 -> 128x128

        # First Swin Transformer Block
        self.swin_unet_E1 = SwinTransformerBlock(
            dim=24,  # Changed from embed_dim=32 to 24
            num_heads=num_heads[0],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )

        # Second convolutional block
        self.conv3 = SeparableConv2d(24, 48, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(48)
        self.conv4 = SeparableConv2d(48, 48, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(48)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 128x128 -> 64x64

        # Second Swin Transformer Block
        self.swin_unet_E2 = SwinTransformerBlock(
            dim=48,
            num_heads=num_heads[1],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )

        # Third convolutional block (Bottleneck)
        self.conv5 = SeparableConv2d(48, 96, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(96)
        self.conv6 = SeparableConv2d(96, 96, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(96)
        self.drop5 = nn.Dropout(0.5)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)  # 64x64 -> 32x32

        # Bottleneck convolutions with dense connections
        self.conv7 = SeparableConv2d(96, 192, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(192)
        self.conv8 = SeparableConv2d(192, 192, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(192)
        self.drop6_1 = nn.Dropout(0.5)

        self.conv9 = SeparableConv2d(192, 192, kernel_size=3, padding=1)
        self.bn9 = nn.BatchNorm2d(192)
        self.conv10 = SeparableConv2d(192, 192, kernel_size=3, padding=1)
        self.bn10 = nn.BatchNorm2d(192)
        self.drop6_2 = nn.Dropout(0.5)

        self.concat1 = nn.Sequential(
            SeparableConv2d(384, 192, kernel_size=3, padding=1),
            SeparableConv2d(192, 192, kernel_size=3, padding=1)
        )
        self.drop6_3 = nn.Dropout(0.5)

        # First Upsampling Block
        self.up1 = nn.ConvTranspose2d(192, 96, kernel_size=2, stride=2)  # 32x32 -> 64x64
        self.bn_up1 = nn.BatchNorm2d(96)
        self.relu_up1 = nn.ReLU(inplace=True)
        self.bidirectional_convLSTM1 = BidirectionalConvLSTM2D(input_channels=96, hidden_channels=192, kernel_size=3, num_layers=1)
        self.swin_unet_D1 = SwinTransformerBlock(
            dim=192 * 2,  # Adjusted for bidirectional output
            num_heads=num_heads[0],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )
        self.conv11 = SeparableConv2d(192 * 2, 48, kernel_size=3, padding=1)
        self.conv12 = SeparableConv2d(48, 48, kernel_size=3, padding=1)

        # Second Upsampling Block
        self.up2 = nn.ConvTranspose2d(48, 48, kernel_size=2, stride=2)  # 64x64 -> 128x128
        self.bn_up2 = nn.BatchNorm2d(48)
        self.relu_up2 = nn.ReLU(inplace=True)
        self.bidirectional_convLSTM2 = BidirectionalConvLSTM2D(input_channels=48, hidden_channels=96, kernel_size=3, num_layers=1)
        self.swin_unet_D2 = SwinTransformerBlock(
            dim=96 * 2,
            num_heads=num_heads[1],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )
        self.conv13 = SeparableConv2d(96 * 2, 24, kernel_size=3, padding=1)
        self.conv14 = SeparableConv2d(24, 24, kernel_size=3, padding=1)

        # Third Upsampling Block
        self.up3 = nn.ConvTranspose2d(24, 24, kernel_size=2, stride=2)  # 128x128 -> 256x256
        self.bn_up3 = nn.BatchNorm2d(24)
        self.relu_up3 = nn.ReLU(inplace=True)
        self.bidirectional_convLSTM3 = BidirectionalConvLSTM2D(input_channels=24, hidden_channels=48, kernel_size=3, num_layers=1)
        self.swin_unet_D3 = SwinTransformerBlock(
            dim=48 * 2,
            num_heads=num_heads[1],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )
        self.conv15 = SeparableConv2d(48 * 2, 24, kernel_size=3, padding=1)
        self.conv16 = SeparableConv2d(24, 24, kernel_size=3, padding=1)

        # Output Layer
        self.final_conv1 = nn.Conv2d(24, 2, kernel_size=3, padding=1)
        self.final_relu = nn.ReLU(inplace=True)
        self.final_conv2 = nn.Conv2d(2, 1, kernel_size=1, padding=0)
        self.final_sigmoid = nn.Sigmoid()

    def forward(self, x):
        """
        Forward pass of the Swin U-Net model.
        Args:
            x: Input tensor with shape (B, 3, 256, 256)
        Returns:
            torch.Tensor: Output segmentation mask with shape (B, 1, 256, 256)
        """
        # Initial Convolutions
        x1 = self.conv1(x)          # (B, 24, 256, 256)
        x1 = self.bn1(x1)
        x1 = self.conv2(x1)         # (B, 24, 256, 256)
        x1 = self.bn2(x1)
        p1 = self.pool1(x1)         # (B, 24, 128, 128)

        # First Swin Transformer Block
        p1_flat = p1.flatten(2).transpose(1, 2)  # (B, 128*128, 24)
        swin_E1 = self.swin_unet_E1(p1_flat)     # (B, 128*128, 24)
        swin_E1 = swin_E1.transpose(1, 2).view(-1, 24, 128, 128)  # Reshape for Conv2d

        # Second Convolutional Block
        x2 = self.conv3(swin_E1)    # (B, 48, 128, 128)
        x2 = self.bn3(x2)
        x2 = self.conv4(x2)          # (B, 48, 128, 128)
        x2 = self.bn4(x2)
        p2 = self.pool2(x2)          # (B, 48, 64, 64)

        # Second Swin Transformer Block
        p2_flat = p2.flatten(2).transpose(1, 2)  # (B, 64*64, 48)
        swin_E2 = self.swin_unet_E2(p2_flat)     # (B, 64*64, 48)
        swin_E2 = swin_E2.transpose(1, 2).view(-1, 48, 64, 64)  # Reshape for Conv2d

        # Third Convolutional Block (Bottleneck)
        x3 = self.conv5(swin_E2)    # (B, 96, 64, 64)
        x3 = self.bn5(x3)
        x3 = self.conv6(x3)          # (B, 96, 64, 64)
        x3 = self.bn6(x3)
        x3 = self.drop5(x3)
        p3 = self.pool3(x3)          # (B, 96, 32, 32)

        # Bottleneck Convolutions with Dense Connections
        x4 = self.conv7(p3)          # (B, 192, 32, 32)
        x4 = self.bn7(x4)
        x4 = self.conv8(x4)          # (B, 192, 32, 32)
        x4 = self.bn8(x4)
        x4 = self.drop6_1(x4)

        x5 = self.conv9(x4)          # (B, 192, 32, 32)
        x5 = self.bn9(x5)
        x5 = self.conv10(x5)         # (B, 192, 32, 32)
        x5 = self.bn10(x5)
        x5 = self.drop6_2(x5)

        concat = torch.cat([x5, x4], dim=1)  # (B, 384, 32, 32)
        concat = self.concat1(concat)         # (B, 192, 32, 32)
        concat = self.drop6_3(concat)         # (B, 192, 32, 32)

        # First Upsampling Block
        up1 = self.up1(concat)                 # (B, 96, 64, 64)
        up1 = self.bn_up1(up1)
        up1 = self.relu_up1(up1)

        # Prepare for BidirectionalConvLSTM2D
        up1_seq = torch.stack([x3, up1], dim=1)  # (B, 2, 96, 64, 64)
        bidir_convLSTM1_out = self.bidirectional_convLSTM1(up1_seq)  # (B, 192*2, 64, 64)

        # Swin Transformer Block in Decoder
        bidir_convLSTM1_flat = bidir_convLSTM1_out.flatten(2).transpose(1, 2)  # (B, 64*64, 192*2)
        swin_D1 = self.swin_unet_D1(bidir_convLSTM1_flat)               # (B, 64*64, 192*2)
        swin_D1 = swin_D1.transpose(1, 2).view(-1, 192*2, 64, 64)    # Reshape for Conv2d

        # Further Convolutions
        conv6 = self.conv11(swin_D1)        # (B, 48, 64, 64)
        conv6 = self.conv12(conv6)          # (B, 48, 64, 64)

        # Second Upsampling Block
        up2 = self.up2(conv6)               # (B, 48, 128, 128)
        up2 = self.bn_up2(up2)
        up2 = self.relu_up2(up2)

        # Prepare for BidirectionalConvLSTM2D
        up2_seq = torch.stack([x2, up2], dim=1)  # (B, 2, 48, 128, 128)
        bidir_convLSTM2_out = self.bidirectional_convLSTM2(up2_seq)  # (B, 96*2, 128, 128)

        # Swin Transformer Block in Decoder
        bidir_convLSTM2_flat = bidir_convLSTM2_out.flatten(2).transpose(1, 2)  # (B, 128*128, 96*2)
        swin_D2 = self.swin_unet_D2(bidir_convLSTM2_flat)               # (B, 128*128, 96*2)
        swin_D2 = swin_D2.transpose(1, 2).view(-1, 96*2, 128, 128)    # Reshape for Conv2d

        # Further Convolutions
        conv7 = self.conv13(swin_D2)        # (B, 24, 128, 128)
        conv7 = self.conv14(conv7)          # (B, 24, 128, 128)

        # Third Upsampling Block
        up3 = self.up3(conv7)               # (B, 24, 256, 256)
        up3 = self.bn_up3(up3)
        up3 = self.relu_up3(up3)

        # Prepare for BidirectionalConvLSTM2D
        up3_seq = torch.stack([x1, up3], dim=1)  # (B, 2, 24, 256, 256)
        bidir_convLSTM3_out = self.bidirectional_convLSTM3(up3_seq)  # (B, 48*2, 256, 256)

        # Swin Transformer Block in Decoder
        bidir_convLSTM3_flat = bidir_convLSTM3_out.flatten(2).transpose(1, 2)  # (B, 256*256, 48*2)
        swin_D3 = self.swin_unet_D3(bidir_convLSTM3_flat)               # (B, 256*256, 48*2)
        swin_D3 = swin_D3.transpose(1, 2).view(-1, 48*2, 256, 256)    # Reshape for Conv2d

        # Further Convolutions
        conv8 = self.conv15(swin_D3)        # (B, 24, 256, 256)
        conv8 = self.conv16(conv8)          # (B, 24, 256, 256)

        # Final Output Convolutions
        final = self.final_conv1(conv8)      # (B, 2, 256, 256)
        final = self.final_relu(final)
        final = self.final_conv2(final)      # (B, 1, 256, 256)
        final = self.final_sigmoid(final)    # (B, 1, 256, 256)

        return final

# ================================== Dataset Class ======================================

class SegmentationDataset(Dataset):
    """
    Custom Dataset for image segmentation tasks.
    Expects images in 'x' folder and masks in 'y' folder.
    """
    def __init__(self, images_dir, masks_dir, transform=None):
        super(SegmentationDataset, self).__init__()
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform

        self.images = sorted(os.listdir(images_dir))
        self.masks = sorted(os.listdir(masks_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.images_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')  # Ensure RGB

        # Load mask
        mask_path = os.path.join(self.masks_dir, self.masks[idx])
        mask = Image.open(mask_path).convert('L')    # Grayscale

        # Apply transformations
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

# =============================== Data Loading and Preprocessing ========================

# Define image dimensions
im_height = 256
im_width = 256

# Define transformations
transform = transforms.Compose([
    transforms.Resize((im_height, im_width)),
    transforms.ToTensor(),  # Converts to [0,1]
])

# Paths to the dataset (update these paths as per your directory structure)
train_images_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1-2_Training_Input'
train_masks_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1_Training_GroundTruth'
test_images_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1-2_Test_Input'
test_masks_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1_Test_GroundTruth'

# Create datasets
train_dataset = SegmentationDataset(train_images_dir, train_masks_dir, transform=transform)
test_dataset = SegmentationDataset(test_images_dir, test_masks_dir, transform=transform)

# Split training data into training and validation sets (80-20 split)
train_size = int(0.8 * len(train_dataset))
valid_size = len(train_dataset) - train_size
train_subset, valid_subset = torch.utils.data.random_split(train_dataset, [train_size, valid_size])

# Create DataLoaders
batch_size = 5

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_subset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# =============================== Training Setup ==========================================

# Instantiate the model
model = SwinUNet(input_channels=3, output_channels=1, embed_dim=32, num_heads=[4, 8], window_size=4, mlp_ratio=4., depth=2)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)  # Move to GPU if available

# Initialize weights using Kaiming Normal initialization
def initialize_weights(module):
    if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Conv3d):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)

model.apply(initialize_weights)

# Define loss function and optimizer
criterion = DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Define learning rate scheduler and early stopping parameters
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.25, patience=5, verbose=True, min_lr=1e-9)
early_stopping_patience = 9
best_val_loss = np.inf
epochs_no_improve = 0

# =============================== Training Loop ===========================================

num_epochs = 1  # You can adjust the number of epochs

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    model.train()
    running_loss = 0.0
    train_loader_count = 0
    for images, masks in train_loader:
        train_loader_count += 1
        print(f"Train loader image count: {train_loader_count}")
        images = images.to(device)  # (B, 3, 256, 256)
        masks = masks.to(device)    # (B, 1, 256, 256)

        optimizer.zero_grad()
        outputs = model(images)      # (B, 1, 256, 256)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in valid_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)

            val_loss += loss.item() * images.size(0)

    val_loss /= len(valid_loader.dataset)

    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Validation Loss: {val_loss:.4f}')

    # Scheduler step
    scheduler.step(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        # Save the best model
        torch.save(model.state_dict(), r'/content/drive/MyDrive/model/modelWeights_Swin_Trans_Weights_Swin_Trans_Leather.pth')
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= early_stopping_patience:
            print('Early stopping!')
            break

# ================================== Prediction ==========================================

# Load the best model weights
model.load_state_dict(torch.load(r'/content/drive/MyDrive/model/modelWeights_Swin_Trans_Weights_Swin_Trans_Leather.pth'))
model.eval()

# Function to save predictions and ground truth
def save_predictions(model, dataloader, save_dir_pred, save_dir_gt, device):
    """
    Saves the predicted masks and ground truth masks.
    Args:
        model (nn.Module): Trained model.
        dataloader (DataLoader): DataLoader for test data.
        save_dir_pred (str): Directory to save predicted masks.
        save_dir_gt (str): Directory to save ground truth masks.
        device (str): Device to run the model on.
    """
    os.makedirs(save_dir_pred, exist_ok=True)
    os.makedirs(save_dir_gt, exist_ok=True)

    with torch.no_grad():
        for i, (images, masks) in enumerate(dataloader):
            if (i % 100 == 0):
                print(f"{i}th Test Image")  # Adjust as per your dataset

            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = outputs.cpu().numpy()
            gts = masks.cpu().numpy()

            for j in range(preds.shape[0]):
                pred_mask = preds[j, 0, :, :]
                gt_mask = gts[j, 0, :, :]

                # Save predicted mask
                pred_img = Image.fromarray((pred_mask * 255).astype(np.uint8))
                pred_img.save(os.path.join(save_dir_pred, f"{i * dataloader.batch_size + j + 1}.png"))

                # Save ground truth mask
                gt_img = Image.fromarray((gt_mask * 255).astype(np.uint8))
                gt_img.save(os.path.join(save_dir_gt, f"{i * dataloader.batch_size + j + 1}.tiff"))

# Define directories to save predictions and ground truth
save_dir_pred = r'/content/drive/MyDrive/output/segmented predicted images'
save_dir_gt = r'/content/drive/MyDrive/output/segmented ground truth'

# Save predictions
save_predictions(model, test_loader, save_dir_pred, save_dir_gt, device)

# =================================== Evaluation =========================================

def evaluate_metrics_pytorch(model, dataloader, device):
    """
    Evaluates various metrics for segmentation performance.
    Args:
        model (nn.Module): Trained model.
        dataloader (DataLoader): DataLoader for test data.
        device (str): Device to run the model on.
    Returns:
        dict: Dictionary containing average metrics.
    """
    model.eval()
    all_metrics = {
        'Accuracy': [],
        'Dice': [],
        'Jaccard': [],
        'Sensitivity': [],
        'Specificity': [],
        'Precision': [],
        'Recall': [],
        'F1-Score': []
    }

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = outputs > 0.5  # Binary mask

            preds = preds.cpu().numpy().astype(np.uint8)
            masks = masks.cpu().numpy().astype(np.uint8)
            masks = (masks > 0).astype(np.uint8)  # Convert to binary masks

            for pred, mask in zip(preds, masks):
                pred_flat = pred.flatten()
                mask_flat = mask.flatten()

                # Calculate metrics
                tn, fp, fn, tp = confusion_matrix(mask_flat, pred_flat, labels=[0,1]).ravel()

                accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-8)
                iou = jaccard_score(mask_flat, pred_flat, zero_division=0)
                dice = f1_score(mask_flat, pred_flat, zero_division=0)
                specificity = tn / (tn + fp + 1e-8)
                sensitivity = recall_score(mask_flat, pred_flat, zero_division=0)
                precision = precision_score(mask_flat, pred_flat, zero_division=0)
                recall = sensitivity
                f1 = dice  # F1-Score is the same as Dice coefficient for binary classification

                all_metrics['Accuracy'].append(accuracy)
                all_metrics['Jaccard'].append(iou)
                all_metrics['Dice'].append(dice)
                all_metrics['Specificity'].append(specificity)
                all_metrics['Sensitivity'].append(sensitivity)
                all_metrics['Precision'].append(precision)
                all_metrics['Recall'].append(recall)
                all_metrics['F1-Score'].append(f1)

    # Compute average metrics
    avg_metrics = {metric: np.mean(values) for metric, values in all_metrics.items()}

    print("Evaluation Metrics:")
    for metric, value in avg_metrics.items():
        print(f"{metric}: {value:.4f}")

    return avg_metrics

# Evaluate the model
metrics = evaluate_metrics_pytorch(model, test_loader, device)


Epoch 1/1




Train loader image count: 1
Train loader image count: 2
Train loader image count: 3
Train loader image count: 4
Train loader image count: 5
Train loader image count: 6
Train loader image count: 7
Train loader image count: 8
Train loader image count: 9
Train loader image count: 10
Train loader image count: 11
Train loader image count: 12
Train loader image count: 13
Train loader image count: 14
Train loader image count: 15
Train loader image count: 16
Train loader image count: 17
Train loader image count: 18
Train loader image count: 19
Train loader image count: 20
Train loader image count: 21
Train loader image count: 22
Train loader image count: 23
Train loader image count: 24
Train loader image count: 25
Train loader image count: 26
Train loader image count: 27
Train loader image count: 28
Train loader image count: 29
Train loader image count: 30
Train loader image count: 31
Train loader image count: 32
Train loader image count: 33
Train loader image count: 34
Train loader image coun

  model.load_state_dict(torch.load(r'/content/drive/MyDrive/model/modelWeights_Swin_Trans_Weights_Swin_Trans_Leather.pth'))


0th Test Image
100th Test Image
Evaluation Metrics:
Accuracy: 0.7963
Dice: 0.4221
Jaccard: 0.3285
Sensitivity: 0.3730
Specificity: 0.9808
Precision: 0.7010
Recall: 0.3730
F1-Score: 0.4221


In [None]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [None]:
from torchinfo import summary

# Define the input size based on your model's expected input.
# For example, if your model expects images with 3 channels and 256x256 dimensions:
input_size = (1, 3, 256, 256)  # (batch_size, channels, height, width)

# Generate and print the model summary
summary(model, input_size=input_size, device='cuda' if torch.cuda.is_available() else 'cpu')

Layer (type:depth-idx)                   Output Shape              Param #
SwinUNet                                 [1, 1, 256, 256]          --
├─SeparableConv2d: 1-1                   [1, 24, 256, 256]         --
│    └─Conv2d: 2-1                       [1, 3, 256, 256]          30
│    └─Conv2d: 2-2                       [1, 24, 256, 256]         96
├─BatchNorm2d: 1-2                       [1, 24, 256, 256]         48
├─SeparableConv2d: 1-3                   [1, 24, 256, 256]         --
│    └─Conv2d: 2-3                       [1, 24, 256, 256]         240
│    └─Conv2d: 2-4                       [1, 24, 256, 256]         600
├─BatchNorm2d: 1-4                       [1, 24, 256, 256]         48
├─MaxPool2d: 1-5                         [1, 24, 128, 128]         --
├─SwinTransformerBlock: 1-6              [1, 16384, 24]            --
│    └─LayerNorm: 2-5                    [1, 16384, 24]            48
│    └─WindowAttention: 2-6              [1024, 16, 24]            196
│    │    └─

In [None]:
# ============================== Imports and Dependencies ==============================

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    precision_recall_curve, confusion_matrix,
    jaccard_score, f1_score, precision_score, recall_score
)

# ================================ Separable Convolution =================================

class SeparableConv2d(nn.Module):
    """
    Implements a separable convolution layer using depthwise and pointwise convolutions.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, bias=True):
        super(SeparableConv2d, self).__init__()
        # Depthwise convolution (groups=in_channels)
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
                                   padding=padding, groups=in_channels, bias=bias)
        # Pointwise convolution
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1,
                                   padding=0, bias=bias)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

# ================================== ConvLSTM2D ========================================

class ConvLSTMCell(nn.Module):
    """
    Implements a ConvLSTM cell.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):
        super(ConvLSTMCell, self).__init__()

        padding = kernel_size // 2  # To maintain spatial dimensions
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.conv = nn.Conv2d(in_channels=input_channels + hidden_channels,
                              out_channels=4 * hidden_channels,
                              kernel_size=kernel_size,
                              padding=padding,
                              bias=bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        # Concatenate input and hidden state
        combined = torch.cat([input_tensor, h_cur], dim=1)  # along channel axis

        # Compute all gates at once
        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, self.hidden_channels, dim=1)

        i = torch.sigmoid(cc_i)   # input gate
        f = torch.sigmoid(cc_f)   # forget gate
        o = torch.sigmoid(cc_o)   # output gate
        g = torch.tanh(cc_g)      # gate gate

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, spatial_size, device):
        height, width = spatial_size
        return (torch.zeros(batch_size, self.hidden_channels, height, width, device=device),
                torch.zeros(batch_size, self.hidden_channels, height, width, device=device))

class ConvLSTM2D(nn.Module):
    """
    Implements a ConvLSTM2D layer that processes a sequence of inputs.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size=3, bias=True, num_layers=1):
        super(ConvLSTM2D, self).__init__()
        self.num_layers = num_layers
        self.hidden_channels = hidden_channels

        layers = []
        for i in range(num_layers):
            input_c = input_channels if i == 0 else hidden_channels
            layers.append(ConvLSTMCell(input_c, hidden_channels, kernel_size, bias))
        self.layers = nn.ModuleList(layers)

    def forward(self, input_tensor, reverse=False):
        # input_tensor shape: (batch, seq_len, channels, height, width)
        batch_size, seq_len, channels, height, width = input_tensor.size()
        device = input_tensor.device

        # Initialize hidden and cell states for all layers
        hidden_state = []
        cell_state = []
        for i in range(self.num_layers):
            h, c = self.layers[i].init_hidden(batch_size, (height, width), device)
            hidden_state.append(h)
            cell_state.append(c)

        # Iterate over time steps
        if reverse:
            time_steps = reversed(range(seq_len))
        else:
            time_steps = range(seq_len)

        outputs = []
        for t in time_steps:
            x = input_tensor[:, t, :, :, :]  # (batch, channels, height, width)
            for i, layer in enumerate(self.layers):
                h, c = layer(x, (hidden_state[i], cell_state[i]))
                hidden_state[i] = h
                cell_state[i] = c
                x = h  # input to next layer
            outputs.append(x)

        outputs = torch.stack(outputs, dim=1)  # (batch, seq_len, channels, height, width)
        if reverse:
            outputs = outputs.flip(dims=[1])  # Reverse back to original order
        return outputs  # Return the sequence of outputs

class BidirectionalConvLSTM2D(nn.Module):
    """
    Implements a Bidirectional ConvLSTM2D layer.
    Processes the input sequence in both forward and backward directions and concatenates the outputs.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size=3, num_layers=1, bias=True):
        super(BidirectionalConvLSTM2D, self).__init__()
        self.forward_conv_lstm = ConvLSTM2D(input_channels, hidden_channels, kernel_size, bias=bias, num_layers=num_layers)
        self.backward_conv_lstm = ConvLSTM2D(input_channels, hidden_channels, kernel_size, bias=bias, num_layers=num_layers)

    def forward(self, input_tensor):
        # input_tensor shape: (batch, seq_len, channels, height, width)
        # Forward direction
        forward_output = self.forward_conv_lstm(input_tensor, reverse=False)  # (batch, seq_len, hidden_channels, H, W)
        # Backward direction
        backward_output = self.backward_conv_lstm(input_tensor, reverse=True)  # (batch, seq_len, hidden_channels, H, W)
        # Concatenate outputs along the channel dimension
        output = torch.cat([forward_output, backward_output], dim=2)  # (batch, seq_len, hidden_channels*2, H, W)
        # Since seq_len=2, we can take the last output
        output = output[:, -1, :, :, :]  # Take the last output (batch, hidden_channels*2, H, W)
        return output

# ============================== Swin Transformer Blocks ================================

class WindowAttention(nn.Module):
    """
    Window based multi-head self attention (W-MSA) module with relative position bias.
    """
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
        """
        Args:
            dim (int): Number of input channels.
            window_size (tuple): Height and width of the window.
            num_heads (int): Number of attention heads.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            attn_drop (float): Dropout ratio of attention weights.
            proj_drop (float): Dropout ratio of output.
        """
        super(WindowAttention, self).__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # Define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
        )  # 2*Wh-1 * 2*Ww-1, nH

        # Get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij'))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1)
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # Query, Key, Value
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # Initialize relative position bias table
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, Wh*Ww, C)
            mask: (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # 3, B_, nH, N, C//nH
        q, k, v = qkv[0], qkv[1], qkv[2]  # each has shape (B_, nH, N, C//nH)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # (B_, nH, N, N)

        # Add relative position bias
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
        )  # Wh*Ww, Wh*Ww, nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)  # (B_, nH, N, N)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = F.softmax(attn, dim=-1)
        else:
            attn = F.softmax(attn, dim=-1)

        attn = self.attn_drop(attn)

        out = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # (B_, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class SwinTransformerBlock(nn.Module):
    """
    Swin Transformer Block with W-MSA and SW-MSA.
    """
    def __init__(self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True,
                 attn_drop=0., proj_drop=0.):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size  # W
        self.shift_size = shift_size    # S
        self.mlp_ratio = mlp_ratio

        assert 0 <= self.shift_size < self.window_size, "shift_size must be in [0, window_size)"

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, (window_size, window_size), num_heads, qkv_bias, attn_drop, proj_drop)

        self.drop_path = nn.Identity()  # Can implement stochastic depth if desired
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(proj_drop)
        )

    def forward(self, x):
        """
        Args:
            x: input features with shape (B, H*W, C)
        """
        H = W = int(np.sqrt(x.shape[1]))
        B, L, C = x.shape
        assert L == H * W, "Input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # Cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # Partition windows
        window_size = self.window_size
        # Pad H and W to be multiples of window_size
        pad_b = (window_size - H % window_size) % window_size
        pad_r = (window_size - W % window_size) % window_size
        shifted_x = F.pad(shifted_x, (0, 0, 0, pad_r, 0, pad_b))  # pad H and W
        _, Hp, Wp, _ = shifted_x.shape

        # Window partition
        x_windows = shifted_x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
        x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)  # (num_windows*B, window_size*window_size, C)

        # Attention
        attn_windows = self.attn(x_windows)  # (num_windows*B, window_size*window_size, C)

        # Merge windows
        shifted_x = attn_windows.view(-1, window_size, window_size, C)
        shifted_x = shifted_x.view(B, Hp // window_size, Wp // window_size, window_size, window_size, C)
        shifted_x = shifted_x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, C)

        # Reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        # Remove padding
        x = x[:, :H, :W, :].contiguous().view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

# =============================== Dice Loss Function ====================================

class DiceLoss(nn.Module):
    """
    Dice Loss function to maximize the Dice coefficient.
    Suitable for binary segmentation tasks.
    """
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred (torch.Tensor): Predicted mask probabilities with shape (B, 1, H, W)
            y_true (torch.Tensor): Ground truth masks with shape (B, 1, H, W)
        Returns:
            torch.Tensor: Dice loss
        """
        y_pred = y_pred.view(-1)
        y_true = y_true.view(-1)

        intersection = (y_pred * y_true).sum()
        dice = (2. * intersection + self.smooth) / (y_pred.sum() + y_true.sum() + self.smooth)

        return 1 - dice

# ================================ Main Model ============================================

class SwinUNet(nn.Module):
    """
    Swin U-Net architecture for image segmentation with bidirectional ConvLSTM layers.
    """
    def __init__(self, input_channels=3, output_channels=1,
                 embed_dim=32, num_heads=[4, 8], window_size=4,
                 mlp_ratio=4., depth=2):
        super(SwinUNet, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels

        # Initial convolutional layers
        self.conv1 = SeparableConv2d(input_channels, 24, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(24)
        self.conv2 = SeparableConv2d(24, 24, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(24)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 256x256 -> 128x128

        # First Swin Transformer Block
        self.swin_unet_E1 = SwinTransformerBlock(
            dim=24,  # Changed from embed_dim=32 to 24
            num_heads=num_heads[0],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )

        # Second convolutional block
        self.conv3 = SeparableConv2d(24, 48, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(48)
        self.conv4 = SeparableConv2d(48, 48, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(48)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 128x128 -> 64x64

        # Second Swin Transformer Block
        self.swin_unet_E2 = SwinTransformerBlock(
            dim=48,
            num_heads=num_heads[1],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )

        # Third convolutional block (Bottleneck)
        self.conv5 = SeparableConv2d(48, 96, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(96)
        self.conv6 = SeparableConv2d(96, 96, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(96)
        self.drop5 = nn.Dropout(0.5)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)  # 64x64 -> 32x32

        # Bottleneck convolutions with dense connections
        self.conv7 = SeparableConv2d(96, 192, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(192)
        self.conv8 = SeparableConv2d(192, 192, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(192)
        self.drop6_1 = nn.Dropout(0.5)

        self.conv9 = SeparableConv2d(192, 192, kernel_size=3, padding=1)
        self.bn9 = nn.BatchNorm2d(192)
        self.conv10 = SeparableConv2d(192, 192, kernel_size=3, padding=1)
        self.bn10 = nn.BatchNorm2d(192)
        self.drop6_2 = nn.Dropout(0.5)

        self.concat1 = nn.Sequential(
            SeparableConv2d(384, 192, kernel_size=3, padding=1),
            SeparableConv2d(192, 192, kernel_size=3, padding=1)
        )
        self.drop6_3 = nn.Dropout(0.5)

        # First Upsampling Block
        self.up1 = nn.ConvTranspose2d(192, 96, kernel_size=2, stride=2)  # 32x32 -> 64x64
        self.bn_up1 = nn.BatchNorm2d(96)
        self.relu_up1 = nn.ReLU(inplace=True)
        self.bidirectional_convLSTM1 = BidirectionalConvLSTM2D(input_channels=96, hidden_channels=192, kernel_size=3, num_layers=1)
        self.swin_unet_D1 = SwinTransformerBlock(
            dim=192 * 2,  # Adjusted for bidirectional output
            num_heads=num_heads[0],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )
        self.conv11 = SeparableConv2d(192 * 2, 48, kernel_size=3, padding=1)
        self.conv12 = SeparableConv2d(48, 48, kernel_size=3, padding=1)

        # Second Upsampling Block
        self.up2 = nn.ConvTranspose2d(48, 48, kernel_size=2, stride=2)  # 64x64 -> 128x128
        self.bn_up2 = nn.BatchNorm2d(48)
        self.relu_up2 = nn.ReLU(inplace=True)
        self.bidirectional_convLSTM2 = BidirectionalConvLSTM2D(input_channels=48, hidden_channels=96, kernel_size=3, num_layers=1)
        self.swin_unet_D2 = SwinTransformerBlock(
            dim=96 * 2,
            num_heads=num_heads[1],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )
        self.conv13 = SeparableConv2d(96 * 2, 24, kernel_size=3, padding=1)
        self.conv14 = SeparableConv2d(24, 24, kernel_size=3, padding=1)

        # Third Upsampling Block
        self.up3 = nn.ConvTranspose2d(24, 24, kernel_size=2, stride=2)  # 128x128 -> 256x256
        self.bn_up3 = nn.BatchNorm2d(24)
        self.relu_up3 = nn.ReLU(inplace=True)
        self.bidirectional_convLSTM3 = BidirectionalConvLSTM2D(input_channels=24, hidden_channels=48, kernel_size=3, num_layers=1)
        self.swin_unet_D3 = SwinTransformerBlock(
            dim=48 * 2,
            num_heads=num_heads[1],
            window_size=window_size,
            shift_size=window_size//2 if True else 0,
            mlp_ratio=mlp_ratio
        )
        self.conv15 = SeparableConv2d(48 * 2, 24, kernel_size=3, padding=1)
        self.conv16 = SeparableConv2d(24, 24, kernel_size=3, padding=1)

        # Output Layer
        self.final_conv1 = nn.Conv2d(24, 2, kernel_size=3, padding=1)
        self.final_relu = nn.ReLU(inplace=True)
        self.final_conv2 = nn.Conv2d(2, 1, kernel_size=1, padding=0)
        self.final_sigmoid = nn.Sigmoid()

    def forward(self, x):
        """
        Forward pass of the Swin U-Net model.
        Args:
            x: Input tensor with shape (B, 3, 256, 256)
        Returns:
            torch.Tensor: Output segmentation mask with shape (B, 1, 256, 256)
        """
        # Initial Convolutions
        x1 = self.conv1(x)          # (B, 24, 256, 256)
        x1 = self.bn1(x1)
        x1 = self.conv2(x1)         # (B, 24, 256, 256)
        x1 = self.bn2(x1)
        p1 = self.pool1(x1)         # (B, 24, 128, 128)

        # First Swin Transformer Block
        p1_flat = p1.flatten(2).transpose(1, 2)  # (B, 128*128, 24)
        swin_E1 = self.swin_unet_E1(p1_flat)     # (B, 128*128, 24)
        swin_E1 = swin_E1.transpose(1, 2).view(-1, 24, 128, 128)  # Reshape for Conv2d

        # Second Convolutional Block
        x2 = self.conv3(swin_E1)    # (B, 48, 128, 128)
        x2 = self.bn3(x2)
        x2 = self.conv4(x2)          # (B, 48, 128, 128)
        x2 = self.bn4(x2)
        p2 = self.pool2(x2)          # (B, 48, 64, 64)

        # Second Swin Transformer Block
        p2_flat = p2.flatten(2).transpose(1, 2)  # (B, 64*64, 48)
        swin_E2 = self.swin_unet_E2(p2_flat)     # (B, 64*64, 48)
        swin_E2 = swin_E2.transpose(1, 2).view(-1, 48, 64, 64)  # Reshape for Conv2d

        # Third Convolutional Block (Bottleneck)
        x3 = self.conv5(swin_E2)    # (B, 96, 64, 64)
        x3 = self.bn5(x3)
        x3 = self.conv6(x3)          # (B, 96, 64, 64)
        x3 = self.bn6(x3)
        x3 = self.drop5(x3)
        p3 = self.pool3(x3)          # (B, 96, 32, 32)

        # Bottleneck Convolutions with Dense Connections
        x4 = self.conv7(p3)          # (B, 192, 32, 32)
        x4 = self.bn7(x4)
        x4 = self.conv8(x4)          # (B, 192, 32, 32)
        x4 = self.bn8(x4)
        x4 = self.drop6_1(x4)

        x5 = self.conv9(x4)          # (B, 192, 32, 32)
        x5 = self.bn9(x5)
        x5 = self.conv10(x5)         # (B, 192, 32, 32)
        x5 = self.bn10(x5)
        x5 = self.drop6_2(x5)

        concat = torch.cat([x5, x4], dim=1)  # (B, 384, 32, 32)
        concat = self.concat1(concat)         # (B, 192, 32, 32)
        concat = self.drop6_3(concat)         # (B, 192, 32, 32)

        # First Upsampling Block
        up1 = self.up1(concat)                 # (B, 96, 64, 64)
        up1 = self.bn_up1(up1)
        up1 = self.relu_up1(up1)

        # Prepare for BidirectionalConvLSTM2D
        up1_seq = torch.stack([x3, up1], dim=1)  # (B, 2, 96, 64, 64)
        bidir_convLSTM1_out = self.bidirectional_convLSTM1(up1_seq)  # (B, 192*2, 64, 64)

        # Swin Transformer Block in Decoder
        bidir_convLSTM1_flat = bidir_convLSTM1_out.flatten(2).transpose(1, 2)  # (B, 64*64, 192*2)
        swin_D1 = self.swin_unet_D1(bidir_convLSTM1_flat)               # (B, 64*64, 192*2)
        swin_D1 = swin_D1.transpose(1, 2).view(-1, 192*2, 64, 64)    # Reshape for Conv2d

        # Further Convolutions
        conv6 = self.conv11(swin_D1)        # (B, 48, 64, 64)
        conv6 = self.conv12(conv6)          # (B, 48, 64, 64)

        # Second Upsampling Block
        up2 = self.up2(conv6)               # (B, 48, 128, 128)
        up2 = self.bn_up2(up2)
        up2 = self.relu_up2(up2)

        # Prepare for BidirectionalConvLSTM2D
        up2_seq = torch.stack([x2, up2], dim=1)  # (B, 2, 48, 128, 128)
        bidir_convLSTM2_out = self.bidirectional_convLSTM2(up2_seq)  # (B, 96*2, 128, 128)

        # Swin Transformer Block in Decoder
        bidir_convLSTM2_flat = bidir_convLSTM2_out.flatten(2).transpose(1, 2)  # (B, 128*128, 96*2)
        swin_D2 = self.swin_unet_D2(bidir_convLSTM2_flat)               # (B, 128*128, 96*2)
        swin_D2 = swin_D2.transpose(1, 2).view(-1, 96*2, 128, 128)    # Reshape for Conv2d

        # Further Convolutions
        conv7 = self.conv13(swin_D2)        # (B, 24, 128, 128)
        conv7 = self.conv14(conv7)          # (B, 24, 128, 128)

        # Third Upsampling Block
        up3 = self.up3(conv7)               # (B, 24, 256, 256)
        up3 = self.bn_up3(up3)
        up3 = self.relu_up3(up3)

        # Prepare for BidirectionalConvLSTM2D
        up3_seq = torch.stack([x1, up3], dim=1)  # (B, 2, 24, 256, 256)
        bidir_convLSTM3_out = self.bidirectional_convLSTM3(up3_seq)  # (B, 48*2, 256, 256)

        # Swin Transformer Block in Decoder
        bidir_convLSTM3_flat = bidir_convLSTM3_out.flatten(2).transpose(1, 2)  # (B, 256*256, 48*2)
        swin_D3 = self.swin_unet_D3(bidir_convLSTM3_flat)               # (B, 256*256, 48*2)
        swin_D3 = swin_D3.transpose(1, 2).view(-1, 48*2, 256, 256)    # Reshape for Conv2d

        # Further Convolutions
        conv8 = self.conv15(swin_D3)        # (B, 24, 256, 256)
        conv8 = self.conv16(conv8)          # (B, 24, 256, 256)

        # Final Output Convolutions
        final = self.final_conv1(conv8)      # (B, 2, 256, 256)
        final = self.final_relu(final)
        final = self.final_conv2(final)      # (B, 1, 256, 256)
        final = self.final_sigmoid(final)    # (B, 1, 256, 256)

        return final

# ================================== Dataset Class ======================================

class SegmentationDataset(Dataset):
    """
    Custom Dataset for image segmentation tasks.
    Expects images in 'x' folder and masks in 'y' folder.
    """
    def __init__(self, images_dir, masks_dir, transform=None):
        super(SegmentationDataset, self).__init__()
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform

        self.images = sorted(os.listdir(images_dir))
        self.masks = sorted(os.listdir(masks_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.images_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')  # Ensure RGB

        # Load mask
        mask_path = os.path.join(self.masks_dir, self.masks[idx])
        mask = Image.open(mask_path).convert('L')    # Grayscale

        # Apply transformations
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

# =============================== Data Loading and Preprocessing ========================

# Define image dimensions
im_height = 256
im_width = 256

# Define transformations
transform = transforms.Compose([
    transforms.Resize((im_height, im_width)),
    transforms.ToTensor(),  # Converts to [0,1]
])

# Paths to the dataset (update these paths as per your directory structure)
train_images_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1-2_Training_Input'
train_masks_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1_Training_GroundTruth'
test_images_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1-2_Test_Input'
test_masks_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1_Test_GroundTruth'

# Create datasets
train_dataset = SegmentationDataset(train_images_dir, train_masks_dir, transform=transform)
test_dataset = SegmentationDataset(test_images_dir, test_masks_dir, transform=transform)

# Split training data into training and validation sets (80-20 split)
train_size = int(0.8 * len(train_dataset))
valid_size = len(train_dataset) - train_size
train_subset, valid_subset = torch.utils.data.random_split(train_dataset, [train_size, valid_size])

# Create DataLoaders
batch_size = 5

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_subset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# =============================== Training Setup ==========================================

# Instantiate the model
model = SwinUNet(input_channels=3, output_channels=1, embed_dim=32, num_heads=[4, 8], window_size=4, mlp_ratio=4., depth=2)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)  # Move to GPU if available

# Initialize weights using Kaiming Normal initialization
def initialize_weights(module):
    if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Conv3d):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)

model.apply(initialize_weights)

# Define loss function and optimizer
criterion = DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Define learning rate scheduler and early stopping parameters
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.25, patience=5, verbose=True, min_lr=1e-9)
early_stopping_patience = 9
best_val_loss = np.inf
epochs_no_improve = 0

# =============================== Training Loop ===========================================

num_epochs = 1  # You can adjust the number of epochs

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    model.train()
    running_loss = 0.0
    train_loader_count = 0
    for images, masks in train_loader:
        train_loader_count += 1
        print(f"Train loader image count: {train_loader_count}")
        images = images.to(device)  # (B, 3, 256, 256)
        masks = masks.to(device)    # (B, 1, 256, 256)

        optimizer.zero_grad()
        outputs = model(images)      # (B, 1, 256, 256)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in valid_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)

            val_loss += loss.item() * images.size(0)

    val_loss /= len(valid_loader.dataset)

    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Validation Loss: {val_loss:.4f}')

    # Scheduler step
    scheduler.step(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        # Save the best model
        torch.save(model.state_dict(), r'/content/drive/MyDrive/model/modelWeights_Swin_Trans_Weights_Swin_Trans_Leather.pth')
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= early_stopping_patience:
            print('Early stopping!')
            break

# ================================== Prediction ==========================================

# Load the best model weights
model.load_state_dict(torch.load(r'/content/drive/MyDrive/model/modelWeights_Swin_Trans_Weights_Swin_Trans_Leather.pth'))
model.eval()

# Function to save predictions and ground truth
def save_predictions(model, dataloader, save_dir_pred, save_dir_gt, device):
    """
    Saves the predicted masks and ground truth masks.
    Args:
        model (nn.Module): Trained model.
        dataloader (DataLoader): DataLoader for test data.
        save_dir_pred (str): Directory to save predicted masks.
        save_dir_gt (str): Directory to save ground truth masks.
        device (str): Device to run the model on.
    """
    os.makedirs(save_dir_pred, exist_ok=True)
    os.makedirs(save_dir_gt, exist_ok=True)

    with torch.no_grad():
        for i, (images, masks) in enumerate(dataloader):
            if (i % 100 == 0):
                print(f"{i}th Test Image")  # Adjust as per your dataset

            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = outputs.cpu().numpy()
            gts = masks.cpu().numpy()

            for j in range(preds.shape[0]):
                pred_mask = preds[j, 0, :, :]
                gt_mask = gts[j, 0, :, :]

                # Save predicted mask
                pred_img = Image.fromarray((pred_mask * 255).astype(np.uint8))
                pred_img.save(os.path.join(save_dir_pred, f"{i * dataloader.batch_size + j + 1}.png"))

                # Save ground truth mask
                gt_img = Image.fromarray((gt_mask * 255).astype(np.uint8))
                gt_img.save(os.path.join(save_dir_gt, f"{i * dataloader.batch_size + j + 1}.tiff"))

# Define directories to save predictions and ground truth
save_dir_pred = r'/content/drive/MyDrive/output/segmented predicted images'
save_dir_gt = r'/content/drive/MyDrive/output/segmented ground truth'

# Save predictions
save_predictions(model, test_loader, save_dir_pred, save_dir_gt, device)

# =================================== Evaluation =========================================

def evaluate_metrics_pytorch(model, dataloader, device):
    """
    Evaluates various metrics for segmentation performance.
    Args:
        model (nn.Module): Trained model.
        dataloader (DataLoader): DataLoader for test data.
        device (str): Device to run the model on.
    Returns:
        dict: Dictionary containing average metrics.
    """
    model.eval()
    all_metrics = {
        'Accuracy': [],
        'Dice': [],
        'Jaccard': [],
        'Sensitivity': [],
        'Specificity': [],
        'Precision': [],
        'Recall': [],
        'F1-Score': []
    }

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = outputs > 0.5  # Binary mask

            preds = preds.cpu().numpy().astype(np.uint8)
            masks = masks.cpu().numpy().astype(np.uint8)
            masks = (masks > 0).astype(np.uint8)  # Convert to binary masks

            for pred, mask in zip(preds, masks):
                pred_flat = pred.flatten()
                mask_flat = mask.flatten()

                # Calculate metrics
                tn, fp, fn, tp = confusion_matrix(mask_flat, pred_flat, labels=[0,1]).ravel()

                accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-8)
                iou = jaccard_score(mask_flat, pred_flat, zero_division=0)
                dice = f1_score(mask_flat, pred_flat, zero_division=0)
                specificity = tn / (tn + fp + 1e-8)
                sensitivity = recall_score(mask_flat, pred_flat, zero_division=0)
                precision = precision_score(mask_flat, pred_flat, zero_division=0)
                recall = sensitivity
                f1 = dice  # F1-Score is the same as Dice coefficient for binary classification

                all_metrics['Accuracy'].append(accuracy)
                all_metrics['Jaccard'].append(iou)
                all_metrics['Dice'].append(dice)
                all_metrics['Specificity'].append(specificity)
                all_metrics['Sensitivity'].append(sensitivity)
                all_metrics['Precision'].append(precision)
                all_metrics['Recall'].append(recall)
                all_metrics['F1-Score'].append(f1)

    # Compute average metrics
    avg_metrics = {metric: np.mean(values) for metric, values in all_metrics.items()}

    print("Evaluation Metrics:")
    for metric, value in avg_metrics.items():
        print(f"{metric}: {value:.4f}")

    return avg_metrics

# Evaluate the model
metrics = evaluate_metrics_pytorch(model, test_loader, device)


In [None]:
!pip install torchinfo

In [None]:
from torchinfo import summary

# Define the input size based on your model's expected input.
# For example, if your model expects images with 3 channels and 256x256 dimensions:
input_size = (1, 3, 256, 256)  # (batch_size, channels, height, width)

# Generate and print the model summary
summary(model, input_size=input_size, device='cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# ============================== Imports and Dependencies ==============================

import os
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.metrics import (
    confusion_matrix,
    jaccard_score, f1_score, precision_score, recall_score
)

# ================================ Separable Convolution =================================

class SeparableConv2d(nn.Module):
    """
    Implements a separable convolution layer using depthwise and pointwise convolutions.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, bias=True):
        super(SeparableConv2d, self).__init__()
        # Depthwise convolution (groups=in_channels)
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
                                   padding=padding, groups=in_channels, bias=bias)
        # Pointwise convolution
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1,
                                   padding=0, bias=bias)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

# ================================== ConvLSTM2D ========================================

class ConvLSTMCell(nn.Module):
    """
    Implements a ConvLSTM cell.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):
        super(ConvLSTMCell, self).__init__()

        padding = kernel_size // 2  # To maintain spatial dimensions
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.conv = nn.Conv2d(in_channels=input_channels + hidden_channels,
                              out_channels=4 * hidden_channels,
                              kernel_size=kernel_size,
                              padding=padding,
                              bias=bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        # Concatenate input and hidden state
        combined = torch.cat([input_tensor, h_cur], dim=1)  # along channel axis

        # Compute all gates at once
        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, self.hidden_channels, dim=1)

        i = torch.sigmoid(cc_i)   # input gate
        f = torch.sigmoid(cc_f)   # forget gate
        o = torch.sigmoid(cc_o)   # output gate
        g = torch.tanh(cc_g)      # gate gate

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, spatial_size, device):
        height, width = spatial_size
        return (torch.zeros(batch_size, self.hidden_channels, height, width, device=device),
                torch.zeros(batch_size, self.hidden_channels, height, width, device=device))

class ConvLSTM2D(nn.Module):
    """
    Implements a ConvLSTM2D layer that processes a sequence of inputs.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size=3, bias=True, num_layers=1):
        super(ConvLSTM2D, self).__init__()
        self.num_layers = num_layers
        self.hidden_channels = hidden_channels

        layers = []
        for i in range(num_layers):
            input_c = input_channels if i == 0 else hidden_channels
            layers.append(ConvLSTMCell(input_c, hidden_channels, kernel_size, bias))
        self.layers = nn.ModuleList(layers)

    def forward(self, input_tensor, reverse=False):
        # input_tensor shape: (batch, seq_len, channels, height, width)
        batch_size, seq_len, channels, height, width = input_tensor.size()
        device = input_tensor.device

        # Initialize hidden and cell states for all layers
        hidden_state = []
        cell_state = []
        for i in range(self.num_layers):
            h, c = self.layers[i].init_hidden(batch_size, (height, width), device)
            hidden_state.append(h)
            cell_state.append(c)

        # Iterate over time steps
        if reverse:
            time_steps = reversed(range(seq_len))
        else:
            time_steps = range(seq_len)

        outputs = []
        for t in time_steps:
            x = input_tensor[:, t, :, :, :]  # (batch, channels, height, width)
            for i, layer in enumerate(self.layers):
                h, c = layer(x, (hidden_state[i], cell_state[i]))
                hidden_state[i] = h
                cell_state[i] = c
                x = h  # input to next layer
            outputs.append(x)

        outputs = torch.stack(outputs, dim=1)  # (batch, seq_len, channels, height, width)
        if reverse:
            outputs = outputs.flip(dims=[1])  # Reverse back to original order
        return outputs  # Return the sequence of outputs

class BidirectionalConvLSTM2D(nn.Module):
    """
    Implements a Bidirectional ConvLSTM2D layer.
    Processes the input sequence in both forward and backward directions and concatenates the outputs.
    """
    def __init__(self, input_channels, hidden_channels, kernel_size=3, num_layers=1, bias=True):
        super(BidirectionalConvLSTM2D, self).__init__()
        self.forward_conv_lstm = ConvLSTM2D(input_channels, hidden_channels, kernel_size, bias=bias, num_layers=num_layers)
        self.backward_conv_lstm = ConvLSTM2D(input_channels, hidden_channels, kernel_size, bias=bias, num_layers=num_layers)

    def forward(self, input_tensor):
        # input_tensor shape: (batch, seq_len, channels, height, width)
        # Forward direction
        forward_output = self.forward_conv_lstm(input_tensor, reverse=False)  # (batch, seq_len, hidden_channels, H, W)
        # Backward direction
        backward_output = self.backward_conv_lstm(input_tensor, reverse=True)  # (batch, seq_len, hidden_channels, H, W)
        # Concatenate outputs along the channel dimension
        output = torch.cat([forward_output, backward_output], dim=2)  # (batch, seq_len, hidden_channels*2, H, W)
        # Since seq_len=2, we can take the last output
        output = output[:, -1, :, :, :]  # Take the last output (batch, hidden_channels*2, H, W)
        return output

# =============================== Patch Embedding ========================================

class PatchEmbed(nn.Module):
    """
    Image to Patch Embedding.
    """
    def __init__(self, img_size=256, patch_size=4, in_chans=3, embed_dim=32):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # x: [B, C, H, W]
        x = self.proj(x)  # [B, embed_dim, H/patch_size, W/patch_size]
        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        x = self.norm(x)
        return x

# ============================== Swin Transformer Blocks ================================

class WindowAttention(nn.Module):
    """
    Window based multi-head self attention (W-MSA) module with relative position bias.
    """
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
        """
        Args:
            dim (int): Number of input channels.
            window_size (tuple): Height and width of the window.
            num_heads (int): Number of attention heads.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            attn_drop (float): Dropout ratio of attention weights.
            proj_drop (float): Dropout ratio of output.
        """
        super(WindowAttention, self).__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # Define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
        )  # 2*Wh-1 * 2*Ww-1, nH

        # Get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij'))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1)
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # Query, Key, Value
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # Initialize relative position bias table
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, Wh*Ww, C)
            mask: (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # 3, B_, nH, N, C//nH
        q, k, v = qkv[0], qkv[1], qkv[2]  # each has shape (B_, nH, N, C//nH)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # (B_, nH, N, N)

        # Add relative position bias
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
        )  # Wh*Ww, Wh*Ww, nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)  # (B_, nH, N, N)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = F.softmax(attn, dim=-1)
        else:
            attn = F.softmax(attn, dim=-1)

        attn = self.attn_drop(attn)

        out = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # (B_, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class SwinTransformerBlock(nn.Module):
    """
    Swin Transformer Block with W-MSA and SW-MSA.
    """
    def __init__(self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True,
                 attn_drop=0., proj_drop=0., mlp_hidden_dim=512):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size  # W
        self.shift_size = shift_size    # S
        self.mlp_ratio = mlp_ratio

        assert 0 <= self.shift_size < self.window_size, "shift_size must be in [0, window_size)"

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, (window_size, window_size), num_heads, qkv_bias, attn_drop, proj_drop)

        self.drop_path = nn.Identity()  # Can implement stochastic depth if desired
        self.norm2 = nn.LayerNorm(dim)
        # mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(proj_drop)
        )

    def forward(self, x, H, W, mask_matrix=None):
        """
        Args:
            x: input features with shape (B, H*W, C)
            H, W: spatial dimensions
            mask_matrix: attention mask
        """
        B, L, C = x.shape
        assert L == H * W, "Input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # Padding for window partition
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))

        _, Hp, Wp, _ = x.shape

        # Cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(
                x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
            )
            attn_mask = mask_matrix
        else:
            shifted_x = x
            attn_mask = None

        # Partition windows
        x_windows = shifted_x.unfold(1, self.window_size, self.window_size).unfold(
            2, self.window_size, self.window_size
        )
        x_windows = x_windows.contiguous().view(-1, self.window_size * self.window_size, C)

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask)

        # Merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = attn_windows.view(
            B, Hp // self.window_size, Wp // self.window_size, self.window_size, self.window_size, C
        )
        shifted_x = shifted_x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, C)

        # Reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(
                shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
            )
        else:
            x = shifted_x

        # Remove padding
        x = x[:, :H, :W, :].contiguous().view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

# =============================== Dice Loss Function ====================================

class DiceLoss(nn.Module):
    """
    Dice Loss function to maximize the Dice coefficient.
    Suitable for binary segmentation tasks.
    """
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred (torch.Tensor): Predicted mask probabilities with shape (B, 1, H, W)
            y_true (torch.Tensor): Ground truth masks with shape (B, 1, H, W)
        Returns:
            torch.Tensor: Dice loss
        """
        y_pred = y_pred.view(-1)
        y_true = y_true.view(-1)

        intersection = (y_pred * y_true).sum()
        dice = (2. * intersection + self.smooth) / (y_pred.sum() + y_true.sum() + self.smooth)

        return 1 - dice

# =============================== Swin UNet Model ========================================

class SwinUNet(nn.Module):
    """
    Swin U-Net architecture for image segmentation.
    """
    def __init__(self, input_channels=3, output_channels=1,
                 filter_num_begin=32, depth=4, stack_num_down=2, stack_num_up=2,
                 num_heads=[4, 8, 8, 8], window_size=[4, 2, 2, 2], num_mlp=512,
                 shift_window=True, **kwargs):
        super(SwinUNet, self).__init__()

        self.input_channels = input_channels
        self.output_channels = output_channels
        self.filter_num_begin = filter_num_begin
        self.depth = depth
        self.stack_num_down = stack_num_down
        self.stack_num_up = stack_num_up
        self.num_heads = num_heads
        self.window_size = window_size
        self.num_mlp = num_mlp
        self.shift_window = shift_window

        # Define the number of channels at each level
        self.filter_nums = [filter_num_begin * (2 ** i) for i in range(depth)]
        # Example: [32, 64, 128, 256] if depth=4

        # Patch Embedding
        self.patch_embed = PatchEmbed(
            img_size=256, patch_size=4, in_chans=input_channels, embed_dim=self.filter_nums[0]
        )

        # Encoder
        self.encoder_layers = nn.ModuleList()
        for i in range(depth):
            down_layers = nn.ModuleList()
            for _ in range(stack_num_down):
                down_layers.append(SwinTransformerBlock(
                    dim=self.filter_nums[i],
                    num_heads=num_heads[i],
                    window_size=window_size[i],
                    shift_size=window_size[i] // 2 if shift_window else 0,
                    mlp_hidden_dim=num_mlp,
                ))
            self.encoder_layers.append(down_layers)

            if i < depth - 1:
                setattr(self, f"downsample_{i}", nn.Conv2d(
                    self.filter_nums[i], self.filter_nums[i + 1], kernel_size=2, stride=2)
                )

        # Bottleneck
        self.bottleneck_layers = nn.ModuleList()
        for _ in range(stack_num_down):
            self.bottleneck_layers.append(SwinTransformerBlock(
                dim=self.filter_nums[-1],
                num_heads=num_heads[-1],
                window_size=window_size[-1],
                shift_size=window_size[-1] // 2 if shift_window else 0,
                mlp_hidden_dim=num_mlp,
            ))

        # Decoder
        self.decoder_layers = nn.ModuleList()
        for i in range(depth):
            if i == 0:
                # The first decoder layer after the bottleneck
                up_in_channels = self.filter_nums[depth - 1]
                in_channels = self.filter_nums[depth - 1]
                out_channels = self.filter_nums[depth - 1]
            else:
                up_in_channels = self.filter_nums[depth - i]
                in_channels = self.filter_nums[depth - i - 1] * 2  # After concatenation
                out_channels = self.filter_nums[depth - i - 1]

            # Upsampling layer (except for the first decoder layer)
            if i > 0:
                setattr(self, f"upsample_{i}", nn.ConvTranspose2d(
                    up_in_channels, out_channels, kernel_size=2, stride=2)
                )

            # Swin Transformer Blocks
            up_layers = nn.ModuleList()
            for _ in range(stack_num_up):
                up_layers.append(SwinTransformerBlock(
                    dim=in_channels,
                    num_heads=num_heads[depth - i - 1],
                    window_size=window_size[depth - i - 1],
                    shift_size=window_size[depth - i - 1] // 2 if shift_window else 0,
                    mlp_hidden_dim=num_mlp,
                ))
            self.decoder_layers.append(up_layers)

        # Final Convolution
        self.final_conv = nn.Sequential(
            nn.Conv2d(self.filter_nums[0], self.filter_nums[0] // 2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.filter_nums[0] // 2, output_channels, kernel_size=1)
        )

    def forward(self, x):
        # Initial Patch Embedding
        x = self.patch_embed(x)  # x shape: [B, num_patches, embed_dim]
        B, N, C = x.shape
        H = W = int(np.sqrt(N))

        x = x.transpose(1, 2).view(B, C, H, W)  # [B, C, H, W]

        # Encoder
        encodings = []
        for i, layers in enumerate(self.encoder_layers):
            for blk in layers:
                x = x.flatten(2).transpose(1, 2)  # [B, H*W, C]
                x = blk(x, H, W)
                x = x.transpose(1, 2).view(B, -1, H, W)
            encodings.append(x)
            if i < self.depth - 1:
                downsample = getattr(self, f"downsample_{i}")
                x = downsample(x)
                _, _, H, W = x.shape

        # Bottleneck
        for blk in self.bottleneck_layers:
            x = x.flatten(2).transpose(1, 2)
            x = blk(x, H, W)
            x = x.transpose(1, 2).view(B, -1, H, W)

        # Decoder
        for i in range(self.depth):
            if i > 0:
                upsample = getattr(self, f"upsample_{i}")
                x = upsample(x)
                _, _, H, W = x.shape

                # Concatenate with skip connection
                skip_connection = encodings[self.depth - i - 1]
                x = torch.cat([x, skip_connection], dim=1)  # Concatenate along channels

            for blk in self.decoder_layers[i]:
                x = x.flatten(2).transpose(1, 2)
                x = blk(x, H, W)
                x = x.transpose(1, 2).view(B, -1, H, W)

        # Final Convolution
        x = self.final_conv(x)
        x = torch.sigmoid(x)
        return x
# ================================== Dataset Class ======================================

class SegmentationDataset(Dataset):
    """
    Custom Dataset for image segmentation tasks.
    Expects images in 'x' folder and masks in 'y' folder.
    """
    def __init__(self, images_dir, masks_dir, transform=None):
        super(SegmentationDataset, self).__init__()
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform

        self.images = sorted(os.listdir(images_dir))
        self.masks = sorted(os.listdir(masks_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.images_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')  # Ensure RGB

        # Load mask
        mask_path = os.path.join(self.masks_dir, self.masks[idx])
        mask = Image.open(mask_path).convert('L')    # Grayscale

        # Apply transformations
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

# =============================== Data Loading and Preprocessing ========================

# Define image dimensions
im_height = 256
im_width = 256

# Define transformations
transform = transforms.Compose([
    transforms.Resize((im_height, im_width)),
    transforms.ToTensor(),  # Converts to [0,1]
])

# Paths to the dataset (update these paths as per your directory structure)
train_images_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1-2_Training_Input'
train_masks_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1_Training_GroundTruth'
test_images_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1-2_Test_Input'
test_masks_dir = r'/content/drive/MyDrive/ML/dataset/ISIC2018_Task1_Test_GroundTruth'

# Create datasets
train_dataset = SegmentationDataset(train_images_dir, train_masks_dir, transform=transform)
test_dataset = SegmentationDataset(test_images_dir, test_masks_dir, transform=transform)

# Split training data into training and validation sets (80-20 split)
train_size = int(0.8 * len(train_dataset))
valid_size = len(train_dataset) - train_size
train_subset, valid_subset = torch.utils.data.random_split(train_dataset, [train_size, valid_size])

# Create DataLoaders
batch_size = 5

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_subset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# =============================== Training Setup ==========================================

# Instantiate the model
model = SwinUNet(input_channels=3, output_channels=1,
                 filter_num_begin=32, depth=4, stack_num_down=2, stack_num_up=2,
                 num_heads=[4, 8, 8, 8], window_size=[4, 2, 2, 2], num_mlp=512,
                 shift_window=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)  # Move to GPU if available

# Initialize weights using Kaiming Normal initialization
def initialize_weights(module):
    if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Conv3d):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)

model.apply(initialize_weights)

# Define loss function and optimizer
criterion = DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Define learning rate scheduler and early stopping parameters
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.25, patience=5, verbose=True, min_lr=1e-9)
early_stopping_patience = 9
best_val_loss = np.inf
epochs_no_improve = 0

# =============================== Training Loop ===========================================

num_epochs = 1  # You can adjust the number of epochs

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    model.train()
    running_loss = 0.0
    train_loader_count = 0
    for images, masks in train_loader:
        train_loader_count += 1
        print(f"Train loader batch count: {train_loader_count}")
        images = images.to(device)  # (B, 3, 256, 256)
        masks = masks.to(device)    # (B, 1, 256, 256)

        optimizer.zero_grad()
        outputs = model(images)      # (B, 1, 256, 256)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in valid_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)

            val_loss += loss.item() * images.size(0)

    val_loss /= len(valid_loader.dataset)

    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Validation Loss: {val_loss:.4f}')

    # Scheduler step
    scheduler.step(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        # Save the best model
        torch.save(model.state_dict(), r'/content/drive/MyDrive/model/modelWeights_Swin_UNet.pth')
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= early_stopping_patience:
            print('Early stopping!')
            break

# ================================== Prediction ==========================================

# Load the best model weights
model.load_state_dict(torch.load(r'/content/drive/MyDrive/model/modelWeights_Swin_UNet.pth'))
model.eval()

# Function to save predictions and ground truth
def save_predictions(model, dataloader, save_dir_pred, save_dir_gt, device):
    """
    Saves the predicted masks and ground truth masks.
    Args:
        model (nn.Module): Trained model.
        dataloader (DataLoader): DataLoader for test data.
        save_dir_pred (str): Directory to save predicted masks.
        save_dir_gt (str): Directory to save ground truth masks.
        device (str): Device to run the model on.
    """
    os.makedirs(save_dir_pred, exist_ok=True)
    os.makedirs(save_dir_gt, exist_ok=True)

    with torch.no_grad():
        for i, (images, masks) in enumerate(dataloader):
            if (i % 100 == 0):
                print(f"{i}th Test Batch")  # Adjust as per your dataset

            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = outputs.cpu().numpy()
            gts = masks.cpu().numpy()

            for j in range(preds.shape[0]):
                pred_mask = preds[j, 0, :, :]
                gt_mask = gts[j, 0, :, :]

                # Save predicted mask
                pred_img = Image.fromarray((pred_mask * 255).astype(np.uint8))
                pred_img.save(os.path.join(save_dir_pred, f"{i * dataloader.batch_size + j + 1}.png"))

                # Save ground truth mask
                gt_img = Image.fromarray((gt_mask * 255).astype(np.uint8))
                gt_img.save(os.path.join(save_dir_gt, f"{i * dataloader.batch_size + j + 1}.tiff"))

# Define directories to save predictions and ground truth
save_dir_pred = r'/content/drive/MyDrive/output/segmented_predicted_images'
save_dir_gt = r'/content/drive/MyDrive/output/segmented_ground_truth'

# Save predictions
save_predictions(model, test_loader, save_dir_pred, save_dir_gt, device)

# =================================== Evaluation =========================================

def evaluate_metrics_pytorch(model, dataloader, device):
    """
    Evaluates various metrics for segmentation performance.
    Args:
        model (nn.Module): Trained model.
        dataloader (DataLoader): DataLoader for test data.
        device (str): Device to run the model on.
    Returns:
        dict: Dictionary containing average metrics.
    """
    model.eval()
    all_metrics = {
        'Accuracy': [],
        'Dice': [],
        'Jaccard': [],
        'Sensitivity': [],
        'Specificity': [],
        'Precision': [],
        'Recall': [],
        'F1-Score': []
    }

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = outputs > 0.5  # Binary mask

            preds = preds.cpu().numpy().astype(np.uint8)
            masks = masks.cpu().numpy().astype(np.uint8)
            masks = (masks > 0).astype(np.uint8)  # Convert to binary masks

            for pred, mask in zip(preds, masks):
                pred_flat = pred.flatten()
                mask_flat = mask.flatten()

                # Calculate metrics
                tn, fp, fn, tp = confusion_matrix(mask_flat, pred_flat, labels=[0,1]).ravel()

                accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-8)
                iou = jaccard_score(mask_flat, pred_flat, zero_division=0)
                dice = f1_score(mask_flat, pred_flat, zero_division=0)
                specificity = tn / (tn + fp + 1e-8)
                sensitivity = recall_score(mask_flat, pred_flat, zero_division=0)
                precision = precision_score(mask_flat, pred_flat, zero_division=0)
                recall = sensitivity
                f1 = dice  # F1-Score is the same as Dice coefficient for binary classification

                all_metrics['Accuracy'].append(accuracy)
                all_metrics['Jaccard'].append(iou)
                all_metrics['Dice'].append(dice)
                all_metrics['Specificity'].append(specificity)
                all_metrics['Sensitivity'].append(sensitivity)
                all_metrics['Precision'].append(precision)
                all_metrics['Recall'].append(recall)
                all_metrics['F1-Score'].append(f1)

    # Compute average metrics
    avg_metrics = {metric: np.mean(values) for metric, values in all_metrics.items()}

    print("Evaluation Metrics:")
    for metric, value in avg_metrics.items():
        print(f"{metric}: {value:.4f}")

    return avg_metrics

# Evaluate the model
metrics = evaluate_metrics_pytorch(model, test_loader, device)