
# Co-Teaching on FashionMNIST (0.3 & 0.6) — **Fixed Loader v2**
This version fixes the `Unexpected image shape (N, 784)` error by **automatically reshaping flattened images** to `(N, 1, 28, 28)`.


In [2]:
import os, json, math, random, time
from pathlib import Path
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

import pandas as pd

torch.backends.cudnn.benchmark = True
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)
print("Working dir:", Path.cwd())


Using device: cpu
Working dir: /Users/wukris/5328


In [3]:
from dataclasses import dataclass

@dataclass
class Config:
    dataset_path: str
    n_classes: int = 3
    batch_size: int = 256
    epochs: int = 25
    lr: float = 1e-3
    weight_decay: float = 5e-4
    num_workers: int = 0     # safer default on macOS
    noise_rate: float = 0.3
    num_gradual: int = 10
    exponent: float = 1.0
    n_trials: int = 10
    out_dir: str = "./results_coteaching"

def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)


In [4]:
def _to_chw4(x: np.ndarray) -> np.ndarray:
    """Convert images to (N,C,H,W). Supports:
    - (N, H, W)  -> (N, 1, H, W)
    - (N, 784)   -> (N, 1, 28, 28)   # flattened FashionMNIST
    - (N, 3072)  -> (N, 3, 32, 32)   # flattened CIFAR-style (not used here but safe)
    - (N, C, H, W) with C in {1,3} -> return as-is
    """
    if x.ndim == 4 and x.shape[1] in (1, 3):
        return x
    if x.ndim == 3:
        # assume grayscale (N,H,W)
        return x[:, None, :, :]
    if x.ndim == 2:
        N, D = x.shape
        if D == 28*28:
            return x.reshape(N, 1, 28, 28)
        if D == 32*32*3:
            return x.reshape(N, 3, 32, 32)
        # fallback: try square grayscale
        s = int(round(D ** 0.5))
        if s*s == D:
            return x.reshape(N, 1, s, s)
    raise ValueError(f"Unexpected image shape {x.shape}")

def load_npz_dataset(path: str):
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"Dataset not found: {path.resolve()}" )
    data = np.load(path)
    Xtr, Str = data['Xtr'], data['Str']
    Xts, Yts = data['Xts'], data['Yts']

    # float + scaling if needed
    Xtr = Xtr.astype(np.float32); Xts = Xts.astype(np.float32)
    if Xtr.max() > 1.0 or Xts.max() > 1.0:
        Xtr /= 255.0; Xts /= 255.0

    Xtr = _to_chw4(Xtr)
    Xts = _to_chw4(Xts)

    return Xtr, Str.astype(np.int64), Xts, Yts.astype(np.int64)

def split_train_val(X, y, val_ratio=0.2, seed=0):
    n = len(y); idx = np.arange(n)
    rng = np.random.default_rng(seed); rng.shuffle(idx)
    n_val = int(n * val_ratio)
    return idx[n_val:], idx[:n_val]

def make_loaders(Xtr, Str, Xts, Yts, tr_idx, val_idx, batch_size=256, num_workers=0):
    Xtr_t = torch.from_numpy(Xtr[tr_idx]); Str_t = torch.from_numpy(Str[tr_idx])
    Xval_t = torch.from_numpy(Xtr[val_idx]); Sval_t = torch.from_numpy(Str[val_idx])
    Xts_t  = torch.from_numpy(Xts);          Yts_t  = torch.from_numpy(Yts)

    tr_ds = TensorDataset(Xtr_t, Str_t)
    val_ds= TensorDataset(Xval_t, Sval_t)
    ts_ds = TensorDataset(Xts_t, Yts_t)

    tr_loader = DataLoader(tr_ds, batch_size=batch_size, shuffle=True,  drop_last=True,  num_workers=num_workers)
    val_loader= DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers)
    ts_loader = DataLoader(ts_ds, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers)
    return tr_loader, val_loader, ts_loader


In [5]:
def preview_dataset(path):
    Xtr, Str, Xts, Yts = load_npz_dataset(path)
    print("Train X shape:", Xtr.shape, "| noisy y:", Str.shape)
    print("Test  X shape:", Xts.shape, "| clean y:", Yts.shape)
    print("Train X stats: min", float(Xtr.min()), "max", float(Xtr.max()), "mean", float(Xtr.mean()))


In [6]:
class SmallCNN(nn.Module):
    def __init__(self, n_classes=3):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool2 = nn.AdaptiveAvgPool2d((7,7))
        self.fc1 = nn.Linear(128*7*7, 128)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        x = self.pool(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = torch.relu(self.bn3(self.conv3(x)))
        x = self.pool2(x)
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [7]:
def coteaching_schedule(epoch, noise_rate=0.3, num_gradual=10, exponent=1.0):
    t = min(epoch / max(1, num_gradual), 1.0)
    forget_rate = noise_rate * (t ** exponent)
    return forget_rate, 1.0 - forget_rate

def select_small_loss_idx(losses, remember_rate):
    B = losses.shape[0]
    k = max(int(remember_rate * B), 1)
    _, idx = torch.topk(losses, k=k, largest=False)
    return idx

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval(); correct = 0; total = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = model(xb).argmax(1)
        correct += (pred==yb).sum().item(); total += yb.size(0)
    return correct / max(total,1)

@torch.no_grad()
def evaluate_ensemble(m1, m2, loader, device):
    m1.eval(); m2.eval(); correct = 0; total = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = (m1(xb)+m2(xb)).argmax(1)
        correct += (pred==yb).sum().item(); total += yb.size(0)
    return correct / max(total,1)


In [8]:
def train_coteaching_one_trial(cfg: Config, seed: int):
    set_seed(seed)
    Xtr, Str, Xts, Yts = load_npz_dataset(cfg.dataset_path)
    tr_idx, val_idx = split_train_val(Xtr, Str, val_ratio=0.2, seed=seed)
    tr_loader, val_loader, ts_loader = make_loaders(
        Xtr, Str, Xts, Yts, tr_idx, val_idx,
        batch_size=cfg.batch_size, num_workers=cfg.num_workers
    )

    m1 = SmallCNN(cfg.n_classes).to(DEVICE)
    m2 = SmallCNN(cfg.n_classes).to(DEVICE)
    opt1 = torch.optim.AdamW(m1.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    opt2 = torch.optim.AdamW(m2.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sch1 = torch.optim.lr_scheduler.CosineAnnealingLR(opt1, T_max=cfg.epochs)
    sch2 = torch.optim.lr_scheduler.CosineAnnealingLR(opt2, T_max=cfg.epochs)
    ce = nn.CrossEntropyLoss(reduction='none')

    hist = {"seed": seed, "ts_acc_ens": []}

    for ep in range(cfg.epochs):
        m1.train(); m2.train()
        fr, rr = coteaching_schedule(ep, cfg.noise_rate, cfg.num_gradual, cfg.exponent)
        for xb, yb in tr_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            l1 = ce(m1(xb), yb); l2 = ce(m2(xb), yb)
            idx1 = select_small_loss_idx(l1.detach(), rr)
            idx2 = select_small_loss_idx(l2.detach(), rr)

            loss1 = ce(m1(xb)[idx2], yb[idx2]).mean()
            loss2 = ce(m2(xb)[idx1], yb[idx1]).mean()

            opt1.zero_grad(set_to_none=True); loss1.backward(); opt1.step()
            opt2.zero_grad(set_to_none=True); loss2.backward(); opt2.step()

        sch1.step(); sch2.step()
        hist["ts_acc_ens"].append(evaluate_ensemble(m1, m2, ts_loader, DEVICE))

    return hist["ts_acc_ens"][-1], hist


In [9]:
def run_trials_for_dataset(cfg: Config, dataset_tag: str):
    ensure_dir(cfg.out_dir)
    out_dir = Path(cfg.out_dir) / f"{dataset_tag}_coteaching"
    ensure_dir(out_dir)

    accs = []
    for k in range(cfg.n_trials):
        seed = 1234 + k
        print(f"\n[Dataset={dataset_tag}] Trial {k+1}/{cfg.n_trials} (seed={seed})")
        acc, hist = train_coteaching_one_trial(cfg, seed)
        accs.append(acc)
        with open(out_dir / f"trial_{k+1:02d}.json", "w") as f:
            json.dump(hist, f)

    mean = float(np.mean(accs)); std = float(np.std(accs, ddof=1)) if len(accs)>1 else 0.0
    summary = {
        "dataset": dataset_tag, "algo": "Co-Teaching",
        "n_trials": cfg.n_trials, "epochs": cfg.epochs,
        "batch_size": cfg.batch_size, "noise_rate": cfg.noise_rate,
        "mean_test_acc": mean, "std_test_acc": std,
    }
    with open(out_dir / "summary.json", "w") as f: json.dump(summary, f, indent=2)

    csv_path = Path(cfg.out_dir) / "summary_all.csv"
    df_row = pd.DataFrame([summary])
    if csv_path.exists():
        df_old = pd.read_csv(csv_path)
        df = pd.concat([df_old, df_row], ignore_index=True)
    else:
        df = df_row
    df.to_csv(csv_path, index=False)

    print(f"[Done] {dataset_tag}: mean±std = {mean:.2%} ± {std:.2%}")
    return mean, std


In [10]:
PATH_03 = "FashionMNIST0.3.npz"
PATH_06 = "FashionMNIST0.6.npz"

cfg03 = Config(dataset_path=PATH_03, noise_rate=0.3, epochs=25, n_trials=10, num_workers=0, out_dir="./results_coteaching")
cfg06 = Config(dataset_path=PATH_06, noise_rate=0.6, epochs=35, n_trials=10, num_workers=0, out_dir="./results_coteaching")

print(cfg03)
print(cfg06)

preview_dataset(PATH_03)
preview_dataset(PATH_06)

run_trials_for_dataset(cfg03, dataset_tag="FashionMNIST0.3")
run_trials_for_dataset(cfg06, dataset_tag="FashionMNIST0.6")


Config(dataset_path='FashionMNIST0.3.npz', n_classes=3, batch_size=256, epochs=25, lr=0.001, weight_decay=0.0005, num_workers=0, noise_rate=0.3, num_gradual=10, exponent=1.0, n_trials=10, out_dir='./results_coteaching')
Config(dataset_path='FashionMNIST0.6.npz', n_classes=3, batch_size=256, epochs=35, lr=0.001, weight_decay=0.0005, num_workers=0, noise_rate=0.6, num_gradual=10, exponent=1.0, n_trials=10, out_dir='./results_coteaching')
Train X shape: (18000, 1, 28, 28) | noisy y: (18000,)
Test  X shape: (3000, 1, 28, 28) | clean y: (3000,)
Train X stats: min 0.0 max 1.0 mean 0.28434571623802185
Train X shape: (18000, 1, 28, 28) | noisy y: (18000,)
Test  X shape: (3000, 1, 28, 28) | clean y: (3000,)
Train X stats: min 0.0 max 1.0 mean 0.28434571623802185

[Dataset=FashionMNIST0.3] Trial 1/10 (seed=1234)

[Dataset=FashionMNIST0.3] Trial 2/10 (seed=1235)

[Dataset=FashionMNIST0.3] Trial 3/10 (seed=1236)

[Dataset=FashionMNIST0.3] Trial 4/10 (seed=1237)

[Dataset=FashionMNIST0.3] Trial 5/1

(0.7425333333333333, 0.20105525313906825)