<a href="https://colab.research.google.com/github/nateraw/huggingface-hub-examples/blob/main/huggingface_timm_trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
! pip install git+https://github.com/nateraw/pytorch-image-models.git@hf-save-and-push --upgrade
! pip install transformers datasets huggingface_hub
! apt install git-lfs
! git config --global credential.helper store

In [26]:
! huggingface-cli login


        _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
        _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
        _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
        _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
        _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

        
Username: nateraw
Password: 
Login successful
Your token has been saved to /root/.huggingface/token


To train and push a `timm` model, its as simple as this:

```python
import timm

# Build a model 🔧
model = timm.create_model('resnet18', pretrained=True, num_classes=4)

# Push it to the 🤗 hub 
timm.models.hub.push_to_hf_hub(
    model,
    repo_path_or_name='resnet18-random-classifier',
    commit_message='😎 Pushed from timm',
    git_user='nateraw',
    git_email='naterawdata@gmail.com',
    config={'num_classes': model.num_classes}
)

# Load from hub 🔥
model_reloaded = timm.create_model(
    'hf_hub:nateraw/resnet18-random-classifier',
    pretrained=True
)
```

## Fine-tuning a `timm` model

Instead of pushing up a random classifier like the snippet above does, lets fine-tune a `timm` model using the `Trainer` from `transformers`. 

In [2]:
import json
import logging
import os
from datetime import datetime
from typing import Optional

import datasets
import numpy as np
import timm
import torch
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from transformers import HfArgumentParser, Trainer, TrainingArguments
from transformers.modeling_utils import ModelOutput

logger = logging.getLogger(__name__)

In [4]:
class TimmTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        logits = model(inputs["pixel_values"])
        labels = inputs.get("labels")

        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, model.num_classes), labels.view(-1))

        return (loss, ModelOutput(logits=logits, loss=loss)) if return_outputs else loss

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        # If we are executing this function, we are the process zero, so we don't check for that.
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")

        # Save the model
        timm.models.hub.save_pretrained_for_hf(self.model, output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

In [29]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

def compute_metrics(p):
    acc1, acc5 = timm.utils.metrics.accuracy(
        torch.tensor(p.predictions, dtype=torch.float32),
        torch.tensor(p.label_ids), topk=(1, 5)
    )
    return {'acc1': acc1, 'acc5': acc5}

In [30]:
class TimmTransforms:
    def __init__(self, data_config):
        self.data_config = data_config
        self._train_transforms = create_transform(is_training=True, **data_config)
        self._val_transforms = create_transform(is_training=False, **data_config)

    def pil_loader(self, path):
        with open(path, "rb") as f:
            im = Image.open(path)
            im = im.convert("RGB")
        return im

    def train_transforms(self, example_batch):
        """Apply _train_transforms across a batch."""
        example_batch["pixel_values"] = [
            self._train_transforms(self.pil_loader(f)) for f in example_batch["image_file_path"]
        ]
        return example_batch

    def val_transforms(self, example_batch):
        """Apply _val_transforms across a batch."""
        example_batch["pixel_values"] = [
            self._val_transforms(self.pil_loader(f)) for f in example_batch["image_file_path"]
        ]
        return example_batch

In [None]:
model_name = 'resnet18'
dataset_name = 'cats_vs_dogs'
timestamp = datetime.now().strftime("%Y-%m-%d-%H%M%S")
hf_hub_model_id = 'nateraw/my-cool-timm-model-3'

# Percent of train to use as val
val_split_percent = 0.20
# Percent of val to use as test
test_split_percent = 0.25

# Define Training Arguments
training_args = TrainingArguments(
    output_dir=f'{model_name}-{dataset_name}-{timestamp}',
    remove_unused_columns=False,
    evaluation_strategy='epoch',
    report_to='tensorboard',
    push_to_hub=True,
    logging_strategy='steps',
    logging_steps=10,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=256,
    fp16=True,
    max_steps=10,
    hub_model_id=hf_hub_model_id
)

# Init Dataset
ds = datasets.load_dataset(dataset_name, task="image-classification")

# Init model
model = timm.create_model(
    model_name,
    pretrained=True,
    num_classes=len(ds["train"].features["labels"].names),
)

# Define transforms
data_config = resolve_data_config({}, model=model)
transforms = TimmTransforms(data_config)

if 'validation' not in ds:
    split = ds["train"].train_test_split(val_split_percent)
    ds["train"] = split["train"]
    ds["validation"] = split["test"]
if 'test' not in ds:
    split = ds["validation"].train_test_split(test_split_percent)
    ds["validation"] = split["train"]
    ds["test"] = split["test"]

# Init Trainer
trainer = TimmTrainer(
    model,
    training_args,
    data_collator=collate_fn,
    train_dataset=ds["train"].with_transform(transforms.train_transforms),
    eval_dataset=ds["validation"].with_transform(transforms.val_transforms),
    tokenizer=None,
    compute_metrics=compute_metrics,
)

# Training
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

# Evaluation
eval_results = trainer.evaluate(ds["test"].with_transform(transforms.val_transforms), metric_key_prefix='test')
trainer.log_metrics("test", eval_results)
trainer.save_metrics("test", eval_results)

# Write model card and (optionally) push to hub
kwargs = {
    "finetuned_from": model_name,
    "tasks": "image-classification",
    "dataset": dataset_name,
    "tags": ["image-classification", "timm"],
}
if training_args.push_to_hub:
    trainer.push_to_hub(**kwargs)
else:
    trainer.create_model_card(**kwargs)

PyTorch: setting up devices
Using custom data configuration default
Reusing dataset cats_vs_dogs (/root/.cache/huggingface/datasets/cats_vs_dogs/default/0.0.0/e44d25c0884431043d9eff89690884a4794720faf7b5ef0ed48191fa2f79295b)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/cats_vs_dogs/default/0.0.0/e44d25c0884431043d9eff89690884a4794720faf7b5ef0ed48191fa2f79295b/cache-bd7ae1e44740468e.arrow
  "Argument interpolation should be of type InterpolationMode instead of int. "
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/cats_vs_dogs/default/0.0.0/e44d25c0884431043d9eff89690884a4794720faf7b5ef0ed48191fa2f79295b/cache-49bea2d3d459fac4.arrow and /root/.cache/huggingface/datasets/cats_vs_dogs/default/0.0.0/e44d25c0884431043d9eff89690884a4794720faf7b5ef0ed48191fa2f79295b/cache-91b28ba72fcdf5d7.arrow
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/cats_vs_dogs/default/0.0.0/e44d25c0884431043d9eff89690884a4794720faf7b5ef0ed48191fa2f79295b/cache-37a8014fc1bdf1d4.arrow and /root/.cache/huggingface/datasets/cats_vs_dogs/default/0.0.0/e44d25c0884431043d9eff89690884a4794720faf7b5ef0ed48191fa2f79295b/cache-b9e243ab39d4a966.arrow
Cloning https

Epoch,Training Loss,Validation Loss
