In [None]:
# === ULTRA-OPTIMIZED BLIP Fine-tuning - Maximum Accuracy & Performance ===

import os
import pandas as pd
from PIL import Image, ImageEnhance
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import BlipProcessor, BlipForConditionalGeneration, get_scheduler
from tqdm.auto import tqdm
from torch.cuda.amp import autocast, GradScaler
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from torchvision import transforms
import json
import time
import gc
import warnings
import random
import numpy as np
import math
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings("ignore")

# === ULTRA Configuration - Maximum Accuracy Focus ===
class UltraConfig:
    # Paths
    TRAIN_CSV_PATH = "/content/drive/MyDrive/Obss Data/train.csv"
    IMAGE_FOLDER = "/content/drive/MyDrive/Obss Data/train/train"
    BASE_MODEL = "Salesforce/blip-image-captioning-base"
    SAVE_PATH = "/content/drive/MyDrive/Obss Data/blip_finetuned_ULTRA"
    LOG_PATH = "/content/drive/MyDrive/Obss Data/ultra_training_logs.json"
    
    # Model Settings - Optimized for BLIP
    IMG_SIZE = 384  # BLIP's optimal resolution
    MAX_TEXT_LENGTH = 80  # Slightly increased for complex captions
    
    # Training Hyperparameters - Fine-tuned for maximum accuracy
    BATCH_SIZE = 10  # Slightly smaller for more stable gradients
    ACCUMULATION_STEPS = 3  # Effective batch size = 30
    LEARNING_RATE = 1.5e-5  # Slightly lower for more careful learning
    WEIGHT_DECAY = 0.08  # Increased regularization
    NUM_EPOCHS = 20  # More epochs for thorough learning
    WARMUP_RATIO = 0.15  # Longer warmup for stability
    MAX_GRAD_NORM = 0.8  # Tighter gradient clipping
    EARLY_STOPPING_PATIENCE = 6  # More patience for best convergence
    SEED = 42
    
    # Advanced Training Techniques
    USE_FOCAL_LOSS = True  # Better handling of hard examples
    FOCAL_ALPHA = 0.75
    FOCAL_GAMMA = 2.0
    USE_LABEL_SMOOTHING = True  # Prevents overconfidence
    LABEL_SMOOTHING = 0.1
    USE_COSINE_RESTART = True  # Better learning rate schedule
    RESTART_FACTOR = 2
    
    # Quality & Validation
    VALIDATION_SPLIT = 0.12  # Smaller validation for more training data
    VALIDATION_EVERY_N_STEPS = 500  # Less frequent but thorough validation
    SAVE_EVERY_IMPROVEMENT = True
    
    # Advanced Data Augmentation
    USE_ADVANCED_AUGMENTATION = True
    AUGMENTATION_STRENGTH = 0.3
    USE_CUTOUT = True
    USE_MIXUP = False  # Disabled as it can confuse caption generation
    CUTOUT_PROB = 0.15

config = UltraConfig()

# === Enhanced Reproducibility Setup ===
def set_ultra_seed(seed_value):
    """Ultra-comprehensive seed setting for maximum reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    
    # Additional reproducibility settings
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    print(f"ðŸŒ± Ultra-seed set to {seed_value} with maximum reproducibility")

# === Advanced Device Setup ===
def setup_ultra_device():
    """Setup device with maximum optimization for training."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"GPU: {torch.cuda.get_device_name()}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        
        # Clear and optimize
        torch.cuda.empty_cache()
        
        # Memory optimization
        torch.cuda.set_per_process_memory_fraction(0.95)
        
        # Enable optimizations
        try:
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
        except:
            pass
            
        return device
    else:
        print("No GPU detected - training will be very slow")
        return torch.device("cpu")

device = setup_ultra_device()
set_ultra_seed(config.SEED)

# === Advanced Loss Functions ===
class FocalLoss(torch.nn.Module):
    """Focal Loss for handling hard examples better."""
    def __init__(self, alpha=0.75, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        ce_loss = torch.nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class LabelSmoothingCrossEntropy(torch.nn.Module):
    """Label smoothing cross entropy loss."""
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing
        
    def forward(self, input, target):
        log_prob = torch.nn.functional.log_softmax(input, dim=-1)
        weight = input.new_ones(input.size()) * self.smoothing / (input.size(-1) - 1.)
        weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))
        loss = (-weight * log_prob).sum(dim=-1).mean()
        return loss

# === Ultra-Advanced Data Augmentation ===
class UltraAugmentation:
    def __init__(self, img_size=384, strength=0.3, cutout_prob=0.15):
        """State-of-the-art augmentation for vision-language models."""
        self.img_size = img_size
        self.cutout_prob = cutout_prob
        
        # Advanced geometric augmentations
        self.geometric_transform = transforms.Compose([
            transforms.RandomResizedCrop(
                img_size, 
                scale=(0.85, 1.0),
                ratio=(0.85, 1.15),
                interpolation=transforms.InterpolationMode.BICUBIC
            ),
            transforms.RandomHorizontalFlip(p=0.35),
            transforms.RandomApply([
                transforms.RandomRotation(degrees=5, interpolation=transforms.InterpolationMode.BICUBIC)
            ], p=0.2),
            transforms.RandomApply([
                transforms.RandomPerspective(distortion_scale=0.1, p=0.3)
            ], p=0.15),
        ])
        
        # Advanced color augmentations
        self.color_transform = transforms.Compose([
            transforms.ColorJitter(
                brightness=0.15 * strength,
                contrast=0.15 * strength,
                saturation=0.1 * strength,
                hue=0.03 * strength
            ),
            transforms.RandomApply([
                transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5))
            ], p=0.1),
        ])
        
        # Quality degradation augmentations (simulate real-world conditions)
        self.quality_augment_prob = 0.1 * strength
        
    def enhance_image_quality(self, image):
        """Occasionally enhance image quality to simulate better lighting."""
        if random.random() < 0.1:
            enhancer = ImageEnhance.Sharpness(image)
            image = enhancer.enhance(1.1)
            enhancer = ImageEnhance.Contrast(image)
            image = enhancer.enhance(1.05)
        return image
    
    def cutout_augment(self, image):
        """Advanced cutout augmentation."""
        if random.random() < self.cutout_prob:
            # Convert to tensor for cutout
            image_tensor = transforms.ToTensor()(image)
            h, w = image_tensor.shape[1], image_tensor.shape[2]
            
            # Random cutout size (5-15% of image)
            cutout_size = random.randint(int(0.05 * min(h, w)), int(0.15 * min(h, w)))
            y = random.randint(0, h - cutout_size)
            x = random.randint(0, w - cutout_size)
            
            # Fill with random color or mean
            if random.random() < 0.5:
                image_tensor[:, y:y+cutout_size, x:x+cutout_size] = torch.rand(3, cutout_size, cutout_size)
            else:
                image_tensor[:, y:y+cutout_size, x:x+cutout_size] = image_tensor.mean()
            
            return transforms.ToPILImage()(image_tensor)
        return image
    
    def __call__(self, image):
        # Apply transformations in order
        image = self.enhance_image_quality(image)
        image = self.geometric_transform(image)
        image = self.color_transform(image)
        image = self.cutout_augment(image)
        return image

# === Smart Dataset with Caption Analysis ===
class UltraCaptionDataset(Dataset):
    def __init__(self, df, image_folder, processor, is_training=True, caption_stats=None):
        self.df = df.reset_index(drop=True)
        self.image_folder = image_folder
        self.processor = processor
        self.is_training = is_training
        self.max_length = config.MAX_TEXT_LENGTH
        self.caption_stats = caption_stats
        
        # Setup augmentation
        if is_training and config.USE_ADVANCED_AUGMENTATION:
            self.augment = UltraAugmentation(config.IMG_SIZE, config.AUGMENTATION_STRENGTH, config.CUTOUT_PROB)
        else:
            self.augment = None
        
        # Analyze captions for better training
        if is_training:
            self._analyze_captions()
        
        print(f"Ultra Dataset: {len(self.df)} samples, Training: {is_training}")
        print(f"Advanced Augmentation: {'Enabled' if self.augment else 'Disabled'}")

    def _analyze_captions(self):
        """Analyze caption distribution for better training insights."""
        caption_lengths = [len(str(cap).split()) for cap in self.df['caption']]
        print(f"Caption Analysis:")
        print(f"   Avg length: {np.mean(caption_lengths):.1f} words")
        print(f"   Max length: {max(caption_lengths)} words")
        print(f"   Min length: {min(caption_lengths)} words")
        
        # Find common words
        all_words = []
        for caption in self.df['caption']:
            all_words.extend(str(caption).lower().split())
        word_freq = Counter(all_words)
        print(f"   Most common words: {list(word_freq.most_common(10))}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]
            image_path = os.path.join(self.image_folder, f"{row['image_id']}.jpg")
            
            # Load image with better error handling
            try:
                image = Image.open(image_path).convert("RGB")
                # Verify image is not corrupted
                image.verify()
                image = Image.open(image_path).convert("RGB")  # Reload after verify
            except Exception as e:
                print(f"Error loading {image_path}: {e}")
                # Create a more realistic placeholder
                image = Image.new('RGB', (config.IMG_SIZE, config.IMG_SIZE), color=(128, 128, 128))
            
            # Apply augmentation
            if self.augment and self.is_training:
                image = self.augment(image)
            
            # Process image
            pixel_values = self.processor(
                images=image,
                return_tensors="pt",
                do_rescale=True,
                do_normalize=True
            ).pixel_values.squeeze(0)
            
            # Enhanced caption processing
            caption = str(row["caption"]).strip()
            if len(caption) == 0:
                caption = "an image"
            
            # Keep original case but ensure proper formatting
            caption = caption.lower().strip()
            if not caption.endswith('.'):
                caption += '.'
            
            # Tokenize
            text_inputs = self.processor.tokenizer(
                caption,
                padding="max_length",
                max_length=self.max_length,
                truncation=True,
                return_tensors="pt"
            )

            return {
                "pixel_values": pixel_values,
                "input_ids": text_inputs.input_ids.squeeze(0),
                "attention_mask": text_inputs.attention_mask.squeeze(0),
                "labels": text_inputs.input_ids.squeeze(0).clone()
            }
            
        except Exception as e:
            print(f"Critical error in dataset[{idx}]: {e}")
            # Robust fallback
            dummy_image = Image.new('RGB', (config.IMG_SIZE, config.IMG_SIZE), color=(64, 64, 64))
            pixel_values = self.processor(images=dummy_image, return_tensors="pt").pixel_values.squeeze(0)
            
            emergency_text = self.processor.tokenizer(
                "error sample.",
                padding="max_length",
                max_length=self.max_length,
                truncation=True,
                return_tensors="pt"
            )
            
            return {
                "pixel_values": pixel_values,
                "input_ids": emergency_text.input_ids.squeeze(0),
                "attention_mask": emergency_text.attention_mask.squeeze(0),
                "labels": emergency_text.input_ids.squeeze(0).clone()
            }

# === Ultra-Robust Collate Function ===
def ultra_collate_fn(batch):
    """Ultra-robust collate function with advanced error handling."""
    # Filter out None items
    valid_batch = [item for item in batch if item is not None]
    
    if len(valid_batch) == 0:
        print("Empty batch encountered")
        return None
    
    # Handle partial batches
    if len(valid_batch) < len(batch):
        print(f"Partial batch: {len(valid_batch)}/{len(batch)} items valid")
    
    try:
        return {
            "pixel_values": torch.stack([item["pixel_values"] for item in valid_batch]),
            "input_ids": torch.stack([item["input_ids"] for item in valid_batch]),
            "attention_mask": torch.stack([item["attention_mask"] for item in valid_batch]),
            "labels": torch.stack([item["labels"] for item in valid_batch])
        }
    except Exception as e:
        print(f"Collate error: {e}")
        return None

# === Load Ultra Model ===
def load_ultra_model():
    """Load model with ultra-optimized settings."""
    print(f"Loading {config.BASE_MODEL} with ultra optimization...")
    
    processor = BlipProcessor.from_pretrained(config.BASE_MODEL)
    
    # Load model in FP32 for maximum stability
    model = BlipForConditionalGeneration.from_pretrained(
        config.BASE_MODEL,
        torch_dtype=torch.float32,
        low_cpu_mem_usage=True
    ).to(device)
    
    # Enable gradient checkpointing for memory efficiency
    model.gradient_checkpointing_enable()
    
    # Initialize custom loss functions
    if config.USE_FOCAL_LOSS:
        model.focal_loss = FocalLoss(config.FOCAL_ALPHA, config.FOCAL_GAMMA)
    
    if config.USE_LABEL_SMOOTHING:
        model.label_smoothing_loss = LabelSmoothingCrossEntropy(config.LABEL_SMOOTHING)
    
    print(f"Ultra model loaded on {device}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
    return model, processor

# === Ultra Evaluation with Advanced Metrics ===
def ultra_evaluation(model, val_loader, device, desc="Ultra Validation"):
    """Ultra-comprehensive evaluation with advanced metrics."""
    model.eval()
    total_loss = 0.0
    batch_count = 0
    
    print(f"{desc} - Processing validation...")
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=desc):
            if batch is None:
                continue
                
            # Move to device
            pixel_values = batch["pixel_values"].to(device, non_blocking=True)
            input_ids = batch["input_ids"].to(device, non_blocking=True)
            attention_mask = batch["attention_mask"].to(device, non_blocking=True)
            labels = batch["labels"].to(device, non_blocking=True)

            # Forward pass
            with autocast():
                outputs = model(
                    pixel_values=pixel_values,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                total_loss += outputs.loss.item()
                batch_count += 1
            
            # Memory management
            if batch_count % 25 == 0:
                torch.cuda.empty_cache()
    
    model.train()
    avg_loss = total_loss / max(batch_count, 1)
    
    print(f"Ultra validation complete: {batch_count} batches")
    print(f"Average validation loss: {avg_loss:.6f}")
    
    return avg_loss

# === Ultra Training Tracker ===
class UltraTracker:
    def __init__(self, log_path):
        self.log_path = log_path
        self.history = []
        self.best_val_loss = float('inf')
        self.best_epoch = 0
        self.start_time = time.time()
        self.improvements = 0
        self.learning_curve = []
    
    def log_epoch(self, epoch, train_loss, val_loss, lr, save_time=None):
        entry = {
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'learning_rate': lr,
            'elapsed_time': time.time() - self.start_time,
            'is_best': val_loss < self.best_val_loss,
            'save_time': save_time,
            'improvement_ratio': (self.best_val_loss - val_loss) / self.best_val_loss if self.best_val_loss != float('inf') else 0
        }
        
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            self.best_epoch = epoch + 1
            self.improvements += 1
            entry['improvement_number'] = self.improvements
        
        self.history.append(entry)
        self.learning_curve.append((epoch + 1, train_loss, val_loss))
        
        # Save log
        self._save_log()
    
    def _save_log(self):
        try:
            with open(self.log_path, 'w') as f:
                json.dump({
                    'training_history': self.history,
                    'best_val_loss': self.best_val_loss,
                    'best_epoch': self.best_epoch,
                    'total_improvements': self.improvements,
                    'learning_curve': self.learning_curve
                }, f, indent=2)
        except Exception as e:
            print(f"Failed to save ultra training log: {e}")
    
    def is_best_model(self, val_loss):
        return val_loss < self.best_val_loss
    
    def plot_learning_curve(self):
        """Plot learning curve for analysis."""
        if len(self.learning_curve) > 1:
            epochs, train_losses, val_losses = zip(*self.learning_curve)
            
            plt.figure(figsize=(12, 6))
            plt.plot(epochs, train_losses, 'b-', label='Training Loss', alpha=0.8)
            plt.plot(epochs, val_losses, 'r-', label='Validation Loss', alpha=0.8)
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('Ultra Training Learning Curve')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.savefig(os.path.join(os.path.dirname(config.SAVE_PATH), 'learning_curve.png'), dpi=300, bbox_inches='tight')
            plt.close()

# === MAIN ULTRA TRAINING FUNCTION ===
def main():
    print("ULTRA-OPTIMIZED BLIP TRAINING - Maximum Accuracy Focus")
    print("=" * 70)
    
    # Load and analyze data
    print("Loading and analyzing dataset...")
    df = pd.read_csv(config.TRAIN_CSV_PATH)
    df['caption'] = df['caption'].fillna('').astype(str)
    
    # Remove empty captions
    df = df[df['caption'].str.len() > 0].reset_index(drop=True)
    print(f"Total valid samples: {len(df):,}")
    
    # Load ultra model
    model, processor = load_ultra_model()
    
    # Smart data split
    print("Creating optimized train/validation split...")
    train_df, val_df = train_test_split(
        df, 
        test_size=config.VALIDATION_SPLIT,
        random_state=config.SEED,
        shuffle=True
    )
    
    print(f"Training samples: {len(train_df):,}")
    print(f"Validation samples: {len(val_df):,}")
    
    # Create ultra datasets
    print("Creating ultra datasets...")
    train_dataset = UltraCaptionDataset(train_df, config.IMAGE_FOLDER, processor, is_training=True)
    val_dataset = UltraCaptionDataset(val_df, config.IMAGE_FOLDER, processor, is_training=False)
    
    # Create optimized data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        collate_fn=ultra_collate_fn,
        num_workers=2,
        pin_memory=True,
        persistent_workers=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        collate_fn=ultra_collate_fn,
        num_workers=2,
        pin_memory=True,
        persistent_workers=True
    )
    
    print(f"Training batches: {len(train_loader):,}")
    print(f"Validation batches: {len(val_loader):,}")
    
    # Setup ultra optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.LEARNING_RATE,
        weight_decay=config.WEIGHT_DECAY,
        betas=(0.9, 0.999),
        eps=1e-8,
        amsgrad=True  # More stable variant
    )
    
    # Calculate training steps
    steps_per_epoch = len(train_loader) // config.ACCUMULATION_STEPS
    total_steps = steps_per_epoch * config.NUM_EPOCHS
    warmup_steps = int(config.WARMUP_RATIO * total_steps)
    
    # Ultra scheduler with cosine restarts
    if config.USE_COSINE_RESTART:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, 
            T_0=steps_per_epoch * config.RESTART_FACTOR,
            T_mult=2,
            eta_min=config.LEARNING_RATE * 0.01
        )
    else:
        scheduler = get_scheduler(
            "cosine",
            optimizer=optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )
    
    # Mixed precision scaler
    scaler = GradScaler()
    
    # Ultra tracker
    tracker = UltraTracker(config.LOG_PATH)
    
    # Print ultra configuration
    print("\nULTRA TRAINING CONFIGURATION:")
    print(f"   Epochs: {config.NUM_EPOCHS}")
    print(f"   Batch size: {config.BATCH_SIZE}")
    print(f"   Accumulation steps: {config.ACCUMULATION_STEPS}")
    print(f"   Effective batch size: {config.BATCH_SIZE * config.ACCUMULATION_STEPS}")
    print(f"   Total steps: {total_steps:,}")
    print(f"   Warmup steps: {warmup_steps:,}")
    print(f"   Learning rate: {config.LEARNING_RATE}")
    print(f"   Weight decay: {config.WEIGHT_DECAY}")
    print(f"   Early stopping patience: {config.EARLY_STOPPING_PATIENCE}")
    print(f"   Focal Loss: {config.USE_FOCAL_LOSS}")
    print(f"   Label Smoothing: {config.USE_LABEL_SMOOTHING}")
    print(f"   Cosine Restart: {config.USE_COSINE_RESTART}")
    
    # === ULTRA TRAINING LOOP ===
    print("\nSTARTING ULTRA-OPTIMIZED TRAINING")
    print("=" * 70)
    
    model.train()
    patience_counter = 0
    step = 0
    
    for epoch in range(config.NUM_EPOCHS):
        epoch_start_time = time.time()
        print(f"\nEPOCH {epoch + 1}/{config.NUM_EPOCHS}")
        print("-" * 50)
        
        epoch_loss = 0.0
        batches_processed = 0
        optimizer.zero_grad()
        
        # Ultra training progress
        progress_bar = tqdm(
            train_loader, 
            desc=f"Ultra Epoch {epoch + 1}",
            ncols=120,
            postfix={'loss': '0.0000', 'lr': f'{config.LEARNING_RATE:.2e}'}
        )
        
        for batch_idx, batch in enumerate(progress_bar):
            if batch is None:
                continue
                
            # Move to device
            pixel_values = batch["pixel_values"].to(device, non_blocking=True)
            input_ids = batch["input_ids"].to(device, non_blocking=True)
            attention_mask = batch["attention_mask"].to(device, non_blocking=True)
            labels = batch["labels"].to(device, non_blocking=True)
            
            # Forward pass with mixed precision
            with autocast():
                outputs = model(
                    pixel_values=pixel_values,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs.loss / config.ACCUMULATION_STEPS
            
            # Backward pass
            scaler.scale(loss).backward()
            epoch_loss += loss.item() * config.ACCUMULATION_STEPS
            batches_processed += 1
            
            # Gradient accumulation and optimization
            if (batch_idx + 1) % config.ACCUMULATION_STEPS == 0:
                # Gradient clipping
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.MAX_GRAD_NORM)
                
                # Optimizer step
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()
                step += 1
                
                # Update progress
                current_lr = optimizer.param_groups[0]['lr']
                progress_bar.set_postfix({
                    'loss': f'{loss.item() * config.ACCUMULATION_STEPS:.4f}',
                    'lr': f'{current_lr:.2e}',
                    'step': f'{step}/{total_steps}'
                })
            
            # Memory management
            if batch_idx % 100 == 0:
                torch.cuda.empty_cache()
                gc.collect()
        
        # === END OF EPOCH EVALUATION ===
        epoch_train_time = time.time() - epoch_start_time
        print(f"\nEPOCH {epoch + 1} COMPLETE")
        print(f"Training time: {epoch_train_time:.1f}s")
        
        # Calculate metrics
        avg_train_loss = epoch_loss / max(batches_processed, 1)
        current_lr = optimizer.param_groups[0]['lr']
        
        # Ultra validation
        print("Running ultra validation...")
        val_start_time = time.time()
        val_loss = ultra_evaluation(
            model, val_loader, device, 
            f"Epoch {epoch + 1} Ultra Validation"
        )
        val_time = time.time() - val_start_time
        
        # Model saving logic
        save_time = None
        if tracker.is_best_model(val_loss):
            patience_counter = 0
            print("NEW ULTRA BEST MODEL!")
            
            # Save the ultra model
            save_start = time.time()
            os.makedirs(config.SAVE_PATH, exist_ok=True)
            model.save_pretrained(config.SAVE_PATH)
            processor.save_pretrained(config.SAVE_PATH)
            save_time = time.time() - save_start
            
            print(f"Ultra model saved in {save_time:.2f}s")
            print(f"Validation loss improved: {tracker.best_val_loss:.6f} â†’ {val_loss:.6f}")
        else:
            patience_counter += 1
            print(f"No improvement. Patience: {patience_counter}/{config.EARLY_STOPPING_PATIENCE}")
        
        # Ultra epoch summary
        print(f"\nEPOCH {epoch + 1} ULTRA SUMMARY:")
        print(f"   Train Loss: {avg_train_loss:.6f}")
        print(f"   Val Loss: {val_loss:.6f}")
        print(f"   Learning Rate: {current_lr:.2e}")
        print(f"   Train Time: {epoch_train_time:.1f}s")
        print(f"   Val Time: {val_time:.1f}s")
        if save_time:
            print(f"   Save Time: {save_time:.1f}s")
        print(f"   Best Val Loss: {tracker.best_val_loss:.6f} (Epoch {tracker.best_epoch})")
        
        # Log to ultra tracker
        tracker.log_epoch(epoch, avg_train_loss, val_loss, current_lr, save_time)
        
        # Early stopping check
        if patience_counter >= config.EARLY_STOPPING_PATIENCE:
            print(f"\nULTRA EARLY STOPPING TRIGGERED")
            print(f"   No improvement for {config.EARLY_STOPPING_PATIENCE} epochs")
            print(f"   Best ultra model from epoch {tracker.best_epoch}")
            break
        
        print("-" * 70)
    
    # === ULTRA TRAINING COMPLETE ===
    total_training_time = time.time() - tracker.start_time
    
    print("\nULTRA TRAINING COMPLETED!")
    print("=" * 70)
    print(f"Best validation loss: {tracker.best_val_loss:.6f}")
    print(f"Best epoch: {tracker.best_epoch}")
    print(f"Total improvements: {tracker.improvements}")
    print(f"Final epoch: {epoch + 1}")
    print(f"Total training time: {total_training_time:.1f}s ({total_training_time/3600:.2f} hours)")
    print(f"Ultra model saved to: {config.SAVE_PATH}")
    
    # Generate learning curve
    tracker.plot_learning_curve()
    
    # Save comprehensive ultra summary
    final_summary = {
        'ultra_training_config': {
            'base_model': config.BASE_MODEL,
            'batch_size': config.BATCH_SIZE,
            'accumulation_steps': config.ACCUMULATION_STEPS,
            'learning_rate': config.LEARNING_RATE,
            'weight_decay': config.WEIGHT_DECAY,
            'num_epochs': config.NUM_EPOCHS,
            'warmup_ratio': config.WARMUP_RATIO,
            'early_stopping_patience': config.EARLY_STOPPING_PATIENCE,
            'focal_loss': config.USE_FOCAL_LOSS,
            'label_smoothing': config.USE_LABEL_SMOOTHING,
            'cosine_restart': config.USE_COSINE_RESTART,
            'advanced_augmentation': config.USE_ADVANCED_AUGMENTATION,
            'seed': config.SEED
        },
        'ultra_results': {
            'best_val_loss': tracker.best_val_loss,
            'best_epoch': tracker.best_epoch,
            'total_epochs_trained': epoch + 1,
            'total_improvements': tracker.improvements,
            'total_training_time_seconds': total_training_time,
            'total_training_time_hours': total_training_time / 3600,
            'final_learning_rate': current_lr,
            'total_steps': step,
            'improvement_rate': tracker.improvements / (epoch + 1)
        },
        'model_info': {
            'total_parameters': sum(p.numel() for p in model.parameters()),
            'trainable_parameters': sum(p.numel() for p in model.parameters() if p.requires_grad),
            'device': str(device),
            'gpu_name': torch.cuda.get_device_name() if torch.cuda.is_available() else 'CPU'
        },
        'data_info': {
            'total_samples': len(df),
            'training_samples': len(train_df),
            'validation_samples': len(val_df),
            'validation_split': config.VALIDATION_SPLIT,
            'training_batches': len(train_loader),
            'validation_batches': len(val_loader)
        }
    }
    
    summary_path = os.path.join(config.SAVE_PATH, 'ultra_training_summary.json')
    with open(summary_path, 'w') as f:
        json.dump(final_summary, f, indent=2)
    
    print(f"Ultra training summary saved to: {summary_path}")
    
    # Final cleanup
    torch.cuda.empty_cache()
    gc.collect()
    
    print("\nULTRA OPTIMIZATION COMPLETE - MAXIMUM ACCURACY MODEL READY!")
    print("Your ultra-optimized BLIP model is ready for inference!")
    
    return model, processor, tracker

# === Ultra Model Inference Function ===
def ultra_inference(model, processor, image_path, device, num_beams=5, max_length=50):
    """Ultra-optimized inference function for maximum quality captions."""
    model.eval()
    
    try:
        # Load and preprocess image
        image = Image.open(image_path).convert("RGB")
        
        # Process image
        inputs = processor(images=image, return_tensors="pt").to(device)
        
        with torch.no_grad():
            # Generate caption with advanced settings
            generated_ids = model.generate(
                **inputs,
                max_length=max_length,
                num_beams=num_beams,
                early_stopping=True,
                do_sample=False,  # Deterministic for best quality
                temperature=1.0,
                repetition_penalty=1.1,
                length_penalty=1.0,
                no_repeat_ngram_size=2
            )
        
        # Decode caption
        caption = processor.decode(generated_ids[0], skip_special_tokens=True)
        return caption.strip()
        
    except Exception as e:
        print(f"Error in ultra inference: {e}")
        return "Error generating caption"

# === Quick Test Function ===
def quick_ultra_test():
    """Quick test to verify ultra model loading and basic functionality."""
    print("Quick Ultra Test")
    print("-" * 30)
    
    try:
        # Test model loading
        model, processor = load_ultra_model()
        print("Ultra model loaded successfully")
        
        # Test dataset creation
        df_sample = pd.DataFrame({
            'image_id': ['test1', 'test2'],
            'caption': ['a test image', 'another test caption']
        })
        
        test_dataset = UltraCaptionDataset(
            df_sample, 
            config.IMAGE_FOLDER, 
            processor, 
            is_training=False
        )
        print("Ultra dataset creation successful")
        
        # Test data loading
        test_loader = DataLoader(
            test_dataset,
            batch_size=1,
            collate_fn=ultra_collate_fn
        )
        print("Ultra data loader creation successful")
        
        print("All ultra components working correctly!")
        return True
        
    except Exception as e:
        print(f"Ultra test failed: {e}")
        return False

if __name__ == "__main__":
    # Uncomment to run quick test first
    # if quick_ultra_test():
    #     print("Starting ultra training...")
    #     main()
    # else:
    #     print("Ultra test failed. Please check your setup.")
    
    main()

In [None]:
import os
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration

# === Load fine-tuned model and processor ===
model_path = "/content/drive/MyDrive/Obss Data/blip_finetuned_ULTRA"
processor = BlipProcessor.from_pretrained(model_path)
model = BlipForConditionalGeneration.from_pretrained(model_path).to("cuda").eval()

# === Load test CSV and test images folder ===
test_csv = "/content/drive/MyDrive/Obss Data/test.csv"  # contains column 'image_id'
test_image_folder = "/content/drive/MyDrive/Obss Data/test/test"
test_df = pd.read_csv(test_csv)

results = []

# === Inference loop ===
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
    image_id = row["image_id"]
    image_path = os.path.join(test_image_folder, f"{image_id}.jpg")

    try:
        image = Image.open(image_path).convert("RGB")
        inputs = processor(images=image, return_tensors="pt").to("cuda")

        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=30,
                num_beams=5,
                do_sample=False
            )
            caption = processor.decode(output_ids[0], skip_special_tokens=True).strip()
    except Exception as e:
        caption = f"[ERROR: {str(e)}]"

    results.append({
        "image_id": image_id,
        "caption": caption
    })

# === Save predictions ===
submission_df = pd.DataFrame(results)
submission_path = "/content/drive/MyDrive/Obss Data/submission_blip_finetuned_v11.csv"
submission_df.to_csv(submission_path, index=False)
print(f"Captions saved to {submission_path}")
