In [1]:
# Run after kernel restart
import gc, os, shutil, torch
gc.collect()
torch.cuda.empty_cache()
# Optional: remove datasets cache (careful: this deletes dataset cache files)
cache_dir = os.path.expanduser('~/.cache/huggingface/datasets')
if os.path.exists(cache_dir):
    print("Clearing HF datasets cache (may free large disk but will re-download if needed)...")
    # shutil.rmtree(cache_dir)  # uncomment only if you want to clear cache
print("Cleanup done. Re-run next steps.")


Cleanup done. Re-run next steps.


In [2]:
# MedGemma Fine-tune + Eval + Adapter Save (2xT4)
# Notebook-style Python script with cell separators and explanatory comments.
# Designed for: alvinl29/medical-vision-llm-dataset-v2 (HF parquet; columns: image, question, answer, image_description, body_part, modality, conversations)

In [3]:
# ==========================
# Cell 0: Setup environment
# - Installs required libraries (run once)
# - Mounts Drive if needed
# - Sets useful environment variables
# ==========================


 #pip installs (uncomment to run in a fresh Colab/Kaggle env)
!pip install -q transformers datasets accelerate peft bitsandbytes safetensors sentencepiece transformers[torch] torchvision timm git+https://github.com/huggingface/peft.git
# Note: On Kaggle, some packages may be preinstalled; adjust as needed.


# If you use Google Drive for checkpoints, mount it (Colab only)
# from google.colab import drive
# drive.mount('/content/gdrive')

# Working directory for checkpoints and artifacts
WORKDIR = "/content/medgemma_finetune" # change to your Drive path if needed
import os
os.makedirs(WORKDIR, exist_ok=True)
print("Workdir:", WORKDIR)

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m92.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m71.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m53.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━

In [4]:
# install protobuf that restores the legacy API
!pip install --upgrade --force-reinstall "protobuf==3.20.3"



Collecting protobuf==3.20.3
  Downloading protobuf-3.20.3-py2.py3-none-any.whl.metadata (720 bytes)
Downloading protobuf-3.20.3-py2.py3-none-any.whl (162 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.1/162.1 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 6.33.0
    Uninstalling protobuf-6.33.0:
      Successfully uninstalled protobuf-6.33.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.12.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
opentelemetry-proto 1.37.0 requires protobuf<7.0,>=5.0, but you have protobuf 3.20.3 which is incompatible.
onnx 1.18.0 requires protobuf>=4.25.1, but you have protobuf 3.20.3 which is incompatible.
a2a-sdk 0.3.10 requires protobuf>=5.29

In [5]:
# ==========================
# Cell 1: Imports & utility functions
# - Central imports, monitoring helpers
# ==========================


import gc
import time
import json
from pathlib import Path
from typing import Optional, Dict, Any

import torch
import numpy as np
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoProcessor,
    BitsAndBytesConfig,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
)
from peft import LoraConfig, get_peft_model, PeftModel

# optional: psutil to monitor memory
try:
    import psutil
except Exception:
    psutil = None


def print_mem(prefix: str = ""):
    """Print simple CPU + GPU memory stats to help monitoring."""
    if psutil:
        vm = psutil.virtual_memory()
        print(f"{prefix} CPU RAM: {vm.available/1024**3:.2f} GB available / {vm.total/1024**3:.2f} GB total")
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i}: {torch.cuda.get_device_name(i)} - {torch.cuda.memory_reserved(i)/1024**3:.2f} GB reserved, {torch.cuda.memory_allocated(i)/1024**3:.2f} GB allocated")

print_mem("startup")

2025-12-12 16:30:02.580550: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765557002.771060      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765557002.826593      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


startup CPU RAM: 29.26 GB available / 31.35 GB total
GPU 0: Tesla T4 - 0.00 GB reserved, 0.00 GB allocated
GPU 1: Tesla T4 - 0.00 GB reserved, 0.00 GB allocated


In [6]:
# ==========================
# Cell 2: Dataset inspection (HF parquet)
# - Loads the HF dataset ID and prints schema + sample rows
# - USER: confirm dataset_id below
# ==========================

HF_DATASET_ID = "alvinl29/medical-vision-llm-dataset-v2"
print("Loading dataset id:", HF_DATASET_ID)
raw = load_dataset(HF_DATASET_ID)
print("Splits:", list(raw.keys()))
for s in raw:
    print(s, len(raw[s]))

print("Train columns:", raw['train'].column_names)
print("Sample train row (first):")
print(raw['train'][0])

# Inspect the 'image' feature type and one image if present
if 'image' in raw['train'].column_names:
    print("Image feature type:", raw['train'].features['image'])


Loading dataset id: alvinl29/medical-vision-llm-dataset-v2


README.md:   0%|          | 0.00/412 [00:00<?, ?B/s]

data/train-00000-of-00002.parquet:   0%|          | 0.00/408M [00:00<?, ?B/s]

data/train-00001-of-00002.parquet:   0%|          | 0.00/412M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/210M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3035 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/758 [00:00<?, ? examples/s]

Splits: ['train', 'validation']
train 3035
validation 758
Train columns: ['image', 'conversations', 'image_description', 'question', 'answer', 'dataset_source', 'modality', 'body_part', 'sample_id', 'instruction']
Sample train row (first):
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=800x969 at 0x7CEA902E3DD0>, 'conversations': {'role': ['user', 'assistant'], 'content': ['Describe the medical findings in this image.', 'Plain abdominal radiographs suggested subileus.']}, 'image_description': 'Plain abdominal radiographs suggested subileus.', 'question': 'Describe the medical findings.', 'answer': 'Plain abdominal radiographs suggested subileus.', 'dataset_source': 'ROCO', 'modality': 'X-ray', 'body_part': 'Abdomen', 'sample_id': 'roco_405', 'instruction': 'Analyze this medical image.'}
Image feature type: Image(mode=None, decode=True)


In [7]:
# ==========================
# Cell 3: Column mapping & prompt template
# - Set the exact column names observed in your dataset
# - Define the prompt template used for training
# ==========================


# Column mapping (edit if your dataset uses different names)
IMAGE_COL = 'image'
QUESTION_COL = 'question'
ANSWER_COL = 'answer'
IMAGE_DESC_COL = 'image_description' # note: fix exact spelling in dataset if different
BODY_PART_COL = 'body_part'
MODALITY_COL = 'modality'
# conversations and instruction exist but we'll use question + image_description as canonical


# Prompt template: keeps metadata and question; model must generate the answer after 'Answer:'
PROMPT_TEMPLATE = (
"You are a medical vision-language assistant.\n"
"Analyze the following image and answer the user's question.\n\n"
"Image Modality: {modality}\n"
"Body Part: {body_part}\n"
"Image Description: {image_description}\n\n"
"Question: {question}\n"
"Answer:"
)


print("Prompt example:\n", PROMPT_TEMPLATE.format(
modality="X-ray", body_part="Abdomen", image_description="Plain abdominal radiographs suggested subileus.", question="Describe the medical findings."))

Prompt example:
 You are a medical vision-language assistant.
Analyze the following image and answer the user's question.

Image Modality: X-ray
Body Part: Abdomen
Image Description: Plain abdominal radiographs suggested subileus.

Question: Describe the medical findings.
Answer:


In [8]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [9]:
# ==========================
# Cell 4: Tokenizer + Processor + model config placeholders
# - Choose base model name (change to the exact HF medgemma variant you want)
# - Prepare BitsAndBytesConfig for 4-bit loading
# ==========================

BASE_MODEL = "google/medgemma-4b-it"  # change if you have a different repo id

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
# Attempt to load a multimodal processor if available; otherwise we'll use PIL transforms
try:
    processor = AutoProcessor.from_pretrained(BASE_MODEL)
    print("Loaded AutoProcessor for model")
except Exception as e:
    processor = None
    print("No AutoProcessor available; will apply manual PIL transforms if needed.")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loaded AutoProcessor for model


In [10]:
# ==========================
# Cell 5: Preprocessing function (CORRECTED)
# - ONLY builds prompt strings.
# - Does NOT tokenize. We leave that for the processor in the collator.
# ==========================

from PIL import Image
import requests
from io import BytesIO

def build_prompt_str(row):
    """
    Build the textual prompt. Must include the <image> token.
    """
    modality = row.get(MODALITY_COL, "") or "Unknown"
    body_part = row.get(BODY_PART_COL, "") or "Unknown"
    image_desc = row.get(IMAGE_DESC_COL, "") or ""
    question = row.get(QUESTION_COL, "") or ""

    # Ensure we use the exact tag the processor expects
    prompt = (
        "<image>\n"
        "You are a medical vision-language assistant.\n"
        "Analyze the following image and answer the user's question.\n\n"
        f"Image Modality: {modality}\n"
        f"Body Part: {body_part}\n"
        f"Image Description: {image_desc}\n\n"
        f"Question: {question}\n"
        "Answer:"
    )
    return prompt

def preprocess_batch_light(batch):
    """
    Revised preprocessing:
    - Just adds the 'text_prompt' and 'text_answer' fields.
    - Does NOT tokenize inputs or remove images.
    """
    prompts = []
    answers = []
    n = len(batch[QUESTION_COL])
    
    for i in range(n):
        row = {k: batch[k][i] for k in batch.keys()}
        prompts.append(build_prompt_str(row))
        answers.append(row.get(ANSWER_COL, "") or "")

    return {
        "text_prompt": prompts,
        "text_answer": answers,
        # We keep the original image column automatically
    }

In [11]:
# ==========================
# Cell 6: Prepare datasets (map preprocess)
# - Applies preprocessing to train and validation splits
# - Optionally create small samples for quick debug
# ==========================

print_mem('before dataset map')

train_ds = raw['train']
val_ds = raw['validation'] if 'validation' in raw else raw['test'] if 'test' in raw else None

# For debug: create small subsets
# tiny_train = train_ds.select(range(min(32, len(train_ds))))
# tiny_val = val_ds.select(range(min(16, len(val_ds))))

print("Mapping train dataset... this may take time")
train_proc = train_ds.map(preprocess_batch_light, batched=True, batch_size=64, remove_columns=[col for col in train_ds.column_names if col!='image'])

if val_ds is not None:
    val_proc = val_ds.map(preprocess_batch_light, batched=True, remove_columns=val_ds.column_names)
else:
    val_proc = None

print("Mapped datasets. Train examples:", len(train_proc), "Val examples:", len(val_proc) if val_proc else 0)
print_mem('after dataset map')


before dataset map CPU RAM: 28.08 GB available / 31.35 GB total
GPU 0: Tesla T4 - 0.00 GB reserved, 0.00 GB allocated
GPU 1: Tesla T4 - 0.00 GB reserved, 0.00 GB allocated
Mapping train dataset... this may take time


Map:   0%|          | 0/3035 [00:00<?, ? examples/s]

Map:   0%|          | 0/758 [00:00<?, ? examples/s]

Mapped datasets. Train examples: 3035 Val examples: 758
after dataset map CPU RAM: 28.15 GB available / 31.35 GB total
GPU 0: Tesla T4 - 0.00 GB reserved, 0.00 GB allocated
GPU 1: Tesla T4 - 0.00 GB reserved, 0.00 GB allocated


In [13]:
# ==========================
# Cell 7: Collator (CORRECTED)
# - Handles tokenization + Image processing together
# - Dynamically creates labels by masking the user prompt
# ==========================
import torch

class OnTheFlyCollator:
    def __init__(self, processor):
        self.processor = processor
        # Fallback padding token if processor doesn't set it
        if self.processor.tokenizer.pad_token_id is None:
             self.processor.tokenizer.pad_token_id = self.processor.tokenizer.eos_token_id

    def __call__(self, features):
        images = [f[IMAGE_COL] for f in features]
        prompts = [f['text_prompt'] for f in features]
        answers = [f['text_answer'] for f in features]
        
        # 1. Prepare full input (Prompt + Answer)
        # The processor handles <image> expansion automatically here
        full_texts = [p + " " + a for p, a in zip(prompts, answers)]
        
        batch = self.processor(
            text=full_texts,
            images=images,
            padding=True,
            truncation=True,
            max_length=256, # Adjust based on your memory constraints
            return_tensors="pt"
        )
        
        # 2. Create Labels (Masking the instruction/prompt part)
        input_ids = batch['input_ids']
        labels = input_ids.clone()
        
        # We need to calculate where the "Answer" starts to mask everything before it
        # We do this by processing JUST the prompt to get its length
        # Note: This adds a tiny overhead but is the safest way to get exact token alignment
        prompt_batch = self.processor(
            text=prompts,
            images=images,
            padding="longest", # Padding strategy might differ, but we only need lengths
            truncation=True,
            max_length=256,
            return_tensors="pt"
        )
        
        # Mask out the prompt tokens in the labels
        for i, prompt_ids in enumerate(prompt_batch['input_ids']):
            # The length of the prompt tokens
            prompt_len = len(prompt_ids)
            
            # Safety check: ensure prompt length is less than total length
            if prompt_len < len(labels[i]):
                labels[i, :prompt_len] = -100
            else:
                # If prompt was truncated to be the whole sequence, ignore this sample
                labels[i, :] = -100

        batch['labels'] = labels
        
        return batch

# Initialize the new collator
collator = OnTheFlyCollator(processor=processor)

In [14]:
# ==========================
# Cell 8: Model load (4-bit) + attach LoRA (for 2xT4 training)
# - This cell loads base model in 4-bit via bitsandbytes and attaches LoRA
# - Use accelerate/transformers device_map to automatically shard across GPUs
# ==========================

print_mem('before model load')

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map='auto',
    trust_remote_code=False,
)

# LoRA config (r=16 => approx 80-120M trainable params depending on target modules)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],  # adjust if model layer names differ
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.gradient_checkpointing_enable()

print("Model + LoRA prepared. Trainable params (approx):")
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable params: {trainable:,}")
print_mem('after model load')

before model load CPU RAM: 28.15 GB available / 31.35 GB total
GPU 0: Tesla T4 - 0.00 GB reserved, 0.00 GB allocated
GPU 1: Tesla T4 - 0.00 GB reserved, 0.00 GB allocated


config.json:   0%|          | 0.00/2.47k [00:00<?, ?B/s]



model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

Model + LoRA prepared. Trainable params (approx):
Trainable params: 6,447,104
after model load CPU RAM: 26.09 GB available / 31.35 GB total
GPU 0: Tesla T4 - 0.30 GB reserved, 0.14 GB allocated
GPU 1: Tesla T4 - 4.48 GB reserved, 2.89 GB allocated


In [16]:
# ==========================
# Cell 9: Custom Trainer with generation-based eval
# - Trainer subclass overrides evaluate() to generate answers and compute metrics
# - Metrics: exact-match and token-level F1
# ==========================
# ==========================
# Cell: Metrics + Custom Trainer (no load_metric)
# ==========================

import re
import torch
from transformers import Trainer

# ----- Normalization -----
def normalize_text(s):
    s = s.lower().strip()
    s = re.sub(r"[^a-z0-9 ]+", " ", s)
    s = " ".join(s.split())
    return s

# ----- Token-level F1 -----
def token_f1(pred: str, gold: str):
    p_tokens = pred.split()
    g_tokens = gold.split()
    if not p_tokens or not g_tokens:
        return 0.0
    common = set(p_tokens) & set(g_tokens)
    prec = len(common) / len(p_tokens)
    rec = len(common) / len(g_tokens)
    if prec + rec == 0:
        return 0.0
    return 2 * (prec * rec) / (prec + rec)

# ----- Custom Trainer -----
class MMTrainer(Trainer):
    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        ds = eval_dataset if eval_dataset is not None else self.eval_dataset
        if ds is None:
            return {}
        
        self.model.eval()
        device = self.model.device
        total = 0
        matches = 0
        f1s = 0.0
        
        for i in range(len(ds)):
            ex = ds[i]
            batch = collator([ex])

            # Move required tensors
            inputs = {
                k: v.to(device)
                for k, v in batch.items()
                if k in ['input_ids', 'attention_mask', 'pixel_values']
            }

            with torch.no_grad():
                gen = self.model.generate(
                    **inputs,
                    max_new_tokens=32,
                    do_sample=False
                )
                output = tokenizer.decode(gen[0], skip_special_tokens=True)

            # Extract prediction
            if "Answer:" in output:
                pred = output.split("Answer:")[-1].strip()
            else:
                pred = output.strip()

            gold = ex.get("answer", "")
            pred_n = normalize_text(pred)
            gold_n = normalize_text(gold)

            if pred_n == gold_n and gold_n != "":
                matches += 1
            f1s += token_f1(pred_n, gold_n)
            total += 1
        
        accuracy = matches / total if total > 0 else 0.0
        avg_f1 = f1s / total if total > 0 else 0.0
        
        metrics = {
            f"{metric_key_prefix}_exact_match": accuracy,
            f"{metric_key_prefix}_token_f1": avg_f1,
            f"{metric_key_prefix}_samples": total
        }

        print("Eval metrics:", metrics)
        return metrics


In [17]:
# ==========================
# Cell 10: TrainingArguments + Trainer init
# - Configured for 2xT4 training
# - Checkpointing to WORKDIR
# ==========================

training_args = TrainingArguments(
    output_dir=WORKDIR,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    fp16=True,
    num_train_epochs=3,
    logging_steps=50,
    save_steps=500,
    save_total_limit=5,
    eval_strategy="steps",
    eval_steps=500,
    learning_rate=2e-4,
    remove_unused_columns=False,
    push_to_hub=False,
)

# ==========================
# Create Trainer (minimal args only)
# ==========================

trainer = MMTrainer(
    model=model,
    args=training_args,
    train_dataset=train_proc,
    eval_dataset=val_proc,
    data_collator=collator
)

# Now patch missing attributes manually AFTER construction
trainer.label_names = ["labels"]       # ensures Trainer looks for "labels"
trainer.tokenizer = tokenizer          # optional but useful for save/train logs

print("Trainer created successfully.")
print("label_names =", trainer.label_names)


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
Trainer.tokenizer is now deprecated. You should use `Trainer.processing_class = processing_class` instead.


Trainer created successfully.
label_names = ['labels']


In [18]:
# ==========================
# Cell 11: Resume logic and run training
# - If a checkpoint exists in WORKDIR, resume
# - Save adapter only frequently
# ==========================

# find latest checkpoint
ckpts = [d for d in os.listdir(WORKDIR) if d.startswith("checkpoint")]
ckpt_to_resume = None
if ckpts:
    ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]) if '-' in x else 0)
    ckpt_to_resume = os.path.join(WORKDIR, ckpts_sorted[-1])
    print("Resuming from checkpoint:", ckpt_to_resume)

trainer.train(resume_from_checkpoint=ckpt_to_resume if ckpt_to_resume else None)

# Save the small adapter weights after training
ADAPTER_DIR = os.path.join(WORKDIR, "final_adapter")
os.makedirs(ADAPTER_DIR, exist_ok=True)
model.save_pretrained(ADAPTER_DIR)
tokenizer.save_pretrained(ADAPTER_DIR)
print("Adapter saved to", ADAPTER_DIR)




<IPython.core.display.Javascript object>

KeyboardInterrupt: 

In [None]:
# ==========================
# Cell 12: Sanity-check inference using base+adapter (fast, optional)
# - reload base in 4-bit + adapter and run few samples
# ==========================

# reload base in 4-bit and attach adapter for quick inference
base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map='auto',
    trust_remote_code=False,
)
model_with_adapter = PeftModel.from_pretrained(base, ADAPTER_DIR)
model_with_adapter.eval()

# run a quick generate on first val example
if val_proc and len(val_proc) > 0:
    ex = val_proc[0]
    batch = collator([ex])
    device = model_with_adapter.device
    inputs = {k: v.to(device) for k,v in batch.items() if k in ['input_ids','attention_mask','pixel_values']}
    with torch.no_grad():
        gen = model_with_adapter.generate(**inputs, max_new_tokens=64)
        print("Generated:", tokenizer.decode(gen[0], skip_special_tokens=True))

In [None]:
# ==========================
# Cell 13: Prepare adapter download (copy ADAPTER_DIR to local/Drive/Kaggle dataset)
# - This cell packages the adapter folder for transfer to P100 session
# ==========================

# Example: zip the adapter for download
import shutil
shp = os.path.join(WORKDIR, 'adapter_bundle.zip')
if os.path.exists(shp):
    os.remove(shp)
shutil.make_archive(shp.replace('.zip',''), 'zip', ADAPTER_DIR)
print('Adapter archive created at', shp)
