In [3]:
# =============================================================================
# CELL 1: Environment Setup and Installations
# =============================================================================

# Check GPU availability
!nvidia-smi

Wed Jul  2 15:44:26 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.86.15              Driver Version: 570.86.15      CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA H200                    On  |   00000000:19:00.0 Off |                    0 |
| N/A   44C    P0             80W /  700W |       1MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [5]:
# =============================================================================
# CELL 2: Imports and Basic Setup
# =============================================================================

import os, random, math, json
from pathlib import Path
from typing import List
import numpy as np
import pandas as pd
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import convert_image_dtype
from scipy.stats import spearmanr, pearsonr

from transformers import AutoModel, AutoVideoProcessor
from sentence_transformers import SentenceTransformer
import decord
decord.bridge.set_bridge('torch')
from decord import VideoReader

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def rank_corr(a: List[float], b: List[float]):
    return spearmanr(a, b).correlation, pearsonr(a, b)[0]

set_seed()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Running on', device)
print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

Running on cuda
GPU Memory: 150.1 GB


In [6]:
# =============================================================================
# CELL 3: Data Configuration and Analysis
# =============================================================================

# Data paths
DATA_DIR   = "TaobaoAIGC/data"
TRAIN_CSV  = f"{DATA_DIR}/train/labels/train_labels.csv"
VAL_CSV    = f"{DATA_DIR}/val/labels/val_labels.csv"  # This will be our test set for submission
TRAIN_VID  = f"{DATA_DIR}/train/videos"
VAL_VID    = f"{DATA_DIR}/val/videos"

# Analyze data structure
print("Analyzing data structure...")
train_df = pd.read_csv(TRAIN_CSV)
val_df = pd.read_csv(VAL_CSV)

print(f"Training data columns: {train_df.columns.tolist()}")
print(f"Training data shape: {train_df.shape}")
print(f"Validation data columns: {val_df.columns.tolist()}")
print(f"Validation data shape: {val_df.shape}")

# Check if validation has ground truth
has_mos_in_val = 'Overall_MOS' in val_df.columns
print(f"Validation has ground truth MOS: {has_mos_in_val}")

if has_mos_in_val:
    print("Using 80-20 train-val split from training data")
else:
    print("Validation set has no ground truth - will be used for final submission")

Analyzing data structure...
Training data columns: ['Prompt', 'Overall_MOS', 'Traditional_MOS', 'Alignment_MOS', 'Aesthetic_MOS', 'Temporal_MOS', 'video_name']
Training data shape: (4000, 7)
Validation data columns: ['Prompt', 'video_name']
Validation data shape: (500, 2)
Validation has ground truth MOS: False
Validation set has no ground truth - will be used for final submission


In [7]:
# =============================================================================
# CELL 4: Updated Dataset Class for Both Train and Inference
# =============================================================================

class TaobaoVDDataset(Dataset):
    MOS_COLS = ['Traditional_MOS','Alignment_MOS','Aesthetic_MOS','Temporal_MOS','Overall_MOS']
    
    def __init__(self, csv_file, video_dir, processor, num_frames=64, mode='train'):
        self.df = pd.read_csv(csv_file)
        self.video_dir = Path(video_dir)
        self.processor = processor
        self.num_frames = num_frames
        self.mode = mode
        
        # Check if we have ground truth labels
        self.has_labels = all(col in self.df.columns for col in self.MOS_COLS)
        print(f"Dataset mode: {mode}, Has labels: {self.has_labels}, Samples: {len(self.df)}")
    
    def _sample(self, path):
        vr = VideoReader(str(path))
        idx = np.linspace(0, len(vr)-1, self.num_frames).astype(int)
        idx = np.clip(idx, 0, len(vr)-1)
        clip = vr.get_batch(idx).permute(0,3,1,2)  # N,C,H,W
        return convert_image_dtype(clip, torch.float32)
    
    def __len__(self): 
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        frames = self._sample(self.video_dir / row["video_name"])
        
        result = {
            "clip": frames, 
            "prompt": row["Prompt"], 
            "video_name": row["video_name"]
        }
        
        if self.has_labels:
            # Training mode - include labels
            labels = (
                pd.to_numeric(row[self.MOS_COLS], errors="coerce")
                  .fillna(row["Overall_MOS"] if "Overall_MOS" in row else 3.0)
                  .astype(np.float32)
                  .values
            )
            result["labels"] = torch.tensor(labels, dtype=torch.float32)
        else:
            # Inference mode - no labels available
            result["labels"] = torch.zeros(5, dtype=torch.float32)  # dummy labels
            
        return result

In [8]:
# =============================================================================
# CELL 5: OPTIMIZED Memory-Efficient Collate Function for H200
# =============================================================================

class OptimizedGPUCollate:
    """Ultra memory-efficient collate with chunked processing and immediate cleanup"""
    def __init__(self, vproc, tenc, device='cuda', max_frames=64):
        self.vproc = vproc
        self.tenc = tenc
        self.device = device
        self.max_frames = max_frames

    def __call__(self, batch):
        prompts = [b["prompt"] for b in batch]
        video_names = [b["video_name"] for b in batch]
        labels = torch.stack([b["labels"] for b in batch])

        # Process videos with aggressive memory management
        processed_videos = []
        
        for i, b in enumerate(batch):
            clip = b["clip"]  # Shape: (T, C, H, W)
            
            try:
                # Limit frames to prevent OOM
                if clip.shape[0] > self.max_frames:
                    indices = torch.linspace(0, clip.shape[0]-1, self.max_frames).long()
                    clip = clip[indices]
                
                # Convert to numpy immediately to save GPU memory
                frames_np = []
                for frame_idx in range(clip.shape[0]):
                    frame = clip[frame_idx].permute(1, 2, 0)  # HWC
                    frame_np = (frame.cpu().numpy() * 255).astype(np.uint8)
                    frames_np.append(frame_np)
                
                # Clear clip tensor immediately
                del clip
                
                # Process with V-JEPA2 using CPU inputs
                with torch.no_grad():
                    processed = self.vproc(frames_np, return_tensors="pt")
                    
                    # Extract video tensor and move to device with proper dtype
                    if "pixel_values_videos" in processed:
                        video_tensor = processed["pixel_values_videos"][0]
                    elif "pixel_values" in processed:
                        video_tensor = processed["pixel_values"][0]
                    else:
                        key = next(iter(processed))
                        video_tensor = processed[key][0]
                    
                    # Convert to half precision to save memory
                    video_tensor = video_tensor.half()
                
                # Clear intermediate data
                del processed, frames_np
                processed_videos.append(video_tensor)
                
                # Aggressive cleanup after each video
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    
            except Exception as e:
                print(f"Error processing video {i}: {e}")
                # Create smaller dummy tensor to maintain batch structure
                dummy = torch.zeros((3, min(self.max_frames, 32), 224, 224), dtype=torch.half)
                processed_videos.append(dummy)
                torch.cuda.empty_cache()

        # Stack videos efficiently
        try:
            pixel_values_videos = torch.stack(processed_videos)
        except RuntimeError as e:
            print(f"Error stacking videos: {e}")
            # Fallback: process with even smaller dummy tensors
            pixel_values_videos = torch.zeros((len(batch), 3, 32, 224, 224), dtype=torch.half)
        
        # Clear processed videos list
        del processed_videos

        # Text processing with memory optimization
        with torch.no_grad():
            # Process text in smaller chunks if batch is large
            if len(prompts) > 4:
                text_embeddings = []
                for i in range(0, len(prompts), 2):
                    chunk = prompts[i:i+2]
                    chunk_emb = self.tenc.encode(
                        chunk, 
                        convert_to_tensor=True, 
                        normalize_embeddings=True,
                        device=self.device,
                        batch_size=len(chunk)
                    )
                    text_embeddings.append(chunk_emb)
                text_emb = torch.cat(text_embeddings, dim=0)
                del text_embeddings
            else:
                text_emb = self.tenc.encode(
                    prompts, 
                    convert_to_tensor=True, 
                    normalize_embeddings=True,
                    device=self.device,
                    batch_size=len(prompts)
                )

        # Final cleanup
        torch.cuda.empty_cache()

        return {
            "pixel_values_videos": pixel_values_videos,
            "text_emb": text_emb,
            "labels": labels,
            "video_names": video_names,
        }

In [9]:
# =============================================================================
# CELL 6: IMPROVED Model Architecture with Strategic Layer Freezing
# =============================================================================
VJEPA_ID = "facebook/vjepa2-vitg-fpc64-384-ssv2"   # ViT-G for H200
TEXT_ID  = "BAAI/bge-large-en-v1.5"

class OptimizedMOSHead(nn.Module):
    def __init__(self, dv, dt, h=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dv + dt),
            nn.Linear(dv + dt, h), 
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(h, h//2), 
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(h//2, 5)  # 4 sub-MOS + Overall
        )
    
    def forward(self, v, t):
        return self.net(torch.cat([v, t], dim=-1))

class VQualAModel(nn.Module):
    def __init__(self, device='cuda'):
        super().__init__()
        
        # Video processor
        self.vproc = AutoVideoProcessor.from_pretrained(VJEPA_ID)
        
        # Video encoder - KEEP ALL IN FP32 to avoid gradient issues
        self.venc = AutoModel.from_pretrained(
            VJEPA_ID,
            torch_dtype=torch.float32,  # Use FP32 for proper gradients
            output_hidden_states=True,
            attn_implementation="sdpa",
        )
        
        # IMPROVED STRATEGIC LAYER FREEZING - Much more aggressive
        frozen_count = 0
        trainable_count = 0
        total_layers = 0
        
        # Count total layers first
        for name, p in self.venc.named_parameters():
            if "encoder.layer." in name:
                layer_match = name.split("encoder.layer.")[1].split(".")[0]
                if layer_match.isdigit():
                    total_layers = max(total_layers, int(layer_match) + 1)
        
        print(f"Total transformer layers detected: {total_layers}")
        
        # Freeze bottom 85% of layers + embeddings + pooler
        freeze_until_layer = int(total_layers * 0.85)  # Freeze bottom 85%
        print(f"Freezing layers 0-{freeze_until_layer-1}, training layers {freeze_until_layer}-{total_layers-1}")
        
        for name, p in self.venc.named_parameters():
            should_freeze = False
            
            # Always freeze embeddings and pooler
            if "embeddings" in name or "pooler" in name:
                should_freeze = True
            
            # Freeze bottom 85% of transformer layers
            elif "encoder.layer." in name:
                layer_match = name.split("encoder.layer.")[1].split(".")[0]
                if layer_match.isdigit():
                    layer_num = int(layer_match)
                    if layer_num < freeze_until_layer:
                        should_freeze = True
            
            # Apply freezing
            if should_freeze:
                p.requires_grad = False
                frozen_count += 1
            else:
                p.requires_grad = True
                trainable_count += 1

        print(f"IMPROVED VJEPA2 Layer Status:")
        print(f"  Frozen parameters: {frozen_count}")
        print(f"  Trainable parameters: {trainable_count}")
        print(f"  Trainable ratio: {trainable_count/(frozen_count+trainable_count)*100:.1f}%")
        print(f"  Memory savings: ~{(frozen_count/(frozen_count+trainable_count))*100:.0f}% reduction in gradient computation")

        # DISABLE gradient checkpointing to fix gradient flow
        if hasattr(self.venc, 'gradient_checkpointing_enable'):
            self.venc.gradient_checkpointing_disable()
        
        # Text encoder on GPU
        self.tenc = SentenceTransformer(TEXT_ID, device=device)

        # Get dimensions
        dv = self.venc.config.hidden_size
        dt = self.tenc.get_sentence_embedding_dimension()
        
        # Prediction head
        self.head = OptimizedMOSHead(dv, dt, h=512)
        print(f"Model dimensions: Video={dv}, Text={dt}")

    def forward(self, pixel_values_videos, text_emb):
        """
        Clean forward pass without gradient checkpointing
        """
        # Keep everything in FP32 for stable gradients
        pixel_values_videos = pixel_values_videos.to(self.venc.device, dtype=torch.float32)
        
        # Simple forward pass without gradient checkpointing tricks
        outputs = self.venc(pixel_values_videos=pixel_values_videos, output_hidden_states=True)
        
        # Get CLS token
        cls_token = outputs.last_hidden_state[:, 0]  # Already FP32
        
        # Ensure text embeddings match
        text_emb = text_emb.to(cls_token.device, dtype=cls_token.dtype)
        
        # MOS prediction
        mos_scores = self.head(cls_token, text_emb)
        
        return mos_scores

In [10]:
# =============================================================================
# CELL 7: Training Functions with Hybrid Loss Function
# =============================================================================

import gc
import os
import wandb
import torch.nn.functional as F

def rank_corr(a: List[float], b: List[float]):
    from scipy.stats import spearmanr, pearsonr
    if len(a) == 0 or len(b) == 0:
        return 0.0, 0.0
    return spearmanr(a, b).correlation, pearsonr(a, b)[0]

def ultra_memory_cleanup():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

import numpy as np

class AdaptiveLossManager:
    def __init__(self, initial_alpha=0.7, initial_beta=0.3):
        self.alpha = initial_alpha
        self.beta = initial_beta
        self.mae_history = []
        self.ranking_history = []
        self.adaptation_rate = 0.1
        
    def update_weights(self, mae_loss, ranking_loss):
        self.mae_history.append(mae_loss)
        self.ranking_history.append(ranking_loss)
        
        if len(self.mae_history) > 10:
            self.mae_history = self.mae_history[-10:]
            self.ranking_history = self.ranking_history[-10:]
        
        if len(self.mae_history) >= 3:
            mae_trend = np.mean(self.mae_history[-3:]) / np.mean(self.mae_history[-6:-3]) if len(self.mae_history) >= 6 else 1.0
            ranking_trend = np.mean(self.ranking_history[-3:]) / np.mean(self.ranking_history[-6:-3]) if len(self.ranking_history) >= 6 else 1.0
            
            if mae_trend > 1.1 and ranking_trend < 0.9:
                self.alpha = max(0.5, self.alpha - self.adaptation_rate)
                self.beta = min(0.5, self.beta + self.adaptation_rate)
            elif ranking_trend > 1.1 and mae_trend < 0.9:
                self.alpha = min(0.8, self.alpha + self.adaptation_rate)
                self.beta = max(0.2, self.beta - self.adaptation_rate)
        
        total = self.alpha + self.beta
        self.alpha = self.alpha / total
        self.beta = self.beta / total
    
    def get_weights(self):
        return self.alpha, self.beta

loss_manager = AdaptiveLossManager(initial_alpha=0.7, initial_beta=0.3)

def hybrid_loss_fn(pred, target, loss_manager=None, epoch=0):
    device = pred.device
    batch_size = pred.shape[0]
    
    smooth_l1_loss = F.smooth_l1_loss(pred, target, beta=0.1)
    
    ranking_loss = torch.tensor(0.0, device=device)
    margin = 0.2
    
    if batch_size > 1:
        total_pairs = 0
        for i in range(batch_size):
            for j in range(i + 1, batch_size):
                for dim in range(pred.shape[1]):
                    pred_diff = pred[i, dim] - pred[j, dim]
                    target_diff = target[i, dim] - target[j, dim]
                    
                    if target_diff > 0.1:
                        ranking_loss += torch.clamp(margin - pred_diff, min=0)
                    elif target_diff < -0.1:
                        ranking_loss += torch.clamp(margin + pred_diff, min=0)
                    
                    total_pairs += 1
        
        if total_pairs > 0:
            ranking_loss = ranking_loss / total_pairs
    
    scale_weights = torch.where(target < 2.5, 1.5,
                               torch.where(target > 4.0, 1.5, 1.0))
    scale_loss = F.mse_loss(pred * scale_weights, target * scale_weights)
    
    if loss_manager is not None:
        alpha, beta = loss_manager.get_weights()
        loss_manager.update_weights(smooth_l1_loss.item(), ranking_loss.item())
    else:
        alpha, beta = 0.7, 0.3
    
    gamma = 0.1
    total_loss = alpha * smooth_l1_loss + beta * ranking_loss + gamma * scale_loss
    
    return total_loss, {
        'total_loss': total_loss.item(),
        'mae_loss': smooth_l1_loss.item(),
        'ranking_loss': ranking_loss.item(),
        'scale_loss': scale_loss.item(),
        'alpha': alpha,
        'beta': beta
    }

def train_epoch(model, loader, optimizer, scaler, accumulation_steps=16, epoch=0):
    model.train()
    total_loss = 0
    total_mae = 0
    total_ranking = 0
    num_batches = 0
    
    optimizer.zero_grad(set_to_none=True)
    
    for i, batch in enumerate(loader):
        try:
            with torch.amp.autocast('cuda', dtype=torch.float16):
                outputs = model(batch['pixel_values_videos'], batch['text_emb'])
                loss, loss_components = hybrid_loss_fn(outputs, batch['labels'].to(outputs.device), loss_manager, epoch)
                loss = loss / accumulation_steps
            
            del outputs
            scaler.scale(loss).backward()
            
            if (i + 1) % accumulation_steps == 0:
                scaler.unscale_(optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)
                
                current_lrs = [group['lr'] for group in optimizer.param_groups]
                wandb.log({
                    "train/gradient_norm": grad_norm.item(),
                    "train/mae_loss": loss_components['mae_loss'],
                    "train/ranking_loss": loss_components['ranking_loss'],
                    "train/lr_text": current_lrs[0] if len(current_lrs) > 0 else 0,
                    "train/lr_video": current_lrs[1] if len(current_lrs) > 1 else 0,
                    "train/lr_head": current_lrs[2] if len(current_lrs) > 2 else 0,
                    "train/step": epoch * len(loader) + i
                })
                
                ultra_memory_cleanup()
            
            total_loss += loss_components['total_loss']
            total_mae += loss_components['mae_loss']
            total_ranking += loss_components['ranking_loss']
            num_batches += 1
            
            del loss
            
            if i % 2 == 0:
                ultra_memory_cleanup()
            
            if i % 50 == 0:
                allocated = torch.cuda.memory_allocated() / 1e9
                avg_loss = total_loss / max(num_batches, 1)
                current_lrs = [group['lr'] for group in optimizer.param_groups]
                print(f"    Batch {i}/{len(loader)}, Loss: {avg_loss:.4f}, LRs: {current_lrs[0]:.2e}/{current_lrs[1]:.2e}/{current_lrs[2]:.2e}, Memory: {allocated:.1f}GB")
                
                if allocated > 120:
                    ultra_memory_cleanup()
                    
        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                optimizer.zero_grad(set_to_none=True)
                ultra_memory_cleanup()
                wandb.log({"train/oom_errors": 1, "train/step": epoch * len(loader) + i})
                continue
            else:
                raise e
    
    if num_batches % accumulation_steps != 0:
        scaler.unscale_(optimizer)
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        optimizer.zero_grad(set_to_none=True)
    
    ultra_memory_cleanup()
    
    avg_loss = total_loss / max(num_batches, 1) if num_batches > 0 else 0.0
    
    current_lrs = [group['lr'] for group in optimizer.param_groups]
    wandb.log({
        "train/epoch_loss": avg_loss,
        "train/epoch_mae": total_mae / max(num_batches, 1),
        "train/epoch_ranking": total_ranking / max(num_batches, 1),
        "train/epoch": epoch,
        "train/batches_processed": num_batches,
        "train/final_lr_text": current_lrs[0] if len(current_lrs) > 0 else 0,
        "train/final_lr_video": current_lrs[1] if len(current_lrs) > 1 else 0,
        "train/final_lr_head": current_lrs[2] if len(current_lrs) > 2 else 0,
    })
    
    return avg_loss

def evaluate(model, loader, epoch=0):
    model.eval()
    
    all_predictions = [[] for _ in range(5)]
    all_ground_truth = [[] for _ in range(5)]
    eval_loss = 0
    num_eval_batches = 0
    
    with torch.no_grad():
        for i, batch in enumerate(loader):
            try:
                with torch.amp.autocast('cuda', dtype=torch.float16):
                    outputs = model(batch['pixel_values_videos'], batch['text_emb'])
                    loss, _ = hybrid_loss_fn(outputs, batch['labels'].to(outputs.device), loss_manager, epoch)
                
                for dim in range(5):
                    all_ground_truth[dim].extend(batch['labels'][:, dim].cpu().tolist())
                    all_predictions[dim].extend(outputs[:, dim].cpu().tolist())
                
                eval_loss += loss.item()
                num_eval_batches += 1
                
                del outputs, loss
                
                if i % 10 == 0:
                    ultra_memory_cleanup()
                
            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    ultra_memory_cleanup()
                    wandb.log({"eval/oom_errors": 1, "eval/epoch": epoch})
                    continue
                else:
                    raise e
    
    total_srocc = 0
    total_plcc = 0
    
    for dim in range(5):
        if len(all_ground_truth[dim]) > 0:
            srocc, plcc = rank_corr(all_ground_truth[dim], all_predictions[dim])
            total_srocc += srocc if not np.isnan(srocc) else 0
            total_plcc += plcc if not np.isnan(plcc) else 0
    
    final_score = (total_srocc + total_plcc) / 10
    avg_eval_loss = eval_loss / num_eval_batches if num_eval_batches > 0 else 0.0
    
    wandb.log({
        "eval/loss": avg_eval_loss,
        "eval/final_score": final_score,
        "eval/total_srocc": total_srocc,
        "eval/total_plcc": total_plcc,
        "eval/epoch": epoch,
        "eval/num_samples": len(all_ground_truth[0])
    })
    
    return final_score

In [11]:
# =============================================================================
# CELL 8: Data Preparation Strategy
# =============================================================================

# Check if validation has ground truth
val_df = pd.read_csv(VAL_CSV)
has_val_labels = 'Overall_MOS' in val_df.columns

if has_val_labels:
    # Case 1: Validation has labels - use it for validation
    print("Validation set has ground truth - using for model validation")
    
    # Use all training data for training
    train_dataset = TaobaoVDDataset(TRAIN_CSV, TRAIN_VID, None, mode='train')
    val_dataset = TaobaoVDDataset(VAL_CSV, VAL_VID, None, mode='val')
    
else:
    # Case 2: Validation has no labels - split training data
    print("Validation set has no ground truth - splitting training data")
    
    train_df = pd.read_csv(TRAIN_CSV)
    
    # 80-20 split
    split_idx = int(0.8 * len(train_df))
    train_subset = train_df.iloc[:split_idx]
    val_subset = train_df.iloc[split_idx:]
    
    # Save temporary CSV files for the split
    train_subset.to_csv('temp_train.csv', index=False)
    val_subset.to_csv('temp_val.csv', index=False)
    
    train_dataset = TaobaoVDDataset('temp_train.csv', TRAIN_VID, None, mode='train')
    val_dataset = TaobaoVDDataset('temp_val.csv', TRAIN_VID, None, mode='val')

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")


Validation set has no ground truth - splitting training data
Dataset mode: train, Has labels: True, Samples: 3200
Dataset mode: val, Has labels: True, Samples: 800
Training samples: 3200
Validation samples: 800


In [12]:
# =============================================================================
# CELL 9: Model Initialization with Clean Checkpoint Loading
# =============================================================================

def load_clean_checkpoint_and_continue_training(checkpoint_path, device='cuda'):
    """Load clean checkpoint from model_final_clean.pt and prepare for continued training"""
    print(f"Loading clean checkpoint from: {checkpoint_path}")
    
    # Initialize fresh model first
    model = VQualAModel(device=device).to(device)
    
    # Load clean checkpoint (following the same pattern as evaluation code)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    best_score = checkpoint['best_score']
    
    print(f"✅ Clean model loaded with score: {best_score:.4f}")
    
    # Extract epoch info for continued training
    previous_epoch = checkpoint.get('epoch', 5)  # Get the epoch from clean checkpoint
    start_epoch = previous_epoch + 1  # Continue from next epoch
    
    print(f"Previous training completed at epoch: {previous_epoch}")
    print(f"Will resume training from epoch: {start_epoch}")
    
    return model, start_epoch, best_score

set_seed(42)

# Load clean checkpoint (following the same pattern as your evaluation code)
CHECKPOINT_PATH = "model_final_clean.pt"
model, start_epoch, previous_best_score = load_clean_checkpoint_and_continue_training(CHECKPOINT_PATH, device)

# Log model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model loaded on {device}")
print(f"Video encoder: {VJEPA_ID}")
print(f"Text encoder: {TEXT_ID}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"GPU Memory after model load: {torch.cuda.memory_allocated()/1e9:.1f} GB")

Loading clean checkpoint from: model_final_clean.pt
Total transformer layers detected: 40
Freezing layers 0-33, training layers 34-39
IMPROVED VJEPA2 Layer Status:
  Frozen parameters: 549
  Trainable parameters: 294
  Trainable ratio: 34.9%
  Memory savings: ~65% reduction in gradient computation
Model dimensions: Video=1408, Text=1024
✅ Clean model loaded with score: 0.6256
Previous training completed at epoch: 4
Will resume training from epoch: 5
Model loaded on cuda
Video encoder: facebook/vjepa2-vitg-fpc64-384-ssv2
Text encoder: BAAI/bge-large-en-v1.5
Total parameters: 1,371,080,325
Trainable parameters: 509,865,349
GPU Memory after model load: 5.6 GB


In [None]:
# =============================================================================
# CELL 9.5: Updated Configuration for Continued Training
# =============================================================================

import wandb

torch.cuda.empty_cache()
gc.collect()

# Configuration - Training for 5 more epochs
BATCH_SIZE = 6
NUM_FRAMES = 64
GRADIENT_ACCUMULATION_STEPS = 32
EPOCHS = 5  # Additional epochs to train
LR = 2e-4 

# Initialize W&B with continued training info
wandb.init(
    project="vquala-h200-optimization",
    name=f"vjepa2-continued-epoch{start_epoch}-bs{BATCH_SIZE}-frames{NUM_FRAMES}",
    config={
        "batch_size": BATCH_SIZE,
        "num_frames": NUM_FRAMES,
        "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
        "effective_batch_size": BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS,
        "epochs": EPOCHS,
        "additional_epochs": EPOCHS,
        "starting_epoch": start_epoch,
        "previous_best_score": previous_best_score,
        "learning_rate": LR,
        "video_encoder": VJEPA_ID,
        "text_encoder": TEXT_ID,
        "resolution": "384x384",
        "precision": "mixed_fp16",
        "optimizer": "AdamW",
        "scheduler": "CosineAnnealingLR",
        "loss_function": "hybrid_mae_ranking",
        "gpu": "H200",
        "memory_optimization": True,
        "checkpoint_loaded": CHECKPOINT_PATH
    },
    tags=["h200", "vjepa2", "video-quality", "continued-training", "clean-checkpoint"],
    notes=f"Continued training from clean checkpoint with previous best score {previous_best_score:.4f}"
)

wandb.config.update({
    "total_parameters": total_params,
    "trainable_parameters": trainable_params,
    "frozen_parameters": total_params - trainable_params,
    "trainable_ratio": trainable_params / total_params
})

# Watch model (lightweight logging)
wandb.watch(model, log="parameters", log_freq=100, log_graph=False)

In [14]:
train_dataset.processor = model.vproc
val_dataset.processor = model.vproc
train_dataset.num_frames = NUM_FRAMES
val_dataset.num_frames = NUM_FRAMES

collator = OptimizedGPUCollate(model.vproc, model.tenc, device=device, max_frames=NUM_FRAMES)

train_dl = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE,
    shuffle=True, 
    collate_fn=collator, 
    drop_last=True,
    num_workers=0,
    pin_memory=False
)

val_dl = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE,
    shuffle=False, 
    collate_fn=collator,
    num_workers=0,
    pin_memory=False
)

wandb.config.update({
    "train_batches": len(train_dl),
    "val_batches": len(val_dl),
    "train_samples": len(train_dataset),
    "val_samples": len(val_dataset)
})



In [15]:
# =============================================================================
# CELL 11: Updated Training Setup for Continued Training
# =============================================================================

from torch.optim.lr_scheduler import OneCycleLR
BASE_LR = 2e-4
MAX_LR = 8e-4
TOTAL_STEPS = len(train_dl) * EPOCHS

def create_discriminative_param_groups(model, base_lr=BASE_LR):
    param_groups = []
    
    text_params = list(model.tenc.parameters())
    if text_params:
        param_groups.append({
            'params': text_params,
            'lr': base_lr * 0.1,
            'name': 'text_encoder'
        })
    
    video_params = []
    for name, param in model.venc.named_parameters():
        if param.requires_grad:
            video_params.append(param)
    
    if video_params:
        param_groups.append({
            'params': video_params,
            'lr': base_lr * 0.5,
            'name': 'video_encoder'
        })
    
    head_params = list(model.head.parameters())
    if head_params:
        param_groups.append({
            'params': head_params,
            'lr': base_lr * 2.0,
            'name': 'mos_head'
        })
    
    return param_groups

param_groups = create_discriminative_param_groups(model, BASE_LR)

optimizer = optim.AdamW(
    param_groups,
    lr=BASE_LR,
    weight_decay=1e-2,
    betas=(0.9, 0.999),
    eps=1e-8
)

scheduler = OneCycleLR(
    optimizer,
    max_lr=[BASE_LR * 0.1, BASE_LR * 0.5, BASE_LR * 2.0],
    total_steps=TOTAL_STEPS,
    epochs=EPOCHS,
    steps_per_epoch=len(train_dl),
    pct_start=0.3,
    div_factor=25,
    final_div_factor=1000,
    anneal_strategy='cos'
)

scaler = torch.amp.GradScaler('cuda', init_scale=1024.0)

# Use a more conservative scheduler for continued training
from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LR/20)

# Memory test and logging
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

sample_batch = next(iter(train_dl))
model.eval()
with torch.no_grad():
    with torch.amp.autocast('cuda', dtype=torch.float16):
        outputs = model(sample_batch['pixel_values_videos'], sample_batch['text_emb'])

peak_mem = torch.cuda.max_memory_allocated() / 1e9
gpu_total = torch.cuda.get_device_properties(0).total_memory / 1e9

wandb.log({
    "system/peak_memory_test_gb": peak_mem,
    "system/gpu_total_memory_gb": gpu_total,
    "system/memory_efficiency_percent": (peak_mem / gpu_total) * 100,
    "system/video_shape": str(sample_batch['pixel_values_videos'].shape),
    "system/text_shape": str(sample_batch['text_emb'].shape)
})

print(f"Memory test: {peak_mem:.1f}GB ({peak_mem/150*100:.1f}%)")

del sample_batch, outputs
torch.cuda.empty_cache()

Memory test: 27.6GB (18.4%)


In [None]:
# =============================================================================
# CELL 12: Updated Training Loop for Continued Training
# =============================================================================

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

wandb.config.update({
    "discriminative_lrs": True,
    "dynamic_loss_weighting": True,
    "scheduler": "OneCycleLR",
    "base_lr": BASE_LR,
    "max_lr": MAX_LR,
    "lr_text": BASE_LR * 0.1,
    "lr_video": BASE_LR * 0.5,
    "lr_head": BASE_LR * 2.0,
    "loss_components": ["smooth_l1", "ranking", "scale"],
    "warmup_pct": 0.3
})

PATIENCE = 5

torch.cuda.empty_cache()
gc.collect()

torch.cuda.reset_peak_memory_stats()
best_score = previous_best_score
patience_counter = 0

import time
training_start_time = time.time()

try:
    for epoch in range(EPOCHS):
        actual_epoch = start_epoch + epoch + 1
        epoch_start_time = time.time()
        print(f"\nEPOCH {actual_epoch}/{start_epoch + EPOCHS}")
        
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        train_loss = train_epoch(
            model, train_dl, optimizer, scaler, 
            GRADIENT_ACCUMULATION_STEPS, actual_epoch
        )
        
        if train_loss > 0:
            torch.cuda.empty_cache()
            final_score = evaluate(model, val_dl, actual_epoch)
        else:
            final_score = 0.0
        
        peak_mem = torch.cuda.max_memory_allocated() / 1e9
        current_mem = torch.cuda.memory_allocated() / 1e9
        epoch_time = time.time() - epoch_start_time
        
        current_lrs = [group['lr'] for group in optimizer.param_groups]
        
        wandb.log({
            "epoch/train_loss": train_loss,
            "epoch/final_score": final_score,
            "epoch/peak_memory_gb": peak_mem,
            "epoch/current_memory_gb": current_mem,
            "epoch/memory_efficiency_percent": (peak_mem / 150) * 100,
            "epoch/epoch_time_minutes": epoch_time / 60,
            "epoch/actual_epoch": actual_epoch,
            "epoch/training_epoch": epoch + 1,
            "epoch/lr_text": current_lrs[0] if len(current_lrs) > 0 else 0,
            "epoch/lr_video": current_lrs[1] if len(current_lrs) > 1 else 0,
            "epoch/lr_head": current_lrs[2] if len(current_lrs) > 2 else 0,
            "epoch/loss_alpha": loss_manager.alpha,
            "epoch/loss_beta": loss_manager.beta
        })
        
        print(f"Loss: {train_loss:.4f}, Score: {final_score:.4f}, Peak Memory: {peak_mem:.1f}GB, Time: {epoch_time/60:.1f}min")
        
        if final_score > best_score:
            best_score = final_score
            best_score_epoch = actual_epoch
            patience_counter = 0
            
            checkpoint = {
                'epoch': actual_epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'loss_manager_state': {
                    'alpha': loss_manager.alpha,
                    'beta': loss_manager.beta,
                    'mae_history': loss_manager.mae_history,
                    'ranking_history': loss_manager.ranking_history
                },
                'best_score': best_score,
                'original_checkpoint': CHECKPOINT_PATH,
                'training_resumed_from_epoch': start_epoch
            }
            
            checkpoint_path = f"enhanced_checkpoint_epoch_{actual_epoch}_score_{best_score:.4f}.pt"
            torch.save(checkpoint, checkpoint_path)
            wandb.save(checkpoint_path)
            
            wandb.log({
                "best/final_score": best_score,
                "best/actual_epoch": actual_epoch
            })
            
        else:
            patience_counter += 1
            wandb.log({"training/patience_counter": patience_counter})
        
        if patience_counter >= PATIENCE:
            wandb.log({"training/early_stopped": True, "training/stopped_epoch": actual_epoch})
            break
        
        torch.cuda.empty_cache()

except Exception as e:
    wandb.log({"training/error": str(e)})
    raise e

finally:
    total_training_time = time.time() - training_start_time
    final_improvement = best_score - previous_best_score
    
    wandb.log({
        "final/best_score": best_score,
        "final/total_improvement": final_improvement,
        "final/total_training_time_hours": total_training_time / 3600,
        "final/peak_memory_usage_gb": torch.cuda.max_memory_allocated() / 1e9,
        "final/gpu_utilization_percent": (torch.cuda.max_memory_allocated() / 1e9 / 150) * 100
    })
    
    print(f"Previous best score: {previous_best_score:.4f}")
    print(f"New best score: {best_score:.4f}")
    print(f"Total improvement: {final_improvement:.4f}")
    print(f"Total training time: {total_training_time/3600:.2f} hours")
    
    torch.cuda.empty_cache()
    wandb.finish()


EPOCH 6/10
    Batch 0/533, Loss: 0.2673, LRs: 8.00e-07/4.00e-06/1.60e-05, Memory: 6.3GB
    Batch 50/533, Loss: 0.3135, LRs: 1.68e-06/4.57e-06/1.54e-05, Memory: 7.5GB
    Batch 100/533, Loss: 0.3087, LRs: 6.82e-06/7.93e-06/1.21e-05, Memory: 7.5GB
    Batch 150/533, Loss: 0.3195, LRs: 9.12e-06/9.43e-06/1.06e-05, Memory: 7.5GB
    Batch 200/533, Loss: 0.3153, LRs: 9.12e-06/9.43e-06/1.06e-05, Memory: 7.5GB
    Batch 250/533, Loss: 0.3201, LRs: 6.82e-06/7.93e-06/1.21e-05, Memory: 7.5GB
    Batch 300/533, Loss: 0.3251, LRs: 1.68e-06/4.57e-06/1.54e-05, Memory: 7.5GB
    Batch 350/533, Loss: 0.3256, LRs: 8.00e-07/4.00e-06/1.60e-05, Memory: 7.5GB
    Batch 400/533, Loss: 0.3272, LRs: 3.98e-06/6.07e-06/1.39e-05, Memory: 7.5GB
    Batch 450/533, Loss: 0.3220, LRs: 9.12e-06/9.43e-06/1.06e-05, Memory: 7.5GB
    Batch 500/533, Loss: 0.3186, LRs: 1.00e-05/1.00e-05/1.00e-05, Memory: 7.5GB


[h264 @ 0x3cd996c0] Reference 5 >= 5
[h264 @ 0x3cd996c0] error while decoding MB 15 42, bytestream 9292
[h264 @ 0x4147c2c0] left block unavailable for requested intra mode
[h264 @ 0x4147c2c0] error while decoding MB 0 25, bytestream 45493


Loss: 0.3172, Score: 0.6287, Peak Memory: 120.6GB, Time: 90.6min

EPOCH 7/10
    Batch 0/533, Loss: 0.2472, LRs: 6.82e-06/7.93e-06/1.21e-05, Memory: 7.5GB
    Batch 50/533, Loss: 0.3217, LRs: 3.98e-06/6.07e-06/1.39e-05, Memory: 7.5GB
    Batch 100/533, Loss: 0.2977, LRs: 8.00e-07/4.00e-06/1.60e-05, Memory: 7.5GB
    Batch 150/533, Loss: 0.3151, LRs: 1.68e-06/4.57e-06/1.54e-05, Memory: 7.5GB
    Batch 200/533, Loss: 0.3193, LRs: 6.82e-06/7.93e-06/1.21e-05, Memory: 7.5GB
    Batch 250/533, Loss: 0.3207, LRs: 9.12e-06/9.43e-06/1.06e-05, Memory: 7.5GB
    Batch 300/533, Loss: 0.3158, LRs: 9.12e-06/9.43e-06/1.06e-05, Memory: 7.5GB
    Batch 350/533, Loss: 0.3168, LRs: 6.82e-06/7.93e-06/1.21e-05, Memory: 7.5GB
    Batch 400/533, Loss: 0.3181, LRs: 1.68e-06/4.57e-06/1.54e-05, Memory: 7.5GB
    Batch 450/533, Loss: 0.3206, LRs: 1.68e-06/4.57e-06/1.54e-05, Memory: 7.5GB
    Batch 500/533, Loss: 0.3226, LRs: 3.98e-06/6.07e-06/1.39e-05, Memory: 7.5GB


[h264 @ 0x3cb86f80] Reference 5 >= 5
[h264 @ 0x3cb86f80] error while decoding MB 15 42, bytestream 9292
[h264 @ 0x41517e00] left block unavailable for requested intra mode
[h264 @ 0x41517e00] error while decoding MB 0 25, bytestream 45493


Loss: 0.3212, Score: 0.6270, Peak Memory: 120.6GB, Time: 90.9min

EPOCH 8/10
    Batch 0/533, Loss: 0.1868, LRs: 9.12e-06/9.43e-06/1.06e-05, Memory: 7.5GB
    Batch 50/533, Loss: 0.3227, LRs: 1.00e-05/1.00e-05/1.00e-05, Memory: 7.5GB
    Batch 100/533, Loss: 0.3290, LRs: 6.82e-06/7.93e-06/1.21e-05, Memory: 7.5GB
    Batch 150/533, Loss: 0.3221, LRs: 3.98e-06/6.07e-06/1.39e-05, Memory: 7.5GB
    Batch 200/533, Loss: 0.3135, LRs: 8.00e-07/4.00e-06/1.60e-05, Memory: 7.5GB
    Batch 250/533, Loss: 0.3081, LRs: 1.68e-06/4.57e-06/1.54e-05, Memory: 7.5GB
    Batch 300/533, Loss: 0.3061, LRs: 6.82e-06/7.93e-06/1.21e-05, Memory: 7.5GB
    Batch 350/533, Loss: 0.3095, LRs: 9.12e-06/9.43e-06/1.06e-05, Memory: 7.5GB
    Batch 400/533, Loss: 0.3087, LRs: 9.12e-06/9.43e-06/1.06e-05, Memory: 7.5GB
    Batch 450/533, Loss: 0.3039, LRs: 3.98e-06/6.07e-06/1.39e-05, Memory: 7.5GB
    Batch 500/533, Loss: 0.3030, LRs: 1.68e-06/4.57e-06/1.54e-05, Memory: 7.5GB


[h264 @ 0x3cafdfc0] Reference 5 >= 5
[h264 @ 0x3cafdfc0] error while decoding MB 15 42, bytestream 9292
[h264 @ 0x3ca2e880] left block unavailable for requested intra mode
[h264 @ 0x3ca2e880] error while decoding MB 0 25, bytestream 45493


Loss: 0.3048, Score: 0.6281, Peak Memory: 120.6GB, Time: 91.3min

EPOCH 9/10
    Batch 0/533, Loss: 0.2907, LRs: 1.68e-06/4.57e-06/1.54e-05, Memory: 7.5GB
    Batch 50/533, Loss: 0.2646, LRs: 3.98e-06/6.07e-06/1.39e-05, Memory: 7.5GB
    Batch 100/533, Loss: 0.2806, LRs: 9.12e-06/9.43e-06/1.06e-05, Memory: 7.5GB
    Batch 150/533, Loss: 0.2894, LRs: 1.00e-05/1.00e-05/1.00e-05, Memory: 7.5GB
    Batch 200/533, Loss: 0.2991, LRs: 6.82e-06/7.93e-06/1.21e-05, Memory: 7.5GB
    Batch 250/533, Loss: 0.2934, LRs: 3.98e-06/6.07e-06/1.39e-05, Memory: 7.5GB
    Batch 300/533, Loss: 0.2955, LRs: 8.00e-07/4.00e-06/1.60e-05, Memory: 7.5GB
    Batch 350/533, Loss: 0.3025, LRs: 1.68e-06/4.57e-06/1.54e-05, Memory: 7.5GB
    Batch 400/533, Loss: 0.3044, LRs: 6.82e-06/7.93e-06/1.21e-05, Memory: 7.5GB
    Batch 450/533, Loss: 0.3058, LRs: 1.00e-05/1.00e-05/1.00e-05, Memory: 7.5GB
    Batch 500/533, Loss: 0.3066, LRs: 9.12e-06/9.43e-06/1.06e-05, Memory: 7.5GB


[h264 @ 0x3ccb8740] Reference 5 >= 5
[h264 @ 0x3ccb8740] error while decoding MB 15 42, bytestream 9292
[h264 @ 0x3cb19740] left block unavailable for requested intra mode
[h264 @ 0x3cb19740] error while decoding MB 0 25, bytestream 45493


Loss: 0.3065, Score: 0.6301, Peak Memory: 120.6GB, Time: 91.0min

EPOCH 10/10
    Batch 0/533, Loss: 0.3455, LRs: 3.98e-06/6.07e-06/1.39e-05, Memory: 7.5GB
    Batch 50/533, Loss: 0.2797, LRs: 1.68e-06/4.57e-06/1.54e-05, Memory: 7.5GB
    Batch 100/533, Loss: 0.2756, LRs: 1.68e-06/4.57e-06/1.54e-05, Memory: 7.5GB
    Batch 150/533, Loss: 0.2740, LRs: 3.98e-06/6.07e-06/1.39e-05, Memory: 7.5GB
    Batch 200/533, Loss: 0.2716, LRs: 9.12e-06/9.43e-06/1.06e-05, Memory: 7.5GB
    Batch 250/533, Loss: 0.2706, LRs: 1.00e-05/1.00e-05/1.00e-05, Memory: 7.5GB
    Batch 300/533, Loss: 0.2738, LRs: 6.82e-06/7.93e-06/1.21e-05, Memory: 7.5GB
    Batch 350/533, Loss: 0.2792, LRs: 3.98e-06/6.07e-06/1.39e-05, Memory: 7.5GB


In [None]:
# =============================================================================
# CELL 13: Save Final Model
# =============================================================================

# Save the complete trained model with all necessary information (NO wandb config)
print("Saving final model without wandb dependencies...")

# Create a completely clean checkpoint with only essential information
final_checkpoint = {
    'epoch': best_score_epoch if 'best_score_epoch' in locals() else (start_epoch + EPOCHS - 1),
    'model_state_dict': model.state_dict(),  # Clean state dict only
    'best_score': best_score,
    'original_checkpoint': CHECKPOINT_PATH,
    'training_resumed_from_epoch': start_epoch,
    'total_epochs_trained': (best_score_epoch if 'best_score_epoch' in locals() else (start_epoch + EPOCHS - 1)),
    'model_config': {
        'vjepa_id': VJEPA_ID,
        'text_encoder_id': TEXT_ID,
        'video_dim': model.venc.config.hidden_size,
        'text_dim': model.tenc.get_sentence_embedding_dimension(),
        'head_hidden': 512
    },
    'training_config': {
        'batch_size': BATCH_SIZE,
        'num_frames': NUM_FRAMES,
        'gradient_accumulation_steps': GRADIENT_ACCUMULATION_STEPS,
        'effective_batch_size': BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS,
        'learning_rate': LR,
        'additional_epochs': EPOCHS,
        'resolution': '384x384',
        'precision': 'mixed_fp16',
        'optimizer': 'AdamW',
        'scheduler': 'CosineAnnealingLR',
        'loss_function': 'hybrid_mae_ranking'
    },
    'dataset_info': {
        'train_samples': len(train_dataset),
        'val_samples': len(val_dataset),
        'train_batches': len(train_dl),
        'val_batches': len(val_dl)
    },
    'performance_metrics': {
        'previous_best_score': previous_best_score,
        'final_best_score': best_score,
        'improvement': best_score - previous_best_score,
        'total_parameters': total_params,
        'trainable_parameters': trainable_params
    }
}

# Save the clean final model
final_model_path = f'model_final_continued_epoch_{final_checkpoint["epoch"]}_score_{best_score:.4f}.pt'
torch.save(final_checkpoint, final_model_path)

print(f"Clean final model saved to: {final_model_path}")
print(f"Model details:")
print(f"  - Final epoch: {final_checkpoint['epoch']}")
print(f"  - Best score: {best_score:.4f}")
print(f"  - Improvement: {best_score - previous_best_score:.4f}")
print(f"  - Started from: {CHECKPOINT_PATH}")
print(f"  - Total parameters: {total_params:,}")
print(f"  - Trainable parameters: {trainable_params:,}")

In [None]:
# =============================================================================
# TEST DATASET EVALUATION
# =============================================================================

import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader

# Configuration
BATCH_SIZE = 1
NUM_FRAMES = 64

# Test dataset paths (update these paths according to your test data structure)
TEST_CSV = "TaobaoAIGC/data/test/labels/test_labels.csv"  # or wherever your test CSV is
TEST_VID = "TaobaoAIGC/data/test/videos"  # or wherever your test videos are

print(f"Model loaded with score: {best_score:.4f}")

# Create test dataset
print("Setting up test dataset...")
test_df = pd.read_csv(TEST_CSV)
print(f"Test dataset contains {len(test_df)} samples")

test_ds = TaobaoVDDataset(TEST_CSV, TEST_VID, model.vproc, num_frames=NUM_FRAMES, mode='inference')
collator = OptimizedGPUCollate(model.vproc, model.tenc, device=device, max_frames=NUM_FRAMES)
test_dl = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=collator, num_workers=0)

# Generate predictions
model.eval()
predictions = []
video_names = []

print("Generating test predictions...")
with torch.no_grad():
    for i, batch in enumerate(test_dl):
        try:
            with torch.amp.autocast('cuda', dtype=torch.float16):
                outputs = model(batch['pixel_values_videos'], batch['text_emb'])
                overall_score = outputs[0, -1].cpu().item()
            
            predictions.append(overall_score)
            video_names.append(batch['video_names'][0])
            
        except Exception as e:
            print(f"Warning: Error processing test video {i}: {e}")
            predictions.append(3.0)  # Default score
            video_names.append(batch['video_names'][0])
            
        if i % 50 == 0:
            print(f"Processed {i+1}/{len(test_dl)} test videos")
            torch.cuda.empty_cache()

# Save test submission
test_submission_df = pd.DataFrame({
    'video_name': video_names, 
    'Overall_MOS': predictions
})

test_submission_df.to_excel('test_prediction.xlsx', index=False)
test_submission_df.to_csv('test_prediction.csv', index=False)  # Also save as CSV

# Create test README
runtime_per_video = NUM_FRAMES / 20
total_params = sum(p.numel() for p in model.parameters())
model_flops = (total_params * 2) / 1e9

test_readme_content = f'''Test Dataset Evaluation Results
==============================
Runtime per video [s]: {runtime_per_video:.2f}
Flops [GFLOPs]: {model_flops:.1f}
CPU[1] / GPU[0]: 0
Extra Data use: 0
LLM use: 0

Model Details:
- Architecture: V-JEPA2-ViT-G + BGE-Large multimodal model
- Frames per video: {NUM_FRAMES}
- Resolution: 384x384
- Training validation score: {best_score:.4f}
- Test samples processed: {len(predictions)}

Prediction Statistics:
- Min prediction: {min(predictions):.4f}
- Max prediction: {max(predictions):.4f}
- Mean prediction: {np.mean(predictions):.4f}
- Std prediction: {np.std(predictions):.4f}

Files Generated:
- test_prediction.xlsx (Excel format)
- test_prediction.csv (CSV format)
- test_readme.txt (this file)
'''

with open('test_readme.txt', 'w') as f:
    f.write(test_readme_content)

# Display results
print(f"\n✅ Test evaluation completed!")
print(f"Test videos processed: {len(predictions)}")
print(f"Prediction statistics:")
print(f"  Min: {min(predictions):.4f}")
print(f"  Max: {max(predictions):.4f}")
print(f"  Mean: {np.mean(predictions):.4f}")
print(f"  Std: {np.std(predictions):.4f}")
print(f"\nFiles created:")
print(f"  📁 test_prediction.xlsx")
print(f"  📁 test_prediction.csv") 
print(f"  📁 test_readme.txt")

# Optional: Show first few predictions
print(f"\nFirst 5 test predictions:")
for i in range(min(5, len(test_submission_df))):
    video_name = test_submission_df.iloc[i]['video_name']
    score = test_submission_df.iloc[i]['Overall_MOS']
    print(f"  {video_name}: {score:.4f}")

# Clean up
torch.cuda.empty_cache()
print(f"\nGPU memory: {torch.cuda.memory_allocated()/1e9:.1f}GB")