In [1]:
import io
import os
import optuna
import numpy as np
import pandas as pd
from tqdm import tqdm as base_tqdm
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchmetrics.segmentation import DiceScore, MeanIoU
from torchmetrics.classification import BinaryF1Score, BinaryJaccardIndex

from bdd_100k_dataset_local import BDD100KDatasetLocal
from model import SmallUNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def load_data_local() -> pd.DataFrame:
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    train_dataset_bdd = BDD100KDatasetLocal(
        images_dir='./copied/Dataset/100k_images_train/bdd100k/images/100k/train',
        masks_dir='./copied/Dataset/bdd100k_lane_labels_trainval/bdd100k/labels/lane/masks/train',
        transform=transform,
    )
    val_dataset_bdd = BDD100KDatasetLocal(
        images_dir='./copied/Dataset/100k_images_val/bdd100k/images/100k/val',
        masks_dir='./copied/Dataset/bdd100k_lane_labels_trainval/bdd100k/labels/lane/masks/val',
        transform=transform
    )
    test_dataset = BDD100KDatasetLocal(
        images_dir='./copied/Dataset/100k_images_test/bdd100k/images/100k/test',
        masks_dir='./copied/Dataset/bdd100k_lane_labels_trainval/bdd100k/labels/lane/masks/test',
        transform=transform
    )

    train_loader = DataLoader(train_dataset_bdd, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset_bdd, batch_size=16, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

    return train_loader, val_loader, test_loader

class BCEDiceLoss(nn.Module):
    def __init__(self, bce_weight=0.5, pos_weight=None):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight) if pos_weight is not None else nn.BCEWithLogitsLoss()
        self.bce_weight = bce_weight

    def forward(self, preds, targets):
        bce_loss = self.bce(preds, targets)

        preds = torch.sigmoid(preds)  # Convert logits to probabilities for Dice
    
        smooth = 1e-6
        preds_flat = preds.view(-1)
        targets_flat = targets.view(-1)

        intersection = (preds_flat * targets_flat).sum()
        dice_loss = 1 - (2 * intersection + smooth) / (preds_flat.sum() + targets_flat.sum() + smooth)

        return self.bce_weight * bce_loss + (1 - self.bce_weight) * dice_loss

def plot_and_save_curve(values, title, ylabel, filename):
    plt.figure()
    plt.plot(values)
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.grid(True)
    plt.savefig(filename)
    plt.close()


def plot_with_mean_curve(values1, values2, label1, label2, title, ylabel, filename):
    mean_values = [(v1 + v2) / 2 for v1, v2 in zip(values1, values2)]

    plt.figure()
    plt.plot(values1, label=label1)
    plt.plot(values2, label=label2)
    plt.plot(mean_values, label='Mean', linestyle='--')
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.grid(True)
    plt.legend()
    plt.savefig(filename)
    plt.close()

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=25, model_path='unet_lane_detection.pth', start_epoch=0):
    best_val_loss = float('inf')

    train_loss_history = []
    val_loss_history = []
    dice_history_lane = []
    iou_history_lane = []
    dice_history_bg = []
    iou_history_bg = []
    
    patience = 3
    epochs_without_improvement = 0

    val_dice_lane = 0
    val_iou_lane = 0
    val_dice_bg = 0
    val_iou_bg = 0

    for epoch in range(start_epoch, num_epochs):
        model.train()
        train_loss = 0

        train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")
        for images, masks in train_loader_tqdm:

            images = images.cuda()
            masks = masks.cuda()

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            train_loader_tqdm.set_postfix({'Loss': train_loss / len(train_loader.dataset)})

        train_loss = train_loss / len(train_loader.dataset)

        model.eval()
        val_loss = 0
        # Lane class (positive class = 1)
        dice_lane = BinaryF1Score().to(device)
        iou_lane = BinaryJaccardIndex().to(device)
        
        # Background class (positive class = 0) - invert predictions and targets
        dice_bg = BinaryF1Score().to(device)
        iou_bg = BinaryJaccardIndex().to(device)

        val_loader_tqdm = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Validation', mininterval=3.0)
        with torch.no_grad():
            for images, masks in val_loader_tqdm:

                images = images.cuda()
                masks = masks.cuda()

                outputs = model(images)

                loss = criterion(outputs, masks)
                val_loss += loss.item() * images.size(0)

                #apply sigmoid for dice and iou calculations
                outputs = torch.sigmoid(outputs)

                # Threshold model output
                preds = outputs
                targets = masks
                preds_bin = (preds > 0.5).int()

                # LANE: positive = 1
                dice_lane.update(preds_bin, masks)
                iou_lane.update(preds_bin, masks)
        
                # BACKGROUND: invert
                dice_bg.update(1 - preds_bin, 1 - masks)
                iou_bg.update(1 - preds_bin, 1 - masks)

                val_loader_tqdm.set_postfix({
                    'Loss': val_loss / len(val_loader.dataset),
                    # 'Dice Lane': val_dice_lane / no_batches,
                    # 'IoU Lane': val_iou_lane / no_batches,
                    'Dice Lane': dice_lane.compute().item(),
                    'IoU Lane': iou_lane.compute().item()
                })
        
        val_loss = val_loss / len(val_loader.dataset)
        
        val_dice_lane = dice_lane.compute().item()
        val_iou_lane = iou_lane.compute().item()
        val_dice_bg = dice_bg.compute().item()
        val_iou_bg = iou_bg.compute().item()

        train_loss_history.append(train_loss)
        val_loss_history.append(val_loss)
        dice_history_lane.append(val_dice_lane)
        iou_history_lane.append(val_iou_lane)
        dice_history_bg.append(val_dice_bg)
        iou_history_bg.append(val_iou_bg)

        scheduler.step(val_loss)

        tqdm.write(f'Epoch {epoch+1}/{num_epochs}, '
            f'Train Loss: {train_loss:.4f}, '
            f'Val Loss: {val_loss:.4f}, '
            f'Dice Lane: {val_dice_lane:.4f}, '
            f'IoU Lane: {val_iou_lane:.4f}, '
            f'Dice Bg: {val_dice_bg:.4f}, '
            f'IoU Bg: {val_iou_bg:.4f}'
        )

        # Save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            
            epochs_without_improvement = 0

            if model_path != None:
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'epoch': epoch
                }, model_path)
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break

    os.makedirs("./metrics", exist_ok=True)

    plot_and_save_curve(train_loss_history, "Training Loss", "Loss", "./metrics/train_loss.png")
    plot_and_save_curve(val_loss_history, "Validation Loss", "Loss", "./metrics/val_loss.png")
    plot_and_save_curve(dice_history_lane, "Lane Line Dice Score", "Dice", "./metrics/dice_score_lane.png")
    plot_and_save_curve(iou_history_lane, "Lane Line IoU Score", "IoU", "./metrics/iou_score_lane.png")
    plot_and_save_curve(dice_history_bg, "Background Dice Score", "Dice", "./metrics/dice_score_bg.png")
    plot_and_save_curve(iou_history_bg, "Background IoU Score", "IoU", "./metrics/iou_score_bg.png")

    plot_with_mean_curve(dice_history_lane, dice_history_bg, "Lane", "Background", "Dice Score", "Dice", "./metrics/dice_score_all.png")
    plot_with_mean_curve(iou_history_lane, iou_history_bg, "Lane", "Background", "IoU Score", "IoU", "./metrics/iou_score_all.png")

    return dice_lane.compute().item()


def start_training():
    # pretrained_model = './best_smallunet_lane_detection.pth'
    # model.load_state_dict(torch.load(pretrained_model, map_location=device, weights_only=True)['model_state_dict'])

    save_model_path = './best_smallunet_lane_detection_20.pth'
    model = SmallUNet(in_channels=3, out_channels=1, base_dropout=0.07574735102229871).cuda()
    
    # Criterion and optimizer setup
    # criterion = nn.BCELoss()
    # criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    
    lane_weight = 6.767409237000989
    pos_weight = torch.tensor([lane_weight]).to(device)
    criterion = BCEDiceLoss(bce_weight=0.29775550050971283, pos_weight=pos_weight)
    
    optimizer = optim.Adam(model.parameters(), lr=0.0014768215380281198)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

    train_loader, val_loader, _ = load_data_local()
    train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=20, model_path=save_model_path)


def objective(trial):
    pretrained_model = './best_smallunet_lane_detection.pth'
    
    # Suggest hyperparameters to try
    bce_weight = trial.suggest_float("bce_weight", 0.1, 1.0)
    lane_weight = trial.suggest_float("lane_weight", 1.0, 20.0)
    base_dropout = trial.suggest_float("base_dropout", 0.05, 0.3)    
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    
    # Model setup
    model = SmallUNet(in_channels=3, out_channels=1, base_dropout=base_dropout).to(device)
    model.load_state_dict(torch.load(pretrained_model, map_location=device, weights_only=True)['model_state_dict'])

    pos_weight = torch.tensor([lane_weight]).to(device)
    criterion = BCEDiceLoss(bce_weight=bce_weight, pos_weight=pos_weight)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

    # Load data
    train_loader, val_loader, _ = load_data_local()

    # Train and get best validation loss
    dice_score = train_model(
        model,
        train_loader,
        val_loader,
        criterion,
        optimizer,
        scheduler,
        num_epochs=5,  # shorter runs while searching
        model_path=None,
    )

    return dice_score

if __name__ == "__main__":
    start_training()
    print("Hello from cluster")
    
    # study = optuna.create_study(direction="maximize")
    # study.optimize(objective, n_trials=20)

    # print("Best trial:")
    # print(study.best_trial)
    # print("✅ Best params:", study.best_params)
