# Imports and Setup

In [None]:
import os
import gc
import joblib
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import segmentation_models_pytorch as smp
from PIL import Image
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from collections import defaultdict

import mlflow
from mlflow.tracking import MlflowClient
import optuna
import optuna.visualization as viz
from optuna.integration import PyTorchLightningPruningCallback
from pytorch_lightning.loggers import MLFlowLogger

# Set random seed for reproducibility
pl.seed_everything(42, workers=True)

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Classes and Definitions

In [None]:
class DepthDistillationDataset(Dataset):
    """Optimized dataset that loads pre-processed numpy masks."""
    
    def __init__(self, img_dir, depth_dir, seg_dir, num_classes, is_train=False, prob=None):
        self.img_dir = img_dir
        self.depth_dir = depth_dir
        self.seg_dir = seg_dir
        self.num_classes = num_classes
        self.is_train = is_train
        self.prob = prob if prob is not None else {'horizontal_flip': 0.5, 'color_jitter': 0.5}
        
        self.images = sorted([f for f in os.listdir(img_dir) if f.endswith('.png')])
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        base_name = img_name.replace('.png', '')
        
        # Load image
        image = Image.open(os.path.join(self.img_dir, img_name)).convert('RGB')
        
        # Load depth
        teacher_depth = np.load(os.path.join(self.depth_dir, f"{base_name}.npy"))

        # Load pre-processed segmentation mask (already class IDs!)
        seg_mask = np.load(os.path.join(self.seg_dir, f"{base_name}.npy"))
        
        # Augmentations
        if self.is_train:
            if torch.rand(1).item() < self.prob['horizontal_flip']:
                image = TF.hflip(image)
                teacher_depth = np.fliplr(teacher_depth).copy()
                seg_mask = np.fliplr(seg_mask).copy()
            
            if torch.rand(1).item() < self.prob['color_jitter']:
                image = T.ColorJitter(brightness=0.25, contrast=0.25)(image)
        
        # To tensors
        image = TF.to_tensor(image)
        image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        teacher_depth = torch.from_numpy(teacher_depth).float()
        seg_mask = torch.from_numpy(seg_mask.copy()).long()
        
        return image, teacher_depth, seg_mask

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(self, 
                 train_img_dir, train_depth_dir, train_seg_dir,
                 val_img_dir, val_depth_dir, val_seg_dir,
                 test_img_dir, test_depth_dir, test_seg_dir,
                 num_classes=32,
                 batch_size=4, 
                 num_workers=2,
                 prob=None):
        """
        Args:
            train_img_dir, train_depth_dir, train_seg_dir: Training data directories
            val_img_dir, val_depth_dir, val_seg_dir: Validation data directories
            test_img_dir, test_depth_dir, test_seg_dir: Test data directories
            batch_size: Batch size for dataloaders
            num_workers: Number of workers for dataloaders
            prob: Dictionary with augmentation probabilities
        """
        super().__init__()
        self.train_img_dir = train_img_dir
        self.train_depth_dir = train_depth_dir
        self.train_seg_dir = train_seg_dir
        self.val_img_dir = val_img_dir
        self.val_depth_dir = val_depth_dir
        self.val_seg_dir = val_seg_dir
        self.test_img_dir = test_img_dir
        self.test_depth_dir = test_depth_dir
        self.test_seg_dir = test_seg_dir
        self.num_classes = num_classes
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.prob = prob if prob is not None else {'horizontal_flip': 0.5, 'color_jitter': 0.5}
        
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = DepthDistillationDataset(
                img_dir=self.train_img_dir,
                depth_dir=self.train_depth_dir,
                seg_dir=self.train_seg_dir,
                num_classes=self.num_classes,
                is_train=True,
                prob=self.prob
            )
            self.val_dataset = DepthDistillationDataset(
                img_dir=self.val_img_dir,
                depth_dir=self.val_depth_dir,
                seg_dir=self.val_seg_dir,
                num_classes=self.num_classes,
                is_train=False
            )
        
        if stage == 'test' or stage is None:
            self.test_dataset = DepthDistillationDataset(
                img_dir=self.test_img_dir,
                depth_dir=self.test_depth_dir,
                seg_dir=self.test_seg_dir,
                num_classes=self.num_classes,
                is_train=False
            )
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False
        )

In [None]:
class DepthSegMultiTaskLoss(nn.Module):
    """
    Multi-task loss with learnable uncertainty weighting.
    Based on "Multi-Task Learning Using Uncertainty to Weigh Losses" (Kendall et al., 2018)
    """
    def __init__(self, init_log_var_depth=0.0, init_log_var_seg=0.0):
        super().__init__()
        self.log_var_depth = nn.Parameter(torch.tensor(init_log_var_depth))
        self.log_var_seg = nn.Parameter(torch.tensor(init_log_var_seg))
        
        self.mse_loss = nn.MSELoss()
        self.ce_loss = nn.CrossEntropyLoss()
    
    def forward(self, depth_pred, depth_target, seg_pred, seg_target):
        # Ensure correct shapes
        if depth_pred.dim() == 4:
            depth_pred = depth_pred.squeeze(1)
        
        # Task losses
        depth_loss = self.mse_loss(depth_pred, depth_target)
        seg_loss = self.ce_loss(seg_pred, seg_target)
        
        # Uncertainty weighting
        # precision = 1/σ², using log_var = log(σ²) for stability
        precision_depth = torch.exp(-self.log_var_depth)
        precision_seg = torch.exp(-self.log_var_seg)
        
        weighted_depth = 0.5 * precision_depth * depth_loss + 0.5 * self.log_var_depth
        weighted_seg = 0.5 * precision_seg * seg_loss + 0.5 * self.log_var_seg
        
        total_loss = weighted_depth + weighted_seg
        
        # Return individual losses for logging
        return {
            'total_loss': total_loss,
            'depth_loss': depth_loss,
            'seg_loss': seg_loss,
            'weight_depth': precision_depth.detach(),
            'weight_seg': precision_seg.detach()
        }

In [None]:
class DepthSegDistillationModule(pl.LightningModule):
    def __init__(self, student_model, lr=3e-4, weight_decay=0.0):
        super().__init__()
        self.student = student_model
        self.lr = lr
        self.weight_decay = weight_decay
        
        # Multi-task loss with uncertainty weighting
        self.loss_fn = DepthSegMultiTaskLoss()
        
        self.save_hyperparameters(ignore=['student_model'])
    
    def forward(self, x):
        return self.student(x)
    
    def training_step(self, batch, batch_idx):
        images, teacher_depth, seg_mask = batch
        
        # Get predictions
        depth_pred, seg_pred = self(images)
        
        # Calculate multi-task loss
        loss_dict = self.loss_fn(depth_pred, teacher_depth, seg_pred, seg_mask)
        
        # Log all metrics
        self.log('train_total_loss', loss_dict['total_loss'], on_step=False, on_epoch=True)
        self.log('train_depth_loss', loss_dict['depth_loss'], on_step=False, on_epoch=True)
        self.log('train_seg_loss', loss_dict['seg_loss'], on_step=False, on_epoch=True)
        self.log('train_weight_depth', loss_dict['weight_depth'], on_step=False, on_epoch=True)
        self.log('train_weight_seg', loss_dict['weight_seg'], on_step=False, on_epoch=True)
        
        return loss_dict['total_loss']
    
    def validation_step(self, batch, batch_idx):
        images, teacher_depth, seg_mask = batch
        
        depth_pred, seg_pred = self(images)
        loss_dict = self.loss_fn(depth_pred, teacher_depth, seg_pred, seg_mask)
        
        self.log('val_total_loss', loss_dict['total_loss'], on_step=False, on_epoch=True)
        self.log('val_depth_loss', loss_dict['depth_loss'], on_step=False, on_epoch=True)
        self.log('val_seg_loss', loss_dict['seg_loss'], on_step=False, on_epoch=True)
        
        return loss_dict['total_loss']
    
    def test_step(self, batch, batch_idx):
        images, teacher_depth, seg_mask = batch
        
        depth_pred, seg_pred = self(images)
        loss_dict = self.loss_fn(depth_pred, teacher_depth, seg_pred, seg_mask)
        
        self.log('test_total_loss', loss_dict['total_loss'], on_step=False, on_epoch=True)
        self.log('test_depth_loss', loss_dict['depth_loss'], on_step=False, on_epoch=True)
        self.log('test_seg_loss', loss_dict['seg_loss'], on_step=False, on_epoch=True)
        
        return loss_dict['total_loss']
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),  # Includes loss function parameters!
            lr=self.lr,
            weight_decay=self.weight_decay
        )
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=20
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_total_loss'
            }
        }

In [None]:
class MultiTaskUnet(nn.Module):
    def __init__(self, num_classes, encoder_name='resnet18', encoder_weights='imagenet'):
        super().__init__()
        base = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=3,
            classes=1,
            activation=None
        )
        self.encoder = base.encoder
        self.decoder = base.decoder
        
        decoder_out_channels = base.segmentation_head[0].in_channels
        
        # Two heads
        self.depth_head = nn.Conv2d(decoder_out_channels, 1, kernel_size=1)
        self.seg_head = nn.Conv2d(decoder_out_channels, num_classes, kernel_size=1)
    
    def forward(self, x):
        feats = self.encoder(x)
        dec = self.decoder(feats)
        
        depth = self.depth_head(dec)
        seg = self.seg_head(dec)
        
        return depth, seg

# Dataset Verification

In [None]:
# Define your paths
TRAIN_IMG_DIR = '../CamVid/train/'
TRAIN_DEPTH_DIR = '../CamVid/train_labels/train_depths/'
TRAIN_SEG_DIR = '../CamVid/train_labels/train_seg_npy/'
VAL_IMG_DIR = '../CamVid/val/'
VAL_DEPTH_DIR = '../CamVid/val_labels/val_depths/'
VAL_SEG_DIR = '../CamVid/val_labels/val_seg_npy/'
TEST_IMG_DIR = '../CamVid/test/'
TEST_DEPTH_DIR = '../CamVid/test_labels/test_depths/'
TEST_SEG_DIR = '../CamVid/test_labels/test_seg_npy/'
CLASS_DICT_PATH = '../CamVid/class_dict.csv'

# Define augmentation probabilities
prob = {
    'horizontal_flip': 0.5,
    'color_jitter': 0.3
}

# Create DataModule
data_module = DataModule(
    train_img_dir=TRAIN_IMG_DIR,
    train_depth_dir=TRAIN_DEPTH_DIR,
    train_seg_dir=TRAIN_SEG_DIR,
    val_img_dir=VAL_IMG_DIR,
    val_depth_dir=VAL_DEPTH_DIR,
    val_seg_dir=VAL_SEG_DIR,
    test_img_dir=TEST_IMG_DIR,
    test_depth_dir=TEST_DEPTH_DIR,
    test_seg_dir=TEST_SEG_DIR,
    batch_size=4,
    num_workers=2,
    prob=prob
)

# Test it
data_module.setup()
print(f"Training samples: {len(data_module.train_dataset)}")
print(f"Validation samples: {len(data_module.val_dataset)}")
print(f"Test samples: {len(data_module.test_dataset)}")

sample_img, sample_depth, sample_mask = data_module.train_dataset[0]
print(f"\nImage Size: {sample_img.shape}")
print(f"Depth Map Size: {sample_depth.shape}")
print(f"Segmentation Mask Size: {sample_mask.shape}")

All shapes are 720x960, so the dataset is correctly loading the samples

# MLFlow and Optuna Setup

## MLFlow URI and Experiment Name

In [None]:
# Set tracking URI
mlflow.set_tracking_uri("mlruns")

# Set experiment name
EXPERIMENT_NAME = "depth_distillation_and_semantic_segmentation_optimization"
mlflow.set_experiment(EXPERIMENT_NAME)

print(f"MLflow tracking URI: {mlflow.get_tracking_uri()}")
print(f"Experiment: {EXPERIMENT_NAME}")

## Optuna Objective Function

In [None]:
def objective(trial):
    """
    Optuna objective function for hyperparameter optimization.
    Trains depth distillation model and logs to MLflow.
    """
    
    # Suggest hyperparameters
    params = {
        
        # Training parameters
        'batch_size': trial.suggest_categorical('batch_size', [16, 24, 32]),
        'learning_rate': trial.suggest_float('learning_rate', 5e-5, 1e-3),
        'num_epochs': trial.suggest_int('num_epochs', 80, 200),
        
        # Augmentation probabilities
        'prob_horizontal_flip': trial.suggest_float('prob_horizontal_flip', 0.0, 0.7),
        'prob_color_jitter': trial.suggest_float('prob_color_jitter', 0.0, 0.7),
        
        # Optimizer parameters
        'weight_decay': trial.suggest_float('weight_decay', 1e-6, 1e-3),
        
        # Fixed parameters
        'encoder_name': 'resnet18',
        'encoder_weights': 'imagenet',
        'num_workers': 2,
        'gradient_clip_val': 1.0,
    }
    
    # Start MLflow run (nested under parent run)
    with mlflow.start_run(nested=True, run_name=f"trial_{trial.number}"):
        
        # Log all hyperparameters
        mlflow.log_params(params)
        mlflow.log_param("trial_number", trial.number)
        
        # Create augmentation probabilities dict
        prob = {
            'horizontal_flip': params['prob_horizontal_flip'],
            'color_jitter': params['prob_color_jitter']
        }
        
        # Create data module
        data_module = DataModule(
            train_img_dir=TRAIN_IMG_DIR,
            train_depth_dir=TRAIN_DEPTH_DIR,
            train_seg_dir=TRAIN_SEG_DIR,
            val_img_dir=VAL_IMG_DIR,
            val_depth_dir=VAL_DEPTH_DIR,
            val_seg_dir=VAL_SEG_DIR,
            test_img_dir=TEST_IMG_DIR,
            test_depth_dir=TEST_DEPTH_DIR,
            test_seg_dir=TEST_SEG_DIR,
            batch_size=params['batch_size'],
            num_workers=params['num_workers'],
            prob=prob
        )
        data_module.setup()
        
        # Create student model
        student = MultiTaskUnet(
            num_classes=data_module.train_dataset.num_classes,
            encoder_name=params['encoder_name'],
            encoder_weights=params['encoder_weights'],
        )
        
        num_params = sum(p.numel() for p in student.parameters())
        mlflow.log_param("model_parameters", num_params)
        
        # Create Lightning module
        lightning_module = DepthSegDistillationModule(
            student_model=student,
            lr=params['learning_rate'],
            weight_decay=params['weight_decay']
        )
        
        # Setup MLflow logger for this trial
        mlflow_logger = MLFlowLogger(
            experiment_name=EXPERIMENT_NAME,
            run_id=mlflow.active_run().info.run_id
        )
        
        # Setup callbacks
        checkpoint_callback = ModelCheckpoint(
            monitor='val_total_loss',
            mode='min',
            filename=f'trial_{trial.number}_best',
            save_top_k=1,
            verbose=False
        )
        
        early_stopping = EarlyStopping(
            monitor='val_total_loss',
            patience=15,
            mode='min',
            verbose=False
        )
        
        # Optuna pruning callback
        pruning_callback = PyTorchLightningPruningCallback(trial, monitor="val_total_loss")
        
        # Create trainer
        trainer = pl.Trainer(
            max_epochs=params['num_epochs'],
            logger=mlflow_logger,
            callbacks=[checkpoint_callback, early_stopping, pruning_callback],
            accelerator='gpu' if torch.cuda.is_available() else 'cpu',
            devices=1,
            gradient_clip_val=params['gradient_clip_val'],
            log_every_n_steps=10,
            enable_progress_bar=False,
            enable_model_summary=False
        )
        
        # Train the model
        try:
            trainer.fit(lightning_module, data_module)
            
            # FIXED: Get best validation loss properly with proper error handling
            if checkpoint_callback.best_model_score is not None:
                best_val_loss = float(checkpoint_callback.best_model_score)
            else:
                # Fallback: get from logged metrics
                best_val_loss = float(trainer.callback_metrics.get('val_total_loss', float('inf')))
            
            # Log best metrics
            mlflow.log_metric("best_val_total_loss", best_val_loss)
            mlflow.log_metric("epochs_trained", trainer.current_epoch)
            
            # Evaluate on test set
            test_results = trainer.test(
                lightning_module, 
                data_module, 
                ckpt_path=checkpoint_callback.best_model_path
            )
            
            if test_results:
                test_loss = test_results[0].get('test_total_loss', None)
                if test_loss:
                    mlflow.log_metric("test_total_loss", test_loss)
            
            # Log model artifact
            mlflow.pytorch.log_model(student, "model")
            
            # Log checkpoint (only if path exists)
            if checkpoint_callback.best_model_path:
                mlflow.log_artifact(checkpoint_callback.best_model_path, "checkpoints")
            
            print(f"Trial {trial.number}: val_total_loss={best_val_loss:.4f}, "
                  f"epochs={trainer.current_epoch}")
            
            return best_val_loss
            
        except optuna.TrialPruned:
            print(f"Trial {trial.number} pruned")
            raise
        except Exception as e:
            print(f"Trial {trial.number} failed: {e}")
            import traceback
            traceback.print_exc()  # Print full traceback for debugging
            mlflow.log_param("status", "failed")
            mlflow.log_param("error", str(e))
            return float('inf')
        finally:
            # CRITICAL: Clean up resources after each trial
            print(f"Cleaning up Trial {trial.number}...")
            
            # Delete objects
            if 'trainer' in locals():
                del trainer
            if 'lightning_module' in locals():
                del lightning_module
            if 'student' in locals():
                del student
            if 'data_module' in locals():
                del data_module
            
            # Force garbage collection
            gc.collect()
            
            # Clear CUDA cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
            
            print(f"Trial {trial.number} cleanup complete")

# Train Models -> Launch Optuna Study

In [None]:
N_TRIALS = 50

# Safety: end any active runs before starting
mlflow.end_run()

# Set random seed for reproducibility
pl.seed_everything(42, workers=True)

# Create parent MLflow run
with mlflow.start_run(run_name="optuna_depth_distillation_and_semantic_seg") as parent_run:
    
    # Log study configuration
    mlflow.log_param("optimization_metric", "val_total_loss")
    mlflow.log_param("n_trials", N_TRIALS)
    mlflow.log_param("model_type", "Depth_Distillation_and_Semantic_Segmentation_UNet")
    mlflow.log_param("dataset", "CamVid")
    
    # Create Optuna study
    study = optuna.create_study(
        study_name="depth_distillation_and_semantic_segmentation_optimization",
        direction="minimize",
        sampler=optuna.samplers.TPESampler(seed=42),
        pruner=optuna.pruners.MedianPruner(n_startup_trials=15, n_warmup_steps=25),
        storage="sqlite:///optuna_study_depth_and_seg_weighted.db",  # Persist to database
        load_if_exists=True  # Resume if interrupted
    )
    
    # Run optimization
    print("Starting Optuna optimization...")
    study.optimize(
        objective,
        n_trials=N_TRIALS,
        show_progress_bar=True
    )
    
    # Log best trial information
    best_trial = study.best_trial
    mlflow.log_params({f"best_{k}": v for k, v in best_trial.params.items()})
    mlflow.log_metric("best_val_total_loss", best_trial.value)
    
    print(f"\n{'='*60}")
    print("Optimization Complete!")
    print(f"{'='*60}")
    print(f"Best trial number: {best_trial.number}")
    print(f"Best validation loss: {best_trial.value:.4f}")
    print(f"Best hyperparameters:")
    for key, value in best_trial.params.items():
        print(f"  {key}: {value}")
    
    # Generate and log optimization visualizations
    print("\nGenerating optimization visualizations...")
    
    try:
        # Create plots directory
        os.makedirs("plots", exist_ok=True)
        
        # Plot optimization history
        fig1 = viz.plot_optimization_history(study)
        fig1.write_html("plots/optimization_history.html")
        mlflow.log_artifact("plots/optimization_history.html", "optimization_plots")
        
        # Plot parameter importances
        fig2 = viz.plot_param_importances(study)
        fig2.write_html("plots/param_importances.html")
        mlflow.log_artifact("plots/param_importances.html", "optimization_plots")
        
        # Plot parallel coordinate
        fig3 = viz.plot_parallel_coordinate(study)
        fig3.write_html("plots/parallel_coordinate.html")
        mlflow.log_artifact("plots/parallel_coordinate.html", "optimization_plots")
        
        # Plot slice
        fig4 = viz.plot_slice(study)
        fig4.write_html("plots/slice_plot.html")
        mlflow.log_artifact("plots/slice_plot.html", "optimization_plots")
        
        print("Optimization plots logged to MLflow")
        
    except Exception as e:
        print(f"Warning: Could not generate visualizations: {e}")

# Evaluate Best Model

## Load Best Model

### Identify Best Run

In [None]:
# Initialize MLflow client
client = MlflowClient()

# Set experiment name
EXPERIMENT_NAME = "depth_distillation_and_semantic_segmentation_optimization"

# Get experiment
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
if experiment is None:
    raise ValueError(f"Experiment '{EXPERIMENT_NAME}' not found!")

experiment_id = experiment.experiment_id

# Search for best run
all_runs = client.search_runs(
    experiment_ids=[experiment_id],
    filter_string="metrics.best_val_total_loss > 0",
    order_by=["metrics.best_val_total_loss ASC"],
    max_results=1
)

if len(all_runs) == 0:
    raise ValueError("No runs found with best_val_total_loss metric!")

best_run = all_runs[0]
best_run_id = best_run.info.run_id

print(f"{'='*60}")
print("Best Run Found!")
print(f"{'='*60}")
print(f"Run ID: {best_run_id}")
print(f"Best Val Loss: {best_run.data.metrics.get('best_val_total_loss', 'N/A'):.4f}")

try:
    # Download artifacts directory
    artifacts_path = client.download_artifacts(best_run_id, "")
    print(f"\nDownloaded artifacts to: {artifacts_path}")
    
    # Look for ckpt model in checkpoints
    checkpoints_dir = os.path.join(artifacts_path, "checkpoints")
    
    if os.path.exists(checkpoints_dir):
        # Find ckpt file
        trial_number = best_run.data.params.get('trial_number', '0')
        ckpt_files = [f for f in os.listdir(checkpoints_dir) 
                       if f.endswith('.ckpt')]
        
        if ckpt_files:
            best_checkpoint_path = os.path.join(checkpoints_dir, ckpt_files[0])
            print(f"Loading model from: {best_checkpoint_path}")
        else:
            raise FileNotFoundError(f"No ckpt files found in {checkpoints_dir}")
    else:
        raise FileNotFoundError(f"checkpoints directory not found at {artifacts_path}")
        
except Exception as e:
    print(f"Error loading model from artifacts: {e}")
    raise

### Evaluate over Validation Set
We load and test the model over validation set to assure everithing is right and we get the same "best_val_loss"

In [None]:
# Get best hyperparameters from MLflow - FIXED
best_params = {
    'encoder_name': best_run.data.params['encoder_name'],
    'encoder_weights': best_run.data.params['encoder_weights'],
    'learning_rate': float(best_run.data.params['learning_rate']),
    'weight_decay': float(best_run.data.params['weight_decay']),
    'batch_size': int(best_run.data.params['batch_size'])
}

print("Best Hyperparameters:")
for key, value in best_params.items():
    print(f"  {key}: {value}")

# Create test data module
test_data_module = DataModule(
    train_img_dir=TRAIN_IMG_DIR,
    train_depth_dir=TRAIN_DEPTH_DIR,
    train_seg_dir=TRAIN_SEG_DIR,
    val_img_dir=VAL_IMG_DIR,
    val_depth_dir=VAL_DEPTH_DIR,
    val_seg_dir=VAL_SEG_DIR,
    test_img_dir=VAL_IMG_DIR, # We use val split
    test_depth_dir=VAL_DEPTH_DIR, # We use val split
    test_seg_dir=VAL_SEG_DIR, # We use val split
    batch_size=best_params['batch_size'],
    num_workers=2,
    prob=None
)
test_data_module.setup()

# Create the model with best architecture
student = MultiTaskUnet(
    num_classes=test_data_module.train_dataset.num_classes,
    encoder_name=best_params['encoder_name'],
    encoder_weights=best_params['encoder_weights'],
)

# Create Lightning module
lightning_module = DepthSegDistillationModule(
    student_model=student,
    lr=best_params['learning_rate'],
    weight_decay=best_params['weight_decay']
)

# Load checkpoint
checkpoint = torch.load(best_checkpoint_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
lightning_module.load_state_dict(checkpoint['state_dict'])
lightning_module.eval()

print("\nBest model loaded successfully!")

# Create trainer for testing
test_trainer = pl.Trainer(
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    logger=False
)

# Test the model
print("\n" + "="*60)
print("Testing Best Model...")
print("="*60)

test_results = test_trainer.test(lightning_module, test_data_module)

print("\n" + "="*60)
print("FINAL TEST RESULTS")
print("="*60)
for key, value in test_results[0].items():
    print(f"  {key}: {value:.4f}")
print("="*60)

Perfect. We got the same loss, so the model was found and loaded correctly.
Now, lets evaluate over the Test split.

### Evaluate over Test Set

In [None]:
# Get best hyperparameters from MLflow - FIXED
best_params = {
    'encoder_name': best_run.data.params['encoder_name'],
    'encoder_weights': best_run.data.params['encoder_weights'],
    'learning_rate': float(best_run.data.params['learning_rate']),
    'weight_decay': float(best_run.data.params['weight_decay']),
    'batch_size': int(best_run.data.params['batch_size'])
}

print("Best Hyperparameters:")
for key, value in best_params.items():
    print(f"  {key}: {value}")

# Create test data module
test_data_module = DataModule(
    train_img_dir=TRAIN_IMG_DIR,
    train_depth_dir=TRAIN_DEPTH_DIR,
    train_seg_dir=TRAIN_SEG_DIR,
    val_img_dir=VAL_IMG_DIR,
    val_depth_dir=VAL_DEPTH_DIR,
    val_seg_dir=VAL_SEG_DIR,
    test_img_dir=TEST_IMG_DIR,
    test_depth_dir=TEST_DEPTH_DIR,
    test_seg_dir=TEST_SEG_DIR,
    batch_size=best_params['batch_size'],
    num_workers=2,
    prob=None
)
test_data_module.setup()

# Create the model with best architecture
student = MultiTaskUnet(
    num_classes=test_data_module.train_dataset.num_classes,
    encoder_name=best_params['encoder_name'],
    encoder_weights=best_params['encoder_weights'],
)

# Create Lightning module
lightning_module = DepthSegDistillationModule(
    student_model=student,
    lr=best_params['learning_rate'],
    weight_decay=best_params['weight_decay']
)

# Load checkpoint
checkpoint = torch.load(best_checkpoint_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
lightning_module.load_state_dict(checkpoint['state_dict'])
lightning_module.eval()

print("\nBest model loaded successfully!")

# Create trainer for testing
test_trainer = pl.Trainer(
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    logger=False
)

# Test the model
print("\n" + "="*60)
print("Testing Best Model...")
print("="*60)

test_results = test_trainer.test(lightning_module, test_data_module)

print("\n" + "="*60)
print("FINAL TEST RESULTS")
print("="*60)
for key, value in test_results[0].items():
    print(f"  {key}: {value:.4f}")
print("="*60)

## Visualize Test Predictions

In [None]:
def visualize_predictions(model, dataset, num_samples=5, seed=42, title_prefix="Test"):
    """Visualize student predictions vs teacher depth maps and segmentation"""
    
    model = model.to(device)
    model.eval()
    
    random.seed(seed)
    offset = random.randrange(0, len(dataset) - num_samples)
    
    with torch.no_grad():
        for i in range(min(num_samples, len(dataset))):
            # Create figure with gridspec
            fig = plt.figure(figsize=(15, 10))
            gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.3, wspace=0.3)
            
            image, teacher_depth, seg_mask_gt = dataset[offset + i]
            
            # Get student predictions
            student_depth_pred, student_seg_pred = model(image.unsqueeze(0).to(device))
            student_depth_pred = student_depth_pred.squeeze().cpu().numpy()
            student_seg_pred = torch.argmax(student_seg_pred, dim=1).squeeze().cpu().numpy()
            
            # Denormalize image for visualization
            img_display = image.clone()
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            img_display = img_display * std + mean
            img_display = torch.clamp(img_display, 0, 1)
            
            # [0, 0] and [1, 0]: Original Image (spans 2 rows)
            ax_img = fig.add_subplot(gs[:, 0])
            ax_img.imshow(img_display.permute(1, 2, 0))
            ax_img.set_title(f'{title_prefix} Sample {i+1}: Input Image', fontsize=12, fontweight='bold')
            ax_img.axis('off')
            
            # [0, 1]: Teacher Depth Map
            ax_teacher_depth = fig.add_subplot(gs[0, 1])
            im1 = ax_teacher_depth.imshow(teacher_depth.numpy(), cmap='Spectral')
            ax_teacher_depth.set_title('Teacher Depth (DA3)', fontsize=12, fontweight='bold')
            ax_teacher_depth.axis('off')
            plt.colorbar(im1, ax=ax_teacher_depth, fraction=0.046, pad=0.04)
            
            # [0, 2]: Ground Truth Semantic Segmentation
            ax_seg_gt = fig.add_subplot(gs[0, 2])
            im2 = ax_seg_gt.imshow(seg_mask_gt.numpy(), cmap='tab20')
            ax_seg_gt.set_title('GT Semantic Segmentation', fontsize=12, fontweight='bold')
            ax_seg_gt.axis('off')
            plt.colorbar(im2, ax=ax_seg_gt, fraction=0.046, pad=0.04)
            
            # [1, 1]: Student Predicted Depth Map
            ax_student_depth = fig.add_subplot(gs[1, 1])
            im3 = ax_student_depth.imshow(student_depth_pred, cmap='Spectral')
            ax_student_depth.set_title('Student Depth (U-Net ResNet18)', fontsize=12, fontweight='bold')
            ax_student_depth.axis('off')
            plt.colorbar(im3, ax=ax_student_depth, fraction=0.046, pad=0.04)
            
            # [1, 2]: Student Predicted Semantic Segmentation
            ax_student_seg = fig.add_subplot(gs[1, 2])
            im4 = ax_student_seg.imshow(student_seg_pred, cmap='tab20')
            ax_student_seg.set_title('Student Semantic Segmentation', fontsize=12, fontweight='bold')
            ax_student_seg.axis('off')
            plt.colorbar(im4, ax=ax_student_seg, fraction=0.046, pad=0.04)
            
            plt.savefig(f'{title_prefix.lower()}_predictions_sample_{i+1}.png', dpi=150, bbox_inches='tight')
            plt.show()

print("Visualizing predictions on TEST set:")
test_dataset = DepthDistillationDataset(
    img_dir=TEST_IMG_DIR,
    depth_dir=TEST_DEPTH_DIR,
    seg_dir=TEST_SEG_DIR,
    num_classes=32,
    is_train=False
)
visualize_predictions(lightning_module.student, test_dataset, num_samples=5, title_prefix="Test")

## Calculate Metrics

In [None]:
def calculate_metrics(model, dataloader, split_name="Test"):
    """Calculate depth estimation metrics"""
    model.to(device)
    model.eval()
    
    total_mse = 0
    total_mae = 0
    total_abs_rel = 0
    num_samples = 0
    
    with torch.no_grad():
        for images, teacher_depth, seg_mask_gt in dataloader:
            images = images.to(device)
            teacher_depth = teacher_depth.to(device)
            seg_mask_gt = seg_mask_gt.to(device)
            
            # Get predictions
            student_depth_pred, student_seg_pred = model(images)
            student_depth_pred = student_depth_pred.squeeze(1)
            
            # Calculate metrics
            mse = F.mse_loss(student_depth_pred, teacher_depth)
            mae = F.l1_loss(student_depth_pred, teacher_depth)
            abs_rel = torch.mean(torch.abs(student_depth_pred - teacher_depth) / (teacher_depth + 1e-8))
            
            total_mse += mse.item() * images.size(0)
            total_mae += mae.item() * images.size(0)
            total_abs_rel += abs_rel.item() * images.size(0)
            num_samples += images.size(0)
    
    metrics = {
        'MSE': total_mse / num_samples,
        'MAE': total_mae / num_samples,
        'RMSE': np.sqrt(total_mse / num_samples),
        'Abs Rel': total_abs_rel / num_samples
    }
    
    print(f"\n{split_name} Set Metrics:")
    print("=" * 40)
    for metric_name, value in metrics.items():
        print(f"  {metric_name:12s}: {value:.4f}")
    print("=" * 40)
    
    return metrics

# Create data module
test_data_module = DataModule(
    train_img_dir=TRAIN_IMG_DIR,
    train_depth_dir=TRAIN_DEPTH_DIR,
    train_seg_dir=TRAIN_SEG_DIR,
    val_img_dir=VAL_IMG_DIR,
    val_depth_dir=VAL_DEPTH_DIR,
    val_seg_dir=VAL_SEG_DIR,
    test_img_dir=TEST_IMG_DIR,
    test_depth_dir=TEST_DEPTH_DIR,
    test_seg_dir=TEST_SEG_DIR,
    batch_size=best_params['batch_size'],
    num_workers=2,
    prob=None
)
test_data_module.setup(stage='test')

# Calculate metrics on TEST set
test_loader = test_data_module.test_dataloader()
test_metrics = calculate_metrics(lightning_module.student, test_loader, split_name="Test")

## Test Set Detailed Analysis

In [None]:
def detailed_test_analysis(model, dataset, num_samples=10):
    """Perform detailed analysis on test samples"""
    model.to(device)
    model.eval()
    
    errors = []

    worst_sample_idx = -1
    best_sample_idx = -1
    best_mse = np.inf
    worst_mse = 0
    
    with torch.no_grad():
        for i in range(min(num_samples, len(dataset))):
            image, teacher_depth, seg_mask_gt = dataset[i]
            
            # Get prediction
            student_depth_pred, student_seg_pred = model(image.unsqueeze(0).to(device))
            student_depth_pred = student_depth_pred.squeeze().cpu()
            
            # Calculate per-sample error
            mse = F.mse_loss(student_depth_pred, teacher_depth).item()
            mae = F.l1_loss(student_depth_pred, teacher_depth).item()
            
            errors.append({
                'sample_idx': i,
                'MSE': mse,
                'MAE': mae
            })

            # Track worst and best samples taking into account MSE value
            if mse > worst_mse:
                worst_mse = mse
                worst_sample_idx = i
                
            if mse < best_mse:
                best_mse = mse
                best_sample_idx = i
    
    # Print per-sample results
    print("\nPer-Sample Test Results:")
    print("-" * 50)
    for error in errors:
        print(f"Sample {error['sample_idx']:3d}: MSE={error['MSE']:.4f}, MAE={error['MAE']:.4f}")
    
    # Statistical summary
    mse_values = [e['MSE'] for e in errors]
    mae_values = [e['MAE'] for e in errors]
    
    print("\nStatistical Summary:")
    print("-" * 50)
    print(f"MSE - Mean: {np.mean(mse_values):.4f}, Std: {np.std(mse_values):.4f}")
    print(f"MAE - Mean: {np.mean(mae_values):.4f}, Std: {np.std(mae_values):.4f}")

    return worst_sample_idx, best_sample_idx

def visualize_edge_predictions(model, dataset, worst_sample_idx, best_sample_idx):
    """Visualize student predictions vs teacher depth maps"""
    
    model = model.to(device)
    model.eval()
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    with torch.no_grad():

        # Worst predicted sample
        image, teacher_depth, seg_mask_gt = dataset[worst_sample_idx]
            
        # Get student prediction
        student_depth_pred, student_seg_pred = model(image.unsqueeze(0).to(device))
        student_pred = student_depth_pred.squeeze().cpu().numpy()
        
        # Denormalize image for visualization
        img_display = image.clone()
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        img_display = img_display * std + mean
        img_display = torch.clamp(img_display, 0, 1)
        
        # Plot
        axes[0, 0].imshow(img_display.permute(1, 2, 0))
        axes[0, 0].set_title(f'Worst predicted Test sample')
        axes[0, 0].axis('off')
        
        axes[0, 1].imshow(teacher_depth.numpy(), cmap='Spectral')
        axes[0, 1].set_title('Teacher Depth (DA3)')
        axes[0, 1].axis('off')
        
        axes[0, 2].imshow(student_pred, cmap='Spectral')
        axes[0, 2].set_title('Student Depth (U-Net ResNet18)')
        axes[0, 2].axis('off')

        # Best predicted sample
        image, teacher_depth, seg_mask_gt = dataset[best_sample_idx]
            
        # Get student prediction
        student_depth_pred, student_seg_pred = model(image.unsqueeze(0).to(device))
        student_pred = student_depth_pred.squeeze().cpu().numpy()
        
        # Denormalize image for visualization
        img_display = image.clone()
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        img_display = img_display * std + mean
        img_display = torch.clamp(img_display, 0, 1)
        
        # Plot
        axes[1, 0].imshow(img_display.permute(1, 2, 0))
        axes[1, 0].set_title(f'Best predicted Test sample')
        axes[1, 0].axis('off')
        
        axes[1, 1].imshow(teacher_depth.numpy(), cmap='Spectral')
        axes[1, 1].set_title('Teacher Depth (DA3)')
        axes[1, 1].axis('off')
        
        axes[1, 2].imshow(student_pred, cmap='Spectral')
        axes[1, 2].set_title('Student Depth (U-Net ResNet18)')
        axes[1, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(f'worst_best_test_depth_predictions.png', dpi=150, bbox_inches='tight')
    plt.show()

# Run detailed analysis
test_dataset = DepthDistillationDataset(
    img_dir=TEST_IMG_DIR,
    depth_dir=TEST_DEPTH_DIR,
    seg_dir=TEST_SEG_DIR,
)

worst_sample_idx, best_sample_idx = detailed_test_analysis(lightning_module.student, test_dataset, num_samples=len(test_dataset))
print(f"Worst predicted sample found at image #{worst_sample_idx}")
print(f"Best predicted sample found at image #{best_sample_idx}")
print()

visualize_edge_predictions(lightning_module.student, test_dataset, worst_sample_idx, best_sample_idx)