In [1]:
import os
import numpy as np

import torch
import pandas as pd
from torchinfo import summary
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
)

from torch.utils.data import Subset
from torchvision.datasets import ImageFolder
from transformers import TrainingArguments, Trainer, AutoModelForImageClassification, AutoImageProcessor, EarlyStoppingCallback

from src.transformers import train_transforms, val_transforms, test_transforms
from src.callbacks import CHECKPOINT_DIR

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

print("CUDA Available:", torch.cuda.is_available())
print("CUDA Version:", torch.version.cuda)

CUDA Available: True
CUDA Version: 12.6


In [2]:
checkpoint = "timm/mobilenetv4_hybrid_medium.e500_r224_in1k"

model = AutoModelForImageClassification.from_pretrained(checkpoint)
processor = AutoImageProcessor.from_pretrained(checkpoint, use_fast=True)

summary(model, input_size=(1, 3, 224, 224))

Layer (type:depth-idx)                                  Output Shape              Param #
TimmWrapperForImageClassification                       [1, 1000]                 --
├─MobileNetV3: 1-1                                      [1, 1000]                 --
│    └─Conv2d: 2-1                                      [1, 32, 112, 112]         864
│    └─BatchNormAct2d: 2-2                              [1, 32, 112, 112]         64
│    │    └─Identity: 3-1                               [1, 32, 112, 112]         --
│    │    └─ReLU: 3-2                                   [1, 32, 112, 112]         --
│    └─Sequential: 2-3                                  [1, 960, 7, 7]            --
│    │    └─Sequential: 3-3                             [1, 48, 56, 56]           43,360
│    │    └─Sequential: 3-4                             [1, 80, 28, 28]           59,712
│    │    └─Sequential: 3-5                             [1, 160, 14, 14]          1,947,920
│    │    └─Sequential: 3-6                 

In [3]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions[0] if isinstance(pred.predictions, tuple) else pred.predictions, axis=-1)

    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average='weighted')
    recall = recall_score(labels, preds, average='weighted')
    f1 = f1_score(labels, preds, average='weighted')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

def collate_fn(batch):
    images, labels = zip(*batch)
    return {"pixel_values": torch.stack(images), "labels": torch.tensor(labels)}

def evaluate_model(trainer, limit=None):
    test_ds = ImageFolder(os.path.join("datasets", "rest_test"), transform=test_transforms)
    if limit is not None:
        test_ds = Subset(test_ds, range(limit))
    score_rest = trainer.evaluate(test_ds)

    test_ds = ImageFolder(os.path.join("datasets", "wit_test"), transform=test_transforms)
    if limit is not None:
        test_ds = Subset(test_ds, range(limit))
    score_wit = trainer.evaluate(test_ds)

    return pd.DataFrame([score_rest, score_wit], index=["rest", "wit"])

# Other datasets

In [None]:
TARGET_DIR = os.path.join(
    CHECKPOINT_DIR, "mobilenet", "other"
)
os.makedirs(TARGET_DIR, exist_ok=True)

train_ds = ImageFolder(os.path.join("datasets", "rest_train"), transform=train_transforms)
# train_ds = Subset(train_ds, range(100))

val_ds = ImageFolder(os.path.join("datasets", "rest_val"), transform=val_transforms)
# val_ds = Subset(val_ds, range(100))

In [5]:
training_args = TrainingArguments(
    output_dir=TARGET_DIR,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=1000,
    learning_rate=5e-5,
    weight_decay=0.01,
    fp16=True,
    logging_dir=os.path.join(TARGET_DIR, "logs_other"),
    logging_steps=100,
    logging_first_step=True,
    warmup_steps=500,
    load_best_model_at_end=True,
    # metric_for_best_model="f1",
    # greater_is_better=True,
    save_total_limit=3,
    report_to=["tensorboard"],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
    data_collator=collate_fn
)

try:
    trainer.train()
except KeyboardInterrupt:
    print("Training interrupted. Saving the model...")
finally:
    model.save_pretrained(os.path.join(TARGET_DIR, "model_other"))
    processor.save_pretrained(os.path.join(TARGET_DIR, "processor_other"))

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.0531,,0.990858,0.990912,0.990858,0.990862
2,0.0412,,0.989116,0.989172,0.989116,0.989121
3,0.0339,,0.992381,0.992408,0.992381,0.992384
4,0.0262,,0.991293,0.991366,0.991293,0.991287
5,0.026,0.039583,0.993905,0.993941,0.993905,0.993907
6,0.0091,0.026733,0.9963,0.996318,0.9963,0.996301
7,0.0141,0.031203,0.994993,0.994997,0.994993,0.994994
8,0.017,0.027747,0.994776,0.994788,0.994776,0.994777
9,0.0209,0.022525,0.995646,0.995647,0.995646,0.995647
10,0.011,0.031016,0.994776,0.994811,0.994776,0.994778


The `save_pretrained` method is disabled for TimmWrapperImageProcessor. The image processor configuration is saved directly in `config.json` when `save_pretrained` is called for saving the model.


In [6]:
evaluate_model(trainer=trainer).T

Unnamed: 0,rest,wit
eval_loss,,
eval_accuracy,0.995212,0.7961
eval_precision,0.995212,0.805932
eval_recall,0.995212,0.7961
eval_f1,0.995212,0.799276
eval_runtime,29.7428,60.6196
eval_samples_per_second,154.491,164.963
eval_steps_per_second,4.842,5.163
epoch,20.0,20.0


# Our dataset

In [None]:
TARGET_DIR = os.path.join(
    CHECKPOINT_DIR, "mobilenet", "wit"
)
os.makedirs(TARGET_DIR, exist_ok=True)

train_ds = ImageFolder(os.path.join("datasets", "wit_train"), transform=train_transforms)
# train_ds = Subset(train_ds, range(100))
val_ds = ImageFolder(os.path.join("datasets", "wit_val"), transform=val_transforms)
# val_ds = Subset(val_ds, range(100))

In [8]:
training_args = TrainingArguments(
    output_dir=TARGET_DIR,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=1000,
    learning_rate=5e-5,
    weight_decay=0.01,
    fp16=True,
    logging_dir=os.path.join(TARGET_DIR, "logs_combined"),
    logging_steps=100,
    warmup_steps=500,
    logging_first_step=True,
    load_best_model_at_end=True,
    # metric_for_best_model="f1",
    # greater_is_better=True,
    save_total_limit=3,
    report_to=["tensorboard"],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    data_collator=collate_fn
)

try:
    trainer.train()
except KeyboardInterrupt:
    print("Training interrupted. Saving the model...")
finally:
    model.save_pretrained(os.path.join(TARGET_DIR, "model_combined"))
    processor.save_pretrained(os.path.join(TARGET_DIR, "processor_combined"))

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.127,,0.9787,0.978713,0.9787,0.978634
2,0.0761,,0.9806,0.980605,0.9806,0.980548
3,0.0753,5.131879,0.9775,0.977467,0.9775,0.977469
4,0.077,0.056476,0.9792,0.979342,0.9792,0.9791
5,0.0557,0.061594,0.9801,0.980247,0.9801,0.980005
6,0.0597,0.062916,0.9776,0.977801,0.9776,0.977475
7,0.0496,0.051668,0.984,0.984,0.984,0.984
8,0.056,0.042003,0.9857,0.9857,0.9857,0.985673
9,0.0492,0.043823,0.985,0.985103,0.985,0.985026
10,0.0483,0.045573,0.9857,0.985698,0.9857,0.985675


In [9]:
evaluate_model(trainer=trainer).T

Unnamed: 0,rest,wit
eval_loss,0.514079,0.042751
eval_accuracy,0.907726,0.9852
eval_precision,0.910556,0.985185
eval_recall,0.907726,0.9852
eval_f1,0.906976,0.985181
eval_runtime,29.2139,60.7672
eval_samples_per_second,157.288,164.562
eval_steps_per_second,4.929,5.151
epoch,11.0,11.0
