In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.cluster import MiniBatchKMeans
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
import time
import torch.optim as optim
import os
import sys
import math

from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torch.utils.data import DataLoader

In [None]:
PROJECT_PATH = '/content/drive/MyDrive/Colab Notebooks/MLME2025_project'
CITYSCAPES_DIR = '/content/drive/MyDrive/Cityspaces/'
GTA5_DIR = '/content/drive/MyDrive/GTA5/'
BEST_MODEL_SAVE_PATH = '/content/drive/MyDrive/Colab Notebooks/MLME2025_project/models/BiSeNet/checkpoints_training_DACS_PLD/PLD_DACS_best_model_BiSeNet.pth'
LAST_EPOCH_SAVE_PATH = '/content/drive/MyDrive/Colab Notebooks/MLME2025_project/models/BiSeNet/checkpoints_training_DACS_PLD/PLD_DACS_last_epoch_BiSeNet.pth'


H_CITYSCAPES = 512
W_CITYSCAPES = 1024

H_GTA5 = 720
W_GTA5 = 1280

NUM_CLASSES = 19
BATCH_SIZE = 4
NUM_WORKERS = 4
LEARNING_RATE = 0.025
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0001

In [None]:
os.chdir(PROJECT_PATH)
sys.path.append(os.getcwd())

from datasets.cityscapes import CityScapes
from datasets.gta5 import GTA5
from utils.augmentation import AdditiveGaussianNoise, RandomHorizontalFlipPair
from utils.utils import poly_lr_scheduler_with_backbone, fast_hist, per_class_iou, mean_iou
from utils.dacs_mixer import DACS_Mixer
from models.BiSeNet.build_bisenet import BiSeNet
from utils.PLD import PLD_Adapter
from utils.PLD_dacs_mixer import PLD_DACS_Mixer

In [None]:
torch.manual_seed(42)
np.random.seed(42)
random.seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# augmentation transformer
aug_transform = RandomHorizontalFlipPair(p=0.5) # AUG1


# data transformer (X)

data_transforms = {

    'source' : transforms.Compose([
        transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05)], p=0.5), #AUG3
        transforms.Resize((H_GTA5, W_GTA5), interpolation=InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        AdditiveGaussianNoise(std=0.01, p=0.5),  #AUG4
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'target' : transforms.Compose([
        transforms.Resize((H_CITYSCAPES, W_CITYSCAPES), interpolation=InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# label transformer (Y)

label_transforms = {
    'source' : transforms.Compose([
        transforms.Resize((H_GTA5, W_GTA5), interpolation=InterpolationMode.NEAREST),
        transforms.Lambda(lambda x: torch.from_numpy(np.array(x)).long())
    ]),
    'target' : transforms.Compose([
        transforms.Resize((H_CITYSCAPES, W_CITYSCAPES), interpolation=InterpolationMode.NEAREST),
        transforms.Lambda(lambda x: torch.from_numpy(np.array(x)).long())
    ])
}

In [None]:
# dataset

source_train_dataset = GTA5(
    data_path=GTA5_DIR,
    transform=data_transforms['source'],
    label_transform=label_transforms['source'],
    aug_transform=aug_transform
)

target_train_dataset = CityScapes(
    data_path=CITYSCAPES_DIR,
    split='train',
    transform=data_transforms['target'],
    label_transform=label_transforms['target']
)

val_dataset = CityScapes(
    data_path=CITYSCAPES_DIR,
    split='val',
    transform=data_transforms['target'],
    label_transform=label_transforms['target']
)

# dataset sizes

dataset_sizes = {
    'source_train': len(source_train_dataset),
    'target_train': len(target_train_dataset),
    'val': len(val_dataset)
}

In [None]:
dataset_sizes

{'source_train': 2500, 'target_train': 1572, 'val': 500}

In [None]:
# dataloader

source_train_loader = torch.utils.data.DataLoader(
    source_train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True,
    pin_memory=True
)

target_train_loader = torch.utils.data.DataLoader(
    target_train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True,
    pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=False,
    pin_memory=True
)


In [None]:
# object that corrects the target pseudo labels (see utils.PLD)
PLD = PLD_Adapter(
    num_classes=NUM_CLASSES,
    ignore_label=255,
    confidence_threshold=0.968,
    feature_dim=256
)

# new dacs mixer that uses the corrected pseudo labels (see utils.PLD_dacs_mixer)
dacs_mixer = PLD_DACS_Mixer(
    num_classes=NUM_CLASSES,
    ignore_label=255,
    confidence_threshold=0.968
)

In [None]:
def train_model(model, source_data_loader, target_data_loader, val_data_loader,
    dataset_sizes, criterion, optimizer, last_epoch_save_path, best_model_save_path,
    num_epochs=1, init_lr=0.01, prev_num_epoch=0, prev_best_miou=0, total_number_epochs=50,
    device='cuda'):

    since = time.time()
    best_miou = prev_best_miou
    best_per_class_iou = None

    iter_target_data_loader = iter(target_data_loader)

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        since_epoch = time.time()

        # TRAINING
        model.train()
        running_loss_source = 0.0
        running_loss_mixed = 0.0
        running_loss_total = 0.0

        hist_source_train = np.zeros((NUM_CLASSES, NUM_CLASSES))
        hist_mixed_train = np.zeros((NUM_CLASSES, NUM_CLASSES))

        progress_bar = tqdm(source_data_loader, desc="Training")

        for i, source_batch in enumerate(progress_bar):
            try:
                target_batch = next(iter_target_data_loader)
            except StopIteration:
                iter_target_data_loader = iter(target_data_loader)
                target_batch = next(iter_target_data_loader)

            X_s = source_batch['x'].to(device)
            Y_s = source_batch['y'].to(device)
            X_t = target_batch['x'].to(device)

            optimizer.zero_grad()

            # 1. Training on source sample (X_s, Y_s)
            outputs_s_tuple = model(X_s)

            if isinstance(outputs_s_tuple, tuple):
                outputs_s = outputs_s_tuple[0]
            else:
                outputs_s = outputs_s_tuple

            loss_s = criterion(outputs_s, Y_s)

            preds_s = torch.argmax(outputs_s, dim=1)
            hist_source_train += fast_hist(
                preds_s.cpu().data.numpy().flatten(),
                Y_s.cpu().data.numpy().flatten(),
                NUM_CLASSES
            )

            # 2. Training on mixed sample (X_m, Y_m)

            logits_t = model(X_t)[0] if isinstance(model(X_t), tuple) else model(X_t)
            # Generate corrected pseudo-labels for the target batch using PLD
            Y_t_pseudo = PLD(model, X_t, logits_t)
            # Perform mixing between source and target images using DACS with corrected pseudo-labels
            X_m, Y_m, lambda_mix, _ = dacs_mixer(model, X_s, Y_s, X_t, Y_t_pseudo)


            if X_m is not None and Y_m is not None:
                outputs_m_tuple = model(X_m)
                if isinstance(outputs_m_tuple, tuple):
                    outputs_m = outputs_m_tuple[0]
                else:
                    outputs_m = outputs_m_tuple

                loss_m = criterion(outputs_m, Y_m)

                preds_m = torch.argmax(outputs_m, dim=1)
                hist_mixed_train += fast_hist(
                    preds_m.cpu().data.numpy().flatten(),
                    Y_m.cpu().data.numpy().flatten(),
                    NUM_CLASSES
                )
            else:
                loss_m = torch.tensor(0.0, device=device)
                lambda_mix = 0.0

            # 3. Total loss
            total_batch_loss = loss_s + lambda_mix * loss_m

            total_batch_loss.backward()
            optimizer.step()

            running_loss_source += loss_s.item()
            running_loss_mixed += loss_m.item()
            running_loss_total += total_batch_loss.item()

            # DEBUG INFO
            progress_bar.set_postfix({
                "batch" : i,
                "loss_s": round(loss_s.item(), 4),
                "loss_m": round(loss_m.item(), 4),
                "lambda_mix": round(lambda_mix, 4) if isinstance(lambda_mix, float) else float(lambda_mix),
                "total_loss": round(total_batch_loss.item(), 4)
            })

        epoch_loss_source = running_loss_source / dataset_sizes['source_train']
        epoch_loss_mixed = running_loss_mixed / dataset_sizes['source_train']
        epoch_loss_total = running_loss_total / dataset_sizes['source_train']

        hist_combined = hist_source_train + hist_mixed_train
        ious_combined = per_class_iou(hist_combined) * 100
        miou_combined = mean_iou(ious_combined)

        print(f'train Loss: {epoch_loss_total:.4f} Acc: {miou_combined:.4f}')

        # VALIDATION
        model.eval()
        running_loss_val = 0.0
        hist_val = np.zeros((NUM_CLASSES, NUM_CLASSES))

        progress_bar_val = tqdm(val_data_loader, desc="Validation")

        for j, batch_val in enumerate(progress_bar_val):
            inputs_val = batch_val['x'].to(device)
            labels_val = batch_val['y'].to(device)

            with torch.no_grad():
                outputs_val_tuple = model(inputs_val)
                if isinstance(outputs_val_tuple, tuple):
                    outputs_val = outputs_val_tuple[0]
                else:
                    outputs_val = outputs_val_tuple

                loss_val = criterion(outputs_val, labels_val)

                running_loss_val += loss_val.item()

                preds_val = torch.argmax(outputs_val, dim=1)
                hist_val += fast_hist(
                    preds_val.cpu().data.numpy().flatten(),
                    labels_val.cpu().data.numpy().flatten(),
                    NUM_CLASSES
                )

            # DEBUG INFO
            progress_bar_val.set_postfix({
                "batch" : j,
                "loss_val": round(loss_val.item(), 4)
            })

        epoch_loss_val = running_loss_val / dataset_sizes['val']
        ious_val = per_class_iou(hist_val) * 100
        miou_val = mean_iou(ious_val)

        print(f'val Loss: {epoch_loss_val:.4f} Acc: {miou_val:.4f}')

        # Update learning rate
        next_lr = poly_lr_scheduler_with_backbone(optimizer, init_lr, prev_num_epoch, total_number_epochs)
        prev_num_epoch += 1

        # save the best model
        if miou_val > best_miou:
            best_miou = miou_val
            best_per_class_iou = ious_val
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, best_model_save_path)

        # save last epoch model
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, last_epoch_save_path)

        time_epoch = time.time() - since_epoch
        print(f'Epoch complete in {time_epoch // 60:.0f}m {time_epoch % 60:.0f}s')
        print(f'Next Learning Rate: {next_lr}')
        print(f'Last Lambda_Mix: {lambda_mix}')
        print()

    time_elapsed = time.time() - since
    print('-' * 20)
    print()
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val MIOU: {best_miou:4f}')
    print(f'Best val per class IOU: {best_per_class_iou}')
    print()
    print(f'Total Epochs completed: {prev_num_epoch}')

    return model, time_elapsed, best_miou, best_per_class_iou

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [None]:
context_path = 'resnet18'

model = BiSeNet(num_classes=NUM_CLASSES, context_path=context_path)
model = model.to(device)

optimizer = optim.SGD(
    params=[
    {'params': model.context_path.parameters(), 'lr': LEARNING_RATE * 0.1, 'initial_lr': LEARNING_RATE * 0.1},
    {'params': [p for module in model.mul_lr for p in module.parameters()], 'lr': LEARNING_RATE, 'initial_lr': LEARNING_RATE}
    ],
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY
)

criterion = nn.CrossEntropyLoss(ignore_index=255)

# load previous checkpoint
checkpoint = torch.load(LAST_EPOCH_SAVE_PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 187MB/s]
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:00<00:00, 225MB/s]


In [None]:
CENTROIDS_SAVE_PATH = '/content/drive/MyDrive/Colab Notebooks/MLME2025_project/models/BiSeNet/checkpoints_training_DACS_PLD/source_centroids.pth'

centroids_loaded = PLD.load_centroids(load_path=CENTROIDS_SAVE_PATH, device=device)

if not centroids_loaded:
    PLD.compute_source_centroids(
        model=model,
        source_data_loader=source_train_loader,
        device=device,
        save_path=CENTROIDS_SAVE_PATH
    )

Caricamento dei centroidi da /content/drive/MyDrive/Colab Notebooks/MLME2025_project/models/BiSeNet/checkpoints_training_PLCA_DACS/source_centroids.pth...
Centroidi caricati con successo.


In [None]:
model, time_elapsed, best_miou, best_per_class_iou = train_model(
    model=model,
    source_data_loader=source_train_loader,
    target_data_loader=target_train_loader,
    val_data_loader=val_loader,
    dataset_sizes=dataset_sizes,
    criterion=criterion,
    optimizer=optimizer,
    last_epoch_save_path=LAST_EPOCH_SAVE_PATH,
    best_model_save_path=BEST_MODEL_SAVE_PATH,
    num_epochs=10,
    init_lr=LEARNING_RATE,
    prev_num_epoch=40,
    prev_best_miou=33.554278,
    total_number_epochs=50,
    device=device
)

Epoch 1/10
----------


Training: 100%|██████████| 625/625 [23:08<00:00,  2.22s/it, batch=624, loss_s=0.102, loss_m=0.049, lambda_mix=0.821, total_loss=0.142]


train Loss: 0.0495 Acc: 71.1025


Validation: 100%|██████████| 125/125 [02:04<00:00,  1.00it/s, batch=124, loss_val=0.507]


val Loss: 0.2437 Acc: 31.8929
Epoch complete in 25m 14s
Next Learning Rate: [0.0005873094715440095, 0.005873094715440094]
Last Lambda_Mix: 0.8209212422370911

Epoch 2/10
----------


Training: 100%|██████████| 625/625 [14:33<00:00,  1.40s/it, batch=624, loss_s=0.126, loss_m=0.0767, lambda_mix=0.792, total_loss=0.186]


train Loss: 0.0496 Acc: 71.6126


Validation: 100%|██████████| 125/125 [00:17<00:00,  7.17it/s, batch=124, loss_val=0.603]


val Loss: 0.2397 Acc: 32.5526
Epoch complete in 14m 51s
Next Learning Rate: [0.0005341770966113464, 0.005341770966113463]
Last Lambda_Mix: 0.7920171618461609

Epoch 3/10
----------


Training: 100%|██████████| 625/625 [14:30<00:00,  1.39s/it, batch=624, loss_s=0.172, loss_m=0.0984, lambda_mix=0.857, total_loss=0.256]


train Loss: 0.0490 Acc: 71.6789


Validation: 100%|██████████| 125/125 [00:17<00:00,  7.16it/s, batch=124, loss_val=0.598]


val Loss: 0.2346 Acc: 32.8807
Epoch complete in 14m 48s
Next Learning Rate: [0.00048044977359257274, 0.004804497735925726]
Last Lambda_Mix: 0.8565302491188049

Epoch 4/10
----------


Training: 100%|██████████| 625/625 [14:33<00:00,  1.40s/it, batch=624, loss_s=0.142, loss_m=0.0266, lambda_mix=0.828, total_loss=0.164]


train Loss: 0.0495 Acc: 71.4818


Validation: 100%|██████████| 125/125 [00:17<00:00,  7.20it/s, batch=124, loss_val=0.565]


val Loss: 0.2469 Acc: 32.1832
Epoch complete in 14m 52s
Next Learning Rate: [0.00042604477233329193, 0.004260447723332918]
Last Lambda_Mix: 0.828330934047699

Epoch 5/10
----------


Training: 100%|██████████| 625/625 [14:32<00:00,  1.40s/it, batch=624, loss_s=0.131, loss_m=0.0573, lambda_mix=0.805, total_loss=0.177]


train Loss: 0.0491 Acc: 71.6685


Validation: 100%|██████████| 125/125 [00:17<00:00,  7.10it/s, batch=124, loss_val=0.535]


val Loss: 0.2404 Acc: 32.1902
Epoch complete in 14m 51s
Next Learning Rate: [0.00037085413874382193, 0.0037085413874382183]
Last Lambda_Mix: 0.8054693341255188

Epoch 6/10
----------


Training: 100%|██████████| 625/625 [14:34<00:00,  1.40s/it, batch=624, loss_s=0.128, loss_m=0.0943, lambda_mix=0.802, total_loss=0.204]


train Loss: 0.0490 Acc: 71.7480


Validation: 100%|██████████| 125/125 [00:17<00:00,  7.16it/s, batch=124, loss_val=0.61]


val Loss: 0.2459 Acc: 32.0551
Epoch complete in 14m 52s
Next Learning Rate: [0.00031473135294854187, 0.003147313529485418]
Last Lambda_Mix: 0.8019260168075562

Epoch 7/10
----------


Training: 100%|██████████| 625/625 [14:28<00:00,  1.39s/it, batch=624, loss_s=0.127, loss_m=0.0839, lambda_mix=0.803, total_loss=0.194]


train Loss: 0.0489 Acc: 71.7706


Validation: 100%|██████████| 125/125 [00:17<00:00,  7.10it/s, batch=124, loss_val=0.525]


val Loss: 0.2264 Acc: 33.0039
Epoch complete in 14m 47s
Next Learning Rate: [0.00025746665870904475, 0.0025746665870904468]
Last Lambda_Mix: 0.8031673431396484

Epoch 8/10
----------


Training: 100%|██████████| 625/625 [14:31<00:00,  1.39s/it, batch=624, loss_s=0.107, loss_m=0.089, lambda_mix=0.764, total_loss=0.175]


train Loss: 0.0487 Acc: 72.1880


Validation: 100%|██████████| 125/125 [00:17<00:00,  7.06it/s, batch=124, loss_val=0.497]


val Loss: 0.2375 Acc: 32.2048
Epoch complete in 14m 49s
Next Learning Rate: [0.0001987358121886906, 0.001987358121886906]
Last Lambda_Mix: 0.7637500166893005

Epoch 9/10
----------


Training: 100%|██████████| 625/625 [14:31<00:00,  1.39s/it, batch=624, loss_s=0.13, loss_m=0.0734, lambda_mix=0.842, total_loss=0.192]


train Loss: 0.0487 Acc: 71.8261


Validation: 100%|██████████| 125/125 [00:17<00:00,  7.12it/s, batch=124, loss_val=0.586]


val Loss: 0.2361 Acc: 31.7848
Epoch complete in 14m 50s
Next Learning Rate: [0.00013797296614612162, 0.001379729661461216]
Last Lambda_Mix: 0.8421937823295593

Epoch 10/10
----------


Training: 100%|██████████| 625/625 [14:30<00:00,  1.39s/it, batch=624, loss_s=0.107, loss_m=0.0323, lambda_mix=0.796, total_loss=0.133]


train Loss: 0.0485 Acc: 72.2341


Validation: 100%|██████████| 125/125 [00:17<00:00,  7.08it/s, batch=124, loss_val=0.583]


val Loss: 0.2451 Acc: 32.1583
Epoch complete in 14m 49s
Next Learning Rate: [7.393788183141576e-05, 0.0007393788183141575]
Last Lambda_Mix: 0.7964637875556946

--------------------

Training complete in 158m 42s
Best val MIOU: 33.554278
Best val per class IOU: None

Total Epochs completed: 50
