# Image Classification Training Notebook

This notebook implements the image classification training script with the specified parameters.

In [None]:
# Install required packages if needed
!pip install transformers datasets evaluate torch torchvision huggingface_hub

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
import logging
import os
from dataclasses import dataclass, field
from typing import Optional

import evaluate
import numpy as np
import torch
from datasets import load_dataset
from PIL import Image
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

import transformers
from transformers import (
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
    AutoConfig,
    AutoImageProcessor,
    AutoModelForImageClassification,
    Trainer,
    TrainingArguments,
    set_seed,
)

In [3]:
@dataclass
class ModelArguments:
    model_name_or_path: str = field(
        default="google/vit-base-patch16-224-in21k",
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
    )
    cache_dir: Optional[str] = field(
        default=None
    )
    model_revision: str = field(
        default="main"
    )
    image_processor_name: str = field(default=None)
    token: str = field(default=None)
    trust_remote_code: bool = field(default=False)
    ignore_mismatched_sizes: bool = field(default=False)

@dataclass
class DataTrainingArguments:
    dataset_name: str = field(default="acidtib/tcg-magic-cards")
    dataset_config_name: Optional[str] = field(default=None)
    train_val_split: Optional[float] = field(default=0.15)
    max_train_samples: Optional[int] = field(default=None)
    max_eval_samples: Optional[int] = field(default=None)
    image_column_name: str = field(default="image")
    label_column_name: str = field(default="label")

In [None]:
# Set up training arguments
training_args = TrainingArguments(
    output_dir="./models/tcg_magic/",
    remove_unused_columns=False,
    do_train=True,
    do_eval=True,
    learning_rate=2e-5,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_strategy="steps",
    logging_steps=100,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    save_total_limit=3,
    seed=420,
    fp16=True,
    push_to_hub=False,
    push_to_hub_model_id="tcg-magic-classifier"
)

# Initialize model and data arguments
model_args = ModelArguments()
data_args = DataTrainingArguments()

In [6]:
# Set up logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

# Set seed
set_seed(training_args.seed)

In [None]:
# Load dataset
dataset = load_dataset(
    data_args.dataset_name,
    data_args.dataset_config_name,
    cache_dir=model_args.cache_dir,
    token=model_args.token,
    trust_remote_code=model_args.trust_remote_code,
)

# Split dataset if needed
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
    split = dataset["train"].train_test_split(data_args.train_val_split)
    dataset["train"] = split["train"]
    dataset["validation"] = split["test"]

In [8]:
# Prepare label mappings
labels = dataset["train"].features[data_args.label_column_name].names
label2id, id2label = {}, {}
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

# Load accuracy metric
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)

def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [None]:
# Load model and image processor
config = AutoConfig.from_pretrained(
    model_args.model_name_or_path,
    num_labels=len(labels),
    label2id=label2id,
    id2label=id2label,
    finetuning_task="image-classification",
)

model = AutoModelForImageClassification.from_pretrained(
    model_args.model_name_or_path,
    config=config,
    ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)

image_processor = AutoImageProcessor.from_pretrained(
    model_args.model_name_or_path,
)

In [10]:
# Set up image transforms
if "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
else:
    size = (image_processor.size["height"], image_processor.size["width"])

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) \
    if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std") \
    else Lambda(lambda x: x)

_train_transforms = Compose([
    RandomResizedCrop(size),
    RandomHorizontalFlip(),
    ToTensor(),
    normalize,
])

_val_transforms = Compose([
    Resize(size),
    CenterCrop(size),
    ToTensor(),
    normalize,
])

def train_transforms(example_batch):
    example_batch["pixel_values"] = [
        _train_transforms(pil_img.convert("RGB")) 
        for pil_img in example_batch[data_args.image_column_name]
    ]
    return example_batch

def val_transforms(example_batch):
    example_batch["pixel_values"] = [
        _val_transforms(pil_img.convert("RGB")) 
        for pil_img in example_batch[data_args.image_column_name]
    ]
    return example_batch

# Apply transforms
dataset["train"].set_transform(train_transforms)
dataset["validation"].set_transform(val_transforms)

In [11]:
# Set up data collator
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example[data_args.label_column_name] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [12]:
# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

In [None]:
# Train the model
train_result = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

In [None]:
# Evaluate the model
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [None]:
# Push to hub
if training_args.push_to_hub:
    kwargs = {
        "finetuned_from": model_args.model_name_or_path,
        "tasks": "image-classification",
        "dataset": data_args.dataset_name,
        "tags": ["image-classification", "vision"],
    }
    trainer.push_to_hub(**kwargs)