In [1]:
model_checkpoint = "google/vit-base-patch16-224-in21k"
model_checkpoint

'google/vit-base-patch16-224-in21k'

In [2]:
import os 
import torch
from peft import PeftModel, LoraConfig, get_peft_model
from transformers import AutoModelForImageClassification


def print_model_size(path):
    size = 0
    for f in os.scandir(path):
        size += os.path.getsize(f)

    print(f"Model size: {(size / 1e6):.2} MB")


def print_trainable_parameters(model, label):
    parameters, trainable = 0, 0
    
    for _, p in model.named_parameters():
        parameters += p.numel()
        trainable += p.numel() if p.requires_grad else 0

    print(f"{label} trainable parameters: {trainable:,}/{parameters:,} ({100 * trainable / parameters:.2f}%)")


def split_dataset(dataset):
    dataset_splits = dataset.train_test_split(test_size=0.1)
    return dataset_splits.values()
    

def create_label_mappings(dataset):
    label2id, id2label = dict(), dict()
    for i, label in enumerate(dataset.features["label"].names):
        label2id[label] = i
        id2label[i] = label 

    return label2id, id2label

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from datasets import load_dataset

# This is the food dataset
dataset1 = load_dataset("food101", split="train[:10000]")

# This is the datasets of pictures of cats and dogs.
# Notice we need to rename the label column so we can
# reuse the same code for both datasets.
# dataset2 = load_dataset("microsoft/cats_vs_dogs", split="train", trust_remote_code=True)
# dataset2 = dataset2.rename_column("labels", "label")

dataset1_train, dataset1_test = split_dataset(dataset1)
# dataset2_train, dataset2_test = split_dataset(dataset2)

Downloading readme: 100%|██████████| 10.5k/10.5k [00:00<00:00, 22.8MB/s]
Downloading data: 100%|██████████| 490M/490M [00:43<00:00, 11.1MB/s] 
Downloading data: 100%|██████████| 464M/464M [00:38<00:00, 12.0MB/s] 
Downloading data: 100%|██████████| 472M/472M [00:39<00:00, 12.1MB/s] 
Downloading data: 100%|██████████| 464M/464M [00:38<00:00, 12.1MB/s] 
Downloading data: 100%|██████████| 475M/475M [00:43<00:00, 11.0MB/s] 
Downloading data: 100%|██████████| 470M/470M [00:42<00:00, 11.2MB/s] 
Downloading data: 100%|██████████| 478M/478M [00:39<00:00, 12.0MB/s] 
Downloading data: 100%|██████████| 486M/486M [00:40<00:00, 11.9MB/s] 
Downloading data: 100%|██████████| 423M/423M [00:34<00:00, 12.1MB/s] 
Downloading data: 100%|██████████| 413M/413M [00:38<00:00, 10.9MB/s] 
Downloading data: 100%|██████████| 426M/426M [00:35<00:00, 12.1MB/s] 
Generating train split: 100%|██████████| 75750/75750 [00:02<00:00, 27537.62 examples/s]
Generating validation split: 100%|██████████| 25250/25250 [00:00<00:0

In [4]:
dataset1_label2id, dataset1_id2label = create_label_mappings(dataset1)
# dataset2_label2id, dataset2_id2label = create_label_mappings(dataset2)

In [8]:
model1 = {
        "train_data": dataset1_train,
        "test_data": dataset1_test,
        "label2id": dataset1_label2id,
        "id2label": dataset1_id2label,
        "epochs": 5,
        "path": "./lora-model1"
    }

In [10]:
from transformers import AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained(model_checkpoint, use_fast=True)

In [13]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    Resize,
    ToTensor,
)

preprocess_pipeline = Compose([
    Resize(image_processor.size["height"]),
    CenterCrop(image_processor.size["height"]),
    ToTensor(),
    Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
])

def preprocess(batch):
    batch["pixel_values"] = [
        preprocess_pipeline(image.convert("RGB")) for image in batch["image"]
    ]
    return batch

In [14]:
model1["train_data"].set_transform(preprocess)
model1["test_data"].set_transform(preprocess)

In [16]:
import numpy as np
import evaluate
import torch
from peft import PeftModel, LoraConfig, get_peft_model
from transformers import AutoModelForImageClassification


metric = evaluate.load("accuracy")


def data_collate(examples):
    """
    Prepare a batch of examples from a list of elements of the
    train or test datasets.
    """
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


def compute_metrics(eval_pred):
    """
    Compute the model's accuracy on a batch of predictions.
    """
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)


def get_base_model(label2id, id2label):
    """
    Create an image classification base model from
    the model checkpoint.
    """
    return AutoModelForImageClassification.from_pretrained(
        model_checkpoint,
        label2id=label2id,
        id2label=id2label,
        ignore_mismatched_sizes=True,
    )


def build_lora_model(label2id, id2label):
    """Build the LoRA model to fine-tune the base model."""
    model = get_base_model(label2id, id2label)
    print_trainable_parameters(model, label="Base model")

    config = LoraConfig(
        r=16,
        lora_alpha=16,
        target_modules=["query", "value"],
        lora_dropout=0.1,
        bias="none",
        modules_to_save=["classifier"],
    )

    lora_model = get_peft_model(model, config)
    print_trainable_parameters(lora_model, label="LoRA")

    return lora_model


In [25]:
from transformers import TrainingArguments

batch_size = 16
training_arguments = TrainingArguments(
    output_dir="./model-checkpoints",
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-3,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=4,
    use_cpu=True,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    label_names=["labels"],
)

In [26]:
from transformers import Trainer

training_arguments.num_train_epochs = model1["epochs"]

trainer = Trainer(
    build_lora_model(model1["label2id"], model1["id2label"]),
    training_arguments,
    train_dataset=model1["train_data"],
    eval_dataset=model1["test_data"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=data_collate,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Base model trainable parameters: 85,876,325/85,876,325 (100.00%)
LoRA trainable parameters: 667,493/86,543,818 (0.77%)


In [None]:
from transformers import Trainer

for cfg in config.values():
    training_arguments.num_train_epochs = cfg["epochs"]

    trainer = Trainer(
        build_lora_model(cfg["label2id"], cfg["id2label"]),
        training_arguments,
        train_dataset=cfg["train_data"],
        eval_dataset=cfg["test_data"],
        tokenizer=image_processor,
        compute_metrics=compute_metrics,
        data_collator=data_collate,
    )

    results = trainer.train()
    evaluation_results = trainer.evaluate(cfg['test_data'])
    print(f"Evaluation accuracy: {evaluation_results['eval_accuracy']}")

    # We can now save the fine-tuned model to disk.
    trainer.save_model(cfg["path"])
    print_model_size(cfg["path"])

In [27]:
results = trainer.train()
evaluation_results = trainer.evaluate(model1['test_data'])
print(f"Evaluation accuracy: {evaluation_results['eval_accuracy']}")

# We can now save the fine-tuned model to disk.
trainer.save_model(model1["path"])
print_model_size(model1["path"])

  1%|▏         | 10/700 [01:22<1:32:34,  8.05s/it]

{'loss': 2.4884, 'grad_norm': 0.7463366389274597, 'learning_rate': 0.004928571428571429, 'epoch': 0.07}


  3%|▎         | 20/700 [02:42<1:30:46,  8.01s/it]

{'loss': 0.4479, 'grad_norm': 0.3320975601673126, 'learning_rate': 0.004857142857142858, 'epoch': 0.14}


  4%|▍         | 30/700 [04:03<1:29:51,  8.05s/it]

{'loss': 0.3172, 'grad_norm': 0.7150269746780396, 'learning_rate': 0.004785714285714286, 'epoch': 0.21}


  6%|▌         | 40/700 [05:23<1:28:50,  8.08s/it]

{'loss': 0.2234, 'grad_norm': 0.6246882081031799, 'learning_rate': 0.004714285714285714, 'epoch': 0.28}


  7%|▋         | 50/700 [06:44<1:26:29,  7.98s/it]

{'loss': 0.2972, 'grad_norm': 0.6688875555992126, 'learning_rate': 0.004642857142857143, 'epoch': 0.36}


  9%|▊         | 60/700 [08:03<1:23:58,  7.87s/it]

{'loss': 0.2321, 'grad_norm': 0.6568211913108826, 'learning_rate': 0.004571428571428572, 'epoch': 0.43}


 10%|█         | 70/700 [09:23<1:22:24,  7.85s/it]

{'loss': 0.2576, 'grad_norm': 1.062619924545288, 'learning_rate': 0.0045000000000000005, 'epoch': 0.5}


 11%|█▏        | 80/700 [10:42<1:21:40,  7.90s/it]

{'loss': 0.2315, 'grad_norm': 0.6849544048309326, 'learning_rate': 0.004428571428571428, 'epoch': 0.57}


 13%|█▎        | 90/700 [12:02<1:20:38,  7.93s/it]

{'loss': 0.2553, 'grad_norm': 1.1308133602142334, 'learning_rate': 0.004357142857142857, 'epoch': 0.64}


 14%|█▍        | 100/700 [13:21<1:18:24,  7.84s/it]

{'loss': 0.2589, 'grad_norm': 0.44937172532081604, 'learning_rate': 0.004285714285714286, 'epoch': 0.71}


 16%|█▌        | 110/700 [14:40<1:18:23,  7.97s/it]

{'loss': 0.2284, 'grad_norm': 0.6282328963279724, 'learning_rate': 0.004214285714285715, 'epoch': 0.78}


 17%|█▋        | 120/700 [15:59<1:16:08,  7.88s/it]

{'loss': 0.2339, 'grad_norm': 0.9051182270050049, 'learning_rate': 0.0041428571428571434, 'epoch': 0.85}


 19%|█▊        | 130/700 [17:19<1:15:10,  7.91s/it]

{'loss': 0.3136, 'grad_norm': 1.1355319023132324, 'learning_rate': 0.004071428571428571, 'epoch': 0.92}


 20%|██        | 140/700 [18:37<1:13:31,  7.88s/it]

{'loss': 0.3191, 'grad_norm': 0.7820829153060913, 'learning_rate': 0.004, 'epoch': 0.99}


                                                   
 20%|██        | 140/700 [19:31<1:13:31,  7.88s/it]

{'eval_loss': 0.25439518690109253, 'eval_accuracy': 0.924, 'eval_runtime': 48.1668, 'eval_samples_per_second': 20.761, 'eval_steps_per_second': 1.308, 'epoch': 0.99}


 21%|██▏       | 150/700 [20:48<1:22:13,  8.97s/it]

{'loss': 0.1766, 'grad_norm': 0.7268509864807129, 'learning_rate': 0.003928571428571429, 'epoch': 1.07}


 23%|██▎       | 160/700 [22:11<1:14:24,  8.27s/it]

{'loss': 0.1532, 'grad_norm': 0.17718355357646942, 'learning_rate': 0.0038571428571428576, 'epoch': 1.14}


 24%|██▍       | 170/700 [23:31<1:11:29,  8.09s/it]

{'loss': 0.1621, 'grad_norm': 0.9461444616317749, 'learning_rate': 0.0037857142857142855, 'epoch': 1.21}


 26%|██▌       | 180/700 [24:51<1:09:50,  8.06s/it]

{'loss': 0.1494, 'grad_norm': 0.5911280512809753, 'learning_rate': 0.0037142857142857147, 'epoch': 1.28}


 27%|██▋       | 190/700 [26:14<1:09:02,  8.12s/it]

{'loss': 0.148, 'grad_norm': 0.2960309386253357, 'learning_rate': 0.0036428571428571426, 'epoch': 1.35}


 29%|██▊       | 200/700 [27:33<1:06:54,  8.03s/it]

{'loss': 0.1304, 'grad_norm': 0.6738382577896118, 'learning_rate': 0.0035714285714285718, 'epoch': 1.42}


 30%|███       | 210/700 [28:56<1:06:46,  8.18s/it]

{'loss': 0.133, 'grad_norm': 0.361554890871048, 'learning_rate': 0.0034999999999999996, 'epoch': 1.49}


 31%|███▏      | 220/700 [30:20<1:07:45,  8.47s/it]

{'loss': 0.1476, 'grad_norm': 0.5292849540710449, 'learning_rate': 0.003428571428571429, 'epoch': 1.56}


 33%|███▎      | 230/700 [31:42<1:04:29,  8.23s/it]

{'loss': 0.1382, 'grad_norm': 0.6103649139404297, 'learning_rate': 0.003357142857142857, 'epoch': 1.63}


 34%|███▍      | 240/700 [33:03<1:02:28,  8.15s/it]

{'loss': 0.2002, 'grad_norm': 0.49247828125953674, 'learning_rate': 0.003285714285714286, 'epoch': 1.71}


 36%|███▌      | 250/700 [34:24<1:00:03,  8.01s/it]

{'loss': 0.111, 'grad_norm': 0.3469056189060211, 'learning_rate': 0.0032142857142857147, 'epoch': 1.78}


 37%|███▋      | 260/700 [35:47<1:01:44,  8.42s/it]

{'loss': 0.1803, 'grad_norm': 0.756784200668335, 'learning_rate': 0.003142857142857143, 'epoch': 1.85}


 39%|███▊      | 270/700 [37:08<58:36,  8.18s/it]  

{'loss': 0.1699, 'grad_norm': 0.2297605276107788, 'learning_rate': 0.0030714285714285717, 'epoch': 1.92}


 40%|████      | 280/700 [38:30<56:39,  8.09s/it]

{'loss': 0.1413, 'grad_norm': 0.3629840612411499, 'learning_rate': 0.003, 'epoch': 1.99}


                                                 
 40%|████      | 281/700 [39:33<56:19,  8.06s/it]

{'eval_loss': 0.2305273860692978, 'eval_accuracy': 0.929, 'eval_runtime': 51.8342, 'eval_samples_per_second': 19.292, 'eval_steps_per_second': 1.215, 'epoch': 2.0}


 41%|████▏     | 290/700 [40:45<1:03:27,  9.29s/it]

{'loss': 0.0555, 'grad_norm': 0.3083096444606781, 'learning_rate': 0.002928571428571429, 'epoch': 2.06}


 43%|████▎     | 300/700 [42:08<55:56,  8.39s/it]  

{'loss': 0.0448, 'grad_norm': 0.20777277648448944, 'learning_rate': 0.002857142857142857, 'epoch': 2.13}


 44%|████▍     | 310/700 [43:34<57:34,  8.86s/it]

{'loss': 0.0516, 'grad_norm': 0.07422656565904617, 'learning_rate': 0.002785714285714286, 'epoch': 2.2}


 46%|████▌     | 320/700 [44:59<54:20,  8.58s/it]

{'loss': 0.0527, 'grad_norm': 0.40222257375717163, 'learning_rate': 0.0027142857142857142, 'epoch': 2.27}


 47%|████▋     | 330/700 [46:23<52:43,  8.55s/it]

{'loss': 0.0452, 'grad_norm': 0.47378888726234436, 'learning_rate': 0.002642857142857143, 'epoch': 2.34}


 49%|████▊     | 340/700 [47:46<48:47,  8.13s/it]

{'loss': 0.0538, 'grad_norm': 0.5559144616127014, 'learning_rate': 0.0025714285714285713, 'epoch': 2.42}


 50%|█████     | 350/700 [49:01<43:40,  7.49s/it]

{'loss': 0.0507, 'grad_norm': 0.36676421761512756, 'learning_rate': 0.0025, 'epoch': 2.49}


 51%|█████▏    | 360/700 [50:15<42:05,  7.43s/it]

{'loss': 0.0402, 'grad_norm': 0.41433799266815186, 'learning_rate': 0.002428571428571429, 'epoch': 2.56}


 53%|█████▎    | 370/700 [51:30<40:47,  7.42s/it]

{'loss': 0.0504, 'grad_norm': 0.9752363562583923, 'learning_rate': 0.002357142857142857, 'epoch': 2.63}


 54%|█████▍    | 380/700 [52:44<39:12,  7.35s/it]

{'loss': 0.0646, 'grad_norm': 0.9773350358009338, 'learning_rate': 0.002285714285714286, 'epoch': 2.7}


 56%|█████▌    | 390/700 [53:58<38:05,  7.37s/it]

{'loss': 0.0579, 'grad_norm': 0.5188830494880676, 'learning_rate': 0.002214285714285714, 'epoch': 2.77}


 57%|█████▋    | 400/700 [55:13<37:39,  7.53s/it]

{'loss': 0.0707, 'grad_norm': 0.6139543056488037, 'learning_rate': 0.002142857142857143, 'epoch': 2.84}


 59%|█████▊    | 410/700 [56:29<36:53,  7.63s/it]

{'loss': 0.0516, 'grad_norm': 0.5116169452667236, 'learning_rate': 0.0020714285714285717, 'epoch': 2.91}


 60%|██████    | 420/700 [57:48<36:14,  7.77s/it]

{'loss': 0.0521, 'grad_norm': 0.4523029029369354, 'learning_rate': 0.002, 'epoch': 2.98}


                                                 
 60%|██████    | 422/700 [58:49<35:57,  7.76s/it]

{'eval_loss': 0.2218172401189804, 'eval_accuracy': 0.937, 'eval_runtime': 45.3382, 'eval_samples_per_second': 22.056, 'eval_steps_per_second': 1.39, 'epoch': 3.0}


 61%|██████▏   | 430/700 [1:12:59<9:23:05, 125.13s/it] 

{'loss': 0.0158, 'grad_norm': 0.12725329399108887, 'learning_rate': 0.0019285714285714288, 'epoch': 3.06}


 63%|██████▎   | 440/700 [1:14:20<48:50, 11.27s/it]   

{'loss': 0.018, 'grad_norm': 0.0740182176232338, 'learning_rate': 0.0018571428571428573, 'epoch': 3.13}


 64%|██████▍   | 450/700 [1:15:38<33:10,  7.96s/it]

{'loss': 0.0089, 'grad_norm': 0.0729297548532486, 'learning_rate': 0.0017857142857142859, 'epoch': 3.2}


 66%|██████▌   | 460/700 [1:16:58<31:44,  7.94s/it]

{'loss': 0.0082, 'grad_norm': 0.05495526269078255, 'learning_rate': 0.0017142857142857144, 'epoch': 3.27}


 67%|██████▋   | 470/700 [1:18:17<30:19,  7.91s/it]

{'loss': 0.0051, 'grad_norm': 0.06741288304328918, 'learning_rate': 0.001642857142857143, 'epoch': 3.34}


 69%|██████▊   | 480/700 [1:19:38<29:43,  8.11s/it]

{'loss': 0.01, 'grad_norm': 0.017896447330713272, 'learning_rate': 0.0015714285714285715, 'epoch': 3.41}


 70%|███████   | 490/700 [1:20:55<26:10,  7.48s/it]

{'loss': 0.0111, 'grad_norm': 0.037134576588869095, 'learning_rate': 0.0015, 'epoch': 3.48}


 71%|███████▏  | 500/700 [1:22:09<24:48,  7.44s/it]

{'loss': 0.0183, 'grad_norm': 0.033515818417072296, 'learning_rate': 0.0014285714285714286, 'epoch': 3.55}


 73%|███████▎  | 510/700 [1:23:23<23:33,  7.44s/it]

{'loss': 0.0114, 'grad_norm': 0.14146484434604645, 'learning_rate': 0.0013571428571428571, 'epoch': 3.62}


 74%|███████▍  | 520/700 [1:24:38<22:24,  7.47s/it]

{'loss': 0.0078, 'grad_norm': 0.0054623158648610115, 'learning_rate': 0.0012857142857142856, 'epoch': 3.69}


 76%|███████▌  | 530/700 [1:25:55<21:58,  7.76s/it]

{'loss': 0.0103, 'grad_norm': 0.1460408866405487, 'learning_rate': 0.0012142857142857144, 'epoch': 3.77}


 77%|███████▋  | 540/700 [1:27:14<20:51,  7.82s/it]

{'loss': 0.0141, 'grad_norm': 0.0630146712064743, 'learning_rate': 0.001142857142857143, 'epoch': 3.84}


 79%|███████▊  | 550/700 [1:28:32<19:31,  7.81s/it]

{'loss': 0.0124, 'grad_norm': 0.41994595527648926, 'learning_rate': 0.0010714285714285715, 'epoch': 3.91}


 80%|████████  | 560/700 [1:29:51<18:20,  7.86s/it]

{'loss': 0.0196, 'grad_norm': 0.08102837204933167, 'learning_rate': 0.001, 'epoch': 3.98}


                                                   


{'eval_loss': 0.20818603038787842, 'eval_accuracy': 0.94, 'eval_runtime': 46.1359, 'eval_samples_per_second': 21.675, 'eval_steps_per_second': 1.366, 'epoch': 4.0}


 81%|████████▏ | 570/700 [1:31:55<20:38,  9.52s/it]

{'loss': 0.0064, 'grad_norm': 0.05691816657781601, 'learning_rate': 0.0009285714285714287, 'epoch': 4.05}


 83%|████████▎ | 580/700 [1:33:15<15:52,  7.94s/it]

{'loss': 0.0033, 'grad_norm': 0.030334725975990295, 'learning_rate': 0.0008571428571428572, 'epoch': 4.12}


 84%|████████▍ | 590/700 [1:34:34<14:40,  8.01s/it]

{'loss': 0.0036, 'grad_norm': 0.005318993702530861, 'learning_rate': 0.0007857142857142857, 'epoch': 4.19}


 86%|████████▌ | 600/700 [1:36:02<14:52,  8.92s/it]

{'loss': 0.0027, 'grad_norm': 0.014028681442141533, 'learning_rate': 0.0007142857142857143, 'epoch': 4.26}


 87%|████████▋ | 610/700 [1:37:30<13:18,  8.88s/it]

{'loss': 0.0038, 'grad_norm': 0.010969904251396656, 'learning_rate': 0.0006428571428571428, 'epoch': 4.33}


 89%|████████▊ | 620/700 [1:38:52<10:52,  8.15s/it]

{'loss': 0.0047, 'grad_norm': 0.07930019497871399, 'learning_rate': 0.0005714285714285715, 'epoch': 4.4}


 90%|█████████ | 630/700 [1:40:12<09:18,  7.98s/it]

{'loss': 0.0033, 'grad_norm': 0.021491754800081253, 'learning_rate': 0.0005, 'epoch': 4.48}


 91%|█████████▏| 640/700 [1:41:33<08:02,  8.04s/it]

{'loss': 0.0035, 'grad_norm': 0.06546765565872192, 'learning_rate': 0.0004285714285714286, 'epoch': 4.55}


 93%|█████████▎| 650/700 [1:42:53<06:41,  8.04s/it]

{'loss': 0.0026, 'grad_norm': 0.005614639725536108, 'learning_rate': 0.00035714285714285714, 'epoch': 4.62}


 94%|█████████▍| 660/700 [1:44:15<05:30,  8.26s/it]

{'loss': 0.0034, 'grad_norm': 0.02387746423482895, 'learning_rate': 0.00028571428571428574, 'epoch': 4.69}


 96%|█████████▌| 670/700 [1:45:41<04:16,  8.54s/it]

{'loss': 0.0024, 'grad_norm': 0.03190445527434349, 'learning_rate': 0.0002142857142857143, 'epoch': 4.76}


 97%|█████████▋| 680/700 [1:47:04<02:48,  8.43s/it]

{'loss': 0.0097, 'grad_norm': 0.016122493892908096, 'learning_rate': 0.00014285714285714287, 'epoch': 4.83}


 99%|█████████▊| 690/700 [1:48:28<01:25,  8.60s/it]

{'loss': 0.0026, 'grad_norm': 0.12231434881687164, 'learning_rate': 7.142857142857143e-05, 'epoch': 4.9}


100%|██████████| 700/700 [1:49:52<00:00,  8.45s/it]

{'loss': 0.0023, 'grad_norm': 0.03155376389622688, 'learning_rate': 0.0, 'epoch': 4.97}


                                                   
100%|██████████| 700/700 [1:50:43<00:00,  8.45s/it]

{'eval_loss': 0.20703613758087158, 'eval_accuracy': 0.946, 'eval_runtime': 50.3043, 'eval_samples_per_second': 19.879, 'eval_steps_per_second': 1.252, 'epoch': 4.97}


100%|██████████| 700/700 [1:50:43<00:00,  9.49s/it]


{'train_runtime': 6643.6853, 'train_samples_per_second': 6.773, 'train_steps_per_second': 0.105, 'train_loss': 0.13161656854408127, 'epoch': 4.97}


100%|██████████| 63/63 [00:48<00:00,  1.29it/s]


Evaluation accuracy: 0.946
Model size: 2.7 MB


In [28]:
def build_inference_model(label2id, id2label, lora_adapter_path):
    """Build the model that will be use to run inference."""

    # Let's load the base model
    model = get_base_model(label2id, id2label)

    # Now, we can create the inference model combining the base model
    # with the fine-tuned LoRA adapter.
    return PeftModel.from_pretrained(model, lora_adapter_path)


def predict(image, model, image_processor):
    """Predict the class represented by the supplied image."""
    
    encoding = image_processor(image.convert("RGB"), return_tensors="pt")
    with torch.no_grad():
        outputs = model(**encoding)
        logits = outputs.logits

    class_index = logits.argmax(-1).item()
    return model.config.id2label[class_index]


In [29]:
model1["inference_model"] = build_inference_model(model1["label2id"], model1["id2label"], model1["path"]) 
model1["image_processor"] = AutoImageProcessor.from_pretrained(model1["path"])

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [30]:
images = {
    'image1': "https://www.allrecipes.com/thmb/AtViolcfVtInHgq_mRtv4tPZASQ=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/ALR-187822-baked-chicken-wings-4x3-5c7b4624c8554f3da5aabb7d3a91a209.jpg",
    'image2':  "https://www.simplyrecipes.com/thmb/KE6iMblr3R2Db6oE8HdyVsFSj2A=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/__opt__aboutcom__coeus__resources__content_migration__simply_recipes__uploads__2019__09__easy-pepperoni-pizza-lead-3-1024x682-583b275444104ef189d693a64df625da.jpg"
}

In [32]:
from PIL import Image
import requests

image = Image.open(requests.get(images["image2"], stream=True).raw)

inference_model = model1["inference_model"]
image_processor = model1["image_processor"]

prediction = predict(image, inference_model, image_processor)
print(f"Prediction: {prediction}")

Prediction: pizza
