In [1]:

import os
import time
import copy
import json
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset, DataLoader, TensorDataset
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

from options import args_parser
from update import LocalUpdate, test_inference
from models import (CNNCifar, CNNMnist, ResNet18, 
                   GeneratorCifar, DiscriminatorCifar, Generator, Discriminator,
                   generate_images, filter_images, add_backdoor_trigger_cifar, add_backdoor_trigger_mnist)
from utils import get_dataset, average_weights, exp_details, create_poisoned_dataset
from unlearn import (
    train_generator_ungan,
    SyntheticImageDataset,
    partition_synthetic_data_iid,
    get_synthetic_subset
)
from evaluate_mia import evaluate_mia

In [2]:
import os
import time
import copy
import json
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset, DataLoader, TensorDataset
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

from options import args_parser
from update import LocalUpdate, test_inference
from models import (CNNCifar, CNNMnist, ResNet18, 
                   GeneratorCifar, DiscriminatorCifar, Generator, Discriminator,
                   generate_images, filter_images, add_backdoor_trigger_cifar, add_backdoor_trigger_mnist)
from utils import get_dataset, average_weights, exp_details, create_poisoned_dataset
from unlearn import (
    train_generator_ungan,
    SyntheticImageDataset,
    partition_synthetic_data_iid,
    get_synthetic_subset
)
from evaluate_mia import evaluate_mia

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


# =================== Image Visualization Functions ===================

def denormalize_image(tensor, dataset_type='cifar'):
    """데이터셋 타입에 따른 정규화 해제"""
    if dataset_type == 'cifar':
        mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
        std = torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1)
    else:  # mnist
        mean = torch.tensor([0.1307]).view(1, 1, 1)
        std = torch.tensor([0.3081]).view(1, 1, 1)
    
    denormalized = tensor * std + mean
    return torch.clamp(denormalized, 0, 1)


def visualize_hybrid_generation(generator, discriminator, dataset, unseen_dataset, 
                               forget_idxs, device, z_dim=100, num_samples=16, dataset_type='cifar'):
    """Hybrid 이미지 생성 과정을 단계별로 시각화"""
    generator.eval()
    discriminator.eval()
    
    # 1. 원본 Forget 이미지들
    forget_samples = []
    forget_labels = []
    for i in range(min(num_samples//4, len(forget_idxs))):
        img, label = dataset[forget_idxs[i]]
        forget_samples.append(img)
        forget_labels.append(label)
    
    # 2. Unseen 이미지들
    unseen_samples = []
    for i in range(min(num_samples//4, len(unseen_dataset))):
        img, _ = unseen_dataset[i]
        unseen_samples.append(img)
    
    # 3. 기본 생성 이미지
    with torch.no_grad():
        noise = torch.randn((num_samples//4, z_dim), device=device)
        basic_generated = generator(noise).cpu()
    
    # 4. Hybrid 이미지
    unseen_style = torch.stack(unseen_samples).mean(dim=0, keepdim=True).to(device)
    
    with torch.no_grad():
        noise = torch.randn((num_samples//4, z_dim), device=device)
        gen_imgs = generator(noise)
        hybrid_generated = style_blend(gen_imgs, unseen_style, strength=0.3).cpu()
    
    # 시각화
    fig, axes = plt.subplots(4, num_samples//4, figsize=(16, 12))
    fig.suptitle(f'Advanced UNGAN: Hybrid Image Generation Process ({dataset_type.upper()})', 
                 fontsize=16, fontweight='bold')
    
    stages = [
        (forget_samples, "1. Original Forget Data", 'Reds'),
        (unseen_samples, "2. Unseen Style Reference", 'Blues'), 
        (basic_generated, "3. Basic Generated", 'Greens'),
        (hybrid_generated, "4. Hybrid (Forget+Unseen)", 'Purples')
    ]
    
    for row, (images, title, cmap) in enumerate(stages):
        for col in range(len(images)):
            ax = axes[row, col]
            img = denormalize_image(images[col], dataset_type)
            
            # MNIST는 흑백, CIFAR는 컬러
            if dataset_type == 'mnist':
                img_np = img.squeeze().numpy()  # (28, 28)
                ax.imshow(img_np, cmap='gray')
            else:
                img_np = img.permute(1, 2, 0).numpy()  # (32, 32, 3)
                ax.imshow(img_np)
            
            ax.axis('off')
            
            if col == 0:
                ax.text(-0.1, 0.5, title, rotation=90, verticalalignment='center',
                       fontsize=12, fontweight='bold', transform=ax.transAxes)
            
            if row == 0:
                ax.set_title(f'Class: {forget_labels[col]}', fontsize=10)
    
    plt.tight_layout()
    plt.savefig(f'./hybrid_generation_process_{dataset_type}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    return hybrid_generated


def save_individual_images(hybrid_images, save_dir='./generated_samples', dataset_type='cifar'):
    """개별 hybrid 이미지들을 파일로 저장"""
    os.makedirs(save_dir, exist_ok=True)
    
    for i, img in enumerate(hybrid_images):
        img_denorm = denormalize_image(img, dataset_type)
        
        if dataset_type == 'mnist':
            # MNIST: 흑백 이미지 처리
            img_denorm = img_denorm.squeeze()  # (1, 28, 28) -> (28, 28)
        
        img_pil = transforms.ToPILImage()(img_denorm)
        img_pil.save(os.path.join(save_dir, f'hybrid_sample_{dataset_type}_{i:03d}.png'))
    
    print(f"💾 Saved {len(hybrid_images)} {dataset_type.upper()} hybrid images to {save_dir}/")


def visualize_advanced_ungan_results(generator, discriminator, dataset, unseen_dataset,
                                   forget_idxs, filtered_images, filtered_labels, 
                                   device, args, dataset_type='cifar'):
    """Advanced UNGAN 결과를 종합적으로 시각화"""
    print(f"\n🎨 ========== Visualizing Advanced UNGAN Results ({dataset_type.upper()}) ==========")
    
    # 1. Hybrid 생성 과정 시각화
    print("📸 1. Creating hybrid generation process visualization...")
    try:
        hybrid_samples = visualize_hybrid_generation(
            generator, discriminator, dataset, unseen_dataset, 
            forget_idxs, device, args.z_dim, num_samples=16, dataset_type=dataset_type
        )
    except Exception as e:
        print(f"⚠️  Visualization error: {e}")
        hybrid_samples = filtered_images[:16] if len(filtered_images) >= 16 else filtered_images
    
    # 2. 개별 이미지 저장
    print("💾 2. Saving individual hybrid images...")
    if len(filtered_images) > 0:
        save_individual_images(filtered_images[:50], dataset_type=dataset_type)
        
        # 비교용 원본 이미지도 저장
        os.makedirs('./original_samples', exist_ok=True)
        for i, idx in enumerate(forget_idxs[:20]):
            img, label = dataset[idx]
            img_denorm = denormalize_image(img, dataset_type)
            
            if dataset_type == 'mnist':
                img_denorm = img_denorm.squeeze()
            
            img_pil = transforms.ToPILImage()(img_denorm)
            img_pil.save(f'./original_samples/forget_sample_{dataset_type}_{i:03d}_class{label}.png')
        
        print(f"💾 Also saved original {dataset_type.upper()} forget samples to ./original_samples/")
    else:
        print("⚠️  No hybrid images to save")
    
    print("✅ Visualization completed! Check these files:")
    print(f"   📊 hybrid_generation_process_{dataset_type}.png - 4단계 생성 과정")
    print("   🖼️  ./generated_samples/ - Hybrid 이미지들")  
    print("   📂 ./original_samples/ - 원본 Forget 이미지들")
    
    return hybrid_samples

In [3]:
# =================== Advanced UNGAN Functions ===================

def extract_intermediate_features(discriminator, images):
    """Discriminator의 중간 레이어에서 특징 벡터 추출"""
    features = images
    for i, layer in enumerate(discriminator.model):
        features = layer(features)
        if i == 4:  # 중간 레이어에서 특징 추출
            break
    return features.view(features.size(0), -1)


def style_blend(generated_imgs, style_reference, strength=0.3):
    """생성된 이미지에 참조 스타일 블렌딩"""
    # 색조 및 밝기 조정
    gen_mean = generated_imgs.mean(dim=[2, 3], keepdim=True)
    gen_std = generated_imgs.std(dim=[2, 3], keepdim=True)
    
    style_mean = style_reference.mean(dim=[2, 3], keepdim=True)
    style_std = style_reference.std(dim=[2, 3], keepdim=True)
    
    # 스타일 전이 (AdaIN 방식 간소화)
    normalized = (generated_imgs - gen_mean) / (gen_std + 1e-8)
    stylized = normalized * (gen_std * (1-strength) + style_std * strength) + \
               (gen_mean * (1-strength) + style_mean * strength)
    
    return torch.clamp(stylized, -1, 1)


def train_hybrid_generator(generator, discriminator, dataset, unseen_dataset, 
                          retain_idxs, forget_idxs, device,
                          alpha=0.7, beta=0.3, z_dim=100, batch_size=64, epochs=50):
    """
    Forget 특성 + Unseen 특성을 융합하는 고급 Generator 훈련
    """
    print(f"[Hybrid UNGAN] Training with α={alpha} (forget) + β={beta} (unseen)")
    
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
    
    # 데이터 로더 준비
    forget_subset = Subset(dataset, forget_idxs)
    forget_loader = DataLoader(forget_subset, batch_size=batch_size//2, shuffle=True, drop_last=True)
    unseen_loader = DataLoader(unseen_dataset, batch_size=batch_size//2, shuffle=True, drop_last=True)
    
    generator.train()
    discriminator.train()
    
    for epoch in range(epochs):
        epoch_g_loss = 0
        epoch_d_loss = 0
        batch_count = 0
        
        # 두 데이터 로더를 동시에 순회
        forget_iter = iter(forget_loader)
        unseen_iter = iter(unseen_loader)
        
        for _ in range(min(len(forget_loader), len(unseen_loader))):
            try:
                forget_batch, forget_labels = next(forget_iter)
                unseen_batch, unseen_labels = next(unseen_iter)
            except StopIteration:
                break
                
            forget_batch = forget_batch.to(device)
            unseen_batch = unseen_batch.to(device)
            
            batch_size_actual = min(forget_batch.size(0), unseen_batch.size(0))
            forget_batch = forget_batch[:batch_size_actual]
            unseen_batch = unseen_batch[:batch_size_actual]
            
            # Discriminator 훈련
            d_optimizer.zero_grad()
            
            # 실제 데이터 (forget=0, unseen=1)
            real_forget_labels = torch.zeros((batch_size_actual, 1), device=device)
            real_unseen_labels = torch.ones((batch_size_actual, 1), device=device)
            fake_labels = torch.full((batch_size_actual, 1), 0.5, device=device)
            
            # Discriminator 손실 계산
            d_pred_forget = discriminator(forget_batch)
            d_loss_forget = F.binary_cross_entropy(d_pred_forget, real_forget_labels)
            
            d_pred_unseen = discriminator(unseen_batch)
            d_loss_unseen = F.binary_cross_entropy(d_pred_unseen, real_unseen_labels)
            
            z = torch.randn((batch_size_actual, z_dim), device=device)
            fake_imgs = generator(z)
            d_pred_fake = discriminator(fake_imgs.detach())
            d_loss_fake = F.binary_cross_entropy(d_pred_fake, fake_labels)
            
            d_loss = d_loss_forget + d_loss_unseen + d_loss_fake
            d_loss.backward()
            d_optimizer.step()
            
            # Generator 훈련
            g_optimizer.zero_grad()
            
            z = torch.randn((batch_size_actual, z_dim), device=device)
            gen_imgs = generator(z)
            
            # Strategy A: Forget 특성 유지
            d_pred_gen = discriminator(gen_imgs)
            forget_target = torch.zeros((batch_size_actual, 1), device=device)
            loss_forget_mimic = F.binary_cross_entropy(d_pred_gen, forget_target)
            
            # Strategy B: Unseen 특성 모방
            try:
                unseen_features = extract_intermediate_features(discriminator, unseen_batch)
                gen_features = extract_intermediate_features(discriminator, gen_imgs)
                loss_unseen_style = F.mse_loss(gen_features, unseen_features)
            except:
                # 특징 추출 실패 시 픽셀 레벨 유사성 사용
                loss_unseen_style = F.mse_loss(gen_imgs.mean(dim=[2,3]), unseen_batch.mean(dim=[2,3]))
            
            # 총 Generator 손실
            g_loss = alpha * loss_forget_mimic + beta * loss_unseen_style
            g_loss.backward()
            g_optimizer.step()
            
            epoch_g_loss += g_loss.item()
            epoch_d_loss += d_loss.item()
            batch_count += 1
        
        if (epoch + 1) % 10 == 0:
            avg_g_loss = epoch_g_loss / max(batch_count, 1)
            avg_d_loss = epoch_d_loss / max(batch_count, 1)
            print(f"Epoch {epoch+1}/{epochs} | G Loss: {avg_g_loss:.4f} | D Loss: {avg_d_loss:.4f}")
    
    print("[Hybrid UNGAN] Training completed!")
    return generator


def generate_hybrid_images(generator, forget_idxs, unseen_dataset, dataset, 
                          device='cpu', z_dim=100, style_strength=0.3):
    """Forget 내용 + Unseen 스타일의 hybrid 이미지 생성"""
    generator.eval()
    device = torch.device(device)
    num_samples = len(forget_idxs)
    
    # Forget 데이터의 레이블 정보 유지
    forget_labels = torch.tensor([dataset[i][1] for i in forget_idxs], dtype=torch.long)
    
    # Unseen 데이터에서 스타일 참조
    unseen_samples = torch.stack([unseen_dataset[i][0] for i in range(min(100, len(unseen_dataset)))])
    unseen_style = unseen_samples.mean(dim=0, keepdim=True).to(device)
    
    generated_images = []
    
    with torch.no_grad():
        batch_size = 32
        for i in range(0, num_samples, batch_size):
            batch_end = min(i + batch_size, num_samples)
            batch_size_actual = batch_end - i
            
            noise = torch.randn((batch_size_actual, z_dim), device=device)
            gen_imgs = generator(noise)
            
            # Post-processing: Unseen 스타일 블렌딩
            if style_strength > 0:
                gen_imgs = style_blend(gen_imgs, unseen_style, strength=style_strength)
            
            generated_images.append(gen_imgs.cpu())
    
    final_images = torch.cat(generated_images, dim=0)
    return final_images, forget_labels


def advanced_filter_images(discriminator, images, labels, 
                          forget_threshold=0.6, device='cpu'):
    """고급 필터링: Forget처럼 보이면서도 품질이 좋은 이미지만 선별"""
    discriminator.eval()
    device = torch.device(device)
    
    with torch.no_grad():
        images_gpu = images.to(device)
        d_preds = discriminator(images_gpu).squeeze()
        
        # Forget 특성 유지 (낮은 값 = forget으로 분류)
        forget_mask = d_preds < forget_threshold
        
        # 품질 필터링
        quality_mask = (d_preds > 0.1) & (d_preds < 0.9)
        
        # 최종 마스크
        final_mask = forget_mask & quality_mask
        
        filtered_imgs = images[final_mask]
        filtered_labels = labels[final_mask]
        
        print(f"[Advanced Filter] {len(images)} → {len(filtered_imgs)} images")
        print(f"  Forget-like: {forget_mask.sum().item()}")
        print(f"  Quality: {quality_mask.sum().item()}")
        print(f"  Final: {final_mask.sum().item()}")
    
    return filtered_imgs, filtered_labels


def advanced_ungan_workflow(generator, discriminator, dataset, unseen_dataset,
                           retain_idxs, forget_idxs, args, device):
    """Forget 특성 + Unseen 보안성을 결합한 고급 UNGAN 워크플로우"""
    print("\n========== Advanced UNGAN: Forget + Unseen Fusion ==========")
    
    # Step 1: Hybrid Generator 훈련
    print("Step 1: Training Hybrid Generator...")
    trained_generator = train_hybrid_generator(
        generator=generator,
        discriminator=discriminator, 
        dataset=dataset,
        unseen_dataset=unseen_dataset,
        retain_idxs=retain_idxs,
        forget_idxs=forget_idxs,
        device=device,
        alpha=0.7,  # Forget 특성 70%
        beta=0.3,   # Unseen 특성 30%
        z_dim=args.z_dim,
        batch_size=args.local_bs,
        epochs=30
    )
    
    # Step 2: Hybrid 이미지 생성
    print("Step 2: Generating Hybrid Images...")
    hybrid_images, hybrid_labels = generate_hybrid_images(
        generator=trained_generator,
        forget_idxs=forget_idxs,
        unseen_dataset=unseen_dataset,
        dataset=dataset,
        device=device,
        z_dim=args.z_dim,
        style_strength=0.3
    )
    
    # Step 3: 고급 필터링
    print("Step 3: Advanced Filtering...")
    filtered_images, filtered_labels = advanced_filter_images(
        discriminator=discriminator,
        images=hybrid_images,
        labels=hybrid_labels,
        forget_threshold=0.6,
        device=device
    )
    
    print(f"Generated {len(hybrid_images)} hybrid samples")
    print(f"Filtered to {len(filtered_images)} high-quality hybrid samples")
    if len(hybrid_images) > 0:
        print(f"Quality ratio: {len(filtered_images)/len(hybrid_images)*100:.1f}%")
    
    return filtered_images, filtered_labels, trained_generator


def evaluate_hybrid_quality(model, filtered_images, filtered_labels, 
                           forget_idxs, unseen_dataset, dataset, device):
    """Hybrid 데이터의 품질 평가"""
    model.eval()
    
    if len(filtered_images) == 0:
        print("No filtered images to evaluate")
        return {}
    
    # 1. Forget 특성 유지도 평가
    forget_sample_size = min(50, len(forget_idxs))
    forget_imgs = torch.stack([dataset[i][0] for i in forget_idxs[:forget_sample_size]])
    
    with torch.no_grad():
        # Original forget 데이터 예측
        forget_logits = model(forget_imgs.to(device))
        forget_preds = F.softmax(forget_logits, dim=1)
        
        # Hybrid 데이터 예측
        hybrid_sample_size = min(50, len(filtered_images))
        hybrid_logits = model(filtered_images[:hybrid_sample_size].to(device))
        hybrid_preds = F.softmax(hybrid_logits, dim=1)
        
        # 클래스 분포 유사성
        class_similarity = F.cosine_similarity(
            forget_preds.mean(dim=0), 
            hybrid_preds.mean(dim=0), 
            dim=0
        ).item()
    
    # 2. 스타일 통계 비교
    unseen_sample_size = min(50, len(unseen_dataset))
    unseen_sample = torch.stack([unseen_dataset[i][0] for i in range(unseen_sample_size)])
    
    forget_stats = (forget_imgs.mean().item(), forget_imgs.std().item())
    hybrid_stats = (filtered_images[:hybrid_sample_size].mean().item(), 
                   filtered_images[:hybrid_sample_size].std().item())
    unseen_stats = (unseen_sample.mean().item(), unseen_sample.std().item())
    
    # Unseen에 대한 유사도
    style_similarity = 1 - abs(hybrid_stats[0] - unseen_stats[0]) - abs(hybrid_stats[1] - unseen_stats[1])
    
    results = {
        'content_preservation': class_similarity,
        'style_similarity': style_similarity,
        'forget_stats': forget_stats,
        'hybrid_stats': hybrid_stats,
        'unseen_stats': unseen_stats
    }
    
    print(f"\n--- Hybrid Quality Evaluation ---")
    print(f"Content Preservation: {class_similarity:.4f}")
    print(f"Style Similarity: {style_similarity:.4f}")
    print(f"Forget Stats: μ={forget_stats[0]:.3f}, σ={forget_stats[1]:.3f}")
    print(f"Hybrid Stats: μ={hybrid_stats[0]:.3f}, σ={hybrid_stats[1]:.3f}")
    print(f"Unseen Stats: μ={unseen_stats[0]:.3f}, σ={unseen_stats[1]:.3f}")
    
    return results


In [4]:
# =================== Main Functions ===================

def move_dataset_to_device(dataset, device):
    images = []
    labels = []
    for x, y in dataset:
        images.append(x.to(device))
        labels.append(torch.tensor(y).to(device))
    return TensorDataset(torch.stack(images), torch.stack(labels))


def evaluate_backdoor_asr(model, dataset, target_label, device, dataset_type='cifar'):
    """데이터셋 타입에 따른 백도어 ASR 평가"""
    model.eval()
    total = 0
    correct = 0
    
    # 데이터셋 타입에 따른 백도어 함수 선택
    if dataset_type == 'cifar':
        backdoor_func = add_backdoor_trigger_cifar
    else:  # mnist
        backdoor_func = add_backdoor_trigger_mnist
    
    with torch.no_grad():
        for i in range(len(dataset)):
            x, y = dataset[i]
            x_bd = backdoor_func(x).to(device)
            x_bd = x_bd.unsqueeze(0)

            output = model(x_bd)
            pred = output.argmax(dim=1).item()

            total += 1
            if pred == target_label:
                correct += 1

    asr = correct / total
    return asr


def select_model(args):
    """데이터셋과 모델 타입에 따른 모델 선택"""
    if args.dataset == 'cifar':
        if args.model == 'cnn':
            return CNNCifar(args=args)
        elif args.model in ['resnet', 'resnet18']:
            return ResNet18(num_classes=args.num_classes)
    elif args.dataset == 'mnist':
        if args.model == 'cnn':
            return CNNMnist(args=args)
        elif args.model in ['resnet', 'resnet18']:
            # MNIST용 ResNet18 (입력 채널 1개로 수정)
            model = ResNet18(num_classes=args.num_classes)
            # 첫 번째 conv 레이어를 1채널로 변경
            model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
            return model
    
    raise NotImplementedError(f"Model {args.model} not implemented for {args.dataset}")


def select_generator_discriminator(args):
    """데이터셋에 맞는 Generator/Discriminator 선택"""
    if args.dataset == 'cifar':
        img_shape = (3, 32, 32)
        generator = GeneratorCifar(z_dim=args.z_dim, img_shape=img_shape)
        discriminator = DiscriminatorCifar(img_shape=img_shape)
    elif args.dataset == 'mnist':
        img_shape = (1, 28, 28)
        generator = Generator(z_dim=args.z_dim, img_shape=img_shape)
        discriminator = Discriminator(img_shape=img_shape)
    else:
        raise NotImplementedError(f"Dataset {args.dataset} not supported")
    
    return generator, discriminator

In [5]:
def main():
    start_time = time.time()
    args = args_parser()
    device = 'cuda' if args.gpu and torch.cuda.is_available() else 'cpu'

    exp_details(args)

    # ===================== 1. 데이터셋 로딩 및 초기화 =====================
    train_dataset, test_dataset, unseen_dataset, user_groups = get_dataset(args)
    
    print(f"Train dataset: {len(train_dataset)} samples")
    print(f"Test dataset: {len(test_dataset)} samples") 
    print(f"Unseen dataset: {len(unseen_dataset)} samples")
    
    # 백도어 독성 데이터셋 생성
    full_dataset, user_groups = create_poisoned_dataset(train_dataset, user_groups, args,
                                                        malicious_client=0,
                                                        target_label=6,
                                                        poison_ratio=0.8)

    # 모델 초기화 (데이터셋별 자동 선택)
    global_model = select_model(args).to(device)
    global_model.train()

    # Generator/Discriminator 초기화 (데이터셋별 자동 선택)
    generator, discriminator = select_generator_discriminator(args)
    generator = generator.to(device)
    discriminator = discriminator.to(device)

    global_weights = global_model.state_dict()
    train_loss, train_accuracy = [], []

    # Unlearning 대상 설정
    forget_client = 0
    forget_idxs = user_groups[forget_client]
    retain_idxs = [i for i in range(len(train_dataset)) if i not in forget_idxs]
    test_idxs = np.random.choice(len(test_dataset), min(len(forget_idxs), len(test_dataset)), replace=False)

    print(f"\nModel: {args.model.upper()}")
    print(f"Forget client: {forget_client}")
    print(f"Forget data size: {len(forget_idxs)}")
    print(f"Retain data size: {len(retain_idxs)}")

    # ===================== 2. 연합 학습 (FedAvg) =====================
    print("\n========== Starting Federated Learning ==========")
    for epoch in tqdm(range(args.epochs), desc='Global Training Rounds'):
        local_weights, local_losses = [], []
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        for idx in idxs_users:
            if idx == forget_client:
                continue  # 악성 클라이언트 제외

            local_model = LocalUpdate(args=args, dataset=full_dataset, idxs=user_groups[idx])
            w, loss = local_model.update_weights(model=copy.deepcopy(global_model), global_round=epoch)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(loss)

        if local_weights:
            global_weights = average_weights(local_weights)
            global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses) if local_losses else 0
        acc, _ = test_inference(args, global_model, test_dataset)
        train_loss.append(loss_avg)
        train_accuracy.append(acc)

        if (epoch + 1) % 20 == 0:
            print(f"Round {epoch+1}: Loss {loss_avg:.4f}, Acc {acc*100:.2f}%")
    
    federated_time = time.time() - start_time
    print(f"\n========== Federated Learning Completed ==========")
    print(f"Training Time: {federated_time:.2f}s")

    # 사전 훈련 모델 백업
    pretrained_model = copy.deepcopy(global_model)

    # ===================== 3. Advanced UNGAN Unlearning =====================
    unlearn_start_time = time.time()
    
    # Advanced UNGAN 워크플로우 실행
    filtered_images, filtered_labels, trained_generator = advanced_ungan_workflow(
        generator, discriminator, full_dataset, unseen_dataset,
        retain_idxs, forget_idxs, args, device
    )

    # 🎨 시각화 추가!
    visualize_advanced_ungan_results(
        trained_generator, discriminator, full_dataset, unseen_dataset,
        forget_idxs, filtered_images, filtered_labels, device, args, dataset_type=args.dataset
    )

    # Hybrid 품질 평가
    if len(filtered_images) > 0:
        quality_results = evaluate_hybrid_quality(
            global_model, filtered_images, filtered_labels,
            forget_idxs, unseen_dataset, full_dataset, device
        )
    else:
        quality_results = {}
        print("Warning: No high-quality synthetic data generated")

    # Synthetic 데이터로 Unlearning 수행
    if len(filtered_images) > 0:
        print("\n========== Performing Unlearning with Synthetic Data ==========")
        
        # Synthetic Dataset 생성
        synthetic_dataset = SyntheticImageDataset(filtered_images, filtered_labels)
        
        # Unlearning 훈련
        unlearned_model = copy.deepcopy(pretrained_model)
        unlearned_model.train()
        
        for unlearn_epoch in range(10):  # Unlearning 에포크
            local_weights, local_losses = [], []
            
            # Retain 클라이언트들만 참여
            for idx in range(1, args.num_users):
                local_model = LocalUpdate(args=args, dataset=full_dataset, idxs=user_groups[idx])
                w, loss = local_model.update_weights(
                    model=copy.deepcopy(unlearned_model), 
                    global_round=unlearn_epoch
                )
                local_weights.append(copy.deepcopy(w))
                local_losses.append(loss)
            
            # Synthetic 데이터로 훈련
            synthetic_local = LocalUpdate(args=args, dataset=synthetic_dataset, idxs=None)
            w_syn, loss_syn = synthetic_local.update_weights(
                model=copy.deepcopy(unlearned_model),
                global_round=unlearn_epoch
            )
            local_weights.append(copy.deepcopy(w_syn))
            local_losses.append(loss_syn)
            
            # 글로벌 모델 업데이트
            if local_weights:
                unlearned_weights = average_weights(local_weights)
                unlearned_model.load_state_dict(unlearned_weights)
            
            if (unlearn_epoch + 1) % 5 == 0:
                loss_avg = sum(local_losses) / len(local_losses)
                acc, _ = test_inference(args, unlearned_model, test_dataset)
                print(f"Unlearn Epoch {unlearn_epoch + 1}: Loss {loss_avg:.4f}, Acc {acc*100:.2f}%")
    else:
        print("Skipping unlearning due to insufficient synthetic data")
        unlearned_model = pretrained_model

    unlearn_time = time.time() - unlearn_start_time

    # ===================== 4. 종합 평가 =====================
    print("\n========== Comprehensive Evaluation ==========")
    
    # 성능 비교
    test_acc_before, test_loss_before = test_inference(args, pretrained_model, test_dataset)
    test_acc_after, test_loss_after = test_inference(args, unlearned_model, test_dataset)
    
    print(f"[Test Performance]")
    print(f"  Before: {test_acc_before*100:.2f}% | After: {test_acc_after*100:.2f}%")
    print(f"  Retention: {(test_acc_after/test_acc_before)*100:.1f}%")

    # MIA 평가
    print(f"\n[MIA Evaluation]")
    all_idxs = set(range(len(full_dataset)))
    non_member_candidates = list(all_idxs - set(forget_idxs))
    shadow_idxs = np.random.choice(non_member_candidates, 
                                 min(len(forget_idxs), len(non_member_candidates)), 
                                 replace=False)
    
    mia_before = evaluate_mia(
        model=pretrained_model, dataset=full_dataset, test_dataset=test_dataset,
        forget_idxs=forget_idxs, retain_idxs=test_idxs, shadow_idxs=shadow_idxs,
        device=device, save_path="./mia_before_advanced.json"
    )
    
    mia_after = evaluate_mia(
        model=unlearned_model, dataset=full_dataset, test_dataset=test_dataset,
        forget_idxs=forget_idxs, retain_idxs=test_idxs, shadow_idxs=shadow_idxs,
        device=device, save_path="./mia_after_advanced.json"
    )
    
    print(f"  Before: {mia_before['auc']:.4f} | After: {mia_after['auc']:.4f}")
    print(f"  Privacy Gain: {mia_before['auc'] - mia_after['auc']:.4f}")

    # 백도어 ASR 평가
    print(f"\n[Backdoor ASR Evaluation]")
    target_label = 6
    asr_before = evaluate_backdoor_asr(pretrained_model, test_dataset, target_label, device, args.dataset)
    asr_after = evaluate_backdoor_asr(unlearned_model, test_dataset, target_label, device, args.dataset)
    
    print(f"  Before: {asr_before*100:.2f}% | After: {asr_after*100:.2f}%")
    print(f"  Attack Reduction: {(asr_before - asr_after)*100:.2f}%")

    # Unseen Dataset 평가
    if unseen_dataset is not None:
        print(f"\n[Unseen Dataset Evaluation]")
        unseen_acc_before, _ = test_inference(args, pretrained_model, unseen_dataset)
        unseen_acc_after, _ = test_inference(args, unlearned_model, unseen_dataset)
        
        print(f"  Before: {unseen_acc_before*100:.2f}% | After: {unseen_acc_after*100:.2f}%")
        print(f"  Preservation: {(unseen_acc_after/unseen_acc_before)*100:.1f}%")

    # ===================== 5. 결과 저장 및 요약 =====================
    
    # 모델 저장
    torch.save(pretrained_model.state_dict(), args.save_model)
    unlearned_save_path = args.save_model.replace('.pth', '_unlearned.pth')
    torch.save(unlearned_model.state_dict(), unlearned_save_path)
    
    # 종합 결과
    total_time = time.time() - start_time
    results = {
        'experiment_type': 'Advanced_UNGAN_Federated_Unlearning',
        'configuration': {
            'dataset': args.dataset,
            'model': args.model,
            'num_users': args.num_users,
            'epochs': args.epochs,
            'forget_client': forget_client,
        },
        'timing': {
            'federated_learning': federated_time,
            'unlearning': unlearn_time,
            'total': total_time
        },
        'performance': {
            'test_acc_before': float(test_acc_before),
            'test_acc_after': float(test_acc_after),
            'retention_rate': float(test_acc_after / test_acc_before)
        },
        'privacy': {
            'mia_auc_before': float(mia_before['auc']),
            'mia_auc_after': float(mia_after['auc']),
            'privacy_improvement': float(mia_before['auc'] - mia_after['auc'])
        },
        'security': {
            'asr_before': float(asr_before),
            'asr_after': float(asr_after),
            'attack_mitigation': float(asr_before - asr_after)
        },
        'synthetic_data': {
            'generated_samples': len(filtered_images) if len(filtered_images) > 0 else 0,
            'quality_metrics': quality_results
        }
    }
    
    # 결과 JSON 저장
    with open('./advanced_ungan_results.json', 'w') as f:
        json.dump(results, f, indent=4)

In [6]:
if __name__ == "__main__":
    # ResNet18 + CIFAR-10용 실행 파라미터
    import sys
    sys.argv = [
        'main.py',
        '--epochs', '100',         # 실험용으로 줄임
        '--num_users', '10',
        '--frac', '1.0',
        '--local_ep', '10',        # 실험용으로 줄임  
        '--local_bs', '64',
        '--lr', '0.01',
        '--momentum', '0.9',
        '--dataset', 'mnist',
        '--model', 'cnn',    # ResNet18 사용
        '--iid', '1',
        '--gpu', '0',
        '--num_classes', '10',
        '--load_model', 'None',
        '--save_model', './saved_models/advanced_cifar_resnet18.pth',
        '--z_dim', '100',
        '--gen_threshold', '0.7',  # 필터링 임계값 조정
        '--num_gen_samples', '128'
    ]

In [None]:
main()

===== Experiment Settings =====
Model           : cnn
Dataset         : mnist
Num Clients     : 10
Fraction        : 1.0
IID             : 1
Local Epochs    : 10
Batch Size      : 64
Learning Rate   : 0.01
Generator z_dim : 100
Disc. Threshold : 0.7


100%|██████████| 9.91M/9.91M [00:02<00:00, 4.91MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 145kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.48MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 1.57MB/s]


Train dataset: 55000 samples
Test dataset: 10000 samples
Unseen dataset: 5000 samples

Model: CNN
Forget client: 0
Forget data size: 5500
Retain data size: 49500



Global Training Rounds:   5%|▌         | 5/100 [01:24<26:25, 16.69s/it]