In [None]:
# ============================
#  1. Install & Import
# ============================
!pip install -U "transformers>=4.43.0" "accelerate>=0.33.0" "bitsandbytes>=0.43" "peft>=0.11" "trl>=0.9.0" datasets tqdm

import torch, json, os
from datasets import load_dataset
from transformers import Idefics3ForConditionalGeneration, AutoProcessor
from peft import LoraConfig, get_peft_model
from trl import SFTConfig, SFTTrainer

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.cuda.empty_cache()

# ============================
#  2. Load dataset
# ============================
ds = load_dataset("mychen76/invoices-and-receipts_ocr_v1")
print(ds)

# Inspect one sample
print(ds["train"][0].keys())

# ============================
#  3. Model + Processor
# ============================
model = Idefics3ForConditionalGeneration.from_pretrained(
    "HuggingFaceTB/SmolVLM-Instruct",
    torch_dtype=torch.float16,
    device_map="auto"
)

processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")

# ============================
#  4. Apply LoRA (QLoRA optional)
# ============================
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# ============================
#  5. Preprocess Dataset
# ============================
def preprocess_fn(example):
    parsed = json.loads(example["parsed_data"])
    structured = parsed.get("json", "{}")
    try:
        structured_json = json.loads(structured.replace("'", '"'))
    except:
        structured_json = {"error": "invalid_json"}

    prompt = "Extract all invoice fields and return as JSON."
    target = json.dumps(structured_json)

    # tokenize separately
    inputs = processor.tokenizer(
        prompt,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )
    labels = processor.tokenizer(
        target,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )

    return {
        "input_ids": inputs["input_ids"][0].tolist(),
        "attention_mask": inputs["attention_mask"][0].tolist(),
        "labels": labels["input_ids"][0].tolist()
    }

tokenized_train = ds["train"].map(preprocess_fn, remove_columns=ds["train"].column_names)
tokenized_valid = ds["valid"].map(preprocess_fn, remove_columns=ds["valid"].column_names)

print(tokenized_train[0].keys())

# ============================
#  6. Config for Training
# ============================
sft_config = SFTConfig(
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    num_train_epochs=3,
    fp16=False,          # disable AMP since model is already in FP16
    bf16=True,           # if GPU supports it
    output_dir="./outputs",
    logging_steps=50,
    eval_steps=200,
    save_steps=500,
    report_to="none"
)

# ============================
#  7. Trainer
# ============================
trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_valid,
)

trainer.train()

# ============================
#  8. Save Fine-tuned Model
# ============================
model.save_pretrained("./fine_tuned_model")
processor.save_pretrained("./fine_tuned_model")

# ============================
#  9. Quick Evaluation
# ============================
sample = tokenized_valid[0]
input_ids = torch.tensor([sample["input_ids"]]).to("cuda")
labels = torch.tensor([sample["labels"]]).to("cuda")

with torch.no_grad():
    loss = model(input_ids=input_ids, labels=labels).loss
print("Sample loss:", loss.item())


In [None]:
print(inputs["input_ids"].shape, labels.shape)


In [None]:
# ==========================
# Invoice Parser Evaluation
# ==========================
import torch
import json
from datasets import load_dataset
from transformers import Idefics3ForConditionalGeneration, AutoProcessor

# ---------------------------
# Config
# ---------------------------
model_dir = "./fine_tuned_model"  # path to your trained model
dataset_id = "mychen76/invoices-and-receipts_ocr_v1"
split = "valid"
max_samples = 5
max_len = 512

# ---------------------------
# Load Model & Processor
# ---------------------------
model = Idefics3ForConditionalGeneration.from_pretrained(
    model_dir,
    torch_dtype="auto",
    device_map="auto"
)
processor = AutoProcessor.from_pretrained(model_dir)
model.eval()

# ---------------------------
# Load Dataset
# ---------------------------
ds = load_dataset(dataset_id)[split]
print(f"Loaded {len(ds)} samples from {dataset_id}/{split}")

# ---------------------------
# Evaluation Loop
# ---------------------------
for i, ex in enumerate(ds.select(range(min(max_samples, len(ds))))):
    parsed_data = json.loads(ex["parsed_data"])
    target_str = parsed_data.get("json", "{}")

    try:
        target_json = json.loads(target_str.replace("'", '"'))
    except Exception:
        target_json = {"error": "invalid_json"}

    prompt = "Extract all invoice fields and return as JSON."
    target_text = json.dumps(target_json)

    # --- Concatenate prompt + target (safe alignment) ---
    full_text = f"{prompt}\n{target_text}"
    enc = processor.tokenizer(
        full_text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=max_len,
    )

    # Mask prompt tokens so loss only applies to the answer
    labels = enc["input_ids"].clone()
    prompt_len = len(processor.tokenizer(prompt)["input_ids"])
    labels[:, :prompt_len] = -100

    # Move to device
    enc = {k: v.to(model.device) for k, v in enc.items()}
    labels = labels.to(model.device)

    # Compute loss
    with torch.no_grad():
        loss = model(**enc, labels=labels).loss

    print(f"\n[{i}] Loss: {loss.item():.4f}")

    # Generate output
    gen_tokens = model.generate(**enc, max_new_tokens=256)
    pred = processor.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0]

    print("Prediction:", pred[:300])
    print("Ground Truth:", target_text[:300])
    print("-" * 80)
