### Libraries and device

In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import pandas as pd
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from transformers import ResNetModel
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
import traceback

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


## 1. Define model

### Adaptive Layer

In [2]:
class MultiStageAdapter(nn.Module):
    """
    Modello che riceve embedding di 4 stage di ResNet e li converte in embedding per Qwen.
    """
    def __init__(self, stage_channels=[256, 512, 1024, 2048], out_dim=2048, hidden_multiplier=2):
        super().__init__()
        self.projections = nn.ModuleList([
            nn.Conv2d(c, out_dim, kernel_size=1) for c in stage_channels
        ])
        self.fusion = nn.Sequential(
            nn.Conv2d(out_dim * len(stage_channels), out_dim * hidden_multiplier, kernel_size=1),
            nn.GELU(),
            nn.Conv2d(out_dim * hidden_multiplier, out_dim, kernel_size=1)
        )
        # Definiamo il layer di Layer Normalization.
        self.final_norm = nn.LayerNorm(out_dim)

    def forward(self, stage0, stage1, stage2, stage3, target_seq_len=196):
        B, _, Ht, Wt = stage3.shape
        proj_feats = []
        for feat, proj in zip([stage0, stage1, stage2, stage3], self.projections):
            x = proj(feat)
            x = F.interpolate(x, size=(Ht, Wt), mode='bilinear', align_corners=False)
            proj_feats.append(x)

        fused = torch.cat(proj_feats, dim=1)  # (B, out_dim*4, Ht, Wt)
        fused = self.fusion(fused)           # (B, out_dim, Ht, Wt)
        seq = fused.flatten(2)               # (B, out_dim, Ht*Wt)
        seq = F.interpolate(seq, size=target_seq_len, mode='linear', align_corners=False)
        seq = seq.permute(0, 2, 1)           # (B, L, C)
        
        # Applichiamo la normalizzazione come ultimo passo.
        seq = self.final_norm(seq)
        
        return seq

### Composite Model

In [3]:
class CompositeModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = ResNetModel.from_pretrained("microsoft/resnet-50")
        self.adapter = MultiStageAdapter()

    def forward(self, pixel_values, target_seq_len=196):
        intermediate_outputs = {}

        def get_hook(idx):
            def hook(module, input, output):
                intermediate_outputs[f"stage_{idx}"] = output
            return hook

        hooks = []
        for idx, stage in enumerate(self.resnet.encoder.stages):
            hooks.append(stage.register_forward_hook(get_hook(idx)))

        intermediate_outputs.clear()
        _ = self.resnet(pixel_values)  # forward pass

        for h in hooks:
            h.remove()

        stage0, stage1, stage2, stage3 = (
            intermediate_outputs["stage_0"],
            intermediate_outputs["stage_1"],
            intermediate_outputs["stage_2"],
            intermediate_outputs["stage_3"],
        )

        projected = self.adapter(stage0, stage1, stage2, stage3, target_seq_len)
        return projected

## 2. Dataset and Collate function

In [4]:
class ImageQwenDataset(Dataset):
    def __init__(self, image_folder, qwen_folder, valid_image_names, image_size=384):
        self.image_folder = image_folder
        self.qwen_folder = qwen_folder
        self.image_files = sorted([f for f in os.listdir(image_folder) if f in valid_image_names])
        self.image_size = image_size

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_folder, img_name)
        qwen_path = os.path.join(self.qwen_folder, os.path.splitext(img_name)[0] + ".pt")

        image = Image.open(img_path).convert("RGB")
        W, H = image.size
        scale = self.image_size / min(W, H)
        new_W = int(W * scale)
        new_H = int(H * scale)
        image = image.resize((new_W, new_H), resample=Image.BICUBIC)
        W, H = image.size
        crop_size = min(W, H)
        left = (W - crop_size) // 2
        top = (H - crop_size) // 2
        image = image.crop((left, top, left + crop_size, top + crop_size))

        transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor()
        ])
        image = transform(image)

        qwen_embedding = torch.load(qwen_path)  # [L, D]

        if torch.isnan(qwen_embedding).any() or torch.isinf(qwen_embedding).any():
            #print(f"Warning: Corrupted data in {qwen_path}")
            return None
        # check if the image tensor has NaN or Inf values
        if torch.isnan(image).any() or torch.isinf(image).any():
            #print(f"Warning: Corrupted image in {img_path}")
            return None

        return image, qwen_embedding

In [5]:
def collate_fn(batch):
    # Filtra i campioni None
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None

    images, qwen_list = zip(*batch)
    images_batch = torch.stack(images)
    qwen_batch = torch.stack(qwen_list)
    return images_batch, qwen_batch

## 3. Saving and Plotting utility

In [6]:
def save_training_plots(train_losses_history, val_losses_history, epoch, checkpoint_dir):
    """
    Function that saves training and validation loss plots for each epoch.

    Args:
        train_losses_history: List of training losses for each epoch.
        val_losses_history: List of validation losses for each epoch.
        epoch: Current epoch number (for naming the plot).
        checkpoint_dir: Directory where the plot will be saved.
    """
    plt.figure(figsize=(12, 5))
    
    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(train_losses_history, label='Training Loss', color='blue')
    plt.plot(val_losses_history, label='Validation Loss', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Progress')
    plt.legend()
    plt.grid(True)
    
    # Zoomed-in plot for the last 10 epochs (if available)
    plt.subplot(1, 2, 2)
    start_idx = max(0, len(train_losses_history) - 10)
    epochs_range = range(start_idx, len(train_losses_history))
    if len(epochs_range) > 1:
        plt.plot(epochs_range, train_losses_history[start_idx:], label='Training Loss', color='blue')
        plt.plot(epochs_range, val_losses_history[start_idx:], label='Validation Loss', color='red')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Last 10 Epochs')
        plt.legend()
        plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(checkpoint_dir, f'training_progress_epoch_{epoch+1}.png'), dpi=300, bbox_inches='tight')
    plt.close()

In [7]:
def save_checkpoint(model, optimizer, scheduler, epoch, train_loss, val_loss, checkpoint_dir, is_best=False, keep_last_n=3):
    """
    Function that saves a checkpoint of the model handling the best N models.

    Args:
        model: The model to save.
        optimizer: The optimizer state to save.
        scheduler: The scheduler state to save.
        epoch: Current epoch number.
        train_loss: Training loss for the current epoch.
        val_loss: Validation loss for the current epoch.
        checkpoint_dir: Directory where the checkpoint will be saved.
        is_best: Boolean indicating if this is the best model so far.
        keep_last_n: Number of best models to keep.

    Returns:
        str: Filename of the saved checkpoint if it is the best model, otherwise None.
    """
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }
    
    # Save the last checkpoint
    torch.save(checkpoint, os.path.join(checkpoint_dir, 'last_checkpoint.pth'))
    
    if is_best:
        # Handle the best N models
        best_models_info_path = os.path.join(checkpoint_dir, 'best_models_info.json')
        
        # Load existing best models info if it exists
        if os.path.exists(best_models_info_path):
            with open(best_models_info_path, 'r') as f:
                best_models = json.load(f)
        else:
            best_models = []
        
        # Add the new model info
        new_model_info = {
            'epoch': epoch,
            'val_loss': val_loss,
            'filename': f'best_model_epoch_{epoch+1}_loss_{val_loss:.4f}.pth'
        }
        best_models.append(new_model_info)
        
        # Order by validation loss and keep only the best N
        best_models.sort(key=lambda x: x['val_loss'])
        
        # Remove worst model files if necessary
        if len(best_models) > keep_last_n:
            models_to_remove = best_models[keep_last_n:]
            for model_info in models_to_remove:
                old_file_path = os.path.join(checkpoint_dir, model_info['filename'])
                if os.path.exists(old_file_path):
                    os.remove(old_file_path)
            best_models = best_models[:keep_last_n]
        
        # Save the new checkpoint
        torch.save(checkpoint, os.path.join(checkpoint_dir, new_model_info['filename']))
        
        # Update the info file with the new best models
        with open(best_models_info_path, 'w') as f:
            json.dump(best_models, f, indent=2)
        
        print(f"New best model saved: {new_model_info['filename']}")
        return new_model_info['filename']
    
    return None

In [8]:
def load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None):
    """
    Function that loads a checkpoint.

    Args:
        checkpoint_path: Path to the checkpoint file.
        model: The model to load the state into.
        optimizer: The optimizer to load the state into (optional).
        scheduler: The scheduler to load the state into (optional).

    Returns:
        Tuple containing the epoch number, training loss, and validation loss.
    """
    print(f"Loading checkpoint from: {checkpoint_path}")
    
    # Set weights_only=False to allow loading of all necessary Python/NumPy objects.
    # Use this only because you trust the source of your checkpoint file.
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    epoch = checkpoint['epoch']
    # The .get() method is safer in case these keys don't exist
    train_loss = checkpoint.get('train_loss', float('inf'))
    val_loss = checkpoint.get('val_loss', float('inf'))
    
    print(f"Checkpoint loaded: Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    return epoch, train_loss, val_loss

## 4. Early Stopping

In [9]:
class EarlyStopping:
    """
    Class for early stopping during training.
    """

    def __init__(self, patience=20, min_delta=1e-4, restore_best_weights=True):
        """
        Early stopping to stop training when it no longer improves.

        Args:
            patience: Number of epochs to wait for improvement before stopping.
            min_delta: Minimum change in the monitored quantity to qualify as an improvement.
            restore_best_weights: Whether to restore model weights from the epoch with the best validation loss.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = float('inf')
        self.counter = 0
        self.best_weights = None
        self.early_stop = False
        
    def __call__(self, val_loss, model):
        """
        Check if training should be stopped.

        Args:
            val_loss: Current validation loss.
            model: The model whose weights may be restored if early stopping is triggered.

        Returns:
            bool: True if training should be stopped, False otherwise.
        """
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            if self.restore_best_weights:
                self.best_weights = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            self.counter += 1
            
        if self.counter >= self.patience:
            self.early_stop = True
            
        return self.early_stop
    
    def restore_weights(self, model):
        """
        Restore the weights of the model to the best weights found during training.
        """
        if self.best_weights is not None:
            best_weights_on_device = {k: v.to(model.state_dict()[k].device) 
                                    for k, v in self.best_weights.items()}
            model.load_state_dict(best_weights_on_device)

## 5. Training

### Settings

In [10]:
# Paths
IMAGES_FOLDER = os.path.join("..", "..", "CV_data", "miniImageNet")
QWEN_EMBEDDINGS_PATH = os.path.join("..", "..", "CV_data", "separated_embeddings", "qwen_384")
CSV_TRAIN = os.path.join("..", "assets", "train_images.csv")

# Image parameters
IMAGE_SIZE = 384

# Training parameters
BATCH_SIZE = 8
LEARNING_RATE_RESNET = 1e-3
LEARNING_RATE_ADAPTER = 1e-4
WEIGHT_DECAY = 1e-3
MAX_EPOCHS = 100     
CHECKPOINT_DIR = os.path.join("..", "models", "composite_model_checkpoints")
RESUME_TRAINING = True
KEEP_BEST_N = 3

# Early stopping parameters
EARLY_STOPPING_PATIENCE = 5
EARLY_STOPPING_MIN_DELTA = 1e-3

# Create the checkpoint directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True) 

### Dataset and Dataloader

In [11]:
images_csv = pd.read_csv(CSV_TRAIN)
images_list = images_csv['filename'].tolist()

dataset = ImageQwenDataset(IMAGES_FOLDER, QWEN_EMBEDDINGS_PATH, images_list, image_size=IMAGE_SIZE)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

In [12]:
model = CompositeModel().to(device)

for param in model.resnet.parameters():
    param.requires_grad = True
for param in model.adapter.parameters():
    param.requires_grad = True
    
optimizer = optim.AdamW([
    {'params': model.resnet.parameters(), 'lr': LEARNING_RATE_RESNET},
    {'params': model.adapter.parameters(), 'lr': LEARNING_RATE_ADAPTER}
], weight_decay=WEIGHT_DECAY)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-5
)
loss_function = nn.MSELoss()

scaler = torch.amp.GradScaler(device)

# Early stopping
early_stopping = EarlyStopping(
    patience=EARLY_STOPPING_PATIENCE, 
    min_delta=EARLY_STOPPING_MIN_DELTA, 
    restore_best_weights=True
)

# Tracking variables
best_val_loss = float('inf')
start_epoch = 0
train_losses_history = []
val_losses_history = []

### Resume training

In [13]:
if RESUME_TRAINING:
    last_checkpoint_path = os.path.join(CHECKPOINT_DIR, 'last_checkpoint.pth')
    if os.path.exists(last_checkpoint_path):
        try:
            start_epoch, last_train_loss, best_val_loss = load_checkpoint(
                last_checkpoint_path, model, optimizer, scheduler
            )
            start_epoch += 1
            
            # Load history if available
            history_path = os.path.join(CHECKPOINT_DIR, 'loss_history.json')
            if os.path.exists(history_path):
                with open(history_path, 'r') as f:
                    history = json.load(f)
                    train_losses_history = history.get('train_losses', [])
                    val_losses_history = history.get('val_losses', [])
            
            # Update early stopping with previous history
            if val_losses_history:
                early_stopping.best_loss = min(val_losses_history)
                best_epoch_idx = val_losses_history.index(early_stopping.best_loss)
                early_stopping.counter = len(val_losses_history) - 1 - best_epoch_idx
                print(f"Early stopping restored: best_loss={early_stopping.best_loss:.4f}, counter={early_stopping.counter}")

            print(f"Training resumed from epoch {start_epoch}")
            
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Start new training from scratch...")
            start_epoch = 0
    else:
        print("No checkpoint found. Starting new training from scratch...")

No checkpoint found. Starting new training from scratch...


### Epoch training and validation

In [14]:
def epoch_training(model, train_dataloader, loss_function, optimizer, scaler):
    model.train()
    train_losses = []
    train_loop = tqdm(train_dataloader, desc=f"Training")
    num_train_batches = 0
    
    for i, batch_data in enumerate(train_loop):
        if batch_data[0] is None:
            continue
            
        try:
            imgs, qwen_embs = batch_data
            imgs = imgs.to(device)
            qwen_embs = qwen_embs.to(device)

            num_train_batches += 1

            optimizer.zero_grad()

            with torch.amp.autocast(device):
                target_seq_len = qwen_embs.shape[1]
                pred = model(imgs, target_seq_len=target_seq_len)
                loss = loss_function(pred, qwen_embs)

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"\n Loss non valida al batch {i}, saltato.")
                continue

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_losses.append(loss.item())
            train_loop.set_postfix(loss=f"{loss.item():.4f}")
            
        except Exception as e:
            print(f"\n Errore batch {i}: {e}")
            continue
    
    mean_loss = np.mean(train_losses) if train_losses else float('inf')
    return mean_loss


In [15]:
def epoch_validation(model, val_dataloader, loss_function):
    model.eval()
    val_losses = []
    val_loop = tqdm(val_dataloader, desc=f"Validation")
    num_val_batches = 0
    
    with torch.no_grad():
        for i, batch_data in enumerate(val_loop):
            if batch_data[0] is None:
                continue

            imgs, qwen_embs = batch_data
            imgs = imgs.to(device)
            qwen_embs = qwen_embs.to(device)

            num_val_batches += 1

            with torch.amp.autocast(device):
                pred = model(imgs, target_seq_len=qwen_embs.shape[1])
                loss = loss_function(pred, qwen_embs)

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Invalid loss at batch {num_val_batches}, skipping.")
                continue

            val_losses.append(loss.item())
            val_loop.set_postfix(loss=f"{loss.item():.4f}")

    mean_loss = np.mean(val_losses) if val_losses else float('inf')

    return mean_loss

### Evaluation

In [16]:
def epoch_evaluation(model, train_loss, val_loss, optimizer, scheduler, epoch, early_stopping):
    """
    Function that evaluates the model at the end of each epoch, updates the learning rate,
    checks for early stopping, and saves the model checkpoint.

    Args:
        model: The model to evaluate.
        train_loss: Training loss for the current epoch.
        val_loss: Validation loss for the current epoch.
        optimizer: Optimizer used for training.
        scheduler: Learning rate scheduler.
        epoch: Current epoch number.
        early_stopping: EarlyStopping instance to check for stopping conditions.

    Returns:
        bool: True if training should stop, False otherwise.
        train_loss: Training loss for the current epoch.
        val_loss: Validation loss for the current epoch.
    """

    train_losses_history.append(train_loss)
    val_losses_history.append(val_loss)
    
    # Update scheduler
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_loss)
    new_lr = optimizer.param_groups[0]['lr']
    
    # Epoch results
    print(f"\n EPOCH RESULTS {epoch + 1}:")
    print(f"   Train Loss: {train_loss:.4f}")
    print(f"   Val Loss:   {val_loss:.4f}")
    print(f"   Learning Rate: {new_lr:.2e}")
    if old_lr != new_lr:
        print(f" LR reduced: {old_lr:.2e} → {new_lr:.2e}")
    
    # Early stopping check
    if early_stopping(val_loss, model):
        print(f"\n EARLY STOPPING ACTIVATED!")
        print(f"   Epoch: {epoch + 1}")
        print(f"   Best val_loss: {early_stopping.best_loss:.4f}")
        print(f"   Epochs without improvement: {early_stopping.patience}")
        
        if early_stopping.restore_best_weights:
            early_stopping.restore_weights(model)
            print(" Restored weights of the best model")
                    
        # Save the best model checkpoint
        save_checkpoint(model, optimizer, scheduler, epoch, train_loss, 
                        early_stopping.best_loss, CHECKPOINT_DIR, is_best=True, keep_last_n=KEEP_BEST_N)
        
        completion_reason = "early_stopping"
        training_completed = True
        return True, train_loss, val_loss
    
    print(f"Early stopping: {early_stopping.counter}/{early_stopping.patience}")
    return False, train_loss, val_loss

### Full training

In [None]:
training_completed = False
completion_reason = "unknown"

try:
    for epoch in range(start_epoch, MAX_EPOCHS):
        print(f"\n EPOCH {epoch + 1}/{MAX_EPOCHS}")
        print("="*50)

        # Training
        train_loss = epoch_training(model, train_loader, loss_function, optimizer, scaler)
        if train_loss is None:
            continue

        # Validation
        val_loss = epoch_validation(model, val_loader, loss_function)
        if val_loss is None:
            continue

        # Evaluation
        stopped, avg_train_loss, avg_val_loss = epoch_evaluation(model, train_loss, val_loss, optimizer, scheduler, epoch, early_stopping)
        if stopped:
            training_completed = True
            completion_reason = "early_stopping"
            break

        # Save checkpoint
        save_checkpoint(model, optimizer, scheduler, epoch, avg_train_loss, avg_val_loss, CHECKPOINT_DIR)
        
        # Save best model
        is_best = avg_val_loss < best_val_loss
        if is_best:
            best_val_loss = avg_val_loss
            save_checkpoint(
                model, optimizer, scheduler, epoch, avg_train_loss, avg_val_loss, 
                CHECKPOINT_DIR, is_best=True, keep_last_n=KEEP_BEST_N
            )
            print("New best model saved!")
        
        # Save plots every 5 epochs
        if (epoch + 1) % 5 == 0:
            save_training_plots(train_losses_history, val_losses_history, epoch, CHECKPOINT_DIR)
        
        # Save loss history
        history = {
            'train_losses': train_losses_history,
            'val_losses': val_losses_history
        }
        with open(os.path.join(CHECKPOINT_DIR, 'loss_history.json'), 'w') as f:
            json.dump(history, f, indent=2)
    

    if not training_completed:
        completion_reason = "max_epochs_reached"
        training_completed = True

except KeyboardInterrupt:
    print(f"\nTraining manually interrupted at epoch {epoch + 1}")
    completion_reason = "manual_interruption"
    training_completed = True

except Exception as e:
    print(f"\nError during training: {e}")
    print("\nStack trace:")
    traceback.print_exc()
    print("Saving current state...")
    save_checkpoint(model, optimizer, scheduler, epoch, avg_train_loss, avg_val_loss, CHECKPOINT_DIR)
    completion_reason = "error"
    training_completed = True


 EPOCH 1/100


Training:  13%|█▎        | 672/5100 [04:36<29:51,  2.47it/s, loss=2.6175]

## 6. Final results

In [None]:
final_epoch = epoch + 1 if 'epoch' in locals() else start_epoch

print("\n" + "="*50)
print("TRAINING COMPLETED")
print("="*50)
print(f"Effective epochs: {final_epoch}")
print(f"Best validation loss: {best_val_loss:.4f}")


if completion_reason == "early_stopping":
    print(f"Stopped with early stopping after {early_stopping.patience} epochs without improvement")
elif completion_reason == "max_epochs_reached":
    print("All scheduled epochs completed")
elif completion_reason == "manual_interruption":
    print("Training manually interrupted")
elif completion_reason == "error":
    print("Training stopped due to an error")


# Save final plots
if train_losses_history:
    save_training_plots(train_losses_history, val_losses_history, final_epoch-1, CHECKPOINT_DIR)

# Best models info
best_models_info_path = os.path.join(CHECKPOINT_DIR, 'best_models_info.json')
if os.path.exists(best_models_info_path):
    with open(best_models_info_path, 'r') as f:
        best_models = json.load(f)
    print(f"\n Best {len(best_models)} models saved:")
    for i, model_info in enumerate(best_models):
        print(f"   {i+1}. Epoch {model_info['epoch']+1}: {model_info['filename']}")
        print(f"      Val Loss: {model_info['val_loss']:.4f}")