In [None]:
import torch
import torchvision
import torchvision.transforms as tvt
from torch.utils.data import DataLoader
import numpy as np
import random
import warnings
warnings.filterwarnings('ignore')
from Fairnet import FairNetViT,FairNetCelebATrainer,evaluate_model,print_metrics,BaselineViT

def seed_everything(seed):
    """ËÆæÁΩÆÈöèÊú∫ÁßçÂ≠ê"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ËÆæÂ§áÈÖçÁΩÆ
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
seed_everything(45)

# Êï∞ÊçÆÈõÜÂèÇÊï∞
image_size = 64
batch_size = 128

# Êï∞ÊçÆÂä†ËΩΩ
print("Loading CelebA dataset...")
dataset = torchvision.datasets.CelebA(
    "./celebA/", 
    split='train', 
    transform=tvt.Compose([
        tvt.Resize((image_size, image_size)),
        tvt.ToTensor(),
        tvt.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
)

test_dataset = torchvision.datasets.CelebA(
    "./celebA/", 
    split='test', 
    transform=tvt.Compose([
        tvt.Resize((image_size, image_size)),
        tvt.ToTensor(),
        tvt.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
)

training_data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
print(f"Training samples: {len(dataset)}, Test samples: {len(test_dataset)}")

def demo_celeba():
    """
    CelebA Êï∞ÊçÆÈõÜ‰∏äÁöÑ FairNet ÊºîÁ§∫
    
    ÂØπÊØîÂü∫Á∫øÊ®°ÂûãÂíå FairNet ÁöÑÊÄßËÉΩÔºåÁâπÂà´ÂÖ≥Ê≥®Â∞ëÊï∞Áæ§‰ΩìÔºàBlond MaleÔºâ
    """
    print("="*60)
    print("FairNet Demo on CelebA Dataset")
    print("="*60)
    print(f"\nDevice: {device}")
    print(f"Task: Gender Classification (Male/Female)")
    print(f"Sensitive Attribute: Hair Color (Blond/Non-Blond)")
    print(f"Minority Group: Blond Male")
    
    # ViT ÈÖçÁΩÆ
    vit_config = ViTConfig(
        num_hidden_layers=8,
        num_attention_heads=8,
        intermediate_size=768,
        image_size=64,
        patch_size=16
    )
    
    # ================================================================
    # ÂÆûÈ™å 1ÔºöÂü∫Á∫øÊ®°Âûã
    # ================================================================
    print("\n" + "#"*60)
    print("# Experiment 1: Baseline ViT")
    print("#"*60)
    
    baseline_model = BaselineViT(vit_config)
    baseline_model = train_baseline(
        baseline_model, 
        training_data_loader, 
        device,
        epochs=15,
        lr=1e-5
    )
    
    baseline_metrics = evaluate_model(
        baseline_model, 
        test_data_loader, 
        device,
        model_type='baseline'
    )
    print_metrics(baseline_metrics, "Baseline ViT Results")
    
    # ================================================================
    # ÂÆûÈ™å 2ÔºöFairNet
    # ================================================================
    print("\n" + "#"*60)
    print("# Experiment 2: FairNet ViT")
    print("#"*60)
    
    fairnet_model = FairNetViT(
        vit_config,
        lora_rank=8,
        lora_alpha=16.0,
        lora_threshold=0.5,
        detector_hidden=128
    )
    
    trainer = FairNetCelebATrainer(fairnet_model, device)
    
    # ÂõõÈò∂ÊÆµËÆ≠ÁªÉ
    trainer.train_full(
        training_data_loader,
        stage1_epochs=10,      # Âü∫Á°ÄÊ®°ÂûãËÆ≠ÁªÉ
        stage2_epochs=5,       # Ê£ÄÊµãÂô®ËÆ≠ÁªÉ
        stage4_epochs=10,      # LoRA ËÆ≠ÁªÉ
        stage1_lr=1e-5,
        stage2_lr=1e-4,
        stage4_lr=1e-4
    )
    
    fairnet_metrics = evaluate_model(
        fairnet_model, 
        test_data_loader, 
        device,
        model_type='fairnet'
    )
    print_metrics(fairnet_metrics, "FairNet ViT Results")
    
    # ================================================================
    # ÁªìÊûúÂØπÊØî
    # ================================================================
    print("\n" + "#"*60)
    print("# Comparison: Baseline vs FairNet")
    print("#"*60)
    
    print("\nüìà Performance Comparison:")
    print("-" * 50)
    print(f"{'Metric':<30} {'Baseline':<12} {'FairNet':<12} {'Œî':<10}")
    print("-" * 50)
    
    comparison_metrics = [
        ('accuracy', 'Overall Accuracy'),
        ('balanced_accuracy', 'Balanced Accuracy'),
        ('acc_non_blond_female', 'Non-Blond Female Acc'),
        ('acc_non_blond_male', 'Non-Blond Male Acc'),
        ('acc_blond_female', 'Blond Female Acc'),
        ('acc_blond_male', 'Blond Male Acc ‚≠ê'),
        ('worst_group_accuracy', 'Worst Group Acc'),
        ('EOP', 'Equal Opportunity'),
        ('EOD', 'Equalized Odds'),
    ]
    
    for key, name in comparison_metrics:
        base_val = baseline_metrics.get(key, 0)
        fair_val = fairnet_metrics.get(key, 0)
        delta = fair_val - base_val
        
        # ÂØπ‰∫éÂÖ¨Âπ≥ÊÄßÊåáÊ†áÔºåÂáèÂ∞èÊòØÂ•Ω‰∫ã
        if key in ['EOP', 'EOD', 'accuracy_gap']:
            delta_str = f"{delta:+.4f}" if delta <= 0 else f"{delta:+.4f}"
            improvement = "‚úì" if delta < 0 else ""
        else:
            delta_str = f"{delta:+.4f}"
            improvement = "‚úì" if delta > 0 else ""
        
        print(f"{name:<30} {base_val:<12.4f} {fair_val:<12.4f} {delta_str} {improvement}")
    
    print("-" * 50)
    
    # ÂÖ≥ÈîÆÊîπËøõÊÄªÁªì
    blond_male_improvement = (
        fairnet_metrics.get('acc_blond_male', 0) - 
        baseline_metrics.get('acc_blond_male', 0)
    )
    majority_change = (
        fairnet_metrics.get('acc_non_blond_female', 0) - 
        baseline_metrics.get('acc_non_blond_female', 0)
    )
    
    print("\nüéØ Key Findings:")
    print(f"  ‚Ä¢ Blond Male (Minority) Improvement: {blond_male_improvement:+.4f}")
    print(f"  ‚Ä¢ Non-Blond Female (Majority) Change: {majority_change:+.4f}")
    
    if blond_male_improvement > 0 and majority_change >= -0.01:
        print("\n‚úÖ SUCCESS: FairNet improved minority group performance "
              "without significant loss on majority group!")
    elif blond_male_improvement > 0:
        print("\n‚ö†Ô∏è PARTIAL SUCCESS: Minority improved but some majority impact")
    else:
        print("\n‚ùå Need tuning: Consider adjusting hyperparameters")
    
    return baseline_model, fairnet_model, baseline_metrics, fairnet_metrics


# ============================================================================
# ËøêË°åÊºîÁ§∫
# ============================================================================

if __name__ == "__main__":
    baseline_model, fairnet_model, baseline_metrics, fairnet_metrics = demo_celeba()
    
    # ‰øùÂ≠òÊ®°Âûã
    print("\nSaving models...")
    torch.save(baseline_model.state_dict(), "baseline_vit_celeba.pth")
    torch.save(fairnet_model.state_dict(), "fairnet_vit_celeba.pth")
    print("Models saved!")

Loading CelebA dataset...
Training samples: 162770, Test samples: 19962
FairNet Demo on CelebA Dataset

Device: cuda:1
Task: Gender Classification (Male/Female)
Sensitive Attribute: Hair Color (Blond/Non-Blond)
Minority Group: Blond Male

############################################################
# Experiment 1: Baseline ViT
############################################################

Training Baseline ViT Model


Epoch 1/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:42<00:00,  7.83it/s, loss=0.1757]


Epoch 1: Avg Loss = 0.2661


Epoch 2/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:40<00:00,  7.91it/s, loss=0.1051]


Epoch 2: Avg Loss = 0.1481


Epoch 3/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:35<00:00,  8.19it/s, loss=0.1302]


Epoch 3: Avg Loss = 0.1200


Epoch 4/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:34<00:00,  8.22it/s, loss=0.1043]


Epoch 4: Avg Loss = 0.0979


Epoch 5/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:33<00:00,  8.30it/s, loss=0.1203]


Epoch 5: Avg Loss = 0.0807


Epoch 6/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:32<00:00,  8.34it/s, loss=0.0736]


Epoch 6: Avg Loss = 0.0649


Epoch 7/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:33<00:00,  8.31it/s, loss=0.0461]


Epoch 7: Avg Loss = 0.0489


Epoch 8/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:33<00:00,  8.26it/s, loss=0.0523]


Epoch 8: Avg Loss = 0.0380


Epoch 9/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:32<00:00,  8.35it/s, loss=0.0720]


Epoch 9: Avg Loss = 0.0285


Epoch 10/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:32<00:00,  8.32it/s, loss=0.0166]


Epoch 10: Avg Loss = 0.0239


Epoch 11/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:33<00:00,  8.30it/s, loss=0.0056]


Epoch 11: Avg Loss = 0.0194


Epoch 12/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:30<00:00,  8.42it/s, loss=0.0116]


Epoch 12: Avg Loss = 0.0170


Epoch 13/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:30<00:00,  8.46it/s, loss=0.0332]


Epoch 13: Avg Loss = 0.0151


Epoch 14/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:27<00:00,  8.61it/s, loss=0.0197]


Epoch 14: Avg Loss = 0.0139


Epoch 15/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:27<00:00,  8.61it/s, loss=0.0168]


Epoch 15: Avg Loss = 0.0114


Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 156/156 [00:15<00:00,  9.82it/s]



Baseline ViT Results

üìä Overall Performance:
  Accuracy: 0.9582
  Balanced Accuracy: 0.9569

üë• Group Performance:
  Non-Blond Female: 0.9568 (n=9767)
  Non-Blond Male:   0.9555 (n=7535)
  Blond Female:     0.9855 (n=2480)
  Blond Male:       0.7667 (n=180) ‚¨ÖÔ∏è Minority Group

üéØ Worst Group Accuracy: 0.7667

‚öñÔ∏è Fairness Metrics:
  EOP (Equal Opportunity): 0.1889
  EOD (Equalized Odds):    0.1088
  Accuracy Gap:            0.1889

############################################################
# Experiment 2: FairNet ViT
############################################################

Stage 1: Training Base Model


Epoch 1/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:29<00:00,  8.52it/s, loss=0.1662]


Epoch 1: Avg Loss = 0.2597


Epoch 2/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:28<00:00,  8.53it/s, loss=0.1411]


Epoch 2: Avg Loss = 0.1518


Epoch 3/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:28<00:00,  8.57it/s, loss=0.1142]


Epoch 3: Avg Loss = 0.1201


Epoch 4/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:28<00:00,  8.57it/s, loss=0.1060]


Epoch 4: Avg Loss = 0.0980


Epoch 5/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:28<00:00,  8.54it/s, loss=0.1227]


Epoch 5: Avg Loss = 0.0801


Epoch 6/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:27<00:00,  8.59it/s, loss=0.0585]


Epoch 6: Avg Loss = 0.0633


Epoch 7/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:27<00:00,  8.60it/s, loss=0.0437]


Epoch 7: Avg Loss = 0.0487


Epoch 8/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:27<00:00,  8.60it/s, loss=0.0473]


Epoch 8: Avg Loss = 0.0372


Epoch 9/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:27<00:00,  8.59it/s, loss=0.0213]


Epoch 9: Avg Loss = 0.0292


Epoch 10/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:28<00:00,  8.55it/s, loss=0.0450]


Epoch 10: Avg Loss = 0.0229

Stage 2: Training Bias Detector


Epoch 1/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:08<00:00,  9.93it/s, loss=0.2199, acc=0.9112]


Epoch 1: Loss=0.2186, Acc=0.9112


Epoch 2/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:08<00:00,  9.91it/s, loss=0.2551, acc=0.9250]


Epoch 2: Loss=0.1868, Acc=0.9250


Epoch 3/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:08<00:00,  9.87it/s, loss=0.2690, acc=0.9288]


Epoch 3: Loss=0.1775, Acc=0.9288


Epoch 4/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:09<00:00,  9.83it/s, loss=0.1620, acc=0.9308]


Epoch 4: Loss=0.1719, Acc=0.9308


Epoch 5/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:08<00:00,  9.90it/s, loss=0.2228, acc=0.9335]


Epoch 5: Loss=0.1664, Acc=0.9335

Stage 3: Building Prototypes


Building prototypes: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:12<00:00,  9.60it/s]



Prototype Statistics:
  Female + Non-Blond: 71595 samples
  Female + Blond: 22869 samples
  Male + Non-Blond: 66837 samples
  Male + Blond: 1387 samples

Stage 4: Training Conditional LoRA


Epoch 1/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:10<00:00,  9.70it/s, loss=0.0002, c_loss=0.0000, t_loss=0.0021]


Epoch 1: Loss=0.0132, C_Loss=0.0121, T_Loss=0.0117


Epoch 2/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:11<00:00,  9.68it/s, loss=0.0007, c_loss=0.0000, t_loss=0.0072]


Epoch 2: Loss=0.0070, C_Loss=0.0054, T_Loss=0.0162


Epoch 3/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:11<00:00,  9.63it/s, loss=0.0011, c_loss=0.0000, t_loss=0.0111]


Epoch 3: Loss=0.0052, C_Loss=0.0036, T_Loss=0.0164


Epoch 4/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:11<00:00,  9.65it/s, loss=0.0007, c_loss=0.0000, t_loss=0.0065]


Epoch 4: Loss=0.0042, C_Loss=0.0028, T_Loss=0.0149


Epoch 5/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:12<00:00,  9.63it/s, loss=0.0007, c_loss=0.0000, t_loss=0.0073]


Epoch 5: Loss=0.0037, C_Loss=0.0024, T_Loss=0.0129


Epoch 6/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:11<00:00,  9.63it/s, loss=0.0006, c_loss=0.0000, t_loss=0.0062]


Epoch 6: Loss=0.0028, C_Loss=0.0016, T_Loss=0.0120


Epoch 7/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:11<00:00,  9.66it/s, loss=0.0014, c_loss=0.0000, t_loss=0.0140]


Epoch 7: Loss=0.0027, C_Loss=0.0017, T_Loss=0.0099


Epoch 8/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:11<00:00,  9.67it/s, loss=0.0002, c_loss=0.0000, t_loss=0.0023]


Epoch 8: Loss=0.0024, C_Loss=0.0015, T_Loss=0.0086


Epoch 9/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:11<00:00,  9.64it/s, loss=0.0883, c_loss=0.0836, t_loss=0.0471]


Epoch 9: Loss=0.0017, C_Loss=0.0010, T_Loss=0.0073


Epoch 10/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1271/1271 [02:12<00:00,  9.61it/s, loss=0.0004, c_loss=0.0000, t_loss=0.0044]


Epoch 10: Loss=0.0019, C_Loss=0.0013, T_Loss=0.0064


Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 156/156 [00:15<00:00,  9.92it/s]



FairNet ViT Results

üìä Overall Performance:
  Accuracy: 0.9559
  Balanced Accuracy: 0.9576

üë• Group Performance:
  Non-Blond Female: 0.9431 (n=9767)
  Non-Blond Male:   0.9676 (n=7535)
  Blond Female:     0.9782 (n=2480)
  Blond Male:       0.8556 (n=180) ‚¨ÖÔ∏è Minority Group

üéØ Worst Group Accuracy: 0.8556

‚öñÔ∏è Fairness Metrics:
  EOP (Equal Opportunity): 0.1121
  EOD (Equalized Odds):    0.0736
  Accuracy Gap:            0.1121

üîß LoRA Activation Rate: 0.1265

############################################################
# Comparison: Baseline vs FairNet
############################################################

üìà Performance Comparison:
--------------------------------------------------
Metric                         Baseline     FairNet      Œî         
--------------------------------------------------
Overall Accuracy               0.9582       0.9559       -0.0023 
Balanced Accuracy              0.9569       0.9576       +0.0007 ‚úì
Non-Blond Female Acc    