# Stage 3 모델 비교 (Baseline 동등 조건): Attention vs Transformer MIL

Baseline과 완전히 동일한 조건(데이터/하이퍼파라미터/평가)에서 모델(Architecture)만 달리하여 성능을 비교합니다.

- 데이터: Stage 2에서 생성한 bag (baseline 동일 스냅샷 강제)
- 학습/평가: baseline과 동일 설정 (WeightedBCE, Adam, ReduceLROnPlateau, EarlyStopping 등)
- 비교 모델: AttentionMIL vs TransformerMIL

참고: 이 노트북은 내부 유틸을 재사용하기 위해 `experiments/arcface/agent/stage3_baseline_transformer.py` 모듈을 임포트하여 실행합니다.

In [None]:
# 모듈 임포트 및 환경 확인
import torch
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import experiments.arcface.agent.stage3_baseline_transformer as exp

print('Using device:', exp.device)
if exp.device.type == 'cuda':
    print('GPU:', torch.cuda.get_device_name(0))

# 동일 시드 적용
exp.seed_everything(42)

# 동일 데이터 로드 (baseline 데이터 크기 검증 포함)
train_loader, val_loader, test_loader = exp.load_data_loaders(batch_size=16)
print('Data loaders ready.')


In [None]:
# 학습 설정 (baseline 동일)
criterion = exp.WeightedBCE(fp_weight=2.0)
learning_rate = 1e-3
max_epochs = 10
patience = 3
scheduler_patience = 1

results = {}
histories = {}

print('🔬 모델 비교 실험 (Baseline과 완전 동일 조건)')
print('='*60)
print(f'손실 함수: WeightedBCE(fp_weight=2.0)')
print(f'학습률: {learning_rate}')
print(f'최대 에포크: {max_epochs}, Patience: {patience}')
print(f'Scheduler Patience: {scheduler_patience}')
print('='*60)

# 1) AttentionMIL
exp.seed_everything(42)
att_model = exp.AttentionMIL(input_dim=256, hidden_dim=128, dropout_p=0.1).to(exp.device)
att_opt = torch.optim.Adam(att_model.parameters(), lr=learning_rate)
att_sch = torch.optim.lr_scheduler.ReduceLROnPlateau(att_opt, mode='max', factor=0.5, patience=scheduler_patience, verbose=True)
att_model, att_hist = exp.train_model(att_model, att_opt, att_sch, train_loader, val_loader, criterion, max_epochs=max_epochs, patience=patience, name='attention_mil')
att_val = exp.evaluate(att_model, val_loader, criterion)
att_tst = exp.evaluate(att_model, test_loader, criterion)
results['Attention'] = {'val': att_val, 'test': att_tst}
histories['Attention'] = att_hist

# 2) TransformerMIL
exp.seed_everything(42)
tr_model = exp.TransformerMIL(input_dim=256, hidden_dim=128, num_heads=4, num_layers=2, dropout_p=0.1).to(exp.device)
tr_opt = torch.optim.Adam(tr_model.parameters(), lr=learning_rate)
tr_sch = torch.optim.lr_scheduler.ReduceLROnPlateau(tr_opt, mode='max', factor=0.5, patience=scheduler_patience, verbose=True)
tr_model, tr_hist = exp.train_model(tr_model, tr_opt, tr_sch, train_loader, val_loader, criterion, max_epochs=max_epochs, patience=patience, name='transformer_mil')
tr_val = exp.evaluate(tr_model, val_loader, criterion)
tr_tst = exp.evaluate(tr_model, test_loader, criterion)
results['Transformer'] = {'val': tr_val, 'test': tr_tst}
histories['Transformer'] = tr_hist

print('✅ 두 모델 학습/평가 완료')


In [None]:
# 최종 리포트 (validation 최적 임계값을 test에 적용)
def find_best_threshold(probs, labels):
    return exp.find_best_threshold(probs, labels)

final_results = {}
print('
📊 모델별 최종 성능 비교')
print('='*80)
for name, res in results.items():
    val_res, tst_res = res['val'], res['test']
    thr, best_f1_val = find_best_threshold(val_res['probs'], val_res['labels'])
    test_preds_adj = (tst_res['probs'] >= thr).astype(int)
    acc = accuracy_score(tst_res['labels'], test_preds_adj)
    f1 = f1_score(tst_res['labels'], test_preds_adj, zero_division=0)
    prec = precision_score(tst_res['labels'], test_preds_adj, zero_division=0)
    rec = recall_score(tst_res['labels'], test_preds_adj, zero_division=0)
    auc_v = tst_res['auc']
    final_results[name] = {
+        'threshold': thr, 'accuracy': acc, 'f1': f1, 'precision': prec, 'recall': rec, 'auc': auc_v,
+        'test_probs': tst_res['probs'], 'test_labels': tst_res['labels'], 'test_preds_adj': test_preds_adj,
+    }
    print(f"\n{name}:")
    print(f'  최적 임계값: {thr:.3f} (Val F1: {best_f1_val:.3f})')
    print(f'  Test Accuracy: {acc:.3f}')
    print(f'  Test F1: {f1:.3f}')
    print(f'  Test Precision: {prec:.3f}')
    print(f'  Test Recall: {rec:.3f}')
    print(f'  Test AUC: {auc_v:.3f}')

print('
' + '='*80)
print('📈 모델 성능 요약 테이블')
print('='*80)
print(f"{'Model':<15} {'Accuracy':<10} {'F1':<8} {'Precision':<11} {'Recall':<8} {'AUC':<8}")
print('-'*80)
for name, r in final_results.items():
    print(f"{name:<15} {r['accuracy']:<10.3f} {r['f1']:<8.3f} {r['precision']:<11.3f} {r['recall']:<8.3f} {r['auc']:<8.3f}")

best_auc = max(final_results.items(), key=lambda x: x[1]['auc']) if final_results else (None, None)
best_f1  = max(final_results.items(), key=lambda x: x[1]['f1']) if final_results else (None, None)
print('
🏆 최고 성능:')
if best_auc[0] is not None:
    print(f"  AUC 기준: {best_auc[0]} (AUC: {best_auc[1]['auc']:.3f})")
if best_f1[0] is not None:
    print(f"  F1 기준:  {best_f1[0]} (F1: {best_f1[1]['f1']:.3f})")
