In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
from torchvision.datasets import Cityscapes
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
%load_ext autoreload
%autoreload 2
input_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


def target_transform(mask):
    mask = np.array(mask)
    mask = torch.from_numpy(mask).long()
    return mask
root_dir = '/tmp/cityscapes'

train_dataset = Cityscapes(root=root_dir,
                           split='train',
                           mode='fine',
                           target_type='semantic',
                           transform=input_transform,
                           target_transform=target_transform)

# Create the DataLoader
train_loader = DataLoader(train_dataset,
                          batch_size=32,
                          shuffle=True,
                          num_workers=4)

val_dataset = Cityscapes(root=root_dir,
                           split='val',
                           mode='fine',
                           target_type='semantic',
                           transform=input_transform,
                           target_transform=target_transform)

val_loader = DataLoader(val_dataset,
                          batch_size=32,
                          shuffle=True,
                          num_workers=4)

test_dataset = Cityscapes(root=root_dir,
                           split='test',
                           mode='fine',
                           target_type='semantic',
                           transform=input_transform,
                           target_transform=target_transform)

test_loader = DataLoader(test_dataset,
                          batch_size=32,
                          shuffle=False,
                          num_workers=4)
for images, masks in train_loader:
    print(images.shape)  # torch.Size([4, 3, H, W])
    print(masks.shape)   # torch.Size([4, H, W])
    print(masks.dtype)   # torch.int64
    break
import segmentation_models_pytorch as smp
import torch

# Initialize U-Net
model = smp.Unet(
    encoder_name="resnet34",        # Choose the encoder. Options include 'resnet34', 'resnet50', etc.
    encoder_weights="imagenet",       # Use pre-trained weights on ImageNet
    in_channels=3,                    # Input channels (3 for RGB images)
    classes=19,                       # Number of output classes for Cityscapes
)

# Optionally move the model to a GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(model)

from train import train
NUM_EPOCHS = 5
optimizer = optim.Adam(model.parameters(), lr=0.001)

train(model=model, optimizer=optimizer,
        train_loader=train_loader, val_loader=val_loader, num_epochs=NUM_EPOCHS, num_examples=5, scheduler=None, freeze_encoder=False)