# Exp J: Mechanism Decomposition - Why Misaligned Gradients Improve

## 目的
Misaligned（random/anti-label）が低λで改善する「パラドックス」の機構を分解・特定する。

## 仮説
1. **暗黙のLR低下仮説**: ||g_mix|| < 1 による効果的学習率低下
2. **直交ノイズ仮説**: g_structに直交する成分が正則化/探索効果を持つ
3. **幾何効果仮説**: 上記では説明できない表現学習の相互作用

## 実験設計
- **ノイズ率**: η = 0.4
- **λ**: 0.1, 0.2, 0.3
- **Seeds**: 0, 1, 2

### 比較条件（5種類）
1. **CE_baseline**: 通常のCE学習
2. **Misaligned_raw**: Random label mix（renorm=False）
3. **Misaligned_renorm**: Random label mix（renorm=True）
4. **LR_matched**: ||g_mix||に合わせてLRを下げたCE
5. **Orthogonal_noise**: g_structに直交ノイズを同強度で加える

## Runs計算
5条件 × 3λ × 3seeds = **45 runs** + baseline 3 = **48 runs**

## 判定ロジック
- Misaligned_renorm で改善が消える → 暗黙LR低下が主因
- Misaligned_renorm でも改善が残り、Orthogonal_noise で再現 → 直交ノイズ効果が主因
- どれでも再現しない → 幾何効果（未解明の相互作用）

In [None]:
# ===== セットアップ =====
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXP_NAME = 'exp_J_mechanism_decomposition'
TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
BASE_DIR = '/content/drive/MyDrive/dual-gradient-learning/Paper-A'
SAVE_DIR = f'{BASE_DIR}/{EXP_NAME}_{TIMESTAMP}'
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(f'{SAVE_DIR}/figures', exist_ok=True)

print(f'Experiment: {EXP_NAME}')
print(f'Timestamp: {TIMESTAMP}')
print(f'Save directory: {SAVE_DIR}')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils import parameters_to_vector
import torchvision
import torchvision.transforms as transforms
import numpy as np
import json
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'Device: {device}')

In [None]:
# ===== モデル定義 =====
class IndexedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return img, label, idx
    def __len__(self):
        return len(self.dataset)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, 1, stride, bias=False),
                nn.BatchNorm2d(planes))
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(64, 2, 1)
        self.layer2 = self._make_layer(128, 2, 2)
        self.layer3 = self._make_layer(256, 2, 2)
        self.layer4 = self._make_layer(512, 2, 2)
        self.linear = nn.Linear(512, num_classes)
    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_planes, planes, s))
            self.in_planes = planes
        return nn.Sequential(*layers)
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.adaptive_avg_pool2d(out, 1)
        out = out.view(out.size(0), -1)
        return self.linear(out)

In [None]:
# ===== 実験パラメータ =====
BATCH_SIZE = 256
NUM_WORKERS = 4
EPOCHS = 100
LR = 0.1
K = 16

NOISE_RATE = 0.4
LAMBDAS = [0.1, 0.2, 0.3]
SEEDS = [0, 1, 2]

# 実験条件
CONDITIONS = [
    {'name': 'CE_baseline',      'method': 'ce'},
    {'name': 'Misaligned_raw',   'method': 'misaligned', 'renorm': False},
    {'name': 'Misaligned_renorm','method': 'misaligned', 'renorm': True},
    {'name': 'LR_matched',       'method': 'lr_matched'},
    {'name': 'Orthogonal_noise', 'method': 'orthogonal_noise'},
]

# 実験条件リスト生成
experiments = []
for cond in CONDITIONS:
    if cond['method'] == 'ce':
        # baselineはλ不要、seed分だけ
        for seed in SEEDS:
            experiments.append({**cond, 'lambda': None, 'seed': seed})
    else:
        for lam in LAMBDAS:
            for seed in SEEDS:
                experiments.append({**cond, 'lambda': lam, 'seed': seed})

total_runs = len(experiments)
print(f'Total runs: {total_runs}')
print(f'Estimated time: {total_runs * 9.5 / 60:.1f} hours')

# config保存
config = {
    'experiment': EXP_NAME,
    'timestamp': TIMESTAMP,
    'parameters': {
        'conditions': [c['name'] for c in CONDITIONS],
        'lambdas': LAMBDAS,
        'seeds': SEEDS,
        'noise_rate': NOISE_RATE,
        'epochs': EPOCHS,
        'K': K
    },
    'total_runs': total_runs
}
with open(f'{SAVE_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)
print(f'Config saved')

In [None]:
# ===== ユーティリティ関数 =====
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

def load_cifar10():
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    return trainset, testset

def get_data_loaders(trainset, testset):
    indexed_trainset = IndexedDataset(trainset)
    train_loader = DataLoader(indexed_trainset, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=True, drop_last=True)
    test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False,
                             num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=True)
    return train_loader, test_loader

def inject_noise(labels, noise_rate, seed):
    np.random.seed(seed + 1000)
    noisy_labels = labels.copy()
    n_noisy = int(noise_rate * len(labels))
    noisy_indices = np.random.choice(len(labels), n_noisy, replace=False)
    for idx in noisy_indices:
        choices = [i for i in range(10) if i != labels[idx]]
        noisy_labels[idx] = np.random.choice(choices)
    return noisy_labels

def generate_random_labels(labels, seed):
    """完全ランダムなラベルを生成（misaligned用）"""
    np.random.seed(seed + 2000)
    return np.random.randint(0, 10, size=len(labels))

def evaluate(model, test_loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return correct / total

In [None]:
# ===== 学習関数群 =====

def train_ce_baseline(model, train_loader, test_loader, noisy_labels):
    """通常のCE学習（baseline）"""
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)
    criterion = nn.CrossEntropyLoss()
    noisy_labels_tensor = torch.tensor(noisy_labels, device=device)
    
    best_acc = 0
    for epoch in range(EPOCHS):
        model.train()
        for inputs, _, indices in train_loader:
            inputs = inputs.to(device, non_blocking=True)
            indices = indices.to(device, non_blocking=True)
            batch_labels = noisy_labels_tensor[indices]
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()
        scheduler.step()
        
        if (epoch + 1) % 10 == 0:
            best_acc = max(best_acc, evaluate(model, test_loader))
    
    final_acc = evaluate(model, test_loader)
    return final_acc, max(best_acc, final_acc), {'avg_gmix_norm': 1.0}


def train_misaligned(model, train_loader, test_loader, noisy_labels, random_labels, lam, renorm):
    """Misaligned gradient mixing（random labels）"""
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)
    criterion = nn.CrossEntropyLoss()
    
    noisy_labels_tensor = torch.tensor(noisy_labels, device=device)
    random_labels_tensor = torch.tensor(random_labels, device=device)
    
    cached_value_grad = None
    global_step = 0
    best_acc = 0
    gmix_norms = []
    
    for epoch in range(EPOCHS):
        model.train()
        for inputs, _, indices in train_loader:
            inputs = inputs.to(device, non_blocking=True)
            indices = indices.to(device, non_blocking=True)
            batch_noisy = noisy_labels_tensor[indices]
            batch_random = random_labels_tensor[indices]
            
            # Structure gradient
            optimizer.zero_grad()
            outputs = model(inputs)
            loss_struct = criterion(outputs, batch_noisy)
            loss_struct.backward(retain_graph=True)
            g_struct = parameters_to_vector([p.grad for p in model.parameters()]).clone()
            
            # Value gradient (from random labels)
            if global_step % K == 0 or cached_value_grad is None:
                optimizer.zero_grad()
                outputs = model(inputs)
                loss_value = criterion(outputs, batch_random)
                loss_value.backward()
                cached_value_grad = parameters_to_vector([p.grad for p in model.parameters()]).clone()
            
            # Normalize and mix
            g_struct_norm = g_struct / (g_struct.norm() + 1e-12)
            g_value_norm = cached_value_grad / (cached_value_grad.norm() + 1e-12)
            g_mix = (1 - lam) * g_struct_norm + lam * g_value_norm
            
            gmix_norms.append(g_mix.norm().item())
            
            if renorm:
                g_mix = g_mix / (g_mix.norm() + 1e-12)
            
            optimizer.zero_grad()
            idx = 0
            for p in model.parameters():
                numel = p.numel()
                p.grad = g_mix[idx:idx+numel].view(p.shape).clone()
                idx += numel
            optimizer.step()
            global_step += 1
        
        scheduler.step()
        if (epoch + 1) % 10 == 0:
            best_acc = max(best_acc, evaluate(model, test_loader))
    
    final_acc = evaluate(model, test_loader)
    return final_acc, max(best_acc, final_acc), {'avg_gmix_norm': np.mean(gmix_norms)}


def train_lr_matched(model, train_loader, test_loader, noisy_labels, lr_scale):
    """LRを下げたCE学習（暗黙LR低下仮説の検証）"""
    adjusted_lr = LR * lr_scale
    optimizer = optim.SGD(model.parameters(), lr=adjusted_lr, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)
    criterion = nn.CrossEntropyLoss()
    noisy_labels_tensor = torch.tensor(noisy_labels, device=device)
    
    best_acc = 0
    for epoch in range(EPOCHS):
        model.train()
        for inputs, _, indices in train_loader:
            inputs = inputs.to(device, non_blocking=True)
            indices = indices.to(device, non_blocking=True)
            batch_labels = noisy_labels_tensor[indices]
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()
        scheduler.step()
        
        if (epoch + 1) % 10 == 0:
            best_acc = max(best_acc, evaluate(model, test_loader))
    
    final_acc = evaluate(model, test_loader)
    return final_acc, max(best_acc, final_acc), {'lr_scale': lr_scale}


def train_orthogonal_noise(model, train_loader, test_loader, noisy_labels, lam):
    """構造勾配に直交ノイズを加える（直交ノイズ仮説の検証）"""
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)
    criterion = nn.CrossEntropyLoss()
    noisy_labels_tensor = torch.tensor(noisy_labels, device=device)
    
    param_dim = sum(p.numel() for p in model.parameters())
    best_acc = 0
    
    for epoch in range(EPOCHS):
        model.train()
        for inputs, _, indices in train_loader:
            inputs = inputs.to(device, non_blocking=True)
            indices = indices.to(device, non_blocking=True)
            batch_labels = noisy_labels_tensor[indices]
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            g_struct = parameters_to_vector([p.grad for p in model.parameters()]).clone()
            g_struct_norm = g_struct / (g_struct.norm() + 1e-12)
            
            # 直交ノイズを生成
            random_vec = torch.randn(param_dim, device=device)
            proj = (random_vec @ g_struct_norm) * g_struct_norm
            orthogonal = random_vec - proj
            orthogonal_norm = orthogonal / (orthogonal.norm() + 1e-12)
            
            # 混合（同じλで）
            g_mix = (1 - lam) * g_struct_norm + lam * orthogonal_norm
            
            optimizer.zero_grad()
            idx = 0
            for p in model.parameters():
                numel = p.numel()
                p.grad = g_mix[idx:idx+numel].view(p.shape).clone()
                idx += numel
            optimizer.step()
        
        scheduler.step()
        if (epoch + 1) % 10 == 0:
            best_acc = max(best_acc, evaluate(model, test_loader))
    
    final_acc = evaluate(model, test_loader)
    return final_acc, max(best_acc, final_acc), {}

In [None]:
# ===== データ準備 =====
trainset, testset = load_cifar10()
clean_labels = np.array(trainset.targets)
train_loader, test_loader = get_data_loaders(trainset, testset)

# ノイズラベルとランダムラベルを事前生成
noisy_labels = inject_noise(clean_labels, NOISE_RATE, seed=0)
random_labels = generate_random_labels(clean_labels, seed=0)

print('Data prepared')
print(f'  Noisy labels: {np.mean(noisy_labels != clean_labels)*100:.1f}% corrupted')
print(f'  Random labels: {np.mean(random_labels != clean_labels)*100:.1f}% different from clean')

# GPU warmup
warmup_model = ResNet18().to(device)
for _ in range(20):
    _ = warmup_model(torch.randn(BATCH_SIZE, 3, 32, 32, device=device))
del warmup_model
torch.cuda.empty_cache()
print('Warmup complete.')

In [None]:
# ===== ||g_mix|| の事前測定（LR_matched用） =====
# Misaligned_rawでの平均||g_mix||を測定して、LR_matchedのスケールを決定

print('Measuring ||g_mix|| for LR matching...')
gmix_norm_by_lambda = {}

for lam in LAMBDAS:
    set_seed(0)
    model = ResNet18().to(device)
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    
    noisy_labels_tensor = torch.tensor(noisy_labels, device=device)
    random_labels_tensor = torch.tensor(random_labels, device=device)
    
    norms = []
    for batch_idx, (inputs, _, indices) in enumerate(train_loader):
        if batch_idx >= 10:  # 最初の10バッチで推定
            break
        inputs = inputs.to(device)
        indices = indices.to(device)
        batch_noisy = noisy_labels_tensor[indices]
        batch_random = random_labels_tensor[indices]
        
        optimizer.zero_grad()
        outputs = model(inputs)
        criterion(outputs, batch_noisy).backward(retain_graph=True)
        g_struct = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        optimizer.zero_grad()
        outputs = model(inputs)
        criterion(outputs, batch_random).backward()
        g_value = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        g_struct_norm = g_struct / (g_struct.norm() + 1e-12)
        g_value_norm = g_value / (g_value.norm() + 1e-12)
        g_mix = (1 - lam) * g_struct_norm + lam * g_value_norm
        norms.append(g_mix.norm().item())
    
    gmix_norm_by_lambda[lam] = np.mean(norms)
    print(f'  λ={lam}: avg ||g_mix|| = {gmix_norm_by_lambda[lam]:.4f}')
    del model

torch.cuda.empty_cache()

In [None]:
# ===== メイン実験ループ =====
results = []
checkpoint_file = f'{SAVE_DIR}/checkpoint.json'
completed = set()

if os.path.exists(checkpoint_file):
    with open(checkpoint_file, 'r') as f:
        results = json.load(f)
    for r in results:
        completed.add((r['condition'], r.get('lambda'), r['seed']))
    print(f'Checkpoint loaded: {len(completed)} runs completed')

run_counter = 0
exp_start = time.time()

for exp in experiments:
    run_counter += 1
    cond_name = exp['name']
    method = exp['method']
    lam = exp.get('lambda')
    seed = exp['seed']
    
    key = (cond_name, lam, seed)
    if key in completed:
        continue
    
    lam_str = f'λ={lam:.1f}' if lam else 'N/A'
    print(f'\n[{run_counter}/{total_runs}] {cond_name} {lam_str} seed={seed}')
    t0 = time.time()
    
    set_seed(seed)
    model = ResNet18().to(device)
    
    # 条件に応じて学習関数を選択
    if method == 'ce':
        final_acc, best_acc, extra = train_ce_baseline(model, train_loader, test_loader, noisy_labels)
    elif method == 'misaligned':
        renorm = exp.get('renorm', False)
        final_acc, best_acc, extra = train_misaligned(
            model, train_loader, test_loader, noisy_labels, random_labels, lam, renorm
        )
    elif method == 'lr_matched':
        lr_scale = gmix_norm_by_lambda[lam]
        final_acc, best_acc, extra = train_lr_matched(
            model, train_loader, test_loader, noisy_labels, lr_scale
        )
    elif method == 'orthogonal_noise':
        final_acc, best_acc, extra = train_orthogonal_noise(
            model, train_loader, test_loader, noisy_labels, lam
        )
    
    elapsed = time.time() - t0
    
    result = {
        'experiment': EXP_NAME,
        'condition': cond_name,
        'method': method,
        'lambda': lam,
        'seed': seed,
        'noise_rate': NOISE_RATE,
        'test_acc': final_acc,
        'test_error': 1 - final_acc,
        'best_acc': best_acc,
        'best_error': 1 - best_acc,
        'time_seconds': elapsed,
        **extra
    }
    results.append(result)
    
    with open(checkpoint_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f'  Error: {1-final_acc:.4f} | Time: {elapsed/60:.1f} min')
    
    completed_count = len(results)
    avg_time = sum([r['time_seconds'] for r in results]) / completed_count
    remaining = total_runs - completed_count
    eta_hours = (remaining * avg_time) / 3600
    print(f'  Progress: {completed_count}/{total_runs} | ETA: {eta_hours:.1f} hours')

print(f'\n{"="*60}')
print(f'EXPERIMENT COMPLETE')
print(f'Total time: {(time.time()-exp_start)/3600:.2f} hours')
print(f'{"="*60}')

In [None]:
# ===== 結果保存 =====
import pandas as pd

with open(f'{SAVE_DIR}/results.json', 'w') as f:
    json.dump(results, f, indent=2)

df = pd.DataFrame(results)
df.to_csv(f'{SAVE_DIR}/results.csv', index=False)

print(f'Results saved to {SAVE_DIR}/')

In [None]:
# ===== 機構分析 =====
import pandas as pd

df = pd.DataFrame(results)

print('='*70)
print('MECHANISM DECOMPOSITION ANALYSIS')
print('='*70)

# CE baseline
ce_error = df[df['condition'] == 'CE_baseline']['test_error'].mean()
print(f'\nCE Baseline Error: {ce_error:.4f}')

print('\n--- Comparison by λ ---')
for lam in LAMBDAS:
    print(f'\nλ = {lam}:')
    
    df_l = df[df['lambda'] == lam]
    
    mis_raw = df_l[df_l['condition'] == 'Misaligned_raw']['test_error'].mean()
    mis_renorm = df_l[df_l['condition'] == 'Misaligned_renorm']['test_error'].mean()
    lr_matched = df_l[df_l['condition'] == 'LR_matched']['test_error'].mean()
    orth_noise = df_l[df_l['condition'] == 'Orthogonal_noise']['test_error'].mean()
    
    print(f'  Misaligned_raw:   {mis_raw:.4f} (Δ vs CE: {mis_raw - ce_error:+.4f})')
    print(f'  Misaligned_renorm: {mis_renorm:.4f} (Δ vs CE: {mis_renorm - ce_error:+.4f})')
    print(f'  LR_matched:       {lr_matched:.4f} (Δ vs CE: {lr_matched - ce_error:+.4f})')
    print(f'  Orthogonal_noise: {orth_noise:.4f} (Δ vs CE: {orth_noise - ce_error:+.4f})')
    
    # 判定ロジック
    print('\n  Interpretation:')
    if mis_raw < ce_error and mis_renorm >= ce_error:
        print('  → 暗黙LR低下が主因（renormで改善が消えた）')
    elif mis_raw < ce_error and mis_renorm < ce_error:
        if orth_noise < ce_error:
            print('  → 直交ノイズ効果が主因（orthogonal_noiseで再現）')
        else:
            print('  → 幾何効果（未解明の相互作用）')
    elif mis_raw >= ce_error:
        print('  → Misalignedによる改善なし')

In [None]:
# ===== 可視化 =====
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(12, 6))

# CE baseline
ce_error = df[df['condition'] == 'CE_baseline']['test_error'].mean()
ce_std = df[df['condition'] == 'CE_baseline']['test_error'].std()

conditions_to_plot = ['Misaligned_raw', 'Misaligned_renorm', 'LR_matched', 'Orthogonal_noise']
colors = ['C0', 'C1', 'C2', 'C3']
x_positions = np.arange(len(LAMBDAS))
width = 0.2

for i, cond in enumerate(conditions_to_plot):
    errors = []
    stds = []
    for lam in LAMBDAS:
        df_subset = df[(df['condition'] == cond) & (df['lambda'] == lam)]
        errors.append(df_subset['test_error'].mean())
        stds.append(df_subset['test_error'].std())
    
    ax.bar(x_positions + i*width, errors, width, yerr=stds, 
           label=cond, color=colors[i], capsize=3, alpha=0.8)

# CE baseline line
ax.axhline(y=ce_error, color='red', linestyle='--', linewidth=2, label=f'CE baseline ({ce_error:.3f})')
ax.fill_between([-0.5, len(LAMBDAS)], ce_error-ce_std, ce_error+ce_std, color='red', alpha=0.1)

ax.set_xlabel('λ', fontsize=12)
ax.set_ylabel('Test Error', fontsize=12)
ax.set_title('Mechanism Decomposition: Why Misaligned Improves?', fontsize=13, fontweight='bold')
ax.set_xticks(x_positions + 1.5*width)
ax.set_xticklabels([str(l) for l in LAMBDAS])
ax.legend(loc='upper right')
ax.grid(alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/figures/mechanism_decomposition.png', dpi=300, bbox_inches='tight')
plt.show()
print(f'Figure saved: {SAVE_DIR}/figures/mechanism_decomposition.png')

In [None]:
# ===== サマリー =====
print('\n' + '='*70)
print('EXPERIMENT J: MECHANISM DECOMPOSITION - SUMMARY')
print('='*70)
print(f'\nSave directory: {SAVE_DIR}')