In [None]:
import os
from typing import Iterable
from datasets import load_metric
import numpy as np
from PIL import Image
import torch
from transformers import TrainingArguments, Trainer, ViTForImageClassification, ViTImageProcessor

In [None]:
dataset_dir = 'dataset/Garbage classification'
image_class_names = ['metal', 'glass', 'paper', 'trash', 'cardboard', 'plastic']

batch_size = 64

In [None]:
def load_data(dataset_dir: str, image_class_names: list[str], batch_size: int):
    batches = []
    images = []
    labels = []
    for class_name in image_class_names:
        class_dir = os.path.join(dataset_dir, class_name)
        for image_name in os.listdir(class_dir):
            image_path = os.path.join(class_dir, image_name)
            image = Image.open(image_path)
            images.append(image.flatten())
            labels.append(i)

            if len(images) == batch_size:
                batches.append({'image': images, 'labels': labels})
                images = []
                labels = []
    return batches

ds = load_data(dataset_dir, image_class_names, batch_size)

In [None]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

def transform(batch: Iterable[Image], processor):
    inputs = processor([x for x in batch['image']], return_tensors='pt')    # process into pixel values
    inputs['labels'] = batch['labels']
    return inputs

processed_ds = [transform(batch, processor) for batch in ds]

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [None]:
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=len(image_class_names),
    id2label={str(i): c for i, c in enumerate(image_class_names)},
    label2id={c: str(i) for i, c in enumerate(image_class_names)}
)

In [None]:
training_args = TrainingArguments(
    output_dir="./vit-trained",
    per_device_train_batch_size=batch_size,
    evaluation_strategy="steps",
    num_train_epochs=4,
    fp16=True,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='tensorboard',
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=processor,
)

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

In [None]:
metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)