### Libraries and device

In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
import traceback

from models.MultiStage_FtResnet_V2 import CompositeModel
from util import preprocess_image

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Dataset and Collate function

In [None]:
class ImageQwenDataset(Dataset):
    """
    Custom Dataset for loading images and their corresponding Qwen embeddings.
    Assumes images are in image_folder and Qwen embeddings are in qwen_folder with the same base filename.
    Only loads images whose filenames are in valid_image_names.
    """

    def __init__(self, image_folder, qwen_folder, valid_image_names, image_size=384):
        """
        Constructor for the dataset.

        Args:
            image_folder (str): Path to the folder containing images.
            qwen_folder (str): Path to the folder containing Qwen embeddings.
            valid_image_names (list): List of valid image filenames to include.
            image_size (int): Size to which images will be resized (image_size x image_size).
        """

        self.image_folder = image_folder
        self.qwen_folder = qwen_folder
        self.image_files = sorted([f for f in os.listdir(image_folder) if f in valid_image_names])
        self.image_size = image_size


    def __len__(self):
        """
        Returns the total number of samples in the dataset.

        Returns:
            int: Number of samples.
        """

        return len(self.image_files)


    def __getitem__(self, idx):
        """
        Retrieves the image and its corresponding Qwen embedding at the specified index.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: (image_tensor, qwen_embedding) where image_tensor is a tensor of shape [3, image_size, image_size]
        """

        # Paths to the image and Qwen embedding
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_folder, img_name)
        qwen_path = os.path.join(self.qwen_folder, os.path.splitext(img_name)[0] + ".pt")

        # Load and preprocess the image
        image = preprocess_image(img_path, self.image_size)

        # Transform the image to tensor
        transform = transforms.Compose([transforms.ToTensor(),])
        image = transform(image)

        # Load the Qwen embedding
        qwen_embedding = torch.load(qwen_path)

        # check if the Qwen embedding tensor or the image tensor has NaN or Inf values
        if torch.isnan(qwen_embedding).any() or torch.isinf(qwen_embedding).any():
            #print(f"Warning: Corrupted data in {qwen_path}")
            return None
        if torch.isnan(image).any() or torch.isinf(image).any():
            #print(f"Warning: Corrupted image in {img_path}")
            return None

        return image, qwen_embedding

In [None]:
def collate_fn(batch):
    """
    Custom collate function to handle None values in the batch.
    Filters out None values and stacks the remaining tensors.

    Args:
        batch (list): List of samples, where each sample is a tuple (image_tensor, qwen_embedding) or None.

    Returns:
        tuple: (images_batch, qwen_batch) where images_batch is a tensor of shape [batch_size, 3, image_size, image_size]
    """

    # Filter out None values
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None

    # Stack the tensors
    images, qwen_list = zip(*batch)
    images_batch = torch.stack(images)
    qwen_batch = torch.stack(qwen_list)

    return images_batch, qwen_batch

## 2. Saving and Plotting utility

In [None]:
def save_training_plots(train_losses_history, val_losses_history, epoch, checkpoint_dir):
    """
    Function that saves training and validation loss plots for each epoch.

    Args:
        train_losses_history: List of training losses for each epoch.
        val_losses_history: List of validation losses for each epoch.
        epoch: Current epoch number (for naming the plot).
        checkpoint_dir: Directory where the plot will be saved.
    """
    
    plt.figure(figsize=(12, 5))
    
    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(train_losses_history, label='Training Loss', color='blue')
    plt.plot(val_losses_history, label='Validation Loss', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Progress')
    plt.legend()
    plt.grid(True)
    
    # Zoomed-in plot for the last 10 epochs (if available)
    plt.subplot(1, 2, 2)
    start_idx = max(0, len(train_losses_history) - 10)
    epochs_range = range(start_idx, len(train_losses_history))
    if len(epochs_range) > 1:
        plt.plot(epochs_range, train_losses_history[start_idx:], label='Training Loss', color='blue')
        plt.plot(epochs_range, val_losses_history[start_idx:], label='Validation Loss', color='red')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Last 10 Epochs')
        plt.legend()
        plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(checkpoint_dir, f'training_progress_epoch_{epoch+1}.png'), dpi=300, bbox_inches='tight')
    plt.close()

In [None]:
def save_checkpoint(model, optimizer, scheduler, epoch, train_loss, val_loss, checkpoint_dir, is_best=False, keep_last_n=3):
    """
    Function that saves a checkpoint of the model handling the best N models.

    Args:
        model: The model to save.
        optimizer: The optimizer state to save.
        scheduler: The scheduler state to save.
        epoch: Current epoch number.
        train_loss: Training loss for the current epoch.
        val_loss: Validation loss for the current epoch.
        checkpoint_dir: Directory where the checkpoint will be saved.
        is_best: Boolean indicating if this is the best model so far.
        keep_last_n: Number of best models to keep.

    Returns:
        str: Filename of the saved checkpoint if it is the best model, otherwise None.
    """

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }
    
    # Save the last checkpoint
    torch.save(checkpoint, os.path.join(checkpoint_dir, 'last_checkpoint.pth'))
    
    if is_best:
        # Handle the best N models
        best_models_info_path = os.path.join(checkpoint_dir, 'best_models_info.json')
        
        # Load existing best models info if it exists
        if os.path.exists(best_models_info_path):
            with open(best_models_info_path, 'r') as f:
                best_models = json.load(f)
        else:
            best_models = []
        
        # Add the new model info
        new_model_info = {
            'epoch': epoch,
            'val_loss': val_loss,
            'filename': f'best_model_epoch_{epoch+1}_loss_{val_loss:.4f}.pth'
        }
        best_models.append(new_model_info)
        
        # Order by validation loss and keep only the best N
        best_models.sort(key=lambda x: x['val_loss'])
        
        # Remove worst model files if necessary
        if len(best_models) > keep_last_n:
            models_to_remove = best_models[keep_last_n:]
            for model_info in models_to_remove:
                old_file_path = os.path.join(checkpoint_dir, model_info['filename'])
                if os.path.exists(old_file_path):
                    os.remove(old_file_path)
            best_models = best_models[:keep_last_n]
        
        # Save the new checkpoint
        torch.save(checkpoint, os.path.join(checkpoint_dir, new_model_info['filename']))
        
        # Update the info file with the new best models
        with open(best_models_info_path, 'w') as f:
            json.dump(best_models, f, indent=2)
        
        print(f"\n[INFO] New best model saved: {new_model_info['filename']}")
        return new_model_info['filename']
    
    return None

In [None]:
def load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None):
    """
    Function that loads a checkpoint.

    Args:
        checkpoint_path: Path to the checkpoint file.
        model: The model to load the state into.
        optimizer: The optimizer to load the state into (optional).
        scheduler: The scheduler to load the state into (optional).

    Returns:
        Tuple containing the epoch number, training loss, and validation loss.
    """

    print(f"[INFO] Loading checkpoint from: {checkpoint_path}")
    
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Load optimizer and scheduler states if provided
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Retrieve epoch and losses
    epoch = checkpoint['epoch']
    train_loss = checkpoint.get('train_loss', float('inf'))
    val_loss = checkpoint.get('val_loss', float('inf'))
    
    print(f"[INFO] Checkpoint loaded: Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    return epoch, train_loss, val_loss

## 3. Early Stopping

In [None]:
class EarlyStopping:
    """
    Class for early stopping during training.
    """

    def __init__(self, patience=20, min_delta=1e-4, restore_best_weights=True):
        """
        Early stopping to stop training when it no longer improves.

        Args:
            patience: Number of epochs to wait for improvement before stopping.
            min_delta: Minimum change in the monitored quantity to qualify as an improvement.
            restore_best_weights: Whether to restore model weights from the epoch with the best validation loss.
        """

        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = float('inf')
        self.counter = 0
        self.best_weights = None
        self.early_stop = False
        

    def __call__(self, val_loss, model):
        """
        Check if training should be stopped.

        Args:
            val_loss: Current validation loss.
            model: The model whose weights may be restored if early stopping is triggered.

        Returns:
            bool: True if training should be stopped, False otherwise.
        """

        # Check for improvement
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            if self.restore_best_weights:
                self.best_weights = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            self.counter += 1
            
        # Check if we have reached the patience limit
        if self.counter >= self.patience:
            self.early_stop = True
            
        return self.early_stop
    

    def restore_weights(self, model):
        """
        Restore the weights of the model to the best weights found during training.

        Args:
            model: The model to restore weights to.
        """

        if self.best_weights is not None:
            best_weights_on_device = {k: v.to(model.state_dict()[k].device) 
                                    for k, v in self.best_weights.items()}
            model.load_state_dict(best_weights_on_device)

## 4. Training

### Settings

In [None]:
# Paths
IMAGES_FOLDER = os.path.join("..", "..", "CV_data", "miniImageNet")  # Folder containing images
QWEN_EMBEDDINGS_PATH = os.path.join("..", "..", "CV_data", "separated_embeddings", "qwen_384")  # Folder containing Qwen embeddings
CSV_TRAIN = os.path.join("..", "assets", "train_images.csv")  # CSV file with image filenames of training images

# Image parameters
IMAGE_SIZE = 384  # Size to which images will be resized (IMAGE_SIZE x IMAGE_SIZE)

# Training parameters
TRAIN_SIZE = 0.8  # Proportion of data used for training
BATCH_SIZE = 8  # Batch size for training and validation
LEARNING_RATE_RESNET = 1e-4  # Learning rate for the ResNet backbone
LEARNING_RATE_ADAPTER = 1e-4  # Learning rate for the adapter layers
WEIGHT_DECAY = 1e-3  # Weight decay for optimizer
MAX_EPOCHS = 100  # Maximum number of epochs
CHECKPOINT_DIR = os.path.join("..", "models", "MultiStage_FtResnet_V2")  # Directory to save checkpoints
RESUME_TRAINING = True  # Whether to resume training from the last checkpoint
KEEP_BEST_N = 3  # Number of best models to keep
PLOT_SAVE_INTERVAL = 5  # Save plots every n epochs

# Early stopping parameters
EARLY_STOPPING_PATIENCE = 5  # Number of epochs with no improvement to wait before stopping
EARLY_STOPPING_MIN_DELTA = 1e-3  # Minimum change in validation loss to qualify as improvement

# Create the checkpoint directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True) 

### Dataset and Dataloader

In [None]:
# Extract image filenames from the CSV
images_csv = pd.read_csv(CSV_TRAIN)
images_list = images_csv['filename'].tolist()


# Create dataset and dataloaders
dataset = ImageQwenDataset(IMAGES_FOLDER, QWEN_EMBEDDINGS_PATH, images_list, image_size=IMAGE_SIZE)

train_size = int(TRAIN_SIZE * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

### Training parameters

In [None]:
# Initialize model setting parameters for training and fine-tuning
model = CompositeModel().to(device)

for param in model.resnet.parameters():
    param.requires_grad = True
for param in model.adapter.parameters():
    param.requires_grad = True
    

# Define optimizer, scheduler, loss function, and scaler
optimizer = optim.AdamW([
    {'params': model.resnet.parameters(), 'lr': LEARNING_RATE_RESNET},
    {'params': model.adapter.parameters(), 'lr': LEARNING_RATE_ADAPTER}
], weight_decay=WEIGHT_DECAY)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-5, threshold=0
)

loss_function = nn.MSELoss()

scaler = torch.amp.GradScaler(device)


# Early stopping
early_stopping = EarlyStopping(
    patience=EARLY_STOPPING_PATIENCE, 
    min_delta=EARLY_STOPPING_MIN_DELTA, 
    restore_best_weights=True
)


# Tracking variables
best_val_loss = float('inf')
start_epoch = 0
train_losses_history = []
val_losses_history = []

### Resume training

In [None]:
# Resume training if specified
if RESUME_TRAINING:
    last_checkpoint_path = os.path.join(CHECKPOINT_DIR, 'last_checkpoint.pth')
    if os.path.exists(last_checkpoint_path):
        try:
            # Load the last checkpoint if it exists
            start_epoch, last_train_loss, best_val_loss = load_checkpoint(
                last_checkpoint_path, model, optimizer, scheduler
            )
            start_epoch += 1
            
            # Load history if available
            history_path = os.path.join(CHECKPOINT_DIR, 'loss_history.json')
            if os.path.exists(history_path):
                with open(history_path, 'r') as f:
                    history = json.load(f)
                    train_losses_history = history.get('train_losses', [])
                    val_losses_history = history.get('val_losses', [])
            
            # Update early stopping with previous history
            if val_losses_history:
                early_stopping.best_loss = min(val_losses_history)
                best_epoch_idx = val_losses_history.index(early_stopping.best_loss)
                early_stopping.counter = len(val_losses_history) - 1 - best_epoch_idx
                print(f"[INFO] Early stopping restored: best_loss={early_stopping.best_loss:.4f}, counter={early_stopping.counter}")

            print(f"[INFO] Training resumed from epoch {start_epoch}")

        except Exception as e:
            print(f"[ERROR] Error loading checkpoint: {e}")
            print("[INFO] Start new training from scratch...")
            start_epoch = 0
    else:
        print("[INFO] No checkpoint found. Starting new training from scratch...")

### Epoch training and validation

In [None]:
def epoch_training(model, train_dataloader, loss_function, optimizer, scaler):
    """
    Function to train the model for one epoch.

    Args:
        model: The model to train.
        train_dataloader: DataLoader for the training data.
        loss_function: Loss function to use.
        optimizer: Optimizer to use.
        scaler: GradScaler for mixed precision training.

    Returns:
        float: Mean training loss for the epoch.
    """

    model.train()
    train_losses = []
    train_loop = tqdm(train_dataloader, desc=f"[LOOP] Training")
    
    # Iterate over batches
    for i, batch_data in enumerate(train_loop):
        # Skip batches with None values
        if batch_data[0] is None:
            continue
            
        try:
            # Get data and move to device
            imgs, qwen_embs = batch_data
            imgs = imgs.to(device)
            qwen_embs = qwen_embs.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass with mixed precision
            with torch.amp.autocast(device):
                target_seq_len = qwen_embs.shape[1]
                pred = model(imgs, target_seq_len=target_seq_len)
                loss = loss_function(pred, qwen_embs)

            # Check for NaN or Inf loss
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"[ERROR] Invalid loss at batch {i}, skipping.")
                continue

            # Backward pass and optimization step
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # Record loss
            train_losses.append(loss.item())
            train_loop.set_postfix(loss=f"{loss.item():.4f}")
            
        except Exception as e:
            print(f"[ERROR] Error at batch {i}: {e}")
            continue
    
    # Compute mean training loss
    mean_loss = np.mean(train_losses) if train_losses else float('inf')
    return mean_loss


In [None]:
def epoch_validation(model, val_dataloader, loss_function):
    """
    Function to validate the model for one epoch.

    Args:
        model: The model to validate.
        val_dataloader: DataLoader for the validation data.
        loss_function: Loss function to use.

    Returns:
        float: Mean validation loss for the epoch.
    """

    model.eval()
    val_losses = []
    val_loop = tqdm(val_dataloader, desc=f"[LOOP] Validation")
    
    with torch.no_grad():
        # Iterate over batches
        for i, batch_data in enumerate(val_loop):
            # Skip batches with None values
            if batch_data[0] is None:
                continue
            
            # Get data and move to device
            imgs, qwen_embs = batch_data
            imgs = imgs.to(device)
            qwen_embs = qwen_embs.to(device)

            # Forward pass with mixed precision
            with torch.amp.autocast(device):
                pred = model(imgs, target_seq_len=qwen_embs.shape[1])
                loss = loss_function(pred, qwen_embs)

            # Check for NaN or Inf loss
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"[ERROR] Invalid loss at batch {i}, skipping.")
                continue

            # Record loss
            val_losses.append(loss.item())
            val_loop.set_postfix(loss=f"{loss.item():.4f}")

    # Compute mean validation loss
    mean_loss = np.mean(val_losses) if val_losses else float('inf')
    return mean_loss

### Evaluation

In [None]:
def epoch_evaluation(model, train_loss, val_loss, optimizer, scheduler, epoch, early_stopping):
    """
    Function that evaluates the model at the end of each epoch, updates the learning rate,
    checks for early stopping, and saves the model checkpoint.

    Args:
        model: The model to evaluate.
        train_loss: Training loss for the current epoch.
        val_loss: Validation loss for the current epoch.
        optimizer: Optimizer used for training.
        scheduler: Learning rate scheduler.
        epoch: Current epoch number.
        early_stopping: EarlyStopping instance to check for stopping conditions.

    Returns:
        bool: True if training should stop, False otherwise.
        train_loss: Training loss for the current epoch.
        val_loss: Validation loss for the current epoch.
    """

    # Save training and validation losses
    train_losses_history.append(train_loss)
    val_losses_history.append(val_loss)
    
    # Update scheduler
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_loss)
    new_lr = optimizer.param_groups[0]['lr']
    
    # Epoch results
    print(f"\n[INFO] EPOCH RESULTS {epoch + 1}:")
    print(f"   Train Loss: {train_loss:.4f}")
    print(f"   Val Loss:   {val_loss:.4f}")
    print(f"   Learning Rate: {new_lr:.2e}")

    print(f"\n[INFO] Scheduler state: {scheduler.num_bad_epochs}/{scheduler.patience}")

    if old_lr != new_lr:
        print(f"[INFO] LR reduced: {old_lr:.2e} → {new_lr:.2e}")
    
    # Early stopping check
    if early_stopping(val_loss, model):
        print(f"\n[INFO] EARLY STOPPING ACTIVATED!")
        print(f"   Epoch: {epoch + 1}")
        print(f"   Best val_loss: {early_stopping.best_loss:.4f}")
        print(f"   Epochs without improvement: {early_stopping.patience}")
        
        if early_stopping.restore_best_weights:
            early_stopping.restore_weights(model)
            print("[INFO] Restored weights of the best model")
                    
        # Save the best model checkpoint
        save_checkpoint(model, optimizer, scheduler, epoch, train_loss, 
                        early_stopping.best_loss, CHECKPOINT_DIR, is_best=True, keep_last_n=KEEP_BEST_N)
        
        return True, train_loss, val_loss

    print(f"[INFO] Early stopping: {early_stopping.counter}/{early_stopping.patience}")
    return False, train_loss, val_loss

### Full training

In [None]:
training_completed = False
completion_reason = "unknown"

try:
    # Main training loop over epochs
    for epoch in range(start_epoch, MAX_EPOCHS):
        print(f"\n[TRAINING] EPOCH {epoch + 1}/{MAX_EPOCHS}")
        print("="*50)

        # Training
        train_loss = epoch_training(model, train_loader, loss_function, optimizer, scaler)
        if train_loss is None:
            continue

        # Validation
        val_loss = epoch_validation(model, val_loader, loss_function)
        if val_loss is None:
            continue

        # Evaluation
        stopped, avg_train_loss, avg_val_loss = epoch_evaluation(model, train_loss, val_loss, optimizer, scheduler, epoch, early_stopping)
        if stopped:
            training_completed = True
            completion_reason = "early_stopping"
            break

        # Save checkpoint
        save_checkpoint(model, optimizer, scheduler, epoch, avg_train_loss, avg_val_loss, CHECKPOINT_DIR)
        
        # Save best model
        is_best = avg_val_loss < best_val_loss
        if is_best:
            best_val_loss = avg_val_loss
            save_checkpoint(
                model, optimizer, scheduler, epoch, avg_train_loss, avg_val_loss, 
                CHECKPOINT_DIR, is_best=True, keep_last_n=KEEP_BEST_N
            )

        # Save plots every n epochs
        if (epoch + 1) % PLOT_SAVE_INTERVAL == 0:
            save_training_plots(train_losses_history, val_losses_history, epoch, CHECKPOINT_DIR)
        
        # Save loss history
        history = {
            'train_losses': train_losses_history,
            'val_losses': val_losses_history
        }
        with open(os.path.join(CHECKPOINT_DIR, 'loss_history.json'), 'w') as f:
            json.dump(history, f, indent=2)
    
    # If max epochs reached without early stopping
    if not training_completed:
        completion_reason = "max_epochs_reached"
        training_completed = True

except KeyboardInterrupt:
    print(f"\n[INFO] Training manually interrupted at epoch {epoch + 1}")
    completion_reason = "manual_interruption"
    training_completed = True

except Exception as e:
    print(f"\n[ERROR] Error during training: {e}")
    print("\nStack trace:")
    traceback.print_exc()
    print("Saving current state...")
    save_checkpoint(model, optimizer, scheduler, epoch, avg_train_loss, avg_val_loss, CHECKPOINT_DIR)
    completion_reason = "error"
    training_completed = True

## 5. Final results

In [None]:
final_epoch = epoch + 1 if 'epoch' in locals() else start_epoch

# Final report
print("[TRAINING] Completed")
print("="*50)
print(f"[INFO] Effective epochs: {final_epoch}")
print(f"[INFO] Best validation loss: {best_val_loss:.4f}")


# Completion reason
if completion_reason == "early_stopping":
    print(f"[INFO] Stopped with early stopping after {early_stopping.patience} epochs without improvement")
elif completion_reason == "max_epochs_reached":
    print("[INFO] All scheduled epochs completed")
elif completion_reason == "manual_interruption":
    print("[INFO] Training manually interrupted")
elif completion_reason == "error":
    print("[ERROR] Training stopped due to an error")


# Save final plots
if train_losses_history:
    save_training_plots(train_losses_history, val_losses_history, final_epoch-1, CHECKPOINT_DIR)

# Best models info
best_models_info_path = os.path.join(CHECKPOINT_DIR, 'best_models_info.json')
if os.path.exists(best_models_info_path):
    with open(best_models_info_path, 'r') as f:
        best_models = json.load(f)
    print(f"\n[INFO] Best {len(best_models)} models saved:")
    for i, model_info in enumerate(best_models):
        print(f"   {i+1}. Epoch {model_info['epoch']+1}: {model_info['filename']}")
        print(f"      Val Loss: {model_info['val_loss']:.4f}")