In [None]:
import argparse
import os
import random
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split


#  Utilities 

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


#  Data loading 

def load_dataset(npz_path: str):
    data = np.load(npz_path)
    Xtr, Str = data["Xtr"], data["Str"]
    Xts, Yts = data["Xts"], data["Yts"]
    return Xtr, Str, Xts, Yts


class NpzImageDataset(Dataset):
    def __init__(self, X, y=None):
        X = X.astype(np.float32)

        if X.max() > 1.0:
            X = X / 255.0

        # reshape to (N, C, H, W)
        if X.ndim == 2:
            # flat vectors
            N, D = X.shape
            side = int(np.sqrt(D))
            if side * side == D:
                # grayscale
                X = X.reshape(N, 1, side, side)
            elif D % 3 == 0:
                side = int(np.sqrt(D // 3))
                if side * side == (D // 3):
                    # RGB, currently NHWC; convert to NCHW
                    X = X.reshape(N, side, side, 3).transpose(0, 3, 1, 2)
                else:
                    raise ValueError(f"Cannot infer image shape from flat dimension {D}")
            else:
                raise ValueError(f"Cannot infer image shape from flat dimension {D}")
        elif X.ndim == 3:
            # (N, H, W) -> grayscale
            X = X[:, None, :, :]
        elif X.ndim == 4:
            # assume either (N, C, H, W) or (N, H, W, C)
            if X.shape[1] in (1, 3):
                # already (N, C, H, W)
                pass
            elif X.shape[-1] in (1, 3):
                # (N, H, W, C) -> (N, C, H, W)
                X = X.transpose(0, 3, 1, 2)
            else:
                raise ValueError(f"Unexpected 4D image shape {X.shape}")
        else:
            raise ValueError(f"Unexpected image array ndim={X.ndim}")

        self.X = X
        self.y = y.astype(np.int64) if y is not None else None

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        x = torch.from_numpy(self.X[idx])
        if self.y is None:
            return x
        return x, int(self.y[idx])


def infer_in_channels(X) -> int:
    if X.ndim == 4:
        # either (N, C, H, W) or (N, H, W, C)
        if X.shape[1] in (1, 3):
            return X.shape[1]
        if X.shape[-1] in (1, 3):
            return X.shape[-1]
    if X.ndim == 3:
        # (N, H, W)
        return 1
    if X.ndim == 2:
        N, D = X.shape
        side = int(np.sqrt(D))
        if side * side == D:
            return 1
        if D % 3 == 0:
            side = int(np.sqrt(D // 3))
            if side * side == (D // 3):
                return 3
        return 1
    raise ValueError(f"Unexpected image shape for infer_in_channels: {X.shape}")


#  Model 

class SmallCNN(nn.Module):
    def __init__(self, in_ch: int = 1, num_classes: int = 3, dropout: float = 0.3):
        super().__init__()
        c = 32
        self.features = nn.Sequential(
            nn.Conv2d(in_ch, c, 3, padding=1),
            nn.BatchNorm2d(c),
            nn.ReLU(inplace=True),
            nn.Conv2d(c, c, 3, padding=1),
            nn.BatchNorm2d(c),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(c, 2 * c, 3, padding=1),
            nn.BatchNorm2d(2 * c),
            nn.ReLU(inplace=True),
            nn.Conv2d(2 * c, 2 * c, 3, padding=1),
            nn.BatchNorm2d(2 * c),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(dropout),
            nn.LazyLinear(num_classes),
        )

    def forward(self, x):
        return self.head(self.features(x))


#  Co-Teaching + T / T_hat core 

def select_small_loss_indices(loss_vec: torch.Tensor, remember_rate: float) -> torch.Tensor:
    k = max(1, int(remember_rate * loss_vec.shape[0]))
    idx = torch.topk(-loss_vec, k=k).indices
    return idx


def schedule_remember_rate(epoch: int, total_epochs: int, noise_rate: float, warmup_epochs: int = 5) -> float:
    if noise_rate is None:
        noise_rate = 0.5
    if epoch < warmup_epochs:
        return 1.0
    num_gradual = max(1, total_epochs - warmup_epochs)
    forget = noise_rate * min((epoch - warmup_epochs + 1) / num_gradual, 1.0)
    return float(max(0.0, 1.0 - forget))


@torch.no_grad()
def disagreement_rate(model1, model2, loader, device):
    tot, disagree = 0, 0
    for xb, _ in loader:
        xb = xb.to(device)
        p1 = model1(xb).argmax(1).cpu()
        p2 = model2(xb).argmax(1).cpu()
        disagree += (p1 != p2).sum().item()
        tot += xb.size(0)
    return disagree / max(1, tot)


def noisy_ce_with_T_per_sample(logits: torch.Tensor, noisy_labels: torch.Tensor, T: torch.Tensor) -> torch.Tensor:
    p_clean = F.softmax(logits, dim=1)      # p(clean | x)
    noisy_pred = p_clean @ T                # P(noisy | x)
    noisy_pred = torch.clamp(noisy_pred, 1e-8, 1.0)
    idx = torch.arange(noisy_labels.size(0), device=noisy_labels.device)
    loss = -torch.log(noisy_pred[idx, noisy_labels])
    return loss


def run_epoch_coteaching_T(
    model1,
    model2,
    opt1,
    opt2,
    train_loader,
    device,
    remember_rate: float,
    T: torch.Tensor,
):
    model1.train()
    model2.train()

    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)

        logits1 = model1(xb)
        logits2 = model2(xb)

        # compute per-sample loss under noisy-label model
        loss1 = noisy_ce_with_T_per_sample(logits1, yb, T).detach()
        loss2 = noisy_ce_with_T_per_sample(logits2, yb, T).detach()

        # select "small-loss" subset for each peer
        idx1 = select_small_loss_indices(loss1, remember_rate)
        idx2 = select_small_loss_indices(loss2, remember_rate)

        peer_loss1 = noisy_ce_with_T_per_sample(logits1[idx2], yb[idx2], T).mean()
        peer_loss2 = noisy_ce_with_T_per_sample(logits2[idx1], yb[idx1], T).mean()

        opt1.zero_grad()
        peer_loss1.backward()
        opt1.step()

        opt2.zero_grad()
        peer_loss2.backward()
        opt2.step()


def evaluate(model, loader, device) -> float:
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            pred = logits.argmax(1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)
    return 100.0 * correct / max(1, total)


#  Estimate T_hat (CIFAR) 

class AnchorNet(nn.Module):
    def __init__(self, in_ch: int, num_classes: int):
        super().__init__()
        self.backbone = SmallCNN(in_ch=in_ch, num_classes=num_classes)

    def forward(self, x):
        return self.backbone(x)


def estimate_transition_matrix_anchor(
    Xtr: np.ndarray,
    Str: np.ndarray,
    in_ch: int,
    num_classes: int = 3,
    epochs: int = 10,
    batch_size: int = 256,
    top_percent: float = 0.05,
    device=None,
):
    if device is None:
        device = get_device()

    net = AnchorNet(in_ch=in_ch, num_classes=num_classes).to(device)
    opt = torch.optim.Adam(net.parameters(), lr=1e-3)

    ds = NpzImageDataset(Xtr, Str)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0)

    # small classifier trained on noisy labels
    for ep in range(epochs):
        net.train()
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = net(xb)
            loss = F.cross_entropy(logits, yb)
            opt.zero_grad()
            loss.backward()
            opt.step()

    # collect predictions
    net.eval()
    all_logits, all_labels = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            logits = net(xb)
            all_logits.append(logits.cpu())
            all_labels.append(yb.clone())
    all_logits = torch.cat(all_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0).numpy()

    probs = F.softmax(all_logits, dim=1).numpy()
    T_hat = np.zeros((num_classes, num_classes), dtype=np.float32)

    for i in range(num_classes):
        idx = np.where(all_labels == i)[0]
        if len(idx) == 0:
            continue
        class_probs = probs[idx]
        Ni = class_probs.shape[0]
        k = max(1, int(Ni * top_percent))
        order = np.argsort(-class_probs[:, i])
        top_idx = order[:k]
        T_hat[i] = class_probs[top_idx].mean(axis=0)

    return T_hat


#  Known T for FashionMNIST0.3 / 0.6 

def get_known_transition_matrix_from_name(base_name: str, num_classes: int) -> np.ndarray:
    base = base_name.lower()
    if "0.3" in base:
        T = np.array([
            [0.7, 0.3, 0.0],
            [0.0, 0.7, 0.3],
            [0.3, 0.0, 0.7],
        ], dtype=np.float32)
    elif "0.6" in base:
        T = np.array([
            [0.4, 0.3, 0.3],
            [0.3, 0.4, 0.3],
            [0.3, 0.3, 0.4],
        ], dtype=np.float32)
    else:
        raise ValueError(f"Unknown dataset name for known T: {base_name}")

    if num_classes != 3:
        raise ValueError(
            f"Current known-T implementation assumes 3 classes, got num_classes={num_classes}."
        )
    return T


#  Full training loop 
def run_coteaching_T_on_dataset(
    npz_path: str,
    epochs: int,
    runs: int,
    batch_size: int,
    lr: float,
    warmup_epochs: int = 5,
    seed0: int = 2025,
):
    Xtr, Str, Xts, Yts = load_dataset(npz_path)
    in_ch = infer_in_channels(Xtr)
    num_classes = len(np.unique(Str))
    device = get_device()

    base = os.path.basename(npz_path).lower()
    ds_name = Path(npz_path).stem

    # known average noise for FashionMNIST
    known_noise = None
    if "0.3" in base:
        known_noise = 0.3
    elif "0.6" in base:
        known_noise = 0.6

    results = []

    for r in range(runs):
        seed = seed0 + r
        set_seed(seed)
        print(f"\n==== Dataset [{ds_name}] | Run {r+1}/{runs} (Co-Teaching + T/T_hat) ====", flush=True)

        idx_all = np.arange(len(Str))
        tr_idx, va_idx = train_test_split(
            idx_all, test_size=0.2, stratify=Str, random_state=seed
        )
        X_tr, y_tr = Xtr[tr_idx], Str[tr_idx]
        X_va, y_va = Xtr[va_idx], Str[va_idx]

        # choose T
        if known_noise is not None:
            T_np = get_known_transition_matrix_from_name(base, num_classes)
            avg_noise_rate = known_noise
            print(f"[{ds_name}] using TRUE T, noise rate={avg_noise_rate:.3f}", flush=True)
        else:
            print(f"[{ds_name}] estimating T_hat by anchor method ...", flush=True)
            T_hat = estimate_transition_matrix_anchor(
                X_tr,
                y_tr,
                in_ch=in_ch,
                num_classes=num_classes,
                epochs=10,
                batch_size=256,
                top_percent=0.05,
                device=device,
            )
            avg_noise_rate = float(1.0 - np.trace(T_hat) / num_classes)
            print(f"[{ds_name}] estimated avg noise rate â {avg_noise_rate:.3f}", flush=True)
            T_np = T_hat.astype(np.float32)

        T_train = torch.from_numpy(T_np).to(device)

        ds_tr = NpzImageDataset(X_tr, y_tr)
        ds_va = NpzImageDataset(X_va, y_va)
        ds_ts = NpzImageDataset(Xts, Yts)

        train_loader = DataLoader(
            ds_tr,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0,
        )
        valid_loader = DataLoader(
            ds_va,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
        )
        test_loader = DataLoader(
            ds_ts,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
        )

        model1 = SmallCNN(in_ch=in_ch, num_classes=num_classes).to(device)
        model2 = SmallCNN(in_ch=in_ch, num_classes=num_classes).to(device)

        opt1 = torch.optim.Adam(model1.parameters(), lr=lr)
        opt2 = torch.optim.Adam(model2.parameters(), lr=lr)

        best_val = 0.0
        best_test = 0.0

        for ep in range(epochs):
            rem = schedule_remember_rate(ep, epochs, avg_noise_rate, warmup_epochs)

            run_epoch_coteaching_T(
                model1, model2, opt1, opt2, train_loader, device, rem, T_train
            )

            disc = disagreement_rate(model1, model2, valid_loader, device)
            acc1 = evaluate(model1, valid_loader, device)
            acc2 = evaluate(model2, valid_loader, device)
            val_acc = max(acc1, acc2)

            print(
                f"[{ds_name}] run {r+1}/{runs} | "
                f"ep {ep+1:03d}/{epochs} "
                f"rem={rem:.3f} "
                f"disc={disc:.3f} "
                f"val1={acc1:.1f} val2={acc2:.1f}",
                flush=True,
            )

            if val_acc > best_val:
                best_val = val_acc
                test_acc1 = evaluate(model1, test_loader, device)
                test_acc2 = evaluate(model2, test_loader, device)
                best_test = max(test_acc1, test_acc2)

        print(
            f"[{ds_name}] run {r+1}/{runs} finished | "
            f"best_test={best_test:.2f}%",
            flush=True,
        )

        results.append(best_test)

    arr = np.array(results, dtype=np.float32)
    mean = float(arr.mean())
    std = float(arr.std(ddof=1) if len(arr) > 1 else 0.0)
    print(f"\n[{ds_name}] TEST Top-1 over {runs} runs: {mean:.2f} Â± {std:.2f} %", flush=True)


def main(argv=None):
    p = argparse.ArgumentParser(description="Co-Teaching + T/T_hat (minimal core version)")
    p.add_argument(
        "--datasets",
        type=str,
        nargs="+",
        default=["FashionMNIST0.3.npz", "FashionMNIST0.6.npz", "CIFAR.npz"],
    )
    p.add_argument("--epochs", type=int, default=80)
    p.add_argument("--runs", type=int, default=10)
    p.add_argument("--batch_size", type=int, default=128)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--warmup_epochs", type=int, default=5)

    if argv is None:
        args, _ = p.parse_known_args()
    else:
        args = p.parse_args(argv)

    print("Device:", get_device(), flush=True)

    for name in args.datasets:
        fpath = Path(name)
        if not fpath.exists() and fpath.suffix == "":
            alt = fpath.with_suffix(".npz")
            if alt.exists():
                fpath = alt
        if not fpath.exists():
            print(f"[WARN] Missing {fpath}. Skipping.", flush=True)
            continue

        run_coteaching_T_on_dataset(
            str(fpath),
            epochs=args.epochs,
            runs=args.runs,
            batch_size=args.batch_size,
            lr=args.lr,
            warmup_epochs=args.warmup_epochs,
        )


if __name__ == "__main__":
    main()

Device: mps

==== Dataset [FashionMNIST0.3] | Run 1/10 (Co-Teaching + T/T_hat) ====
[FashionMNIST0.3] using TRUE T, noise rate=0.300
[FashionMNIST0.3] run 1/10 | ep 001/80 rem=1.000 disc=0.008 val1=69.1 val2=69.1
[FashionMNIST0.3] run 1/10 | ep 002/80 rem=1.000 disc=0.010 val1=69.3 val2=69.3
[FashionMNIST0.3] run 1/10 | ep 003/80 rem=1.000 disc=0.009 val1=69.3 val2=69.2
[FashionMNIST0.3] run 1/10 | ep 004/80 rem=1.000 disc=0.012 val1=69.1 val2=69.6
[FashionMNIST0.3] run 1/10 | ep 005/80 rem=1.000 disc=0.009 val1=69.6 val2=69.4
[FashionMNIST0.3] run 1/10 | ep 006/80 rem=0.996 disc=0.010 val1=69.4 val2=69.5
[FashionMNIST0.3] run 1/10 | ep 007/80 rem=0.992 disc=0.009 val1=69.5 val2=69.3
[FashionMNIST0.3] run 1/10 | ep 008/80 rem=0.988 disc=0.006 val1=69.6 val2=69.6
[FashionMNIST0.3] run 1/10 | ep 009/80 rem=0.984 disc=0.008 val1=69.6 val2=69.5
[FashionMNIST0.3] run 1/10 | ep 010/80 rem=0.980 disc=0.006 val1=69.4 val2=69.6
[FashionMNIST0.3] run 1/10 | ep 011/80 rem=0.976 disc=0.012 val1=69