In [1]:
import sys
sys.path.append('../code')
from resnet import *
from fashionmnist_funcs import *
from fashionmnist_net import *
from fashionmnist_dataset import *
import torch as t
import numpy as np
from numpy import polyfit
from numpy import polyval
import tqdm
import matplotlib.pylab as plt
import matplotlib.cm as cm
import json
import hyperparams
from importlib import reload
from scipy.interpolate import interp1d
from PIL import Image
%matplotlib inline
plt.rcParams['figure.figsize']=(12,9)
plt.rcParams['font.size']= 20


# Data loading

In [2]:
_, test_loader, train_loader_no_augumentation, valid_loader = fashionmnist_loader(batch_size=128, split_train_val=True,
                                                                                  maxsize=10112*2)
#                                                                              maxsize=8848*2)

# Experiments

In [3]:
epoch_num = 100
# epoch_num = 50

run_num = 5 # количество запусков эксперимента

# версия нужна, чтобы различать старые и новые результаты экспериментов. 
# менять нужно каждый раз, когда есть хотя бы незначительные изменения в эксперименте
experiment_version = '1'

validate_every_epoch = 5 

# train_splines_every_epoch = 5 # каждые 5 эпох отслеживать траекторию гиперпараметров
train_splines_every_epoch = 2
# train_splines_every_epoch = 3
# train_splines_every_epoch = 10

# размер мини-эпохи в батчах, за которую у нас производится либо обучение спайлов, либо их использование
mini_epoch_size = 10

start_beta = 0.9914
start_temp  = 6.5

### without distillation

In [4]:
# запуск без дистилляции
filename='nodistil'
fashionmnist_base(experiment_version, run_num, epoch_num, start_beta, start_temp, filename, 
           tr_load=train_loader_no_augumentation, t_load=test_loader, 
           validate_every_epoch=validate_every_epoch)

current loss:0.653059184551239: 100%|██████████| 79/79 [00:03<00:00, 22.78it/s] 
current loss:0.6535857915878296:   4%|▍         | 3/79 [00:00<00:03, 21.20it/s]

{'epoch': 0, 'test loss': 0.6432048678398132, 'accuracy': 0.7724999785423279}


current loss:0.5091562867164612: 100%|██████████| 79/79 [00:03<00:00, 23.14it/s] 
current loss:0.4601860046386719:  19%|█▉        | 15/79 [00:00<00:03, 20.76it/s]


KeyboardInterrupt: 

### with distillation and lambda2=0

In [5]:
# Запуск --- с CNN-дистилляцией
filename='distil-1'
fashionmnist_base(experiment_version, run_num, epoch_num, start_beta, start_temp, filename, 
           tr_load=train_loader_no_augumentation, t_load=test_loader, 
           validate_every_epoch=validate_every_epoch, mode='distil-1')

  f = F.log_softmax(batch_logits/temp)
current loss:0.6592805981636047: 100%|██████████| 79/79 [00:03<00:00, 23.58it/s]
current loss:0.6613246202468872:   4%|▍         | 3/79 [00:00<00:03, 20.97it/s]

{'epoch': 0, 'test loss': 0.6718989610671997, 'accuracy': 0.7657999992370605}


current loss:0.5368021726608276:  97%|█████████▋| 77/79 [00:03<00:00, 22.24it/s]


KeyboardInterrupt: 

### with distillation and lambda1=0

In [6]:
filename='distil-2'
fashionmnist_base(experiment_version, run_num, epoch_num, start_beta, start_temp, filename, 
           tr_load=train_loader_no_augumentation, t_load=test_loader, 
           validate_every_epoch=validate_every_epoch, mode='distil-2')

current loss:9.001847267150879: 100%|██████████| 79/79 [00:03<00:00, 22.86it/s]
current loss:9.066611289978027:   4%|▍         | 3/79 [00:00<00:03, 20.99it/s]

{'epoch': 0, 'test loss': 2.9407753944396973, 'accuracy': 0.06599999964237213}


current loss:8.972392082214355: 100%|██████████| 79/79 [00:03<00:00, 22.61it/s]
current loss:8.950529098510742: 100%|██████████| 79/79 [00:03<00:00, 22.54it/s]
current loss:8.837343215942383:  10%|█         | 8/79 [00:00<00:03, 19.31it/s]


KeyboardInterrupt: 

### with random metaparameters

In [7]:
# Запуск --- со случаными значениями гиперпараметров
filename='random'
fashionmnist_base(experiment_version, run_num, epoch_num, start_beta, start_temp, filename, 
           tr_load=train_loader_no_augumentation, t_load=test_loader, 
           validate_every_epoch=validate_every_epoch, mode='random')

current loss:3.058194398880005: 100%|██████████| 79/79 [00:03<00:00, 23.51it/s] 
current loss:3.1982827186584473:   3%|▎         | 2/79 [00:00<00:05, 13.80it/s]

{'epoch': 0, 'test loss': 1.215680718421936, 'accuracy': 0.6782000064849854, 'temp': 1.9649925231933594, 'lambda1': 0.6926254630088806, 'lambda2': 0.27192527055740356}


current loss:3.0269060134887695: 100%|██████████| 79/79 [00:03<00:00, 22.29it/s]
current loss:2.918316125869751:  18%|█▊        | 14/79 [00:00<00:02, 22.38it/s] 


KeyboardInterrupt: 

### with metaparameter optimization

In [8]:
# Запуск --- с CNN-дистилляцией и оптимизацией гиперпараметров, 2-lambda
filename='opt'
fashionmnist_with_validation_set(experiment_version, run_num, epoch_num, filename,
                          tr_s_epoch=train_splines_every_epoch,
                          m_e=mini_epoch_size,
                          tr_load=train_loader_no_augumentation, 
                          t_load=test_loader, 
                          val_load=valid_loader, 
                          validate_every_epoch=validate_every_epoch, 
                          mode='opt')

current loss:101.75953674316406: : 79it [00:10,  7.66it/s]


2.2814853191375732


current loss:111.34952545166016: : 1it [00:00,  6.90it/s]

{'epoch': 0, 'test loss': 2.285008430480957, 'accuracy': 0.10000000149011612, 'temp': 0.1, 'lambda1': 1.0, 'lambda2': 1.0}


current loss:101.60304260253906: : 79it [00:10,  7.73it/s]
current loss:99.77582550048828: : 71it [00:09,  7.71it/s] 


KeyboardInterrupt: 

### with linear models

In [9]:
# Запуск --- с CNN-дистилляцией и оптимизацией гиперпараметров, 2-lambda (с линейными моделями)
filename='splines'
fashionmnist_with_validation_set(experiment_version, run_num, epoch_num, filename,
                          tr_s_epoch=train_splines_every_epoch,
                          m_e=mini_epoch_size,
                          tr_load=train_loader_no_augumentation, 
                          t_load=test_loader, 
                          val_load=valid_loader, 
                          validate_every_epoch=validate_every_epoch, 
                          mode='splines')

current loss:101.78016662597656: : 79it [00:07, 10.43it/s]


2.302297830581665


current loss:111.41935729980469: : 1it [00:00,  6.51it/s]

{'epoch': 0, 'test loss': 2.2998642921447754, 'accuracy': 0.10000000149011612, 'temp': 0.1, 'lambda1': 1.0, 'lambda2': 1.0}


current loss:101.51590728759766: : 79it [00:07, 10.31it/s]
current loss:97.34861755371094: : 35it [00:03,  9.64it/s] 


KeyboardInterrupt: 

### with hyperopt

In [10]:
# Запуск --- с CNN-дистилляцией и оптимизацией гиперпараметров, 2-lambda (с линейными моделями)
filename='hyperopt'

fashionmnist_with_hyperopt(experiment_version, run_num, epoch_num, filename,
                          tr_s_epoch=train_splines_every_epoch,
                          m_e=mini_epoch_size,
                          tr_load=train_loader_no_augumentation, 
                          t_load=test_loader, 
                          val_load=valid_loader, 
                          validate_every_epoch=validate_every_epoch, 
                          trial_num=5)

current loss:1.7769527435302734: : 79it [00:05, 15.46it/s]


0.6659116148948669


current loss:1.7355767488479614: : 2it [00:00, 14.84it/s]

{'epoch': 0, 'test loss': 0.6651864647865295, 'accuracy': 0.7678999900817871, 'temp': 1.0, 'lambda1': 0.10000000149011612, 'lambda2': 1.0, 'val acc': 0.7697784900665283}


current loss:1.645639419555664: : 79it [00:04, 16.05it/s] 
current loss:1.5809215307235718: : 79it [00:04, 16.03it/s]
current loss:1.5371875762939453: : 79it [00:04, 15.93it/s]
current loss:1.460813283920288: : 16it [00:01, 14.36it/s] 


KeyboardInterrupt: 

# Results

In [None]:
data_b = open_data_json("../log/fashionmnist_exp23_nodistil.jsonl")
# data_d1 = open_data_json("../log/fashionmnist_exp23_distil-1.jsonl")
# data_d2 = open_data_json("../log/fashionmnist_exp23_distil-2.jsonl")
data_r = open_data_json("../log/fashionmnist_exp23_random.jsonl")
data_o = open_data_json("../log/fashionmnist_exp23_opt.jsonl")
data_s = open_data_json("../log/fashionmnist_exp23_splines.jsonl")
data_h = open_data_json("../log/fashionmnist_exp23_hyperopt.jsonl")

In [None]:
from matplotlib import pylab as plt

plt.rcParams['font.family'] = 'DejaVu Serif'
plt.rcParams['lines.linewidth'] = 2
plt.rcParams['lines.markersize'] = 12
plt.rcParams['xtick.labelsize'] = 24
plt.rcParams['ytick.labelsize'] = 24
plt.rcParams['legend.fontsize'] = 24
plt.rcParams['axes.titlesize'] = 36
plt.rcParams['axes.labelsize'] = 24

In [None]:
plot_data_params(data_b, 'test loss', 'без дистилляции', 'tab:blue', '')
plot_data_params(data_d1, 'test loss', 'дистилляция с lambda2=0', 'tab:orange', '')
plot_data_params(data_d2, 'test loss', 'дистилляция с lambda1=0', 'black', '')
plot_data_params(data_r, 'test loss', 'оптимизация гиперпараметров', 'tab:red', '')
plot_data_params(data_o, 'test loss', 'случайные гиперпараметры', 'tab:green', '')
plot_data_params(data_h, 'test loss', 'hyperopt', 'tab:brown', '')

plt.xlabel('Количество эпох')
plt.ylabel('Потеря на тестовой выборке')

plt.legend()
plt.savefig('../figs/fashionmnist_loss_23.pdf')

In [None]:
plot_data_params(data_b, 'accuracy', 'без дистилляции', 'tab:red', '')
plot_data_params(data_d1, 'accuracy', 'lambda2=0', 'tab:blue', '')
plot_data_params(data_d2, 'accuracy', 'lambda1=0', 'tab:brown', '')
plot_data_params(data_o, 'accuracy', 'оптимизация гиперпараметров', 'tab:green', '')
plot_data_params(data_r, 'accuracy', 'случайные гиперпараметры', 'tab:orange', '')
plot_data_params(data_s, 'accuracy', 'прогнозирование гиперпараметров', 'black', '')
plot_data_params(data_h, 'accuracy', 'hyperopt', 'navy', '')

plt.xlabel('Количество эпох')
plt.ylabel('Точность классификации')
plt.legend()
plt.savefig('../figs/fashionmnist_acc_23.pdf')

In [None]:
plot_data_params(data_b, 'accuracy', 'without distillation', 'black', '+')
plot_data_params(data_o, 'accuracy', 'metaparamter optimization', 'green', 'x')
plot_data_params(data_r, 'accuracy', 'random metaparameters', 'blue', '.')
plot_data_params(data_s, 'accuracy', 'metaparameter prediction', 'red', '4')

plt.xlabel('Epoch number')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('../figs/fashionmnist_acc_'+experiment_version+'_eng.pdf')

In [None]:
plot_data_params(data_s, 'lambda1', 'metaparameter prediction', 'red', '+')
plot_data_params(data_o, 'lambda1', 'metaparameter optimization', 'green', 'x')
plt.xlabel('Iteration number')
plt.ylabel(r'$\lambda_1$')
plt.legend()
plt.savefig('../figs/fashionmnist_lambda1_iter'+experiment_version+'_eng.pdf')
plt.show()

In [None]:
plot_data_params(data_s, 'lambda1', 'прогнозирование метапараметров', 'red', '')
plot_data_params(data_o, 'lambda1', 'оптимизация метапараметров', 'green', '')
plt.xlabel('Число итераций')
plt.ylabel(r'$\lambda_1$')

plt.legend()
plt.savefig('../figs/fashionmnist_lambda1_iter'+experiment_version+'_color.pdf')
plt.show()

In [None]:
plot_data_params(data_s, 'lambda2', 'metaparameter prediction', 'red', '+')
plot_data_params(data_o, 'lambda2', 'metaparameter optimization', 'green', 'x')
plt.xlabel('Iteration number')
plt.ylabel(r'$\lambda_2$')
plt.legend()
plt.savefig('../figs/fashionmnist_lambda2_iter'+experiment_version+'_eng.pdf')
plt.show()

In [None]:
plot_data_params(data_s, 'lambda2', 'прогнозирование метапараметров', 'red', '')
plot_data_params(data_o, 'lambda2', 'оптимизация метапараметров', 'green', '')
plt.xlabel('Число итераций')
plt.ylabel(r'$\lambda_2$')
plt.legend()
plt.savefig('../figs/fashionmnist_lambda2_iter'+experiment_version+'_color.pdf')
plt.show()

In [None]:
plot_data_params(data_s, 'temp', 'metaparameter prediction', 'red', '+')
plot_data_params(data_o, 'temp', 'metaparameter optimization', 'green', 'x')
plt.xlabel('Iteration number')
plt.ylabel(r'$T$')
plt.legend()
plt.savefig('../figs/fashionmnist_temp_iter'+experiment_version+'_eng.pdf')
plt.show()

In [None]:
plot_data_params(data_s, 'temp', 'прогнозирование метапараметров', 'red', '')
plot_data_params(data_o, 'temp', 'оптимизация метапараметров', 'green', '')
plt.xlabel('Число итераций')
plt.ylabel(r'$T$')
plt.legend()
plt.savefig('../figs/fashionmnist_temp_iter'+experiment_version+'_color.pdf')
plt.show()