In [None]:
import os
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageDraw
import json
import numpy as np
import torch.nn as nn
import timm
from tqdm import tqdm


class IDD20KLLDataset(Dataset):
    def __init__(self, image_root, mask_root, transforms=None):
        self.image_paths = []
        self.mask_paths = []
        self.transforms = transforms

        # Collect all image and mask paths
        for subdir, _, files in tqdm(os.walk(mask_root)):  
            for file in files:
                if file.endswith("_gtFine_polygons.json"): 
                    base_name = file.replace("_gtFine_polygons.json", "")
                    # Construct paths
                    mask_path = Path(subdir) / file
                    image_path = Path(image_root) / Path(subdir).relative_to(mask_root) / f"{base_name}_leftImg8bit.jpg"

                    if image_path.exists():
                        self.mask_paths.append(mask_path)
                        self.image_paths.append(image_path)
                    else:
                        print(f"Warning: Image not found for mask: {mask_path}")

        print(f"Found {len(self.image_paths)} images")
        print(f"Found {len(self.mask_paths)} masks")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        mask_path = self.mask_paths[idx]

        mask = self.create_class_mask(mask_path, image.size)

        if self.transforms:
            image = self.transforms(image)
            mask = transforms.Resize((224, 224))(mask)  
            mask = np.array(mask, dtype=np.uint8)  # Ensure it is integer class indices
            mask = torch.tensor(mask, dtype=torch.long)  

        return image, mask  
    @staticmethod
    def create_class_mask(json_path, image_size):
        with open(json_path, 'r') as f:
            data = json.load(f)

        img_width, img_height = image_size
        mask = np.zeros((img_height, img_width), dtype=np.uint8)  # Single-channel binary mask

        # Draw polygons for all labels
        for obj in data["objects"]:
            if not obj["deleted"]:
                polygon = [(point[0], point[1]) for point in obj["polygon"]]
                # Skip polygons with fewer than 2 points
                if len(polygon) < 2:
                    print(f"Warning: Skipping invalid polygon with {len(polygon)} points in {json_path}")
                    continue
                img = Image.new("L", (img_width, img_height), 0)  # "L" mode creates a single-channel grayscale image
                ImageDraw.Draw(img).polygon(polygon, outline=1, fill=1)
                mask += np.array(img, dtype=np.uint8)

        return Image.fromarray(mask) 


In [None]:

mask_root = "E:\\Projects\\Finished\\Semantic Segmentation\\idd20kII\\gtFine"
image_root = "E:\\Projects\\Finished\\Semantic Segmentation\\idd20kII\\leftImg8bit"
train_image_root = os.path.join(image_root, "train")
train_mask_root = os.path.join(mask_root, "train")

transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),          
])

train_dataset = IDD20KLLDataset(train_image_root, train_mask_root, transforms=transform)
num_labels = 34  #34 Classes in Indian Driving

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)

class FCNSegmentationModel(nn.Module):
    def __init__(self, num_classes):
        super(FCNSegmentationModel, self).__init__()
        
        # Using a pre-trained ResNet backbone
        self.backbone = timm.create_model("resnet34", pretrained=True, features_only=True)  # Set `features_only=True`
        
        # The backbone returns feature maps at different layers
        # We're going to use the last feature map before fully connected layers (after conv layers)
        self.segmentation_head = nn.Conv2d(512, num_classes, kernel_size=1)  # 512 is the output channels of ResNet34

        # Add an upsampling layer to match input size (224x224 -> 224x224)
        self.upsample = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True)

    def forward(self, x):
        # Forward pass through the backbone (without fully connected layers)
        features = self.backbone(x)[-1]  # We want the last feature map output (batch_size, 512, 7, 7)
        
        print(f"Features shape before upsampling: {features.shape}")  # Should print (batch_size, 512, 7, 7)?????????
        
        features_up = self.upsample(features)  # Shape: (batch_size, 512, 224, 224)
        
        print(f"Features shape after upsampling: {features_up.shape}")  # Should print (batch_size, 512, 224, 224)

        # Apply the segmentation head to get per-pixel predictions
        segmentation_map = self.segmentation_head(features_up)  # Shape: (batch_size, num_classes, 224, 224)
        return segmentation_map

model = FCNSegmentationModel(num_classes=num_labels)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()  
    running_loss = 0.0
    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()

        outputs = model(images)

        loss = loss_fn(outputs, masks)

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
    
    torch.save(model.state_dict(), f"resnet_segmentation_epoch_{epoch+1}.pth")


250it [00:00, 461.93it/s]


Found 7034 images
Found 7034 masks


Epoch 1/10:   0%|          | 0/880 [00:00<?, ?it/s]

Features shape before upsampling: torch.Size([8, 512, 7, 7])
Features shape after upsampling: torch.Size([8, 512, 224, 224])


Epoch 1/10:   0%|          | 1/880 [00:07<1:48:04,  7.38s/it]

Features shape before upsampling: torch.Size([8, 512, 7, 7])
Features shape after upsampling: torch.Size([8, 512, 224, 224])


Epoch 1/10:   0%|          | 2/880 [00:17<2:09:09,  8.83s/it]


KeyboardInterrupt: 