In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    LlavaForConditionalGeneration,
    LlavaProcessor,
    TrainingArguments,
    Trainer
)
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

class CustomLlavaDataset(Dataset):
    def __init__(self, image_paths, questions, answers, processor):
        self.image_paths = image_paths
        self.questions = questions
        self.answers = answers
        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')

        # Format the conversation
        text = f"USER: {self.questions[idx]}\nASSISTANT: {self.answers[idx]}"

        # Process image and text separately
        vision_x = self.processor.image_processor(image, return_tensors="pt")
        language_x = self.processor.tokenizer(
            text,
            return_tensors="pt",
            padding="max_length",
            max_length=512,
            truncation=True
        )

        # Combine into single dict and remove batch dimension
        inputs = {
            "pixel_values": vision_x.pixel_values.squeeze(0),
            "input_ids": language_x.input_ids.squeeze(0),
            "attention_mask": language_x.attention_mask.squeeze(0),
        }

        return inputs

def train_llava(image_paths, questions, answers, output_dir="./llava_finetuned", num_epochs=3):
    # Initialize model and processor
    model_id = "llava-hf/llava-1.5-7b-hf"
    processor = LlavaProcessor.from_pretrained(model_id)

    # Set device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Load model
    model = LlavaForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        low_cpu_mem_usage=True
    ).to(device)

    # Enable gradient checkpointing
    model.gradient_checkpointing_enable()

    # Prepare dataset
    dataset = CustomLlavaDataset(image_paths, questions, answers, processor)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        learning_rate=2e-5,
        logging_steps=10,
        save_strategy="epoch",
        fp16=device=="cuda",
        optim="adamw_torch",
        gradient_checkpointing=True,
        no_cuda=device=="cpu"
    )

    # Custom data collator
    def collate_fn(batch):
        collated = {
            'pixel_values': torch.stack([x['pixel_values'] for x in batch]).to(device),
            'input_ids': torch.stack([x['input_ids'] for x in batch]).to(device),
            'attention_mask': torch.stack([x['attention_mask'] for x in batch]).to(device),
        }
        return collated

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=collate_fn
    )

    # Train
    trainer.train()
    trainer.save_model()

    return model, processor

def generate_response(model, processor, image_path, question, max_length=128):
    device = next(model.parameters()).device
    image = Image.open(image_path).convert('RGB')
    prompt = f"USER: {question}\nASSISTANT:"

    # Process image and text separately
    vision_x = processor.image_processor(image, return_tensors="pt")
    language_x = processor.tokenizer(
        prompt,
        return_tensors="pt",
        padding=True
    )

    # Combine and move to device
    inputs = {
        "pixel_values": vision_x.pixel_values.to(device),
        "input_ids": language_x.input_ids.to(device),
        "attention_mask": language_x.attention_mask.to(device),
    }

    outputs = model.generate(
        **inputs,
        max_length=max_length,
        num_beams=4,
        temperature=0.8,
        do_sample=True
    )

    response = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response.split("ASSISTANT: ")[-1]


In [2]:
# Example usage
if __name__ == "__main__":
    # Sample data (replace with your own.
    image_paths = ["./sample_img/red_car.png", "./sample_img/palm_beach.png"]
    questions = ["What is in this image?", "Describe this scene."]
    answers = ["A red car parked on the street.", "A sunny beach with palm trees."]

    # Install required packages in Colab
    """
    !pip install -q transformers accelerate safetensors sentencepiece
    !pip install -q git+https://github.com/huggingface/transformers.git
    """

    # Train
    model, processor = train_llava(image_paths, questions, answers)

    # Test
    test_image = "./sample_img/red_car.png"
    test_question = "What do you see in this image?"
    response = generate_response(model, processor, test_image, test_question)
    print(f"Generated response: {response}")

preprocessor_config.json:   0%|          | 0.00/505 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.45k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.62M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/41.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/701 [00:00<?, ?B/s]

Using device: cuda


config.json:   0%|          | 0.00/950 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/70.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.18G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


RuntimeError: cannot pin 'torch.cuda.FloatTensor' only dense CPU tensors can be pinned