In [None]:
!pip install transformers datasets accelerate evaluate scikit-learn matplotlib seaborn pandas

In [None]:
import os
import json
import time
from datetime import datetime

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import torchvision.transforms as transforms

from datasets import load_dataset
from transformers import (
    AutoModelForImageClassification,
    AutoImageProcessor,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
)
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_recall_fscore_support,
    accuracy_score
)
import evaluate

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    import random
    random.seed(seed)

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
RESULTS_DIR = "results/ms3"
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(f"{RESULTS_DIR}/models", exist_ok=True)
print(f"Results will be saved to {RESULTS_DIR}")

In [None]:
MODELS = {
    "deit_base": {
        "model_name": "facebook/deit-base-patch16-224",
    },
    "deit_small": {
        "model_name": "facebook/deit-small-patch16-224",
    },
    "swin_base": {
        "model_name": "microsoft/swin-base-patch4-window7-224",
    },
    "beit_base": {
        "model_name": "microsoft/beit-base-patch16-224",
    },
}

print("Models for comparison:")
for name, config in MODELS.items():
    print(f"  - {name}: {config['model_name']}")

## Load Dataset

In [None]:
dataset = load_dataset("cifar10")
labels = dataset["train"].features["label"].names
num_labels = len(labels)

print(f"Classes: {labels}")
print(f"Training images: {len(dataset['train'])}")
print(f"Test images: {len(dataset['test'])}")

## Augmentation

In [None]:
class AugmentationConfig:
    def __init__(self):
        self.use_cutmix = True
        self.use_mixup = True
        self.cutmix_prob = 0.4
        self.mixup_prob = 0.2
        self.cutmix_alpha = 0.8
        self.mixup_alpha = 0.6


def get_train_transforms():
    return transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    ])


def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2


def cutmix_data(images, labels, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = images.size(0)
    index = torch.randperm(batch_size).to(images.device)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
    images[:, :, bbx1:bbx2, bby1:bby2] = images[index, :, bbx1:bbx2, bby1:bby2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2]))
    return images, labels, labels[index], lam


def mixup_data(images, labels, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = images.size(0)
    index = torch.randperm(batch_size).to(images.device)
    mixed_images = lam * images + (1 - lam) * images[index]
    return mixed_images, labels, labels[index], lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


aug_config = AugmentationConfig()
train_transforms = get_train_transforms()

## Custom Trainer

In [None]:
class AugmentedTrainer(Trainer):
    def __init__(self, *args, aug_config=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.aug_config = aug_config or AugmentationConfig()

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")

        if self.model.training:
            r = np.random.rand()
            if self.aug_config.use_cutmix and r < self.aug_config.cutmix_prob:
                inputs["pixel_values"], labels_a, labels_b, lam = cutmix_data(
                    inputs["pixel_values"], labels, self.aug_config.cutmix_alpha
                )
                outputs = model(**inputs)
                loss = mixup_criterion(nn.CrossEntropyLoss(), outputs.logits, labels_a, labels_b, lam)
            elif self.aug_config.use_mixup and r < (self.aug_config.cutmix_prob + self.aug_config.mixup_prob):
                inputs["pixel_values"], labels_a, labels_b, lam = mixup_data(
                    inputs["pixel_values"], labels, self.aug_config.mixup_alpha
                )
                outputs = model(**inputs)
                loss = mixup_criterion(nn.CrossEntropyLoss(), outputs.logits, labels_a, labels_b, lam)
            else:
                outputs = model(**inputs, labels=labels)
                loss = outputs.loss
        else:
            outputs = model(**inputs, labels=labels)
            loss = outputs.loss

        return (loss, outputs) if return_outputs else loss

## Metrics

In [None]:
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    preds = np.argmax(predictions, axis=1)
    
    acc = accuracy_metric.compute(predictions=preds, references=labels)
    f1 = f1_metric.compute(predictions=preds, references=labels, average="macro")
    
    precision, recall, _, _ = precision_recall_fscore_support(
        labels, preds, average="macro", zero_division=0
    )
    
    return {
        "accuracy": acc["accuracy"],
        "f1_macro": f1["f1"],
        "precision_macro": precision,
        "recall_macro": recall,
    }

## CSV Export

In [None]:
def save_training_history(trainer, model_key):
    history = trainer.state.log_history
    
    train_data = []
    eval_data = []
    
    for entry in history:
        if "loss" in entry and "eval_loss" not in entry:
            train_data.append({
                "step": entry.get("step", 0),
                "epoch": entry.get("epoch", 0),
                "loss": entry.get("loss", 0),
                "learning_rate": entry.get("learning_rate", 0),
            })
        if "eval_loss" in entry:
            eval_data.append({
                "step": entry.get("step", 0),
                "epoch": entry.get("epoch", 0),
                "eval_loss": entry.get("eval_loss", 0),
                "eval_accuracy": entry.get("eval_accuracy", 0),
                "eval_f1_macro": entry.get("eval_f1_macro", 0),
            })
    
    if train_data:
        df_train = pd.DataFrame(train_data)
        df_train.to_csv(f"{RESULTS_DIR}/{model_key}_training_history.csv", index=False)
    
    if eval_data:
        df_eval = pd.DataFrame(eval_data)
        df_eval.to_csv(f"{RESULTS_DIR}/{model_key}_eval_history.csv", index=False)


def save_classification_report(y_true, y_pred, model_key, label_names):
    report = classification_report(y_true, y_pred, target_names=label_names, output_dict=True)
    df_report = pd.DataFrame(report).transpose()
    df_report.to_csv(f"{RESULTS_DIR}/{model_key}_classification_report.csv")
    return report


def save_confusion_matrix(y_true, y_pred, model_key, label_names):
    cm = confusion_matrix(y_true, y_pred)
    df_cm = pd.DataFrame(cm, index=label_names, columns=label_names)
    df_cm.to_csv(f"{RESULTS_DIR}/{model_key}_confusion_matrix.csv")
    return cm


def save_per_class_metrics(y_true, y_pred, model_key, label_names):
    precision, recall, f1, support = precision_recall_fscore_support(
        y_true, y_pred, average=None, zero_division=0
    )
    
    df_metrics = pd.DataFrame({
        "class": label_names,
        "precision": precision,
        "recall": recall,
        "f1_score": f1,
        "support": support.astype(int),
    })
    df_metrics.to_csv(f"{RESULTS_DIR}/{model_key}_per_class_metrics.csv", index=False)
    return df_metrics


def save_overall_metrics(metrics_dict, model_key):
    df = pd.DataFrame([metrics_dict])
    df.to_csv(f"{RESULTS_DIR}/{model_key}_overall_metrics.csv", index=False)

## Training

In [None]:
LEARNING_RATE = 2e-4
BATCH_SIZE = 64
NUM_EPOCHS = 50
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.1
EARLY_STOPPING_PATIENCE = 5


def train_model(model_key, model_config):
    
    model_name = model_config["model_name"]
    print(f"\n{'='*60}")
    print(f"Training: {model_key}")
    print(f"Model: {model_name}")
    print(f"{'='*60}")
    
    torch.cuda.empty_cache()
    
    processor = AutoImageProcessor.from_pretrained(model_name)
    
    def preprocess_train(examples):
        images = []
        for img in examples["img"]:
            img = img.convert("RGB")
            if train_transforms:
                img = train_transforms(img)
            images.append(img)
        inputs = processor(images, return_tensors="pt")
        inputs["labels"] = examples["label"]
        return inputs

    def preprocess_val(examples):
        images = [img.convert("RGB") for img in examples["img"]]
        inputs = processor(images, return_tensors="pt")
        inputs["labels"] = examples["label"]
        return inputs

    train_ds = dataset["train"].with_transform(preprocess_train)
    val_ds = dataset["test"].with_transform(preprocess_val)
    
    model = AutoModelForImageClassification.from_pretrained(
        model_name,
        num_labels=num_labels,
        ignore_mismatched_sizes=True,
        id2label={str(i): label for i, label in enumerate(labels)},
        label2id={label: str(i) for i, label in enumerate(labels)}
    )
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Parameters: {total_params:,} (trainable: {trainable_params:,})")
    
    output_dir = f"{RESULTS_DIR}/models/{model_key}"
    
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=NUM_EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE * 2,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        warmup_ratio=WARMUP_RATIO,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="steps",
        logging_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="f1_macro",
        greater_is_better=True,
        save_total_limit=2,
        remove_unused_columns=False,
        push_to_hub=False,
        report_to="none",
        seed=42,
        fp16=torch.cuda.is_available(),
        dataloader_num_workers=2,
    )
    
    trainer = AugmentedTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE)],
        aug_config=aug_config
    )
    
    start_time = time.time()
    train_result = trainer.train()
    training_time = time.time() - start_time
    
    print(f"Training completed in {training_time/60:.1f} minutes")
    
    eval_result = trainer.evaluate()
    print(f"Test Accuracy: {eval_result['eval_accuracy']:.4f}")
    print(f"Test Loss: {eval_result['eval_loss']:.4f}")
    
    predictions = trainer.predict(val_ds)
    y_pred = np.argmax(predictions.predictions, axis=1)
    y_true = predictions.label_ids
    
    save_training_history(trainer, model_key)
    save_classification_report(y_true, y_pred, model_key, labels)
    save_confusion_matrix(y_true, y_pred, model_key, labels)
    df_perclass = save_per_class_metrics(y_true, y_pred, model_key, labels)
    
    overall = {
        "model": model_key,
        "accuracy": eval_result["eval_accuracy"],
        "f1_macro": eval_result["eval_f1_macro"],
        "precision_macro": eval_result["eval_precision_macro"],
        "recall_macro": eval_result["eval_recall_macro"],
        "test_loss": eval_result["eval_loss"],
        "train_loss": train_result.metrics.get("train_loss", 0),
        "epochs_trained": len([l for l in trainer.state.log_history if "eval_loss" in l]),
        "training_time_min": training_time / 60,
        "total_params": total_params,
        "architecture": model_config.get("description", model_name),
    }
    save_overall_metrics(overall, model_key)
    
    trainer.save_model(output_dir)
    processor.save_pretrained(output_dir)
    
    print(f"CSVs saved for {model_key}")
    
    del model, trainer
    torch.cuda.empty_cache()
    
    return overall, df_perclass

## Start Training

In [None]:
all_results = []
all_perclass = {}

for i, (model_key, model_config) in enumerate(MODELS.items(), 1):
    print(f"\n[{i}/{len(MODELS)}] {model_key}")
    
    try:
        overall, perclass = train_model(model_key, model_config)
        all_results.append(overall)
        all_perclass[model_key] = perclass
        
        df_all = pd.DataFrame(all_results)
        df_all.to_csv(f"{RESULTS_DIR}/all_models_comparison.csv", index=False)
        
    except Exception as e:
        print(f"Error training {model_key}: {e}")
        import traceback
        traceback.print_exc()
        continue

print("\nTraining completed!")

## Compare Results

In [None]:
ms1_ms2_results = [
    {
        "model": "vit_base_ms1",
        "accuracy": 0.9896,
        "f1_macro": 0.9896,
        "test_loss": 0.0436,
        "train_loss": 0.6220,
        "epochs_trained": 21,
        "total_params": 85800000,
        "architecture": "ViT Baseline (MS1)",
        "source": "MS1"
    },
    {
        "model": "hybrid_vit_ms2",
        "accuracy": 0.9870,
        "f1_macro": 0.9870,
        "test_loss": 0.0528,
        "train_loss": 0.6238,
        "epochs_trained": 14,
        "total_params": 118300000,
        "architecture": "Hybrid ViT (MS2)",
        "source": "MS2"
    },
    {
        "model": "resnet50_ms1",
        "accuracy": 0.9750,
        "f1_macro": 0.9750,
        "test_loss": 0.0862,
        "train_loss": 0.6887,
        "epochs_trained": 49,
        "total_params": 23500000,
        "architecture": "ResNet-50 (MS1)",
        "source": "MS1"
    }
]

df_ms1_ms2 = pd.DataFrame(ms1_ms2_results)

In [None]:
df_ms3 = pd.read_csv(f"{RESULTS_DIR}/all_models_comparison.csv")
df_ms3["source"] = "MS3"

df_all = pd.concat([df_ms3, df_ms1_ms2], ignore_index=True)
df_all = df_all.sort_values("accuracy", ascending=False).reset_index(drop=True)

print("Complete comparison of all models:")
print(df_all[["model", "accuracy", "f1_macro", "test_loss", "total_params", "source"]].to_string())

df_all.to_csv(f"{RESULTS_DIR}/complete_comparison.csv", index=False)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax1 = axes[0]
colors = ["tab:blue" if s == "MS3" else "tab:orange" if s == "MS2" else "tab:green" for s in df_all["source"]]
ax1.barh(df_all["model"], df_all["accuracy"], color=colors)
ax1.set_xlabel("Accuracy")
ax1.set_title("Accuracy Comparison")
ax1.set_xlim([0.97, 1.0])

ax2 = axes[1]
for source, color in [("MS1", "tab:green"), ("MS2", "tab:orange"), ("MS3", "tab:blue")]:
    mask = df_all["source"] == source
    ax2.scatter(
        df_all[mask]["total_params"] / 1e6,
        df_all[mask]["accuracy"],
        c=color,
        s=100,
        label=source,
        alpha=0.7
    )

ax2.set_xlabel("Parameters (millions)")
ax2.set_ylabel("Accuracy")
ax2.set_title("Accuracy vs Model Size")
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{RESULTS_DIR}/comparison_plots.png", dpi=150)
plt.show()

In [None]:
if all_perclass:
    perclass_data = []
    for model_key, df in all_perclass.items():
        for _, row in df.iterrows():
            perclass_data.append({
                "model": model_key,
                "class": row["class"],
                "f1": row["f1_score"]
            })
    
    df_perclass_all = pd.DataFrame(perclass_data)
    pivot = df_perclass_all.pivot(index="class", columns="model", values="f1")
    
    plt.figure(figsize=(10, 6))
    sns.heatmap(pivot, annot=True, fmt=".3f", cmap="RdYlGn", vmin=0.95, vmax=1.0)
    plt.title("F1-Score per Class (MS3 Models)")
    plt.tight_layout()
    plt.savefig(f"{RESULTS_DIR}/perclass_heatmap.png", dpi=150)
    plt.show()

In [None]:
print("Saved files:\n")
for f in sorted(os.listdir(RESULTS_DIR)):
    if f.endswith(".csv") or f.endswith(".png"):
        print(f"  {f}")