# Import Libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("D:/pytorch")
from segmentation_models_pytorch.utils.imports import *

print_versions()

# Select device (GPU or CPU)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

  from .autonotebook import tqdm as notebook_tqdm


Torch Version: 2.4.1+cu118
Using device: cuda


# Import Dataset

In [3]:
# Define the base directory for your dataset
DATASET_DIR = "VH"
img_sub, msk_sub = 'img', 'gt'

# Load paths for training, validation, and test sets with default subdirectories
train_imgs, train_masks = get_dataset_paths(DATASET_DIR, split='train', img_subdir=img_sub, mask_subdir=msk_sub, mask_ext='tiff')
val_imgs, val_masks = get_dataset_paths(DATASET_DIR, split='val', img_subdir=img_sub, mask_subdir=msk_sub, mask_ext='tiff')
test_imgs, test_masks = get_dataset_paths(DATASET_DIR, split='test', img_subdir=img_sub, mask_subdir=msk_sub, mask_ext='tiff')

# Verify images
print(f"Number of training images: {len(train_imgs)}, gt: {len(train_masks)}")
print(f"Number of validation images: {len(val_imgs)}, gt: {len(val_masks)}")
print(f"Number of test images: {len(test_imgs)}, gt: {len(test_masks)}")
print(50*'-')

weights = calculate_class_weights(train_masks)
print("Class Weights:", weights)

Number of training images: 534, gt: 534
Number of validation images: 150, gt: 150
Number of test images: 138, gt: 138
--------------------------------------------------
Class Weights: {np.float32(0.0): np.float64(1.0), np.float32(1.0): np.float64(3.9)}


In [4]:
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip()
])

# Train dataset with the same transformations for both images and masks
train_dataset = Dataset2D(train_imgs, train_masks, transform=transform, transform_label=transform)

# Validation and test datasets only convert to tensors
val_dataset = Dataset2D(val_imgs, val_masks, transform=transforms.ToTensor(), transform_label=None)
test_dataset = Dataset2D(test_imgs, test_masks, transform=transforms.ToTensor(), transform_label=None)

# Model Initialization

In [5]:
import torch
import torch.nn as nn

class UNetEncoder(nn.Module):
    def __init__(self):
        super(UNetEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

    def forward(self, x):
        return self.encoder(x)

class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias=True):
        super(ConvLSTMCell, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.conv = nn.Conv2d(
            in_channels=self.input_dim + self.hidden_dim,
            out_channels=4 * self.hidden_dim,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            bias=bias
        )

    def forward(self, x, hidden):
        h_prev, c_prev = hidden

        combined = torch.cat([x, h_prev], dim=1)  # Concatenate along channel axis
        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, self.hidden_dim, dim=1)

        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c = f * c_prev + i * g
        h = o * torch.tanh(c)

        return h, c

class ConvLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers):
        super(ConvLSTM, self).__init__()
        self.num_layers = num_layers

        self.layers = nn.ModuleList([
            ConvLSTMCell(
                input_dim=input_dim if i == 0 else hidden_dim,
                hidden_dim=hidden_dim,
                kernel_size=kernel_size
            ) for i in range(num_layers)
        ])

    def forward(self, x):
        # x: [B, T, C, H, W]
        B, T, C, H, W = x.shape
        h, c = self.init_hidden(B, C, H, W)

        outputs = []
        for t in range(T):
            inp = x[:, t, :, :, :]
            for i, layer in enumerate(self.layers):
                h[i], c[i] = layer(inp, (h[i], c[i]))
                inp = h[i]
            outputs.append(h[-1])

        return torch.stack(outputs, dim=1), (h, c)

    def init_hidden(self, B, C, H, W):
        h = [torch.zeros(B, C, H, W, device=next(self.parameters()).device) for _ in range(self.num_layers)]
        c = [torch.zeros(B, C, H, W, device=next(self.parameters()).device) for _ in range(self.num_layers)]
        return h, c

class CNNConvLSTMNet(nn.Module):
    def __init__(self, cnn_backbone, feature_channels, temporal_channels, kernel_size, num_layers):
        super(CNNConvLSTMNet, self).__init__()

        # CNN Backbone (e.g., U-Net encoder or other feature extractor)
        self.cnn_backbone = cnn_backbone

        # ConvLSTM for temporal processing
        self.conv_lstm = ConvLSTM(
            input_dim=feature_channels,
            hidden_dim=temporal_channels,
            kernel_size=kernel_size,
            num_layers=num_layers
        )

        # Decoder for segmentation (basic example; replaceable by U-Net decoder)
        self.decoder = nn.Sequential(
            nn.Conv2d(temporal_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        x: Input tensor of shape [B, T, H, W]
        """
        B, T, H, W = x.shape

        # Step 1: Extract CNN features for each time frame
        cnn_features = []
        for t in range(T):
            frame = x[:, t, :, :].unsqueeze(1)  # Extract frame [B, 1, H, W]
            cnn_features.append(self.cnn_backbone(frame))  # Shape [B, C, H', W']
        
        cnn_features = torch.stack(cnn_features, dim=1)  # Shape [B, T, C, H', W']

        # Step 2: Process features with ConvLSTM
        conv_lstm_out, _ = self.conv_lstm(cnn_features)  # Shape [B, T, temporal_channels, H', W']

        # Step 3: Decode the last ConvLSTM output
        last_output = conv_lstm_out[:, -1, :, :, :]  # Use the last time step [B, temporal_channels, H', W']
        segmentation_output = self.decoder(last_output)  # [B, 1, H, W]
        
        return segmentation_output.squeeze(1)  # [B, H, W]

In [6]:
import torch
import torch.nn as nn
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights


class EfficientNetB0Backbone(nn.Module):
    def __init__(self, pretrained=True, output_channels=32):
        super(EfficientNetB0Backbone, self).__init__()
        # Load EfficientNet-B0 model
        if pretrained:
            self.model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
        else:
            self.model = efficientnet_b0(weights=None)
        
        # Modify the first convolutional layer to accept single-channel input
        original_conv = self.model.features[0][0]
        self.model.features[0][0] = nn.Conv2d(
            in_channels=1,  # Single-channel input
            out_channels=original_conv.out_channels,
            kernel_size=original_conv.kernel_size,
            stride=original_conv.stride,
            padding=original_conv.padding,
            bias=original_conv.bias
        )

        # Extract EfficientNet features (exclude classification head)
        self.features = self.model.features

        # Dynamically determine the number of output channels from the last feature layer
        last_feature_channels = self.features[-1][0].out_channels

        # Add a final convolutional layer to reduce channels to `output_channels`
        self.reduce_channels = nn.Conv2d(
            in_channels=last_feature_channels,
            out_channels=output_channels,
            kernel_size=1,
            bias=False
        )

        # Remove hardcoded upsampling scale
        self.upsample = nn.Upsample(mode="bilinear", align_corners=False)

    def forward(self, x):
        """
        Forward pass for the EfficientNetB0 backbone.
        Args:
            x: Input tensor of shape [B, 1, H, W] (single channel input).
        Returns:
            Features of shape [B, output_channels, H, W].
        """
        input_size = x.shape[2:]  # Get input height and width (H, W)

        # Extract features using EfficientNet
        features = self.features(x)  # Shape [B, last_feature_channels, H', W']

        # Reduce channels to the desired output
        reduced_features = self.reduce_channels(features)  # Shape [B, output_channels, H', W']

        # Adjust upsampling dynamically to match input size
        self.upsample.scale_factor = (
            input_size[0] / reduced_features.shape[2],
            input_size[1] / reduced_features.shape[3],
        )
        upsampled_features = self.upsample(reduced_features)  # Shape [B, output_channels, H, W]

        return upsampled_features

In [7]:
# Replace DummyCNN with UNetEncoder
# cnn_backbone = UNetEncoder()
cnn_backbone = EfficientNetB0Backbone(pretrained=True, output_channels=32)


# Initialize the CNN-ConvLSTM model
model = CNNConvLSTMNet(cnn_backbone, feature_channels=32, temporal_channels=32, kernel_size=3, num_layers=1)

# Example input: [B, T, H, W]
input_tensor = torch.randn(1, 11, 512, 512)

# Forward pass
output = model(input_tensor)
print("Output shape:", output.shape)  # Expected: [B, H, W]

if  output.shape[1:] != (512, 512):
    raise ValueError("The output shape is incorrect!")
    print(output.shape)

Output shape: torch.Size([1, 512, 512])


# Configuring Training

In [None]:
# Hyperparameters
LEARNING_RATE = 0.001
BATCH_SIZE_TRAIN = 2
BATCH_SIZE_VALID = 2
BATCH_SIZE_TEST = 2

# Define the loss function (DiceLoss or CrossEntropyLoss)
loss = smp.utils.losses.DiceLoss()  # Change to CrossEntropyLoss() if needed
#loss = smp.utils.losses.CrossEntropyLoss()

# Define the metric for evaluation. IoU (Intersection over Union) is a standard metric for segmentation.
#metrics = [smp.utils.metrics.mIoU()]S
metrics = [smp.utils.metrics.IoU()]

# Initialize the optimizer
opt = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True, num_workers=0)
valid_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE_VALID, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE_TEST, shuffle=False, num_workers=0)

# Training loop setup using SMP utilities
train_epoch = TrainEpoch(model, loss=loss, metrics=metrics, optimizer=opt, device=DEVICE)
valid_epoch = ValidEpoch(model, loss=loss, metrics=metrics, device=DEVICE)

# Verify batch
for images, labels in train_loader:
    print("Image batch shape:", images.shape)
    print("Label batch shape:", labels.shape)
    break

Image batch shape: torch.Size([2, 11, 512, 512])
Label batch shape: torch.Size([2, 512, 512])


# Training Loop

In [None]:
# Initialize the minimum dice loss and max IoU for saving the best model
max_iou = 0

# Number of epochs to train
EPOCHS = 100

# # Model save path
# model_save_dir = 'test_models'
# os.makedirs(model_save_dir, exist_ok=True)  # Create directory if it doesn't exist

# Run the training loop for the specified number of epochs
for epoch in range(EPOCHS):
    print(f'\nEpoch: {epoch + 1}/{EPOCHS}')
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    # If validation IoU improves, save the model's state dictionary
    if max_iou < valid_logs['iou_score']:
        max_iou = valid_logs['iou_score']
        torch.save(model.state_dict(), 'F:/CNN_LSTM/test_models/test_convlstm_b0.pth')
        print('Model saved!')


Epoch: 1/100
train: 100%|██████████| 267/267 [8:12:49<00:00, 110.75s/it, DiceLoss - 0.3819, iou_score - 0.4769]  
valid: 100%|██████████| 75/75 [10:56<00:00,  8.75s/it, DiceLoss - 0.2654, iou_score - 0.5872]
Model saved!

Epoch: 2/100
train: 100%|██████████| 267/267 [8:08:19<00:00, 109.74s/it, DiceLoss - 0.3057, iou_score - 0.5566]  
valid: 100%|██████████| 75/75 [11:01<00:00,  8.82s/it, DiceLoss - 0.2247, iou_score - 0.6458]
Model saved!

Epoch: 3/100
train: 100%|██████████| 267/267 [7:50:02<00:00, 105.63s/it, DiceLoss - 0.2757, iou_score - 0.5921]  
valid: 100%|██████████| 75/75 [10:23<00:00,  8.32s/it, DiceLoss - 0.3006, iou_score - 0.5608]

Epoch: 4/100
train: 100%|██████████| 267/267 [7:42:06<00:00, 103.84s/it, DiceLoss - 0.2336, iou_score - 0.6433]  
valid: 100%|██████████| 75/75 [10:19<00:00,  8.26s/it, DiceLoss - 0.1911, iou_score - 0.6835]
Model saved!

Epoch: 5/100
train: 100%|██████████| 267/267 [7:43:07<00:00, 104.07s/it, DiceLoss - 0.1997, iou_score - 0.6874]  
valid: 100

# Loading the Pre-trained Model

In [9]:
print(f"Max IoU: {max_iou}")

Max IoU: 0.623845636844635


In [10]:
model.load_state_dict(torch.load('F:/CNN_LSTM/test_models/test_convlstm_b0.pth'))
model.to('cuda')

  model.load_state_dict(torch.load('F:/CNN_LSTM/test_models/test_convlstm_b7.pth'))


CNNConvLSTMNet(
  (cnn_backbone): EfficientNetB0Backbone(
    (model): EfficientNet(
      (features): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): SiLU(inplace=True)
        )
        (1): Sequential(
          (0): MBConv(
            (block): Sequential(
              (0): Conv2dNormActivation(
                (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
                (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): SiLU(inplace=True)
              )
              (1): SqueezeExcitation(
                (avgpool): AdaptiveAvgPool2d(output_size=1)
                (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
                (fc2): Conv2d(8, 32, kernel_size=(1, 1), str

# Model Evaluation

In [None]:
from segmentation_models_pytorch.utils.model_eval import display_binary_metrics

metrics_df = display_binary_metrics(model, test_loader, DEVICE, threshold=0.5, 
                               show_iou=True, show_precision=True, show_recall=True, show_f1_score=True)

ValueError: Input and output must have the same number of spatial dimensions, but got input with spatial dimensions of [512] and output size of torch.Size([512, 512]). Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size in (o1, o2, ...,oK) format.

In [None]:
metrics_df

# Visualizing Model Predictions

In [None]:
from segmentation_models_pytorch.utils.visualization import visualize_predictions

# For binary segmentation:
visualize_predictions(model, valid_loader, DEVICE, num_images=5, binary=True, threshold=0.5)

# For multiclass segmentation:
# visualize_predictions(model, valid_loader, DEVICE, num_images=5, binary=True)