In [None]:
! pip install --quiet transformers
! pip install --quiet "transformers[torch]"
! pip install --quiet torch
! pip install --quiet datasets
! pip install --quiet evaluate
! pip install --quiet scikit-learn
! pip install --quiet tensorboard
! pip install --quiet matplotlib
! pip install --quiet ipywidgets

In [2]:
import random
from PIL import ImageDraw, ImageFont, Image

import torch
import numpy as np

import evaluate
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoImageProcessor, AutoModelForImageClassification
from transformers import Trainer, TrainingArguments

from huggingface_hub import notebook_login

In [None]:
notebook_login()

# **Dataset**

In [None]:
dataset = load_dataset("cvmil/rice-disease-02")

In [None]:
labels = dataset['train'].features['label'].names

print(dataset)
print("\n\nlabels:", len(labels), labels)

In [None]:
def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(300, 300)):

    w, h = size
    labels = ds.features['label'].names
    grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    font = ImageFont.truetype("Chalkduster.ttf", 32)

    for label_id, label in enumerate(labels):

        # Filter the dataset by a single label, shuffle it, and grab a few samples
        ds_slice = ds.filter(lambda ex: ex['label'] == label_id, num_proc=16).shuffle(seed).select(range(examples_per_class))

        # Plot this label's examples along a row
        for i, example in enumerate(ds_slice):
            image = example['image']
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, "white", font=font)

    return grid

show_examples(dataset['train'], seed=random.randint(0, 1337), examples_per_class=3)


## **Dataset Preprocessing**

### Image Processor

In [None]:
processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')
processor

In [None]:
label2id = {c:idx for idx,c in enumerate(labels)}
id2label = {idx:c for idx,c in enumerate(labels)}
print(label2id)

def transforms(batch):
    print(batch)
    batch['image'] = [x.convert('RGB') for x in batch['image']]
    inputs = processor(batch['image'],return_tensors='pt')
    inputs['labels'] = batch['label']
    return inputs

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

In [None]:
processed_dataset = dataset.with_transform(transforms)
processed_dataset

### Evaluation Metrics

In [9]:
accuracy = evaluate.load('accuracy')
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits,axis=1)
    score = accuracy.compute(predictions=predictions, references=labels)
    return score

# **Model**

In [None]:
model = AutoModelForImageClassification.from_pretrained(
    'google/vit-base-patch16-224',
    num_labels = len(labels),
    id2label = id2label,
    label2id = label2id,
    ignore_mismatched_sizes = True
)

model

Freeze all the parameters except for the new classifier layer

In [11]:
for name,p in model.named_parameters():
    if not name.startswith('classifier'):
        p.requires_grad = False

We can check how many parameters are there in the model along with how many are actually going to be trained now.

In [None]:
num_params = sum([p.numel() for p in model.parameters()])
trainable_params = sum([p.numel() for p in model.parameters() if p.requires_grad])

print(f"{num_params = :,} | {trainable_params = :,}")

# **Training**

In [14]:
training_args = TrainingArguments(
    output_dir="./vit-base-patch16-224_01",
    per_device_train_batch_size=16,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=100,
    num_train_epochs=1,
    learning_rate=3e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub='cvmil/vit-base-patch16-224_01',
    report_to='tensorboard',
    load_best_model_at_end=True,
)

In [15]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=processed_dataset["train"],
    eval_dataset=processed_dataset["validation"],
    tokenizer=processor
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate(processed_dataset['test'])