In [1]:
import os
import numpy as np

import torch
import pandas as pd
from torchinfo import summary
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
)

from torch.utils.data import Subset
from torchvision.datasets import ImageFolder
from transformers import SiglipForImageClassification, TrainingArguments, Trainer, AutoImageProcessor, EarlyStoppingCallback

from src.transformers import train_transforms, val_transforms, test_transforms
from src.callbacks import CHECKPOINT_DIR

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

print("CUDA Available:", torch.cuda.is_available())
print("CUDA Version:", torch.version.cuda)

CUDA Available: True
CUDA Version: 12.6


In [None]:
checkpoint = "google/siglip2-base-patch16-224"

model = SiglipForImageClassification.from_pretrained(
    checkpoint,
    num_labels=2,
)
processor = AutoImageProcessor.from_pretrained(checkpoint, use_fast=True)

summary(model, input_size=(1, 3, 224, 224))

Some weights of SiglipForImageClassification were not initialized from the model checkpoint at google/siglip2-base-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Layer (type:depth-idx)                                  Output Shape              Param #
SiglipForImageClassification                            [1, 2]                    --
├─SiglipVisionTransformer: 1-1                          [1, 768]                  --
│    └─SiglipVisionEmbeddings: 2-1                      [1, 196, 768]             --
│    │    └─Conv2d: 3-1                                 [1, 768, 14, 14]          590,592
│    │    └─Embedding: 3-2                              [1, 196, 768]             150,528
│    └─SiglipEncoder: 2-2                               [1, 196, 768]             --
│    │    └─ModuleList: 3-3                             --                        85,054,464
│    └─LayerNorm: 2-3                                   [1, 196, 768]             1,536
│    └─SiglipMultiheadAttentionPoolingHead: 2-4         [1, 768]                  768
│    │    └─MultiheadAttention: 3-4                     [1, 1, 768]               2,362,368
│    │    └─LayerNorm: 3-5     

In [3]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions[0] if isinstance(pred.predictions, tuple) else pred.predictions, axis=-1)

    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average='weighted')
    recall = recall_score(labels, preds, average='weighted')
    f1 = f1_score(labels, preds, average='weighted')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

def collate_fn(batch):
    images, labels = zip(*batch)
    return {"pixel_values": torch.stack(images), "labels": torch.tensor(labels)}

def evaluate_model(trainer, limit=None):
    test_ds = ImageFolder(os.path.join("datasets", "rest_test"), transform=test_transforms)
    if limit is not None:
        test_ds = Subset(test_ds, range(limit))
    score_rest = trainer.evaluate(test_ds)

    test_ds = ImageFolder(os.path.join("datasets", "wit_test"), transform=test_transforms)
    if limit is not None:
        test_ds = Subset(test_ds, range(limit))
    score_wit = trainer.evaluate(test_ds)

    return pd.DataFrame([score_rest, score_wit], index=["rest", "wit"])

# Other datasets

In [4]:
TARGET_DIR = os.path.join(
    CHECKPOINT_DIR, "siglip2", "other"
)
os.makedirs(TARGET_DIR, exist_ok=True)

train_ds = ImageFolder(os.path.join("datasets", "rest_train"), transform=train_transforms)
# train_ds = Subset(train_ds, range(100))

val_ds = ImageFolder(os.path.join("datasets", "rest_val"), transform=val_transforms)
# val_ds = Subset(val_ds, range(100))

In [5]:
training_args = TrainingArguments(
    output_dir=TARGET_DIR,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=1000,
    learning_rate=5e-5,
    weight_decay=0.01,
    fp16=True,
    logging_dir=os.path.join(TARGET_DIR, "logs_other"),
    logging_steps=100,
    logging_first_step=True,
    warmup_steps=500,
    load_best_model_at_end=True,
    # metric_for_best_model="f1",
    # greater_is_better=True,
    save_total_limit=3,
    report_to=["tensorboard"],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
    data_collator=collate_fn
)

try:
    trainer.train()
except KeyboardInterrupt:
    print("Training interrupted. Saving the model...")
finally:
    model.save_pretrained(os.path.join(TARGET_DIR, "model_other"))
    processor.save_pretrained(os.path.join(TARGET_DIR, "processor_other"))

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.0807,0.058582,0.987593,0.987727,0.987593,0.98758
2,0.059,0.051118,0.982586,0.982874,0.982586,0.982559
3,0.071,0.059224,0.982804,0.983103,0.982804,0.982777
4,0.0543,0.035338,0.990422,0.990494,0.990422,0.990415
5,0.0568,0.057108,0.986286,0.986498,0.986286,0.986269
6,0.0467,0.041782,0.990858,0.99093,0.990858,0.990851
7,0.0432,0.061002,0.98781,0.987968,0.98781,0.987797
8,0.0244,0.052385,0.983457,0.983791,0.983457,0.98343
9,0.0392,0.033738,0.99064,0.990644,0.99064,0.990641
10,0.0408,0.036805,0.990422,0.990445,0.990422,0.990418


In [6]:
evaluate_model(trainer=trainer).T

Unnamed: 0,rest,wit
eval_loss,0.024456,1.483998
eval_accuracy,0.993254,0.7284
eval_precision,0.993265,0.792155
eval_recall,0.993254,0.7284
eval_f1,0.993251,0.737035
eval_runtime,79.9253,169.8421
eval_samples_per_second,57.491,58.878
eval_steps_per_second,1.802,1.843
epoch,18.0,18.0


# Our dataset

In [7]:
TARGET_DIR = os.path.join(
    CHECKPOINT_DIR, "siglip2", "wit"
)
os.makedirs(TARGET_DIR, exist_ok=True)

train_ds = ImageFolder(os.path.join("datasets", "wit_train"), transform=train_transforms)
# train_ds = Subset(train_ds, range(100))
val_ds = ImageFolder(os.path.join("datasets", "wit_val"), transform=val_transforms)
# val_ds = Subset(val_ds, range(100))

In [8]:
training_args = TrainingArguments(
    output_dir=TARGET_DIR,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=1000,
    learning_rate=5e-5,
    weight_decay=0.01,
    fp16=True,
    logging_dir=os.path.join(TARGET_DIR, "logs_combined"),
    logging_steps=100,
    warmup_steps=500,
    logging_first_step=True,
    load_best_model_at_end=True,
    # metric_for_best_model="f1",
    # greater_is_better=True,
    save_total_limit=3,
    report_to=["tensorboard"],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    data_collator=collate_fn
)

try:
    trainer.train()
except KeyboardInterrupt:
    print("Training interrupted. Saving the model...")
finally:
    model.save_pretrained(os.path.join(TARGET_DIR, "model_combined"))
    processor.save_pretrained(os.path.join(TARGET_DIR, "processor_combined"))

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.0848,0.063858,0.9765,0.977085,0.9765,0.976609
2,0.0942,0.050997,0.9818,0.981802,0.9818,0.981756
3,0.0646,0.05308,0.9823,0.982316,0.9823,0.982252
4,0.0648,0.051041,0.9818,0.981844,0.9818,0.98174
5,0.065,0.068069,0.9735,0.97493,0.9735,0.9737


In [9]:
evaluate_model(trainer=trainer).T

Unnamed: 0,rest,wit
eval_loss,0.461022,0.05264
eval_accuracy,0.902067,0.9825
eval_precision,0.903942,0.982496
eval_recall,0.902067,0.9825
eval_f1,0.901408,0.982456
eval_runtime,79.085,168.7196
eval_samples_per_second,58.102,59.27
eval_steps_per_second,1.821,1.855
epoch,5.0,5.0
