In [1]:
# ------------------------------
# Cell 0: Define Configuration
# ------------------------------
class Config:
    dataset_name = "coco_caption"  # or your dataset
    max_samples = 10               # limit for testing
    image_token = "<image>"        # token used in LLaVA prompts

config = Config()
print(f"Set!")


Set!


In [None]:
# Cell 1: Install Required Dependencies
"""
Install all necessary packages for LLaVA fine-tuning with DeepSpeed support.
We need transformers, deepspeed, peft (for LoRA), and other ML libraries.
"""
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install --upgrade --no-cache-dir transformers sentencepiece
!pip install --upgrade transformers sentencepiece
!pip install deepspeed>=0.12.0     
!pip install peft==0.9.0
!pip install accelerate>=0.25.0    
!pip install datasets==3.6.0
!pip install Pillow>=9.0.0         
!pip install requests              
!pip install pycocotools           
!pip install wandb                 
!pip show peft  


[0mLooking in indexes: https://download.pytorch.org/whl/cu118
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2

In [None]:
import sys
required_version = "0.1.99"

try:
    import sentencepiece as spm
    from packaging import version
    if version.parse(spm.__version__) >= version.parse(required_version):
        print(f"SentencePiece version {spm.__version__} is compatible ")
    else:
        print(f"SentencePiece version {spm.__version__} is too old ")
except ImportError:
    print("SentencePiece is not installed ❌")


SentencePiece version 0.2.1 is compatible ✅


In [None]:
# Cell 2: Import Required Libraries
"""
Import all necessary libraries for the fine-tuning process.
Each import serves a specific purpose in our pipeline.
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import transformers
from transformers import (
    AutoTokenizer,                    
    TrainingArguments,              
    Trainer,                         
    EarlyStoppingCallback            
)
from peft import (
    LoraConfig,                      
    get_peft_model,                  
    TaskType,                        
    prepare_model_for_kbit_training  
)
import deepspeed                     
from datasets import load_dataset    
import os
import json
from PIL import Image               
import requests                     
from typing import Dict, List, Any, Optional
import numpy as np
import warnings
warnings.filterwarnings('ignore')  

print("All libraries imported successfully!")

All libraries imported successfully!


In [None]:
# ------------------------------
# Cell 3: Configuration & Hyperparameters
# ------------------------------
class Config:
    
    model_name = "llava-hf/llava-1.5-7b-hf"
    processor_name = "llava-hf/llava-1.5-7b-hf"

    
    dataset_name = "jxie/coco_captions"  
    max_samples = 1000
    image_token = "<image>"

    
    lora_r = 16
    lora_alpha = 32
    lora_dropout = 0.1
    lora_target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ]

   
    output_dir = "./llava-coco-lora"
    num_train_epochs = 3
    per_device_train_batch_size = 2
    per_device_eval_batch_size = 2
    gradient_accumulation_steps = 8
    learning_rate = 2e-4
    warmup_steps = 100
    logging_steps = 50
    save_steps = 500
    eval_steps = 500
    max_grad_norm = 1.0
    dataloader_num_workers = 4

config = Config()
print(f"Configuration loaded. Training on {torch.cuda.device_count()} GPU(s)")

Configuration loaded. Training on 1 GPU(s)


In [None]:
# ------------------------------
# Load LLaVA Processor and Model
# ------------------------------
from transformers import AutoProcessor


processor = AutoProcessor.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    trust_remote_code=True
)


import importlib

llava_module = importlib.import_module("transformers.models.llava.modeling_llava")
LlavaForConditionalGeneration = getattr(llava_module, "LlavaForConditionalGeneration")


model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    trust_remote_code=True,
    device_map="auto"
)

print("Processor and model loaded successfully!")


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Processor and model loaded successfully!


In [None]:

from datasets import load_dataset
from PIL import Image
from torch.utils.data import Dataset
import torch


config.dataset_name = "yerevann/coco-karpathy" 
config.max_samples = None


print(f"Loading dataset: {config.dataset_name}")


try:
    dataset_stream = load_dataset(config.dataset_name, split="train", streaming=True)
    dataset_stream = dataset_stream.take(config.max_samples)
    print("Dataset loaded successfully!")
except Exception as e:
    print(f"Error loading dataset: {e}")
    # Fallback options
    print("Trying alternative dataset...")
    try:
        config.dataset_name = "HuggingFaceM4/COCO"
        dataset_stream = load_dataset(config.dataset_name, split="train", streaming=True)
        dataset_stream = dataset_stream.take(config.max_samples)
        print("Alternative dataset loaded!")
    except Exception as e2:
        print(f"Alternative also failed: {e2}")
        print("You may need to use a different dataset or check your internet connection")

class COCODataset(Dataset):
    def __init__(self, dataset_stream, processor, max_length=512, max_samples=None):
        self.processor = processor
        self.max_length = max_length
        self.max_samples = max_samples
        
        
        if max_samples is not None:
            self.dataset = []
            for i, example in enumerate(dataset_stream):
                if i >= max_samples:
                    break
                self.dataset.append(example)
        else:
            self.dataset = list(dataset_stream)
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        
       
        try:
            if isinstance(sample.get('image'), Image.Image):
                image = sample['image']
            elif 'url' in sample:
                image = Image.new('RGB', (224, 224), color='white')
                print(f"Using placeholder image for sample {idx} (URL: {sample.get('url', 'N/A')})")
            else:
                image = Image.open(sample['image']).convert('RGB')
        except Exception as e:
            print(f"Error loading image at index {idx}: {e}")
            
            image = Image.new('RGB', (224, 224), color='white')
        
    
        if 'caption' in sample:
            caption = sample['caption']
        elif 'sentences' in sample and len(sample['sentences']) > 0:
          
            caption = sample['sentences'][0]
        elif 'text' in sample:
            caption = sample['text']
        else:
           
            caption = "A photo."
            print(f"No caption found for sample {idx}, using fallback")
        
        conversation = [
            {"role": "user", "content": f"{config.image_token}\nDescribe this image."},
            {"role": "assistant", "content": caption}
        ]
        
        text = self.processor.apply_chat_template(
            conversation, 
            tokenize=False, 
            add_generation_prompt=False
        )
        
        inputs = self.processor(
            text=text,
            images=image,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        labels = inputs["input_ids"].clone()
        
        
        user_tokenized = self.processor.tokenizer.apply_chat_template(
            [conversation[0]],
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt"
        )
        user_length = user_tokenized.shape[1]
        labels[:, :user_length] = -100
        
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "labels": labels.squeeze(0)
        }


try:
    train_dataset = COCODataset(dataset_stream, processor, max_samples=config.max_samples)
    print(f"COCO dataset ready with {len(train_dataset)} samples for fine-tuning.")
    
   
    if len(train_dataset) > 0:
        test_sample = train_dataset[0]
        print(f"Sample shapes:")
        for key, value in test_sample.items():
            print(f"  {key}: {value.shape}")
    else:
        print("Warning: Dataset is empty!")
        
except Exception as e:
    print(f"Error creating dataset: {e}")
    print("Please check the dataset format and try again.")

Loading dataset: yerevann/coco-karpathy
Dataset loaded successfully!
Error creating dataset: unsupported operand type(s) for -: 'NoneType' and 'int'
Please check the dataset format and try again.


In [None]:

"""
Configure and apply LoRA (Low-Rank Adaptation) safely.
Ensures modules exist, supports k-bit training, and prepares for checkpoints.
"""
print("Configuring LoRA...")

from peft import LoraConfig, get_peft_model, TaskType

config.lora_r = 8               
config.lora_alpha = 16          
config.lora_dropout = 0.05      
config.lora_target_modules = ["q_proj", "v_proj"]  



lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=config.lora_r,
    lora_alpha=config.lora_alpha,
    lora_dropout=config.lora_dropout,
    target_modules=config.lora_target_modules,
    bias="none",
    inference_mode=False,
)


valid_modules = [name for name, _ in model.named_modules()]
missing = [m for m in lora_config.target_modules if not any(m in v for v in valid_modules)]
if missing:
    print(f"Warning: Some LoRA target_modules not found in model: {missing}")
    print(f"Available modules (sample): {valid_modules[:50]}")

    lora_config.target_modules = [m for m in lora_config.target_modules if m in valid_modules]


model = get_peft_model(model, lora_config)


trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,}")
print(f"Total parameters: {total_params:,}")
print(f"Percentage trainable: {100 * trainable_params / total_params:.2f}%")


model.gradient_checkpointing_enable()
print("LoRA applied successfully and model ready for training!")


Configuring LoRA...
Trainable parameters: 4,980,736
Total parameters: 7,068,407,808
Percentage trainable: 0.07%
LoRA applied successfully and model ready for training!


In [None]:

"""
Create or load cached training/validation datasets and define the data collator.
This avoids reprocessing the dataset on restart and ensures reproducibility.
"""
import pickle
import json
print("Preparing datasets...")

cache_dir = os.path.join(config.output_dir, "dataset_cache")
os.makedirs(cache_dir, exist_ok=True)
train_cache_path = os.path.join(cache_dir, "train_dataset.pkl")
val_cache_path = os.path.join(cache_dir, "val_dataset.pkl")
cache_info_path = os.path.join(cache_dir, "cache_info.json")


if processor.tokenizer.pad_token_id is None:
    processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
    print(f"pad_token_id was None, set to eos_token_id ({processor.tokenizer.eos_token_id})")


cache_valid = False
if os.path.exists(train_cache_path) and os.path.exists(val_cache_path) and os.path.exists(cache_info_path):
    try:
        with open(cache_info_path, "r") as f:
            cache_info = json.load(f)

        if (cache_info.get("dataset_name") == config.dataset_name and 
            cache_info.get("max_samples") == config.max_samples):
            cache_valid = True
            print("Found valid cached datasets...")
    except:
        print("Cache info corrupted, will regenerate...")

if cache_valid:
    print("Loading cached datasets...")
    with open(train_cache_path, "rb") as f:
        train_dataset = pickle.load(f)
    with open(val_cache_path, "rb") as f:
        val_dataset = pickle.load(f)
else:
    print("No valid cached dataset found. Processing datasets from scratch...")
    

    dataset_stream = load_dataset(config.dataset_name, split="train", streaming=True)
    
  
    full_dataset = COCODataset(
        dataset_stream, 
        processor, 
        max_length=512, 
        max_samples=config.max_samples
    )
    
  
    dataset_size = len(full_dataset)
    val_size = max(1, int(0.1 * dataset_size))  
    train_size = dataset_size - val_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    

    try:
        with open(train_cache_path, "wb") as f:
            pickle.dump(train_dataset, f)
        with open(val_cache_path, "wb") as f:
            pickle.dump(val_dataset, f)
        

        cache_info = {
            "dataset_name": config.dataset_name,
            "max_samples": config.max_samples,
            "train_size": train_size,
            "val_size": val_size
        }
        with open(cache_info_path, "w") as f:
            json.dump(cache_info, f)
            
        print("Datasets processed and cached successfully!")
    except Exception as e:
        print(f"Warning: Could not cache datasets: {e}")

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")


class LLaVADataCollator:
    """
    Custom data collator for LLaVA training.
    Handles proper batching of images, text, and labels.
    """
    def __init__(self, processor):
        self.processor = processor
    
    def __call__(self, batch):
  
        input_ids = [item["input_ids"] for item in batch]
        attention_masks = [item["attention_mask"] for item in batch]
        pixel_values = [item["pixel_values"] for item in batch]
        labels = [item["labels"] for item in batch]
        
 
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id
        )
        attention_masks = torch.nn.utils.rnn.pad_sequence(
            attention_masks, batch_first=True, padding_value=0
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=-100
        )
        
      
        pixel_values = torch.stack(pixel_values)
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_masks,
            "pixel_values": pixel_values,
            "labels": labels
        }

data_collator = LLaVADataCollator(processor)
print("Data collator created successfully!")


if len(train_dataset) > 0:
    try:
        sample_batch = [train_dataset[0], train_dataset[0]]
        test_batch = data_collator(sample_batch)
        print(f"Data collator test successful. Batch shapes:")
        for key, value in test_batch.items():
            print(f"  {key}: {value.shape}")
    except Exception as e:
        print(f"Warning: Data collator test failed: {e}")

Preparing datasets...
No valid cached dataset found. Processing datasets from scratch...
Datasets processed and cached successfully!
Training samples: 74505
Validation samples: 8278
Data collator created successfully!
Using placeholder image for sample 81069 (URL: http://images.cocodataset.org/train2014/COCO_train2014_000000397355.jpg)
Using placeholder image for sample 81069 (URL: http://images.cocodataset.org/train2014/COCO_train2014_000000397355.jpg)
Data collator test successful. Batch shapes:
  input_ids: torch.Size([2, 10])
  attention_mask: torch.Size([2, 10])
  pixel_values: torch.Size([2, 3, 336, 336])
  labels: torch.Size([2, 10])


In [12]:
# ------------------------------
# Cell: Check Transformers & SentencePiece Versions
# ------------------------------
import transformers
import sentencepiece
import torch

print(f"Transformers version: {transformers.__version__}")
print(f"SentencePiece version: {sentencepiece.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
print(f"bfloat16 supported: {torch.cuda.is_bf16_supported()}")


Transformers version: 4.56.1
SentencePiece version: 0.2.1
PyTorch version: 2.8.0.dev20250319+cu128
CUDA available: True
CUDA device count: 1
bfloat16 supported: True


In [None]:
# Cell 8: Training Arguments Configuration (Restart-Safe, FP16, No DeepSpeed)
from transformers import TrainingArguments
import torch

print("Setting up training arguments...")


defaults = {
    "per_device_train_batch_size": 4,
    "per_device_eval_batch_size": 4,
    "gradient_accumulation_steps": 1,
    "learning_rate": 5e-5,
    "warmup_steps": 0,
    "max_grad_norm": 1.0,
    "dataloader_num_workers": 0,
    "num_train_epochs": 3,
    "logging_steps": 10,
    "eval_steps": 50,
    "save_steps": 50,
    "output_dir": "./output"
}


for key, value in defaults.items():
    if not hasattr(config, key):
        setattr(config, key, value)
        print(f"Set default {key}: {value}")

training_args = TrainingArguments(
    output_dir=config.output_dir,
    num_train_epochs=config.num_train_epochs,
    per_device_train_batch_size=config.per_device_train_batch_size,
    per_device_eval_batch_size=config.per_device_eval_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    learning_rate=config.learning_rate,
    warmup_steps=config.warmup_steps,
    max_grad_norm=config.max_grad_norm,
    weight_decay=0.01,
    logging_dir=f"{config.output_dir}/logs",
    logging_steps=config.logging_steps,
    eval_strategy="steps",
    eval_steps=config.eval_steps,
    save_strategy="steps",
    save_steps=config.save_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    dataloader_num_workers=config.dataloader_num_workers,
    remove_unused_columns=False,
    fp16=True,   
    bf16=False,
    seed=42,
    data_seed=42,
    report_to=None
)

print("✓ TrainingArguments created successfully (Single-GPU, FP16)")


Setting up training arguments...
✓ TrainingArguments created successfully (Single-GPU, FP16)


In [None]:

"""
Create proper dataset and data collator that handles LLaVA's <image> token requirements.
"""
import torch
from torch.utils.data import Dataset
from PIL import Image
import requests
from io import BytesIO

class LLaVADataset(Dataset):
    """Custom dataset that properly formats inputs for LLaVA training."""
    
    def __init__(self, hf_dataset, processor, image_token="<image>", max_length=512):
        self.dataset = hf_dataset
        self.processor = processor
        self.image_token = image_token
        self.max_length = max_length
        

        if self.image_token not in self.processor.tokenizer.vocab:
            self.processor.tokenizer.add_tokens([self.image_token])
            print(f"Added {self.image_token} token to tokenizer")
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        

        if 'image' in item:
            image = item['image']
            if isinstance(image, str):

                response = requests.get(image)
                image = Image.open(BytesIO(response.content)).convert('RGB')
            elif not isinstance(image, Image.Image):
                image = Image.fromarray(image).convert('RGB')
        else:

            image = Image.new('RGB', (224, 224), color='white')
        

        if 'caption' in item:
            caption = item['caption']
        elif 'text' in item:
            caption = item['text']
        else:
            caption = "Describe this image."
        

        conversation_text = f"Human: {self.image_token} {caption}\nAssistant: This is an image showing {caption}"
        

        image_inputs = self.processor.image_processor(image, return_tensors="pt")
        

        text_inputs = self.processor.tokenizer(
            conversation_text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length
        )

        labels = text_inputs["input_ids"].clone()
        

        assistant_start = conversation_text.find("Assistant:")
        if assistant_start != -1:

            assistant_tokens = self.processor.tokenizer(
                conversation_text[:assistant_start], 
                return_tensors="pt"
            )["input_ids"]
            labels[:, :assistant_tokens.shape[1]] = -100
        
        return {
            "pixel_values": image_inputs["pixel_values"].squeeze(0),
            "input_ids": text_inputs["input_ids"].squeeze(0),
            "attention_mask": text_inputs["attention_mask"].squeeze(0),
            "labels": labels.squeeze(0)
        }

class LLaVADataCollator:
    """Data collator that properly handles LLaVA multimodal inputs."""
    
    def __init__(self, tokenizer, pad_token_id=None):
        self.tokenizer = tokenizer
        self.pad_token_id = pad_token_id or tokenizer.pad_token_id or tokenizer.eos_token_id
    
    def __call__(self, batch):

        pixel_values = torch.stack([item["pixel_values"] for item in batch])
        

        max_length = max(item["input_ids"].shape[0] for item in batch)
        
        # Pad input_ids, attention_mask, and labels
        input_ids = []
        attention_mask = []
        labels = []
        
        for item in batch:
            seq_len = item["input_ids"].shape[0]
            pad_length = max_length - seq_len
            

            padded_input_ids = torch.cat([
                item["input_ids"], 
                torch.full((pad_length,), self.pad_token_id, dtype=torch.long)
            ])
            input_ids.append(padded_input_ids)
            

            padded_attention = torch.cat([
                item["attention_mask"],
                torch.zeros(pad_length, dtype=torch.long)
            ])
            attention_mask.append(padded_attention)
            

            padded_labels = torch.cat([
                item["labels"],
                torch.full((pad_length,), -100, dtype=torch.long)
            ])
            labels.append(padded_labels)
        
        return {
            "pixel_values": pixel_values,
            "input_ids": torch.stack(input_ids),
            "attention_mask": torch.stack(attention_mask),
            "labels": torch.stack(labels)
        }


print("Creating fixed LLaVA dataset...")


from datasets import load_dataset
dataset = load_dataset("HuggingFaceM4/COCO", split="train", streaming=False)
small_dataset = dataset.select(range(min(100, len(dataset))))  

train_dataset = LLaVADataset(
    hf_dataset=small_dataset,
    processor=processor,
    image_token="<image>",
    max_length=512
)

data_collator = LLaVADataCollator(
    tokenizer=processor.tokenizer,
    pad_token_id=processor.tokenizer.pad_token_id
)

print(f"Dataset created with {len(train_dataset)} samples")


print("Testing dataset...")
try:
    sample = train_dataset[0]
    print("Sample keys:", sample.keys())
    print("Input IDs shape:", sample["input_ids"].shape)
    print("Pixel values shape:", sample["pixel_values"].shape)
    print("Labels shape:", sample["labels"].shape)
    
  
    decoded_text = processor.tokenizer.decode(sample["input_ids"], skip_special_tokens=False)
    has_image_token = "<image>" in decoded_text
    print(f"Text contains <image> token: {has_image_token}")
    print(f"Sample text: {decoded_text[:200]}...")
    
    
    batch = data_collator([sample])
    print("Batch keys:", batch.keys())
    print("Batch pixel_values shape:", batch["pixel_values"].shape)
    print("Batch input_ids shape:", batch["input_ids"].shape)
    
    if has_image_token:
        print("✓ Dataset properly formatted for LLaVA!")
    else:
        print("⚠ Warning: <image> token not found in text")
        
except Exception as e:
    print(f"Dataset test failed: {e}")
    raise

print("Dataset and data collator ready!")

Creating fixed LLaVA dataset...


README.md: 0.00B [00:00, ?B/s]

COCO.py: 0.00B [00:00, ?B/s]

The repository for HuggingFaceM4/COCO contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/HuggingFaceM4/COCO.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y


Downloading data:   0%|          | 0.00/36.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.5G [00:00<?, ?B/s]

FSTimeoutError: 

In [None]:

"""
Custom collator for image + text inputs.
Handles batching of pixel_values and tokenized inputs for LoRA fine-tuning.
"""
from torch.utils.data import DataLoader
from torch import nn
import torch

class LLaVADataCollator:
    def __init__(self, tokenizer, image_key="pixel_values", text_keys=["input_ids", "attention_mask", "labels"]):
        self.tokenizer = tokenizer
        self.image_key = image_key
        self.text_keys = text_keys

    def __call__(self, batch):
        collated = {}


        collated[self.image_key] = torch.stack([b[self.image_key] for b in batch])


        for key in self.text_keys:
            collated[key] = nn.utils.rnn.pad_sequence(
                [b[key].squeeze(0) for b in batch],
                batch_first=True,
                padding_value=self.tokenizer.pad_token_id
            )

        return collated

# Create instance
data_collator = LLaVADataCollator(tokenizer=processor.tokenizer)
print("✅ Data collator ready!")


✅ Data collator ready!


In [None]:

"""
Initialize standard Hugging Face Trainer for single-GPU LLaVA fine-tuning.
This avoids DeepSpeed and custom trainer classes that can cause conflicts.
"""
print("Initializing standard Hugging Face trainer...")
from torch.utils.data import random_split
from transformers import Trainer, EarlyStoppingCallback
import os


min_val_size = 2
val_size = max(min_val_size, len(train_dataset) // 5)
train_size = len(train_dataset) - val_size


if val_size >= min_val_size and len(train_dataset) > 10:
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
    use_evaluation = True
    print(f"Created split - Training: {len(train_dataset)}, Validation: {len(val_dataset)}")
else:
    val_dataset = None
    use_evaluation = False
    print(f"Dataset too small for validation. Training on {len(train_dataset)} samples only.")


from transformers import TrainingArguments

training_args_clean = TrainingArguments(
    output_dir=getattr(config, 'output_dir', './output'),
    num_train_epochs=getattr(config, 'num_train_epochs', 3),
    per_device_train_batch_size=getattr(config, 'per_device_train_batch_size', 1),
    per_device_eval_batch_size=getattr(config, 'per_device_eval_batch_size', 1),
    gradient_accumulation_steps=getattr(config, 'gradient_accumulation_steps', 4),
    learning_rate=getattr(config, 'learning_rate', 2e-4),
    warmup_steps=getattr(config, 'warmup_steps', 100),
    max_grad_norm=getattr(config, 'max_grad_norm', 1.0),
    weight_decay=0.01,
    logging_dir=f"{getattr(config, 'output_dir', './output')}/logs",
    logging_steps=getattr(config, 'logging_steps', 10),
    save_strategy="steps",
    save_steps=getattr(config, 'save_steps', 100),
    eval_strategy="steps" if use_evaluation else "no",
    eval_steps=getattr(config, 'eval_steps', 100) if use_evaluation else None,
    load_best_model_at_end=use_evaluation,
    metric_for_best_model="eval_loss" if use_evaluation else None,
    greater_is_better=False,
    dataloader_num_workers=0,  # Avoid multiprocessing issues
    remove_unused_columns=False,  # Keep all columns for multimodal data
    fp16=torch.cuda.is_available() and not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
    seed=42,
    data_seed=42,
    report_to=None,  
)

print("Clean training arguments created without DeepSpeed")
print(f"Batch size: {training_args_clean.per_device_train_batch_size}")
print(f"Gradient accumulation: {training_args_clean.gradient_accumulation_steps}")
print(f"Effective batch size: {training_args_clean.per_device_train_batch_size * training_args_clean.gradient_accumulation_steps}")
print(f"Learning rate: {training_args_clean.learning_rate}")
print(f"Evaluation enabled: {use_evaluation}")
# --- Create Custom Trainer Class for LLaVA ---
class LLaVATrainerCustom(Trainer):
    """Custom trainer for LLaVA that properly handles multimodal inputs."""

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """Custom loss computation for LLaVA multimodal training."""
        labels = inputs.get("labels")

        outputs = model(**inputs)

        if labels is not None:
            loss = outputs.loss
        else:
            logits = outputs.logits
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = inputs["input_ids"][..., 1:].contiguous()
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        return (loss, outputs) if return_outputs else loss


trainer = LLaVATrainerCustom(
    model=model,
    args=training_args_clean,
    train_dataset=train_dataset,
    eval_dataset=val_dataset if use_evaluation else None,
    data_collator=data_collator,
    tokenizer=processor.tokenizer,
    callbacks=[
        EarlyStoppingCallback(
            early_stopping_patience=3, 
            early_stopping_threshold=0.01
        )
    ] if use_evaluation else []
)

print("Standard Hugging Face Trainer initialized successfully!")
print(f"Model type: {type(model).__name__}")
print(f"Trainer type: {type(trainer).__name__}")

checkpoint_dir = None
output_dir = training_args_clean.output_dir
if os.path.exists(output_dir):
    checkpoints = [
        os.path.join(output_dir, d) for d in os.listdir(output_dir)
        if os.path.isdir(os.path.join(output_dir, d)) and d.startswith("checkpoint")
    ]
    if checkpoints:
        checkpoint_dir = max(checkpoints, key=os.path.getmtime)
        print(f"Found checkpoint: {checkpoint_dir}")
    else:
        print("No existing checkpoints found.")

if len(train_dataset) == 0:
    raise ValueError("Train dataset is empty!")

print(f"\nTraining setup verified:")
print(f"- Training samples: {len(train_dataset)}")
print(f"- Validation samples: {len(val_dataset) if val_dataset else 0}")
print(f"- Output directory: {output_dir}")

try:
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        device_count = torch.cuda.device_count()
        current_device = torch.cuda.current_device()
        total_memory = torch.cuda.get_device_properties(current_device).total_memory / 1e9
        allocated_memory = torch.cuda.memory_allocated(current_device) / 1e9
        
        print(f"\nGPU Status:")
        print(f"- Available devices: {device_count}")
        print(f"- Current device: {current_device}")
        print(f"- Total memory: {total_memory:.1f}GB")
        print(f"- Allocated memory: {allocated_memory:.1f}GB")
        print(f"- Free memory: {total_memory - allocated_memory:.1f}GB")
    
    print("\n" + "="*50)
    print("STARTING TRAINING")
    print("="*50)
    
    train_result = trainer.train(resume_from_checkpoint=checkpoint_dir)
    
    metrics = train_result.metrics
    trainer.save_metrics("train", metrics)
    trainer.save_state()
    
    print("\n" + "="*50)
    print("TRAINING COMPLETED SUCCESSFULLY!")
    print("="*50)
    
    print(f"Final training loss: {metrics.get('train_loss', 'N/A'):.4f}")
    print(f"Training runtime: {metrics.get('train_runtime', 'N/A'):.2f} seconds")
    print(f"Training samples per second: {metrics.get('train_samples_per_second', 'N/A'):.2f}")
    
    if use_evaluation and 'eval_loss' in metrics:
        print(f"Final validation loss: {metrics.get('eval_loss', 'N/A'):.4f}")
    
    trainer.save_model()
    processor.save_pretrained(output_dir) 
    
    print(f"\nModel and processor saved to: {output_dir}")
    
except RuntimeError as e:
    error_str = str(e)
    print(f"\nRuntimeError occurred: {error_str}")
    
    if "CUDA out of memory" in error_str:
        print("\nCUDA OUT OF MEMORY SOLUTIONS:")
        print("1. Reduce per_device_train_batch_size (current: {})".format(training_args_clean.per_device_train_batch_size))
        print("2. Increase gradient_accumulation_steps (current: {})".format(training_args_clean.gradient_accumulation_steps))
        print("3. Enable gradient checkpointing in model")
        print("4. Use a smaller model or reduce max_length")
        
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated(0) / 1e9
            reserved = torch.cuda.memory_reserved(0) / 1e9
            print(f"\nGPU Memory at error:")
            print(f"- Allocated: {allocated:.1f}GB")
            print(f"- Reserved: {reserved:.1f}GB")
    raise
    
except Exception as e:
    print(f"\nTraining failed: {type(e).__name__}: {e}")
    raise

print("\nTraining process completed!")

Initializing standard Hugging Face trainer...
Created split - Training: 59603, Validation: 14900
Clean training arguments created without DeepSpeed
Batch size: 2
Gradient accumulation: 8
Effective batch size: 16
Learning rate: 0.0002
Evaluation enabled: True
Standard Hugging Face Trainer initialized successfully!
Model type: LlavaForConditionalGeneration
Trainer type: LLaVATrainerCustom
No existing checkpoints found.

Training setup verified:
- Training samples: 59603
- Validation samples: 14900
- Output directory: ./llava-coco-lora

GPU Status:
- Available devices: 1
- Current device: 0
- Total memory: 85.1GB
- Allocated memory: 28.3GB
- Free memory: 56.8GB

STARTING TRAINING
Using placeholder image for sample 23355 (URL: http://images.cocodataset.org/train2014/COCO_train2014_000000574217.jpg)
Using placeholder image for sample 68015 (URL: http://images.cocodataset.org/train2014/COCO_train2014_000000539296.jpg)
Using placeholder image for sample 9866 (URL: http://images.cocodataset.or

ValueError: Image features and image tokens do not match: tokens: 0, features 4718592

In [None]:
"""
Save the fine-tuned LLaVA model, LoRA adapters, and processor.
Ensures restart-safe reloads without DeepSpeed.
"""
print("Saving fine-tuned model...")

os.makedirs(config.output_dir, exist_ok=True)

trainer.save_model(config.output_dir)

processor.save_pretrained(config.output_dir)

adapter_dir = os.path.join(config.output_dir, "lora_adapter")
os.makedirs(adapter_dir, exist_ok=True)
model.save_pretrained(adapter_dir)

print(f"✅ Model and LoRA adapter saved under {config.output_dir}")


In [None]:
"""
Evaluate the fine-tuned model and test on a sample image.
Safe token slicing and device handling.
"""
print("Running evaluation...")

if val_dataset is not None and len(val_dataset) > 0:
    eval_results = trainer.evaluate()
    print("Evaluation Results:")
    for key, value in eval_results.items():
        if isinstance(value, (int, float)):
            print(f"{key}: {value:.4f}")
else:
    print("No validation dataset provided, skipping evaluation.")

if val_dataset is not None and len(val_dataset) > 0:
    print("\nTesting model on a sample image...")
    sample_idx = 0
    sample = val_dataset[sample_idx]

    with torch.no_grad():
        pixel_values = sample["pixel_values"].unsqueeze(0).to(model.device)
        input_text = f"{config.image_token}\nDescribe this image."

        inputs = processor(
            text=input_text,
            images=None,  
            return_tensors="pt",
            padding=True
        )

        inputs = {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
        inputs["pixel_values"] = pixel_values

       
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id
        )

       
        generated_text = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
        input_prompt = processor.tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
        generated_response = generated_text[len(input_prompt):].strip()

        print(f"Generated description: {generated_response}")
else:
    print("No validation samples available for testing.")


In [None]:


if torch.cuda.is_available():
    torch.cuda.empty_cache()


print("="*60)
print("FINE-TUNING SUMMARY")
print("="*60)
print(f"✓ Model: {config.model_name}")
print(f"✓ Dataset: {getattr(config, 'dataset_name', 'custom')} ({getattr(config, 'max_samples', len(train_dataset))} samples)")
print(f"✓ Training method: LoRA (no DeepSpeed)")
if 'trainable_params' in globals() and 'total_params' in globals():
    print(f"✓ Trainable parameters: {trainable_params:,} ({100 * trainable_params / total_params:.2*_
