In [None]:
# weighted_loss.py

import sys, os
import torch
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
from experiment_logger import log_experiment_to_sheet

# ==== PATHS ====
PROJECT_ROOT = "/content/cnn-fairness-tradeoff/project"
COMMON_PATH = os.path.join(PROJECT_ROOT, "common")
sys.path.append(COMMON_PATH)

# ==== IMPORTS ====
from dataset import load_imbalanced_cifar
from model import SimpleCNN
from pipeline.train import train_model
from pipeline.eval import eval_model
from utils import get_device, set_seed, plot_cm

# ==== SETUP ====
set_seed(42)
device = get_device()
print("Device:", device)

# ==== LOAD DATA ====
dataset = load_imbalanced_cifar(cat_count=30, dog_count=500)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_ds, test_ds = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

# ==== CLASS WEIGHTS ====
# count labels
labels = torch.tensor([y for _, y in dataset])
num_cats = (labels == 0).sum().item()
num_dogs = (labels == 1).sum().item()

print(f"Class counts â†’ Cats: {num_cats}, Dogs: {num_dogs}")

# Inverse frequency weighting
weight_cat = 1.0 / num_cats
weight_dog = 1.0 / num_dogs

class_weights = torch.tensor([weight_cat, weight_dog], device=device)
print("Class weights:", class_weights)

criterion = nn.CrossEntropyLoss(weight=class_weights)

# ==== MODEL ====
model = SimpleCNN().to(device)

# ==== TRAINING ====
model, losses = train_model(
    model=model, train_loader=train_loader, device=device, criterion=criterion, epochs=5
)

# ==== EVALUATION ====
overall, cat_acc, dog_acc, cm = eval_model(model, test_loader, device)

print("==== Evaluation with Weighted Loss ====")
print("Overall accuracy:", overall)
print("Cat accuracy:", cat_acc)
print("Dog accuracy:", dog_acc)

plot_cm(cm)

# ==== OPTIONAL: SAVE OUTPUT ====
# ==== SAVE RESULTS TO GOOGLE SHEETS ====

metrics = {
    "overall_acc": overall,
    "cat_acc": cat_acc,
    "dog_acc": dog_acc,
}

config = {
    "method": "weighted_loss",
    "epochs": 5,
    "batch_size": 32,
    "lr": 1e-3,
    "cat_count": num_cats,
    "dog_count": num_dogs,
    "weight_cat": float(weight_cat),
    "weight_dog": float(weight_dog),
}

log_experiment_to_sheet(
    experiment_name="weighted_loss",
    metrics=metrics,
    config=config,
    notes="Inverse-frequency class weights (cost-sensitive learning).",
)