In [1]:
import torch
import numpy as np

# --- 0. Setup ---
d, k = 10, 10
client_ranks = [3, 5, 8]
weights = [1/3, 1/3, 1/3]
client_states = []

for rank in client_ranks:
    # A: d x r
    A = torch.randn(d, rank)
    # B: r x k
    B = torch.randn(rank, k)
    client_states.append({'A': A, 'B': B})

# --- 1. Stacking Aggregation ---
# A matrices are stacked along the column (dim=1)
A_stacked = torch.cat([state['A'] for state in client_states], dim=1)

# B matrices are stacked along the row (dim=0)
B_stacked = torch.cat([state['B'] for state in client_states], dim=0)

# The global rank r_G is the sum of local ranks
r_global_stacked = A_stacked.shape[1]

# --- 2. Structural Verification (B_stacked * A_stacked) ---
delta_W_stacked = B_stacked @ A_stacked

# --- 3. SVD for Broadcasting (Simulated) ---
# SVD is performed on the product to find a low-rank representation for broadcasting
U, S, Vt = torch.linalg.svd(delta_W_stacked)

# We can choose a target rank r_target for broadcasting, e.g., the max client rank
r_target = max(client_ranks)

# Truncate U, S, Vt to the target rank
U_broadcast = U[:, :r_target]
S_broadcast = S[:r_target]
Vt_broadcast = Vt[:r_target, :]

# Reconstruct a compressed global model (B_global, A_global) for broadcasting
S_diag_sqrt = torch.diag(torch.sqrt(S_broadcast))
B_global_compressed = U_broadcast @ S_diag_sqrt
A_global_compressed = S_diag_sqrt @ Vt_broadcast


print("--- Task 1 Results (Stacking and SVD) ---")
print(f"Total Global Stacked Rank (r_G): {r_global_stacked}")
print(f"A_stacked shape: {A_stacked.shape}")
print(f"B_stacked shape: {B_stacked.shape}")
print(f"Delta_W_stacked shape: {delta_W_stacked.shape}")
print(f"Compressed B_global shape (r={r_target}): {B_global_compressed.shape}")
print(f"Compressed A_global shape (r={r_target}): {A_global_compressed.shape}")

--- Task 1 Results (Stacking and SVD) ---
Total Global Stacked Rank (r_G): 16
A_stacked shape: torch.Size([10, 16])
B_stacked shape: torch.Size([16, 10])
Delta_W_stacked shape: torch.Size([16, 16])
Compressed B_global shape (r=8): torch.Size([16, 8])
Compressed A_global shape (r=8): torch.Size([8, 16])


In [3]:
import torch

# --- 0. Setup ---
# d, k 保持 10x10，确保 Delta_W 维度一致
d, k = 10, 10
# 客户端的 LoRA 秩是异构的
client_ranks = [3, 5, 8]
weights = [1/3, 1/3, 1/3]
# 服务器设定的全局目标秩
r_target = 5

client_updates = []
for rank in client_ranks:
    # A: d x r (10 x r)
    A = torch.randn(d, rank)
    # B: r x k (r x 10)
    B = torch.randn(rank, k)

    # CORRECT: 计算 Delta_W 为 A @ B，结果是 10 x 10 矩阵
    delta_W = A @ B
    client_updates.append(delta_W)

# --- 1. Aggregation of Full Delta_W ---
# 对所有 10x10 的 Delta_W 矩阵进行 FedAvg 求和
delta_W_agg = sum(w * dw for w, dw in zip(weights, client_updates))

# --- 2. SVD Decomposition and Truncation ---
# 对聚合结果进行 SVD 分解
U, S, Vt = torch.linalg.svd(delta_W_agg)

# 截断 U, S, Vt 到目标秩 r_target
U_sliced = U[:, :r_target]
S_sliced = S[:r_target]
Vt_sliced = Vt[:r_target, :]

# --- 3. Reconstruction of A' and B' ---
# 创建奇异值矩阵的平方根对角阵
S_diag_sqrt = torch.diag(torch.sqrt(S_sliced))

# B' = U_r * sqrt(Sigma_r) -> 10 x r_target (10 x 5)
B_prime = U_sliced @ S_diag_sqrt

# A' = sqrt(Sigma_r) * Vt_r -> r_target x 10 (5 x 10)
A_prime = S_diag_sqrt @ Vt_sliced

# 验证重建质量
delta_W_reconstructed = B_prime @ A_prime
reconstruction_error = torch.linalg.norm(delta_W_agg - delta_W_reconstructed)


print("--- Task 2 Results (SVD Reconstruction) ---")
print(f"Target Rank (r_target): {r_target}")
print(f"B_prime shape: {B_prime.shape}")
print(f"A_prime shape: {A_prime.shape}")
print(f"Reconstruction Error (Norm): {reconstruction_error.item():.4e}")

--- Task 2 Results (SVD Reconstruction) ---
Target Rank (r_target): 5
B_prime shape: torch.Size([10, 5])
A_prime shape: torch.Size([5, 10])
Reconstruction Error (Norm): 3.7511e+00


In [6]:
import torch

# --- 0. Setup ---
r_G = 100 # Global Rank
r_L = 70  # Local Rank
d, k = 10, 10 # Base matrix dimensions (Delta_W must be d x k = 10x10)
weight_L = 0.5
weight_G = 0.5

# 初始化全局模型参数 (所有元素为 5.0)
A_global = torch.ones(d, r_G) * 5.0 # 10 x 100
B_global = torch.ones(r_G, k) * 5.0 # 100 x 10

# 本地更新参数 (所有元素为 1.0)
A_local = torch.ones(d, r_L) * 1.0 # 10 x 70
B_local = torch.ones(r_L, k) * 1.0 # 70 x 10


# --- 推理 1: 补零设置 (Zero-Padding/Setting) ---

# 1a. 将本地 A 和 B 补零到 r_G (100)
A_local_padded = torch.zeros(d, r_G)
A_local_padded[:, :r_L] = A_local
B_local_padded = torch.zeros(r_G, k)
B_local_padded[:r_L, :] = B_local

# 1b. 执行 FedAvg
A_global_1 = weight_G * A_global + weight_L * A_local_padded
B_global_1 = weight_G * B_global + weight_L * B_local_padded

# 检查更新区域 (r=1 to 70) 和稀释区域 (r=71 to 100) 的值
val_updated_1 = A_global_1[0, 0].item()  # 5 * 0.5 + 1 * 0.5 = 3.0
val_diluted_1 = A_global_1[0, r_L].item() # 5 * 0.5 + 0 * 0.5 = 2.5


# --- 推理 2: 截断/舍弃 (Truncation/Discarding) ---

# 2a. 计算 Delta_W (A @ B), 结果为 10x10
Delta_W_global = A_global @ B_global
Delta_W_local = A_local @ B_local

# 2b. 对 10x10 的 Delta_W 进行 FedAvg
Delta_W_global_2 = weight_G * Delta_W_global + weight_L * Delta_W_local

# 2c. 最终 SVD 重构到 r_G (r=100)
# 由于 Delta_W 只有 10x10，其最大秩 r_slice = 10
U, S, Vt = torch.linalg.svd(Delta_W_global_2)
r_slice = min(d, k)
U_final = U[:, :r_slice]
S_final = S[:r_slice]
Vt_final = Vt[:r_slice, :]

# 重构 B_prime 和 A_prime (秩为 10)
S_diag_sqrt = torch.diag(torch.sqrt(S_final))
B_reconstructed_r10 = U_final @ S_diag_sqrt # 10 x 10
A_reconstructed_r10 = S_diag_sqrt @ Vt_final # 10 x 10

# 填充到 r_G = 100 的结构尺寸，即用零填充 r=11 到 r=100 的空间
A_global_2 = torch.zeros(d, r_G)
A_global_2[:, :r_slice] = A_reconstructed_r10 # 10x100

B_global_2 = torch.zeros(r_G, k)
B_global_2[:r_slice, :] = B_reconstructed_r10 # 100x10


print("--- Task 3 Results (Mismatch Inference) ---")
print(f"Global Rank (r_G): {r_G}, Local Rank (r_L): {r_L}, Base Dims: {d}x{k}")

print("\n[Inference 1: Zero-Padding/Setting]")
print(f"Value in updated section (r=1 to 70): {val_updated_1:.2f}")
print(f"Value in DILUTED section (r=71 to 100): {val_diluted_1:.2f}")

print("\n[Inference 2: Truncation/Discarding]")
print(f"Delta_W_global_2 shape (Aggregated Delta W): {Delta_W_global_2.shape}")
print(f"Final A_global_2 shape (Reconstructed): {A_global_2.shape}")
print(f"Final B_global_2 shape (Reconstructed): {B_global_2.shape}")
print(f"Value in Truncated section of A_global_2 (r=71): {A_global_2[0, r_L].item():.2f}")
print("Note: In Inference 2, the local update fully utilizes the 10x10 Delta_W space, but the final reconstructed A/B matrices only have non-zero values up to rank 10 (r_slice), demonstrating the lost rank potential.")

--- Task 3 Results (Mismatch Inference) ---
Global Rank (r_G): 100, Local Rank (r_L): 70, Base Dims: 10x10

[Inference 1: Zero-Padding/Setting]
Value in updated section (r=1 to 70): 3.00
Value in DILUTED section (r=71 to 100): 2.50

[Inference 2: Truncation/Discarding]
Delta_W_global_2 shape (Aggregated Delta W): torch.Size([10, 10])
Final A_global_2 shape (Reconstructed): torch.Size([10, 100])
Final B_global_2 shape (Reconstructed): torch.Size([100, 10])
Value in Truncated section of A_global_2 (r=71): 0.00
Note: In Inference 2, the local update fully utilizes the 10x10 Delta_W space, but the final reconstructed A/B matrices only have non-zero values up to rank 10 (r_slice), demonstrating the lost rank potential.
