In [1]:
from IPython.display import clear_output


In [2]:
!pip install wandb -qU

import wandb

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import gc

from torch.nn import Sequential

import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
%pip install torchsummary -q
from torchsummary import summary

Note: you may need to restart the kernel to use updated packages.


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class conv_block(nn.Module):
    def __init__(self, in_c, out_c, dropout_rate=0.4):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu = nn.LeakyReLU()
        self.residual = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0)
        self.dropout = nn.Dropout2d(dropout_rate)

    def forward(self, inputs):
        residual = self.residual(inputs)
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x += residual
        x = self.dropout(x)
        x = self.relu(x)
        return x

class dense_block(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.4):
        super(dense_block, self).__init__()
        self.linear = nn.Linear(in_channels, out_channels)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        x = x.view(batch_size, channels, height * width)
        x = x.transpose(1, 2)
        x = self.linear(x)
        x = self.dropout(x)
        x = x.transpose(1, 2)
        x = x.view(batch_size, channels, height, width)
        return x

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)
        self.relu = nn.ReLU()
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.global_avg_pool(x)
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y)
        return x * y

class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size=7, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, kernel_size=1, padding=0, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, kernel_size=1, padding=0, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class PyramidPoolingModule(nn.Module):
    def __init__(self, in_channels, pool_sizes):
        super(PyramidPoolingModule, self).__init__()
        self.stages = nn.ModuleList([nn.AdaptiveAvgPool2d(pool_size) for pool_size in pool_sizes])
        self.convs = nn.ModuleList([nn.Conv2d(in_channels, in_channels // len(pool_sizes), kernel_size=1) for _ in pool_sizes])
        self.conv1x1 = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1)

    def forward(self, x):
        h, w = x.size(2), x.size(3)
        pyramids = [F.interpolate(self.convs[i](stage(x)), size=(h, w), mode='bilinear', align_corners=True) for i, stage in enumerate(self.stages)]
        out = torch.cat([x] + pyramids, dim=1)
        out = self.conv1x1(out)
        return out

class FeatureFusionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FeatureFusionBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.relu = nn.ReLU()

    def forward(self, x1, x2):
        x = torch.cat([x1, x2], dim=1)
        x = self.conv(x)
        x = self.relu(x)
        return x

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c, dropout_rate=0.4):
        super().__init__()
        self.conv = conv_block(in_c, out_c, dropout_rate)
        self.pool = nn.MaxPool2d((2, 2))
        self.se = SEBlock(out_c)
        self.channel_attention = ChannelAttention(out_c)
        self.spatial_attention = SpatialAttention()

    def forward(self, inputs):
        x = self.conv(inputs)
        x = self.se(x)
        x = self.channel_attention(x) * x
        x = self.spatial_attention(x) * x
        p = self.pool(x)
        return x, p

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c, dropout_rate=0.4):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c + out_c, out_c, dropout_rate)
        self.attention = AttentionGate(F_g=out_c, F_l=out_c, F_int=out_c // 2)
        self.dense = dense_block(out_c, out_c, dropout_rate)
        self.channel_attention = ChannelAttention(out_c)
        self.spatial_attention = SpatialAttention()

    def forward(self, inputs, skip):
        x = self.up(inputs)
        skip = self.dense(skip)  # Apply dense block to skip connection
        skip = self.attention(x, skip)  # Apply attention gate
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        x = self.channel_attention(x) * x
        x = self.spatial_attention(x) * x
        return x

class Model(nn.Module):
    def __init__(self, encoder_cfg, bottleneck_cfg, decoder_cfg, output_channels, dropout_rate=0.4):
        super().__init__()

        """ Encoder """
        self.encoder_blocks = nn.ModuleList([
            encoder_block(encoder_cfg[i], encoder_cfg[i + 1], dropout_rate) for i in range(len(encoder_cfg) - 1)
        ])

        """ Bottleneck with Residual Connection """
        self.b_conv1 = conv_block(bottleneck_cfg[0], bottleneck_cfg[1], dropout_rate)
        self.b_conv2 = conv_block(bottleneck_cfg[1], bottleneck_cfg[1], dropout_rate)
        self.b_residual = nn.Conv2d(bottleneck_cfg[0], bottleneck_cfg[1], kernel_size=1)
        self.se = SEBlock(bottleneck_cfg[1])
        self.channel_attention = ChannelAttention(bottleneck_cfg[1])
        self.spatial_attention = SpatialAttention()
        self.ppm = PyramidPoolingModule(bottleneck_cfg[1], pool_sizes=[1, 2, 3, 6])

        """ Decoder """
        self.decoder_blocks = nn.ModuleList([
            decoder_block(decoder_cfg[i], decoder_cfg[i + 1], dropout_rate) for i in range(len(decoder_cfg) - 1)
        ])

        """ Final Convolution """
        self.final_conv = nn.Conv2d(decoder_cfg[-1], output_channels, kernel_size=1)
        self.dropout = nn.Dropout2d(dropout_rate)

    def forward(self, inputs):
        """ Encoder """
        skip_connections = []
        x = inputs
        for encoder in self.encoder_blocks:
            s, x = encoder(x)
            skip_connections.append(s)

        """ Bottleneck with Residual Connection """
        residual = self.b_residual(x)
        x = self.b_conv1(x)
        x = self.dropout(x)
        x = self.b_conv2(x) + residual
        x = self.se(x)
        x = self.channel_attention(x) * x
        x = self.spatial_attention(x) * x
        x = self.ppm(x)

        """ Decoder """
        skip_connections = skip_connections[::-1]
        for i in range(len(self.decoder_blocks)):
            x = self.decoder_blocks[i](x, skip_connections[i])

        """ Final Convolution """
        x = self.dropout(x)
        x = self.final_conv(x)
        return x


In [None]:
"""
This file needs to contain the main training loop. The training code should be encapsulated in a main() function to
avoid any global variables.
"""
        
import matplotlib.pyplot as plt
import numpy as np

from torchvision import datasets
import torchvision.transforms.v2 as transforms

from torchvision.datasets import Cityscapes
from argparse import ArgumentParser

import wandb
import torch.optim.lr_scheduler as lr_scheduler
import gc

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import random


try:
    import utils
except Exception:
    import sys
    sys.path.insert(1, '/kaggle/input/5lsm0-neural-networks-for-cv-dataset')
    import utils


def get_arg_parser():
    parser = ArgumentParser()
    parser.add_argument("--data_path", type=str, default=".", help="Path to the data")
    """add more arguments here and change the default values to your needs in the run_container.sh file"""
    return parser


def generate_random_colormap(num_classes):
    colormap = {}
    for i in range(num_classes):
        colormap[i] = tuple(random.randint(0, 255) for _ in range(3))
    return colormap


# Then, use the colormap as before
def colorize_mask(mask,colormap):
    """Convert a label mask to an RGB image."""
    mask_colorized = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    for label, color in colormap.items():
        mask_colorized[mask == label] = color
    return mask_colorized

def visualize_images_and_masks(loader):
    dataiter = iter(loader)
    images, masks = dataiter.__next__()

    fig = plt.figure(figsize=(16, 10))
    for idx in np.arange(4):
        ax1 = fig.add_subplot(2, 4, 2*idx+1, xticks=[], yticks=[])
        ax2 = fig.add_subplot(2, 4, 2*idx+2, xticks=[], yticks=[])

        plt.sca(ax1)
        plt.imshow(np.transpose(images[idx].cpu(), (1, 2, 0)))
        if idx == 0:
            ax1.set_title('Images')

        plt.sca(ax2)
        plt.imshow(masks[idx].squeeze().cpu(), cmap="gray")
        if idx == 0:
            ax2.set_title('Masks')
    plt.show()



import torch.nn.functional as F

class FocalLoss(nn.modules.loss._WeightedLoss):
    def __init__(self, weight=None, gamma=2,reduction='mean'):
        super(FocalLoss, self).__init__(weight,reduction=reduction)
        self.gamma = gamma
        self.weight = weight #weight parameter will act as the alpha parameter to balance class weights
        self.cross_entropy = nn.CrossEntropyLoss(ignore_index=255)

    def forward(self, input, target):

        ce_loss = self.cross_entropy(input,target)#F.cross_entropy(input, target,reduction=self.reduction,weight=self.weight, ignore_index=255) 
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss
    
    
    
def compute_iou(preds, labels, num_classes):
    ious = []
    for cls in range(num_classes):
        pred_cls = preds == cls
        label_cls = labels == cls
        intersection = (pred_cls & label_cls).sum().item()
        union = (pred_cls | label_cls).sum().item()
        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append(intersection / union)
    return ious
    
    
    
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
    
    
def main(args):
    print("Starting the main method...")

    
    """define your model, trainingsloop optimitzer etc. here"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # You don't have a gpu, so use cpu

    import os
    import matplotlib.pyplot as plt
    import numpy as np
    from torchvision import datasets
    import torchvision.transforms.v2 as transforms
    from torchvision.datasets import Cityscapes
    from argparse import ArgumentParser
    import wandb
    import torch.optim.lr_scheduler as lr_scheduler
    import random

    
    
    transform_data = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ColorJitter(contrast=0.7),
    transforms.RandomRotation([-120,120]),
    transforms.RandomHorizontalFlip(p=0.6),
    transforms.RandomVerticalFlip(p=0.6),
    transforms.GaussianBlur(3,sigma=(0.1, 5)),

    transforms.ToTensor(),
    transforms.Lambda(lambda x: x / 255.0),  # Scale the tensor values to the range [0, 1]
    ])

    # Apply transformations to the dataset
    try:
        transformed_dataset = Cityscapes("/kaggle/input/5lsm0-neural-networks-for-cv-dataset", split='train', mode='fine', target_type='semantic',transforms=transform_data)#, transform=transform_data, target_transform=transform_target)
    except Exception:
        transformed_dataset = Cityscapes(args.data_path, split='train', mode='fine', target_type='semantic',transforms=transform_data)

    
    try:
        transformed_test_dataset = Cityscapes("/kaggle/input/5lsm0-neural-networks-for-cv-dataset", split='test', mode='fine', target_type='semantic',transforms= transform_data)#, transform=transform_data, target_transform=transform_target)
        val_loader =  DataLoader(transformed_test_dataset, batch_size=64, shuffle=False, num_workers=18)
    except Exception:
        val_loader=None
        
    
    split_point1 = int(0.7 * len(transformed_dataset))  # 70% for training
    split_point2 = int(0.9 * len(transformed_dataset))  # 20% for validation, 10% for testing

    # Split the dataset
    transformed_train_dataset = torch.utils.data.Subset(transformed_dataset, range(split_point1))
    transformed_val_dataset = torch.utils.data.Subset(transformed_dataset, range(split_point1, split_point2))
    transformed_test = torch.utils.data.Subset(transformed_dataset, range(split_point2, len(transformed_dataset)))


    bs = 32
    print("Batch size ", bs)
    
    train_loader = DataLoader(transformed_train_dataset, batch_size=16, shuffle=True, num_workers=18)
    val_loader = DataLoader(transformed_val_dataset, batch_size=16, shuffle=True, num_workers=18)
    test_loader = DataLoader(transformed_test, batch_size=4, shuffle=False, num_workers=18)

    
    
    encoder_cfg = [3, 64, 128, 256, 512]  # Example encoder configuration
    bottleneck_cfg = [512, 1024]  # Example bottleneck configuration
    decoder_cfg = [1024, 512, 256, 128, 64]  # Example decoder configuration
    

    output_channels = 19  # Number of output channels for segmentation

 
    try:
        model = Model(encoder_cfg, bottleneck_cfg, decoder_cfg, output_channels).cuda()

        print("Using CUDA")
        
    except Exception as exc:
        print(exc)

        model = Model(encoder_cfg, bottleneck_cfg, decoder_cfg, output_channels)


        print("Using CPU")

    display(summary(model,(3,256, 256)))

    num_epochs = 200
    print(num_epochs)
    learning_rate = 0.001
    print("USING ", learning_rate)
    
    best_train_loss= 100
    best_test_loss= 100

    # Set the loss function and optimizer
    from torch import tensor
    from torchmetrics.classification import Dice
    from torchmetrics.functional.classification import dice
    criterion = nn.CrossEntropyLoss(ignore_index=255) 
    wd = 0.01 
    print("WD: ", wd)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate,weight_decay=wd,amsgrad=True)
    
    scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
    
    print(scheduler.get_last_lr(), " LEARNING RATE")
    count = 0
    loss_list = []
    iteration_list = []
    # Train the model
    
    
    def generate_random_colormap(num_colors):
        np.random.seed(0)
        colormap = np.random.randint(0, 255, size=(num_colors, 3), dtype=np.uint8)
        return colormap

    def colorize_mask(mask, colormap):
        colorized = colormap[mask]
        return colorized
    
    
    colormap = generate_random_colormap(256)
    
    early_stopper = EarlyStopper(patience=20, min_delta=10)

    for epoch in range(num_epochs): 
        running_loss = 0.0
        test_loss=0.0
        model.train()
        for i, data in enumerate(train_loader):
            inputs, labels = data[0].to(device), (data[1]*255).to(device).long()

            optimizer.zero_grad()

            outputs = model(inputs)
            labels = labels.squeeze(1)  # remove the extra dimension
            labels = utils.map_id_to_train_id(labels).to(device)
            print(labels.unique())
            loss = criterion(outputs, labels)
            v=epoch + 1
            
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            print(f'Epoch {epoch + 1}, Iteration [{i}/{len(train_loader)}], Loss: {running_loss/(i+1)}, Model is in mode training: {model.training}',outputs.shape,inputs.shape)

            
        clear_output()
        
        if((running_loss / len(train_loader))<best_train_loss):
            best_train_loss=running_loss / len(train_loader)
            print(f"Updated train loss, now it is {best_train_loss}")
        print(f'Finished epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')

        with torch.no_grad():
            model.eval()
            
            for i, data in enumerate(test_loader):

                inputs, labels = data[0].to(device), (data[1]*255).to(device).long()
                labels=labels.squeeze(1)
                labels = utils.map_id_to_train_id(labels).to(device)
                outputs = model(inputs)


                loss = criterion(outputs, labels)
                v=epoch + 1


                test_loss += loss.item()
                preds = torch.argmax(outputs, dim=1)
                ious = compute_iou(preds.cpu().numpy(), labels.cpu().numpy(), output_channels)

                print(f'TEST: Epoch {epoch + 1}, Iteration [{i}/{len(test_loader)}], Loss: {running_loss/(i+1)}, Model is in mode training: {model.training}') #,outputs.shape,inputs.shape

            if((test_loss / len(test_loader))<best_test_loss):
                best_test_loss=test_loss / len(test_loader)
                print(f"Updated test loss, now it is {best_test_loss}")
                
        clear_output()
        
        scheduler.step()
        print(scheduler.get_last_lr(), " NEW LEARNING RATE")
        print(f'Finished TEST epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(test_loader):.4f}')

        
        

        # Visualize the first image in the batch
        if i % 1 == 0:  # Visualize every 5 iterations
            import matplotlib.pyplot as plt
            import numpy as np

            def generate_random_colormap(num_classes):
                colormap = {}
                for i in range(num_classes):
                    colormap[i] = tuple(random.randint(0, 255) for _ in range(3))
                return colormap

            def colorize_mask(mask):
                """Convert a label mask to an RGB image."""
                mask_colorized = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
                for label, color in colormap.items():
                    mask_colorized[mask == label] = color
                return mask_colorized

            model.eval()
            predicted = torch.argmax(outputs,dim=1)
            print(predicted.unique())
            img = inputs.cpu().numpy()[0]
            img = np.transpose(img, (1, 2, 0))
            img = img * 255  # Rescale the image back to 0-255
            img = np.clip(img, 0, 255)

            label = labels.cpu().numpy()[0]
            pred = predicted.cpu().numpy()[0]

            print("Ground truth label range:", np.min(label), np.max(label))
            print("Predicted label range:", np.min(pred), np.max(pred))

            # Generate a random colormap for your number of classes
            num_classes = len(np.unique(label))  # Change this to the number of classes in your dataset
            colormap = generate_random_colormap(num_classes)

            fig, axs = plt.subplots(1, 3, figsize=(15, 5))
            axs[0].imshow(img)
            axs[0].set_title('Input Image')
            axs[1].imshow(colorize_mask(label))
            axs[1].set_title('Ground Truth')
            axs[2].imshow(colorize_mask(pred))
            axs[2].set_title('Prediction')

            plt.show()
            model.train()


        if early_stopper.early_stop(test_loss):             
            break

        
    if (val_loader):
        with torch.no_grad():
                model.eval()

                for i, data in enumerate(val_loader):
                    inputs, labels = data[0].to(device), (data[1]*255).to(device).long()
                    labels=labels.squeeze(1)

                    outputs = model(inputs)

                    loss = criterion(outputs, labels)

                    running_loss += loss.item()

                    print(f'VAL: Iteration [{i}/{len(val_loader)}], Loss: {running_loss/(i+1)}',outputs.shape,inputs.shape)
    # save model
    print(f"Best losses --> Train:{best_train_loss} and Test:{best_test_loss}")
    torch.save(model.state_dict(), "model_scripted.pth")
    



    # visualize some results

    pass


if __name__ == "__main__":
    gc.collect()
    main("")

