<a href="https://colab.research.google.com/github/Rainery-Ar/CS64-6/blob/main/3%20tasks%2Cweek9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

# Set random seed
torch.manual_seed(42)
np.random.seed(42)

print("="*80)
print(" Federated Learning LoRA Heterogeneous Rank Aggregation")
print(" Theoretical Verification Experiment")
print(" (Fully Corresponds to Code 1 Logic)")
print("="*80)

# 0. Experiment Configuration
print("\n[Experiment Configuration]")
d, k = 10, 10  # Base matrix dimensions (d_out × d_in for LoRA)
R_GLOBAL = 100  # Global model rank
client_ranks = [20, 25, 30]  # Client heterogeneous ranks
weights = [1/3, 1/3, 1/3]  # Client weights

print(f"Base matrix dimensions: d={d}, k={k}")
print(f"Global model rank: {R_GLOBAL}")
print(f"Client ranks: {client_ranks}")
print(f"Total rank: {sum(client_ranks)} {'<' if sum(client_ranks) < R_GLOBAL else '>'} {R_GLOBAL}")
print(f"Client weights: {weights}")

# Simulate Client States
print("\n[Simulating Client LoRA Parameters]")
client_states = []
for i, rank in enumerate(client_ranks):
    # A: rank × k (e.g., 20 × 10)
    A = torch.randn(rank, k)
    # B: d × rank (e.g., 10 × 20)
    B = torch.randn(d, rank)
    client_states.append({'A': A, 'B': B, 'rank': rank})
    print(f"Client {i+1}: A shape={A.shape}, B shape={B.shape}, rank={rank}")

# Task 1: Parameter Stacking Aggregation
print("\n" + "="*80)
print("[Task 1] SVD Broadcast-based Parameter Stacking Aggregation")
print("="*80)

# Stack A matrices vertically (row-wise)
A_stacked = torch.cat([state['A'] for state in client_states], dim=0)

# Stack B matrices horizontally (column-wise)
B_stacked = torch.cat([state['B'] for state in client_states], dim=1)

# Global stacked rank
r_global_stacked = A_stacked.shape[0]

# Compute global update matrix
Delta_W_stacked = B_stacked @ A_stacked

# SVD decomposition for broadcasting
U, S, Vt = torch.linalg.svd(Delta_W_stacked)

print(f"Client ranks: {client_ranks}")
print(f"Global stacked rank r_G: {r_global_stacked} (sum of all client ranks)")
print(f"A_stacked shape: {A_stacked.shape} (all A vertically stacked)")
print(f"B_stacked shape: {B_stacked.shape} (all B horizontally stacked)")
print(f"Delta_W_stacked shape: {Delta_W_stacked.shape}")
print(f"SVD factors: U {U.shape}, S {S.shape}, Vt {Vt.shape}")
print(f"\n Task 1 Complete: All client information preserved, global rank={r_global_stacked}")

def _recon_errors_from_svd(delta_w, U, S, Vt, r_list, title):
    print("\n[Task 1 • Reconstruction Error Table]")
    print(f"Target matrix: {title}  |  shape={tuple(delta_w.shape)}")
    norm_ref = torch.linalg.norm(delta_w).item()
    print(f"{'r':>6}  {'||ΔW - UΣVᵀ||_F':>20}  {'RelErr':>12}")
    print("-" * 44)
    for r in r_list:
        r_eff = min(r, S.numel())
        S_half = torch.diag(torch.sqrt(S[:r_eff]))
        B_r = U[:, :r_eff] @ S_half
        A_r = S_half @ Vt[:r_eff, :]
        rec = B_r @ A_r
        err = torch.linalg.norm(delta_w - rec).item()
        rel = err / max(norm_ref, 1e-12)
        print(f"{r:6d}  {err:20.6e}  {rel:12.6e}")

# choose representative ranks (1, min/max client rank, and rank ceiling)
r_candidates_task1 = sorted(set([1, min(client_ranks), max(client_ranks), min(d, k)]))

# Use the SVD you already computed on Delta_W_stacked
_recon_errors_from_svd(
    delta_w=Delta_W_stacked,
    U=U, S=S, Vt=Vt,
    r_list=r_candidates_task1,
    title="Delta_W_stacked"
)


# Task 2: SVD Dimensionality Reduction Reconstruction
print("\n" + "="*80)
print("[Task 2] SVD-based Dimensionality Reduction Reconstruction")
print("="*80)

# 1. Compute complete update matrix for each client
delta_W_list = [state['B'] @ state['A'] for state in client_states]

# 2. Aggregate all Delta_W (FedAvg)
Delta_W_agg = sum(w * dw for w, dw in zip(weights, delta_W_list))

# 3. SVD decomposition
U, S, Vt = torch.linalg.svd(Delta_W_agg)

# 4. Truncate to target rank
r_target = R_GLOBAL
r_actual = min(r_target, len(S), d, k)  # Actual rank limited by min(d, k)

U_r = U[:, :r_actual]
S_r = S[:r_actual]
Vt_r = Vt[:r_actual, :]

# 5. Reconstruct B' and A'
S_diag_sqrt = torch.diag(torch.sqrt(S_r))
B_prime = U_r @ S_diag_sqrt  # d × r_actual
A_prime = S_diag_sqrt @ Vt_r  # r_actual × k

# Verify reconstruction error
Delta_W_reconstructed = B_prime @ A_prime
reconstruction_error = torch.linalg.norm(Delta_W_agg - Delta_W_reconstructed)

print(f"Target rank r_target: {r_target}")
print(f"Actual reconstructed rank r_actual: {r_actual} (limited by min(d,k)={min(d,k)})")
print(f"Delta_W_agg shape: {Delta_W_agg.shape}")
print(f"B' shape: {B_prime.shape}")
print(f"A' shape: {A_prime.shape}")
print(f"Reconstruction error (Frobenius Norm): {reconstruction_error.item():.6f}")
#print(f"Top 5 singular values: {S[:5].numpy()}")
print(f"\n Task 2 Complete: Optimal {r_actual}-rank approximation of all client updates")

def _recon_errors_from_svd_T2(delta_w, U, S, Vt, r_list, title):
    print("\n[Task 2 • Reconstruction Error Table]")
    print(f"Target matrix: {title}  |  shape={tuple(delta_w.shape)}")
    norm_ref = torch.linalg.norm(delta_w).item()
    print(f"{'r':>6}  {'||ΔW - UΣVᵀ||_F':>20}  {'RelErr':>12}")
    print("-" * 44)
    for r in r_list:
        r_eff = min(r, S.numel())
        S_half = torch.diag(torch.sqrt(S[:r_eff]))
        B_r = U[:, :r_eff] @ S_half
        A_r = S_half @ Vt[:r_eff, :]
        rec = B_r @ A_r
        err = torch.linalg.norm(delta_w - rec).item()
        rel = err / max(norm_ref, 1e-12)
        print(f"{r:6d}  {err:20.6e}  {rel:12.6e}")

# representative ranks (same idea; include the actual r you used)
r_candidates_task2 = sorted(set([1, min(client_ranks), max(client_ranks), min(d, k)]))

# Use the SVD you already computed on Delta_W_agg
_recon_errors_from_svd_T2(
    delta_w=Delta_W_agg,
    U=U, S=S, Vt=Vt,
    r_list=r_candidates_task2,
    title="Delta_W_agg"
)


# Task 3: Engineering Decision Inference for Rank Mismatch
print("\n" + "="*80)
print("[Task 3] Engineering Decision Inference: Global High-Rank vs Local Low-Rank")
print("="*80)

max_local_rank = max(client_ranks)
print(f"Global model rank: {R_GLOBAL}")
print(f"Client ranks: {client_ranks}")
print(f"Maximum local rank: {max_local_rank}")
print(f"Rank mismatch: Global {R_GLOBAL} > Max local {max_local_rank}")

#Inference 1: Zero-Padding/Dilution
print("\n" + "-"*80)
print("Inference 1: Zero-Padding/Dilution")
print("-"*80)

# Simulate global model's initial parameters
A_global_init = torch.ones(R_GLOBAL, k) * 5.0
B_global_init = torch.ones(d, R_GLOBAL) * 5.0

# Each client pads zeros to global rank
A_clients_padded = []
B_clients_padded = []

for state in client_states:
    r_local = state['rank']
    # Zero padding
    A_padded = torch.zeros(R_GLOBAL, k)
    A_padded[:r_local, :] = state['A']

    B_padded = torch.zeros(d, R_GLOBAL)
    B_padded[:, :r_local] = state['B']

    A_clients_padded.append(A_padded)
    B_clients_padded.append(B_padded)

# FedAvg mixing
weight_global = 0.5
weight_local = 0.5

A_scheme1 = weight_global * A_global_init
B_scheme1 = weight_global * B_global_init

for A_pad, B_pad, w in zip(A_clients_padded, B_clients_padded, weights):
    A_scheme1 += weight_local * w * A_pad
    B_scheme1 += weight_local * w * B_pad

# Verify dilution effect
val_updated = A_scheme1[0, 0].item()  # Updated region
val_diluted = A_scheme1[max_local_rank, 0].item()  # Diluted region

print(f"Operation flow:")
print(f"  1. Each client pads A (r_i×{k}) with zeros to ({R_GLOBAL}×{k})")
print(f"  2. FedAvg: A_new = {weight_global}×A_global + {weight_local}×Σ(w_i×A_padded_i)")
print(f"\nNumerical verification (global initial value=5.0):")
print(f"  Updated region (r=1~{max_local_rank}): {val_updated:.2f}")
print(f"  Diluted region (r={max_local_rank+1}~{R_GLOBAL}): {val_diluted:.2f}")
print(f"  Theoretical value: 5.0 × {weight_global} = {5.0 * weight_global:.2f}")
print(f"\n  Problem: Un-updated dimensions from 5.0 → {val_diluted:.2f} (diluted by {(5.0-val_diluted)/5.0*100:.0f}%)")

#Inference 2: Truncation/Discarding High-Rank
print("\n" + "-"*80)
print("Inference 2: ΔW Aggregation with Truncation")
print("-"*80)

# Compute and aggregate Delta_W
delta_W_list_2 = [state['B'] @ state['A'] for state in client_states]
Delta_W_global = sum(w * dw for w, dw in zip(weights, delta_W_list_2))

# Mix with global Delta_W
Delta_W_global_init = torch.ones(d, k) * 0.1
Delta_W_scheme2 = weight_global * Delta_W_global_init + weight_local * Delta_W_global

# SVD decomposition
U2, S2, Vt2 = torch.linalg.svd(Delta_W_scheme2)

# Actual rank limited by min(d, k)
r_slice = min(d, k)

U_final = U2[:, :r_slice]
S_final = S2[:r_slice]
Vt_final = Vt2[:r_slice, :]

# Reconstruct A and B
S_diag_sqrt2 = torch.diag(torch.sqrt(S_final))
B_reconstructed = U_final @ S_diag_sqrt2
A_reconstructed = S_diag_sqrt2 @ Vt_final

# Pad to global rank
A_global_2 = torch.zeros(R_GLOBAL, k)
A_global_2[:r_slice, :] = A_reconstructed

B_global_2 = torch.zeros(d, R_GLOBAL)
B_global_2[:, :r_slice] = B_reconstructed

wasted_rank = R_GLOBAL - r_slice

print(f"Operation flow:")
print(f"  1. Compute Delta_W_i = B_i @ A_i ({d}×{k} matrix)")
print(f"  2. FedAvg: ΔW = {weight_global}×ΔW_global + {weight_local}×Σ(w_i×ΔW_i)")
print(f"  3. SVD decomposition and reconstruction")
print(f"\nNumerical verification:")
print(f"  Delta_W shape: {Delta_W_scheme2.shape}")
print(f"  SVD actual rank: {r_slice} (= min({d}, {k}))")
print(f"  A_global_2 shape: {A_global_2.shape}")
print(f"  B_global_2 shape: {B_global_2.shape}")
print(f"  Non-zero rank range: r=1~{r_slice}")
print(f"  Zero rank range: r={r_slice+1}~{R_GLOBAL}")
print(f"\n  Problem: Global model {wasted_rank} rank dimensions completely wasted ({wasted_rank/R_GLOBAL*100:.0f}%)")

# Final Summary
print("\n" + "="*80)
print("[Experiment Summary]")
print("="*80)

print("\n Task 1 Verification:")
print(f"  - Stacking aggregation handles heterogeneous ranks, global rank={r_global_stacked}")
print(f"  - All client information preserved, no loss")

print("\n Task 2 Verification:")
print(f"  - SVD reconstruction error: {reconstruction_error.item():.6f}")
print(f"  - Optimal {r_actual}-rank approximation (limited by min(d,k)={min(d,k)})")

print("\n Task 3 Verification:")
print(f"  - Inference 1 (Zero-padding): Un-updated dimensions diluted by {(5.0-val_diluted)/5.0*100:.0f}%")
print(f"  - Inference 2 (Truncation): {wasted_rank} rank dimensions wasted ({wasted_rank/R_GLOBAL*100:.0f}%)")

print("\n Key Findings:")
print(f"  Scheme 1: Preserves high-rank structure → but dilutes {(5.0-val_diluted)/5.0*100:.0f}% parameter information")
print(f"  Scheme 2: SVD optimal approximation → but wastes {wasted_rank/R_GLOBAL*100:.0f}% rank expressiveness")
print(f"\n  Fundamental trade-off between:")
print(f"  - High-rank structure vs Information density")
print(f"  - Global consistency vs Local optimality")

print("\n" + "="*80)
print("Experiment Complete!")
print("="*80)

 Federated Learning LoRA Heterogeneous Rank Aggregation
 Theoretical Verification Experiment
 (Fully Corresponds to Code 1 Logic)

[Experiment Configuration]
Base matrix dimensions: d=10, k=10
Global model rank: 100
Client ranks: [20, 25, 30]
Total rank: 75 < 100
Client weights: [0.3333333333333333, 0.3333333333333333, 0.3333333333333333]

[Simulating Client LoRA Parameters]
Client 1: A shape=torch.Size([20, 10]), B shape=torch.Size([10, 20]), rank=20
Client 2: A shape=torch.Size([25, 10]), B shape=torch.Size([10, 25]), rank=25
Client 3: A shape=torch.Size([30, 10]), B shape=torch.Size([10, 30]), rank=30

[Task 1] SVD Broadcast-based Parameter Stacking Aggregation
Client ranks: [20, 25, 30]
Global stacked rank r_G: 75 (sum of all client ranks)
A_stacked shape: torch.Size([75, 10]) (all A vertically stacked)
B_stacked shape: torch.Size([10, 75]) (all B horizontally stacked)
Delta_W_stacked shape: torch.Size([10, 10])
SVD factors: U torch.Size([10, 10]), S torch.Size([10]), Vt torch.Size