In [None]:
# תא ניסוי: augmentation + oversampling for minority class (cats)

import torch
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
import random

from dataset import load_imbalanced_cifar
from model import SimpleCNN
from pipeline.train import train_model
from pipeline.eval import eval_model
from utils import get_device, set_seed, plot_cm
from experiment_logger import log_experiment_to_sheet

# ==== SETUP ====
set_seed(42)
device = get_device()
print("Device:", device)

# ==== DEFINE AUGMENTATION (same as before) ====

cat_aug = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.3, contrast=0.3),
    ]
)


class AugmentedCIFAR(Dataset):
    """Wrap base dataset; for augmented indices, apply augmentation every time."""

    def __init__(self, base_dataset, indices, augment_label=0, transform=None):
        self.base = base_dataset
        self.indices = indices
        self.augment_label = augment_label
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        img, label = self.base[real_idx]

        # augmentation is applied only to minority label
        if label == self.augment_label and self.transform:
            img = self.transform(img)

        return img, label


# ==== LOAD ORIGINAL DATA ====
base_dataset = load_imbalanced_cifar(cat_count=30, dog_count=500)
print("Loaded base dataset (30 cats / 500 dogs)")

# ==== GET INDICES ====
cat_indices = [i for i, (_, y) in enumerate(base_dataset) if y == 0]
dog_indices = [i for i, (_, y) in enumerate(base_dataset) if y == 1]

# ==== TRAIN/TEST SPLIT ====
random.shuffle(cat_indices)
random.shuffle(dog_indices)

cat_train = cat_indices[: int(0.8 * len(cat_indices))]
cat_test = cat_indices[int(0.8 * len(cat_indices)) :]

dog_train = dog_indices[: int(0.8 * len(dog_indices))]
dog_test = dog_indices[int(0.8 * len(dog_indices)) :]

# ==== OVERSAMPLING: make cats ≈ dogs ====
target = len(dog_train)  # number of dogs in train
repeat_factor = target // len(cat_train)
remainder = target % len(cat_train)

oversampled_cats = cat_train * repeat_factor + cat_train[:remainder]

balanced_train_indices = oversampled_cats + dog_train
random.shuffle(balanced_train_indices)

print(f"Train cats: {len(oversampled_cats)}, Train dogs: {len(dog_train)}")
print(f"Balanced train size: {len(balanced_train_indices)}")

# ==== BUILD DATASETS ====
train_ds = AugmentedCIFAR(
    base_dataset, balanced_train_indices, augment_label=0, transform=cat_aug
)

test_ds = Subset(base_dataset, cat_test + dog_test)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

print("Final train size:", len(train_ds))
print("Final test size:", len(test_ds))

# ==== MODEL & TRAINING ====
model = SimpleCNN().to(device)

model, losses = train_model(
    model=model, train_loader=train_loader, device=device, epochs=5
)

# ==== EVALUATION ====
overall, cat_acc, dog_acc, cm = eval_model(model, test_loader, device)

print("==== Evaluation with Balanced Augmentation ====")
print("Overall accuracy:", overall)
print("Cat accuracy:", cat_acc)
print("Dog accuracy:", dog_acc)

plot_cm(cm)

# ==== LOG TO GOOGLE SHEETS ====
metrics = {
    "overall_acc": overall,
    "cat_acc": cat_acc,
    "dog_acc": dog_acc,
}

config = {
    "method": "augmentation_balanced",
    "epochs": 5,
    "batch_size": 32,
    "lr": 1e-3,
    "cats_original": len(cat_indices),
    "dogs_original": len(dog_indices),
    "cats_after_oversample": len(oversampled_cats),
}

log_experiment_to_sheet(
    experiment_name="augmentation",
    metrics=metrics,
    config=config,
    notes="Balanced augmentation: oversampling + transforms",
)

