In [None]:
import random
import numpy as np
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

In [None]:
# REPO = "https://github.com/annanasnas/semantic_segmentation-25.git"
# !git clone $REPO
# !pip install -q -r requirements.txt pyyaml

## Config

In [None]:
import yaml

with open("configs/deeplabv2.yaml", "r", encoding="utf-8") as f:
    cfg = yaml.safe_load(f)

!python download_data.py

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = cfg["train"]["batch_size"]
epochs = cfg["train"]["epochs"]
data_dir = cfg["data"]["root"]
learning_rate = cfg["train"]["learning_rate"]

## DataLoaders

In [None]:
from datasets.cityscapes import CityscapesDataset
from models.deeplabv2.deeplabv2 import get_deeplab_v2
from train import train_model
from torchvision import transforms
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader


imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

image_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

train_dataset = CityscapesDataset(
    root_dir=data_dir,
    split="train",
    image_transform=image_transforms
)

val_dataset = CityscapesDataset(
    root_dir=data_dir,
    split="val",
    image_transform=image_transforms,
)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

## Training

In [None]:
from tqdm import tqdm
import torch.optim as optim
import torch.nn as nn

model = get_deeplab_v2()
optimizer = optim.SGD(model.optim_parameters(lr=learning_rate), momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(ignore_index=255)
scheduler = optim.lr_scheduler.PolynomialLR(optimizer, total_iters=50, power=0.9)
scaler = GradScaler()

model.to(device)

train_model(DeepLabV2_model, DeepLabV2_criterion, DeepLabV2_optimizer,
            cityscapes_train_dataloader, cityscapes_test_dataloader, class_names, device, n_epochs, model_name='DeepLabV2')