In [None]:
# Install required libraries
!pip install transformers accelerate datasets torch torchvision pillow

# Import necessary libraries
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TrainingArguments, Trainer
from datasets import load_dataset
from PIL import Image
import torch

# Load the pretrained model and processor
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    torch_dtype="auto",
    device_map="auto",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

# Load the dataset
ds = load_dataset("mdwiratathya/ROCO-radiology")

# Debugging function to inspect dataset integrity
def check_data_integrity(dataset):
    for idx, example in enumerate(dataset):
        try:
            assert isinstance(example["image"], Image.Image), f"Invalid image format at index {idx}"
            assert isinstance(example["caption"], str), f"Invalid caption format at index {idx}"
        except AssertionError as e:
            print(f"Dataset integrity check failed: {e}")
            return False
    print("All dataset entries are valid.")
    return True

# Verify dataset integrity
if not check_data_integrity(ds["train"]):
    raise ValueError("Dataset contains invalid entries.")

# Preprocessing function with robust error handling
def preprocess_data(example):
    try:
        # Resize image
        image = example["image"].resize((128, 128))
        caption = example["caption"]

        # Prepare text prompt
        text_prompt = processor.apply_chat_template(
            [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": caption},
                    ],
                }
            ],
            add_generation_prompt=True,
        )

        # Tokenize and prepare inputs
        inputs = processor(
            text=[text_prompt],
            images=[image],
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt",
        )
        inputs["labels"] = inputs["input_ids"]

        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "pixel_values": inputs["pixel_values"].squeeze(),
            "labels": inputs["labels"].squeeze(),
        }
    except Exception as e:
        print(f"Error processing example: {example}, {e}")
        raise

# Preprocess dataset without multiprocessing
try:
    print("Preprocessing train dataset...")
    train_data = ds["train"].map(preprocess_data, batched=False)
    print("Preprocessing validation dataset...")
    val_data = ds["validation"].map(preprocess_data, batched=False)
    print("Dataset preprocessing completed.")
except Exception as e:
    print(f"Error during dataset preprocessing: {e}")
    raise

# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=4,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    num_train_epochs=3,
    logging_dir="./logs",
    logging_steps=500,
    save_total_limit=2,
    load_best_model_at_end=True,
    fp16=True,
)

# Trainer setup
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    tokenizer=processor,
)

# Train the model
try:
    print("Training model...")
    trainer.train()
    print("Model training completed.")
except Exception as e:
    print(f"Error during training: {e}")
    raise

# Save the fine-tuned model and processor
model.save_pretrained("./fine_tuned_qwen2vl")
processor.save_pretrained("./fine_tuned_qwen2vl")

# Test the fine-tuned model
test_example = ds["test"][0]
test_image = test_example["image"].resize((128, 128))
test_caption = test_example["caption"]

# Prepare inputs for testing
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": test_caption},
        ],
    }
]
text_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
    text=[text_prompt],
    images=[test_image],
    padding=True,
    truncation=True,
    max_length=512,
    return_tensors="pt",
).to("cuda")

# Generate output
output_ids = model.generate(**inputs, max_new_tokens=1024)
output_text = processor.batch_decode(
    output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)

# Print generated response
print("Generated Response:", output_text[0])