## Load Library

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from typing import Tuple, Dict, Any, Optional
from dataclasses import dataclass
import tqdm
import matplotlib.pyplot as plt

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data loading

In [3]:
def load_dataset(file_path, val_ratio=0.2, random_state=42):
    dataset = np.load(file_path)
    Xtr, Str = dataset['Xtr'], dataset['Str']
    Xts, Yts = dataset['Xts'], dataset['Yts']

    # Shuffle & split (80% train, 20% validation)
    np.random.seed(random_state)
    indices = np.arange(len(Str))
    np.random.shuffle(indices)

    split_idx = int(len(Str) * (1 - val_ratio))
    train_idx, val_idx = indices[:split_idx], indices[split_idx:]

    X_train, y_train = Xtr[train_idx], Str[train_idx]
    X_val, y_val = Xtr[val_idx], Str[val_idx]

    return X_train, y_train, X_val, y_val, Xts, Yts

def reshape_mnist(*arrays):
    reshaped = []
    for arr in arrays:
        if arr.ndim == 1:
            arr = torch.tensor(arr, dtype=torch.long)
        elif arr.ndim >= 2:
            arr = arr.reshape(-1, 1, 28, 28)
            arr = torch.tensor(arr, dtype=torch.float32) / 255.0
        reshaped.append(arr)
    return tuple(reshaped)

def reshape_cifar(*arrays):
    reshaped = []
    for arr in arrays:
        if arr.ndim == 1:
            arr = torch.tensor(arr, dtype=torch.long)
        elif arr.ndim >= 2:
            arr = np.transpose(arr, (0, 3, 1, 2))
            arr = torch.tensor(arr, dtype=torch.float32) / 255.0
        reshaped.append(arr)
    return tuple(reshaped)


In [4]:
Xtr_03, Str_03, Xval_03, Sval_03, Xts_03, Yts_03 = reshape_mnist(*load_dataset('datasets/FashionMNIST0.3.npz'))
Xtr_06, Str_06, Xval_06, Sval_06, Xts_06, Yts_06 = reshape_mnist(*load_dataset('datasets/FashionMNIST0.6.npz'))
Xtr_cifar, Str_cifar, Xval_cifar, Sval_cifar, Xts_cifar, Yts_cifar = reshape_cifar(*load_dataset('datasets/CIFAR.npz'))

In [None]:
print("Xtr_03:", Xtr_03.shape)
print("Xtr_06:", Xtr_06.shape)
print("Xtr_cifar:", Xtr_cifar.shape)


In [None]:
# Mnist data checking
plt.imshow(Xtr_03[114, -1, :, :], cmap='gray') #[pic number,_,_,_]
plt.title(f"Label: {Str_03[0]}")
plt.show()

In [None]:
# CIFAR data checking
plt.imshow(np.transpose(Xtr_cifar[514], (1, -1, 0)))
plt.title(f"Label: {Str_cifar[0]}")
plt.show()

4# Main Program

## Classifier: ResNet-34

In [5]:
# ---------------------------
# Config
# ---------------------------
@dataclass
class Config:
    epochs: int = 60
    batch_size: int = 128
    lr: float = 1e-2
    momentum: float = 0.9
    weight_decay: float = 1e-3
    milestones: Tuple[int, int] = (30, 60)
    lam: float = 0.3                 # weight for cycle term
    label_smoothing: float = 0.0
    diag_reg_weight: float = 0.0     # >0 to enforce diagonal-dominance
    diag_margin: float = 0.05
    use_tqdm: bool = True
    tqdm_leave: bool = False
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
# ---------------------------
# ResNet-34 (classifier-oriented)
# ---------------------------
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, 3, stride, 1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = None
        if stride != 1 or in_planes != planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, planes, 1, stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)



class ResNet34(nn.Module):
    def __init__(self, in_channels: int, num_classes: int):
        super().__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(in_channels, 64, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(64, 3)
        self.layer2 = self._make_layer(128, 4, stride=2)
        self.layer3 = self._make_layer(256, 6, stride=2)
        self.layer4 = self._make_layer(512, 3, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, planes, blocks, stride=1):
        layers = [BasicBlock(self.inplanes, planes, stride)]
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.inplanes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)


In [None]:
'''
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes,mode='cifar10'):
        super(ResNet, self).__init__()
        self.in_planes = 64
        if mode == 'mnist':
            self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=0, bias=False)
        else:
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, revision=True):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)

        out = self.linear(out)

        clean = F.softmax(out, 1)

        return clean
    def ResNet34(num_classes):
        return ResNet(BasicBlock, [3,4,6,3], num_classes)
'''

In [7]:
# ---------------------------
# Column-stochastic T/T'
# ---------------------------
class ColumnStochastic(nn.Module):
    """
    Column-stochastic via column-wise softmax. Init near identity to encourage diagonal dominance.
    """
    def __init__(self, num_classes: int, init_identity: bool = True):
        super().__init__()
        logits = torch.zeros(num_classes, num_classes)
        if init_identity:
            logits += -4.0
            for i in range(num_classes):
                logits[i, i] = 4.0
        self.logits = nn.Parameter(logits)

    def forward(self):
        return F.softmax(self.logits, dim=0)  # column-wise softmax

    @torch.no_grad()
    def as_numpy(self):
        return self.forward().detach().cpu().numpy()

    def diag_dominance_loss(self, margin: float = 0.05):
        T = self.forward().detach()
        C = T.shape[0]
        loss = 0.0
        for j in range(C):
            diag = T[j, j]
            off = torch.cat([T[:j, j], T[j+1:, j]]) if C > 1 else T[j:j+1, j]
            max_off = off.max() if off.numel() > 0 else torch.tensor(0.0, device=off.device)
            loss += F.relu(max_off - diag + margin)
        return loss


# ---------------------------
# Utils
# ---------------------------
def one_hot(idx: torch.Tensor, num_classes: int, smoothing: float = 0.0):
    y = F.one_hot(idx, num_classes).float()
    return y * (1.0 - smoothing) + smoothing / num_classes if smoothing > 0 else y

def ce_soft(pred_probs: torch.Tensor, target_probs: torch.Tensor, eps: float = 1e-12):
    pred = torch.clamp(pred_probs, eps, 1.0)
    return -(target_probs * pred.log()).sum(dim=1).mean()

@torch.no_grad()
def top1_acc(logits: torch.Tensor, y: torch.Tensor):
    return (logits.argmax(1) == y).float().mean().item()


# ---------------------------
# CCR Trainer (class-based, tqdm)
# ---------------------------
class CCRTrainer:
    """
    Train single ResNet-34 classifier with CCR losses:
      L1 = CE(y~, T @ p_clean)
      L2 = CE(p_clean, T' @ y~)
      L3 = CE(p_clean, T' @ (T @ p_clean))
      Total: L = L1 + L2 + lam * L3
    """
    def __init__(self, input_shape: Tuple[int, int, int], num_classes: int, cfg: Config):
        self.cfg = cfg
        C_in, _, _ = input_shape
        self.device = torch.device(cfg.device)
        self.model = ResNet34(C_in, num_classes).to(self.device)
        self.T = ColumnStochastic(num_classes, init_identity=True).to(self.device)
        self.Tp = ColumnStochastic(num_classes, init_identity=True).to(self.device)

        params = list(self.model.parameters()) + list(self.T.parameters()) + list(self.Tp.parameters())
        self.opt = torch.optim.SGD(params, lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
        self.sched = torch.optim.lr_scheduler.MultiStepLR(self.opt, milestones=list(cfg.milestones), gamma=0.1)

    def _train_one_epoch(self, loader: DataLoader, epoch: int) -> float:
        self.model.train(); self.T.train(); self.Tp.train()
        running = 0.0
        iterator = tqdm(loader, desc=f"Epoch {epoch}/{self.cfg.epochs}", leave=self.cfg.tqdm_leave) \
                  if self.cfg.use_tqdm else loader

        for it, (xb, yb) in enumerate(iterator, start=1):
            xb, yb = xb.to(self.device), yb.to(self.device)
            self.opt.zero_grad()

            logits = self.model(xb)
            p_clean = F.softmax(logits, dim=1)

            T = self.T()
            Tp = self.Tp()

            # L1: CE(y~, T @ p_clean)
            p_noisy = (T @ p_clean.T).T
            L1 = F.nll_loss(torch.log(torch.clamp(p_noisy, 1e-12, 1.0)), yb)

            # L2: CE(p_clean, T' @ y~)
            y_prob = one_hot(yb, p_clean.shape[1], smoothing=self.cfg.label_smoothing)
            proj_clean = (Tp @ y_prob.T).T
            L2 = ce_soft(p_clean, proj_clean)

            # L3: CE(p_clean, T'(T p_clean))
            cyc = (Tp @ (T @ p_clean.T)).T
            L3 = ce_soft(p_clean, cyc)

            loss = L1 + L2 + self.cfg.lam * L3

            if self.cfg.diag_reg_weight > 0.0:
                loss = loss + self.cfg.diag_reg_weight * (
                    self.T.diag_dominance_loss(self.cfg.diag_margin) +
                    self.Tp.diag_dominance_loss(self.cfg.diag_margin)
                )

            loss.backward()
            self.opt.step()

            running += float(loss.item())
            if self.cfg.use_tqdm:
                self._set_postfix(iterator, it, running)

        return running / max(1, it)

    def _set_postfix(self, iterator, it, running):
        lr_curr = self.opt.param_groups[0]["lr"]
        iterator.set_postfix(loss=f"{running/it:.4f}", lr=f"{lr_curr:.3e}")

    @torch.no_grad()
    def _eval_loader(self, loader: DataLoader) -> float:
        self.model.eval()
        accs = []
        iterator = tqdm(loader, desc="Eval", leave=self.cfg.tqdm_leave) if self.cfg.use_tqdm else loader
        for xb, yb in iterator:
            xb, yb = xb.to(self.device), yb.to(self.device)
            logits = self.model(xb)
            accs.append(top1_acc(logits, yb))
        return float(np.mean(accs)) if len(accs) > 0 else 0.0

    def fit(self, train_loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader) -> Dict[str, Any]:
        best_val, best_test = 0.0, 0.0
        for ep in range(1, self.cfg.epochs + 1):
            train_loss = self._train_one_epoch(train_loader, ep)
            self.sched.step()

            val_acc = self._eval_loader(val_loader)
            test_acc = self._eval_loader(test_loader)
            if val_acc > best_val:
                best_val, best_test = val_acc, test_acc

            print(f"Epoch {ep:03d} | train_loss={train_loss:.4f}  val={val_acc*100:.2f}% "
                  f"test={test_acc*100:.2f}%  best@val={best_test*100:.2f}%")

        return {
            "best_val_acc": best_val,
            "best_test_acc": best_test,
            "T": self.T.as_numpy(),
            "Tp": self.Tp.as_numpy(),
            "state_dict": self.model.state_dict()
        }

In [None]:
# =============================================================
# 🔬 Quick Test: ONLY use mnist0.3 tensors already loaded above
#     (append this block at the very bottom of your file)
# =============================================================
if __name__ == "__main__":
    from torch.utils.data import TensorDataset, DataLoader

    # ---- 0) 绑定 tqdm：兼容你当前的 `import tqdm` 顶部导入写法 ----
    try:
        # 如果是 `import tqdm`
        if hasattr(tqdm, "tqdm"):
            tqdm = tqdm.tqdm
    except NameError:
        # 兜底：如果上面没导入成功，就用更智能的 auto 版本
        from tqdm.auto import tqdm  # noqa: F401

    # ---- 1) 选择 mnist0.3 的“已载入张量作为网络输入” ----
    Xtr, Str = Xtr_cifar, Str_cifar
    Xva, Sva = Xval_cifar, Sval_cifar
    Xte, Yte = Xts_cifar, Yts_cifar

    # ---- 2) 基本校验：保证后面训练真的用到你载入的张量 ----


    # ---- 3) 构建 DataLoader（仅使用 mnist0.3 这套）----
    def make_loader(X, y, bs=128, shuffle=False, num_workers=2):
        ds = TensorDataset(X, y)
        return DataLoader(ds, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers)

    batch_size = 128
    train_loader = make_loader(Xtr, Str, bs=batch_size, shuffle=True)
    val_loader   = make_loader(Xva, Sva, bs=batch_size)
    test_loader  = make_loader(Xte, Yte, bs=batch_size)

    # ---- 4) 打印一个 batch，确认网络真实输入的是你载入的张量 ----
    xb_chk, yb_chk = next(iter(train_loader))
    print(f"\nSanity batch -> xb: {tuple(xb_chk.shape)} (min={xb_chk.min():.3f}, max={xb_chk.max():.3f}), "
          f"yb: {tuple(yb_chk.shape)}, labels in [ {int(yb_chk.min())} , {int(yb_chk.max())} ]")

    # ---- 5) 组装配置，跑一个短训练（冒烟测试），确认流程正确 ----
    input_shape = tuple(Xtr.shape[1:])                # (1,28,28)
    num_classes = int(Str.max().item()) + 1

    cfg = Config(
        epochs=60,                 # 先小跑 5 个 epoch 验证流程；确认无误再改回 60
        batch_size=batch_size,
        lr=1e-2,
        momentum=0.9,
        weight_decay=1e-3,
        milestones=(30, 60),
        lam=0.3,
        label_smoothing=0.0,
        diag_reg_weight=0.0,
        diag_margin=0.05,
        use_tqdm=True,
        tqdm_leave=False,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )

    print(f"\n🚀 Start CCR on MNIST0.3 with your loaded tensors "
          f"(input={input_shape}, classes={num_classes}, device={cfg.device})")
    trainer = CCRTrainer(input_shape, num_classes, cfg)

    # ---- 6) 训练并汇报结果（确保输出与你预期格式一致）----
    result = trainer.fit(train_loader, val_loader, test_loader)




Sanity batch -> xb: (128, 3, 32, 32) (min=0.000, max=1.000), yb: (128,), labels in [ 0 , 2 ]

🚀 Start CCR on MNIST0.3 with your loaded tensors (input=(3, 32, 32), classes=3, device=cuda)


Epoch 1/60:   0%|          | 0/94 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Epoch 001 | train_loss=2.8359  val=33.38% test=33.23%  best@val=33.23%


Epoch 2/60:   0%|          | 0/94 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Epoch 002 | train_loss=2.5763  val=35.45% test=34.88%  best@val=34.88%


Epoch 3/60:   0%|          | 0/94 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Epoch 003 | train_loss=2.5341  val=35.44% test=53.12%  best@val=34.88%


Epoch 4/60:   0%|          | 0/94 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Epoch 004 | train_loss=2.5217  val=34.66% test=41.04%  best@val=34.88%


Epoch 5/60:   0%|          | 0/94 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Epoch 005 | train_loss=2.5225  val=36.93% test=55.06%  best@val=55.06%


Epoch 6/60:   0%|          | 0/94 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Epoch 006 | train_loss=2.5167  val=37.15% test=43.76%  best@val=43.76%


Epoch 7/60:   0%|          | 0/94 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Eval:   0%|          | 0/24 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x000001FEBB491DC0>
Traceback (most recent call last):
  File "d:\Anaconda3\envs\my_pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 1663, in __del__
    self._shutdown_workers()
  File "d:\Anaconda3\envs\my_pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 1627, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "d:\Anaconda3\envs\my_pytorch\lib\multiprocessing\process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "d:\Anaconda3\envs\my_pytorch\lib\multiprocessing\popen_spawn_win32.py", line 108, in wait
    res = _winapi.WaitForSingleObject(int(self._handle), msecs)
KeyboardInterrupt: 


In [11]:

    print("\n✅ Test (MNIST0.3) complete.")
    print(f"Best Test Accuracy @ Val = {result['best_test_acc']*100:.2f}%")

    # 可选：检查学到的转移矩阵
    T_est, Tp_est = result["T"], result["Tp"]
    np.set_printoptions(precision=4, suppress=True)
    print("\nForward T (shape:", T_est.shape, ") first 3 cols:\n", T_est[:, :min(3, T_est.shape[1])])
    print("\nBackward T' (shape:", Tp_est.shape, ") first 3 cols:\n", Tp_est[:, :min(3, Tp_est.shape[1])])


✅ Test (MNIST0.3) complete.
Best Test Accuracy @ Val = 52.94%

Forward T (shape: (3, 3) ) first 3 cols:
 [[0.999  0.0005 0.0005]
 [0.0005 0.999  0.0005]
 [0.0005 0.0005 0.999 ]]

Backward T' (shape: (3, 3) ) first 3 cols:
 [[0.999  0.0005 0.0005]
 [0.0005 0.999  0.0005]
 [0.0005 0.0005 0.999 ]]


In [None]:
T_est