##  Basic Library imports

In [None]:
# %pip install -U transformers datasets peft trl bitsandbytes accelerate qwen-vl-utils pillow tensorboard

In [1]:
import os
import pandas as pd
import json
from urllib.parse import urlparse
import warnings
import torch
from PIL import Image
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["OMP_NUM_THREADS"] = "16"
os.environ["UNSLOTH_DISABLE_TRAINER_PATCHING"] = "1"  
os.environ["UNSLOTH_NO_CUDA_EXTENSIONS"] = "1"

warnings.filterwarnings("ignore")

DATASET_FOLDER = "../dataset/"

In [2]:
FULL_TRAINING_MODE = True

# Set to True to run inference on test set and generate submission
RUN_INFERENCE = False

# Set to True to skip dataset creation if already created
SKIP_DATASET_CREATION = False

# Set to True to skip image downloading if already done
SKIP_IMAGE_DOWNLOAD = False

INFERENCE_CHECKPOINT = None  # None = use default output_dir

INFERENCE_ONLY = False

RESUME_FROM_CHECKPOINT = None



# Options for resuming:
# RESUME_FROM_CHECKPOINT = "./qwen2-vl-7b-price-predictor-best/periodic_checkpoint"  # Latest periodic
# RESUME_FROM_CHECKPOINT = "./qwen2-vl-7b-price-predictor-best/checkpoint-epoch-1"  # Specific epoch
# RESUME_FROM_CHECKPOINT = True  # Auto-detect latest checkpoint

print("Configuration:")
print(f"  Full Training Mode: {FULL_TRAINING_MODE}")
print(f"  Run Inference: {RUN_INFERENCE}")
print(f"  Skip Dataset Creation: {SKIP_DATASET_CREATION}")
print(f"  Skip Image Download: {SKIP_IMAGE_DOWNLOAD}")



"""
development (validation mode):
FULL_TRAINING_MODE = False
INFERENCE_ONLY = False
RUN_INFERENCE = False
SKIP_DATASET_CREATION = False
SKIP_IMAGE_DOWNLOAD = False

full training:
FULL_TRAINING_MODE = True   # Use all data
INFERENCE_ONLY = False
RUN_INFERENCE = False
SKIP_DATASET_CREATION = True   # Reuse
SKIP_IMAGE_DOWNLOAD = True     # Reuse

generate submission:
FULL_TRAINING_MODE = True   # Doesn't matter
INFERENCE_ONLY = True       # Skip training!
RUN_INFERENCE = True        # Generate predictions
SKIP_DATASET_CREATION = True
SKIP_IMAGE_DOWNLOAD = True

"""

Configuration:
  Full Training Mode: True
  Run Inference: False
  Skip Dataset Creation: False
  Skip Image Download: False


"\ndevelopment (validation mode):\nFULL_TRAINING_MODE = False\nINFERENCE_ONLY = False\nRUN_INFERENCE = False\nSKIP_DATASET_CREATION = False\nSKIP_IMAGE_DOWNLOAD = False\n\nfull training:\nFULL_TRAINING_MODE = True   # Use all data\nINFERENCE_ONLY = False\nRUN_INFERENCE = False\nSKIP_DATASET_CREATION = True   # Reuse\nSKIP_IMAGE_DOWNLOAD = True     # Reuse\n\ngenerate submission:\nFULL_TRAINING_MODE = True   # Doesn't matter\nINFERENCE_ONLY = True       # Skip training!\nRUN_INFERENCE = True        # Generate predictions\nSKIP_DATASET_CREATION = True\nSKIP_IMAGE_DOWNLOAD = True\n\n"

In [4]:
import json
import random

print("=" * 60)
print("Loading Pre-Split Datasets...")
print("=" * 60)

# Load training split
print("\n📂 Loading training data from train_split.jsonl...")
train_dataset = []
with open(f"{DATASET_FOLDER}/train_split.jsonl", 'r', encoding='utf-8') as f:
    for line_num, line in enumerate(f, 1):
        try:
            train_dataset.append(json.loads(line))
        except json.JSONDecodeError as e:
            print(f"⚠️ Warning: Skipping malformed JSON on line {line_num}: {e}")
            continue

print(f"✅ Loaded {len(train_dataset)} training samples")

# Load validation split
print("\n📂 Loading validation data from validation_split.jsonl...")
validation_dataset = []
with open(f"{DATASET_FOLDER}/validation_split.jsonl", 'r', encoding='utf-8') as f:
    for line_num, line in enumerate(f, 1):
        try:
            validation_dataset.append(json.loads(line))
        except json.JSONDecodeError as e:
            print(f"⚠️ Warning: Skipping malformed JSON on line {line_num}: {e}")
            continue

print(f"✅ Loaded {len(validation_dataset)} validation samples")

print("\n" + "=" * 60)
print("✅ Datasets ready for DDP training!")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(validation_dataset)}")
print("=" * 60)


Loading Pre-Split Datasets...

📂 Loading training data from train_split.jsonl...
✅ Loaded 67499 training samples

📂 Loading validation data from validation_split.jsonl...
✅ Loaded 7500 validation samples

✅ Datasets ready for DDP training!
Training samples: 67499
Validation samples: 7500


## Finetuning

In [None]:
import torch
from transformers import (
    AutoProcessor,
    Qwen2_5_VLForConditionalGeneration,
    BitsAndBytesConfig,
)
from unsloth import FastVisionModel
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import gc
import numpy as np

torch.cuda.empty_cache()
gc.collect()
torch.cuda.reset_peak_memory_stats()

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
train_dataset_path = "./train_split.jsonl"
validation_dataset_path = "./validation_split.jsonl"
output_dir = "./qwen2.5-vl-3b-price-predictor-best"

print(f"Model: {model_id}")
print(f"Output directory: {output_dir}")

In [None]:
import torch
import torch.distributed as dist
from transformers import TrainerCallback
from PIL import Image
import numpy as np
import gc

class SMAPEEvaluationCallback(TrainerCallback):
    """
    DDP-aware SMAPE evaluation callback.
    - Runs SMAPE computation ONLY on rank 0 (main GPU)
    - All other GPUs wait at a barrier
    - Evaluates every `eval_steps` steps
    - NO early stopping (just monitoring)
    """
    def __init__(self, validation_dataset, processor, eval_steps=150):
        self.validation_dataset = validation_dataset
        self.processor = processor
        self.eval_steps = eval_steps
        self.smape_history = []
        self.best_smape = float('inf')
        
    def on_step_end(self, args, state, control, **kwargs):
        # Only evaluate at specified intervals
        if state.global_step % self.eval_steps != 0:
            return control
        
        # Check if we're in DDP mode
        is_distributed = dist.is_initialized()
        is_main_process = not is_distributed or dist.get_rank() == 0
        
        if is_main_process:
            print(f"\n{'=' * 60}")
            print(f"🎯 SMAPE Evaluation at Step {state.global_step}")
            print(f"{'=' * 60}")
            
            model = kwargs.get("model")
            if model is None:
                print("⚠️ Warning: Model not found in callback kwargs")
                if is_distributed:
                    dist.barrier()  # Still need to sync with other processes
                return control
            
            # Set model to eval mode
            model.eval()
            
            predictions = []
            actuals = []
            
            # Run inference on validation set
            with torch.no_grad():
                for i, example in enumerate(self.validation_dataset):
                    try:
                        # Prepare input
                        messages = example["messages"]
                        text = self.processor.apply_chat_template(
                            messages, tokenize=False, add_generation_prompt=True
                        )
                        
                        # Extract images
                        images = []
                        for msg in messages:
                            if isinstance(msg["content"], list):
                                for item in msg["content"]:
                                    if isinstance(item, dict) and item.get("type") == "image":
                                        img_path = item["image"]
                                        try:
                                            img = Image.open(img_path).convert("RGB")
                                            images.append(img)
                                        except Exception as e:
                                            print(f"⚠️ Error loading image {img_path}: {e}")
                                            continue
                        
                        # Process inputs
                        inputs = self.processor(
                            text=[text],
                            images=[images if images else None],
                            return_tensors="pt",
                            padding=True,
                        ).to(model.device)
                        
                        # Generate prediction
                        generated_ids = model.generate(
                            **inputs,
                            max_new_tokens=20,
                            num_beams=1,
                            do_sample=False,
                            pad_token_id=self.processor.tokenizer.pad_token_id,
                            eos_token_id=self.processor.tokenizer.eos_token_id,
                        )
                        
                        # Decode prediction
                        generated_ids_trimmed = generated_ids[0][len(inputs.input_ids[0]):]
                        prediction = self.processor.decode(
                            generated_ids_trimmed,
                            skip_special_tokens=True,
                            clean_up_tokenization_spaces=False,
                        ).strip()
                        
                        # Parse prediction
                        try:
                            pred_price = float(prediction)
                        except (ValueError, TypeError):
                            pred_price = 0.0
                        
                        # Get actual price from assistant response
                        actual_price = None
                        for msg in messages:
                            if msg["role"] == "assistant":
                                try:
                                    actual_price = float(msg["content"])
                                except (ValueError, TypeError):
                                    continue
                        
                        if actual_price is not None:
                            predictions.append(pred_price)
                            actuals.append(actual_price)
                        
                        # Clean up
                        del inputs, generated_ids
                        
                        # Progress indicator (every 50 samples)
                        if (i + 1) % 50 == 0:
                            print(f"  Evaluated {i + 1}/{len(self.validation_dataset)} samples...")
                            
                    except Exception as e:
                        print(f"⚠️ Error processing sample {i}: {e}")
                        continue
            
            # Calculate SMAPE
            if len(predictions) > 0:
                smape = self._calculate_smape(actuals, predictions)
                self.smape_history.append(smape)
                
                # Update best SMAPE
                if smape < self.best_smape:
                    self.best_smape = smape
                    print(f"\n🎉 New best SMAPE: {smape:.2f}% (improved!)")
                else:
                    print(f"\n📊 Current SMAPE: {smape:.2f}% (best: {self.best_smape:.2f}%)")
                
                print(f"Samples evaluated: {len(predictions)}")
            else:
                print("\n⚠️ No valid predictions - SMAPE not calculated")
            
            # Cleanup memory
            torch.cuda.empty_cache()
            gc.collect()
            
            # Set model back to training mode
            model.train()
            
            print(f"{'=' * 60}\n")
        
        # CRITICAL: Synchronize all processes
        # Non-main processes wait here while rank 0 does SMAPE eval
        if is_distributed:
            dist.barrier()  # All GPUs must reach this point before continuing
            if not is_main_process:
                print(f"[Rank {dist.get_rank()}] ⏸️  Waited for SMAPE evaluation to complete")
        
        return control
    
    def _calculate_smape(self, actual, predicted):
        """Calculate Symmetric Mean Absolute Percentage Error"""
        actual = np.array(actual)
        predicted = np.array(predicted)
        
        numerator = np.abs(predicted - actual)
        denominator = (np.abs(actual) + np.abs(predicted)) / 2
        
        # Avoid division by zero
        mask = denominator != 0
        smape_values = np.zeros_like(numerator)
        smape_values[mask] = numerator[mask] / denominator[mask]
        
        return np.mean(smape_values) * 100


In [None]:
from transformers import TrainerCallback

print("🚀 Loading Qwen2.5-VL-3B with Unsloth optimization...")
model, tokenizer = FastVisionModel.from_pretrained(
    model_name=model_id,
    max_seq_length=384,
    load_in_4bit=True,
    dtype=torch.bfloat16,
    # use_gradient_checkpointing=True,
    use_gradient_checkpointing=False,
    trust_remote_code=True,
)

# Enable training mode
model = FastVisionModel.get_peft_model(
    model,
    r=16,
    lora_alpha=16,
    lora_dropout=0.,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", 
                   "gate_proj", "up_proj", "down_proj"],
    bias="none",
    # use_gradient_checkpointing="unsloth",
    use_gradient_checkpointing=False,
    random_state=42,
)


In [None]:
print("Loading processor...")
min_pix = 256 * 28 * 28
max_pix = 512 * 28 * 28

processor = AutoProcessor.from_pretrained(model_id, min_pixels=min_pix, max_pixels=max_pix, trust_remote_code=True)

if processor.tokenizer.pad_token is None:
    processor.tokenizer.pad_token = processor.tokenizer.eos_token
    processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
processor.tokenizer.padding_side = "right"
print(f"   EOS token: '{processor.tokenizer.eos_token}' (ID: {processor.tokenizer.eos_token_id})")
print(f"   PAD token: '{processor.tokenizer.pad_token}' (ID: {processor.tokenizer.pad_token_id})")
print("Processor loaded successfully!")
print(f"Vocab size: {len(processor.tokenizer)}")
print(f"   Min pixels: {min_pix:,} (~256x256 after resize)")
print(f"   Max pixels: {max_pix:,} (~512x512 after resize)")

In [None]:
def collate_fn(examples):
    """
    Custom data collator for Qwen2-VL multimodal inputs.
    
    What it does:
    1. Applies chat template to convert messages to text format
    2. Extracts image paths from message structure
    3. Processes both text and images together with the processor
    4. Creates labels for causal language modeling
    """
    texts = [
        processor.apply_chat_template(
            example["messages"], 
            tokenize=False,
            add_generation_prompt=False
        )
        for example in examples
    ]

    image_inputs = []
    for example in examples:
        images = []
        for msg in example["messages"]:
            if isinstance(msg["content"], list):
                for item in msg["content"]:
                    if isinstance(item, dict) and item.get("type") == "image":
                        try:
                            img = Image.open(item["image"]).convert("RGB")
                            images.append(img)
                        except Exception as e:
                            print(f"⚠️ Error loading image: {e}")
                            continue
        image_inputs.append(images if images else None)

    batch = processor(
        text=texts,
        images=image_inputs,
        return_tensors="pt",
        padding=True,
    )

    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    batch["labels"] = labels
    
    return batch

print("Custom collator function defined!")


In [None]:
# Training configuration using SFTConfig
training_args = SFTConfig(
    # Basic settings
    output_dir=output_dir,
    num_train_epochs=3,
    # Batch size (A100-optimized)
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,  # Effective batch = 16
    # Memory optimization
    gradient_checkpointing=False,
    # A100-specific optimizations
    optim="adamw_torch_fused",
    tf32=True,  # A100 TensorFloat-32
    bf16=True,
    # Training hyperparameters
    learning_rate=2e-4,
    weight_decay=0.001,
    max_grad_norm=0.3,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    # Logging
    logging_steps=25,
    logging_first_step=True,
    # Evaluation & checkpointing
    eval_strategy="no",
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,
    load_best_model_at_end=False,
    # VLM-specific (CRITICAL!)
    remove_unused_columns=False,
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
    max_length=384,
    packing=False,
    # A100 performance
    dataloader_num_workers=16,
    dataloader_pin_memory=True,
    dataloader_persistent_workers=True,
    dataloader_prefetch_factor=8,
    # Reproducibility
    torch_empty_cache_steps=20,
    seed=42,
    data_seed=42,
    ddp_find_unused_parameters=False,
)

print("Training configuration:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(
    f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}"
)
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  Mixed precision: bf16={training_args.bf16}")


In [None]:
class MemoryCleanupCallback(TrainerCallback):
    """Clean up GPU memory between epochs"""
    def on_epoch_end(self, args, state, control, **kwargs):
        torch.cuda.empty_cache()
        gc.collect()
        print("\n🧹 GPU memory cleaned up")
        return control


In [None]:
import os
import shutil
import json
import signal
import torch.distributed as dist
from transformers import TrainerCallback

class PeriodicCheckpointCallback(TrainerCallback):
    """
    DDP-safe periodic checkpoint callback.
    - Saves every N steps (step-based)
    - Only rank 0 saves the checkpoint
    - Other ranks wait at barrier
    - Handles keyboard interrupt gracefully
    """
    def __init__(self, save_steps=250, checkpoint_dir="periodic_checkpoint"):
        self.save_steps = save_steps
        self.checkpoint_dir = checkpoint_dir
        self.last_save_step = 0
        self.saving_in_progress = False
        
    def on_step_end(self, args, state, control, **kwargs):
        # Check if it's time to save
        if state.global_step - self.last_save_step < self.save_steps:
            return control
        
        # Check if we're in DDP mode
        is_distributed = dist.is_initialized()
        is_main_process = not is_distributed or dist.get_rank() == 0
        
        if is_main_process:
            self.saving_in_progress = True  # Flag to prevent Ctrl+C during save
            
            checkpoint_path = os.path.join(args.output_dir, self.checkpoint_dir)
            
            # Remove old periodic checkpoint
            if os.path.exists(checkpoint_path):
                print(f"\n🔄 Replacing periodic checkpoint at step {state.global_step}")
                shutil.rmtree(checkpoint_path)
            else:
                print(f"\n💾 Creating periodic checkpoint at step {state.global_step}")
            
            # Get model from kwargs
            model = kwargs.get("model")
            if model is None:
                print("⚠️ Warning: Model not found in callback kwargs, skipping checkpoint")
                self.saving_in_progress = False
                if is_distributed:
                    dist.barrier()
                return control
            
            try:
                # Save model checkpoint
                model.save_pretrained(checkpoint_path)
                
                # Save trainer state
                state_path = os.path.join(checkpoint_path, "trainer_state.json")
                with open(state_path, 'w') as f:
                    state_dict = {
                        'epoch': state.epoch,
                        'global_step': state.global_step,
                        'max_steps': state.max_steps,
                        'num_train_epochs': state.num_train_epochs,
                        'log_history': state.log_history[-10:],  # Keep last 10 entries
                        'best_metric': state.best_metric,
                        'best_model_checkpoint': state.best_model_checkpoint,
                    }
                    json.dump(state_dict, f, indent=2)
                
                print(f"✅ Periodic checkpoint saved (step {state.global_step})")
                self.last_save_step = state.global_step
                
            except Exception as e:
                print(f"❌ Error saving checkpoint: {e}")
            finally:
                self.saving_in_progress = False
        
        # CRITICAL: Synchronize all processes
        if is_distributed:
            dist.barrier()  # All GPUs wait until rank 0 finishes saving
            if not is_main_process:
                print(f"[Rank {dist.get_rank()}] ⏸️  Waited for checkpoint save to complete")
        
        return control


In [None]:
import signal
import sys

class GracefulInterruptHandler:
    """
    Handles Ctrl+C gracefully - waits if checkpoint is being saved
    """
    def __init__(self, checkpoint_callback):
        self.checkpoint_callback = checkpoint_callback
        self.interrupted = False
        signal.signal(signal.SIGINT, self.handle_interrupt)
    
    def handle_interrupt(self, signum, frame):
        if self.checkpoint_callback.saving_in_progress:
            print("\n⚠️  Checkpoint save in progress... Please wait!")
            print("(Forcing exit may corrupt the checkpoint)")
            return
        
        self.interrupted = True
        print("\n🛑 Keyboard interrupt received. Stopping training gracefully...")
        print("(Last periodic checkpoint is safe to use)")
        sys.exit(0)

In [None]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=None,
    data_collator=collate_fn,
    processing_class=processor.tokenizer,
)
trainer.add_callback(MemoryCleanupCallback())

# ADD PERIODIC CHECKPOINT CALLBACK (Always enabled!)
periodic_checkpoint = PeriodicCheckpointCallback(
    save_steps=250,  # Save every 10 steps
    checkpoint_dir="periodic_checkpoint"  # Single rolling checkpoint
)
trainer.add_callback(periodic_checkpoint)
interrupt_handler = GracefulInterruptHandler(periodic_checkpoint)
smape_callback = SMAPEEvaluationCallback(
    validation_dataset=validation_dataset,
    processor=processor,
    eval_steps=150  # Evaluate every 150 steps
)
trainer.add_callback(smape_callback)


print("✅ Periodic checkpoint callback added (every 10 steps)")

print("\n" + "=" * 50)
print("Trainer initialized successfully!")
print("=" * 50)

In [None]:
if not INFERENCE_ONLY:    
    print("\n" + "=" * 50)
    print("Starting Fine-Tuning...")
    print("=" * 50 + "\n")

    if RESUME_FROM_CHECKPOINT:
        print(f"🔄 Resuming from checkpoint: {RESUME_FROM_CHECKPOINT}")
    
    trainer.train(resume_from_checkpoint=RESUME_FROM_CHECKPOINT)

    print("\n" + "=" * 50)
    print("Training Complete!")
    if smape_callback:
        print(f"Best SMAPE achieved: {smape_callback.best_smape:.2f}%")
        print(f"SMAPE history: {smape_callback.smape_history}")
    print("=" * 50)

    print("\n" + "=" * 60)
    print("TRAINING SUMMARY")
    print("=" * 60)
    print(f"Total epochs: {trainer.state.epoch}")
    print(f"Total steps: {trainer.state.global_step}")
    print(f"Best checkpoint: {trainer.state.best_model_checkpoint}")

    if smape_callback:
        print(f"\nSMAPE Performance:")
        print(f"  Initial SMAPE: {smape_callback.smape_history[0]:.2f}%")
        print(f"  Final SMAPE: {smape_callback.smape_history[-1]:.2f}%")
        print(f"  Best SMAPE: {smape_callback.best_smape:.2f}%")
        print(
            f"  Improvement: {smape_callback.smape_history[0] - smape_callback.best_smape:.2f}%"
        )
        print(f"\n  Epoch-by-epoch SMAPE:")
        for i, smape in enumerate(smape_callback.smape_history, 1):
            print(f"    Epoch {i}: {smape:.2f}%")

    print("=" * 60)

    print("\n" + "=" * 50)
    print("Saving the Best Model...")
    print("=" * 50)

    trainer.save_model(output_dir)

    processor.save_pretrained(output_dir)

    print(f"\n*** Model and processor saved to: {output_dir}")
    print("\nSaved files:")
    print("  - adapter_config.json (LoRA configuration)")
    print("  - adapter_model.safetensors (LoRA weights)")
    print("  - tokenizer files")
    print("  - processor configuration")
else:
    print("\n***  Skipping training (INFERENCE_ONLY=True)")

In [None]:
from datasets import load_dataset
from tqdm import tqdm
import torch.distributed as dist
import gc

if RUN_INFERENCE:
    is_distributed = dist.is_initialized()
    is_main_process = not is_distributed or dist.get_rank() == 0

    if is_main_process:
        print("\n" + "=" * 60)
        print("Running Inference on Test Set...")
        print("=" * 60)
    
        test_dataset = load_dataset(
            "json", data_files=f"{DATASET_FOLDER}/test.jsonl", split="train"
        )
        print(f"Test samples: {len(test_dataset)}")
    
        checkpoint_path = INFERENCE_CHECKPOINT if INFERENCE_CHECKPOINT else output_dir
        if os.path.exists(f"{checkpoint_path}/adapter_model.safetensors"):
            print(f"\n*** Loading adapter from: {checkpoint_path}")
            from peft import PeftModel
    
            try:
                _ = model.device
            except:
                print("Loading base model...")
                bnb_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.bfloat16,
                )
                model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                    model_id,
                    quantization_config=bnb_config,
                    device_map="auto",
                    trust_remote_code=True,
                )
    
            model = PeftModel.from_pretrained(model, checkpoint_path)
            model.eval()
            print("*** Fine-tuned adapter loaded!")
        else:
            print("*** Using model from training session (no saved adapter found)")
            model.eval()
    
        predictions = []
        sample_ids = []
        batch_size = 16
    
        for batch_start in tqdm(range(0, len(test_dataset), batch_size), desc="Inference"):
            batch_end = min(batch_start + batch_size, len(test_dataset))
            batch_examples = test_dataset[batch_start:batch_end]
    
            # Ensure batch_examples is properly formatted
            if not isinstance(batch_examples["sample_id"], list):
                batch_examples = {k: [v] for k, v in batch_examples.items()}
    
            batch_sample_ids = batch_examples["sample_id"]
            batch_messages = batch_examples["messages"]
    
            # Initialize batch containers
            batch_texts = []
            batch_images = []
    
            # Process all messages in a SINGLE loop (FIX: was doing two loops!)
            for messages in batch_messages:
                # Process text
                text = processor.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=True
                )
                batch_texts.append(text)
                
                # Process images
                images = []
                for msg in messages:
                    if isinstance(msg["content"], list):
                        for item in msg["content"]:
                            if isinstance(item, dict) and item.get("type") == "image":
                                try:
                                    img = Image.open(item["image"]).convert("RGB")
                                    images.append(img)
                                except Exception as e:
                                    print(f"⚠️ Error loading image {item['image']}: {e}")
                                    continue
                
                # Append images list (or None if no images)
                batch_images.append(images if images else None)
    
            # Process inputs
            inputs = processor(
                text=batch_texts,
                images=batch_images,
                return_tensors="pt",
                padding=True,
            ).to("cuda")
    
            # Generate predictions
            with torch.no_grad():
                generated_ids = model.generate(
                    **inputs,
                    max_new_tokens=20,
                    num_beams=1,
                    do_sample=False,
                    temperature=None,
                    top_p=None,
                    pad_token_id=processor.tokenizer.pad_token_id,
                    eos_token_id=processor.tokenizer.eos_token_id,
                )
    
            # Decode predictions
            for i, (in_ids, out_ids) in enumerate(zip(inputs.input_ids, generated_ids)):
                generated_ids_trimmed = out_ids[len(in_ids):]
                prediction = processor.decode(
                    generated_ids_trimmed,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=False,
                ).strip()
    
                try:
                    predicted_price = float(prediction)
                except (ValueError, TypeError):
                    print(
                        f"⚠️ Warning: Non-numeric prediction '{prediction}' for sample {batch_sample_ids[i]}, using 0.0"
                    )
                    predicted_price = 0.0
    
                predictions.append(predicted_price)
                sample_ids.append(batch_sample_ids[i])
            
            # Clean up batch memory
            del inputs, generated_ids
    
            # Progress update
            if batch_end % 500 == 0 or batch_end == len(test_dataset):
                print(f"  Processed {batch_end}/{len(test_dataset)} samples...")
    
        # Create submission DataFrame
        submission_df = pd.DataFrame({"sample_id": sample_ids, "price": predictions})
        submission_path = "test_out.csv"
    
        # Validate submission format
        print("\n✅ Validating submission format...")
        assert "sample_id" in submission_df.columns, "Missing sample_id column"
        assert "price" in submission_df.columns, "Missing price column"
        assert len(submission_df) == len(test_dataset), (
            f"Mismatch: {len(submission_df)} predictions vs {len(test_dataset)} test samples"
        )
        assert submission_df["sample_id"].duplicated().sum() == 0, (
            "Duplicate sample_ids found!"
        )
        print("✅ Submission format validated")
    
        # Check for failed predictions
        failed_predictions = sum(1 for p in predictions if p == 0.0)
        if failed_predictions > len(predictions) * 0.1:
            print(
                f"\n⚠️ WARNING: {failed_predictions}/{len(predictions)} predictions failed (returned 0.0)"
            )
            print("This suggests the model may need more training or better prompts")
    
        # Save submission
        submission_df.to_csv(submission_path, index=False)
    
        print(f"\n{'=' * 60}")
        print(f"✅ Submission saved to: {submission_path}")
        print(f"Total predictions: {len(predictions)}")
        print(f"\nSample predictions:")
        print(submission_df.head(10))
        print(f"{'=' * 60}")
    
        # Final memory cleanup
        del test_dataset, submission_df
        torch.cuda.empty_cache()
        gc.collect()
        print("\n✅ GPU memory cleaned up")
    
    else:
        print(f"[Rank {dist.get_rank()}] Skipping inference (only rank 0 runs inference)")
    
    # CRITICAL: Sync all processes after inference completes
    if is_distributed:
        dist.barrier()
        print(f"[Rank {dist.get_rank() if dist.is_initialized() else 0}] ✅ All processes synchronized after inference")
