In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install -q datasets transformers
!pip install transformers[torch]
!pip install accelerate -U



In [3]:
from huggingface_hub import notebook_login, login

login("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")
# notebook_login()

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [4]:
from datasets import load_dataset
from datasets import load_metric
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from transformers import AutoImageProcessor
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
import torch
import numpy as np
from PIL import Image
import cv2
import os
from torchvision.transforms import (
    RandomRotation,
    ColorJitter,
    GaussianBlur,
    RandomAffine,
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

In [5]:
dataset = load_dataset("imagefolder", data_dir="/content/drive/MyDrive/Datasets/hf_head_dataset")
id2label = {0 : "Normal", 1 : "Abnormal"}
label2id = {"Normal" : 0, "Abnormal" : 1}
# model_checkpoint = "microsoft/beit-base-patch16-224"
# model_checkpoint = "microsoft/dit-base-finetuned-rvlcdip"
# model_checkpoint = "Zetatech/pvt-tiny-224"
# model_checkpoint = "MBZUAI/swiftformer-xs"
# model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
model_checkpoint = "google/vit-base-patch16-224"
batch_size = 64
image_processor  = AutoImageProcessor.from_pretrained(model_checkpoint)

Resolving data files:   0%|          | 0/1000 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/240 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/300 [00:00<?, ?it/s]

In [6]:
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if "height" in image_processor.size:
    size = (image_processor.size["height"], image_processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = image_processor.size.get("longest_edge")

train_transforms = Compose(
        [
            Resize(size),
            # RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            # RandomRotation(degrees=30),
            ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.2),
            # GaussianBlur(kernel_size=7),
            RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=5),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [7]:
train_ds = dataset["train"]
val_ds = dataset["validation"]
test_ds = dataset["test"]
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)
test_ds.set_transform(preprocess_val)

In [8]:
model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
model_name = model_checkpoint.split("/")[-1]

args = TrainingArguments(
    f"{model_name}",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=30,
    warmup_ratio=0.1,
    logging_steps=30,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)

metric = load_metric("accuracy")

  metric = load_metric("accuracy")


In [10]:
# def compute_metrics(eval_pred):
#     """Computes accuracy on a batch of predictions"""
#     predictions = np.argmax(eval_pred.predictions, axis=1)
#     return metric.compute(predictions=predictions, references=eval_pred.label_ids)

def compute_metrics(eval_pred):
    """Computes accuracy, precision, recall, and F1-score on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    labels = eval_pred.label_ids

    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, average='weighted')
    recall = recall_score(labels, predictions, average='weighted')
    f1 = f1_score(labels, predictions, average='weighted')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1
    }

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [11]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

In [12]:
train_results = trainer.train()
print(train_results)
test_results = trainer.evaluate(test_ds)
print(test_results)
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1 Score
1,No log,0.584086,0.733333,0.677011,0.733333,0.64793
2,No log,0.572685,0.733333,0.537778,0.733333,0.620513
3,No log,0.608915,0.720833,0.722227,0.720833,0.721518
4,No log,0.533219,0.745833,0.720498,0.745833,0.672658
5,No log,0.531422,0.7625,0.741013,0.7625,0.741552
6,No log,0.528419,0.758333,0.748588,0.758333,0.695905
7,No log,0.521981,0.775,0.77,0.775,0.728571
8,0.556400,0.520447,0.783333,0.774,0.783333,0.748148
9,0.556400,0.504403,0.770833,0.761638,0.770833,0.765024
10,0.556400,0.484481,0.8125,0.805086,0.8125,0.794117


  _warn_prf(average, modifier, msg_start, len(result))


TrainOutput(global_step=120, training_loss=0.29066893259684246, metrics={'train_runtime': 1718.9255, 'train_samples_per_second': 17.453, 'train_steps_per_second': 0.07, 'total_flos': 2.32475968843776e+18, 'train_loss': 0.29066893259684246, 'epoch': 30.0})


{'eval_loss': 0.4410024881362915, 'eval_accuracy': 0.8566666666666667, 'eval_precision': 0.8522571872571872, 'eval_recall': 0.8566666666666667, 'eval_f1_score': 0.8517268099292696, 'eval_runtime': 209.1833, 'eval_samples_per_second': 1.434, 'eval_steps_per_second': 0.024, 'epoch': 30.0}
***** train metrics *****
  epoch                    =         30.0
  total_flos               = 2165101178GF
  train_loss               =       0.2907
  train_runtime            =   0:28:38.92
  train_samples_per_second =       17.453
  train_steps_per_second   =         0.07


In [13]:
trainer.predict(test_ds)

PredictionOutput(predictions=array([[ 3.3304825e+00, -2.5004392e+00],
       [ 2.2165306e+00, -1.5902930e+00],
       [ 2.9848788e+00, -1.5227240e+00],
       [ 2.6635427e+00, -2.5224636e+00],
       [ 1.1937218e+00, -7.5723988e-01],
       [ 3.2451699e+00, -3.1203432e+00],
       [ 2.9070003e+00, -2.8376637e+00],
       [ 8.9994067e-01, -1.9225847e+00],
       [-8.9169294e-02, -1.9362932e-01],
       [ 8.5416210e-01, -9.4588429e-01],
       [ 2.0232575e+00, -1.5426316e+00],
       [ 1.3317599e+00, -1.3865502e+00],
       [ 1.4092352e+00, -8.1956685e-01],
       [ 2.1408925e+00, -1.2749697e+00],
       [ 1.1710196e+00, -4.7222561e-01],
       [ 1.6628503e+00, -2.8681242e+00],
       [ 3.4437612e-01, -1.2054706e-02],
       [ 2.4297497e+00, -2.1347966e+00],
       [ 2.9887469e+00, -3.0184093e+00],
       [ 1.7454376e+00, -7.8144008e-01],
       [ 2.0842812e+00, -2.1996326e+00],
       [ 2.9874361e+00, -2.6103559e+00],
       [ 1.9342282e+00, -2.2506452e+00],
       [ 4.4279230e-01, -1.9