# Stage 3 Final DS-MIL: Fair Comparison with Baseline

이 노트북은 Stage 2에서 생성한 MIL Bag 데이터를 입력으로 받아 DS-MIL (Dual-Stream Multiple Instance Learning) 모델을 학습하고 baseline AttentionMIL과 공정한 비교를 수행합니다.

## 실험 목적
- **공정한 비교**: baseline AttentionMIL과 동일한 조건에서 DS-MIL 성능 평가
- **DS-MIL 효과 검증**: dual-stream architecture가 단일 attention 방식보다 우수한지 확인
- **아키텍처 분석**: critical instance identification과 cross-attention의 효과 분석

## 실험 조건 (Baseline과 동일)
1. **데이터셋**: `bags_arcface_margin_0.4_50p_random_*.pkl` (baseline과 동일)
2. **손실함수**: WeightedBCE(fp_weight=2.0)
3. **최적화**: Adam optimizer (lr=1e-3), ReduceLROnPlateau scheduler
4. **임계값 최적화**: validation set에서 F1 score 기준으로 최적 threshold 선택
5. **평가 지표**: Accuracy, F1, Precision, Recall, AUC

## DS-MIL vs AttentionMIL 아키텍처 차이점

### AttentionMIL (Baseline)
```
instances → attention → bag_representation → classifier → prediction
```

### DS-MIL (This Work)
```
instances → instance_classifier → critical_instance_identification
         ↓
         cross_attention → bag_representation → bag_classifier → prediction
```

**핵심 차이점:**
- **Dual Loss**: bag-level + instance-level loss 동시 사용
- **Critical Instance**: 각 bag에서 가장 중요한 instance 자동 식별
- **Cross-Attention**: critical instance를 query로 사용하는 attention mechanism

In [1]:
# 환경 설정
import os
import random
import pickle
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, roc_curve, confusion_matrix
from tqdm import tqdm

# GPU 설정
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = os.getenv('MIL_STAGE3_GPU', '3')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
if device.type == 'cuda':
    print('GPU:', torch.cuda.get_device_name(0))
else:
    print('CUDA를 사용할 수 없습니다. CPU 모드로 실행됩니다.')

# 시드 고정
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


Using device: cuda
GPU: NVIDIA GeForce RTX 3090


In [None]:
# Stage 2 Bag 데이터 로드 및 Instance 평균 계산
# baseline과 동일한 데이터셋 사용 (random)
embedding_margin = '0.4'
bags_dir = '/workspace/MIL/data/processed/bags'
train_pkl = os.path.join(bags_dir, f'bags_arcface_margin_{embedding_margin}_50p_random_train.pkl')
val_pkl   = os.path.join(bags_dir, f'bags_arcface_margin_{embedding_margin}_50p_random_val.pkl')
test_pkl  = os.path.join(bags_dir, f'bags_arcface_margin_{embedding_margin}_50p_random_test.pkl')

print('Loading MIL bags...')
try:
    with open(train_pkl, 'rb') as f:
        train_data = pickle.load(f)
    with open(val_pkl, 'rb') as f:
        val_data = pickle.load(f)
    with open(test_pkl, 'rb') as f:
        test_data = pickle.load(f)
except FileNotFoundError as e:
    print(f'Error loading data files: {e}')
    raise
except Exception as e:
    print(f'Unexpected error loading data: {e}')
    raise

# Instance mean 계산: (10,5,256) → (10,256)
# 각 bag의 instance들을 평균내어 더 간단한 representation 생성
def to_instance_means(bags):
    """각 bag의 instance들을 평균내어 (num_instances, embedding_dim) 형태로 변환"""
    return [bag.mean(axis=1).astype(np.float32) for bag in bags]

train_features = to_instance_means(train_data['bags'])
val_features   = to_instance_means(val_data['bags'])
test_features  = to_instance_means(test_data['bags'])

train_labels = train_data['labels']
val_labels   = val_data['labels']
test_labels  = test_data['labels']

print(f'Train bags: {len(train_labels)}, Val bags: {len(val_labels)}, Test bags: {len(test_labels)}')
print(f'Class distribution in train: {np.bincount(train_labels)}')
print(f'Class distribution in val: {np.bincount(val_labels)}')
print(f'Class distribution in test: {np.bincount(test_labels)}')

In [3]:
# Dataset 클래스 (on‑the‑fly Tensor 변환)

class MILDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features  # list of np.ndarray
        self.labels = labels      # list of int
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        x = torch.tensor(self.features[idx], dtype=torch.float32)
        y = torch.tensor(self.labels[idx], dtype=torch.float32)
        return x, y

batch_size = 16
train_loader = DataLoader(MILDataset(train_features, train_labels), batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(MILDataset(val_features,   val_labels),   batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(MILDataset(test_features,  test_labels),  batch_size=batch_size, shuffle=False)


In [None]:
# 모델 정의 및 초기화 함수 – DSMIL (Dual‑Stream Multiple Instance Learning)

class DSMIL(nn.Module):
    """
    Dual‑Stream MIL model based on the DSMIL architecture.
    
    DS-MIL은 두 가지 스트림을 사용합니다:
    1. Instance-level classifier: 각 instance의 중요도를 예측
    2. Bag-level classifier: attention을 통해 전체 bag을 분류
    
    주요 특징:
    - Critical instance identification: instance classifier로 가장 중요한 instance 식별
    - Cross-attention: critical instance를 기준으로 다른 instance들의 가중치 계산
    - Dual loss: bag-level과 instance-level loss를 모두 사용하여 학습
    """
    def __init__(self, input_dim=256, att_dim=128, dropout=0.1):
        super().__init__()
        # Instance classifier: predicts a score per instance
        self.instance_fc = nn.Linear(input_dim, 1)
        # Query network for attention
        self.q_net = nn.Sequential(
            nn.Linear(input_dim, att_dim),
            nn.ReLU(),
            nn.Linear(att_dim, att_dim),
            nn.Tanh(),
        )
        # Value network (with dropout)
        self.v_net = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
        )
        # Bag classifier: maps aggregated representation to a single logit
        self.bag_fc = nn.Linear(input_dim, 1)

    def forward(self, x):
        """
        Forward pass.
        Args:
            x (Tensor): bag of instance features with shape (batch_size, num_instances, input_dim)
        Returns:
            bag_logits (Tensor): shape (batch_size,), logits for bag labels
            inst_logits (Tensor): shape (batch_size, num_instances), logits for instances
            att_weights (Tensor): shape (batch_size, num_instances), attention weights
        """
        # Instance logits for each instance in the bag
        inst_logits = self.instance_fc(x).squeeze(-1)  # (batch_size, num_instances)

        # Identify the critical instance for each bag via the highest instance logit
        top_indices = torch.argmax(inst_logits, dim=1)  # (batch_size,)

        # Compute query vectors for all instances
        Q = self.q_net(x)  # (batch_size, num_instances, att_dim)

        # Extract the features of the critical instances and compute their queries
        batch_indices = torch.arange(x.size(0), device=x.device)
        m_feats = x[batch_indices, top_indices]  # (batch_size, input_dim)
        q_max = self.q_net(m_feats)  # (batch_size, att_dim)

        # Compute attention scores via the inner product between Q and q_max
        att_scores = torch.bmm(Q, q_max.unsqueeze(-1)).squeeze(-1)  # (batch_size, num_instances)

        # Normalize attention scores
        att_weights = torch.softmax(att_scores / (Q.size(-1) ** 0.5), dim=1)  # (batch_size, num_instances)

        # Compute value representations
        V = self.v_net(x)  # (batch_size, num_instances, input_dim)

        # Aggregate the values using attention weights
        bag_repr = torch.bmm(att_weights.unsqueeze(1), V).squeeze(1)  # (batch_size, input_dim)

        # Predict bag logits
        bag_logits = self.bag_fc(bag_repr).squeeze(-1)  # (batch_size,)

        return bag_logits, inst_logits, att_weights

class MeanPoolingModel(nn.Module):
    """
    A simple baseline model that pools instance features by averaging and
    produces a single logit per bag.
    """
    def __init__(self, input_dim=256):
        super().__init__()
        self.fc = nn.Linear(input_dim, 1)
        nn.init.xavier_uniform_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def forward(self, x):
        # Average pooling over instances
        bag_mean = x.mean(dim=1)
        logits = self.fc(bag_mean).squeeze(-1)
        return logits

# 모델 인스턴스
mil_model = DSMIL(input_dim=256, att_dim=128, dropout=0.1).to(device)
base_model = MeanPoolingModel(input_dim=256).to(device)

# 손실 함수 및 최적화 기법
# Note: A Weighted BCE loss will be defined in the training pipeline later
criterion = nn.BCEWithLogitsLoss()
optimizer_mil  = torch.optim.Adam(mil_model.parameters(), lr=1e-3)
optimizer_base = torch.optim.Adam(base_model.parameters(), lr=1e-3)
scheduler_mil  = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_mil, mode='max', factor=0.5, patience=1, verbose=True)
scheduler_base = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_base, mode='max', factor=0.5, patience=1, verbose=True)

In [None]:
# 학습 및 평가 함수 (Early Stopping 포함)

def train_one_epoch(model, optimizer, loader):
    """한 에포크 학습 수행
    
    DS-MIL의 경우 bag-level과 instance-level loss를 모두 사용하여 학습
    """
    model.train()
    total_loss = 0.0
    preds_all = []
    labels_all = []
    
    for X, y in tqdm(loader, desc='Train', leave=False):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        
        try:
            # DS-MIL 모델인지 확인
            if isinstance(model, DSMIL):
                bag_logits, inst_logits, _ = model(X)
                # Critical instance의 logit 추출 (각 bag에서 최대값)
                top_logits = inst_logits.max(dim=1).values
                
                # Dual loss 계산: bag-level + instance-level
                loss_bag = criterion(bag_logits, y)
                loss_top = criterion(top_logits, y)
                loss = loss_bag + loss_top
                
                # 최종 예측은 bag-level logits 사용
                logits = bag_logits
            else:
                # 다른 모델들 처리
                logits = model(X)[0] if hasattr(model, 'instance_fc') else model(X)
                loss = criterion(logits, y)
            
            loss.backward()
            # Gradient clipping으로 안정적 학습
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item() * y.size(0)
            preds = (torch.sigmoid(logits) >= 0.5).float()
            preds_all.extend(preds.cpu().numpy())
            labels_all.extend(y.cpu().numpy())
            
        except Exception as e:
            print(f'Error in training step: {e}')
            raise
    
    return total_loss / len(loader.dataset), accuracy_score(labels_all, preds_all)

def evaluate(model, loader):
    """모델 평가 수행
    
    Returns:
        tuple: (loss, accuracy, auc, f1, probabilities, labels, predictions)
    """
    model.eval()
    total_loss = 0.0
    probs_all = []
    preds_all = []
    labels_all = []
    
    with torch.no_grad():
        for X, y in tqdm(loader, desc='Eval', leave=False):
            X, y = X.to(device), y.to(device)
            
            try:
                if isinstance(model, DSMIL):
                    bag_logits, inst_logits, _ = model(X)
                    top_logits = inst_logits.max(dim=1).values
                    
                    # 평가시에도 동일한 loss 계산
                    loss_bag = criterion(bag_logits, y)
                    loss_top = criterion(top_logits, y)
                    loss = loss_bag + loss_top
                    
                    logits = bag_logits
                else:
                    logits = model(X)[0] if hasattr(model, 'instance_fc') else model(X)
                    loss = criterion(logits, y)
                
                total_loss += loss.item() * y.size(0)
                probs = torch.sigmoid(logits)
                preds = (probs >= 0.5).float()
                
                probs_all.extend(probs.cpu().numpy())
                preds_all.extend(preds.cpu().numpy())
                labels_all.extend(y.cpu().numpy())
                
            except Exception as e:
                print(f'Error in evaluation step: {e}')
                raise
    
    # 메트릭 계산
    acc = accuracy_score(labels_all, preds_all)
    auc = roc_auc_score(labels_all, probs_all) if len(set(labels_all)) > 1 else 0.0
    f1 = f1_score(labels_all, preds_all) if len(set(preds_all)) > 1 else 0.0
    
    return total_loss / len(loader.dataset), acc, auc, f1, np.array(probs_all), np.array(labels_all), np.array(preds_all)

def train_model(model, optimizer, scheduler, train_loader, val_loader, max_epochs=10, patience=3, name='model'):
    """전체 모델 학습 루프 (Early Stopping 포함)"""
    best_auc = 0.0
    best_state = None
    epochs_no_improve = 0
    
    print(f"\nStarting training for {name}...")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    for epoch in range(1, max_epochs+1):
        print(f"\nEpoch {epoch}/{max_epochs} – {name}")
        tr_loss, tr_acc = train_one_epoch(model, optimizer, train_loader)
        val_loss, val_acc, val_auc, val_f1, _, _, _ = evaluate(model, val_loader)
        print(f"  Train Loss: {tr_loss:.4f}, Acc: {tr_acc:.4f}")
        print(f"  Val   Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, AUC: {val_auc:.4f}, F1: {val_f1:.4f}")
        
        scheduler.step(val_auc)
        
        if val_auc > best_auc:
            best_auc = val_auc
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            torch.save(best_state, f'best_{name}.pth')
            print(f"  ✅ New best AUC: {best_auc:.4f} – model saved.")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            print(f"  No improvement. Patience: {epochs_no_improve}/{patience}")
            if epochs_no_improve >= patience:
                print("  🛑 Early stopping triggered.")
                break
    
    if best_state is not None:
        model.load_state_dict(best_state)
        print(f"\nLoaded best model with AUC: {best_auc:.4f}")
    
    return model

In [None]:
# ==============================================================================
# Final Pipeline: DS-MIL training with validation‑based threshold search
#
# 이 셀은 baseline AttentionMIL과 동일한 조건으로 DS-MIL 모델을 학습합니다:
# 1. 동일한 데이터셋 사용 (random split)
# 2. 동일한 손실함수 (WeightedBCE, fp_weight=2.0)
# 3. 동일한 threshold 최적화 방식 (F1 기준)
# 
# DS-MIL의 핵심 특징:
# - Dual-stream architecture: bag-level + instance-level classification
# - Critical instance identification을 통한 attention mechanism
# - Dual loss로 더 robust한 학습
# ==============================================================================

import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_curve, auc, confusion_matrix

# Define Weighted BCE loss - baseline과 동일한 설정
class WeightedBCE(nn.Module):
    def __init__(self, fp_weight=2.0):
        super().__init__()
        self.fp_weight = fp_weight
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
    def forward(self, logits, labels):
        loss = self.bce(logits, labels)
        fp_mask = (labels == 0).float()
        loss = loss * (1 + self.fp_weight * fp_mask)
        return loss.mean()

# Use Weighted BCE as the criterion for the DSMIL model
criterion = WeightedBCE(fp_weight=2.0)
print("Using WeightedBCE loss with fp_weight=2.0 (same as baseline)")

# Initialise a fresh DSMIL model
mil_model_final = DSMIL(input_dim=256, att_dim=128, dropout=0.1).to(device)
optimizer_final = torch.optim.Adam(mil_model_final.parameters(), lr=1e-3)
scheduler_final = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_final, mode='max', factor=0.5, patience=1, verbose=True
)

# Train the model
mil_model_final = train_model(
    mil_model_final, optimizer_final, scheduler_final,
    train_loader, val_loader, max_epochs=10, patience=3, name='dsmil_final'
)

# Evaluate on validation and test
val_loss_final, val_acc_final, val_auc_final, val_f1_final, val_probs_final, val_labels_final, _ = evaluate(
    mil_model_final, val_loader
)
test_loss_final, test_acc_final, test_auc_final, test_f1_final, test_probs_final, test_labels_final, _ = evaluate(
    mil_model_final, test_loader
)

# Function to find best threshold based on F1 - baseline과 동일한 방식
def find_best_threshold(probs, labels):
    """F1 score를 최대화하는 threshold 찾기"""
    best_thr, best_val = 0.5, 0.0
    for thr in np.linspace(0.05, 0.95, 37):  # baseline과 동일한 범위
        preds = (probs >= thr).astype(int)
        val = f1_score(labels, preds, zero_division=0)
        if val > best_val:
            best_val, best_thr = val, thr
    return best_thr, best_val

# Determine the best threshold on validation set
best_thr_final, best_f1_valid = find_best_threshold(val_probs_final, val_labels_final)
print(f'Best validation F1 threshold: {best_thr_final:.3f} (F1={best_f1_valid:.3f})')

# Apply the threshold to test set
test_preds_adj_final = (test_probs_final >= best_thr_final).astype(int)
acc_final = accuracy_score(test_labels_final, test_preds_adj_final)
f1_final = f1_score(test_labels_final, test_preds_adj_final, zero_division=0)
prec_final = precision_score(test_labels_final, test_preds_adj_final, zero_division=0)
recall_final = recall_score(test_labels_final, test_preds_adj_final, zero_division=0)

print('\n' + '='*60)
print('DS-MIL FINAL RESULTS (vs Baseline Comparison)')
print('='*60)
print('Final test metrics (Weighted BCE + optimised threshold):')
print(f'  Accuracy: {acc_final:.3f}, F1: {f1_final:.3f}, Precision: {prec_final:.3f}, Recall: {recall_final:.3f}, AUC: {test_auc_final:.3f}')
print('\nBaseline AttentionMIL Results (for comparison):')
print('  Accuracy: 0.792, F1: 0.759, Precision: 0.750, Recall: 0.768, AUC: 0.829')
print('\nPerformance Comparison:')
print(f'  AUC improvement: {test_auc_final:.3f} - 0.829 = {test_auc_final - 0.829:+.3f}')
print(f'  F1 improvement: {f1_final:.3f} - 0.759 = {f1_final - 0.759:+.3f}')
print(f'  Accuracy improvement: {acc_final:.3f} - 0.792 = {acc_final - 0.792:+.3f}')

# Confusion matrix
cm_final = confusion_matrix(test_labels_final.astype(int), test_preds_adj_final.astype(int), labels=[0,1])
plt.figure(figsize=(4,3))
sns.heatmap(
    cm_final, annot=True, fmt='d', cmap='Blues',
    xticklabels=['Genuine','Forged'], yticklabels=['Genuine','Forged']
)
plt.title(f'DS-MIL Confusion Matrix (Thr={best_thr_final:.2f})')
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.tight_layout()
plt.show()

# ROC curve with baseline comparison
fpr_final, tpr_final, _ = roc_curve(test_labels_final, test_probs_final)
auc_final_value = auc(fpr_final, tpr_final)
plt.figure(figsize=(6,5))
plt.plot(fpr_final, tpr_final, color='blue', linewidth=2, 
         label=f'DS-MIL (AUC={auc_final_value:.3f})')
plt.axhline(y=0.829, color='red', linestyle='--', alpha=0.7, 
           label='Baseline AttentionMIL (AUC=0.829)')
plt.plot([0,1],[0,1],'k--', alpha=0.5, label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve – DS-MIL vs Baseline Comparison')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Summary of improvements
print('\n' + '='*60)
print('PERFORMANCE ANALYSIS SUMMARY')
print('='*60)
if test_auc_final > 0.829:
    print("✅ DS-MIL shows IMPROVEMENT over baseline AttentionMIL")
    print(f"   - AUC improved by {test_auc_final - 0.829:.3f} points")
else:
    print("❌ DS-MIL shows DECLINE compared to baseline AttentionMIL")
    print(f"   - AUC decreased by {0.829 - test_auc_final:.3f} points")

if f1_final > 0.759:
    print(f"✅ F1 Score improved by {f1_final - 0.759:.3f} points")
else:
    print(f"❌ F1 Score decreased by {0.759 - f1_final:.3f} points")

print(f"\nKey factors in DS-MIL architecture:")
print(f"- Dual-stream learning (bag + instance level)")
print(f"- Critical instance identification")
print(f"- Cross-attention mechanism")
print(f"- Combined loss function (bag_loss + instance_loss)")

In [None]:
# 실험 결과 저장 및 요약
import json
from datetime import datetime

# 결과 저장 디렉토리 생성
import os
os.makedirs('/workspace/MIL/output/results', exist_ok=True)
os.makedirs('/workspace/MIL/output/models', exist_ok=True)

# 실험 결과 정리
results = {
    'experiment': 'DS-MIL vs AttentionMIL Comparison',
    'timestamp': datetime.now().isoformat(),
    'model': 'DS-MIL',
    'data': {
        'dataset': f'bags_arcface_margin_{embedding_margin}_50p_random',
        'train_samples': len(train_labels),
        'val_samples': len(val_labels), 
        'test_samples': len(test_labels),
        'train_pos_ratio': np.mean(train_labels),
        'val_pos_ratio': np.mean(val_labels),
        'test_pos_ratio': np.mean(test_labels)
    },
    'model_config': {
        'input_dim': 256,
        'attention_dim': 128,
        'dropout': 0.1,
        'total_parameters': sum(p.numel() for p in mil_model_final.parameters())
    },
    'training': {
        'loss_function': 'WeightedBCE',
        'fp_weight': 2.0,
        'optimizer': 'Adam',
        'learning_rate': 1e-3,
        'scheduler': 'ReduceLROnPlateau',
        'max_epochs': 10,
        'patience': 3,
        'best_val_auc': float(val_auc_final)
    },
    'results': {
        'validation': {
            'loss': float(val_loss_final),
            'accuracy': float(val_acc_final),
            'auc': float(val_auc_final),
            'f1': float(val_f1_final)
        },
        'test': {
            'loss': float(test_loss_final),
            'accuracy': float(acc_final),
            'auc': float(test_auc_final), 
            'f1': float(f1_final),
            'precision': float(prec_final),
            'recall': float(recall_final),
            'best_threshold': float(best_thr_final)
        },
        'baseline_comparison': {
            'baseline_auc': 0.829,
            'baseline_f1': 0.759,
            'baseline_accuracy': 0.792,
            'auc_improvement': float(test_auc_final - 0.829),
            'f1_improvement': float(f1_final - 0.759),
            'accuracy_improvement': float(acc_final - 0.792)
        }
    }
}

# JSON으로 결과 저장
results_file = '/workspace/MIL/output/results/dsmil_vs_baseline_results.json'
with open(results_file, 'w') as f:
    json.dump(results, f, indent=2)

# 모델 저장
model_file = '/workspace/MIL/output/models/dsmil_final_model.pth'
torch.save({
    'model_state_dict': mil_model_final.state_dict(),
    'model_config': results['model_config'],
    'results': results['results']
}, model_file)

print(f"\n📁 Results saved to: {results_file}")
print(f"🔧 Model saved to: {model_file}")

# 최종 요약 출력
print(f"\n{'='*80}")
print(f"EXPERIMENT COMPLETED: DS-MIL vs Baseline AttentionMIL")
print(f"{'='*80}")
print(f"Dataset: ArcFace margin {embedding_margin}, 50% positive, random split")
print(f"DS-MIL Parameters: {results['model_config']['total_parameters']:,}")
print(f"Training completed with best validation AUC: {val_auc_final:.4f}")
print(f"\nFINAL TEST RESULTS:")
print(f"  DS-MIL:     AUC={test_auc_final:.3f}, F1={f1_final:.3f}, Acc={acc_final:.3f}")
print(f"  Baseline:   AUC=0.829, F1=0.759, Acc=0.792")
print(f"  Difference: AUC={test_auc_final-0.829:+.3f}, F1={f1_final-0.759:+.3f}, Acc={acc_final-0.792:+.3f}")

if test_auc_final > 0.829:
    print(f"🎉 DS-MIL OUTPERFORMS baseline by {test_auc_final-0.829:.3f} AUC points!")
else:
    print(f"📉 DS-MIL underperforms baseline by {0.829-test_auc_final:.3f} AUC points.")
    
print(f"{'='*80}")