In [1]:
from datasets import load_dataset, DatasetDict, load_from_disk
from huggingface_hub import login
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    DefaultDataCollator,
    TrainingArguments,
    Trainer,
)
import evaluate
import numpy as np

import os

In [2]:
login(os.getenv("HF_READ"))

In [3]:
dataset = load_dataset("SABR22/threat_classification")
#dataset = load_from_disk("threat_dataset\\threatv2\hf_threat_dataset_format")

README.md:   0%|          | 0.00/465 [00:00<?, ?B/s]

train-00000-of-00014.parquet:   0%|          | 0.00/353M [00:00<?, ?B/s]

train-00001-of-00014.parquet:   0%|          | 0.00/21.1M [00:00<?, ?B/s]

train-00002-of-00014.parquet:   0%|          | 0.00/153M [00:00<?, ?B/s]

train-00003-of-00014.parquet:   0%|          | 0.00/469M [00:00<?, ?B/s]

train-00004-of-00014.parquet:   0%|          | 0.00/216M [00:00<?, ?B/s]

train-00005-of-00014.parquet:   0%|          | 0.00/329M [00:00<?, ?B/s]

train-00006-of-00014.parquet:   0%|          | 0.00/258M [00:00<?, ?B/s]

train-00007-of-00014.parquet:   0%|          | 0.00/292M [00:00<?, ?B/s]

train-00008-of-00014.parquet:   0%|          | 0.00/174M [00:00<?, ?B/s]

train-00009-of-00014.parquet:   0%|          | 0.00/269M [00:00<?, ?B/s]

train-00010-of-00014.parquet:   0%|          | 0.00/279M [00:00<?, ?B/s]

train-00011-of-00014.parquet:   0%|          | 0.00/177M [00:00<?, ?B/s]

train-00012-of-00014.parquet:   0%|          | 0.00/64.9M [00:00<?, ?B/s]

train-00013-of-00014.parquet:   0%|          | 0.00/58.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/24573 [00:00<?, ? examples/s]

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 24573
    })
})

In [5]:
dataset = dataset["train"].train_test_split(test_size=0.15, seed=42)
dataset = DatasetDict({
    "train": dataset["train"],
    "validation": dataset["test"],
})

In [6]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 20887
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 3686
    })
})

In [7]:
labels = dataset['train'].features['label'].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [15]:
label2id

{'non-threat': '0', 'threat': '1'}

In [8]:
checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


In [9]:
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transfroms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [10]:
def transforms(examples):
    examples["pixel_values"] = [_transfroms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

In [11]:
dataset = dataset.with_transform(transforms)

In [12]:
data_collator = DefaultDataCollator()

In [21]:
# accuracy = evaluate.load("accuracy")

# def compute_metrics(eval_pred):
#     predictions, labels = eval_pred
#     predictions = np.argmax(predictions, axis=1)
#     return accuracy.compute(predictions=predictions, references=labels)

def compute_metrics(pred):
    from sklearn.metrics import f1_score, precision_score, recall_score
    logits, labels = pred
    predictions = np.argmax(logits, axis=1)
    f1 = f1_score(labels, predictions, average="binary")
    precision = precision_score(labels, predictions, average="binary")
    recall = recall_score(labels, predictions, average="binary")
    return {"f1": f1, "precision": precision, "recall": recall}


In [22]:
model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
login(os.getenv("HF_WRITE"))

In [24]:
training_args = TrainingArguments(
    output_dir="./ViT-threat-classification-v2",
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=32,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=32,
    num_train_epochs=5,
    warmup_ratio=0.1,
    logging_steps=20,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    save_total_limit=2,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
)

trainer.train()

  0%|          | 0/1630 [00:00<?, ?it/s]

{'loss': 0.6866, 'grad_norm': 2.142916679382324, 'learning_rate': 3.6809815950920245e-06, 'epoch': 0.06}
{'loss': 0.6314, 'grad_norm': 1.8821499347686768, 'learning_rate': 7.361963190184049e-06, 'epoch': 0.12}
{'loss': 0.5293, 'grad_norm': 1.7498693466186523, 'learning_rate': 1.1042944785276074e-05, 'epoch': 0.18}
{'loss': 0.3508, 'grad_norm': 1.277146339416504, 'learning_rate': 1.4723926380368098e-05, 'epoch': 0.25}
{'loss': 0.2569, 'grad_norm': 1.3036167621612549, 'learning_rate': 1.8404907975460123e-05, 'epoch': 0.31}
{'loss': 0.201, 'grad_norm': 2.536496162414551, 'learning_rate': 2.208588957055215e-05, 'epoch': 0.37}
{'loss': 0.1282, 'grad_norm': 0.8879151940345764, 'learning_rate': 2.5766871165644174e-05, 'epoch': 0.43}
{'loss': 0.0938, 'grad_norm': 3.226513147354126, 'learning_rate': 2.9447852760736196e-05, 'epoch': 0.49}
{'loss': 0.1061, 'grad_norm': 3.6998634338378906, 'learning_rate': 2.965235173824131e-05, 'epoch': 0.55}
{'loss': 0.083, 'grad_norm': 4.1234049797058105, 'lear

  0%|          | 0/116 [00:00<?, ?it/s]

{'eval_loss': 0.05755264312028885, 'eval_f1': 0.9465648854961832, 'eval_precision': 0.9738219895287958, 'eval_recall': 0.9207920792079208, 'eval_runtime': 40.4558, 'eval_samples_per_second': 91.112, 'eval_steps_per_second': 2.867, 'epoch': 1.0}
{'loss': 0.0588, 'grad_norm': 0.3646495044231415, 'learning_rate': 2.638036809815951e-05, 'epoch': 1.04}
{'loss': 0.0404, 'grad_norm': 0.5889979600906372, 'learning_rate': 2.5971370143149284e-05, 'epoch': 1.1}
{'loss': 0.0503, 'grad_norm': 0.21267925202846527, 'learning_rate': 2.5562372188139063e-05, 'epoch': 1.16}
{'loss': 0.0532, 'grad_norm': 4.012630939483643, 'learning_rate': 2.5153374233128835e-05, 'epoch': 1.23}
{'loss': 0.0556, 'grad_norm': 1.1993904113769531, 'learning_rate': 2.474437627811861e-05, 'epoch': 1.29}
{'loss': 0.0519, 'grad_norm': 3.324092149734497, 'learning_rate': 2.4335378323108386e-05, 'epoch': 1.35}
{'loss': 0.0559, 'grad_norm': 0.4710441529750824, 'learning_rate': 2.3926380368098158e-05, 'epoch': 1.41}
{'loss': 0.0437, 

  0%|          | 0/116 [00:00<?, ?it/s]

{'eval_loss': 0.03974929824471474, 'eval_f1': 0.9641367806505421, 'eval_precision': 0.9747048903878583, 'eval_recall': 0.9537953795379538, 'eval_runtime': 37.3348, 'eval_samples_per_second': 98.728, 'eval_steps_per_second': 3.107, 'epoch': 2.0}
{'loss': 0.0356, 'grad_norm': 1.7082453966140747, 'learning_rate': 1.983640081799591e-05, 'epoch': 2.02}
{'loss': 0.02, 'grad_norm': 0.10511762648820877, 'learning_rate': 1.9427402862985686e-05, 'epoch': 2.08}
{'loss': 0.0386, 'grad_norm': 0.73386549949646, 'learning_rate': 1.901840490797546e-05, 'epoch': 2.14}
{'loss': 0.0568, 'grad_norm': 4.3746185302734375, 'learning_rate': 1.8609406952965237e-05, 'epoch': 2.21}
{'loss': 0.0387, 'grad_norm': 0.9029757380485535, 'learning_rate': 1.8200408997955012e-05, 'epoch': 2.27}
{'loss': 0.0423, 'grad_norm': 4.824296474456787, 'learning_rate': 1.7791411042944788e-05, 'epoch': 2.33}
{'loss': 0.0402, 'grad_norm': 0.8919986486434937, 'learning_rate': 1.738241308793456e-05, 'epoch': 2.39}
{'loss': 0.0328, 'gr

  0%|          | 0/116 [00:00<?, ?it/s]

{'eval_loss': 0.04087326303124428, 'eval_f1': 0.9646672144617913, 'eval_precision': 0.9607201309328969, 'eval_recall': 0.9686468646864687, 'eval_runtime': 37.3983, 'eval_samples_per_second': 98.561, 'eval_steps_per_second': 3.102, 'epoch': 3.0}
{'loss': 0.0257, 'grad_norm': 0.18956641852855682, 'learning_rate': 1.3292433537832312e-05, 'epoch': 3.0}
{'loss': 0.0182, 'grad_norm': 0.2626650631427765, 'learning_rate': 1.2883435582822087e-05, 'epoch': 3.06}
{'loss': 0.0267, 'grad_norm': 1.263091802597046, 'learning_rate': 1.247443762781186e-05, 'epoch': 3.12}
{'loss': 0.0239, 'grad_norm': 0.3936562240123749, 'learning_rate': 1.2065439672801638e-05, 'epoch': 3.19}
{'loss': 0.0291, 'grad_norm': 0.8159292340278625, 'learning_rate': 1.1656441717791411e-05, 'epoch': 3.25}
{'loss': 0.0238, 'grad_norm': 4.220809459686279, 'learning_rate': 1.1247443762781187e-05, 'epoch': 3.31}
{'loss': 0.0354, 'grad_norm': 0.48322027921676636, 'learning_rate': 1.0838445807770962e-05, 'epoch': 3.37}
{'loss': 0.032,

  0%|          | 0/116 [00:00<?, ?it/s]

{'eval_loss': 0.03816292807459831, 'eval_f1': 0.9650122050447518, 'eval_precision': 0.9518459069020867, 'eval_recall': 0.9785478547854786, 'eval_runtime': 37.7324, 'eval_samples_per_second': 97.688, 'eval_steps_per_second': 3.074, 'epoch': 4.0}
{'loss': 0.0227, 'grad_norm': 0.1687212437391281, 'learning_rate': 6.339468302658486e-06, 'epoch': 4.04}
{'loss': 0.021, 'grad_norm': 1.6843472719192505, 'learning_rate': 5.930470347648262e-06, 'epoch': 4.1}
{'loss': 0.0201, 'grad_norm': 0.06507040560245514, 'learning_rate': 5.521472392638037e-06, 'epoch': 4.17}
{'loss': 0.0187, 'grad_norm': 0.188401997089386, 'learning_rate': 5.112474437627812e-06, 'epoch': 4.23}
{'loss': 0.0187, 'grad_norm': 2.627936363220215, 'learning_rate': 4.703476482617587e-06, 'epoch': 4.29}
{'loss': 0.0202, 'grad_norm': 7.665559768676758, 'learning_rate': 4.294478527607362e-06, 'epoch': 4.35}
{'loss': 0.0118, 'grad_norm': 0.4879707396030426, 'learning_rate': 3.885480572597137e-06, 'epoch': 4.41}
{'loss': 0.0084, 'grad_n

  0%|          | 0/116 [00:00<?, ?it/s]

{'eval_loss': 0.03812766447663307, 'eval_f1': 0.9656862745098039, 'eval_precision': 0.9563106796116505, 'eval_recall': 0.9752475247524752, 'eval_runtime': 38.4128, 'eval_samples_per_second': 95.958, 'eval_steps_per_second': 3.02, 'epoch': 4.99}
{'train_runtime': 2859.5456, 'train_samples_per_second': 36.522, 'train_steps_per_second': 0.57, 'train_loss': 0.06734145838607308, 'epoch': 4.99}


TrainOutput(global_step=1630, training_loss=0.06734145838607308, metrics={'train_runtime': 2859.5456, 'train_samples_per_second': 36.522, 'train_steps_per_second': 0.57, 'total_flos': 8.081174644968112e+18, 'train_loss': 0.06734145838607308, 'epoch': 4.992343032159265})

In [25]:
trainer.push_to_hub(token=os.getenv("HF_WRITE"))

CommitInfo(commit_url='https://huggingface.co/SABR22/ViT-threat-classification-v2/commit/32dfeebba898d14ce27b579ec3a8ac34903fc867', commit_message='End of training', commit_description='', oid='32dfeebba898d14ce27b579ec3a8ac34903fc867', pr_url=None, repo_url=RepoUrl('https://huggingface.co/SABR22/ViT-threat-classification-v2', endpoint='https://huggingface.co', repo_type='model', repo_id='SABR22/ViT-threat-classification-v2'), pr_revision=None, pr_num=None)