In [1]:
import os
import numpy as np

import torch
import pandas as pd
from torchinfo import summary
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
)

from torch.utils.data import Subset
from torchvision.datasets import ImageFolder
from transformers import (
    SiglipForImageClassification,
    TrainingArguments,
    Trainer,
    AutoImageProcessor,
    EarlyStoppingCallback,
)

from src.transformers import train_transforms, val_transforms, test_transforms
from src.callbacks import CHECKPOINT_DIR

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

print("CUDA Available:", torch.cuda.is_available())
print("CUDA Version:", torch.version.cuda)

CUDA Available: True
CUDA Version: 12.6


In [2]:
checkpoint = "google/siglip2-base-patch16-224"

model = SiglipForImageClassification.from_pretrained(
    checkpoint,
    num_labels=2,
)
processor = AutoImageProcessor.from_pretrained(checkpoint, use_fast=True)

summary(model, input_size=(1, 3, 224, 224))

Some weights of SiglipForImageClassification were not initialized from the model checkpoint at google/siglip2-base-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Layer (type:depth-idx)                                  Output Shape              Param #
SiglipForImageClassification                            [1, 2]                    --
├─SiglipVisionTransformer: 1-1                          [1, 768]                  --
│    └─SiglipVisionEmbeddings: 2-1                      [1, 196, 768]             --
│    │    └─Conv2d: 3-1                                 [1, 768, 14, 14]          590,592
│    │    └─Embedding: 3-2                              [1, 196, 768]             150,528
│    └─SiglipEncoder: 2-2                               [1, 196, 768]             --
│    │    └─ModuleList: 3-3                             --                        85,054,464
│    └─LayerNorm: 2-3                                   [1, 196, 768]             1,536
│    └─SiglipMultiheadAttentionPoolingHead: 2-4         [1, 768]                  768
│    │    └─MultiheadAttention: 3-4                     [1, 1, 768]               2,362,368
│    │    └─LayerNorm: 3-5     

In [3]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(
        pred.predictions[0]
        if isinstance(pred.predictions, tuple)
        else pred.predictions,
        axis=-1,
    )

    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="weighted")
    recall = recall_score(labels, preds, average="weighted")
    f1 = f1_score(labels, preds, average="weighted")

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }


def collate_fn(batch):
    images, labels = zip(*batch)
    return {"pixel_values": torch.stack(images), "labels": torch.tensor(labels)}


def evaluate_model(trainer, limit=None):
    test_ds = ImageFolder(
        os.path.join("datasets", "rest_test"), transform=test_transforms
    )
    if limit is not None:
        test_ds = Subset(test_ds, range(limit))
    score_rest = trainer.evaluate(test_ds)

    test_ds = ImageFolder(
        os.path.join("datasets", "wit_test"), transform=test_transforms
    )
    if limit is not None:
        test_ds = Subset(test_ds, range(limit))
    score_wit = trainer.evaluate(test_ds)

    return pd.DataFrame([score_rest, score_wit], index=["rest", "wit"])

# Other datasets

In [4]:
TARGET_DIR = os.path.join(CHECKPOINT_DIR, "siglip2", "other")
os.makedirs(TARGET_DIR, exist_ok=True)

train_ds = ImageFolder(
    os.path.join("datasets", "rest_train"), transform=train_transforms
)
# train_ds = Subset(train_ds, range(100))

val_ds = ImageFolder(os.path.join("datasets", "rest_val"), transform=val_transforms)
# val_ds = Subset(val_ds, range(100))

In [5]:
training_args = TrainingArguments(
    output_dir=TARGET_DIR,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=1000,
    learning_rate=5e-4,
    weight_decay=0.01,
    fp16=True,
    logging_dir=os.path.join(TARGET_DIR, "logs"),
    logging_steps=100,
    logging_first_step=True,
    warmup_steps=500,
    load_best_model_at_end=True,
    lr_scheduler_type="cosine",
    gradient_accumulation_steps=2,
    # metric_for_best_model="f1",
    # greater_is_better=True,
    save_total_limit=3,
    report_to=["tensorboard"],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
    data_collator=collate_fn,
)

try:
    trainer.train()
except KeyboardInterrupt:
    print("Training interrupted. Saving the model...")
finally:
    model.save_pretrained(os.path.join(TARGET_DIR, "model"))
    processor.save_pretrained(os.path.join(TARGET_DIR, "processor"))

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.3599,0.292432,0.885503,0.891092,0.885503,0.8858
2,0.3362,0.388499,0.823901,0.851418,0.823901,0.816936
3,0.3599,0.366436,0.833914,0.837699,0.833914,0.832022
4,0.3813,0.40682,0.815629,0.822322,0.815629,0.816095
5,0.3826,0.404702,0.819983,0.823476,0.819983,0.820429
6,0.3886,0.403304,0.810623,0.814234,0.810623,0.80827


In [6]:
evaluate_model(trainer=trainer).T

Unnamed: 0,rest,wit
eval_loss,0.289871,0.31235
eval_accuracy,0.885963,0.8787
eval_precision,0.892775,0.879517
eval_recall,0.885963,0.8787
eval_f1,0.886357,0.874884
eval_runtime,459.8264,171.7863
eval_samples_per_second,9.993,58.212
eval_steps_per_second,0.313,1.822
epoch,6.0,6.0


# Our dataset

In [7]:
TARGET_DIR = os.path.join(CHECKPOINT_DIR, "siglip2", "wit")
os.makedirs(TARGET_DIR, exist_ok=True)

train_ds = ImageFolder(
    os.path.join("datasets", "wit_train"), transform=train_transforms
)
# train_ds = Subset(train_ds, range(100))
val_ds = ImageFolder(os.path.join("datasets", "wit_val"), transform=val_transforms)
# val_ds = Subset(val_ds, range(100))

In [8]:
training_args = TrainingArguments(
    output_dir=TARGET_DIR,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=1000,
    learning_rate=5e-4,
    weight_decay=0.01,
    fp16=True,
    logging_dir=os.path.join(TARGET_DIR, "logs"),
    logging_steps=100,
    warmup_steps=500,
    logging_first_step=True,
    load_best_model_at_end=True,
    lr_scheduler_type="cosine",
    gradient_accumulation_steps=2,
    # metric_for_best_model="f1",
    # greater_is_better=True,
    save_total_limit=3,
    report_to=["tensorboard"],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    data_collator=collate_fn,
)

try:
    trainer.train()
except KeyboardInterrupt:
    print("Training interrupted. Saving the model...")
finally:
    model.save_pretrained(os.path.join(TARGET_DIR, "model"))
    processor.save_pretrained(os.path.join(TARGET_DIR, "processor"))

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.2045,0.195957,0.9188,0.921943,0.9188,0.919573
2,0.2451,0.232362,0.9045,0.906685,0.9045,0.905179
3,0.2816,0.321944,0.8707,0.875127,0.8707,0.86523
4,0.2556,0.275362,0.885,0.889506,0.885,0.886241


In [9]:
evaluate_model(trainer=trainer).T

Unnamed: 0,rest,wit
eval_loss,0.483702,0.188542
eval_accuracy,0.835909,0.9243
eval_precision,0.840014,0.927249
eval_recall,0.835909,0.9243
eval_f1,0.833695,0.925043
eval_runtime,80.3859,171.1234
eval_samples_per_second,57.162,58.437
eval_steps_per_second,1.791,1.829
epoch,4.0,4.0
