In [1]:
import json
import os
import random
from tqdm import tqdm
import numpy as np
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.dataset import Dataset, TensorDataset  
torch.backends.cudnn.benchmark=True
from pyhessian import hessian # Hessian computation
import scipy.io
import matplotlib.pyplot as plt


In [2]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Inputs to hidden layer linear transformation
        self.hidden = nn.Linear(784, 784)
        self.hidden2 = nn.Linear(784, 600)
        self.hidden3 = nn.Linear(600, 400)
        self.hidden4 = nn.Linear(400, 200)
        # Output layer, 62 units 
        self.output = nn.Linear(200, 26)
        
        # Define sigmoid activation and softmax output 
        self.ReLu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        # Pass the input tensor through each of our operations
        x = torch.reshape(x, (-1, 784))
        x = self.hidden(x)
        x = self.ReLu(x)
        x = self.hidden2(x)
        x = self.ReLu(x)
        x = self.hidden3(x)
        x = self.ReLu(x)
        x = self.hidden4(x)
        x = self.ReLu(x)
        x = self.output(x)
        x = self.softmax(x)
        
        return x

In [3]:
def client_update(client_model, optimizer, train_loader, mode, epoch=5):
    """
    This function updates/trains client model on client data
    """
    client_model.train()
    for e in range(epoch):
        for batch_idx, (inputs, target) in enumerate(train_loader):
            inputs, target = inputs.cuda(), target.cuda()
            optimizer.zero_grad()
            output = client_model(inputs)
            loss = nn.functional.nll_loss(output, target)
            loss.backward()
            optimizer.step()
    if mode == 'Average':
        return loss.item()
    if mode == 'HessFuse':
        client_model.eval()
        for batch_idx, (inputs, target) in enumerate(train_loader):
                inputs, target = inputs.cuda(), target.cuda()
                loss2 = torch.nn.CrossEntropyLoss()
                hessian_comp = hessian(client_model.eval(), loss2, data=(inputs, target), cuda=True)
                top_eigenvalues = hessian_comp.trace()
                break

        return loss.item(), top_eigenvalues

In [4]:
def server_aggregate(global_model, client_models, weights):
    """
    This function has aggregation method 'mean'
    """
    ### This will take simple mean of the weights of models ###
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([(weights[i]*(client_models[i].state_dict()[k].float())) for i in range(len(client_models))], 0).mean(0)
            
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

In [5]:
def test(global_model, test_loader):
    """This function test the global model on test data and returns test loss and test accuracy """
    global_model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        count = 0
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
#             target = torch.nn.functional.one_hot(target)
            output = global_model(data)
            test_loss += nn.functional.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            count += 1
        
    test_loss /= len(test_loader.dataset)
    acc = correct / len(test_loader.dataset)

    return test_loss, acc

In [6]:
num_clients = 10
num_selected = 10
num_rounds = 150
epochs = 5
batch_size = 100
random.seed(13)
np.random.seed(13)
torch.manual_seed(13)

mat = scipy.io.loadmat('./dataset/emnist-letters.mat')
data = mat["dataset"]
writer_ids_train = data['train'][0,0]['writers'][0,0]
writer_ids_train = np.squeeze(writer_ids_train)
X_train = data['train'][0,0]['images'][0,0]
X_train = X_train.reshape((X_train.shape[0], 28, 28), order = "F")
y_train = data['train'][0,0]['labels'][0,0]
y_train = np.squeeze(y_train)
y_train -= 1 #y_train is zero-based
indtemp = list()
vec = writer_ids_train%num_clients
for i in range(num_clients):
    indi = np.where(vec == i)
    indtemp.append(list(indi[0]))

train_loader = [torch.utils.data.DataLoader(TensorDataset(torch.FloatTensor(X_train[indtemp[i]][:][:]), torch.LongTensor(y_train[indtemp[i]][:][:])), batch_size=batch_size, shuffle=True) for i in range(num_clients)]
    

writer_ids_test = data['test'][0,0]['writers'][0,0]
writer_ids_test = np.squeeze(writer_ids_test)
X_test = data['test'][0,0]['images'][0,0]
X_test= X_test.reshape((X_test.shape[0], 28, 28), order = "F")
y_test = data['test'][0,0]['labels'][0,0]
y_test = np.squeeze(y_test)
y_test -= 1 #y_test is zero-based
test_loader = torch.utils.data.DataLoader(TensorDataset(torch.FloatTensor(X_test), torch.LongTensor(y_test)), batch_size=batch_size, shuffle=True)
 


In [7]:
############################################
#### Initializing models and optimizer  ####
############################################

# #### global model ##########
# global_model =  VGG('VGG19').cuda()
global_model =  Network().cuda()
# ############## client models ##############
# client_models = [ VGG('VGG19').cuda() for _ in range(num_selected)]
client_models = [ Network().cuda() for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict()) ### initial synchronizing with global model 

############### optimizers ################
opt = [optim.Adam(model.parameters(), lr=0.00001) for model in client_models]


###### List containing info about learning #########
losses_train = []
losses_test = []
acc_train = []
acc_test = []
# Runnining FL
mode = 'HessFuse'
for r in range(num_rounds):
    # select random clients
    client_idx = np.random.permutation(num_clients)[:num_selected]
    # client update
    losstot = 0
    eigs = np.ones(num_selected)
    for i in tqdm(range(num_selected)):
        if mode == 'HessFuse': #0.658 test acc 1/trace
            loss, eigss = client_update(client_models[client_idx[i]], opt[client_idx[i]], train_loader[client_idx[i]], mode, epoch=epochs)
            eigs[i] = eigss[0]
        if mode == 'Average': #0.638
            loss = client_update(client_models[client_idx[i]], opt[client_idx[i]], train_loader[client_idx[client_idx[i]]], mode, epoch=epochs)
        losstot += loss
        
    weights = eigs/(np.sum(eigs))
    print(weights)
    losses_train.append(loss)
    server_aggregate(global_model, client_models, weights*num_selected)  
    test_loss, acc = test(global_model.eval(), test_loader)
    losses_test.append(test_loss)
    acc_test.append(acc)
    print('%d-th round' % r)
    print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / num_selected, test_loss, acc))
    
    

100%|██████████| 10/10 [00:17<00:00,  1.79s/it]


[0.08693041 0.0766195  0.06779291 0.29004363 0.05919217 0.12081558
 0.0292873  0.04821023 0.12292144 0.09818683]


  0%|          | 0/10 [00:00<?, ?it/s]

0-th round
average train loss -0.045 | test loss -0.346 | test acc: 0.367


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.15196743  0.10375109  0.10239802  0.26527988  0.1701561   0.18987239
  0.05148887 -0.13229914  0.0456384   0.05174697]


  0%|          | 0/10 [00:00<?, ?it/s]

1-th round
average train loss -0.0401 | test loss -0.466 | test acc: 0.482


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[0.03855168 0.04197226 0.04936895 0.13909707 0.10011593 0.09957544
 0.08101252 0.07746192 0.25072013 0.1221241 ]


  0%|          | 0/10 [00:00<?, ?it/s]

2-th round
average train loss -0.0524 | test loss -0.488 | test acc: 0.501


100%|██████████| 10/10 [00:17<00:00,  1.78s/it]


[0.09560898 0.12744989 0.0500896  0.10552131 0.1004416  0.08853844
 0.23389856 0.09465206 0.03942752 0.06437204]


  0%|          | 0/10 [00:00<?, ?it/s]

3-th round
average train loss -0.0534 | test loss -0.5 | test acc: 0.511


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[0.07580187 0.08796482 0.11902596 0.13720005 0.12737071 0.06239172
 0.04670086 0.15205827 0.08344108 0.10804466]


  0%|          | 0/10 [00:00<?, ?it/s]

4-th round
average train loss -0.0519 | test loss -0.508 | test acc: 0.516


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[0.12167808 0.10339406 0.05821965 0.11598935 0.10109785 0.09620405
 0.27017926 0.03246523 0.06044997 0.04032251]


  0%|          | 0/10 [00:00<?, ?it/s]

5-th round
average train loss -0.0622 | test loss -0.512 | test acc: 0.520


100%|██████████| 10/10 [00:17<00:00,  1.80s/it]


[0.01866802 0.02467405 0.02343662 0.05851534 0.15629563 0.23462203
 0.07938078 0.08771688 0.2475633  0.06912735]


  0%|          | 0/10 [00:00<?, ?it/s]

6-th round
average train loss -0.0607 | test loss -0.517 | test acc: 0.523


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[0.09467366 0.01919481 0.04935009 0.02071109 0.07956976 0.42891272
 0.11395024 0.08615901 0.02258291 0.0848957 ]


  0%|          | 0/10 [00:00<?, ?it/s]

7-th round
average train loss -0.0629 | test loss -0.52 | test acc: 0.526


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.03183482  0.10774007  0.21286433  0.09751738  0.04460357  0.16484684
  0.17168386  0.12460142 -0.12761656  0.17192427]


  0%|          | 0/10 [00:00<?, ?it/s]

8-th round
average train loss -0.0456 | test loss -0.523 | test acc: 0.529


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[0.13536005 0.10983373 0.04325535 0.04906575 0.10243137 0.14340498
 0.13998186 0.06927005 0.07738622 0.13001063]


  0%|          | 0/10 [00:00<?, ?it/s]

9-th round
average train loss -0.0603 | test loss -0.525 | test acc: 0.531


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.20206066  0.03194424  0.0998708   0.05462624  0.05001709  0.10926246
  0.15402879  0.11064114 -0.02354931  0.21109788]


  0%|          | 0/10 [00:00<?, ?it/s]

10-th round
average train loss -0.0596 | test loss -0.527 | test acc: 0.532


100%|██████████| 10/10 [00:18<00:00,  1.80s/it]


[0.04858022 0.07950737 0.04562675 0.12081801 0.11497904 0.18950757
 0.1491668  0.06764313 0.13122652 0.05294459]


  0%|          | 0/10 [00:00<?, ?it/s]

11-th round
average train loss -0.0561 | test loss -0.529 | test acc: 0.535


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[0.04481419 0.19100576 0.12187263 0.03954738 0.02370945 0.3800379
 0.04635815 0.04477827 0.03241951 0.07545676]


  0%|          | 0/10 [00:00<?, ?it/s]

12-th round
average train loss -0.0532 | test loss -0.53 | test acc: 0.536


100%|██████████| 10/10 [00:17<00:00,  1.78s/it]


[0.06005227 0.15772613 0.08853322 0.1468821  0.003317   0.06500128
 0.03746626 0.0606118  0.30442586 0.07598409]


  0%|          | 0/10 [00:00<?, ?it/s]

13-th round
average train loss -0.053 | test loss -0.532 | test acc: 0.537


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[ 0.01030354  0.0291703  -0.01101253  0.11824456  0.07450093  0.22429622
  0.20549462  0.16958908  0.12062098  0.0587923 ]


  0%|          | 0/10 [00:00<?, ?it/s]

14-th round
average train loss -0.0516 | test loss -0.533 | test acc: 0.538


100%|██████████| 10/10 [00:17<00:00,  1.78s/it]


[-0.01732204  0.10154822  0.08145714  0.07570107  0.13993495  0.07142636
  0.03366231  0.12224347  0.32139511  0.0699534 ]


  0%|          | 0/10 [00:00<?, ?it/s]

15-th round
average train loss -0.052 | test loss -0.534 | test acc: 0.539


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]


[0.00948682 0.08166856 0.20089142 0.039154   0.09392373 0.14158577
 0.10164501 0.04145358 0.23146507 0.05872603]


  0%|          | 0/10 [00:00<?, ?it/s]

16-th round
average train loss -0.0463 | test loss -0.547 | test acc: 0.551


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[0.08426392 0.09721728 0.01611958 0.09368292 0.12628875 0.21845736
 0.04242911 0.13304235 0.11748829 0.07101044]


  0%|          | 0/10 [00:00<?, ?it/s]

17-th round
average train loss -0.0597 | test loss -0.592 | test acc: 0.599


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]


[ 0.0556209  -0.01726525  0.10170619  0.21409333  0.10049245  0.06546543
  0.16554708  0.10938266  0.1395612   0.06539603]


  0%|          | 0/10 [00:00<?, ?it/s]

18-th round
average train loss -0.0628 | test loss -0.596 | test acc: 0.601


100%|██████████| 10/10 [00:17<00:00,  1.78s/it]


[0.06842246 0.07098127 0.05227507 0.09626354 0.2082372  0.04167008
 0.15348571 0.22626129 0.04001717 0.04238621]


  0%|          | 0/10 [00:00<?, ?it/s]

19-th round
average train loss -0.0674 | test loss -0.598 | test acc: 0.604


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[0.05736886 0.20438246 0.12920596 0.09329778 0.08130922 0.12725364
 0.06940023 0.06030946 0.16201537 0.01545701]


  0%|          | 0/10 [00:00<?, ?it/s]

20-th round
average train loss -0.0748 | test loss -0.601 | test acc: 0.606


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[0.02384833 0.03947727 0.17188726 0.03946272 0.10733607 0.07049219
 0.19874573 0.11443064 0.07915163 0.15516817]


  0%|          | 0/10 [00:00<?, ?it/s]

21-th round
average train loss -0.0593 | test loss -0.602 | test acc: 0.606


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[0.24697038 0.02352801 0.07134845 0.08169246 0.2209532  0.07876171
 0.0437918  0.07553701 0.09105862 0.06635834]


  0%|          | 0/10 [00:00<?, ?it/s]

22-th round
average train loss -0.0495 | test loss -0.603 | test acc: 0.608


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[0.03765593 0.08066292 0.08194935 0.11808085 0.0844595  0.0866147
 0.04397549 0.11436246 0.26804242 0.08419638]


  0%|          | 0/10 [00:00<?, ?it/s]

23-th round
average train loss -0.0747 | test loss -0.604 | test acc: 0.609


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[ 0.18280895  0.2040703   0.16146238 -0.02405809  0.05044928  0.22418476
  0.05413629  0.0262285   0.08980837  0.03090926]


  0%|          | 0/10 [00:00<?, ?it/s]

24-th round
average train loss -0.0817 | test loss -0.605 | test acc: 0.610


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[0.19532515 0.04231801 0.03968004 0.09129557 0.17344579 0.11806039
 0.07008512 0.0217497  0.11716055 0.13087969]


  0%|          | 0/10 [00:00<?, ?it/s]

25-th round
average train loss -0.0665 | test loss -0.606 | test acc: 0.610


100%|██████████| 10/10 [00:17<00:00,  1.78s/it]


[ 0.04237349  0.09707517  0.250707    0.02989329  0.12542265  0.13309515
  0.06687386  0.22272176 -0.01761657  0.04945419]


  0%|          | 0/10 [00:00<?, ?it/s]

26-th round
average train loss -0.0689 | test loss -0.606 | test acc: 0.611


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[0.05044039 0.23442877 0.02477923 0.14959568 0.11735792 0.03880371
 0.15849356 0.10794568 0.02372702 0.09442804]


  0%|          | 0/10 [00:00<?, ?it/s]

27-th round
average train loss -0.0598 | test loss -0.607 | test acc: 0.611


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.05342119  0.14157563  0.13184891  0.07818328  0.08835368 -0.06648839
  0.22700517  0.21429006  0.07944819  0.05236228]


  0%|          | 0/10 [00:00<?, ?it/s]

28-th round
average train loss -0.0735 | test loss -0.608 | test acc: 0.611


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[0.01903958 0.15166625 0.08200474 0.10604383 0.08109491 0.02350404
 0.04663782 0.07692518 0.16226887 0.25081478]


  0%|          | 0/10 [00:00<?, ?it/s]

29-th round
average train loss -0.0734 | test loss -0.608 | test acc: 0.611


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[0.05947835 0.03453334 0.09717721 0.09059063 0.10902488 0.22945378
 0.02635284 0.1610796  0.06127993 0.13102944]


  0%|          | 0/10 [00:00<?, ?it/s]

30-th round
average train loss -0.0602 | test loss -0.615 | test acc: 0.618


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.00222731  0.10646543  0.2040206  -0.01583703  0.04868597  0.04051386
  0.03986249  0.24828067  0.06205567  0.26372503]


  0%|          | 0/10 [00:00<?, ?it/s]

31-th round
average train loss -0.0675 | test loss -0.622 | test acc: 0.627


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[ 0.32708024  0.00941589  0.27913505  0.03336308 -0.02839594  0.11060522
  0.04779084  0.11514006  0.01701252  0.08885303]


  0%|          | 0/10 [00:00<?, ?it/s]

32-th round
average train loss -0.0578 | test loss -0.626 | test acc: 0.631


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[ 0.13254361  0.00679028  0.16571322  0.14242974  0.19983709  0.07123596
  0.15906588 -0.10102528  0.12228481  0.10112469]


  0%|          | 0/10 [00:00<?, ?it/s]

33-th round
average train loss -0.0778 | test loss -0.628 | test acc: 0.632


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[0.07118212 0.09371463 0.13971104 0.11519932 0.00376871 0.10811434
 0.02585648 0.2332069  0.08102872 0.12821774]


  0%|          | 0/10 [00:00<?, ?it/s]

34-th round
average train loss -0.0684 | test loss -0.629 | test acc: 0.633


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[ 0.08153848  0.06771802  0.28961062  0.10557918  0.09395515  0.0500557
  0.0714313   0.16397488  0.1703458  -0.09420913]


  0%|          | 0/10 [00:00<?, ?it/s]

35-th round
average train loss -0.0875 | test loss -0.629 | test acc: 0.633


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.08181909  0.0905846  -0.08721377  0.00343994  0.11102548  0.03936768
  0.28408971  0.04985757  0.31300357  0.11402612]


  0%|          | 0/10 [00:00<?, ?it/s]

36-th round
average train loss -0.0654 | test loss -0.63 | test acc: 0.634


100%|██████████| 10/10 [00:18<00:00,  1.83s/it]


[0.05122515 0.0683479  0.01387099 0.33192256 0.23851724 0.11999425
 0.02175903 0.07844837 0.03882256 0.03709195]


  0%|          | 0/10 [00:00<?, ?it/s]

37-th round
average train loss -0.0816 | test loss -0.63 | test acc: 0.634


100%|██████████| 10/10 [00:17<00:00,  1.78s/it]


[ 0.22351597  0.0801293   0.15222097  0.10970111 -0.29209     0.2104243
  0.02053146  0.21724944  0.12005918  0.15825825]


  0%|          | 0/10 [00:00<?, ?it/s]

38-th round
average train loss -0.0666 | test loss -0.629 | test acc: 0.632


100%|██████████| 10/10 [00:18<00:00,  1.81s/it]


[0.14016535 0.01717006 0.0185974  0.02921382 0.13522318 0.15571411
 0.08228855 0.06554164 0.20674095 0.14934493]


  0%|          | 0/10 [00:00<?, ?it/s]

39-th round
average train loss -0.0582 | test loss -0.631 | test acc: 0.634


100%|██████████| 10/10 [00:17<00:00,  1.79s/it]


[ 0.08999197  0.13202126 -0.04253063  0.08677886  0.12504631  0.07787718
  0.06463097  0.38672412  0.07406207  0.0053979 ]


  0%|          | 0/10 [00:00<?, ?it/s]

40-th round
average train loss -0.0706 | test loss -0.631 | test acc: 0.634


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[0.03903316 0.25490962 0.09946263 0.0771436  0.00412444 0.18702654
 0.1807863  0.0426193  0.08468122 0.0302132 ]


  0%|          | 0/10 [00:00<?, ?it/s]

41-th round
average train loss -0.0626 | test loss -0.631 | test acc: 0.635


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[0.10559649 0.12180376 0.09155122 0.12145485 0.14064293 0.0249086
 0.22686875 0.09888411 0.04694993 0.02133936]


  0%|          | 0/10 [00:00<?, ?it/s]

42-th round
average train loss -0.0651 | test loss -0.632 | test acc: 0.635


100%|██████████| 10/10 [00:17<00:00,  1.80s/it]


[ 0.0342676   0.17074405  0.26634662  0.14789212  0.10229391  0.12784528
 -0.00612328  0.05311756  0.10045613  0.00316   ]


  0%|          | 0/10 [00:00<?, ?it/s]

43-th round
average train loss -0.0589 | test loss -0.632 | test acc: 0.635


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[0.10662006 0.02203983 0.02106421 0.24886306 0.02833085 0.09584306
 0.08973691 0.02598044 0.15493589 0.20658569]


  0%|          | 0/10 [00:00<?, ?it/s]

44-th round
average train loss -0.0641 | test loss -0.632 | test acc: 0.636


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[ 0.1098599   0.14985503  0.08637031  0.07743809  0.00690532  0.10450289
  0.20447299  0.15695429 -0.01276158  0.11640275]


  0%|          | 0/10 [00:00<?, ?it/s]

45-th round
average train loss -0.0611 | test loss -0.633 | test acc: 0.636


100%|██████████| 10/10 [00:17<00:00,  1.71s/it]


[ 0.01345927  0.0841762   0.50690233  0.15933322 -0.04796877  0.04951885
  0.05640551  0.09982156  0.0773148   0.00103703]


  0%|          | 0/10 [00:00<?, ?it/s]

46-th round
average train loss -0.0818 | test loss -0.632 | test acc: 0.636


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[ 0.11962118  0.03727535  0.15798288  0.11631271  0.15386699  0.11686579
  0.05596696  0.18677815  0.0558539  -0.0005239 ]


  0%|          | 0/10 [00:00<?, ?it/s]

47-th round
average train loss -0.0666 | test loss -0.633 | test acc: 0.636


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[0.05424805 0.07746238 0.06376416 0.0459073  0.04051378 0.05010758
 0.25512922 0.05010273 0.1994934  0.16327139]


  0%|          | 0/10 [00:00<?, ?it/s]

48-th round
average train loss -0.0749 | test loss -0.633 | test acc: 0.636


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.21056727  0.31634176 -1.8466361  -0.02512237  0.21264802  0.21577855
  0.79204604  0.41733139  0.67108153  0.03596389]


  0%|          | 0/10 [00:00<?, ?it/s]

49-th round
average train loss -0.062 | test loss -0.62 | test acc: 0.624


100%|██████████| 10/10 [00:16<00:00,  1.68s/it]


[-0.24943725  0.14640077 -0.45769899 -0.17009699 -0.14921392  0.73025536
  0.92940183  0.2437681  -0.19793428  0.17455537]


  0%|          | 0/10 [00:00<?, ?it/s]

50-th round
average train loss -0.0613 | test loss -0.619 | test acc: 0.623


100%|██████████| 10/10 [00:17<00:00,  1.79s/it]


[-0.13856488 -0.01120703  0.09943502  0.02518623  0.24801945 -0.02257565
  0.04782351  0.21664672  0.08892241  0.44631422]


  0%|          | 0/10 [00:00<?, ?it/s]

51-th round
average train loss -0.0679 | test loss -0.63 | test acc: 0.633


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.05360011  0.0597679   0.1707482   0.02638707  0.26228636  0.08304888
 -0.00507143  0.24438309  0.10132903  0.0035208 ]


  0%|          | 0/10 [00:00<?, ?it/s]

52-th round
average train loss -0.0659 | test loss -0.632 | test acc: 0.635


100%|██████████| 10/10 [00:17<00:00,  1.79s/it]


[ 0.1016239   0.0369239  -0.02351105  0.25001403  0.0511997   0.02041386
  0.13680217  0.03607223  0.28138355  0.1090777 ]


  0%|          | 0/10 [00:00<?, ?it/s]

53-th round
average train loss -0.0767 | test loss -0.633 | test acc: 0.636


100%|██████████| 10/10 [00:17<00:00,  1.71s/it]


[ 0.31252885 -0.0197023   0.34797235  0.16240576  0.1055696  -0.04827495
  0.06265177  0.14633671  0.09236491 -0.16185271]


  0%|          | 0/10 [00:00<?, ?it/s]

54-th round
average train loss -0.0629 | test loss -0.633 | test acc: 0.636


100%|██████████| 10/10 [00:17<00:00,  1.70s/it]


[ 0.00283709  0.22422974 -0.46881415 -0.74588701 -0.10826888 -0.10472058
 -0.50318236 -0.03306357  1.64131557  1.09555415]


  0%|          | 0/10 [00:00<?, ?it/s]

55-th round
average train loss -0.0625 | test loss -0.619 | test acc: 0.622


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.03027275  0.03856833  0.0323493   0.07402282  0.10231875  0.22647439
  0.18211087 -0.02558468  0.16964265  0.16982482]


  0%|          | 0/10 [00:00<?, ?it/s]

56-th round
average train loss -0.0767 | test loss -0.633 | test acc: 0.636


100%|██████████| 10/10 [00:17<00:00,  1.80s/it]


[0.09897397 0.06046317 0.08898914 0.04044103 0.02867169 0.55953876
 0.02783531 0.03553953 0.01841444 0.04113295]


  0%|          | 0/10 [00:00<?, ?it/s]

57-th round
average train loss -0.0889 | test loss -0.633 | test acc: 0.635


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[-0.14263358  0.06539599  0.14307995  0.01651114  0.22972291  0.48014512
 -0.13416757  0.17573544  0.15551273  0.01069788]


  0%|          | 0/10 [00:00<?, ?it/s]

58-th round
average train loss -0.0735 | test loss -0.633 | test acc: 0.635


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.1033051   0.2653892   0.06742608  0.26960548  0.48323487  0.13374748
  0.30879393  0.01306639 -1.53027395  0.88570542]


  0%|          | 0/10 [00:00<?, ?it/s]

59-th round
average train loss -0.0704 | test loss -0.622 | test acc: 0.625


100%|██████████| 10/10 [00:17<00:00,  1.71s/it]


[-0.07353239 -0.01448569 -0.51137283  0.53317374  0.00178764  0.09063203
  0.12631809  0.33376605  0.58122347 -0.06751011]


  0%|          | 0/10 [00:00<?, ?it/s]

60-th round
average train loss -0.0833 | test loss -0.63 | test acc: 0.632


100%|██████████| 10/10 [00:18<00:00,  1.82s/it]


[ 0.05492152  0.59061895  0.112305    0.06586961  0.14081518  0.02508867
 -0.31285932  0.04725147 -0.01140714  0.28739607]


  0%|          | 0/10 [00:00<?, ?it/s]

61-th round
average train loss -0.0552 | test loss -0.633 | test acc: 0.635


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[-0.02106914  0.07016098  0.00708115  0.00884846 -0.0192794   0.1899417
  0.41870709  0.26853688  0.04418959  0.03288269]


  0%|          | 0/10 [00:00<?, ?it/s]

62-th round
average train loss -0.072 | test loss -0.634 | test acc: 0.636


100%|██████████| 10/10 [00:17<00:00,  1.79s/it]


[ 0.23763732 -0.28728132  0.33245825  0.00972045  0.11640164 -0.05326984
  0.28229953 -0.00432277  0.17806047  0.18829627]


  0%|          | 0/10 [00:00<?, ?it/s]

63-th round
average train loss -0.0656 | test loss -0.634 | test acc: 0.636


100%|██████████| 10/10 [00:17<00:00,  1.72s/it]


[-0.06169375  0.46353894  2.93064005 -4.44072856  1.86857554  0.38601615
 13.76493415 -0.07018541 -6.00191411 -7.83918299]


  0%|          | 0/10 [00:00<?, ?it/s]

64-th round
average train loss -0.0652 | test loss -0.0464 | test acc: 0.046


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.20200801  0.07899843  0.109069   -0.00206417 -0.01366042  0.10905791
  0.08540017  0.217648    0.10026231  0.11328076]


  0%|          | 0/10 [00:00<?, ?it/s]

65-th round
average train loss -0.0558 | test loss -0.511 | test acc: 0.514


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[ 0.04824949  0.09943265  0.05159529  0.31256683  0.10523335  0.00759901
  0.11772548  0.09547933 -0.05254469  0.21466324]


  0%|          | 0/10 [00:00<?, ?it/s]

66-th round
average train loss -0.0535 | test loss -0.554 | test acc: 0.556


100%|██████████| 10/10 [00:18<00:00,  1.80s/it]


[ 0.16933445  0.05116815  0.43009541  0.10756666  0.00945857  0.18994709
  0.03970647 -0.02302211 -0.04850282  0.07424812]


  0%|          | 0/10 [00:00<?, ?it/s]

67-th round
average train loss -0.0505 | test loss -0.593 | test acc: 0.595


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.17485341  0.31117375  0.25806859  0.02990365  0.02163692  0.30871132
 -0.00090684  0.19053369  0.08471096 -0.37868545]


  0%|          | 0/10 [00:00<?, ?it/s]

68-th round
average train loss -0.0531 | test loss -0.602 | test acc: 0.605


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.02583006  0.02232863 -0.06141061  0.54553485 -0.00072208 -0.03117334
  0.26201166  0.05700573  0.01994497  0.16065013]


  0%|          | 0/10 [00:00<?, ?it/s]

69-th round
average train loss -0.063 | test loss -0.608 | test acc: 0.610


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.12783554  0.51706596 -0.29734417  0.03047614  0.24201513 -0.14239648
  0.04716051  0.47477473 -0.03147109  0.03188372]


  0%|          | 0/10 [00:00<?, ?it/s]

70-th round
average train loss -0.068 | test loss -0.612 | test acc: 0.615


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.1528098   0.07374101  0.07545688  0.10588339  0.20971566 -0.21684217
  0.04497453  0.13537498  0.11681159  0.30207433]


  0%|          | 0/10 [00:00<?, ?it/s]

71-th round
average train loss -0.0633 | test loss -0.616 | test acc: 0.619


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[-0.06338079  0.02481589 -0.0008831   0.11124265  0.14624099  0.50079594
  0.08989981  0.11618365  0.03878171  0.03630325]


  0%|          | 0/10 [00:00<?, ?it/s]

72-th round
average train loss -0.0673 | test loss -0.619 | test acc: 0.621


100%|██████████| 10/10 [00:18<00:00,  1.82s/it]


[ 0.18605018 -0.0028667   0.15306588  0.13975547  0.08115283  0.05649524
  0.16471299  0.08499934  0.02863881  0.10799597]


  0%|          | 0/10 [00:00<?, ?it/s]

73-th round
average train loss -0.0672 | test loss -0.622 | test acc: 0.624


100%|██████████| 10/10 [00:17<00:00,  1.79s/it]


[ 0.02935947  0.08385688  0.16116576  0.23917182  0.01169255  0.11925969
  0.05022977  0.21890813  0.09084257 -0.00448664]


  0%|          | 0/10 [00:00<?, ?it/s]

74-th round
average train loss -0.0676 | test loss -0.623 | test acc: 0.625


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]


[-0.07128648 -0.24279086 -0.19285441  0.03786099  0.35920426  0.13209364
  0.25655251  0.04849914  0.60008033  0.07264088]


  0%|          | 0/10 [00:00<?, ?it/s]

75-th round
average train loss -0.0529 | test loss -0.623 | test acc: 0.625


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]


[ 0.33463253  0.01136025  0.0796332   0.11441417  0.1189026   0.00397713
 -0.01450478  0.29864727 -0.00099314  0.05393076]


  0%|          | 0/10 [00:00<?, ?it/s]

76-th round
average train loss -0.065 | test loss -0.625 | test acc: 0.626


100%|██████████| 10/10 [00:17<00:00,  1.72s/it]


[ 0.06418349  0.23663478 -0.46302548  0.32138091  0.10213272 -0.02891255
  0.24547757 -0.01465922  0.2565619   0.28022588]


  0%|          | 0/10 [00:00<?, ?it/s]

77-th round
average train loss -0.073 | test loss -0.625 | test acc: 0.627


100%|██████████| 10/10 [00:16<00:00,  1.70s/it]


[-0.55957344 -0.29187464  0.85123688 -1.47206577  0.07954929 -0.01828545
 -0.92479898  0.86160325 -0.54920014  3.023409  ]


  0%|          | 0/10 [00:00<?, ?it/s]

78-th round
average train loss -0.0649 | test loss -0.574 | test acc: 0.577


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.06519202  0.39787038 -0.05470555  0.12415885 -0.04807462  0.06792893
  0.27668791  0.01606279  0.16631419 -0.01143488]


  0%|          | 0/10 [00:00<?, ?it/s]

79-th round
average train loss -0.0664 | test loss -0.623 | test acc: 0.625


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[ 0.07510666  0.05687621  0.0479404   0.47790968 -0.24563952  0.08179925
  0.19239533 -0.00511587  0.11827402  0.20045384]


  0%|          | 0/10 [00:00<?, ?it/s]

80-th round
average train loss -0.0692 | test loss -0.625 | test acc: 0.627


100%|██████████| 10/10 [00:16<00:00,  1.68s/it]


[-0.05554124  0.13887411  0.15971081  0.08234642 -0.06149071 -0.44813817
  0.15798448  0.18933424 -0.17628106  1.01320111]


  0%|          | 0/10 [00:00<?, ?it/s]

81-th round
average train loss -0.0671 | test loss -0.621 | test acc: 0.622


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[ 0.01163424  0.07594311  0.1870432   0.1312733  -0.0088883   0.03093909
  0.06903088 -0.00361203  0.02743321  0.47920329]


  0%|          | 0/10 [00:00<?, ?it/s]

82-th round
average train loss -0.0559 | test loss -0.627 | test acc: 0.628


100%|██████████| 10/10 [00:17<00:00,  1.71s/it]


[-0.24122351  0.71157963  0.09805944  0.05863296 -0.03500398  0.06062533
  0.51693687 -0.5228007   0.36466183 -0.01146788]


  0%|          | 0/10 [00:00<?, ?it/s]

83-th round
average train loss -0.0666 | test loss -0.625 | test acc: 0.626


100%|██████████| 10/10 [00:17<00:00,  1.79s/it]


[-2.15479168 -0.40521623  0.22918603  4.93842296 -4.72361727  0.89624815
  1.23502093  0.12003759  0.43295934  0.43175019]


  0%|          | 0/10 [00:00<?, ?it/s]

84-th round
average train loss -0.0652 | test loss -0.263 | test acc: 0.263


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.24011863 -0.01989214  0.08725905  0.10326734  0.03537292  0.0385502
  0.34086816  0.01178693 -0.24653584  0.40920473]


  0%|          | 0/10 [00:00<?, ?it/s]

85-th round
average train loss -0.0603 | test loss -0.592 | test acc: 0.593


100%|██████████| 10/10 [00:17<00:00,  1.70s/it]


[ 0.04982445 -0.01988386 -0.03575896  0.56781259 -0.06336697  0.14603202
  0.02911262  0.05163331  0.22408613  0.05050867]


  0%|          | 0/10 [00:00<?, ?it/s]

86-th round
average train loss -0.0574 | test loss -0.61 | test acc: 0.611


100%|██████████| 10/10 [00:17<00:00,  1.71s/it]


[-0.04314826  0.03941709  0.22486991  0.09557679  0.14026539  0.01697729
 -0.08613907  0.53664549  0.02338591  0.05214946]


  0%|          | 0/10 [00:00<?, ?it/s]

87-th round
average train loss -0.058 | test loss -0.618 | test acc: 0.620


100%|██████████| 10/10 [00:17<00:00,  1.70s/it]


[ 0.24689279  0.17686493  0.58068961  0.06844498  0.36120932 -0.07174251
  0.05568095 -0.31919653 -0.7623126   0.66346905]


  0%|          | 0/10 [00:00<?, ?it/s]

88-th round
average train loss -0.0704 | test loss -0.618 | test acc: 0.619


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.40464469  0.01110647 -0.02903374 -0.00833684  0.02576965  0.30272923
  0.03901545  0.02168923  0.15703613  0.07537973]


  0%|          | 0/10 [00:00<?, ?it/s]

89-th round
average train loss -0.0693 | test loss -0.624 | test acc: 0.626


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[-0.01307417  0.08007032  0.02885483  0.09131287  0.14320846  0.1465581
  0.23886769  0.03194328  0.10393985  0.14831877]


  0%|          | 0/10 [00:00<?, ?it/s]

90-th round
average train loss -0.0664 | test loss -0.627 | test acc: 0.628


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[ 0.18428758  0.2319236   0.25552829 -0.18450128  0.04608524  0.24948203
  0.05867248  0.04966069  0.03100784  0.07785351]


  0%|          | 0/10 [00:00<?, ?it/s]

91-th round
average train loss -0.0567 | test loss -0.628 | test acc: 0.629


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[-0.00418784  0.36181545 -0.40166541  0.13165236  0.06243412  0.04286879
  0.18571234 -0.38829318  0.05679122  0.95287215]


  0%|          | 0/10 [00:00<?, ?it/s]

92-th round
average train loss -0.055 | test loss -0.623 | test acc: 0.624


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[0.02154015 0.38066134 0.00668456 0.0017155  0.42509711 0.07134856
 0.01838866 0.02730844 0.04023563 0.00702005]


  0%|          | 0/10 [00:00<?, ?it/s]

93-th round
average train loss -0.061 | test loss -0.629 | test acc: 0.630


100%|██████████| 10/10 [00:16<00:00,  1.68s/it]


[ 0.026427    1.01743538  0.13097799  0.11090353  0.00207777  0.04681191
 -0.29609776  0.16467191 -0.03575137 -0.16745635]


  0%|          | 0/10 [00:00<?, ?it/s]

94-th round
average train loss -0.0623 | test loss -0.627 | test acc: 0.628


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[-0.02096525  0.15355334  0.11508989  0.334134    0.0843873  -0.01491887
  0.22872104  0.07051722  0.05986777 -0.01038643]


  0%|          | 0/10 [00:00<?, ?it/s]

95-th round
average train loss -0.0792 | test loss -0.63 | test acc: 0.632


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[ 0.00131389 -0.10643106  0.69497552  0.02283317 -0.05915121  0.18353177
  0.22888932 -0.01442558 -0.00945363  0.05791781]


  0%|          | 0/10 [00:00<?, ?it/s]

96-th round
average train loss -0.0626 | test loss -0.631 | test acc: 0.632


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[ 1.51166432  3.57337738  1.37545499  1.46304959  0.53274432 -0.02015666
 -0.14331076 -0.29794125 -1.31752178 -5.67736015]


  0%|          | 0/10 [00:00<?, ?it/s]

97-th round
average train loss -0.0706 | test loss -0.35 | test acc: 0.352


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[ 0.00256543  0.00133213 -0.30209319 -0.02012979  0.79939858  0.63228913
  0.81056481  0.03749769  0.16063655 -1.12206134]


  0%|          | 0/10 [00:00<?, ?it/s]

98-th round
average train loss -0.0617 | test loss -0.594 | test acc: 0.595


100%|██████████| 10/10 [00:17<00:00,  1.79s/it]


[ 0.03314544  0.06180575  0.08470583  0.17716289  0.07464324  0.48608141
 -0.2003543   0.04739153  0.09434643  0.14107177]


  0%|          | 0/10 [00:00<?, ?it/s]

99-th round
average train loss -0.0709 | test loss -0.619 | test acc: 0.620


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[-0.0041002  -0.03514624  0.07705435 -0.02067164 -0.09740287  0.24253659
  0.87203902  0.02595947 -0.06709208  0.0068236 ]


  0%|          | 0/10 [00:00<?, ?it/s]

100-th round
average train loss -0.0499 | test loss -0.623 | test acc: 0.625


100%|██████████| 10/10 [00:17<00:00,  1.78s/it]


[ 0.54910793 -0.18764528  0.24131064 -0.05676884 -0.01764976  0.16807424
  0.15527259  0.11437656  0.03305582  0.00086609]


  0%|          | 0/10 [00:00<?, ?it/s]

101-th round
average train loss -0.0608 | test loss -0.627 | test acc: 0.629


100%|██████████| 10/10 [00:17<00:00,  1.72s/it]


[ 0.61375084 -0.53067408  0.33992331  0.24003307  0.22400388 -0.32059911
  0.02446418 -0.10165509 -0.11166144  0.62241443]


  0%|          | 0/10 [00:00<?, ?it/s]

102-th round
average train loss -0.0588 | test loss -0.625 | test acc: 0.626


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.07123172 -0.03365782  0.01118885  0.15324168  0.06499573  0.01149657
  0.19233104  0.31284478  0.04773314  0.16859431]


  0%|          | 0/10 [00:00<?, ?it/s]

103-th round
average train loss -0.0651 | test loss -0.63 | test acc: 0.631


100%|██████████| 10/10 [00:16<00:00,  1.69s/it]


[ 0.40926756 -0.4300979  -0.00241869  0.29816456  0.00627113 -0.06733814
  1.09939974  0.03091843 -0.19439527 -0.14977143]


  0%|          | 0/10 [00:00<?, ?it/s]

104-th round
average train loss -0.0772 | test loss -0.625 | test acc: 0.626


100%|██████████| 10/10 [00:16<00:00,  1.67s/it]


[ 3.52980679e-01  6.10514043e-01 -2.54650826e-02 -9.46146235e-02
  3.72920565e-02  3.70397816e-02  1.43279238e-01 -4.89298683e-02
 -2.80245885e-04 -1.18159784e-02]


  0%|          | 0/10 [00:00<?, ?it/s]

105-th round
average train loss -0.0717 | test loss -0.629 | test acc: 0.631


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]


[ 0.51845223 -0.03061452 -0.04681724  0.20449877 -0.19575468  0.07283025
  0.5510111   0.03177345 -0.0750978  -0.03028156]


  0%|          | 0/10 [00:00<?, ?it/s]

106-th round
average train loss -0.0672 | test loss -0.631 | test acc: 0.632


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]


[-0.13158352  0.04061226 -0.36939179  0.03411397  0.99199967 -0.13195044
  0.20717567  0.08782753  0.24339572  0.02780093]


  0%|          | 0/10 [00:00<?, ?it/s]

107-th round
average train loss -0.05 | test loss -0.629 | test acc: 0.631


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[-0.25736152  0.12910077  0.06843858  0.06571943  0.05267254  0.08965077
  0.08377861  0.35505562  0.30954595  0.10339924]


  0%|          | 0/10 [00:00<?, ?it/s]

108-th round
average train loss -0.0581 | test loss -0.632 | test acc: 0.633


100%|██████████| 10/10 [00:17<00:00,  1.71s/it]


[ 0.41215901  0.18401995 -0.16802162 -0.25728032 -2.06540475  0.95999549
  1.75374862 -0.18651418  0.07343228  0.29386551]


  0%|          | 0/10 [00:00<?, ?it/s]

109-th round
average train loss -0.071 | test loss -0.605 | test acc: 0.607


100%|██████████| 10/10 [00:17<00:00,  1.72s/it]


[ 0.00415793  0.07622767  0.24322706 -0.58940089 -0.16862433 -0.06288878
  0.07224155  0.21463494  0.3995644   0.81086045]


  0%|          | 0/10 [00:00<?, ?it/s]

110-th round
average train loss -0.0652 | test loss -0.625 | test acc: 0.626


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.35633317 -0.02174802  0.10415748 -0.01852333  0.13983981  0.07247629
  0.05521954  0.01152646  0.15011259  0.15060601]


  0%|          | 0/10 [00:00<?, ?it/s]

111-th round
average train loss -0.0666 | test loss -0.632 | test acc: 0.633


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[-0.03727412  0.29264285  0.05415853  0.14749265  0.01846214 -0.11060084
  0.39721875 -0.0176496   0.2615167  -0.00596707]


  0%|          | 0/10 [00:00<?, ?it/s]

112-th round
average train loss -0.0761 | test loss -0.633 | test acc: 0.634


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.01914985 -0.11215798  0.01047266  0.05223083 -0.09053605  0.08124739
  0.48595385  0.07296187  0.24654549  0.23413207]


  0%|          | 0/10 [00:00<?, ?it/s]

113-th round
average train loss -0.0589 | test loss -0.633 | test acc: 0.634


100%|██████████| 10/10 [00:17<00:00,  1.71s/it]


[ 1.42650587  0.1682396   2.19644043  0.17352019 -1.75825661 -5.77676895
  2.44995474  0.33145509  0.68244186  1.10646778]


  0%|          | 0/10 [00:00<?, ?it/s]

114-th round
average train loss -0.064 | test loss -0.376 | test acc: 0.376


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[-0.07051012 -0.02135425  0.16372947 -0.00549823  0.07237625  0.08769807
  0.00859812  0.55233675  0.16153192  0.05109202]


  0%|          | 0/10 [00:00<?, ?it/s]

115-th round
average train loss -0.0672 | test loss -0.616 | test acc: 0.617


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[-1.39793049 -0.85260412 -4.44673476  0.67177952 13.08833166 -1.18905174
  1.33642418 -5.88688438  0.33428884 -0.65761871]


  0%|          | 0/10 [00:00<?, ?it/s]

116-th round
average train loss -0.0695 | test loss -0.174 | test acc: 0.175


100%|██████████| 10/10 [00:17<00:00,  1.71s/it]


[ 0.01122945 -0.21981123  0.20701177  0.19548171  0.05703812  0.0010049
  0.18717021 -0.07994146  0.58489662  0.05591991]


  0%|          | 0/10 [00:00<?, ?it/s]

117-th round
average train loss -0.0463 | test loss -0.494 | test acc: 0.494


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[ 0.3004931  -0.01978926 -0.00647686  0.19866455  0.09797416  0.11045215
 -0.08746936  0.09243086  0.17246669  0.14125397]


  0%|          | 0/10 [00:00<?, ?it/s]

118-th round
average train loss -0.05 | test loss -0.547 | test acc: 0.548


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]


[ 0.49934987 -0.02557238 -0.05240458  0.05571368 -0.01662221 -0.32358395
  0.1364419  -0.206732    0.10329799  0.83011168]


  0%|          | 0/10 [00:00<?, ?it/s]

119-th round
average train loss -0.0665 | test loss -0.562 | test acc: 0.562


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.14687676  0.02412973  0.04510128  0.03179211  0.09379003  0.35688079
  0.06532115  0.06003204 -0.00899482  0.18507092]


  0%|          | 0/10 [00:00<?, ?it/s]

120-th round
average train loss -0.067 | test loss -0.582 | test acc: 0.583


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.10495033  0.02766892  0.26749593  0.09265786  0.09760473  0.03152221
  0.33119214 -0.36007363  0.46620142 -0.05921992]


  0%|          | 0/10 [00:00<?, ?it/s]

121-th round
average train loss -0.0648 | test loss -0.594 | test acc: 0.595


100%|██████████| 10/10 [00:17<00:00,  1.72s/it]


[ 0.22197513  0.02042673  0.06671362 -0.06406066  0.05540808  0.14335
 -0.05996799  0.13651138  0.26905204  0.21059167]


  0%|          | 0/10 [00:00<?, ?it/s]

122-th round
average train loss -0.071 | test loss -0.603 | test acc: 0.603


100%|██████████| 10/10 [00:17<00:00,  1.72s/it]


[ 0.09096219 -0.02738038 -0.0256537   0.61037618 -0.15837886  0.17440721
  0.29766728  0.031951    0.17698561 -0.17093653]


  0%|          | 0/10 [00:00<?, ?it/s]

123-th round
average train loss -0.0722 | test loss -0.606 | test acc: 0.607


100%|██████████| 10/10 [00:17<00:00,  1.71s/it]


[-0.19780027  0.25718154 -0.00690276  0.00547735  0.2077033   0.11572097
  0.31083763  0.06312439  0.17666244  0.06799542]


  0%|          | 0/10 [00:00<?, ?it/s]

124-th round
average train loss -0.0713 | test loss -0.611 | test acc: 0.612


100%|██████████| 10/10 [00:17<00:00,  1.72s/it]


[ 0.71884655 -0.03605731  0.05224737 -0.00914826  0.17498604 -0.53426421
 -0.06874619 -0.23736545 -0.01668707  0.95618853]


  0%|          | 0/10 [00:00<?, ?it/s]

125-th round
average train loss -0.052 | test loss -0.608 | test acc: 0.610


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.19212894  0.4890258   0.01410548 -0.00924922  0.0082761   0.033214
  0.06951581  0.04295644  0.09366087  0.06636579]


  0%|          | 0/10 [00:00<?, ?it/s]

126-th round
average train loss -0.0581 | test loss -0.615 | test acc: 0.616


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]


[ 0.1737501   0.48844096  0.01098138 -0.00363281  0.09071601  0.0235136
  0.01599681 -0.00135603  0.10481004  0.09677996]


  0%|          | 0/10 [00:00<?, ?it/s]

127-th round
average train loss -0.0686 | test loss -0.618 | test acc: 0.618


100%|██████████| 10/10 [00:17<00:00,  1.72s/it]


[ 2.85737027e-01 -7.21904401e-02  3.13367795e-02  3.15825800e-01
  2.35413846e-01  2.13247977e-01 -1.18694045e-01  1.46709735e-01
 -4.15448410e-05 -3.73451362e-02]


  0%|          | 0/10 [00:00<?, ?it/s]

128-th round
average train loss -0.0686 | test loss -0.62 | test acc: 0.621


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]


[ 0.47569118 -0.72977281 -0.30143211 -0.23779837 -0.17263421 -0.3908493
  0.92449071  1.65924886 -0.24672578  0.01978183]


  0%|          | 0/10 [00:00<?, ?it/s]

129-th round
average train loss -0.0617 | test loss -0.604 | test acc: 0.605


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 1.23826112e+00 -7.09323312e-02 -9.29806262e-03 -2.14744350e-02
 -8.95508321e-02 -7.33981378e-03 -5.45609814e-03  3.77413776e-04
 -2.05016867e-02 -1.40852740e-02]


  0%|          | 0/10 [00:00<?, ?it/s]

130-th round
average train loss -0.0707 | test loss -0.618 | test acc: 0.619


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.10918228  0.10584739  0.05619056  0.03261869  0.03634701 -0.12721497
  0.02101258 -0.00160791 -0.00762402  0.77524839]


  0%|          | 0/10 [00:00<?, ?it/s]

131-th round
average train loss -0.0628 | test loss -0.621 | test acc: 0.621


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.04049674  0.15044192  0.3781785   0.00108657  0.09290777 -0.00189324
  0.05148964  0.02146242  0.02866002  0.23716967]


  0%|          | 0/10 [00:00<?, ?it/s]

132-th round
average train loss -0.0737 | test loss -0.624 | test acc: 0.625


100%|██████████| 10/10 [00:17<00:00,  1.70s/it]


[ 0.14762006 -0.00357961  0.02562193  0.0341272   0.03478439  0.62371081
 -0.05141902 -0.07800671  0.15573646  0.1114045 ]


  0%|          | 0/10 [00:00<?, ?it/s]

133-th round
average train loss -0.0583 | test loss -0.625 | test acc: 0.626


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.33894139  0.06611378  0.06489275  0.24325543  0.05520431 -0.02484206
  0.06391703  0.13702261 -0.12647455  0.18196932]


  0%|          | 0/10 [00:00<?, ?it/s]

134-th round
average train loss -0.0588 | test loss -0.627 | test acc: 0.627


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]


[ 0.03755951  0.0768865   0.03377828  0.85700556 -0.16961884 -0.02834922
  0.87659601 -0.42273849 -0.32049999  0.0593807 ]


  0%|          | 0/10 [00:00<?, ?it/s]

135-th round
average train loss -0.0572 | test loss -0.622 | test acc: 0.622


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[ 0.01849603  0.19549291 -0.07419105  0.13143451  0.05460456  0.31956977
  0.06361667  0.01420699  0.17884678  0.09792281]


  0%|          | 0/10 [00:00<?, ?it/s]

136-th round
average train loss -0.0605 | test loss -0.628 | test acc: 0.628


100%|██████████| 10/10 [00:17<00:00,  1.71s/it]


[-0.00194989 -0.01621665  0.12909622 -0.34152498  0.28034483 -0.2299056
  1.06897246 -0.13278528  0.07987879  0.1640901 ]


  0%|          | 0/10 [00:00<?, ?it/s]

137-th round
average train loss -0.0706 | test loss -0.625 | test acc: 0.626


100%|██████████| 10/10 [00:16<00:00,  1.69s/it]


[-0.09188132  0.4802992  -0.00487399  0.22650054 -0.1133777   0.28376946
  0.19411507 -0.00778376  0.03996798 -0.00673547]


  0%|          | 0/10 [00:00<?, ?it/s]

138-th round
average train loss -0.0695 | test loss -0.629 | test acc: 0.630


100%|██████████| 10/10 [00:17<00:00,  1.76s/it]


[-0.02771082 -0.02696441  0.06801885 -0.04795462  0.01306022  0.27123899
  0.53091804 -0.11277679  0.14299043  0.18918011]


  0%|          | 0/10 [00:00<?, ?it/s]

139-th round
average train loss -0.0505 | test loss -0.629 | test acc: 0.631


100%|██████████| 10/10 [00:17<00:00,  1.79s/it]


[ 0.10842179  0.70995826  0.14378859  0.03982155  0.1461573   0.05826953
  0.05361843 -0.03349674 -0.23208932  0.0055506 ]


  0%|          | 0/10 [00:00<?, ?it/s]

140-th round
average train loss -0.0627 | test loss -0.63 | test acc: 0.631


100%|██████████| 10/10 [00:17<00:00,  1.78s/it]


[-0.78428453 -0.00544928  0.42924918  0.50323302  0.05472764  0.2846765
  0.16670941 -0.01771874  0.08449153  0.28436527]


  0%|          | 0/10 [00:00<?, ?it/s]

141-th round
average train loss -0.077 | test loss -0.629 | test acc: 0.630


100%|██████████| 10/10 [00:17<00:00,  1.74s/it]


[-0.01013843  0.01162231  0.08470237  0.14137684  0.09784675  0.27266202
 -0.01581982 -0.00758267  0.4457193  -0.02038866]


  0%|          | 0/10 [00:00<?, ?it/s]

142-th round
average train loss -0.0628 | test loss -0.631 | test acc: 0.632


100%|██████████| 10/10 [00:18<00:00,  1.81s/it]


[ 0.32971578 -0.15048764  0.11543083  0.22200411  0.2060676   0.0172036
  0.08528112 -0.10223527  0.0670358   0.20998406]


  0%|          | 0/10 [00:00<?, ?it/s]

143-th round
average train loss -0.0667 | test loss -0.631 | test acc: 0.632


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.06901797  0.09069634  0.09666475  0.10413804  0.53279677  0.06817726
  0.03877352 -0.0290019  -0.01669464  0.04543191]


  0%|          | 0/10 [00:00<?, ?it/s]

144-th round
average train loss -0.0651 | test loss -0.632 | test acc: 0.632


100%|██████████| 10/10 [00:17<00:00,  1.80s/it]


[ 0.04748576  0.5417029   0.16214705  0.00463968  0.01139179  0.20023876
  0.02083704 -0.03314482  0.08792109 -0.04321926]


  0%|          | 0/10 [00:00<?, ?it/s]

145-th round
average train loss -0.0635 | test loss -0.632 | test acc: 0.633


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[-0.03330334  0.3533866  -0.00167889  0.03544209  0.09840456  0.08234149
  0.02098969  0.03691404  0.12685783  0.28064593]


  0%|          | 0/10 [00:00<?, ?it/s]

146-th round
average train loss -0.0693 | test loss -0.632 | test acc: 0.633


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


[ 0.0043246  -0.00539908  0.06583863  0.08709059 -0.0328614   0.09869506
  0.55924691  0.13910724 -0.06280807  0.14676552]


  0%|          | 0/10 [00:00<?, ?it/s]

147-th round
average train loss -0.0706 | test loss -0.633 | test acc: 0.633


100%|██████████| 10/10 [00:17<00:00,  1.72s/it]


[-0.04760065  0.31116381  0.44633389 -0.09138235  0.21250702 -0.0588201
 -0.06661418  0.13087247  0.28178003 -0.11823994]


  0%|          | 0/10 [00:00<?, ?it/s]

148-th round
average train loss -0.0543 | test loss -0.633 | test acc: 0.633


100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


[ 0.05461634  0.00105745  0.2356162   0.05148909  0.05244168  0.10878594
  0.34429206  0.21264746  0.06026211 -0.12120832]
149-th round
average train loss -0.076 | test loss -0.634 | test acc: 0.634


In [8]:
# import random
# choose = (random.randrange(len(datareally['y'])))

# datareally = (dd['f1816_24'])
# x1 = datareally['x'][choose]
# print(datareally['y'][choose])

# import numpy as np
# x1 = np.zeros((784,63))
# counters = np.zeros((63))
# print(x1.shape)
# for i in range(len(datareally['y'])):
#     print(np.array(datareally['x'][i]).shape)
#     label = datareally['y'][i]
#     counters[label] += 1 
#     x1[:, label] += np.array(datareally['x'][i])

# import matplotlib.pyplot as plt
# fig, axes = plt.subplots(8,8, figsize=(8,8))
# for i,ax in enumerate(axes.flat):
#     tempol = x1[:, i]/counters[i]
#     xplot = np.reshape(np.ravel(tempol), (28, 28))
#     ax.imshow(xplot)

