In [1]:
# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard torchvision

# Install Gemma release branch from Hugging Face
%pip install "transformers>=4.51.3"

# Install Hugging Face libraries
%pip install  --upgrade \
  "datasets==3.3.2" \
  "accelerate==1.4.0" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.45.3" \
  "trl==0.15.2" \
  "peft==0.14.0" \
  "pillow==11.1.0" \
  protobuf \
  sentencepiece




In [2]:
# !pip install flash_attn

In [None]:
# from google.colab import userdata
from huggingface_hub import login
login()


In [4]:
from datasets import load_dataset, DatasetDict
from PIL import Image
import os

# System message for the assistant
system_message = """Your task is to:
    - Identify the **document type**.
    - Determine whether the document is **Real** or **Fake** based on below reasoning:
    - Suspicious or inconsistent entries.
    - Font inconsistencies.
    - Violations of standard banking or accounting practices.
    - Textual or numeric manipulation (e.g., formatting issues, overwritten values).
    - Metadata mismatches (e.g., conflicting dates, fake signatures/stamps).
    - Unnatural linguistic patterns or overly generic phrasing.
    - Semantic inconsistencies or hallucinated data.

Return your output in the following json format:
DocumentType: <e.g., Bank Statement, Salary Slip, ID Card>
Authenticity: <Original, Fraud, Real, Fake, Genuine>
Reason: <Clear, concise explanation with observed issues related to authenticity>
"""

# User prompt template
user_prompt = """Authenticity: {category}, DocumentType: {doctype}, Reason: {reason}"""

from datasets import load_dataset
from PIL import Image

def resize_half_min256(img: Image.Image) -> Image.Image:
    w, h = img.size
    new_w = max(w // 2, 256)
    new_h = max(h // 2, 256)
    if w // 2 < 256 or h // 2 < 256:
        aspect = w / h
        if w < h:
            new_w = 256
            new_h = int(256 / aspect)
        else:
            new_h = 256
            new_w = int(256 * aspect)
        new_w = min(new_w, w)
        new_h = min(new_h, h)
    return img.resize((new_w, new_h), Image.LANCZOS)

def resize_data(sample):
    img = sample["image"]
    img = resize_half_min256(img)
    sample["image"] = img
    return sample


def format_data(sample):
    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_message}],
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": resize_data(sample)["image"],
                },
                {
                    "type": "text",
                    "text": "Idenitify the documentType and authenticity with reason",
                },
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text",
                         "text": user_prompt.format(
                            category=sample["category"],
                            doctype=sample["documentType"],
                            reason=sample["reason"]),
                        },
                        ],
        }
    ]

def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    # Iterate through each conversation
    for msg in messages:
        # Get content (ensure it's a list)
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        # Check each content element for images
        for element in content:
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                # Get the image and convert to RGB
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                image_inputs.append(image.convert("RGB"))
    return image_inputs


In [5]:



# Load dataset from the hub
# dataset = load_dataset("AliceRolan/realfakedataset", split="train")
train_dataset, eval_dataset = load_dataset("AliceRolan/IDCardDataset", split=["train[:100%]",  "test[:100%]"])

# Convert dataset to OAI messages
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
train_dataset = [format_data(sample) for sample in train_dataset]
eval_dataset = [format_data(sample) for sample in eval_dataset]
# test_dataset = [format_data(sample) for sample in test_dataset]
# print(dataset[0]["messages"])

train-00000-of-00001.parquet:   0%|          | 0.00/434M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/144M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [6]:
print(train_dataset[0])

[{'role': 'system', 'content': [{'type': 'text', 'text': 'Your task is to:\n    - Identify the **document type**.\n    - Determine whether the document is **Real** or **Fake** based on below reasoning:\n    - Suspicious or inconsistent entries.\n    - Font inconsistencies.\n    - Violations of standard banking or accounting practices.\n    - Textual or numeric manipulation (e.g., formatting issues, overwritten values).\n    - Metadata mismatches (e.g., conflicting dates, fake signatures/stamps).\n    - Unnatural linguistic patterns or overly generic phrasing.\n    - Semantic inconsistencies or hallucinated data.\n\nReturn your output in the following json format:\nDocumentType: <e.g., Bank Statement, Salary Slip, ID Card>\nAuthenticity: <Original, Fraud, Real, Fake, Genuine>\nReason: <Clear, concise explanation with observed issues related to authenticity>\n'}]}, {'role': 'user', 'content': [{'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=567x358 at 0x7CC7E976C850>}, {'

In [7]:
# print(dataset[1900]["messages"])

In [8]:
import torch
# import flash_attn_2_cuda
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

# Hugging Face model id
model_id = "google/gemma-3-4b-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

In [9]:
# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

# import torch
# from google.colab import drive
# # import shutil

# # Mount Google Drive
# drive.mount('/content/drive')

# # Load Model with PEFT adapter
# model = AutoModelForImageTextToText.from_pretrained(
#   '/content/gemma-currency-FT',
#   device_map="auto",
#   torch_dtype=torch.bfloat16,
#   attn_implementation="eager"
# )
# processor = AutoProcessor.from_pretrained('/content/gemma-currency-FT')



config.json:   0%|          | 0.00/815 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/1.61k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

In [10]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=4,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)


In [11]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-idcard-FT",     # directory to save and repository id
    num_train_epochs=1,                         # number of training epochs
    per_device_train_batch_size=1,              # batch size per device during training
    gradient_accumulation_steps=4,              # number of steps before performing a backward/update pass
    gradient_checkpointing=True,                # use gradient checkpointing to save memory
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    logging_steps=25,                            # log every 5 steps
    save_strategy="epoch",                      # save checkpoint every epoch
    learning_rate=2e-4,                         # learning rate, based on QLoRA paper
    bf16=True,                                  # use bfloat16 precision
    max_grad_norm=0.3,                          # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                          # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",               # use constant learning rate scheduler
    # push_to_hub=True,                           # push model to hub
    report_to="tensorboard",                    # report metrics to tensorboard
    gradient_checkpointing_kwargs={
        "use_reentrant": False
    },  # use reentrant checkpointing
    dataset_text_field="",                      # need a dummy field for collator
    dataset_kwargs={"skip_prepare_dataset": True},  # important for collator
)
args.remove_unused_columns = False # important for collator

# Create a data collator to encode text and image pairs
def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        image_inputs = process_vision_info(example)
        text = processor.apply_chat_template(
            example, add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()

    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels
    return batch


In [12]:
!pip install evaluate nltk rouge-score
import nltk
nltk.download('punkt')

Collecting rouge-score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge-score
  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge-score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=8f86a86930fa9ef5f56a108959f4b5563024694996f95dc588644e3cd395362d
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge-score
Installing collected packages: rouge-score
Successfully installed rouge-score-0.1.2


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [13]:
# import numpy as np
import traceback
import json
import evaluate


# Make sure you have these loaded
rouge_metric = evaluate.load("rouge")
bleu_metric = evaluate.load("bleu")
rouge_results = None
def compute_metrics(eval_pred):
    try:
        raw_logits, raw_labels = eval_pred
        predicted_ids = np.argmax(raw_logits[0], axis=-1)

        decoded_preds = processor.batch_decode(predicted_ids, skip_special_tokens=True)

        labels = np.where(raw_labels != -100, raw_labels, processor.tokenizer.pad_token_id)
        decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)

        def extract_assistant_response(text,tx):
            """
            Extracts content after the last 'Assistant:' marker.
            If not found, falls back to trying the 'Answer:' marker.
            """
            # print("org:",text)
            # --- FIX IS HERE: Added fallback logic ---
            # 1. Try to split by "Assistant:" first
            parts = text.rsplit("Assistant:", 1)
            if len(parts) > 1:
                # print(f"{tx}_Assistant:",parts[1].strip() if len(parts) > 1 else "")
                return parts[1].strip()
            parts = text.rsplit("assistant:", 1)
            if len(parts) > 1:
                # print(f"{tx}_assistant:",parts[1].strip() if len(parts) > 1 else "")
                return parts[1].strip()
                assistant

            # 2. If that fails, try to split by "Answer:"
            parts = text.rsplit("Answer:", 1)
            if len(parts) > 1:
                # print(f"{tx}_Answer:",parts[1].strip() if len(parts) > 1 else "")
                return parts[1].strip()

            # 3. If both fail, return an empty string
            # print(f"{tx}_orig:{text}")
            return ""
            # ----------------------------------------
        # print("decoded_preds:",decoded_preds)
        # print("decoded_labels:",decoded_labels)
        pred_responses = [extract_assistant_response(p,"pred_responses") for p in decoded_preds]
        label_responses = [extract_assistant_response(l,"label_responses") for l in decoded_labels]

        # Calculate ROUGE (it's more robust to empty strings)
        rouge_results = rouge_metric.compute(predictions=pred_responses, references=label_responses)
        # print("rouge_results:",rouge_results)
        # Tokenize for BLEU score
        pred_tokenized = [pred.split() for pred in pred_responses]
        label_tokenized = [[label.split()] for label in label_responses]

        # --- FIX IS HERE ---
        bleu_results = {"bleu": 0.0} # Default score
        # Check if there are any non-empty prediction strings to score
        if any(pred_tokenized):
            bleu_results = bleu_metric.compute(predictions=pred_tokenized, references=label_tokenized)
        # -------------------

        # -------------------
        # print("bleu_results:",bleu_results)
        all_metrics = {
            "rouge1": rouge_results["rouge1"],
            "rouge2": rouge_results["rouge2"],
            "rougeL": rouge_results["rougeL"],
            "rougeLsum": rouge_results["rougeLsum"],
            "bleu": bleu_results["bleu"],
        }
        return all_metrics

    except Exception as e:
        print(f"Error in compute_metrics: {e}")
        traceback.print_exc()
        return {"rouge1": rouge_results["rouge1"],
            "rouge2": rouge_results["rouge2"],
            "rougeL": rouge_results["rougeL"],
            "rougeLsum": rouge_results["rougeLsum"],
                "bleu": 0}

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

In [14]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    # eval_dataset=eval_dataset,
    # tokenizer=processor
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
    # compute_metrics=compute_metrics,
)


In [None]:
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
# trainer.save_model()


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
25,2.7174
50,0.0977
75,0.041
100,0.0391
125,0.0399
150,0.0397
175,0.0371
200,0.0379
225,0.0377
250,0.0377


In [16]:
import pandas as pd
metrics = trainer.state.log_history
dfm=pd.DataFrame(metrics)
dfm

Unnamed: 0,loss,grad_norm,learning_rate,mean_token_accuracy,epoch,step,train_runtime,train_samples_per_second,train_steps_per_second,total_flos,train_loss
0,2.7174,1.791844,0.0002,0.900931,0.033333,25,,,,,
1,0.0977,1.666703,0.0002,0.99307,0.066667,50,,,,,
2,0.041,0.393393,0.0002,0.993436,0.1,75,,,,,
3,0.0391,1.056044,0.0002,0.993512,0.133333,100,,,,,
4,0.0399,1.288428,0.0002,0.993465,0.166667,125,,,,,
5,0.0397,0.401747,0.0002,0.993548,0.2,150,,,,,
6,0.0371,0.996804,0.0002,0.993786,0.233333,175,,,,,
7,0.0379,0.874247,0.0002,0.992934,0.266667,200,,,,,
8,0.0377,0.728275,0.0002,0.993561,0.3,225,,,,,
9,0.0377,0.35873,0.0002,0.993816,0.333333,250,,,,,


In [18]:
dfm

Unnamed: 0,loss,grad_norm,learning_rate,mean_token_accuracy,epoch,step,train_runtime,train_samples_per_second,train_steps_per_second,total_flos,train_loss
0,2.7174,1.791844,0.0002,0.900931,0.033333,25,,,,,
1,0.0977,1.666703,0.0002,0.99307,0.066667,50,,,,,
2,0.041,0.393393,0.0002,0.993436,0.1,75,,,,,
3,0.0391,1.056044,0.0002,0.993512,0.133333,100,,,,,
4,0.0399,1.288428,0.0002,0.993465,0.166667,125,,,,,
5,0.0397,0.401747,0.0002,0.993548,0.2,150,,,,,
6,0.0371,0.996804,0.0002,0.993786,0.233333,175,,,,,
7,0.0379,0.874247,0.0002,0.992934,0.266667,200,,,,,
8,0.0377,0.728275,0.0002,0.993561,0.3,225,,,,,
9,0.0377,0.35873,0.0002,0.993816,0.333333,250,,,,,


In [19]:
trainer.save_model(args.output_dir)
processor.save_pretrained(args.output_dir)

['gemma-idcard-FT/processor_config.json']

In [20]:
trainer.push_to_hub()

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

events.out.tfevents.1754842390.f544100c312f.1803.0:   0%|          | 0.00/16.2k [00:00<?, ?B/s]

Upload 5 LFS files:   0%|          | 0/5 [00:00<?, ?it/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/2.76G [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.62k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/AliceRolan/gemma-idcard-FT/commit/d5e0a15d0216171c35b74bdee89ea91712e20eea', commit_message='End of training', commit_description='', oid='d5e0a15d0216171c35b74bdee89ea91712e20eea', pr_url=None, repo_url=RepoUrl('https://huggingface.co/AliceRolan/gemma-idcard-FT', endpoint='https://huggingface.co', repo_type='model', repo_id='AliceRolan/gemma-idcard-FT'), pr_revision=None, pr_num=None)

In [21]:
# eval_results = trainer.evaluate()
# print(eval_results)

# free the memory again
# del model
# del trainer
torch.cuda.empty_cache()


In [22]:
import gc
import time

def clear_memory():
    # Delete variables if they exist in the current global scope
    if 'inputs' in globals(): del globals()['inputs']
    if 'model' in globals(): del globals()['model']
    if 'processor' in globals(): del globals()['processor']
    if 'trainer' in globals(): del globals()['trainer']
    if 'peft_model' in globals(): del globals()['peft_model']
    if 'bnb_config' in globals(): del globals()['bnb_config']
    time.sleep(2)

    # Garbage collection and clearing CUDA memory
    gc.collect()
    time.sleep(2)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    time.sleep(2)
    gc.collect()
    time.sleep(2)

    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

clear_memory()

GPU allocated memory: 0.02 GB
GPU reserved memory: 4.01 GB


In [23]:
import torch, gc

gc.collect()

# Force empty cache
torch.cuda.empty_cache()

# (Optional) Reset CUDA context for the process (drastic)
torch.cuda.ipc_collect()

print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")


Allocated: 0.02 GB
Reserved:  4.01 GB


In [None]:
from peft import PeftModel

# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

processor = AutoProcessor.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

['merged_model/processor_config.json']

In [None]:
import os
print(os.getcwd())

/workspace


In [24]:
import torch
from peft import PeftModel
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig


args = SFTConfig(
    output_dir="gemma-idcard-FT",     # directory to save and repository id
    num_train_epochs=1,                         # number of training epochs
    per_device_train_batch_size=1,              # batch size per device during training
    gradient_accumulation_steps=4,              # number of steps before performing a backward/update pass
    gradient_checkpointing=True,                # use gradient checkpointing to save memory
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    logging_steps=5,                            # log every 5 steps
    save_strategy="epoch",                      # save checkpoint every epoch
    learning_rate=2e-4,                         # learning rate, based on QLoRA paper
    bf16=True,                                  # use bfloat16 precision
    max_grad_norm=0.3,                          # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                         # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",               # use constant learning rate scheduler
    # push_to_hub=True,                           # push model to hub
    report_to="tensorboard",                    # report metrics to tensorboard
    gradient_checkpointing_kwargs={
        "use_reentrant": False
    },  # use reentrant checkpointing
    dataset_text_field="",                      # need a dummy field for collator
    dataset_kwargs={"skip_prepare_dataset": True},  # important for collator
    per_device_eval_batch_size=1, # Reduce the evaluation batch size
)
args.remove_unused_columns = False # important for collator

# Load Model with PEFT adapter and 4-bit quantization for evaluation
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_storage=torch.bfloat16,
)

model = AutoModelForImageTextToText.from_pretrained(
  "/content/gemma-idcard-FT",
  device_map="auto",
  torch_dtype=torch.bfloat16,
  attn_implementation="eager",
  quantization_config=bnb_config, # Add quantization config
)

processor = AutoProcessor.from_pretrained("/content/gemma-idcard-FT")
# eval_dataset = load_dataset("AliceRolan/CurrencyDataset", split="test")
# eval_dataset = [format_data(sample) for sample in eval_dataset] # Apply format_data to eval_dataset


eval_trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=None,  # Optional or use validation dataset
    eval_dataset=eval_dataset,  # Use your validation dataset here
    # tokenizer=processor,
    processing_class=processor,
    # peft_config=peft_config
    data_collator=collate_fn,
    # compute_metrics=compute_metrics,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [25]:
results1 = eval_trainer.evaluate()
print(results1)

{'eval_loss': 0.008958653546869755, 'eval_model_preparation_time': 0.0149, 'eval_runtime': 487.2082, 'eval_samples_per_second': 2.053, 'eval_steps_per_second': 2.053}


In [26]:
import math
print("Perplexity:", math.exp(results1["eval_loss"]))

Perplexity: 1.0089989023855734


In [27]:
eval_trainer.push_to_hub()

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

events.out.tfevents.1754848676.f544100c312f.1803.1:   0%|          | 0.00/420 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/2.72G [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.62k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/AliceRolan/gemma-idcard-FT/commit/3b3101eda975e156f2a2a0faa23bab93a8e5aead', commit_message='End of training', commit_description='', oid='3b3101eda975e156f2a2a0faa23bab93a8e5aead', pr_url=None, repo_url=RepoUrl('https://huggingface.co/AliceRolan/gemma-idcard-FT', endpoint='https://huggingface.co', repo_type='model', repo_id='AliceRolan/gemma-idcard-FT'), pr_revision=None, pr_num=None)

In [28]:
system_message

'Your task is to:\n    - Identify the **document type**.\n    - Determine whether the document is **Real** or **Fake** based on below reasoning:\n    - Suspicious or inconsistent entries.\n    - Font inconsistencies.\n    - Violations of standard banking or accounting practices.\n    - Textual or numeric manipulation (e.g., formatting issues, overwritten values).\n    - Metadata mismatches (e.g., conflicting dates, fake signatures/stamps).\n    - Unnatural linguistic patterns or overly generic phrasing.\n    - Semantic inconsistencies or hallucinated data.\n\nReturn your output in the following json format:\nDocumentType: <e.g., Bank Statement, Salary Slip, ID Card>\nAuthenticity: <Original, Fraud, Real, Fake, Genuine>\nReason: <Clear, concise explanation with observed issues related to authenticity>\n'

In [29]:
def execute_prompt(img):
  messages = [
    {
        "role": "system",
        "content": [{"type": "text", "text": f"You are a helpful banking assistant and are a forensic financial analyst specializing in detecting fraud, forgery, and AI-generated content in banking documents.{system_message}"}]
    },
    {
        "role": "user",
        "content": [
            {"type": "image", "image": img},
            {"type": "text", "text": "Idenitify the documentType and authenticity as Real or fake and provide reason for authenticity if found as fake and provide fraud score as fraudScore and prediction confidence score as confidenceScore. Return your output in json format"}
        ]
    }
]

  inputs = processor.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=True,
    return_dict=True, return_tensors="pt"
  ).to(model.device, dtype=torch.bfloat16)

  input_len = inputs["input_ids"].shape[-1]

  with torch.inference_mode():
      generation = model.generate(**inputs, max_new_tokens=1024, do_sample=True,top_p=1.0, temperature=0.5)
      generation = generation[0][input_len:]

  decoded = processor.decode(generation, skip_special_tokens=True)
  return decoded


In [30]:
decode = execute_prompt(Image.open("/content/SlovakID-Fake-1.png"))
print(decode)

Authenticity: Fake, DocumentType: SlovakIDCard, Reason: Manipulated or edited. Face Replacement or morphing


In [41]:
val_dataset = load_dataset("AliceRolan/IDCardDataset", split=["test"])

In [42]:
val_dataset[0][0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1134x716>,
 'label': 0,
 'documentType': 'SlovakIDCard',
 'category': 'Real',
 'filename': 'Real_SlovakIDCard_001.png',
 'reason': 'Geniune. No Manipulation'}

In [43]:
import json
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
# Initialize the ROUGE scorer
  # 'rouge1', 'rouge2', 'rougeL' measure overlap of unigrams, bigrams,
  # and the longest common subsequence, respectively.
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
# Initialize the NLTK smoothing function for BLEU
chencherry = SmoothingFunction()
results_list = []
# Initialize variables to store the sum of scores for averaging
total_scores = {
    'filename': None,
    'rouge1_f': 0,
    'rouge2_f': 0,
    'rougeL_f': 0,
    'bleu': 0
}
i=0
for data in val_dataset[0]:
  # print(data)
  output = execute_prompt(data['image'])

  # llm_output_json = output.replace('`','').strip()
  import json
  from rouge_score import rouge_scorer
  from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

  # --- 1. Define the Ground Truth and LLM Output ---

  # The ground truth string from your example
  # ground_truth_text = "Authenticity: Real, DocumentType: IndianCurrency, Reason: Geniune. No Manipulation"

  # # The LLM's output in JSON format
  # ground_truth_text = {
  #   "DocumentType": data["documentType"],
  #   "Authenticity": data["category"],
  #   "Reason": data["reason"]
  # }

  # For a fair comparison, we'll convert the LLM's JSON output into a single string.
  # We'll concatenate the relevant values.
  ground_truth_text = (
      f"DocumentType: {data['documentType']},"
      f"Authenticity: {data['category']}, "
      f"Reason: {data['reason']}"
  )
  candidate_text = output.replace('`','').strip()

  # print("--- Texts for Comparison ---")
  # print(f"Reference (Ground Truth): {ground_truth_text}")
  # print(f"Candidate (LLM Output): {candidate_text}\n")


  # --- 2. Calculate ROUGE Scores ---


  # Calculate scores
  rouge_scores = scorer.score(ground_truth_text, candidate_text)

  # # Add F1-scores to totals
  # total_scores['filename'] += data['filename']
  # total_scores['rouge1_f'] += rouge_scores['rouge1'].fmeasure
  # total_scores['rouge2_f'] += rouge_scores['rouge2'].fmeasure
  # total_scores['rougeL_f'] += rouge_scores['rougeL'].fmeasure

  # print("Individual ROUGE Scores:")
  # print(f"  ROUGE-1 F1: {rouge_scores['rouge1'].fmeasure:.4f}")
  # print(f"  ROUGE-2 F1: {rouge_scores['rouge2'].fmeasure:.4f}")
  # print(f"  ROUGE-L F1: {rouge_scores['rougeL'].fmeasure:.4f}")
 # --- BLEU Calculation ---
  reference_tokens = [ground_truth_text.lower().split()]
  candidate_tokens = candidate_text.lower().split()

  bleu_score = sentence_bleu(
      reference_tokens,
      candidate_tokens,
      weights=(0.25, 0.25, 0.25, 0.25), # Standard BLEU-4
      smoothing_function=chencherry.method1
  )

  # total_scores['bleu'] += bleu_score
  # Store all individual results in a dictionary
  individual_results = {
      'item_id': i + 1,
      'filename': data['filename'],
      'rouge1_f': rouge_scores['rouge1'].fmeasure,
      'rouge2_f': rouge_scores['rouge2'].fmeasure,
      'rougeL_f': rouge_scores['rougeL'].fmeasure,
      'bleu': bleu_score,
      # 'ground_truth': ground_truth_text, # Optional: for context
      'llm_output': candidate_text # Optional: for context
  }
  print("Processing completed for file",data['filename'])
  i+=1
  # Add the dictionary to our list
  results_list.append(individual_results)

print("Processing complete. Storing results in DataFrame.")

  # print(f"Individual BLEU Score: {bleu_score:.4f}")


Processing completed for file Real_SlovakIDCard_001.png
Processing completed for file Real_SlovakIDCard_002.png
Processing completed for file Real_SlovakIDCard_003.png
Processing completed for file Real_SlovakIDCard_004.png
Processing completed for file Real_SlovakIDCard_005.png
Processing completed for file Real_SlovakIDCard_006.png
Processing completed for file Real_SlovakIDCard_007.png
Processing completed for file Real_SlovakIDCard_008.png
Processing completed for file Real_SlovakIDCard_009.png
Processing completed for file Real_SlovakIDCard_010.png
Processing completed for file Real_SlovakIDCard_011.png
Processing completed for file Real_SlovakIDCard_012.png
Processing completed for file Real_SlovakIDCard_013.png
Processing completed for file Real_SlovakIDCard_014.png
Processing completed for file Real_SlovakIDCard_015.png
Processing completed for file Real_SlovakIDCard_016.png
Processing completed for file Real_SlovakIDCard_017.png
Processing completed for file Real_SlovakIDCard_

In [44]:

# --- 4. Calculate and Display Average Scores ---

num_items = len(val_dataset)
import pandas as pd

# Create a DataFrame from the list of dictionaries
df = pd.DataFrame(results_list)
df

Unnamed: 0,item_id,filename,rouge1_f,rouge2_f,rougeL_f,bleu,llm_output
0,1,Real_SlovakIDCard_001.png,1.0,0.714286,0.750000,0.382603,"Authenticity: Real, DocumentType: SlovakIDCard..."
1,2,Real_SlovakIDCard_002.png,1.0,0.714286,0.750000,0.382603,"Authenticity: Real, DocumentType: SlovakIDCard..."
2,3,Real_SlovakIDCard_003.png,1.0,0.714286,0.750000,0.382603,"Authenticity: Real, DocumentType: SlovakIDCard..."
3,4,Real_SlovakIDCard_004.png,1.0,0.714286,0.750000,0.382603,"Authenticity: Real, DocumentType: SlovakIDCard..."
4,5,Real_SlovakIDCard_005.png,1.0,0.714286,0.750000,0.382603,"Authenticity: Real, DocumentType: SlovakIDCard..."
...,...,...,...,...,...,...,...
995,996,Fake_SlovakIDCard_496.png,1.0,0.818182,0.833333,0.648412,"Authenticity: Fake, DocumentType: SlovakIDCard..."
996,997,Fake_SlovakIDCard_497.png,1.0,0.818182,0.833333,0.648412,"Authenticity: Fake, DocumentType: SlovakIDCard..."
997,998,Fake_SlovakIDCard_498.png,1.0,0.818182,0.833333,0.648412,"Authenticity: Fake, DocumentType: SlovakIDCard..."
998,999,Fake_SlovakIDCard_499.png,1.0,0.818182,0.833333,0.648412,"Authenticity: Fake, DocumentType: SlovakIDCard..."


In [45]:
df.to_csv("IDCardDataset-GemmaFT-Results.csv", index=False)

In [46]:
df.describe()

Unnamed: 0,item_id,rouge1_f,rouge2_f,rougeL_f,bleu
count,1000.0,1000.0,1000.0,1000.0,1000.0
mean,500.5,0.850639,0.556905,0.663428,0.332782
std,288.819436,0.229038,0.293344,0.1752,0.226833
min,1.0,0.352941,0.111111,0.235294,0.020256
25%,250.75,0.5,0.111111,0.4,0.022702
50%,500.5,1.0,0.714286,0.75,0.382603
75%,750.25,1.0,0.714286,0.75,0.382603
max,1000.0,1.0,0.818182,0.833333,0.648412
