In [1]:
!pip install -U \
  "numpy==1.26.4" \
  "pandas==2.2.2" \
  "pyarrow==15.0.2" \
  "datasets==2.19.1" \
  "transformers==4.57.1" \
  "timm>=0.9.12" \
  "evaluate==0.4.2" \
  "accelerate>=0.34.2" \
  "scikit-learn>=1.3" \
  "matplotlib>=3.8" \
  "pillow>=10.3" \
  "protobuf<5"




In [None]:
import os, json, random, math, time
from typing import Dict
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn

from datasets import load_dataset, Image, ClassLabel
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    DefaultDataCollator,
    set_seed,
)
import evaluate
from sklearn.metrics import f1_score, confusion_matrix, classification_report


SEED = 42
set_seed(SEED)
torch.backends.cudnn.benchmark = True

OUTPUT_DIR = "outputs/dit_base_rvlcdip_3cls"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Model & training hparams
MODEL_NAME   = "microsoft/dit-base"
IMAGE_SIZE   = 224
TRAIN_BS     = 16
EVAL_BS      = 32
GRAD_ACCUM   = 2
EPOCHS       = 12
LR           = 5e-5
WEIGHT_DECAY = 0.05
WARMUP_RATIO = 0.10
PATIENCE     = 3
USE_FP16     = True


In [None]:
import warnings, io
from PIL import Image as PILImage
from datasets import load_dataset, Image, ClassLabel


warnings.filterwarnings("ignore", message="Corrupt EXIF data")
warnings.filterwarnings("ignore", message="cannot identify image file")


ds = load_dataset("chainyo/rvl-cdip")

ds = ds.cast_column("image", Image(decode=False))

# Keep only 3 target classes
target_names = ["email", "invoice", "scientific publication"]  # new ids will be 0,1,2 in this order
full_names = ds["train"].features["label"].names
name2id_full = {n: i for i, n in enumerate(full_names)}
keep_ids_full = [name2id_full[n] for n in target_names]

def keep_label_fn(example):
    # Uses only the integer label â€” no image decoding involved
    return example["label"] in keep_ids_full

ds_small = ds.filter(keep_label_fn)


idfull2new = {name2id_full[n]: i for i, n in enumerate(target_names)}
def relabel_fn(example):
    example["label"] = idfull2new[example["label"]]
    return example

ds_small = ds_small.map(relabel_fn)


new_label_feature = ClassLabel(names=target_names)
for split in list(ds_small.keys()):
    ds_small[split] = ds_small[split].cast_column("label", new_label_feature)


def valid_mask_batch(batch):
    masks = []
    images = batch["image"]  
    for im in images:
        try:
            if im.get("bytes") is not None:
                with io.BytesIO(im["bytes"]) as fh:
                    with PILImage.open(fh) as pil:
                        pil.verify()  
            else:
                
                with PILImage.open(im["path"]) as pil:
                    pil.verify()
            masks.append(True)
        except Exception:
            masks.append(False)
    return {"keep": masks}

for split in list(ds_small.keys()):
    ds_small[split] = ds_small[split].map(
        valid_mask_batch,
        batched=True,
        remove_columns=[],
        desc=f"Validating images in {split}",
    ).filter(lambda keep: keep, input_columns=["keep"])


ds_small = ds_small.cast_column("image", Image(decode=True))


available = set(ds_small.keys())
val_key = "validation" if "validation" in available else ("val" if "val" in available else None)
if val_key is None:

    split_res = ds_small["train"].train_test_split(test_size=0.02, seed=42, stratify_by_column="label")
    ds_small["train"], ds_small["validation"] = split_res["train"], split_res["test"]
    val_key = "validation"

print(ds_small)
print(ds_small["train"].features)
for k in ds_small:
    print(k, "rows:", ds_small[k].num_rows)
print("Using validation split key:", val_key)


Downloading readme: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/119 [00:00<?, ?it/s]

Downloading metadata: 0.00B [00:00, ?B/s]

Loading dataset shards:   0%|          | 0/64 [00:00<?, ?it/s]

Filter:   0%|          | 0/319999 [00:00<?, ? examples/s]

Filter:   0%|          | 0/40000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/40000 [00:00<?, ? examples/s]

Map:   0%|          | 0/59803 [00:00<?, ? examples/s]

Map:   0%|          | 0/7565 [00:00<?, ? examples/s]

Map:   0%|          | 0/7632 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/59803 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/7565 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/7632 [00:00<?, ? examples/s]

Validating images in train:   0%|          | 0/59803 [00:00<?, ? examples/s]

Filter:   0%|          | 0/59803 [00:00<?, ? examples/s]

Validating images in test:   0%|          | 0/7565 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7565 [00:00<?, ? examples/s]

Validating images in val:   0%|          | 0/7632 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7632 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'label', 'keep'],
        num_rows: 59803
    })
    test: Dataset({
        features: ['image', 'label', 'keep'],
        num_rows: 7564
    })
    val: Dataset({
        features: ['image', 'label', 'keep'],
        num_rows: 7632
    })
})
{'image': Image(mode=None, decode=True, id=None), 'label': ClassLabel(names=['email', 'invoice', 'scientific publication'], id=None), 'keep': Value(dtype='bool', id=None)}
train rows: 59803
test rows: 7564
val rows: 7632
Using validation split key: val


In [4]:
id2label = {i: n for i, n in enumerate(target_names)}
label2id = {n: i for i, n in id2label.items()}
num_labels = 3


In [5]:
processor = AutoImageProcessor.from_pretrained(MODEL_NAME, use_fast=True)

def _to_rgb(img):
    return img if img.mode == "RGB" else img.convert("RGB")

def train_transform(examples):
    images = [_to_rgb(img) for img in examples["image"]]
    proc = processor(images=images, return_tensors="pt")
    return {"pixel_values": proc["pixel_values"], "labels": examples["label"]}

def test_transform(examples):
    images = [_to_rgb(img) for img in examples["image"]]
    proc = processor(images=images, return_tensors="pt")
    return {"pixel_values": proc["pixel_values"], "labels": examples["label"]}

train_ds = ds_small["train"].with_transform(train_transform)
val_ds   = ds_small[val_key].with_transform(test_transform)
test_ds  = ds_small["test"].with_transform(test_transform)


In [None]:
model = AutoModelForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)



Some weights of BeitForImageClassification were not initialized from the model checkpoint at microsoft/dit-base and are newly initialized: ['beit.pooler.layernorm.bias', 'beit.pooler.layernorm.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
accuracy_metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    acc = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"]
    f1m = f1_score(labels, preds, average="macro")
    return {"accuracy": acc, "f1_macro": f1m}


Using the latest cached version of the module from /home/jovyan/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--accuracy/f887c0aab52c2d38e1f8a215681126379eca617f96c447638f751434e8e65b14 (last modified on Fri Nov  7 10:22:50 2025) since it couldn't be found locally at evaluate-metric--accuracy, or remotely on the Hugging Face Hub.


In [None]:
args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=TRAIN_BS,
    per_device_eval_batch_size=EVAL_BS,
    gradient_accumulation_steps=GRAD_ACCUM,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    num_train_epochs=EPOCHS,
    learning_rate=LR,
    weight_decay=WEIGHT_DECAY,
    warmup_ratio=WARMUP_RATIO,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    fp16=USE_FP16,
    report_to=[],                  
    seed=SEED,
    dataloader_num_workers=0,     
    remove_unused_columns=False,   
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=processor,
    data_collator=DefaultDataCollator(),
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=PATIENCE)],
)


  trainer = Trainer(


In [9]:
batch = next(iter(trainer.get_train_dataloader()))
print(batch["pixel_values"].shape, batch["labels"].shape)  # expect [16, 3, 224, 224] and [16]


torch.Size([16, 3, 224, 224]) torch.Size([16])


In [None]:
train_output = trainer.train()
train_output


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.0638,0.062512,0.985718,0.985751
2,0.0297,0.024861,0.992925,0.992941
3,0.0232,0.024341,0.993973,0.993984
4,0.024,0.019324,0.995283,0.995292
5,0.0145,0.019069,0.995152,0.995164
6,0.0095,0.028647,0.994628,0.994635


In [None]:
# Test metrics
metrics_test = trainer.evaluate(test_ds, metric_key_prefix="test")
print(metrics_test)

# Per-class report + confusion matrix
preds = trainer.predict(test_ds)
y_true = preds.label_ids
y_pred = preds.predictions.argmax(axis=1)

print("\nClassification report (per-class):")
print(classification_report(y_true, y_pred, target_names=[id2label[i] for i in range(num_labels)]))

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(5,5))
plt.imshow(cm, interpolation='nearest')
plt.title("Confusion Matrix (3 classes)")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.colorbar()
plt.tight_layout()
plt.show()

# Save best checkpoint + processor
best_dir = os.path.join(OUTPUT_DIR, "best")
os.makedirs(best_dir, exist_ok=True)
trainer.save_model(best_dir)
processor.save_pretrained(best_dir)
print("Saved to:", best_dir)
