<a href="https://colab.research.google.com/github/Rainery-Ar/CS64-6/blob/main/Verification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch

def fedavg_svd_aggregation_high_rank(client_data, global_dims, client_weights=None):
    """
    Simulate heterogeneous-rank (R=1, 2, 3) aggregation with SVD.
    Client 1 contributes a Delta W that is truly rank-2.
    """
    print("✓ FedAvg SVD High-Rank Simulation defined successfully!\n")
    print("==================================================")
    print("=== ORIGINAL CLIENT STATES (True Heterogeneous Rank) ===")
    print("==================================================")

    # --- 1. Client contributions (product first) ---
    aggregated_delta_w = torch.zeros(global_dims)
    num_clients = len(client_data)
    client_weights = [1.0 / num_clients] * num_clients  # keep equal weights (original behavior)

    for i, (rank, A, B) in enumerate(client_data):
        W = A @ B
        aggregated_delta_w += client_weights[i] * W

        # Print client info
        print(f"\nClient {i} (rank={rank}):")
        print(f"A shape: {A.shape} | A:\n {A}")
        print(f"B shape: {B.shape} | B:\n {B}")
        print(f"A @ B (full weight matrix / Delta W_{i}):\n {W}")

    # --- 2. Server aggregate target (sum of products result) ---
    print("\n==================================================")
    print("=== AGGREGATED TARGET MATRIX ===")
    print("==================================================")
    print(f"Delta W_agg:\n {aggregated_delta_w}")

    # --- 3. SVD decomposition ---
    U, S, Vh = torch.linalg.svd(aggregated_delta_w)

    # Diagnostics: singular values, effective rank, reconstruction error
    print("\n--------------------------------------------------")
    print("=== SVD DIAGNOSTICS ===")
    print("--------------------------------------------------")
    print(f"Singular values S:\n {S}")
    tol = S.max() * max(global_dims) * torch.finfo(aggregated_delta_w.dtype).eps
    eff_rank = int((S > tol).sum().item())
    print(f"Effective numerical rank (tol={tol:.3e}): {eff_rank}")

    # Factor assignment: split sqrt of singular values to A and B
    S_sqrt = torch.sqrt(S)
    A_global = U @ torch.diag(S_sqrt)
    B_global = torch.diag(S_sqrt) @ Vh

    # Sign alignment
    if A_global[0, 0] < 0:
        A_global = -A_global
        B_global = -B_global

    # Reconstruction check
    global_reconstruction = A_global @ B_global
    rec_err = torch.linalg.norm(aggregated_delta_w - global_reconstruction, ord='fro').item()
    print(f"Reconstruction Frobenius error: {rec_err:.3e}")

    # --- 4. Print Global State ---
    print("\n==================================================")
    print("=== GLOBAL STATE (SVD DECOMPOSED) ===")
    print("==================================================")
    print(f"Global A shape: {A_global.shape}")
    print(f"Global A:\n {A_global}")
    print(f"Global B shape: {B_global.shape}")
    print(f"Global B:\n {B_global}")
    print(f"Global A @ B (reconstructed):\n {global_reconstruction}")

    # --- 5. Client-side slicing validation (rank-specific broadcast) ---
    print("\n==================================================")
    print("=== LOCAL STATES (RANK-SPECIFIC SLICES) ===")
    print("==================================================")

    for i, (rank, _, _) in enumerate(client_data):
        A_sliced = A_global[:, :rank]
        B_sliced = B_global[:rank, :]

        W_reconstructed = A_sliced @ B_sliced

        print(f"\nClient {i} (rank={rank}):")
        print(f"A shape: {A_sliced.shape} | A (sliced):\n {A_sliced}")
        print(f"B shape: {B_sliced.shape} | B (sliced):\n {B_sliced}")
        print(f"A @ B (reconstructed with rank {rank} - Approximation of Delta W_agg):")
        print(W_reconstructed)


# --- Define heterogeneous high-rank input data ---
DIM = (3, 3)

# Client 0: R=1 (rank-1 contribution)
R0 = 1
A0 = torch.ones(DIM[0], R0)
B0 = torch.ones(R0, DIM[1])

# Client 1: R=2 (true rank-2 contribution)
# Ensure the 2 columns of A1 and the 2 rows of B1 are linearly independent
R1 = 2
A1 = torch.tensor([[2., 1.], [2., 0.], [2., 1.]])
B1 = torch.tensor([[2., 2., 2.], [1., 3., 1.]])

# Client 2: R=3 (constructed to be rank-2 even though client uses rank 3)
R2 = 3
A2 = torch.tensor([[1., 0., 1.], [0., 1., 0.], [1., 0., 1.]])
B2 = torch.tensor([[1., 0., 1.], [0., 1., 0.], [1., 0., 1.]])

client_data_list = [
    (R0, A0, B0),
    (R1, A1, B1),
    (R2, A2, B2),
]

# --- Run ---
torch.set_printoptions(precision=4, sci_mode=False)
fedavg_svd_aggregation_high_rank(client_data_list, DIM)


✓ FedAvg SVD High-Rank Simulation defined successfully!

=== ORIGINAL CLIENT STATES (True Heterogeneous Rank) ===

Client 0 (rank=1):
A shape: torch.Size([3, 1]) | A:
 tensor([[1.],
        [1.],
        [1.]])
B shape: torch.Size([1, 3]) | B:
 tensor([[1., 1., 1.]])
A @ B (full weight matrix / Delta W_0):
 tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

Client 1 (rank=2):
A shape: torch.Size([3, 2]) | A:
 tensor([[2., 1.],
        [2., 0.],
        [2., 1.]])
B shape: torch.Size([2, 3]) | B:
 tensor([[2., 2., 2.],
        [1., 3., 1.]])
A @ B (full weight matrix / Delta W_1):
 tensor([[5., 7., 5.],
        [4., 4., 4.],
        [5., 7., 5.]])

Client 2 (rank=3):
A shape: torch.Size([3, 3]) | A:
 tensor([[1., 0., 1.],
        [0., 1., 0.],
        [1., 0., 1.]])
B shape: torch.Size([3, 3]) | B:
 tensor([[1., 0., 1.],
        [0., 1., 0.],
        [1., 0., 1.]])
A @ B (full weight matrix / Delta W_2):
 tensor([[2., 0., 2.],
        [0., 1., 0.],
        [2., 0., 2.]]

In [6]:
import torch

def fedavg_svd_aggregation_high_rank(client_data, global_dims, client_weights=None):
    """
    Simulate heterogeneous-rank (R=1, 2, 3) aggregation with SVD.
    Client 1 contributes a Delta W that is truly rank-2.
    Client 2 is set to produce a full-rank (rank-3) contribution.
    """
    print("✓ FedAvg SVD High-Rank Simulation defined successfully!\n")
    print("=== ORIGINAL CLIENT STATES (True Heterogeneous Rank) ===")

    # --- 1. Client contributions (product first) ---
    dtype = client_data[0][1].dtype
    device = client_data[0][1].device
    aggregated_delta_w = torch.zeros(global_dims, dtype=dtype, device=device)
    num_clients = len(client_data)

    if client_weights is None:
        client_weights = [1.0 / num_clients] * num_clients
    else:
        assert len(client_weights) == num_clients, "Length of client_weights must match number of clients."

    for i, (rank, A, B) in enumerate(client_data):
        W = A @ B
        aggregated_delta_w += float(client_weights[i]) * W

        # Print client info
        print(f"\nClient {i} (rank={rank}):")
        print(f"A shape: {A.shape} | A:\n {A}")
        print(f"B shape: {B.shape} | B:\n {B}")
        print(f"A @ B (full weight matrix / Delta W_{i}):\n {W}")

    # --- 2. Server aggregate target (sum of products result) ---
    print("=== AGGREGATED TARGET MATRIX ===")
    print(f"Delta W_agg:\n {aggregated_delta_w}")

    # --- 3. SVD decomposition ---
    U, S, Vh = torch.linalg.svd(aggregated_delta_w)

    # Diagnostics: singular values, effective rank, reconstruction error
    print("=== SVD DIAGNOSTICS ===")
    print(f"Singular values S:\n {S}")
    tol = S.max() * max(global_dims) * torch.finfo(aggregated_delta_w.dtype).eps
    eff_rank = int((S > tol).sum().item())
    print(f"Effective numerical rank (tol={tol:.3e}): {eff_rank}")

    # Factor assignment: split sqrt of singular values to A and B
    S_sqrt = torch.sqrt(S)
    A_global = U @ torch.diag(S_sqrt)
    B_global = torch.diag(S_sqrt) @ Vh

    # Sign alignment
    if A_global[0, 0] < 0:
        A_global = -A_global
        B_global = -B_global

    # Reconstruction check
    global_reconstruction = A_global @ B_global
    rec_err = torch.linalg.norm(aggregated_delta_w - global_reconstruction, ord='fro').item()
    print(f"Reconstruction Frobenius error: {rec_err:.3e}")

    # --- 4. Print Global State ---
    print("=== GLOBAL STATE (SVD DECOMPOSED) ===")
    print(f"Global A shape: {A_global.shape}")
    print(f"Global A:\n {A_global}")
    print(f"Global B shape: {B_global.shape}")
    print(f"Global B:\n {B_global}")
    print(f"Global A @ B (reconstructed):\n {global_reconstruction}")

    # --- 5. Client-side slicing validation (rank-specific broadcast) ---
    print("=== LOCAL STATES (RANK-SPECIFIC SLICES) ===")

    for i, (rank, _, _) in enumerate(client_data):
        A_sliced = A_global[:, :rank]
        B_sliced = B_global[:rank, :]

        W_reconstructed = A_sliced @ B_sliced

        print(f"\nClient {i} (rank={rank}):")
        print(f"A shape: {A_sliced.shape} | A (sliced):\n {A_sliced}")
        print(f"B shape: {B_sliced.shape} | B (sliced):\n {B_sliced}")
        print(f"A @ B (reconstructed with rank {rank} - Approximation of Delta W_agg):")
        print(W_reconstructed)


# --- Define heterogeneous high-rank input data ---
DIM = (3, 3)

# Client 0: R=1 (rank-1 contribution)
R0 = 1
A0 = torch.ones(DIM[0], R0)
B0 = torch.ones(R0, DIM[1])

# Client 1: R=2 (true rank-2 contribution)
# Ensure the 2 columns of A1 and the 2 rows of B1 are linearly independent
R1 = 2
A1 = torch.tensor([[2., 1.],
                   [2., 0.],
                   [2., 1.]])
B1 = torch.tensor([[2., 2., 2.],
                   [1., 3., 1.]])

# Client 2: R=3 (full-rank contribution)
R2 = 3
A2 = torch.eye(3)
B2 = torch.tensor([[1., 2., 3.],
                   [0., 1., 4.],
                   [5., 6., 0.]])  # invertible; det != 0

client_data_list = [
    (R0, A0, B0),
    (R1, A1, B1),
    (R2, A2, B2),
]

# --- Run ---
torch.set_printoptions(precision=4, sci_mode=False)
fedavg_svd_aggregation_high_rank(client_data_list, DIM)


✓ FedAvg SVD High-Rank Simulation defined successfully!

=== ORIGINAL CLIENT STATES (True Heterogeneous Rank) ===

Client 0 (rank=1):
A shape: torch.Size([3, 1]) | A:
 tensor([[1.],
        [1.],
        [1.]])
B shape: torch.Size([1, 3]) | B:
 tensor([[1., 1., 1.]])
A @ B (full weight matrix / Delta W_0):
 tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

Client 1 (rank=2):
A shape: torch.Size([3, 2]) | A:
 tensor([[2., 1.],
        [2., 0.],
        [2., 1.]])
B shape: torch.Size([2, 3]) | B:
 tensor([[2., 2., 2.],
        [1., 3., 1.]])
A @ B (full weight matrix / Delta W_1):
 tensor([[5., 7., 5.],
        [4., 4., 4.],
        [5., 7., 5.]])

Client 2 (rank=3):
A shape: torch.Size([3, 3]) | A:
 tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
B shape: torch.Size([3, 3]) | B:
 tensor([[1., 2., 3.],
        [0., 1., 4.],
        [5., 6., 0.]])
A @ B (full weight matrix / Delta W_2):
 tensor([[1., 2., 3.],
        [0., 1., 4.],
        [5., 6., 0.]]