In [1]:
from keras.utils import to_categorical
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import copy
import matplotlib.pyplot as plt
from tqdm import tqdm
import networkx as nx
from torch.nn.utils import parameters_to_vector, vector_to_parameters
import networkx as nx

In [2]:
from sklearn.datasets import load_svmlight_file


def get_data(data):
    data = load_svmlight_file(data)
    return data[0].toarray(), data[1]
data = "./abalone_scale.txt"
X, y = get_data(data)
#y[y == -1] = 0
# Convert labels to one-hot encoding
#y = to_categorical(y)

#data = "./a1a_t"
#X_test,y_test = get_data(data)
#y_test[y_test == -1] = 0
#y_test = to_categorical(y_test)


In [3]:
# Dataset partitioning
def random_split(X, y, n, seed):
    """Equally split data between n agents"""
    rng = np.random.default_rng(seed)
    perm = rng.permutation(y.shape[0])
    X_split = np.array_split(X[perm], n)  #np.stack to keep as a np array
    y_split = np.array_split(y[perm], n)
    return X_split, y_split

In [4]:
no_users = 30

In [5]:
X, y = random_split(X, y, 3, 42)

In [6]:
X1 = X[0][:, 0:4]
X2 = X[1][:, 2: 6]
X3 = X[2][:, 3:]
y1 = y[0]
y2 = y[1]
y3 = y[2]

X_test = [np.concatenate([X[1], X[2]], axis=0), np.concatenate([X[0], X[2]], axis=0), np.concatenate([X[0], X[1]], axis=0)]
y_test = [np.concatenate((y[1], y[2])), np.concatenate((y[0], y[2])), np.concatenate((y[0], y[1]))]


subset_ranges = [np.arange(0, 4), np.arange(2, 6), np.arange(3, 8)]
subset_lengths = [subset_ranges[0].shape[0], subset_ranges[1].shape[0], subset_ranges[2].shape[0]]

In [7]:
X_test[0].shape

(2784, 8)

In [8]:
X_test[0].shape

(2784, 8)

In [9]:
# Graph implementation
def generate_graph(cluster_sizes=[100,100], pin=0.5, pout=0.01, seed=0):
    """Generate a random connected graph"""
    probs = np.array([[pin, pout, pout],[pout, pin, pout],[pout, pout, pin]])
    while True:
        g = nx.stochastic_block_model(cluster_sizes, probs, seed=0)
        if nx.algorithms.components.is_connected(g):
            return g


cluster_sizes = [10, 10, 10]
#features_sizes = [8, 7, 6, 5]
pin = 0.5
pout = 0.1
seed = 0
alpha = 1e-2
lamda = 1e-1#1e-3#1e-1#1e-3
eta = 1 * 1e-2
d0 = min(subset_lengths)
no_users = sum(cluster_sizes)
batch_size = 50
epochs = 1
it = 500
G = generate_graph(cluster_sizes, pin, pout, seed)

# Set a random seed for reproducibility
seed = 17
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
#nx.draw(G, with_labels=True, node_size=100, alpha=1, linewidths=10)
#plt.show()

In [10]:
d0

4

In [11]:
def degrees(A):
    """Return the degrees of each node of a graph from its adjacency matrix"""
    return np.sum(A, axis=0).reshape(A.shape[0], 1)

def node_degree(n, G):
    cnt = 0
    for i in G.neighbors(n):
        cnt += 1
    return cnt

def get_neighbors(n, G):
    neighbors_list = []
    for i in G.neighbors(n):
        neighbors_list.append(int(i))
    return neighbors_list

In [12]:
y1.shape

(1393,)

In [13]:
datapoints = {}
count = 0

X1, y1 = random_split(X1, y1, 10, 42)
X2, y2 = random_split(X2, y2, 10, 42)
X3, y3 = random_split(X3, y3, 10, 42)

X_train = [X1, X2, X3]
y_train = [y1, y2, y3]
input_sizes = [X1[0].shape[1], X2[0].shape[1], X3[0].shape[1]]

for i, cluster_size in enumerate(cluster_sizes):
    for j in range(cluster_size):
        
        test_features = X_train[i][j]#X_test[:, subset_ranges[i]]
        test_label = y_train[i][j]#y_test
        datapoints[count] = {
                'features': X_train[i][j],
                'degree': node_degree(count, G),
                'label': y_train[i][j],
                'neighbors': get_neighbors(count, G),
                'input_size': X_train[i][j].shape[1],
                'test_features':X_test[i][:, subset_ranges[i]],
                'test_labels': y_test[i]
            }
        count += 1

In [14]:
datapoints[21]['test_features'].shape

(2785, 5)

In [15]:
class MyDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = torch.FloatTensor(data)
        self.targets = torch.FloatTensor(targets)
        
    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]

        return x, y
    
    def __len__(self):
        return len(self.data)


In [16]:

class MLP_Net(nn.Module):
    def __init__(self, input_size, num_classes, user_id):
        super(MLP_Net, self).__init__()
        self.fc1 = nn.Linear(input_size, num_classes, bias=True)
        self.user_id = user_id

    def forward(self, x):
        x = torch.flatten(x, 1)
        output = self.fc1(x)#F.softmax(self.fc1(x), dim=1)  # Applying softmax along the second dimension
        return output

In [17]:
from typing import Iterable, Optional

def grads_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor:
    r"""Convert parameters to one vector

    Args:
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.

    Returns:
        The parameters represented by a single vector
    """
    # Flag for the device where the parameter is located
    param_device = None

    vec = []
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = param.grad

        vec.append(param_device.view(-1))
    return torch.cat(vec)

In [18]:
model = MLP_Net(datapoints[21]["input_size"], 1, user_id=0)

lr = 0.01

dataloader = DataLoader(MyDataset(datapoints[21]["features"], datapoints[21]["label"]), batch_size=50, shuffle=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for i in range(1000):
    for j, (x, y) in zip(range(1), dataloader):
        criterion = nn.MSELoss()
        optimizer.zero_grad()
        yhat = model(x)


        
        # Calculate accuracy
        #_, predicted = torch.max(yhat, 1)
        #_, predicted_true = torch.max(y, 1)
        #correct_predictions = (predicted == predicted_true).sum().item()
         

        loss = criterion(yhat.squeeze(), y)
        
        
        loss.backward()
        print(i, loss.detach())
        #optimizer.step()
        new_model = parameters_to_vector(model.parameters()) - lr * grads_to_vector(model.parameters())
        vector_to_parameters(parameters=model.parameters(), vec=new_model)
        #if i % 50 ==0:
            #lr *= 0.9
            

#parameters_to_vector(model.parameters())

0 tensor(96.4009)
1 tensor(120.1769)
2 tensor(111.0524)
3 tensor(91.6952)
4 tensor(94.1627)
5 tensor(94.1745)
6 tensor(80.5656)
7 tensor(69.3334)
8 tensor(70.9690)
9 tensor(71.9623)
10 tensor(56.2823)
11 tensor(52.8805)
12 tensor(55.0263)
13 tensor(45.1626)
14 tensor(45.2221)
15 tensor(45.4920)
16 tensor(44.4830)
17 tensor(41.0744)
18 tensor(47.4335)
19 tensor(36.9328)
20 tensor(38.0889)
21 tensor(34.5194)
22 tensor(27.8188)
23 tensor(34.5506)
24 tensor(35.8637)
25 tensor(41.8037)
26 tensor(25.5844)
27 tensor(30.0256)
28 tensor(32.2292)
29 tensor(37.1375)
30 tensor(31.9085)
31 tensor(31.6938)
32 tensor(37.3044)
33 tensor(30.9458)
34 tensor(17.9449)
35 tensor(29.8354)
36 tensor(27.1558)
37 tensor(33.5450)
38 tensor(32.5139)
39 tensor(30.4092)
40 tensor(22.7408)
41 tensor(25.3071)
42 tensor(26.2734)
43 tensor(35.3796)
44 tensor(24.6531)
45 tensor(23.9695)
46 tensor(34.5260)
47 tensor(32.8605)
48 tensor(32.5869)
49 tensor(18.6828)
50 tensor(22.8432)
51 tensor(24.4910)
52 tensor(18.6819)
5

421 tensor(9.5394)
422 tensor(9.0814)
423 tensor(8.5772)
424 tensor(11.6208)
425 tensor(9.0468)
426 tensor(7.1423)
427 tensor(14.6393)
428 tensor(11.9258)
429 tensor(9.6304)
430 tensor(7.2514)
431 tensor(12.4283)
432 tensor(6.9642)
433 tensor(15.0882)
434 tensor(6.6011)
435 tensor(9.1185)
436 tensor(11.8664)
437 tensor(12.8153)
438 tensor(10.7700)
439 tensor(7.2885)
440 tensor(13.2076)
441 tensor(7.9973)
442 tensor(12.6127)
443 tensor(16.0539)
444 tensor(9.9215)
445 tensor(6.0890)
446 tensor(6.2137)
447 tensor(14.3051)
448 tensor(8.5768)
449 tensor(4.5137)
450 tensor(13.5124)
451 tensor(9.8070)
452 tensor(9.1943)
453 tensor(8.4796)
454 tensor(10.6170)
455 tensor(8.8816)
456 tensor(14.6223)
457 tensor(12.3801)
458 tensor(7.7698)
459 tensor(9.2065)
460 tensor(9.3257)
461 tensor(8.6773)
462 tensor(8.8561)
463 tensor(7.3800)
464 tensor(8.1687)
465 tensor(6.6027)
466 tensor(12.0275)
467 tensor(7.5526)
468 tensor(8.9036)
469 tensor(11.6197)
470 tensor(11.1041)
471 tensor(16.3770)
472 tensor(

897 tensor(5.3784)
898 tensor(7.0077)
899 tensor(7.0254)
900 tensor(10.8680)
901 tensor(12.9191)
902 tensor(6.8479)
903 tensor(5.5857)
904 tensor(7.5197)
905 tensor(10.0674)
906 tensor(7.6830)
907 tensor(8.9906)
908 tensor(7.1564)
909 tensor(10.6470)
910 tensor(8.5121)
911 tensor(7.6536)
912 tensor(8.9909)
913 tensor(9.8533)
914 tensor(4.3170)
915 tensor(9.0209)
916 tensor(9.3485)
917 tensor(5.5439)
918 tensor(7.2766)
919 tensor(10.0440)
920 tensor(9.2125)
921 tensor(3.4506)
922 tensor(9.9528)
923 tensor(10.2199)
924 tensor(9.3591)
925 tensor(5.6228)
926 tensor(6.5693)
927 tensor(7.9162)
928 tensor(5.7938)
929 tensor(9.6396)
930 tensor(9.0934)
931 tensor(5.9705)
932 tensor(10.9917)
933 tensor(9.6978)
934 tensor(6.6489)
935 tensor(9.1834)
936 tensor(9.1887)
937 tensor(10.8827)
938 tensor(7.6818)
939 tensor(9.0877)
940 tensor(12.2672)
941 tensor(7.8016)
942 tensor(6.3757)
943 tensor(12.5605)
944 tensor(12.6414)
945 tensor(4.7607)
946 tensor(9.3541)
947 tensor(7.5439)
948 tensor(10.1912)


In [19]:
class ClientUpdate(object):
    def __init__(self, dataset, batchSize, alpha, lamda, epochs, projection_list, projected_weights):
        self.train_loader = DataLoader(MyDataset(dataset["features"], dataset["label"]), batch_size=batchSize, shuffle=True)
        #self.learning_rate = learning_rate
        self.epochs = epochs
        self.batchSize = batchSize

    def train(self, model):
        criterion = nn.MSELoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=alpha)

        e_loss = []
        for epoch in range(1, self.epochs+1):
            train_loss = 0
            #model.train()
            for i, (data, labels) in zip(range(1), self.train_loader):
                data, labels = data, labels
                optimizer.zero_grad() 
                output = model(data)  
                loss = criterion(output.squeeze(), labels)
                #loss += mu/2 * torch.norm(client_param.data - server_param.data)**2
                loss.backward()
                grads = grads_to_vector(model.parameters())
                #optimizer.step()
                train_loss += loss.item()*data.size(0)
                weights = parameters_to_vector(model.parameters())
                mat_vec_sum = torch.zeros_like(weights)
                for j in G.neighbors(model.user_id):
                    mat_vec_sum = torch.add(mat_vec_sum, torch.matmul(torch.transpose(projection_list[model.user_id][j], 0, 1), 
                                                         projected_weights[model.user_id][j] - projected_weights[j][model.user_id]))
                
                model_update = parameters_to_vector(model.parameters()) - alpha * (grads + lamda * mat_vec_sum)
                
            vector_to_parameters(parameters=model.parameters(), vec=model_update)
                

            train_loss = train_loss/self.batchSize#len(self.train_loader.dataset) 
            e_loss.append(train_loss)

        total_loss = e_loss#sum(e_loss)/len(e_loss)

        return model.state_dict(), total_loss

In [20]:
# Preparing projection matrices
models = [MLP_Net(input_size=datapoints[i]['input_size'], num_classes=1, user_id=i) for i in range(no_users)]
#temp = MLP_Net()
projection_list = []
projected_weights = []

def update_ProjWeight(projection_list, projected_weights, models, first_run=True):
    for i in range(no_users):
        neighbors_mat = []
        neighbors_weights = []
        for j in range(no_users):
            if j in G.neighbors(i):
                with torch.no_grad():
                    if first_run == True:
                        # Specify the dimensions of the rectangular matrix
                        row, column = d0, parameters_to_vector(models[i].parameters()).size()[0]

                        # Generate random values for the diagonal from a normal distribution
                        diag_values = torch.ones(min(row, column))

                        # Create a rectangular matrix with diagonal elements
                        mat = torch.diag(diag_values)

                        # If the matrix is larger than the diagonal vector, fill the remaining elements with zeros
                       
                        mat = torch.cat((mat, torch.zeros(row, column - row)), dim=1)

                        

                        # Append the matrix to the list
                        neighbors_mat.append(mat)
                        neighbors_weights.append(torch.matmul(mat, parameters_to_vector(models[i].parameters())))
                    else:
                        neighbors_weights.append(torch.matmul(projection_list[i][j], parameters_to_vector(models[i].parameters())))
            else:
                neighbors_mat.append(0)
                neighbors_weights.append(0)
        if first_run == True:
            projection_list.append(neighbors_mat)
        projected_weights.append(neighbors_weights)

update_ProjWeight(projection_list, projected_weights, models)



In [21]:
def testing(model, dataset, bs, criterion): 
    test_loss = 0
    correct = 0
    test_loader = DataLoader(MyDataset(dataset["test_features"], dataset["test_labels"]), batch_size=bs)
    l = len(test_loader)
    model.eval()
    for data, labels in test_loader:
        data, labels = data, labels
        output = model(data)
        loss = criterion(output.squeeze(), labels)
        test_loss += loss.item()*data.size(0)
        #_, pred = torch.max(output, 1)
        #_, predicted_true = torch.max(labels, 1)
        #correct += pred.eq(predicted_true.data.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)
    #correct /= len(test_loader.dataset)
    
    return test_loss#, correct

In [22]:
#global_model = CNN_Net().cuda()
models = [MLP_Net(input_size=datapoints[i]['input_size'], num_classes=1, user_id=i) for i in range(no_users)]
dummy_models = [MLP_Net(input_size=datapoints[i]['input_size'], num_classes=1, user_id=i) for i in range(no_users)]

#model.load_state_dict(global_model.state_dict())

criterion = nn.MSELoss()


train_loss = []
test_loss = []
test_accuracy = []
total_rel_error = []

in_cluster_proj_norm = []
out_cluster_proj_norm = []
in_cluster_proj_diff_norm = []
out_cluster_proj_diff_norm = []
task_loss = {'0':[],
                '1':[],
                '2':[]}
task_rel_error = {'0':[],
            '1':[],
            '2':[]}

for curr_round in tqdm(range(1, it+1)):
    w, local_loss = [], []

    
    for i in range(no_users):
        dummy_models[i].load_state_dict(models[i].state_dict())
        local_update = ClientUpdate(dataset=datapoints[i], batchSize=batch_size, alpha=alpha, lamda=lamda, epochs=1, projection_list=projection_list, projected_weights=projected_weights)
        weights, loss = local_update.train(dummy_models[i])
        w.append(weights)
        local_loss.append(loss)
        models[i].load_state_dict(w[i])
        
    
    
    # Update prjection matrix
    projected_weights = []
    update_ProjWeight(projection_list, projected_weights, models, first_run=False)
    
    #print(projection_list[0], projected_weights[0])
    
    for i in range(no_users):
        weights = parameters_to_vector(models[i].parameters())
        for j in G.neighbors(i):
            temp_mat = torch.outer(projected_weights[i][j] - projected_weights[j][i], weights)
            projection_list[i][j] = torch.add(projection_list[i][j], -1 * eta * lamda * temp_mat)
    
    in_cluster_proj_norm_round = 0
    out_cluster_proj_norm_round = 0
    in_cluster_proj_diff_round = 0
    out_cluster_proj_diff_round = 0
    in_edges = 0
    out_edges = 0
    
    for i in range(no_users//3):
        for j in G.neighbors(i):
            if j < no_users//3:
                in_edges += 1
                in_cluster_proj_norm_round += torch.norm(projection_list[i][j]).detach().numpy()
                in_cluster_proj_diff_round += torch.norm(projected_weights[i][j] - projected_weights[j][i]).detach().numpy()
            else:
                out_edges += 1
                out_cluster_proj_norm_round += torch.norm(projection_list[i][j]).detach().numpy()
                out_cluster_proj_diff_round += torch.norm(projected_weights[i][j] - projected_weights[j][i]).detach().numpy()
    in_cluster_proj_norm.append(in_cluster_proj_norm_round / in_edges)
    out_cluster_proj_norm.append(out_cluster_proj_norm_round / out_edges)
    in_cluster_proj_diff_norm.append(in_cluster_proj_diff_round / in_edges)
    out_cluster_proj_diff_norm.append(out_cluster_proj_diff_round / out_edges)
            
                                         
                                              
    
        
        
        
    
    




          
            

    local_test_acc = []
    local_test_loss = []
    user_rel_error = 0
    per_task_loss = []
    per_task_rel_error = []
    
    for k in range(no_users):
      
        g_loss = testing(models[k], datapoints[k], 50, criterion)
        local_test_loss.append(g_loss)
        #user_rel_error += acc#rel_error(models[i])
        if (k + 1) % 10 == 0:
            task_loss[str(k // 10)].append(sum(per_task_loss) / 10)
            #task_rel_error[str(k // 10)].append(sum(per_task_rel_error) / 10)
            per_task_loss = []
            #per_task_rel_error = []
        per_task_loss.append(g_loss)
        #per_task_rel_error.append(acc)#rel_error(models[i]))
    
    
        

    g_loss = sum(local_test_loss) / len(local_test_loss)
    #total_rel_error.append(user_rel_error / no_users)
    
    

    test_loss.append(g_loss)
    #test_accuracy.append(g_accuracy)
    print("Training_loss %2.5f"% (test_loss[-1]))

  0%|          | 1/500 [00:01<08:55,  1.07s/it]

Training_loss 101.21219


  0%|          | 2/500 [00:02<08:55,  1.08s/it]

Training_loss 94.19116


  1%|          | 3/500 [00:03<08:22,  1.01s/it]

Training_loss 87.84749


  1%|          | 4/500 [00:04<08:12,  1.01it/s]

Training_loss 81.96899


  1%|          | 5/500 [00:05<09:29,  1.15s/it]

Training_loss 76.59020


  1%|          | 6/500 [00:09<17:27,  2.12s/it]

Training_loss 71.66375


  1%|▏         | 7/500 [00:12<21:02,  2.56s/it]

Training_loss 67.10812


  2%|▏         | 8/500 [00:14<18:36,  2.27s/it]

Training_loss 62.94853


  2%|▏         | 9/500 [00:15<15:51,  1.94s/it]

Training_loss 59.15367


  2%|▏         | 10/500 [00:16<13:40,  1.67s/it]

Training_loss 55.61642


  2%|▏         | 11/500 [00:17<12:03,  1.48s/it]

Training_loss 52.45607


  2%|▏         | 12/500 [00:18<10:57,  1.35s/it]

Training_loss 49.47430


  3%|▎         | 13/500 [00:19<10:08,  1.25s/it]

Training_loss 46.76617


  3%|▎         | 14/500 [00:20<09:23,  1.16s/it]

Training_loss 44.28576


  3%|▎         | 15/500 [00:22<09:16,  1.15s/it]

Training_loss 42.04618


  3%|▎         | 16/500 [00:23<09:32,  1.18s/it]

Training_loss 39.92235


  3%|▎         | 17/500 [00:24<09:20,  1.16s/it]

Training_loss 38.00380


  4%|▎         | 18/500 [00:25<09:04,  1.13s/it]

Training_loss 36.24697


  4%|▍         | 19/500 [00:26<08:32,  1.07s/it]

Training_loss 34.59729


  4%|▍         | 20/500 [00:27<08:36,  1.08s/it]

Training_loss 33.08346


  4%|▍         | 21/500 [00:28<08:11,  1.03s/it]

Training_loss 31.70297


  4%|▍         | 22/500 [00:29<08:28,  1.06s/it]

Training_loss 30.41503


  5%|▍         | 23/500 [00:30<08:12,  1.03s/it]

Training_loss 29.18917


  5%|▍         | 24/500 [00:31<08:49,  1.11s/it]

Training_loss 28.05814


  5%|▌         | 25/500 [00:32<08:22,  1.06s/it]

Training_loss 27.05124


  5%|▌         | 26/500 [00:33<08:40,  1.10s/it]

Training_loss 26.11480


  5%|▌         | 27/500 [00:34<08:04,  1.03s/it]

Training_loss 25.19500


  6%|▌         | 28/500 [00:35<07:44,  1.02it/s]

Training_loss 24.37255


  6%|▌         | 29/500 [00:37<08:31,  1.09s/it]

Training_loss 23.61613


  6%|▌         | 30/500 [00:39<11:14,  1.43s/it]

Training_loss 22.91516


  6%|▌         | 31/500 [00:40<10:38,  1.36s/it]

Training_loss 22.28425


  6%|▋         | 32/500 [00:41<10:30,  1.35s/it]

Training_loss 21.65582


  7%|▋         | 33/500 [00:43<10:19,  1.33s/it]

Training_loss 21.07655


  7%|▋         | 34/500 [00:44<09:27,  1.22s/it]

Training_loss 20.55051


  7%|▋         | 35/500 [00:45<09:24,  1.21s/it]

Training_loss 20.05919


  7%|▋         | 36/500 [00:46<08:50,  1.14s/it]

Training_loss 19.59517


  7%|▋         | 37/500 [00:47<08:52,  1.15s/it]

Training_loss 19.18333


  8%|▊         | 38/500 [00:48<08:12,  1.07s/it]

Training_loss 18.79825


  8%|▊         | 39/500 [00:49<07:46,  1.01s/it]

Training_loss 18.42195


  8%|▊         | 40/500 [00:50<07:50,  1.02s/it]

Training_loss 18.06407


  8%|▊         | 41/500 [00:51<07:31,  1.02it/s]

Training_loss 17.74334


  8%|▊         | 42/500 [00:51<07:21,  1.04it/s]

Training_loss 17.45419


  9%|▊         | 43/500 [00:52<07:25,  1.02it/s]

Training_loss 17.17148


  9%|▉         | 44/500 [00:54<07:39,  1.01s/it]

Training_loss 16.90722


  9%|▉         | 45/500 [00:55<08:00,  1.06s/it]

Training_loss 16.66566


  9%|▉         | 46/500 [00:56<08:22,  1.11s/it]

Training_loss 16.41835


  9%|▉         | 47/500 [00:57<07:58,  1.06s/it]

Training_loss 16.18495


 10%|▉         | 48/500 [00:58<07:47,  1.03s/it]

Training_loss 15.97277


 10%|▉         | 49/500 [00:59<07:42,  1.03s/it]

Training_loss 15.78049


 10%|█         | 50/500 [01:00<08:10,  1.09s/it]

Training_loss 15.58739


 10%|█         | 51/500 [01:01<08:41,  1.16s/it]

Training_loss 15.41877


 10%|█         | 52/500 [01:03<08:25,  1.13s/it]

Training_loss 15.24066


 11%|█         | 53/500 [01:03<08:04,  1.08s/it]

Training_loss 15.07899


 11%|█         | 54/500 [01:04<07:48,  1.05s/it]

Training_loss 14.92840


 11%|█         | 55/500 [01:06<07:51,  1.06s/it]

Training_loss 14.78655


 11%|█         | 56/500 [01:07<09:42,  1.31s/it]

Training_loss 14.64758


 11%|█▏        | 57/500 [01:11<14:12,  1.92s/it]

Training_loss 14.51572


 12%|█▏        | 58/500 [01:13<14:38,  1.99s/it]

Training_loss 14.38473


 12%|█▏        | 59/500 [01:14<13:34,  1.85s/it]

Training_loss 14.27269


 12%|█▏        | 60/500 [01:15<11:37,  1.59s/it]

Training_loss 14.16447


 12%|█▏        | 61/500 [01:16<09:57,  1.36s/it]

Training_loss 14.05736


 12%|█▏        | 62/500 [01:17<09:23,  1.29s/it]

Training_loss 13.95388


 13%|█▎        | 63/500 [01:18<08:55,  1.23s/it]

Training_loss 13.85379


 13%|█▎        | 64/500 [01:20<09:21,  1.29s/it]

Training_loss 13.75793


 13%|█▎        | 65/500 [01:21<09:05,  1.25s/it]

Training_loss 13.66584


 13%|█▎        | 66/500 [01:22<08:31,  1.18s/it]

Training_loss 13.57868


 13%|█▎        | 67/500 [01:23<08:07,  1.13s/it]

Training_loss 13.49039


 14%|█▎        | 68/500 [01:25<08:46,  1.22s/it]

Training_loss 13.41363


 14%|█▍        | 69/500 [01:26<09:49,  1.37s/it]

Training_loss 13.33029


 14%|█▍        | 70/500 [01:27<08:57,  1.25s/it]

Training_loss 13.24794


 14%|█▍        | 71/500 [01:28<08:19,  1.16s/it]

Training_loss 13.17231


 14%|█▍        | 72/500 [01:29<08:16,  1.16s/it]

Training_loss 13.10108


 15%|█▍        | 73/500 [01:31<09:00,  1.27s/it]

Training_loss 13.03200


 15%|█▍        | 74/500 [01:32<09:15,  1.30s/it]

Training_loss 12.95857


 15%|█▌        | 75/500 [01:33<09:06,  1.29s/it]

Training_loss 12.89709


 15%|█▌        | 76/500 [01:35<09:37,  1.36s/it]

Training_loss 12.83776


 15%|█▌        | 77/500 [01:37<11:51,  1.68s/it]

Training_loss 12.78559


 16%|█▌        | 78/500 [01:42<17:27,  2.48s/it]

Training_loss 12.71825


 16%|█▌        | 79/500 [01:46<20:32,  2.93s/it]

Training_loss 12.65637


 16%|█▌        | 80/500 [01:47<17:26,  2.49s/it]

Training_loss 12.59564


 16%|█▌        | 81/500 [01:48<14:38,  2.10s/it]

Training_loss 12.53390


 16%|█▋        | 82/500 [01:49<12:14,  1.76s/it]

Training_loss 12.47642


 17%|█▋        | 83/500 [01:50<10:26,  1.50s/it]

Training_loss 12.42706


 17%|█▋        | 84/500 [01:51<09:28,  1.37s/it]

Training_loss 12.37756


 17%|█▋        | 85/500 [01:52<08:44,  1.26s/it]

Training_loss 12.33240


 17%|█▋        | 86/500 [01:53<07:55,  1.15s/it]

Training_loss 12.28380


 17%|█▋        | 87/500 [01:54<07:31,  1.09s/it]

Training_loss 12.23442


 18%|█▊        | 88/500 [01:55<07:20,  1.07s/it]

Training_loss 12.19029


 18%|█▊        | 89/500 [01:56<07:00,  1.02s/it]

Training_loss 12.14206


 18%|█▊        | 90/500 [01:57<07:25,  1.09s/it]

Training_loss 12.09138


 18%|█▊        | 91/500 [01:59<07:42,  1.13s/it]

Training_loss 12.04652


 18%|█▊        | 92/500 [02:00<07:58,  1.17s/it]

Training_loss 12.00421


 19%|█▊        | 93/500 [02:01<07:31,  1.11s/it]

Training_loss 11.95998


 19%|█▉        | 94/500 [02:02<07:31,  1.11s/it]

Training_loss 11.91583


 19%|█▉        | 95/500 [02:03<07:35,  1.12s/it]

Training_loss 11.87600


 19%|█▉        | 96/500 [02:04<07:42,  1.15s/it]

Training_loss 11.83909


 19%|█▉        | 97/500 [02:05<07:28,  1.11s/it]

Training_loss 11.79783


 20%|█▉        | 98/500 [02:06<06:57,  1.04s/it]

Training_loss 11.75833


 20%|█▉        | 99/500 [02:08<08:40,  1.30s/it]

Training_loss 11.72110


 20%|██        | 100/500 [02:10<08:56,  1.34s/it]

Training_loss 11.67639


 20%|██        | 101/500 [02:10<08:10,  1.23s/it]

Training_loss 11.64025


 20%|██        | 102/500 [02:12<07:51,  1.18s/it]

Training_loss 11.60250


 21%|██        | 103/500 [02:13<08:02,  1.22s/it]

Training_loss 11.56413


 21%|██        | 104/500 [02:14<08:37,  1.31s/it]

Training_loss 11.52740


 21%|██        | 105/500 [02:15<08:03,  1.22s/it]

Training_loss 11.48849


 21%|██        | 106/500 [02:17<07:52,  1.20s/it]

Training_loss 11.44894


 21%|██▏       | 107/500 [02:18<07:42,  1.18s/it]

Training_loss 11.41013


 22%|██▏       | 108/500 [02:19<07:30,  1.15s/it]

Training_loss 11.37539


 22%|██▏       | 109/500 [02:20<06:51,  1.05s/it]

Training_loss 11.33882


 22%|██▏       | 110/500 [02:21<06:44,  1.04s/it]

Training_loss 11.30652


 22%|██▏       | 111/500 [02:22<06:34,  1.01s/it]

Training_loss 11.26890


 22%|██▏       | 112/500 [02:23<06:40,  1.03s/it]

Training_loss 11.23312


 23%|██▎       | 113/500 [02:24<06:52,  1.07s/it]

Training_loss 11.19912


 23%|██▎       | 114/500 [02:25<07:00,  1.09s/it]

Training_loss 11.16331


 23%|██▎       | 115/500 [02:26<07:08,  1.11s/it]

Training_loss 11.13071


 23%|██▎       | 116/500 [02:27<07:23,  1.15s/it]

Training_loss 11.09810


 23%|██▎       | 117/500 [02:28<06:53,  1.08s/it]

Training_loss 11.06712


 24%|██▎       | 118/500 [02:29<06:34,  1.03s/it]

Training_loss 11.03785


 24%|██▍       | 119/500 [02:30<06:36,  1.04s/it]

Training_loss 11.00640


 24%|██▍       | 120/500 [02:31<06:15,  1.01it/s]

Training_loss 10.97601


 24%|██▍       | 121/500 [02:32<06:56,  1.10s/it]

Training_loss 10.94675


 24%|██▍       | 122/500 [02:34<07:33,  1.20s/it]

Training_loss 10.91475


 25%|██▍       | 123/500 [02:36<08:23,  1.34s/it]

Training_loss 10.88333


 25%|██▍       | 124/500 [02:38<09:57,  1.59s/it]

Training_loss 10.85161


 25%|██▌       | 125/500 [02:40<11:10,  1.79s/it]

Training_loss 10.82399


 25%|██▌       | 126/500 [02:43<13:53,  2.23s/it]

Training_loss 10.79401


 25%|██▌       | 127/500 [02:45<12:40,  2.04s/it]

Training_loss 10.76286


 26%|██▌       | 128/500 [02:46<11:04,  1.79s/it]

Training_loss 10.73642


 26%|██▌       | 129/500 [02:47<09:27,  1.53s/it]

Training_loss 10.70837


 26%|██▌       | 130/500 [02:48<08:16,  1.34s/it]

Training_loss 10.68037


 26%|██▌       | 131/500 [02:49<07:36,  1.24s/it]

Training_loss 10.65478


 26%|██▋       | 132/500 [02:50<07:11,  1.17s/it]

Training_loss 10.63193


 27%|██▋       | 133/500 [02:51<07:18,  1.20s/it]

Training_loss 10.60499


 27%|██▋       | 134/500 [02:52<07:35,  1.24s/it]

Training_loss 10.57847


 27%|██▋       | 135/500 [02:53<07:05,  1.16s/it]

Training_loss 10.54718


 27%|██▋       | 136/500 [02:54<06:44,  1.11s/it]

Training_loss 10.52062


 27%|██▋       | 137/500 [02:56<07:27,  1.23s/it]

Training_loss 10.49575


 28%|██▊       | 138/500 [02:57<07:36,  1.26s/it]

Training_loss 10.46936


 28%|██▊       | 139/500 [02:59<07:40,  1.27s/it]

Training_loss 10.44155


 28%|██▊       | 140/500 [03:00<07:35,  1.27s/it]

Training_loss 10.41648


 28%|██▊       | 141/500 [03:01<06:52,  1.15s/it]

Training_loss 10.39296


 28%|██▊       | 142/500 [03:02<06:46,  1.13s/it]

Training_loss 10.36918


 29%|██▊       | 143/500 [03:03<06:21,  1.07s/it]

Training_loss 10.34287


 29%|██▉       | 144/500 [03:04<06:11,  1.04s/it]

Training_loss 10.31767


 29%|██▉       | 145/500 [03:05<05:51,  1.01it/s]

Training_loss 10.29362


 29%|██▉       | 146/500 [03:06<06:23,  1.08s/it]

Training_loss 10.26608


 29%|██▉       | 147/500 [03:07<06:24,  1.09s/it]

Training_loss 10.24303


 30%|██▉       | 148/500 [03:09<07:13,  1.23s/it]

Training_loss 10.22076


 30%|██▉       | 149/500 [03:10<07:19,  1.25s/it]

Training_loss 10.19408


 30%|███       | 150/500 [03:11<06:51,  1.18s/it]

Training_loss 10.16698


 30%|███       | 151/500 [03:12<07:08,  1.23s/it]

Training_loss 10.14199


 30%|███       | 152/500 [03:13<06:57,  1.20s/it]

Training_loss 10.12341


 31%|███       | 153/500 [03:14<06:49,  1.18s/it]

Training_loss 10.10319


 31%|███       | 154/500 [03:16<07:23,  1.28s/it]

Training_loss 10.07858


 31%|███       | 155/500 [03:18<09:03,  1.57s/it]

Training_loss 10.05560


 31%|███       | 156/500 [03:19<08:07,  1.42s/it]

Training_loss 10.03419


 31%|███▏      | 157/500 [03:20<07:12,  1.26s/it]

Training_loss 10.01061


 32%|███▏      | 158/500 [03:21<06:36,  1.16s/it]

Training_loss 9.98923


 32%|███▏      | 159/500 [03:22<06:04,  1.07s/it]

Training_loss 9.96775


 32%|███▏      | 160/500 [03:23<05:43,  1.01s/it]

Training_loss 9.94607


 32%|███▏      | 161/500 [03:24<05:51,  1.04s/it]

Training_loss 9.92598


 32%|███▏      | 162/500 [03:25<05:47,  1.03s/it]

Training_loss 9.90585


 33%|███▎      | 163/500 [03:26<05:27,  1.03it/s]

Training_loss 9.88674


 33%|███▎      | 164/500 [03:27<05:24,  1.04it/s]

Training_loss 9.86462


 33%|███▎      | 165/500 [03:28<05:39,  1.01s/it]

Training_loss 9.84018


 33%|███▎      | 166/500 [03:29<06:24,  1.15s/it]

Training_loss 9.81979


 33%|███▎      | 167/500 [03:30<06:23,  1.15s/it]

Training_loss 9.80094


 34%|███▎      | 168/500 [03:31<06:02,  1.09s/it]

Training_loss 9.78108


 34%|███▍      | 169/500 [03:32<05:41,  1.03s/it]

Training_loss 9.76012


 34%|███▍      | 170/500 [03:34<06:10,  1.12s/it]

Training_loss 9.73957


 34%|███▍      | 171/500 [03:35<06:25,  1.17s/it]

Training_loss 9.71841


 34%|███▍      | 172/500 [03:36<06:07,  1.12s/it]

Training_loss 9.69761


 35%|███▍      | 173/500 [03:37<05:59,  1.10s/it]

Training_loss 9.67949


 35%|███▍      | 174/500 [03:38<05:52,  1.08s/it]

Training_loss 9.65836


 35%|███▌      | 175/500 [03:39<05:43,  1.06s/it]

Training_loss 9.63696


 35%|███▌      | 176/500 [03:40<05:29,  1.02s/it]

Training_loss 9.61682


 35%|███▌      | 177/500 [03:41<05:09,  1.04it/s]

Training_loss 9.59528


 36%|███▌      | 178/500 [03:42<05:00,  1.07it/s]

Training_loss 9.57706


 36%|███▌      | 179/500 [03:43<04:57,  1.08it/s]

Training_loss 9.55839


 36%|███▌      | 180/500 [03:43<04:52,  1.10it/s]

Training_loss 9.53960


 36%|███▌      | 181/500 [03:44<04:44,  1.12it/s]

Training_loss 9.52043


 36%|███▋      | 182/500 [03:45<04:50,  1.09it/s]

Training_loss 9.50050


 37%|███▋      | 183/500 [03:46<04:57,  1.07it/s]

Training_loss 9.48491


 37%|███▋      | 184/500 [03:47<05:00,  1.05it/s]

Training_loss 9.46882


 37%|███▋      | 185/500 [03:48<04:51,  1.08it/s]

Training_loss 9.45022


 37%|███▋      | 186/500 [03:49<05:08,  1.02it/s]

Training_loss 9.43247


 37%|███▋      | 187/500 [03:50<05:11,  1.01it/s]

Training_loss 9.41510


 38%|███▊      | 188/500 [03:51<04:57,  1.05it/s]

Training_loss 9.39934


 38%|███▊      | 189/500 [03:52<04:48,  1.08it/s]

Training_loss 9.38434


 38%|███▊      | 190/500 [03:53<04:39,  1.11it/s]

Training_loss 9.36957


 38%|███▊      | 191/500 [03:54<04:45,  1.08it/s]

Training_loss 9.35643


 38%|███▊      | 192/500 [03:55<04:36,  1.11it/s]

Training_loss 9.33781


 39%|███▊      | 193/500 [03:55<04:31,  1.13it/s]

Training_loss 9.32028


 39%|███▉      | 194/500 [03:56<04:24,  1.16it/s]

Training_loss 9.30232


 39%|███▉      | 195/500 [03:57<04:24,  1.15it/s]

Training_loss 9.28448


 39%|███▉      | 196/500 [03:58<04:24,  1.15it/s]

Training_loss 9.26510


 39%|███▉      | 197/500 [03:59<04:30,  1.12it/s]

Training_loss 9.25051


 40%|███▉      | 198/500 [04:00<04:27,  1.13it/s]

Training_loss 9.23282


 40%|███▉      | 199/500 [04:01<05:03,  1.01s/it]

Training_loss 9.21749


 40%|████      | 200/500 [04:02<05:21,  1.07s/it]

Training_loss 9.20154


 40%|████      | 201/500 [04:03<05:20,  1.07s/it]

Training_loss 9.18949


 40%|████      | 202/500 [04:04<05:06,  1.03s/it]

Training_loss 9.17518


 41%|████      | 203/500 [04:05<04:53,  1.01it/s]

Training_loss 9.16061


 41%|████      | 204/500 [04:07<05:27,  1.11s/it]

Training_loss 9.14549


 41%|████      | 205/500 [04:08<05:27,  1.11s/it]

Training_loss 9.13291


 41%|████      | 206/500 [04:09<05:17,  1.08s/it]

Training_loss 9.11957


 41%|████▏     | 207/500 [04:10<05:08,  1.05s/it]

Training_loss 9.10208


 42%|████▏     | 208/500 [04:11<04:54,  1.01s/it]

Training_loss 9.09193


 42%|████▏     | 209/500 [04:12<04:42,  1.03it/s]

Training_loss 9.07443


 42%|████▏     | 210/500 [04:13<04:55,  1.02s/it]

Training_loss 9.06158


 42%|████▏     | 211/500 [04:14<05:15,  1.09s/it]

Training_loss 9.04690


 42%|████▏     | 212/500 [04:15<05:23,  1.12s/it]

Training_loss 9.02882


 43%|████▎     | 213/500 [04:16<04:59,  1.04s/it]

Training_loss 9.01418


 43%|████▎     | 214/500 [04:17<04:47,  1.00s/it]

Training_loss 8.99986


 43%|████▎     | 215/500 [04:18<04:38,  1.02it/s]

Training_loss 8.98493


 43%|████▎     | 216/500 [04:19<04:44,  1.00s/it]

Training_loss 8.97108


 43%|████▎     | 217/500 [04:20<05:00,  1.06s/it]

Training_loss 8.95943


 44%|████▎     | 218/500 [04:21<04:39,  1.01it/s]

Training_loss 8.94439


 44%|████▍     | 219/500 [04:22<04:47,  1.02s/it]

Training_loss 8.93230


 44%|████▍     | 220/500 [04:23<04:48,  1.03s/it]

Training_loss 8.92167


 44%|████▍     | 221/500 [04:24<04:46,  1.03s/it]

Training_loss 8.90788


 44%|████▍     | 222/500 [04:25<04:29,  1.03it/s]

Training_loss 8.89615


 45%|████▍     | 223/500 [04:26<04:36,  1.00it/s]

Training_loss 8.88307


 45%|████▍     | 224/500 [04:27<04:28,  1.03it/s]

Training_loss 8.87284


 45%|████▌     | 225/500 [04:28<04:26,  1.03it/s]

Training_loss 8.86207


 45%|████▌     | 226/500 [04:29<04:26,  1.03it/s]

Training_loss 8.84848


 45%|████▌     | 227/500 [04:30<04:22,  1.04it/s]

Training_loss 8.83378


 46%|████▌     | 228/500 [04:31<04:45,  1.05s/it]

Training_loss 8.82046


 46%|████▌     | 229/500 [04:32<05:18,  1.18s/it]

Training_loss 8.80670


 46%|████▌     | 230/500 [04:34<05:23,  1.20s/it]

Training_loss 8.79349


 46%|████▌     | 231/500 [04:35<05:45,  1.28s/it]

Training_loss 8.78393


 46%|████▋     | 232/500 [04:37<05:47,  1.30s/it]

Training_loss 8.77429


 47%|████▋     | 233/500 [04:38<05:24,  1.22s/it]

Training_loss 8.76015


 47%|████▋     | 234/500 [04:39<05:15,  1.19s/it]

Training_loss 8.74645


 47%|████▋     | 235/500 [04:40<04:51,  1.10s/it]

Training_loss 8.73456


 47%|████▋     | 236/500 [04:41<04:44,  1.08s/it]

Training_loss 8.72047


 47%|████▋     | 237/500 [04:42<04:34,  1.05s/it]

Training_loss 8.70646


 48%|████▊     | 238/500 [04:43<04:30,  1.03s/it]

Training_loss 8.69786


 48%|████▊     | 239/500 [04:44<04:20,  1.00it/s]

Training_loss 8.68413


 48%|████▊     | 240/500 [04:45<04:19,  1.00it/s]

Training_loss 8.67276


 48%|████▊     | 241/500 [04:46<04:19,  1.00s/it]

Training_loss 8.66483


 48%|████▊     | 242/500 [04:47<04:26,  1.03s/it]

Training_loss 8.65343


 49%|████▊     | 243/500 [04:48<04:17,  1.00s/it]

Training_loss 8.64119


 49%|████▉     | 244/500 [04:48<04:08,  1.03it/s]

Training_loss 8.63148


 49%|████▉     | 245/500 [04:49<04:05,  1.04it/s]

Training_loss 8.62091


 49%|████▉     | 246/500 [04:50<04:00,  1.05it/s]

Training_loss 8.61422


 49%|████▉     | 247/500 [04:51<03:59,  1.06it/s]

Training_loss 8.60534


 50%|████▉     | 248/500 [04:52<03:57,  1.06it/s]

Training_loss 8.59336


 50%|████▉     | 249/500 [04:53<03:54,  1.07it/s]

Training_loss 8.58386


 50%|█████     | 250/500 [04:54<03:55,  1.06it/s]

Training_loss 8.57283


 50%|█████     | 251/500 [04:55<04:05,  1.01it/s]

Training_loss 8.56484


 50%|█████     | 252/500 [04:57<04:51,  1.17s/it]

Training_loss 8.55419


 51%|█████     | 253/500 [04:59<05:34,  1.35s/it]

Training_loss 8.54577


 51%|█████     | 254/500 [05:00<05:15,  1.28s/it]

Training_loss 8.53597


 51%|█████     | 255/500 [05:01<05:05,  1.25s/it]

Training_loss 8.52829


 51%|█████     | 256/500 [05:02<04:40,  1.15s/it]

Training_loss 8.51835


 51%|█████▏    | 257/500 [05:03<04:32,  1.12s/it]

Training_loss 8.50993


 52%|█████▏    | 258/500 [05:04<04:15,  1.06s/it]

Training_loss 8.49971


 52%|█████▏    | 259/500 [05:05<04:00,  1.00it/s]

Training_loss 8.48940


 52%|█████▏    | 260/500 [05:06<04:15,  1.06s/it]

Training_loss 8.47935


 52%|█████▏    | 261/500 [05:07<04:11,  1.05s/it]

Training_loss 8.46814


 52%|█████▏    | 262/500 [05:08<04:25,  1.12s/it]

Training_loss 8.46178


 53%|█████▎    | 263/500 [05:10<05:02,  1.27s/it]

Training_loss 8.44933


 53%|█████▎    | 264/500 [05:11<04:39,  1.18s/it]

Training_loss 8.44136


 53%|█████▎    | 265/500 [05:12<04:23,  1.12s/it]

Training_loss 8.43182


 53%|█████▎    | 266/500 [05:14<05:30,  1.41s/it]

Training_loss 8.42402


 53%|█████▎    | 267/500 [05:16<06:14,  1.61s/it]

Training_loss 8.41390


 54%|█████▎    | 268/500 [05:17<05:46,  1.49s/it]

Training_loss 8.40464


 54%|█████▍    | 269/500 [05:18<05:21,  1.39s/it]

Training_loss 8.39415


 54%|█████▍    | 270/500 [05:19<04:45,  1.24s/it]

Training_loss 8.38816


 54%|█████▍    | 271/500 [05:20<04:27,  1.17s/it]

Training_loss 8.38100


 54%|█████▍    | 272/500 [05:21<04:07,  1.09s/it]

Training_loss 8.37115


 55%|█████▍    | 273/500 [05:22<04:21,  1.15s/it]

Training_loss 8.36150


 55%|█████▍    | 274/500 [05:24<04:28,  1.19s/it]

Training_loss 8.35251


 55%|█████▌    | 275/500 [05:25<04:52,  1.30s/it]

Training_loss 8.34350


 55%|█████▌    | 276/500 [05:27<05:18,  1.42s/it]

Training_loss 8.33503


 55%|█████▌    | 277/500 [05:28<05:20,  1.44s/it]

Training_loss 8.32642


 56%|█████▌    | 278/500 [05:31<06:12,  1.68s/it]

Training_loss 8.32290


 56%|█████▌    | 279/500 [05:32<05:29,  1.49s/it]

Training_loss 8.31485


 56%|█████▌    | 280/500 [05:34<06:56,  1.89s/it]

Training_loss 8.30516


 56%|█████▌    | 281/500 [05:37<08:07,  2.22s/it]

Training_loss 8.29408


 56%|█████▋    | 282/500 [05:40<08:01,  2.21s/it]

Training_loss 8.28596


 57%|█████▋    | 283/500 [05:41<07:35,  2.10s/it]

Training_loss 8.27757


 57%|█████▋    | 284/500 [05:42<06:22,  1.77s/it]

Training_loss 8.27028


 57%|█████▋    | 285/500 [05:43<05:23,  1.51s/it]

Training_loss 8.26356


 57%|█████▋    | 286/500 [05:44<04:44,  1.33s/it]

Training_loss 8.25499


 57%|█████▋    | 287/500 [05:45<04:27,  1.26s/it]

Training_loss 8.24762


 58%|█████▊    | 288/500 [05:46<04:17,  1.21s/it]

Training_loss 8.23777


 58%|█████▊    | 289/500 [05:48<04:07,  1.17s/it]

Training_loss 8.22846


 58%|█████▊    | 290/500 [05:49<04:05,  1.17s/it]

Training_loss 8.22228


 58%|█████▊    | 291/500 [05:50<03:59,  1.14s/it]

Training_loss 8.21496


 58%|█████▊    | 292/500 [05:51<04:00,  1.16s/it]

Training_loss 8.20796


 59%|█████▊    | 293/500 [05:52<03:47,  1.10s/it]

Training_loss 8.20101


 59%|█████▉    | 294/500 [05:53<03:35,  1.05s/it]

Training_loss 8.19570


 59%|█████▉    | 295/500 [05:54<03:30,  1.03s/it]

Training_loss 8.18816


 59%|█████▉    | 296/500 [05:55<03:34,  1.05s/it]

Training_loss 8.18176


 59%|█████▉    | 297/500 [05:56<03:30,  1.03s/it]

Training_loss 8.17459


 60%|█████▉    | 298/500 [05:57<03:17,  1.02it/s]

Training_loss 8.16941


 60%|█████▉    | 299/500 [05:58<03:22,  1.01s/it]

Training_loss 8.16371


 60%|██████    | 300/500 [05:59<03:32,  1.06s/it]

Training_loss 8.15735


 60%|██████    | 301/500 [06:00<03:34,  1.08s/it]

Training_loss 8.15047


 60%|██████    | 302/500 [06:01<03:27,  1.05s/it]

Training_loss 8.14132


 61%|██████    | 303/500 [06:02<03:12,  1.02it/s]

Training_loss 8.13592


 61%|██████    | 304/500 [06:03<03:08,  1.04it/s]

Training_loss 8.12937


 61%|██████    | 305/500 [06:04<03:09,  1.03it/s]

Training_loss 8.12196


 61%|██████    | 306/500 [06:05<03:08,  1.03it/s]

Training_loss 8.11400


 61%|██████▏   | 307/500 [06:06<03:05,  1.04it/s]

Training_loss 8.10592


 62%|██████▏   | 308/500 [06:07<03:05,  1.04it/s]

Training_loss 8.10047


 62%|██████▏   | 309/500 [06:08<03:09,  1.01it/s]

Training_loss 8.09168


 62%|██████▏   | 310/500 [06:09<03:06,  1.02it/s]

Training_loss 8.08398


 62%|██████▏   | 311/500 [06:10<02:59,  1.05it/s]

Training_loss 8.07799


 62%|██████▏   | 312/500 [06:11<02:58,  1.05it/s]

Training_loss 8.07190


 63%|██████▎   | 313/500 [06:11<02:53,  1.08it/s]

Training_loss 8.06562


 63%|██████▎   | 314/500 [06:12<02:51,  1.09it/s]

Training_loss 8.05946


 63%|██████▎   | 315/500 [06:14<03:06,  1.01s/it]

Training_loss 8.05682


 63%|██████▎   | 316/500 [06:15<03:02,  1.01it/s]

Training_loss 8.05151


 63%|██████▎   | 317/500 [06:16<03:03,  1.00s/it]

Training_loss 8.04715


 64%|██████▎   | 318/500 [06:17<03:45,  1.24s/it]

Training_loss 8.03905


 64%|██████▍   | 319/500 [06:19<04:20,  1.44s/it]

Training_loss 8.03307


 64%|██████▍   | 320/500 [06:21<04:17,  1.43s/it]

Training_loss 8.03180


 64%|██████▍   | 321/500 [06:22<03:52,  1.30s/it]

Training_loss 8.02632


 64%|██████▍   | 322/500 [06:23<03:34,  1.20s/it]

Training_loss 8.02111


 65%|██████▍   | 323/500 [06:24<03:35,  1.22s/it]

Training_loss 8.01314


 65%|██████▍   | 324/500 [06:25<03:29,  1.19s/it]

Training_loss 8.00659


 65%|██████▌   | 325/500 [06:26<03:16,  1.12s/it]

Training_loss 8.00350


 65%|██████▌   | 326/500 [06:27<03:25,  1.18s/it]

Training_loss 7.99959


 65%|██████▌   | 327/500 [06:28<03:13,  1.12s/it]

Training_loss 7.99421


 66%|██████▌   | 328/500 [06:29<02:58,  1.04s/it]

Training_loss 7.98965


 66%|██████▌   | 329/500 [06:30<02:55,  1.03s/it]

Training_loss 7.98073


 66%|██████▌   | 330/500 [06:31<02:45,  1.03it/s]

Training_loss 7.97330


 66%|██████▌   | 331/500 [06:32<02:40,  1.05it/s]

Training_loss 7.96791


 66%|██████▋   | 332/500 [06:33<02:41,  1.04it/s]

Training_loss 7.96272


 67%|██████▋   | 333/500 [06:34<02:37,  1.06it/s]

Training_loss 7.95658


 67%|██████▋   | 334/500 [06:35<02:40,  1.03it/s]

Training_loss 7.95136


 67%|██████▋   | 335/500 [06:36<02:37,  1.05it/s]

Training_loss 7.94670


 67%|██████▋   | 336/500 [06:37<02:55,  1.07s/it]

Training_loss 7.94079


 67%|██████▋   | 337/500 [06:38<02:53,  1.07s/it]

Training_loss 7.93809


 68%|██████▊   | 338/500 [06:39<02:48,  1.04s/it]

Training_loss 7.93451


 68%|██████▊   | 339/500 [06:40<02:40,  1.00it/s]

Training_loss 7.93057


 68%|██████▊   | 340/500 [06:41<02:48,  1.06s/it]

Training_loss 7.92589


 68%|██████▊   | 341/500 [06:42<02:46,  1.05s/it]

Training_loss 7.92181


 68%|██████▊   | 342/500 [06:43<02:47,  1.06s/it]

Training_loss 7.91479


 69%|██████▊   | 343/500 [06:44<02:40,  1.02s/it]

Training_loss 7.90652


 69%|██████▉   | 344/500 [06:45<02:46,  1.07s/it]

Training_loss 7.90170


 69%|██████▉   | 345/500 [06:46<02:46,  1.07s/it]

Training_loss 7.89726


 69%|██████▉   | 346/500 [06:47<02:38,  1.03s/it]

Training_loss 7.89356


 69%|██████▉   | 347/500 [06:48<02:34,  1.01s/it]

Training_loss 7.88689


 70%|██████▉   | 348/500 [06:49<02:31,  1.00it/s]

Training_loss 7.88180


 70%|██████▉   | 349/500 [06:50<02:23,  1.05it/s]

Training_loss 7.87633


 70%|███████   | 350/500 [06:51<02:25,  1.03it/s]

Training_loss 7.87350


 70%|███████   | 351/500 [06:52<02:21,  1.05it/s]

Training_loss 7.86867


 70%|███████   | 352/500 [06:53<02:15,  1.09it/s]

Training_loss 7.86661


 71%|███████   | 353/500 [06:54<02:23,  1.03it/s]

Training_loss 7.86259


 71%|███████   | 354/500 [06:55<02:27,  1.01s/it]

Training_loss 7.85852


 71%|███████   | 355/500 [06:56<02:35,  1.07s/it]

Training_loss 7.85102


 71%|███████   | 356/500 [06:57<02:31,  1.05s/it]

Training_loss 7.84453


 71%|███████▏  | 357/500 [06:58<02:29,  1.04s/it]

Training_loss 7.84007


 72%|███████▏  | 358/500 [06:59<02:28,  1.05s/it]

Training_loss 7.83504


 72%|███████▏  | 359/500 [07:00<02:22,  1.01s/it]

Training_loss 7.82893


 72%|███████▏  | 360/500 [07:02<02:34,  1.10s/it]

Training_loss 7.82409


 72%|███████▏  | 361/500 [07:03<02:30,  1.08s/it]

Training_loss 7.81719


 72%|███████▏  | 362/500 [07:04<02:31,  1.10s/it]

Training_loss 7.81469


 73%|███████▎  | 363/500 [07:05<02:34,  1.13s/it]

Training_loss 7.81074


 73%|███████▎  | 364/500 [07:06<02:31,  1.11s/it]

Training_loss 7.80904


 73%|███████▎  | 365/500 [07:07<02:24,  1.07s/it]

Training_loss 7.80662


 73%|███████▎  | 366/500 [07:08<02:26,  1.09s/it]

Training_loss 7.80211


 73%|███████▎  | 367/500 [07:09<02:23,  1.08s/it]

Training_loss 7.79623


 74%|███████▎  | 368/500 [07:10<02:15,  1.02s/it]

Training_loss 7.79199


 74%|███████▍  | 369/500 [07:11<02:08,  1.02it/s]

Training_loss 7.78962


 74%|███████▍  | 370/500 [07:12<02:04,  1.05it/s]

Training_loss 7.78320


 74%|███████▍  | 371/500 [07:13<02:02,  1.06it/s]

Training_loss 7.77908


 74%|███████▍  | 372/500 [07:14<02:00,  1.06it/s]

Training_loss 7.77446


 75%|███████▍  | 373/500 [07:15<01:54,  1.11it/s]

Training_loss 7.77160


 75%|███████▍  | 374/500 [07:16<01:55,  1.09it/s]

Training_loss 7.76741


 75%|███████▌  | 375/500 [07:16<01:53,  1.11it/s]

Training_loss 7.76571


 75%|███████▌  | 376/500 [07:18<02:02,  1.01it/s]

Training_loss 7.76362


 75%|███████▌  | 377/500 [07:19<02:02,  1.01it/s]

Training_loss 7.75804


 76%|███████▌  | 378/500 [07:20<02:06,  1.03s/it]

Training_loss 7.75371


 76%|███████▌  | 379/500 [07:21<02:05,  1.03s/it]

Training_loss 7.75088


 76%|███████▌  | 380/500 [07:22<02:04,  1.04s/it]

Training_loss 7.74766


 76%|███████▌  | 381/500 [07:23<02:11,  1.11s/it]

Training_loss 7.74533


 76%|███████▋  | 382/500 [07:24<02:09,  1.10s/it]

Training_loss 7.74098


 77%|███████▋  | 383/500 [07:25<02:07,  1.09s/it]

Training_loss 7.73795


 77%|███████▋  | 384/500 [07:27<02:20,  1.21s/it]

Training_loss 7.73397


 77%|███████▋  | 385/500 [07:29<02:41,  1.40s/it]

Training_loss 7.72941


 77%|███████▋  | 386/500 [07:30<02:25,  1.28s/it]

Training_loss 7.72704


 77%|███████▋  | 387/500 [07:31<02:14,  1.19s/it]

Training_loss 7.72366


 78%|███████▊  | 388/500 [07:32<02:06,  1.13s/it]

Training_loss 7.71873


 78%|███████▊  | 389/500 [07:33<01:58,  1.07s/it]

Training_loss 7.71557


 78%|███████▊  | 390/500 [07:34<02:00,  1.10s/it]

Training_loss 7.71106


 78%|███████▊  | 391/500 [07:35<01:51,  1.03s/it]

Training_loss 7.70662


 78%|███████▊  | 392/500 [07:36<01:53,  1.05s/it]

Training_loss 7.70188


 79%|███████▊  | 393/500 [07:37<01:57,  1.10s/it]

Training_loss 7.70008


 79%|███████▉  | 394/500 [07:39<02:30,  1.42s/it]

Training_loss 7.69794


 79%|███████▉  | 395/500 [07:41<03:00,  1.72s/it]

Training_loss 7.69421


 79%|███████▉  | 396/500 [07:43<02:58,  1.71s/it]

Training_loss 7.68830


 79%|███████▉  | 397/500 [07:44<02:44,  1.60s/it]

Training_loss 7.68415


 80%|███████▉  | 398/500 [07:46<02:28,  1.45s/it]

Training_loss 7.68192


 80%|███████▉  | 399/500 [07:47<02:11,  1.30s/it]

Training_loss 7.67893


 80%|████████  | 400/500 [07:48<02:01,  1.21s/it]

Training_loss 7.67545


 80%|████████  | 401/500 [07:49<01:57,  1.19s/it]

Training_loss 7.67237


 80%|████████  | 402/500 [07:50<01:55,  1.18s/it]

Training_loss 7.66977


 81%|████████  | 403/500 [07:51<01:51,  1.15s/it]

Training_loss 7.66494


 81%|████████  | 404/500 [07:52<01:48,  1.13s/it]

Training_loss 7.66318


 81%|████████  | 405/500 [07:54<02:00,  1.26s/it]

Training_loss 7.66142


 81%|████████  | 406/500 [07:55<01:49,  1.16s/it]

Training_loss 7.65701


 81%|████████▏ | 407/500 [07:55<01:41,  1.09s/it]

Training_loss 7.65336


 82%|████████▏ | 408/500 [07:56<01:36,  1.05s/it]

Training_loss 7.65091


 82%|████████▏ | 409/500 [07:57<01:33,  1.03s/it]

Training_loss 7.64956


 82%|████████▏ | 410/500 [07:58<01:32,  1.03s/it]

Training_loss 7.64642


 82%|████████▏ | 411/500 [07:59<01:29,  1.01s/it]

Training_loss 7.64264


 82%|████████▏ | 412/500 [08:00<01:27,  1.01it/s]

Training_loss 7.64110


 83%|████████▎ | 413/500 [08:01<01:25,  1.02it/s]

Training_loss 7.63776


 83%|████████▎ | 414/500 [08:02<01:26,  1.01s/it]

Training_loss 7.63442


 83%|████████▎ | 415/500 [08:03<01:25,  1.01s/it]

Training_loss 7.63062


 83%|████████▎ | 416/500 [08:04<01:23,  1.01it/s]

Training_loss 7.62781


 83%|████████▎ | 417/500 [08:05<01:25,  1.03s/it]

Training_loss 7.62422


 84%|████████▎ | 418/500 [08:07<01:31,  1.12s/it]

Training_loss 7.62222


 84%|████████▍ | 419/500 [08:08<01:30,  1.12s/it]

Training_loss 7.61904


 84%|████████▍ | 420/500 [08:10<01:44,  1.30s/it]

Training_loss 7.61531


 84%|████████▍ | 421/500 [08:12<02:05,  1.58s/it]

Training_loss 7.61164


 84%|████████▍ | 422/500 [08:14<02:09,  1.66s/it]

Training_loss 7.60832


 85%|████████▍ | 423/500 [08:15<02:02,  1.59s/it]

Training_loss 7.60519


 85%|████████▍ | 424/500 [08:17<02:00,  1.58s/it]

Training_loss 7.60360


 85%|████████▌ | 425/500 [08:19<02:16,  1.83s/it]

Training_loss 7.60290


 85%|████████▌ | 426/500 [08:22<02:42,  2.19s/it]

Training_loss 7.59801


 85%|████████▌ | 427/500 [08:25<02:47,  2.30s/it]

Training_loss 7.59581


 86%|████████▌ | 428/500 [08:26<02:34,  2.14s/it]

Training_loss 7.59221


 86%|████████▌ | 429/500 [08:28<02:28,  2.09s/it]

Training_loss 7.58951


 86%|████████▌ | 430/500 [08:31<02:43,  2.33s/it]

Training_loss 7.58921


 86%|████████▌ | 431/500 [08:34<02:41,  2.34s/it]

Training_loss 7.58760


 86%|████████▋ | 432/500 [08:35<02:13,  1.97s/it]

Training_loss 7.58346


 87%|████████▋ | 433/500 [08:36<01:51,  1.67s/it]

Training_loss 7.58045


 87%|████████▋ | 434/500 [08:37<01:42,  1.55s/it]

Training_loss 7.57982


 87%|████████▋ | 435/500 [08:38<01:29,  1.38s/it]

Training_loss 7.57944


 87%|████████▋ | 436/500 [08:39<01:22,  1.28s/it]

Training_loss 7.57839


 87%|████████▋ | 437/500 [08:40<01:13,  1.17s/it]

Training_loss 7.57421


 88%|████████▊ | 438/500 [08:41<01:15,  1.22s/it]

Training_loss 7.57089


 88%|████████▊ | 439/500 [08:42<01:07,  1.11s/it]

Training_loss 7.56759


 88%|████████▊ | 440/500 [08:43<01:08,  1.14s/it]

Training_loss 7.56520


 88%|████████▊ | 441/500 [08:44<01:01,  1.04s/it]

Training_loss 7.56280


 88%|████████▊ | 442/500 [08:45<01:00,  1.04s/it]

Training_loss 7.56043


 89%|████████▊ | 443/500 [08:46<00:57,  1.00s/it]

Training_loss 7.55912


 89%|████████▉ | 444/500 [08:47<00:55,  1.01it/s]

Training_loss 7.55609


 89%|████████▉ | 445/500 [08:48<00:55,  1.02s/it]

Training_loss 7.55424


 89%|████████▉ | 446/500 [08:49<00:55,  1.04s/it]

Training_loss 7.55221


 89%|████████▉ | 447/500 [08:50<00:53,  1.02s/it]

Training_loss 7.54828


 90%|████████▉ | 448/500 [08:51<00:54,  1.06s/it]

Training_loss 7.54788


 90%|████████▉ | 449/500 [08:52<00:54,  1.07s/it]

Training_loss 7.54659


 90%|█████████ | 450/500 [08:53<00:52,  1.04s/it]

Training_loss 7.54504


 90%|█████████ | 451/500 [08:54<00:50,  1.03s/it]

Training_loss 7.54492


 90%|█████████ | 452/500 [08:56<00:52,  1.08s/it]

Training_loss 7.54347


 91%|█████████ | 453/500 [08:57<00:58,  1.25s/it]

Training_loss 7.54236


 91%|█████████ | 454/500 [08:58<00:53,  1.16s/it]

Training_loss 7.54032


 91%|█████████ | 455/500 [08:59<00:49,  1.10s/it]

Training_loss 7.53963


 91%|█████████ | 456/500 [09:00<00:45,  1.04s/it]

Training_loss 7.53622


 91%|█████████▏| 457/500 [09:01<00:44,  1.03s/it]

Training_loss 7.53460


 92%|█████████▏| 458/500 [09:02<00:45,  1.10s/it]

Training_loss 7.53287


 92%|█████████▏| 459/500 [09:03<00:43,  1.07s/it]

Training_loss 7.52954


 92%|█████████▏| 460/500 [09:04<00:41,  1.04s/it]

Training_loss 7.52766


 92%|█████████▏| 461/500 [09:05<00:40,  1.05s/it]

Training_loss 7.52671


 92%|█████████▏| 462/500 [09:06<00:39,  1.04s/it]

Training_loss 7.52391


 93%|█████████▎| 463/500 [09:07<00:36,  1.02it/s]

Training_loss 7.51909


 93%|█████████▎| 464/500 [09:08<00:35,  1.02it/s]

Training_loss 7.51611


 93%|█████████▎| 465/500 [09:09<00:35,  1.01s/it]

Training_loss 7.51511


 93%|█████████▎| 466/500 [09:10<00:34,  1.02s/it]

Training_loss 7.51181


 93%|█████████▎| 467/500 [09:11<00:32,  1.01it/s]

Training_loss 7.50988


 94%|█████████▎| 468/500 [09:12<00:30,  1.04it/s]

Training_loss 7.50798


 94%|█████████▍| 469/500 [09:13<00:31,  1.00s/it]

Training_loss 7.50711


 94%|█████████▍| 470/500 [09:14<00:29,  1.02it/s]

Training_loss 7.50762


 94%|█████████▍| 471/500 [09:15<00:29,  1.00s/it]

Training_loss 7.50455


 94%|█████████▍| 472/500 [09:16<00:27,  1.00it/s]

Training_loss 7.50355


 95%|█████████▍| 473/500 [09:17<00:25,  1.05it/s]

Training_loss 7.50176


 95%|█████████▍| 474/500 [09:18<00:25,  1.03it/s]

Training_loss 7.49759


 95%|█████████▌| 475/500 [09:19<00:25,  1.02s/it]

Training_loss 7.49812


 95%|█████████▌| 476/500 [09:20<00:25,  1.05s/it]

Training_loss 7.49594


 95%|█████████▌| 477/500 [09:23<00:33,  1.44s/it]

Training_loss 7.49410


 96%|█████████▌| 478/500 [09:24<00:31,  1.43s/it]

Training_loss 7.48985


 96%|█████████▌| 479/500 [09:25<00:27,  1.32s/it]

Training_loss 7.48842


 96%|█████████▌| 480/500 [09:26<00:24,  1.21s/it]

Training_loss 7.48561


 96%|█████████▌| 481/500 [09:27<00:22,  1.18s/it]

Training_loss 7.48269


 96%|█████████▋| 482/500 [09:28<00:20,  1.16s/it]

Training_loss 7.48220


 97%|█████████▋| 483/500 [09:29<00:19,  1.12s/it]

Training_loss 7.47991


 97%|█████████▋| 484/500 [09:30<00:16,  1.05s/it]

Training_loss 7.47956


 97%|█████████▋| 485/500 [09:31<00:15,  1.00s/it]

Training_loss 7.47404


 97%|█████████▋| 486/500 [09:32<00:13,  1.01it/s]

Training_loss 7.47249


 97%|█████████▋| 487/500 [09:33<00:12,  1.05it/s]

Training_loss 7.47006


 98%|█████████▊| 488/500 [09:34<00:11,  1.02it/s]

Training_loss 7.46588


 98%|█████████▊| 489/500 [09:35<00:10,  1.05it/s]

Training_loss 7.46350


 98%|█████████▊| 490/500 [09:36<00:09,  1.06it/s]

Training_loss 7.46256


 98%|█████████▊| 491/500 [09:37<00:08,  1.03it/s]

Training_loss 7.45969


 98%|█████████▊| 492/500 [09:38<00:08,  1.00s/it]

Training_loss 7.45863


 99%|█████████▊| 493/500 [09:39<00:06,  1.02it/s]

Training_loss 7.46002


 99%|█████████▉| 494/500 [09:40<00:05,  1.06it/s]

Training_loss 7.45887


 99%|█████████▉| 495/500 [09:41<00:04,  1.08it/s]

Training_loss 7.45882


 99%|█████████▉| 496/500 [09:41<00:03,  1.11it/s]

Training_loss 7.45460


 99%|█████████▉| 497/500 [09:42<00:02,  1.13it/s]

Training_loss 7.45135


100%|█████████▉| 498/500 [09:43<00:01,  1.16it/s]

Training_loss 7.44968


100%|█████████▉| 499/500 [09:44<00:00,  1.15it/s]

Training_loss 7.44841


100%|██████████| 500/500 [09:45<00:00,  1.17s/it]

Training_loss 7.44543





In [23]:
test_loss = np.array(test_loss)
total_rel_error = np.array(total_rel_error)

in_cluster_proj_norm = np.array(in_cluster_proj_norm)
out_cluster_proj_norm = np.array(out_cluster_proj_norm)
in_cluster_proj_diff_norm = np.array(in_cluster_proj_diff_norm)
out_cluster_proj_diff_norm = np.array(out_cluster_proj_diff_norm)


In [24]:
'''
  0%|          | 1/2000 [00:12<6:59:49, 12.60s/it]
Training_loss 0.69317,   Accuracy 0.52177
  0%|          | 2/2000 [00:25<6:58:41, 12.57s/it]
Training_loss 0.69256,   Accuracy 0.52523
  0%|          | 3/2000 [00:37<6:55:30, 12.48s/it]
Training_loss 0.69282,   Accuracy 0.52400
  0%|          | 4/2000 [00:50<6:56:56, 12.53s/it]
Training_loss 0.69216,   Accuracy 0.52552
  0%|          | 5/2000 [01:05<7:35:34, 13.70s/it]
Training_loss 0.69178,   Accuracy 0.52765
  0%|          | 6/2000 [01:18<7:24:07, 13.36s/it]
Training_loss 0.69037,   Accuracy 0.55240
  0%|          | 7/2000 [01:32<7:32:02, 13.61s/it]
Training_loss 0.68986,   Accuracy 0.55653
  0%|          | 8/2000 [01:47<7:42:39, 13.94s/it]
Training_loss 0.68921,   Accuracy 0.56338
'''

'\n  0%|          | 1/2000 [00:12<6:59:49, 12.60s/it]\nTraining_loss 0.69317,   Accuracy 0.52177\n  0%|          | 2/2000 [00:25<6:58:41, 12.57s/it]\nTraining_loss 0.69256,   Accuracy 0.52523\n  0%|          | 3/2000 [00:37<6:55:30, 12.48s/it]\nTraining_loss 0.69282,   Accuracy 0.52400\n  0%|          | 4/2000 [00:50<6:56:56, 12.53s/it]\nTraining_loss 0.69216,   Accuracy 0.52552\n  0%|          | 5/2000 [01:05<7:35:34, 13.70s/it]\nTraining_loss 0.69178,   Accuracy 0.52765\n  0%|          | 6/2000 [01:18<7:24:07, 13.36s/it]\nTraining_loss 0.69037,   Accuracy 0.55240\n  0%|          | 7/2000 [01:32<7:32:02, 13.61s/it]\nTraining_loss 0.68986,   Accuracy 0.55653\n  0%|          | 8/2000 [01:47<7:42:39, 13.94s/it]\nTraining_loss 0.68921,   Accuracy 0.56338\n'

In [25]:
np.save( 'training_loss_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_'), test_loss)
np.save('relative_error_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_'), total_rel_error)
np.save( 'in_cluster_proj_norm_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_'), in_cluster_proj_norm)
np.save('out_cluster_proj_norm_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_'), out_cluster_proj_norm)
np.save( 'in_cluster_proj_diff_norm_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_'), in_cluster_proj_diff_norm)
np.save('out_cluster_proj_diff_norm_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_'), out_cluster_proj_diff_norm)

In [26]:
'training_loss_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_')

'training_loss_sheave_fml0_1_pout0_1'

In [27]:
# task_loss dictionary
for key, value in task_loss.items():
    # Convert the list to a NumPy array
    array_loss = np.array(value)
    
    # Save the NumPy array using the specified format
    np.save('training_loss_sheave_fml_task' + key + '_' + str(lamda).replace('.', '_') + '_pout' + str(pout).replace('.', '_'), array_loss)

# task_rel_error dictionary
for key, value in task_rel_error.items():
    # Convert the list to a NumPy array
    array_rel_error = np.array(value)
    
    # Save the NumPy array using the specified format
    np.save('relative_error_sheave_fml_task' + key + '_' + str(lamda).replace('.', '_') + '_pout' + str(pout).replace('.', '_'), array_rel_error)