# This notebook demonstrate how to finetune VIT model 

## 1. Initial preperation

### 1.1 Install Depandacies

In [None]:
! pip install datasets transformers accelerate matplotlib

### 1.2 Setup HF token. Need to setup HF token with write access because we will download foundation/base VIT model from HF hub

In [None]:
from huggingface_hub import interpreter_login

interpreter_login()

### 1.3 check device

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Data preparation 

### 2.1 Load data from HF using ``dataset`` library

In [None]:
from datasets import load_dataset

CACHE_DIR = './hf_cache'

data = load_dataset("funkepal/medicinal_plant_images",cache_dir = CACHE_DIR)

In [None]:
print(data)

## 3. Pre-processing Dataset

### 3.1 Load pre-trained model

In [None]:
from transformers import ViTImageProcessor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)



In [None]:
print(processor)

In [None]:
## This function will convert the image to tensor
## image -> ViTImageProcessor -> tensor
## because computer only understand number ViTImageProcessor act as tokenizer where it convert image to tensor
def process_example(example):
    inputs = processor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs

In [None]:
# ## This function accept batch of image and convert it to tensor
# def transform(example_batch):
#     # Take a list of PIL images and turn them to pixel values
#     inputs = processor([x for x in example_batch['image']], return_tensors='pt')

#     # Don't forget to include the labels!
#     inputs['labels'] = example_batch['labels']
#     print(inputs)
#     return inputs

In [None]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

size =(processor.size['height'],processor.size['width'])

normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
train_transforms = Compose(
        [
            RandomResizedCrop(size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [None]:
# split up training into training + validation
train_ds = data['train']
val_ds = data['test']

train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)


## 4. Training and Evaluation

- Define a collate function.

- Define an evaluation metric. During training, the model should be evaluated on its prediction accuracy. You should define a compute_metrics function accordingly.

- Load a pretrained checkpoint

- Define the training configuration.

### 4.1 Define collate function
- Batches are coming in as lists of dicts, so you can just unpack + stack those into batch tensors.

In [None]:
import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }


### 4.2 Define an evaluation metric

In [None]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)


### 4.3 Load a pretrained checkpoint
- id2label and label2id mappings to have human-readable labels in the Hub widget (if you choose to push_to_hub).

In [None]:
from transformers import ViTForImageClassification

labels = data['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)


### 4.4 Define the training configuration.

In [None]:
from transformers import TrainingArguments



training_args = TrainingArguments(
  output_dir="./vit-medicinal-plant-finetune-v2",
  per_device_train_batch_size=10,
  eval_strategy="steps",
  num_train_epochs=10,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  load_best_model_at_end=True,

)


### 4.5 Set our trainer

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=processor,
)


### 4.6 Train

In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

### 4.7 Visualize some metric by hf trainer

- during training with HF trainer you can get all the trainer log from ``trainer.state.log_history``
- ``trainer.state.log_history`` contain all step and eval log

In [None]:
# save log_history as pickel

# import pickle

# Example data
# data = trainer.state.log_history
# # Define the pickle file path
# pickle_file = './output.pkl'

# # Save data to pickle file
# with open(pickle_file, 'wb') as f:
#     pickle.dump(complex_data, f)

# print(f'Data saved to {pickle_file}')

In [None]:
import pickle

# Open the .pkl file in read-binary mode
with open('trainer_log_dump.pkl', 'rb') as file:
    # Load the data from the file
    trainer_log_dump = pickle.load(file)


In [None]:
#visualize some metric

import matplotlib.pyplot as plt

from matplotlib import pyplot as plt

# Sample data
data = trainer_log_dump

# Filter evaluation data
eval_data = [entry for entry in data if 'eval_loss' in entry]

# Extract metrics
steps = [entry['step'] for entry in eval_data]
eval_loss = [entry['eval_loss'] for entry in eval_data]
eval_accuracy = [entry['eval_accuracy'] for entry in eval_data]

# Plot eval_loss
plt.style.use('dark_background')
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.plot(steps, eval_loss, marker='.', linestyle='-', color='b', label='Eval Loss')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Evaluation Loss Over Steps')
plt.legend()

# Plot eval_accuracy
plt.subplot(1, 2, 2)
plt.plot(steps, eval_accuracy, marker='.', linestyle='-', color='g', label='Eval Accuracy')
plt.xlabel('Steps')
plt.ylabel('Accuracy')
plt.title('Evaluation Accuracy Over Steps')
plt.legend()

plt.tight_layout()
plt.show()





In [None]:
import matplotlib.pyplot as plt


data = trainer_log_dump

# Filter non-evaluation data
non_eval_data = [entry for entry in data if 'eval_loss' not in entry if 'loss' in entry if 'grad_norm' in entry if 'learning_rate' in entry]

# Extract metrics
steps = [entry['step'] for entry in non_eval_data]
loss = [entry['loss'] for entry in non_eval_data]
grad_norm = [entry['grad_norm'] for entry in non_eval_data ]
learning_rate = [entry['learning_rate'] for entry in non_eval_data ]


print(f"{len(steps)}-{len(loss)}-{len(grad_norm)}-{len(learning_rate)}")

# Use dark background style
plt.style.use('dark_background')

# Plot loss
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(steps, loss, marker='.', linestyle='-', color='c', label='Loss')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training Loss Over Steps')
plt.legend()

# Plot grad_norm
plt.subplot(1, 3, 2)
plt.plot(steps, grad_norm, marker='.', linestyle='-', color='m', label='Grad Norm')
plt.xlabel('Steps')
plt.ylabel('Grad Norm')
plt.title('Gradient Norm Over Steps')
plt.legend()

# Plot learning_rate
plt.subplot(1, 3, 3)
plt.plot(steps, learning_rate, marker='.', linestyle='-', color='y', label='Learning Rate')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Over Steps')
plt.legend()

plt.tight_layout()
plt.show()




### 4.7 Evaluate

In [None]:

metrics = trainer.evaluate(train_ds)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


### 4.8 Inference model from checkpoint

In [None]:

## load model from checkpoint
checkpoint_path = "./vit-medicinal-plant-finetune-v2/checkpoint-4800"

inference_processor = ViTImageProcessor.from_pretrained(checkpoint_path)
inputs = inference_processor(val_ds[0]['image'], return_tensors="pt")
inference_model = ViTForImageClassification.from_pretrained(checkpoint_path)
with torch.no_grad():
    logits = inference_model(**inputs).logits

predicted_label = logits.argmax(-1).item()
predicted_class = inference_model.config.id2label[predicted_label]

print(predicted_class)


In [None]:
## Converting model to pt format


from transformers import ViTForImageClassification,ViTImageProcessor
import torch

## load model from checkpoint
checkpoint_path = "./vit-medicinal-plant-finetune/checkpoint-1900"
save_path = "./torch_model"
inference_processor = ViTImageProcessor.from_pretrained(checkpoint_path)
# inputs = inference_processor(val_ds[0]['image'], return_tensors="pt")
inference_model = ViTForImageClassification.from_pretrained(checkpoint_path)

# torch.save(inference_model.state_dict(),f"{save_path}/checkpoint-1900.pt")
for param_tensor in inference_model.state_dict():
    print(param_tensor, "\t", inference_model.state_dict()[param_tensor].size())






In [6]:
## load model from hub

from transformers import ViTImageProcessor,ViTForImageClassification
import torch
from PIL import Image

checkpoint_path = "funkepal/vit-medicinal-plant-finetune"
CACHE_DIR = './.hf_cache'

inference_processor = ViTImageProcessor.from_pretrained(checkpoint_path,cache_dir=CACHE_DIR)
inference_model = ViTForImageClassification.from_pretrained(checkpoint_path,cache_dir=CACHE_DIR)


sample_image = Image.open('1.jpg')
inputs = inference_processor(sample_image, return_tensors = 'pt')
with torch.no_grad():
    logits = inference_model(**inputs).logits

predicted_label = logits.argmax(-1).item()
predicted_class = inference_model.config.id2label[predicted_label]

print(predicted_class)


Nooni
