In [None]:
import torch
import matplotlib.pyplot as plt
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
from src.model import load_vit_model
from src.utils import transform_function, collate_fn, compute_metrics

torch.manual_seed(42)


In [None]:
ds = load_dataset("albertvillanova/medmnist-v2", "bloodmnist")
labels = ds["train"].features["label"].names
print("Labels:", labels)

model, image_processor = load_vit_model(num_labels=len(labels))

ds = ds.with_transform(lambda examples: transform_function(examples, image_processor))


In [None]:
training_args = TrainingArguments(
    output_dir="./vit-bloodmnist-notebook",
    per_device_train_batch_size=32,
    evaluation_strategy="steps",
    num_train_epochs=10,
    save_steps=374,
    eval_steps=374,
    logging_steps=374,
    save_total_limit=2,
    load_best_model_at_end=True,
    report_to=None,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    tokenizer=image_processor,
)

In [None]:
trainer.train()

In [None]:
results = trainer.evaluate()
print(results)

In [None]:
import pandas as pd

loghistory = pd.DataFrame(trainer.state.log_history)
loghistory = loghistory.fillna(0)
loghistory = loghistory.groupby('epoch').mean()

plt.figure(figsize=(6,4))
plt.plot(loghistory.index, loghistory['loss'], label='Training Loss')
plt.plot(loghistory.index, loghistory['eval_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training & Validation Loss')
plt.legend()
plt.show()