In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install torch torchvision opencv-python



## HELPER FUNCITONS

In [None]:
## HELPER FUNCTIONS
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms

# Define color ranges for each class
# List of RGB ranges corresponding to each class index
idx_to_color = [
    [(120, 120, 120), (135, 135, 135)],  # 0: Road (approximate range for [128, 128, 128])
    [(120, 0, 0), (135, 15, 15)],        # 1: Sidewalk (approximate range for [128, 0, 0])
    [(185, 185, 120), (200, 200, 135)],  # 2: Building (approximate range for [192, 192, 128])
    [(120, 55, 120), (135, 75, 135)],    # 3: Vegetation (approximate range for [128, 64, 128])
    [(50, 30, 215), (70, 50, 230)],      # 4: Sky (approximate range for [60, 40, 222])
    [(120, 120, 0), (135, 135, 15)],     # 5: Traffic Sign (approximate range for [128, 128, 0])
    [(185, 120, 120), (200, 135, 135)],  # 6: Pedestrian (approximate range for [192, 128, 128])
    [(55, 55, 120), (75, 75, 135)],      # 7: Vehicle (approximate range for [64, 64, 128])
    [(55, 0, 120), (75, 15, 135)],       # 8: Pole (approximate range for [64, 0, 128])
    [(55, 55, 0), (75, 75, 15)],         # 9: Fence (approximate range for [64, 64, 0])
    [(0, 120, 185), (15, 135, 200)]      # 10: Road Marking (approximate range for [0, 128, 192])
]

# Mapping of class indices to category IDs
name_to_category = {
    0: 0,   # Road
    1: 1,   # Sidewalk
    2: 2,   # Building
    3: 3,   # Vegetation
    4: 4,   # Sky
    5: 5,   # Traffic Sign
    6: 6,   # Pedestrian
    7: 7,   # Vehicle
    8: 8,   # Pole
    9: 9,   # Fence
    10: 10  # Road Marking
}

def one_hot_to_rgb(one_hot_mask, idx_to_color):
    """Convert batch of one-hot encoded masks back to RGB format."""
    # Check if there's a batch dimension
    if one_hot_mask.dim() == 4:
        batch_size, num_classes, height, width = one_hot_mask.shape
        rgb_batch = []

        # Process each mask in the batch
        for i in range(batch_size):
            class_indices = torch.argmax(one_hot_mask[i], dim=0).cpu().numpy()  # (H, W)
            rgb_mask = np.zeros((height, width, 3), dtype=np.uint8)

            # Map each class index to its corresponding RGB color
            for class_idx, color_range in enumerate(idx_to_color):
                rgb_color = color_range[0]  # Use the first color in the range
                rgb_mask[class_indices == class_idx] = rgb_color  # Assign RGB color to the appropriate pixels

            # Convert to tensor and permute to match (C, H, W) format
            rgb_mask = torch.from_numpy(rgb_mask).permute(2, 0, 1).float() / 255.0
            rgb_batch.append(rgb_mask)

        # Stack the batch back together
        rgb_batch = torch.stack(rgb_batch).to(one_hot_mask.device)
        return rgb_batch

    else:
        raise ValueError("Expected input with batch dimension [B, C, H, W]")



def canny_edge_extraction(mask):
    """Extract edges from the mask using Canny edge detection."""
    gray_mask = cv2.cvtColor(np.array(mask), cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(gray_mask, threshold1=100, threshold2=200)
    edges = np.expand_dims(edges, axis=0)  # Add channel dimension
    return torch.tensor(edges, dtype=torch.float32) / 255.0  # Normalize to [0, 1]

def rgb_to_class_idx(mask):
    """Convert RGB mask to class indices based on color ranges."""
    mask = np.array(mask)
    class_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int32)
    for class_idx, (min_color, max_color) in enumerate(idx_to_color):
        mask_range = (mask >= min_color) & (mask <= max_color)
        class_mask[mask_range.all(axis=2)] = class_idx
    return torch.tensor(class_mask, dtype=torch.long)

def one_hot_encode(mask, num_classes):
    """Convert class index mask to one-hot encoded format."""
    return torch.nn.functional.one_hot(mask, num_classes=num_classes).permute(2, 0, 1).float()

def preprocess_mask(mask, num_classes=11):
    """Preprocess mask by converting to class indices, one-hot encoding, and adding Canny edges."""
    class_mask = rgb_to_class_idx(mask)
    one_hot_mask = one_hot_encode(class_mask, num_classes)
    edges = canny_edge_extraction(mask)
    combined_mask = torch.cat([one_hot_mask, edges], dim=0)  # Concatenate along channel dimension
    return combined_mask


## MAKING THE CUSTOM DATASET

In [None]:
## DATASET

from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image

class CityscapesDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, num_classes=11):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = os.listdir(image_dir)
        self.transform = transform
        self.num_classes = num_classes

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.image_files[idx])

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("RGB")

        # Convert mask from RGB to class indices and one-hot encode it
        class_mask = rgb_to_class_idx(mask)
        one_hot_mask = one_hot_encode(class_mask, self.num_classes)

        if self.transform:
            image = self.transform(image)

        return image, one_hot_mask


## MODEL DEFINITION

In [None]:
## DEFINING THE MODEL
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

class PatchEmbedding(nn.Module):
    """Embed patches for input to Swin Transformer."""
    def __init__(self, in_channels, embed_dim, patch_size=4):
        super(PatchEmbedding, self).__init__()
        self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        return self.patch_embed(x)

class SwinTransformerBlock(nn.Module):
    """Single Swin Transformer Block with Window and Shifted Window Self-Attention."""
    def __init__(self, dim, num_heads, window_size=7):
        super(SwinTransformerBlock, self).__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.msa = nn.MultiheadAttention(dim, num_heads)
        self.ln2 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.ReLU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, x):
        # Reshape x to (batch_size * height * width, channels) for LayerNorm and MSA
        batch_size, channels, height, width = x.shape
        x = x.permute(0, 2, 3, 1).contiguous().view(-1, channels)

        # Apply LayerNorm and MSA
        x_ln = self.ln1(x)
        x_msa, _ = self.msa(x_ln.unsqueeze(0), x_ln.unsqueeze(0), x_ln.unsqueeze(0))
        x = x + x_msa.squeeze(0)

        # Apply FFN
        x_ln = self.ln2(x)
        x_ffn = self.ffn(x_ln)
        x = x + x_ffn

        # Reshape back to (batch_size, channels, height, width)
        x = x.view(batch_size, height, width, channels).permute(0, 3, 1, 2)
        return x


class Encoder(nn.Module):
    """Encoder based on Swin Transformer."""
    def __init__(self, in_channels, embed_dim):
        super(Encoder, self).__init__()
        self.patch_embed = PatchEmbedding(in_channels, embed_dim)
        self.swin_transformers = nn.ModuleList([
            SwinTransformerBlock(embed_dim, num_heads=3),
            SwinTransformerBlock(embed_dim, num_heads=3)
        ])

    def forward(self, x):
        x = self.patch_embed(x)
        for transformer in self.swin_transformers:
            x = transformer(x)
        return x

class CRFBlock(nn.Module):
    """Conditional Residual Fusion block with WAM and OLM."""
    def __init__(self, in_channels, out_channels):
        super(CRFBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.lrelu = nn.LeakyReLU(0.2)
        self.olm = OppositionLearningMechanism(out_channels)
        self.wam = WeightAssignmentMechanism(out_channels)

    def forward(self, x, condition):
        x = self.conv1(x)
        x = self.lrelu(x)
        x = self.conv2(x)
        x = self.wam(x)
        x = self.olm(x, condition)  # The OLM will adjust both channel and spatial dimensions
        return x


class OppositionLearningMechanism(nn.Module):
    """Opposition-based Learning Mechanism for enhancing semantic feature information."""
    def __init__(self, x_channels):
        super(OppositionLearningMechanism, self).__init__()
        # Adjust condition to have the same number of channels as x
        self.channel_adjust = nn.Conv2d(in_channels=x_channels, out_channels=x_channels, kernel_size=1)

    def forward(self, x, condition):
        # Match channel dimensions of condition to x
        condition = self.channel_adjust(condition)
        # Match spatial dimensions of condition to x
        condition = F.interpolate(condition, size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=True)
        return x * (1 - condition) + condition

class WeightAssignmentMechanism(nn.Module):
    """Assign attention weights on channel and spatial dimensions."""
    def __init__(self, channels):
        super(WeightAssignmentMechanism, self).__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        w = self.pool(self.conv(x))
        return x * self.sigmoid(w)

class Decoder(nn.Module):
    """Decoder with CRF and Swin Transformer."""
    def __init__(self, embed_dim, out_channels):
        super(Decoder, self).__init__()
        self.crf_block = CRFBlock(embed_dim, embed_dim)
        self.swin_transformers = nn.ModuleList([
            SwinTransformerBlock(embed_dim, num_heads=3),
            SwinTransformerBlock(embed_dim, num_heads=3)
        ])
        self.final_conv = nn.Conv2d(embed_dim, out_channels, kernel_size=3, padding=1)

    def forward(self, x, low_level_feat):
        x = self.crf_block(x, low_level_feat)
        for transformer in self.swin_transformers:
            x = transformer(x)
        return torch.tanh(self.final_conv(x))

class MultiScaleDiscriminator(nn.Module):
    """Multi-Scale Discriminator with PatchGAN."""
    def __init__(self, in_channels):
        super(MultiScaleDiscriminator, self).__init__()
        self.discriminator_blocks = nn.ModuleList([
            PatchGANDiscriminator(in_channels),
            PatchGANDiscriminator(in_channels)
        ])

    def forward(self, x):
        return torch.mean(torch.stack([d(x) for d in self.discriminator_blocks], dim=0))

class PatchGANDiscriminator(nn.Module):
    """Single PatchGAN Discriminator."""
    def __init__(self, in_channels):
        super(PatchGANDiscriminator, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.layers(x)

class SC_UNet(nn.Module):
    """Complete SC-UNet model combining Encoder, Decoder and Discriminator for GAN training."""
    def __init__(self, in_channels, embed_dim, out_channels):
        super(SC_UNet, self).__init__()
        self.encoder = Encoder(in_channels, embed_dim)
        self.decoder = Decoder(embed_dim, out_channels)
        self.discriminator = MultiScaleDiscriminator(out_channels)

    def forward(self, x):
        encoded_features = self.encoder(x)  # Use encoded features as low_level_feat
        decoded_image = self.decoder(encoded_features, encoded_features)  # Pass encoded_features as low_level_feat
        return decoded_image, self.discriminator(decoded_image)


# Instantiate the model
model = SC_UNet(in_channels=3, embed_dim=96, out_channels=3)
print(model)


SC_UNet(
  (encoder): Encoder(
    (patch_embed): PatchEmbedding(
      (patch_embed): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    )
    (swin_transformers): ModuleList(
      (0-1): 2 x SwinTransformerBlock(
        (ln1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (msa): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=96, out_features=96, bias=True)
        )
        (ln2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (ffn): Sequential(
          (0): Linear(in_features=96, out_features=384, bias=True)
          (1): ReLU()
          (2): Linear(in_features=384, out_features=96, bias=True)
        )
      )
    )
  )
  (decoder): Decoder(
    (crf_block): CRFBlock(
      (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (lrelu): LeakyReLU(negative_slope=0.2)
      (olm): OppositionLearningMecha

## LOSS AND OPTIMISE

In [None]:
## LOSS FUNCTION AND OPTIMISER
import torch.optim as optim
import torch.nn.functional as F

def hinge_adversarial_loss(real, fake):
    loss_real = torch.mean(F.relu(1.0 - real))
    loss_fake = torch.mean(F.relu(1.0 + fake))
    return loss_real + loss_fake

def feature_matching_loss(real_features, fake_features):
    """Calculate feature matching loss by converting real_features (masks) to RGB format."""
    # Convert real_features to RGB format to match fake_features
    real_features_rgb = one_hot_to_rgb(real_features, idx_to_color)

    # Ensure both tensors have the same spatial dimensions
    if real_features_rgb.shape != fake_features.shape:
        real_features_rgb = F.interpolate(real_features_rgb, size=fake_features.shape[2:], mode="bilinear", align_corners=False)

    # Calculate mean absolute difference
    loss = torch.mean(torch.abs(real_features_rgb - fake_features))
    return loss

def perceptual_loss(real, fake, feature_extractor):
    real_features = feature_extractor(real)
    fake_features = feature_extractor(fake)
    return feature_matching_loss(real_features, fake_features)

# Instantiate the optimizer
generator_optimizer = optim.Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(model.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))


## ACCURACY AND VALIDATION

In [None]:
import torch
import torch.nn.functional as F

def compute_accuracy(predictions, targets):
    """Compute accuracy by comparing predictions to target classes."""
    # Ensure predictions and targets have the same spatial dimensions
    if predictions.shape[2:] != targets.shape[2:]:
        predictions = F.interpolate(predictions, size=targets.shape[2:], mode="bilinear", align_corners=False)

    # Convert to class indices
    predicted_classes = torch.argmax(predictions, dim=1)  # Get class with highest probability
    target_classes = torch.argmax(targets, dim=1)         # Convert one-hot to class index

    # Compute accuracy
    correct = (predicted_classes == target_classes).float()
    return correct.sum() / correct.numel()

def validate(model, val_dataloader, device):
    """Evaluate model on the validation set."""
    model.eval()
    val_loss = 0.0
    val_accuracy = 0.0

    with torch.no_grad():
        for images, masks in val_dataloader:
            images, masks = images.to(device), masks.to(device)

            # Generate fake images
            fake_images = model(images)
            if isinstance(fake_images, tuple):
                fake_images = fake_images[0]
            print(fake_images.shape)

            # Resize fake images to match masks' spatial dimensions
            if fake_images.shape[2:] != masks.shape[2:]:
                fake_images = F.interpolate(fake_images, size=masks.shape[2:], mode="bilinear", align_corners=False)

            # Calculate MSE loss
            loss = F.mse_loss(fake_images, masks)
            val_loss += loss.item()

            # Calculate accuracy
            val_accuracy += compute_accuracy(fake_images, masks).item()

    val_loss /= len(val_dataloader)
    val_accuracy /= len(val_dataloader)

    return val_loss, val_accuracy

def train(model, train_dataloader, val_dataloader, num_epochs=100):
    model.train()
    for epoch in range(num_epochs):
        train_loss = 0.0
        train_accuracy = 0.0
        for images, masks in train_dataloader:
            images, masks = images.to(device), masks.to(device)

            # Convert one-hot mask to RGB format for the discriminator
            rgb_masks = one_hot_to_rgb(masks, idx_to_color)

            # 1. Train the discriminator
            discriminator_optimizer.zero_grad()
            with torch.no_grad():
                fake_images, _ = model(images)
            real_pred = model.discriminator(rgb_masks)  # Use RGB masks
            fake_pred = model.discriminator(fake_images.detach())
            d_loss = hinge_adversarial_loss(real_pred, fake_pred)
            d_loss.backward()
            discriminator_optimizer.step()

            # 2. Train the generator
            generator_optimizer.zero_grad()
            fake_images, _ = model(images)
            fake_pred = model.discriminator(fake_images)
            g_loss_adv = -torch.mean(fake_pred)
            g_loss_fm = feature_matching_loss(masks, fake_images)
            g_loss = g_loss_adv + 10 * g_loss_fm  # Adjust weights based on experiments
            g_loss.backward()
            generator_optimizer.step()

            # Update training loss and accuracy
            train_loss += g_loss.item()
            train_accuracy += compute_accuracy(fake_images, masks).item()

        # Average training loss and accuracy for the epoch
        train_loss /= len(train_dataloader)
        train_accuracy /= len(train_dataloader)

        # Validate the model
        # val_loss, val_accuracy = validate(model, val_dataloader, device)

        # Print metrics for this epoch
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"  Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
        # print(f"  Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")


# TRAINING

In [None]:
from torchvision import transforms

# Define transformations for images
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Load training dataset
train_dataset = CityscapesDataset(image_dir='add_path_here', mask_dir='add_path_here', transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=46, shuffle=True)

# Load validation dataset
val_dataset = CityscapesDataset(image_dir='add_path_here', mask_dir='add_path_here', transform=transform)
val_dataloader = DataLoader(val_dataset, batch_size=46, shuffle=False)

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SC_UNet(in_channels=3, embed_dim=96, out_channels=3).to(device)
device

device(type='cuda')

In [None]:
# Start training with validation
train(model, train_dataloader, val_dataloader, num_epochs=2)

Epoch [1/2]
  Train Loss: 6.4544, Train Accuracy: 0.7044
Epoch [2/2]
  Train Loss: 6.4549, Train Accuracy: 0.7045


In [None]:
model_save_path = "add_path_here"

torch.save(model.state_dict(), model_save_path)

torch.save(model, "add_path_here")



# USING THE MODEL

In [None]:
import torch

model = SC_UNet(in_channels=11, embed_dim=96, out_channels=3).to(device)
model.load_state_dict(torch.load("sc_unet_model.pth"))
model.eval()


In [None]:
from PIL import Image
import numpy as np

def preprocess_mask(mask_path, idx_to_color, num_classes=11):
    """Load and preprocess the mask from RGB format to one-hot encoding."""
    # Load mask and convert to numpy array
    mask = Image.open(mask_path).convert("RGB")
    mask_np = np.array(mask)

    # Initialize one-hot encoded mask with shape (num_classes, H, W)
    one_hot_mask = np.zeros((num_classes, mask_np.shape[0], mask_np.shape[1]), dtype=np.float32)

    # Populate one-hot encoded mask
    for class_idx, color_range in enumerate(idx_to_color):
        color = color_range[0]  # Use the primary color for this class
        match_pixels = np.all(mask_np == color, axis=-1)
        one_hot_mask[class_idx][match_pixels] = 1

    # Convert to torch tensor and add a batch dimension
    one_hot_mask = torch.tensor(one_hot_mask).unsqueeze(0).to(device)
    return one_hot_mask

# Example usage
one_hot_mask = preprocess_mask("path_to_mask_image.png", idx_to_color)


In [None]:
with torch.no_grad():
    generated_image = model(one_hot_mask)
    generated_image = (generated_image[0] * 255).clamp(0, 255).byte().cpu()
    generated_image_pil = transforms.ToPILImage()(generated_image)

generated_image_pil.show()