### Install and import the required packages

In [None]:
! pip install --quiet "transformers[torch]"
! pip install --quiet evaluate
! pip install --quiet ipywidgets
! pip install --quiet datasets
! pip install --quiet pillow
! pip install --quiet scikit-learn
! pip install --quiet tensorboard

In [None]:
import torch
from transformers import Trainer, TrainingArguments
from transformers import AutoImageProcessor, AutoModelForImageClassification

import evaluate
import numpy as np
from datetime import datetime
from datasets import load_dataset

# from huggingface_hub import notebook_login
# notebook_login()

### Select Model and Dataset

In [None]:
model_path = "google/vit-base-patch16-224"
dataset_path = "cvmil/rice-disease-02"

# **Dataset**

### **Import Dataset from Huggingface**

We import the dataset from the Huggingface hub using the `load_dataset` function from the `datasets` library.

In [None]:
dataset = load_dataset(dataset_path)
labels = dataset['train'].features['label'].names

print(dataset)
print(f"\n\nNumber of classes: {len(labels)}")
for i, label in enumerate(labels):
    print(f"{i}: {label}")


### **Image Processor**

Now we initialize an image processor using the `AutoImageProcessor` class from a pre-trained model.</br>
The `from_pretrained` method loads the processor with the configurations and parameters from the specified model path.


In [None]:
processor = AutoImageProcessor.from_pretrained(model_path)
processor

### **Dataset Processing**

1. `transforms(batch)`: Converts images to RGB, processes them into tensors, and attaches labels.
2. `collate_fn(batch)`: Collates a batch of processed images and labels into tensors for model input.


In [None]:
label2id = {c:idx for idx,c in enumerate(labels)}
id2label = {idx:c for idx,c in enumerate(labels)}

def transforms(batch):
    batch['image'] = [x.convert('RGB') for x in batch['image']]
    inputs = processor(batch['image'],return_tensors='pt')
    inputs['labels'] = batch['label']
    return inputs

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

Apply the `transforms` function to the dataset.

In [None]:
processed_dataset = dataset.with_transform(transforms)
processed_dataset

### **Evaluation Metrics**

This section defines a `compute_metrics` function that computes the accuracy of model predictions.</br>
It uses the `evaluate` library to load the accuracy metric and calculates the accuracy by comparing the model's predicted labels to the true labels.

In [None]:
accuracy = evaluate.load('accuracy')
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits,axis=1)
    score = accuracy.compute(predictions=predictions, references=labels)
    return score

# **Model**

In [None]:
model = AutoModelForImageClassification.from_pretrained(
    model_path,
    num_labels = len(labels),
    id2label = id2label,
    label2id = label2id,
    ignore_mismatched_sizes = True
)

model

Freeze all the parameters except for the new classifier layer

In [None]:
for name, param in model.named_parameters():
    param.requires_grad = name.startswith(('classifier', 'distillation_classifier'))

We can check how many parameters are there in the model along with how many are actually going to be trained now.

In [None]:
num_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {num_params:,} | Trainable parameters: {trainable_params:,}")

# **Training**

In [None]:
dt_string = datetime.now().strftime("%m%d%Y")
output_dir = f"./training_output/{model_path.split('/')[-1]}_{dt_string}"
logging_dir = f"{output_dir}/logs"

training_args = TrainingArguments(
    per_device_train_batch_size=16,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=100,
    num_train_epochs=10,
    learning_rate=3e-4,
    save_total_limit=5,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    report_to=['tensorboard'],
    logging_dir=logging_dir,
    output_dir=output_dir,
    push_to_hub=True,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=processed_dataset["train"],
    eval_dataset=processed_dataset["validation"],
    processing_class=processor,
)

In [None]:
train_results = trainer.train()

In [None]:
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()