In [1]:
from keras.utils import to_categorical
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import copy
import matplotlib.pyplot as plt
from tqdm import tqdm
import networkx as nx
from torch.nn.utils import parameters_to_vector, vector_to_parameters
import networkx as nx

In [2]:
from sklearn.datasets import load_svmlight_file


def get_data(data):
    data = load_svmlight_file(data)
    return data[0].toarray(), data[1]
data = "./abalone_scale.txt"
X, y = get_data(data)
#y[y == -1] = 0
# Convert labels to one-hot encoding
#y = to_categorical(y)

#data = "./a1a_t"
#X_test,y_test = get_data(data)
#y_test[y_test == -1] = 0
#y_test = to_categorical(y_test)


In [3]:
# Dataset partitioning
def random_split(X, y, n, seed):
    """Equally split data between n agents"""
    rng = np.random.default_rng(seed)
    perm = rng.permutation(y.shape[0])
    X_split = np.array_split(X[perm], n)  #np.stack to keep as a np array
    y_split = np.array_split(y[perm], n)
    return X_split, y_split

In [4]:
no_users = 30

In [5]:
X, y = random_split(X, y, 3, 42)

In [6]:
X1 = X[0][:, 0:4]
X2 = X[1][:, 2: 6]
X3 = X[2][:, 3:]
y1 = y[0]
y2 = y[1]
y3 = y[2]

X_test = [np.concatenate([X[1], X[2]], axis=0), np.concatenate([X[0], X[2]], axis=0), np.concatenate([X[0], X[1]], axis=0)]
y_test = [np.concatenate((y[1], y[2])), np.concatenate((y[0], y[2])), np.concatenate((y[0], y[1]))]


subset_ranges = [np.arange(0, 4), np.arange(2, 6), np.arange(3, 8)]
subset_lengths = [subset_ranges[0].shape[0], subset_ranges[1].shape[0], subset_ranges[2].shape[0]]

In [7]:
len(X1)

1393

In [8]:
X_test[0].shape

(2784, 8)

In [9]:
X_test[0].shape

(2784, 8)

In [10]:
# Graph implementation
def generate_graph(cluster_sizes=[100,100], pin=0.5, pout=0.01, seed=0):
    """Generate a random connected graph"""
    probs = np.array([[pin, pout, pout],[pout, pin, pout],[pout, pout, pin]])
    while True:
        g = nx.stochastic_block_model(cluster_sizes, probs, seed=0)
        if nx.algorithms.components.is_connected(g):
            return g


cluster_sizes = [10, 10, 10]
#features_sizes = [8, 7, 6, 5]
pin = 0.5
pout = 0.1
seed = 0
alpha = 1e-2
lamda = 0#1e-1#1e-3#1e-1#1e-3
eta = 1 * 1e-2
d0 = min(subset_lengths)
no_users = sum(cluster_sizes)
batch_size = 20
epochs = 1
it = 2000
G = generate_graph(cluster_sizes, pin, pout, seed)

# Set a random seed for reproducibility
seed = 17
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
#nx.draw(G, with_labels=True, node_size=100, alpha=1, linewidths=10)
#plt.show()

In [11]:
d0

8

In [12]:
def degrees(A):
    """Return the degrees of each node of a graph from its adjacency matrix"""
    return np.sum(A, axis=0).reshape(A.shape[0], 1)

def node_degree(n, G):
    cnt = 0
    for i in G.neighbors(n):
        cnt += 1
    return cnt

def get_neighbors(n, G):
    neighbors_list = []
    for i in G.neighbors(n):
        neighbors_list.append(int(i))
    return neighbors_list

In [13]:
y1.shape

(1393,)

In [14]:
datapoints = {}
count = 0

X1, y1 = random_split(X1, y1, 10, 42)
X2, y2 = random_split(X2, y2, 10, 42)
X3, y3 = random_split(X3, y3, 10, 42)

X_train = [X1, X2, X3]
y_train = [y1, y2, y3]
input_sizes = [X1[0].shape[1], X2[0].shape[1], X3[0].shape[1]]

for i, cluster_size in enumerate(cluster_sizes):
    for j in range(cluster_size):
        
        test_features = X_train[i][j]#X_test[:, subset_ranges[i]]
        test_label = y_train[i][j]#y_test
        datapoints[count] = {
                'features': X_train[i][j],
                'degree': node_degree(count, G),
                'label': y_train[i][j],
                'neighbors': get_neighbors(count, G),
                'input_size': X_train[i][j].shape[1],
                'test_features':X_test[i][:, subset_ranges[i]],
                'test_labels': y_test[i]
            }
        count += 1

In [15]:
datapoints[21]['test_features'].shape

(2785, 8)

In [16]:
class MyDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = torch.FloatTensor(data)
        self.targets = torch.FloatTensor(targets)
        
    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]

        return x, y
    
    def __len__(self):
        return len(self.data)


In [17]:

class MLP_Net(nn.Module):
    def __init__(self, input_size, num_classes, user_id):
        super(MLP_Net, self).__init__()
        self.fc1 = nn.Linear(input_size, num_classes, bias=True)
        self.user_id = user_id

    def forward(self, x):
        x = torch.flatten(x, 1)
        output = self.fc1(x)#F.softmax(self.fc1(x), dim=1)  # Applying softmax along the second dimension
        return output

In [18]:
from typing import Iterable, Optional

def grads_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor:
    r"""Convert parameters to one vector

    Args:
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.

    Returns:
        The parameters represented by a single vector
    """
    # Flag for the device where the parameter is located
    param_device = None

    vec = []
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = param.grad

        vec.append(param_device.view(-1))
    return torch.cat(vec)

In [19]:
model = MLP_Net(datapoints[21]["input_size"], 1, user_id=0)

lr = 0.01

dataloader = DataLoader(MyDataset(datapoints[21]["features"], datapoints[21]["label"]), batch_size=50, shuffle=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for i in range(1000):
    for j, (x, y) in zip(range(1), dataloader):
        criterion = nn.MSELoss()
        optimizer.zero_grad()
        yhat = model(x)


        
        # Calculate accuracy
        #_, predicted = torch.max(yhat, 1)
        #_, predicted_true = torch.max(y, 1)
        #correct_predictions = (predicted == predicted_true).sum().item()
         

        loss = criterion(yhat.squeeze(), y)
        
        
        loss.backward()
        print(i, loss.detach())
        #optimizer.step()
        new_model = parameters_to_vector(model.parameters()) - lr * grads_to_vector(model.parameters())
        vector_to_parameters(parameters=model.parameters(), vec=new_model)
        #if i % 50 ==0:
            #lr *= 0.9
            

#parameters_to_vector(model.parameters())

0 tensor(93.4025)
1 tensor(121.6946)
2 tensor(102.6583)
3 tensor(82.3093)
4 tensor(73.0132)
5 tensor(92.3217)
6 tensor(80.1299)
7 tensor(57.3422)
8 tensor(65.2588)
9 tensor(58.0577)
10 tensor(40.5504)
11 tensor(51.3492)
12 tensor(50.2891)
13 tensor(52.1893)
14 tensor(55.1758)
15 tensor(27.6117)
16 tensor(31.0062)
17 tensor(35.5602)
18 tensor(32.8030)
19 tensor(24.6145)
20 tensor(25.1938)
21 tensor(24.8841)
22 tensor(28.8012)
23 tensor(25.2397)
24 tensor(28.0119)
25 tensor(22.4082)
26 tensor(18.4860)
27 tensor(21.7617)
28 tensor(26.3294)
29 tensor(17.7479)
30 tensor(21.5105)
31 tensor(30.6467)
32 tensor(17.9162)
33 tensor(15.8744)
34 tensor(20.9609)
35 tensor(18.4591)
36 tensor(16.5488)
37 tensor(13.0537)
38 tensor(16.8153)
39 tensor(15.0364)
40 tensor(19.5861)
41 tensor(10.0754)
42 tensor(11.2041)
43 tensor(10.7943)
44 tensor(16.9486)
45 tensor(22.1211)
46 tensor(10.7859)
47 tensor(14.2005)
48 tensor(20.1441)
49 tensor(10.5829)
50 tensor(22.5816)
51 tensor(10.2437)
52 tensor(16.6312)
5

439 tensor(9.4729)
440 tensor(5.3435)
441 tensor(4.2114)
442 tensor(10.4373)
443 tensor(8.4458)
444 tensor(10.3180)
445 tensor(7.3400)
446 tensor(5.0345)
447 tensor(10.2591)
448 tensor(14.1265)
449 tensor(7.1168)
450 tensor(10.3710)
451 tensor(9.2006)
452 tensor(5.4707)
453 tensor(11.8797)
454 tensor(9.3633)
455 tensor(5.8007)
456 tensor(4.8002)
457 tensor(10.9880)
458 tensor(5.7708)
459 tensor(11.2766)
460 tensor(7.2226)
461 tensor(12.4360)
462 tensor(10.8187)
463 tensor(6.8543)
464 tensor(7.0458)
465 tensor(6.4301)
466 tensor(6.3984)
467 tensor(6.2020)
468 tensor(5.4771)
469 tensor(6.8973)
470 tensor(11.4558)
471 tensor(8.1193)
472 tensor(14.7236)
473 tensor(6.6412)
474 tensor(10.7313)
475 tensor(9.8787)
476 tensor(4.7234)
477 tensor(7.1869)
478 tensor(10.2079)
479 tensor(7.2461)
480 tensor(7.9890)
481 tensor(6.7481)
482 tensor(12.0549)
483 tensor(9.2832)
484 tensor(9.9671)
485 tensor(6.9775)
486 tensor(6.5490)
487 tensor(11.9139)
488 tensor(8.0578)
489 tensor(7.1097)
490 tensor(11.8

900 tensor(7.2022)
901 tensor(6.9912)
902 tensor(10.9864)
903 tensor(4.6179)
904 tensor(9.5163)
905 tensor(6.4326)
906 tensor(7.1401)
907 tensor(7.8353)
908 tensor(7.3568)
909 tensor(4.9576)
910 tensor(12.2050)
911 tensor(8.0864)
912 tensor(5.3518)
913 tensor(9.5107)
914 tensor(7.0288)
915 tensor(4.7146)
916 tensor(6.3132)
917 tensor(7.7449)
918 tensor(7.4608)
919 tensor(7.0597)
920 tensor(4.7534)
921 tensor(6.5064)
922 tensor(5.9731)
923 tensor(6.2760)
924 tensor(11.1439)
925 tensor(10.4024)
926 tensor(5.6694)
927 tensor(8.4276)
928 tensor(7.7747)
929 tensor(10.2556)
930 tensor(7.3119)
931 tensor(7.8629)
932 tensor(6.1120)
933 tensor(7.3113)
934 tensor(6.3813)
935 tensor(8.3312)
936 tensor(6.6076)
937 tensor(8.5386)
938 tensor(8.9513)
939 tensor(7.9242)
940 tensor(5.4387)
941 tensor(10.2702)
942 tensor(13.1192)
943 tensor(6.1367)
944 tensor(10.7050)
945 tensor(7.1904)
946 tensor(9.4496)
947 tensor(6.0622)
948 tensor(13.3876)
949 tensor(7.8252)
950 tensor(9.2321)
951 tensor(8.2854)
952

In [20]:
class ClientUpdate(object):
    def __init__(self, dataset, batchSize, alpha, lamda, epochs, projection_list, projected_weights):
        self.train_loader = DataLoader(MyDataset(dataset["features"], dataset["label"]), batch_size=batchSize, shuffle=True)
        #self.learning_rate = learning_rate
        self.epochs = epochs
        self.batchSize = batchSize

    def train(self, model):
        criterion = nn.MSELoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=alpha)

        e_loss = []
        for epoch in range(1, self.epochs+1):
            train_loss = 0
            #model.train()
            for i, (data, labels) in zip(range(1), self.train_loader):
                data, labels = data, labels
                optimizer.zero_grad() 
                output = model(data)  
                loss = criterion(output.squeeze(), labels)
                #loss += mu/2 * torch.norm(client_param.data - server_param.data)**2
                loss.backward()
                grads = grads_to_vector(model.parameters())
                #optimizer.step()
                train_loss += loss.item()*data.size(0)
                weights = parameters_to_vector(model.parameters())
                mat_vec_sum = torch.zeros_like(weights)
                for j in G.neighbors(model.user_id):
                    mat_vec_sum = torch.add(mat_vec_sum, torch.matmul(torch.transpose(projection_list[model.user_id][j], 0, 1), 
                                                         projected_weights[model.user_id][j] - projected_weights[j][model.user_id]))
                
                model_update = parameters_to_vector(model.parameters()) - alpha * (grads + lamda * mat_vec_sum)
                
            vector_to_parameters(parameters=model.parameters(), vec=model_update)
                

            train_loss = train_loss/self.batchSize#len(self.train_loader.dataset) 
            e_loss.append(train_loss)

        total_loss = e_loss#sum(e_loss)/len(e_loss)

        return model.state_dict(), total_loss

In [21]:
# Preparing projection matrices
models = [MLP_Net(input_size=datapoints[i]['input_size'], num_classes=1, user_id=i) for i in range(no_users)]
#temp = MLP_Net()
projection_list = []
projected_weights = []

def update_ProjWeight(projection_list, projected_weights, models, first_run=True):
    for i in range(no_users):
        neighbors_mat = []
        neighbors_weights = []
        for j in range(no_users):
            if j in G.neighbors(i):
                with torch.no_grad():
                    if first_run == True:
                        # Specify the dimensions of the rectangular matrix
                        row, column = d0, parameters_to_vector(models[i].parameters()).size()[0]

                        # Generate random values for the diagonal from a normal distribution
                        diag_values = torch.ones(min(row, column))

                        # Create a rectangular matrix with diagonal elements
                        mat = torch.diag(diag_values)

                        # If the matrix is larger than the diagonal vector, fill the remaining elements with zeros
                       
                        mat = torch.cat((mat, torch.zeros(row, column - row)), dim=1)

                        

                        # Append the matrix to the list
                        neighbors_mat.append(mat)
                        neighbors_weights.append(torch.matmul(mat, parameters_to_vector(models[i].parameters())))
                    else:
                        neighbors_weights.append(torch.matmul(projection_list[i][j], parameters_to_vector(models[i].parameters())))
            else:
                neighbors_mat.append(0)
                neighbors_weights.append(0)
        if first_run == True:
            projection_list.append(neighbors_mat)
        projected_weights.append(neighbors_weights)

update_ProjWeight(projection_list, projected_weights, models)



In [22]:
def testing(model, dataset, bs, criterion): 
    test_loss = 0
    correct = 0
    test_loader = DataLoader(MyDataset(dataset["test_features"], dataset["test_labels"]), batch_size=bs)
    l = len(test_loader)
    model.eval()
    for data, labels in test_loader:
        data, labels = data, labels
        output = model(data)
        loss = criterion(output.squeeze(), labels)
        test_loss += loss.item()*data.size(0)
        #_, pred = torch.max(output, 1)
        #_, predicted_true = torch.max(labels, 1)
        #correct += pred.eq(predicted_true.data.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)
    #correct /= len(test_loader.dataset)
    
    return test_loss#, correct

In [23]:
#global_model = CNN_Net().cuda()
models = [MLP_Net(input_size=datapoints[i]['input_size'], num_classes=1, user_id=i) for i in range(no_users)]
dummy_models = [MLP_Net(input_size=datapoints[i]['input_size'], num_classes=1, user_id=i) for i in range(no_users)]

#model.load_state_dict(global_model.state_dict())

criterion = nn.MSELoss()


train_loss = []
test_loss = []
test_accuracy = []
total_rel_error = []

in_cluster_proj_norm = []
out_cluster_proj_norm = []
in_cluster_proj_diff_norm = []
out_cluster_proj_diff_norm = []
task_loss = {'0':[],
                '1':[],
                '2':[]}
task_rel_error = {'0':[],
            '1':[],
            '2':[]}

for curr_round in tqdm(range(1, it+1)):
    w, local_loss = [], []

    
    for i in range(no_users):
        dummy_models[i].load_state_dict(models[i].state_dict())
        local_update = ClientUpdate(dataset=datapoints[i], batchSize=batch_size, alpha=alpha, lamda=lamda, epochs=1, projection_list=projection_list, projected_weights=projected_weights)
        weights, loss = local_update.train(dummy_models[i])
        w.append(weights)
        local_loss.append(loss)
        models[i].load_state_dict(w[i])
        
    
    
    # Update prjection matrix
    projected_weights = []
    update_ProjWeight(projection_list, projected_weights, models, first_run=False)
    
    #print(projection_list[0], projected_weights[0])
    
    for i in range(no_users):
        weights = parameters_to_vector(models[i].parameters())
        for j in G.neighbors(i):
            temp_mat = torch.outer(projected_weights[i][j] - projected_weights[j][i], weights)
            projection_list[i][j] = torch.add(projection_list[i][j], -1 * eta * lamda * temp_mat)
    
    in_cluster_proj_norm_round = 0
    out_cluster_proj_norm_round = 0
    in_cluster_proj_diff_round = 0
    out_cluster_proj_diff_round = 0
    in_edges = 0
    out_edges = 0
    
    for i in range(no_users//3):
        for j in G.neighbors(i):
            if j < no_users//3:
                in_edges += 1
                in_cluster_proj_norm_round += torch.norm(projection_list[i][j]).detach().numpy()
                in_cluster_proj_diff_round += torch.norm(projected_weights[i][j] - projected_weights[j][i]).detach().numpy()
            else:
                out_edges += 1
                out_cluster_proj_norm_round += torch.norm(projection_list[i][j]).detach().numpy()
                out_cluster_proj_diff_round += torch.norm(projected_weights[i][j] - projected_weights[j][i]).detach().numpy()
    in_cluster_proj_norm.append(in_cluster_proj_norm_round / in_edges)
    out_cluster_proj_norm.append(out_cluster_proj_norm_round / out_edges)
    in_cluster_proj_diff_norm.append(in_cluster_proj_diff_round / in_edges)
    out_cluster_proj_diff_norm.append(out_cluster_proj_diff_round / out_edges)
            
                                         
                                              
    
        
        
        
    
    




          
            

    local_test_acc = []
    local_test_loss = []
    user_rel_error = 0
    per_task_loss = []
    per_task_rel_error = []
    
    for k in range(no_users):
      
        g_loss = testing(models[k], datapoints[k], 50, criterion)
        local_test_loss.append(g_loss)
        #user_rel_error += acc#rel_error(models[i])
        if (k + 1) % 10 == 0:
            task_loss[str(k // 10)].append(sum(per_task_loss) / 10)
            #task_rel_error[str(k // 10)].append(sum(per_task_rel_error) / 10)
            per_task_loss = []
            #per_task_rel_error = []
        per_task_loss.append(g_loss)
        #per_task_rel_error.append(acc)#rel_error(models[i]))
    
    
        

    g_loss = sum(local_test_loss) / len(local_test_loss)
    #total_rel_error.append(user_rel_error / no_users)
    
    

    test_loss.append(g_loss)
    #test_accuracy.append(g_accuracy)
    print("Training_loss %2.5f"% (test_loss[-1]))

  0%|          | 1/500 [00:01<08:31,  1.03s/it]

Training_loss 96.52692


  0%|          | 2/500 [00:01<08:11,  1.01it/s]

Training_loss 88.24042


  1%|          | 3/500 [00:02<08:11,  1.01it/s]

Training_loss 80.92075


  1%|          | 4/500 [00:04<08:20,  1.01s/it]

Training_loss 74.09452


  1%|          | 5/500 [00:05<08:23,  1.02s/it]

Training_loss 68.20425


  1%|          | 6/500 [00:05<08:10,  1.01it/s]

Training_loss 62.74623


  1%|▏         | 7/500 [00:06<08:02,  1.02it/s]

Training_loss 57.88087


  2%|▏         | 8/500 [00:07<08:00,  1.02it/s]

Training_loss 53.62113


  2%|▏         | 9/500 [00:08<07:55,  1.03it/s]

Training_loss 49.77535


  2%|▏         | 10/500 [00:09<08:18,  1.02s/it]

Training_loss 46.22095


  2%|▏         | 11/500 [00:11<08:20,  1.02s/it]

Training_loss 43.26291


  2%|▏         | 12/500 [00:12<08:14,  1.01s/it]

Training_loss 40.36154


  3%|▎         | 13/500 [00:13<08:08,  1.00s/it]

Training_loss 37.86119


  3%|▎         | 14/500 [00:13<08:03,  1.01it/s]

Training_loss 35.51697


  3%|▎         | 15/500 [00:14<07:52,  1.03it/s]

Training_loss 33.40331


  3%|▎         | 16/500 [00:15<07:49,  1.03it/s]

Training_loss 31.49765


  3%|▎         | 17/500 [00:16<07:56,  1.01it/s]

Training_loss 29.84443


  4%|▎         | 18/500 [00:17<07:52,  1.02it/s]

Training_loss 28.29955


  4%|▍         | 19/500 [00:18<07:54,  1.01it/s]

Training_loss 26.89806


  4%|▍         | 20/500 [00:19<07:57,  1.01it/s]

Training_loss 25.64248


  4%|▍         | 21/500 [00:20<08:07,  1.02s/it]

Training_loss 24.48926


  4%|▍         | 22/500 [00:21<08:01,  1.01s/it]

Training_loss 23.40231


  5%|▍         | 23/500 [00:22<07:50,  1.01it/s]

Training_loss 22.42820


  5%|▍         | 24/500 [00:23<07:42,  1.03it/s]

Training_loss 21.54532


  5%|▌         | 25/500 [00:24<07:41,  1.03it/s]

Training_loss 20.71339


  5%|▌         | 26/500 [00:25<08:04,  1.02s/it]

Training_loss 20.01520


  5%|▌         | 27/500 [00:26<08:12,  1.04s/it]

Training_loss 19.40723


  6%|▌         | 28/500 [00:28<08:13,  1.05s/it]

Training_loss 18.79964


  6%|▌         | 29/500 [00:29<07:59,  1.02s/it]

Training_loss 18.27291


  6%|▌         | 30/500 [00:29<07:47,  1.01it/s]

Training_loss 17.72186


  6%|▌         | 31/500 [00:31<08:31,  1.09s/it]

Training_loss 17.27516


  6%|▋         | 32/500 [00:32<08:16,  1.06s/it]

Training_loss 16.87788


  7%|▋         | 33/500 [00:33<08:04,  1.04s/it]

Training_loss 16.49411


  7%|▋         | 34/500 [00:34<08:43,  1.12s/it]

Training_loss 16.11219


  7%|▋         | 35/500 [00:35<08:28,  1.09s/it]

Training_loss 15.77456


  7%|▋         | 36/500 [00:36<08:20,  1.08s/it]

Training_loss 15.45231


  7%|▋         | 37/500 [00:37<08:10,  1.06s/it]

Training_loss 15.19542


  8%|▊         | 38/500 [00:38<08:04,  1.05s/it]

Training_loss 14.90682


  8%|▊         | 39/500 [00:39<08:13,  1.07s/it]

Training_loss 14.62693


  8%|▊         | 40/500 [00:40<08:05,  1.05s/it]

Training_loss 14.40819


  8%|▊         | 41/500 [00:42<08:24,  1.10s/it]

Training_loss 14.20120


  8%|▊         | 42/500 [00:43<08:27,  1.11s/it]

Training_loss 13.98941


  9%|▊         | 43/500 [00:44<08:11,  1.07s/it]

Training_loss 13.79011


  9%|▉         | 44/500 [00:45<08:05,  1.06s/it]

Training_loss 13.60271


  9%|▉         | 45/500 [00:46<07:56,  1.05s/it]

Training_loss 13.41995


  9%|▉         | 46/500 [00:47<07:43,  1.02s/it]

Training_loss 13.25755


  9%|▉         | 47/500 [00:48<07:35,  1.01s/it]

Training_loss 13.08118


 10%|▉         | 48/500 [00:49<07:31,  1.00it/s]

Training_loss 12.95223


 10%|▉         | 49/500 [00:50<07:37,  1.02s/it]

Training_loss 12.80835


 10%|█         | 50/500 [00:51<07:33,  1.01s/it]

Training_loss 12.67553


 10%|█         | 51/500 [00:52<07:43,  1.03s/it]

Training_loss 12.54444


 10%|█         | 52/500 [00:53<07:57,  1.07s/it]

Training_loss 12.38893


 11%|█         | 53/500 [00:54<08:20,  1.12s/it]

Training_loss 12.26629


 11%|█         | 54/500 [00:55<08:39,  1.17s/it]

Training_loss 12.15202


 11%|█         | 55/500 [00:57<08:45,  1.18s/it]

Training_loss 12.02949


 11%|█         | 56/500 [00:58<09:36,  1.30s/it]

Training_loss 11.91490


 11%|█▏        | 57/500 [00:59<09:23,  1.27s/it]

Training_loss 11.80343


 12%|█▏        | 58/500 [01:01<09:56,  1.35s/it]

Training_loss 11.71063


 12%|█▏        | 59/500 [01:02<09:22,  1.28s/it]

Training_loss 11.62591


 12%|█▏        | 60/500 [01:03<09:04,  1.24s/it]

Training_loss 11.51725


 12%|█▏        | 61/500 [01:04<08:22,  1.15s/it]

Training_loss 11.42015


 12%|█▏        | 62/500 [01:05<07:57,  1.09s/it]

Training_loss 11.33298


 13%|█▎        | 63/500 [01:06<07:45,  1.06s/it]

Training_loss 11.24433


 13%|█▎        | 64/500 [01:07<07:34,  1.04s/it]

Training_loss 11.17945


 13%|█▎        | 65/500 [01:08<07:24,  1.02s/it]

Training_loss 11.10948


 13%|█▎        | 66/500 [01:09<07:24,  1.02s/it]

Training_loss 11.02775


 13%|█▎        | 67/500 [01:10<07:16,  1.01s/it]

Training_loss 10.93887


 14%|█▎        | 68/500 [01:11<07:14,  1.01s/it]

Training_loss 10.86544


 14%|█▍        | 69/500 [01:12<07:10,  1.00it/s]

Training_loss 10.80898


 14%|█▍        | 70/500 [01:13<07:29,  1.05s/it]

Training_loss 10.74730


 14%|█▍        | 71/500 [01:14<07:54,  1.11s/it]

Training_loss 10.68251


 14%|█▍        | 72/500 [01:15<07:39,  1.07s/it]

Training_loss 10.62592


 15%|█▍        | 73/500 [01:16<07:38,  1.07s/it]

Training_loss 10.55873


 15%|█▍        | 74/500 [01:18<07:47,  1.10s/it]

Training_loss 10.49687


 15%|█▌        | 75/500 [01:19<07:38,  1.08s/it]

Training_loss 10.43436


 15%|█▌        | 76/500 [01:20<07:26,  1.05s/it]

Training_loss 10.37429


 15%|█▌        | 77/500 [01:21<07:24,  1.05s/it]

Training_loss 10.31834


 16%|█▌        | 78/500 [01:22<07:22,  1.05s/it]

Training_loss 10.25594


 16%|█▌        | 79/500 [01:23<07:15,  1.03s/it]

Training_loss 10.19213


 16%|█▌        | 80/500 [01:24<07:14,  1.03s/it]

Training_loss 10.14413


 16%|█▌        | 81/500 [01:25<07:13,  1.03s/it]

Training_loss 10.07415


 16%|█▋        | 82/500 [01:26<07:20,  1.05s/it]

Training_loss 10.01985


 17%|█▋        | 83/500 [01:27<07:40,  1.11s/it]

Training_loss 9.97343


 17%|█▋        | 84/500 [01:29<08:31,  1.23s/it]

Training_loss 9.93054


 17%|█▋        | 85/500 [01:30<09:13,  1.33s/it]

Training_loss 9.88739


 17%|█▋        | 86/500 [01:32<09:02,  1.31s/it]

Training_loss 9.83804


 17%|█▋        | 87/500 [01:33<08:31,  1.24s/it]

Training_loss 9.80497


 18%|█▊        | 88/500 [01:34<08:48,  1.28s/it]

Training_loss 9.76764


 18%|█▊        | 89/500 [01:35<08:32,  1.25s/it]

Training_loss 9.73441


 18%|█▊        | 90/500 [01:36<08:10,  1.20s/it]

Training_loss 9.69780


 18%|█▊        | 91/500 [01:37<08:03,  1.18s/it]

Training_loss 9.65031


 18%|█▊        | 92/500 [01:39<08:12,  1.21s/it]

Training_loss 9.60911


 19%|█▊        | 93/500 [01:40<08:06,  1.19s/it]

Training_loss 9.56719


 19%|█▉        | 94/500 [01:41<07:55,  1.17s/it]

Training_loss 9.53515


 19%|█▉        | 95/500 [01:42<07:59,  1.18s/it]

Training_loss 9.49973


 19%|█▉        | 96/500 [01:43<08:01,  1.19s/it]

Training_loss 9.47058


 19%|█▉        | 97/500 [01:45<08:13,  1.22s/it]

Training_loss 9.43260


 20%|█▉        | 98/500 [01:47<09:41,  1.45s/it]

Training_loss 9.40106


 20%|█▉        | 99/500 [01:48<09:36,  1.44s/it]

Training_loss 9.36654


 20%|██        | 100/500 [01:49<09:16,  1.39s/it]

Training_loss 9.32405


 20%|██        | 101/500 [01:50<08:39,  1.30s/it]

Training_loss 9.30559


 20%|██        | 102/500 [01:51<08:13,  1.24s/it]

Training_loss 9.27618


 21%|██        | 103/500 [01:52<07:28,  1.13s/it]

Training_loss 9.24706


 21%|██        | 104/500 [01:53<06:51,  1.04s/it]

Training_loss 9.21032


 21%|██        | 105/500 [01:54<06:30,  1.01it/s]

Training_loss 9.18665


 21%|██        | 106/500 [01:55<06:26,  1.02it/s]

Training_loss 9.15500


 21%|██▏       | 107/500 [01:56<06:08,  1.07it/s]

Training_loss 9.12486


 22%|██▏       | 108/500 [01:57<06:06,  1.07it/s]

Training_loss 9.09544


 22%|██▏       | 109/500 [01:58<06:16,  1.04it/s]

Training_loss 9.06537


 22%|██▏       | 110/500 [01:59<06:17,  1.03it/s]

Training_loss 9.03722


 22%|██▏       | 111/500 [02:00<06:16,  1.03it/s]

Training_loss 9.00971


 22%|██▏       | 112/500 [02:01<06:49,  1.05s/it]

Training_loss 8.98440


 23%|██▎       | 113/500 [02:03<07:49,  1.21s/it]

Training_loss 8.96094


 23%|██▎       | 114/500 [02:04<08:03,  1.25s/it]

Training_loss 8.93917


 23%|██▎       | 115/500 [02:05<07:45,  1.21s/it]

Training_loss 8.91568


 23%|██▎       | 116/500 [02:06<07:38,  1.19s/it]

Training_loss 8.88804


 23%|██▎       | 117/500 [02:07<07:33,  1.18s/it]

Training_loss 8.86414


 24%|██▎       | 118/500 [02:09<07:37,  1.20s/it]

Training_loss 8.83848


 24%|██▍       | 119/500 [02:10<07:33,  1.19s/it]

Training_loss 8.81344


 24%|██▍       | 120/500 [02:11<06:57,  1.10s/it]

Training_loss 8.79501


 24%|██▍       | 121/500 [02:11<06:24,  1.02s/it]

Training_loss 8.76936


 24%|██▍       | 122/500 [02:12<06:10,  1.02it/s]

Training_loss 8.74936


 25%|██▍       | 123/500 [02:13<06:05,  1.03it/s]

Training_loss 8.73124


 25%|██▍       | 124/500 [02:14<06:08,  1.02it/s]

Training_loss 8.70823


 25%|██▌       | 125/500 [02:15<05:52,  1.06it/s]

Training_loss 8.67570


 25%|██▌       | 126/500 [02:16<05:39,  1.10it/s]

Training_loss 8.66208


 25%|██▌       | 127/500 [02:17<05:51,  1.06it/s]

Training_loss 8.64311


 26%|██▌       | 128/500 [02:18<06:04,  1.02it/s]

Training_loss 8.62670


 26%|██▌       | 129/500 [02:19<05:56,  1.04it/s]

Training_loss 8.60242


 26%|██▌       | 130/500 [02:20<05:51,  1.05it/s]

Training_loss 8.58841


 26%|██▌       | 131/500 [02:21<05:41,  1.08it/s]

Training_loss 8.57415


 26%|██▋       | 132/500 [02:22<05:46,  1.06it/s]

Training_loss 8.55716


 27%|██▋       | 133/500 [02:23<05:42,  1.07it/s]

Training_loss 8.53676


 27%|██▋       | 134/500 [02:24<05:41,  1.07it/s]

Training_loss 8.51944


 27%|██▋       | 135/500 [02:25<05:34,  1.09it/s]

Training_loss 8.50121


 27%|██▋       | 136/500 [02:25<05:40,  1.07it/s]

Training_loss 8.48532


 27%|██▋       | 137/500 [02:26<05:30,  1.10it/s]

Training_loss 8.46451


 28%|██▊       | 138/500 [02:27<05:27,  1.11it/s]

Training_loss 8.45544


 28%|██▊       | 139/500 [02:28<05:24,  1.11it/s]

Training_loss 8.43212


 28%|██▊       | 140/500 [02:29<05:18,  1.13it/s]

Training_loss 8.41758


 28%|██▊       | 141/500 [02:30<05:19,  1.12it/s]

Training_loss 8.40653


 28%|██▊       | 142/500 [02:31<05:18,  1.12it/s]

Training_loss 8.38306


 29%|██▊       | 143/500 [02:32<05:24,  1.10it/s]

Training_loss 8.37011


 29%|██▉       | 144/500 [02:33<05:47,  1.03it/s]

Training_loss 8.35886


 29%|██▉       | 145/500 [02:34<06:29,  1.10s/it]

Training_loss 8.33235


 29%|██▉       | 146/500 [02:35<06:36,  1.12s/it]

Training_loss 8.31159


 29%|██▉       | 147/500 [02:36<06:31,  1.11s/it]

Training_loss 8.29727


 30%|██▉       | 148/500 [02:37<06:11,  1.05s/it]

Training_loss 8.28376


 30%|██▉       | 149/500 [02:38<05:57,  1.02s/it]

Training_loss 8.26219


 30%|███       | 150/500 [02:39<05:49,  1.00it/s]

Training_loss 8.24996


 30%|███       | 151/500 [02:40<05:37,  1.04it/s]

Training_loss 8.24171


 30%|███       | 152/500 [02:41<05:32,  1.05it/s]

Training_loss 8.22164


 31%|███       | 153/500 [02:42<05:25,  1.07it/s]

Training_loss 8.21224


 31%|███       | 154/500 [02:43<05:34,  1.03it/s]

Training_loss 8.20719


 31%|███       | 155/500 [02:44<05:25,  1.06it/s]

Training_loss 8.20149


 31%|███       | 156/500 [02:45<05:23,  1.06it/s]

Training_loss 8.18520


 31%|███▏      | 157/500 [02:46<05:18,  1.08it/s]

Training_loss 8.17216


 32%|███▏      | 158/500 [02:47<05:20,  1.07it/s]

Training_loss 8.15339


 32%|███▏      | 159/500 [02:48<05:20,  1.06it/s]

Training_loss 8.13836


 32%|███▏      | 160/500 [02:49<05:16,  1.07it/s]

Training_loss 8.12379


 32%|███▏      | 161/500 [02:50<05:48,  1.03s/it]

Training_loss 8.11285


 32%|███▏      | 162/500 [02:51<05:44,  1.02s/it]

Training_loss 8.10539


 33%|███▎      | 163/500 [02:52<05:36,  1.00it/s]

Training_loss 8.08810


 33%|███▎      | 164/500 [02:53<05:28,  1.02it/s]

Training_loss 8.07966


 33%|███▎      | 165/500 [02:54<05:25,  1.03it/s]

Training_loss 8.07122


 33%|███▎      | 166/500 [02:55<05:23,  1.03it/s]

Training_loss 8.06244


 33%|███▎      | 167/500 [02:56<05:23,  1.03it/s]

Training_loss 8.05360


 34%|███▎      | 168/500 [02:57<05:18,  1.04it/s]

Training_loss 8.04472


 34%|███▍      | 169/500 [02:58<05:16,  1.04it/s]

Training_loss 8.03387


 34%|███▍      | 170/500 [02:58<05:13,  1.05it/s]

Training_loss 8.01872


 34%|███▍      | 171/500 [02:59<05:12,  1.05it/s]

Training_loss 8.01174


 34%|███▍      | 172/500 [03:00<05:18,  1.03it/s]

Training_loss 7.99937


 35%|███▍      | 173/500 [03:01<05:20,  1.02it/s]

Training_loss 7.99348


 35%|███▍      | 174/500 [03:03<05:40,  1.04s/it]

Training_loss 7.98056


 35%|███▌      | 175/500 [03:04<05:33,  1.03s/it]

Training_loss 7.97226


 35%|███▌      | 176/500 [03:05<05:29,  1.02s/it]

Training_loss 7.96353


 35%|███▌      | 177/500 [03:06<06:02,  1.12s/it]

Training_loss 7.95308


 36%|███▌      | 178/500 [03:07<05:53,  1.10s/it]

Training_loss 7.94184


 36%|███▌      | 179/500 [03:08<05:47,  1.08s/it]

Training_loss 7.93382


 36%|███▌      | 180/500 [03:09<05:39,  1.06s/it]

Training_loss 7.92207


 36%|███▌      | 181/500 [03:10<05:28,  1.03s/it]

Training_loss 7.91272


 36%|███▋      | 182/500 [03:11<05:22,  1.01s/it]

Training_loss 7.90436


 37%|███▋      | 183/500 [03:12<05:22,  1.02s/it]

Training_loss 7.88965


 37%|███▋      | 184/500 [03:13<05:18,  1.01s/it]

Training_loss 7.88427


 37%|███▋      | 185/500 [03:14<05:12,  1.01it/s]

Training_loss 7.87760


 37%|███▋      | 186/500 [03:15<05:10,  1.01it/s]

Training_loss 7.86726


 37%|███▋      | 187/500 [03:16<05:10,  1.01it/s]

Training_loss 7.86189


 38%|███▊      | 188/500 [03:17<05:08,  1.01it/s]

Training_loss 7.85268


 38%|███▊      | 189/500 [03:18<05:08,  1.01it/s]

Training_loss 7.83960


 38%|███▊      | 190/500 [03:19<05:07,  1.01it/s]

Training_loss 7.83052


 38%|███▊      | 191/500 [03:20<05:05,  1.01it/s]

Training_loss 7.82195


 38%|███▊      | 192/500 [03:21<05:11,  1.01s/it]

Training_loss 7.80844


 39%|███▊      | 193/500 [03:22<05:39,  1.11s/it]

Training_loss 7.80087


 39%|███▉      | 194/500 [03:23<05:34,  1.09s/it]

Training_loss 7.79455


 39%|███▉      | 195/500 [03:24<05:22,  1.06s/it]

Training_loss 7.78612


 39%|███▉      | 196/500 [03:25<05:14,  1.03s/it]

Training_loss 7.78044


 39%|███▉      | 197/500 [03:26<05:06,  1.01s/it]

Training_loss 7.77296


 40%|███▉      | 198/500 [03:27<05:04,  1.01s/it]

Training_loss 7.76637


 40%|███▉      | 199/500 [03:28<05:09,  1.03s/it]

Training_loss 7.75894


 40%|████      | 200/500 [03:29<05:08,  1.03s/it]

Training_loss 7.75484


 40%|████      | 201/500 [03:30<05:05,  1.02s/it]

Training_loss 7.75052


 40%|████      | 202/500 [03:31<05:02,  1.02s/it]

Training_loss 7.74439


 41%|████      | 203/500 [03:32<05:04,  1.02s/it]

Training_loss 7.73837


 41%|████      | 204/500 [03:33<05:08,  1.04s/it]

Training_loss 7.72806


 41%|████      | 205/500 [03:35<05:08,  1.05s/it]

Training_loss 7.72407


 41%|████      | 206/500 [03:36<05:16,  1.08s/it]

Training_loss 7.71667


 41%|████▏     | 207/500 [03:37<05:16,  1.08s/it]

Training_loss 7.71172


 42%|████▏     | 208/500 [03:38<05:32,  1.14s/it]

Training_loss 7.70326


 42%|████▏     | 209/500 [03:39<05:25,  1.12s/it]

Training_loss 7.69988


 42%|████▏     | 210/500 [03:40<05:13,  1.08s/it]

Training_loss 7.69283


 42%|████▏     | 211/500 [03:41<05:05,  1.06s/it]

Training_loss 7.68685


 42%|████▏     | 212/500 [03:42<04:59,  1.04s/it]

Training_loss 7.67763


 43%|████▎     | 213/500 [03:43<04:54,  1.03s/it]

Training_loss 7.67065


 43%|████▎     | 214/500 [03:44<04:46,  1.00s/it]

Training_loss 7.66355


 43%|████▎     | 215/500 [03:45<04:46,  1.01s/it]

Training_loss 7.65881


 43%|████▎     | 216/500 [03:46<04:42,  1.01it/s]

Training_loss 7.65421


 43%|████▎     | 217/500 [03:47<04:38,  1.02it/s]

Training_loss 7.64933


 44%|████▎     | 218/500 [03:48<04:37,  1.02it/s]

Training_loss 7.64321


 44%|████▍     | 219/500 [03:49<04:38,  1.01it/s]

Training_loss 7.63665


 44%|████▍     | 220/500 [03:50<04:47,  1.03s/it]

Training_loss 7.63467


 44%|████▍     | 221/500 [03:51<04:47,  1.03s/it]

Training_loss 7.63484


 44%|████▍     | 222/500 [03:52<04:43,  1.02s/it]

Training_loss 7.62803


 45%|████▍     | 223/500 [03:53<04:55,  1.07s/it]

Training_loss 7.62237


 45%|████▍     | 224/500 [03:55<05:09,  1.12s/it]

Training_loss 7.62035


 45%|████▌     | 225/500 [03:56<04:55,  1.08s/it]

Training_loss 7.61375


 45%|████▌     | 226/500 [03:57<04:49,  1.06s/it]

Training_loss 7.60623


 45%|████▌     | 227/500 [03:58<04:45,  1.04s/it]

Training_loss 7.59838


 46%|████▌     | 228/500 [03:59<04:39,  1.03s/it]

Training_loss 7.59156


 46%|████▌     | 229/500 [04:00<04:33,  1.01s/it]

Training_loss 7.58610


 46%|████▌     | 230/500 [04:01<04:31,  1.00s/it]

Training_loss 7.57903


 46%|████▌     | 231/500 [04:01<04:26,  1.01it/s]

Training_loss 7.57322


 46%|████▋     | 232/500 [04:03<04:56,  1.11s/it]

Training_loss 7.56570


 47%|████▋     | 233/500 [04:04<05:12,  1.17s/it]

Training_loss 7.55997


 47%|████▋     | 234/500 [04:05<04:58,  1.12s/it]

Training_loss 7.55591


 47%|████▋     | 235/500 [04:06<04:46,  1.08s/it]

Training_loss 7.55281


 47%|████▋     | 236/500 [04:07<04:35,  1.04s/it]

Training_loss 7.54874


 47%|████▋     | 237/500 [04:08<04:27,  1.02s/it]

Training_loss 7.53931


 48%|████▊     | 238/500 [04:09<04:41,  1.08s/it]

Training_loss 7.53006


 48%|████▊     | 239/500 [04:11<04:55,  1.13s/it]

Training_loss 7.52216


 48%|████▊     | 240/500 [04:12<04:46,  1.10s/it]

Training_loss 7.51759


 48%|████▊     | 241/500 [04:13<04:36,  1.07s/it]

Training_loss 7.51844


 48%|████▊     | 242/500 [04:14<04:31,  1.05s/it]

Training_loss 7.51355


 49%|████▊     | 243/500 [04:15<04:23,  1.02s/it]

Training_loss 7.51386


 49%|████▉     | 244/500 [04:16<04:18,  1.01s/it]

Training_loss 7.51007


 49%|████▉     | 245/500 [04:16<04:13,  1.00it/s]

Training_loss 7.50293


 49%|████▉     | 246/500 [04:17<04:12,  1.00it/s]

Training_loss 7.50107


 49%|████▉     | 247/500 [04:18<04:10,  1.01it/s]

Training_loss 7.49383


 50%|████▉     | 248/500 [04:19<04:09,  1.01it/s]

Training_loss 7.48901


 50%|████▉     | 249/500 [04:20<04:05,  1.02it/s]

Training_loss 7.48013


 50%|█████     | 250/500 [04:21<04:08,  1.00it/s]

Training_loss 7.48029


 50%|█████     | 251/500 [04:22<04:07,  1.01it/s]

Training_loss 7.47404


 50%|█████     | 252/500 [04:23<04:05,  1.01it/s]

Training_loss 7.47162


 51%|█████     | 253/500 [04:24<04:03,  1.01it/s]

Training_loss 7.46821


 51%|█████     | 254/500 [04:26<04:24,  1.07s/it]

Training_loss 7.46167


 51%|█████     | 255/500 [04:27<04:25,  1.08s/it]

Training_loss 7.45621


 51%|█████     | 256/500 [04:28<04:17,  1.05s/it]

Training_loss 7.45189


 51%|█████▏    | 257/500 [04:29<04:08,  1.02s/it]

Training_loss 7.44731


 52%|█████▏    | 258/500 [04:30<04:08,  1.03s/it]

Training_loss 7.44077


 52%|█████▏    | 259/500 [04:31<04:05,  1.02s/it]

Training_loss 7.43772


 52%|█████▏    | 260/500 [04:32<04:02,  1.01s/it]

Training_loss 7.43622


 52%|█████▏    | 261/500 [04:33<04:01,  1.01s/it]

Training_loss 7.42816


 52%|█████▏    | 262/500 [04:34<04:07,  1.04s/it]

Training_loss 7.41486


 53%|█████▎    | 263/500 [04:35<04:05,  1.04s/it]

Training_loss 7.41230


 53%|█████▎    | 264/500 [04:36<03:56,  1.00s/it]

Training_loss 7.40738


 53%|█████▎    | 265/500 [04:37<03:50,  1.02it/s]

Training_loss 7.40070


 53%|█████▎    | 266/500 [04:38<03:48,  1.02it/s]

Training_loss 7.39693


 53%|█████▎    | 267/500 [04:39<03:50,  1.01it/s]

Training_loss 7.39379


 54%|█████▎    | 268/500 [04:40<03:49,  1.01it/s]

Training_loss 7.38788


 54%|█████▍    | 269/500 [04:41<03:48,  1.01it/s]

Training_loss 7.38323


 54%|█████▍    | 270/500 [04:42<04:15,  1.11s/it]

Training_loss 7.37682


 54%|█████▍    | 271/500 [04:43<04:09,  1.09s/it]

Training_loss 7.37410


 54%|█████▍    | 272/500 [04:44<04:03,  1.07s/it]

Training_loss 7.37433


 55%|█████▍    | 273/500 [04:45<03:55,  1.04s/it]

Training_loss 7.37185


 55%|█████▍    | 274/500 [04:46<03:53,  1.03s/it]

Training_loss 7.36741


 55%|█████▌    | 275/500 [04:47<03:47,  1.01s/it]

Training_loss 7.36053


 55%|█████▌    | 276/500 [04:48<03:44,  1.00s/it]

Training_loss 7.35498


 55%|█████▌    | 277/500 [04:49<03:43,  1.00s/it]

Training_loss 7.35356


 56%|█████▌    | 278/500 [04:50<03:42,  1.00s/it]

Training_loss 7.34195


 56%|█████▌    | 279/500 [04:51<03:41,  1.00s/it]

Training_loss 7.33733


 56%|█████▌    | 280/500 [04:52<03:39,  1.00it/s]

Training_loss 7.33298


 56%|█████▌    | 281/500 [04:53<03:36,  1.01it/s]

Training_loss 7.33387


 56%|█████▋    | 282/500 [04:54<03:35,  1.01it/s]

Training_loss 7.32942


 57%|█████▋    | 283/500 [04:55<03:33,  1.02it/s]

Training_loss 7.32744


 57%|█████▋    | 284/500 [04:56<03:33,  1.01it/s]

Training_loss 7.32591


 57%|█████▋    | 285/500 [04:57<03:43,  1.04s/it]

Training_loss 7.32174


 57%|█████▋    | 286/500 [04:58<04:01,  1.13s/it]

Training_loss 7.31855


 57%|█████▋    | 287/500 [04:59<03:51,  1.09s/it]

Training_loss 7.31611


 58%|█████▊    | 288/500 [05:00<03:43,  1.05s/it]

Training_loss 7.31584


 58%|█████▊    | 289/500 [05:01<03:36,  1.02s/it]

Training_loss 7.31191


 58%|█████▊    | 290/500 [05:02<03:35,  1.03s/it]

Training_loss 7.30750


 58%|█████▊    | 291/500 [05:04<03:47,  1.09s/it]

Training_loss 7.29991


 58%|█████▊    | 292/500 [05:05<03:42,  1.07s/it]

Training_loss 7.29815


 59%|█████▊    | 293/500 [05:06<03:36,  1.04s/it]

Training_loss 7.29837


 59%|█████▉    | 294/500 [05:07<03:38,  1.06s/it]

Training_loss 7.29311


 59%|█████▉    | 295/500 [05:08<03:31,  1.03s/it]

Training_loss 7.28960


 59%|█████▉    | 296/500 [05:09<03:27,  1.02s/it]

Training_loss 7.28752


 59%|█████▉    | 297/500 [05:10<03:23,  1.00s/it]

Training_loss 7.28249


 60%|█████▉    | 298/500 [05:11<03:21,  1.00it/s]

Training_loss 7.27921


 60%|█████▉    | 299/500 [05:12<03:18,  1.01it/s]

Training_loss 7.27958


 60%|██████    | 300/500 [05:13<03:17,  1.01it/s]

Training_loss 7.27703


 60%|██████    | 301/500 [05:14<03:36,  1.09s/it]

Training_loss 7.27651


 60%|██████    | 302/500 [05:15<03:35,  1.09s/it]

Training_loss 7.27512


 61%|██████    | 303/500 [05:16<03:27,  1.06s/it]

Training_loss 7.27000


 61%|██████    | 304/500 [05:17<03:21,  1.03s/it]

Training_loss 7.26594


 61%|██████    | 305/500 [05:18<03:17,  1.01s/it]

Training_loss 7.26647


 61%|██████    | 306/500 [05:19<03:14,  1.00s/it]

Training_loss 7.25981


 61%|██████▏   | 307/500 [05:20<03:11,  1.01it/s]

Training_loss 7.25586


 62%|██████▏   | 308/500 [05:21<03:08,  1.02it/s]

Training_loss 7.25333


 62%|██████▏   | 309/500 [05:22<03:07,  1.02it/s]

Training_loss 7.25199


 62%|██████▏   | 310/500 [05:23<03:06,  1.02it/s]

Training_loss 7.25048


 62%|██████▏   | 311/500 [05:24<03:07,  1.01it/s]

Training_loss 7.24262


 62%|██████▏   | 312/500 [05:25<03:06,  1.01it/s]

Training_loss 7.24405


 63%|██████▎   | 313/500 [05:26<03:06,  1.01it/s]

Training_loss 7.24473


 63%|██████▎   | 314/500 [05:27<03:05,  1.00it/s]

Training_loss 7.23985


 63%|██████▎   | 315/500 [05:28<03:04,  1.00it/s]

Training_loss 7.23810


 63%|██████▎   | 316/500 [05:29<03:04,  1.00s/it]

Training_loss 7.22795


 63%|██████▎   | 317/500 [05:30<03:21,  1.10s/it]

Training_loss 7.22895


 64%|██████▎   | 318/500 [05:31<03:16,  1.08s/it]

Training_loss 7.22710


 64%|██████▍   | 319/500 [05:32<03:11,  1.06s/it]

Training_loss 7.23188


 64%|██████▍   | 320/500 [05:33<03:11,  1.06s/it]

Training_loss 7.22727


 64%|██████▍   | 321/500 [05:34<03:09,  1.06s/it]

Training_loss 7.22116


 64%|██████▍   | 322/500 [05:35<03:02,  1.03s/it]

Training_loss 7.21590


 65%|██████▍   | 323/500 [05:36<02:56,  1.00it/s]

Training_loss 7.21450


 65%|██████▍   | 324/500 [05:37<02:52,  1.02it/s]

Training_loss 7.20988


 65%|██████▌   | 325/500 [05:38<02:51,  1.02it/s]

Training_loss 7.20924


 65%|██████▌   | 326/500 [05:39<02:51,  1.02it/s]

Training_loss 7.21177


 65%|██████▌   | 327/500 [05:40<03:07,  1.09s/it]

Training_loss 7.20824


 66%|██████▌   | 328/500 [05:42<03:10,  1.11s/it]

Training_loss 7.20280


 66%|██████▌   | 329/500 [05:43<03:10,  1.11s/it]

Training_loss 7.20188


 66%|██████▌   | 330/500 [05:44<03:30,  1.24s/it]

Training_loss 7.19831


 66%|██████▌   | 331/500 [05:46<03:41,  1.31s/it]

Training_loss 7.19052


 66%|██████▋   | 332/500 [05:47<03:37,  1.29s/it]

Training_loss 7.19021


 67%|██████▋   | 333/500 [05:48<03:27,  1.24s/it]

Training_loss 7.19401


 67%|██████▋   | 334/500 [05:49<03:13,  1.16s/it]

Training_loss 7.19373


 67%|██████▋   | 335/500 [05:50<03:06,  1.13s/it]

Training_loss 7.19419


 67%|██████▋   | 336/500 [05:51<02:58,  1.09s/it]

Training_loss 7.19115


 67%|██████▋   | 337/500 [05:52<02:52,  1.06s/it]

Training_loss 7.19285


 68%|██████▊   | 338/500 [05:53<02:47,  1.03s/it]

Training_loss 7.19021


 68%|██████▊   | 339/500 [05:54<02:46,  1.03s/it]

Training_loss 7.18793


 68%|██████▊   | 340/500 [05:55<02:41,  1.01s/it]

Training_loss 7.18439


 68%|██████▊   | 341/500 [05:56<02:39,  1.00s/it]

Training_loss 7.18070


 68%|██████▊   | 342/500 [05:57<02:37,  1.00it/s]

Training_loss 7.17969


 69%|██████▊   | 343/500 [05:58<02:40,  1.02s/it]

Training_loss 7.17426


 69%|██████▉   | 344/500 [05:59<02:36,  1.01s/it]

Training_loss 7.17739


 69%|██████▉   | 345/500 [06:00<02:34,  1.00it/s]

Training_loss 7.17625


 69%|██████▉   | 346/500 [06:01<02:41,  1.05s/it]

Training_loss 7.17004


 69%|██████▉   | 347/500 [06:03<02:50,  1.11s/it]

Training_loss 7.16979


 70%|██████▉   | 348/500 [06:04<02:56,  1.16s/it]

Training_loss 7.16377


 70%|██████▉   | 349/500 [06:05<02:59,  1.19s/it]

Training_loss 7.16240


 70%|███████   | 350/500 [06:06<02:49,  1.13s/it]

Training_loss 7.16285


 70%|███████   | 351/500 [06:07<02:41,  1.08s/it]

Training_loss 7.16045


 70%|███████   | 352/500 [06:08<02:35,  1.05s/it]

Training_loss 7.16165


 71%|███████   | 353/500 [06:09<02:32,  1.04s/it]

Training_loss 7.15589


 71%|███████   | 354/500 [06:10<02:29,  1.02s/it]

Training_loss 7.15449


 71%|███████   | 355/500 [06:11<02:26,  1.01s/it]

Training_loss 7.15221


 71%|███████   | 356/500 [06:12<02:23,  1.01it/s]

Training_loss 7.14723


 71%|███████▏  | 357/500 [06:13<02:22,  1.00it/s]

Training_loss 7.14195


 72%|███████▏  | 358/500 [06:14<02:20,  1.01it/s]

Training_loss 7.13979


 72%|███████▏  | 359/500 [06:15<02:19,  1.01it/s]

Training_loss 7.14036


 72%|███████▏  | 360/500 [06:16<02:18,  1.01it/s]

Training_loss 7.13946


 72%|███████▏  | 361/500 [06:17<02:22,  1.02s/it]

Training_loss 7.13827


 72%|███████▏  | 362/500 [06:18<02:36,  1.14s/it]

Training_loss 7.13474


 73%|███████▎  | 363/500 [06:19<02:30,  1.10s/it]

Training_loss 7.13518


 73%|███████▎  | 364/500 [06:20<02:26,  1.08s/it]

Training_loss 7.13335


 73%|███████▎  | 365/500 [06:21<02:22,  1.05s/it]

Training_loss 7.13358


 73%|███████▎  | 366/500 [06:22<02:18,  1.03s/it]

Training_loss 7.13115


 73%|███████▎  | 367/500 [06:23<02:15,  1.02s/it]

Training_loss 7.12777


 74%|███████▎  | 368/500 [06:24<02:16,  1.03s/it]

Training_loss 7.12587


 74%|███████▍  | 369/500 [06:26<02:14,  1.03s/it]

Training_loss 7.12502


 74%|███████▍  | 370/500 [06:27<02:12,  1.02s/it]

Training_loss 7.12326


 74%|███████▍  | 371/500 [06:27<02:10,  1.01s/it]

Training_loss 7.12069


 74%|███████▍  | 372/500 [06:28<02:07,  1.00it/s]

Training_loss 7.11716


 75%|███████▍  | 373/500 [06:29<02:04,  1.02it/s]

Training_loss 7.11532


 75%|███████▍  | 374/500 [06:30<02:04,  1.01it/s]

Training_loss 7.11291


 75%|███████▌  | 375/500 [06:31<02:03,  1.01it/s]

Training_loss 7.11271


 75%|███████▌  | 376/500 [06:32<02:02,  1.01it/s]

Training_loss 7.10823


 75%|███████▌  | 377/500 [06:34<02:25,  1.18s/it]

Training_loss 7.11016


 76%|███████▌  | 378/500 [06:35<02:31,  1.24s/it]

Training_loss 7.11051


 76%|███████▌  | 379/500 [06:37<02:30,  1.24s/it]

Training_loss 7.10917


 76%|███████▌  | 380/500 [06:38<02:26,  1.22s/it]

Training_loss 7.10811


 76%|███████▌  | 381/500 [06:39<02:30,  1.27s/it]

Training_loss 7.10866


 76%|███████▋  | 382/500 [06:41<02:31,  1.28s/it]

Training_loss 7.10573


 77%|███████▋  | 383/500 [06:42<02:26,  1.25s/it]

Training_loss 7.10205


 77%|███████▋  | 384/500 [06:43<02:25,  1.25s/it]

Training_loss 7.09672


 77%|███████▋  | 385/500 [06:44<02:21,  1.23s/it]

Training_loss 7.09869


 77%|███████▋  | 386/500 [06:45<02:17,  1.21s/it]

Training_loss 7.09789


 77%|███████▋  | 387/500 [06:46<02:15,  1.20s/it]

Training_loss 7.09620


 78%|███████▊  | 388/500 [06:48<02:13,  1.19s/it]

Training_loss 7.09397


 78%|███████▊  | 389/500 [06:49<02:09,  1.17s/it]

Training_loss 7.08809


 78%|███████▊  | 390/500 [06:50<02:14,  1.22s/it]

Training_loss 7.08511


 78%|███████▊  | 391/500 [06:51<02:07,  1.17s/it]

Training_loss 7.08435


 78%|███████▊  | 392/500 [06:52<02:01,  1.13s/it]

Training_loss 7.08183


 79%|███████▊  | 393/500 [06:53<01:56,  1.09s/it]

Training_loss 7.07650


 79%|███████▉  | 394/500 [06:54<01:56,  1.10s/it]

Training_loss 7.07496


 79%|███████▉  | 395/500 [06:55<01:54,  1.09s/it]

Training_loss 7.07653


 79%|███████▉  | 396/500 [06:57<02:02,  1.17s/it]

Training_loss 7.07388


 79%|███████▉  | 397/500 [06:58<02:14,  1.30s/it]

Training_loss 7.07483


 80%|███████▉  | 398/500 [07:00<02:13,  1.31s/it]

Training_loss 7.07354


 80%|███████▉  | 399/500 [07:01<02:07,  1.27s/it]

Training_loss 7.07636


 80%|████████  | 400/500 [07:02<02:13,  1.33s/it]

Training_loss 7.07598


 80%|████████  | 401/500 [07:05<02:42,  1.64s/it]

Training_loss 7.07610


 80%|████████  | 402/500 [07:07<03:08,  1.93s/it]

Training_loss 7.07312


 81%|████████  | 403/500 [07:10<03:27,  2.14s/it]

Training_loss 7.06900


 81%|████████  | 404/500 [07:12<03:26,  2.15s/it]

Training_loss 7.06732


 81%|████████  | 405/500 [07:14<03:11,  2.02s/it]

Training_loss 7.06378


 81%|████████  | 406/500 [07:16<03:09,  2.02s/it]

Training_loss 7.06074


 81%|████████▏ | 407/500 [07:18<03:24,  2.20s/it]

Training_loss 7.06066


 82%|████████▏ | 408/500 [07:21<03:20,  2.17s/it]

Training_loss 7.05897


 82%|████████▏ | 409/500 [07:23<03:31,  2.32s/it]

Training_loss 7.06095


 82%|████████▏ | 410/500 [07:25<03:21,  2.24s/it]

Training_loss 7.06084


 82%|████████▏ | 411/500 [07:26<02:52,  1.94s/it]

Training_loss 7.06153


 82%|████████▏ | 412/500 [07:28<02:26,  1.66s/it]

Training_loss 7.05997


 83%|████████▎ | 413/500 [07:29<02:11,  1.51s/it]

Training_loss 7.06131


 83%|████████▎ | 414/500 [07:30<02:01,  1.42s/it]

Training_loss 7.06324


 83%|████████▎ | 415/500 [07:31<01:56,  1.37s/it]

Training_loss 7.06238


 83%|████████▎ | 416/500 [07:33<02:00,  1.44s/it]

Training_loss 7.06481


 83%|████████▎ | 417/500 [07:36<02:51,  2.06s/it]

Training_loss 7.06277


 84%|████████▎ | 418/500 [07:40<03:28,  2.55s/it]

Training_loss 7.05947


 84%|████████▍ | 419/500 [07:43<03:29,  2.59s/it]

Training_loss 7.05815


 84%|████████▍ | 420/500 [07:44<02:54,  2.18s/it]

Training_loss 7.05620


 84%|████████▍ | 421/500 [07:47<03:07,  2.37s/it]

Training_loss 7.05574


 84%|████████▍ | 422/500 [07:48<02:43,  2.09s/it]

Training_loss 7.05358


 85%|████████▍ | 423/500 [07:51<03:10,  2.48s/it]

Training_loss 7.05643


 85%|████████▍ | 424/500 [07:54<03:01,  2.39s/it]

Training_loss 7.05870


 85%|████████▌ | 425/500 [07:55<02:45,  2.21s/it]

Training_loss 7.05590


 85%|████████▌ | 426/500 [07:57<02:34,  2.09s/it]

Training_loss 7.05228


 85%|████████▌ | 427/500 [07:59<02:24,  1.99s/it]

Training_loss 7.04896


 86%|████████▌ | 428/500 [08:00<02:10,  1.82s/it]

Training_loss 7.04498


 86%|████████▌ | 429/500 [08:02<01:58,  1.67s/it]

Training_loss 7.04487


 86%|████████▌ | 430/500 [08:03<01:47,  1.54s/it]

Training_loss 7.04292


 86%|████████▌ | 431/500 [08:04<01:44,  1.51s/it]

Training_loss 7.04139


 86%|████████▋ | 432/500 [08:06<01:35,  1.41s/it]

Training_loss 7.03940


 87%|████████▋ | 433/500 [08:08<02:00,  1.80s/it]

Training_loss 7.04077


 87%|████████▋ | 434/500 [08:12<02:31,  2.29s/it]

Training_loss 7.03914


 87%|████████▋ | 435/500 [08:15<02:57,  2.73s/it]

Training_loss 7.03053


 87%|████████▋ | 436/500 [08:17<02:23,  2.25s/it]

Training_loss 7.03014


 87%|████████▋ | 437/500 [08:18<02:03,  1.96s/it]

Training_loss 7.02690


 88%|████████▊ | 438/500 [08:19<01:45,  1.70s/it]

Training_loss 7.02293


 88%|████████▊ | 439/500 [08:20<01:31,  1.50s/it]

Training_loss 7.02228


 88%|████████▊ | 440/500 [08:21<01:20,  1.34s/it]

Training_loss 7.02209


 88%|████████▊ | 441/500 [08:22<01:13,  1.24s/it]

Training_loss 7.01717


 88%|████████▊ | 442/500 [08:23<01:07,  1.17s/it]

Training_loss 7.01601


 89%|████████▊ | 443/500 [08:24<01:03,  1.11s/it]

Training_loss 7.01718


 89%|████████▉ | 444/500 [08:25<01:03,  1.14s/it]

Training_loss 7.01579


 89%|████████▉ | 445/500 [08:26<01:04,  1.18s/it]

Training_loss 7.01713


 89%|████████▉ | 446/500 [08:27<01:00,  1.12s/it]

Training_loss 7.01497


 89%|████████▉ | 447/500 [08:29<00:58,  1.10s/it]

Training_loss 7.01511


 90%|████████▉ | 448/500 [08:29<00:55,  1.07s/it]

Training_loss 7.01343


 90%|████████▉ | 449/500 [08:31<00:54,  1.06s/it]

Training_loss 7.01012


 90%|█████████ | 450/500 [08:32<00:52,  1.04s/it]

Training_loss 7.01292


 90%|█████████ | 451/500 [08:33<00:50,  1.03s/it]

Training_loss 7.01020


 90%|█████████ | 452/500 [08:34<00:49,  1.03s/it]

Training_loss 7.01307


 91%|█████████ | 453/500 [08:35<00:55,  1.18s/it]

Training_loss 7.01655


 91%|█████████ | 454/500 [08:37<01:00,  1.32s/it]

Training_loss 7.01530


 91%|█████████ | 455/500 [08:38<00:59,  1.32s/it]

Training_loss 7.01673


 91%|█████████ | 456/500 [08:40<01:00,  1.36s/it]

Training_loss 7.00900


 91%|█████████▏| 457/500 [08:43<01:19,  1.85s/it]

Training_loss 7.00619


 92%|█████████▏| 458/500 [08:46<01:39,  2.38s/it]

Training_loss 7.00487


 92%|█████████▏| 459/500 [08:49<01:39,  2.42s/it]

Training_loss 7.00250


 92%|█████████▏| 460/500 [08:50<01:22,  2.05s/it]

Training_loss 7.00182


 92%|█████████▏| 461/500 [08:52<01:22,  2.13s/it]

Training_loss 7.00217


 92%|█████████▏| 462/500 [08:54<01:13,  1.92s/it]

Training_loss 7.00325


 93%|█████████▎| 463/500 [08:55<01:01,  1.66s/it]

Training_loss 6.99815


 93%|█████████▎| 464/500 [08:56<00:53,  1.48s/it]

Training_loss 6.99961


 93%|█████████▎| 465/500 [08:57<00:47,  1.36s/it]

Training_loss 7.00100


 93%|█████████▎| 466/500 [08:58<00:47,  1.39s/it]

Training_loss 6.99260


 93%|█████████▎| 467/500 [09:00<00:48,  1.47s/it]

Training_loss 6.99215


 94%|█████████▎| 468/500 [09:02<00:54,  1.70s/it]

Training_loss 6.98996


 94%|█████████▍| 469/500 [09:04<00:55,  1.80s/it]

Training_loss 6.99159


 94%|█████████▍| 470/500 [09:06<00:54,  1.81s/it]

Training_loss 6.99081


 94%|█████████▍| 471/500 [09:08<00:54,  1.88s/it]

Training_loss 6.98748


 94%|█████████▍| 472/500 [09:09<00:48,  1.72s/it]

Training_loss 6.98873


 95%|█████████▍| 473/500 [09:11<00:41,  1.55s/it]

Training_loss 6.98639


 95%|█████████▍| 474/500 [09:12<00:39,  1.51s/it]

Training_loss 6.98184


 95%|█████████▌| 475/500 [09:14<00:41,  1.66s/it]

Training_loss 6.98083


 95%|█████████▌| 476/500 [09:16<00:40,  1.70s/it]

Training_loss 6.98097


 95%|█████████▌| 477/500 [09:17<00:38,  1.67s/it]

Training_loss 6.98085


 96%|█████████▌| 478/500 [09:20<00:43,  1.98s/it]

Training_loss 6.98097


 96%|█████████▌| 479/500 [09:24<00:54,  2.61s/it]

Training_loss 6.98301


 96%|█████████▌| 480/500 [09:28<00:59,  2.99s/it]

Training_loss 6.98404


 96%|█████████▌| 481/500 [09:31<00:54,  2.88s/it]

Training_loss 6.98754


 96%|█████████▋| 482/500 [09:33<00:48,  2.67s/it]

Training_loss 6.98764


 97%|█████████▋| 483/500 [09:37<00:52,  3.11s/it]

Training_loss 6.98702


 97%|█████████▋| 484/500 [09:41<00:55,  3.46s/it]

Training_loss 6.98236


 97%|█████████▋| 485/500 [09:44<00:49,  3.32s/it]

Training_loss 6.97904


 97%|█████████▋| 486/500 [09:48<00:48,  3.47s/it]

Training_loss 6.98238


 97%|█████████▋| 487/500 [09:51<00:43,  3.38s/it]

Training_loss 6.98069


 98%|█████████▊| 488/500 [09:56<00:44,  3.69s/it]

Training_loss 6.98279


 98%|█████████▊| 489/500 [09:58<00:36,  3.35s/it]

Training_loss 6.98519


 98%|█████████▊| 490/500 [10:01<00:32,  3.27s/it]

Training_loss 6.98244


 98%|█████████▊| 491/500 [10:05<00:30,  3.44s/it]

Training_loss 6.98386


 98%|█████████▊| 492/500 [10:09<00:28,  3.62s/it]

Training_loss 6.98372


 99%|█████████▊| 493/500 [10:12<00:24,  3.48s/it]

Training_loss 6.98425


 99%|█████████▉| 494/500 [10:15<00:19,  3.28s/it]

Training_loss 6.98477


 99%|█████████▉| 495/500 [10:19<00:17,  3.48s/it]

Training_loss 6.98714


 99%|█████████▉| 496/500 [10:22<00:13,  3.28s/it]

Training_loss 6.98625


 99%|█████████▉| 497/500 [10:26<00:10,  3.49s/it]

Training_loss 6.98543


100%|█████████▉| 498/500 [10:29<00:06,  3.41s/it]

Training_loss 6.98569


100%|█████████▉| 499/500 [10:31<00:03,  3.02s/it]

Training_loss 6.98373


100%|██████████| 500/500 [10:34<00:00,  1.27s/it]

Training_loss 6.98397





In [24]:
test_loss = np.array(test_loss)
total_rel_error = np.array(total_rel_error)

in_cluster_proj_norm = np.array(in_cluster_proj_norm)
out_cluster_proj_norm = np.array(out_cluster_proj_norm)
in_cluster_proj_diff_norm = np.array(in_cluster_proj_diff_norm)
out_cluster_proj_diff_norm = np.array(out_cluster_proj_diff_norm)


In [25]:
'''
  0%|          | 1/2000 [00:12<6:59:49, 12.60s/it]
Training_loss 0.69317,   Accuracy 0.52177
  0%|          | 2/2000 [00:25<6:58:41, 12.57s/it]
Training_loss 0.69256,   Accuracy 0.52523
  0%|          | 3/2000 [00:37<6:55:30, 12.48s/it]
Training_loss 0.69282,   Accuracy 0.52400
  0%|          | 4/2000 [00:50<6:56:56, 12.53s/it]
Training_loss 0.69216,   Accuracy 0.52552
  0%|          | 5/2000 [01:05<7:35:34, 13.70s/it]
Training_loss 0.69178,   Accuracy 0.52765
  0%|          | 6/2000 [01:18<7:24:07, 13.36s/it]
Training_loss 0.69037,   Accuracy 0.55240
  0%|          | 7/2000 [01:32<7:32:02, 13.61s/it]
Training_loss 0.68986,   Accuracy 0.55653
  0%|          | 8/2000 [01:47<7:42:39, 13.94s/it]
Training_loss 0.68921,   Accuracy 0.56338
'''

'\n  0%|          | 1/2000 [00:12<6:59:49, 12.60s/it]\nTraining_loss 0.69317,   Accuracy 0.52177\n  0%|          | 2/2000 [00:25<6:58:41, 12.57s/it]\nTraining_loss 0.69256,   Accuracy 0.52523\n  0%|          | 3/2000 [00:37<6:55:30, 12.48s/it]\nTraining_loss 0.69282,   Accuracy 0.52400\n  0%|          | 4/2000 [00:50<6:56:56, 12.53s/it]\nTraining_loss 0.69216,   Accuracy 0.52552\n  0%|          | 5/2000 [01:05<7:35:34, 13.70s/it]\nTraining_loss 0.69178,   Accuracy 0.52765\n  0%|          | 6/2000 [01:18<7:24:07, 13.36s/it]\nTraining_loss 0.69037,   Accuracy 0.55240\n  0%|          | 7/2000 [01:32<7:32:02, 13.61s/it]\nTraining_loss 0.68986,   Accuracy 0.55653\n  0%|          | 8/2000 [01:47<7:42:39, 13.94s/it]\nTraining_loss 0.68921,   Accuracy 0.56338\n'

In [26]:
np.save( 'training_loss_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_'), test_loss)
np.save('relative_error_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_'), total_rel_error)
np.save( 'in_cluster_proj_norm_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_'), in_cluster_proj_norm)
np.save('out_cluster_proj_norm_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_'), out_cluster_proj_norm)
np.save( 'in_cluster_proj_diff_norm_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_'), in_cluster_proj_diff_norm)
np.save('out_cluster_proj_diff_norm_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_'), out_cluster_proj_diff_norm)

In [27]:
'training_loss_sheave_fml' + str(lamda).replace('.', '_')+ '_pout' + str(pout).replace('.', '_')

'training_loss_sheave_fml0_pout0_1'

In [28]:
# task_loss dictionary
for key, value in task_loss.items():
    # Convert the list to a NumPy array
    array_loss = np.array(value)
    
    # Save the NumPy array using the specified format
    np.save('training_loss_sheave_fml_task' + key + '_' + str(lamda).replace('.', '_') + '_pout' + str(pout).replace('.', '_'), array_loss)

# task_rel_error dictionary
for key, value in task_rel_error.items():
    # Convert the list to a NumPy array
    array_rel_error = np.array(value)
    
    # Save the NumPy array using the specified format
    np.save('relative_error_sheave_fml_task' + key + '_' + str(lamda).replace('.', '_') + '_pout' + str(pout).replace('.', '_'), array_rel_error)

In [33]:
def findMedianSortedArrays(nums1, nums2):
    # Ensure nums1 is the smaller array
    if len(nums1) > len(nums2):
        nums1, nums2 = nums2, nums1

    x, y = len(nums1), len(nums2)
    low, high = 0, x

    while low <= high:
        partitionX = (low + high) // 2
        partitionY = (x + y + 1) // 2 - partitionX

        maxX = float('-inf') if partitionX == 0 else nums1[partitionX - 1]
        minX = float('inf') if partitionX == x else nums1[partitionX]

        maxY = float('-inf') if partitionY == 0 else nums2[partitionY - 1]
        minY = float('inf') if partitionY == y else nums2[partitionY]

        if maxX <= minY and maxY <= minX:
            # We have found the correct partition
            if (x + y) % 2 == 0:
                # If the total number of elements is even
                return (max(maxX, maxY) + min(minX, minY)) / 2
            else:
                # If the total number of elements is odd
                return max(maxX, maxY)
        elif maxX > minY:
            # We need to move partitionX to the left
            high = partitionX - 1
        else:
            # We need to move partitionX to the right
            low = partitionX + 1

# Example usage:
nums1 = [1, 3]
nums2 = [2, 3, 4]
result = findMedianSortedArrays(nums1, nums2)
print("Median:", result)

Median: 3
