# Cross-Validation Pipeline for TCD-SegFormer

This notebook implements 5-fold cross-validation for robust model evaluation, allowing for:
- Testing both TrueResSegformer and standard Segformer models
- Evaluating with or without class weights
- Comparing metrics across different folds
- Visualizing aggregate performance

The cross-validation process will train models on different data splits and evaluate their performance, providing robust statistics on model capabilities.

In [None]:
# Import necessary libraries
import os
import sys
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import json

# Set matplotlib style
plt.style.use('ggplot')

# Import our project modules
from config import Config
from cross_validation import run_cross_validation
from utils import get_logger, set_seed

In [None]:
# Configure logger
logger = get_logger()
is_notebook = True  # Flag to indicate we're running in a notebook

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Set random seed for reproducibility
SEED = 42
set_seed(SEED)
logger.info(f"Random seed set to {SEED}")

## Configuration Setup

Set up the configuration for cross-validation. You can experiment with different settings by modifying the parameters below.

In [None]:
# Initialize configuration
config = Config()

# Base configuration 
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config["output_dir"] = f"./outputs/cv_{timestamp}"
config["dataset_name"] = "restor/tcd"  # Adjust if using a different dataset

# Model configuration
config["model_name"] = "nvidia/mit-b0"  # Can be changed to other models like mit-b1, mit-b2, etc.
config["use_true_res_segformer"] = True  # Set to True for TrueResSegformer or False for standard Segformer

# Class weights configuration
config["class_weights_enabled"] = True  # Set to True to use class weights, False otherwise

# Training parameters
config["num_epochs"] = 10  # Reducing epochs for faster cross-validation, increase for production
config["train_batch_size"] = 4  # Adjust based on your GPU memory
config["eval_batch_size"] = 8  # Adjust based on your GPU memory
config["learning_rate"] = 1e-5
config["apply_loss_at_original_resolution"] = True  # Whether to upsample logits for loss at original resolution

# Cross-validation parameters
config["cross_validation"]["enabled"] = True
config["cross_validation"]["num_folds"] = 5  # Standard 5-fold CV
config["cross_validation"]["metrics_to_track"] = ["f1_score_class_1", "IoU_class_1", "accuracy"]

# Performance optimization
config["mixed_precision"] = True  # Enable mixed precision for faster training
config["num_workers"] = 4  # Adjust based on your CPU cores

# Create output directory
os.makedirs(config["output_dir"], exist_ok=True)

# Save configuration for reproducibility
config_path = os.path.join(config["output_dir"], "cross_validation_config.json")
config.save(config_path)
logger.info(f"Configuration saved to {config_path}")

# Display key configuration settings
print(f"\n===== Cross-Validation Configuration =====\n")
print(f"Model: {'TrueResSegformer' if config['use_true_res_segformer'] else 'Standard Segformer'}")
print(f"Base model: {config['model_name']}")
print(f"Class weights enabled: {config['class_weights_enabled']}")
print(f"Loss at original resolution: {config['apply_loss_at_original_resolution']}")
print(f"Number of folds: {config['cross_validation']['num_folds']}")
print(f"Output directory: {config['output_dir']}\n")

## Run Cross-Validation

Execute the cross-validation process using our defined configuration. This will:
1. Split the dataset into 5 folds
2. Train a model on each training split
3. Evaluate on the corresponding validation split
4. Collect and aggregate metrics across all folds

In [None]:
# Run cross validation
logger.info("Starting cross-validation process...")

cv_results = run_cross_validation(
    config=config,
    num_folds=config["cross_validation"]["num_folds"],
    logger_obj=logger,
    is_notebook=True
)

logger.info("Cross-validation completed!")

## Analyze Results

Analyze and visualize the results of cross-validation.

In [None]:
# Load results from the saved JSON file
results_path = os.path.join(config["output_dir"], "cross_validation", "cv_results.json")
with open(results_path, 'r') as f:
    saved_results = json.load(f)

# Display aggregate metrics
print("\n===== Aggregate Metrics =====\n")
for key, value in saved_results["aggregate_metrics"].items():
    print(f"{key}: {value:.4f}")

In [None]:
# Create pandas DataFrame with metrics across folds
metrics_per_fold = saved_results["metrics_per_fold"]
fold_metrics_list = []

for fold_idx, metrics in enumerate(metrics_per_fold):
    metrics_dict = {"fold": fold_idx + 1}
    for metric_name, metric_value in metrics.items():
        if isinstance(metric_value, (int, float)):
            metrics_dict[metric_name] = metric_value
    fold_metrics_list.append(metrics_dict)

metrics_df = pd.DataFrame(fold_metrics_list)
metrics_df.set_index("fold", inplace=True)

# Display the DataFrame
metrics_df

In [None]:
# Visualize key metrics across folds
key_metrics = ["f1_score_class_1", "IoU_class_1", "accuracy", "precision_class_1", "recall_class_1"]
available_metrics = [m for m in metrics_df.columns if any(km in m for km in key_metrics)]

if available_metrics:
    plt.figure(figsize=(14, 8))
    
    # Plot bar chart
    ax = metrics_df[available_metrics].plot(kind="bar", figsize=(14, 8), rot=0)
    plt.title("Metrics Across Folds", fontsize=16)
    plt.ylabel("Score", fontsize=14)
    plt.xlabel("Fold", fontsize=14)
    plt.grid(axis="y", alpha=0.3)
    plt.ylim(0, 1.0) 
    
    # Add data labels on top of bars
    for container in ax.containers:
        ax.bar_label(container, fmt='%.3f', fontsize=9)
        
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.show()
    
    # Save the figure
    metrics_chart_path = os.path.join(config["output_dir"], "metrics_by_fold.png")
    plt.savefig(metrics_chart_path)
    print(f"Metrics visualization saved to: {metrics_chart_path}")

In [None]:
# Plot metrics trend across folds using line charts
if available_metrics:
    plt.figure(figsize=(12, 6))
    for metric in available_metrics:
        plt.plot(metrics_df.index, metrics_df[metric], marker='o', linewidth=2, label=metric)
    
    plt.title("Metrics Trend Across Folds", fontsize=16)
    plt.xlabel("Fold", fontsize=14)
    plt.ylabel("Score", fontsize=14)
    plt.ylim(0, 1)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=12)
    plt.xticks(metrics_df.index)
    plt.tight_layout()
    plt.show()
    
    # Save the figure
    trend_chart_path = os.path.join(config["output_dir"], "metrics_trend.png")
    plt.savefig(trend_chart_path)
    print(f"Trend visualization saved to: {trend_chart_path}")

## Display Boxplot of Metrics Distribution

Visualize the distribution of each metric across all folds using box plots.

In [None]:
# Create boxplots for each metric
if available_metrics:
    plt.figure(figsize=(12, 6))
    
    # Convert to long format for seaborn
    metrics_long = metrics_df[available_metrics].reset_index().melt(
        id_vars='fold', 
        value_vars=available_metrics, 
        var_name='Metric', 
        value_name='Score'
    )
    
    # Create boxplot
    sns.boxplot(x='Metric', y='Score', data=metrics_long)
    plt.title("Distribution of Metrics Across Folds", fontsize=16)
    plt.xlabel("Metric", fontsize=14)
    plt.ylabel("Score", fontsize=14)
    plt.grid(axis='y', alpha=0.3)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
    # Save the figure
    boxplot_path = os.path.join(config["output_dir"], "metrics_boxplot.png")
    plt.savefig(boxplot_path)
    print(f"Boxplot visualization saved to: {boxplot_path}")

## Find Best Performing Model

Identify the best performing model across all folds based on key metrics.

In [None]:
# Find best performing fold based on IoU or F1 score
target_metric = "IoU_class_1"  # Change this to your preferred metric
if target_metric in metrics_df.columns:
    best_fold = metrics_df[target_metric].idxmax()
    best_fold_metrics = metrics_df.loc[best_fold]
    best_model_dir = saved_results["best_model_dirs"][best_fold-1]  # Adjust for 0-indexing in the results
    
    print(f"\n===== Best Performing Model =====\n")
    print(f"Best fold: {best_fold}")
    print(f"Best model directory: {best_model_dir}")
    print(f"\nPerformance metrics:")
    for metric, value in best_fold_metrics.items():
        print(f"  {metric}: {value:.4f}")
else:
    print(f"Target metric '{target_metric}' not found in results.")

## Conclusion

The cross-validation results provide robust metrics on model performance across different data splits. These results can be used to:

1. Determine the most reliable model configuration
2. Compare different architectures (TrueResSegformer vs. standard Segformer)
3. Evaluate the impact of class weights on performance
4. Assess model stability across different data splits

For detailed analysis, review the saved metrics and visualizations in the output directory.