In [None]:
# -*- coding: utf-8 -*-
import time, random, math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

# =========================
# Configuration
# =========================
SEEDS = [42, 43, 44, 45, 46]   # Multiple random seeds
IMBALANCE_RATIO = 0.98         # Training set: ratio of digit 9 (majority); digit 4 (minority) ratio = 1 - IMBALANCE_RATIO
TOTAL_TRAIN = 5000
BATCH_SIZE = 100

ERM_STEPS = 8000               # Number of training steps for ERM baseline
GRID_STEP = 0.02               # Step size for grid search
STEPS_PER_GRID = 8000          # Number of training steps for each p4 in grid search
LR = 1e-3
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0

TEST_MAJ_RATIO = 0.50          # Test set fixed at 50/50 ratio (likelihood ratio p4=0.5)
P4_TRAIN_PRIOR = 1.0 - IMBALANCE_RATIO  # p4 prior in training set

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# =========================
# Utility functions
# =========================
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

def load_mnist_49():
    """Load MNIST dataset and keep only digits 4 and 9."""
    tf = transforms.ToTensor()
    tr = datasets.MNIST("./data", train=True,  download=True, transform=tf)
    te = datasets.MNIST("./data", train=False, download=True, transform=tf)

    def filt(ds):
        xs, ys = [], []
        for i in range(len(ds)):
            x, y = ds[i]
            if y in (4, 9):
                xs.append(x)
                ys.append(0 if y == 4 else 1)  # map 4->0 (minority), 9->1 (majority)
        return torch.stack(xs), torch.tensor(ys, dtype=torch.long)

    Xtr, ytr = filt(tr)
    Xte, yte = filt(te)
    return Xtr, ytr, Xte, yte

def sample_imbalanced_train(X, y, total, maj_ratio, seed):
    """Sample imbalanced training data given majority ratio."""
    set_seed(seed)
    idx4 = (y==0).nonzero(as_tuple=True)[0]
    idx9 = (y==1).nonzero(as_tuple=True)[0]
    n_maj = int(total*maj_ratio)
    n_min = total - n_maj
    sel9 = idx9[torch.randperm(len(idx9))[:n_maj]]
    sel4 = idx4[torch.randperm(len(idx4))[:n_min]]
    idx  = torch.cat([sel4, sel9], dim=0)
    idx  = idx[torch.randperm(len(idx))]
    return X[idx], y[idx]

def build_test_with_ratio(X, y, maj_ratio=0.5, total=1000):
    """Build a balanced test set with specified majority ratio."""
    n_maj = int(total * maj_ratio)
    n_min = total - n_maj
    idx4 = (y==0).nonzero(as_tuple=True)[0][:n_min]
    idx9 = (y==1).nonzero(as_tuple=True)[0][:n_maj]
    idx = torch.cat([idx4, idx9], dim=0)
    idx = idx[torch.randperm(len(idx))]
    return X[idx], y[idx]

class LeNet5(nn.Module):
    """Standard LeNet-5 architecture."""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool  = nn.AvgPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*4*4, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 28->24->12
        x = self.pool(F.relu(self.conv2(x)))  # 12->8->4
        x = x.view(x.size(0), -1)             # flatten
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

ce = nn.CrossEntropyLoss(reduction='none')

def compute_weighted_loss(logits, targets, p_vec, p_train_vec):
    """Compute weighted cross-entropy loss based on group weights."""
    w = p_vec[targets] / p_train_vec[targets]
    losses = ce(logits, targets)
    return (losses * w).mean()

def train_steps(model, loader, steps, p_vec, p_train_vec, lr=LR, momentum=MOMENTUM, wd=WEIGHT_DECAY, verbose_every=0):
    """Train the model for a fixed number of steps."""
    model.train()
    opt = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd)
    it = iter(loader)
    for s in range(1, steps+1):
        try:
            xb, yb = next(it)
        except StopIteration:
            it = iter(loader)
            xb, yb = next(it)
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = compute_weighted_loss(logits, yb, p_vec, p_train_vec)
        loss.backward()
        opt.step()
        if verbose_every and s % verbose_every == 0:
            print(f"[train] step {s}/{steps} loss={loss.item():.4f}")

@torch.no_grad()
def evaluate(model, loader):
    """Evaluate accuracy on overall, per-group, and worst group."""
    model.eval()
    correct = 0
    total = 0
    correct_g = {0:0, 1:0}
    total_g   = {0:0, 1:0}
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = model(xb).argmax(1)
        correct += (pred==yb).sum().item()
        total   += yb.numel()
        for g in (0,1):
            mask = (yb==g)
            if mask.any():
                correct_g[g] += (pred[mask]==yb[mask]).sum().item()
                total_g[g]   += mask.sum().item()
    overall = correct / total
    acc_g = {g: (correct_g[g]/total_g[g] if total_g[g]>0 else float('nan')) for g in (0,1)}
    worst = min(acc_g.values())
    return overall, acc_g, worst

# =========================
# Main loop: multi-seed experiments
# =========================
grid = np.arange(0.0, 1.0+1e-8, GRID_STEP)
acc_overall_map = {float(p4): [] for p4 in grid}
acc_worst_map   = {float(p4): [] for p4 in grid}
acc_min_map     = {float(p4): [] for p4 in grid}
acc_maj_map     = {float(p4): [] for p4 in grid}

erm_overall_list, erm_worst_list, erm_min_list, erm_maj_list = [], [], [], []

t0_all = time.time()

Xtr_all, ytr_all, Xte_all, yte_all = load_mnist_49()
print("Filtered train:", Xtr_all.shape, torch.bincount(ytr_all).tolist())
print("Filtered test :",  Xte_all.shape, torch.bincount(yte_all).tolist())

for seed in SEEDS:
    print(f"\n========== SEED {seed} ==========")
    set_seed(seed)

    Xte, yte = build_test_with_ratio(Xte_all, yte_all, maj_ratio=TEST_MAJ_RATIO, total=1000)
    print("Test counts (4,9):", torch.bincount(yte).tolist())

    Xtr, ytr = sample_imbalanced_train(Xtr_all, ytr_all, TOTAL_TRAIN, IMBALANCE_RATIO, seed)
    Xtr_mod = Xtr.clone()
    p_train_counts = torch.bincount(ytr, minlength=2).float()
    p_train_vec = (p_train_counts / p_train_counts.sum()).to(device)
    print("Imbalanced train counts:", p_train_counts.tolist(), "  p_train:", (p_train_vec.cpu().numpy().round(6)).tolist())

    train_loader = DataLoader(TensorDataset(Xtr_mod, ytr), batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    test_loader  = DataLoader(TensorDataset(Xte, yte),     batch_size=1000,    shuffle=False)

    # === ERM (p = p_train) ===
    print("\n=== ERM (p = p_train) ===")
    model_erm = LeNet5().to(device)
    train_steps(model_erm, train_loader, steps=ERM_STEPS, p_vec=p_train_vec, p_train_vec=p_train_vec)
    erm_overall, erm_accg, erm_worst = evaluate(model_erm, test_loader)
    print(f"ERM test Acc: overall={erm_overall:.4f}, 4(min)={erm_accg[0]:.4f}, 9(maj)={erm_accg[1]:.4f}, worst={erm_worst:.4f}")
    erm_overall_list.append(erm_overall)
    erm_worst_list.append(erm_worst)
    erm_min_list.append(erm_accg[0])
    erm_maj_list.append(erm_accg[1])

    # === Grid search (from scratch) ===
    print("\n=== Grid search (from scratch) ===")
    for p4 in grid:
        p4 = float(p4)
        p_vec = torch.tensor([p4, 1.0-p4], dtype=torch.float32, device=device)
        model = LeNet5().to(device)
        train_steps(model, train_loader, steps=STEPS_PER_GRID, p_vec=p_vec, p_train_vec=p_train_vec)
        overall, accg, worst = evaluate(model, test_loader)

        acc_overall_map[p4].append(overall)
        acc_worst_map[p4].append(worst)
        acc_min_map[p4].append(accg[0])
        acc_maj_map[p4].append(accg[1])

        print(f"  p4={p4:.2f} -> overall={overall:.4f} | 4={accg[0]:.4f} 9={accg[1]:.4f} | worst={worst:.4f}")

print(f"\nTotal wall time: {(time.time()-t0_all)/60:.1f} min")

# =========================
# Aggregate results
# =========================
def mean_std(arr):
    return float(np.mean(arr)), (float(np.std(arr, ddof=1)) if len(arr) > 1 else 0.0)

erm_overall_mean, erm_overall_std = mean_std(erm_overall_list)
erm_worst_mean,  erm_worst_std  = mean_std(erm_worst_list)
erm_min_mean,    erm_min_std    = mean_std(erm_min_list)
erm_maj_mean,    erm_maj_std    = mean_std(erm_maj_list)

print("\n==== Aggregated over seeds ====")
print(f"ERM overall = {erm_overall_mean:.4f} ± {erm_overall_std:.4f}")
print(f"ERM worst   = {erm_worst_mean:.4f} ± {erm_worst_std:.4f}")
print(f"ERM min(4)  = {erm_min_mean:.4f} ± {erm_min_std:.4f}")
print(f"ERM maj(9)  = {erm_maj_mean:.4f} ± {erm_maj_std:.4f}")

p4s = np.array(list(acc_overall_map.keys())); p4s.sort()

def agg_curve(metric_map):
    means, stds = [], []
    for p in p4s:
        m, s = mean_std(metric_map[p])
        means.append(m); stds.append(s)
    return np.array(means), np.array(stds)

overall_mean, overall_std = agg_curve(acc_overall_map)
worst_mean,   worst_std   = agg_curve(acc_worst_map)
min_mean,     min_std     = agg_curve(acc_min_map)
maj_mean,     maj_std     = agg_curve(acc_maj_map)

best_overall_idx = int(np.argmax(overall_mean))
best_worst_idx   = int(np.argmax(worst_mean))
print(f"[Overall-mean best] p4={p4s[best_overall_idx]:.2f}, mean overall={overall_mean[best_overall_idx]:.4f}±{overall_std[best_overall_idx]:.4f}, "
      f"mean worst={worst_mean[best_overall_idx]:.4f}±{worst_std[best_overall_idx]:.4f}")
print(f"[Worst-mean  best] p4={p4s[best_worst_idx]:.2f}, mean overall={overall_mean[best_worst_idx]:.4f}±{overall_std[best_worst_idx]:.4f}, "
      f"mean worst={worst_mean[best_worst_idx]:.4f}±{worst_std[best_worst_idx]:.4f}")

# =========================
# Plot curves
# =========================
plt.figure(figsize=(8,6))
plt.plot(p4s, overall_mean, label='Overall Acc (mean)')
plt.fill_between(p4s, overall_mean - overall_std, overall_mean + overall_std, alpha=0.2)

plt.plot(p4s, worst_mean, label='Worst-Group Acc (mean)')
plt.fill_between(p4s, worst_mean - worst_std, worst_mean + worst_std, alpha=0.2)

plt.plot(p4s, min_mean, label='Minority (digit 4) Acc (mean)')
plt.fill_between(p4s, min_mean - min_std, min_mean + min_std, alpha=0.2)

plt.axvline(x=0.5, linestyle='--', label='LR p4 = 0.5 (test prior)')
plt.axvline(x=P4_TRAIN_PRIOR, linestyle=':', label='Train prior p4')
plt.xlabel('Group weight p4 (digit 4)')
plt.ylabel('Accuracy')
plt.title('Accuracy vs p4 (From Scratch, multi-seed mean ± std)')
plt.legend()
plt.grid(True)
plt.show()