In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
import numpy as np

In [2]:
class CityscapesDataset(Dataset):
    def __init__(self, root, split='train', transform=None, target_transform=None):
        self.root = root
        self.split = split
        self.transform = transform
        self.target_transform = target_transform
        
        self.images_dir = os.path.join(root, 'leftImg8bit', split)
        self.masks_dir = os.path.join(root, 'gtFine', split)
        
        self.images = []
        self.masks = []
        
        # Collect all images and corresponding masks
        for city in os.listdir(self.images_dir):
            city_images_dir = os.path.join(self.images_dir, city)
            city_masks_dir = os.path.join(self.masks_dir, city)
            
            for file_name in os.listdir(city_images_dir):
                if file_name.endswith('_leftImg8bit.png'):
                    image_path = os.path.join(city_images_dir, file_name)
                    mask_name = file_name.replace('_leftImg8bit.png', '_gtFine_labelIds.png')
                    mask_path = os.path.join(city_masks_dir, mask_name)
                    
                    self.images.append(image_path)
                    self.masks.append(mask_path)
                    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('RGB')
        mask = Image.open(self.masks[idx])
        
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)
        
    
        return image, mask


In [3]:
# Define the preprocessing transforms
transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def target_transform(mask):
    mask = mask.resize((256, 256), Image.NEAREST)  # Resize with nearest neighbor to preserve labels
    mask = np.array(mask)  # Convert to numpy array
    mask = torch.from_numpy(mask).long()  # Convert to torch tensor
    return mask


In [4]:
# Create the dataset
root_dir = os.getcwd()
train_dataset = CityscapesDataset(root=root_dir, split='train', transform=transform, target_transform=target_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)

In [5]:
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.optim as optim

# Load the pretrained U-Net model with ResNet-34 encoder
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=34  # Number of classes for Cityscapes
)

In [6]:
# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(device)

cpu


In [7]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [8]:
# Training loop
num_epochs = 10

model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
        
        torch.set_printoptions(threshold=10000, edgeitems=1000, linewidth=1000)

        inputs, labels = inputs.to(device), labels.to(device).long()

        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}')


Epoch 1/10, Loss: 0.8305092011728594
Epoch 2/10, Loss: 0.6165137884597625
Epoch 3/10, Loss: 0.5504282839516158
Epoch 4/10, Loss: 0.5203260881125286
Epoch 5/10, Loss: 0.4882466739384077
Epoch 6/10, Loss: 0.4641659569035294
Epoch 7/10, Loss: 0.44401915393449287
Epoch 8/10, Loss: 0.428942819155993
Epoch 9/10, Loss: 0.4123560603988427
Epoch 10/10, Loss: 0.4018541305896736


In [9]:
def decode_segmap(image, colormap):
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)
    for l in range(0, len(colormap)):
        idx = image == l
        r[idx] = colormap[l, 0]
        g[idx] = colormap[l, 1]
        b[idx] = colormap[l, 2]
    rgb = np.stack([r, g, b], axis=2)
    return rgb

In [39]:
test_img_path = r"C:\Users\Hardik Gohil\OneDrive\Documents\Work_Related\2nd year Summer\Amit Sethi\Intro to computer vision with pytorch\leftImg8bit\train\aachen\aachen_000000_000019_leftImg8bit.png"
test_img = Image.open(test_img_path).convert("RGB")
test_img = transform(test_img).unsqueeze(0).to(device)

with torch.no_grad():
    output = model(test_img)
output = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()



In [12]:
def create_cityscapes_label_colormap():
    colormap = np.zeros((256, 3), dtype=np.uint8)
    colormap[0] = [128, 64, 128]    # road
    colormap[1] = [244, 35, 232]    # sidewalk
    colormap[2] = [70, 70, 70]      # building
    colormap[3] = [102, 102, 156]   # wall
    colormap[4] = [190, 153, 153]   # fence
    colormap[5] = [153, 153, 153]   # pole
    colormap[6] = [250, 170, 30]    # traffic light
    colormap[7] = [220, 220, 0]     # traffic sign
    colormap[8] = [107, 142, 35]    # vegetation
    colormap[9] = [152, 251, 152]   # terrain
    colormap[10] = [70, 130, 180]   # sky
    colormap[11] = [220, 20, 60]    # person
    colormap[12] = [255, 0, 0]      # rider
    colormap[13] = [0, 0, 142]      # car
    colormap[14] = [0, 0, 70]       # truck
    colormap[15] = [0, 60, 100]     # bus
    colormap[16] = [0, 80, 100]     # train
    colormap[17] = [0, 0, 230]      # motorcycle
    colormap[18] = [119, 11, 32]    # bicycle
    colormap[19] = [0, 0, 0]        # void
    colormap[20] = [105, 105, 105]  # dynamic
    colormap[21] = [169, 169, 169]  # static
    colormap[22] = [192, 192, 192]  # guard rail
    colormap[23] = [128, 128, 0]    # bridge
    colormap[24] = [255, 255, 255]  # tunnel
    colormap[25] = [0, 128, 0]      # parking
    colormap[26] = [0, 128, 128]    # rail track
    colormap[27] = [128, 0, 128]    # guard rail
    colormap[28] = [128, 128, 64]   # fence
    colormap[29] = [128, 128, 192]  # wall
    colormap[30] = [192, 128, 64]   # building
    colormap[31] = [192, 128, 128]  # pole
    colormap[32] = [192, 192, 128]  # traffic sign
    colormap[33] = [255, 192, 128]  # traffic light
    colormap[34] = [128, 255, 255]  # sky
    return colormap

In [40]:
colormap = create_cityscapes_label_colormap()
segmentation_map = decode_segmap(output, colormap)

In [41]:
img = cv2.imread(test_img_path)

cv2.imshow("Img", img)
cv2.imshow('Segmentation Map', cv2.cvtColor(segmentation_map, cv2.COLOR_RGB2BGR))
cv2.waitKey(0)
cv2.destroyAllWindows()


In [24]:
# Save the model weights
model_save_path = 'unet_cityscapes.pth'
torch.save(model.state_dict(), model_save_path)


In [25]:
# Load the model weights
model_load_path = 'unet_cityscapes.pth'
model.load_state_dict(torch.load(model_load_path, map_location=device))
model.eval()  # Set the model to evaluation mode

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track