# Co-Teaching++ (GPU, Forget-Rate Schedule + Dropout/BN)

# 1.导入库并使用GPU

In [3]:
# ==== Deterministic setup (reproducibility) ====
import os, random, numpy as np, torch, math
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PIN_MEMORY = bool(torch.cuda.is_available())
NUM_WORKERS = 0  # keep 0 to avoid multi-process RNG issues unless necessary

# Make cuDNN deterministic globally
try:
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
except Exception as _e:
    pass

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    try:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    except Exception as _e:
        pass
print("Device:", DEVICE)
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

Device: cuda
GPU: NVIDIA GeForce RTX 2070


# 2.数据加载与预处理

In [5]:
def set_seed(seed=42):
    import os, random, numpy as np, torch
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def _infer_hw_c_from_flat_dim(D:int):
    s = int(np.sqrt(D) + 1e-8)
    if s*s == D: return s, s, 1
    if D % 3 == 0:
        s3 = int(np.sqrt(D//3) + 1e-8)
        if s3*s3*3 == D: return s3, s3, 3
    return None, None, None

def load_npz(path:str):
    d = np.load(path)
    Xtr, Str = d["Xtr"], d["Str"]
    Xts, Yts = d["Xts"], d["Yts"]
    if Xtr.ndim == 4 and Xtr.shape[-1]==3: H,W,C = Xtr.shape[1], Xtr.shape[2], 3
    elif Xtr.ndim == 3: H,W,C = Xtr.shape[1], Xtr.shape[2], 1
    elif Xtr.ndim == 2:
        H,W,C = _infer_hw_c_from_flat_dim(Xtr.shape[1])
        if H is None: raise ValueError("Cannot infer HWC")
    else: raise ValueError(f"Bad shape: {Xtr.shape}")
    num_classes = int(max(Str.max(), Yts.max())+1)
    return (Xtr, Str, Xts, Yts, (H,W,C), num_classes)

class NPZImageDataset(Dataset):
    def __init__(self, X, y, shape_hw_c):
        X = X.astype(np.float32)
        H,W,C = shape_hw_c
        if X.ndim == 2:
            X = X.reshape(-1,H,W) if C==1 else X.reshape(-1,H,W,C)
        if X.max()>1.5: X = X/255.0
        if X.ndim == 3: X = X[:,None,:,:]
        elif X.ndim == 4: X = np.transpose(X,(0,3,1,2))
        X = (X - X.mean())/(X.std()+1e-6)
        self.X, self.y = X, y.astype(np.int64)
    def __len__(self): return len(self.X)
    def __getitem__(self,i): return torch.from_numpy(self.X[i]), int(self.y[i])


# 3.定义卷积神经网络 CNN 模型

In [7]:
class CNN_BN_Drop(nn.Module):
    def __init__(self, in_ch=1, num_classes=10, p_drop=0.4):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_ch, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(p_drop),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(p_drop),
        )
        self.adapt = nn.AdaptiveAvgPool2d((7,7))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128*7*7, 256), nn.ReLU(),
            nn.Dropout(p_drop),
            nn.Linear(256, num_classes)
        )
    def forward(self,x):
        x = self.features(x)
        x = self.adapt(x)
        return self.classifier(x)

def make_model(in_ch, C, p_drop=0.4):
    return CNN_BN_Drop(in_ch=in_ch, num_classes=C, p_drop=p_drop)


# 4.算法训练与评估

In [9]:
def topk_indices_by_small_loss(logits, y, keep_ratio):
    losses = F.cross_entropy(logits, y, reduction='none')
    k = max(1, int(keep_ratio * len(losses)))
    return torch.topk(-losses, k=k).indices

def coteach_epoch(model1, model2, loader, opt1, opt2, keep_ratio):
    model1.train(); model2.train(); total=0.0
    for x,y in loader:
        x,y = x.to(DEVICE), y.to(DEVICE)
        l1, l2 = model1(x), model2(x)
        idx1 = topk_indices_by_small_loss(l1, y, keep_ratio)
        idx2 = topk_indices_by_small_loss(l2, y, keep_ratio)
        loss1 = F.cross_entropy(l1[idx2], y[idx2])
        loss2 = F.cross_entropy(l2[idx1], y[idx1])
        opt1.zero_grad(); loss1.backward(); opt1.step()
        opt2.zero_grad(); loss2.backward(); opt2.step()
        total += (loss1.item()+loss2.item())/2
    return total/len(loader)

@torch.no_grad()
def evaluate(model, loader):
    model.eval(); corr=0; tot=0
    for x,y in loader:
        x,y = x.to(DEVICE), y.to(DEVICE)
        pred = model(x).argmax(1)
        corr += (pred==y).sum().item(); tot += y.size(0)
    return corr/tot

def linear_schedule(epoch, max_epoch, final_forget):
    return min(final_forget, final_forget * epoch / max(1, max_epoch))

def cosine_schedule(epoch, max_epoch, final_forget):
    t = epoch / max(1, max_epoch)
    return final_forget * (0.5 - 0.5*math.cos(math.pi * t))


# 5.训练配置：Co-teaching的单独训练

In [11]:
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

def train_config_coteach_scheduled(
    dataset_path, epochs, wd, lr, seed=2025,
    final_forget=0.6, sched='linear', p_drop=0.4,
    warmup_epochs=8,
    tr_idx=None, va_idx=None,
    generator=None
):
    # 固定所有随机性
    set_seed(seed)

    # 加载数据
    Xtr, Str, Xts, Yts, hwc, C = load_npz(dataset_path)

    # 若外部未传入折索引，则使用固定8:2划分
    if tr_idx is None or va_idx is None:
        tr_idx, va_idx = train_test_split(
            np.arange(len(Str)), test_size=0.2, stratify=Str, random_state=seed
        )

    # 构建数据集
    tr = NPZImageDataset(Xtr[tr_idx], Str[tr_idx], shape_hw_c=hwc)
    va = NPZImageDataset(Xtr[va_idx], Str[va_idx], shape_hw_c=hwc)
    ts = NPZImageDataset(Xts, Yts, shape_hw_c=hwc)

    # DataLoader（训练集传入 generator 锁定 shuffle）
    tr_loader = DataLoader(tr, batch_size=256, shuffle=True,
                           num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                           generator=generator)
    va_loader = DataLoader(va, batch_size=256, shuffle=False,
                           num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    ts_loader = DataLoader(ts, batch_size=256, shuffle=False,
                           num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

    # 初始化模型和优化器
    in_ch = hwc[2]
    m1 = make_model(in_ch, C, p_drop).to(DEVICE)
    m2 = make_model(in_ch, C, p_drop).to(DEVICE)
    opt1 = torch.optim.AdamW(m1.parameters(), lr=lr, weight_decay=wd)
    opt2 = torch.optim.AdamW(m2.parameters(), lr=lr, weight_decay=wd)

    # ======= 训练 =======
    for ep in range(1, epochs + 1):
        if ep <= warmup_epochs:
            m1.train(); m2.train()
            for x, y in tr_loader:
                x, y = x.to(DEVICE), y.to(DEVICE)
                opt1.zero_grad(set_to_none=True)
                F.cross_entropy(m1(x), y).backward()
                opt1.step()
                opt2.zero_grad(set_to_none=True)
                F.cross_entropy(m2(x), y).backward()
                opt2.step()
            continue

        # Warmup 结束后进入 Co-Teaching
        t_ep = ep - warmup_epochs
        T_total = max(1, epochs - warmup_epochs)
        if sched == 'cosine':
            forget = cosine_schedule(t_ep, T_total, final_forget)
        else:
            forget = linear_schedule(t_ep, T_total, final_forget)

        keep_ratio = float(min(1.0, max(0.0, 1.0 - forget)))
        coteach_epoch(m1, m2, tr_loader, opt1, opt2, keep_ratio)

    # ======= 评估 =======
    va_acc = max(evaluate(m1, va_loader), evaluate(m2, va_loader))
    ts_acc = max(evaluate(m1, ts_loader), evaluate(m2, ts_loader))
    val_loss = 1.0 - va_acc
    return val_loss, ts_acc


# Co-teaching的十折训练

In [12]:
from sklearn.model_selection import KFold

def tune_and_report_coteach_plus_cv(
    dataset_path,
    lr=1e-3,
    it_values=(30,),
    wd_values=(5e-2,),
    seed=42,
    final_forget=0.6,
    sched='linear',
    p_drop=0.5,
    num_folds=10
):
    print(f"==== 10-Fold CV on Dataset: {dataset_path} ====")

    # 固定主随机种子
    set_seed(seed)

    # 一次性加载数据
    Xtr, Str, Xts, Yts, hwc, C = load_npz(dataset_path)
    kf = KFold(n_splits=num_folds, shuffle=True, random_state=seed)

    fold_accs, fold_losses = [], []

    for fold, (train_idx, val_idx) in enumerate(kf.split(Xtr)):
        print(f"\n===== Fold {fold + 1}/{num_folds} =====")
        # 每折单独 generator，确保 shuffle 顺序一致
        g = torch.Generator(device='cpu').manual_seed(seed + fold)

        best = None
        for it in it_values:
            for wd in wd_values:
                vloss, tacc = train_config_coteach_scheduled(
                    dataset_path=dataset_path,
                    epochs=it, wd=wd, lr=lr,
                    seed=seed + fold,              # 每折固定种子
                    final_forget=final_forget,
                    sched=sched,
                    p_drop=p_drop,
                    tr_idx=train_idx, va_idx=val_idx,
                    generator=g
                )
                print(f"Fold {fold+1} | wd={wd}, it={it} | val_loss={vloss:.4f}, test_acc={tacc*100:.2f}%")

                if (best is None) or (vloss < best[0] - 1e-12) or (abs(vloss - best[0]) < 1e-12 and tacc > best[1]):
                    best = (vloss, tacc, wd, it)

        fold_losses.append(best[0])
        fold_accs.append(best[1])

    mean_acc = float(np.mean(fold_accs))
    std_acc  = float(np.std(fold_accs))
    mean_loss = float(np.mean(fold_losses))

    print("\n==== 10-Fold Cross Validation Result ====")
    print(f"Mean Val Loss: {mean_loss:.4f}")
    print(f"Mean Test Accuracy: {mean_acc*100:.2f}% ± {std_acc*100:.2f}%")
    return mean_loss, mean_acc


# 6.网格调参 + 报告最佳结果

In [14]:
def tune_and_report_coteach_plus(
    dataset_path,
    lr=1e-3,
    it_values=(500,),
    wd_values=(1e-4, 1e-3, 5e-2, 1e-2, 5e-1, 1e-1),
    seed=42,
    final_forget=0.6,
    sched='linear',
    p_drop=0.4,
    return_best=False
):
    print(f"==== Dataset: {dataset_path} ====")
    print("Tuned configs (both orientations):")
    best = None
    set_seed(seed)

    for it in it_values:
        for wd in wd_values:
            g = torch.Generator(device='cpu').manual_seed(seed)
            vloss, tacc = train_config_coteach_scheduled(
                dataset_path=dataset_path,
                epochs=it,
                wd=wd,
                lr=lr,
                seed=seed,
                final_forget=final_forget,
                sched=sched,
                p_drop=p_drop,
                generator=g
            )
            print(f"wd={wd}, it={it} | val_loss={vloss:.4f}, test_acc={tacc*100:.2f}%")
            if (best is None) or (tacc > best[1] + 1e-12) or (abs(tacc - best[1]) < 1e-12 and vloss < best[0]):
                best = (vloss, tacc, wd, it)
    print(f"** Best: wd={best[2]}, it={best[3]} | test_acc={best[1]*100:.2f}%, val_loss={best[0]:.4f}")

In [15]:
# 指定你的数据集路径
DATASET_1 = "datasets/FashionMNIST0.3.npz"
DATASET_2 = "datasets/FashionMNIST0.6.npz"
DATASET_3 = "datasets/CIFAR.npz"

# 调参过程

In [17]:
tune_and_report_coteach_plus(
    DATASET_1, 
    lr=1e-3, 
    it_values=(10,),
    wd_values=(5e-2,1e-2,5e-1,1e-1), 
    seed=42,
    final_forget=0.3, 
    sched='linear', 
    p_drop=0.5
)

==== Dataset: datasets/FashionMNIST0.3.npz ====
Tuned configs (both orientations):
wd=0.05, it=10 | val_loss=0.3164, test_acc=98.60%
wd=0.01, it=10 | val_loss=0.3158, test_acc=98.57%
wd=0.5, it=10 | val_loss=0.3153, test_acc=98.30%
wd=0.1, it=10 | val_loss=0.3150, test_acc=98.63%
** Best: wd=0.1, it=10 | test_acc=98.63%, val_loss=0.3150


In [18]:
tune_and_report_coteach_plus(
    DATASET_2, 
    lr=1e-4, 
    it_values=(30,),
    wd_values=(1e-4,5e-1,1e-1,), 
    seed=42,
    final_forget=0.75, 
    sched='linear', 
    p_drop=0.5
)

==== Dataset: datasets/FashionMNIST0.6.npz ====
Tuned configs (both orientations):
wd=0.0001, it=30 | val_loss=0.6228, test_acc=97.13%
wd=0.5, it=30 | val_loss=0.6217, test_acc=97.13%
wd=0.1, it=30 | val_loss=0.6206, test_acc=96.53%
** Best: wd=0.5, it=30 | test_acc=97.13%, val_loss=0.6217


In [33]:
tune_and_report_coteach_plus(
    DATASET_3,
    lr=1e-4, 
    it_values=(10,), 
    wd_values=(5e-2,1e-4,5e-1,1e-1),
    seed=42, 
    final_forget=0.74,
    sched='linear', 
    p_drop=0.5
)

==== Dataset: datasets/CIFAR.npz ====
Tuned configs (both orientations):
wd=0.05, it=10 | val_loss=0.6520, test_acc=41.73%
wd=0.0001, it=10 | val_loss=0.6530, test_acc=37.07%
wd=0.5, it=10 | val_loss=0.6390, test_acc=51.37%
wd=0.1, it=10 | val_loss=0.6667, test_acc=33.33%
** Best: wd=0.5, it=10 | test_acc=51.37%, val_loss=0.6390


In [35]:
tune_and_report_coteach_plus(
    DATASET_3,
    lr=1e-4, 
    it_values=(10,20,30), 
    wd_values=(5e-1,),
    seed=42, 
    final_forget=0.74,
    sched='linear', 
    p_drop=0.5
)

==== Dataset: datasets/CIFAR.npz ====
Tuned configs (both orientations):
wd=0.5, it=10 | val_loss=0.6453, test_acc=34.50%
wd=0.5, it=20 | val_loss=0.6250, test_acc=53.33%
wd=0.5, it=30 | val_loss=0.6130, test_acc=72.47%
** Best: wd=0.5, it=30 | test_acc=72.47%, val_loss=0.6130


# 最佳参数

In [37]:
tune_and_report_coteach_plus_cv(
    dataset_path=DATASET_1,
    lr=1e-3,
    it_values=(10,),
    wd_values=(1e-1,),
    seed=42,
    final_forget=0.3,
    sched='linear',
    p_drop=0.5,
    num_folds=10
)

==== 10-Fold CV on Dataset: datasets/FashionMNIST0.3.npz ====

===== Fold 1/10 =====
Fold 1 | wd=0.1, it=10 | val_loss=0.3039, test_acc=98.63%
** Best in Fold 1: wd=0.1, it=10 | val_loss=0.3039, test_acc=98.63%

===== Fold 2/10 =====
Fold 2 | wd=0.1, it=10 | val_loss=0.2972, test_acc=98.60%
** Best in Fold 2: wd=0.1, it=10 | val_loss=0.2972, test_acc=98.60%

===== Fold 3/10 =====
Fold 3 | wd=0.1, it=10 | val_loss=0.3178, test_acc=98.83%
** Best in Fold 3: wd=0.1, it=10 | val_loss=0.3178, test_acc=98.83%

===== Fold 4/10 =====
Fold 4 | wd=0.1, it=10 | val_loss=0.3011, test_acc=98.47%
** Best in Fold 4: wd=0.1, it=10 | val_loss=0.3011, test_acc=98.47%

===== Fold 5/10 =====
Fold 5 | wd=0.1, it=10 | val_loss=0.3022, test_acc=98.57%
** Best in Fold 5: wd=0.1, it=10 | val_loss=0.3022, test_acc=98.57%

===== Fold 6/10 =====
Fold 6 | wd=0.1, it=10 | val_loss=0.3072, test_acc=98.73%
** Best in Fold 6: wd=0.1, it=10 | val_loss=0.3072, test_acc=98.73%

===== Fold 7/10 =====
Fold 7 | wd=0.1, it=1

(0.3087777777777778, 0.9868333333333335)

In [41]:
tune_and_report_coteach_plus_cv(
    dataset_path=DATASET_2,
    lr=1e-4,
    it_values=(30,),
    wd_values=(5e-1,),
    seed=42,
    final_forget=0.75,
    sched='linear',
    p_drop=0.5,
    num_folds=10
)

==== 10-Fold CV on Dataset: datasets/FashionMNIST0.6.npz ====

===== Fold 1/10 =====
Fold 1 | wd=0.5, it=30 | val_loss=0.6133, test_acc=96.33%
** Best in Fold 1: wd=0.5, it=30 | val_loss=0.6133, test_acc=96.33%

===== Fold 2/10 =====
Fold 2 | wd=0.5, it=30 | val_loss=0.6061, test_acc=95.63%
** Best in Fold 2: wd=0.5, it=30 | val_loss=0.6061, test_acc=95.63%

===== Fold 3/10 =====
Fold 3 | wd=0.5, it=30 | val_loss=0.6200, test_acc=97.53%
** Best in Fold 3: wd=0.5, it=30 | val_loss=0.6200, test_acc=97.53%

===== Fold 4/10 =====
Fold 4 | wd=0.5, it=30 | val_loss=0.6233, test_acc=96.57%
** Best in Fold 4: wd=0.5, it=30 | val_loss=0.6233, test_acc=96.57%

===== Fold 5/10 =====
Fold 5 | wd=0.5, it=30 | val_loss=0.6139, test_acc=95.90%
** Best in Fold 5: wd=0.5, it=30 | val_loss=0.6139, test_acc=95.90%

===== Fold 6/10 =====
Fold 6 | wd=0.5, it=30 | val_loss=0.5961, test_acc=95.47%
** Best in Fold 6: wd=0.5, it=30 | val_loss=0.5961, test_acc=95.47%

===== Fold 7/10 =====
Fold 7 | wd=0.5, it=3

(0.6141666666666665, 0.9305)

In [44]:
tune_and_report_coteach_plus_cv(
    dataset_path=DATASET_3,
    lr=1e-4,
    it_values=(30,), 
    wd_values=(5e-1,),
    seed=42,
    final_forget=0.74,
    sched='linear',
    p_drop=0.5,
    num_folds=10
)

==== 10-Fold CV on Dataset: datasets/CIFAR.npz ====

===== Fold 1/10 =====
Fold 1 | wd=0.5, it=30 | val_loss=0.6113, test_acc=65.23%
** Best in Fold 1: wd=0.5, it=30 | val_loss=0.6113, test_acc=65.23%

===== Fold 2/10 =====
Fold 2 | wd=0.5, it=30 | val_loss=0.6200, test_acc=68.43%
** Best in Fold 2: wd=0.5, it=30 | val_loss=0.6200, test_acc=68.43%

===== Fold 3/10 =====
Fold 3 | wd=0.5, it=30 | val_loss=0.6213, test_acc=65.53%
** Best in Fold 3: wd=0.5, it=30 | val_loss=0.6213, test_acc=65.53%

===== Fold 4/10 =====
Fold 4 | wd=0.5, it=30 | val_loss=0.6280, test_acc=61.07%
** Best in Fold 4: wd=0.5, it=30 | val_loss=0.6280, test_acc=61.07%

===== Fold 5/10 =====
Fold 5 | wd=0.5, it=30 | val_loss=0.6207, test_acc=64.37%
** Best in Fold 5: wd=0.5, it=30 | val_loss=0.6207, test_acc=64.37%

===== Fold 6/10 =====
Fold 6 | wd=0.5, it=30 | val_loss=0.6153, test_acc=56.40%
** Best in Fold 6: wd=0.5, it=30 | val_loss=0.6153, test_acc=56.40%

===== Fold 7/10 =====
Fold 7 | wd=0.5, it=30 | val_lo

(0.6208666666666668, 0.6308666666666667)