In [None]:
import os
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import pandas as pd
from tqdm import tqdm
import textwrap

import matplotlib.pyplot as plt
#from typing import List, Dict, Any, Optional
import warnings
warnings.filterwarnings('ignore')

print(torch.cuda.is_available())   # Should be True
print(torch.cuda.device_count())   # Should be >= 1
print(torch.cuda.get_device_name(0))  # Should print your GPU's name
print(torch.version.cuda)   # Should show CUDA version
print(torch.backends.cudnn.version())  # Should print a version number

# Fix SSL certificate issues
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# Hugging Face imports
from transformers import (
    AutoProcessor, 
    AutoModelForVisualQuestionAnswering,
    Trainer,
    TrainingArguments,
    VisionEncoderDecoderModel,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModel
)
import open_clip
from open_clip import create_model_from_pretrained, get_tokenizer
from datasets import Dataset, load_dataset, DatasetDict
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
import evaluate

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Load METEOR metric
meteor = evaluate.load('meteor')

### Fine-Tune BLIP Model

In [None]:
processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = AutoModelForVisualQuestionAnswering.from_pretrained(
    "Salesforce/blip-vqa-base", 
    torch_dtype=torch.float16,
    device_map="auto"
)

# print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params:,}")

In [None]:
# Load Path-VQA dataset from Hugging Face
dataset = load_dataset("flaviagiammarino/path-vqa")

print(f"Train: {len(dataset['train'])} samples")
print(f"Validation: {len(dataset['validation'])} samples")
print(f"Test: {len(dataset['test'])} samples")

train_dataset = dataset['train']
val_dataset = dataset['validation']
test_dataset = dataset['test']

In [None]:
from torch.utils.data import Dataset

class BlipDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]  # always int index!
        img = item["image"].convert('RGB')
        encoding = self.processor(
            images=img,
            padding="max_length",
            return_tensors="pt",
            truncation=True
        )
        return {
            "pixel_values": {k: v.squeeze() for k, v in encoding.items()}['pixel_values'],  # shape (1, 3, H, W)
            "question": item["question"],
            "answer": item["answer"]
        }

In [None]:
train_indices = random.sample(range(len(train_dataset)), 8000)
val_indices = random.sample(range(len(val_dataset)), 800)

train_list = [train_dataset[i] for i in train_indices]
val_list   = [val_dataset[i] for i in val_indices]

train_ds = BlipDataset(train_list, processor)
val_ds   = BlipDataset(val_list, processor)

max_q = max(len(item["question"].split()) for item in train_list)
max_a = max(len(item["answer"].split()) for item in train_list)

print(f"Maximum character length of training question: {max_q}")
print(f"Maximum character length of training answer: {max_a}")

In [None]:
# for name, module in model.named_modules():
#     print(name)

In [None]:
lora_config = LoraConfig(
    task_type=None,
    #inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=[
        # vision encoder, fused QKV + output
        "self_attn.qkv",
        "self_attn.projection",

        # text encoder self-attention
        "attention.self.query",
        "attention.self.key",
        "attention.self.value",
        "attention.output.dense",

        # text decoder cross-attention
        "crossattention.self.query",
        "crossattention.self.key",
        "crossattention.self.value",
        "crossattention.output.dense",
    ]
)

model = get_peft_model(model, lora_config)
model.to(device)
model.print_trainable_parameters()

In [None]:
def collate_fn(batch):
    # process image 
    pixel_values = torch.stack([item["pixel_values"] for item in batch])

    # process question
    question_inputs = processor.tokenizer(
        [item["question"] for item in batch],
        padding="max_length",
        truncation=True,
        return_tensors="pt",
        max_length=max(max_q, max_a)
    )

    # process answer
    answer_inputs = processor.tokenizer(
        [item["answer"] for item in batch],
        padding="max_length",
        truncation=True,
        return_tensors="pt",
        max_length=max(max_q, max_a)
    )

    output = {
        "pixel_values": pixel_values,
        "input_ids": question_inputs["input_ids"],
        "attention_mask": question_inputs["attention_mask"],
        "labels": answer_inputs["input_ids"]#.clone()
 
    }

    return output

In [None]:
training_args = TrainingArguments(
    output_dir=".\model",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    learning_rate=5e-5,
    warmup_steps=50,
    dataloader_num_workers=0,
    logging_dir=".\logs",
    logging_steps=100, 
    report_to=["tensorboard"],  
    eval_strategy="steps",
    fp16=True,                   
    remove_unused_columns=False, 
    label_names= ["labels"]
)

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if "num_items_in_batch" in inputs:
            inputs.pop("num_items_in_batch")
        return super().compute_loss(model, inputs, return_outputs=return_outputs)


trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
)

In [None]:
print("Start training...")
trainer.train()

### Fine-Tune BioMedClip + Mistral 7B

In [None]:
dataset = load_dataset("YYYWei/path-va")
train_dataset = dataset['train']
val_dataset = dataset['validation']
test_dataset = dataset['test']

In [None]:
from huggingface_hub import login
login()

In [None]:
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(load_in_8bit=True)

clip_model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
    'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224',
    device=device
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",           # Best quality for 4bit
    bnb_4bit_compute_dtype=torch.bfloat16, # Use bfloat16 math for speed & stability
    bnb_4bit_use_double_quant=False
)

llm_model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.3",
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
#model.config.use_cache = False  # Disable cache for LLM

# LLM
llm_tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.3",
    use_fast=True
)
#llm_tokenizer.pad_token = llm_tokenizer.eos_token
if llm_tokenizer.bos_token_id is None:
    llm_tokenizer.add_special_tokens({'bos_token': '<bos>'})
if llm_tokenizer.eos_token_id is None:
    llm_tokenizer.add_special_tokens({'eos_token': '<eos>'})
llm_tokenizer.pad_token = llm_tokenizer.eos_token

# Dummy image (black image)
dummy = preprocess_val(Image.new("RGB", (224, 224))).unsqueeze(0).to(device)
with torch.no_grad():
    dummy_feat = clip_model.encode_image(dummy)  # shape: [1, D]

vision_feature_dim = dummy_feat.shape[-1]  # This is what we need

projection = nn.Linear(vision_feature_dim, llm_model.config.hidden_size).to(device)

In [None]:
from torch.utils.data import Dataset

class BiomedCLIPCaptionDataset(Dataset):
    def __init__(self, items):
        self.items = items

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        item = self.items[idx]
        img = item["image"].convert("RGB")
        caption = item["answer"]
        img_t = preprocess_val(img).unsqueeze(0).to(device)
        with torch.no_grad():
            vis_emb = clip_model.encode_image(img_t)
        vis_proj = projection(vis_emb).squeeze(0).cpu()
        return {"vision_prefix": vis_proj, "caption": caption}

In [None]:
def collate_fn(batch):
    vision_prefix = torch.stack([b["vision_prefix"] for b in batch])
    captions = [b["caption"] for b in batch]

    # Tokenize the fixed textual prompt 'Caption:'
    prompt_texts = ["Caption:"] * len(batch)
    prompt_tok = llm_tokenizer(
        prompt_texts,
        add_special_tokens=False,
        padding="longest",
        return_tensors="pt"
    )
    prompt_ids = prompt_tok.input_ids  # [B, P]
    prompt_mask = prompt_tok.attention_mask

    # Tokenize the ground-truth captions
    cap_tok = llm_tokenizer(
        captions,
        padding="longest",
        truncation=True,
        return_tensors="pt"
    )
    cap_ids = cap_tok.input_ids      # [B, C]
    cap_mask = cap_tok.attention_mask

    # Build input_ids: [bos slot] + prompt_ids + cap_ids
    bos_id = llm_tokenizer.bos_token_id
    B = len(batch)
    bos_col = torch.full((B,1), bos_id, dtype=torch.long)
    input_ids = torch.cat([bos_col.to(cap_ids.device), prompt_ids, cap_ids], dim=1)

    # Labels: mask bos and prompt tokens, then captions
    labels = torch.cat([
        torch.full((B,1), -100, dtype=torch.long),
        torch.full_like(prompt_ids, -100),
        cap_ids.clone()
    ], dim=1)

    # Attention mask: 1 for bos slot + prompt + caption
    attention_mask = torch.cat([
        torch.ones((B,1), dtype=prompt_mask.dtype),
        prompt_mask,
        cap_mask
    ], dim=1)

    return {
        "vision_prefix": vision_prefix,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

In [None]:
lora_config = LoraConfig(
    #task_type="CAUSAL_LM",
    lora_alpha=32,
    lora_dropout=0.05,
    r=16,
    bias="none",
    
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)

peft_llm = get_peft_model(llm_model, lora_config)
peft_llm.print_trainable_parameters()

In [None]:
class VisualInstructMistral(nn.Module):
    def __init__(self, peft_llm):
        super().__init__()
        self.llm = peft_llm

    def forward(self, input_ids=None, vision_prefix=None, attention_mask=None, labels=None):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        inputs_embeds = self.llm.get_input_embeddings()(input_ids)
        inputs_embeds[:, 0, :] = vision_prefix.to(device)
        return self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels.to(device) if labels is not None else None
        )

    def generate(self, input_ids, attention_mask, vision_prefix, **gen_kwargs):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        inputs_embeds = self.llm.get_input_embeddings()(input_ids)
        inputs_embeds[:, 0, :] = vision_prefix.to(device)
        return self.llm.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            bos_token_id=llm_tokenizer.bos_token_id,
            eos_token_id=llm_tokenizer.eos_token_id,
            pad_token_id=llm_tokenizer.pad_token_id,
            **gen_kwargs
        )


model = VisualInstructMistral(peft_llm)

In [None]:
train_indices = random.sample(range(len(train_dataset)), 4000)
val_indices   = random.sample(range(len(val_dataset)), 800)

train_list = [train_dataset[i] for i in train_indices]
val_list   = [val_dataset[i] for i in val_indices]

train_ds = BiomedCLIPCaptionDataset(train_list)
val_ds   = BiomedCLIPCaptionDataset(val_list)

In [None]:
training_args = TrainingArguments(
    output_dir="model",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    learning_rate=1e-4,
    warmup_steps=50,
    logging_dir="logs",
    logging_steps=50, # logs loss every 100 steps
    report_to=["tensorboard"],  # log to TensorBoard
    eval_strategy="steps",
    eval_steps=50,
    #save_steps=200,
    remove_unused_columns=False,
    fp16=False,
    bf16=False,
    dataloader_pin_memory=False,
    dataloader_num_workers=0,
    label_names=["labels"]
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    tokenizer=llm_tokenizer
)

In [None]:
trainer.train()