In [None]:
from PIL import Image
import cv2
import numpy as np
from pathlib import Path
from tqdm import tqdm
import shutil
import random

import matplotlib.pyplot as plt
import os 

import torch 
import torch.nn as nn

import time

from torchvision import transforms
from torch.utils.data import DataLoader

from sklearn.model_selection import train_test_split


In [None]:
!wget https://s3.eu-central-1.amazonaws.com/avg-kitti/data_semantics.zip

In [None]:
!unzip /kaggle/working/data_semantics.zip

In [None]:
# !pip install ultralytics torch opencv-python numpy
!pip install tqdm

In [None]:
os.listdir("/kaggle/working/training/semantic_rgb")[:10]

# EDA of the images and the Labels

In [None]:
img = Image.open("/kaggle/working/training/image_2/000115_10.png")

In [None]:
plt.figure(figsize=(15, 30))
plt.imshow(img)
plt.axis("off")
plt.show()

## Labels

In [None]:
label_img = Image.open("/kaggle/working/training/semantic_rgb/000115_10.png")

In [None]:
label_img_numpy = np.array(label_img)

In [None]:
# Reshape to (num_pixels, 3)
pixels = label_img_numpy.reshape(-1, 3)

# Get unique RGB triplets
unique_colors = np.unique(pixels, axis=0)

print("Unique colors (labels):", unique_colors)
print("Number of labels:", len(unique_colors))


In [None]:
print(f"There are {len(unique_colors)} unique labels in this specific image")

### In this specific image there are 22 labels in this image!

In [None]:
plt.figure(figsize=(15, 30))
plt.imshow(label_img_numpy)
plt.axis("off")
plt.show()

# Get the number of all labels 

In [None]:


# # Path to semantic RGB masks
# mask_dir = Path("/kaggle/working/training/semantic_rgb")

# # Extract unique colors
# colors = set()
# for mask_path in tqdm(mask_dir.glob("*.png"), desc="Scanning masks"):
#     mask = cv2.imread(str(mask_path))[:, :, ::-1]  # BGR → RGB
#     unique_colors = np.unique(mask.reshape(-1, 3), axis=0)
#     colors.update(map(tuple, unique_colors))

# # Sort colors and show first 20 for reference
# colors = sorted(colors)
# print(f"✅ Found {len(colors)} unique colors.")
# print(colors)

colors = [(0, 0, 0), (0, 0, 70), (0, 0, 90), (0, 0, 110), (0, 0, 142), (0, 0, 230), 
          (0, 60, 100), (0, 80, 100), (70, 70, 70), (70, 130, 180), (81, 0, 81), 
          (102, 102, 156), (107, 142, 35), (111, 74, 0), (119, 11, 32), (128, 64, 128),
          (150, 100, 100), (150, 120, 90), (152, 251, 152), (153, 153, 153), 
          (180, 165, 180), (190, 153, 153), (220, 20, 60), (220, 220, 0), 
          (230, 150, 140), (244, 35, 232), (250, 170, 30), (250, 170, 160), 
          (255, 0, 0)]
# Save color map to file
color_map_file = Path("/kaggle/working/color_map.txt")
with open(color_map_file, "w") as f:
    for i, c in enumerate(colors):
        f.write(f"{i}: {c}\n")

print(f"Color map saved to {color_map_file}")


# Define the UNet

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3):
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding), 
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size, padding=padding), 
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, n_classes=1):
        super().__init__()
        self.d1 = DoubleConv(3, 64)
        self.d2 = DoubleConv(64, 128)
        self.d3 = DoubleConv(128, 256)
        self.u1 = DoubleConv(256+128, 128)
        self.u2 = DoubleConv(128+64, 64)
        self.out_conv = nn.Conv2d(64, n_classes, 1)
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor = 2, mode = "bilinear", align_corners=True)
        
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # print(x.shape)
        x1 = self.d1(x)
        # print(x1.shape)
        x2 = self.d2(self.pool(x1))
        # print(x2.shape)
        x3 = self.d3(self.pool(x2))
        # print(x3.shape)
        x = self.up(x3)
        # print(x.shape)
        x = self.u1(torch.cat([x, x2], dim=1))
        # print(x.shape)
        x = self.up(x)
        x = self.u2(torch.cat([x, x1], dim=1))
        out = self.out_conv(x)
        # print(out.shape)
        return out

In [None]:
# import torch
# from torch.utils.data import Dataset
# from PIL import Image
# import numpy as np

# # 1️⃣ Load the color map from your file
# color_map_file = "/kaggle/working/color_map.txt"
# color_to_class = {}

# with open(color_map_file, "r") as f:
#     for line in f:
#         line = line.strip()
#         if line:
#             cls, rgb = line.split(":")
#             cls = int(cls.strip())
#             rgb = rgb.strip()[1:-1]  # remove parentheses
#             r, g, b = [int(x) for x in rgb.split(",")]
#             color_to_class[(r, g, b)] = cls

# # 2️⃣ Dataset class
# class SegmentationDataset(Dataset):
#     def __init__(self, image_paths, label_paths, transform=None):
#         self.image_paths = image_paths
#         self.label_paths = label_paths
#         self.transform = transform

#     def __len__(self):
#         return len(self.image_paths)

#     def __getitem__(self, idx):
#         # Load image
#         img = Image.open(self.image_paths[idx]).convert("RGB")
#         # img = np.array(img)
#         # Load RGB mask
#         label_rgb = Image.open(self.label_paths[idx]).convert("RGB")
#         # label_rgb = np.array(label_rgb)

#         if self.transform:
#             img = self.transform[0](img)
#             label_rgb = self.transform[1](label_rgb)
#         # Convert RGB mask to class indices
#         label = np.zeros((label_rgb.shape[0], label_rgb.shape[1]), dtype=np.int64)
#         for color, cls_idx in color_to_class.items():
#             mask = np.all(label_rgb == color, axis=-1)
#             label[mask] = cls_idx

#         # label = torch.from_numpy(label)
#         # img = torch.from_numpy(img)
#         return img, label


In [None]:
def rgb_to_class(mask, colors):
    mask = np.array(mask)  # H x W x 3
    out = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int64)
    for idx, color in enumerate(colors):
        matches = np.all(mask == color, axis=-1)
        out[matches] = idx
    return out  # shape H x W

In [None]:
class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, mask_paths, colors, img_transform=None, mask_transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.colors = colors
        self.img_transform = img_transform
        self.mask_transform = mask_transform

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        mask = Image.open(self.mask_paths[idx]).convert("RGB")  # keep RGB

        if self.img_transform:
            img = self.img_transform(img)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        mask = rgb_to_class(mask, self.colors)  # H x W
        mask = torch.from_numpy(mask).long()  # 1 x H x W for DiceLoss
        return img, mask

    def __len__(self):
        return len(self.image_paths)


In [None]:
image_transformation = transforms.Compose([
    transforms.Resize((368, 1224), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor()
])
label_transformation = transforms.Compose([
    transforms.Resize((368, 1224), interpolation=transforms.InterpolationMode.NEAREST),
    # transforms.ToTensor()
])

In [None]:
img_files = os.listdir("/kaggle/working/training/image_2")
label_files = os.listdir("/kaggle/working/training/semantic_rgb")

img_files.sort()
label_files.sort()

img_files = [os.path.join("/kaggle/working/training/image_2", elem) for elem in img_files]
label_files = [os.path.join("/kaggle/working/training/semantic_rgb", elem) for elem in label_files]


In [None]:

# Split into train and test (e.g., 80% train, 20% test)
train_imgs, test_imgs, train_labels, test_labels = train_test_split(
    img_files,
    label_files,
    test_size=0.2,
    random_state=42
)

print(f"Train size: {len(train_imgs)}, Test size: {len(test_imgs)}")

In [None]:
train_sem_seg_dataset = SegmentationDataset(train_imgs, train_labels, colors, image_transformation, label_transformation)
train_loader = DataLoader(train_sem_seg_dataset, batch_size=2, shuffle=True)

test_sem_seg_dataset = SegmentationDataset(test_imgs, test_labels, colors, image_transformation, label_transformation)
val_loader = DataLoader(test_sem_seg_dataset, batch_size=2, shuffle=True)

In [None]:
i=0
for img, label in train_sem_seg_dataset:
    i += 1 
    print("Image")
    print(type(img))
    print(img.shape)
    
    print("label")
    print(type(label))
    print(label.shape)
    print(torch.unique(label))
    
    if i > 10:
        break
    

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiClassDiceLoss(nn.Module):
    def __init__(self, num_classes, smooth=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.smooth = smooth

    def forward(self, pred, target):
        """
        pred: [N, C, H, W], raw logits from model
        target: [N, H, W], integer class labels 0..C-1
        """
        pred = F.softmax(pred, dim=1)  # softmax across channels
        target_one_hot = F.one_hot(target.long(), num_classes=self.num_classes)  # [N,H,W,C]
        target_one_hot = target_one_hot.permute(0,3,1,2).float()  # [N,C,H,W]

        intersection = (pred * target_one_hot).sum(dim=(0,2,3))  # per class
        union = pred.sum(dim=(0,2,3)) + target_one_hot.sum(dim=(0,2,3))
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        loss = 1.0 - dice.mean()
        return loss


In [None]:
# class DiceLoss(nn.Module):
#     def __init__(self):
#         super().__init__()

#     def forward(self, pred, target):
#         pred = torch.sigmoid(pred)
#         smooth = 1.0
#         # print(f"pred shape {pred.shape}, target shape {target.shape}")
#         intersection = (pred * target).sum()
#         dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
#         return 1.0 - dice


In [None]:
# criterion = nn.CrossEntropyLoss()
# criterion = DiceLoss()
criterion = MultiClassDiceLoss(29)

In [None]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for imgs, masks in loader:
        imgs, masks = imgs.to(device), masks.to(device).long()
        # masks = masks.unsqueeze(1)  # Shape [B,1,H,W]
        
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(loader)

In [None]:
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for imgs, masks in loader:
            imgs, masks = imgs.to(device), masks.to(device)
            # masks = masks.unsqueeze(1)
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            total_loss += loss.item()
    return total_loss / len(loader)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model = UNet(n_classes=29)
model.to(device)

In [None]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
epochs = 10
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
for imgs, masks in train_loader:
    print(imgs.shape, imgs.dtype)   # [N,3,368,1224], float32
    print(masks.shape, masks.dtype) # [N,368,1224], int64
    break

In [None]:
best_val_loss = float('inf')
for epoch in range(epochs):
    t0 = time.time()
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss = validate(model, val_loader, criterion, device)
   
    if val_loss < best_val_loss:
        print("Best Result! Model will be saved")
        best_val_loss = val_loss
        torch.save(model.state_dict(), "model.pt")

        
    t1 = time.time()
    elapsed_time = t1 - t0
    print(f"Epoch {epoch+1:02d}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Time: {elapsed_time:.2f}s")
    

In [None]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
model = UNet(n_classes=29).to(device)
model.load_state_dict(torch.load("model.pt"))

In [None]:
# Suppose you have your color map as a NumPy array: (29, 3)
# color_map[i] = [R, G, B] for class i
color_map = np.array([  # example, replace with your real values
    [0, 0, 0], [0, 0, 70], [0, 0, 90], [0, 0, 110], [0, 0, 142],
    [0, 0, 230], [0, 60, 100], [0, 80, 100], [70, 70, 70], [70, 130, 180],
    [81, 0, 81], [102, 102, 156], [107, 142, 35], [111, 74, 0], [119, 11, 32],
    [128, 64, 128], [150, 100, 100], [150, 120, 90], [152, 251, 152], [153, 153, 153],
    [180, 165, 180], [190, 153, 153], [220, 20, 60], [220, 220, 0], [230, 150, 140],
    [244, 35, 232], [250, 170, 30], [250, 170, 160], [255, 0, 0]
])

model.eval()
with torch.no_grad():
    for imgs, masks in val_loader:

        imgs, masks = imgs.to(device), masks.to(device)
        preds = model(imgs)  # (B, 29, H, W)

        # Get predicted class index at each pixel
        preds_classes = torch.argmax(preds, dim=1)  # (B, H, W)

        # Convert class indices to RGB for visualization
        preds_rgb = color_map[preds_classes.cpu().numpy()]  # shape (B, H, W, 3)
        masks_rgb = color_map[masks.cpu().numpy()]          # shape (B, H, W, 3)

        # Visualize one example
        idx = 0  # pick first in batch
        img_np = imgs[idx].cpu().numpy().transpose(1, 2, 0)
        mask_np = masks_rgb[idx]
        pred_np = preds_rgb[idx]

        plt.figure(figsize=(9, 10))
        plt.subplot(3, 1, 1)
        plt.imshow(img_np)
        plt.title("Input Image")
        plt.axis('off')

        plt.subplot(3, 1, 2)
        plt.imshow(mask_np.astype(np.uint8))
        plt.title("Ground Truth")
        plt.axis('off')

        plt.subplot(3, 1, 3)
        plt.imshow(pred_np.astype(np.uint8))
        plt.title("Prediction")
        plt.axis('off')

        plt.tight_layout()
        plt.show()
        break
