In [26]:
MOVEMENT_DIM = 5
GENRE_DIM = 5
STYLE_DIM = 6
PRETRAINING = True

In [27]:


# get the images
from PIL import Image
import os, json, torch
from augmentation import augment_images_for_pretraining

MAIN_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VISION_DEVICE = MAIN_DEVICE
try:  # Check if running in Colab
    from google.colab import drive
    IN_COLAB = True
    print("running in Google Colab")
    mount_path = '/content/drive'
    if not os.path.exists(mount_path):
        drive.mount(mount_path)
    imgs_directory_path = '/content/drive/MyDrive/ArtEmbed'
    pretraining_metadata = '/content/drive/MyDrive/ArtEmbed/wikiart_metadata_with_pretraining_groundtruth.json'

except ImportError:  # Not Colab
    from pathlib import Path
    IN_COLAB = False

    try:
        BASE_DIR = Path(__file__).resolve().parent  # works in scripts
        print("running from laptop, probably")
        VISION_DEVICE = "cpu" # not even GPU mem on laptop
    except NameError:
        BASE_DIR = Path.cwd()  # fallback for notebooks
        print("running from IDAS, probably")

    imgs_directory_path = BASE_DIR / "paintings"
    pretraining_metadata = BASE_DIR / "metadata" / "wikiart_metadata_with_pretraining_groundtruth.json"



running from IDAS, probably


In [28]:

# --- Import libraries ---
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch

# --- Load BLIP-2 model and processor ---
model_name = "Salesforce/blip2-flan-t5-xl"
local_model_path =  BASE_DIR / "blip2_model"

if os.path.exists(local_model_path):
    print("Loading model from local directory...")
    processor = Blip2Processor.from_pretrained(local_model_path, use_fast=True)
    print("Processor loaded")
    blip2 = Blip2ForConditionalGeneration.from_pretrained(local_model_path)
    print("Model loaded")
else:
    print("Downloading model from Hugging Face...")
    processor = Blip2Processor.from_pretrained(model_name, use_fast=True)
    blip2 = Blip2ForConditionalGeneration.from_pretrained(model_name)

    # Save to local directory for future use
    processor.save_pretrained(local_model_path)
    blip2.save_pretrained(local_model_path)

blip2.to(VISION_DEVICE)  # Load model on CPU first if on computer
print(f"model sent to {VISION_DEVICE}")

# Freeze vision encoder to save memory; we are not training the vision encoder
for param in blip2.vision_model.parameters():
    param.requires_grad = False

Loading model from local directory...
Processor loaded


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded
model sent to cuda


In [29]:
import os
import json
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader


class ImagePathDataset(Dataset):
    """Dataset that loads images on-the-fly from paths."""
    
    def __init__(self, image_paths, targets, processor):
        """
        Args:
            image_paths: List of full paths to images
            targets: List of target vectors (numpy arrays or lists)
            processor: BLIP2 processor for image preprocessing
        """
        self.image_paths = image_paths
        self.targets = targets
        self.processor = processor
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image on-the-fly
        img = Image.open(self.image_paths[idx]).convert("RGB")
        
        # Process image
        pixel_values = self.processor(images=img, return_tensors="pt").pixel_values.squeeze(0)
        
        # Convert target to tensor
        target_tensor = torch.tensor(self.targets[idx], dtype=torch.float32)
        
        return pixel_values, target_tensor


def get_targets_from_metadata(image_ids, pretraining_metadata):
    """
    Get pretraining targets from metadata without loading images.
    
    Args:
        image_ids: List of image IDs
        pretraining_metadata: Dictionary with image metadata including pretraining_groundtruth
        
    Returns:
        targets: List of target vectors
        valid_indices: List of indices where targets were found
    """
    targets = []
    valid_indices = []
    
    for idx, img_id in enumerate(image_ids):
        metadata = pretraining_metadata.get(img_id, None)
        if metadata is not None:
            target = metadata.get("pretraining_groundtruth", None)
            if target is not None:
                targets.append(target)
                valid_indices.append(idx)
            else:
                print(f"Warning: No pretraining_groundtruth for image {img_id}")
        else:
            print(f"Warning: No metadata found for image {img_id}")
    
    print(f"Found {len(targets)} valid targets out of {len(image_ids)} images")
    return targets, valid_indices


def create_train_test_loaders(
    imgs_directory_path, 
    pretraining_metadata_path, 
    processor, 
    batch_size_train=16, 
    batch_size_test=32, 
    test_percentage=0.1,
    num_workers=0,  # Added for potential multiprocessing
):
    """
    Create train and test dataloaders that load images on-the-fly.
    
    Args:
        imgs_directory_path: Path to directory containing images
        pretraining_metadata_path: Path to JSON file with metadata
        processor: BLIP2 processor
        batch_size_train: Batch size for training
        batch_size_test: Batch size for testing
        test_percentage: Fraction of data to use for testing
        num_workers: Number of worker processes for data loading
        
    Returns:
        train_loader, test_loader: PyTorch DataLoaders
    """
    # --- Scan folder for images ---
    image_paths = []
    image_ids = []
    all_files = sorted(os.listdir(imgs_directory_path))
    
    for file_name in all_files:
        if file_name.lower().endswith((".jpg", ".jpeg", ".png")):
            path = os.path.join(imgs_directory_path, file_name)
            image_paths.append(path)
            image_ids.append(file_name.split("_")[0])
    
    print(f"Found {len(image_paths)} images. Sample ids: {image_ids[:10]}")
    
    # --- Load metadata ---
    with open(pretraining_metadata_path, 'r', encoding="utf-8") as f:
        metadata = json.load(f)
    print(f"Loaded metadata for {len(metadata)} paintings.")
    
    # --- Get targets and filter to valid images ---
    targets, valid_indices = get_targets_from_metadata(image_ids, metadata)
    
    # Filter image_paths to only those with valid targets
    image_paths = [image_paths[i] for i in valid_indices]
    image_ids = [image_ids[i] for i in valid_indices]
    
    if len(image_paths) == 0:
        raise ValueError("No valid images found with pretraining targets!")
    
    # --- Split train/test ---
    num_images = len(image_paths)
    num_test = int(num_images * test_percentage)
    
    # Random split
    indices = list(range(num_images))
    random.shuffle(indices)
    test_indices = set(indices[:num_test])
    
    train_paths = [image_paths[i] for i in range(num_images) if i not in test_indices]
    train_targets = [targets[i] for i in range(num_images) if i not in test_indices]
    
    test_paths = [image_paths[i] for i in range(num_images) if i in test_indices]
    test_targets = [targets[i] for i in range(num_images) if i in test_indices]
    
    print(f"Train: {len(train_paths)} images, Test: {len(test_paths)} images")
    
    # --- Create datasets ---
    train_dataset = ImagePathDataset(train_paths, train_targets, processor)
    test_dataset = ImagePathDataset(test_paths, test_targets, processor)
    
    # --- Create dataloaders ---
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size_train,
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size_test,
        shuffle=False,
        pin_memory=True,
        num_workers=num_workers
    )
    
    return train_loader, test_loader


if PRETRAINING:
    train_loader, test_loader = create_train_test_loaders(
        imgs_directory_path=imgs_directory_path,
        pretraining_metadata_path=pretraining_metadata,
        processor=processor,
        batch_size_train=16,
        batch_size_test=32,
        test_percentage=0.1,
    )

Found 102 images. Sample ids: ['000001', '000002', '000003', '000004', '000005', '000006', '000007', '000008', '000009', '000010']
Loaded metadata for 6021 paintings.
Found 102 valid targets out of 102 images
Train: 92 images, Test: 10 images


In [30]:


def print_gpu_mem(prefix="GPU"):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**2   # MB
        reserved = torch.cuda.memory_reserved() / 1024**2     # MB
        print(f"{prefix} Memory — Allocated: {allocated:.2f} MB | Reserved: {reserved:.2f} MB")
    else:
        print("CUDA not available")

In [37]:


from torch import nn

import torch
import torch.nn as nn


class BLIP2MultiHeadRegression(nn.Module):
    def __init__(self, blip2_model,
                 use_style_head=True,
                 train_qformer=False,
                 train_vision=False):
        super().__init__()

        # --- Core model ---
        self.blip2 = blip2_model
        self.use_style_head = use_style_head

        # --- Control what's trainable ---
        for param in self.blip2.vision_model.parameters():
            param.requires_grad = train_vision
        for param in self.blip2.qformer.parameters():
            param.requires_grad = train_qformer

        # --- Move modules to appropriate devices ---
        self.blip2.vision_model.to(VISION_DEVICE)
        self.blip2.qformer.to(MAIN_DEVICE)

        # query_tokens is an nn.Parameter → rewrap properly after moving
        self.blip2.query_tokens = nn.Parameter(
            self.blip2.query_tokens.to(MAIN_DEVICE)
        )

        # --- Config info ---
        num_query_tokens = blip2_model.config.num_query_tokens
        hidden_size = blip2_model.config.qformer_config.hidden_size
        feature_dim = num_query_tokens * hidden_size

        print(f"Num query tokens: {num_query_tokens}")
        print(f"Hidden size: {hidden_size}")
        print(f"Feature dim: {feature_dim}")
        print(f"Use style head: {use_style_head}")

        # --- Shared feature extraction ---
        self.shared_features = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.2)
        ).to(MAIN_DEVICE)

        # --- Movement and Genre heads ---
        self.movement_head = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, MOVEMENT_DIM)
        ).to(MAIN_DEVICE)

        self.genre_head = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, GENRE_DIM)
        ).to(MAIN_DEVICE)

        # --- Style head (always defined, but only used if enabled) ---
        self.style_head = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, STYLE_DIM)
        ).to(MAIN_DEVICE)

    def forward(self, images, return_features=False):
        """
        Forward pass with optional CPU/GPU split for vision model.

        Args:
            images: [batch_size, 3, H, W]
            return_features: If True, also return shared features

        Returns:
            dict with keys: 'movement', 'genre', 'style', 'combined', optionally 'features'
        """

        # --- Vision encoding ---
        images_vision = images.to(VISION_DEVICE)

        if self.training and next(self.blip2.vision_model.parameters()).requires_grad:
            vision_outputs = self.blip2.vision_model(pixel_values=images_vision)
        else:
            with torch.no_grad():
                vision_outputs = self.blip2.vision_model(pixel_values=images_vision)

        image_embeds = vision_outputs.last_hidden_state.to(MAIN_DEVICE)  # move to GPU

        # --- Q-Former processing ---
        query_tokens = self.blip2.query_tokens.expand(images.shape[0], -1, -1).to(MAIN_DEVICE)
        image_attention_mask = torch.ones(image_embeds.shape[:-1], dtype=torch.long).to(MAIN_DEVICE)

        query_outputs = self.blip2.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_attention_mask,
            return_dict=True,
        )

        # --- Flatten Q-Former output ---
        query_hidden_states = query_outputs.last_hidden_state
        flattened = query_hidden_states.flatten(start_dim=1)

        # --- Shared features ---
        shared_features = self.shared_features(flattened)

        # --- Regression heads ---
        movement_scores = torch.sigmoid(self.movement_head(shared_features))
        genre_scores = torch.sigmoid(self.genre_head(shared_features))
        style_scores = torch.sigmoid(self.style_head(shared_features))

        outputs = {
            'movement': movement_scores,
            'genre': genre_scores,
            'style': style_scores,
            'combined': torch.cat([movement_scores, genre_scores, style_scores], dim=1)
        }

        if return_features:
            outputs['features'] = shared_features

        return outputs


class WeightedMultiHeadLoss(nn.Module):
    def __init__(self, movement_weight=1.0, genre_weight=0.7, style_weight=0.8, use_style=True):
        super().__init__()
        self.movement_weight = movement_weight
        self.genre_weight = genre_weight
        self.style_weight = style_weight
        self.use_style = use_style

    def forward(self, predictions, targets, confidences=None):
        """
        Args:
            predictions: dict with 'movement', 'genre', 'style'
            targets: tensor [batch, total_dim] (already prepared)
            confidences: dict with confidence scores (optional)
        """
        # Split targets using global dims
        movement_target = targets[:, :MOVEMENT_DIM]
        genre_target    = targets[:, MOVEMENT_DIM : MOVEMENT_DIM + GENRE_DIM]
        style_target    = targets[:, MOVEMENT_DIM + GENRE_DIM :]

        mse = nn.MSELoss(reduction='none')

        # Movement loss
        movement_loss = mse(predictions['movement'], movement_target)
        if confidences is not None and 'movement' in confidences:
            movement_loss = movement_loss * confidences['movement']
        movement_loss = movement_loss.mean() * self.movement_weight

        # Genre loss
        genre_loss = mse(predictions['genre'], genre_target)
        if confidences is not None and 'genre' in confidences:
            genre_loss = genre_loss * confidences['genre']
        genre_loss = genre_loss.mean() * self.genre_weight

        total_loss = movement_loss + genre_loss
        loss_dict = {'movement': movement_loss.item(), 'genre': genre_loss.item()}

        # Style loss
        if self.use_style:
            style_loss = mse(predictions['style'], style_target)
            if confidences is not None and 'style' in confidences:
                style_loss = style_loss * confidences['style']
            style_loss = style_loss.mean() * self.style_weight
            total_loss += style_loss
            loss_dict['style'] = style_loss.item()

        loss_dict['total'] = total_loss.item()
        return total_loss, loss_dict

In [31]:

def augment_batch(image_paths, targets, processor):
    """
    Load images, create flipped versions, and return pixel values.
    
    Args:
        image_paths: List of file paths to images
        targets: List of target vectors
        processor: BLIP2 processor
        
    Returns:
        pixel_values: Tensor of shape [batch_size*2, 3, H, W]
        doubled_targets: List with doubled targets
    """
    images = []
    doubled_targets = []
    
    for img_path, target in zip(image_paths, targets):
        # Load image
        img = Image.open(img_path).convert("RGB")
        
        # Add original
        images.append(img)
        doubled_targets.append(target)
        
        # Add flipped
        flipped_img = ImageOps.mirror(img)
        images.append(flipped_img)
        doubled_targets.append(target)
    
    # Process all images at once
    pixel_values = processor(images=images, return_tensors="pt").pixel_values
    
    return pixel_values, doubled_targets

In [32]:
import time
def train_epoch(model, dataloader, optimizer, criterion, device, processor):
    model.train()
    total_loss = 0.0
    start_time = time.time()
    
    for step, (image_paths, targets) in enumerate(dataloader):
        # Augment the batch
        print(f"image_paths, {image_paths}")
        pixel_values, doubled_targets = augment_batch(image_paths, targets, processor)
        
        pixel_values = pixel_values.to(device, non_blocking=True)
        targets_tensor = torch.tensor(doubled_targets, dtype=torch.float32).to(device, non_blocking=True)
        
        optimizer.zero_grad()
        predictions = model(pixel_values)
        loss, loss_dict = criterion(predictions, targets_tensor)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        time_elapsed = time.time() - start_time
        
        if step % 10 == 0:
            num_images_so_far = step * len(image_paths) * 2  # *2 because of augmentation
            print(f"Step {step}/{len(dataloader)} | Images: {num_images_so_far} | Time: {time_elapsed:.2f}s | Loss: {loss.item():.4f}")
            print_gpu_mem()
    
    num_batches = len(dataloader)
    total_images = num_batches * dataloader.batch_size * 2  # *2 for augmentation
    avg_loss = total_loss / num_batches
    epoch_time = time.time() - start_time
    
    print(f"Epoch complete | Avg Loss: {avg_loss:.4f} | Total images: {total_images} | Time: {epoch_time:.2f}s")
    
    return avg_loss

In [33]:
def test_epoch(model, dataloader, criterion, device, processor):
    model.eval()
    total_loss = 0.0
    
    with torch.no_grad():
        for image_paths, targets in dataloader:
            # Augment the batch (same as training for consistency)
            pixel_values, doubled_targets = augment_batch(image_paths, targets, processor)
            
            pixel_values = pixel_values.to(device, non_blocking=True)
            targets_tensor = torch.tensor(doubled_targets, dtype=torch.float32).to(device, non_blocking=True)
            
            predictions = model(pixel_values)
            loss, _ = criterion(predictions, targets_tensor)
            total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Validation complete | Avg Loss: {avg_loss:.4f}")
    return avg_loss

In [34]:
from datetime import datetime
import os
def save_progress(model, save_path):

    os.makedirs(save_path, exist_ok=True)
    time_str = datetime.now().strftime("%Y%m%d_%H%M%S")  # e.g., 20251013_170512
    checkpoint_file = os.path.join(save_path, f"model_{time_str}.pt")

    state_dict = {
        "shared_features": model.shared_features.state_dict(),
        "movement_head": model.movement_head.state_dict(),
        "genre_head": model.genre_head.state_dict(),
        "style_head": model.style_head.state_dict(),
    }

    # Optionally include Q-Former if it's being trained
    if any(p.requires_grad for p in model.blip2.qformer.parameters()):
        state_dict["qformer"] = model.blip2.qformer.state_dict()
        
    torch.save(state_dict, checkpoint_file)
    print(f"✅ Saved fine-tuned modules to: {checkpoint_file}")

In [40]:
def train_model(model, train_loader, val_loader, optimizer, criterion, device, processor, 
                num_epochs=10, save_path=None, scheduler=None, early_stopping_patience=None):
    """
    Train the model with optional learning rate scheduling and early stopping.
    
    Args:
        model: Model to train
        train_loader: Training DataLoader
        val_loader: Validation DataLoader
        optimizer: Optimizer
        criterion: Loss function
        device: Device to train on
        processor: BLIP2 processor for augmentation
        num_epochs: Number of epochs to train
        save_path: Path to save checkpoints
        scheduler: Optional learning rate scheduler
        early_stopping_patience: Stop if validation loss doesn't improve for N epochs (None = disabled)
    
    Returns:
        history: Dictionary with training history
    """
    print("Starting training")
    
    history = {
        "train_loss": [],
        "val_loss": [],
        "learning_rates": []
    }
    
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    
    for epoch in range(1, num_epochs + 1):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch}/{num_epochs}")
        if scheduler:
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Learning rate: {current_lr:.2e}")
            history["learning_rates"].append(current_lr)
        print(f"{'='*60}")
        
        # Training
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device, processor)
        history["train_loss"].append(train_loss)
        
        # Validation
        if val_loader is not None:
            val_loss = test_epoch(model, val_loader, criterion, device, processor)
            history["val_loss"].append(val_loss)
            
            # Check for improvement
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                epochs_without_improvement = 0
                if save_path is not None:
                    best_model_path = save_path / "best_model.pt"
                    torch.save(model.state_dict(), best_model_path)
                    print(f"✓ New best model saved! Val loss: {val_loss:.4f}")
            else:
                epochs_without_improvement += 1
                print(f"No improvement for {epochs_without_improvement} epoch(s)")
            
            # Early stopping
            if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
                print(f"\nEarly stopping triggered after {epoch} epochs")
                print(f"Best validation loss: {best_val_loss:.4f}")
                break
            
            # Learning rate scheduling
            if scheduler:
                if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(val_loss)
                else:
                    scheduler.step()
        
        # Save checkpoint
        if save_path is not None:
            checkpoint_path = save_path / f"checkpoint_epoch_{epoch}.pt"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss if val_loader else None,
                'history': history
            }, checkpoint_path)
            print(f"Checkpoint saved: {checkpoint_path}")
    
    print("\n" + "="*60)
    print("Training complete")
    if val_loader is not None:
        print(f"Best validation loss: {best_val_loss:.4f}")
    print("="*60)
    
    return history


In [42]:

def pretrain_model():
    """
    Pretrain the model without style head.
    """
    print("="*60)
    print("PRETRAINING MODE (no style head)")
    print("="*60)
    
    # Model setup
    pretrain_model = BLIP2MultiHeadRegression(
        blip2,
        use_style_head=False,
        train_qformer=False,
        train_vision=False
    )
    
    # Loss and optimizer
    pretrain_criterion = WeightedMultiHeadLoss(
        movement_weight=1.0,
        genre_weight=0.7,
        use_style=False
    ).to(MAIN_DEVICE)
    
    optimizer = torch.optim.AdamW(
        pretrain_model.parameters(),
        lr=1e-4,
        weight_decay=0.01  # Added weight decay for regularization
    )
    
    # Learning rate scheduler (optional but recommended)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=3,
        min_lr=1e-6
    )
    
    # Save directory
    save_dir = BASE_DIR / "checkpoints"
    save_dir.mkdir(exist_ok=True, parents=True)
    
    # Train
    history = train_model(
        model=pretrain_model,
        train_loader=train_loader,
        val_loader=test_loader,
        optimizer=optimizer,
        criterion=pretrain_criterion,
        device=MAIN_DEVICE,
        processor=processor,
        num_epochs=10,
        save_path=save_dir,
        scheduler=scheduler,
        early_stopping_patience=5  # Stop if no improvement for 5 epochs
    )
    
    # Save final history
    import json
    history_path = save_dir / "training_history.json"
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=2)
    print(f"\nTraining history saved to {history_path}")
    
    return history


if PRETRAINING:
    pretrain_model()

PRETRAINING MODE (no style head)
Num query tokens: 32
Hidden size: 768
Feature dim: 24576
Use style head: False
Starting training

Epoch 1/10
Learning rate: 1.00e-04


AttributeError: 'Tensor' object has no attribute 'read'

In [None]:


import os
import glob
import torch
from transformers import Blip2Processor
from augmentation import augment_annotated_images

# --- Global variables for lazy loading ---
_model, _processor = None, None

def get_latest_checkpoint(checkpoint_dir="./checkpoints"):
    """Return the latest checkpoint path or None if none exist."""
    checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_epoch_*.pt"))
    if not checkpoint_files:
        return None
    # Sort by epoch number
    checkpoint_files.sort(key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split("_")[-1]))
    return checkpoint_files[-1]

import os
from datetime import datetime
import torch

BASE_DIR = "/path/to/your/project"  # replace with your BASE_DIR

def save_model_checkpoint(model):
    checkpoint_dir = os.path.join(BASE_DIR, ".checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)
    # Generate abbreviated timestamp (YYMMDD_HHMMSS)
    timestamp = datetime.now().strftime("%y%m%d_%H%M%S")

    # Build checkpoint path
    checkpoint_path = os.path.join(checkpoint_dir, f"model_{timestamp}.pt")

    # Save model state
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Model saved to {checkpoint_path}")


def initialize_model_for_webaccess():
    """
    Initialize the BLIP2 multi-head regression model and processor.
    Loads the latest checkpoint if available.
    """
    model = BLIP2MultiHeadRegression(
        blip2,
        use_style_head=True,
        train_qformer=True,
        train_vision=False
    )

    latest_ckpt = get_latest_checkpoint()
    if latest_ckpt is not None:
        model.load_state_dict(torch.load(latest_ckpt, map_location="cpu"))
        print(f"Loaded model weights from {latest_ckpt}")
    else:
        print("No checkpoint found, using untrained weights.")

    model.eval()

    processor = Blip2Processor.from_pretrained(local_model_path, use_fast=True)
    return model, processor

def get_model_and_processor():
    """
    Lazy-load the model and processor.
    """
    global _model, _processor
    if _model is None or _processor is None:
        _model, _processor = initialize_model_for_webaccess()
        print(f"Model and processor ready")
    return _model, _processor

def forward_images(images):
    model, processor = get_model_and_processor()
    model.eval()

    # Process all images as a batch
    inputs = processor(images=images, return_tensors="pt").pixel_values

    with torch.no_grad():
        outputs = model(inputs)

    embeddings = outputs["combined"].cpu().tolist()
    print(f"Forward pass completed on {len(images)} images")
    return embeddings

def backward_single_image(image, target, lr=1e-5):
    """
    Perform a single training step on one image.
    """
    model, processor = get_model_and_processor()
    criterion = WeightedMultiHeadLoss(movement_weight=1.0, genre_weight=1.0, use_style=True).to(MAIN_DEVICE)

    augmented_images, augmented_targets = augment_annotated_images([image], [target])
    print(f"Augmented to {len(augmented_images)} images for training")
    

    model.train()
    inputs = processor(images=augmented_images, return_tensors="pt").pixel_values.to(MAIN_DEVICE)
    target_tensor = torch.tensor(augmented_targets, dtype=torch.float32).to(MAIN_DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    optimizer.zero_grad()

    outputs = model(inputs)
    print("Model outputs obtained", outputs.keys())
    loss, loss_dict = criterion(outputs, target_tensor)

    print("Backward pass with loss:", loss.item(), loss_dict)

    loss.backward()
    optimizer.step()


    return loss.item(), loss_dict