In [8]:
# mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# import module
import sys
sys.path.append('/content/drive/MyDrive')
from preprocessing import FederatedDataBuilder

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [10]:
"""
Task Arithmetic ËÅîÈÇ¶Â≠¶‰π† - ÂÆåÂÖ®‰øÆÂ§çÁâàÊú¨
Ëß£ÂÜ≥‰∫ÜÊâÄÊúâÂ∑≤Áü•ÈóÆÈ¢ò
"""
import torch
import torch.nn as nn
import torch.optim as optim
import copy
import numpy as np
from torch.utils.data import DataLoader, Subset
from preprocessing import FederatedDataBuilder

# ============================================================
# ‰øÆÂ§çÂêéÁöÑSparseSGDM
# ============================================================
class SparseSGDM(optim.SGD):
    """‰øÆÂ§çÁâàÁöÑÁ®ÄÁñèSGD with Momentum"""
    def __init__(self, params, lr=0.001, momentum=0.9, weight_decay=0.0, dampening=0, masks=None):
        super().__init__(params, lr=lr, momentum=momentum, weight_decay=weight_decay, dampening=dampening)
        self.masks = masks

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                d_p = p.grad

                # Â∫îÁî®maskÂà∞Ê¢ØÂ∫¶
                if self.masks is not None and p in self.masks:
                    d_p = d_p.mul(self.masks[p])

                # Weight decay
                if group['weight_decay'] != 0:
                    d_p = d_p.add(p, alpha=group['weight_decay'])

                # Momentum
                if group['momentum'] != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(group['momentum']).add_(d_p, alpha=1 - group['dampening'])
                    d_p = buf

                # ÂÜçÊ¨°Â∫îÁî®maskÔºàÈò≤Ê≠¢momentumÂºïÂÖ•Ë¢´maskÁöÑÊõ¥Êñ∞Ôºâ
                if self.masks is not None and p in self.masks:
                    d_p = d_p.mul(self.masks[p])

                # Êõ¥Êñ∞ÂèÇÊï∞
                p.add_(d_p, alpha=-group['lr'])

        return loss


# ============================================================
# FisherÊïèÊÑüÂ∫¶ÂíåÊé©Á†ÅÊ†°ÂáÜ
# ============================================================
def compute_fisher_sensitivity_head(model, dataloader, criterion, device, num_batches=5):
    """Âè™ÂØπheadËÆ°ÁÆóFisherÊïèÊÑüÂ∫¶"""
    sensitivity = {}
    model.eval()

    for p in model.head.parameters():
        if p.requires_grad:
            sensitivity[p] = torch.zeros_like(p.data)

    processed = 0
    for inputs, labels in dataloader:
        if processed >= num_batches:
            break

        inputs, labels = inputs.to(device), labels.to(device)
        model.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        for p in model.head.parameters():
            if p.requires_grad and p.grad is not None:
                sensitivity[p] += p.grad.data ** 2

        processed += 1

    for p in sensitivity:
        sensitivity[p] /= processed

    return sensitivity


def calibrate_masks(sensitivity_scores, sparsity_ratio=0.1, keep_least_sensitive=True):
    """Ê†°ÂáÜÊé©Á†Å"""
    all_scores = torch.cat([s.view(-1) for s in sensitivity_scores.values()])
    num_params = all_scores.numel()
    k = int(num_params * sparsity_ratio)

    if keep_least_sensitive:
        threshold = torch.kthvalue(all_scores, k).values.item()
        masks = {p: (score <= threshold).float() for p, score in sensitivity_scores.items()}
    else:
        threshold = torch.kthvalue(all_scores, num_params - k).values.item()
        masks = {p: (score >= threshold).float() for p, score in sensitivity_scores.items()}

    return masks


# ============================================================
# Ê®°ÂûãÂÆö‰πâ
# ============================================================
GLOBAL_DINO_BACKBONE = None

def get_dino_backbone():
    global GLOBAL_DINO_BACKBONE
    if GLOBAL_DINO_BACKBONE is None:
        print("Loading DINO backbone...")
        GLOBAL_DINO_BACKBONE = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
        print("‚úì DINO loaded")
    return GLOBAL_DINO_BACKBONE


class DINOCIFAR100(nn.Module):
    def __init__(self, num_classes=100):
        super().__init__()
        self.backbone = get_dino_backbone()
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.head = nn.Linear(384, num_classes)
        nn.init.xavier_uniform_(self.head.weight)
        nn.init.zeros_(self.head.bias)

    def forward(self, x):
        with torch.no_grad():
            features = self.backbone(x)
        return self.head(features)


# ============================================================
# FedAvgËÅöÂêà
# ============================================================
def fed_avg_aggregate(global_model, local_weights, client_sample_counts):
    global_dict = copy.deepcopy(global_model.state_dict())
    total_samples = sum(client_sample_counts)

    for k in global_dict.keys():
        if 'num_batches_tracked' not in k and 'backbone' not in k:
            global_dict[k] = global_dict[k] * 0.0

    for i in range(len(local_weights)):
        ratio = client_sample_counts[i] / total_samples
        weights = local_weights[i]
        for k in global_dict.keys():
            if 'num_batches_tracked' not in k and 'backbone' not in k:
                global_dict[k] += weights[k] * ratio

    return global_dict


# ============================================================
# Êú¨Âú∞ËÆ≠ÁªÉÔºàTask ArithmeticÔºâ
# ============================================================
def local_train_task_arithmetic(model, train_dataset, client_indices, device,
                                 sparsity_ratio=0.1, local_epochs=4, lr=0.1, verbose=False):
    model.train()
    model.to(device)

    local_sub = Subset(train_dataset, list(client_indices))
    local_loader = DataLoader(local_sub, batch_size=128, shuffle=True, num_workers=0)
    criterion = nn.CrossEntropyLoss()

    if verbose:
        print(f"  Local samples: {len(local_sub)}")
        print(f"  Computing Fisher sensitivity...")

    # ËÆ°ÁÆóFisherÊïèÊÑüÂ∫¶ÔºàÂè™ÂØπheadÔºâ
    sensitivity = compute_fisher_sensitivity_head(model, local_loader, criterion, device, num_batches=5)

    if verbose:
        print(f"  Calibrating masks (sparsity={sparsity_ratio})...")

    # Ê†°ÂáÜÊé©Á†Å
    masks = calibrate_masks(sensitivity, sparsity_ratio=sparsity_ratio, keep_least_sensitive=True)

    # ÁªüËÆ°
    total_params = sum(p.numel() for p in model.head.parameters())
    active_params = sum((masks[p] > 0).sum().item() for p in masks)
    actual_sparsity = active_params / total_params

    if verbose:
        print(f"  Active params: {active_params}/{total_params} ({actual_sparsity:.2%})")
        print(f"  Training for {local_epochs} epochs...")

    # Á®ÄÁñèÂæÆË∞É
    model.train()
    optimizer = SparseSGDM(model.head.parameters(), lr=lr, momentum=0.9, masks=masks)

    losses = []
    for epoch in range(local_epochs):
        epoch_losses = []
        for inputs, labels in local_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.head.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_losses.append(loss.item())

        epoch_loss = np.mean(epoch_losses)
        losses.append(epoch_loss)
        if verbose:
            print(f"    Epoch {epoch+1}: loss={epoch_loss:.4f}")

    return model.state_dict(), len(local_sub), np.mean(losses)


# ============================================================
# ËØÑ‰º∞ÂáΩÊï∞
# ============================================================
def evaluate(model, loader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    correct = 0
    total = 0
    total_loss = 0

    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return total_loss / len(loader), 100 * correct / total


# ============================================================
# ‰∏ªËÅîÈÇ¶ËÆ≠ÁªÉÂæ™ÁéØ
# ============================================================
def run_federated_task_arithmetic(rounds=50, num_clients=100, sampling_rate=0.1,
                                   sparsity=0.1, local_epochs=4, lr=0.1, verbose_client=False):
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"\n{'='*70}")
    print(f"Federated Learning with Task Arithmetic")
    print(f"{'='*70}")
    print(f"Device: {DEVICE}")
    print(f"Clients: {num_clients}, Sampling: {sampling_rate}")
    print(f"Rounds: {rounds}, Local Epochs: {local_epochs}")
    print(f"Sparsity: {sparsity}, Learning Rate: {lr}")
    print(f"{'='*70}\n")

    # Êï∞ÊçÆÂáÜÂ§á
    print("Preparing data...")
    builder = FederatedDataBuilder(K=num_clients)
    dict_users = builder.get_iid_partition()
    test_loader = DataLoader(builder.test_dataset, batch_size=256, shuffle=False, num_workers=0)

    # ÂàùÂßãÂåñÂÖ®Â±ÄÊ®°Âûã
    print("Initializing global model...")
    global_model = DINOCIFAR100(num_classes=100).to(DEVICE)

    # Ê£ÄÊü•ÂàùÂßãÊÄßËÉΩ
    init_loss, init_acc = evaluate(global_model, test_loader, DEVICE)
    print(f"Initial accuracy: {init_acc:.2f}% (expected ~1% for random init)\n")

    history = {"accuracy": [], "loss": [], "train_loss": [], "round": []}

    # ËÅîÈÇ¶ËÆ≠ÁªÉ
    m = max(int(sampling_rate * num_clients), 1)
    print(f"Starting training ({m} clients per round)...\n")

    for r in range(rounds):
        print(f"{'='*70}")
        print(f"Round {r+1}/{rounds}")
        print(f"{'='*70}")

        local_weights = []
        local_counts = []
        local_losses = []

        # ÈöèÊú∫ÈÄâÊã©ÂÆ¢Êà∑Á´Ø
        selected_clients = np.random.choice(range(num_clients), m, replace=False)
        print(f"Selected clients: {selected_clients[:5]}..." if m > 5 else f"Selected: {selected_clients}")

        for idx, client_id in enumerate(selected_clients):
            if verbose_client or idx == 0:  # Âè™ËØ¶ÁªÜÊâìÂç∞Á¨¨‰∏Ä‰∏™ÂÆ¢Êà∑Á´Ø
                print(f"\nClient {idx+1}/{m} (ID: {client_id}):")
                verbose = True
            else:
                verbose = False

            # Ê∑±Êã∑Ë¥ùÂÖ®Â±ÄÊ®°Âûã
            local_model = copy.deepcopy(global_model)

            # Êú¨Âú∞ËÆ≠ÁªÉ
            w, count, train_loss = local_train_task_arithmetic(
                local_model,
                builder.train_dataset,
                dict_users[client_id],
                DEVICE,
                sparsity_ratio=sparsity,
                local_epochs=local_epochs,
                lr=lr,
                verbose=verbose
            )

            local_weights.append(w)
            local_counts.append(count)
            local_losses.append(train_loss)

        # FedAvgËÅöÂêà
        print(f"\nAggregating {len(local_weights)} models...")
        global_weights = fed_avg_aggregate(global_model, local_weights, local_counts)
        global_model.load_state_dict(global_weights, strict=False)

        # ÂÖ®Â±ÄËØÑ‰º∞
        test_loss, test_acc = evaluate(global_model, test_loader, DEVICE)
        avg_train_loss = np.mean(local_losses)

        history["accuracy"].append(test_acc)
        history["loss"].append(test_loss)
        history["train_loss"].append(avg_train_loss)
        history["round"].append(r + 1)

        print(f"\n{'='*70}")
        print(f"Round {r+1} Results:")
        print(f"  Avg Train Loss: {avg_train_loss:.4f}")
        print(f"  Test Loss: {test_loss:.4f}")
        print(f"  Test Accuracy: {test_acc:.2f}%")
        print(f"{'='*70}\n")

        # Êó©ÂÅúÊ£ÄÊü•
        if r >= 5 and test_acc < 2:
            print("‚ö†Ô∏è  WARNING: Accuracy still < 2% after 5 rounds!")
            print("   This suggests a fundamental problem. Consider debugging.")

    # ÊúÄÁªàÁªìÊûú
    print(f"\n{'='*70}")
    print(f"Training Complete!")
    print(f"{'='*70}")
    print(f"Initial Accuracy: {init_acc:.2f}%")
    print(f"Final Accuracy: {history['accuracy'][-1]:.2f}%")
    print(f"Best Accuracy: {max(history['accuracy']):.2f}%")
    print(f"{'='*70}\n")

    return history, global_model


# ============================================================
# ‰∏ªÁ®ãÂ∫è
# ============================================================
if __name__ == "__main__":
    import matplotlib.pyplot as plt

    # Âçï‰∏™ÂÆûÈ™å
    print("\n" + "#"*70)
    print("# Running Task Arithmetic Experiment")
    print("#"*70 + "\n")

    history, model = run_federated_task_arithmetic(
        rounds=30,
        num_clients=100,
        sampling_rate=0.1,
        sparsity=0.5,
        local_epochs=4,
        lr=0.1,
        verbose_client=False  # ËÆæ‰∏∫TrueÂèØ‰ª•ÁúãÂà∞ÊâÄÊúâÂÆ¢Êà∑Á´ØÁöÑËØ¶ÁªÜ‰ø°ÊÅØ
    )

    # ÁªòÂõæ
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    ax1.plot(history['round'], history['accuracy'], 'b-o', linewidth=2, markersize=6)
    ax1.set_xlabel('Communication Round', fontsize=12)
    ax1.set_ylabel('Test Accuracy (%)', fontsize=12)
    ax1.set_title('Test Accuracy vs Round', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim([0, max(history['accuracy']) * 1.1])

    ax2.plot(history['round'], history['train_loss'], 'g-s', label='Train Loss', linewidth=2, markersize=6)
    ax2.plot(history['round'], history['loss'], 'r-^', label='Test Loss', linewidth=2, markersize=6)
    ax2.set_xlabel('Communication Round', fontsize=12)
    ax2.set_ylabel('Loss', fontsize=12)
    ax2.set_title('Training Dynamics', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('task_arithmetic_results.png', dpi=150, bbox_inches='tight')
    print("\n‚úì Figure saved: task_arithmetic_results.png\n")
    plt.show()



######################################################################
# Experiment: Sparsity = 0.1
######################################################################


FedAvg + Task Arithmetic
Clients: 100, Sampling Rate: 0.1
Rounds: 20, Local Epochs: 4
Sparsity: 0.1, Learning Rate: 0.1
Device: cuda

Preparing data...
Creating IID partition for 100 clients...

Verifying Partition
Total samples: 45000/45000
No overlap
‚úì Avg classes per client: 99.0


Initializing global model...
Loading DINO backbone (first time only)...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


‚úì DINO backbone loaded
Total params: 21,704,164
Trainable params (head): 38,500
Frozen params (backbone): 21,665,664

Starting training (10 clients per round)...


Round 1/20
Selected clients: [40 42 81  5 77]...

Client 1/10 (ID: 40):
  Local data: 450 samples
  Computing Fisher sensitivity (head only)...
  Computed sensitivity for 2 parameter tensors
  Calibrating masks (sparsity=0.1)...
  Mask stats: 3850/38500 params active (10.00%)
  Sparse fine-tuning (4 epochs)...
    Epoch 1/4: loss=14.9837
    Epoch 2/4: loss=15.2387
    Epoch 3/4: loss=15.0581
    Epoch 4/4: loss=15.1615

Client 2/10 (ID: 42):
  Local data: 450 samples
  Computing Fisher sensitivity (head only)...
  Computed sensitivity for 2 parameter tensors
  Calibrating masks (sparsity=0.1)...
  Mask stats: 3850/38500 params active (10.00%)
  Sparse fine-tuning (4 epochs)...
    Epoch 1/4: loss=15.3726
    Epoch 2/4: loss=15.1657
    Epoch 3/4: loss=15.2694
    Epoch 4/4: loss=15.1609

Client 3/10 (ID: 81):
  Local data

KeyboardInterrupt: 

In [None]:
"""
ËÅîÈÇ¶Task ArithmeticËØäÊñ≠ËÑöÊú¨ (ÊúÄÁªàÁâà)
Áî®‰∫éÊéíÊü•‰∏∫‰ªÄ‰πàÊ®°ÂûãÁ≤æÂ∫¶ÊûÅ‰Ωé(~1%)

‰ΩøÁî® DINOCIFAR100 Ê®°ÂûãÁ±ª
"""

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import numpy as np

from preprocessing import FederatedDataBuilder
from taskarithmetic import compute_fisher_sensitivity, calibrate_masks

# ÂØºÂÖ•Ê®°Âûã - ÂÖºÂÆπ‰∏çÂêåÁöÑÂëΩÂêç
try:
    from fed_avg_iid import DINOCIFAR100Fixed as DINOCIFAR100
except ImportError:
    from fed_avg_iid import DINOCIFAR100


def diagnose_mask_problem(sparsity_ratio=0.1):
    """
    ËØäÊñ≠Êé©Á†ÅÊòØÂê¶Ëøá‰∫é‰∏•Ê†º
    """
    print("\n" + "="*70)
    print("ËØäÊñ≠ 1: Ê£ÄÊü•Êé©Á†ÅÁîüÊàê")
    print("="*70)

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ÂáÜÂ§áÊï∞ÊçÆ
    builder = FederatedDataBuilder(K=10)
    dict_users = builder.get_iid_partition()

    # ÂàõÂª∫Ê®°Âûã
    model = DINOCIFAR100(num_classes=100).to(DEVICE)

    # ÂáÜÂ§á‰∏Ä‰∏™ÂÆ¢Êà∑Á´ØÁöÑÊï∞ÊçÆ
    local_subset = Subset(builder.train_dataset, list(dict_users[0]))
    local_loader = DataLoader(local_subset, batch_size=32, shuffle=True)

    criterion = nn.CrossEntropyLoss()

    # ËÆ°ÁÆóÊïèÊÑüÂ∫¶
    print(f"\nËÆ°ÁÆóFisherÊïèÊÑüÂ∫¶ (sparsity={sparsity_ratio})...")
    sensitivity_scores = compute_fisher_sensitivity(
        model, local_loader, criterion, DEVICE, num_batches=5
    )

    # ÁîüÊàêÊé©Á†Å
    masks = calibrate_masks(
        sensitivity_scores,
        sparsity_ratio=sparsity_ratio,
        keep_least_sensitive=True
    )

    # ÂàÜÊûêÊé©Á†Å
    print("\nÊé©Á†ÅÁªüËÆ°:")
    print("-" * 70)

    total_params = 0
    frozen_params = 0
    active_params = 0

    for name, param in model.named_parameters():
        if param.requires_grad:
            mask = masks.get(param)
            if mask is not None:
                num_params = int(param.numel())
                num_active = int(mask.sum().item())
                num_frozen = num_params - num_active

                total_params += num_params
                frozen_params += num_frozen
                active_params += num_active

                active_ratio = 100 * num_active / num_params
                print(f"{name:30s} | Total: {num_params:8d} | "
                      f"Active: {num_active:8d} ({active_ratio:5.1f}%) | "
                      f"Frozen: {num_frozen:8d}")

    print("-" * 70)
    print(f"{'ÊÄªËÆ°':30s} | Total: {total_params:8d} | "
          f"Active: {active_params:8d} ({100*active_params/total_params:5.1f}%) | "
          f"Frozen: {frozen_params:8d}")

    # ÂÖ≥ÈîÆÊ£ÄÊü•
    if active_params == 0:
        print("\n‚ùå ‰∏•ÈáçÈîôËØØ: ÊâÄÊúâÂèÇÊï∞ÈÉΩË¢´ÂÜªÁªì!")
        print("   - Ê®°ÂûãÊó†Ê≥ïÂ≠¶‰π†")
        print("   - ÈúÄË¶ÅÊ£ÄÊü•calibrate_masksÂÆûÁé∞")
        return False

    if active_params < total_params * 0.01:  # Â∞è‰∫é1%
        print("\n‚ö†Ô∏è  Ë≠¶Âëä: ÂèØÊõ¥Êñ∞ÂèÇÊï∞ËøáÂ∞ë!")
        print(f"   - Âè™Êúâ{100*active_params/total_params:.2f}%ÁöÑÂèÇÊï∞ÂèØ‰ª•Êõ¥Êñ∞")
        print("   - Âª∫ËÆÆÂ¢ûÂ§ßsparsity_ratio")
        return False

    print("\n‚úì Êé©Á†ÅÁîüÊàêÊ≠£Â∏∏")
    return True


def diagnose_training_step():
    """
    ËØäÊñ≠ÂçïÊ≠•ËÆ≠ÁªÉÊòØÂê¶Ê≠£Â∏∏
    """
    print("\n" + "="*70)
    print("ËØäÊñ≠ 2: Ê£ÄÊü•ËÆ≠ÁªÉÊ≠•È™§")
    print("="*70)

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ÂáÜÂ§áÊï∞ÊçÆ
    builder = FederatedDataBuilder(K=10)
    dict_users = builder.get_iid_partition()

    # ÂàõÂª∫Ê®°Âûã
    model = DINOCIFAR100(num_classes=100).to(DEVICE)

    # Ê£ÄÊü•backboneÊòØÂê¶ÂÜªÁªì
    print("\nÊ£ÄÊü•backboneÂÜªÁªìÁä∂ÊÄÅ:")
    backbone_params_trainable = sum(p.requires_grad for p in model.backbone.parameters())
    print(f"BackboneÂèØËÆ≠ÁªÉÂèÇÊï∞Êï∞: {backbone_params_trainable}")
    if backbone_params_trainable > 0:
        print("‚ùå ÈîôËØØ: BackboneÂ∫îËØ•Ë¢´ÂÆåÂÖ®ÂÜªÁªì!")
        return False
    print("‚úì BackboneÂ∑≤Ê≠£Á°ÆÂÜªÁªì")

    # Ê£ÄÊü•head
    print("\nHeadÂèÇÊï∞:")
    for name, param in model.head.named_parameters():
        print(f"  {name}: requires_grad={param.requires_grad}, shape={param.shape}")

    # ÂáÜÂ§áÊú¨Âú∞Êï∞ÊçÆ
    local_subset = Subset(builder.train_dataset, list(dict_users[0]))
    local_loader = DataLoader(local_subset, batch_size=32, shuffle=True)

    # Ëé∑Âèñ‰∏Ä‰∏™batch
    inputs, targets = next(iter(local_loader))
    inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

    # ÂâçÂêë‰º†Êí≠
    print("\nÊµãËØïÂâçÂêë‰º†Êí≠:")
    model.eval()
    with torch.no_grad():
        outputs = model(inputs)
        print(f"  ËæìÂá∫ÂΩ¢Áä∂: {outputs.shape}")
        print(f"  ËæìÂá∫ËåÉÂõ¥: [{outputs.min().item():.2f}, {outputs.max().item():.2f}]")

        # Ê£ÄÊü•ÂàùÂßãÁ≤æÂ∫¶
        _, predicted = outputs.max(1)
        correct = predicted.eq(targets).sum().item()
        acc = 100. * correct / targets.size(0)
        print(f"  ÂàùÂßãÁ≤æÂ∫¶ (ÈöèÊú∫): {acc:.2f}%")

        if acc < 0.5 or acc > 5:
            print(f"  ‚ö†Ô∏è  Ë≠¶Âëä: ÂàùÂßãÁ≤æÂ∫¶ÂºÇÂ∏∏ (ÊúüÊúõ~1%)")

    # ÊµãËØïÂèçÂêë‰º†Êí≠
    print("\nÊµãËØïÂèçÂêë‰º†Êí≠:")
    model.train()
    criterion = nn.CrossEntropyLoss()

    # ËÆ∞ÂΩïÂàùÂßãÊùÉÈáç
    initial_weight = model.head.weight.clone()

    # ËÆ≠ÁªÉ‰∏ÄÊ≠•
    optimizer = torch.optim.SGD(model.head.parameters(), lr=0.1)
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()

    # Ê£ÄÊü•Ê¢ØÂ∫¶
    if model.head.weight.grad is None:
        print("  ‚ùå ÈîôËØØ: Ê≤°ÊúâËÆ°ÁÆóÊ¢ØÂ∫¶!")
        return False

    grad_norm = model.head.weight.grad.norm().item()
    print(f"  Ê¢ØÂ∫¶ËåÉÊï∞: {grad_norm:.4f}")

    if grad_norm < 1e-6:
        print("  ‚ö†Ô∏è  Ë≠¶Âëä: Ê¢ØÂ∫¶ËøáÂ∞è")

    # Êõ¥Êñ∞ÊùÉÈáç
    optimizer.step()

    # Ê£ÄÊü•ÊùÉÈáçÊòØÂê¶ÊîπÂèò
    weight_change = (model.head.weight - initial_weight).abs().max().item()
    print(f"  ÊùÉÈáçÊúÄÂ§ßÂèòÂåñ: {weight_change:.6f}")

    if weight_change < 1e-8:
        print("  ‚ùå ÈîôËØØ: ÊùÉÈáçÊ≤°ÊúâÊõ¥Êñ∞!")
        return False

    print("  ‚úì ËÆ≠ÁªÉÊ≠•È™§Ê≠£Â∏∏")
    return True


def diagnose_aggregation():
    """
    ËØäÊñ≠ËÅöÂêàÊòØÂê¶Ê≠£Â∏∏
    """
    print("\n" + "="*70)
    print("ËØäÊñ≠ 3: Ê£ÄÊü•FedAvgËÅöÂêà")
    print("="*70)

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    from fed_avg_iid import fed_avg_aggregate

    # ÂàõÂª∫ÂÖ®Â±ÄÊ®°Âûã
    global_model = DINOCIFAR100(num_classes=100).to(DEVICE)

    # ÂàõÂª∫‰∏§‰∏™Ê®°ÊãüÁöÑÊú¨Âú∞Ê®°ÂûãÊùÉÈáç
    local_weights = []

    for i in range(2):
        local_model = DINOCIFAR100(num_classes=100).to(DEVICE)
        # ÈöèÊú∫‰øÆÊîπÊùÉÈáç
        with torch.no_grad():
            local_model.head.weight += torch.randn_like(local_model.head.weight) * 0.1
        local_weights.append(local_model.state_dict())

    client_counts = [100, 100]

    # ÊâßË°åËÅöÂêà
    print("\nÊâßË°åËÅöÂêà...")
    global_weight_before = global_model.head.weight.clone()

    new_weights = fed_avg_aggregate(global_model, local_weights, client_counts)
    global_model.load_state_dict(new_weights, strict=False)

    global_weight_after = global_model.head.weight

    # Ê£ÄÊü•ÊùÉÈáçÊòØÂê¶ÊîπÂèò
    weight_change = (global_weight_after - global_weight_before).abs().max().item()
    print(f"ÂÖ®Â±ÄÊ®°ÂûãÊùÉÈáçÊúÄÂ§ßÂèòÂåñ: {weight_change:.6f}")

    if weight_change < 1e-8:
        print("‚ùå ÈîôËØØ: ËÅöÂêàÂêéÂÖ®Â±ÄÊ®°ÂûãÊùÉÈáçÊ≤°ÊúâÊîπÂèò!")
        return False

    print("‚úì ËÅöÂêàÊ≠£Â∏∏")
    return True


def test_without_task_arithmetic():
    """
    ÊµãËØï‰∏ç‰ΩøÁî®Task ArithmeticÁöÑÊ†áÂáÜFedAvg
    """
    print("\n" + "="*70)
    print("ËØäÊñ≠ 4: ÊµãËØïÊ†áÂáÜFedAvg (Êó†Task Arithmetic)")
    print("="*70)

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Êï∞ÊçÆÂáÜÂ§á
    builder = FederatedDataBuilder(K=10)
    dict_users = builder.get_iid_partition()

    test_loader = DataLoader(
        builder.test_dataset,
        batch_size=256,
        shuffle=False
    )

    # ÂÖ®Â±ÄÊ®°Âûã
    global_model = DINOCIFAR100(num_classes=100).to(DEVICE)

    from fed_avg_iid import fed_avg_aggregate, evaluate_global

    print("\nËøêË°å3ËΩÆÊ†áÂáÜFedAvg...")

    for r in range(3):
        # ÈÄâÊã©2‰∏™ÂÆ¢Êà∑Á´Ø
        selected_clients = np.random.choice(range(10), 2, replace=False)

        local_weights = []
        client_counts = []

        for client_idx in selected_clients:
            # Êú¨Âú∞ËÆ≠ÁªÉ
            local_model = DINOCIFAR100(num_classes=100).to(DEVICE)
            local_model.load_state_dict(global_model.state_dict())

            local_subset = Subset(builder.train_dataset, list(dict_users[client_idx]))
            local_loader = DataLoader(local_subset, batch_size=32, shuffle=True)

            optimizer = torch.optim.SGD(local_model.head.parameters(), lr=0.1, momentum=0.9)
            criterion = nn.CrossEntropyLoss()

            local_model.train()
            step_count = 0
            iterator = iter(local_loader)

            # Ê≠£Á°ÆÂÆûÁé∞J=4Ê≠•
            while step_count < 4:
                try:
                    inputs, targets = next(iterator)
                    inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

                    optimizer.zero_grad()
                    outputs = local_model(inputs)
                    loss = criterion(outputs, targets)
                    loss.backward()
                    optimizer.step()

                    step_count += 1
                except StopIteration:
                    break

            local_weights.append(local_model.state_dict())
            client_counts.append(len(dict_users[client_idx]))

        # ËÅöÂêà
        new_weights = fed_avg_aggregate(global_model, local_weights, client_counts)
        global_model.load_state_dict(new_weights, strict=False)

        # ËØÑ‰º∞
        test_loss, test_acc = evaluate_global(global_model, test_loader, DEVICE)
        print(f"Round {r+1}: Test Acc = {test_acc:.2f}%")

        if test_acc < 1.0:
            print("  ‚ö†Ô∏è  Á≤æÂ∫¶‰ªçÁÑ∂Ëøá‰Ωé!")
        elif test_acc > 3.0:
            print("  ‚úì Á≤æÂ∫¶ÂºÄÂßãÊèêÂçá,Âü∫Á°ÄÊµÅÁ®ãÊ≠£Â∏∏")
            return True

    return False


def main():
    """
    ËøêË°åÊâÄÊúâËØäÊñ≠
    """
    print("\n" + "üîç"*35)
    print("      ËÅîÈÇ¶Task Arithmetic ËØäÊñ≠Â∑•ÂÖ∑")
    print("üîç"*35)

    # ËØäÊñ≠1: Êé©Á†Å
    mask_ok = diagnose_mask_problem(sparsity_ratio=0.1)

    # ËØäÊñ≠2: ËÆ≠ÁªÉÊ≠•È™§
    training_ok = diagnose_training_step()

    # ËØäÊñ≠3: ËÅöÂêà
    aggregation_ok = diagnose_aggregation()

    # ËØäÊñ≠4: Êó†TAÁöÑFedAvg
    fedavg_ok = test_without_task_arithmetic()

    # ÊÄªÁªì
    print("\n" + "="*70)
    print("ËØäÊñ≠ÊÄªÁªì")
    print("="*70)
    print(f"1. Êé©Á†ÅÁîüÊàê: {'‚úì Ê≠£Â∏∏' if mask_ok else '‚ùå ÂºÇÂ∏∏'}")
    print(f"2. ËÆ≠ÁªÉÊ≠•È™§: {'‚úì Ê≠£Â∏∏' if training_ok else '‚ùå ÂºÇÂ∏∏'}")
    print(f"3. FedAvgËÅöÂêà: {'‚úì Ê≠£Â∏∏' if aggregation_ok else '‚ùå ÂºÇÂ∏∏'}")
    print(f"4. Ê†áÂáÜFedAvg: {'‚úì Ê≠£Â∏∏' if fedavg_ok else '‚ùå ÂºÇÂ∏∏'}")

    print("\n" + "="*70)
    print("Âª∫ËÆÆ:")
    print("="*70)

    if not mask_ok:
        print("1. Ê£ÄÊü•calibrate_masksÂáΩÊï∞ÂÆûÁé∞")
        print("2. Â∞ùËØïÊõ¥Â§ßÁöÑsparsity_ratio (Â¶Ç0.5)")
        print("3. Á°ÆËÆ§keep_least_sensitiveÈÄªËæëÊ≠£Á°Æ")

    if not training_ok:
        print("1. Ê£ÄÊü•Ê®°ÂûãÂàùÂßãÂåñ")
        print("2. Á°ÆËÆ§backboneÊ≠£Á°ÆÂÜªÁªì")
        print("3. Ë∞ÉÊï¥Â≠¶‰π†Áéá")

    if not fedavg_ok:
        print("1. Âü∫Á°ÄFedAvgÂ∞±ÊúâÈóÆÈ¢ò,ÂÖà‰øÆÂ§çÂÆÉ")
        print("2. Ê£ÄÊü•Êï∞ÊçÆÂä†ËΩΩ")
        print("3. Â¢ûÂä†Êú¨Âú∞ËÆ≠ÁªÉÊ≠•Êï∞")

    if mask_ok and training_ok and aggregation_ok and not fedavg_ok:
        print("1. ÈóÆÈ¢òÂèØËÉΩÂú®Êï∞ÊçÆÂ§ÑÁêÜÊàñÊ®°ÂûãÊû∂ÊûÑ")
        print("2. Â∞ùËØïËøêË°åfed_avg_iid.pyÁúãÊòØÂê¶Ê≠£Â∏∏")

    print("\nüí° Âø´ÈÄü‰øÆÂ§çÂª∫ËÆÆ:")
    print("   - ÂÖàÁ°Æ‰øùÊ†áÂáÜFedAvgËÉΩwork (Á≤æÂ∫¶>5%)")
    print("   - ÂÜçÂä†ÂÖ•Task Arithmetic")
    print("   - ‰ΩøÁî®ËæÉÂ§ßÁöÑsparsity_ratioÂºÄÂßãÊµãËØï")


if __name__ == "__main__":
    main()


üîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîç
      ËÅîÈÇ¶Task Arithmetic ËØäÊñ≠Â∑•ÂÖ∑
üîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîçüîç

ËØäÊñ≠ 1: Ê£ÄÊü•Êé©Á†ÅÁîüÊàê
Creating IID partition for 10 clients...
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main



ËÆ°ÁÆóFisherÊïèÊÑüÂ∫¶ (sparsity=0.1)...
Calculating sensitivity over 5 batches...

Êé©Á†ÅÁªüËÆ°:
----------------------------------------------------------------------
backbone.cls_token             | Total:      384 | Active:        0 (  0.0%) | Frozen:      384
backbone.pos_embed             | Total:    75648 | Active:      316 (  0.4%) | Frozen:    75332
backbone.patch_embed.proj.weight | Total:   294912 | Active:        0 (  0.0%) | Frozen:   294912
backbone.patch_embed.proj.bias | Total:      384 | Active:        0 (  0.0%) | Frozen:      384
backbone.blocks.0.norm1.weight | Total:      384 | Active:      207 ( 53.9%) | Frozen:      177
backbone.blocks.0.norm1.bias   | Total:      384 | Active:       44 ( 11.5%) | Frozen:      340
backbone.blocks.0.attn.qkv.weight | Total:   442368 | Active:   351693 ( 79.5%) | Frozen:    90675
backbone.blocks.0.attn.qkv.bias | Total:     1152 | Active:      685 ( 59.5%) | Frozen:      467
backbone.blocks.0.attn.proj.weight | Total:   147456 | Ac

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main



Ê£ÄÊü•backboneÂÜªÁªìÁä∂ÊÄÅ:
BackboneÂèØËÆ≠ÁªÉÂèÇÊï∞Êï∞: 150
‚ùå ÈîôËØØ: BackboneÂ∫îËØ•Ë¢´ÂÆåÂÖ®ÂÜªÁªì!

ËØäÊñ≠ 3: Ê£ÄÊü•FedAvgËÅöÂêà
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main



ÊâßË°åËÅöÂêà...
ÂÖ®Â±ÄÊ®°ÂûãÊùÉÈáçÊúÄÂ§ßÂèòÂåñ: 0.325009
‚úì ËÅöÂêàÊ≠£Â∏∏

ËØäÊñ≠ 4: ÊµãËØïÊ†áÂáÜFedAvg (Êó†Task Arithmetic)
Creating IID partition for 10 clients...
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main



ËøêË°å3ËΩÆÊ†áÂáÜFedAvg...
Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Downloading/Loading DINO ViT-S/16...


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Round 1: Test Acc = 6.55%
  ‚úì Á≤æÂ∫¶ÂºÄÂßãÊèêÂçá,Âü∫Á°ÄÊµÅÁ®ãÊ≠£Â∏∏

ËØäÊñ≠ÊÄªÁªì
1. Êé©Á†ÅÁîüÊàê: ‚úì Ê≠£Â∏∏
2. ËÆ≠ÁªÉÊ≠•È™§: ‚ùå ÂºÇÂ∏∏
3. FedAvgËÅöÂêà: ‚úì Ê≠£Â∏∏
4. Ê†áÂáÜFedAvg: ‚úì Ê≠£Â∏∏

Âª∫ËÆÆ:
1. Ê£ÄÊü•Ê®°ÂûãÂàùÂßãÂåñ
2. Á°ÆËÆ§backboneÊ≠£Á°ÆÂÜªÁªì
3. Ë∞ÉÊï¥Â≠¶‰π†Áéá

üí° Âø´ÈÄü‰øÆÂ§çÂª∫ËÆÆ:
   - ÂÖàÁ°Æ‰øùÊ†áÂáÜFedAvgËÉΩwork (Á≤æÂ∫¶>5%)
   - ÂÜçÂä†ÂÖ•Task Arithmetic
   - ‰ΩøÁî®ËæÉÂ§ßÁöÑsparsity_ratioÂºÄÂßãÊµãËØï
