In [None]:
# תא ניסוי: augmentation – minority-only data augmentation (cats)

import torch
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import transforms

from dataset import load_imbalanced_cifar
from model import SimpleCNN
from pipeline.train import train_model
from pipeline.eval import eval_model
from utils import get_device, set_seed, plot_cm
from experiment_logger import log_experiment_to_sheet

# ==== SETUP ====
set_seed(42)
device = get_device()
print("Device:", device)

# ==== DEFINE AUGMENTED DATASET ====

cat_aug = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.3, contrast=0.3),
    ]
)


class AugmentedCIFAR(Dataset):
    """עוטף את ה-dataset ומפעיל augmentation רק על label מסוים (כאן: חתולים=0)."""

    def __init__(self, base_dataset, augment_for_label=0, transform=None):
        self.base = base_dataset
        self.augment_label = augment_for_label
        self.transform = transform

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, label = self.base[idx]
        if label == self.augment_label and self.transform is not None:
            img = self.transform(img)
        return img, label


# ==== LOAD DATA ====
base_dataset = load_imbalanced_cifar(cat_count=30, dog_count=500)
dataset = AugmentedCIFAR(base_dataset, augment_for_label=0, transform=cat_aug)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_ds, test_ds = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

print(f"Train size: {len(train_ds)}, Test size: {len(test_ds)}")

# ==== MODEL & TRAINING ====
model = SimpleCNN().to(device)

model, losses = train_model(
    model=model,
    train_loader=train_loader,
    device=device,
    epochs=5,
)

# ==== EVALUATION ====
overall, cat_acc, dog_acc, cm = eval_model(model, test_loader, device)

print("==== Evaluation with Augmentation ====")
print("Overall accuracy:", overall)
print("Cat accuracy:", cat_acc)
print("Dog accuracy:", dog_acc)

plot_cm(cm)

# ==== LOG TO GOOGLE SHEETS ====
labels = torch.tensor([y for _, y in base_dataset])
num_cats = (labels == 0).sum().item()
num_dogs = (labels == 1).sum().item()

metrics = {
    "overall_acc": overall,
    "cat_acc": cat_acc,
    "dog_acc": dog_acc,
}

config = {
    "method": "augmentation",
    "epochs": 5,
    "batch_size": 32,
    "lr": 1e-3,
    "augmentation": "cat_only_flip_rotate_color",
    "cat_count_original": num_cats,
    "dog_count_original": num_dogs,
}

log_experiment_to_sheet(
    experiment_name="augmentation",
    metrics=metrics,
    config=config,
    notes="Applied augmentation only to minority class (cats).",
)