In [1]:
import os
import numpy as np

import torch
import pandas as pd
from torchinfo import summary
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
)

from torch.utils.data import Subset
from torchvision.datasets import ImageFolder
from transformers import (
    TrainingArguments,
    Trainer,
    AutoModelForImageClassification,
    AutoImageProcessor,
    EarlyStoppingCallback,
)

from src.transformers import train_transforms, val_transforms, test_transforms
from src.callbacks import CHECKPOINT_DIR

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

print("CUDA Available:", torch.cuda.is_available())
print("CUDA Version:", torch.version.cuda)

CUDA Available: True
CUDA Version: 12.6


In [2]:
checkpoint = "timm/mobilenetv4_hybrid_medium.e500_r224_in1k"

model = AutoModelForImageClassification.from_pretrained(checkpoint)
processor = AutoImageProcessor.from_pretrained(checkpoint, use_fast=True)

summary(model, input_size=(1, 3, 224, 224))

Layer (type:depth-idx)                                  Output Shape              Param #
TimmWrapperForImageClassification                       [1, 1000]                 --
├─MobileNetV3: 1-1                                      [1, 1000]                 --
│    └─Conv2d: 2-1                                      [1, 32, 112, 112]         864
│    └─BatchNormAct2d: 2-2                              [1, 32, 112, 112]         64
│    │    └─Identity: 3-1                               [1, 32, 112, 112]         --
│    │    └─ReLU: 3-2                                   [1, 32, 112, 112]         --
│    └─Sequential: 2-3                                  [1, 960, 7, 7]            --
│    │    └─Sequential: 3-3                             [1, 48, 56, 56]           43,360
│    │    └─Sequential: 3-4                             [1, 80, 28, 28]           59,712
│    │    └─Sequential: 3-5                             [1, 160, 14, 14]          1,947,920
│    │    └─Sequential: 3-6                 

In [3]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(
        pred.predictions[0]
        if isinstance(pred.predictions, tuple)
        else pred.predictions,
        axis=-1,
    )

    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="weighted")
    recall = recall_score(labels, preds, average="weighted")
    f1 = f1_score(labels, preds, average="weighted")

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }


def collate_fn(batch):
    images, labels = zip(*batch)
    return {"pixel_values": torch.stack(images), "labels": torch.tensor(labels)}


def evaluate_model(trainer, limit=None):
    test_ds = ImageFolder(
        os.path.join("datasets", "rest_test"), transform=test_transforms
    )
    if limit is not None:
        test_ds = Subset(test_ds, range(limit))
    score_rest = trainer.evaluate(test_ds)

    test_ds = ImageFolder(
        os.path.join("datasets", "wit_test"), transform=test_transforms
    )
    if limit is not None:
        test_ds = Subset(test_ds, range(limit))
    score_wit = trainer.evaluate(test_ds)

    return pd.DataFrame([score_rest, score_wit], index=["rest", "wit"])

# Other datasets

In [4]:
TARGET_DIR = os.path.join(CHECKPOINT_DIR, "mobilenet", "other")
os.makedirs(TARGET_DIR, exist_ok=True)

train_ds = ImageFolder(
    os.path.join("datasets", "rest_train"), transform=train_transforms
)
# train_ds = Subset(train_ds, range(100))

val_ds = ImageFolder(os.path.join("datasets", "rest_val"), transform=val_transforms)
# val_ds = Subset(val_ds, range(100))

In [5]:
training_args = TrainingArguments(
    output_dir=TARGET_DIR,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=1000,
    learning_rate=5e-4,
    weight_decay=0.01,
    fp16=True,
    logging_dir=os.path.join(TARGET_DIR, "logs"),
    logging_steps=100,
    logging_first_step=True,
    warmup_steps=500,
    load_best_model_at_end=True,
    lr_scheduler_type="cosine",
    gradient_accumulation_steps=2,
    # metric_for_best_model="f1",
    # greater_is_better=True,
    save_total_limit=3,
    report_to=["tensorboard"],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
    data_collator=collate_fn,
)

try:
    trainer.train()
except KeyboardInterrupt:
    print("Training interrupted. Saving the model...")
finally:
    model.save_pretrained(os.path.join(TARGET_DIR, "model"))
    processor.save_pretrained(os.path.join(TARGET_DIR, "processor"))

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.1323,0.180474,0.958206,0.959881,0.958206,0.958023
2,0.0866,0.110209,0.959948,0.962068,0.959948,0.960037
3,0.0764,0.065332,0.980192,0.980294,0.980192,0.980204
4,0.0671,0.050342,0.985633,0.985637,0.985633,0.985635
5,0.0613,0.074897,0.976926,0.977269,0.976926,0.976952
6,0.0529,0.041778,0.986069,0.986116,0.986069,0.98606
7,0.0584,0.064377,0.988681,0.988691,0.988681,0.988683
8,0.0517,0.042357,0.98781,0.98781,0.98781,0.987809
9,0.0569,6.401874,0.603396,0.72645,0.603396,0.506545
10,0.0369,0.033462,0.989769,0.989794,0.989769,0.989765


The `save_pretrained` method is disabled for TimmWrapperImageProcessor. The image processor configuration is saved directly in `config.json` when `save_pretrained` is called for saving the model.


In [6]:
evaluate_model(trainer=trainer).T

Unnamed: 0,rest,wit
eval_loss,0.02697,0.861446
eval_accuracy,0.990207,0.7116
eval_precision,0.99022,0.785282
eval_recall,0.990207,0.7116
eval_f1,0.990203,0.720564
eval_runtime,28.7855,58.8255
eval_samples_per_second,159.629,169.994
eval_steps_per_second,5.003,5.321
epoch,15.0,15.0


# Our dataset

In [7]:
TARGET_DIR = os.path.join(CHECKPOINT_DIR, "mobilenet", "wit")
os.makedirs(TARGET_DIR, exist_ok=True)

train_ds = ImageFolder(
    os.path.join("datasets", "wit_train"), transform=train_transforms
)
# train_ds = Subset(train_ds, range(100))
val_ds = ImageFolder(os.path.join("datasets", "wit_val"), transform=val_transforms)
# val_ds = Subset(val_ds, range(100))

In [8]:
training_args = TrainingArguments(
    output_dir=TARGET_DIR,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=1000,
    learning_rate=5e-4,
    weight_decay=0.01,
    fp16=True,
    logging_dir=os.path.join(TARGET_DIR, "logs"),
    logging_steps=100,
    warmup_steps=500,
    logging_first_step=True,
    load_best_model_at_end=True,
    lr_scheduler_type="cosine",
    gradient_accumulation_steps=2,
    # metric_for_best_model="f1",
    # greater_is_better=True,
    save_total_limit=3,
    report_to=["tensorboard"],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    data_collator=collate_fn,
)

try:
    trainer.train()
except KeyboardInterrupt:
    print("Training interrupted. Saving the model...")
finally:
    model.save_pretrained(os.path.join(TARGET_DIR, "model"))
    processor.save_pretrained(os.path.join(TARGET_DIR, "processor"))

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.1099,0.092464,0.9638,0.96434,0.9638,0.963942
2,0.0955,0.065102,0.9767,0.976756,0.9767,0.976721
3,0.0868,0.069809,0.9758,0.975874,0.9758,0.975693
4,0.0781,0.084138,0.9697,0.970194,0.9697,0.96944
5,0.0796,0.062597,0.976,0.976041,0.976,0.975905
6,0.0729,0.078142,0.9723,0.972758,0.9723,0.972077
7,0.0764,0.057125,0.9793,0.979284,0.9793,0.97929
8,0.0738,0.055456,0.9823,0.982282,0.9823,0.982274
9,0.073,0.070988,0.9801,0.980077,0.9801,0.980083
10,0.0574,0.049798,0.9836,0.983585,0.9836,0.983577


In [9]:
evaluate_model(trainer=trainer).T

Unnamed: 0,rest,wit
eval_loss,0.49886,0.046354
eval_accuracy,0.892927,0.9834
eval_precision,0.896322,0.98344
eval_recall,0.892927,0.9834
eval_f1,0.891907,0.983415
eval_runtime,29.0881,61.2401
eval_samples_per_second,157.968,163.292
eval_steps_per_second,4.95,5.111
epoch,14.0,14.0
