In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
import zipfile
import os

zip_path = '/content/drive/MyDrive/MLDL_repo/GTA5.zip'
extract_path = '/content/dataset'

os.makedirs(extract_path, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print("Extraction complete!")

Extraction complete!


In [None]:
import sys
sys.path.append('/content/drive/MyDrive/MLDL_repo/step5_unet/unet_resnet18')

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

transform = A.Compose([
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.5),
    A.GaussianBlur(blur_limit=(3, 5), sigma_limit=(0.1, 1.0), p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.0, p=0.5),
    A.Resize(720, 1280),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

In [None]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [None]:
import os
import re

def find_latest_checkpoint(checkpoint_dir):
    if not os.path.exists(checkpoint_dir):
        print(f"Checkpoint directory {checkpoint_dir} does not exist.")
        return None

    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith("unet_epoch") and f.endswith(".pt")]
    if not checkpoints:
        print(f"No checkpoints found in {checkpoint_dir}.")
        return None

    def extract_epoch(fname):
        match = re.search(r"unet_epoch_(\d+)\.pt", fname)
        return int(match.group(1)) if match else -1

    checkpoints.sort(key=lambda x: extract_epoch(x), reverse=True)
    latest = checkpoints[0]
    print(f"The latest checkpoint is: {latest}")
    return os.path.join(checkpoint_dir, latest)

In [None]:
### UNET - ENCODER RESNET 18

import torch
import torch.nn as nn
import torchvision.transforms as T
import time
import csv
import os
from model.resnet_unet import SResUnet
from torchvision.models import resnet18
from dataset_custom.gta5_aug import GTA5
from train import train, evaluate
from torch.utils.data import DataLoader, random_split
from dataset_custom.labels import GTA5Labels_TaskCV2017
from utils_p import show_predictions_triplet


def main():

    # Parameters
    epochs = 30
    batch_size = 2
    lr = 1e-3
    num_classes = 19
    data_root = '/content/dataset/GTA5'
    checkpoint_dir = "/content/drive/MyDrive/checkpoints_unet_resnet18/"

    if not os.path.exists(checkpoint_dir):
      os.makedirs(checkpoint_dir)

    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Dataset
    dataset = GTA5(root=data_root, transform=transform)
    n_total = len(dataset)
    train_len = int(n_total * 0.7) # 70% training
    val_len = n_total - train_len
    train_set, val_set = random_split(dataset, [train_len, val_len])

    print(f"Train dataset size: {len(train_set)}")
    print(f"Val dataset size: {len(val_set)}")

    # Dataloaders
    loader_args = dict(batch_size=batch_size, num_workers=2, pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

    # Model
    model = SResUnet(resnet18, pretrained=True, out_channels=num_classes)
    model.to(device)

    # Loss & Optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=255)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    start_epoch = 0
    start_batch = 0

    latest_ckpt = find_latest_checkpoint(checkpoint_dir)
    print("latest_ckpt is:", latest_ckpt)

    if latest_ckpt:
      print(f"Restore from checkpoint: {latest_ckpt}")
      checkpoint = torch.load(latest_ckpt, map_location=device)
      model.load_state_dict(checkpoint['model_state_dict'])
      # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
      start_epoch = checkpoint['epoch'] + 1  # Riparti dall'epoca successiva
      start_batch = 0  # Non serve più riprendere dal batch
      print(f"Picking up from epoch {start_epoch}")
    else:
      print("No checkpoint found, start training from scratch")

    gta5_labels = GTA5Labels_TaskCV2017

    log_file = f"unet_resnet18_log.csv"
    with open(log_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch', 'train_loss', 'val_loss', 'accuracy', 'mIoU', 'dice', 'time_sec', 'vram_MB'])


    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")

        start = time.time()
        train_loss = train(model, train_loader, optimizer, criterion, device)

        # checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f"unet_epoch_{epoch+1}.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, checkpoint_path)
        print(f"Checkpoint stored: {checkpoint_path}")


        val_loss, acc, miou, dice = evaluate(model, train_loader, criterion, device, num_classes)
        end = time.time()

        # VRAM tracking
        vram = torch.cuda.max_memory_allocated() / 1e6 if torch.cuda.is_available() else 0
        torch.cuda.reset_peak_memory_stats()

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Acc: {acc:.4f} | mIoU: {miou:.4f} | Dice: {dice:.4f} | Time: {end - start:.2f}s | VRAM: {vram:.2f} MB")
        print("Visualization of 10 random samples:")
        show_predictions_triplet(model, train_loader, device, gta5_labels, num_images=5, denorm=None)

        with open(log_file, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch+1, train_loss, val_loss, acc, miou, dice, end - start, vram])


    final_model_path = "/content/drive/MyDrive/MLDL_repo/step5_unet/unet_resnet18/unet_resnet18_final.pht"
    torch.save(model.state_dict(), final_model_path)
    print(f"Final Model Stored in: {final_model_path}")

if __name__ == "__main__":
    main()

