# Exp B3: Ablation Study - Disentangling Direction vs Step Size

## 目的
ソフィアの指摘に基づき、以下を切り分ける：
1. **g_mix再正規化**の効果（暗黙のLR低下を除去）
2. **キャッシュK=1 vs K=16**の効果
3. **真のAnti-aligned（-g_clean）**の効果

### 実験条件
- λ: 0.2, 0.3（最も効果が見える範囲）
- Noise: 40%のみ
- Seeds: 0, 1, 2

### 条件（4種類）
1. **Baseline**: Aligned + renorm=False + K=16（現状のexp_C相当）
2. **+Renorm**: Aligned + renorm=True + K=16
3. **+Renorm+NoCache**: Aligned + renorm=True + K=1
4. **True Anti**: g_value = -g_clean + renorm=True + K=1

### runs数: 24 runs (4条件 × 2λ × 3seeds)
### 推定時間: ~3.5時間

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
SAVE_DIR = '/content/drive/MyDrive/dual-gradient-results/exp_B3_ablation'
os.makedirs(SAVE_DIR, exist_ok=True)
print(f'Save directory: {SAVE_DIR}')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils import parameters_to_vector
import torchvision
import torchvision.transforms as transforms
import numpy as np
import json
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'Device: {device}')

In [None]:
class IndexedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return img, label, idx
    def __len__(self):
        return len(self.dataset)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, 1, stride, bias=False),
                nn.BatchNorm2d(planes))
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(64, 2, 1)
        self.layer2 = self._make_layer(128, 2, 2)
        self.layer3 = self._make_layer(256, 2, 2)
        self.layer4 = self._make_layer(512, 2, 2)
        self.linear = nn.Linear(512, num_classes)
    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_planes, planes, s))
            self.in_planes = planes
        return nn.Sequential(*layers)
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.adaptive_avg_pool2d(out, 1)
        out = out.view(out.size(0), -1)
        return self.linear(out)

In [None]:
BATCH_SIZE = 256
NUM_WORKERS = 4
EPOCHS = 100
LR = 0.1
NOISE_RATE = 0.4  # 固定

LAMBDAS = [0.2, 0.3]
SEEDS = [0, 1, 2]

# 4つの実験条件
CONDITIONS = [
    {'name': 'Baseline',      'value_type': 'aligned', 'renorm': False, 'K': 16},
    {'name': '+Renorm',       'value_type': 'aligned', 'renorm': True,  'K': 16},
    {'name': '+Renorm+K1',    'value_type': 'aligned', 'renorm': True,  'K': 1},
    {'name': 'TrueAnti+K1',   'value_type': 'anti',    'renorm': True,  'K': 1},
]

total_runs = len(CONDITIONS) * len(LAMBDAS) * len(SEEDS)
print(f'Total runs: {total_runs}')
print(f'Estimated time: {total_runs * 9.2 / 60:.1f} hours')

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

def load_cifar10():
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    return trainset, testset

def get_data_loaders(trainset, testset):
    indexed_trainset = IndexedDataset(trainset)
    train_loader = DataLoader(indexed_trainset, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=True, drop_last=True)
    test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False,
                             num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=True)
    return train_loader, test_loader

def inject_noise(labels, noise_rate, seed):
    np.random.seed(seed + 1000)
    noisy_labels = labels.copy()
    n_noisy = int(noise_rate * len(labels))
    noisy_indices = np.random.choice(len(labels), n_noisy, replace=False)
    for idx in noisy_indices:
        choices = [i for i in range(10) if i != labels[idx]]
        noisy_labels[idx] = np.random.choice(choices)
    return noisy_labels

In [None]:
def evaluate(model, test_loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return correct / total

def train_dual_gradient_ablation(model, train_loader, test_loader, 
                                  clean_labels, noisy_labels, lam,
                                  value_type='aligned', renorm=False, K=16):
    """
    切り分け実験用の統一訓練関数
    
    Args:
        value_type: 'aligned' (g_clean) or 'anti' (-g_clean)
        renorm: g_mixを再正規化するか
        K: 勾配キャッシュの更新頻度（1=毎バッチ計算）
    """
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)
    criterion = nn.CrossEntropyLoss()
    
    clean_labels_tensor = torch.tensor(clean_labels, device=device)
    noisy_labels_tensor = torch.tensor(noisy_labels, device=device)
    
    cached_value_grad = None
    global_step = 0
    best_acc = 0
    
    # ロギング用
    cos_log = []
    norm_log = []
    
    for epoch in range(EPOCHS):
        model.train()
        for inputs, _, indices in train_loader:
            inputs = inputs.to(device, non_blocking=True)
            indices = indices.to(device, non_blocking=True)
            batch_noisy = noisy_labels_tensor[indices]
            batch_clean = clean_labels_tensor[indices]
            
            # ===== 構造勾配（ノイズラベルから）=====
            optimizer.zero_grad()
            outputs = model(inputs)
            loss_struct = criterion(outputs, batch_noisy)
            loss_struct.backward(retain_graph=True)
            g_struct = parameters_to_vector([p.grad for p in model.parameters()]).clone()
            
            # ===== 価値勾配（条件に応じて計算）=====
            if global_step % K == 0 or cached_value_grad is None:
                optimizer.zero_grad()
                outputs = model(inputs)
                loss_value = criterion(outputs, batch_clean)  # 真のラベルで計算
                loss_value.backward()
                g_clean = parameters_to_vector([p.grad for p in model.parameters()]).clone()
                
                if value_type == 'aligned':
                    cached_value_grad = g_clean
                elif value_type == 'anti':
                    # ★★★ 真のAnti-aligned: -g_clean ★★★
                    cached_value_grad = -g_clean
            
            # ===== 勾配の正規化と混合 =====
            g_struct_norm = g_struct / (g_struct.norm() + 1e-12)
            g_value_norm = cached_value_grad / (cached_value_grad.norm() + 1e-12)
            g_mix = (1 - lam) * g_struct_norm + lam * g_value_norm
            
            # ★★★ 再正規化オプション ★★★
            if renorm:
                g_mix = g_mix / (g_mix.norm() + 1e-12)
            
            # ロギング（最初の100ステップのみ）
            if global_step < 100:
                cos_sim = F.cosine_similarity(g_struct_norm.unsqueeze(0), g_value_norm.unsqueeze(0)).item()
                mix_norm = g_mix.norm().item()
                cos_log.append(cos_sim)
                norm_log.append(mix_norm)
            
            # パラメータ更新
            optimizer.zero_grad()
            idx = 0
            for p in model.parameters():
                numel = p.numel()
                p.grad = g_mix[idx:idx+numel].view(p.shape).clone()
                idx += numel
            optimizer.step()
            global_step += 1
        
        scheduler.step()
        if (epoch + 1) % 10 == 0:
            best_acc = max(best_acc, evaluate(model, test_loader))
    
    final_acc = evaluate(model, test_loader)
    
    # ロギング結果
    avg_cos = np.mean(cos_log) if cos_log else 0
    avg_norm = np.mean(norm_log) if norm_log else 0
    
    return final_acc, max(best_acc, final_acc), avg_cos, avg_norm

In [None]:
trainset, testset = load_cifar10()
clean_labels = np.array(trainset.targets)
train_loader, test_loader = get_data_loaders(trainset, testset)

print('Verifying IndexedDataset...')
sample_batch = next(iter(train_loader))
print(f'  ✓ indices shape: {list(sample_batch[2].shape)}')

print('\nWarming up GPU...')
warmup_model = ResNet18().to(device)
for _ in range(20):
    _ = warmup_model(torch.randn(BATCH_SIZE, 3, 32, 32, device=device))
del warmup_model
torch.cuda.empty_cache()
print('Warmup complete.')

In [None]:
results = []
checkpoint_file = f'{SAVE_DIR}/checkpoint.json'
completed = set()

if os.path.exists(checkpoint_file):
    with open(checkpoint_file, 'r') as f:
        results = json.load(f)
    for r in results:
        completed.add((r['condition'], r['lambda'], r['seed']))
    print(f'Checkpoint loaded: {len(completed)} runs')

run_counter = 0
exp_start = time.time()

for cond in CONDITIONS:
    print(f'\n{"="*60}')
    print(f'CONDITION: {cond["name"]}')
    print(f'  value_type={cond["value_type"]}, renorm={cond["renorm"]}, K={cond["K"]}')
    print(f'{"="*60}')
    
    for lam in LAMBDAS:
        for seed in SEEDS:
            run_counter += 1
            key = (cond['name'], lam, seed)
            
            if key in completed:
                print(f'[{run_counter}/{total_runs}] {cond["name"]} λ={lam} seed={seed} - SKIPPED')
                continue
            
            print(f'\n[{run_counter}/{total_runs}] {cond["name"]} λ={lam} seed={seed}')
            t0 = time.time()
            
            set_seed(seed)
            noisy_labels = inject_noise(clean_labels, NOISE_RATE, seed)
            model = ResNet18().to(device)
            
            final_acc, best_acc, avg_cos, avg_norm = train_dual_gradient_ablation(
                model, train_loader, test_loader,
                clean_labels, noisy_labels, lam,
                value_type=cond['value_type'],
                renorm=cond['renorm'],
                K=cond['K']
            )
            elapsed = time.time() - t0
            
            results.append({
                'experiment': 'exp_B3_ablation',
                'condition': cond['name'],
                'value_type': cond['value_type'],
                'renorm': cond['renorm'],
                'K': cond['K'],
                'lambda': lam,
                'noise_rate': NOISE_RATE,
                'seed': seed,
                'test_acc': final_acc,
                'test_error': 1 - final_acc,
                'best_test_error': 1 - best_acc,
                'avg_cos_similarity': avg_cos,
                'avg_gmix_norm': avg_norm,
                'time_seconds': elapsed
            })
            
            with open(checkpoint_file, 'w') as f:
                json.dump(results, f, indent=2)
            
            # 判定
            ce_baseline = 0.38
            if (1 - final_acc) < ce_baseline * 0.7:
                status = ' ✅ IMPROVED'
            elif (1 - final_acc) > ce_baseline:
                status = ' ⚠️ DEGRADED'
            else:
                status = ' ~ marginal'
            
            print(f'  Error: {1-final_acc:.4f} | Best: {1-best_acc:.4f} | cos: {avg_cos:.4f} | ||g_mix||: {avg_norm:.4f} | Time: {elapsed/60:.1f} min{status}')

print(f'\n{"="*60}')
print(f'EXPERIMENT COMPLETE')
print(f'Total time: {(time.time()-exp_start)/3600:.2f} hours')
print(f'{"="*60}')

In [None]:
import pandas as pd

with open(f'{SAVE_DIR}/exp_B3_results.json', 'w') as f:
    json.dump(results, f, indent=2)

df = pd.DataFrame(results)

print('\n' + '='*70)
print('ABLATION STUDY SUMMARY')
print('='*70)

summary = df.groupby(['condition', 'lambda']).agg({
    'test_error': ['mean', 'std'],
    'avg_cos_similarity': 'mean',
    'avg_gmix_norm': 'mean'
}).round(4)
print(summary)
summary.to_csv(f'{SAVE_DIR}/exp_B3_summary.csv')

# 切り分け分析
print('\n' + '='*70)
print('KEY COMPARISONS')
print('='*70)

print('\n【Q1: 再正規化の効果（暗黙のLR低下を除去）】')
for lam in LAMBDAS:
    baseline = df[(df['condition']=='Baseline') & (df['lambda']==lam)]['test_error'].mean()
    renorm = df[(df['condition']=='+Renorm') & (df['lambda']==lam)]['test_error'].mean()
    print(f'  λ={lam}: Baseline={baseline:.3f}, +Renorm={renorm:.3f}, Δ={renorm-baseline:+.3f}')

print('\n【Q2: キャッシュK=1の効果】')
for lam in LAMBDAS:
    renorm = df[(df['condition']=='+Renorm') & (df['lambda']==lam)]['test_error'].mean()
    renorm_k1 = df[(df['condition']=='+Renorm+K1') & (df['lambda']==lam)]['test_error'].mean()
    print(f'  λ={lam}: +Renorm(K=16)={renorm:.3f}, +Renorm+K1={renorm_k1:.3f}, Δ={renorm_k1-renorm:+.3f}')

print('\n【Q3: 真のAnti-aligned（-g_clean）の効果】')
for lam in LAMBDAS:
    aligned = df[(df['condition']=='+Renorm+K1') & (df['lambda']==lam)]['test_error'].mean()
    anti = df[(df['condition']=='TrueAnti+K1') & (df['lambda']==lam)]['test_error'].mean()
    print(f'  λ={lam}: Aligned={aligned:.3f}, TrueAnti={anti:.3f}, Δ={anti-aligned:+.3f}')
    if anti > 0.38:
        print(f'         ✅ TrueAnti is WORSE than CE baseline (0.38) as expected!')

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Plot 1: 条件別のTest Error
ax1 = axes[0]
for cond in CONDITIONS:
    df_c = df[df['condition'] == cond['name']]
    g = df_c.groupby('lambda')['test_error'].agg(['mean', 'std'])
    ax1.errorbar(g.index, g['mean'], yerr=g['std'], marker='o', capsize=4, 
                 linewidth=2, markersize=8, label=cond['name'])
ax1.axhline(y=0.38, color='gray', linestyle='--', alpha=0.5, label='CE baseline')
ax1.set_xlabel('λ', fontsize=12)
ax1.set_ylabel('Test Error', fontsize=12)
ax1.set_title('Ablation: Test Error by Condition', fontsize=13)
ax1.legend()
ax1.grid(alpha=0.3)

# Plot 2: ||g_mix|| の比較
ax2 = axes[1]
for cond in CONDITIONS:
    df_c = df[df['condition'] == cond['name']]
    g = df_c.groupby('lambda')['avg_gmix_norm'].mean()
    ax2.plot(g.index, g.values, marker='s', linewidth=2, markersize=8, label=cond['name'])
ax2.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='Unit norm')
ax2.set_xlabel('λ', fontsize=12)
ax2.set_ylabel('||g_mix||', fontsize=12)
ax2.set_title('Gradient Norm (Step Size)', fontsize=13)
ax2.legend()
ax2.grid(alpha=0.3)

# Plot 3: cos(g_struct, g_value) の比較
ax3 = axes[2]
for cond in CONDITIONS:
    df_c = df[df['condition'] == cond['name']]
    g = df_c.groupby('lambda')['avg_cos_similarity'].mean()
    ax3.plot(g.index, g.values, marker='^', linewidth=2, markersize=8, label=cond['name'])
ax3.axhline(y=0, color='gray', linestyle='--', alpha=0.5, label='Orthogonal')
ax3.set_xlabel('λ', fontsize=12)
ax3.set_ylabel('cos(g_struct, g_value)', fontsize=12)
ax3.set_title('Gradient Alignment', fontsize=13)
ax3.legend()
ax3.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/exp_B3_ablation_plot.png', dpi=300, bbox_inches='tight')
plt.show()

print(f'Figure saved to: {SAVE_DIR}/exp_B3_ablation_plot.png')