In [2]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import timm
import numpy as np

# Define SeparableConv2d
class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
        super(SeparableConv2d, self).__init__()
        self.depthwise = nn.Conv2d(
            in_channels, in_channels, kernel_size, stride, padding,
            groups=in_channels, bias=bias
        )
        self.pointwise = nn.Conv2d(
            in_channels, out_channels, kernel_size=1, bias=bias
        )

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

# Define ConvLSTMCell
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias=True):
        super(ConvLSTMCell, self).__init__()

        if isinstance(kernel_size, tuple):
            padding = tuple(k // 2 for k in kernel_size)
        else:
            padding = kernel_size // 2

        self.conv = nn.Conv2d(
            in_channels=input_dim + hidden_dim,
            out_channels=4 * hidden_dim,
            kernel_size=kernel_size,
            padding=padding,
            bias=bias
        )

    def forward(self, x, h_cur, c_cur):
        combined = torch.cat([x, h_cur], dim=1)  # concatenate along channel axis

        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, conv_output.shape[1] // 4, dim=1)

        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

# Define ConvLSTM module
class ConvLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, batch_first=True, bias=True, go_backwards=False):
        super(ConvLSTM, self).__init__()

        self.num_layers = num_layers
        self.go_backwards = go_backwards
        self.cell_list = nn.ModuleList()

        for i in range(0, self.num_layers):
            cur_input_dim = input_dim if i == 0 else hidden_dim

            self.cell_list.append(ConvLSTMCell(
                input_dim=cur_input_dim,
                hidden_dim=hidden_dim,
                kernel_size=kernel_size,
                bias=bias
            ))

    def forward(self, input_tensor):
        # input_tensor: (batch, seq_len, channels, height, width)
        b, seq_len, _, h, w = input_tensor.size()
        hidden_state = self._init_hidden(b, h, w, input_tensor.device)

        seq_indices = range(seq_len)
        if self.go_backwards:
            seq_indices = reversed(seq_indices)

        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):
            h, c = hidden_state[layer_idx]
            output_inner = []

            for t in seq_indices:
                h, c = self.cell_list[layer_idx](cur_layer_input[:, t, :, :, :], h, c)
                output_inner.append(h)

            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

        return layer_output[:, -1], h

    def _init_hidden(self, batch_size, image_height, image_width, device):
        init_states = []
        for i in range(self.num_layers):
            hidden_dim = self.cell_list[i].conv.out_channels // 4
            init_states.append((
                torch.zeros(batch_size, hidden_dim, image_height, image_width).to(device),
                torch.zeros(batch_size, hidden_dim, image_height, image_width).to(device)
            ))
        return init_states

# Swin Transformer Block using timm
class SwinTransformerBlock(nn.Module):
    def __init__(self, embed_dim, depths, num_heads, window_size, img_size, mlp_ratio=4., qkv_bias=True):
        super(SwinTransformerBlock, self).__init__()
        self.model = timm.models.swin_transformer.SwinTransformer(
            img_size=img_size,
            patch_size=4,
            in_chans=embed_dim,
            num_classes=0,
            embed_dim=embed_dim,
            depths=depths,
            num_heads=num_heads,
            window_size=window_size,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            drop_rate=0.,
            drop_path_rate=0.,
            ape=False,
            patch_norm=True,
            use_checkpoint=False
        )

    def forward(self, x):
        x = self.model.forward_features(x)
        return x


# TBConvL-Net Model Definition
class TBConvLNet(nn.Module):
    def __init__(self):
        super(TBConvLNet, self).__init__()

        # Convergence Path (Downsampling)
        self.conv1_1 = SeparableConv2d(3, 24, kernel_size=3, stride=1, padding=1)
        self.bn1_1 = nn.BatchNorm2d(24)
        self.conv1_2 = SeparableConv2d(24, 24, kernel_size=3, stride=1, padding=1)
        self.bn1_2 = nn.BatchNorm2d(24)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 128x128
        
        # Swin Transformer at downsampled stage
        self.swin_unet_e1 = SwinTransformerBlock(embed_dim=24, depths=[2, 2], num_heads=[3, 6], window_size=8, img_size=128)

        self.conv2_1 = SeparableConv2d(24, 48, kernel_size=3, stride=1, padding=1)
        self.bn2_1 = nn.BatchNorm2d(48)
        self.conv2_2 = SeparableConv2d(48, 48, kernel_size=3, stride=1, padding=1)
        self.bn2_2 = nn.BatchNorm2d(48)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 64x64
        
        self.swin_unet_e2 = SwinTransformerBlock(embed_dim=48, depths=[2, 2], num_heads=[6, 12], window_size=8,img_size=64)

        self.conv3_1 = SeparableConv2d(48, 96, kernel_size=3, stride=1, padding=1)
        self.bn3_1 = nn.BatchNorm2d(96)
        self.conv3_2 = SeparableConv2d(96, 96, kernel_size=3, stride=1, padding=1)
        self.bn3_2 = nn.BatchNorm2d(96)
        self.drop3 = nn.Dropout2d(0.5)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)  # 32x32

        # Skip connection to Swin Transformer
        self.swin_unet_e3 = SwinTransformerBlock(embed_dim=96, depths=[2, 2], num_heads=[12, 24], window_size=8,img_size=32)

        # Dense blocks
        self.conv4_1 = SeparableConv2d(96, 192, kernel_size=3, stride=1, padding=1)
        self.bn4_1 = nn.BatchNorm2d(192)
        self.conv4_2 = SeparableConv2d(192, 192, kernel_size=3, stride=1, padding=1)
        self.bn4_2 = nn.BatchNorm2d(192)
        self.drop4_1 = nn.Dropout2d(0.5)

        # Deconvergence Path (Upsampling)
        self.up6 = nn.ConvTranspose2d(192, 96, kernel_size=2, stride=2)  # Up to 64x64
        self.bn6 = nn.BatchNorm2d(96)

        self.up7 = nn.ConvTranspose2d(96, 48, kernel_size=2, stride=2)  # Up to 128x128
        self.bn7 = nn.BatchNorm2d(48)

        self.up8 = nn.ConvTranspose2d(48, 24, kernel_size=2, stride=2)  # Up to 256x256
        self.bn8 = nn.BatchNorm2d(24)

        # Final convolutions
        self.conv_final = nn.Conv2d(24, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Convergence Path
        conv1 = F.relu(self.bn1_1(self.conv1_1(x)))
        conv1 = F.relu(self.bn1_2(self.conv1_2(conv1)))
        pool1 = self.pool1(conv1)
        print("After pool1:", pool1.shape)

        swin1 = self.swin_unet_e1(pool1)

        # Adjust channels and spatial size
        swin1 = nn.Conv2d(48, 24, kernel_size=1)(swin1)  # Reduce channels to 24
        swin1 = F.interpolate(swin1, size=(128, 128), mode='bilinear', align_corners=False)  # Resize to match expected spatial size
        print("After swin_unet_e1:", swin1.shape)

        conv2 = F.relu(self.bn2_1(self.conv2_1(swin1)))
        conv2 = F.relu(self.bn2_2(self.conv2_2(conv2)))
        pool2 = self.pool2(conv2)
        print("After pool2:", pool2.shape)

        swin2 = self.swin_unet_e2(pool2)

        conv3 = F.relu(self.bn3_1(self.conv3_1(swin2)))
        conv3 = F.relu(self.bn3_2(self.conv3_2(conv3)))
        drop3 = self.drop3(conv3)
        pool3 = self.pool3(drop3)

        swin3 = self.swin_unet_e3(pool3)

        # Dense block at bottom of network
        conv4 = F.relu(self.bn4_1(self.conv4_1(swin3)))
        dense_output = self.drop4_1(F.relu(self.bn4_2(self.conv4_2(conv4))))

        # Deconvergence Path
        up6 = F.relu(self.bn6(self.up6(dense_output) + conv3))  # Skip connection with conv3
        up7 = F.relu(self.bn7(self.up7(up6) + conv2))           # Skip connection with conv2
        up8 = F.relu(self.bn8(self.up8(up7) + conv1))           # Skip connection with conv1

        # Final Convolution
        out = self.sigmoid(self.conv_final(up8))

        return out


# Instantiate and test the model
model = TBConvLNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Create a dummy input tensor with the expected input size
dummy_input = torch.randn(1, 3, 256, 256).to(device)

# Perform a forward pass
output = model(dummy_input)
print("Output shape:", output.shape)


After pool1: torch.Size([1, 24, 128, 128])


RuntimeError: Given groups=1, weight of size [24, 48, 1, 1], expected input[1, 16, 16, 48] to have 48 channels, but got 16 channels instead

In [None]:
# Custom Dataset Class
class CustomDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None, target_transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        mask = Image.open(self.mask_paths[idx]).convert('L')

        if self.transform:
            img = self.transform(img)
        if self.target_transform:
            mask = self.target_transform(mask)
        else:
            mask = transforms.ToTensor()(mask)

        return img, mask

# Image and Mask Paths
def get_image_mask_paths(path):
    images_dir = os.path.join(path, 'x')
    masks_dir = os.path.join(path, 'y')

    image_paths = sorted([os.path.join(images_dir, f) for f in os.listdir(images_dir)])
    mask_paths = sorted([os.path.join(masks_dir, f) for f in os.listdir(masks_dir)])

    return image_paths, mask_paths

# Define Transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

target_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Get Data Paths
train_data_path = '/path/to/train'
valid_data_path = '/path/to/valid'

train_image_paths, train_mask_paths = get_image_mask_paths(train_data_path)
valid_image_paths, valid_mask_paths = get_image_mask_paths(valid_data_path)

# Create Datasets
train_dataset = CustomDataset(train_image_paths, train_mask_paths, transform=transform, target_transform=target_transform)
valid_dataset = CustomDataset(valid_image_paths, valid_mask_paths, transform=transform, target_transform=target_transform)

# Create DataLoaders
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)


In [None]:
# Define Dice Loss
def dice_loss(pred, target, smooth=1.):
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = 1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))
    return loss.mean()

# Define Optimizer and Loss Function
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = dice_loss

# Training Loop
num_epochs = 100
best_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(train_loader.dataset)

    # Validation
    model.eval()
    valid_loss = 0.0
    with torch.no_grad():
        for images, masks in valid_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)

            valid_loss += loss.item() * images.size(0)

    valid_loss /= len(valid_loader.dataset)

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {valid_loss:.4f}')

    # Save Best Model
    if valid_loss < best_loss:
        best_loss = valid_loss
        torch.save(model.state_dict(), 'best_model.pth')
        print('Model saved.')

# Load Best Model
model.load_state_dict(torch.load('best_model.pth'))


In [None]:
# Evaluation Function
def evaluate_metrics(y_true, y_pred):
    n = y_pred.shape[0]
    all_accuracy = np.zeros(n)
    all_dice = np.zeros(n)
    all_jaccard = np.zeros(n)
    all_sensitivity = np.zeros(n)
    all_specificity = np.zeros(n)

    for i in range(n):
        gt = y_true[i].cpu().numpy().flatten()
        pred = y_pred[i].cpu().numpy().flatten()

        precisions, recalls, thresholds = precision_recall_curve(gt, pred)
        f1 = 2 * (precisions * recalls) / (precisions + recalls + 1e-8)
        max_value = np.argmax(f1)
        thres = thresholds[max_value]
        pred_label = (pred >= thres).astype(np.uint8)

        tn, fp, fn, tp = confusion_matrix(gt, pred_label).ravel()

        accuracy = (tp + tn) / (tp + tn + fp + fn)
        iou = tp / (tp + fp + fn + 1e-8)
        dice = 2 * tp / (2 * tp + fp + fn + 1e-8)
        specificity = tn / (tn + fp + 1e-8)
        recall = tp / (tp + fn + 1e-8)

        all_accuracy[i] = accuracy
        all_dice[i] = dice
        all_jaccard[i] = iou
        all_sensitivity[i] = recall
        all_specificity[i] = specificity

    print('Accuracy: {:4f}, Dice: {:4f}, Jaccard: {:4f}, Sensitivity: {:4f}, Specificity: {:4f}'.format(
        np.nanmean(all_accuracy), np.nanmean(all_dice), np.nanmean(all_jaccard), np.nanmean(all_sensitivity), np.nanmean(all_specificity)
    ))
    return all_accuracy, all_dice, all_jaccard, all_sensitivity, all_specificity

# Evaluation on Validation Set
model.eval()
all_preds = []
all_masks = []

with torch.no_grad():
    for images, masks in valid_loader:
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)
        all_preds.append(outputs)
        all_masks.append(masks)

all_preds = torch.cat(all_preds, dim=0)
all_masks = torch.cat(all_masks, dim=0)

evl = evaluate_metrics(all_masks, all_preds)


In [None]:
# Save Predicted Masks
output_dir = '/path/to/save/segmentations'
gt_dir = '/path/to/save/GT'
os.makedirs(output_dir, exist_ok=True)
os.makedirs(gt_dir, exist_ok=True)

all_preds_np = all_preds.cpu().numpy()
all_masks_np = all_masks.cpu().numpy()

for i in range(all_preds_np.shape[0]):
    pred_mask = all_preds_np[i, 0]
    gt_mask = all_masks_np[i, 0]

    plt.imsave(os.path.join(output_dir, f"{i+1}.png"), pred_mask, cmap='gray')
    plt.imsave(os.path.join(gt_dir, f"{i+1}.png"), gt_mask, cmap='gray')
