In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
import numpy as np
from glob import glob

In [None]:
class UNetResNet34(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.encoder = models.resnet34(pretrained=True)
        self.encoder_layer1 = nn.Sequential(
            self.encoder.conv1, self.encoder.bn1, self.encoder.relu,
            self.encoder.maxpool, self.encoder.layer1)
        self.encoder_layer2 = self.encoder.layer2
        self.encoder_layer3 = self.encoder.layer3
        self.encoder_layer4 = self.encoder.layer4
        self.upconv4 = self.upsample_block(512, 256)
        self.upconv3 = self.upsample_block(256, 128)
        self.upconv2 = self.upsample_block(128, 64)
        self.upconv1 = self.upsample_block(64, 64)
        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)

    def upsample_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x1 = self.encoder_layer1(x)
        x2 = self.encoder_layer2(x1)
        x3 = self.encoder_layer3(x2)
        x4 = self.encoder_layer4(x3)
        d4 = self.upconv4(x4)
        d3 = self.upconv3(d4 + x3)
        d2 = self.upconv2(d3 + x2)
        d1 = self.upconv1(d2 + x1)
        out = self.final_conv(d1)
        return out

In [None]:
class CMDataset(Dataset):
    color_encoding = [
        ('Bus Lane', (0, 255, 255)),
        ('Cycle Lane', (0, 128, 255)),
        ('Diamond', (178, 102, 255)),
        ('Junction Box', (255, 255, 51)),
        ('Left Arrow', (255, 102, 178)),
        ('Pedestrian Crossing', (255, 255, 0)),
        ('Right Arrow', (255, 0, 127)),
        ('Straight Arrow', (255, 0, 255)),
        ('Slow', (0, 255, 0)),
        ('Straight-Left Arrow', (255, 128, 0)),
        ('Straight-Right Arrow', (255, 0, 0)),
        ('Background', (0, 0, 0))
    ]

    def __init__(self, mode='train', num_classes=12):
        self.mode = mode
        self.num_classes = num_classes
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        self.data_files = self.get_files(os.path.join(os.getcwd(), f'{mode}/'))
        self.label_files = [f.replace(mode, f"{mode}_labels") for f in self.data_files]

    def get_files(self, folder):
        return glob(f"{folder}/*.jpg")

    def __len__(self):
        return len(self.data_files)

    def __getitem__(self, index):
        data = Image.open(self.data_files[index]).resize((512, 512))
        label = Image.open(self.label_files[index]).resize((512, 512))
        data = self.normalize(data)
        label = self.one_hot_encode(np.array(label))
        return data, torch.tensor(label, dtype=torch.long)

    def one_hot_encode(self, label):
        semantic_map = np.zeros(label.shape[:2], dtype=int)
        for class_index, (_, color) in enumerate(self.color_encoding):
            equality = np.all(label == color, axis=-1)
            semantic_map[equality] = class_index
        return semantic_map

In [None]:
def train(model, dataloader, optimizer, loss_func):
    model.train()
    tq = tqdm(total=len(dataloader) * dataloader.batch_size)
    tq.set_description('Training')
    for data, label in dataloader:
        data, label = data.cuda(), label.cuda()
        output = model(data)
        loss = loss_func(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        tq.set_postfix(loss=f'{loss.item():.6f}')
        tq.update(dataloader.batch_size)
    tq.close()

def validate(model, dataloader):
    model.eval()
    total_accuracy = 0
    with torch.no_grad():
        for data, label in dataloader:
            data, label = data.cuda(), label.cuda()
            output = model(data)
            output = torch.argmax(output, dim=1)
            accuracy = (output == label).float().mean().item()
            total_accuracy += accuracy
    return total_accuracy / len(dataloader)

In [None]:
num_classes = 12
model = UNetResNet34(num_classes=num_classes).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss_func = nn.CrossEntropyLoss()

train_dataset = CamVidDataset(mode='train')
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True)

val_dataset = CamVidDataset(mode='val')
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)


In [None]:
EPOCHS = 100
CHECKPOINT_STEP = 1
VALIDATE_STEP = 1
max_miou = 0

for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    train(model, train_loader, optimizer, loss_func)
    if epoch % VALIDATE_STEP == 0:
        val_accuracy = validate(model, val_loader)
        print(f"Validation Accuracy: {val_accuracy:.4f}")
    if epoch % CHECKPOINT_STEP == 0:
        torch.save(model.state_dict(), f'checkpoint_epoch_{epoch}.pth')