# This notebook demonstrate how to finetune VIT model 

## 1. Initial preperation

### 1.1 Install Depandacies

In [1]:
! pip install datasets transformers accelerate torchvision scikit-learn

zsh:1: /Users/arifhamid/Desktop/portfolio/vit_trainer/.venv/bin/pip: bad interpreter: /Users/arifhamid/Desktop/portfolio/vit_training/.venv/bin/python: no such file or directory
You should consider upgrading via the '/Users/arifhamid/.pyenv/versions/3.10.5/bin/python3.10 -m pip install --upgrade pip' command.[0m[33m
[0m

### 1.2 Setup HF token. Need to setup HF token with write access because we will download foundation/base VIT model from HF hub

In [None]:
from huggingface_hub import interpreter_login

interpreter_login()

## 2. Data preparation 

### 2.1 Create custom data using ``dataset`` library

In [None]:
from datasets import load_dataset

LOCAL_DATASET_PATH = './Medicinal-plant-dataset'
CACHE_DIR = './.hf_cache/'

ds = load_dataset('imagefolder',data_dir=LOCAL_DATASET_PATH)


In [None]:
data = ds.train_test_split(test_size=0.2)
data

In [None]:
# push data to hf 
data.push_to_hub("funkepal/medicinal_plant_images")

In [None]:
# load data from hf hub
from datasets import load_dataset

HF_DATASET_PATH = 'funkepal/medicinal_plant_images'
CACHE_DIR = './.hf_cache/'

hf_ds = load_dataset(HF_DATASET_PATH,cache_dir=CACHE_DIR)

labels = hf_ds['train'].features['label']
labels

## 3. Pre-processing Dataset

### 3.1 Load pre-trained model

In [None]:
from transformers import ViTImageProcessor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path,cache_dir=CACHE_DIR)


In [None]:
print(processor)

In [None]:
## This function will convert the image to tensor
## image -> ViTImageProcessor -> tensor
## because computer only understand number ViTImageProcessor act as tokenizer where it convert image to tensor
def process_example(example):
    inputs = processor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs

### 3.2 Apply data augmentation to our image
 - this step make transformation to the image dataset

In [None]:


from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

size =(processor.size['height'],processor.size['width'])

normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
train_transforms = Compose(
        [
            RandomResizedCrop(size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [None]:
# split up training into training + validation
train_ds = hf_ds['train']
val_ds = hf_ds['test']

train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

train_ds[0]

## 4. Training and Evaluation

- Define a collate function.

- Define an evaluation metric. During training, the model should be evaluated on its prediction accuracy. You should define a compute_metrics function accordingly.

- Load a pretrained checkpoint

- Define the training configuration.

### 4.1 Define collate function
- Batches are coming in as lists of dicts, so you can just unpack + stack those into batch tensors.

In [None]:
import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }


### 4.2 Define an evaluation metric

In [None]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)


### 4.3 Load a pretrained checkpoint
- id2label and label2id mappings to have human-readable labels in the Hub widget (if you choose to push_to_hub).

In [None]:
from transformers import ViTForImageClassification

labels = hf_ds['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
    cache_dir=CACHE_DIR
)


### 4.4 Define the training configuration.

In [None]:
from transformers import TrainingArguments
import torch

def check_device():
    if torch.cuda.is_available():
        return 'cuda'
    if torch.backends.mps.is_available():
        return 'mps'
    else:
        return 'cpu'
    
device = torch.device(check_device())

training_args_mac = TrainingArguments(
  output_dir="./vit-medicinal-plant-finetune",
  per_device_train_batch_size=10,
  eval_strategy="steps",
  num_train_epochs=1,
  save_steps=200,
  eval_steps=200,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  load_best_model_at_end=True,
)


### 4.5 Set our trainer

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args_mac,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=processor,
)


### 4.6 Train

In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

### 4.7 Evaluate

In [None]:

metrics = trainer.evaluate(train_ds)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


### 4.8 Inference model from checkpoint

In [None]:

## load model from checkpoint
checkpoint_path = "./vit-medicinal-plant-finetune/checkpoint-1900"

inference_processor = ViTImageProcessor.from_pretrained(checkpoint_path)
inputs = inference_processor(val_ds[0]['image'], return_tensors="pt")
inference_model = ViTForImageClassification.from_pretrained(checkpoint_path)
with torch.no_grad():
    logits = inference_model(**inputs).logits

predicted_label = logits.argmax(-1).item()
predicted_class = inference_model.config.id2label[predicted_label]

print(predicted_class)


In [None]:
## Converting model to pt format


from transformers import ViTForImageClassification,ViTImageProcessor
import torch

## load model from checkpoint
checkpoint_path = "./vit-medicinal-plant-finetune/checkpoint-1900"
save_path = "./torch_model"
inference_processor = ViTImageProcessor.from_pretrained(checkpoint_path)
# inputs = inference_processor(val_ds[0]['image'], return_tensors="pt")
inference_model = ViTForImageClassification.from_pretrained(checkpoint_path)

# torch.save(inference_model.state_dict(),f"{save_path}/checkpoint-1900.pt")
for param_tensor in inference_model.state_dict():
    print(param_tensor, "\t", inference_model.state_dict()[param_tensor].size())
