# Swin-Tiny Experiment with 20,000 Samples
- **Purpose**: Compare hierarchical Swin-Tiny against ViT’s global approach for document layouts.
- **Details**: Swin’s shifted window attention excels at structured documents, reaching **79.55%** accuracy.
- **Outcome**: Outperforms ViT slightly, especially on financial classes (e.g., Invoice: **77.6%**).

In [None]:
# Step 1: Installing required libraries & Setting up the Environment
!pip install -q transformers datasets torch torchvision accelerate

import torch
from transformers import SwinForImageClassification, AutoImageProcessor , TrainingArguments, Trainer, EarlyStoppingCallback
from datasets import load_dataset, IterableDataset
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from collections import defaultdict
from torchvision import transforms
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"GPU name: {torch.cuda.get_device_name(0)}")
!free -h
!df -h

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m75.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m57.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m44.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Step 2: Loading Dataset with Balanced Streaming and Augmentation
dataset = load_dataset("aharley/rvl_cdip", streaming=True)
label_map = {0: "letter", 1: "form", 2: "email", 3: "handwritten", 4: "advertisement",
             5: "scientific report", 6: "scientific publication", 7: "specification",
             8: "file folder", 9: "news article", 10: "budget", 11: "invoice",
             12: "presentation", 13: "questionnaire", 14: "resume", 15: "memo"}
num_labels = len(label_map)
processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

class BalancedStreamingDataset(IterableDataset):
    def __init__(self, dataset_split, total_samples, num_classes=16):
        self.dataset = dataset_split
        self.total_samples = total_samples
        self.target_per_class = total_samples // num_classes  # ~1250 for 20,000
        self.num_classes = num_classes
        self._epoch = 0
        self.augment = transforms.Compose([
            transforms.RandomRotation(10),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
        ])

    def __iter__(self):
        class_counts = defaultdict(int)
        samples_yielded = 0
        for example in self.dataset:
            label = example["label"]
            if class_counts[label] < self.target_per_class:
                class_counts[label] += 1
                image = example["image"].convert("RGB")
                image = self.augment(image)
                inputs = processor(images=image, return_tensors="pt")
                yield {
                    "pixel_values": inputs["pixel_values"].squeeze(0),
                    "labels": label
                }
                samples_yielded += 1
                if samples_yielded >= self.total_samples:
                    break

    def __len__(self):
        return self.total_samples

    def set_epoch(self, epoch: int):
        self._epoch = epoch

train_size = 20000
val_size = 2000
test_size = 2000
train_dataset = BalancedStreamingDataset(dataset["train"], train_size)
val_dataset = BalancedStreamingDataset(dataset["validation"], val_size)
test_dataset = BalancedStreamingDataset(dataset["test"], test_size)
print(f"Training size: {train_size}, Validation size: {val_size}, Test size: {test_size}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/6.15k [00:00<?, ?B/s]

rvl_cdip.py:   0%|          | 0.00/4.80k [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/2.64k [00:00<?, ?B/s]

The repository for aharley/rvl_cdip contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/aharley/rvl_cdip.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


preprocessor_config.json:   0%|          | 0.00/255 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Training size: 20000, Validation size: 2000, Test size: 2000


In [None]:
# Step 3: Load Pre-trained Swin Model
model = SwinForImageClassification.from_pretrained(
    "microsoft/swin-tiny-patch4-window7-224",
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)
model.to(device)
print(f"GPU memory allocated: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB")

config.json:   0%|          | 0.00/71.8k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/113M [00:00<?, ?B/s]

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([16]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([16, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


GPU memory allocated: 106.13 MB


In [None]:
# Step 4: Define Metrics
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average="weighted")
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

In [None]:
# Step 5: Setting Up The Training Arguments
training_args = TrainingArguments(
    output_dir="./rvl_cdip_swin",
    run_name="rvl_cdip_swin_20000",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=7,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_dir="./logs",
    logging_steps=50,
    fp16=True,
    gradient_accumulation_steps=8  # Effective batch size 64
)

In [None]:
# Step 6: Train the Model with Early Stopping
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)
print("Starting training...")
trainer.train()
print(f"GPU memory allocated post-training: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB")

Starting training...


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mgavhaneprasad14092001[0m ([33mgavhaneprasad14092001-indian-school-of-mines[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.135,1.145428,0.657,0.692723,0.657,0.656576
2,0.8994,0.914054,0.7245,0.739782,0.7245,0.723739
3,0.7527,0.851549,0.745,0.760701,0.745,0.746
4,0.6534,0.754308,0.779,0.787363,0.779,0.779905
5,0.6241,0.706114,0.7955,0.801254,0.7955,0.796672
6,0.5303,0.678224,0.7945,0.800946,0.7945,0.796291


'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: a728b165-7e03-411e-87ba-08c713a15961)')' thrown while requesting GET https://huggingface.co/datasets/rvl_cdip/resolve/main/data/rvl-cdip.tar.gz
Retrying in 1s [Retry 1/5].
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 54b1c4bf-4329-407a-9a1c-2ffb852f320a)')' thrown while requesting GET https://huggingface.co/datasets/rvl_cdip/resolve/main/data/rvl-cdip.tar.gz
Retrying in 1s [Retry 1/5].
'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: dcc05993-80bc-4d19-8cc3-fde32d88b3f4)')' thrown while requesting GET https://huggingface.co/datasets/rvl_cdip/resolve/main/data/rvl-cdip.tar.gz
Retrying in 1s [Retry 1/5].
'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(

GPU memory allocated post-training: 334.84 MB


In [None]:
# Step 7: Evaluate on Test Set
test_dataset = BalancedStreamingDataset(dataset["test"], test_size)
test_results = trainer.evaluate(test_dataset)
print("Test Results:", test_results)

predictions = trainer.predict(test_dataset)
preds = np.argmax(predictions.predictions, axis=-1)
labels = predictions.label_ids
financial_classes = {1: "form (proxy for tax forms)",
                     10: "budget (proxy for financial reports)",
                     11: "invoice (direct match)",
                     15: "memo (proxy for financial reports)"}
for cls in financial_classes:
    mask = labels == cls
    cls_preds = preds[mask]
    cls_labels = labels[mask]
    acc = accuracy_score(cls_labels, cls_preds) if len(cls_labels) > 0 else 0
    print(f"Accuracy for {financial_classes[cls]}: {acc:.4f}")

'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 54c6a8e7-7264-48d8-91bd-72d6bb782a5e)')' thrown while requesting GET https://huggingface.co/datasets/rvl_cdip/resolve/main/data/rvl-cdip.tar.gz
Retrying in 1s [Retry 1/5].


Test Results: {'eval_loss': 0.7023600935935974, 'eval_accuracy': 0.7955, 'eval_precision': 0.8024634006386182, 'eval_recall': 0.7955, 'eval_f1': 0.7970911594085774, 'eval_runtime': 228.2623, 'eval_samples_per_second': 8.762, 'eval_steps_per_second': 1.095, 'epoch': 6.9792}
Accuracy for form (proxy for tax forms): 0.6800
Accuracy for budget (proxy for financial reports): 0.6960
Accuracy for invoice (direct match): 0.7760
Accuracy for memo (proxy for financial reports): 0.7600


In [None]:
# Step 8: Save the Model
model.save_pretrained("./rvl_cdip_swin_model")
processor.save_pretrained("./rvl_cdip_swin_model")
!du -sh ./rvl_cdip_swin_model
!df -h

# Optional: Save to Google Drive (Colab)
from google.colab import drive
drive.mount('/content/drive')
!cp -r ./rvl_cdip_swin_model /content/drive/MyDrive/rvl_cdip_swin_model

106M	./rvl_cdip_swin_model
Filesystem      Size  Used Avail Use% Mounted on
overlay         113G   42G   72G  37% /
tmpfs            64M     0   64M   0% /dev
shm             5.7G   16K  5.7G   1% /dev/shm
/dev/root       2.0G  1.2G  820M  59% /usr/sbin/docker-init
/dev/sda1        92G   72G   20G  79% /kaggle/input
tmpfs           6.4G  292K  6.4G   1% /var/colab
tmpfs           6.4G     0  6.4G   0% /proc/acpi
tmpfs           6.4G     0  6.4G   0% /proc/scsi
tmpfs           6.4G     0  6.4G   0% /sys/firmware
Mounted at /content/drive
