In [1]:
import json
import torch
import logging
import datasets
import pandas as pd
from PIL import Image
import torch.nn as nn
from tqdm.auto import tqdm
from datasets import load_dataset

In [2]:
import numpy as np
from datasets import Dataset
from torch.utils.data import DataLoader

### Set up Logger

In [3]:
# Clear previous handlers to avoid duplicate logs in Jupyter
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

# Configure logging
logging.basicConfig(
    level=logging.INFO,  # Change to DEBUG for more verbosity
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]  # Ensures it logs to Jupyter cell output
)

logger = logging.getLogger(__name__)
logger.info("Logging is set up in the notebook!")

2025-07-08 04:55:30,137 - INFO - Logging is set up in the notebook!


### Load the MultiDomain Dataset

In [4]:
prefix = "Generate a one word or single number answer for the given image and question"

In [5]:
def prepend_prefix(example):
    example['question'] = prefix + ': ' + example['question']
    return example

In [6]:
dataset = load_dataset("dutta18/multi-domain-VQA-1.5K")

In [7]:
train_set, val_set = dataset['train'], dataset['validation']

In [8]:
train_set = train_set.map(prepend_prefix)
val_set = val_set.map(prepend_prefix)

### Importing Model

In [9]:
device = "cuda"

In [10]:
from torch.amp import autocast
from torch.nn.utils import clip_grad_norm_
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig

In [11]:
model_id = "google/paligemma-3b-pt-224"

### Intialize Quantisation Configs

In [12]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,  
    bnb_4bit_use_double_quant=True,  # Use double quantization for memory savings
)

### Load Model

In [13]:
base_model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id, 
    quantization_config = bnb_config, 
    attn_implementation = "flash_attention_2",
    torch_dtype = torch.float16, 
    device_map = 'auto'
)

processor = PaliGemmaProcessor.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [14]:
image_token = processor.tokenizer.convert_tokens_to_ids("<image>")
bos_token = processor.tokenizer.convert_tokens_to_ids("<bos>")

### Intialize DORA Configs

In [15]:
dora_config = LoraConfig(
    r = 16,
    lora_alpha = 16*2,       # Scaling factor
    lora_dropout = 0.05,     # Dropout rate
    target_modules = ["q_proj", "k_proj", "v_proj",  "o_proj", "out_proj", "gate_proj", "up_proj", "down_proj"],
    use_dora = True  
)

In [16]:
base_model = prepare_model_for_kbit_training(base_model)
quantized_dora_base_model = get_peft_model(base_model, dora_config)

### Calculate Number of Params: ~ 24.4 M

In [17]:
def report_trainable_params():
    
    # Simple param report
    trainable = sum(p.numel() for p in quantized_dora_base_model.parameters() if p.requires_grad)
    print(f"Total trainable params: {trainable/1e6:.1f} M")

In [18]:
report_trainable_params()

Total trainable params: 24.4 M


### Setting Dataloaders

In [19]:
def collate_fn(examples):
    texts = [
        f"<image> <bos> answer {example['question']}" for example in examples
    ]  
    labels = [example['answer'] for example in examples]
    images = [example["image"].convert("RGB") for example in examples]

    tokens = processor(text=texts, images=images, suffix=labels, return_tensors="pt", padding="longest")
    tokens = {k: v.to(device) for k, v in tokens.items()}
    tokens["pixel_values"] = tokens["pixel_values"].to(torch.bfloat16)

    return tokens

In [20]:
batchSize_ = 4

In [21]:
train_loader = DataLoader(train_set, batch_size=batchSize_, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=batchSize_, shuffle=False, collate_fn=collate_fn)

### Validation Function

In [22]:
@torch.no_grad()
def do_validation():
    quantized_dora_base_model.eval()
    val_loss = 0.0
    
    for batch in tqdm(val_loader, desc="Validating"):
        
        with autocast(device_type='cuda', dtype=torch.bfloat16):
            outputs = quantized_dora_base_model(**batch)
            val_loss += outputs.loss.item()

    avg_val_loss = val_loss / len(val_loader)
    torch.cuda.empty_cache()
    quantized_dora_base_model.train()
    return avg_val_loss

### Training Hyperparams

In [23]:
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup

In [32]:
epochs = 5
weight_decay = 0.001
learning_rate = 5e-4
gradient_accumulation_steps = 2

In [33]:
optimizer = torch.optim.AdamW(quantized_dora_base_model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [34]:
total_train_steps = len(train_loader) // gradient_accumulation_steps * epochs
warmup_steps = int(0.05 * total_train_steps)

In [35]:
print(total_train_steps, warmup_steps)

935 46


In [36]:
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_train_steps
)

In [37]:
global_step = 0
best_val_loss = float("inf")
quantized_dora_base_model.config.use_cache = False

In [38]:
_ = quantized_dora_base_model.train()

In [39]:
saveDir = '/home/aritrad/main/PaliGemma-3B/MOE/Multidomain/chkpts/'

## Native PyTorch Training Loop

##### I am using val_loss as the checkpointing criteria, but any other metric which test text generation quality can be used here.

##### MAX GPU USAGE = 24 GB

In [None]:
for epoch in tqdm(range(epochs)):
    
    total_loss = 0.0

    for idx, batch in enumerate(train_loader):

        with autocast(device_type = 'cuda', dtype = torch.bfloat16):
            outputs = quantized_dora_base_model(**batch)
            loss = outputs.loss / gradient_accumulation_steps
        
        loss.backward()
        total_loss += loss.item()

        # Accumulate Grads and step optimizer and log train-loss.
        if (idx+1) % gradient_accumulation_steps == 0:
            clip_grad_norm_(quantized_dora_base_model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1

            logger.info(f"[Epoch {epoch+1} | Idx: {idx} | Optim Step: {global_step} | Loss: {loss.item():.4f}]")

            # Evaluation loop
            if global_step % 60 == 0:
                avg_val_loss = do_validation()
                logger.info(f"Validation Loss at: {idx+1} -> {avg_val_loss:.4f}\n")
    
                if avg_val_loss < best_val_loss:
                    quantized_dora_base_model.save_pretrained(os.path.join(saveDir, 'PaliGemma-MultiDomain-QDORA-chkpt-1500-16R.pt'))
                    logger.info(f"***** Checkpoint Saved *****\n")
                    best_val_loss = avg_val_loss
            
    logger.info(f"Epoch {epoch+1} completed. Avg loss: {total_loss / len(train_loader):.4f}\n\n")

  0%|          | 0/5 [00:00<?, ?it/s]

  return fn(*args, **kwargs)
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
2025-07-08 04:57:17,701 - INFO - [Epoch 1 | Idx: 1 | Optim Step: 1 | Loss: 2.4060]
2025-07-08 04:57:20,984 - INFO - [Epoch 1 | Idx: 3 | Optim Step: 2 | Loss: 1.5784]
2025-07-08 04:57:24,273 - INFO - [Epoch 1 | Idx: 5 | Optim Step: 3 | Loss: 2.1227]
2025-07-08 04:57:27,574 - INFO - [Epoch 1 | Idx: 7 | Optim Step: 4 | Loss: 1.7689]
2025-07-08 04:57:30,866 - INFO - [Epoch 1 | Idx: 9 | Optim Step: 5 | Loss: 1.1190]
2025-07-08 04:57:34,135 - INFO - [Epoch 1 | Idx: 11 | O

Validating:   0%|          | 0/150 [00:00<?, ?it/s]