In [9]:
%%writefile models.py
"""
Unified Model Collection for PFN Experiments
Includes: PartitionedMLP, BottleneckMLP, DeepMLP, StandardCNN, BottleneckCNN, DeepCNN, PartitionedCNN
"""

import torch
import torch.nn as nn
from typing import List, Dict


# ==================== MLP MODELS ====================

class PartitionedLinear(nn.Module):
    """Linear layer partitioned into sub-blocks for PFN analysis."""
    
    def __init__(self, in_features: int, out_features: int, num_blocks: int):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_blocks = num_blocks
        self.block_size = out_features // num_blocks
        
        self.blocks = nn.ModuleList([
            nn.Linear(in_features, self.block_size, bias=True)
            for i in range(num_blocks)
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.cat([block(x) for block in self.blocks], dim=-1)
    
    def get_block_params(self, block_idx: int) -> List[nn.Parameter]:
        return list(self.blocks[block_idx].parameters())
    
    def get_block_gradients(self, block_idx: int) -> torch.Tensor:
        grads = [p.grad.flatten() for p in self.blocks[block_idx].parameters() if p.grad is not None]
        if grads: return torch.cat(grads)
        return torch.tensor([0.0], device=next(self.blocks[block_idx].parameters()).device)


class PartitionedMLP(nn.Module):
    """
    Enhanced MLP for MNIST with sub-block partitioning and Dropout.
    - Input: 784
    - Hidden1: 512 (4 blocks of 128) -> ReLU -> Dropout
    - Hidden2: 256 (2 blocks of 128) -> ReLU -> Dropout
    - Output: 10
    """
    
    def __init__(self):
        super().__init__()
        self.layer1 = PartitionedLinear(784, 512, num_blocks=4)
        self.dropout1 = nn.Dropout(0.2)
        self.layer2 = PartitionedLinear(512, 256, num_blocks=2)
        self.dropout2 = nn.Dropout(0.2)
        self.layer3 = nn.Linear(256, 10)
        self.relu = nn.ReLU()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(-1, 784)
        x = self.dropout1(self.relu(self.layer1(x)))
        x = self.dropout2(self.relu(self.layer2(x)))
        return self.layer3(x)
    
    def get_all_block_gradients(self) -> Dict[str, List[torch.Tensor]]:
        gradients = {}
        device = next(self.parameters()).device
        gradients['layer1'] = [self.layer1.get_block_gradients(i) for i in range(4)]
        gradients['layer2'] = [self.layer2.get_block_gradients(i) for i in range(2)]
        grads = [p.grad.flatten() for p in self.layer3.parameters() if p.grad is not None]
        gradients['layer3'] = [torch.cat(grads)] if grads else [torch.tensor([0.0], device=device)]
        return gradients
    
    def get_parameter_groups(self) -> List[Dict]:
        groups = []
        for i in range(4): groups.append({'params': self.layer1.get_block_params(i), 'name': f'layer1_block{i}', 'lr': 0.001})
        for i in range(2): groups.append({'params': self.layer2.get_block_params(i), 'name': f'layer2_block{i}', 'lr': 0.001})
        groups.append({'params': list(self.layer3.parameters()), 'name': 'layer3_block0', 'lr': 0.001})
        return groups


class BottleneckMLP(nn.Module):
    """
    Hourglass MLP with artificial bottleneck layers.
    Architecture: 784 -> 256 -> 64 -> 8 -> 64 -> 256 -> 10
    The narrow 8-neuron layer (reduced from 16) creates a strong information bottleneck.
    """
    
    def __init__(self, bottleneck_width: int = 8):
        super().__init__()
        self.enc1, self.enc2, self.enc3 = nn.Linear(784, 256), nn.Linear(256, 64), nn.Linear(64, bottleneck_width)
        self.dec1, self.dec2, self.output = nn.Linear(bottleneck_width, 64), nn.Linear(64, 256), nn.Linear(256, 10)
        self.relu, self.dropout = nn.ReLU(), nn.Dropout(0.1)
        self.layers = [self.enc1, self.enc2, self.enc3, self.dec1, self.dec2, self.output]
        self.layer_names = ['enc1', 'enc2', 'bottleneck', 'dec1', 'dec2', 'output']
    
    def forward(self, x):
        x = x.view(-1, 784)
        for i, layer in enumerate(self.layers[:-1]):
            x = self.dropout(self.relu(layer(x))) if i != 2 else self.relu(layer(x))
        return self.output(x)
    
    def get_all_block_gradients(self) -> Dict[str, List[torch.Tensor]]:
        gradients = {}
        device = next(self.parameters()).device
        for name, layer in zip(self.layer_names, self.layers):
            grads = [p.grad.flatten() for p in layer.parameters() if p.grad is not None]
            gradients[name] = [torch.cat(grads)] if grads else [torch.tensor([0.0], device=device)]
        return gradients
    
    def get_parameter_groups(self) -> List[Dict]:
        return [{'params': list(layer.parameters()), 'name': f'{name}_block0', 'lr': 0.001}
                for name, layer in zip(self.layer_names, self.layers)]


class DeepMLP(nn.Module):
    """
    Very deep MLP (10+ layers) without residual connections.
    This will suffer from gradient vanishing, perfect for PFN to help.
    """
    
    def __init__(self, num_hidden_layers: int = 10, hidden_dim: int = 128):
        super().__init__()
        layers = [nn.Linear(784, hidden_dim)] + [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_hidden_layers - 1)] + [nn.Linear(hidden_dim, 10)]
        self.layers = nn.ModuleList(layers)
        self.relu = nn.ReLU()
        self.layer_names = [f'hidden_{i}' for i in range(num_hidden_layers)] + ['output']
    
    def forward(self, x):
        x = x.view(-1, 784)
        for layer in self.layers[:-1]: x = self.relu(layer(x))
        return self.layers[-1](x)
    
    def get_all_block_gradients(self) -> Dict[str, List[torch.Tensor]]:
        gradients = {}
        device = next(self.parameters()).device
        for name, layer in zip(self.layer_names, self.layers):
            grads = [p.grad.flatten() for p in layer.parameters() if p.grad is not None]
            gradients[name] = [torch.cat(grads)] if grads else [torch.tensor([0.0], device=device)]
        return gradients
    
    def get_parameter_groups(self) -> List[Dict]:
        return [{'params': list(layer.parameters()), 'name': f'{name}_block0', 'lr': 0.001}
                for name, layer in zip(self.layer_names, self.layers)]


# ==================== CNN MODELS ====================

class ChannelPartitionedConv2d(nn.Module):
    """将卷积层沿输出通道拆分为多个独立Block，构建并行流路径"""
    
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, 
                 stride: int = 1, padding: int = 0, num_blocks: int = 4):
        super().__init__()
        assert out_channels % num_blocks == 0
        self.num_blocks = num_blocks
        self.block_size = out_channels // num_blocks
        self.blocks = nn.ModuleList([
            nn.Conv2d(in_channels, self.block_size, kernel_size, stride, padding)
            for _ in range(num_blocks)
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.cat([block(x) for block in self.blocks], dim=1)


class PartitionedCNN(nn.Module):
    """专为PFN设计的分块CNN，每层4个并行流管道"""
    
    def __init__(self, num_blocks: int = 4, num_classes: int = 10):
        super().__init__()
        self.num_blocks = num_blocks
        self.conv1 = ChannelPartitionedConv2d(3, 64, 3, padding=1, num_blocks=num_blocks)
        self.conv2 = ChannelPartitionedConv2d(64, 128, 3, padding=1, num_blocks=num_blocks)
        self.conv3 = ChannelPartitionedConv2d(128, 128, 3, padding=1, num_blocks=num_blocks)
        self.relu, self.pool, self.gap = nn.ReLU(), nn.MaxPool2d(2), nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(128, num_classes)
        self.layer_names = ['conv1', 'conv2', 'conv3']
        self._layers = [self.conv1, self.conv2, self.conv3]
    
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        return self.fc(self.gap(x).view(x.size(0), -1))
    
    def get_all_block_gradients(self) -> Dict[str, List[torch.Tensor]]:
        gradients = {}
        device = next(self.parameters()).device
        for name, layer in zip(self.layer_names, self._layers):
            block_grads = []
            for block in layer.blocks:
                grads = [p.grad.flatten() for p in block.parameters() if p.grad is not None]
                block_grads.append(torch.cat(grads) if grads else torch.tensor([0.0], device=device))
            gradients[name] = block_grads
        fc_grads = [p.grad.flatten() for p in self.fc.parameters() if p.grad is not None]
        gradients['fc'] = [torch.cat(fc_grads)] if fc_grads else [torch.tensor([0.0], device=device)]
        return gradients
    
    def get_parameter_groups(self) -> List[Dict]:
        groups = []
        for name, layer in zip(self.layer_names, self._layers):
            for i, block in enumerate(layer.blocks):
                groups.append({'params': list(block.parameters()), 'name': f'{name}_block{i}', 'lr': 0.001})
        groups.append({'params': list(self.fc.parameters()), 'name': 'fc_block0', 'lr': 0.001})
        return groups


class DeepPartitionedCNN(nn.Module):
    """
    真正的PFN主场：既深（梯度消失），又宽（分块路由）。
    无残差连接，无BatchNorm，制造梯度消失瓶颈让PFN发挥作用。
    """
    
    def __init__(self, num_layers: int = 12, num_classes: int = 10, num_blocks: int = 4):
        super().__init__()
        self.num_layers = num_layers
        self.num_blocks = num_blocks
        
        # 初始层
        self.conv_in = ChannelPartitionedConv2d(3, 32, 3, padding=1, num_blocks=num_blocks)
        
        # 中间深层
        self.hidden_layers = nn.ModuleList()
        
        in_ch, out_ch = 32, 32
        for i in range(num_layers - 1):
            if i > 0 and i % 4 == 0:
                out_ch = min(in_ch * 2, 128)
            self.hidden_layers.append(ChannelPartitionedConv2d(in_ch, out_ch, 3, padding=1, num_blocks=num_blocks))
            in_ch = out_ch
        
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(out_ch, num_classes)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(2)
        
        self.layer_names = ['conv_in'] + [f'conv_{i}' for i in range(len(self.hidden_layers))] + ['classifier']
    
    def forward(self, x):
        x = self.relu(self.conv_in(x))
        
        for i, conv in enumerate(self.hidden_layers):
            x = self.relu(conv(x))
            if i > 0 and i % 4 == 0:
                x = self.maxpool(x)
        
        x = self.pool(x)
        return self.classifier(x.view(x.size(0), -1))
    
    def get_all_block_gradients(self) -> Dict[str, List[torch.Tensor]]:
        gradients = {}
        device = next(self.parameters()).device
        
        # conv_in
        block_grads = []
        for block in self.conv_in.blocks:
            grads = [p.grad.flatten() for p in block.parameters() if p.grad is not None]
            block_grads.append(torch.cat(grads) if grads else torch.tensor([0.0], device=device))
        gradients['conv_in'] = block_grads
        
        # 中间层
        for i, layer in enumerate(self.hidden_layers):
            block_grads = []
            for block in layer.blocks:
                grads = [p.grad.flatten() for p in block.parameters() if p.grad is not None]
                block_grads.append(torch.cat(grads) if grads else torch.tensor([0.0], device=device))
            gradients[f'conv_{i}'] = block_grads
        
        # Classifier
        fc_grads = [p.grad.flatten() for p in self.classifier.parameters() if p.grad is not None]
        gradients['classifier'] = [torch.cat(fc_grads)] if fc_grads else [torch.tensor([0.0], device=device)]
        
        return gradients
    
    def get_parameter_groups(self) -> List[Dict]:
        groups = []
        
        # conv_in
        for i, block in enumerate(self.conv_in.blocks):
            groups.append({'params': list(block.parameters()), 'name': f'conv_in_block{i}', 'lr': 0.001})
        
        # 中间层
        for layer_idx, layer in enumerate(self.hidden_layers):
            for i, block in enumerate(layer.blocks):
                groups.append({'params': list(block.parameters()), 'name': f'conv_{layer_idx}_block{i}', 'lr': 0.001})
        
        groups.append({'params': list(self.classifier.parameters()), 'name': 'classifier_block0', 'lr': 0.001})
        
        return groups


class BottleneckCNN(nn.Module):
    """
    CNN with strong bottleneck architecture for CIFAR-10.
    Creates artificial information bottleneck to test PFN.
    No BatchNorm, reduced bottleneck width (4 instead of 8).
    """
    
    def __init__(self, bottleneck_channels: int = 4, num_classes: int = 10):
        super().__init__()
        self.enc1 = nn.Sequential(nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2))
        self.enc2 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2))
        self.bottleneck = nn.Sequential(nn.Conv2d(128, bottleneck_channels, 1), nn.ReLU())
        self.dec1 = nn.Sequential(nn.Conv2d(bottleneck_channels, 128, 3, padding=1), nn.ReLU())
        self.dec2 = nn.Sequential(nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1))
        self.classifier = nn.Linear(256, num_classes)
        self.layer_names = ['enc1', 'enc2', 'bottleneck', 'dec1', 'dec2', 'classifier']
        self._layers = [self.enc1, self.enc2, self.bottleneck, self.dec1, self.dec2, self.classifier]
    
    def forward(self, x):
        x = self.dec2(self.dec1(self.bottleneck(self.enc2(self.enc1(x)))))
        return self.classifier(x.view(x.size(0), -1))
    
    def get_all_block_gradients(self) -> Dict[str, List[torch.Tensor]]:
        gradients = {}
        device = next(self.parameters()).device
        for name, layer in zip(self.layer_names, self._layers):
            grads = [p.grad.flatten() for p in layer.parameters() if p.grad is not None]
            gradients[name] = [torch.cat(grads)] if grads else [torch.tensor([0.0], device=device)]
        return gradients
    
    def get_parameter_groups(self) -> List[Dict]:
        return [{'params': list(layer.parameters()), 'name': f'{name}_block0', 'lr': 0.001}
                for name, layer in zip(self.layer_names, self._layers)]


# 新增：VGG 相关工具（不使用 BatchNorm）
def make_vgg_layers(cfg: List, batch_norm: bool = False):
	layers = []
	in_channels = 3
	for v in cfg:
		if v == 'M':
			layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
		else:
			conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
			if batch_norm:
				layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
			else:
				layers += [conv2d, nn.ReLU(inplace=True)]
			in_channels = v
	return nn.Sequential(*layers)

class VGG11NoBN(nn.Module):
	"""VGG-11 without BatchNorm，适配 CIFAR（AdaptiveAvgPool2d -> Linear）"""
	def __init__(self, num_classes: int = 10):
		super().__init__()
		# cfg 'A' 对应 VGG11（no BatchNorm）
		cfg_A = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
		self.features = make_vgg_layers(cfg_A, batch_norm=False)
		self.pool = nn.AdaptiveAvgPool2d(1)
		self.classifier = nn.Linear(512, num_classes)

	def forward(self, x: torch.Tensor) -> torch.Tensor:
		x = self.features(x)
		x = self.pool(x).view(x.size(0), -1)
		return self.classifier(x)

	def get_all_block_gradients(self) -> Dict[str, List[torch.Tensor]]:
		gradients = {}
		device = next(self.parameters()).device
		conv_idx = 0
		for m in self.features:
			if isinstance(m, nn.Conv2d):
				grads = [p.grad.flatten() for p in m.parameters() if p.grad is not None]
				gradients[f'conv{conv_idx}'] = [torch.cat(grads)] if grads else [torch.tensor([0.0], device=device)]
				conv_idx += 1
		fc_grads = [p.grad.flatten() for p in self.classifier.parameters() if p.grad is not None]
		gradients['classifier'] = [torch.cat(fc_grads)] if fc_grads else [torch.tensor([0.0], device=device)]
		return gradients

	def get_parameter_groups(self) -> List[Dict]:
		groups = []
		conv_idx = 0
		for m in self.features:
			if isinstance(m, nn.Conv2d):
				groups.append({'params': list(m.parameters()), 'name': f'conv{conv_idx}_block0', 'lr': 0.001})
				conv_idx += 1
		groups.append({'params': list(self.classifier.parameters()), 'name': 'classifier_block0', 'lr': 0.001})
		return groups


# ==================== MODEL FACTORY ====================

def get_model(scenario: str, dataset: str, **kwargs) -> nn.Module:
	num_classes = 100 if dataset == 'cifar100' else 10
	
	if dataset == 'mnist':
		if scenario == 'standard': return PartitionedMLP()
		elif scenario == 'bottleneck': return BottleneckMLP(bottleneck_width=kwargs.get('bottleneck_width', 8))
		elif scenario == 'deep': return DeepMLP(num_hidden_layers=kwargs.get('num_layers', 10))
	
	elif dataset in ['cifar10', 'cifar100']:
		# 标准场景使用无 BatchNorm 的 VGG-11
		if scenario == 'standard':
			return VGG11NoBN(num_classes=num_classes)
		elif scenario == 'bottleneck': 
			return BottleneckCNN(bottleneck_channels=kwargs.get('bottleneck_width', 4), num_classes=num_classes)
		elif scenario == 'deep': 
			return DeepPartitionedCNN(num_layers=kwargs.get('num_layers', 12), num_classes=num_classes, num_blocks=4)
	
	raise ValueError(f"Unknown scenario '{scenario}' or dataset '{dataset}'")

Overwriting models.py


In [10]:
%%writefile pfn.py
"""
Parameter Flow Network (PFN) - 极简物理版
核心原则：容量 = 梯度能量，不要玄学
"""

import torch
import numpy as np
from torch.optim import Optimizer
from collections import deque, defaultdict
from typing import Dict, List, Set, Tuple, Optional


class PFNGraphBuilder:
    """极简图构建器：容量 = 归一化梯度范数 × 深度补偿"""
    
    def __init__(self, use_hessian_approx: bool = False, depth_penalty: bool = True, history_size: int = 5):
        self.num_nodes = 0
        self.source = 0
        self.sink = 0
        self.node_map = {}
        self.node_names = {}
        self.layer_names = []
        self.current_epoch = 0
        self.total_epochs = 1
        self.history_size = history_size
        self.grad_history: List[Dict[str, List[float]]] = []
        self.debug = True  # 开启调试输出
    
    def setup_topology(self, gradients: Dict[str, List[torch.Tensor]]):
        self.layer_names = list(gradients.keys())
        self.node_map = {}
        self.node_names = {}
        
        current_id = 1
        for layer_name in self.layer_names:
            for block_idx in range(len(gradients[layer_name])):
                self.node_map[(layer_name, block_idx)] = current_id
                self.node_names[current_id] = f"{layer_name}_block{block_idx}"
                current_id += 1
        
        self.source = 0
        self.sink = current_id
        self.num_nodes = current_id + 1
    
    def _compute_normalized_energy(self, gradients: Dict[str, List[torch.Tensor]]) -> Dict[str, List[float]]:
        """计算归一化的梯度能量"""
        # 收集所有范数
        all_norms = []
        raw_norms = {}
        
        for layer_name, blocks in gradients.items():
            raw_norms[layer_name] = []
            for grad in blocks:
                norm = grad.norm(2).item()
                raw_norms[layer_name].append(norm)
                all_norms.append(norm)
        
        # 计算平均范数用于归一化
        avg_norm = np.mean(all_norms) + 1e-9
        
        # 归一化
        normalized = {}
        for layer_name, norms in raw_norms.items():
            normalized[layer_name] = [n / avg_norm for n in norms]
        
        return normalized
    
    def _get_smoothed_energy(self, current_energy: Dict[str, List[float]]) -> Dict[str, List[float]]:
        """使用历史平滑能量值，减少噪声"""
        self.grad_history.append(current_energy)
        if len(self.grad_history) > self.history_size:
            self.grad_history.pop(0)
        
        if len(self.grad_history) == 1:
            return current_energy
        
        smoothed = {}
        for layer_name in current_energy:
            smoothed[layer_name] = []
            for b_idx in range(len(current_energy[layer_name])):
                values = [h[layer_name][b_idx] for h in self.grad_history if layer_name in h and b_idx < len(h[layer_name])]
                smoothed[layer_name].append(np.mean(values) if values else current_energy[layer_name][b_idx])
        
        return smoothed
    
    def build_graph(self, gradients: Dict[str, List[torch.Tensor]]) -> Tuple[np.ndarray, Dict]:
        """构建流网络：容量 = 归一化梯度能量 × 深度补偿"""
        self.setup_topology(gradients)
        capacity = np.zeros((self.num_nodes, self.num_nodes))
        
        # 归一化并平滑
        normalized_energy = self._compute_normalized_energy(gradients)
        smoothed_energy = self._get_smoothed_energy(normalized_energy)
        
        num_layers = len(self.layer_names)
        
        for layer_idx, layer_name in enumerate(self.layer_names):
            energies = smoothed_energy[layer_name]
            # 深度补偿：更温和的指数补偿
            depth_scale = 1.0 + 0.2 * layer_idx
            
            for b_idx, energy in enumerate(energies):
                u = self.node_map[(layer_name, b_idx)]
                scaled_energy = max(energy * depth_scale, 0.01)  # 最小容量提高到0.01
                
                # 1. 注入边：只有第一层连Source
                if layer_idx == 0:
                    capacity[self.source][u] = scaled_energy
                
                # 2. 传输边：Mesh Connectivity（交叉连接）
                if layer_idx < num_layers - 1:
                    next_layer = self.layer_names[layer_idx + 1]
                    next_energies = smoothed_energy[next_layer]
                    num_next_blocks = len(next_energies)
                    
                    # 连接到下一层的多个block（mesh connectivity）
                    for offset in [-1, 0, 1]:
                        nb_idx = b_idx + offset
                        if 0 <= nb_idx < num_next_blocks:
                            v = self.node_map[(next_layer, nb_idx)]
                            # 直连权重1.0，邻近权重0.3
                            weight = 1.0 if offset == 0 else 0.3
                            capacity[u][v] = scaled_energy * weight
                
                # 3. 提取边：最后一层连Sink
                if layer_idx == num_layers - 1:
                    capacity[u][self.sink] = scaled_energy
        
        return capacity, {'normalized_energy': smoothed_energy}
    
    def get_node_name(self, node_idx: int) -> str:
        if node_idx == self.source: return "Source"
        if node_idx == self.sink: return "Sink"
        return self.node_names.get(node_idx, f"Node_{node_idx}")


class IncrementalPushRelabel:
    """简化的Push-Relabel最大流算法"""
    
    def __init__(self):
        self.flow = defaultdict(float)
        self.height = {}
        self.excess = {}
        self.n = 0
    
    def find_min_cut(self, capacity: np.ndarray, source: int, sink: int) -> Tuple[float, List, Set, Set]:
        self.n = capacity.shape[0]
        max_flow = self._solve(capacity, source, sink)
        
        # 计算S集合（可达集）
        S_set = {source}
        queue = deque([source])
        while queue:
            u = queue.popleft()
            for v in range(self.n):
                if v not in S_set:
                    residual = capacity[u][v] - self.flow.get((u, v), 0)
                    if residual > 1e-9:
                        S_set.add(v)
                        queue.append(v)
        
        T_set = set(range(self.n)) - S_set
        cut_edges = [(u, v) for u in S_set for v in T_set if capacity[u][v] > 1e-9]
        
        return max_flow, cut_edges, S_set, T_set
    
    def _solve(self, capacity: np.ndarray, source: int, sink: int) -> float:
        n = self.n
        self.height = {i: 0 for i in range(n)}
        self.height[source] = n
        self.excess = {i: 0.0 for i in range(n)}
        self.flow = defaultdict(float)
        
        # 预流
        for v in range(n):
            if capacity[source][v] > 1e-9:
                f = capacity[source][v]
                self.flow[(source, v)] = f
                self.flow[(v, source)] = -f
                self.excess[v] = f
                self.excess[source] -= f
        
        active = deque([v for v in range(n) if v != source and v != sink and self.excess[v] > 1e-9])
        max_iter = n * n * 2
        
        for _ in range(max_iter):
            if not active: break
            u = active.popleft()
            if self.excess[u] <= 1e-9: continue
            
            # Push
            for v in range(n):
                if self.excess[u] <= 1e-9: break
                res = capacity[u][v] - self.flow.get((u, v), 0)
                if res > 1e-9 and self.height[u] == self.height[v] + 1:
                    delta = min(self.excess[u], res)
                    self.flow[(u, v)] += delta
                    self.flow[(v, u)] -= delta
                    self.excess[u] -= delta
                    self.excess[v] += delta
                    if v != source and v != sink and self.excess[v] > 1e-9 and v not in active:
                        active.append(v)
            
            # Relabel
            if self.excess[u] > 1e-9:
                min_h = float('inf')
                for v in range(n):
                    if capacity[u][v] - self.flow.get((u, v), 0) > 1e-9:
                        min_h = min(min_h, self.height[v])
                if min_h < float('inf'):
                    self.height[u] = min_h + 1
                    active.append(u)
        
        return max(self.excess.get(sink, 0), 0)


DinicSolver = IncrementalPushRelabel


class BottleneckOptimizer:
    """温和的瓶颈补偿优化器 - 不重置LR，基于当前值调整"""
    
    def __init__(self, optimizer: Optimizer, base_lr: float = 0.001,
                 base_boost: float = 1.3, max_boost: float = 3.0, decay_factor: float = 0.95):
        self.optimizer = optimizer
        self.base_lr = base_lr
        self.base_boost = base_boost
        self.max_boost = max_boost
        self.decay_factor = decay_factor
        self.name_to_idx = {g['name']: i for i, g in enumerate(optimizer.param_groups) if 'name' in g}
        self.bottleneck_history: List[List[str]] = []
        self.boost_counts: Dict[str, int] = defaultdict(int)
        self.debug = True
        self.step_count = 0
    
    def _node_to_param_group(self, node_name: str) -> Optional[str]:
        if node_name in self.name_to_idx:
            return node_name
        for name in self.name_to_idx:
            if node_name in name:
                return name
        return None
    
    def update_learning_rates(self, S_set: Set[int], T_set: Set[int],
                              cut_edges: List[Tuple[int, int]],
                              capacity_matrix: np.ndarray,
                              graph_builder, flow_deficit: float = 0.0):
        self.step_count += 1
        
        # Debug输出
        if self.debug and self.step_count % 10 == 0:
            total_potential = np.sum(capacity_matrix[graph_builder.source, :])
            print(f"\n[PFN Debug] Step {self.step_count}")
            print(f"  Total Potential Flow: {total_potential:.4f}")
            print(f"  Graph Partition: S={len(S_set)}, T={len(T_set)}")
            
            bottleneck_names = []
            for u, v in cut_edges[:5]:  # 只显示前5个
                u_name = graph_builder.get_node_name(u)
                v_name = graph_builder.get_node_name(v)
                bottleneck_names.append(f"{u_name}->{v_name}")
            print(f"  Cut Edges: {bottleneck_names}")
        
        # 温和衰减所有非BN层的LR（不是重置！）
        for group in self.optimizer.param_groups:
            if 'name' in group and 'bn' not in group['name'].lower():
                # 温和衰减，而不是重置
                current_lr = group['lr']
                target_lr = self.base_lr
                # 缓慢向base_lr回归
                group['lr'] = current_lr * self.decay_factor + target_lr * (1 - self.decay_factor)
        
        # 识别瓶颈节点
        bottleneck_nodes = set()
        for u, v in cut_edges:
            if v != graph_builder.sink:
                bottleneck_nodes.add(v)
            if u != graph_builder.source:
                bottleneck_nodes.add(u)
        
        # 基于当前LR进行boost（不是重置后boost）
        boosted_names = []
        for node_id in bottleneck_nodes:
            node_name = graph_builder.get_node_name(node_id)
            param_name = self._node_to_param_group(node_name)
            
            if param_name and param_name in self.name_to_idx:
                idx = self.name_to_idx[param_name]
                current_lr = self.optimizer.param_groups[idx]['lr']
                
                # 计算boost，考虑flow_deficit
                boost = self.base_boost * (1.0 + min(flow_deficit, 0.5))
                new_lr = min(current_lr * boost, self.base_lr * self.max_boost)
                
                self.optimizer.param_groups[idx]['lr'] = new_lr
                self.boost_counts[param_name] += 1
                boosted_names.append(f"{param_name}({new_lr:.5f})")
        
        if self.debug and self.step_count % 10 == 0 and boosted_names:
            print(f"  Boosted: {boosted_names[:3]}")
        
        self.bottleneck_history.append([f"{graph_builder.get_node_name(u)}->{graph_builder.get_node_name(v)}" for u, v in cut_edges])
    
    def apply_gradient_clipping(self, model, T_set: Set[int], graph_builder):
        pass
    
    def step(self):
        self.optimizer.step()
    
    def zero_grad(self):
        self.optimizer.zero_grad()
    
    def get_statistics(self) -> Dict:
        return {
            'bottleneck_history': self.bottleneck_history,
            'boost_counts': dict(self.boost_counts)
        }

Overwriting pfn.py


In [11]:
%%writefile run_experiment.py
"""PFN Experiment - Parameter Flow Network for Neural Network Optimization."""

import os, json, argparse, torch, numpy as np
import torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm import tqdm
from datetime import datetime
from typing import Dict, List, Optional
from models import get_model
from pfn import PFNGraphBuilder, IncrementalPushRelabel, BottleneckOptimizer


# ==================== CONFIG ====================
CONFIG = {
    'mnist': {
        'epochs_list': [30],
        'batches_list': [32],
        'scenarios': [
            {'name': '0.Original', 'scenario': 'standard'},
            {'name': '1.Bottleneck', 'scenario': 'bottleneck', 'bottleneck_width': 4},
            {'name': '2.Deep', 'scenario': 'deep', 'num_layers': 15, 'lr': 0.0005},
            {'name': '3.Noisy', 'scenario': 'standard', 'pixel_noise': 0.3, 'label_noise': 0.15, 'samples': 100},
        ]
    },
    'cifar10': {
        'epochs_list': [50],
        'batches_list': [64,128,256,512],
        'scenarios': [
            {'name': '0.Original', 'scenario': 'standard'},
            {'name': '1.Bottleneck', 'scenario': 'bottleneck', 'bottleneck_width': 2},
            {'name': '2.Deep', 'scenario': 'deep', 'num_layers': 15, 'lr': 0.0005},
            {'name': '3.Noisy', 'scenario': 'standard', 'pixel_noise': 0.2, 'label_noise': 0.1, 'samples': 200},
        ]
    },
    'cifar100': {
        'epochs_list': [100],
        'batches_list': [1024],
        'scenarios': [
            {'name': '1.Bottleneck', 'scenario': 'bottleneck', 'bottleneck_width': 2},
            {'name': '2.Deep', 'scenario': 'deep', 'num_layers': 15, 'lr': 0.0005},
            {'name': '3.Noisy', 'scenario': 'standard', 'pixel_noise': 0.15, 'label_noise': 0.1, 'samples': 300},
        ]
    },
}
PFN_INTERVAL = 100  # 增加间隔，减少干扰Adam
PFN_WARMUP_STEPS = 200  # 前200步不介入，让Adam稳定
BASE_LR = 0.0001


# ==================== BASELINE CACHE ====================
class BaselineCache:
    """Baseline结果缓存系统"""
    
    def __init__(self, cache_dir: str = './results-baseline'):
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)
    
    def _get_cache_key(self, dataset: str, scenario: str, batch_size: int, epochs: int) -> str:
        """生成缓存键"""
        return f"{dataset}_{scenario}_b{batch_size}_e{epochs}"
    
    def get(self, dataset: str, scenario: str, batch_size: int, epochs: int) -> Optional[Dict]:
        """获取缓存的baseline结果"""
        key = self._get_cache_key(dataset, scenario, batch_size, epochs)
        cache_file = os.path.join(self.cache_dir, f"{key}.json")
        
        if os.path.exists(cache_file):
            try:
                with open(cache_file, 'r') as f:
                    return json.load(f)
            except:
                return None
        return None
    
    def save(self, dataset: str, scenario: str, batch_size: int, epochs: int, 
             acc: float, loss_hist: List[float], acc_hist: List[float]):
        """保存baseline结果"""
        key = self._get_cache_key(dataset, scenario, batch_size, epochs)
        cache_file = os.path.join(self.cache_dir, f"{key}.json")
        
        data = {
            'acc': float(acc),
            'loss': loss_hist,
            'acc_hist': acc_hist,
            'timestamp': datetime.now().isoformat()
        }
        
        with open(cache_file, 'w') as f:
            json.dump(data, f)


baseline_cache = BaselineCache()


# ==================== DATA ====================
class NoisyDataset(Dataset):
    """Dataset wrapper with pixel and label noise."""
    def __init__(self, base_dataset, pixel_noise=0.0, label_noise=0.0, num_classes=10):
        self.base = base_dataset
        self.pixel_noise = pixel_noise
        self.num_classes = num_classes
        
        self.noisy_labels = []
        np.random.seed(42)
        for i in range(len(base_dataset)):
            _, label = base_dataset[i]
            label = label.item() if isinstance(label, torch.Tensor) else label
            if label_noise > 0 and np.random.random() < label_noise:
                candidates = [l for l in range(num_classes) if l != label]
                label = np.random.choice(candidates)
            self.noisy_labels.append(label)
    
    def __len__(self): return len(self.base)
    
    def __getitem__(self, idx):
        img, _ = self.base[idx]
        label = self.noisy_labels[idx]
        if self.pixel_noise > 0:
            img = img + torch.randn_like(img) * self.pixel_noise
        return img, label


class SmallDataset(Dataset):
    """Dataset wrapper that limits samples per class."""
    def __init__(self, base_dataset, samples_per_class, num_classes=10):
        self.base = base_dataset
        indices_per_class = {c: [] for c in range(num_classes)}
        for idx in range(len(base_dataset)):
            _, label = base_dataset[idx]
            label = label.item() if isinstance(label, torch.Tensor) else label
            if len(indices_per_class[label]) < samples_per_class:
                indices_per_class[label].append(idx)
        self.indices = [i for c in range(num_classes) for i in indices_per_class[c]]
    
    def __len__(self): return len(self.indices)
    def __getitem__(self, idx): return self.base[self.indices[idx]]


def get_loaders(dataset_name, batch_size, pixel_noise=0, label_noise=0, samples_per_class=None):
    num_classes = 100 if dataset_name == 'cifar100' else 10
    
    if dataset_name == 'mnist':
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        train_ds = datasets.MNIST('./data', train=True, download=True, transform=transform)
        test_ds = datasets.MNIST('./data', train=False, transform=transform)
    else:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        
        # --- 针对 Kaggle 路径适配 ---
        if dataset_name == 'cifar100' and os.path.exists('/kaggle/input/cifar100'):
            # 创建虚拟目录结构，使之符合 torchvision 期待的 'cifar-100-python' 命名
            tmp_root = './tmp_cifar'
            os.makedirs(tmp_root, exist_ok=True)
            target_link = os.path.join(tmp_root, 'cifar-100-python')
            if not os.path.exists(target_link):
                os.symlink('/kaggle/input/cifar100', target_link)
            
            train_ds = datasets.CIFAR100(root=tmp_root, train=True, download=False, transform=train_transform)
            test_ds = datasets.CIFAR100(root=tmp_root, train=False, download=False, transform=test_transform)
        elif dataset_name == 'cifar100':
            train_ds = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
            test_ds = datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
        else:
            # CIFAR10 逻辑
            train_ds = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
            test_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
    
    # Apply modifications
    if samples_per_class:
        train_ds = SmallDataset(train_ds, samples_per_class, num_classes)
    if pixel_noise > 0 or label_noise > 0:
        train_ds = NoisyDataset(train_ds, pixel_noise, label_noise, num_classes)
    
    train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size * 2, shuffle=False, num_workers=2)
    return train_loader, test_loader


# ==================== DEVICE ====================
def get_device(force_cpu=False):
    if force_cpu: return torch.device('cpu')
    if torch.cuda.is_available(): return torch.device('cuda')
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return torch.device('mps')
    return torch.device('cpu')


# ==================== TRAINING ====================
def train_epoch(model, loader, opt, criterion, device, pfn=None, step=0, epoch=0, total_epochs=1):
    model.train()
    total_loss = 0
    
    for x, y in tqdm(loader, desc="    train", leave=False):
        x, y = x.to(device), y.to(device)
        
        (pfn['opt'] if pfn else opt).zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        
        if pfn and step > PFN_WARMUP_STEPS and step % PFN_INTERVAL == 0:
            grads = {k: [t.detach().cpu() for t in v] for k, v in model.get_all_block_gradients().items()}
            pfn['gb'].current_epoch = epoch
            pfn['gb'].total_epochs = total_epochs
            
            cap, meta = pfn['gb'].build_graph(grads)
            max_flow, cuts, S, T = pfn['solver'].find_min_cut(cap, pfn['gb'].source, pfn['gb'].sink)
            
            total_cap = sum(cap[u][v] for u, v in cuts) if cuts else 1.0
            flow_deficit = (total_cap - max_flow) / (total_cap + 1e-9)
            pfn['opt'].update_learning_rates(S, T, cuts, cap, pfn['gb'], flow_deficit)
        
        (pfn['opt'] if pfn else opt).step()
        total_loss += loss.item()
        step += 1
    
    return total_loss / len(loader), step


def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            correct += (model(x).argmax(1) == y).sum().item()
            total += y.size(0)
    return correct / total


def train_model(config: dict, device, use_pfn=False):
    lr = config.get('lr', BASE_LR)
    train_loader, test_loader = get_loaders(
        config['dataset'], config['batch_size'],
        config.get('pixel_noise', 0), config.get('label_noise', 0), config.get('samples')
    )
    
    model = get_model(
        config['scenario'], config['dataset'],
        bottleneck_width=config.get('bottleneck_width', 8),
        num_layers=config.get('num_layers', 10)
    ).to(device)
    
    if use_pfn:
        groups = model.get_parameter_groups()
        for g in groups: g['lr'] = lr
        opt = optim.Adam(groups, lr=lr)
        pfn = {
            'gb': PFNGraphBuilder(history_size=5),
            'solver': IncrementalPushRelabel(),
            'opt': BottleneckOptimizer(opt, lr, base_boost=1.3, max_boost=3.0, decay_factor=0.95)
        }
        pfn['gb'].debug = (config['epochs'] <= 15)
        pfn['opt'].debug = (config['epochs'] <= 15)
    else:
        opt = optim.Adam(model.parameters(), lr=lr)
        pfn = None
    
    scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, config['epochs'])
    step = 0
    loss_hist, acc_hist = [], []
    
    for ep in range(config['epochs']):
        loss, step = train_epoch(model, train_loader, opt, nn.CrossEntropyLoss(), device, pfn, step, ep, config['epochs'])
        scheduler.step()
        acc = evaluate(model, test_loader, device)
        loss_hist.append(loss)
        acc_hist.append(acc)
        print(f"    [{ep+1:2d}/{config['epochs']}] loss={loss:.4f} acc={acc:.4f}")
    
    return acc, loss_hist, acc_hist


def run_scenario(dataset, cfg, epochs, batch, device, results_dir):
    name = cfg['name']
    os.makedirs(os.path.join(results_dir, name), exist_ok=True)
    
    config = {
        'dataset': dataset, 'scenario': cfg['scenario'],
        'epochs': cfg.get('epochs', epochs), 'batch_size': cfg.get('batch_size', batch),
        'lr': cfg.get('lr', BASE_LR), 'bottleneck_width': cfg.get('bottleneck_width', 8),
        'num_layers': cfg.get('num_layers', 10), 'pixel_noise': cfg.get('pixel_noise', 0),
        'label_noise': cfg.get('label_noise', 0), 'samples': cfg.get('samples')
    }
    
    print(f"  [Baseline]", end=' ')
    cached = baseline_cache.get(dataset, cfg['scenario'], config['batch_size'], config['epochs'])
    
    if cached:
        print("(缓存)")
        base_acc = cached['acc']
        base_loss = cached['loss']
        base_acc_hist = cached['acc_hist']
    else:
        print("(训练)")
        base_acc, base_loss, base_acc_hist = train_model(config, device, False)
        baseline_cache.save(dataset, cfg['scenario'], config['batch_size'], config['epochs'],
                           base_acc, base_loss, base_acc_hist)
    
    print(f"  [PFN]")
    pfn_acc, pfn_loss, pfn_acc_hist = train_model(config, device, True)
    
    with open(os.path.join(results_dir, name, 'log.json'), 'w') as f:
        json.dump({'baseline': {'loss': base_loss, 'acc': base_acc_hist}, 'pfn': {'loss': pfn_loss, 'acc': pfn_acc_hist}}, f)
    
    try:
        import matplotlib.pyplot as plt
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
        eps = range(1, config['epochs'] + 1)
        ax1.plot(eps, base_loss, 'o-', label='Baseline'); ax1.plot(eps, pfn_loss, 's-', label='PFN')
        ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.legend(); ax1.grid(alpha=0.3)
        ax2.plot(eps, base_acc_hist, 'o-', label='Baseline'); ax2.plot(eps, pfn_acc_hist, 's-', label='PFN')
        ax2.set_xlabel('Epoch'); ax2.set_ylabel('Accuracy'); ax2.legend(); ax2.grid(alpha=0.3)
        plt.tight_layout(); plt.savefig(os.path.join(results_dir, name, 'curves.png'), dpi=150); plt.close()
    except: pass
    
    return {'name': name, 'baseline': base_acc, 'pfn': pfn_acc, 'improvement': pfn_acc - base_acc}


def run_experiment(dataset='mnist', force_cpu=False, scenarios=None,
                    epochs_override=None, batch_override=None):
    device = get_device(force_cpu)
    cfg = CONFIG[dataset]
    epochs = epochs_override if epochs_override is not None else cfg['epochs_list'][0]
    batch = batch_override if batch_override is not None else cfg['batches_list'][0]
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = f'./results_{timestamp}/{dataset}_e{epochs}_b{batch}'
    os.makedirs(results_dir, exist_ok=True)
    
    print("=" * 60)
    print(f"  PFN EXPERIMENT - {dataset.upper()}")
    print(f"  Device: {device} | Epochs: {epochs} | Batch: {batch}")
    print("=" * 60)
    
    scenario_configs = cfg['scenarios']
    if scenarios:
        scenario_configs = [s for s in scenario_configs if s['name'] in scenarios or any(kw in s['name'].lower() for kw in scenarios)]
    
    results = []
    for i, scenario_cfg in enumerate(scenario_configs):
        print(f"\n[{i+1}/{len(scenario_configs)}] {scenario_cfg['name']}")
        print("-" * 40)
        results.append(run_scenario(dataset, scenario_cfg, epochs, batch, device, results_dir))
        print(f"  >> Base={results[-1]['baseline']:.4f} | PFN={results[-1]['pfn']:.4f} | Δ={results[-1]['improvement']:+.4f}")
    
    print("\n" + "=" * 60)
    wins = sum(1 for r in results if r['improvement'] > 0)
    avg = sum(r['improvement'] for r in results) / len(results) if results else 0
    for r in results:
        print(f"{r['name']:<20} {r['baseline']:.4f}  {r['pfn']:.4f}  {r['improvement']:+.4f} {'✓' if r['improvement'] > 0 else '✗'}")
    print("-" * 60)
    print(f"AVG: {avg:+.4f}  ({wins}/{len(results)} wins)")
    
    with open(os.path.join(results_dir, 'summary.json'), 'w') as f:
        json.dump({'results': results, 'avg': avg, 'wins': wins}, f, indent=2)
    
    print(f"\nSaved: {results_dir}/")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PFN Experiment Runner')
    parser.add_argument('--dataset', default=None, choices=['mnist', 'cifar10', 'cifar100'],
                        help='Dataset to use')
    parser.add_argument('--datasets', nargs='+', default=None, help='Multiple datasets')
    parser.add_argument('--cpu', action='store_true', help='Force CPU')
    parser.add_argument('--scenarios', nargs='+', default=None, help='Specific scenarios')
    parser.add_argument('--no-cache', action='store_true', help='Ignore cache')
    parser.add_argument('--epochs', type=int, default=None, help='Override epochs')
    parser.add_argument('--batch', type=int, default=None, help='Override batch size')
    parser.add_argument('--epochs-list', nargs='+', type=int, default=None)
    parser.add_argument('--batches-list', nargs='+', type=int, default=None)
    
    args = parser.parse_args()
    
    if args.no_cache:
        baseline_cache.cache_dir = './results-baseline-nocache'
        os.makedirs(baseline_cache.cache_dir, exist_ok=True)
    
    if args.datasets:
        targets = args.datasets
    elif args.dataset:
        targets = [args.dataset]
    else:
        targets = list(CONFIG.keys())
    
    for ds in targets:
        cfg = CONFIG[ds]
        epochs_list = args.epochs_list if args.epochs_list is not None else ([args.epochs] if args.epochs is not None else cfg['epochs_list'])
        batches_list = args.batches_list if args.batches_list is not None else ([args.batch] if args.batch is not None else cfg['batches_list'])
        
        for ep in epochs_list:
            for bs in batches_list:
                run_experiment(ds, args.cpu, args.scenarios, epochs_override=ep, batch_override=bs)

Overwriting run_experiment.py


In [None]:
!python run_experiment.py --dataset cifar100

  PFN EXPERIMENT - CIFAR100
  Device: cuda | Epochs: 100 | Batch: 1024

[1/3] 1.Bottleneck
----------------------------------------
  [Baseline] (训练)
    [ 1/100] loss=4.5802 acc=0.0221                                             
    [ 2/100] loss=4.4725 acc=0.0270                                             
    [ 3/100] loss=4.4033 acc=0.0309                                             
    [ 4/100] loss=4.3474 acc=0.0410                                             
    [ 5/100] loss=4.2975 acc=0.0465                                             
    [ 6/100] loss=4.2515 acc=0.0452                                             
    [ 7/100] loss=4.2090 acc=0.0520                                             
    [ 8/100] loss=4.1717 acc=0.0571                                             
    [ 9/100] loss=4.1417 acc=0.0638                                             
    [10/100] loss=4.1147 acc=0.0640                                             
    [11/100] loss=4.0909 acc=0.0700     