In [1]:
import sys
sys.path.append('../code')
from resnet import *
from cifar_very_tiny import *
from cifar_dataset import *    
import torch as t 
import numpy as np
import tqdm
import matplotlib.pylab as plt
import matplotlib.cm as cm
import json
import hyperparams
from importlib import reload

%matplotlib inline
plt.rcParams['figure.figsize']=(12,9)
plt.rcParams['font.size']= 20

In [2]:
# добавил в загрузку валидационную выборку
# обрати внимание, maxsize --- это размер совокпного обучения и валидации
# поэтому размер обучающей выборки совпадает с тем, что было до этого
_, test_loader, train_loader_no_augumentation, valid_loader = cifar10_loader(batch_size=128, split_train_val=True,
                                                                             maxsize=10112*2)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epoch_num = 100
run_num = 5 # количество запусков эксперимента
# версия нужна, чтобы различать старые и новые результаты экспериментов. 
# менять нужно каждый раз, когда есть хотя бы незначительные изменения в эксперименте
experiment_version = '18' 

validate_every_epoch = 5 # каждые 5 эпох отслеживать параметры модели

# с этими гиперпараметрами мы начинаем эксперименты
start_beta = 0.9914 #0.3 
start_temp  = 6.5 #10**(0.5)

In [4]:
def accuracy(student):
        student.eval()
        total = 0 
        correct = 0
        with t.no_grad():
            for x,y in test_loader:
                x = x.to(device)
                y = y.to(device)
                out = student(x)
                correct += t.eq(t.argmax(out, 1), y).sum()
                total+=len(x)
        student.train()
        return (correct/total).cpu().detach().numpy()

In [5]:
# запуск без дистилляции
for _ in range(run_num):
    internal_results = []
    student = Cifar_Very_Tiny(10).to(device)
    optim = t.optim.Adam(student.parameters())    
    crit = nn.CrossEntropyLoss()
    for e in range(epoch_num):
        tq = tqdm.tqdm(train_loader_no_augumentation)
        losses = []
        for x,y in tq:
            x = x.to(device)
            y = y.to(device)
            student.zero_grad()            
            loss = crit(student(x), y)
            losses.append(loss.cpu().detach().numpy())
            loss.backward()
            optim.step()
            tq.set_description('current loss:{}'.format(np.mean(losses[-10:])))        
        if e==0 or (e+1)%validate_every_epoch == 0: # если номер эпохи делится на 5 или эпоха - первая             
            test_loss = []
            student.eval()
            for x,y in test_loader:
                x = x.to(device)
                y = y.to(device)                            
                test_loss.append(crit(student(x), y).detach().cpu().numpy())                 
            test_loss = float(np.mean(test_loss))
            acc = float(accuracy(student))
            student.train()
            internal_results.append({'epoch': e, 'test loss':test_loss, 'accuracy':acc})
            print (internal_results[-1])

    with open('exp'+experiment_version+'_basic.jsonl', 'a') as out:
        out.write(json.dumps({'results':internal_results, 'version': experiment_version})+'\n')

RuntimeError: CUDA error: out of memory

In [6]:
kl = nn.KLDivLoss(reduction='batchmean')
sm = nn.Softmax(dim=1)

def distill(out, batch_logits, temp):
    g = sm(out/temp)
    f = F.log_softmax(batch_logits/temp)    
    return kl(f, g)

In [10]:
# Запуск --- с CNN-дистилляцией
# в качестве значений гиперпараметров ставим  start_beta, start_temp
logits = np.load('./logits_cnn.npy')
for _ in range(run_num):
    internal_results = []
    beta = start_beta
    temp = start_temp
    student = Cifar_Very_Tiny(10).to(device)
    optim = t.optim.Adam(student.parameters())   
    crit = nn.CrossEntropyLoss()
    for e in range(epoch_num):
        tq = tqdm.tqdm(train_loader_no_augumentation)
        losses = []
        for batch_id, (x,y) in enumerate(tq):
            x = x.to(device)
            y = y.to(device)            
            batch_logits = t.Tensor(logits[128*batch_id:128*(batch_id+1)]).to(device)            
            student.zero_grad()
            out = student(x)
            student_loss = crit(out, y)            
            distillation_loss = distill(out, batch_logits, temp)
            loss = (1-beta) * student_loss + beta*distillation_loss
            losses.append(loss.cpu().detach().numpy())
            loss.backward()
            optim.step()
            tq.set_description('current loss:{}'.format(np.mean(losses[-10:])))
        if e==0 or (e+1)%validate_every_epoch == 0: # если номер эпохи делится на 5 или эпоха - первая             
            test_loss = []
            student.eval()
            for x,y in test_loader:
                x = x.to(device)
                y = y.to(device)                            
                test_loss.append(crit(student(x), y).detach().cpu().numpy())                 
            test_loss = float(np.mean(test_loss))
            acc = float(accuracy(student))
            student.train()
            internal_results.append({'epoch': e, 'test loss':test_loss, 'accuracy':acc})
            print (internal_results[-1])

            
    with open('exp'+experiment_version+'_distill.jsonl', 'a') as out:
        out.write(json.dumps({'results':internal_results, 'version': experiment_version})+'\n')

FileNotFoundError: [Errno 2] No such file or directory: './logits_cnn.npy'

In [11]:
# Запуск --- со случаными значениями гиперпараметров
crit = nn.CrossEntropyLoss()

# определяем функцию потерь как замкнутую относительно аргументов функцию
# нужно для подсчета градиентов гиперпараметров по двухуровневой оптимизации
def param_loss(batch,model,h):
    x,y,batch_logits = batch    
    beta,beta2,temp = h
    out = model(x)
    beta = F.sigmoid(beta)
    beta2 = F.sigmoid(beta2)
    temp = F.sigmoid(temp) * 10
    distillation_loss = distill(out, batch_logits, temp)
    student_loss = crit(out, y)                
    loss = beta * distillation_loss + beta2 * student_loss
    return loss

logits = np.load('../code/logits_cnn.npy')
for _ in range(run_num):
    internal_results = []
    
    # теперь beta и temp - не числа, а тензоры, по которым можно считать градиент
    beta1 = t.nn.Parameter(t.tensor(np.random.uniform(low=-1, high=1), device=device), requires_grad=True)
    beta2 = t.nn.Parameter(t.tensor(np.random.uniform(low=-1, high=1), device=device), requires_grad=True)
    temp = t.nn.Parameter(t.tensor(np.random.uniform(low=-2, high=0), device=device), requires_grad=True)    
    h = [beta1, beta2, temp]
    
    student = Cifar_Very_Tiny(10).to(device)
    optim = t.optim.Adam(student.parameters())   
    
 
    
    for e in range(epoch_num): # хочется посмотреть куда сойдутся гиперпараметры, поэтому возьмем побольше эпох
        tq = tqdm.tqdm(train_loader_no_augumentation)
        losses = []
        for batch_id, ((x,y)) in enumerate(tq):
            x = x.to(device)
            y = y.to(device)            
            batch_logits = t.Tensor(logits[128*batch_id:128*(batch_id+1)]).to(device) 
            
            optim.zero_grad()
            loss = param_loss((x,y,batch_logits), student,h)
            losses.append(loss.cpu().detach().numpy())
            loss.backward()
            optim.step()
            tq.set_description('current loss:{}'.format(np.mean(losses[-10:])))
        if e==0 or (e+1)%validate_every_epoch == 0: # если номер эпохи делится на 5 или эпоха - первая             
            test_loss = []
            student.eval()
            for x,y in test_loader:
                x = x.to(device)
                y = y.to(device)                            
                test_loss.append(crit(student(x), y).detach().cpu().numpy())                 
            test_loss = float(np.mean(test_loss))
            
            
            acc = float(accuracy(student))
            student.train()
            internal_results.append({'epoch': e, 'test loss':test_loss, 'accuracy':acc, 
                                     'temp':float(10*F.sigmoid(h[2]).cpu().detach().numpy()),
                                     'beta1':float(F.sigmoid(h[0]).cpu().detach().numpy()),
                                     'beta2':float(F.sigmoid(h[1]).cpu().detach().numpy())})
            
            print (internal_results[-1])

            
    with open('exp'+experiment_version+'_dist_h_rand.jsonl', 'a') as out:
        out.write(json.dumps({'results':internal_results, 'version': experiment_version})+'\n')

  f = F.log_softmax(batch_logits/temp)
current loss:1.5960102081298828: 100%|██████████| 79/79 [00:02<00:00, 37.20it/s]
current loss:1.5144392251968384:   5%|▌         | 4/79 [00:00<00:02, 35.73it/s]

{'epoch': 0, 'test loss': 1.6371701955795288, 'accuracy': 0.39069998264312744, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:1.4105151891708374: 100%|██████████| 79/79 [00:02<00:00, 37.93it/s]
current loss:1.3029096126556396: 100%|██████████| 79/79 [00:02<00:00, 38.36it/s]
current loss:1.2245540618896484: 100%|██████████| 79/79 [00:02<00:00, 37.72it/s]
current loss:1.1704442501068115: 100%|██████████| 79/79 [00:02<00:00, 37.13it/s]
current loss:1.1405836343765259:   5%|▌         | 4/79 [00:00<00:02, 37.06it/s]

{'epoch': 4, 'test loss': 1.5349537134170532, 'accuracy': 0.48240000009536743, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:1.1241776943206787: 100%|██████████| 79/79 [00:02<00:00, 39.29it/s]
current loss:1.0838569402694702: 100%|██████████| 79/79 [00:02<00:00, 39.06it/s]
current loss:1.0467839241027832: 100%|██████████| 79/79 [00:02<00:00, 38.81it/s]
current loss:1.0155231952667236: 100%|██████████| 79/79 [00:02<00:00, 38.69it/s]
current loss:0.9861445426940918: 100%|██████████| 79/79 [00:02<00:00, 39.09it/s]
current loss:0.9413947463035583:   5%|▌         | 4/79 [00:00<00:01, 38.27it/s]

{'epoch': 9, 'test loss': 1.5232237577438354, 'accuracy': 0.5163999795913696, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.9588192701339722: 100%|██████████| 79/79 [00:02<00:00, 38.65it/s]
current loss:0.9337289929389954: 100%|██████████| 79/79 [00:02<00:00, 38.91it/s]
current loss:0.9107760190963745: 100%|██████████| 79/79 [00:02<00:00, 38.76it/s]
current loss:0.8922513127326965: 100%|██████████| 79/79 [00:02<00:00, 38.41it/s]
current loss:0.8778277635574341: 100%|██████████| 79/79 [00:02<00:00, 39.03it/s]
current loss:0.81922847032547:   5%|▌         | 4/79 [00:00<00:02, 36.68it/s]  

{'epoch': 14, 'test loss': 1.607627272605896, 'accuracy': 0.5275999903678894, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.8582963943481445: 100%|██████████| 79/79 [00:02<00:00, 38.44it/s]
current loss:0.8443190455436707: 100%|██████████| 79/79 [00:02<00:00, 39.25it/s]
current loss:0.8239636421203613: 100%|██████████| 79/79 [00:02<00:00, 38.98it/s]
current loss:0.8073336482048035: 100%|██████████| 79/79 [00:02<00:00, 38.98it/s]
current loss:0.7944084405899048: 100%|██████████| 79/79 [00:02<00:00, 39.26it/s]
current loss:0.7320681214332581:   5%|▌         | 4/79 [00:00<00:02, 36.76it/s]

{'epoch': 19, 'test loss': 1.7182430028915405, 'accuracy': 0.5331000089645386, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.783896267414093: 100%|██████████| 79/79 [00:02<00:00, 38.82it/s] 
current loss:0.7730891704559326: 100%|██████████| 79/79 [00:02<00:00, 39.41it/s]
current loss:0.768798291683197: 100%|██████████| 79/79 [00:02<00:00, 38.63it/s] 
current loss:0.7551109790802002: 100%|██████████| 79/79 [00:02<00:00, 39.15it/s]
current loss:0.7394734621047974: 100%|██████████| 79/79 [00:02<00:00, 39.32it/s]
current loss:0.6798076033592224:   5%|▌         | 4/79 [00:00<00:01, 38.67it/s]

{'epoch': 24, 'test loss': 1.7722703218460083, 'accuracy': 0.5317999720573425, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.7238637804985046: 100%|██████████| 79/79 [00:02<00:00, 38.44it/s]
current loss:0.7106261849403381: 100%|██████████| 79/79 [00:02<00:00, 39.18it/s]
current loss:0.6961642503738403: 100%|██████████| 79/79 [00:02<00:00, 39.17it/s]
current loss:0.6856549978256226: 100%|██████████| 79/79 [00:02<00:00, 38.83it/s]
current loss:0.6754071712493896: 100%|██████████| 79/79 [00:02<00:00, 38.72it/s]
current loss:0.6342300772666931:   5%|▌         | 4/79 [00:00<00:01, 37.67it/s]

{'epoch': 29, 'test loss': 1.9106009006500244, 'accuracy': 0.5277999639511108, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.6666207313537598: 100%|██████████| 79/79 [00:02<00:00, 38.89it/s]
current loss:0.6614529490470886: 100%|██████████| 79/79 [00:02<00:00, 38.84it/s]
current loss:0.6568856239318848: 100%|██████████| 79/79 [00:02<00:00, 38.63it/s]
current loss:0.6544619202613831: 100%|██████████| 79/79 [00:02<00:00, 39.27it/s]
current loss:0.650668740272522: 100%|██████████| 79/79 [00:02<00:00, 38.57it/s] 
current loss:0.6004081964492798:   5%|▌         | 4/79 [00:00<00:02, 35.22it/s]

{'epoch': 34, 'test loss': 2.239879608154297, 'accuracy': 0.5108000040054321, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.652887761592865: 100%|██████████| 79/79 [00:02<00:00, 38.56it/s] 
current loss:0.6533041596412659: 100%|██████████| 79/79 [00:02<00:00, 39.17it/s]
current loss:0.6506632566452026: 100%|██████████| 79/79 [00:02<00:00, 38.73it/s]
current loss:0.6480476260185242: 100%|██████████| 79/79 [00:02<00:00, 38.65it/s]
current loss:0.6474117040634155: 100%|██████████| 79/79 [00:02<00:00, 39.37it/s]
current loss:0.5998724102973938:   5%|▌         | 4/79 [00:00<00:02, 37.20it/s]

{'epoch': 39, 'test loss': 2.3187503814697266, 'accuracy': 0.5126999616622925, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.6354327201843262: 100%|██████████| 79/79 [00:02<00:00, 38.95it/s]
current loss:0.6243346929550171: 100%|██████████| 79/79 [00:02<00:00, 39.29it/s]
current loss:0.6133341789245605: 100%|██████████| 79/79 [00:02<00:00, 38.63it/s]
current loss:0.593894898891449: 100%|██████████| 79/79 [00:02<00:00, 38.89it/s] 
current loss:0.5799328684806824: 100%|██████████| 79/79 [00:02<00:00, 39.18it/s]
current loss:0.5297158360481262:   5%|▌         | 4/79 [00:00<00:02, 37.42it/s]

{'epoch': 44, 'test loss': 2.2390029430389404, 'accuracy': 0.5214999914169312, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.5692687630653381: 100%|██████████| 79/79 [00:02<00:00, 38.83it/s] 
current loss:0.5626218318939209: 100%|██████████| 79/79 [00:02<00:00, 38.50it/s] 
current loss:0.553619384765625: 100%|██████████| 79/79 [00:01<00:00, 39.72it/s]  
current loss:0.5465501546859741: 100%|██████████| 79/79 [00:02<00:00, 39.20it/s] 
current loss:0.5357116460800171: 100%|██████████| 79/79 [00:02<00:00, 38.82it/s] 
current loss:0.49664750695228577:   5%|▌         | 4/79 [00:00<00:02, 37.41it/s]

{'epoch': 49, 'test loss': 2.369778633117676, 'accuracy': 0.5221999883651733, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.5263174176216125: 100%|██████████| 79/79 [00:02<00:00, 39.00it/s] 
current loss:0.5206528902053833: 100%|██████████| 79/79 [00:02<00:00, 38.91it/s] 
current loss:0.5129932165145874: 100%|██████████| 79/79 [00:02<00:00, 38.79it/s] 
current loss:0.509372353553772: 100%|██████████| 79/79 [00:02<00:00, 38.93it/s]  
current loss:0.5086177587509155: 100%|██████████| 79/79 [00:02<00:00, 38.97it/s] 
current loss:0.47425299882888794:   5%|▌         | 4/79 [00:00<00:01, 39.67it/s]

{'epoch': 54, 'test loss': 2.4762070178985596, 'accuracy': 0.526199996471405, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.5095704793930054: 100%|██████████| 79/79 [00:02<00:00, 38.94it/s] 
current loss:0.5111590623855591: 100%|██████████| 79/79 [00:02<00:00, 38.82it/s] 
current loss:0.5241861939430237: 100%|██████████| 79/79 [00:02<00:00, 38.58it/s] 
current loss:0.5506923794746399: 100%|██████████| 79/79 [00:01<00:00, 39.63it/s] 
current loss:0.5679622888565063: 100%|██████████| 79/79 [00:02<00:00, 39.03it/s] 
current loss:0.48071956634521484:   5%|▌         | 4/79 [00:00<00:01, 39.09it/s]

{'epoch': 59, 'test loss': 2.610971212387085, 'accuracy': 0.5200999975204468, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.5597604513168335: 100%|██████████| 79/79 [00:02<00:00, 39.02it/s] 
current loss:0.552208662033081: 100%|██████████| 79/79 [00:02<00:00, 39.25it/s]  
current loss:0.5308359265327454: 100%|██████████| 79/79 [00:02<00:00, 38.46it/s] 
current loss:0.5037561655044556: 100%|██████████| 79/79 [00:02<00:00, 39.32it/s] 
current loss:0.48582831025123596: 100%|██████████| 79/79 [00:02<00:00, 39.12it/s]
current loss:0.4253630042076111:   5%|▌         | 4/79 [00:00<00:01, 38.63it/s] 

{'epoch': 64, 'test loss': 2.75834321975708, 'accuracy': 0.5102999806404114, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.47640419006347656: 100%|██████████| 79/79 [00:02<00:00, 38.61it/s]
current loss:0.46889644861221313: 100%|██████████| 79/79 [00:02<00:00, 39.12it/s]
current loss:0.4647294580936432: 100%|██████████| 79/79 [00:02<00:00, 38.85it/s] 
current loss:0.4584428668022156: 100%|██████████| 79/79 [00:01<00:00, 39.55it/s] 
current loss:0.4582298696041107: 100%|██████████| 79/79 [00:02<00:00, 39.14it/s] 
current loss:0.39456668496131897:   5%|▌         | 4/79 [00:00<00:01, 37.73it/s]

{'epoch': 69, 'test loss': 2.919712543487549, 'accuracy': 0.5103999972343445, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.45655736327171326: 100%|██████████| 79/79 [00:02<00:00, 38.57it/s]
current loss:0.45208925008773804: 100%|██████████| 79/79 [00:02<00:00, 38.49it/s]
current loss:0.4473051130771637: 100%|██████████| 79/79 [00:02<00:00, 38.96it/s] 
current loss:0.4468528628349304: 100%|██████████| 79/79 [00:01<00:00, 39.57it/s] 
current loss:0.44671887159347534: 100%|██████████| 79/79 [00:02<00:00, 39.14it/s]
current loss:0.3732871115207672:   5%|▌         | 4/79 [00:00<00:01, 39.76it/s] 

{'epoch': 74, 'test loss': 2.965271234512329, 'accuracy': 0.5062999725341797, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.4448484480381012: 100%|██████████| 79/79 [00:02<00:00, 39.43it/s] 
current loss:0.442873477935791: 100%|██████████| 79/79 [00:01<00:00, 39.51it/s]  
current loss:0.44493111968040466: 100%|██████████| 79/79 [00:02<00:00, 38.86it/s]
current loss:0.4443860650062561: 100%|██████████| 79/79 [00:02<00:00, 38.83it/s] 
current loss:0.43697911500930786: 100%|██████████| 79/79 [00:02<00:00, 39.20it/s]
current loss:0.3787512183189392:   5%|▌         | 4/79 [00:00<00:01, 39.16it/s] 

{'epoch': 79, 'test loss': 3.009542226791382, 'accuracy': 0.5130000114440918, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.4440193772315979: 100%|██████████| 79/79 [00:02<00:00, 38.81it/s] 
current loss:0.45250368118286133: 100%|██████████| 79/79 [00:02<00:00, 38.82it/s]
current loss:0.4434415400028229: 100%|██████████| 79/79 [00:02<00:00, 38.82it/s] 
current loss:0.44719696044921875: 100%|██████████| 79/79 [00:01<00:00, 39.53it/s]
current loss:0.4544309675693512: 100%|██████████| 79/79 [00:02<00:00, 38.34it/s] 
current loss:0.3821108937263489:   6%|▋         | 5/79 [00:00<00:01, 40.01it/s] 

{'epoch': 84, 'test loss': 3.0682780742645264, 'accuracy': 0.505899965763092, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.44917458295822144: 100%|██████████| 79/79 [00:02<00:00, 39.01it/s]
current loss:0.43252310156822205: 100%|██████████| 79/79 [00:02<00:00, 39.31it/s]
current loss:0.4202929437160492: 100%|██████████| 79/79 [00:02<00:00, 38.48it/s] 
current loss:0.4098083972930908: 100%|██████████| 79/79 [00:02<00:00, 38.55it/s] 
current loss:0.4009183347225189: 100%|██████████| 79/79 [00:02<00:00, 38.96it/s] 
current loss:0.3485493063926697:   6%|▋         | 5/79 [00:00<00:01, 40.08it/s] 

{'epoch': 89, 'test loss': 3.1625051498413086, 'accuracy': 0.5066999793052673, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.39604371786117554: 100%|██████████| 79/79 [00:02<00:00, 39.30it/s]
current loss:0.3874768614768982: 100%|██████████| 79/79 [00:02<00:00, 39.37it/s] 
current loss:0.38485151529312134: 100%|██████████| 79/79 [00:02<00:00, 38.84it/s]
current loss:0.3806740641593933: 100%|██████████| 79/79 [00:02<00:00, 39.49it/s] 
current loss:0.3770963251590729: 100%|██████████| 79/79 [00:01<00:00, 39.57it/s] 
current loss:0.33689016103744507:   6%|▋         | 5/79 [00:00<00:01, 40.79it/s]

{'epoch': 94, 'test loss': 3.2295210361480713, 'accuracy': 0.5102999806404114, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:0.3763147294521332: 100%|██████████| 79/79 [00:02<00:00, 38.95it/s] 
current loss:0.3724122643470764: 100%|██████████| 79/79 [00:02<00:00, 39.26it/s] 
current loss:0.3709509074687958: 100%|██████████| 79/79 [00:01<00:00, 39.78it/s] 
current loss:0.37264224886894226: 100%|██████████| 79/79 [00:01<00:00, 39.69it/s]
current loss:0.3720141649246216: 100%|██████████| 79/79 [00:02<00:00, 39.38it/s] 
current loss:3.3858792781829834:   5%|▌         | 4/79 [00:00<00:02, 37.30it/s]

{'epoch': 99, 'test loss': 3.3103854656219482, 'accuracy': 0.5076999664306641, 'temp': 3.068796694278717, 'beta1': 0.478068470954895, 'beta2': 0.5175783634185791}


current loss:2.423539638519287: 100%|██████████| 79/79 [00:02<00:00, 38.98it/s] 
current loss:2.2782561779022217:   5%|▌         | 4/79 [00:00<00:01, 39.04it/s]

{'epoch': 0, 'test loss': 1.6697720289230347, 'accuracy': 0.37449997663497925, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:2.085214138031006: 100%|██████████| 79/79 [00:02<00:00, 38.96it/s] 
current loss:1.8229516744613647: 100%|██████████| 79/79 [00:02<00:00, 39.11it/s]
current loss:1.6474193334579468: 100%|██████████| 79/79 [00:01<00:00, 39.88it/s]
current loss:1.5267300605773926: 100%|██████████| 79/79 [00:02<00:00, 38.92it/s]
current loss:1.477131724357605:   5%|▌         | 4/79 [00:00<00:02, 36.54it/s] 

{'epoch': 4, 'test loss': 1.453606128692627, 'accuracy': 0.534500002861023, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:1.4359080791473389: 100%|██████████| 79/79 [00:02<00:00, 38.50it/s]
current loss:1.359168291091919: 100%|██████████| 79/79 [00:02<00:00, 38.76it/s] 
current loss:1.3036820888519287: 100%|██████████| 79/79 [00:02<00:00, 38.92it/s]
current loss:1.2511323690414429: 100%|██████████| 79/79 [00:02<00:00, 39.08it/s]
current loss:1.2028300762176514: 100%|██████████| 79/79 [00:02<00:00, 38.98it/s]
current loss:1.1974743604660034:   5%|▌         | 4/79 [00:00<00:01, 39.53it/s]

{'epoch': 9, 'test loss': 1.5240607261657715, 'accuracy': 0.5557000041007996, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:1.1588349342346191: 100%|██████████| 79/79 [00:02<00:00, 39.45it/s]
current loss:1.1265394687652588: 100%|██████████| 79/79 [00:02<00:00, 39.42it/s]
current loss:1.0921647548675537: 100%|██████████| 79/79 [00:02<00:00, 38.91it/s]
current loss:1.0577216148376465: 100%|██████████| 79/79 [00:02<00:00, 38.19it/s]
current loss:1.0324304103851318: 100%|██████████| 79/79 [00:02<00:00, 39.02it/s]
current loss:1.0324180126190186:   5%|▌         | 4/79 [00:00<00:01, 38.40it/s]

{'epoch': 14, 'test loss': 1.590728998184204, 'accuracy': 0.5625999569892883, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:1.0084675550460815: 100%|██████████| 79/79 [00:02<00:00, 38.90it/s]
current loss:0.9854482412338257: 100%|██████████| 79/79 [00:02<00:00, 39.24it/s]
current loss:0.9659941792488098: 100%|██████████| 79/79 [00:01<00:00, 39.62it/s]
current loss:0.9450408220291138: 100%|██████████| 79/79 [00:02<00:00, 39.30it/s]
current loss:0.9240888357162476: 100%|██████████| 79/79 [00:01<00:00, 39.75it/s]
current loss:0.9180734753608704:   5%|▌         | 4/79 [00:00<00:01, 39.64it/s]

{'epoch': 19, 'test loss': 1.6864498853683472, 'accuracy': 0.5658999681472778, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.9029666781425476: 100%|██████████| 79/79 [00:02<00:00, 38.93it/s]
current loss:0.885595440864563: 100%|██████████| 79/79 [00:02<00:00, 38.84it/s] 
current loss:0.8697985410690308: 100%|██████████| 79/79 [00:02<00:00, 38.59it/s]
current loss:0.8532823324203491: 100%|██████████| 79/79 [00:02<00:00, 38.82it/s]
current loss:0.8397396206855774: 100%|██████████| 79/79 [00:02<00:00, 39.17it/s]
current loss:0.8460346460342407:   5%|▌         | 4/79 [00:00<00:01, 39.11it/s]

{'epoch': 24, 'test loss': 1.7827633619308472, 'accuracy': 0.5667999982833862, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.8313273191452026: 100%|██████████| 79/79 [00:02<00:00, 39.25it/s]
current loss:0.8227823376655579: 100%|██████████| 79/79 [00:01<00:00, 39.54it/s]
current loss:0.8118001222610474: 100%|██████████| 79/79 [00:01<00:00, 39.81it/s]
current loss:0.8213736414909363: 100%|██████████| 79/79 [00:02<00:00, 39.21it/s]
current loss:0.8264223337173462: 100%|██████████| 79/79 [00:01<00:00, 39.51it/s]
current loss:0.809027910232544:   6%|▋         | 5/79 [00:00<00:01, 40.25it/s] 

{'epoch': 29, 'test loss': 1.9346833229064941, 'accuracy': 0.5489000082015991, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.8187869191169739: 100%|██████████| 79/79 [00:02<00:00, 39.35it/s]
current loss:0.7966429591178894: 100%|██████████| 79/79 [00:02<00:00, 38.73it/s]
current loss:0.7705046534538269: 100%|██████████| 79/79 [00:02<00:00, 39.04it/s]
current loss:0.7511571049690247: 100%|██████████| 79/79 [00:02<00:00, 38.75it/s]
current loss:0.7407791018486023: 100%|██████████| 79/79 [00:02<00:00, 39.36it/s]
current loss:0.7425050139427185:   5%|▌         | 4/79 [00:00<00:01, 39.54it/s]

{'epoch': 34, 'test loss': 1.8427491188049316, 'accuracy': 0.568399965763092, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.7305631041526794: 100%|██████████| 79/79 [00:02<00:00, 39.21it/s]
current loss:0.7214564085006714: 100%|██████████| 79/79 [00:02<00:00, 38.12it/s]
current loss:0.7122157216072083: 100%|██████████| 79/79 [00:02<00:00, 39.34it/s]
current loss:0.7034059762954712: 100%|██████████| 79/79 [00:02<00:00, 39.23it/s]
current loss:0.6964942216873169: 100%|██████████| 79/79 [00:02<00:00, 38.93it/s]
current loss:0.689768373966217:   5%|▌         | 4/79 [00:00<00:01, 38.96it/s] 

{'epoch': 39, 'test loss': 1.9760221242904663, 'accuracy': 0.5649999976158142, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.6956966519355774: 100%|██████████| 79/79 [00:02<00:00, 38.95it/s]
current loss:0.6890276670455933: 100%|██████████| 79/79 [00:02<00:00, 38.50it/s]
current loss:0.686158299446106: 100%|██████████| 79/79 [00:02<00:00, 38.78it/s] 
current loss:0.6774289011955261: 100%|██████████| 79/79 [00:02<00:00, 38.75it/s]
current loss:0.6693255305290222: 100%|██████████| 79/79 [00:02<00:00, 38.71it/s]
current loss:0.648051917552948:   5%|▌         | 4/79 [00:00<00:01, 39.00it/s] 

{'epoch': 44, 'test loss': 2.143141508102417, 'accuracy': 0.5622999668121338, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.6637268662452698: 100%|██████████| 79/79 [00:01<00:00, 39.82it/s]
current loss:0.653684139251709: 100%|██████████| 79/79 [00:02<00:00, 39.12it/s] 
current loss:0.6460921764373779: 100%|██████████| 79/79 [00:02<00:00, 39.14it/s]
current loss:0.6430311799049377: 100%|██████████| 79/79 [00:02<00:00, 38.91it/s]
current loss:0.6391170024871826: 100%|██████████| 79/79 [00:02<00:00, 38.76it/s]
current loss:0.6273821592330933:   5%|▌         | 4/79 [00:00<00:01, 39.33it/s]

{'epoch': 49, 'test loss': 2.2679316997528076, 'accuracy': 0.5586999654769897, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.6415634751319885: 100%|██████████| 79/79 [00:02<00:00, 39.06it/s]
current loss:0.6387597322463989: 100%|██████████| 79/79 [00:02<00:00, 39.04it/s]
current loss:0.6267525553703308: 100%|██████████| 79/79 [00:02<00:00, 39.49it/s]
current loss:0.6153958439826965: 100%|██████████| 79/79 [00:02<00:00, 38.99it/s]
current loss:0.6295173764228821: 100%|██████████| 79/79 [00:02<00:00, 38.99it/s]
current loss:0.6561276912689209:   6%|▋         | 5/79 [00:00<00:01, 41.10it/s]

{'epoch': 54, 'test loss': 2.2642080783843994, 'accuracy': 0.5647000074386597, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.619256854057312: 100%|██████████| 79/79 [00:01<00:00, 39.54it/s] 
current loss:0.5965635180473328: 100%|██████████| 79/79 [00:02<00:00, 39.46it/s]
current loss:0.5938411951065063: 100%|██████████| 79/79 [00:02<00:00, 39.01it/s]
current loss:0.5873284339904785: 100%|██████████| 79/79 [00:02<00:00, 38.92it/s]
current loss:0.5854867696762085: 100%|██████████| 79/79 [00:02<00:00, 38.74it/s]
current loss:0.574897825717926:   5%|▌         | 4/79 [00:00<00:01, 39.55it/s] 

{'epoch': 59, 'test loss': 2.34574818611145, 'accuracy': 0.5629000067710876, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.5791704654693604: 100%|██████████| 79/79 [00:02<00:00, 39.03it/s]
current loss:0.5733369588851929: 100%|██████████| 79/79 [00:02<00:00, 39.21it/s]
current loss:0.5729442834854126: 100%|██████████| 79/79 [00:01<00:00, 39.53it/s]
current loss:0.5792938470840454: 100%|██████████| 79/79 [00:02<00:00, 39.05it/s]
current loss:0.5842593312263489: 100%|██████████| 79/79 [00:02<00:00, 38.88it/s]
current loss:0.5570176839828491:   6%|▋         | 5/79 [00:00<00:01, 40.13it/s]

{'epoch': 64, 'test loss': 2.579885244369507, 'accuracy': 0.552899956703186, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.590684711933136: 100%|██████████| 79/79 [00:02<00:00, 39.22it/s] 
current loss:0.5902456045150757: 100%|██████████| 79/79 [00:02<00:00, 38.97it/s]
current loss:0.5922551155090332: 100%|██████████| 79/79 [00:02<00:00, 39.44it/s]
current loss:0.5875311493873596: 100%|██████████| 79/79 [00:01<00:00, 39.55it/s]
current loss:0.5768929719924927: 100%|██████████| 79/79 [00:02<00:00, 39.03it/s]
current loss:0.5616417527198792:   5%|▌         | 4/79 [00:00<00:01, 38.62it/s]

{'epoch': 69, 'test loss': 2.7479865550994873, 'accuracy': 0.5432999730110168, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.5711977481842041: 100%|██████████| 79/79 [00:02<00:00, 38.81it/s]
current loss:0.5718890428543091: 100%|██████████| 79/79 [00:02<00:00, 39.14it/s]
current loss:0.5730878114700317: 100%|██████████| 79/79 [00:02<00:00, 38.70it/s]
current loss:0.5677412152290344: 100%|██████████| 79/79 [00:02<00:00, 39.16it/s]
current loss:0.567437469959259: 100%|██████████| 79/79 [00:02<00:00, 39.17it/s] 
current loss:0.581380307674408:   5%|▌         | 4/79 [00:00<00:01, 38.08it/s] 

{'epoch': 74, 'test loss': 3.0017051696777344, 'accuracy': 0.525600016117096, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.5645322799682617: 100%|██████████| 79/79 [00:02<00:00, 38.84it/s]
current loss:0.5622052550315857: 100%|██████████| 79/79 [00:02<00:00, 38.80it/s]
current loss:0.5629364252090454: 100%|██████████| 79/79 [00:02<00:00, 39.28it/s]
current loss:0.5651446580886841: 100%|██████████| 79/79 [00:02<00:00, 39.24it/s]
current loss:0.5555983781814575: 100%|██████████| 79/79 [00:02<00:00, 38.84it/s]
current loss:0.5361375212669373:   5%|▌         | 4/79 [00:00<00:01, 38.94it/s]

{'epoch': 79, 'test loss': 2.83712100982666, 'accuracy': 0.5370999574661255, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.549796462059021: 100%|██████████| 79/79 [00:02<00:00, 39.08it/s] 
current loss:0.54936683177948: 100%|██████████| 79/79 [00:02<00:00, 38.55it/s]  
current loss:0.5331680178642273: 100%|██████████| 79/79 [00:02<00:00, 38.81it/s]
current loss:0.5271758437156677: 100%|██████████| 79/79 [00:02<00:00, 39.19it/s]
current loss:0.5312039852142334: 100%|██████████| 79/79 [00:02<00:00, 38.73it/s]
current loss:0.5170463919639587:   5%|▌         | 4/79 [00:00<00:01, 38.94it/s]

{'epoch': 84, 'test loss': 2.9259445667266846, 'accuracy': 0.5307999849319458, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.5275338888168335: 100%|██████████| 79/79 [00:02<00:00, 38.67it/s]
current loss:0.5187145471572876: 100%|██████████| 79/79 [00:02<00:00, 39.15it/s] 
current loss:0.5025306344032288: 100%|██████████| 79/79 [00:02<00:00, 38.81it/s] 
current loss:0.47770053148269653: 100%|██████████| 79/79 [00:02<00:00, 38.95it/s]
current loss:0.4626256823539734: 100%|██████████| 79/79 [00:02<00:00, 39.12it/s] 
current loss:0.4611351788043976:   5%|▌         | 4/79 [00:00<00:01, 38.93it/s] 

{'epoch': 89, 'test loss': 2.821408271789551, 'accuracy': 0.538599967956543, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.4578748643398285: 100%|██████████| 79/79 [00:02<00:00, 39.11it/s] 
current loss:0.451593816280365: 100%|██████████| 79/79 [00:02<00:00, 38.94it/s]  
current loss:0.44906872510910034: 100%|██████████| 79/79 [00:02<00:00, 38.54it/s]
current loss:0.4474538266658783: 100%|██████████| 79/79 [00:02<00:00, 39.28it/s] 
current loss:0.44626474380493164: 100%|██████████| 79/79 [00:02<00:00, 39.21it/s]
current loss:0.43205326795578003:   5%|▌         | 4/79 [00:00<00:01, 38.93it/s]

{'epoch': 94, 'test loss': 2.819812536239624, 'accuracy': 0.5498999953269958, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:0.4398050904273987: 100%|██████████| 79/79 [00:02<00:00, 39.25it/s] 
current loss:0.4350927472114563: 100%|██████████| 79/79 [00:02<00:00, 39.32it/s] 
current loss:0.43122926354408264: 100%|██████████| 79/79 [00:02<00:00, 39.33it/s]
current loss:0.4329541325569153: 100%|██████████| 79/79 [00:02<00:00, 39.08it/s] 
current loss:0.4349820017814636: 100%|██████████| 79/79 [00:02<00:00, 39.14it/s] 
current loss:2.470644474029541:   6%|▋         | 5/79 [00:00<00:01, 40.77it/s] 

{'epoch': 99, 'test loss': 2.8421289920806885, 'accuracy': 0.5498999953269958, 'temp': 1.8425272405147552, 'beta1': 0.5354272127151489, 'beta2': 0.4940754175186157}


current loss:1.8089697360992432: 100%|██████████| 79/79 [00:02<00:00, 38.77it/s]
current loss:1.7020447254180908:   5%|▌         | 4/79 [00:00<00:01, 39.17it/s]

{'epoch': 0, 'test loss': 1.6454490423202515, 'accuracy': 0.38189998269081116, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:1.514905333518982: 100%|██████████| 79/79 [00:02<00:00, 39.06it/s] 
current loss:1.3315057754516602: 100%|██████████| 79/79 [00:02<00:00, 38.36it/s]
current loss:1.2163641452789307: 100%|██████████| 79/79 [00:02<00:00, 39.09it/s]
current loss:1.134932518005371: 100%|██████████| 79/79 [00:02<00:00, 39.41it/s] 
current loss:1.11722993850708:   5%|▌         | 4/79 [00:00<00:01, 39.35it/s]  

{'epoch': 4, 'test loss': 1.5327696800231934, 'accuracy': 0.5320000052452087, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:1.0719091892242432: 100%|██████████| 79/79 [00:02<00:00, 39.25it/s]
current loss:1.02322256565094: 100%|██████████| 79/79 [00:02<00:00, 38.76it/s]  
current loss:0.984525203704834: 100%|██████████| 79/79 [00:01<00:00, 39.57it/s] 
current loss:0.950641930103302: 100%|██████████| 79/79 [00:02<00:00, 38.89it/s] 
current loss:0.9200813174247742: 100%|██████████| 79/79 [00:02<00:00, 39.32it/s]
current loss:0.9075573086738586:   5%|▌         | 4/79 [00:00<00:01, 38.68it/s]

{'epoch': 9, 'test loss': 1.5292139053344727, 'accuracy': 0.5601999759674072, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.8897256851196289: 100%|██████████| 79/79 [00:02<00:00, 39.05it/s]
current loss:0.8637951612472534: 100%|██████████| 79/79 [00:02<00:00, 38.70it/s]
current loss:0.8426017761230469: 100%|██████████| 79/79 [00:02<00:00, 38.66it/s]
current loss:0.8255480527877808: 100%|██████████| 79/79 [00:02<00:00, 39.28it/s]
current loss:0.8103519678115845: 100%|██████████| 79/79 [00:02<00:00, 39.20it/s]
current loss:0.7879716753959656:   5%|▌         | 4/79 [00:00<00:01, 38.42it/s]

{'epoch': 14, 'test loss': 1.6340559720993042, 'accuracy': 0.5656999945640564, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.7946528196334839: 100%|██████████| 79/79 [00:02<00:00, 38.66it/s]
current loss:0.7784470915794373: 100%|██████████| 79/79 [00:02<00:00, 38.70it/s]
current loss:0.7661382555961609: 100%|██████████| 79/79 [00:02<00:00, 38.84it/s]
current loss:0.7502151131629944: 100%|██████████| 79/79 [00:02<00:00, 39.48it/s]
current loss:0.7327073812484741: 100%|██████████| 79/79 [00:01<00:00, 39.65it/s]
current loss:0.7035826444625854:   6%|▋         | 5/79 [00:00<00:01, 40.92it/s]

{'epoch': 19, 'test loss': 1.8101122379302979, 'accuracy': 0.5597999691963196, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.7188721895217896: 100%|██████████| 79/79 [00:01<00:00, 40.04it/s]
current loss:0.7047933340072632: 100%|██████████| 79/79 [00:02<00:00, 39.07it/s]
current loss:0.694301187992096: 100%|██████████| 79/79 [00:02<00:00, 38.70it/s] 
current loss:0.6854023337364197: 100%|██████████| 79/79 [00:02<00:00, 39.09it/s]
current loss:0.6769618391990662: 100%|██████████| 79/79 [00:02<00:00, 39.26it/s]
current loss:0.6612550616264343:   5%|▌         | 4/79 [00:00<00:01, 39.58it/s]

{'epoch': 24, 'test loss': 2.004169464111328, 'accuracy': 0.5572999715805054, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.6739623546600342: 100%|██████████| 79/79 [00:02<00:00, 38.94it/s]
current loss:0.6729492545127869: 100%|██████████| 79/79 [00:02<00:00, 39.49it/s]
current loss:0.6743632555007935: 100%|██████████| 79/79 [00:02<00:00, 39.12it/s]
current loss:0.6591769456863403: 100%|██████████| 79/79 [00:02<00:00, 39.11it/s]
current loss:0.6443625092506409: 100%|██████████| 79/79 [00:02<00:00, 38.77it/s]
current loss:0.6107137799263:   6%|▋         | 5/79 [00:00<00:01, 39.39it/s]   

{'epoch': 29, 'test loss': 1.8520857095718384, 'accuracy': 0.5780999660491943, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.6415621638298035: 100%|██████████| 79/79 [00:02<00:00, 39.14it/s]
current loss:0.6493255496025085: 100%|██████████| 79/79 [00:02<00:00, 38.50it/s]
current loss:0.6423073410987854: 100%|██████████| 79/79 [00:02<00:00, 38.75it/s]
current loss:0.6315152049064636: 100%|██████████| 79/79 [00:02<00:00, 39.33it/s]
current loss:0.6166327595710754: 100%|██████████| 79/79 [00:02<00:00, 38.79it/s]
current loss:0.5760847926139832:   5%|▌         | 4/79 [00:00<00:02, 37.12it/s]

{'epoch': 34, 'test loss': 2.0125880241394043, 'accuracy': 0.5701000094413757, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.6062019467353821: 100%|██████████| 79/79 [00:02<00:00, 38.66it/s]
current loss:0.5993397831916809: 100%|██████████| 79/79 [00:02<00:00, 38.98it/s]
current loss:0.5888211727142334: 100%|██████████| 79/79 [00:02<00:00, 38.64it/s]
current loss:0.5840606093406677: 100%|██████████| 79/79 [00:02<00:00, 38.89it/s]
current loss:0.5751473903656006: 100%|██████████| 79/79 [00:02<00:00, 39.27it/s]
current loss:0.5260634422302246:   5%|▌         | 4/79 [00:00<00:01, 38.24it/s]

{'epoch': 39, 'test loss': 2.0264904499053955, 'accuracy': 0.5776000022888184, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.5689722299575806: 100%|██████████| 79/79 [00:02<00:00, 38.81it/s]
current loss:0.562309205532074: 100%|██████████| 79/79 [00:01<00:00, 39.76it/s] 
current loss:0.5544530749320984: 100%|██████████| 79/79 [00:01<00:00, 39.57it/s]
current loss:0.5473325252532959: 100%|██████████| 79/79 [00:02<00:00, 38.96it/s]
current loss:0.5432132482528687: 100%|██████████| 79/79 [00:02<00:00, 39.08it/s]
current loss:0.49602729082107544:   6%|▋         | 5/79 [00:00<00:01, 40.33it/s]

{'epoch': 44, 'test loss': 2.0549473762512207, 'accuracy': 0.5805999636650085, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.5396857261657715: 100%|██████████| 79/79 [00:01<00:00, 39.58it/s]
current loss:0.539698600769043: 100%|██████████| 79/79 [00:02<00:00, 39.48it/s]  
current loss:0.5356298089027405: 100%|██████████| 79/79 [00:02<00:00, 39.11it/s] 
current loss:0.5346709489822388: 100%|██████████| 79/79 [00:02<00:00, 39.06it/s] 
current loss:0.5315684080123901: 100%|██████████| 79/79 [00:02<00:00, 38.38it/s] 
current loss:0.48177745938301086:   5%|▌         | 4/79 [00:00<00:01, 39.04it/s]

{'epoch': 49, 'test loss': 2.137270212173462, 'accuracy': 0.5798999667167664, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.525619626045227: 100%|██████████| 79/79 [00:02<00:00, 38.56it/s]  
current loss:0.5219794511795044: 100%|██████████| 79/79 [00:02<00:00, 38.96it/s] 
current loss:0.5203210115432739: 100%|██████████| 79/79 [00:02<00:00, 38.68it/s] 
current loss:0.5176366567611694: 100%|██████████| 79/79 [00:02<00:00, 39.18it/s] 
current loss:0.5153437852859497: 100%|██████████| 79/79 [00:02<00:00, 39.14it/s] 
current loss:0.4691760540008545:   6%|▋         | 5/79 [00:00<00:01, 40.10it/s] 

{'epoch': 54, 'test loss': 2.2966532707214355, 'accuracy': 0.5730999708175659, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.5129520893096924: 100%|██████████| 79/79 [00:02<00:00, 39.09it/s] 
current loss:0.5100735425949097: 100%|██████████| 79/79 [00:02<00:00, 39.38it/s] 
current loss:0.5083873271942139: 100%|██████████| 79/79 [00:02<00:00, 38.91it/s] 
current loss:0.5032636523246765: 100%|██████████| 79/79 [00:02<00:00, 38.97it/s] 
current loss:0.49525755643844604: 100%|██████████| 79/79 [00:02<00:00, 39.12it/s]
current loss:0.469393253326416:   6%|▋         | 5/79 [00:00<00:01, 40.04it/s]  

{'epoch': 59, 'test loss': 2.469184398651123, 'accuracy': 0.5692999958992004, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.49074864387512207: 100%|██████████| 79/79 [00:01<00:00, 39.57it/s]
current loss:0.4881882667541504: 100%|██████████| 79/79 [00:01<00:00, 39.54it/s] 
current loss:0.48912930488586426: 100%|██████████| 79/79 [00:02<00:00, 38.70it/s]
current loss:0.4888829290866852: 100%|██████████| 79/79 [00:02<00:00, 39.13it/s] 
current loss:0.4851200580596924: 100%|██████████| 79/79 [00:01<00:00, 39.90it/s] 
current loss:0.4540731608867645:   5%|▌         | 4/79 [00:00<00:01, 39.24it/s] 

{'epoch': 64, 'test loss': 2.6614129543304443, 'accuracy': 0.5605999827384949, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.4821108877658844: 100%|██████████| 79/79 [00:02<00:00, 39.27it/s] 
current loss:0.4857872426509857: 100%|██████████| 79/79 [00:02<00:00, 39.00it/s] 
current loss:0.4851261079311371: 100%|██████████| 79/79 [00:02<00:00, 39.22it/s] 
current loss:0.48902425169944763: 100%|██████████| 79/79 [00:02<00:00, 38.98it/s]
current loss:0.48160290718078613: 100%|██████████| 79/79 [00:02<00:00, 38.86it/s]
current loss:0.4502229690551758:   6%|▋         | 5/79 [00:00<00:01, 40.05it/s]

{'epoch': 69, 'test loss': 2.5461385250091553, 'accuracy': 0.567799985408783, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.4788668751716614: 100%|██████████| 79/79 [00:01<00:00, 39.59it/s] 
current loss:0.4728289544582367: 100%|██████████| 79/79 [00:02<00:00, 39.00it/s] 
current loss:0.46877479553222656: 100%|██████████| 79/79 [00:02<00:00, 38.82it/s]
current loss:0.4616360664367676: 100%|██████████| 79/79 [00:02<00:00, 39.10it/s] 
current loss:0.46098217368125916: 100%|██████████| 79/79 [00:02<00:00, 38.76it/s]
current loss:0.43421608209609985:   5%|▌         | 4/79 [00:00<00:01, 39.51it/s]

{'epoch': 74, 'test loss': 2.5861732959747314, 'accuracy': 0.5608999729156494, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.46171045303344727: 100%|██████████| 79/79 [00:01<00:00, 39.52it/s]
current loss:0.46178144216537476: 100%|██████████| 79/79 [00:02<00:00, 39.45it/s]
current loss:0.4672144949436188: 100%|██████████| 79/79 [00:02<00:00, 39.15it/s] 
current loss:0.47685885429382324: 100%|██████████| 79/79 [00:01<00:00, 39.52it/s]
current loss:0.4855410158634186: 100%|██████████| 79/79 [00:02<00:00, 38.94it/s] 
current loss:0.46299177408218384:   6%|▋         | 5/79 [00:00<00:01, 39.83it/s]

{'epoch': 79, 'test loss': 2.7042429447174072, 'accuracy': 0.5507999658584595, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.4949016571044922: 100%|██████████| 79/79 [00:02<00:00, 39.47it/s] 
current loss:0.4943664073944092: 100%|██████████| 79/79 [00:01<00:00, 39.51it/s] 
current loss:0.49354735016822815: 100%|██████████| 79/79 [00:02<00:00, 38.97it/s]
current loss:0.4833245277404785: 100%|██████████| 79/79 [00:02<00:00, 39.32it/s] 
current loss:0.48177972435951233: 100%|██████████| 79/79 [00:02<00:00, 39.39it/s]
current loss:0.45067664980888367:   5%|▌         | 4/79 [00:00<00:01, 39.02it/s]

{'epoch': 84, 'test loss': 2.580810070037842, 'accuracy': 0.5590999722480774, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.47394388914108276: 100%|██████████| 79/79 [00:02<00:00, 39.20it/s]
current loss:0.4568236470222473: 100%|██████████| 79/79 [00:02<00:00, 39.12it/s] 
current loss:0.4439719319343567: 100%|██████████| 79/79 [00:02<00:00, 38.87it/s] 
current loss:0.43318867683410645: 100%|██████████| 79/79 [00:02<00:00, 38.90it/s]
current loss:0.42361927032470703: 100%|██████████| 79/79 [00:02<00:00, 39.21it/s]
current loss:0.41115957498550415:   5%|▌         | 4/79 [00:00<00:01, 39.10it/s]

{'epoch': 89, 'test loss': 2.711174964904785, 'accuracy': 0.5584999918937683, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.41844111680984497: 100%|██████████| 79/79 [00:01<00:00, 39.61it/s]
current loss:0.410287082195282: 100%|██████████| 79/79 [00:02<00:00, 39.25it/s]  
current loss:0.4035255014896393: 100%|██████████| 79/79 [00:02<00:00, 39.14it/s] 
current loss:0.40146908164024353: 100%|██████████| 79/79 [00:02<00:00, 39.06it/s]
current loss:0.396422415971756: 100%|██████████| 79/79 [00:02<00:00, 39.26it/s]  
current loss:0.38108888268470764:   5%|▌         | 4/79 [00:00<00:01, 39.71it/s]

{'epoch': 94, 'test loss': 2.837930679321289, 'accuracy': 0.5537999868392944, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:0.39236539602279663: 100%|██████████| 79/79 [00:02<00:00, 38.81it/s]
current loss:0.38860076665878296: 100%|██████████| 79/79 [00:02<00:00, 39.14it/s]
current loss:0.38530707359313965: 100%|██████████| 79/79 [00:02<00:00, 38.37it/s]
current loss:0.3835321068763733: 100%|██████████| 79/79 [00:02<00:00, 38.07it/s] 
current loss:0.3813697099685669: 100%|██████████| 79/79 [00:02<00:00, 39.15it/s] 
current loss:1.6702147722244263:   5%|▌         | 4/79 [00:00<00:01, 39.27it/s]

{'epoch': 99, 'test loss': 2.8840367794036865, 'accuracy': 0.555899977684021, 'temp': 2.7162519097328186, 'beta1': 0.7153030037879944, 'beta2': 0.3486984670162201}


current loss:1.2539260387420654: 100%|██████████| 79/79 [00:02<00:00, 39.04it/s]
current loss:1.208921194076538:   5%|▌         | 4/79 [00:00<00:01, 38.43it/s] 

{'epoch': 0, 'test loss': 1.7738889455795288, 'accuracy': 0.3563999831676483, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:1.1240736246109009: 100%|██████████| 79/79 [00:02<00:00, 39.03it/s]
current loss:1.0392874479293823: 100%|██████████| 79/79 [00:02<00:00, 38.90it/s]
current loss:0.9807825088500977: 100%|██████████| 79/79 [00:02<00:00, 39.41it/s]
current loss:0.9313694834709167: 100%|██████████| 79/79 [00:02<00:00, 38.65it/s]
current loss:0.888186514377594:   5%|▌         | 4/79 [00:00<00:02, 37.40it/s] 

{'epoch': 4, 'test loss': 1.3984673023223877, 'accuracy': 0.5245000123977661, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.893180251121521: 100%|██████████| 79/79 [00:01<00:00, 39.58it/s] 
current loss:0.8599916696548462: 100%|██████████| 79/79 [00:02<00:00, 39.43it/s]
current loss:0.8312677145004272: 100%|██████████| 79/79 [00:02<00:00, 39.09it/s]
current loss:0.8035972714424133: 100%|██████████| 79/79 [00:02<00:00, 38.93it/s]
current loss:0.7787098288536072: 100%|██████████| 79/79 [00:01<00:00, 39.55it/s]
current loss:0.7336326837539673:   5%|▌         | 4/79 [00:00<00:01, 39.05it/s]

{'epoch': 9, 'test loss': 1.5256586074829102, 'accuracy': 0.5277000069618225, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.7571166157722473: 100%|██████████| 79/79 [00:02<00:00, 38.69it/s]
current loss:0.7364963293075562: 100%|██████████| 79/79 [00:02<00:00, 38.81it/s]
current loss:0.7197533845901489: 100%|██████████| 79/79 [00:02<00:00, 38.89it/s]
current loss:0.7038689851760864: 100%|██████████| 79/79 [00:02<00:00, 39.26it/s]
current loss:0.6917222738265991: 100%|██████████| 79/79 [00:02<00:00, 38.98it/s]
current loss:0.6391732096672058:   5%|▌         | 4/79 [00:00<00:01, 37.76it/s]

{'epoch': 14, 'test loss': 1.5707740783691406, 'accuracy': 0.5547999739646912, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.6832338571548462: 100%|██████████| 79/79 [00:02<00:00, 39.19it/s]
current loss:0.6662374138832092: 100%|██████████| 79/79 [00:02<00:00, 38.88it/s]
current loss:0.6513720750808716: 100%|██████████| 79/79 [00:02<00:00, 39.12it/s]
current loss:0.6379522085189819: 100%|██████████| 79/79 [00:02<00:00, 38.86it/s]
current loss:0.6266793012619019: 100%|██████████| 79/79 [00:02<00:00, 38.74it/s]
current loss:0.5722004771232605:   5%|▌         | 4/79 [00:00<00:01, 37.60it/s]

{'epoch': 19, 'test loss': 1.7338749170303345, 'accuracy': 0.5512999892234802, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.6160462498664856: 100%|██████████| 79/79 [00:02<00:00, 39.16it/s]
current loss:0.6061490774154663: 100%|██████████| 79/79 [00:02<00:00, 38.95it/s]
current loss:0.5956818461418152: 100%|██████████| 79/79 [00:02<00:00, 38.86it/s]
current loss:0.5889562368392944: 100%|██████████| 79/79 [00:01<00:00, 39.78it/s]
current loss:0.583314061164856: 100%|██████████| 79/79 [00:02<00:00, 39.42it/s] 
current loss:0.5280445218086243:   5%|▌         | 4/79 [00:00<00:01, 38.60it/s]

{'epoch': 24, 'test loss': 1.747589349746704, 'accuracy': 0.5561999678611755, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.5726593136787415: 100%|██████████| 79/79 [00:02<00:00, 38.99it/s]
current loss:0.5578781962394714: 100%|██████████| 79/79 [00:02<00:00, 39.04it/s]
current loss:0.5443860292434692: 100%|██████████| 79/79 [00:02<00:00, 38.87it/s]
current loss:0.5324739813804626: 100%|██████████| 79/79 [00:02<00:00, 38.85it/s] 
current loss:0.5245113968849182: 100%|██████████| 79/79 [00:02<00:00, 38.99it/s] 
current loss:0.47833213210105896:   5%|▌         | 4/79 [00:00<00:01, 39.19it/s]

{'epoch': 29, 'test loss': 1.8418891429901123, 'accuracy': 0.5568000078201294, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.5207871198654175: 100%|██████████| 79/79 [00:02<00:00, 38.97it/s] 
current loss:0.5180486440658569: 100%|██████████| 79/79 [00:02<00:00, 38.96it/s] 
current loss:0.5155802965164185: 100%|██████████| 79/79 [00:02<00:00, 38.83it/s] 
current loss:0.5069218873977661: 100%|██████████| 79/79 [00:02<00:00, 38.94it/s] 
current loss:0.5011433362960815: 100%|██████████| 79/79 [00:02<00:00, 39.17it/s] 
current loss:0.45226845145225525:   5%|▌         | 4/79 [00:00<00:01, 39.45it/s]

{'epoch': 34, 'test loss': 1.9713488817214966, 'accuracy': 0.5449000000953674, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.5006364583969116: 100%|██████████| 79/79 [00:02<00:00, 39.03it/s] 
current loss:0.49743857979774475: 100%|██████████| 79/79 [00:02<00:00, 39.12it/s]
current loss:0.48863863945007324: 100%|██████████| 79/79 [00:02<00:00, 39.27it/s]
current loss:0.4814920425415039: 100%|██████████| 79/79 [00:02<00:00, 38.85it/s] 
current loss:0.4800458550453186: 100%|██████████| 79/79 [00:02<00:00, 39.22it/s] 
current loss:0.4525145888328552:   5%|▌         | 4/79 [00:00<00:01, 39.50it/s] 

{'epoch': 39, 'test loss': 2.0824134349823, 'accuracy': 0.5376999974250793, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.4918295741081238: 100%|██████████| 79/79 [00:02<00:00, 38.84it/s] 
current loss:0.5003794431686401: 100%|██████████| 79/79 [00:02<00:00, 39.49it/s] 
current loss:0.5088450312614441: 100%|██████████| 79/79 [00:02<00:00, 39.13it/s] 
current loss:0.48550671339035034: 100%|██████████| 79/79 [00:02<00:00, 39.16it/s]
current loss:0.46913281083106995: 100%|██████████| 79/79 [00:02<00:00, 38.86it/s]
current loss:0.4307404160499573:   6%|▋         | 5/79 [00:00<00:01, 41.20it/s]

{'epoch': 44, 'test loss': 2.1418442726135254, 'accuracy': 0.5365999937057495, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.45080462098121643: 100%|██████████| 79/79 [00:02<00:00, 39.02it/s]
current loss:0.43551015853881836: 100%|██████████| 79/79 [00:02<00:00, 38.68it/s]
current loss:0.4239957332611084: 100%|██████████| 79/79 [00:02<00:00, 39.26it/s] 
current loss:0.4151502251625061: 100%|██████████| 79/79 [00:02<00:00, 39.30it/s] 
current loss:0.4086225628852844: 100%|██████████| 79/79 [00:02<00:00, 38.89it/s] 
current loss:0.38482168316841125:   6%|▋         | 5/79 [00:00<00:01, 39.81it/s]

{'epoch': 49, 'test loss': 2.179382801055908, 'accuracy': 0.5462999939918518, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.40263137221336365: 100%|██████████| 79/79 [00:02<00:00, 38.79it/s]
current loss:0.39765793085098267: 100%|██████████| 79/79 [00:02<00:00, 38.68it/s]
current loss:0.392251193523407: 100%|██████████| 79/79 [00:02<00:00, 39.13it/s]  
current loss:0.3880766034126282: 100%|██████████| 79/79 [00:02<00:00, 39.39it/s] 
current loss:0.3831044137477875: 100%|██████████| 79/79 [00:02<00:00, 38.86it/s] 
current loss:0.365244060754776:   5%|▌         | 4/79 [00:00<00:01, 39.45it/s]  

{'epoch': 54, 'test loss': 2.316977024078369, 'accuracy': 0.5447999835014343, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.3799169361591339: 100%|██████████| 79/79 [00:02<00:00, 39.23it/s] 
current loss:0.37605080008506775: 100%|██████████| 79/79 [00:01<00:00, 39.60it/s]
current loss:0.37034159898757935: 100%|██████████| 79/79 [00:02<00:00, 38.76it/s]
current loss:0.3669285774230957: 100%|██████████| 79/79 [00:02<00:00, 38.78it/s] 
current loss:0.363444983959198: 100%|██████████| 79/79 [00:02<00:00, 39.13it/s]  
current loss:0.35312363505363464:   5%|▌         | 4/79 [00:00<00:01, 38.08it/s]

{'epoch': 59, 'test loss': 2.409700393676758, 'accuracy': 0.5400999784469604, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.3622265160083771: 100%|██████████| 79/79 [00:02<00:00, 39.05it/s] 
current loss:0.3620818257331848: 100%|██████████| 79/79 [00:02<00:00, 38.87it/s] 
current loss:0.3624568283557892: 100%|██████████| 79/79 [00:02<00:00, 38.79it/s] 
current loss:0.3587300777435303: 100%|██████████| 79/79 [00:02<00:00, 39.44it/s] 
current loss:0.3611127734184265: 100%|██████████| 79/79 [00:02<00:00, 38.86it/s] 
current loss:0.3472211956977844:   5%|▌         | 4/79 [00:00<00:01, 39.47it/s] 

{'epoch': 64, 'test loss': 2.528999090194702, 'accuracy': 0.5327999591827393, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.35524457693099976: 100%|██████████| 79/79 [00:02<00:00, 38.51it/s]
current loss:0.3571789860725403: 100%|██████████| 79/79 [00:02<00:00, 38.78it/s] 
current loss:0.35575607419013977: 100%|██████████| 79/79 [00:02<00:00, 39.01it/s]
current loss:0.3506823480129242: 100%|██████████| 79/79 [00:02<00:00, 39.01it/s] 
current loss:0.34805774688720703: 100%|██████████| 79/79 [00:02<00:00, 39.27it/s]
current loss:0.3254656195640564:   5%|▌         | 4/79 [00:00<00:01, 39.81it/s] 

{'epoch': 69, 'test loss': 2.524679660797119, 'accuracy': 0.5376999974250793, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.3466694951057434: 100%|██████████| 79/79 [00:01<00:00, 39.82it/s] 
current loss:0.3461199104785919: 100%|██████████| 79/79 [00:02<00:00, 39.23it/s] 
current loss:0.3455655872821808: 100%|██████████| 79/79 [00:01<00:00, 39.50it/s] 
current loss:0.34822240471839905: 100%|██████████| 79/79 [00:02<00:00, 39.03it/s]
current loss:0.35596033930778503: 100%|██████████| 79/79 [00:02<00:00, 38.98it/s]
current loss:0.33101770281791687:   6%|▋         | 5/79 [00:00<00:01, 40.59it/s]

{'epoch': 74, 'test loss': 2.554171323776245, 'accuracy': 0.5422999858856201, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.37831172347068787: 100%|██████████| 79/79 [00:02<00:00, 39.30it/s]
current loss:0.39090341329574585: 100%|██████████| 79/79 [00:02<00:00, 38.43it/s]
current loss:0.3952457904815674: 100%|██████████| 79/79 [00:02<00:00, 39.14it/s] 
current loss:0.3869093954563141: 100%|██████████| 79/79 [00:02<00:00, 39.06it/s] 
current loss:0.3790518045425415: 100%|██████████| 79/79 [00:02<00:00, 39.20it/s] 
current loss:0.33927369117736816:   6%|▋         | 5/79 [00:00<00:01, 40.21it/s]

{'epoch': 79, 'test loss': 2.58917498588562, 'accuracy': 0.5345999598503113, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.36041751503944397: 100%|██████████| 79/79 [00:02<00:00, 38.65it/s]
current loss:0.3481639623641968: 100%|██████████| 79/79 [00:02<00:00, 39.01it/s] 
current loss:0.337345689535141: 100%|██████████| 79/79 [00:02<00:00, 39.19it/s]  
current loss:0.33003395795822144: 100%|██████████| 79/79 [00:02<00:00, 38.69it/s]
current loss:0.3248136639595032: 100%|██████████| 79/79 [00:02<00:00, 38.81it/s] 
current loss:0.3103242814540863:   5%|▌         | 4/79 [00:00<00:01, 38.37it/s]

{'epoch': 84, 'test loss': 2.6618869304656982, 'accuracy': 0.5378999710083008, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.3201521039009094: 100%|██████████| 79/79 [00:02<00:00, 38.81it/s] 
current loss:0.3156408667564392: 100%|██████████| 79/79 [00:02<00:00, 39.06it/s] 
current loss:0.31512850522994995: 100%|██████████| 79/79 [00:02<00:00, 38.84it/s]
current loss:0.3144482970237732: 100%|██████████| 79/79 [00:02<00:00, 38.99it/s] 
current loss:0.31188908219337463: 100%|██████████| 79/79 [00:02<00:00, 39.27it/s]
current loss:0.29393720626831055:   5%|▌         | 4/79 [00:00<00:01, 39.07it/s]

{'epoch': 89, 'test loss': 2.707547426223755, 'accuracy': 0.5383999943733215, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.3135595917701721: 100%|██████████| 79/79 [00:02<00:00, 38.97it/s] 
current loss:0.320098340511322: 100%|██████████| 79/79 [00:02<00:00, 38.88it/s]  
current loss:0.32597988843917847: 100%|██████████| 79/79 [00:02<00:00, 39.25it/s]
current loss:0.3356485068798065: 100%|██████████| 79/79 [00:02<00:00, 38.79it/s] 
current loss:0.34437888860702515: 100%|██████████| 79/79 [00:02<00:00, 38.66it/s]
current loss:0.30845779180526733:   5%|▌         | 4/79 [00:00<00:02, 37.34it/s]

{'epoch': 94, 'test loss': 2.7661001682281494, 'accuracy': 0.5327999591827393, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:0.3458069860935211: 100%|██████████| 79/79 [00:02<00:00, 39.18it/s] 
current loss:0.3288576602935791: 100%|██████████| 79/79 [00:02<00:00, 38.97it/s] 
current loss:0.31279075145721436: 100%|██████████| 79/79 [00:02<00:00, 39.10it/s]
current loss:0.31260743737220764: 100%|██████████| 79/79 [00:02<00:00, 38.94it/s]
current loss:0.3201320767402649: 100%|██████████| 79/79 [00:02<00:00, 39.19it/s] 
current loss:1.4003757238388062:   6%|▋         | 5/79 [00:00<00:01, 39.93it/s]

{'epoch': 99, 'test loss': 2.6767613887786865, 'accuracy': 0.5393999814987183, 'temp': 4.141729772090912, 'beta1': 0.684798538684845, 'beta2': 0.39808404445648193}


current loss:1.0347975492477417: 100%|██████████| 79/79 [00:01<00:00, 39.85it/s]
current loss:0.9772519469261169:   5%|▌         | 4/79 [00:00<00:01, 39.33it/s]

{'epoch': 0, 'test loss': 1.603238821029663, 'accuracy': 0.40789997577667236, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.8846572637557983: 100%|██████████| 79/79 [00:02<00:00, 39.24it/s]
current loss:0.7978424429893494: 100%|██████████| 79/79 [00:02<00:00, 38.62it/s]
current loss:0.7390905618667603: 100%|██████████| 79/79 [00:02<00:00, 39.00it/s]
current loss:0.6981330513954163: 100%|██████████| 79/79 [00:02<00:00, 38.84it/s]
current loss:0.6954877972602844:   5%|▌         | 4/79 [00:00<00:01, 38.97it/s]

{'epoch': 4, 'test loss': 1.4145580530166626, 'accuracy': 0.5271999835968018, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.6666755676269531: 100%|██████████| 79/79 [00:02<00:00, 38.71it/s]
current loss:0.6389302015304565: 100%|██████████| 79/79 [00:02<00:00, 38.66it/s]
current loss:0.6167778372764587: 100%|██████████| 79/79 [00:02<00:00, 38.96it/s]
current loss:0.5972195863723755: 100%|██████████| 79/79 [00:02<00:00, 39.19it/s]
current loss:0.5797063112258911: 100%|██████████| 79/79 [00:02<00:00, 38.83it/s]
current loss:0.5601258873939514:   6%|▋         | 5/79 [00:00<00:01, 40.38it/s]

{'epoch': 9, 'test loss': 1.4362932443618774, 'accuracy': 0.552899956703186, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.5657404065132141: 100%|██████████| 79/79 [00:02<00:00, 39.18it/s]
current loss:0.5532761812210083: 100%|██████████| 79/79 [00:02<00:00, 39.17it/s]
current loss:0.5396425724029541: 100%|██████████| 79/79 [00:02<00:00, 39.04it/s]
current loss:0.5270166397094727: 100%|██████████| 79/79 [00:02<00:00, 38.95it/s] 
current loss:0.5150847434997559: 100%|██████████| 79/79 [00:02<00:00, 38.85it/s] 
current loss:0.48820480704307556:   6%|▋         | 5/79 [00:00<00:01, 39.88it/s]

{'epoch': 14, 'test loss': 1.5456156730651855, 'accuracy': 0.5583999752998352, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.5008901357650757: 100%|██████████| 79/79 [00:02<00:00, 39.22it/s] 
current loss:0.4904411733150482: 100%|██████████| 79/79 [00:02<00:00, 39.08it/s] 
current loss:0.4763355851173401: 100%|██████████| 79/79 [00:02<00:00, 39.10it/s] 
current loss:0.46558699011802673: 100%|██████████| 79/79 [00:02<00:00, 38.78it/s]
current loss:0.4545329511165619: 100%|██████████| 79/79 [00:02<00:00, 38.58it/s] 
current loss:0.433424711227417:   5%|▌         | 4/79 [00:00<00:01, 38.48it/s]  

{'epoch': 19, 'test loss': 1.6460036039352417, 'accuracy': 0.5644999742507935, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.443354994058609: 100%|██████████| 79/79 [00:02<00:00, 39.05it/s]  
current loss:0.44033393263816833: 100%|██████████| 79/79 [00:02<00:00, 38.57it/s]
current loss:0.4379383623600006: 100%|██████████| 79/79 [00:02<00:00, 38.74it/s] 
current loss:0.4354320466518402: 100%|██████████| 79/79 [00:02<00:00, 38.98it/s] 
current loss:0.4321622848510742: 100%|██████████| 79/79 [00:02<00:00, 38.86it/s] 
current loss:0.4054502844810486:   5%|▌         | 4/79 [00:00<00:02, 37.20it/s] 

{'epoch': 24, 'test loss': 1.8328330516815186, 'accuracy': 0.5511999726295471, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.43203991651535034: 100%|██████████| 79/79 [00:02<00:00, 38.72it/s]
current loss:0.42661523818969727: 100%|██████████| 79/79 [00:02<00:00, 39.38it/s]
current loss:0.4127601981163025: 100%|██████████| 79/79 [00:01<00:00, 39.77it/s] 
current loss:0.4024673402309418: 100%|██████████| 79/79 [00:02<00:00, 38.72it/s] 
current loss:0.3956065773963928: 100%|██████████| 79/79 [00:02<00:00, 38.90it/s] 
current loss:0.3778845965862274:   5%|▌         | 4/79 [00:00<00:01, 38.50it/s] 

{'epoch': 29, 'test loss': 1.8964576721191406, 'accuracy': 0.5530999898910522, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.38937872648239136: 100%|██████████| 79/79 [00:02<00:00, 39.25it/s]
current loss:0.38365381956100464: 100%|██████████| 79/79 [00:02<00:00, 38.64it/s]
current loss:0.38064518570899963: 100%|██████████| 79/79 [00:02<00:00, 39.40it/s]
current loss:0.377204030752182: 100%|██████████| 79/79 [00:02<00:00, 39.03it/s]  
current loss:0.37648430466651917: 100%|██████████| 79/79 [00:02<00:00, 39.24it/s]
current loss:0.3550637364387512:   6%|▋         | 5/79 [00:00<00:01, 40.38it/s] 

{'epoch': 34, 'test loss': 2.035360097885132, 'accuracy': 0.5428000092506409, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.377963662147522: 100%|██████████| 79/79 [00:02<00:00, 39.44it/s]  
current loss:0.38006290793418884: 100%|██████████| 79/79 [00:02<00:00, 39.36it/s]
current loss:0.3776240348815918: 100%|██████████| 79/79 [00:02<00:00, 39.11it/s] 
current loss:0.3757201135158539: 100%|██████████| 79/79 [00:02<00:00, 39.41it/s] 
current loss:0.3781293034553528: 100%|██████████| 79/79 [00:02<00:00, 38.98it/s] 
current loss:0.3480686545372009:   6%|▋         | 5/79 [00:00<00:01, 40.81it/s] 

{'epoch': 39, 'test loss': 2.1344892978668213, 'accuracy': 0.5443999767303467, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.3720863461494446: 100%|██████████| 79/79 [00:02<00:00, 38.93it/s] 
current loss:0.3685370981693268: 100%|██████████| 79/79 [00:02<00:00, 38.54it/s] 
current loss:0.3633870482444763: 100%|██████████| 79/79 [00:02<00:00, 39.30it/s] 
current loss:0.3562277555465698: 100%|██████████| 79/79 [00:02<00:00, 39.35it/s] 
current loss:0.349719375371933: 100%|██████████| 79/79 [00:02<00:00, 38.44it/s]  
current loss:0.3262895941734314:   5%|▌         | 4/79 [00:00<00:01, 38.58it/s] 

{'epoch': 44, 'test loss': 2.140925407409668, 'accuracy': 0.5489999651908875, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.3418233394622803: 100%|██████████| 79/79 [00:02<00:00, 38.67it/s] 
current loss:0.33046576380729675: 100%|██████████| 79/79 [00:02<00:00, 39.25it/s]
current loss:0.32151126861572266: 100%|██████████| 79/79 [00:02<00:00, 39.27it/s]
current loss:0.31493377685546875: 100%|██████████| 79/79 [00:02<00:00, 38.59it/s]
current loss:0.3109075427055359: 100%|██████████| 79/79 [00:02<00:00, 39.43it/s] 
current loss:0.2828836441040039:   6%|▋         | 5/79 [00:00<00:01, 39.90it/s] 

{'epoch': 49, 'test loss': 2.177776336669922, 'accuracy': 0.5511999726295471, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.3070005774497986: 100%|██████████| 79/79 [00:01<00:00, 39.63it/s] 
current loss:0.3056342303752899: 100%|██████████| 79/79 [00:02<00:00, 38.41it/s] 
current loss:0.3025677800178528: 100%|██████████| 79/79 [00:02<00:00, 39.00it/s] 
current loss:0.3020032048225403: 100%|██████████| 79/79 [00:02<00:00, 39.09it/s] 
current loss:0.3014821410179138: 100%|██████████| 79/79 [00:02<00:00, 38.93it/s] 
current loss:0.27229824662208557:   5%|▌         | 4/79 [00:00<00:01, 38.40it/s]

{'epoch': 54, 'test loss': 2.2675278186798096, 'accuracy': 0.5489999651908875, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.300153523683548: 100%|██████████| 79/79 [00:02<00:00, 39.01it/s]  
current loss:0.2998940348625183: 100%|██████████| 79/79 [00:02<00:00, 38.76it/s] 
current loss:0.2976965308189392: 100%|██████████| 79/79 [00:02<00:00, 38.68it/s] 
current loss:0.29423609375953674: 100%|██████████| 79/79 [00:02<00:00, 39.00it/s]
current loss:0.2926936149597168: 100%|██████████| 79/79 [00:02<00:00, 38.78it/s] 
current loss:0.2657621204853058:   5%|▌         | 4/79 [00:00<00:01, 38.32it/s] 

{'epoch': 59, 'test loss': 2.469849109649658, 'accuracy': 0.5410999655723572, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.28968939185142517: 100%|██████████| 79/79 [00:02<00:00, 39.16it/s]
current loss:0.2921612560749054: 100%|██████████| 79/79 [00:02<00:00, 38.88it/s] 
current loss:0.29128116369247437: 100%|██████████| 79/79 [00:01<00:00, 39.61it/s]
current loss:0.2920967638492584: 100%|██████████| 79/79 [00:02<00:00, 38.89it/s] 
current loss:0.2975485920906067: 100%|██████████| 79/79 [00:02<00:00, 38.77it/s] 
current loss:0.2676340937614441:   5%|▌         | 4/79 [00:00<00:02, 37.03it/s] 

{'epoch': 64, 'test loss': 2.622582197189331, 'accuracy': 0.5327000021934509, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.3032524585723877: 100%|██████████| 79/79 [00:02<00:00, 39.02it/s] 
current loss:0.3111141324043274: 100%|██████████| 79/79 [00:02<00:00, 39.22it/s] 
current loss:0.32421135902404785: 100%|██████████| 79/79 [00:02<00:00, 39.44it/s]
current loss:0.3185969591140747: 100%|██████████| 79/79 [00:02<00:00, 38.82it/s] 
current loss:0.3016517162322998: 100%|██████████| 79/79 [00:02<00:00, 38.86it/s] 
current loss:0.28751322627067566:   5%|▌         | 4/79 [00:00<00:01, 39.71it/s]

{'epoch': 69, 'test loss': 2.8601434230804443, 'accuracy': 0.5156999826431274, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.2872140109539032: 100%|██████████| 79/79 [00:02<00:00, 38.88it/s] 
current loss:0.27746352553367615: 100%|██████████| 79/79 [00:02<00:00, 38.97it/s]
current loss:0.27236929535865784: 100%|██████████| 79/79 [00:02<00:00, 39.13it/s]
current loss:0.2665298879146576: 100%|██████████| 79/79 [00:02<00:00, 39.23it/s] 
current loss:0.26403898000717163: 100%|██████████| 79/79 [00:01<00:00, 39.50it/s]
current loss:0.24909044802188873:   5%|▌         | 4/79 [00:00<00:01, 38.98it/s]

{'epoch': 74, 'test loss': 2.665229320526123, 'accuracy': 0.5372999906539917, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.26138216257095337: 100%|██████████| 79/79 [00:02<00:00, 39.20it/s]
current loss:0.26029476523399353: 100%|██████████| 79/79 [00:02<00:00, 38.89it/s]
current loss:0.2593821883201599: 100%|██████████| 79/79 [00:02<00:00, 39.31it/s] 
current loss:0.26192548871040344: 100%|██████████| 79/79 [00:02<00:00, 39.02it/s]
current loss:0.25982439517974854: 100%|██████████| 79/79 [00:02<00:00, 38.87it/s]
current loss:0.23744256794452667:   5%|▌         | 4/79 [00:00<00:01, 39.39it/s]

{'epoch': 79, 'test loss': 2.843296527862549, 'accuracy': 0.5331999659538269, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.2591767907142639: 100%|██████████| 79/79 [00:02<00:00, 39.06it/s] 
current loss:0.2592661380767822: 100%|██████████| 79/79 [00:02<00:00, 38.88it/s] 
current loss:0.26049619913101196: 100%|██████████| 79/79 [00:02<00:00, 38.88it/s]
current loss:0.2561143636703491: 100%|██████████| 79/79 [00:02<00:00, 39.28it/s] 
current loss:0.2514943480491638: 100%|██████████| 79/79 [00:02<00:00, 39.11it/s] 
current loss:0.23177914321422577:   5%|▌         | 4/79 [00:00<00:01, 39.21it/s]

{'epoch': 84, 'test loss': 2.9388058185577393, 'accuracy': 0.5340999960899353, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.2486763447523117: 100%|██████████| 79/79 [00:02<00:00, 38.46it/s] 
current loss:0.24514582753181458: 100%|██████████| 79/79 [00:02<00:00, 38.83it/s]
current loss:0.24460415542125702: 100%|██████████| 79/79 [00:02<00:00, 39.01it/s]
current loss:0.24602976441383362: 100%|██████████| 79/79 [00:02<00:00, 38.90it/s]
current loss:0.24352219700813293: 100%|██████████| 79/79 [00:02<00:00, 39.16it/s]
current loss:0.22276847064495087:   5%|▌         | 4/79 [00:00<00:01, 39.31it/s]

{'epoch': 89, 'test loss': 3.125361680984497, 'accuracy': 0.5249999761581421, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.24225327372550964: 100%|██████████| 79/79 [00:02<00:00, 39.02it/s]
current loss:0.240276500582695: 100%|██████████| 79/79 [00:02<00:00, 38.41it/s]  
current loss:0.2342139035463333: 100%|██████████| 79/79 [00:02<00:00, 38.64it/s] 
current loss:0.23635125160217285: 100%|██████████| 79/79 [00:02<00:00, 38.66it/s]
current loss:0.24128273129463196: 100%|██████████| 79/79 [00:02<00:00, 39.02it/s]
current loss:0.2249315083026886:   6%|▋         | 5/79 [00:00<00:01, 40.34it/s] 

{'epoch': 94, 'test loss': 3.243072032928467, 'accuracy': 0.5185999870300293, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


current loss:0.23840448260307312: 100%|██████████| 79/79 [00:02<00:00, 39.11it/s]
current loss:0.23770686984062195: 100%|██████████| 79/79 [00:02<00:00, 39.27it/s]
current loss:0.22614264488220215: 100%|██████████| 79/79 [00:02<00:00, 38.62it/s]
current loss:0.22679536044597626: 100%|██████████| 79/79 [00:02<00:00, 39.06it/s]
current loss:0.22590534389019012: 100%|██████████| 79/79 [00:02<00:00, 38.78it/s]


{'epoch': 99, 'test loss': 3.188584089279175, 'accuracy': 0.527899980545044, 'temp': 2.6265543699264526, 'beta1': 0.2749079763889313, 'beta2': 0.3148784935474396}


In [7]:
# Запуск --- с CNN-дистилляцией и оптимизацией гиперпараметров, 2-beta
crit = nn.CrossEntropyLoss()
# определяем функцию потерь как замкнутую относительно аргументов функцию
# нужно для подсчета градиентов гиперпараметров по двухуровневой оптимизации
def param_loss(batch,model,h):
    x,y,batch_logits = batch    
    beta,beta2,temp = h
    out = model(x)
    beta = F.sigmoid(beta)
    beta2 = F.sigmoid(beta2)
    temp = F.sigmoid(temp) * 9.9+0.1
    distillation_loss = distill(out, batch_logits, temp)
    student_loss = crit(out, y)                
    loss = beta * distillation_loss + beta2 * student_loss
    return loss
# определяем функцию валидационную функцию потерь как замкнутую относительно аргументов функцию
# нужно для подсчета градиентов гиперпараметров по двухуровневой оптимизации
def hyperparam_loss(batch, model):
    x,y = batch
    out = model(x)
    student_loss = crit(out, y)            
    return student_loss

hist = []
logits = np.load('../code/logits_cnn.npy')
for _ in range(run_num):
    internal_results = []    
    # теперь beta и temp - не числа, а тензоры, по которым можно считать градиент
    beta1 = t.nn.Parameter(t.tensor(np.random.uniform(low=-1, high=1), device=device), requires_grad=True)
    beta2 = t.nn.Parameter(t.tensor(np.random.uniform(low=-1, high=1), device=device), requires_grad=True)
    temp = t.nn.Parameter(t.tensor(np.random.uniform(low=-2, high=0), device=device), requires_grad=True)    
    h = [beta1, beta2, temp]
    
    student = Cifar_Very_Tiny(10).to(device)
    optim = t.optim.Adam(student.parameters())   
    
    # параметры Adam и функцию подсчета градиента 
    # взял из статьи по DARTS (выбор архитектуры сети градиентными методами)
    # там также используется оптимизация гиперпараметров
    optim2 = t.optim.SGD(h,  lr=10e4)   
    hyper_grad_calc = hyperparams.AdamHyperGradCalculator(student, param_loss, hyperparam_loss, optim, h)
    
    crit = t.nn.CrossEntropyLoss()

    for e in range(epoch_num): # хочется посмотреть куда сойдутся гиперпараметры, поэтому возьмем побольше эпох
        
        
        tq = tqdm.tqdm(zip(train_loader_no_augumentation, valid_loader))
        losses = []
        for batch_id, ((x,y), (v_x, v_y)) in enumerate(tq):
       
            x = x.to(device)
            y = y.to(device)            
                      
            batch_logits = t.Tensor(logits[128*batch_id:128*(batch_id+1)]).to(device) 
            # если настала пора понаблюдать за траекторий гиперпараметров
          
            #print (batch_id, 'train mini')
            v_x = v_x.to(device)
            v_y = v_y.to(device)  
            optim2.zero_grad()            
            hyper_grad_calc.calc_gradients((x,y,batch_logits), (v_x, v_y))    
            t.nn.utils.clip_grad_value_(h, 1.0)
            for h_ in h:
                h_.grad = t.where(t.isnan(h_.grad), t.zeros_like(h_.grad), h_.grad)  

            optim2.step()                         
            optim.zero_grad()
            loss = param_loss((x,y,batch_logits), student,h)
            losses.append(loss.cpu().detach().numpy())
            loss.backward()
            optim.step()
            tq.set_description('current loss:{}'.format(np.mean(losses[-10:])))
    
        if e==0 or (e+1)%validate_every_epoch == 0: # если номер эпохи делится на 5 или эпоха - первая             
            test_loss = []
            student.eval()
            for x,y in test_loader:
                x = x.to(device)
                y = y.to(device)                            
                test_loss.append(crit(student(x), y).detach().cpu().numpy())                 
            test_loss = float(np.mean(test_loss))
            test_loss2 = []            
            for x,y in test_loader:
                x = x.to(device)
                y = y.to(device)                            
                test_loss2.append(crit(student(x), y).detach().cpu().numpy())                 
            print (float(np.mean(test_loss2)))
            
            
            acc = float(accuracy(student))
            student.train()
            internal_results.append({'epoch': e, 'test loss':test_loss, 'accuracy':acc, 
                                     'temp':float(0.1+9.9*F.sigmoid(h[2]).cpu().detach().numpy()),
                                     'beta1':float(F.sigmoid(h[0]).cpu().detach().numpy()),
                                     'beta2':float(F.sigmoid(h[1]).cpu().detach().numpy())})
            
            print (internal_results[-1])

            
    with open('exp'+experiment_version+'_dist_h_b2_optim.jsonl', 'a') as out:
        out.write(json.dumps({'results':internal_results, 'version': experiment_version})+'\n')

  f = F.log_softmax(batch_logits/temp)
current loss:1.3017537593841553: : 79it [00:04, 16.93it/s]


1.6263879537582397


current loss:1.2302647829055786: : 2it [00:00, 16.78it/s]

{'epoch': 0, 'test loss': 1.6265851259231567, 'accuracy': 0.4016999900341034, 'temp': 4.4052257329225535, 'beta1': 0.3685993254184723, 'beta2': 0.605717658996582}


current loss:1.166755199432373: : 79it [00:04, 17.41it/s] 
current loss:1.0916316509246826: : 79it [00:04, 17.44it/s]
current loss:1.0362969636917114: : 79it [00:04, 17.29it/s]
current loss:0.9906137585639954: : 79it [00:04, 17.74it/s]


1.346238374710083


current loss:0.9420437812805176: : 2it [00:00, 16.90it/s]

{'epoch': 4, 'test loss': 1.3413652181625366, 'accuracy': 0.5242999792098999, 'temp': 4.329772618412972, 'beta1': 0.37247970700263977, 'beta2': 0.6241470575332642}


current loss:0.9590458869934082: : 79it [00:04, 17.26it/s]
current loss:0.9331226348876953: : 79it [00:04, 17.34it/s]
current loss:0.9084598422050476: : 79it [00:04, 17.33it/s]
current loss:0.889809250831604: : 79it [00:04, 17.77it/s] 
current loss:0.8741071820259094: : 79it [00:04, 17.84it/s]


1.388424038887024


current loss:0.8140456080436707: : 2it [00:00, 16.85it/s]

{'epoch': 9, 'test loss': 1.3779600858688354, 'accuracy': 0.5389999747276306, 'temp': 4.206782177090645, 'beta1': 0.3790625035762787, 'beta2': 0.6461889743804932}


current loss:0.8618078231811523: : 79it [00:04, 17.52it/s]
current loss:0.8500486612319946: : 79it [00:04, 17.84it/s]
current loss:0.8383117914199829: : 79it [00:04, 17.62it/s]
current loss:0.8269823789596558: : 79it [00:04, 17.46it/s]
current loss:0.8175268173217773: : 79it [00:04, 17.23it/s]


1.4895967245101929


current loss:0.7428112030029297: : 2it [00:00, 16.72it/s]

{'epoch': 14, 'test loss': 1.475982904434204, 'accuracy': 0.5383999943733215, 'temp': 4.074038794636726, 'beta1': 0.3861335515975952, 'beta2': 0.6673938035964966}


current loss:0.8063219785690308: : 79it [00:04, 17.27it/s]
current loss:0.7999871373176575: : 79it [00:04, 17.54it/s]
current loss:0.7917097806930542: : 79it [00:04, 17.43it/s]
current loss:0.7834007740020752: : 79it [00:04, 17.34it/s]
current loss:0.7768707275390625: : 79it [00:04, 17.25it/s]


1.612917423248291


current loss:0.7017958760261536: : 2it [00:00, 17.26it/s]

{'epoch': 19, 'test loss': 1.6107280254364014, 'accuracy': 0.5328999757766724, 'temp': 3.9294792354106907, 'beta1': 0.3938654363155365, 'beta2': 0.6864962577819824}


current loss:0.7792097330093384: : 79it [00:04, 17.43it/s]
current loss:0.7932737469673157: : 79it [00:04, 17.58it/s]
current loss:0.7910440564155579: : 79it [00:04, 17.61it/s]
current loss:0.7907085418701172: : 79it [00:04, 17.63it/s]
current loss:0.7785986661911011: : 79it [00:04, 17.77it/s]


1.6780295372009277


current loss:0.6722232699394226: : 2it [00:00, 17.37it/s]

{'epoch': 24, 'test loss': 1.6730891466140747, 'accuracy': 0.5382999777793884, 'temp': 3.74751678109169, 'beta1': 0.40325868129730225, 'beta2': 0.7098405957221985}


current loss:0.7626968622207642: : 79it [00:04, 17.66it/s]
current loss:0.7482647895812988: : 79it [00:04, 17.29it/s]
current loss:0.7342264652252197: : 79it [00:04, 17.47it/s]
current loss:0.7257925271987915: : 79it [00:04, 17.38it/s]
current loss:0.7144201397895813: : 79it [00:04, 17.70it/s]


1.813779592514038


current loss:0.655958354473114: : 2it [00:00, 17.85it/s]

{'epoch': 29, 'test loss': 1.8175066709518433, 'accuracy': 0.5386999845504761, 'temp': 3.605783733725548, 'beta1': 0.41076087951660156, 'beta2': 0.7245112061500549}


current loss:0.7063024640083313: : 79it [00:04, 17.67it/s]
current loss:0.6972044706344604: : 79it [00:04, 17.25it/s]
current loss:0.6903601288795471: : 79it [00:04, 17.26it/s]
current loss:0.6889689564704895: : 79it [00:04, 17.73it/s]
current loss:0.6914924383163452: : 79it [00:04, 17.49it/s]


1.9631751775741577


current loss:0.6544389724731445: : 2it [00:00, 17.14it/s]

{'epoch': 34, 'test loss': 1.9444810152053833, 'accuracy': 0.5339999794960022, 'temp': 3.4654952168464663, 'beta1': 0.4179129898548126, 'beta2': 0.7376765012741089}


current loss:0.6949753761291504: : 79it [00:04, 17.18it/s]
current loss:0.6944025754928589: : 79it [00:04, 17.24it/s]
current loss:0.6915723085403442: : 79it [00:04, 17.11it/s]
current loss:0.693548321723938: : 79it [00:04, 17.60it/s] 
current loss:0.6872674822807312: : 79it [00:04, 17.41it/s]


2.055521011352539


current loss:0.624057948589325: : 2it [00:00, 17.26it/s]

{'epoch': 39, 'test loss': 2.055975914001465, 'accuracy': 0.5302000045776367, 'temp': 3.3209330022335055, 'beta1': 0.4252382218837738, 'beta2': 0.7496042251586914}


current loss:0.6881000399589539: : 79it [00:04, 17.35it/s]
current loss:0.6940128207206726: : 79it [00:04, 17.31it/s]
current loss:0.6919234991073608: : 79it [00:04, 17.55it/s]
current loss:0.689541220664978: : 79it [00:04, 17.58it/s] 
current loss:0.6878783106803894: : 79it [00:04, 17.56it/s]


2.1459834575653076


current loss:0.6331682205200195: : 2it [00:00, 17.93it/s]

{'epoch': 44, 'test loss': 2.1393513679504395, 'accuracy': 0.527899980545044, 'temp': 3.1979086309671403, 'beta1': 0.4316053092479706, 'beta2': 0.7574560642242432}


current loss:0.6754481196403503: : 79it [00:04, 17.41it/s]
current loss:0.6695823669433594: : 79it [00:04, 17.67it/s]
current loss:0.6540367603302002: : 79it [00:04, 17.54it/s]
current loss:0.6546670198440552: : 79it [00:04, 17.39it/s]
current loss:0.6587449908256531: : 79it [00:04, 17.13it/s]


2.2269349098205566


current loss:0.58729487657547: : 2it [00:00, 16.99it/s]  

{'epoch': 49, 'test loss': 2.2176873683929443, 'accuracy': 0.5349000096321106, 'temp': 3.0571114599704745, 'beta1': 0.4389081299304962, 'beta2': 0.7654333710670471}


current loss:0.6701923608779907: : 79it [00:04, 17.54it/s]
current loss:0.715788722038269: : 79it [00:04, 17.42it/s] 
current loss:0.7244351506233215: : 79it [00:04, 17.61it/s]
current loss:0.6850741505622864: : 79it [00:04, 17.64it/s]
current loss:0.6541073322296143: : 79it [00:04, 17.42it/s]


2.232384443283081


current loss:0.6118568778038025: : 2it [00:00, 18.05it/s]

{'epoch': 54, 'test loss': 2.2518935203552246, 'accuracy': 0.5306999683380127, 'temp': 2.887581527233124, 'beta1': 0.4481451213359833, 'beta2': 0.774332582950592}


current loss:0.6349664926528931: : 79it [00:04, 17.52it/s]
current loss:0.6166360974311829: : 79it [00:04, 17.55it/s]
current loss:0.6074409484863281: : 79it [00:04, 17.61it/s]
current loss:0.597339391708374: : 79it [00:04, 17.40it/s] 
current loss:0.5874051451683044: : 79it [00:04, 17.37it/s]


2.329512357711792


current loss:0.5192108750343323: : 2it [00:00, 16.88it/s]

{'epoch': 59, 'test loss': 2.3320417404174805, 'accuracy': 0.5393999814987183, 'temp': 2.790560179948807, 'beta1': 0.4541175365447998, 'beta2': 0.7787737250328064}


current loss:0.5823472738265991: : 79it [00:04, 17.59it/s]
current loss:0.5777136087417603: : 79it [00:04, 17.66it/s]
current loss:0.573780357837677: : 79it [00:04, 17.36it/s] 
current loss:0.571344256401062: : 79it [00:04, 17.44it/s] 
current loss:0.5693166851997375: : 79it [00:04, 17.43it/s]


2.588453531265259


current loss:0.48007139563560486: : 2it [00:00, 17.71it/s]

{'epoch': 64, 'test loss': 2.6021358966827393, 'accuracy': 0.5345999598503113, 'temp': 2.703070491552353, 'beta1': 0.4595061242580414, 'beta2': 0.7820281982421875}


current loss:0.5813384056091309: : 79it [00:04, 17.50it/s]
current loss:0.5856188535690308: : 79it [00:04, 17.78it/s]
current loss:0.5882053971290588: : 79it [00:04, 17.81it/s]
current loss:0.5955738425254822: : 79it [00:04, 17.55it/s]
current loss:0.6090722680091858: : 79it [00:04, 17.51it/s]


2.695180654525757


current loss:0.4929533004760742: : 2it [00:00, 17.82it/s]

{'epoch': 69, 'test loss': 2.679417371749878, 'accuracy': 0.5313000082969666, 'temp': 2.6101455211639406, 'beta1': 0.465158611536026, 'beta2': 0.7858591675758362}


current loss:0.6254734992980957: : 79it [00:04, 17.57it/s]
current loss:0.6249877214431763: : 79it [00:04, 17.45it/s]
current loss:0.6401739120483398: : 79it [00:04, 17.55it/s]
current loss:0.63704913854599: : 79it [00:04, 17.50it/s]  
current loss:0.6129701733589172: : 79it [00:04, 17.12it/s]


2.7919790744781494


current loss:0.5548305511474609: : 2it [00:00, 17.68it/s]

{'epoch': 74, 'test loss': 2.7881078720092773, 'accuracy': 0.5270000100135803, 'temp': 2.4848822161555293, 'beta1': 0.47257867455482483, 'beta2': 0.7916404008865356}


current loss:0.5965907573699951: : 79it [00:04, 17.54it/s]
current loss:0.5854742527008057: : 79it [00:04, 17.49it/s]
current loss:0.5795331597328186: : 79it [00:04, 17.05it/s]
current loss:0.5704843997955322: : 79it [00:04, 17.21it/s]
current loss:0.5764371156692505: : 79it [00:04, 17.59it/s]


2.9776740074157715


current loss:0.524960994720459: : 2it [00:00, 18.27it/s] 

{'epoch': 79, 'test loss': 2.980046272277832, 'accuracy': 0.5191999673843384, 'temp': 2.341389679908753, 'beta1': 0.48140376806259155, 'beta2': 0.7972025871276855}


current loss:0.5659851431846619: : 79it [00:04, 17.78it/s]
current loss:0.5777302980422974: : 79it [00:04, 17.97it/s]
current loss:0.5929580926895142: : 79it [00:04, 17.53it/s]
current loss:0.6029939651489258: : 79it [00:04, 17.48it/s]
current loss:0.6075765490531921: : 79it [00:04, 17.70it/s]


3.0794172286987305


current loss:0.5297022461891174: : 2it [00:00, 17.15it/s]

{'epoch': 84, 'test loss': 3.090209484100342, 'accuracy': 0.5155999660491943, 'temp': 2.2143125981092453, 'beta1': 0.4897114336490631, 'beta2': 0.8012228012084961}


current loss:0.6196204423904419: : 79it [00:04, 17.43it/s]
current loss:0.6581814289093018: : 79it [00:04, 17.49it/s]
current loss:0.6264867186546326: : 79it [00:04, 17.66it/s]
current loss:0.6121066808700562: : 79it [00:04, 17.39it/s]
current loss:0.6100121736526489: : 79it [00:04, 17.45it/s]


3.1776177883148193


current loss:0.5146216750144958: : 2it [00:00, 18.19it/s]

{'epoch': 89, 'test loss': 3.198621988296509, 'accuracy': 0.5157999992370605, 'temp': 2.067721372842789, 'beta1': 0.499505877494812, 'beta2': 0.806324303150177}


current loss:0.5953188538551331: : 79it [00:04, 17.27it/s]
current loss:0.5727754831314087: : 79it [00:04, 17.52it/s]
current loss:0.556590735912323: : 79it [00:04, 17.19it/s] 
current loss:0.5573680996894836: : 79it [00:04, 17.30it/s] 
current loss:0.5822465419769287: : 79it [00:04, 17.06it/s]


3.193737745285034


current loss:0.4723166525363922: : 2it [00:00, 17.43it/s] 

{'epoch': 94, 'test loss': 3.209338665008545, 'accuracy': 0.5250999927520752, 'temp': 1.9713625445961953, 'beta1': 0.5069623589515686, 'beta2': 0.8095009326934814}


current loss:0.5877798795700073: : 79it [00:04, 17.49it/s]
current loss:0.6155301928520203: : 79it [00:04, 17.77it/s]
current loss:0.646986722946167: : 79it [00:04, 17.28it/s] 
current loss:0.6520911455154419: : 79it [00:04, 17.46it/s]
current loss:0.6438767910003662: : 79it [00:04, 17.39it/s]


3.821662187576294


current loss:3.2018206119537354: : 2it [00:00, 17.12it/s]

{'epoch': 99, 'test loss': 3.836545467376709, 'accuracy': 0.4976999759674072, 'temp': 1.8488009482622147, 'beta1': 0.5172783732414246, 'beta2': 0.8133730888366699}


current loss:2.207810878753662: : 79it [00:04, 17.48it/s] 


1.6860008239746094


current loss:2.064730405807495: : 2it [00:00, 18.17it/s] 

{'epoch': 0, 'test loss': 1.684362530708313, 'accuracy': 0.3822999894618988, 'temp': 1.3063847661018373, 'beta1': 0.3523814380168915, 'beta2': 0.37267357110977173}


current loss:1.889472246170044: : 79it [00:04, 17.61it/s] 
current loss:1.827779769897461: : 79it [00:04, 17.52it/s] 
current loss:1.8729130029678345: : 79it [00:04, 17.29it/s]
current loss:1.986703872680664: : 79it [00:04, 17.13it/s] 


1.5299526453018188


current loss:1.8080615997314453: : 2it [00:00, 17.59it/s]

{'epoch': 4, 'test loss': 1.5363012552261353, 'accuracy': 0.5148000121116638, 'temp': 0.8071360290050507, 'beta1': 0.41637226939201355, 'beta2': 0.4106115996837616}


current loss:2.091416835784912: : 79it [00:04, 17.09it/s] 
current loss:2.306811571121216: : 79it [00:04, 17.32it/s] 
current loss:3.164769411087036: : 79it [00:04, 17.51it/s] 
current loss:4.781073093414307: : 79it [00:04, 17.38it/s] 
current loss:7.654008388519287: : 79it [00:04, 17.29it/s] 


1.6256177425384521


current loss:7.257728576660156: : 2it [00:00, 17.24it/s]

{'epoch': 9, 'test loss': 1.615188717842102, 'accuracy': 0.5187999606132507, 'temp': 0.22855860330164435, 'beta1': 0.6943212747573853, 'beta2': 0.46838265657424927}


current loss:20.418262481689453: : 79it [00:04, 17.37it/s]
current loss:26.890380859375: : 79it [00:04, 17.50it/s]   
current loss:27.474544525146484: : 79it [00:04, 17.28it/s]
current loss:26.288782119750977: : 79it [00:04, 17.42it/s]
current loss:21.45860481262207: : 79it [00:04, 17.76it/s] 


1.6907296180725098


current loss:19.683996200561523: : 2it [00:00, 17.53it/s]

{'epoch': 14, 'test loss': 1.6806920766830444, 'accuracy': 0.5285999774932861, 'temp': 0.10392555391881615, 'beta1': 0.9666738510131836, 'beta2': 0.5680816769599915}


current loss:22.872547149658203: : 79it [00:04, 17.46it/s]
current loss:22.856340408325195: : 79it [00:04, 17.74it/s]
current loss:21.639982223510742: : 79it [00:04, 17.61it/s]
current loss:20.86452865600586: : 79it [00:04, 17.44it/s] 
current loss:19.35684585571289: : 79it [00:04, 17.21it/s] 


1.5219448804855347


current loss:18.737703323364258: : 2it [00:00, 17.68it/s]

{'epoch': 19, 'test loss': 1.51130211353302, 'accuracy': 0.5561000108718872, 'temp': 0.1017801220092224, 'beta1': 0.9836464524269104, 'beta2': 0.6354319453239441}


current loss:18.641048431396484: : 79it [00:04, 17.28it/s]
current loss:19.100934982299805: : 79it [00:04, 17.45it/s]
current loss:19.963939666748047: : 79it [00:04, 17.41it/s]
current loss:17.685489654541016: : 79it [00:04, 17.45it/s]
current loss:16.597429275512695: : 79it [00:04, 17.51it/s]


1.50128173828125


current loss:17.795312881469727: : 2it [00:00, 17.46it/s]

{'epoch': 24, 'test loss': 1.4922659397125244, 'accuracy': 0.5616999864578247, 'temp': 0.10129461983451621, 'beta1': 0.9882352352142334, 'beta2': 0.6864341497421265}


current loss:15.740602493286133: : 79it [00:04, 17.33it/s]
current loss:15.487452507019043: : 79it [00:04, 17.33it/s]
current loss:16.42426109313965: : 79it [00:04, 17.21it/s] 
current loss:15.354474067687988: : 79it [00:04, 17.52it/s]
current loss:15.08299446105957: : 79it [00:04, 17.43it/s] 


1.4360291957855225


current loss:15.45877456665039: : 2it [00:00, 16.95it/s] 

{'epoch': 29, 'test loss': 1.4349623918533325, 'accuracy': 0.574400007724762, 'temp': 0.10104475357220509, 'beta1': 0.9904458522796631, 'beta2': 0.7230094075202942}


current loss:17.471233367919922: : 79it [00:04, 17.38it/s]
current loss:14.586740493774414: : 79it [00:04, 17.47it/s]
current loss:15.507814407348633: : 79it [00:04, 17.36it/s]
current loss:15.973030090332031: : 79it [00:04, 17.50it/s]
current loss:15.861468315124512: : 79it [00:04, 17.22it/s]


1.4429905414581299


current loss:14.229557991027832: : 2it [00:00, 17.20it/s]

{'epoch': 34, 'test loss': 1.4360575675964355, 'accuracy': 0.5681999921798706, 'temp': 0.10081379642506363, 'beta1': 0.9924569129943848, 'beta2': 0.7560235261917114}


current loss:14.145917892456055: : 79it [00:04, 17.52it/s]
current loss:14.595109939575195: : 79it [00:04, 17.46it/s]
current loss:14.279806137084961: : 79it [00:04, 17.20it/s]
current loss:14.140406608581543: : 79it [00:04, 17.30it/s]
current loss:13.604597091674805: : 79it [00:04, 17.39it/s]


1.4411834478378296


current loss:13.393040657043457: : 2it [00:00, 17.91it/s]

{'epoch': 39, 'test loss': 1.442201018333435, 'accuracy': 0.5661999583244324, 'temp': 0.10066727134762915, 'beta1': 0.9938012957572937, 'beta2': 0.7841594219207764}


current loss:13.544346809387207: : 79it [00:04, 17.49it/s]
current loss:12.885472297668457: : 79it [00:04, 17.41it/s]
current loss:12.197389602661133: : 79it [00:04, 17.39it/s]
current loss:12.227052688598633: : 79it [00:04, 17.46it/s]
current loss:11.607261657714844: : 79it [00:04, 17.29it/s]


1.3651387691497803


current loss:11.18017292022705: : 2it [00:00, 17.26it/s] 

{'epoch': 44, 'test loss': 1.368108868598938, 'accuracy': 0.5704999566078186, 'temp': 0.10062498785482604, 'beta1': 0.9942559599876404, 'beta2': 0.8000763058662415}


current loss:12.161767959594727: : 79it [00:04, 17.50it/s]
current loss:11.842790603637695: : 79it [00:04, 17.33it/s]
current loss:11.797865867614746: : 79it [00:04, 17.31it/s]
current loss:12.151239395141602: : 79it [00:04, 17.57it/s]
current loss:12.128068923950195: : 79it [00:04, 17.76it/s]


1.366507887840271


current loss:12.375667572021484: : 2it [00:00, 17.98it/s]

{'epoch': 49, 'test loss': 1.3625026941299438, 'accuracy': 0.5740999579429626, 'temp': 0.10057742596654862, 'beta1': 0.994662880897522, 'beta2': 0.813188910484314}


current loss:12.542646408081055: : 79it [00:04, 17.65it/s]
current loss:11.780023574829102: : 79it [00:04, 17.48it/s]
current loss:12.454862594604492: : 79it [00:04, 17.34it/s]
current loss:11.620166778564453: : 79it [00:04, 17.49it/s]
current loss:10.649702072143555: : 79it [00:04, 17.21it/s]


1.362784743309021


current loss:10.2969331741333: : 2it [00:00, 17.57it/s]  

{'epoch': 54, 'test loss': 1.3673474788665771, 'accuracy': 0.5745999813079834, 'temp': 0.1005316047670931, 'beta1': 0.9950838685035706, 'beta2': 0.8267570734024048}


current loss:10.66209602355957: : 79it [00:04, 17.38it/s] 
current loss:11.39321517944336: : 79it [00:04, 17.16it/s] 
current loss:10.781686782836914: : 79it [00:04, 17.21it/s]
current loss:11.768160820007324: : 79it [00:04, 17.12it/s]
current loss:10.69367504119873: : 79it [00:04, 17.49it/s] 


1.3274120092391968


current loss:10.143837928771973: : 2it [00:00, 17.70it/s]

{'epoch': 59, 'test loss': 1.319204330444336, 'accuracy': 0.5794000029563904, 'temp': 0.10048247261765937, 'beta1': 0.9955077171325684, 'beta2': 0.8402719497680664}


current loss:10.624434471130371: : 79it [00:04, 17.72it/s]
current loss:10.388486862182617: : 79it [00:04, 17.43it/s]
current loss:10.89540958404541: : 79it [00:04, 17.53it/s] 
current loss:10.103076934814453: : 79it [00:04, 17.36it/s]
current loss:11.046026229858398: : 79it [00:04, 17.32it/s]


1.3560689687728882


current loss:10.125893592834473: : 2it [00:00, 17.70it/s]

{'epoch': 64, 'test loss': 1.351547360420227, 'accuracy': 0.5708999633789062, 'temp': 0.1004597566123266, 'beta1': 0.9957343935966492, 'beta2': 0.8498545289039612}


current loss:11.514570236206055: : 79it [00:04, 17.41it/s]
current loss:11.151843070983887: : 79it [00:04, 17.35it/s]
current loss:10.728367805480957: : 79it [00:04, 17.24it/s]
current loss:9.573473930358887: : 79it [00:04, 17.21it/s] 
current loss:9.329168319702148: : 79it [00:04, 17.36it/s] 


1.3013715744018555


current loss:9.182336807250977: : 2it [00:00, 17.48it/s]

{'epoch': 69, 'test loss': 1.2923656702041626, 'accuracy': 0.5857999920845032, 'temp': 0.10042708722794487, 'beta1': 0.9960238933563232, 'beta2': 0.8593095541000366}


current loss:9.711549758911133: : 79it [00:04, 17.30it/s] 
current loss:8.930342674255371: : 79it [00:04, 17.33it/s] 
current loss:9.20984935760498: : 79it [00:04, 17.12it/s]  
current loss:10.977529525756836: : 79it [00:04, 17.27it/s]
current loss:10.394938468933105: : 79it [00:04, 17.32it/s]


1.3827677965164185


current loss:10.62585735321045: : 2it [00:00, 17.27it/s]

{'epoch': 74, 'test loss': 1.3903284072875977, 'accuracy': 0.5712000131607056, 'temp': 0.10040070990908134, 'beta1': 0.9962450861930847, 'beta2': 0.8669282793998718}


current loss:10.44104290008545: : 79it [00:04, 17.19it/s] 
current loss:8.806753158569336: : 79it [00:04, 17.73it/s] 
current loss:8.723278045654297: : 79it [00:04, 17.54it/s] 
current loss:8.621718406677246: : 79it [00:04, 17.32it/s] 
current loss:9.563055038452148: : 79it [00:04, 17.45it/s] 


1.349684238433838


current loss:8.485699653625488: : 2it [00:00, 18.15it/s]

{'epoch': 79, 'test loss': 1.3500251770019531, 'accuracy': 0.5781999826431274, 'temp': 0.10039130732257036, 'beta1': 0.9963551759719849, 'beta2': 0.8728551268577576}


current loss:10.433225631713867: : 79it [00:04, 17.58it/s]
current loss:9.27551555633545: : 79it [00:04, 17.38it/s]  
current loss:8.534539222717285: : 79it [00:04, 17.66it/s] 
current loss:9.265989303588867: : 79it [00:04, 17.55it/s]
current loss:9.146355628967285: : 79it [00:04, 17.41it/s] 


1.3075238466262817


current loss:8.519092559814453: : 2it [00:00, 17.58it/s]

{'epoch': 84, 'test loss': 1.3141335248947144, 'accuracy': 0.5877999663352966, 'temp': 0.10037264239581419, 'beta1': 0.9965299963951111, 'beta2': 0.8787373900413513}


current loss:9.119911193847656: : 79it [00:04, 17.27it/s] 
current loss:8.472445487976074: : 79it [00:04, 17.31it/s]
current loss:8.34369945526123: : 79it [00:04, 17.36it/s]  
current loss:9.090272903442383: : 79it [00:04, 17.34it/s]
current loss:10.009836196899414: : 79it [00:04, 17.34it/s]


1.3350269794464111


current loss:8.698501586914062: : 2it [00:00, 17.86it/s]

{'epoch': 89, 'test loss': 1.341927170753479, 'accuracy': 0.5745999813079834, 'temp': 0.10035360740475881, 'beta1': 0.9966963529586792, 'beta2': 0.8843344449996948}


current loss:9.51005744934082: : 79it [00:04, 17.47it/s]  
current loss:9.59361457824707: : 79it [00:04, 17.17it/s]  
current loss:8.737044334411621: : 79it [00:04, 17.41it/s] 
current loss:8.58899974822998: : 79it [00:04, 17.18it/s]  
current loss:8.654080390930176: : 79it [00:04, 17.39it/s] 


1.3102450370788574


current loss:8.305673599243164: : 2it [00:00, 17.30it/s]

{'epoch': 94, 'test loss': 1.3080577850341797, 'accuracy': 0.5809000134468079, 'temp': 0.1003383695075172, 'beta1': 0.9968379735946655, 'beta2': 0.8889551162719727}


current loss:9.089746475219727: : 79it [00:04, 17.31it/s] 
current loss:8.537192344665527: : 79it [00:04, 17.18it/s] 
current loss:9.70354175567627: : 79it [00:04, 17.06it/s]  
current loss:8.724169731140137: : 79it [00:04, 17.62it/s] 
current loss:8.522262573242188: : 79it [00:04, 17.36it/s] 


1.3361525535583496


current loss:3.524574041366577: : 2it [00:00, 17.29it/s]

{'epoch': 99, 'test loss': 1.3426684141159058, 'accuracy': 0.5785999894142151, 'temp': 0.10031822155033297, 'beta1': 0.9970160722732544, 'beta2': 0.8939468860626221}


current loss:2.4842097759246826: : 79it [00:04, 17.47it/s]


1.6453959941864014


current loss:2.3424184322357178: : 2it [00:00, 17.21it/s]

{'epoch': 0, 'test loss': 1.6412477493286133, 'accuracy': 0.38999998569488525, 'temp': 1.72330559194088, 'beta1': 0.5569825172424316, 'beta2': 0.45634862780570984}


current loss:2.12237286567688: : 79it [00:04, 17.56it/s]  
current loss:1.9827409982681274: : 79it [00:04, 17.39it/s]
current loss:1.9068918228149414: : 79it [00:04, 17.53it/s]
current loss:1.8712937831878662: : 79it [00:04, 17.25it/s]


1.5183348655700684


current loss:1.7520898580551147: : 2it [00:00, 17.07it/s]

{'epoch': 4, 'test loss': 1.5281392335891724, 'accuracy': 0.517799973487854, 'temp': 1.318281637132168, 'beta1': 0.5807989239692688, 'beta2': 0.4997516870498657}


current loss:1.8661448955535889: : 79it [00:04, 17.67it/s]
current loss:1.8786866664886475: : 79it [00:04, 17.69it/s]
current loss:1.9158365726470947: : 79it [00:04, 17.52it/s]
current loss:2.0213332176208496: : 79it [00:04, 17.29it/s]
current loss:2.152205228805542: : 79it [00:04, 17.49it/s] 


1.6003848314285278


current loss:1.972718596458435: : 2it [00:00, 18.19it/s]

{'epoch': 9, 'test loss': 1.612813949584961, 'accuracy': 0.5417999625205994, 'temp': 0.7532065220177173, 'beta1': 0.6407650709152222, 'beta2': 0.5420528054237366}


current loss:2.653510570526123: : 79it [00:04, 17.73it/s] 
current loss:3.810690402984619: : 79it [00:04, 17.54it/s] 
current loss:7.772858619689941: : 79it [00:04, 17.14it/s] 
current loss:25.053096771240234: : 79it [00:04, 17.46it/s]
current loss:24.821840286254883: : 79it [00:04, 17.56it/s]


2.139064073562622


current loss:24.80022430419922: : 2it [00:00, 17.74it/s]

{'epoch': 14, 'test loss': 2.129584789276123, 'accuracy': 0.4876999855041504, 'temp': 0.11453955612378196, 'beta1': 0.9164961576461792, 'beta2': 0.6194297671318054}


current loss:20.838794708251953: : 79it [00:04, 17.39it/s]
current loss:23.06540870666504: : 79it [00:04, 17.44it/s] 
current loss:22.03142738342285: : 79it [00:04, 17.47it/s] 
current loss:21.80172348022461: : 79it [00:04, 17.49it/s] 
current loss:20.273380279541016: : 79it [00:04, 17.23it/s]


1.6875358819961548


current loss:20.63673210144043: : 2it [00:00, 17.72it/s] 

{'epoch': 19, 'test loss': 1.6948771476745605, 'accuracy': 0.5397999882698059, 'temp': 0.10408810215012637, 'beta1': 0.968441367149353, 'beta2': 0.6874310374259949}


current loss:20.790523529052734: : 79it [00:04, 17.32it/s]
current loss:20.615705490112305: : 79it [00:04, 17.54it/s]
current loss:20.07293128967285: : 79it [00:04, 17.65it/s] 
current loss:19.251379013061523: : 79it [00:04, 17.36it/s]
current loss:18.65338706970215: : 79it [00:04, 17.13it/s] 


1.6580461263656616


current loss:18.15836524963379: : 2it [00:00, 16.85it/s]

{'epoch': 24, 'test loss': 1.6705914735794067, 'accuracy': 0.5582999587059021, 'temp': 0.10247617135755718, 'beta1': 0.9791955947875977, 'beta2': 0.7280718088150024}


current loss:17.22364044189453: : 79it [00:04, 17.30it/s] 
current loss:18.149738311767578: : 79it [00:04, 17.20it/s]
current loss:16.423505783081055: : 79it [00:04, 17.35it/s]
current loss:16.01973533630371: : 79it [00:04, 17.36it/s] 
current loss:14.967008590698242: : 79it [00:04, 17.58it/s]


1.5474236011505127


current loss:14.919621467590332: : 2it [00:00, 17.14it/s]

{'epoch': 29, 'test loss': 1.5362927913665771, 'accuracy': 0.5662999749183655, 'temp': 0.10193672356836032, 'beta1': 0.9833767414093018, 'beta2': 0.7524330019950867}


current loss:15.05053424835205: : 79it [00:04, 17.46it/s] 
current loss:16.73500633239746: : 79it [00:04, 17.21it/s] 
current loss:14.943699836730957: : 79it [00:04, 17.70it/s]
current loss:14.836767196655273: : 79it [00:04, 17.41it/s]
current loss:14.610345840454102: : 79it [00:04, 17.22it/s]


1.563080072402954


current loss:13.23208999633789: : 2it [00:00, 17.03it/s] 

{'epoch': 34, 'test loss': 1.5697364807128906, 'accuracy': 0.5546000003814697, 'temp': 0.10142792898841435, 'beta1': 0.9870505332946777, 'beta2': 0.7747579216957092}


current loss:13.977572441101074: : 79it [00:04, 17.52it/s]
current loss:13.356289863586426: : 79it [00:04, 17.36it/s]
current loss:13.21729564666748: : 79it [00:04, 17.27it/s] 
current loss:13.65520191192627: : 79it [00:04, 17.51it/s] 
current loss:13.546887397766113: : 79it [00:04, 17.56it/s]


1.5240182876586914


current loss:14.477784156799316: : 2it [00:00, 18.27it/s]

{'epoch': 39, 'test loss': 1.5270607471466064, 'accuracy': 0.56659996509552, 'temp': 0.10128945759061025, 'beta1': 0.9884278178215027, 'beta2': 0.7890336513519287}


current loss:13.692753791809082: : 79it [00:04, 17.64it/s]
current loss:13.647787094116211: : 79it [00:04, 17.30it/s]
current loss:13.021466255187988: : 79it [00:04, 17.40it/s]
current loss:12.827122688293457: : 79it [00:04, 17.37it/s]
current loss:11.717390060424805: : 79it [00:04, 17.29it/s]


1.4740103483200073


current loss:12.788212776184082: : 2it [00:00, 17.10it/s]

{'epoch': 44, 'test loss': 1.469451665878296, 'accuracy': 0.5674999952316284, 'temp': 0.10113833082214115, 'beta1': 0.9898278713226318, 'beta2': 0.802385151386261}


current loss:11.917309761047363: : 79it [00:04, 17.25it/s]
current loss:12.232942581176758: : 79it [00:04, 17.49it/s]
current loss:11.614970207214355: : 79it [00:04, 17.44it/s]
current loss:12.3844633102417: : 79it [00:04, 17.81it/s]  
current loss:11.136541366577148: : 79it [00:04, 17.14it/s]


1.5047553777694702


current loss:12.591941833496094: : 2it [00:00, 17.69it/s]

{'epoch': 49, 'test loss': 1.498121976852417, 'accuracy': 0.56659996509552, 'temp': 0.1009899801661959, 'beta1': 0.991007387638092, 'beta2': 0.814919114112854}


current loss:11.038308143615723: : 79it [00:04, 17.55it/s]
current loss:11.284210205078125: : 79it [00:04, 17.35it/s]
current loss:10.572509765625: : 79it [00:04, 17.52it/s]   
current loss:10.889418601989746: : 79it [00:04, 17.57it/s]
current loss:10.350168228149414: : 79it [00:04, 17.35it/s]


1.4685354232788086


current loss:10.282727241516113: : 2it [00:00, 17.91it/s]

{'epoch': 54, 'test loss': 1.4601397514343262, 'accuracy': 0.5709999799728394, 'temp': 0.10088770267757355, 'beta1': 0.9918827414512634, 'beta2': 0.8259462118148804}


current loss:10.142888069152832: : 79it [00:04, 17.76it/s]
current loss:9.4686279296875: : 79it [00:04, 17.54it/s]   
current loss:9.153495788574219: : 79it [00:04, 17.53it/s] 
current loss:9.111323356628418: : 79it [00:04, 17.38it/s] 
current loss:9.393434524536133: : 79it [00:04, 17.41it/s] 


1.4251667261123657


current loss:9.847163200378418: : 2it [00:00, 17.75it/s]

{'epoch': 59, 'test loss': 1.4314521551132202, 'accuracy': 0.5751999616622925, 'temp': 0.10081224197492702, 'beta1': 0.9925051927566528, 'beta2': 0.8335803151130676}


current loss:9.99797248840332: : 79it [00:04, 17.40it/s]  
current loss:9.063674926757812: : 79it [00:04, 17.62it/s] 
current loss:9.1594877243042: : 79it [00:04, 17.55it/s]  
current loss:8.977715492248535: : 79it [00:04, 17.35it/s] 
current loss:10.230879783630371: : 79it [00:04, 17.42it/s]


1.4553900957107544


current loss:10.177414894104004: : 2it [00:00, 17.61it/s]

{'epoch': 64, 'test loss': 1.4553489685058594, 'accuracy': 0.5744999647140503, 'temp': 0.10076743844329031, 'beta1': 0.9929185509681702, 'beta2': 0.8402206897735596}


current loss:9.637990951538086: : 79it [00:04, 17.55it/s] 
current loss:8.924570083618164: : 79it [00:04, 17.48it/s] 
current loss:8.787263870239258: : 79it [00:04, 17.32it/s]
current loss:10.251033782958984: : 79it [00:04, 17.52it/s]
current loss:9.04606819152832: : 79it [00:04, 17.36it/s]  


1.4289311170578003


current loss:8.744405746459961: : 2it [00:00, 17.40it/s]

{'epoch': 69, 'test loss': 1.424481749534607, 'accuracy': 0.5823000073432922, 'temp': 0.10069838044728385, 'beta1': 0.9935007691383362, 'beta2': 0.8469810485839844}


current loss:9.10175609588623: : 79it [00:04, 17.20it/s] 
current loss:9.115968704223633: : 79it [00:04, 17.14it/s]
current loss:9.600275993347168: : 79it [00:04, 17.21it/s] 
current loss:8.366682052612305: : 79it [00:04, 17.20it/s] 
current loss:9.056707382202148: : 79it [00:04, 17.63it/s]


1.3880707025527954


current loss:7.808822154998779: : 2it [00:00, 18.24it/s]

{'epoch': 74, 'test loss': 1.3893109560012817, 'accuracy': 0.5774999856948853, 'temp': 0.10066376108516124, 'beta1': 0.9938415288925171, 'beta2': 0.8527933359146118}


current loss:8.732566833496094: : 79it [00:04, 17.61it/s] 
current loss:8.396478652954102: : 79it [00:04, 17.50it/s]
current loss:8.34264850616455: : 79it [00:04, 17.46it/s]  
current loss:8.687150955200195: : 79it [00:04, 17.37it/s] 
current loss:8.221601486206055: : 79it [00:04, 17.46it/s] 


1.4219697713851929


current loss:8.6785306930542: : 2it [00:00, 17.91it/s]  

{'epoch': 79, 'test loss': 1.4264967441558838, 'accuracy': 0.5694000124931335, 'temp': 0.10063982759529609, 'beta1': 0.9940548539161682, 'beta2': 0.8574373722076416}


current loss:8.90186595916748: : 79it [00:04, 17.38it/s]  
current loss:8.063011169433594: : 79it [00:04, 17.19it/s] 
current loss:8.242015838623047: : 79it [00:04, 17.32it/s] 
current loss:9.690329551696777: : 79it [00:04, 17.42it/s] 
current loss:8.356471061706543: : 79it [00:04, 17.44it/s] 


1.4239037036895752


current loss:8.01689338684082: : 2it [00:00, 18.07it/s] 

{'epoch': 84, 'test loss': 1.4411511421203613, 'accuracy': 0.5684999823570251, 'temp': 0.10060690156096826, 'beta1': 0.9943504929542542, 'beta2': 0.8628915548324585}


current loss:7.379761695861816: : 79it [00:04, 17.31it/s] 
current loss:8.643304824829102: : 79it [00:04, 17.46it/s] 
current loss:7.505954742431641: : 79it [00:04, 17.38it/s] 
current loss:7.1326494216918945: : 79it [00:04, 17.46it/s]
current loss:8.072772979736328: : 79it [00:04, 17.55it/s] 


1.425058364868164


current loss:7.711479187011719: : 2it [00:00, 17.95it/s]

{'epoch': 89, 'test loss': 1.4272652864456177, 'accuracy': 0.5766000151634216, 'temp': 0.10057121201971314, 'beta1': 0.9946755170822144, 'beta2': 0.8683989644050598}


current loss:7.613584995269775: : 79it [00:04, 17.53it/s] 
current loss:8.278158187866211: : 79it [00:04, 17.80it/s] 
current loss:8.253838539123535: : 79it [00:04, 17.32it/s] 
current loss:7.533140659332275: : 79it [00:04, 17.41it/s] 
current loss:8.17845344543457: : 79it [00:04, 17.38it/s]  


1.4087982177734375


current loss:8.054102897644043: : 2it [00:00, 17.53it/s]

{'epoch': 94, 'test loss': 1.3898144960403442, 'accuracy': 0.5740000009536743, 'temp': 0.1005482727152412, 'beta1': 0.994895875453949, 'beta2': 0.8729050755500793}


current loss:8.820046424865723: : 79it [00:04, 17.44it/s] 
current loss:8.20452880859375: : 79it [00:04, 17.41it/s]  
current loss:7.831916809082031: : 79it [00:04, 17.20it/s] 
current loss:7.827820777893066: : 79it [00:04, 17.30it/s] 
current loss:8.872997283935547: : 79it [00:04, 17.19it/s] 


1.40928316116333


current loss:1.9876203536987305: : 2it [00:00, 17.02it/s]

{'epoch': 99, 'test loss': 1.3927905559539795, 'accuracy': 0.5733000040054321, 'temp': 0.10051414796071186, 'beta1': 0.9952371120452881, 'beta2': 0.8780838251113892}


current loss:1.5005288124084473: : 79it [00:04, 17.14it/s]


1.5992683172225952


current loss:1.3969227075576782: : 2it [00:00, 18.16it/s]

{'epoch': 0, 'test loss': 1.5954548120498657, 'accuracy': 0.412200003862381, 'temp': 2.7808780491352083, 'beta1': 0.33312711119651794, 'beta2': 0.5532147884368896}


current loss:1.3318697214126587: : 79it [00:04, 17.41it/s]
current loss:1.243456482887268: : 79it [00:04, 17.27it/s] 
current loss:1.1862386465072632: : 79it [00:04, 17.37it/s]
current loss:1.1601108312606812: : 79it [00:04, 17.41it/s]


1.351966381072998


current loss:1.0749001502990723: : 2it [00:00, 17.21it/s]

{'epoch': 4, 'test loss': 1.3491554260253906, 'accuracy': 0.5335000157356262, 'temp': 2.5015751585364345, 'beta1': 0.3484426736831665, 'beta2': 0.5871961712837219}


current loss:1.142987847328186: : 79it [00:04, 17.57it/s] 
current loss:1.1269235610961914: : 79it [00:04, 17.53it/s]
current loss:1.1112641096115112: : 79it [00:04, 17.56it/s]
current loss:1.0974267721176147: : 79it [00:04, 17.49it/s]
current loss:1.0971723794937134: : 79it [00:04, 18.00it/s]


1.3690646886825562


current loss:0.958907425403595: : 2it [00:00, 17.56it/s]

{'epoch': 9, 'test loss': 1.3595818281173706, 'accuracy': 0.5684999823570251, 'temp': 2.105599582195282, 'beta1': 0.37314245104789734, 'beta2': 0.6237569451332092}


current loss:1.0979527235031128: : 79it [00:04, 17.25it/s]
current loss:1.101043939590454: : 79it [00:04, 17.47it/s] 
current loss:1.0983481407165527: : 79it [00:04, 17.30it/s]
current loss:1.0946415662765503: : 79it [00:04, 17.35it/s]
current loss:1.0888391733169556: : 79it [00:04, 17.29it/s]


1.4547626972198486


current loss:0.9042747616767883: : 2it [00:00, 17.05it/s]

{'epoch': 14, 'test loss': 1.4595491886138916, 'accuracy': 0.5813999772071838, 'temp': 1.7322538033127786, 'beta1': 0.400315523147583, 'beta2': 0.6446808576583862}


current loss:1.0817731618881226: : 79it [00:04, 17.71it/s]
current loss:1.0697282552719116: : 79it [00:04, 17.43it/s]
current loss:1.0534794330596924: : 79it [00:04, 17.37it/s]
current loss:1.0391234159469604: : 79it [00:04, 17.45it/s]
current loss:1.0354092121124268: : 79it [00:04, 17.46it/s]


1.5875221490859985


current loss:0.845422089099884: : 2it [00:00, 17.91it/s] 

{'epoch': 19, 'test loss': 1.5866981744766235, 'accuracy': 0.5821999907493591, 'temp': 1.4832664325833322, 'beta1': 0.4205058515071869, 'beta2': 0.657599925994873}


current loss:1.0361417531967163: : 79it [00:04, 17.38it/s]
current loss:1.031614899635315: : 79it [00:04, 17.27it/s] 
current loss:1.0268189907073975: : 79it [00:04, 17.50it/s]
current loss:1.0355955362319946: : 79it [00:04, 17.46it/s]
current loss:1.0497790575027466: : 79it [00:04, 17.31it/s]


1.725553274154663


current loss:0.8745366930961609: : 2it [00:00, 17.88it/s]

{'epoch': 24, 'test loss': 1.7123384475708008, 'accuracy': 0.5780999660491943, 'temp': 1.1849796645343305, 'beta1': 0.45335376262664795, 'beta2': 0.6706154942512512}


current loss:1.085391879081726: : 79it [00:04, 17.45it/s] 
current loss:1.125555157661438: : 79it [00:04, 17.24it/s] 
current loss:1.1869940757751465: : 79it [00:04, 17.12it/s]
current loss:1.3142344951629639: : 79it [00:04, 17.06it/s]
current loss:1.5835106372833252: : 79it [00:04, 17.40it/s]


2.0876739025115967


current loss:1.5030862092971802: : 2it [00:00, 17.53it/s]

{'epoch': 29, 'test loss': 2.076810598373413, 'accuracy': 0.5299000144004822, 'temp': 0.7519449181854725, 'beta1': 0.5279669165611267, 'beta2': 0.6910529732704163}


current loss:1.8537189960479736: : 79it [00:04, 17.43it/s]
current loss:2.57281756401062: : 79it [00:04, 17.35it/s]  
current loss:3.475846529006958: : 79it [00:04, 17.44it/s] 
current loss:7.173414707183838: : 79it [00:04, 17.55it/s] 
current loss:12.740545272827148: : 79it [00:04, 17.72it/s]


2.350623369216919


current loss:15.09340763092041: : 2it [00:00, 17.46it/s] 

{'epoch': 34, 'test loss': 2.3573100566864014, 'accuracy': 0.4740999937057495, 'temp': 0.17951429830864074, 'beta1': 0.7738242745399475, 'beta2': 0.7295075058937073}


current loss:19.661043167114258: : 79it [00:04, 17.18it/s]
current loss:19.819074630737305: : 79it [00:04, 17.18it/s]
current loss:21.940622329711914: : 79it [00:04, 17.47it/s]
current loss:23.9019832611084: : 79it [00:04, 17.31it/s]  
current loss:23.904666900634766: : 79it [00:04, 17.37it/s]


2.2056961059570312


current loss:23.474130630493164: : 2it [00:00, 18.37it/s]

{'epoch': 39, 'test loss': 2.202270984649658, 'accuracy': 0.5239999890327454, 'temp': 0.10645935547072441, 'beta1': 0.9503583908081055, 'beta2': 0.7803885340690613}


current loss:21.59208106994629: : 79it [00:04, 17.62it/s] 
current loss:20.515945434570312: : 79it [00:04, 17.37it/s]
current loss:21.082500457763672: : 79it [00:04, 17.02it/s]
current loss:21.66577911376953: : 79it [00:04, 17.50it/s] 
current loss:17.93790054321289: : 79it [00:04, 18.00it/s] 


1.9219821691513062


current loss:16.080747604370117: : 2it [00:00, 17.87it/s]

{'epoch': 44, 'test loss': 1.9118198156356812, 'accuracy': 0.5543999671936035, 'temp': 0.10319026241195389, 'beta1': 0.971928060054779, 'beta2': 0.8104509115219116}


current loss:18.547481536865234: : 79it [00:04, 17.44it/s]
current loss:19.493732452392578: : 79it [00:04, 17.39it/s]
current loss:16.883651733398438: : 79it [00:04, 17.20it/s]
current loss:15.786150932312012: : 79it [00:04, 17.37it/s]
current loss:15.474260330200195: : 79it [00:04, 17.62it/s]


1.8127706050872803


current loss:14.534029960632324: : 2it [00:00, 17.99it/s]

{'epoch': 49, 'test loss': 1.8064144849777222, 'accuracy': 0.5728999972343445, 'temp': 0.1022214190219529, 'beta1': 0.979822039604187, 'beta2': 0.8315444588661194}


current loss:15.33851146697998: : 79it [00:04, 17.34it/s] 
current loss:15.234037399291992: : 79it [00:04, 17.35it/s]
current loss:18.292680740356445: : 79it [00:04, 17.37it/s]
current loss:14.301790237426758: : 79it [00:04, 17.26it/s]
current loss:15.550074577331543: : 79it [00:04, 17.65it/s]


1.675970196723938


current loss:13.248035430908203: : 2it [00:00, 17.20it/s]

{'epoch': 54, 'test loss': 1.6667640209197998, 'accuracy': 0.5879999995231628, 'temp': 0.10156287831050577, 'beta1': 0.9852301478385925, 'beta2': 0.8486856818199158}


current loss:14.688562393188477: : 79it [00:04, 17.38it/s]
current loss:14.828277587890625: : 79it [00:04, 17.23it/s]
current loss:14.625328063964844: : 79it [00:04, 16.92it/s]
current loss:13.563840866088867: : 79it [00:04, 17.14it/s]
current loss:13.723840713500977: : 79it [00:04, 17.01it/s]


1.646972894668579


current loss:10.708666801452637: : 2it [00:00, 17.59it/s]

{'epoch': 59, 'test loss': 1.6382004022598267, 'accuracy': 0.5913000106811523, 'temp': 0.1012589908001246, 'beta1': 0.9879960417747498, 'beta2': 0.8594293594360352}


current loss:15.064672470092773: : 79it [00:04, 17.38it/s]
current loss:14.051183700561523: : 79it [00:04, 17.24it/s]
current loss:15.690114974975586: : 79it [00:04, 17.39it/s]
current loss:14.178590774536133: : 79it [00:04, 17.35it/s]
current loss:14.349615097045898: : 79it [00:04, 17.51it/s]


1.7043737173080444


current loss:10.007050514221191: : 2it [00:00, 18.17it/s]

{'epoch': 64, 'test loss': 1.6989179849624634, 'accuracy': 0.5782999992370605, 'temp': 0.10102620548132109, 'beta1': 0.9901912212371826, 'beta2': 0.870535671710968}


current loss:12.813369750976562: : 79it [00:04, 17.38it/s]
current loss:12.869026184082031: : 79it [00:04, 17.27it/s]
current loss:13.337236404418945: : 79it [00:04, 17.36it/s]
current loss:13.604547500610352: : 79it [00:04, 17.40it/s]
current loss:11.411198616027832: : 79it [00:04, 17.52it/s]


1.6449536085128784


current loss:8.527569770812988: : 2it [00:00, 17.37it/s]

{'epoch': 69, 'test loss': 1.6423214673995972, 'accuracy': 0.5791000127792358, 'temp': 0.10091225362557453, 'beta1': 0.9912883043289185, 'beta2': 0.8768554925918579}


current loss:11.870660781860352: : 79it [00:04, 17.26it/s]
current loss:12.278867721557617: : 79it [00:04, 17.28it/s]
current loss:11.368908882141113: : 79it [00:04, 17.29it/s]
current loss:11.511972427368164: : 79it [00:04, 17.01it/s]
current loss:11.618374824523926: : 79it [00:04, 17.40it/s]


1.5818010568618774


current loss:8.183626174926758: : 2it [00:00, 17.71it/s]

{'epoch': 74, 'test loss': 1.5888737440109253, 'accuracy': 0.5884000062942505, 'temp': 0.10083206928175059, 'beta1': 0.9920628666877747, 'beta2': 0.8827288150787354}


current loss:10.619778633117676: : 79it [00:04, 17.66it/s]
current loss:10.743253707885742: : 79it [00:04, 17.27it/s]
current loss:10.424576759338379: : 79it [00:04, 17.43it/s]
current loss:10.409128189086914: : 79it [00:04, 17.42it/s]
current loss:12.108007431030273: : 79it [00:04, 17.60it/s]


1.6229883432388306


current loss:10.071783065795898: : 2it [00:00, 17.37it/s]

{'epoch': 79, 'test loss': 1.6081045866012573, 'accuracy': 0.583299994468689, 'temp': 0.10076136132120156, 'beta1': 0.9927073121070862, 'beta2': 0.8878860473632812}


current loss:11.12480640411377: : 79it [00:04, 17.49it/s] 
current loss:10.230615615844727: : 79it [00:04, 17.37it/s]
current loss:10.555429458618164: : 79it [00:04, 17.16it/s]
current loss:12.191644668579102: : 79it [00:04, 17.19it/s]
current loss:10.595218658447266: : 79it [00:04, 17.25it/s]


1.5699445009231567


current loss:7.222309112548828: : 2it [00:00, 16.98it/s]

{'epoch': 84, 'test loss': 1.555162787437439, 'accuracy': 0.5825999975204468, 'temp': 0.10068489238692564, 'beta1': 0.993415355682373, 'beta2': 0.8925594091415405}


current loss:9.727373123168945: : 79it [00:04, 17.49it/s] 
current loss:9.922648429870605: : 79it [00:04, 17.14it/s] 
current loss:9.46332836151123: : 79it [00:04, 17.35it/s]  
current loss:10.73973560333252: : 79it [00:04, 17.03it/s] 
current loss:9.265539169311523: : 79it [00:04, 17.29it/s] 


1.481425166130066


current loss:6.414854526519775: : 2it [00:00, 18.15it/s]

{'epoch': 89, 'test loss': 1.479904294013977, 'accuracy': 0.5871999859809875, 'temp': 0.10063563021976735, 'beta1': 0.9939071536064148, 'beta2': 0.8974898457527161}


current loss:9.454645156860352: : 79it [00:04, 17.72it/s] 
current loss:8.70400333404541: : 79it [00:04, 17.20it/s]  
current loss:8.680898666381836: : 79it [00:04, 17.33it/s]
current loss:9.497525215148926: : 79it [00:04, 17.16it/s] 
current loss:9.097417831420898: : 79it [00:04, 17.32it/s] 


1.4693427085876465


current loss:6.252307891845703: : 2it [00:00, 17.58it/s] 

{'epoch': 94, 'test loss': 1.473243236541748, 'accuracy': 0.5945999622344971, 'temp': 0.10058792970594368, 'beta1': 0.994339644908905, 'beta2': 0.9008243680000305}


current loss:9.017457962036133: : 79it [00:04, 17.20it/s]
current loss:8.971270561218262: : 79it [00:04, 17.35it/s] 
current loss:9.720405578613281: : 79it [00:04, 17.36it/s]
current loss:8.157176971435547: : 79it [00:04, 17.59it/s] 
current loss:9.009834289550781: : 79it [00:04, 17.35it/s] 


1.4814186096191406


current loss:3.2201220989227295: : 2it [00:00, 17.10it/s]

{'epoch': 99, 'test loss': 1.4732648134231567, 'accuracy': 0.5879999995231628, 'temp': 0.10055283983092522, 'beta1': 0.9946680068969727, 'beta2': 0.9046599864959717}


current loss:2.376424551010132: : 79it [00:04, 17.18it/s] 


1.6793439388275146


current loss:2.2229902744293213: : 2it [00:00, 17.63it/s]

{'epoch': 0, 'test loss': 1.684191346168518, 'accuracy': 0.36730000376701355, 'temp': 2.024406258761883, 'beta1': 0.6150349974632263, 'beta2': 0.44633322954177856}


current loss:1.989084005355835: : 79it [00:04, 17.67it/s] 
current loss:1.8340895175933838: : 79it [00:04, 17.66it/s]
current loss:1.7672102451324463: : 79it [00:04, 17.40it/s]
current loss:1.6953630447387695: : 79it [00:04, 17.26it/s]


1.5162278413772583


current loss:1.6081504821777344: : 2it [00:00, 18.10it/s]

{'epoch': 4, 'test loss': 1.5135022401809692, 'accuracy': 0.5282999873161316, 'temp': 1.7441188097000124, 'beta1': 0.6260484457015991, 'beta2': 0.48420044779777527}


current loss:1.6265742778778076: : 79it [00:04, 17.64it/s]
current loss:1.5581908226013184: : 79it [00:04, 17.39it/s]
current loss:1.5151921510696411: : 79it [00:04, 17.16it/s]
current loss:1.4744462966918945: : 79it [00:04, 17.36it/s]
current loss:1.4489563703536987: : 79it [00:04, 17.44it/s]


1.5952386856079102


current loss:1.3275378942489624: : 2it [00:00, 16.75it/s]

{'epoch': 9, 'test loss': 1.5861977338790894, 'accuracy': 0.5530999898910522, 'temp': 1.6018170669674874, 'beta1': 0.634571373462677, 'beta2': 0.5072849988937378}


current loss:1.432022213935852: : 79it [00:04, 17.51it/s] 
current loss:1.4295490980148315: : 79it [00:04, 17.12it/s]
current loss:1.4397127628326416: : 79it [00:04, 17.30it/s]
current loss:1.437403917312622: : 79it [00:04, 17.36it/s] 
current loss:1.4581263065338135: : 79it [00:04, 17.29it/s]


1.6867806911468506


current loss:1.3056551218032837: : 2it [00:00, 17.67it/s]

{'epoch': 14, 'test loss': 1.6958664655685425, 'accuracy': 0.5548999905586243, 'temp': 1.2997662879526617, 'beta1': 0.6551996469497681, 'beta2': 0.5323436260223389}


current loss:1.492081642150879: : 79it [00:04, 17.34it/s] 
current loss:1.560826063156128: : 79it [00:04, 17.35it/s] 
current loss:1.657751441001892: : 79it [00:04, 17.48it/s] 
current loss:1.7847782373428345: : 79it [00:04, 17.52it/s]
current loss:2.0425941944122314: : 79it [00:04, 17.43it/s]


1.8954963684082031


current loss:1.7992138862609863: : 2it [00:00, 17.55it/s]

{'epoch': 19, 'test loss': 1.9018796682357788, 'accuracy': 0.541700005531311, 'temp': 0.7983975932002068, 'beta1': 0.6957959532737732, 'beta2': 0.5632132887840271}


current loss:2.0321223735809326: : 79it [00:04, 17.56it/s]
current loss:2.2811851501464844: : 79it [00:04, 17.81it/s]
current loss:3.453968048095703: : 79it [00:04, 17.34it/s] 
current loss:6.231575965881348: : 79it [00:04, 17.45it/s] 
current loss:12.061065673828125: : 79it [00:04, 17.31it/s]


2.6783196926116943


current loss:10.887161254882812: : 2it [00:00, 17.58it/s]

{'epoch': 24, 'test loss': 2.6615688800811768, 'accuracy': 0.45879998803138733, 'temp': 0.20807000659406186, 'beta1': 0.8178674578666687, 'beta2': 0.6029783487319946}


current loss:18.086135864257812: : 79it [00:04, 17.51it/s]
current loss:23.948457717895508: : 79it [00:04, 17.37it/s]
current loss:28.83186912536621: : 79it [00:04, 17.06it/s] 
current loss:25.5826473236084: : 79it [00:04, 17.37it/s]  
current loss:26.459346771240234: : 79it [00:04, 17.42it/s]


2.2847487926483154


current loss:22.437896728515625: : 2it [00:00, 17.87it/s]

{'epoch': 29, 'test loss': 2.275291681289673, 'accuracy': 0.49889999628067017, 'temp': 0.10627972384681926, 'beta1': 0.9553734660148621, 'beta2': 0.6781066656112671}


current loss:24.99847984313965: : 79it [00:04, 17.44it/s] 
current loss:21.82062530517578: : 79it [00:04, 17.58it/s] 
current loss:20.436946868896484: : 79it [00:04, 17.15it/s]
current loss:20.062946319580078: : 79it [00:04, 17.25it/s]
current loss:18.77225685119629: : 79it [00:04, 17.29it/s] 


1.9158421754837036


current loss:16.6851806640625: : 2it [00:00, 17.33it/s]  

{'epoch': 34, 'test loss': 1.9134944677352905, 'accuracy': 0.5496999621391296, 'temp': 0.10302779744961299, 'beta1': 0.974614679813385, 'beta2': 0.7281284928321838}


current loss:17.934062957763672: : 79it [00:04, 17.50it/s]
current loss:20.229228973388672: : 79it [00:04, 17.70it/s]
current loss:18.021320343017578: : 79it [00:04, 17.42it/s]
current loss:17.733652114868164: : 79it [00:04, 17.13it/s]
current loss:16.23242950439453: : 79it [00:04, 17.42it/s] 


1.8377934694290161


current loss:15.267830848693848: : 2it [00:00, 18.11it/s]

{'epoch': 39, 'test loss': 1.8464224338531494, 'accuracy': 0.5551999807357788, 'temp': 0.10194569284649334, 'beta1': 0.9824831485748291, 'beta2': 0.7607954144477844}


current loss:15.926977157592773: : 79it [00:04, 17.55it/s]
current loss:15.227025985717773: : 79it [00:04, 17.59it/s]
current loss:14.736856460571289: : 79it [00:04, 17.14it/s]
current loss:15.271665573120117: : 79it [00:04, 17.11it/s]
current loss:15.378694534301758: : 79it [00:04, 17.37it/s]


1.8055354356765747


current loss:12.3842191696167: : 2it [00:00, 17.13it/s]  

{'epoch': 44, 'test loss': 1.8051515817642212, 'accuracy': 0.5620999932289124, 'temp': 0.10152188346983167, 'beta1': 0.985927939414978, 'beta2': 0.7828115224838257}


current loss:15.09801197052002: : 79it [00:04, 17.63it/s] 
current loss:13.269538879394531: : 79it [00:04, 17.28it/s]
current loss:13.280378341674805: : 79it [00:04, 17.48it/s]
current loss:14.450419425964355: : 79it [00:04, 17.32it/s]
current loss:14.237818717956543: : 79it [00:04, 17.30it/s]


1.6510016918182373


current loss:12.412138938903809: : 2it [00:00, 17.76it/s]

{'epoch': 49, 'test loss': 1.644473671913147, 'accuracy': 0.5787999629974365, 'temp': 0.1013780645618681, 'beta1': 0.9873884916305542, 'beta2': 0.7998993992805481}


current loss:14.40449333190918: : 79it [00:04, 17.26it/s] 
current loss:14.51506233215332: : 79it [00:04, 17.48it/s] 
current loss:12.551908493041992: : 79it [00:04, 17.64it/s]
current loss:14.115409851074219: : 79it [00:04, 17.47it/s]
current loss:13.815803527832031: : 79it [00:04, 17.26it/s]


1.6568082571029663


current loss:9.870450019836426: : 2it [00:00, 18.11it/s] 

{'epoch': 54, 'test loss': 1.6552788019180298, 'accuracy': 0.578499972820282, 'temp': 0.10113004541562987, 'beta1': 0.9894604682922363, 'beta2': 0.8156688809394836}


current loss:13.542129516601562: : 79it [00:04, 17.23it/s]
current loss:11.883852005004883: : 79it [00:04, 17.38it/s]
current loss:12.641687393188477: : 79it [00:04, 17.41it/s]
current loss:13.092538833618164: : 79it [00:04, 17.52it/s]
current loss:12.280055046081543: : 79it [00:04, 17.60it/s]


1.599993348121643


current loss:9.474387168884277: : 2it [00:00, 17.58it/s] 

{'epoch': 59, 'test loss': 1.594719409942627, 'accuracy': 0.5788999795913696, 'temp': 0.10101106054135017, 'beta1': 0.9905807375907898, 'beta2': 0.8269152641296387}


current loss:12.470891952514648: : 79it [00:04, 17.36it/s]
current loss:12.53990650177002: : 79it [00:04, 17.47it/s] 
current loss:11.99651050567627: : 79it [00:04, 17.64it/s] 
current loss:11.880683898925781: : 79it [00:04, 17.60it/s]
current loss:12.7490234375: : 79it [00:04, 17.36it/s]     


1.6730968952178955


current loss:9.331271171569824: : 2it [00:00, 16.90it/s]

{'epoch': 64, 'test loss': 1.6589418649673462, 'accuracy': 0.5697000026702881, 'temp': 0.10088979513457162, 'beta1': 0.9916535019874573, 'beta2': 0.8368543386459351}


current loss:12.576833724975586: : 79it [00:04, 17.18it/s]
current loss:12.453725814819336: : 79it [00:04, 17.42it/s]
current loss:11.291839599609375: : 79it [00:04, 17.37it/s]
current loss:10.84695816040039: : 79it [00:04, 17.22it/s] 
current loss:11.325445175170898: : 79it [00:04, 17.47it/s]


1.5518059730529785


current loss:8.069636344909668: : 2it [00:00, 17.89it/s]

{'epoch': 69, 'test loss': 1.5549708604812622, 'accuracy': 0.583299994468689, 'temp': 0.10082868255412905, 'beta1': 0.9922676682472229, 'beta2': 0.8465946912765503}


current loss:10.732223510742188: : 79it [00:04, 17.64it/s]
current loss:12.201116561889648: : 79it [00:04, 17.35it/s]
current loss:10.404167175292969: : 79it [00:04, 17.45it/s]
current loss:10.758280754089355: : 79it [00:04, 17.45it/s]
current loss:10.069955825805664: : 79it [00:04, 17.62it/s]


1.5480557680130005


current loss:7.46298360824585: : 2it [00:00, 18.15it/s] 

{'epoch': 74, 'test loss': 1.5521162748336792, 'accuracy': 0.5870000123977661, 'temp': 0.10076037959734095, 'beta1': 0.9928790926933289, 'beta2': 0.8547149300575256}


current loss:9.595479965209961: : 79it [00:04, 17.51it/s] 
current loss:10.01147174835205: : 79it [00:04, 17.56it/s] 
current loss:10.717489242553711: : 79it [00:04, 17.49it/s]
current loss:10.779531478881836: : 79it [00:04, 17.17it/s]
current loss:10.244649887084961: : 79it [00:04, 17.27it/s]


1.5656206607818604


current loss:8.7559814453125: : 2it [00:00, 17.73it/s]  

{'epoch': 79, 'test loss': 1.5635879039764404, 'accuracy': 0.5733000040054321, 'temp': 0.10071124701571535, 'beta1': 0.9933211207389832, 'beta2': 0.8616700768470764}


current loss:10.438821792602539: : 79it [00:04, 17.46it/s]
current loss:10.100204467773438: : 79it [00:04, 17.49it/s]
current loss:9.62955093383789: : 79it [00:04, 17.45it/s]  
current loss:9.606049537658691: : 79it [00:04, 17.41it/s] 
current loss:9.35466480255127: : 79it [00:04, 17.38it/s]  


1.5557118654251099


current loss:8.031904220581055: : 2it [00:00, 17.72it/s]

{'epoch': 84, 'test loss': 1.5594738721847534, 'accuracy': 0.5812000036239624, 'temp': 0.10066903433034896, 'beta1': 0.9937485456466675, 'beta2': 0.8683751225471497}


current loss:9.375593185424805: : 79it [00:04, 17.40it/s] 
current loss:9.628148078918457: : 79it [00:04, 17.48it/s] 
current loss:9.811529159545898: : 79it [00:04, 17.43it/s] 
current loss:10.106728553771973: : 79it [00:04, 17.34it/s]
current loss:9.27039909362793: : 79it [00:04, 17.24it/s] 


1.4858653545379639


current loss:7.096207141876221: : 2it [00:00, 17.72it/s]

{'epoch': 89, 'test loss': 1.4765255451202393, 'accuracy': 0.5812000036239624, 'temp': 0.10062017359741732, 'beta1': 0.9941619038581848, 'beta2': 0.8745586276054382}


current loss:8.893206596374512: : 79it [00:04, 17.45it/s]
current loss:9.465353965759277: : 79it [00:04, 17.61it/s] 
current loss:8.891962051391602: : 79it [00:04, 17.27it/s]
current loss:8.629135131835938: : 79it [00:04, 17.33it/s]
current loss:8.205286979675293: : 79it [00:04, 17.34it/s] 


1.4912112951278687


current loss:6.338035583496094: : 2it [00:00, 17.18it/s]

{'epoch': 94, 'test loss': 1.4971920251846313, 'accuracy': 0.5813999772071838, 'temp': 0.10059259139961796, 'beta1': 0.9944108128547668, 'beta2': 0.8793951869010925}


current loss:8.705863952636719: : 79it [00:04, 17.42it/s] 
current loss:8.156615257263184: : 79it [00:04, 17.38it/s] 
current loss:8.352819442749023: : 79it [00:04, 17.45it/s] 
current loss:7.964221000671387: : 79it [00:04, 17.29it/s] 
current loss:8.589632034301758: : 79it [00:04, 17.45it/s] 


1.454845905303955
{'epoch': 99, 'test loss': 1.454610824584961, 'accuracy': 0.5884999632835388, 'temp': 0.10057063108179137, 'beta1': 0.9946000576019287, 'beta2': 0.8833026885986328}


In [1]:
with open("exp6_basic.jsonl", "r") as read_file:
    data_b = [json.loads(line) for line in read_file]
with open("exp6_distill.jsonl", "r") as read_file:
    data_d = [json.loads(line) for line in read_file]
with open("exp6_dist_h_rand.jsonl", "r") as read_file:
    data_dr = [json.loads(line) for line in read_file]
with open("exp6_dist_h_optim.jsonl", "r") as read_file:
    data_h = [json.loads(line) for line in read_file]

FileNotFoundError: [Errno 2] No such file or directory: 'exp6_basic.jsonl'

In [None]:
from matplotlib import pylab as plt
plt.rcParams['font.family'] = 'DejaVu Serif'
plt.rcParams['lines.linewidth'] = 2
plt.rcParams['lines.markersize'] = 12
plt.rcParams['xtick.labelsize'] = 24
plt.rcParams['ytick.labelsize'] = 24
plt.rcParams['legend.fontsize'] = 24
plt.rcParams['axes.titlesize'] = 36
plt.rcParams['axes.labelsize'] = 24

epoch_b = np.array([data_b[2]['results'][i]['epoch'] for i in range(len(data_b[0]['results']))])
loss_b = np.array([subdata['results'][i]['test loss'] for i in range(len(data_b[0]['results'])) for subdata in data_b]).reshape(epoch_b.shape[0], -1)
plt.plot(epoch_b, loss_b.mean(1), '-', color='red', label='без дистилляции')
plt.fill_between(epoch_b, loss_b.mean(1)-loss_b.std(1), loss_b.mean(1)+loss_b.std(1), alpha=0.2, color='red')

epoch_d = np.array([data_d[2]['results'][i]['epoch'] for i in range(len(data_d[2]['results']))])
loss_d = np.array([subdata['results'][i]['test loss'] for i in range(len(data_d[0]['results'])) for subdata in data_d]).reshape(epoch_d.shape[0], -1)
plt.plot(epoch_d, loss_d.mean(1), '-', color='blue', label='оптимальные гипепараметров')
plt.fill_between(epoch_d, loss_d.mean(1)-loss_d.std(1), loss_d.mean(1)+loss_d.std(1), alpha=0.2, color='blue')

epoch_dr = np.array([data_dr[2]['results'][i]['epoch'] for i in range(len(data_dr[2]['results']))])
loss_dr = np.array([subdata['results'][i]['test loss'] for i in range(len(data_dr[0]['results'])) for subdata in data_dr]).reshape(epoch_dr.shape[0], -1)
plt.plot(epoch_dr, loss_dr.mean(1), '-', color='black', label='случайные гипепараметры')
plt.fill_between(epoch_dr, loss_dr.mean(1)-loss_dr.std(1), loss_dr.mean(1)+loss_dr.std(1), alpha=0.2, color='black')


epoch_h = np.array([data_dr[2]['results'][i]['epoch'] for i in range(len(data_h[2]['results']))])
loss_h = np.array([subdata['results'][i]['test loss'] for i in range(len(data_d[0]['results'])) for subdata in data_h]).reshape(epoch_h.shape[0], -1)
plt.plot(epoch_h, loss_h.mean(1), '-', color='green', label='оптимизация гипепараметры')
plt.fill_between(epoch_h, loss_h.mean(1)-loss_h.std(1), loss_h.mean(1)+loss_h.std(1), alpha=0.2, color='green')

plt.xlabel('Количество эпох')
plt.ylabel('Потеря на тестовой выборке')

plt.legend()
plt.savefig('loss.pdf')

In [None]:
epoch_b = np.array([data_b[2]['results'][i]['epoch'] for i in range(len(data_b[0]['results']))])
acc_b = np.array([subdata['results'][i]['accuracy'] for i in range(len(data_b[0]['results'])) for subdata in data_b]).reshape(epoch_b.shape[0], -1)
plt.plot(epoch_b, acc_b.mean(1), '-', color='red', label='без дистилляции')
plt.fill_between(epoch_b, acc_b.mean(1)-acc_b.std(1), acc_b.mean(1)+acc_b.std(1), alpha=0.2, color='red')

epoch_d = np.array([data_d[2]['results'][i]['epoch'] for i in range(len(data_d[2]['results']))])
acc_d = np.array([subdata['results'][i]['accuracy'] for i in range(len(data_d[0]['results'])) for subdata in data_d]).reshape(epoch_d.shape[0], -1)
plt.plot(epoch_d, acc_d.mean(1), '-', color='blue', label='оптимальные гипепараметры')
plt.fill_between(epoch_d, acc_d.mean(1)-acc_d.std(1), acc_d.mean(1)+acc_d.std(1), alpha=0.2, color='blue')

epoch_h = np.array([data_h[2]['results'][i]['epoch'] for i in range(len(data_h[2]['results']))])
acc_h = np.array([subdata['results'][i]['accuracy'] for i in range(len(data_d[0]['results'])) for subdata in data_h]).reshape(epoch_h.shape[0], -1)
plt.plot(epoch_h, acc_h.mean(1), '-', color='green', label='оптимизация гиперпараметров')
plt.fill_between(epoch_h, acc_h.mean(1)-acc_h.std(1), acc_h.mean(1)+acc_h.std(1), alpha=0.2, color='green')

epoch_dr = np.array([data_dr[2]['results'][i]['epoch'] for i in range(len(data_dr[2]['results']))])
acc_dr = np.array([subdata['results'][i]['accuracy'] for i in range(len(data_dr[0]['results'])) for subdata in data_dr]).reshape(epoch_h.shape[0], -1)
plt.plot(epoch_dr, acc_dr.mean(1), '-', color='black', label='случайные гиперпараметры')
plt.fill_between(epoch_dr, acc_dr.mean(1)-acc_h.std(1), acc_dr.mean(1)+acc_dr.std(1), alpha=0.2, color='black')


plt.xlabel('Количество эпох')
plt.ylabel('Точность классификации')
plt.legend()
plt.savefig('acc.pdf')

In [None]:
epoch_b = np.hstack((epoch_b, epoch_b, epoch_b, epoch_b, epoch_b))

In [None]:
epoch_b = np.array([data_b[2]['results'][i]['epoch'] for i in range(len(data_b[0]['results']))])
epoch_b.reshape(41, 1)
epoch_b = np.hstack((epoch_b, epoch_b, epoch_b, epoch_b, epoch_b))
loss_b = np.array([subdata['results'][i]['test loss'] for i in range(len(data_b[0]['results'])) for subdata in data_b]).reshape(epoch_b.shape[0], -1)
plt.scatter(epoch_b, loss_b, color='red', marker='.', label='без дистилляции')
#plt.fill_between(epoch_b, loss_b.mean(1)-loss_b.std(1), loss_b.mean(1)+loss_b.std(1), alpha=0.2, color='red')

epoch_d = np.array([data_d[2]['results'][i]['epoch'] for i in range(len(data_d[2]['results']))])
epoch_d.reshape(41, 1)
epoch_d = np.hstack((epoch_d, epoch_d, epoch_d, epoch_d, epoch_d))
loss_d = np.array([subdata['results'][i]['test loss'] for i in range(len(data_d[0]['results'])) for subdata in data_d]).reshape(epoch_d.shape[0], -1)
plt.scatter(epoch_d, loss_d, marker='d', color='blue', label='оптимальные гипепараметры')
#plt.fill_between(epoch_d, loss_d.mean(1)-loss_d.std(1), loss_d.mean(1)+loss_d.std(1), alpha=0.2, color='blue')

epoch_dr = np.array([data_dr[2]['results'][i]['epoch'] for i in range(len(data_dr[2]['results']))])
epoch_dr.reshape(41, 1)
epoch_dr = np.hstack((epoch_dr, epoch_dr, epoch_dr, epoch_dr, epoch_dr))
loss_dr = np.array([subdata['results'][i]['test loss'] for i in range(len(data_dr[0]['results'])) for subdata in data_dr]).reshape(epoch_dr.shape[0], -1)
plt.scatter(epoch_dr, loss_dr, marker='x', color='black', label='случайные гипепараметры')
#plt.fill_between(epoch_dr, loss_dr.mean(1)-loss_dr.std(1), loss_dr.mean(1)+loss_dr.std(1), alpha=0.2, color='black')


epoch_h = np.array([data_dr[2]['results'][i]['epoch'] for i in range(len(data_h[2]['results']))])
epoch_h.reshape(41, 1)
epoch_h = np.hstack((epoch_h, epoch_h, epoch_h, epoch_h, epoch_h))
loss_h = np.array([subdata['results'][i]['test loss'] for i in range(len(data_d[0]['results'])) for subdata in data_h]).reshape(epoch_h.shape[0], -1)
plt.scatter(epoch_h, loss_h, marker='+', color='green', label='оптимизация гипепараметров')
#plt.fill_between(epoch_h, loss_h.mean(1)-loss_h.std(1), loss_h.mean(1)+loss_h.std(1), alpha=0.2, color='green')

plt.xlabel('Количество эпох')
plt.ylabel('Потеря на тестовой выборке')
plt.legend()
plt.savefig('scatter_plot_loss.pdf')

In [None]:
epoch_d = np.array([data_d[2]['results'][i]['epoch'] for i in range(len(data_d[2]['results']))])
beta_d = np.array([data_d[2]['results'][i]['beta'] for i in range(len(data_d[2]['results']))])
plt.plot(epoch_d, beta_d, '-', color='blue', label='дистилляция без оптимизации гипепараметров')
plt.fill_between(epoch_d, beta_d-beta_d.std(), beta_d+beta_d.std(), alpha=0.2, color='blue')

epoch_h = np.array([data_h[2]['results'][i]['epoch'] for i in range(len(data_h[2]['results']))])
beta_h = np.array([data_h[2]['results'][i]['beta'] for i in range(len(data_h[2]['results']))])
plt.plot(epoch_h, beta_h, '-', color='green', label='дистилляция с оптимизацией гипепараметров')
plt.fill_between(epoch_h, beta_h-beta_h.std(), beta_h+beta_h.std(), alpha=0.2, color='green')

plt.legend()
plt.savefig('3.eps')

In [None]:
epoch_d = np.array([data_d[2]['results'][i]['epoch'] for i in range(len(data_d[2]['results']))])
temp_d = np.array([data_d[2]['results'][i]['temp'] for i in range(len(data_d[2]['results']))])
plt.plot(epoch_d, temp_d, '-', color='blue', label='дистилляция без оптимизации гипепараметров')
plt.fill_between(epoch_d, temp_d-temp_d.std(), temp_d+temp_d.std(), alpha=0.2, color='blue')

epoch_h = np.array([data_h[2]['results'][i]['epoch'] for i in range(len(data_h[2]['results']))])
temp_h = np.array([data_h[2]['results'][i]['temp'] for i in range(len(data_h[2]['results']))])
plt.plot(epoch_h, temp_h, '-', color='green', label='дистилляция с оптимизацией гипепараметров')
plt.fill_between(temp_h, temp_h-temp_h.std(), temp_h+temp_h.std(), alpha=0.2, color='green')

plt.legend()
plt.savefig('4.eps')

In [None]:
l[0]

In [None]:
cm.seismic(l[0])

In [None]:
acc_dr = np.array([subdata['results'][i]['accuracy'] for i in range(len(data_dr[0]['results'])) for subdata in data_dr]).reshape(epoch_dr.shape[0], -1)
acc_h = np.array([subdata['results'][i]['accuracy'] for i in range(len(data_h[0]['results'])) for subdata in data_h]).reshape(epoch_h.shape[0], -1)
all_results = list(acc_dr) + list(acc_h)
max_ = np.max(all_results)
min_ = np.min(all_results)

colors = [cm.seismic((r-min_)/(max_-min_)) for r in acc_dr.flatten()]
temp_dr = np.array([subdata['results'][i]['temp'] for i in range(len(data_dr[0]['results'])) for subdata in data_dr]).reshape(epoch_dr.shape[0], -1)
beta_dr = np.array([subdata['results'][i]['beta'] for i in range(len(data_dr[0]['results'])) for subdata in data_dr]).reshape(epoch_dr.shape[0], -1)
plt.scatter(beta_dr.flatten(), temp_dr.flatten(), marker='d', c=colors, label='случайные гипепараметры')

colors = [cm.seismic((r-min_)/(max_-min_)) for r in acc_h.flatten()]
temp_h = np.array([subdata['results'][i]['temp'] for i in range(len(data_h[0]['results'])) for subdata in data_h]).reshape(epoch_h.shape[0], -1)
beta_h = np.array([subdata['results'][i]['beta'] for i in range(len(data_h[0]['results'])) for subdata in data_h]).reshape(epoch_h.shape[0], -1)
plt.scatter(beta_h, temp_h, marker='x', c=colors, label='оптимизация гипепараметров')

plt.xlabel('beta')
plt.ylabel('$T_0$')
plt.legend()
plt.savefig('scatter_plot_beta_temp.pdf')

In [None]:
max_

In [None]:
"""
посмотреть, куда сходятся гиперпараметры.
Задача скорее всего невыпуклая по гиперпараметрам, поэтому может быть несколько точек экстремума.

Взять одно, наилучшее значение гиперпараметров.

Посчитать дистилляцию БЕЗ оптимизации гиперпараметров с наилушчими значениями.

НЕ ЗАБУДЬ ПОМЕНЯТЬ ИМЯ ФАЙЛА ДЛЯ СОХРАНЕНИЯ
"""

In [None]:
"""
Посчитать дистилляцию с оптимизацей гиперпараметров, в качестве начальной точки взять не случайные значения,
а start_beta, start_temp.

НЕ ЗАБУДЬ ПОМЕНЯТЬ ИМЯ ФАЙЛА ДЛЯ СОХРАНЕНИЯ
"""

In [None]:
"""
Построить график функции потерь на тесте в зависимости от эпохи. 
На графике должны быть линии для :
    - оптимизации без дистилляции
    - оптимизации с дистилляцией без оптимизации гиперпараметров, значения соответсвутют start_temp, start_beta
    - оптимизации с дистилляцией без оптимизации гиперпараметров, значения соответсвутют оптимизированным значениям гиперпараметров
    - оптимизации с дистилляцией c оптимизацией гиперпараметров, начальное приближение соответсвуeт start_temp, start_beta
    - оптимизации с дистилляцией c оптимизацией гиперпараметров, начальное приближение случайное
"""

In [None]:
"""
Построить график точности на тесте в зависимости от эпохи. 
На графике должны быть линии для :
    - оптимизации без дистилляции
    - оптимизации с дистилляцией без оптимизации гиперпараметров, значения соответсвутют start_temp, start_beta
    - оптимизации с дистилляцией без оптимизации гиперпараметров, значения соответсвутют оптимизированным значениям гиперпараметров
    - оптимизации с дистилляцией c оптимизацией гиперпараметров, начальное приближение соответсвуeт start_temp, start_beta
    - оптимизации с дистилляцией c оптимизацией гиперпараметров, начальное приближение случайное
"""

In [None]:
"""
Построить график беты в зависимости от эпохи. 
На графике должны быть линии для :    
    - оптимизации с дистилляцией без оптимизации гиперпараметров, значения соответсвутют start_temp, start_beta
    - оптимизации с дистилляцией без оптимизации гиперпараметров, значения соответсвутют оптимизированным значениям гиперпараметров
    - оптимизации с дистилляцией c оптимизацией гиперпараметров, начальное приближение соответсвуeт start_temp, start_beta
    - оптимизации с дистилляцией c оптимизацией гиперпараметров, начальное приближение случайное
"""

In [None]:
"""
Построить график температуры в зависимости от эпохи. 
На графике должны быть линии для :    
    - оптимизации с дистилляцией без оптимизации гиперпараметров, значения соответсвутют start_temp, start_beta
    - оптимизации с дистилляцией без оптимизации гиперпараметров, значения соответсвутют оптимизированным значениям гиперпараметров
    - оптимизации с дистилляцией c оптимизацией гиперпараметров, начальное приближение соответсвуeт start_temp, start_beta
    - оптимизации с дистилляцией c оптимизацией гиперпараметров, начальное приближение случайное
"""