# Exp K: Learning Dynamics - Critical Slowing Down

## 目的
臨界点近傍での学習ダイナミクスを観測し、「critical slowing down」の有無を確認する。
相転移的挙動の時間発展的証拠を得る。

## 実験設計
- **ノイズ率**: η = 0.4, 0.8
- **λ**: 
  - Ordered側: 0.30, 0.35
  - 臨界近傍: 0.43, 0.45, 0.47
  - Collapse側: 0.55
- **Seeds**: 0, 1, 2, 3, 4（揺らぎを見るため多め）

## 観測量
- 学習曲線（毎エポック記録）
- 収束時間（一定性能に達するまでのエポック）
- 揺らぎの時間発展

## Runs計算
6 λ × 2 η × 5 seeds = **60 runs**

## 期待される結果
- 臨界点近傍で収束が遅くなる（critical slowing down）
- 臨界点近傍で揺らぎが大きくなる

In [None]:
# ===== セットアップ =====
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXP_NAME = 'exp_K_learning_dynamics'
TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
BASE_DIR = '/content/drive/MyDrive/dual-gradient-learning/Paper-A'
SAVE_DIR = f'{BASE_DIR}/{EXP_NAME}_{TIMESTAMP}'
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(f'{SAVE_DIR}/figures', exist_ok=True)

print(f'Experiment: {EXP_NAME}')
print(f'Timestamp: {TIMESTAMP}')
print(f'Save directory: {SAVE_DIR}')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils import parameters_to_vector
import torchvision
import torchvision.transforms as transforms
import numpy as np
import json
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'Device: {device}')

In [None]:
# ===== モデル定義 =====
class IndexedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return img, label, idx
    def __len__(self):
        return len(self.dataset)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, 1, stride, bias=False),
                nn.BatchNorm2d(planes))
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(64, 2, 1)
        self.layer2 = self._make_layer(128, 2, 2)
        self.layer3 = self._make_layer(256, 2, 2)
        self.layer4 = self._make_layer(512, 2, 2)
        self.linear = nn.Linear(512, num_classes)
    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_planes, planes, s))
            self.in_planes = planes
        return nn.Sequential(*layers)
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.adaptive_avg_pool2d(out, 1)
        out = out.view(out.size(0), -1)
        return self.linear(out)

In [None]:
# ===== 実験パラメータ =====
BATCH_SIZE = 256
NUM_WORKERS = 4
EPOCHS = 100
LR = 0.1
K = 16

# λの設計: ordered / critical / collapse
LAMBDAS = [0.30, 0.35, 0.43, 0.45, 0.47, 0.55]
LAMBDA_REGIONS = {
    0.30: 'ordered', 0.35: 'ordered',
    0.43: 'critical', 0.45: 'critical', 0.47: 'critical',
    0.55: 'collapse'
}

NOISE_RATES = [0.4, 0.8]
SEEDS = [0, 1, 2, 3, 4]  # 揺らぎを見るため多め

# 実験条件リスト生成
experiments = []
for noise_rate in NOISE_RATES:
    for lam in LAMBDAS:
        for seed in SEEDS:
            experiments.append({
                'lambda': lam,
                'noise_rate': noise_rate,
                'seed': seed,
                'region': LAMBDA_REGIONS[lam]
            })

total_runs = len(experiments)
print(f'Total runs: {total_runs}')
print(f'Estimated time: {total_runs * 9.5 / 60:.1f} hours')

# config保存
config = {
    'experiment': EXP_NAME,
    'timestamp': TIMESTAMP,
    'parameters': {
        'lambdas': LAMBDAS,
        'lambda_regions': LAMBDA_REGIONS,
        'noise_rates': NOISE_RATES,
        'seeds': SEEDS,
        'epochs': EPOCHS,
        'K': K
    },
    'total_runs': total_runs
}
with open(f'{SAVE_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)
print(f'Config saved')

In [None]:
# ===== ユーティリティ関数 =====
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

def load_cifar10():
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    return trainset, testset

def get_data_loaders(trainset, testset):
    indexed_trainset = IndexedDataset(trainset)
    train_loader = DataLoader(indexed_trainset, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=True, drop_last=True)
    test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False,
                             num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=True)
    return train_loader, test_loader

def inject_noise(labels, noise_rate, seed):
    np.random.seed(seed + 1000)
    noisy_labels = labels.copy()
    n_noisy = int(noise_rate * len(labels))
    noisy_indices = np.random.choice(len(labels), n_noisy, replace=False)
    for idx in noisy_indices:
        choices = [i for i in range(10) if i != labels[idx]]
        noisy_labels[idx] = np.random.choice(choices)
    return noisy_labels

def evaluate(model, test_loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return correct / total

In [None]:
# ===== 学習関数（高解像度ログ版） =====
def train_with_dynamics_logging(model, train_loader, test_loader, clean_labels, noisy_labels, lam):
    """
    学習ダイナミクスを毎エポック記録するdual-gradient learning。
    """
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)
    criterion = nn.CrossEntropyLoss()
    
    clean_labels_tensor = torch.tensor(clean_labels, device=device)
    noisy_labels_tensor = torch.tensor(noisy_labels, device=device)
    
    cached_value_grad = None
    global_step = 0
    
    # 毎エポックの記録
    dynamics = {
        'epoch': [],
        'test_acc': [],
        'test_error': [],
        'train_loss': [],
        'avg_cos': [],
        'avg_gmix_norm': []
    }
    
    for epoch in range(EPOCHS):
        model.train()
        epoch_loss = 0
        epoch_cos = []
        epoch_norm = []
        n_batches = 0
        
        for inputs, _, indices in train_loader:
            inputs = inputs.to(device, non_blocking=True)
            indices = indices.to(device, non_blocking=True)
            batch_noisy = noisy_labels_tensor[indices]
            batch_clean = clean_labels_tensor[indices]
            
            # Structure gradient
            optimizer.zero_grad()
            outputs = model(inputs)
            loss_struct = criterion(outputs, batch_noisy)
            loss_struct.backward(retain_graph=True)
            g_struct = parameters_to_vector([p.grad for p in model.parameters()]).clone()
            
            epoch_loss += loss_struct.item()
            n_batches += 1
            
            # Value gradient
            if global_step % K == 0 or cached_value_grad is None:
                optimizer.zero_grad()
                outputs = model(inputs)
                loss_value = criterion(outputs, batch_clean)
                loss_value.backward()
                cached_value_grad = parameters_to_vector([p.grad for p in model.parameters()]).clone()
            
            # Cosine similarity
            g_struct_norm = g_struct / (g_struct.norm() + 1e-12)
            g_value_norm = cached_value_grad / (cached_value_grad.norm() + 1e-12)
            cos_sim = (g_struct_norm @ g_value_norm).item()
            epoch_cos.append(cos_sim)
            
            # Mix
            g_mix = (1 - lam) * g_struct_norm + lam * g_value_norm
            epoch_norm.append(g_mix.norm().item())
            
            optimizer.zero_grad()
            idx = 0
            for p in model.parameters():
                numel = p.numel()
                p.grad = g_mix[idx:idx+numel].view(p.shape).clone()
                idx += numel
            optimizer.step()
            global_step += 1
        
        scheduler.step()
        
        # 毎エポック評価・記録
        test_acc = evaluate(model, test_loader)
        dynamics['epoch'].append(epoch + 1)
        dynamics['test_acc'].append(test_acc)
        dynamics['test_error'].append(1 - test_acc)
        dynamics['train_loss'].append(epoch_loss / n_batches)
        dynamics['avg_cos'].append(np.mean(epoch_cos))
        dynamics['avg_gmix_norm'].append(np.mean(epoch_norm))
    
    final_acc = dynamics['test_acc'][-1]
    best_acc = max(dynamics['test_acc'])
    
    return final_acc, best_acc, dynamics

In [None]:
# ===== データ準備 =====
trainset, testset = load_cifar10()
clean_labels = np.array(trainset.targets)
train_loader, test_loader = get_data_loaders(trainset, testset)

print('Data prepared')

# GPU warmup
warmup_model = ResNet18().to(device)
for _ in range(20):
    _ = warmup_model(torch.randn(BATCH_SIZE, 3, 32, 32, device=device))
del warmup_model
torch.cuda.empty_cache()
print('Warmup complete.')

In [None]:
# ===== メイン実験ループ =====
results = []
checkpoint_file = f'{SAVE_DIR}/checkpoint.json'
completed = set()

if os.path.exists(checkpoint_file):
    with open(checkpoint_file, 'r') as f:
        results = json.load(f)
    for r in results:
        completed.add((r['lambda'], r['noise_rate'], r['seed']))
    print(f'Checkpoint loaded: {len(completed)} runs completed')

run_counter = 0
exp_start = time.time()

for exp in experiments:
    run_counter += 1
    lam = exp['lambda']
    noise_rate = exp['noise_rate']
    seed = exp['seed']
    region = exp['region']
    
    key = (lam, noise_rate, seed)
    if key in completed:
        continue
    
    print(f'\n[{run_counter}/{total_runs}] λ={lam:.2f} ({region}) η={noise_rate} seed={seed}')
    t0 = time.time()
    
    set_seed(seed)
    noisy_labels = inject_noise(clean_labels, noise_rate, seed)
    model = ResNet18().to(device)
    
    final_acc, best_acc, dynamics = train_with_dynamics_logging(
        model, train_loader, test_loader,
        clean_labels, noisy_labels, lam
    )
    elapsed = time.time() - t0
    
    # 収束時間の計算（error < 0.20 に達する最初のエポック）
    convergence_epoch = None
    for i, err in enumerate(dynamics['test_error']):
        if err < 0.20:
            convergence_epoch = i + 1
            break
    
    result = {
        'experiment': EXP_NAME,
        'lambda': lam,
        'region': region,
        'noise_rate': noise_rate,
        'seed': seed,
        'test_acc': final_acc,
        'test_error': 1 - final_acc,
        'best_acc': best_acc,
        'best_error': 1 - best_acc,
        'convergence_epoch': convergence_epoch,
        'time_seconds': elapsed,
        'dynamics': dynamics
    }
    results.append(result)
    
    with open(checkpoint_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    conv_str = f'conv@{convergence_epoch}' if convergence_epoch else 'no conv'
    print(f'  Error: {1-final_acc:.4f} | {conv_str} | Time: {elapsed/60:.1f} min')
    
    completed_count = len(results)
    avg_time = sum([r['time_seconds'] for r in results]) / completed_count
    remaining = total_runs - completed_count
    eta_hours = (remaining * avg_time) / 3600
    print(f'  Progress: {completed_count}/{total_runs} | ETA: {eta_hours:.1f} hours')

print(f'\n{"="*60}')
print(f'EXPERIMENT COMPLETE')
print(f'Total time: {(time.time()-exp_start)/3600:.2f} hours')
print(f'{"="*60}')

In [None]:
# ===== 結果保存 =====
import pandas as pd

with open(f'{SAVE_DIR}/results.json', 'w') as f:
    json.dump(results, f, indent=2)

# CSV用（dynamicsを除く）
results_flat = [{k: v for k, v in r.items() if k != 'dynamics'} for r in results]
df = pd.DataFrame(results_flat)
df.to_csv(f'{SAVE_DIR}/results.csv', index=False)

print(f'Results saved to {SAVE_DIR}/')

In [None]:
# ===== Critical Slowing Down 分析 =====
import pandas as pd
import numpy as np

df = pd.DataFrame([{k: v for k, v in r.items() if k != 'dynamics'} for r in results])

print('='*70)
print('CRITICAL SLOWING DOWN ANALYSIS')
print('='*70)

for noise_rate in NOISE_RATES:
    print(f'\n--- Noise Rate: {int(noise_rate*100)}% ---')
    df_n = df[df['noise_rate'] == noise_rate]
    
    # 領域別の統計
    for region in ['ordered', 'critical', 'collapse']:
        df_r = df_n[df_n['region'] == region]
        if len(df_r) == 0:
            continue
        
        conv_epochs = df_r['convergence_epoch'].dropna()
        no_conv_rate = df_r['convergence_epoch'].isna().mean()
        
        print(f'\n  {region.upper()}:')
        print(f'    Mean error: {df_r["test_error"].mean():.4f} ± {df_r["test_error"].std():.4f}')
        if len(conv_epochs) > 0:
            print(f'    Convergence epoch: {conv_epochs.mean():.1f} ± {conv_epochs.std():.1f}')
        print(f'    No convergence rate: {no_conv_rate*100:.1f}%')
    
    # Critical Slowing Down の証拠
    ordered_conv = df_n[df_n['region'] == 'ordered']['convergence_epoch'].dropna().mean()
    critical_conv = df_n[df_n['region'] == 'critical']['convergence_epoch'].dropna().mean()
    
    if not np.isnan(ordered_conv) and not np.isnan(critical_conv):
        if critical_conv > ordered_conv * 1.2:
            print(f'\n  ⚡ CRITICAL SLOWING DOWN DETECTED')
            print(f'     Ordered conv: {ordered_conv:.1f} epochs')
            print(f'     Critical conv: {critical_conv:.1f} epochs')
            print(f'     Slowdown factor: {critical_conv/ordered_conv:.2f}x')

In [None]:
# ===== 可視化: 学習曲線 =====
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

for i, noise_rate in enumerate(NOISE_RATES):
    results_n = [r for r in results if r['noise_rate'] == noise_rate]
    
    for j, region in enumerate(['ordered', 'critical', 'collapse']):
        ax = axes[i, j]
        results_r = [r for r in results_n if r['region'] == region]
        
        if len(results_r) == 0:
            ax.set_visible(False)
            continue
        
        # λごとに色分け
        lambdas_in_region = sorted(set(r['lambda'] for r in results_r))
        colors = plt.cm.viridis(np.linspace(0, 1, len(lambdas_in_region)))
        
        for lam, color in zip(lambdas_in_region, colors):
            results_l = [r for r in results_r if r['lambda'] == lam]
            
            # 全seedの曲線を薄く
            for r in results_l:
                ax.plot(r['dynamics']['epoch'], r['dynamics']['test_error'],
                       alpha=0.3, color=color, linewidth=1)
            
            # 平均を太く
            epochs = results_l[0]['dynamics']['epoch']
            mean_errors = np.mean([r['dynamics']['test_error'] for r in results_l], axis=0)
            ax.plot(epochs, mean_errors, color=color, linewidth=2, label=f'λ={lam}')
        
        ax.set_xlabel('Epoch', fontsize=11)
        ax.set_ylabel('Test Error', fontsize=11)
        ax.set_title(f'{region.upper()} | η={int(noise_rate*100)}%', fontsize=12, fontweight='bold')
        ax.legend(fontsize=9)
        ax.grid(alpha=0.3)
        ax.set_ylim([0, 1])

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/figures/learning_curves.png', dpi=300, bbox_inches='tight')
plt.show()
print(f'Figure saved: {SAVE_DIR}/figures/learning_curves.png')

In [None]:
# ===== 可視化: 収束時間 =====
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for i, noise_rate in enumerate(NOISE_RATES):
    ax = axes[i]
    df_n = df[df['noise_rate'] == noise_rate]
    
    # λごとの収束時間
    stats = df_n.groupby('lambda')['convergence_epoch'].agg(['mean', 'std', 'count'])
    
    # 収束しなかった割合も考慮
    no_conv = df_n.groupby('lambda')['convergence_epoch'].apply(lambda x: x.isna().mean())
    
    ax.errorbar(stats.index, stats['mean'], yerr=stats['std'],
                marker='o', capsize=4, linewidth=2, markersize=8, color='C0')
    
    # 臨界領域をハイライト
    ax.axvspan(0.43, 0.47, alpha=0.2, color='red', label='Critical region')
    
    ax.set_xlabel('λ', fontsize=12)
    ax.set_ylabel('Convergence Epoch', fontsize=12)
    ax.set_title(f'Convergence Time: η = {int(noise_rate*100)}%', fontsize=13, fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/figures/convergence_time.png', dpi=300, bbox_inches='tight')
plt.show()
print(f'Figure saved: {SAVE_DIR}/figures/convergence_time.png')

In [None]:
# ===== サマリー =====
print('\n' + '='*70)
print('EXPERIMENT K: LEARNING DYNAMICS - SUMMARY')
print('='*70)

# 主要な発見
for noise_rate in NOISE_RATES:
    df_n = df[df['noise_rate'] == noise_rate]
    
    ordered_conv = df_n[df_n['region'] == 'ordered']['convergence_epoch'].dropna().mean()
    critical_conv = df_n[df_n['region'] == 'critical']['convergence_epoch'].dropna().mean()
    
    print(f'\nη = {int(noise_rate*100)}%:')
    print(f'  Ordered region convergence: {ordered_conv:.1f} epochs' if not np.isnan(ordered_conv) else '  Ordered: N/A')
    print(f'  Critical region convergence: {critical_conv:.1f} epochs' if not np.isnan(critical_conv) else '  Critical: N/A')

print(f'\nSave directory: {SAVE_DIR}')