In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import copy
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import networkx as nx
from torch.nn.utils import parameters_to_vector, vector_to_parameters

In [2]:
# Graph implementation
def generate_graph(cluster_sizes=[100,100], pin=0.5, pout=0.01, seed=0):
    """Generate a random connected graph"""
    probs = np.array([[pin, pout],[pout, pin]])
    while True:
        g = nx.stochastic_block_model(cluster_sizes, probs)
        if nx.algorithms.components.is_connected(g):
            return g


cluster_sizes = [10, 10]
pin = 0.5
#pout = 0.01
pout = 0.01
seed = 0
alpha = 1e-2
lamda = 1e-3
eta = 1e-3
d0 = 9
no_users = sum(cluster_sizes)
batch_size = 50
epochs = 1
it = 2000
G = generate_graph(cluster_sizes, pin, pout, seed)

#nx.draw(G, with_labels=True, node_size=100, alpha=1, linewidths=10)
#plt.show()

In [3]:
# Metropolis weights 
number_nodes = G.number_of_nodes()
weights = np.zeros([number_nodes, number_nodes])
for edge in G.edges():
  i, j = edge[0], edge[1]
  weights[i - 1][j - 1] = 1 / (1 + np.max([G.degree(i), G.degree(j)]))
  weights[j - 1][i - 1] = weights[i - 1][j - 1]

print(weights)

weights = weights + np.diag(1 - np.sum(weights, axis=0))

metropolis_weights = weights
print(metropolis_weights)


[[0.         0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
  0.11111111 0.11111111 0.11111111 0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.        ]
 [0.11111111 0.         0.11111111 0.11111111 0.11111111 0.11111111
  0.11111111 0.11111111 0.11111111 0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.        ]
 [0.11111111 0.11111111 0.         0.14285714 0.         0.
  0.14285714 0.14285714 0.14285714 0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.        ]
 [0.11111111 0.11111111 0.14285714 0.         0.         0.14285714
  0.         0.14285714 0.14285714 0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.        ]
 [0.11111111 0.11111111 0.         0.         0.         0.
  0.         0.16666667 0.         0.         0.         0.
  0.         0.         0.         0

In [4]:
def load_dataset():
    transforms_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,))])
    mnist_data_train = datasets.MNIST('./data/mnist', train=True, download=True, transform=transforms_mnist)
    mnist_data_test = datasets.MNIST('./data/mnist', train=False, download=True, transform=transforms_mnist)

    return mnist_data_train, mnist_data_test

In [5]:
def degrees(A):
    """Return the degrees of each node of a graph from its adjacency matrix"""
    return np.sum(A, axis=0).reshape(A.shape[0], 1)

def node_degree(n, G):
    cnt = 0
    for i in G.neighbors(n):
        cnt += 1
    return cnt

def get_neighbors(n, G):
    neighbors_list = []
    for i in G.neighbors(n):
        neighbors_list.append(int(i))
    return neighbors_list

In [6]:
# Dataset partitioning
def random_split(X, y, n, seed):
    """Equally split data between n agents"""
    rng = np.random.default_rng(seed)
    perm = rng.permutation(y.size)
    X_split = np.array_split(X[perm], n)  #np.stack to keep as a np array
    y_split = np.array_split(y[perm], n)
    return X_split, y_split





X_train = np.load('X_train.npy')
X_test = np.load('X_test.npy')
y_train = np.load('y_train.npy')
y_test = np.load('y_test.npy')


X, y = random_split(X_train, y_train, no_users, 1234)

In [7]:
X_train.shape

(14087, 9)

In [8]:
datapoints = {}
count = 0
W1 = np.array([2.0, 2.0, 3.0, 3.0])
W2 = np.array([-2.0, 2.0, 3.0, -3.0])
W3 = 2 * W1
W4 = 2  * W2
W = [W1, W2]
m = 200
n = 4

scaler = [1.0, -1.0]

noise_sd = 0.001
for i, cluster_size in enumerate(cluster_sizes):
    for j in range(cluster_size):
        features = np.random.normal(loc=0.0, scale=1.0, size=(m, n))
        label = np.dot(features, W[i ]) + np.random.normal(0,noise_sd)
        data = X[count]
        data[:, 0:4] *= scaler[i]
        datapoints[count] = {
                'features': data,
                'degree': node_degree(count, G),
                'label': y[count],
                'neighbors': get_neighbors(count, G),
                'exact_weights': torch.from_numpy(W[i])
            }
        count += 1

In [9]:
class MyDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = torch.FloatTensor(data)
        self.targets = torch.FloatTensor(targets).unsqueeze(-1)
        
    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]

        return x, y
    
    def __len__(self):
        return len(self.data)


In [10]:
class MLP_Net(nn.Module):
    def __init__(self, user_id):
        super(MLP_Net, self).__init__()
        self.fc1 = nn.Linear(9, 4, bias=False)
        self.fc2 = nn.Linear(4, 1, bias=False)
        #self.fc3 = nn.Linear(200, 10)
        self.user_id = user_id

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        output = self.fc2(x)
        #output = self.fc3(x)
        return output

In [11]:
from typing import Iterable, Optional

def grads_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor:
    r"""Convert parameters to one vector

    Args:
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.

    Returns:
        The parameters represented by a single vector
    """
    # Flag for the device where the parameter is located
    param_device = None

    vec = []
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = param.grad

        vec.append(param_device.view(-1))
    return torch.cat(vec)

In [12]:
model = MLP_Net(user_id=0)

lr = 0.01

dataloader = DataLoader(MyDataset(datapoints[19]["features"], datapoints[19]["label"]), batch_size=100, shuffle=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for i in range(100):
    for (x, y) in dataloader:
        criterion = nn.MSELoss()
        optimizer.zero_grad()
        yhat = model(x)
        print(y.size())
        print(yhat.size())
        loss = criterion(yhat, y)
        
        loss.backward()
        print(i, loss, grads_to_vector(model.parameters()), parameters_to_vector(model.parameters()))
        #optimizer.step()
        new_model = parameters_to_vector(model.parameters()) - lr * grads_to_vector(model.parameters())
        vector_to_parameters(parameters=model.parameters(), vec=new_model)
        #if i % 50 ==0:
            #lr *= 0.9
            

#parameters_to_vector(model.parameters())

torch.Size([100, 1])
torch.Size([100, 1])
0 tensor(51.3979, grad_fn=<MseLossBackward0>) tensor([-1.1087e+00,  5.5799e-01,  6.3292e-01,  9.8480e-01, -7.4522e-01,
        -4.7028e-01, -8.7950e-01, -9.5351e-01, -1.3999e+00,  5.6491e-01,
        -3.7858e-01, -5.7189e-01,  9.8729e-01, -1.3380e+00, -1.5386e+00,
        -1.3571e+00,  1.6465e+00,  3.0235e-01,  1.6303e-03, -1.3636e-02,
         2.8218e-02, -3.2794e-02,  4.4032e-02,  5.2516e-02,  4.4513e-02,
        -5.7070e-02, -3.2253e-02, -1.6077e+00,  1.7316e+00,  1.1764e-01,
        -2.7916e-01, -1.0065e-02, -1.1904e-01, -2.4345e-02,  2.4428e+00,
        -2.2569e+00, -5.8705e+00, -2.4265e+00, -4.5961e+00, -2.8717e+00]) tensor([ 0.2174, -0.0693, -0.2605, -0.3262,  0.1340,  0.1778,  0.3215,  0.1843,
         0.2819, -0.1330, -0.3203,  0.1090,  0.2042, -0.2964, -0.1564,  0.0786,
         0.0686, -0.0298, -0.0785,  0.0474,  0.2879,  0.0576, -0.2817, -0.3173,
        -0.2821,  0.1581,  0.1648, -0.2803, -0.0643,  0.0141,  0.0529,  0.2593,
       

6 tensor(15.0719, grad_fn=<MseLossBackward0>) tensor([ 0.8612, -1.7823,  0.5173, -1.3180,  0.1214,  0.6647,  0.1340, -0.1695,
         2.0680,  0.0390, -0.0453,  0.1095, -0.0127,  0.0334,  0.0115,  0.0295,
        -0.0114, -0.0217,  0.2373, -0.4072, -1.0512, -0.1470, -0.2342, -0.4460,
        -0.0678, -0.4390,  0.5376, -0.1597,  0.1214,  0.1738,  0.2174, -0.0360,
        -0.1370, -0.0848, -0.0979, -0.0821,  0.9944, -0.1596,  0.1263, -0.0073]) tensor([ 0.5347,  0.0308, -0.7347, -0.8840,  0.5203,  0.4822,  0.8255,  1.3063,
         0.2657, -0.1994, -0.2505,  0.0652,  0.1026, -0.1739, -0.0466,  0.1979,
         0.0028, -0.0129, -0.2818,  0.4920, -0.0117,  0.4893, -0.8391, -1.0310,
        -0.8319,  0.8867,  0.4522, -0.1554, -0.1765, -0.0055,  0.0279,  0.2507,
         0.0350, -0.1503,  0.0424,  0.0174,  2.0244, -0.0572,  1.8978, -0.2528],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
6 tensor(16.2499, grad_fn=<MseLossBackward0>) tensor([ 1.1276, -1.6622,  2.5219

7 tensor(17.0610, grad_fn=<MseLossBackward0>) tensor([ 1.5735, -3.2579,  3.0709, -3.9829,  2.5402,  2.5587,  2.4717, -0.9180,
         1.3120,  0.0253, -0.0302,  0.0590, -0.0280,  0.0307,  0.0176,  0.0269,
        -0.0239, -0.0206,  0.2222, -0.6103, -0.8619,  1.7886, -1.9709, -1.1736,
        -1.7406, -0.0783,  0.0648, -0.2514,  0.3311, -0.0888,  0.4199, -0.2105,
        -0.4037, -0.3123, -0.0114, -0.0491,  2.1683, -0.0426,  2.3968,  0.1390]) tensor([ 0.5350,  0.0982, -0.8718, -0.8302,  0.4983,  0.4581,  0.8190,  1.3832,
         0.1171, -0.2029, -0.2460,  0.0594,  0.1046, -0.1769, -0.0480,  0.1948,
         0.0051, -0.0115, -0.2965,  0.5626, -0.0495,  0.4560, -0.7852, -1.0431,
        -0.7838,  0.9468,  0.4147, -0.1528, -0.1760, -0.0109,  0.0167,  0.2478,
         0.0389, -0.1501,  0.0546,  0.0233,  2.0743, -0.0460,  1.8933, -0.2495],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
7 tensor(14.0604, grad_fn=<MseLossBackward0>) tensor([-7.5766e-01, -9.2022e-02,

11 tensor(5.0243, grad_fn=<MseLossBackward0>) tensor([-5.4115e+00,  5.6999e+00, -3.6406e+00,  4.1359e+00, -3.0370e+00,
        -8.5895e+00, -3.1099e+00, -1.4816e+00,  2.4218e+00, -8.0875e-04,
         1.1410e-03, -1.4426e-03,  1.0606e-03, -7.5762e-04, -5.0858e-04,
        -7.5315e-04, -5.4906e-04,  1.0360e-04,  5.5203e-01, -7.5810e-01,
         5.2182e-01, -6.3961e-01,  5.9275e-01,  6.0563e-01,  5.4320e-01,
         5.0815e-01, -1.7075e+00,  5.8465e-01, -6.1581e-01,  3.9332e-01,
        -4.4683e-01,  3.2811e-01,  9.2800e-01,  3.3599e-01,  1.6006e-01,
        -2.6164e-01, -4.9369e+00, -4.0735e-02, -1.4863e+00, -3.2985e-01]) tensor([ 6.5227e-01,  1.9792e-01, -1.1683e+00, -6.4548e-01,  4.0341e-01,
         4.3415e-01,  7.9046e-01,  1.4943e+00, -1.9336e-01, -2.0967e-01,
        -2.3664e-01,  4.9758e-02,  1.0747e-01, -1.8048e-01, -4.9106e-02,
         1.9107e-01,  9.4716e-03, -9.7626e-03, -2.9199e-01,  7.3983e-01,
        -1.1422e-01,  4.1348e-01, -6.4825e-01, -1.1284e+00, -6.6916e-01,
    

13 tensor(13.7054, grad_fn=<MseLossBackward0>) tensor([ 1.2078, -1.7235,  1.0263, -1.0853, -0.1501,  1.3476, -0.2868,  1.0516,
         0.9044, -0.0124,  0.0150, -0.0124,  0.0088, -0.0148, -0.0075, -0.0148,
         0.0129,  0.0147, -0.1429, -0.6186, -0.6260,  0.6039, -1.6494,  0.0563,
        -1.3440,  0.4506,  2.3515, -0.1574,  0.1170,  0.0447,  0.0567,  0.1232,
        -0.1067,  0.0988, -0.1959, -0.0716,  0.8178, -0.0338,  1.7765,  0.0221]) tensor([ 0.7420,  0.1666, -1.1463, -0.6420,  0.4263,  0.5104,  0.8233,  1.5246,
        -0.2869, -0.2091, -0.2374,  0.0505,  0.1073, -0.1803, -0.0490,  0.1913,
         0.0092, -0.0099, -0.2698,  0.7694, -0.1080,  0.4259, -0.6201, -1.1545,
        -0.6522,  0.9048,  0.5414, -0.1584, -0.1576, -0.0458, -0.0068,  0.2314,
         0.0406, -0.1497,  0.0850,  0.0270,  2.2760,  0.0143,  1.9154, -0.2327],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
13 tensor(14.9442, grad_fn=<MseLossBackward0>) tensor([ 1.9890e+00, -3.5501e+0

16 tensor(11.4743, grad_fn=<MseLossBackward0>) tensor([ 0.6060, -0.4707,  2.0655, -2.4391,  1.9808,  2.4368,  1.9430,  1.3491,
        -0.1699, -0.0447,  0.0604, -0.0362,  0.0108, -0.0249, -0.0127, -0.0284,
         0.0471, -0.0111, -0.2067, -0.4728,  0.8343, -0.4054, -0.4205,  0.7556,
        -0.4073,  1.1991, -1.3957, -0.0824,  0.0081, -0.0687,  0.2033, -0.1562,
        -0.2036, -0.1703, -0.1511,  0.0576,  2.1264, -0.1299, -0.4642,  0.0939]) tensor([ 0.8181,  0.2491, -1.1989, -0.5324,  0.4004,  0.4702,  0.8335,  1.5598,
        -0.3761, -0.2010, -0.2485,  0.0615,  0.1049, -0.1774, -0.0486,  0.1942,
         0.0048, -0.0123, -0.2319,  0.8596, -0.0484,  0.4251, -0.5077, -1.2176,
        -0.5648,  0.8315,  0.5971, -0.1522, -0.1571, -0.0566, -0.0163,  0.2255,
         0.0460, -0.1467,  0.0977,  0.0213,  2.3172,  0.0557,  1.9042, -0.2267],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
16 tensor(13.5324, grad_fn=<MseLossBackward0>) tensor([ 3.3394e-01, -5.9296e-0

21 tensor(13.1608, grad_fn=<MseLossBackward0>) tensor([ 1.4218, -1.7952,  0.6903, -1.1694, -0.0998,  1.3664, -0.1816,  1.4003,
         0.3283, -0.1926,  0.2127, -0.0925,  0.0432, -0.0835,  0.0028, -0.0691,
         0.0703,  0.0439, -0.5148, -0.1626, -0.2978,  0.1975, -1.1954,  0.2979,
        -0.8534,  0.7387,  2.5624, -0.1122,  0.0715, -0.0138,  0.0650,  0.0431,
        -0.1316,  0.0149, -0.1693, -0.0480,  1.3216, -0.2719,  1.4096,  0.0578]) tensor([ 0.9535,  0.2752, -1.1686, -0.4724,  0.4036,  0.4947,  0.8809,  1.6044,
        -0.5035, -0.1536, -0.3060,  0.1041,  0.0992, -0.1719, -0.0404,  0.1996,
        -0.0109, -0.0164, -0.1492,  0.9527,  0.0181,  0.4633, -0.3974, -1.3329,
        -0.4859,  0.7360,  0.6689, -0.1467, -0.1533, -0.0668, -0.0239,  0.2199,
         0.0430, -0.1421,  0.1077,  0.0177,  2.3943,  0.1539,  1.9568, -0.2165],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
21 tensor(13.9567, grad_fn=<MseLossBackward0>) tensor([ 2.2771, -3.6381,  1.44

22 tensor(13.7629, grad_fn=<MseLossBackward0>) tensor([-1.8780,  1.1510,  1.0064, -0.1648,  0.4314,  1.2910,  0.0810, -0.9033,
        -2.1106, -0.3660,  0.4125, -0.2073, -0.0320,  0.0751,  0.1714,  0.0623,
        -0.0126, -0.1065, -0.3485,  0.2804,  0.3547, -0.5836,  1.0584,  1.2736,
         0.9479, -1.2478, -0.1510,  0.1783, -0.1868, -0.1217,  0.0432, -0.0668,
        -0.1617, -0.0553,  0.0703,  0.2295, -0.8419, -0.6420, -1.7594, -0.0375]) tensor([ 0.9350,  0.3452, -1.1883, -0.4230,  0.3942,  0.4641,  0.8814,  1.5967,
        -0.5131, -0.1283, -0.3343,  0.1182,  0.0995, -0.1730, -0.0409,  0.1978,
        -0.0170, -0.0142, -0.1251,  0.9766,  0.0414,  0.4464, -0.3407, -1.3452,
        -0.4398,  0.7044,  0.6492, -0.1424, -0.1562, -0.0672, -0.0276,  0.2188,
         0.0457, -0.1403,  0.1118,  0.0161,  2.3799,  0.1941,  1.9301, -0.2159],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
22 tensor(11.0367, grad_fn=<MseLossBackward0>) tensor([ 8.3784e-01, -7.4380e-0

        -4.3070e-02,  1.4133e+00, -8.1114e-01,  1.2652e+00,  7.1901e-02]) tensor([ 1.0387,  0.4039, -1.1765, -0.3530,  0.4180,  0.4712,  0.9205,  1.6396,
        -0.5405,  0.0692, -0.5504,  0.2100,  0.1230, -0.2096, -0.0636,  0.1585,
        -0.0582,  0.0309, -0.0434,  1.0587,  0.0705,  0.4699, -0.2493, -1.4621,
        -0.3707,  0.6371,  0.6528, -0.1396, -0.1511, -0.0757, -0.0299,  0.2138,
         0.0413, -0.1378,  0.1163,  0.0086,  2.4432,  0.5069,  2.0047, -0.2054],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
27 tensor(12.8244, grad_fn=<MseLossBackward0>) tensor([ 2.4159e+00, -3.6979e+00,  1.1540e+00, -1.7620e+00,  6.5230e-01,
         9.9853e-01,  6.8244e-01,  3.7440e-02, -8.7944e-01, -3.7359e-01,
         2.9087e-01, -9.5399e-02, -5.8522e-01,  5.3560e-01,  5.7046e-01,
         6.3031e-01,  2.9413e-01, -2.3713e-01, -1.1490e-03, -4.3246e-01,
        -8.0653e-01,  1.9369e+00, -1.9720e+00, -1.6275e+00, -1.9218e+00,
         5.2614e-01, -4.5761e-01, -1.618

29 tensor(11.4246, grad_fn=<MseLossBackward0>) tensor([ 1.0848, -1.4387,  0.5600, -1.3512,  0.1778,  1.2832,  0.0981,  1.4231,
         0.0982, -0.4835,  0.5524, -0.5262, -0.1297,  0.0399,  0.2933,  0.1336,
         0.1784, -0.0228, -0.3006, -0.3109, -0.2994,  0.4654, -1.2277, -0.1010,
        -1.0034,  0.5461,  2.5471, -0.0836,  0.0509,  0.0263,  0.0383,  0.0367,
        -0.0899,  0.0203, -0.1420, -0.0365,  1.3389, -0.9411,  1.2465,  0.0750]) tensor([ 1.0484,  0.4624, -1.1753, -0.3119,  0.4255,  0.4586,  0.9301,  1.6431,
        -0.5406,  0.1995, -0.6915,  0.2728,  0.1469, -0.2407, -0.0873,  0.1238,
        -0.0889,  0.0664, -0.0305,  1.1134,  0.0722,  0.4594, -0.1948, -1.4853,
        -0.3236,  0.6045,  0.6337, -0.1371, -0.1511, -0.0793, -0.0312,  0.2122,
         0.0414, -0.1363,  0.1183,  0.0059,  2.4468,  0.7195,  2.0164, -0.2029],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
29 tensor(12.0058, grad_fn=<MseLossBackward0>) tensor([ 2.3891, -3.6225,  1.08

35 tensor(8.2827, grad_fn=<MseLossBackward0>) tensor([ 2.4018, -3.3916,  1.1143, -1.4124,  0.3020,  0.7148,  0.3542,  0.4362,
         0.0109, -0.1035, -0.4585,  0.3174, -1.3328,  1.2731,  1.0560,  1.3094,
        -0.0510, -0.4785,  0.3693, -0.7538, -0.1224,  1.7261, -1.9637, -1.1912,
        -1.9494,  1.0255,  0.2321, -0.1742,  0.2502, -0.0504,  0.0972, -0.0358,
        -0.0988, -0.0561,  0.0127,  0.0358,  0.1042, -0.3081,  0.9016,  0.0944]) tensor([ 9.9820e-01,  7.3142e-01, -1.1793e+00, -1.5704e-01,  4.4534e-01,
         3.7300e-01,  9.5022e-01,  1.5925e+00, -4.9773e-01,  6.2696e-01,
        -1.0903e+00,  4.8283e-01,  2.1775e-01, -3.5381e-01, -1.6930e-01,
         7.9006e-03, -1.1312e-01,  2.2340e-01, -1.2108e-01,  1.4146e+00,
         2.5736e-02,  3.3561e-01,  6.8762e-02, -1.4614e+00, -7.3171e-02,
         5.1759e-01,  4.7634e-01, -1.2704e-01, -1.6034e-01, -9.2338e-02,
        -3.4914e-02,  2.0775e-01,  4.3608e-02, -1.3441e-01,  1.2386e-01,
        -2.2093e-03,  2.4117e+00,  1.3695e

        -0.0407, -0.1050,  0.0307, -0.0830,  0.6516, -0.3784, -0.8719,  0.0668]) tensor([ 1.0079,  0.7830, -1.1796, -0.1276,  0.4603,  0.3635,  0.9645,  1.5662,
        -0.4753,  0.7196, -1.1337,  0.5007,  0.2320, -0.3886, -0.1871, -0.0211,
        -0.0723,  0.2919, -0.1506,  1.4981, -0.0031,  0.3035,  0.1339, -1.4513,
        -0.0042,  0.4835,  0.4325, -0.1262, -0.1639, -0.0945, -0.0353,  0.2060,
         0.0432, -0.1353,  0.1249, -0.0065,  2.4115,  1.4749,  2.0874, -0.2080],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
37 tensor(8.0548, grad_fn=<MseLossBackward0>) tensor([ 0.0693, -0.5507, -0.0557, -0.7540,  0.3397,  0.5323,  0.1897,  0.8976,
        -0.3400,  0.0442, -0.1844, -0.1454,  0.0737, -0.0434, -0.0373, -0.1503,
        -0.0629,  0.3179, -0.1241, -0.3024, -0.6043,  0.7411, -1.2112, -0.4248,
        -1.0757,  0.6083,  2.0222, -0.0683,  0.0521,  0.0814, -0.0227,  0.0438,
        -0.0117,  0.0543, -0.1041,  0.0078,  0.7857,  0.2076,  0.6538,  0.0900]

43 tensor(8.4528, grad_fn=<MseLossBackward0>) tensor([-2.2321e-01, -3.5567e-01, -4.6006e-01, -5.4980e-01,  6.0859e-01,
         5.0157e-01,  5.4126e-01, -7.3556e-01, -7.4256e-01, -4.4141e-01,
        -2.1587e-01,  6.3843e-02,  3.2694e-01,  3.4063e-02,  5.1433e-01,
         2.8562e-02, -1.5787e+00, -5.6617e-01,  3.2959e-01, -4.7749e-01,
        -9.8803e-01, -1.1296e+00,  4.8898e-01, -5.2832e-02,  4.8142e-01,
        -4.8298e-01,  1.0995e+00,  6.3176e-02, -4.1269e-02, -1.0928e-03,
         4.6985e-02, -4.9483e-02, -8.5734e-02, -5.3108e-02,  2.0186e-02,
         1.2080e-01,  3.3577e-02, -4.3933e-01, -4.3152e-01,  3.7563e-02]) tensor([ 1.0782,  0.9233, -1.0801,  0.0394,  0.4574,  0.3375,  0.9543,  1.4500,
        -0.3697,  0.8877, -1.0919,  0.5249,  0.2292, -0.4592, -0.2454, -0.0665,
         0.1660,  0.4464, -0.2565,  1.7807,  0.0241,  0.2484,  0.3252, -1.3440,
         0.1896,  0.3691,  0.1359, -0.1205, -0.1800, -0.1098, -0.0445,  0.2074,
         0.0457, -0.1312,  0.1289, -0.0148,  2.33

45 tensor(5.2342, grad_fn=<MseLossBackward0>) tensor([ 3.8542e-01, -4.7105e-01,  6.2598e-01, -1.8044e+00,  1.2491e+00,
         1.1746e+00,  1.3055e+00,  1.4095e+00,  3.6269e-01,  1.7108e-01,
        -2.7443e-01,  5.8161e-01, -9.3778e-01,  9.0093e-01,  1.0335e+00,
         9.3365e-01,  3.7385e-01, -7.2166e-01,  2.3741e-01, -5.5303e-01,
         6.4602e-01, -7.2195e-01,  5.0693e-01,  1.6186e-01,  4.1929e-01,
         5.7235e-01,  8.8899e-01, -7.2580e-03,  1.5973e-02, -5.9968e-02,
         1.5662e-01, -1.3056e-01, -1.2995e-01, -1.3263e-01, -1.1081e-01,
        -8.6132e-04,  1.4001e+00, -2.7894e-01, -4.3771e-01,  1.4419e-01]) tensor([ 1.1179,  0.9471, -1.0384,  0.0886,  0.4381,  0.3401,  0.9338,  1.4312,
        -0.3332,  0.9247, -1.0578,  0.5272,  0.2127, -0.4679, -0.2647, -0.0667,
         0.2513,  0.4802, -0.2922,  1.8689,  0.0291,  0.2744,  0.3553, -1.2651,
         0.2266,  0.3162,  0.0665, -0.1220, -0.1816, -0.1137, -0.0482,  0.2100,
         0.0471, -0.1280,  0.1296, -0.0172,  2.31

        -3.8964e-02, -8.7238e-01,  7.6613e-02, -1.2186e-01, -8.0207e-04]) tensor([ 1.2175e+00,  9.6292e-01, -9.5190e-01,  1.5707e-01,  4.1289e-01,
         3.4942e-01,  9.0137e-01,  1.3823e+00, -2.9102e-01,  9.8501e-01,
        -9.9157e-01,  5.3447e-01,  1.6711e-01, -4.5521e-01, -2.8725e-01,
        -3.6118e-02,  3.4520e-01,  4.9788e-01, -3.0149e-01,  1.9997e+00,
        -1.1781e-02,  3.2994e-01,  3.6191e-01, -1.1150e+00,  2.5365e-01,
         2.4153e-01,  1.7695e-03, -1.2482e-01, -1.8379e-01, -1.2307e-01,
        -5.1569e-02,  2.1133e-01,  4.6168e-02, -1.2486e-01,  1.3326e-01,
        -1.8044e-02,  2.2783e+00,  1.6382e+00,  2.2441e+00, -2.4272e-01],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
50 tensor(7.5262, grad_fn=<MseLossBackward0>) tensor([-0.0657, -0.3748, -0.3995, -0.6706,  0.7846,  0.5678,  0.7341, -0.5762,
        -0.6295, -0.3817, -0.1665,  0.1871,  0.2244,  0.0617,  0.4396,  0.0154,
        -1.0951, -0.1804, -0.0307, -0.1601, -0.8822, -1.2134, 

51 tensor(4.9001, grad_fn=<MseLossBackward0>) tensor([ 0.2666, -0.2384,  0.2912, -1.5510,  1.1881,  1.4991,  1.2625,  0.9873,
         0.3642,  0.2747, -0.3096,  0.5862, -0.9106,  0.8708,  1.0772,  0.9041,
         0.4219, -0.5303,  0.0463, -0.3382,  0.4813, -0.6555,  0.6155,  0.4630,
         0.5236,  0.0155,  0.6469, -0.0205,  0.0143, -0.0526,  0.1519, -0.1346,
        -0.1402, -0.1365, -0.1088, -0.0345,  1.2865, -0.0864, -0.4736,  0.1365]) tensor([ 1.2234,  0.9889, -0.9477,  0.1757,  0.4073,  0.3451,  0.8938,  1.3810,
        -0.2695,  1.0000, -0.9720,  0.5260,  0.1660, -0.4603, -0.3053, -0.0376,
         0.3654,  0.5000, -0.3011,  2.0287, -0.0129,  0.3342,  0.3701, -1.0679,
         0.2693,  0.2589, -0.0290, -0.1232, -0.1870, -0.1240, -0.0524,  0.2114,
         0.0465, -0.1246,  0.1332, -0.0186,  2.2829,  1.6418,  2.2503, -0.2448],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
51 tensor(5.6163, grad_fn=<MseLossBackward0>) tensor([-0.2686,  0.1410, -0.5381

53 tensor(1.3791, grad_fn=<MseLossBackward0>) tensor([-1.9716e+00,  2.0183e+00, -6.5485e-01, -1.2583e+00,  1.6418e+00,
         5.2351e-02,  1.7660e+00, -2.7264e-03, -1.1136e+00, -5.3385e-01,
         5.7021e-01, -3.4017e-01,  1.8245e-01, -6.9254e-02, -5.2969e-01,
        -6.2361e-02, -1.1329e-01,  4.9948e-02,  2.1492e-01, -7.5867e-01,
         1.1849e+00, -3.1901e+00,  3.2939e+00,  2.3767e+00,  3.2788e+00,
         1.4864e+00, -5.6584e+00,  2.1520e-01, -2.2030e-01,  7.1479e-02,
         1.3735e-01, -1.7921e-01, -5.7143e-03, -1.9276e-01,  2.9760e-04,
         1.2155e-01,  1.0578e+00, -6.3464e-01, -1.0268e+00,  6.5423e-02]) tensor([ 1.2591,  0.9954, -0.9216,  0.1920,  0.4026,  0.3259,  0.8840,  1.3615,
        -0.2512,  1.0142, -0.9311,  0.5325,  0.1499, -0.4600, -0.3286, -0.0297,
         0.3868,  0.4980, -0.3009,  2.0823, -0.0189,  0.3486,  0.3677, -1.0184,
         0.2774,  0.2717, -0.0597, -0.1212, -0.1904, -0.1278, -0.0516,  0.2103,
         0.0459, -0.1247,  0.1356, -0.0165,  2.27

55 tensor(5.6424, grad_fn=<MseLossBackward0>) tensor([ 0.4701, -0.6939, -0.5531,  0.6728, -1.5895, -0.4273, -1.2919,  0.3294,
        -0.0580,  0.0513,  0.1168, -0.9760,  0.5437, -0.7285, -0.0933, -0.5103,
         0.0637,  0.5025,  0.4545, -0.9596,  0.0386,  0.4525, -0.8560, -0.8381,
        -0.6359, -0.2286,  0.0146, -0.0948,  0.0753,  0.0922, -0.0492,  0.1023,
         0.0753,  0.0515,  0.0177, -0.1051, -0.3487,  0.0912, -0.7343, -0.0416]) tensor([ 1.3167,  0.9682, -0.9156,  0.2139,  0.3787,  0.3101,  0.8595,  1.3505,
        -0.2152,  1.0301, -0.9224,  0.5360,  0.1431, -0.4577, -0.3383, -0.0266,
         0.3982,  0.4895, -0.2796,  2.0979, -0.0381,  0.3666,  0.3489, -0.9977,
         0.2662,  0.2586, -0.0221, -0.1228, -0.1890, -0.1310, -0.0513,  0.2098,
         0.0448, -0.1244,  0.1375, -0.0158,  2.2651,  1.6410,  2.2703, -0.2494],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
55 tensor(6.2961, grad_fn=<MseLossBackward0>) tensor([ 9.5241e-03, -4.6362e-01,

60 tensor(5.4286, grad_fn=<MseLossBackward0>) tensor([ 0.6348, -0.9188, -0.4698,  0.5582, -1.4626, -0.3508, -1.2172,  0.3941,
         0.0851,  0.0400,  0.1834, -0.9095,  0.4978, -0.6784, -0.0530, -0.4501,
         0.1839,  0.5013,  0.2545, -0.7398,  0.0280,  0.4237, -0.7536, -0.6681,
        -0.5223, -0.3008,  0.0690, -0.0782,  0.0580,  0.0438, -0.0268,  0.1007,
         0.0571,  0.0417, -0.0148, -0.0535, -0.2436,  0.0058, -0.6443, -0.0339]) tensor([ 1.4160e+00,  9.8069e-01, -8.9648e-01,  2.5894e-01,  3.5220e-01,
         2.5841e-01,  8.3065e-01,  1.3022e+00, -1.4013e-01,  1.0638e+00,
        -8.6657e-01,  5.8326e-01,  9.8040e-02, -4.3947e-01, -3.6928e-01,
        -1.7104e-03,  4.1279e-01,  4.2517e-01, -2.5892e-01,  2.2100e+00,
        -5.9655e-02,  3.7096e-01,  3.4597e-01, -8.4706e-01,  2.9651e-01,
         2.9254e-01, -5.4994e-02, -1.1616e-01, -1.9681e-01, -1.4206e-01,
        -4.6291e-02,  2.0548e-01,  3.9165e-02, -1.2639e-01,  1.3987e-01,
        -6.4808e-03,  2.2661e+00,  1.6296e

         9.7225e-02,  5.5940e-01, -6.9098e-01, -1.2555e+00,  1.7201e-02]) tensor([ 1.4212,  1.0158, -0.8934,  0.2667,  0.3518,  0.2392,  0.8360,  1.2881,
        -0.1229,  1.0678, -0.8418,  0.6027,  0.0834, -0.4349, -0.3826,  0.0063,
         0.4106,  0.4029, -0.2578,  2.2494, -0.0550,  0.3618,  0.3564, -0.7943,
         0.3176,  0.3274, -0.1175, -0.1123, -0.2019, -0.1426, -0.0444,  0.2035,
         0.0385, -0.1277,  0.1407, -0.0046,  2.2728,  1.6203,  2.3330, -0.2554],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
62 tensor(5.6249, grad_fn=<MseLossBackward0>) tensor([-1.2186,  0.6582, -0.2445,  0.2099, -0.1928,  0.1916, -0.3019,  0.4144,
        -0.3910, -0.2858, -0.0034,  0.3376, -0.0329, -0.0915,  0.3798,  0.0232,
        -0.0047,  0.3230, -2.2561,  1.7265,  0.0571,  1.7992, -1.5090, -1.7995,
        -1.6687,  0.2069,  0.9604,  0.0735, -0.0944,  0.0287, -0.0161,  0.0555,
        -0.0130,  0.0349, -0.0783, -0.1007, -0.2255, -0.0474,  2.4418, -0.0122]) tenso

        -0.1651, -0.1406, -0.0952, -0.0712,  0.9304,  0.2522, -0.5477,  0.1089]) tensor([ 1.5302e+00,  1.0248e+00, -8.7277e-01,  3.0313e-01,  3.2550e-01,
         1.8730e-01,  8.3132e-01,  1.2349e+00, -5.8600e-02,  1.0960e+00,
        -7.9261e-01,  6.5911e-01,  3.5769e-02, -4.1673e-01, -4.1281e-01,
         2.7598e-02,  4.0331e-01,  3.2418e-01, -1.8932e-01,  2.3353e+00,
        -9.3152e-02,  3.6659e-01,  3.3818e-01, -6.6806e-01,  3.3782e-01,
         3.7052e-01, -1.3431e-01, -1.0996e-01, -2.0730e-01, -1.4334e-01,
        -4.3035e-02,  2.0229e-01,  3.9552e-02, -1.2591e-01,  1.4461e-01,
         1.3577e-03,  2.2947e+00,  1.6158e+00,  2.3633e+00, -2.5843e-01],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
67 tensor(5.3018, grad_fn=<MseLossBackward0>) tensor([-1.6070e-01,  1.6273e-01, -8.8338e-01,  1.5769e+00, -1.2799e+00,
        -1.0451e+00, -1.2343e+00, -3.7996e-01,  2.1970e-01, -1.0193e-03,
        -6.2095e-01, -9.9503e-01,  1.0110e+00, -5.0078e-01, -6.7878e-

         0.0065,  0.0808, -0.0350, -0.0378, -0.2471,  0.2050,  0.4852,  0.0028]) tensor([ 1.5625,  1.0156, -0.8568,  0.3055,  0.3253,  0.1769,  0.8321,  1.2126,
        -0.0503,  1.1001, -0.7796,  0.6835,  0.0185, -0.4114, -0.4175,  0.0340,
         0.3976,  0.3078, -0.1756,  2.3644, -0.1044,  0.3838,  0.3170, -0.6528,
         0.3233,  0.3695, -0.0988, -0.1098, -0.2081, -0.1435, -0.0432,  0.2026,
         0.0399, -0.1242,  0.1459,  0.0050,  2.2922,  1.6169,  2.3780, -0.2591],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
69 tensor(3.7835, grad_fn=<MseLossBackward0>) tensor([ 0.7171, -1.1455,  0.5901,  0.4384, -1.1680, -0.4812, -1.1079,  0.4116,
        -0.0325,  0.2315, -0.4257,  0.1458, -0.4638,  0.3507,  0.2849,  0.3405,
         0.3694, -0.1763,  0.2417, -0.3020,  0.4381,  1.3073, -1.3393, -1.1161,
        -1.4136, -0.7279,  1.3338, -0.1576,  0.1870, -0.0388, -0.0855,  0.0816,
         0.0448,  0.0796,  0.0523, -0.0030, -0.5757,  0.3246, -0.3952, -0.0126]

74 tensor(1.7815, grad_fn=<MseLossBackward0>) tensor([-1.5274,  1.4711, -0.2601, -0.8970,  1.1488, -0.4809,  1.2275,  0.0641,
        -0.6122, -0.5731,  0.5354, -0.1085, -0.0212,  0.1063, -0.6500,  0.1154,
        -0.0171,  0.0459,  1.2569, -2.0037,  1.7761, -3.2176,  3.1840,  2.6610,
         3.0737,  1.9287, -6.8783,  0.1735, -0.1671,  0.0296,  0.1019, -0.1305,
         0.0546, -0.1395, -0.0073,  0.0696,  0.1814, -0.5258, -1.6135, -0.0112]) tensor([ 1.6246,  1.0499, -0.8551,  0.3079,  0.3101,  0.1279,  0.8417,  1.1494,
        -0.0135,  1.1103, -0.7224,  0.7224, -0.0240, -0.4196, -0.4471,  0.0375,
         0.3863,  0.2453, -0.1197,  2.4340, -0.1021,  0.3778,  0.2970, -0.5537,
         0.3284,  0.4434, -0.1727, -0.1032, -0.2170, -0.1397, -0.0400,  0.2027,
         0.0447, -0.1217,  0.1496,  0.0130,  2.3099,  1.6092,  2.4172, -0.2624],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
75 tensor(5.1829, grad_fn=<MseLossBackward0>) tensor([-9.3182e-01,  5.5517e-01,

        -0.1697, -0.1474, -0.1047, -0.0721,  0.8836,  0.2702, -0.4762,  0.1135]) tensor([ 1.6515,  1.0565, -0.8543,  0.3086,  0.3078,  0.1217,  0.8481,  1.1317,
         0.0048,  1.1214, -0.7177,  0.7254, -0.0343, -0.4219, -0.4532,  0.0351,
         0.3846,  0.2180, -0.0898,  2.4473, -0.1206,  0.3761,  0.2919, -0.5211,
         0.3290,  0.4498, -0.1638, -0.1023, -0.2180, -0.1390, -0.0397,  0.2033,
         0.0460, -0.1202,  0.1502,  0.0144,  2.3230,  1.6137,  2.4167, -0.2627],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
76 tensor(5.1721, grad_fn=<MseLossBackward0>) tensor([-0.3359,  0.2390, -0.8450,  1.2669, -1.0012, -0.7883, -1.0095, -0.1971,
        -0.2163,  0.0534, -0.5892, -1.1084,  1.0313, -0.4705, -0.7135, -0.8121,
        -0.0748, -0.2085,  0.6276, -0.5081, -1.4779,  0.6362, -0.4795, -0.0109,
        -0.4362, -0.7835,  0.9581, -0.0783,  0.1122,  0.0823, -0.2324,  0.1771,
         0.1722,  0.1824,  0.0825, -0.0419, -0.2614,  0.0469, -0.6857, -0.0623]

79 tensor(4.3503, grad_fn=<MseLossBackward0>) tensor([ 0.6206, -0.4300,  0.6408, -1.2932,  1.3446,  1.5899,  1.3495,  0.5590,
         0.7371,  0.4896, -0.3513,  0.0725, -0.3708,  0.4630,  0.6545,  0.4737,
         0.2479, -0.1047,  0.0181, -0.3063,  0.1037, -0.7778,  0.6956,  0.9222,
         0.6012,  0.0772,  0.2425,  0.0207, -0.0327, -0.0573,  0.1389, -0.1451,
        -0.1664, -0.1466, -0.1063, -0.0657,  0.8509,  0.2879, -0.4630,  0.1120]) tensor([ 1.6863,  1.0690, -0.8493,  0.3108,  0.2977,  0.1016,  0.8517,  1.1010,
         0.0250,  1.1284, -0.6970,  0.7427, -0.0586, -0.4273, -0.4609,  0.0349,
         0.3795,  0.1820, -0.0561,  2.4789, -0.1287,  0.3736,  0.2838, -0.4916,
         0.3290,  0.4678, -0.1702, -0.1001, -0.2204, -0.1390, -0.0377,  0.2037,
         0.0473, -0.1183,  0.1511,  0.0187,  2.3339,  1.6150,  2.4354, -0.2636],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
79 tensor(5.1281, grad_fn=<MseLossBackward0>) tensor([-0.3852,  0.2736, -0.8177

        -0.0265,  0.0421, -0.0212, -0.0192, -0.0225, -0.0656, -0.4607, -0.0200]) tensor([ 1.7183,  1.0559, -0.8389,  0.3207,  0.2810,  0.0900,  0.8396,  1.0834,
         0.0326,  1.1335, -0.6884,  0.7538, -0.0712, -0.4306, -0.4591,  0.0350,
         0.3740,  0.1702, -0.0408,  2.4972, -0.1386,  0.3873,  0.2642, -0.5019,
         0.3150,  0.4555, -0.1283, -0.1002, -0.2204, -0.1395, -0.0370,  0.2044,
         0.0468, -0.1168,  0.1519,  0.0221,  2.3324,  1.6177,  2.4441, -0.2641],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
81 tensor(5.3658, grad_fn=<MseLossBackward0>) tensor([ 4.1840e-01, -7.2724e-01,  8.6337e-02,  8.3300e-01,  4.9948e-02,
        -5.2582e-01, -5.8844e-01, -3.4364e-01,  6.4337e-01, -1.3628e-01,
         4.1794e-02,  1.8540e-01,  4.4176e-01, -5.8421e-02, -1.7805e-01,
        -2.3556e-01, -1.9025e-01,  8.4151e-01, -6.5882e-01,  3.9666e-01,
         2.8728e-01,  1.0649e+00, -7.4276e-01, -1.0750e+00, -1.0201e+00,
         6.9224e-03,  2.4901e+00, 

torch.Size([100, 1])
86 tensor(4.9883, grad_fn=<MseLossBackward0>) tensor([-3.3500e-01, -5.7256e-02, -1.4149e-01, -2.4361e-02,  1.9526e-01,
         3.2956e-01,  6.4104e-02,  1.3169e-01,  4.8644e-02, -2.3185e-01,
         1.7972e-01,  5.6435e-01, -2.0576e-01,  4.4179e-02,  3.8754e-01,
         1.9348e-01,  2.5006e-01,  2.4087e-01, -2.5649e+00,  2.2862e+00,
         2.2468e-01,  1.8164e+00, -1.5607e+00, -1.7413e+00, -1.6901e+00,
         4.2873e-01,  1.5692e+00,  1.0820e-03, -2.3575e-02, -3.2399e-02,
         4.3452e-02, -2.2442e-02, -5.2870e-02, -3.7906e-02, -2.0523e-02,
        -1.5642e-01, -1.1856e-01,  9.3654e-03,  2.6298e+00,  5.4415e-03]) tensor([ 1.7663,  1.0823, -0.8417,  0.3265,  0.2610,  0.0578,  0.8427,  1.0402,
         0.0450,  1.1391, -0.6567,  0.7848, -0.1169, -0.4376, -0.4602,  0.0372,
         0.3681,  0.1194, -0.0111,  2.5682, -0.1464,  0.3984,  0.2400, -0.4864,
         0.3002,  0.4797, -0.1184, -0.0951, -0.2264, -0.1399, -0.0306,  0.2022,
         0.0464, -0.1167,  0

torch.Size([100, 1])
torch.Size([100, 1])
88 tensor(4.5253, grad_fn=<MseLossBackward0>) tensor([ 0.4504, -0.6968, -0.4023,  0.3898, -1.0881,  0.1387, -0.8868,  0.4469,
         0.2508,  0.2630,  0.0333, -0.6378,  0.2747, -0.2583,  0.2039, -0.0382,
         0.0256,  0.3791,  0.1623, -0.4810,  0.0072,  0.1481, -0.4667, -1.0067,
        -0.1756, -0.2461, -0.0106, -0.0576,  0.0218,  0.0619, -0.0238,  0.0663,
         0.0461,  0.0101, -0.0021, -0.0183, -0.0165, -0.1197, -0.4028, -0.0221]) tensor([ 1.7887,  1.0962, -0.8404,  0.3270,  0.2505,  0.0422,  0.8435,  1.0229,
         0.0447,  1.1459, -0.6491,  0.7879, -0.1332, -0.4398, -0.4639,  0.0359,
         0.3621,  0.1008,  0.0378,  2.5624, -0.1488,  0.3778,  0.2525, -0.4581,
         0.3175,  0.4806, -0.1341, -0.0934, -0.2279, -0.1399, -0.0280,  0.2011,
         0.0463, -0.1167,  0.1507,  0.0300,  2.3688,  1.6235,  2.4842, -0.2636],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
88 tensor(5.2539, grad_fn=<MseLossBac

93 tensor(4.9159, grad_fn=<MseLossBackward0>) tensor([-2.7691e-01, -4.5216e-02, -2.0744e-01,  1.1212e-02,  1.5405e-01,
         2.6840e-01,  2.4338e-02,  1.2945e-01,  9.8488e-02, -2.1115e-01,
         2.0150e-01,  6.0532e-01, -2.4669e-01,  6.2323e-02,  3.9591e-01,
         2.2396e-01,  2.8289e-01,  2.2655e-01, -2.5883e+00,  2.3879e+00,
         1.8392e-01,  1.8127e+00, -1.5671e+00, -1.7481e+00, -1.6952e+00,
         4.7226e-01,  1.7015e+00,  7.2904e-04, -3.5535e-02, -4.5948e-02,
         5.1964e-02, -3.6664e-02, -6.3993e-02, -5.5953e-02, -3.9093e-02,
        -1.6791e-01, -7.7661e-02,  4.8456e-02,  2.6517e+00,  2.9616e-03]) tensor([ 1.8243,  1.1431, -0.8325,  0.3338,  0.2284,  0.0115,  0.8446,  0.9974,
         0.0625,  1.1537, -0.6299,  0.8102, -0.1813, -0.4455, -0.4541,  0.0358,
         0.3592,  0.0545,  0.0661,  2.6270, -0.1470,  0.3900,  0.2321, -0.4523,
         0.3027,  0.4893, -0.1209, -0.0894, -0.2320, -0.1391, -0.0203,  0.1975,
         0.0407, -0.1180,  0.1515,  0.0305,  2.39

95 tensor(3.6364, grad_fn=<MseLossBackward0>) tensor([ 0.3579, -0.6393,  0.3441,  0.5077, -1.2219, -0.3263, -1.1412,  0.6092,
         0.3643,  0.1394, -0.2287,  0.1620, -0.3304,  0.1953,  0.0782,  0.1458,
         0.5323, -0.1685,  0.2644, -0.0975,  0.4433,  1.2048, -1.3040, -0.4472,
        -1.1661, -0.4458,  1.5195, -0.1203,  0.1217,  0.0110, -0.0978,  0.0893,
         0.0475,  0.0666,  0.0298,  0.0184, -0.3577,  0.3478, -0.3446,  0.0042]) tensor([ 1.8306,  1.1771, -0.8222,  0.3189,  0.2319,  0.0045,  0.8633,  0.9883,
         0.0606,  1.1577, -0.6265,  0.8114, -0.2009, -0.4483, -0.4565,  0.0324,
         0.3541,  0.0240,  0.1176,  2.6173, -0.1503,  0.3578,  0.2576, -0.4073,
         0.3313,  0.4855, -0.1648, -0.0873, -0.2329, -0.1377, -0.0169,  0.1954,
         0.0401, -0.1187,  0.1532,  0.0330,  2.4159,  1.6345,  2.5151, -0.2603],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
95 tensor(5.5885, grad_fn=<MseLossBackward0>) tensor([-0.7230,  0.2888,  0.5131

98 tensor(4.8652, grad_fn=<MseLossBackward0>) tensor([-2.0240e-01, -1.0668e-01, -1.3915e-01, -5.4749e-02,  2.1268e-01,
         3.1495e-01,  8.7864e-02,  1.0971e-01,  1.0258e-01, -1.9211e-01,
         1.9823e-01,  6.2965e-01, -2.7715e-01,  8.2173e-02,  4.1041e-01,
         2.5049e-01,  2.8079e-01,  1.8323e-01, -2.5843e+00,  2.4286e+00,
         1.9114e-01,  1.8001e+00, -1.5417e+00, -1.7222e+00, -1.6664e+00,
         4.4680e-01,  1.7801e+00,  2.9345e-02, -5.6081e-02, -2.6715e-02,
         6.5313e-02, -3.7717e-02, -7.8159e-02, -5.6546e-02, -3.5296e-02,
        -1.2910e-01, -7.2859e-02,  6.6838e-02,  2.6347e+00,  2.8306e-04]) tensor([ 1.8633,  1.1833, -0.8183,  0.3336,  0.2119, -0.0128,  0.8516,  0.9741,
         0.0774,  1.1637, -0.6167,  0.8142, -0.2175, -0.4566, -0.4453,  0.0270,
         0.3429, -0.0038,  0.1151,  2.6635, -0.1449,  0.3855,  0.2263, -0.4364,
         0.3015,  0.4879, -0.1263, -0.0876, -0.2334, -0.1364, -0.0147,  0.1956,
         0.0390, -0.1171,  0.1537,  0.0313,  2.43

99 tensor(5.5246, grad_fn=<MseLossBackward0>) tensor([-0.7242,  0.2922,  0.5150, -1.0362,  1.1850,  1.1379,  1.0973, -0.3764,
        -1.3422,  0.0121, -0.3324,  0.3943, -0.1275,  0.2606,  0.5135,  0.2898,
        -0.5045, -0.1082, -0.3379,  0.1417, -1.3193, -0.7351,  0.8902,  0.1030,
         0.9367, -1.4021,  0.5188,  0.1093, -0.0763, -0.0389,  0.1034, -0.1149,
        -0.1215, -0.1135,  0.0175,  0.1602, -0.3998,  0.0349,  0.0036, -0.0288]) tensor([ 1.8557,  1.2152, -0.8185,  0.3136,  0.2310, -0.0138,  0.8797,  0.9659,
         0.0760,  1.1624, -0.6143,  0.8144, -0.2231, -0.4620, -0.4498,  0.0203,
         0.3341, -0.0289,  0.1532,  2.6452, -0.1528,  0.3420,  0.2662, -0.3918,
         0.3412,  0.4879, -0.1855, -0.0849, -0.2349, -0.1358, -0.0117,  0.1930,
         0.0376, -0.1187,  0.1541,  0.0330,  2.4454,  1.6339,  2.5332, -0.2580],
       grad_fn=<CatBackward0>)
torch.Size([100, 1])
torch.Size([100, 1])
99 tensor(4.1838, grad_fn=<MseLossBackward0>) tensor([ 0.7226, -0.5040,  0.6118

In [13]:
parameters_to_vector(model.parameters())

tensor([ 1.8768,  1.1992, -0.8163,  0.3325,  0.2062, -0.0229,  0.8547,  0.9673,
         0.0898,  1.1644, -0.6107,  0.8203, -0.2283, -0.4644, -0.4407,  0.0198,
         0.3353, -0.0331,  0.1337,  2.6766, -0.1442,  0.3831,  0.2246, -0.4307,
         0.3011,  0.4861, -0.1293, -0.0865, -0.2343, -0.1352, -0.0121,  0.1942,
         0.0371, -0.1175,  0.1541,  0.0319,  2.4447,  1.6360,  2.5613, -0.2576],
       grad_fn=<CatBackward0>)

In [14]:
class CNN_Net(nn.Module):
    def __init__(self):
        super(CNN_Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2,2)
        self.dropout = nn.Dropout(p=0.2)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        output = self.fc2(x)
        return output

In [15]:
class ClientUpdate(object):
    def __init__(self, dataset, batchSize, alpha, lamda, epochs, projection_list, projected_weights):
        self.train_loader = DataLoader(MyDataset(dataset["features"], dataset["label"]), batch_size=batchSize, shuffle=True)
        #self.learning_rate = learning_rate
        self.epochs = epochs
        self.batchSize = batchSize

    def train(self, model):
        criterion = nn.MSELoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.5)

        e_loss = []
        for epoch in range(1, self.epochs+1):
            train_loss = 0
            model.train()
            for i, (data, labels) in zip(range(1), self.train_loader):
                data, labels = data, labels
                optimizer.zero_grad() 
                output = model(data)  
                loss = criterion(output, labels)
                #loss += mu/2 * torch.norm(client_param.data - server_param.data)**2
                loss.backward()
                grads = grads_to_vector(model.parameters())
                #optimizer.step()
                train_loss += loss.item()*data.size(0)
                weights = parameters_to_vector(model.parameters())
                mat_vec_sum = torch.zeros_like(weights)
                for j in G.neighbors(model.user_id):
                    mat_vec_sum = torch.add(mat_vec_sum, torch.matmul(torch.transpose(projection_list[model.user_id][j], 0, 1), 
                                                         projected_weights[j][model.user_id] - projected_weights[model.user_id][j]))
                
                model_update = parameters_to_vector(model.parameters()) - alpha * (grads + lamda * mat_vec_sum)
                
            vector_to_parameters(parameters=model.parameters(), vec=model_update)
                

            train_loss = train_loss/self.batchSize#len(self.train_loader.dataset) 
            e_loss.append(train_loss)

        total_loss = e_loss#sum(e_loss)/len(e_loss)

        return model.state_dict(), total_loss

In [16]:
# Preparing projection matrices
models = [MLP_Net(user_id=i) for i in range(no_users)]
#temp = MLP_Net()
projection_list = []
projected_weights = []

def update_ProjWeight(projection_list, projected_weights, first_run=True):
    #projected_weights = []
    for i in range(no_users):
        neighbors_mat = []
        neighbors_weights = []
        for j in range(no_users):
            if j in G.neighbors(i):
                with torch.no_grad():
                    if first_run == True:
                        row, column = d0, parameters_to_vector(models[i].parameters()).size()[0]
                        mat = torch.zeros((row, column))
                        mat.fill_diagonal_(1.0 + 1.0 * float(np.random.randn(1)))
                        neighbors_mat.append(mat)
                        neighbors_weights.append(torch.matmul(mat, parameters_to_vector(models[i].parameters())))
                    else:
                        neighbors_weights.append(torch.matmul(projection_list[i][j], parameters_to_vector(models[i].parameters())))
            else:
                neighbors_mat.append(0)
                neighbors_weights.append(0)
        if first_run == True:
            projection_list.append(neighbors_mat)
        projected_weights.append(neighbors_weights)

update_ProjWeight(projection_list, projected_weights)



In [17]:
print(projection_list[0])

[0, 0, 0, 0, 0, 0, tensor([[-0.0910,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.0910,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0910,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
       

In [18]:
def testing(model, dataset, bs, criterion): 
    test_loss = 0
    correct = 0
    test_loader = DataLoader(MyDataset(X_test, y_test), batch_size=bs)
    l = len(test_loader)
    model.eval()
    for data, labels in test_loader:
        data, labels = data, labels
        output = model(data)
        loss = criterion(output, labels)
        test_loss += loss.item()*data.size(0)
        #_, pred = torch.max(output, 1)
        #correct += pred.eq(labels.data.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)
    
    return test_loss

In [19]:
def rel_error(model):
    return (torch.norm(parameters_to_vector(model.parameters()) - datapoints[model.user_id]['exact_weights']) / torch.norm(datapoints[model.user_id]['exact_weights'])).detach()

In [20]:
torch.norm(datapoints[model.user_id]['exact_weights'])

tensor(5.0990, dtype=torch.float64)

In [21]:
model = MLP_Net(user_id=0)

from torch.nn.utils import parameters_to_vector, vector_to_parameters

with torch.no_grad():    
    params = parameters_to_vector(model.parameters())

    print(params)

params *= 2.

vector_to_parameters(parameters=model.parameters(), vec=params)

parameters_to_vector(model.parameters())





tensor([-0.2990,  0.0029, -0.2107, -0.2244,  0.2317, -0.2778,  0.1295,  0.0966,
        -0.0209,  0.2944, -0.0308,  0.2008, -0.2231,  0.0903, -0.2803,  0.2837,
        -0.3285, -0.0821, -0.0325,  0.1278,  0.0808, -0.2796, -0.2289,  0.3253,
        -0.0260,  0.1448, -0.2644, -0.2124, -0.1185,  0.1464,  0.2814,  0.2077,
        -0.2523, -0.1028,  0.2875, -0.2432, -0.1797, -0.0783, -0.2253,  0.4485])


tensor([-0.5980,  0.0058, -0.4213, -0.4488,  0.4633, -0.5556,  0.2591,  0.1932,
        -0.0417,  0.5888, -0.0616,  0.4016, -0.4463,  0.1806, -0.5606,  0.5674,
        -0.6570, -0.1642, -0.0651,  0.2556,  0.1615, -0.5591, -0.4578,  0.6505,
        -0.0521,  0.2896, -0.5287, -0.4248, -0.2371,  0.2928,  0.5627,  0.4154,
        -0.5046, -0.2056,  0.5750, -0.4863, -0.3594, -0.1567, -0.4507,  0.8970],
       grad_fn=<CatBackward0>)

In [22]:
#global_model = CNN_Net().cuda()
models = [MLP_Net(user_id=i) for i in range(no_users)]
dummy_models = [MLP_Net(user_id=i) for i in range(no_users)]

#model.load_state_dict(global_model.state_dict())

criterion = nn.MSELoss()


train_loss = []
test_loss = []
test_accuracy = []
total_rel_error = []

for curr_round in tqdm(range(1, it+1)):
    w, local_loss = [], []

    
    for i in range(no_users):
        dummy_models[i].load_state_dict(models[i].state_dict())
        local_update = ClientUpdate(dataset=datapoints[i], batchSize=batch_size, alpha=alpha, lamda=lamda, epochs=1, projection_list=projection_list, projected_weights=projected_weights)
        weights, loss = local_update.train(dummy_models[i])
        w.append(weights)
        local_loss.append(loss)
        models[i].load_state_dict(w[i])
    
    if curr_round % 50 == 0:
        alpha *= 0.9
        
    
    
    # Update prjection matrix
    
    #print(projection_list[0], projected_weights[0])
    
    for i in range(no_users):
        weights = parameters_to_vector(models[i].parameters())
        for j in G.neighbors(i):
            mat_vec_sum = torch.zeros(d0)
            for k in G.neighbors(i):
                mat_vec_sum = torch.add(mat_vec_sum, projected_weights[i][k] - projected_weights[k][i])
            temp_mat = torch.outer(mat_vec_sum, weights).clone()


            projection_list[i][j] = torch.add(projection_list[i][j], -1 * eta * lamda * temp_mat)
                                         
    projected_weights = []                                          
    update_ProjWeight(projection_list, projected_weights, first_run=False)
        
        
        
    
    




          
            

    local_test_acc = []
    local_test_loss = []
    user_rel_error = 0
    for k in range(no_users):
      
        g_loss = testing(models[i], datapoints[i], 50, criterion)
        local_test_loss.append(g_loss)
        #user_rel_error += rel_error(models[i])
    
    
        

    g_loss = sum(local_test_loss) / len(local_test_loss)
    #total_rel_error.append(user_rel_error / no_users)
    
    

    test_loss.append(g_loss)
    #test_accuracy.append(g_accuracy)
    print("Training_loss %2.5f"% (test_loss[-1]))

  0%|          | 1/2000 [00:00<25:59,  1.28it/s]

Training_loss 45.46236


  0%|          | 2/2000 [00:01<26:51,  1.24it/s]

Training_loss 44.77860


  0%|          | 3/2000 [00:02<25:59,  1.28it/s]

Training_loss 44.27951


  0%|          | 4/2000 [00:03<26:27,  1.26it/s]

Training_loss 43.61483


  0%|          | 5/2000 [00:04<31:39,  1.05it/s]

Training_loss 42.92942


  0%|          | 6/2000 [00:05<33:13,  1.00it/s]

Training_loss 42.22031


  0%|          | 7/2000 [00:06<33:37,  1.01s/it]

Training_loss 41.38924


  0%|          | 8/2000 [00:07<34:57,  1.05s/it]

Training_loss 40.55104


  0%|          | 9/2000 [00:09<43:43,  1.32s/it]

Training_loss 38.93654


  0%|          | 10/2000 [00:11<48:30,  1.46s/it]

Training_loss 37.77702


  1%|          | 11/2000 [00:13<55:04,  1.66s/it]

Training_loss 36.30158


  1%|          | 12/2000 [00:15<57:11,  1.73s/it]

Training_loss 34.74037


  1%|          | 13/2000 [00:16<54:23,  1.64s/it]

Training_loss 33.04861


  1%|          | 14/2000 [00:18<51:34,  1.56s/it]

Training_loss 31.67837


  1%|          | 15/2000 [00:19<50:13,  1.52s/it]

Training_loss 30.45491


  1%|          | 16/2000 [00:20<46:41,  1.41s/it]

Training_loss 28.83760


  1%|          | 17/2000 [00:21<44:32,  1.35s/it]

Training_loss 27.28283


  1%|          | 18/2000 [00:23<42:53,  1.30s/it]

Training_loss 25.97546


  1%|          | 19/2000 [00:25<50:27,  1.53s/it]

Training_loss 24.50185


  1%|          | 20/2000 [00:29<1:16:15,  2.31s/it]

Training_loss 23.40433


  1%|          | 21/2000 [00:31<1:18:50,  2.39s/it]

Training_loss 22.36016


  1%|          | 22/2000 [00:33<1:07:16,  2.04s/it]

Training_loss 21.61693


  1%|          | 23/2000 [00:34<59:24,  1.80s/it]  

Training_loss 20.46785


  1%|          | 24/2000 [00:35<53:08,  1.61s/it]

Training_loss 19.94251


  1%|▏         | 25/2000 [00:36<48:46,  1.48s/it]

Training_loss 19.48666


  1%|▏         | 26/2000 [00:37<45:24,  1.38s/it]

Training_loss 18.70886


  1%|▏         | 27/2000 [00:39<43:12,  1.31s/it]

Training_loss 18.35434


  1%|▏         | 28/2000 [00:40<42:29,  1.29s/it]

Training_loss 17.88449


  1%|▏         | 29/2000 [00:41<41:58,  1.28s/it]

Training_loss 17.55983


  2%|▏         | 30/2000 [00:43<53:42,  1.64s/it]

Training_loss 17.41804


  2%|▏         | 31/2000 [00:47<1:07:57,  2.07s/it]

Training_loss 17.52664


  2%|▏         | 32/2000 [00:49<1:14:41,  2.28s/it]

Training_loss 17.21831


  2%|▏         | 33/2000 [00:52<1:18:51,  2.41s/it]

Training_loss 17.24532


  2%|▏         | 34/2000 [00:54<1:13:29,  2.24s/it]

Training_loss 17.45346


  2%|▏         | 35/2000 [00:56<1:08:34,  2.09s/it]

Training_loss 17.57334


  2%|▏         | 36/2000 [00:57<1:00:31,  1.85s/it]

Training_loss 17.60733


  2%|▏         | 37/2000 [00:58<54:37,  1.67s/it]  

Training_loss 17.38238


  2%|▏         | 38/2000 [00:59<49:29,  1.51s/it]

Training_loss 17.15529


  2%|▏         | 39/2000 [01:01<47:18,  1.45s/it]

Training_loss 17.01805


  2%|▏         | 40/2000 [01:02<46:22,  1.42s/it]

Training_loss 17.01682


  2%|▏         | 41/2000 [01:03<44:27,  1.36s/it]

Training_loss 16.79065


  2%|▏         | 42/2000 [01:04<43:21,  1.33s/it]

Training_loss 16.67323


  2%|▏         | 43/2000 [01:06<45:58,  1.41s/it]

Training_loss 16.87122


  2%|▏         | 44/2000 [01:07<44:33,  1.37s/it]

Training_loss 17.23959


  2%|▏         | 45/2000 [01:09<43:35,  1.34s/it]

Training_loss 17.47915


  2%|▏         | 46/2000 [01:10<44:36,  1.37s/it]

Training_loss 17.34358


  2%|▏         | 47/2000 [01:11<44:24,  1.36s/it]

Training_loss 17.07524


  2%|▏         | 48/2000 [01:13<42:36,  1.31s/it]

Training_loss 16.92674


  2%|▏         | 49/2000 [01:14<41:01,  1.26s/it]

Training_loss 16.80567


  2%|▎         | 50/2000 [01:15<41:09,  1.27s/it]

Training_loss 16.67976


  3%|▎         | 51/2000 [01:17<43:39,  1.34s/it]

Training_loss 16.44895


  3%|▎         | 52/2000 [01:18<43:05,  1.33s/it]

Training_loss 16.65519


  3%|▎         | 53/2000 [01:19<41:49,  1.29s/it]

Training_loss 16.99399


  3%|▎         | 54/2000 [01:20<39:36,  1.22s/it]

Training_loss 16.92186


  3%|▎         | 55/2000 [01:21<38:41,  1.19s/it]

Training_loss 16.97451


  3%|▎         | 56/2000 [01:22<37:06,  1.15s/it]

Training_loss 17.15155


  3%|▎         | 57/2000 [01:23<37:17,  1.15s/it]

Training_loss 16.82957


  3%|▎         | 58/2000 [01:24<36:41,  1.13s/it]

Training_loss 16.70756


  3%|▎         | 59/2000 [01:25<35:27,  1.10s/it]

Training_loss 16.65943


  3%|▎         | 60/2000 [01:27<35:28,  1.10s/it]

Training_loss 16.61578


  3%|▎         | 61/2000 [01:28<35:34,  1.10s/it]

Training_loss 16.46315


  3%|▎         | 62/2000 [01:29<35:41,  1.11s/it]

Training_loss 16.28452


  3%|▎         | 63/2000 [01:30<35:50,  1.11s/it]

Training_loss 16.24197


  3%|▎         | 64/2000 [01:32<42:57,  1.33s/it]

Training_loss 16.52859


  3%|▎         | 65/2000 [01:35<58:22,  1.81s/it]

Training_loss 16.42128


  3%|▎         | 66/2000 [01:37<59:59,  1.86s/it]

Training_loss 16.45028


  3%|▎         | 67/2000 [01:38<53:50,  1.67s/it]

Training_loss 16.64060


  3%|▎         | 68/2000 [01:39<49:44,  1.55s/it]

Training_loss 16.54936


  3%|▎         | 69/2000 [01:41<48:23,  1.50s/it]

Training_loss 16.51918


  4%|▎         | 70/2000 [01:42<48:19,  1.50s/it]

Training_loss 16.30855


  4%|▎         | 71/2000 [01:44<48:08,  1.50s/it]

Training_loss 16.41202


  4%|▎         | 72/2000 [01:45<50:52,  1.58s/it]

Training_loss 16.04620


  4%|▎         | 73/2000 [01:47<51:35,  1.61s/it]

Training_loss 16.15463


  4%|▎         | 74/2000 [01:49<50:45,  1.58s/it]

Training_loss 15.97897


  4%|▍         | 75/2000 [01:50<48:39,  1.52s/it]

Training_loss 16.46732


  4%|▍         | 76/2000 [01:51<44:25,  1.39s/it]

Training_loss 16.38977


  4%|▍         | 77/2000 [01:52<44:46,  1.40s/it]

Training_loss 16.37760


  4%|▍         | 78/2000 [01:55<52:20,  1.63s/it]

Training_loss 16.15393


  4%|▍         | 79/2000 [01:57<55:44,  1.74s/it]

Training_loss 16.11586


  4%|▍         | 80/2000 [01:58<52:08,  1.63s/it]

Training_loss 16.12028


  4%|▍         | 81/2000 [02:01<1:01:41,  1.93s/it]

Training_loss 16.15081


  4%|▍         | 82/2000 [02:02<59:59,  1.88s/it]  

Training_loss 16.13739


  4%|▍         | 83/2000 [02:04<1:00:57,  1.91s/it]

Training_loss 16.45052


  4%|▍         | 84/2000 [02:06<56:19,  1.76s/it]  

Training_loss 16.60716


  4%|▍         | 85/2000 [02:07<50:28,  1.58s/it]

Training_loss 16.70121


  4%|▍         | 86/2000 [02:08<46:27,  1.46s/it]

Training_loss 16.54456


  4%|▍         | 87/2000 [02:09<46:05,  1.45s/it]

Training_loss 16.31087


  4%|▍         | 88/2000 [02:11<45:21,  1.42s/it]

Training_loss 16.25019


  4%|▍         | 89/2000 [02:12<43:30,  1.37s/it]

Training_loss 16.17428


  4%|▍         | 90/2000 [02:13<43:51,  1.38s/it]

Training_loss 16.21722


  5%|▍         | 91/2000 [02:15<43:56,  1.38s/it]

Training_loss 16.21095


  5%|▍         | 92/2000 [02:16<42:41,  1.34s/it]

Training_loss 16.06153


  5%|▍         | 93/2000 [02:18<43:14,  1.36s/it]

Training_loss 16.00530


  5%|▍         | 94/2000 [02:19<45:04,  1.42s/it]

Training_loss 16.09065


  5%|▍         | 95/2000 [02:21<48:24,  1.52s/it]

Training_loss 15.96290


  5%|▍         | 96/2000 [02:22<46:10,  1.46s/it]

Training_loss 15.87660


  5%|▍         | 97/2000 [02:23<42:50,  1.35s/it]

Training_loss 15.77390


  5%|▍         | 98/2000 [02:25<43:49,  1.38s/it]

Training_loss 15.89123


  5%|▍         | 99/2000 [02:26<44:23,  1.40s/it]

Training_loss 16.09853


  5%|▌         | 100/2000 [02:28<43:56,  1.39s/it]

Training_loss 16.24659


  5%|▌         | 101/2000 [02:29<41:56,  1.32s/it]

Training_loss 16.07429


  5%|▌         | 102/2000 [02:30<40:31,  1.28s/it]

Training_loss 16.43224


  5%|▌         | 103/2000 [02:31<41:08,  1.30s/it]

Training_loss 16.41641


  5%|▌         | 104/2000 [02:33<42:11,  1.34s/it]

Training_loss 16.55300


  5%|▌         | 105/2000 [02:34<42:07,  1.33s/it]

Training_loss 16.47550


  5%|▌         | 106/2000 [02:35<41:25,  1.31s/it]

Training_loss 16.57981


  5%|▌         | 107/2000 [02:38<50:30,  1.60s/it]

Training_loss 16.61744


  5%|▌         | 108/2000 [02:39<54:11,  1.72s/it]

Training_loss 16.71633


  5%|▌         | 109/2000 [02:41<53:22,  1.69s/it]

Training_loss 17.00516


  6%|▌         | 110/2000 [02:43<53:05,  1.69s/it]

Training_loss 17.58314


  6%|▌         | 111/2000 [02:46<1:08:34,  2.18s/it]

Training_loss 17.03974


  6%|▌         | 112/2000 [02:50<1:22:04,  2.61s/it]

Training_loss 16.85636


  6%|▌         | 113/2000 [02:51<1:12:28,  2.30s/it]

Training_loss 16.56378


  6%|▌         | 114/2000 [02:53<1:05:01,  2.07s/it]

Training_loss 16.35438


  6%|▌         | 115/2000 [02:55<1:01:02,  1.94s/it]

Training_loss 16.45952


  6%|▌         | 116/2000 [02:56<55:59,  1.78s/it]  

Training_loss 16.19572


  6%|▌         | 117/2000 [02:57<53:33,  1.71s/it]

Training_loss 16.13235


  6%|▌         | 118/2000 [02:59<49:37,  1.58s/it]

Training_loss 16.50719


  6%|▌         | 119/2000 [03:01<51:59,  1.66s/it]

Training_loss 16.71994


  6%|▌         | 120/2000 [03:02<51:30,  1.64s/it]

Training_loss 16.65295


  6%|▌         | 121/2000 [03:03<48:08,  1.54s/it]

Training_loss 16.50208


  6%|▌         | 122/2000 [03:05<45:31,  1.45s/it]

Training_loss 16.44139


  6%|▌         | 123/2000 [03:06<44:00,  1.41s/it]

Training_loss 16.73090


  6%|▌         | 124/2000 [03:08<48:10,  1.54s/it]

Training_loss 16.59126


  6%|▋         | 125/2000 [03:12<1:09:21,  2.22s/it]

Training_loss 16.42958


  6%|▋         | 126/2000 [03:14<1:08:58,  2.21s/it]

Training_loss 16.32221


  6%|▋         | 127/2000 [03:16<1:05:49,  2.11s/it]

Training_loss 16.66120


  6%|▋         | 128/2000 [03:17<1:01:27,  1.97s/it]

Training_loss 16.78265


  6%|▋         | 129/2000 [03:19<58:04,  1.86s/it]  

Training_loss 16.77286


  6%|▋         | 130/2000 [03:21<57:05,  1.83s/it]

Training_loss 16.31350


  7%|▋         | 131/2000 [03:23<57:36,  1.85s/it]

Training_loss 16.41048


  7%|▋         | 132/2000 [03:25<59:35,  1.91s/it]

Training_loss 16.57332


  7%|▋         | 133/2000 [03:26<58:19,  1.87s/it]

Training_loss 16.49023


  7%|▋         | 134/2000 [03:28<52:35,  1.69s/it]

Training_loss 16.29037


  7%|▋         | 135/2000 [03:29<46:57,  1.51s/it]

Training_loss 16.22397


  7%|▋         | 136/2000 [03:30<42:24,  1.37s/it]

Training_loss 16.06041


  7%|▋         | 137/2000 [03:31<38:58,  1.26s/it]

Training_loss 15.91206


  7%|▋         | 138/2000 [03:32<37:19,  1.20s/it]

Training_loss 15.68855


  7%|▋         | 139/2000 [03:33<36:47,  1.19s/it]

Training_loss 15.74695


  7%|▋         | 140/2000 [03:34<34:43,  1.12s/it]

Training_loss 15.87361


  7%|▋         | 141/2000 [03:35<35:58,  1.16s/it]

Training_loss 15.80866


  7%|▋         | 142/2000 [03:37<36:46,  1.19s/it]

Training_loss 15.84775


  7%|▋         | 143/2000 [03:38<35:42,  1.15s/it]

Training_loss 15.86352


  7%|▋         | 144/2000 [03:39<35:05,  1.13s/it]

Training_loss 16.09082


  7%|▋         | 145/2000 [03:40<36:59,  1.20s/it]

Training_loss 16.40703


  7%|▋         | 146/2000 [03:42<47:32,  1.54s/it]

Training_loss 16.22852


  7%|▋         | 147/2000 [03:46<1:07:46,  2.19s/it]

Training_loss 16.23040


  7%|▋         | 148/2000 [03:47<58:56,  1.91s/it]  

Training_loss 16.31460


  7%|▋         | 149/2000 [03:49<51:41,  1.68s/it]

Training_loss 16.29091


  8%|▊         | 150/2000 [03:50<46:12,  1.50s/it]

Training_loss 16.50687


  8%|▊         | 151/2000 [03:51<42:53,  1.39s/it]

Training_loss 16.80645


  8%|▊         | 152/2000 [03:52<40:07,  1.30s/it]

Training_loss 16.84357


  8%|▊         | 153/2000 [03:53<38:32,  1.25s/it]

Training_loss 16.88430


  8%|▊         | 154/2000 [03:54<37:07,  1.21s/it]

Training_loss 16.93709


  8%|▊         | 155/2000 [03:55<36:28,  1.19s/it]

Training_loss 16.82261


  8%|▊         | 156/2000 [03:56<35:33,  1.16s/it]

Training_loss 16.79334


  8%|▊         | 157/2000 [03:57<35:53,  1.17s/it]

Training_loss 16.61888


  8%|▊         | 158/2000 [03:59<35:46,  1.17s/it]

Training_loss 16.42808


  8%|▊         | 159/2000 [04:02<51:33,  1.68s/it]

Training_loss 16.22600


  8%|▊         | 160/2000 [04:06<1:16:06,  2.48s/it]

Training_loss 16.20254


  8%|▊         | 161/2000 [04:08<1:08:39,  2.24s/it]

Training_loss 16.33711


  8%|▊         | 162/2000 [04:09<1:02:24,  2.04s/it]

Training_loss 16.08878


  8%|▊         | 163/2000 [04:10<55:36,  1.82s/it]  

Training_loss 15.95203


  8%|▊         | 164/2000 [04:12<50:27,  1.65s/it]

Training_loss 15.96483


  8%|▊         | 165/2000 [04:13<47:55,  1.57s/it]

Training_loss 16.11927


  8%|▊         | 166/2000 [04:14<44:08,  1.44s/it]

Training_loss 16.08101


  8%|▊         | 167/2000 [04:15<40:42,  1.33s/it]

Training_loss 16.05473


  8%|▊         | 168/2000 [04:16<38:10,  1.25s/it]

Training_loss 16.14341


  8%|▊         | 169/2000 [04:17<36:40,  1.20s/it]

Training_loss 16.29874


  8%|▊         | 170/2000 [04:18<34:48,  1.14s/it]

Training_loss 16.26092


  9%|▊         | 171/2000 [04:19<33:42,  1.11s/it]

Training_loss 16.18849


  9%|▊         | 172/2000 [04:20<32:57,  1.08s/it]

Training_loss 15.77020


  9%|▊         | 173/2000 [04:22<33:23,  1.10s/it]

Training_loss 15.73229


  9%|▊         | 174/2000 [04:23<32:59,  1.08s/it]

Training_loss 15.55820


  9%|▉         | 175/2000 [04:24<32:45,  1.08s/it]

Training_loss 15.64104


  9%|▉         | 176/2000 [04:25<32:25,  1.07s/it]

Training_loss 15.76790


  9%|▉         | 177/2000 [04:26<32:57,  1.08s/it]

Training_loss 15.83256


  9%|▉         | 178/2000 [04:27<32:41,  1.08s/it]

Training_loss 15.95017


  9%|▉         | 179/2000 [04:28<32:38,  1.08s/it]

Training_loss 15.86476


  9%|▉         | 180/2000 [04:30<38:19,  1.26s/it]

Training_loss 16.06257


  9%|▉         | 181/2000 [04:37<1:36:06,  3.17s/it]

Training_loss 16.01213


  9%|▉         | 182/2000 [04:39<1:23:07,  2.74s/it]

Training_loss 15.82576


  9%|▉         | 183/2000 [04:40<1:07:46,  2.24s/it]

Training_loss 15.98779


  9%|▉         | 184/2000 [04:41<57:46,  1.91s/it]  

Training_loss 16.10663


  9%|▉         | 185/2000 [04:42<49:28,  1.64s/it]

Training_loss 16.05501


  9%|▉         | 186/2000 [04:43<44:08,  1.46s/it]

Training_loss 15.97610


  9%|▉         | 187/2000 [04:45<41:34,  1.38s/it]

Training_loss 15.97846


  9%|▉         | 188/2000 [04:46<41:02,  1.36s/it]

Training_loss 16.33126


  9%|▉         | 189/2000 [04:47<37:45,  1.25s/it]

Training_loss 16.35612


 10%|▉         | 190/2000 [04:48<37:20,  1.24s/it]

Training_loss 16.30519


 10%|▉         | 191/2000 [04:49<35:38,  1.18s/it]

Training_loss 16.37390


 10%|▉         | 192/2000 [04:50<34:57,  1.16s/it]

Training_loss 16.57513


 10%|▉         | 193/2000 [04:51<33:45,  1.12s/it]

Training_loss 16.69907


 10%|▉         | 194/2000 [04:52<33:13,  1.10s/it]

Training_loss 16.38286


 10%|▉         | 195/2000 [04:53<33:46,  1.12s/it]

Training_loss 16.06997


 10%|▉         | 196/2000 [04:55<33:16,  1.11s/it]

Training_loss 16.16551


 10%|▉         | 197/2000 [04:56<32:30,  1.08s/it]

Training_loss 15.95099


 10%|▉         | 198/2000 [04:57<32:30,  1.08s/it]

Training_loss 15.78550


 10%|▉         | 199/2000 [04:58<32:15,  1.07s/it]

Training_loss 15.80966


 10%|█         | 200/2000 [04:59<33:02,  1.10s/it]

Training_loss 15.57934


 10%|█         | 201/2000 [05:00<32:17,  1.08s/it]

Training_loss 15.54366


 10%|█         | 202/2000 [05:01<34:26,  1.15s/it]

Training_loss 15.82990


 10%|█         | 203/2000 [05:02<33:54,  1.13s/it]

Training_loss 15.84996


 10%|█         | 204/2000 [05:03<33:03,  1.10s/it]

Training_loss 15.73777


 10%|█         | 205/2000 [05:04<32:17,  1.08s/it]

Training_loss 15.66145


 10%|█         | 206/2000 [05:06<33:23,  1.12s/it]

Training_loss 15.62301


 10%|█         | 207/2000 [05:07<32:24,  1.08s/it]

Training_loss 15.85517


 10%|█         | 208/2000 [05:08<32:32,  1.09s/it]

Training_loss 15.90111


 10%|█         | 209/2000 [05:09<32:06,  1.08s/it]

Training_loss 16.34756


 10%|█         | 210/2000 [05:10<32:06,  1.08s/it]

Training_loss 16.23536


 11%|█         | 211/2000 [05:11<30:48,  1.03s/it]

Training_loss 16.19024


 11%|█         | 212/2000 [05:12<30:32,  1.03s/it]

Training_loss 16.18142


 11%|█         | 213/2000 [05:13<30:48,  1.03s/it]

Training_loss 16.01789


 11%|█         | 214/2000 [05:14<31:17,  1.05s/it]

Training_loss 15.96035


 11%|█         | 215/2000 [05:15<31:24,  1.06s/it]

Training_loss 15.90371


 11%|█         | 216/2000 [05:16<31:08,  1.05s/it]

Training_loss 15.92818


 11%|█         | 217/2000 [05:21<1:09:49,  2.35s/it]

Training_loss 15.85871


 11%|█         | 218/2000 [05:23<59:49,  2.01s/it]  

Training_loss 15.57008


 11%|█         | 219/2000 [05:24<53:09,  1.79s/it]

Training_loss 15.76187


 11%|█         | 220/2000 [05:25<49:12,  1.66s/it]

Training_loss 15.74382


 11%|█         | 221/2000 [05:27<51:51,  1.75s/it]

Training_loss 15.86723


 11%|█         | 222/2000 [05:29<48:11,  1.63s/it]

Training_loss 15.91915


 11%|█         | 223/2000 [05:30<50:51,  1.72s/it]

Training_loss 15.63660


 11%|█         | 224/2000 [05:32<47:29,  1.60s/it]

Training_loss 15.59977


 11%|█▏        | 225/2000 [05:33<48:17,  1.63s/it]

Training_loss 15.59294


 11%|█▏        | 226/2000 [05:35<44:22,  1.50s/it]

Training_loss 15.74774


 11%|█▏        | 227/2000 [05:36<39:58,  1.35s/it]

Training_loss 15.96672


 11%|█▏        | 228/2000 [05:37<37:12,  1.26s/it]

Training_loss 16.18252


 11%|█▏        | 229/2000 [05:38<37:21,  1.27s/it]

Training_loss 16.29416


 12%|█▏        | 230/2000 [05:40<40:00,  1.36s/it]

Training_loss 15.87336


 12%|█▏        | 231/2000 [05:41<43:04,  1.46s/it]

Training_loss 15.71036


 12%|█▏        | 232/2000 [05:43<42:11,  1.43s/it]

Training_loss 15.63415


 12%|█▏        | 233/2000 [05:44<41:01,  1.39s/it]

Training_loss 15.72348


 12%|█▏        | 234/2000 [05:45<37:29,  1.27s/it]

Training_loss 15.70525


 12%|█▏        | 235/2000 [05:46<37:41,  1.28s/it]

Training_loss 15.33398


 12%|█▏        | 236/2000 [05:47<36:21,  1.24s/it]

Training_loss 15.49925


 12%|█▏        | 237/2000 [05:49<37:02,  1.26s/it]

Training_loss 15.69260


 12%|█▏        | 238/2000 [05:50<39:48,  1.36s/it]

Training_loss 15.64551


 12%|█▏        | 239/2000 [05:52<40:13,  1.37s/it]

Training_loss 15.60722


 12%|█▏        | 240/2000 [05:53<41:42,  1.42s/it]

Training_loss 15.44855


 12%|█▏        | 241/2000 [05:55<40:37,  1.39s/it]

Training_loss 15.66122


 12%|█▏        | 242/2000 [05:56<38:20,  1.31s/it]

Training_loss 15.38820


 12%|█▏        | 243/2000 [05:57<37:17,  1.27s/it]

Training_loss 15.57711


 12%|█▏        | 244/2000 [05:58<37:22,  1.28s/it]

Training_loss 15.63644


 12%|█▏        | 245/2000 [05:59<36:44,  1.26s/it]

Training_loss 16.23190


 12%|█▏        | 246/2000 [06:01<39:51,  1.36s/it]

Training_loss 16.41170


 12%|█▏        | 247/2000 [06:02<39:21,  1.35s/it]

Training_loss 15.91106


 12%|█▏        | 248/2000 [06:03<38:11,  1.31s/it]

Training_loss 15.94477


 12%|█▏        | 249/2000 [06:05<38:58,  1.34s/it]

Training_loss 16.00735


 12%|█▎        | 250/2000 [06:06<39:04,  1.34s/it]

Training_loss 15.89991


 13%|█▎        | 251/2000 [06:08<40:42,  1.40s/it]

Training_loss 15.99546


 13%|█▎        | 252/2000 [06:10<46:25,  1.59s/it]

Training_loss 15.79584


 13%|█▎        | 253/2000 [06:12<49:54,  1.71s/it]

Training_loss 15.59107


 13%|█▎        | 254/2000 [06:14<52:11,  1.79s/it]

Training_loss 15.51365


 13%|█▎        | 255/2000 [06:16<53:37,  1.84s/it]

Training_loss 15.53021


 13%|█▎        | 256/2000 [06:17<49:24,  1.70s/it]

Training_loss 15.44535


 13%|█▎        | 257/2000 [06:19<49:48,  1.71s/it]

Training_loss 15.28541


 13%|█▎        | 258/2000 [06:21<56:06,  1.93s/it]

Training_loss 15.27009


 13%|█▎        | 259/2000 [06:23<54:36,  1.88s/it]

Training_loss 15.57769


 13%|█▎        | 260/2000 [06:25<57:56,  2.00s/it]

Training_loss 15.27278


 13%|█▎        | 261/2000 [06:27<57:59,  2.00s/it]

Training_loss 15.10251


 13%|█▎        | 262/2000 [06:29<59:16,  2.05s/it]

Training_loss 14.96835


 13%|█▎        | 263/2000 [06:31<56:45,  1.96s/it]

Training_loss 14.92740


 13%|█▎        | 264/2000 [06:33<51:56,  1.80s/it]

Training_loss 15.07742


 13%|█▎        | 265/2000 [06:36<1:03:27,  2.19s/it]

Training_loss 14.90128


 13%|█▎        | 266/2000 [06:39<1:13:31,  2.54s/it]

Training_loss 15.00540


 13%|█▎        | 267/2000 [06:43<1:22:32,  2.86s/it]

Training_loss 15.22758


 13%|█▎        | 268/2000 [06:45<1:16:58,  2.67s/it]

Training_loss 15.28308


 13%|█▎        | 269/2000 [06:46<1:05:25,  2.27s/it]

Training_loss 15.36464


 14%|█▎        | 270/2000 [06:47<55:40,  1.93s/it]  

Training_loss 15.36687


 14%|█▎        | 271/2000 [06:49<48:40,  1.69s/it]

Training_loss 15.45740


 14%|█▎        | 272/2000 [06:50<44:01,  1.53s/it]

Training_loss 15.38750


 14%|█▎        | 273/2000 [06:51<42:05,  1.46s/it]

Training_loss 15.47346


 14%|█▎        | 274/2000 [06:53<42:16,  1.47s/it]

Training_loss 15.58268


 14%|█▍        | 275/2000 [06:55<51:36,  1.80s/it]

Training_loss 15.47730


 14%|█▍        | 276/2000 [06:57<49:07,  1.71s/it]

Training_loss 15.28310


 14%|█▍        | 277/2000 [06:58<44:53,  1.56s/it]

Training_loss 15.05328


 14%|█▍        | 278/2000 [06:59<41:44,  1.45s/it]

Training_loss 15.26132


 14%|█▍        | 279/2000 [07:00<38:37,  1.35s/it]

Training_loss 15.52276


 14%|█▍        | 280/2000 [07:01<35:10,  1.23s/it]

Training_loss 15.45940


 14%|█▍        | 281/2000 [07:02<32:50,  1.15s/it]

Training_loss 15.46966


 14%|█▍        | 282/2000 [07:03<31:24,  1.10s/it]

Training_loss 15.44975


 14%|█▍        | 283/2000 [07:04<30:12,  1.06s/it]

Training_loss 15.49479


 14%|█▍        | 284/2000 [07:05<30:30,  1.07s/it]

Training_loss 15.50509


 14%|█▍        | 285/2000 [07:06<30:52,  1.08s/it]

Training_loss 15.64731


 14%|█▍        | 286/2000 [07:07<31:13,  1.09s/it]

Training_loss 15.68007


 14%|█▍        | 287/2000 [07:09<32:51,  1.15s/it]

Training_loss 15.65173


 14%|█▍        | 288/2000 [07:11<42:32,  1.49s/it]

Training_loss 15.60297


 14%|█▍        | 289/2000 [07:12<39:55,  1.40s/it]

Training_loss 15.85883


 14%|█▍        | 290/2000 [07:13<40:11,  1.41s/it]

Training_loss 15.65472


 15%|█▍        | 291/2000 [07:16<49:08,  1.73s/it]

Training_loss 15.77101


 15%|█▍        | 292/2000 [07:18<51:25,  1.81s/it]

Training_loss 15.97006


 15%|█▍        | 293/2000 [07:20<51:43,  1.82s/it]

Training_loss 16.03610


 15%|█▍        | 294/2000 [07:22<54:18,  1.91s/it]

Training_loss 16.02722


 15%|█▍        | 295/2000 [07:24<54:37,  1.92s/it]

Training_loss 16.29329


 15%|█▍        | 296/2000 [07:28<1:09:37,  2.45s/it]

Training_loss 15.89970


 15%|█▍        | 297/2000 [07:29<1:02:57,  2.22s/it]

Training_loss 16.18093


 15%|█▍        | 298/2000 [07:30<54:57,  1.94s/it]  

Training_loss 15.91376


 15%|█▍        | 299/2000 [07:32<50:53,  1.80s/it]

Training_loss 15.84315


 15%|█▌        | 300/2000 [07:33<46:52,  1.65s/it]

Training_loss 15.89787


 15%|█▌        | 301/2000 [07:35<48:30,  1.71s/it]

Training_loss 16.33314


 15%|█▌        | 302/2000 [07:37<46:48,  1.65s/it]

Training_loss 16.00816


 15%|█▌        | 303/2000 [07:38<43:06,  1.52s/it]

Training_loss 15.86496


 15%|█▌        | 304/2000 [07:39<43:22,  1.53s/it]

Training_loss 15.41205


 15%|█▌        | 305/2000 [07:41<46:19,  1.64s/it]

Training_loss 15.39004


 15%|█▌        | 306/2000 [07:43<45:09,  1.60s/it]

Training_loss 15.46527


 15%|█▌        | 307/2000 [07:44<42:57,  1.52s/it]

Training_loss 15.34985


 15%|█▌        | 308/2000 [07:45<40:40,  1.44s/it]

Training_loss 15.56712


 15%|█▌        | 309/2000 [07:47<38:52,  1.38s/it]

Training_loss 16.01618


 16%|█▌        | 310/2000 [07:48<36:49,  1.31s/it]

Training_loss 16.04090


 16%|█▌        | 311/2000 [07:49<36:21,  1.29s/it]

Training_loss 16.41081


 16%|█▌        | 312/2000 [07:50<36:57,  1.31s/it]

Training_loss 16.41854


 16%|█▌        | 313/2000 [07:52<37:26,  1.33s/it]

Training_loss 16.69346


 16%|█▌        | 314/2000 [07:53<37:58,  1.35s/it]

Training_loss 16.44874


 16%|█▌        | 314/2000 [07:54<42:27,  1.51s/it]


KeyboardInterrupt: 

In [None]:
#Training_loss 5.33078 with no communication

In [None]:
#plot.plot(test_loss)
parameters_to_vector(models[19].parameters())

In [None]:
for j in G.neighbors(0):
    print(j)

In [None]:
parameters_to_vector(models[0].parameters())

In [None]:
projection_list[0]

In [None]:
projected_weights[0]

In [None]:
test_loss = np.array(test_loss)
total_rel_error = np.array(total_rel_error)

In [None]:
print(test_loss)

In [None]:
np.save( 'training_loss_sheave_fml_alpha' + str(alpha).replace('.', '_') + "_eta_"+ str(eta).replace('.', '_') +  '_pout' + str(pout).replace('.', '_') + '+d0_' + str(d0), test_loss)
#np.save('relative_error_sheave_fml' + str(lamda).replace('.', '_'), total_rel_error)

In [None]:
'training_loss_sheave_fml' + str(lamda).replace('.', '_'), test_loss