In [15]:
import torch
import math
import numpy as np

In [16]:
def aggregate(W,R,C):
    if len(W) < 2:
        print("No need to aggregate")
        return
    Wt = (torch.eye(R[0].shape[0]).double()-R[0]@C[1]+R[0]@C[1]@torch.inverse(C[0]+C[1])@C[1])@W[0] + (torch.eye(R[0].shape[0]).double()-R[1]@C[0]+R[1]@C[0]@torch.inverse(C[0]+C[1])@C[0])@W[1]
    Ct  = C[0]+C[1]
    Rt = torch.pinverse(Ct)
    for i in range(1,len(W)-1):
        Wt = (torch.eye(R[0].shape[0]).double()-Rt@C[i+1]+Rt@C[i+1]@torch.inverse(Ct+C[i+1])@C[i+1])@Wt + (torch.eye(R[0].shape[0]).double()-R[i+1]@Ct+R[i+1]@Ct@torch.inverse(Ct+C[i+1])@Ct)@W[i+1]
        Ct = Ct + C[i+1]
        Rt = torch.pinverse(Ct)
    return Wt, Ct

In [17]:
def RI(W,C,nc,rg):
    R_origin= torch.pinverse(C-nc*rg*torch.eye(512).double())
    Wt = W+(nc*rg*R_origin)@W
    return Wt

In [18]:
def generate_data():
    X = torch.randn(10000,512).double()
    Z_scalar = torch.randint(low=0, high=10, size=(10000,))
    length = len(Z_scalar)
    num_classes = Z_scalar.max().item() + 1
    one_hot_matrix = torch.eye(num_classes)
    Z = one_hot_matrix[Z_scalar].double()
    return X, Z

In [19]:
def partition_data(X, Z, num_client):    
    data_x = []
    data_z = []
    num_data_per_client = int(X.shape[0]/num_client)
    for i in range(num_client):
        data_x.append(X[i*num_data_per_client:(i+1)*num_data_per_client])
        data_z.append(Z[i*num_data_per_client:(i+1)*num_data_per_client])
    return data_x, data_z

In [20]:
def training(data_x,data_z,num_client,rg):
    C,CRg,R,RRg = [],[],[],[]
    W, WRg = [], []
    for i in range(num_client):
        C.append(data_x[i].T@data_x[i])
        R.append(torch.pinverse(C[i]))
        CRg.append(data_x[i].T@data_x[i]+rg*torch.eye(512).double())
        RRg.append(torch.pinverse(CRg[i]))
    for i in range(num_client):
        W.append(torch.pinverse(data_x[i])@data_z[i])
        WRg.append(RRg[i]@data_x[i].T@data_z[i])
    
    return C,CRg,R,RRg,W, WRg

In [21]:
def diff(W_agg,W_total):
    return torch.sum(torch.abs(W_total-W_agg)).data

In [23]:
runs = 5
rg = 1

num_clients = [2, 10, 20, 50, 100, 200]
diffs_1 = []
diffs_2 = []
for t in range(runs):
    print("Run #{}".format(t))
    X, Z = generate_data()
    
    C_total = X.T@X
    R_total = torch.pinverse(C_total)
    iX = torch.pinverse(X)
    W_total = iX@Z
    
    diff_per_run_1 = []
    diff_per_run_2 = []
    for num_client in num_clients:
        print("Client amount:{}".format(num_client))
        data_x, data_z = partition_data(X, Z, num_client)
        C,CRg,R,RRg,W, WRg = training(data_x,data_z,num_client,rg)
        W_agg, _ = aggregate(W,R,C)
        W_aggRg,C_aggRg = aggregate(WRg,RRg,CRg)
        W_aggRg_flip = RI(W_aggRg,C_aggRg, num_client, rg)
        diff1 = diff(W_agg, W_total)
        diff2 = diff(W_aggRg_flip, W_total)
        print("Difference between aggregation of weights with regularization with {} clients in total:{}".format(num_client,diff1))
        print("Difference between aggregation of weights via flipping regularization with {} clients in total:{}".format(num_client,diff2))
        diff_per_run_1.append(diff1)
        diff_per_run_2.append(diff2)
    diffs_1.append(diff_per_run_1)
    diffs_2.append(diff_per_run_2)  
    
diffs_1 = np.array(diffs_1)
diffs_2 = np.array(diffs_2)

Run #0
Client amount:2
Difference between aggregation of weights with regularization with 2 clients in total:7.968852425263746e-14
Difference between aggregation of weights via flipping regularization with 2 clients in total:4.9031221918062104e-14
Client amount:10
Difference between aggregation of weights with regularization with 10 clients in total:1.7735499856653524e-12
Difference between aggregation of weights via flipping regularization with 10 clients in total:1.7689026911193637e-12
Client amount:20
Difference between aggregation of weights with regularization with 20 clients in total:0.966210706713434
Difference between aggregation of weights via flipping regularization with 20 clients in total:5.074599091205582e-10
Client amount:50
Difference between aggregation of weights with regularization with 50 clients in total:5.937192917708385
Difference between aggregation of weights via flipping regularization with 50 clients in total:8.552229799273888e-10
Client amount:100
Difference 

In [25]:
mean1 = np.mean(diffs_1,axis=0)
mean2 = np.mean(diffs_2,axis=0)
print(mean1)
print(mean2)

[7.82773497e-14 1.75617839e-12 9.76156643e-01 5.97395172e+00
 5.92788566e+04 3.67204871e+12]
[4.94720776e-14 1.73883758e-12 5.09290985e-10 8.44973762e-10
 7.57385149e-10 7.81168590e-10]
