In [None]:
from datasets import load_dataset
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

from transformers import (
    ViTConfig,
    ViTForImageClassification,
    ViTImageProcessor,
    TrainingArguments,
    Trainer,
)


In [None]:
# =============================
# Pfade & Hyperparameter
# =============================

# Pfade anpassen
IMAGENET_ROOT = "../../../data/imagenet"          # imagefolder-Style: train/, val/
ANIMAL_ROOT = "../../../data/animal_images" # wie in deinem bisherigen Notebook

OUTPUT_IMAGENET_MODEL_DIR = "./vit_patchX_imagenet"
OUTPUT_ANIMAL_MODEL_DIR = "./vit_patchX_animals"

# Modell-/Trainings-Config
PATCH_SIZE = 8       # <--- HIER deine gewünschte Patch-Size eintragen
IMAGE_SIZE = 224      # Input-Resolution
IMAGENET_NUM_EPOCHS = 10
ANIMAL_NUM_EPOCHS = 5
BATCH_SIZE = 32
LEARNING_RATE = 5e-5
SEED = 42

torch.manual_seed(SEED)
np.random.seed(SEED)


In [None]:
# =============================
# ImageNet laden (imagefolder)
# =============================

# Erwartete Struktur:
# IMAGENET_ROOT/train/<klasse>/*.jpg
# IMAGENET_ROOT/val/<klasse>/*.jpg

imagenet_train = load_dataset(
    "imagefolder",
    data_dir=f"{IMAGENET_ROOT}/train",
    split="train"
)

imagenet_val = load_dataset(
    "imagefolder",
    data_dir=f"{IMAGENET_ROOT}/val",
    split="train"
)

imagenet_train, imagenet_val


In [None]:
# =============================
# Labels für ImageNet
# =============================

imagenet_label_names = imagenet_train.features["label"].names
num_imagenet_labels = len(imagenet_label_names)

id2label_imagenet = {i: name for i, name in enumerate(imagenet_label_names)}
label2id_imagenet = {name: i for i, name in enumerate(imagenet_label_names)}

num_imagenet_labels, list(id2label_imagenet.items())[:5]


In [None]:
# =============================
# Processor / Normalisierung
# =============================

# Processor von einem existierenden ViT holen (nur für Preprocessing!)
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

image_size = processor.size.get("height", IMAGE_SIZE)
image_mean = processor.image_mean
image_std = processor.image_std

image_size, image_mean, image_std


In [None]:
# =============================
# Transforms & Preprocessing
# =============================

from torchvision import transforms

# einfache Augmentierung
train_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_mean, std=image_std),
])

val_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_mean, std=image_std),
])

def transform_imagenet_train(examples):
    examples["pixel_values"] = [
        train_transform(img.convert("RGB")) for img in examples["image"]
    ]
    return examples

def transform_imagenet_val(examples):
    examples["pixel_values"] = [
        val_transform(img.convert("RGB")) for img in examples["image"]
    ]
    return examples

# Datasets transformieren
imagenet_train = imagenet_train.with_transform(transform_imagenet_train)
imagenet_val = imagenet_val.with_transform(transform_imagenet_val)


In [None]:
# =============================
# Collate-Funktion
# =============================

from torch.utils.data import DataLoader

def collate_fn(batch):
    pixel_values = torch.stack([example["pixel_values"] for example in batch])
    labels = torch.tensor([example["label"] for example in batch])
    return {"pixel_values": pixel_values, "labels": labels}


In [None]:
# =============================
# ViT-Config mit neuer Patch-Size (ImageNet)
# =============================

# ViTConfig von einem existierenden Modell laden und PATCH_SIZE überschreiben
base_config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k")

vit_imagenet_config = ViTConfig(
    hidden_size=base_config.hidden_size,
    num_hidden_layers=base_config.num_hidden_layers,
    num_attention_heads=base_config.num_attention_heads,
    intermediate_size=base_config.intermediate_size,
    image_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,                 # <--- WICHTIG: neue Patch-Size
    num_labels=num_imagenet_labels,
    id2label=id2label_imagenet,
    label2id=label2id_imagenet,
)

vit_imagenet_model = ViTForImageClassification(vit_imagenet_config)

vit_imagenet_model


In [None]:
# =============================
# TrainingArguments & Trainer (ImageNet)
# =============================

imagenet_training_args = TrainingArguments(
    output_dir=OUTPUT_IMAGENET_MODEL_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=IMAGENET_NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=100,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

import evaluate
accuracy_metric = evaluate.load("accuracy")

def compute_metrics_imagenet(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return accuracy_metric.compute(predictions=preds, references=labels)

imagenet_trainer = Trainer(
    model=vit_imagenet_model,
    args=imagenet_training_args,
    train_dataset=imagenet_train,
    eval_dataset=imagenet_val,
    data_collator=collate_fn,
    compute_metrics=compute_metrics_imagenet,
)


In [None]:
# =============================
# ImageNet trainieren & speichern
# =============================

imagenet_trainer.train()

# bestes Modell speichern
imagenet_trainer.save_model(OUTPUT_IMAGENET_MODEL_DIR)
processor.save_pretrained(OUTPUT_IMAGENET_MODEL_DIR)


In [None]:
# =============================
# Phase B: Animal-Dataset laden
# =============================

animal_dataset = load_dataset("imagefolder", data_dir=ANIMAL_ROOT)

animal_train = animal_dataset["train"]
# falls ein "test" Split existiert, wird er genutzt; sonst None
animal_test = animal_dataset.get("test", None)

animal_dataset


In [None]:
# =============================
# Labels für dein Animal-Dataset
# =============================

animal_label_names = animal_train.features["label"].names
num_animal_labels = len(animal_label_names)

id2label_animals = {i: name for i, name in enumerate(animal_label_names)}
label2id_animals = {name: i for i, name in enumerate(animal_label_names)}

num_animal_labels, id2label_animals


In [None]:
# =============================
# Transforms für dein Dataset
# =============================

# Wir benutzen dieselben Augmentierungen / Normalisierung wie oben

def transform_animal(examples):
    examples["pixel_values"] = [
        train_transform(img.convert("RGB")) for img in examples["image"]
    ]
    return examples

animal_train = animal_train.with_transform(transform_animal)

if animal_test is not None:
    animal_test = animal_test.with_transform(transform_animal)


In [None]:
# =============================
# Modell von ImageNet-Checkpoint laden (from_pretrained)
# =============================

vit_animals_model = ViTForImageClassification.from_pretrained(
    OUTPUT_IMAGENET_MODEL_DIR,
    num_labels=num_animal_labels,
    id2label=id2label_animals,
    label2id=label2id_animals,
    ignore_mismatched_sizes=True,  # falls Kopfgröße nicht passt
)

vit_animals_model


In [None]:
# =============================
# TrainingArguments & Trainer (Animals)
# =============================

animal_training_args = TrainingArguments(
    output_dir=OUTPUT_ANIMAL_MODEL_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=ANIMAL_NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    evaluation_strategy="epoch" if animal_test is not None else "no",
    save_strategy="epoch",
    logging_steps=50,
    remove_unused_columns=False,
    load_best_model_at_end=animal_test is not None,
    metric_for_best_model="accuracy" if animal_test is not None else None,
)

accuracy_metric = evaluate.load("accuracy")

def compute_metrics_animals(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return accuracy_metric.compute(predictions=preds, references=labels)

animal_trainer = Trainer(
    model=vit_animals_model,
    args=animal_training_args,
    train_dataset=animal_train,
    eval_dataset=animal_test if animal_test is not None else None,
    data_collator=collate_fn,
    compute_metrics=compute_metrics_animals if animal_test is not None else None,
)


In [None]:
# =============================
# Fein-Tuning auf dein Dataset & speichern
# =============================

animal_trainer.train()

animal_trainer.save_model(OUTPUT_ANIMAL_MODEL_DIR)
processor.save_pretrained(OUTPUT_ANIMAL_MODEL_DIR)


In [None]:
# =============================
# Evaluation & Report
# =============================

if animal_test is not None:
    predictions = animal_trainer.predict(animal_test)

    labels_true = predictions.label_ids
    labels_pred = np.argmax(predictions.predictions, axis=-1)

    from sklearn.metrics import classification_report, confusion_matrix

    report = classification_report(
        labels_true,
        labels_pred,
        target_names=animal_label_names,
        output_dict=True,
    )

    report_df = pd.DataFrame(report).transpose()
    print("\n\nscikit-learn report:\n", report_df)

    report_df.to_csv(f"{OUTPUT_ANIMAL_MODEL_DIR}/classification_report.csv", index=True)

    cm = confusion_matrix(labels_true, labels_pred)
    cm
