# Fine-tuning a Vision Transformer for Document Validation (Valid/Faked Classification)

This notebook demonstrates how to fine-tune an open-source Vision Transformer (ViT) model for classifying documents as 'valid' or 'invalid/faked' based on visual features. You will need a dataset of document images labeled as 'valid' or 'faked'.


In [2]:
# 2. Import Libraries
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
from sklearn.model_selection import train_test_split
from transformers import ViTFeatureExtractor, ViTForImageClassification
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix


  from .autonotebook import tqdm as notebook_tqdm


## 3. Dataset Preparation

You need a dataset of document images, ideally organized into subfolders for each class (e.g., `dataset/valid/` and `dataset/faked/`).

Example directory structure:
```
dataset/
├── valid/
│   ├── doc1.jpg
│   ├── doc2.png
│   └── ...
└── faked/
    ├── forged1.jpg
    ├── tampered2.png
    └── ...
```

Define your dataset path and create a custom Dataset class.


In [None]:
# Define your dataset root directory
DATA_DIR = "./custom_document_dataset" # **Update this path to your dataset**
NUM_CLASSES = 2 # 'valid', 'faked'
LABELS = {"valid": 0, "faked": 1} # Map class names to integers
ID2LABEL = {0: "valid", 1: "faked"}

class DocumentDataset(Dataset):
    def __init__(self, image_paths, labels, feature_extractor):
        self.image_paths = image_paths
        self.labels = labels
        self.feature_extractor = feature_extractor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        
        image = Image.open(img_path).convert("RGB")
        pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values
        
        return {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(label)}

# Gather all image paths and labels
all_image_paths = []
all_labels = []

for label_name, label_id in LABELS.items():
    class_dir = os.path.join(DATA_DIR, label_name)
    if os.path.exists(class_dir):
        for img_name in os.listdir(class_dir):
            if img_name.endswith((".jpg", ".jpeg", ".png")):
                all_image_paths.append(os.path.join(class_dir, img_name))
                all_labels.append(label_id)

if not all_image_paths:
    raise ValueError(f"No images found in {DATA_DIR}. Please check your dataset path and structure.")

# Split dataset into training and validation sets
train_paths, val_paths, train_labels, val_labels = train_test_split(
    all_image_paths, all_labels, test_size=0.2, random_state=42, stratify=all_labels
)

print(f"Found {len(train_paths)} training images and {len(val_paths)} validation images.")

# Initialize feature extractor (e.g., for ViT)
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

train_dataset = DocumentDataset(train_paths, train_labels, feature_extractor)
val_dataset = DocumentDataset(val_paths, val_labels, feature_extractor)

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)


## 4. Load Pre-trained Vision Transformer Model


In [None]:
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224", # You can choose other ViT models
    num_labels=NUM_CLASSES,
    id2label=ID2LABEL,
    label2id=LABELS
)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print(f"Model loaded and moved to {device}.")


## 5. Fine-tuning the Model


In [None]:
# Define optimizer and learning rate
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Training loop
num_epochs = 3 # Adjust as needed
model.train()

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    total_loss = 0
    for batch in tqdm(train_dataloader, desc="Training"):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

    avg_train_loss = total_loss / len(train_dataloader)
    print(f"Average training loss: {avg_train_loss:.4f}")

print("Fine-tuning complete!")


## 6. Evaluation


In [None]:
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(val_dataloader, desc="Evaluating"):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(pixel_values=pixel_values)
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        
        all_preds.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=list(LABELS.keys())))

print("\nConfusion Matrix:")
print(confusion_matrix(all_labels, all_preds))


## 7. Save the Fine-tuned Model


In [None]:
SAVE_DIR = "./fine_tuned_vit_classifier" # Directory to save the model
os.makedirs(SAVE_DIR, exist_ok=True)

model.save_pretrained(SAVE_DIR)
feature_extractor.save_pretrained(SAVE_DIR)

print(f"Fine-tuned model and feature extractor saved to {SAVE_DIR}")

# To load the model later:
# loaded_feature_extractor = ViTFeatureExtractor.from_pretrained(SAVE_DIR)
# loaded_model = ViTForImageClassification.from_pretrained(SAVE_DIR)


## Next Steps:

1.  **Prepare your dataset:** Create a directory structure as described in Section 3 with 'valid' and 'faked' document images.
2.  **Update `DATA_DIR`:** Modify the `DATA_DIR` variable in Section 3 to point to your dataset.
3.  **Run all cells:** Execute the notebook cells sequentially.
4.  **Integrate:** Once you have a fine-tuned model, you can integrate it into your Streamlit `app.py` to classify uploaded documents.
