In [2]:
from resnet import *
from cifar_very_tiny import *
from cifar_dataset import *    
import torch as t 
import numpy as np
import tqdm
import matplotlib.pylab as plt
import matplotlib.cm as cm
%matplotlib inline
plt.rcParams['figure.figsize']=(12,9)
plt.rcParams['font.size']= 20

In [3]:
train_loader, test_loader, train_loader_no_augumentation = cifar10_loader(batch_size=128)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz


100.0%

Extracting ../data/cifar-10-python.tar.gz to ../data
Files already downloaded and verified
Files already downloaded and verified


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epoch_num = 50
basic_results = [] # результаты без дистилляции. Каждый элемент списка - accuracy
resnet_results = [] # результаты с дистилляцией ResNet. Каждый элемент списка - кортеж вида (accuracy, beta, temp)
cnn_results = [] # результаты с дистилляцией CNN. Каждый элемент списка - кортеж вида (accuracy, beta, temp)

In [5]:
def accuracy(student):
        total = 0 
        correct = 0
        with t.no_grad():
            for x,y in test_loader:
                x = x.to(device)
                y = y.to(device)
                out = student(x)
                correct += t.eq(t.argmax(out, 1), y).sum()
                total+=len(x)
        return (correct/total).cpu().detach().numpy()

In [6]:
# 5 запусков --- без дистилляции

In [None]:
for _ in range(5):
    student = Cifar_Very_Tiny(10).to(device)
    optim = t.optim.Adam(student.parameters())
    #optim = t.optim.SGD(student.parameters(), lr=0.001)
    crit = nn.CrossEntropyLoss()
    for _ in range(epoch_num):
        tq = tqdm.tqdm(train_loader_no_augumentation)
        losses = []
        for x,y in tq:
            x = x.to(device)
            y = y.to(device)
            student.zero_grad()
            out = student(x)
            loss = crit(student(x), y)
            losses.append(loss.cpu().detach().numpy())
            loss.backward()
            optim.step()
            tq.set_description('current loss:{}'.format(np.mean(losses[-10:])))
    basic_results.append(accuracy(student))
    print ('accuracy', basic_results[-1])

current loss:2.261925458908081: 100%|██████████| 391/391 [00:27<00:00, 14.13it/s] 
current loss:2.19840145111084: 100%|██████████| 391/391 [00:27<00:00, 14.39it/s]  
current loss:2.1343493461608887: 100%|██████████| 391/391 [00:26<00:00, 14.56it/s]
current loss:2.072148561477661: 100%|██████████| 391/391 [00:26<00:00, 14.59it/s] 
current loss:2.012829303741455: 100%|██████████| 391/391 [00:26<00:00, 14.56it/s] 
current loss:1.9595016241073608: 100%|██████████| 391/391 [00:26<00:00, 14.66it/s]
current loss:1.911120057106018: 100%|██████████| 391/391 [00:26<00:00, 14.62it/s] 
current loss:1.8669869899749756: 100%|██████████| 391/391 [00:27<00:00, 14.28it/s]
current loss:1.8269790410995483: 100%|██████████| 391/391 [00:26<00:00, 14.54it/s]
current loss:1.7901004552841187: 100%|██████████| 391/391 [00:27<00:00, 14.28it/s]
current loss:1.78475022315979:  36%|███▋      | 142/391 [00:09<00:08, 27.83it/s]  

In [7]:
kl = nn.KLDivLoss(reduction='batchmean')
sm = nn.Softmax(dim=1)

def distill(out, batch_logits, temp):
    g = sm(out/temp)
    f = F.log_softmax(batch_logits/temp)    
    return kl(f, g)

In [12]:
# 20 запусков --- с ResNet
logits = np.load('./logits_resnet.npy')
for _ in range(20):
    beta = np.random.uniform()
    temp = 10**(np.random.uniform(low=-1, high=1)) # температура от 0.1 до 10
    print ('hyperparameters', beta, temp)
    student = Cifar_Very_Tiny(10).to(device)
    optim = t.optim.Adam(student.parameters())
    #optim = t.optim.SGD(student.parameters(), lr=0.001)
    crit = nn.CrossEntropyLoss()
    for _ in range(epoch_num):
        tq = tqdm.tqdm(train_loader_no_augumentation)
        losses = []
        for batch_id, (x,y) in enumerate(tq):
            x = x.to(device)
            y = y.to(device)
            batch_logits = t.Tensor(logits[128*batch_id:128*(batch_id+1)]).to(device)            
            student.zero_grad()
            out = student(x)
            student_loss = crit(student(x), y)
            #distillation_loss = 0.0 # здесь твой код!
            distillation_loss = distill(out, batch_logits, temp)
            loss = (1-beta) * student_loss + beta*distillation_loss
            losses.append(loss.cpu().detach().numpy())
            loss.backward()
            optim.step()
            tq.set_description('current loss:{}'.format(np.mean(losses[-10:])))
    resnet_results.append((accuracy(student), beta, temp))
    print ('accuracy', resnet_results[-1])

  0%|          | 0/391 [00:00<?, ?it/s]

hyperparameters 0.11958219685629079 0.17182169343405432


current loss:1.733682632446289: 100%|██████████| 391/391 [00:10<00:00, 35.72it/s]
current loss:1.7021169662475586: 100%|██████████| 391/391 [00:11<00:00, 35.34it/s]
current loss:1.6770436763763428: 100%|██████████| 391/391 [00:11<00:00, 35.26it/s]
current loss:1.6521902084350586: 100%|██████████| 391/391 [00:10<00:00, 35.78it/s]
current loss:1.6224422454833984: 100%|██████████| 391/391 [00:10<00:00, 35.75it/s]
current loss:1.5858001708984375: 100%|██████████| 391/391 [00:10<00:00, 36.20it/s]
current loss:1.5463708639144897: 100%|██████████| 391/391 [00:10<00:00, 35.78it/s]
current loss:1.5095635652542114: 100%|██████████| 391/391 [00:10<00:00, 35.86it/s]
current loss:1.4784539937973022: 100%|██████████| 391/391 [00:11<00:00, 35.29it/s]
current loss:1.4535796642303467: 100%|██████████| 391/391 [00:10<00:00, 35.89it/s]
current loss:1.4331159591674805: 100%|██████████| 391/391 [00:10<00:00, 35.63it/s]
current loss:1.415586233139038: 100%|██████████| 391/391 [00:10<00:00, 35.84it/s]
curren

accuracy (array(0.51199996, dtype=float32), 0.11958219685629079, 0.17182169343405432)
hyperparameters 0.7932574995506428 2.6571499396075287


current loss:-1.431886076927185: 100%|██████████| 391/391 [00:10<00:00, 36.06it/s]
current loss:-1.4360668659210205: 100%|██████████| 391/391 [00:10<00:00, 36.38it/s]
current loss:-1.4395372867584229: 100%|██████████| 391/391 [00:11<00:00, 35.35it/s]
current loss:-1.4427677392959595: 100%|██████████| 391/391 [00:10<00:00, 36.27it/s]
current loss:-1.445885419845581: 100%|██████████| 391/391 [00:10<00:00, 36.40it/s]
current loss:-1.4490022659301758: 100%|██████████| 391/391 [00:10<00:00, 36.05it/s]
current loss:-1.4521690607070923: 100%|██████████| 391/391 [00:10<00:00, 36.20it/s]
current loss:-1.4554526805877686: 100%|██████████| 391/391 [00:10<00:00, 35.68it/s]
current loss:-1.4587643146514893: 100%|██████████| 391/391 [00:10<00:00, 35.57it/s]
current loss:-1.4620280265808105: 100%|██████████| 391/391 [00:10<00:00, 36.05it/s]
current loss:-1.4652376174926758: 100%|██████████| 391/391 [00:10<00:00, 36.17it/s]
current loss:-1.46833336353302: 100%|██████████| 391/391 [00:10<00:00, 36.32it

accuracy (array(0.3902, dtype=float32), 0.7932574995506428, 2.6571499396075287)
hyperparameters 0.9920685676830809 0.6343817086326983


current loss:-2.347463369369507: 100%|██████████| 391/391 [00:10<00:00, 36.94it/s]
current loss:-2.35337495803833: 100%|██████████| 391/391 [00:10<00:00, 36.78it/s]
current loss:-2.3556501865386963: 100%|██████████| 391/391 [00:10<00:00, 36.52it/s]
current loss:-2.3571090698242188: 100%|██████████| 391/391 [00:10<00:00, 36.90it/s]
current loss:-2.3582098484039307: 100%|██████████| 391/391 [00:10<00:00, 37.11it/s]
current loss:-2.3591067790985107: 100%|██████████| 391/391 [00:10<00:00, 37.16it/s]
current loss:-2.359865665435791: 100%|██████████| 391/391 [00:10<00:00, 36.53it/s]
current loss:-2.3605077266693115: 100%|██████████| 391/391 [00:10<00:00, 36.79it/s]
current loss:-2.3610587120056152: 100%|██████████| 391/391 [00:10<00:00, 36.24it/s]
current loss:-2.361546039581299: 100%|██████████| 391/391 [00:10<00:00, 36.17it/s]
current loss:-2.361983060836792: 100%|██████████| 391/391 [00:10<00:00, 36.64it/s]
current loss:-2.3623762130737305: 100%|██████████| 391/391 [00:10<00:00, 36.85it/s

accuracy (array(0.2794, dtype=float32), 0.9920685676830809, 0.6343817086326983)
hyperparameters 0.7436227300342308 2.139444215767114


current loss:-1.1925791501998901: 100%|██████████| 391/391 [00:10<00:00, 36.74it/s]
current loss:-1.2029459476470947: 100%|██████████| 391/391 [00:10<00:00, 36.29it/s]
current loss:-1.2106733322143555: 100%|██████████| 391/391 [00:10<00:00, 36.16it/s]
current loss:-1.2171456813812256: 100%|██████████| 391/391 [00:10<00:00, 36.10it/s]
current loss:-1.2229564189910889: 100%|██████████| 391/391 [00:10<00:00, 36.03it/s]
current loss:-1.2284257411956787: 100%|██████████| 391/391 [00:10<00:00, 36.26it/s]
current loss:-1.2336169481277466: 100%|██████████| 391/391 [00:10<00:00, 36.70it/s]
current loss:-1.2385437488555908: 100%|██████████| 391/391 [00:10<00:00, 36.57it/s]
current loss:-1.2432814836502075: 100%|██████████| 391/391 [00:10<00:00, 36.41it/s]
current loss:-1.247802734375: 100%|██████████| 391/391 [00:10<00:00, 36.38it/s]
current loss:-1.2521226406097412: 100%|██████████| 391/391 [00:10<00:00, 36.12it/s]
current loss:-1.2562007904052734: 100%|██████████| 391/391 [00:10<00:00, 35.92it

accuracy (array(0.4201, dtype=float32), 0.7436227300342308, 2.139444215767114)
hyperparameters 0.5732453120378002 2.5671887944190748


current loss:-0.4065067768096924: 100%|██████████| 391/391 [00:10<00:00, 36.20it/s]
current loss:-0.42194756865501404: 100%|██████████| 391/391 [00:10<00:00, 35.74it/s]
current loss:-0.4368494153022766: 100%|██████████| 391/391 [00:10<00:00, 35.92it/s]
current loss:-0.4513775408267975: 100%|██████████| 391/391 [00:10<00:00, 36.10it/s]
current loss:-0.46572819352149963: 100%|██████████| 391/391 [00:10<00:00, 36.52it/s]
current loss:-0.478985071182251: 100%|██████████| 391/391 [00:10<00:00, 36.57it/s]
current loss:-0.49138593673706055: 100%|██████████| 391/391 [00:10<00:00, 36.31it/s]
current loss:-0.5029922723770142: 100%|██████████| 391/391 [00:10<00:00, 36.07it/s]
current loss:-0.5138979554176331: 100%|██████████| 391/391 [00:10<00:00, 36.15it/s]
current loss:-0.5241125822067261: 100%|██████████| 391/391 [00:10<00:00, 36.20it/s]
current loss:-0.5335608720779419: 100%|██████████| 391/391 [00:10<00:00, 36.35it/s]
current loss:-0.5422219038009644: 100%|██████████| 391/391 [00:10<00:00, 3

accuracy (array(0.43469998, dtype=float32), 0.5732453120378002, 2.5671887944190748)
hyperparameters 0.03930156523190198 4.140869457065747


current loss:2.0230765342712402: 100%|██████████| 391/391 [00:10<00:00, 35.74it/s]
current loss:1.9432204961776733: 100%|██████████| 391/391 [00:10<00:00, 35.69it/s]
current loss:1.8703712224960327: 100%|██████████| 391/391 [00:10<00:00, 36.68it/s]
current loss:1.8064231872558594: 100%|██████████| 391/391 [00:10<00:00, 36.53it/s]
current loss:1.7514501810073853: 100%|██████████| 391/391 [00:10<00:00, 36.17it/s]
current loss:1.7042007446289062: 100%|██████████| 391/391 [00:10<00:00, 36.51it/s]
current loss:1.6626619100570679: 100%|██████████| 391/391 [00:10<00:00, 36.53it/s]
current loss:1.6258199214935303: 100%|██████████| 391/391 [00:10<00:00, 35.98it/s]
current loss:1.592835783958435: 100%|██████████| 391/391 [00:10<00:00, 35.81it/s]
current loss:1.5627743005752563: 100%|██████████| 391/391 [00:10<00:00, 36.11it/s]
current loss:1.5345200300216675: 100%|██████████| 391/391 [00:10<00:00, 35.93it/s]
current loss:1.508158802986145: 100%|██████████| 391/391 [00:10<00:00, 36.79it/s]
curren

accuracy (array(0.55009997, dtype=float32), 0.03930156523190198, 4.140869457065747)
hyperparameters 0.7500580818687124 0.9828856557056701


current loss:-1.2158348560333252: 100%|██████████| 391/391 [00:11<00:00, 35.46it/s]
current loss:-1.227516531944275: 100%|██████████| 391/391 [00:11<00:00, 34.90it/s]
current loss:-1.2353882789611816: 100%|██████████| 391/391 [00:11<00:00, 34.42it/s]
current loss:-1.2410627603530884: 100%|██████████| 391/391 [00:11<00:00, 34.55it/s]
current loss:-1.2454416751861572: 100%|██████████| 391/391 [00:11<00:00, 34.83it/s]
current loss:-1.2489577531814575: 100%|██████████| 391/391 [00:11<00:00, 33.87it/s]
current loss:-1.2519474029541016: 100%|██████████| 391/391 [00:11<00:00, 35.14it/s]
current loss:-1.254575490951538: 100%|██████████| 391/391 [00:11<00:00, 34.75it/s]
current loss:-1.2569167613983154: 100%|██████████| 391/391 [00:11<00:00, 34.68it/s]
current loss:-1.2590267658233643: 100%|██████████| 391/391 [00:11<00:00, 34.50it/s]
current loss:-1.2609597444534302: 100%|██████████| 391/391 [00:11<00:00, 34.42it/s]
current loss:-1.262738585472107: 100%|██████████| 391/391 [00:11<00:00, 34.24i

accuracy (array(0.4114, dtype=float32), 0.7500580818687124, 0.9828856557056701)
hyperparameters 0.10129593583806318 3.5658258086369043


current loss:1.7608306407928467: 100%|██████████| 391/391 [00:10<00:00, 37.09it/s]
current loss:1.6908928155899048: 100%|██████████| 391/391 [00:10<00:00, 36.80it/s]
current loss:1.629123330116272: 100%|██████████| 391/391 [00:10<00:00, 36.86it/s]
current loss:1.5725090503692627: 100%|██████████| 391/391 [00:10<00:00, 36.90it/s]
current loss:1.5219169855117798: 100%|██████████| 391/391 [00:10<00:00, 37.18it/s]
current loss:1.4777776002883911: 100%|██████████| 391/391 [00:10<00:00, 37.82it/s]
current loss:1.4408150911331177: 100%|██████████| 391/391 [00:10<00:00, 38.01it/s]
current loss:1.4104893207550049: 100%|██████████| 391/391 [00:10<00:00, 37.68it/s]
current loss:1.3850791454315186: 100%|██████████| 391/391 [00:10<00:00, 36.97it/s]
current loss:1.3628618717193604: 100%|██████████| 391/391 [00:10<00:00, 36.90it/s]
current loss:1.3429731130599976: 100%|██████████| 391/391 [00:10<00:00, 37.66it/s]
current loss:1.32404363155365: 100%|██████████| 391/391 [00:10<00:00, 37.50it/s]
current

accuracy (array(0.53499997, dtype=float32), 0.10129593583806318, 3.5658258086369043)
hyperparameters 0.5928250619685678 1.7646429195182196


current loss:-0.4925658702850342: 100%|██████████| 391/391 [00:11<00:00, 34.83it/s]
current loss:-0.5046738982200623: 100%|██████████| 391/391 [00:11<00:00, 35.31it/s]
current loss:-0.5157442092895508: 100%|██████████| 391/391 [00:11<00:00, 35.13it/s]
current loss:-0.5261997580528259: 100%|██████████| 391/391 [00:11<00:00, 34.80it/s]
current loss:-0.5361726880073547: 100%|██████████| 391/391 [00:11<00:00, 34.54it/s]
current loss:-0.545760989189148: 100%|██████████| 391/391 [00:11<00:00, 34.76it/s]
current loss:-0.554977536201477: 100%|██████████| 391/391 [00:11<00:00, 34.51it/s]
current loss:-0.5638519525527954: 100%|██████████| 391/391 [00:11<00:00, 34.88it/s]
current loss:-0.5723615288734436: 100%|██████████| 391/391 [00:11<00:00, 34.42it/s]
current loss:-0.5803617238998413: 100%|██████████| 391/391 [00:11<00:00, 34.73it/s]
current loss:-0.587859570980072: 100%|██████████| 391/391 [00:11<00:00, 34.72it/s]
current loss:-0.5949331521987915: 100%|██████████| 391/391 [00:11<00:00, 34.35i

accuracy (array(0.44869998, dtype=float32), 0.5928250619685678, 1.7646429195182196)
hyperparameters 0.04381179951812075 0.9081812206423826


current loss:2.0773870944976807: 100%|██████████| 391/391 [00:11<00:00, 34.44it/s]
current loss:2.029015064239502: 100%|██████████| 391/391 [00:11<00:00, 34.43it/s]
current loss:1.9742813110351562: 100%|██████████| 391/391 [00:11<00:00, 34.27it/s]
current loss:1.9151643514633179: 100%|██████████| 391/391 [00:11<00:00, 33.66it/s]
current loss:1.8550344705581665: 100%|██████████| 391/391 [00:11<00:00, 34.36it/s]
current loss:1.7973073720932007: 100%|██████████| 391/391 [00:11<00:00, 34.58it/s]
current loss:1.7457234859466553: 100%|██████████| 391/391 [00:11<00:00, 34.98it/s]
current loss:1.7001034021377563: 100%|██████████| 391/391 [00:11<00:00, 34.62it/s]
current loss:1.65984308719635: 100%|██████████| 391/391 [00:11<00:00, 34.60it/s]
current loss:1.6241931915283203: 100%|██████████| 391/391 [00:11<00:00, 34.97it/s]
current loss:1.5919965505599976: 100%|██████████| 391/391 [00:11<00:00, 34.97it/s]
current loss:1.5630522966384888: 100%|██████████| 391/391 [00:11<00:00, 34.30it/s]
current

accuracy (array(0.5203, dtype=float32), 0.04381179951812075, 0.9081812206423826)
hyperparameters 0.16132921600620231 5.259476554867887


current loss:1.5156606435775757: 100%|██████████| 391/391 [00:10<00:00, 35.78it/s]
current loss:1.4650264978408813: 100%|██████████| 391/391 [00:11<00:00, 35.35it/s]
current loss:1.413596749305725: 100%|██████████| 391/391 [00:11<00:00, 35.31it/s]
current loss:1.358938455581665: 100%|██████████| 391/391 [00:11<00:00, 35.44it/s]
current loss:1.3045425415039062: 100%|██████████| 391/391 [00:11<00:00, 35.15it/s]
current loss:1.254073143005371: 100%|██████████| 391/391 [00:11<00:00, 35.42it/s]
current loss:1.2108075618743896: 100%|██████████| 391/391 [00:11<00:00, 35.47it/s]
current loss:1.1744104623794556: 100%|██████████| 391/391 [00:10<00:00, 35.83it/s]
current loss:1.143571376800537: 100%|██████████| 391/391 [00:11<00:00, 34.82it/s]
current loss:1.1167261600494385: 100%|██████████| 391/391 [00:10<00:00, 36.22it/s]
current loss:1.0932228565216064: 100%|██████████| 391/391 [00:11<00:00, 35.40it/s]
current loss:1.0722665786743164: 100%|██████████| 391/391 [00:11<00:00, 35.29it/s]
current 

accuracy (array(0.5216, dtype=float32), 0.16132921600620231, 5.259476554867887)
hyperparameters 0.5489897558452683 0.11909060341740987


current loss:-0.1960679441690445: 100%|██████████| 391/391 [00:10<00:00, 36.46it/s]
current loss:-0.2344983071088791: 100%|██████████| 391/391 [00:10<00:00, 35.61it/s]
current loss:-0.2518623471260071: 100%|██████████| 391/391 [00:10<00:00, 35.83it/s]
current loss:-0.261246919631958: 100%|██████████| 391/391 [00:11<00:00, 34.37it/s]
current loss:-0.267011821269989: 100%|██████████| 391/391 [00:11<00:00, 35.36it/s]
current loss:-0.2709406912326813: 100%|██████████| 391/391 [00:11<00:00, 34.84it/s]
current loss:-0.273771733045578: 100%|██████████| 391/391 [00:10<00:00, 35.95it/s]
current loss:-0.27591782808303833: 100%|██████████| 391/391 [00:11<00:00, 35.08it/s]
current loss:-0.27758318185806274: 100%|██████████| 391/391 [00:11<00:00, 34.80it/s]
current loss:-0.27894943952560425: 100%|██████████| 391/391 [00:11<00:00, 35.02it/s]
current loss:-0.2801061272621155: 100%|██████████| 391/391 [00:11<00:00, 34.32it/s]
current loss:-0.2811063230037689: 100%|██████████| 391/391 [00:11<00:00, 33.

accuracy (array(0.33449998, dtype=float32), 0.5489897558452683, 0.11909060341740987)
hyperparameters 0.14405238895431893 0.9176800855999695


current loss:1.5842519998550415: 100%|██████████| 391/391 [00:10<00:00, 35.77it/s]
current loss:1.5370876789093018: 100%|██████████| 391/391 [00:11<00:00, 35.39it/s]
current loss:1.4889137744903564: 100%|██████████| 391/391 [00:10<00:00, 35.85it/s]
current loss:1.4396626949310303: 100%|██████████| 391/391 [00:11<00:00, 34.95it/s]
current loss:1.3945993185043335: 100%|██████████| 391/391 [00:11<00:00, 35.21it/s]
current loss:1.3560881614685059: 100%|██████████| 391/391 [00:10<00:00, 35.56it/s]
current loss:1.3230241537094116: 100%|██████████| 391/391 [00:11<00:00, 35.26it/s]
current loss:1.2934398651123047: 100%|██████████| 391/391 [00:10<00:00, 35.83it/s]
current loss:1.2672721147537231: 100%|██████████| 391/391 [00:11<00:00, 34.97it/s]
current loss:1.2437098026275635: 100%|██████████| 391/391 [00:10<00:00, 35.65it/s]
current loss:1.222394347190857: 100%|██████████| 391/391 [00:11<00:00, 35.52it/s]
current loss:1.2026903629302979: 100%|██████████| 391/391 [00:11<00:00, 34.32it/s]
curre

accuracy (array(0.5493, dtype=float32), 0.14405238895431893, 0.9176800855999695)
hyperparameters 0.2966415123820938 0.6998974378114141


current loss:0.8962749242782593: 100%|██████████| 391/391 [00:10<00:00, 35.69it/s]
current loss:0.8629859089851379: 100%|██████████| 391/391 [00:11<00:00, 35.06it/s]
current loss:0.8370498418807983: 100%|██████████| 391/391 [00:11<00:00, 34.47it/s]
current loss:0.8127754926681519: 100%|██████████| 391/391 [00:11<00:00, 34.90it/s]
current loss:0.7901501655578613: 100%|██████████| 391/391 [00:10<00:00, 35.88it/s]
current loss:0.7697151899337769: 100%|██████████| 391/391 [00:11<00:00, 34.39it/s]
current loss:0.7516400814056396: 100%|██████████| 391/391 [00:11<00:00, 35.39it/s]
current loss:0.7358689904212952: 100%|██████████| 391/391 [00:11<00:00, 34.99it/s]
current loss:0.7221025228500366: 100%|██████████| 391/391 [00:11<00:00, 34.65it/s]
current loss:0.709987223148346: 100%|██████████| 391/391 [00:10<00:00, 35.76it/s]
current loss:0.6989326477050781: 100%|██████████| 391/391 [00:11<00:00, 34.85it/s]
current loss:0.6888738870620728: 100%|██████████| 391/391 [00:11<00:00, 35.41it/s]
curre

accuracy (array(0.4894, dtype=float32), 0.2966415123820938, 0.6998974378114141)
hyperparameters 0.5591809580041471 6.929433578193837


current loss:-0.33389338850975037: 100%|██████████| 391/391 [00:11<00:00, 35.35it/s]
current loss:-0.35187989473342896: 100%|██████████| 391/391 [00:11<00:00, 35.03it/s]
current loss:-0.3681008219718933: 100%|██████████| 391/391 [00:11<00:00, 34.74it/s]
current loss:-0.3832077383995056: 100%|██████████| 391/391 [00:11<00:00, 35.13it/s]
current loss:-0.3974228799343109: 100%|██████████| 391/391 [00:11<00:00, 35.22it/s]
current loss:-0.41100192070007324: 100%|██████████| 391/391 [00:11<00:00, 34.73it/s]
current loss:-0.4242301881313324: 100%|██████████| 391/391 [00:11<00:00, 34.91it/s]
current loss:-0.4373413026332855: 100%|██████████| 391/391 [00:11<00:00, 34.92it/s]
current loss:-0.45022153854370117: 100%|██████████| 391/391 [00:11<00:00, 34.88it/s]
current loss:-0.46278223395347595: 100%|██████████| 391/391 [00:10<00:00, 35.86it/s]
current loss:-0.4751318395137787: 100%|██████████| 391/391 [00:11<00:00, 35.16it/s]
current loss:-0.4871085584163666: 100%|██████████| 391/391 [00:11<00:00

accuracy (array(0.4567, dtype=float32), 0.5591809580041471, 6.929433578193837)
hyperparameters 0.7708544429051166 0.1881820379132395


current loss:-1.2618229389190674: 100%|██████████| 391/391 [00:11<00:00, 35.52it/s]
current loss:-1.285182237625122: 100%|██████████| 391/391 [00:10<00:00, 35.67it/s]
current loss:-1.2971251010894775: 100%|██████████| 391/391 [00:11<00:00, 34.50it/s]
current loss:-1.304699420928955: 100%|██████████| 391/391 [00:10<00:00, 35.68it/s]
current loss:-1.3099257946014404: 100%|██████████| 391/391 [00:10<00:00, 35.64it/s]
current loss:-1.313735008239746: 100%|██████████| 391/391 [00:11<00:00, 34.16it/s]
current loss:-1.3166072368621826: 100%|██████████| 391/391 [00:11<00:00, 35.46it/s]
current loss:-1.3188191652297974: 100%|██████████| 391/391 [00:11<00:00, 34.80it/s]
current loss:-1.3205407857894897: 100%|██████████| 391/391 [00:11<00:00, 33.60it/s]
current loss:-1.3219339847564697: 100%|██████████| 391/391 [00:11<00:00, 33.91it/s]
current loss:-1.3230860233306885: 100%|██████████| 391/391 [00:11<00:00, 34.82it/s]
current loss:-1.324037790298462: 100%|██████████| 391/391 [00:11<00:00, 33.64it

accuracy (array(0.343, dtype=float32), 0.7708544429051166, 0.1881820379132395)
hyperparameters 0.2001505951088114 2.0082100955162296


current loss:1.3350355625152588: 100%|██████████| 391/391 [00:11<00:00, 35.43it/s]
current loss:1.294571042060852: 100%|██████████| 391/391 [00:11<00:00, 35.23it/s]
current loss:1.2522237300872803: 100%|██████████| 391/391 [00:11<00:00, 35.01it/s]
current loss:1.2055028676986694: 100%|██████████| 391/391 [00:10<00:00, 35.78it/s]
current loss:1.1566532850265503: 100%|██████████| 391/391 [00:11<00:00, 34.83it/s]
current loss:1.109885573387146: 100%|██████████| 391/391 [00:11<00:00, 34.40it/s]
current loss:1.0670220851898193: 100%|██████████| 391/391 [00:11<00:00, 35.28it/s]
current loss:1.0294773578643799: 100%|██████████| 391/391 [00:11<00:00, 35.16it/s]
current loss:0.9979041814804077: 100%|██████████| 391/391 [00:11<00:00, 33.75it/s]
current loss:0.9715303182601929: 100%|██████████| 391/391 [00:11<00:00, 34.96it/s]
current loss:0.9490405321121216: 100%|██████████| 391/391 [00:11<00:00, 35.26it/s]
current loss:0.9290874600410461: 100%|██████████| 391/391 [00:11<00:00, 35.20it/s]
curren

accuracy (array(0.5191, dtype=float32), 0.2001505951088114, 2.0082100955162296)
hyperparameters 0.8575734834141558 2.5363177272872863


current loss:-1.7248703241348267: 100%|██████████| 391/391 [00:11<00:00, 34.27it/s]
current loss:-1.728928804397583: 100%|██████████| 391/391 [00:11<00:00, 35.49it/s]
current loss:-1.7320436239242554: 100%|██████████| 391/391 [00:11<00:00, 34.97it/s]
current loss:-1.7346241474151611: 100%|██████████| 391/391 [00:10<00:00, 35.75it/s]
current loss:-1.7370202541351318: 100%|██████████| 391/391 [00:11<00:00, 34.22it/s]
current loss:-1.7391421794891357: 100%|██████████| 391/391 [00:11<00:00, 35.00it/s]
current loss:-1.741145133972168: 100%|██████████| 391/391 [00:11<00:00, 35.41it/s]
current loss:-1.743058204650879: 100%|██████████| 391/391 [00:11<00:00, 33.89it/s]
current loss:-1.7448794841766357: 100%|██████████| 391/391 [00:11<00:00, 35.33it/s]
current loss:-1.746617078781128: 100%|██████████| 391/391 [00:11<00:00, 35.19it/s]
current loss:-1.7483100891113281: 100%|██████████| 391/391 [00:11<00:00, 34.98it/s]
current loss:-1.7499576807022095: 100%|██████████| 391/391 [00:11<00:00, 34.41it

accuracy (array(0.3509, dtype=float32), 0.8575734834141558, 2.5363177272872863)
hyperparameters 0.08789189176403112 0.3030171051381871


current loss:1.878554105758667: 100%|██████████| 391/391 [00:11<00:00, 34.59it/s]
current loss:1.8440635204315186: 100%|██████████| 391/391 [00:11<00:00, 34.40it/s]
current loss:1.8124792575836182: 100%|██████████| 391/391 [00:11<00:00, 34.44it/s]
current loss:1.7793323993682861: 100%|██████████| 391/391 [00:11<00:00, 33.83it/s]
current loss:1.7439508438110352: 100%|██████████| 391/391 [00:11<00:00, 34.94it/s]
current loss:1.7061455249786377: 100%|██████████| 391/391 [00:11<00:00, 35.25it/s]
current loss:1.6658546924591064: 100%|██████████| 391/391 [00:11<00:00, 34.01it/s]
current loss:1.624132752418518: 100%|██████████| 391/391 [00:11<00:00, 33.57it/s]
current loss:1.583040475845337: 100%|██████████| 391/391 [00:11<00:00, 33.94it/s]
current loss:1.5447977781295776: 100%|██████████| 391/391 [00:11<00:00, 35.00it/s]
current loss:1.5099928379058838: 100%|██████████| 391/391 [00:11<00:00, 34.37it/s]
current loss:1.4789745807647705: 100%|██████████| 391/391 [00:11<00:00, 33.93it/s]
current

accuracy (array(0.525, dtype=float32), 0.08789189176403112, 0.3030171051381871)
hyperparameters 0.5389962357205916 1.85106538587003


current loss:-0.24212546646595: 100%|██████████| 391/391 [00:11<00:00, 33.45it/s]
current loss:-0.2638351023197174: 100%|██████████| 391/391 [00:11<00:00, 35.22it/s]
current loss:-0.28170520067214966: 100%|██████████| 391/391 [00:11<00:00, 35.12it/s]
current loss:-0.29786163568496704: 100%|██████████| 391/391 [00:11<00:00, 34.44it/s]
current loss:-0.31253162026405334: 100%|██████████| 391/391 [00:11<00:00, 34.11it/s]
current loss:-0.325778067111969: 100%|██████████| 391/391 [00:11<00:00, 33.84it/s]
current loss:-0.3377898335456848: 100%|██████████| 391/391 [00:11<00:00, 34.31it/s]
current loss:-0.34869876503944397: 100%|██████████| 391/391 [00:11<00:00, 34.86it/s]
current loss:-0.35881564021110535: 100%|██████████| 391/391 [00:11<00:00, 33.76it/s]
current loss:-0.3682805001735687: 100%|██████████| 391/391 [00:11<00:00, 34.58it/s]
current loss:-0.3774101138114929: 100%|██████████| 391/391 [00:11<00:00, 34.83it/s]
current loss:-0.38604119420051575: 100%|██████████| 391/391 [00:11<00:00, 

accuracy (array(0.45409998, dtype=float32), 0.5389962357205916, 1.85106538587003)


In [None]:
# 20 запусков --- с CNN
logits = np.load('./logits_cnn.npy')
for _ in range(20):
    beta = np.random.uniform()
    temp = 10**(np.random.uniform(low=-1, high=1)) # температура от 0.1 до 10  
    print ('hyperparameters', beta, temp)
    student = Cifar_Very_Tiny(10).to(device)
    optim = t.optim.Adam(student.parameters())
    #optim = t.optim.SGD(student.parameters(), lr=0.001)
    crit = nn.CrossEntropyLoss()
    for _ in range(epoch_num):
        tq = tqdm.tqdm(train_loader_no_augumentation)
        losses = []
        for batch_id, (x,y) in enumerate(tq):
            x = x.to(device)
            y = y.to(device)            
            batch_logits = t.Tensor(logits[128*batch_id:128*(batch_id+1)]).to(device)            
            student.zero_grad()
            out = student(x)
            student_loss = crit(student(x), y)
            #distillation_loss = 0.0 # здесь твой код!
            distillation_loss = distill(out, batch_logits, temp)
            loss = (1-beta) * student_loss + beta*distillation_loss
            losses.append(loss.cpu().detach().numpy())
            loss.backward()
            optim.step()
            tq.set_description('current loss:{}'.format(np.mean(losses[-10:])))
    cnn_results.append((accuracy(student), beta, temp))
    print ('accuracy', cnn_results[-1])

In [None]:
# график, разрез beta-accuracy
plt.scatter([0]*len(basic_results), basic_results, c='r', label='Без дистилляции')
#plt.scatter([r[1] for r in cnn_results], [r[0] for r in cnn_results], c='r', marker='x', label='Дистилляция CNN')
plt.scatter([r[1] for r in resnet_results], [r[0] for r in resnet_results], c='r', marker='d', label='Дистилляция ResNet')
plt.legend(loc='best')
plt.xlabel('Beta')
plt.ylabel('Accuracy')
plt.savefig('scatter_beta_acc.png')

In [None]:
# график, разрез accuracy-Temp
plt.scatter([0.0]*len(basic_results), basic_results, c='r', label='Без дистилляции')
#plt.scatter([np.log10(r[2]) for r in cnn_results], [r[0] for r in cnn_results], c='r', marker='x', label='Дистилляция CNN')
plt.scatter([np.log10(r[2]) for r in resnet_results], [r[0] for r in resnet_results], c='r', marker='d', label='Дистилляция ResNet')
plt.legend(loc='best')
plt.xlabel('log(T)')
plt.ylabel('Accuracy')
plt.savefig('scatter_temp_acc.png')

In [None]:
# график, разрез Temp-beta. Цвет --- точность относительно всех результатов

# для резульатов без дистилляции у нас будут повторяющиеся точки. 
# обычно в таком случае в график вносят небольшой шум, чтобы точки можно было визуально различить
eps = 0.02 

all_results = basic_results+[r[0] for r in cnn_results] + [r[0] for r in resnet_results]
max_ = np.max(all_results)
min_ = np.min(all_results)
colors = [cm.seismic((r-min_)/(max_-min_)) for r in basic_results]
plt.scatter(np.random.randn(len(basic_results))*eps, np.random.randn(len(basic_results))*eps, c=colors, label='Без дистилляции')
#colors = [cm.seismic((r[0]-min_)/(max_-min_)) for r in cnn_results]
#plt.scatter([np.log10(r[2]) + np.random.randn()*eps for r in cnn_results], [r[1] + np.random.randn()*eps for r in cnn_results], c=colors, marker='x', label='Дистилляция CNN')
#colors = [cm.seismic((r[0]-min_)/(max_-min_)) for r in resnet_results]
#plt.scatter([np.log10(r[2]) + np.random.randn()*eps for r in resnet_results], [r[1] + np.random.randn()*eps for r in resnet_results], c=colors, marker='d', label='Дистилляция ResNet')
plt.legend(loc='best')
plt.xlabel('log(T)')
plt.ylabel('Beta')
plt.savefig('scatter_temp_beta.png')