In [16]:
from datasets import load_dataset, DatasetDict
from huggingface_hub import login
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    DefaultDataCollator,
    TrainingArguments,
    Trainer,
)
import evaluate
import numpy as np

import os

In [2]:
login(os.getenv("HF_READ"))

In [3]:
dataset = load_dataset("SABR22/threat_classification")

README.md:   0%|          | 0.00/445 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/202 [00:00<?, ? examples/s]

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 202
    })
})

In [17]:
dataset = dataset["train"].train_test_split(test_size=8, seed=42)
dataset = DatasetDict({
    "train": dataset["train"],
    "validation": dataset["test"],
})

In [18]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 162
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 8
    })
})

In [19]:
labels = dataset['train'].features['label'].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [20]:
checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


In [21]:
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transfroms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [22]:
def transforms(examples):
    examples["pixel_values"] = [_transfroms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

In [23]:
dataset = dataset.with_transform(transforms)

In [24]:
data_collator = DefaultDataCollator()

In [25]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [26]:
model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [27]:
login(os.getenv("HF_WRITE"))

In [31]:
training_args = TrainingArguments(
    output_dir="./ViT-threat-classification",
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-6,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    warmup_steps=500,
    logging_steps=5,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    weight_decay=0.01,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
)

trainer.train()

  0%|          | 0/30 [00:00<?, ?it/s]

{'loss': 0.3565, 'grad_norm': 4.208386421203613, 'learning_rate': 1e-08, 'epoch': 0.49}
{'loss': 0.328, 'grad_norm': 4.1375603675842285, 'learning_rate': 2e-08, 'epoch': 0.98}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.45563793182373047, 'eval_accuracy': 0.875, 'eval_runtime': 0.0816, 'eval_samples_per_second': 98.08, 'eval_steps_per_second': 24.52, 'epoch': 0.98}
{'loss': 0.3272, 'grad_norm': 4.664910316467285, 'learning_rate': 3e-08, 'epoch': 1.46}
{'loss': 0.3226, 'grad_norm': 4.1039347648620605, 'learning_rate': 4e-08, 'epoch': 1.95}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.47360897064208984, 'eval_accuracy': 0.75, 'eval_runtime': 0.0813, 'eval_samples_per_second': 98.429, 'eval_steps_per_second': 24.607, 'epoch': 1.95}
{'loss': 0.3305, 'grad_norm': 4.455121040344238, 'learning_rate': 5e-08, 'epoch': 2.44}
{'loss': 0.3619, 'grad_norm': 4.2679362297058105, 'learning_rate': 6e-08, 'epoch': 2.93}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.45681232213974, 'eval_accuracy': 1.0, 'eval_runtime': 0.0743, 'eval_samples_per_second': 107.619, 'eval_steps_per_second': 26.905, 'epoch': 2.93}
{'train_runtime': 35.4274, 'train_samples_per_second': 13.718, 'train_steps_per_second': 0.847, 'train_loss': 0.33779648939768475, 'epoch': 2.93}


TrainOutput(global_step=30, training_loss=0.33779648939768475, metrics={'train_runtime': 35.4274, 'train_samples_per_second': 13.718, 'train_steps_per_second': 0.847, 'total_flos': 3.688618705654579e+16, 'train_loss': 0.33779648939768475, 'epoch': 2.926829268292683})

In [32]:
trainer.push_to_hub(token=os.getenv("HF_WRITE"))

model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/SABR22/ViT-threat-classification/commit/404c2610a2e8f49a091eceedec33ee969345ca96', commit_message='End of training', commit_description='', oid='404c2610a2e8f49a091eceedec33ee969345ca96', pr_url=None, repo_url=RepoUrl('https://huggingface.co/SABR22/ViT-threat-classification', endpoint='https://huggingface.co', repo_type='model', repo_id='SABR22/ViT-threat-classification'), pr_revision=None, pr_num=None)