In [1]:
import sys
sys.path.append('../code')
from resnet import *
# from funcs import *
from cifar_very_tiny import *
from cifar_tiny import *
from cifar_dataset import *    
import torch as t 
import numpy as np
from numpy import polyfit
from numpy import polyval
import tqdm
import matplotlib.pylab as plt
import matplotlib.cm as cm
import json
# import hyperparams
from importlib import reload
from scipy.interpolate import interp1d
from PIL import Image
%matplotlib inline
plt.rcParams['figure.figsize']=(12,9)
plt.rcParams['font.size']= 20

In [3]:
epoch_num = 25
# epoch_num = 50

run_num = 2 # количество запусков эксперимента

# версия нужна, чтобы различать старые и новые результаты экспериментов. 
# менять нужно каждый раз, когда есть хотя бы незначительные изменения в эксперименте
experiment_version = '3'

validate_every_epoch = 5 

# train_splines_every_epoch = 5 # каждые 5 эпох отслеживать траекторию гиперпараметров
# train_splines_every_epoch = 2
# train_splines_every_epoch = 3
train_splines_every_epoch = 10

# размер мини-эпохи в батчах, за которую у нас производится либо обучение спайлов, либо их использование
mini_epoch_size = 10

start_beta = 0.5
start_temp  = 1.0

In [4]:
train_loader_no_augumentation, valid_loader, test_loader = cifar10_loader(batch_size=128, split_train_val=True,
                                                                             maxsize=128*100, use_aug=False)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
device = 'cuda' if t.cuda.is_available() else 'cpu'

In [6]:
# This code is a modiciation of https://github.com/khanrc/pt.darts/blob/master/architect.py
import copy
import torch
import math
from torch.optim import adam


class AdamHyperGradCalculator():
    """ Compute gradients of hyperparameters wrt parameters are optimizaed by Adam """
    def __init__(self,  net, parameters_loss_function, hyperparameters_loss_function, optimizer, h, additional_params):
        """
        Args:
            net
            w_momentum: weights momentum
        """
        self.net = net
        self.v_net =  None # lazy
        self.mu_feat, self.log_sigma_feat = additional_params
        self.w = list(self.net.parameters()) + list(self.mu_feat.parameters()) + [self.log_sigma_feat]
        self.w_loss = parameters_loss_function #data,model, h
        self.h_loss = hyperparameters_loss_function #x,y,model
        self.optimizer = optimizer
        self.h = list(h)
        


    def virtual_step(self, trn):
        """
        Compute unrolled weight w' (virtual step)
        Step process:
        1) forward
        2) calc loss
        3) compute gradient (by backprop)
        4) update gradient
        """
        # forward & calc loss
        lr = self.optimizer.param_groups[0]['lr']
        h = self.h 
        optimizer = self.optimizer
        
        loss = self.w_loss(trn, self.net, h) # L_trn(w)

        # compute gradient
        gradients = torch.autograd.grad(loss, list(self.net.parameters()) + list(self.mu_feat.parameters()) + [self.log_sigma_feat], allow_unused=True)
        # do virtual step (update gradient)
        # below operations do not need gradient tracking
        with torch.no_grad():
            # dict key is not the value, but the pointer. So original network weight have to
            # be iterated also.
            for w, vw, g in zip(list(self.net.parameters()) + list(self.mu_feat.parameters()) + [self.log_sigma_feat], list(self.v_net.parameters())  + list(self.mu_feat.parameters()) + [self.log_sigma_feat], gradients):           
                #state = optimizer.state[w]
                
                # Lazy state initialization: not ready yet                    
                #if len(state) == 0:
                #    return 
                """vw_ = w.clone() 
                adam([vw_],
                             [g],
                             [state['exp_avg'].clone()],
                             [state['exp_avg_sq'].clone()],  
                             None,                           
                             [state['step']+1],
                             amsgrad = False, 
                             weight_decay = 0.0,
                             beta1 = optimizer.param_groups[0]['betas'][0],
                             beta2 = optimizer.param_groups[0]['betas'][1],
                             lr = optimizer.param_groups[0]['lr'],
                             eps = optimizer.param_groups[0]['eps'])
                                                                               
                vw.copy_(vw_)
                """
               
                vw.copy_(w - optimizer.defaults['lr'] * g)
            
            
    def calc_gradients(self, trn, val):
        """ Compute unrolled loss and backward its gradients
        Args:
            xi: learning rate for virtual gradient step (same as net lr)
            w_optim: weights optimizer - for virtual step
        """
        lr = self.optimizer.param_groups[0]['lr']
        h = self.h 
        optimizer = self.optimizer
        #for w in self.net.parameters():
        #        state = optimizer.state[w]
        #        if len(state)==0:
        #            print ('not ready')
        #            return
                
        if self.v_net is None:
        
            self.v_net = copy.deepcopy(self.net)        
        # do virtual step (calc w`)
        self.virtual_step(trn)

        # calc unrolled loss
        loss = self.h_loss(val, self.v_net) # L_val(w`)
           
        v_grads = torch.autograd.grad(loss,list(self.v_net.parameters()) + list(self.mu_feat.parameters()) + [self.log_sigma_feat])
        dw = v_grads

        hessian = self.compute_hessian(dw, trn)

        # update final gradient = dalpha - xi*hessian
        with torch.no_grad():
            for alpha,  he in zip(h,  hessian):
                alpha.grad =  -lr*he

    def compute_hessian(self, dw, trn):
        """
        dw = dw` { L_val(w`, alpha) }
        w+ = w + eps * dw
        w- = w - eps * dw
        hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
        eps = 0.01 / ||dw||
        """
        h = self.h
        norm = torch.cat([w.view(-1) for w in dw]).norm()

        eps = 1e-2 / norm
      

        # w+ = w + eps*dw`
        with torch.no_grad():
            for p, d in zip(list(self.net.parameters()) + list(self.mu_feat.parameters()) + [self.log_sigma_feat], dw):
                p += eps * d
        loss = self.w_loss(trn, self.net, h)
        dalpha_pos = torch.autograd.grad(loss, h) # dalpha { L_trn(w+) }

        # w- = w - eps*dw`
        with torch.no_grad():
            for p, d in zip(self.net.parameters(), dw):
                p -= 2. * eps * d
        loss = self.w_loss(trn, self.net, h)
        dalpha_neg = torch.autograd.grad(loss, h) # dalpha { L_trn(w-) }

        # recover w
        with torch.no_grad():
            for p, d in zip(self.net.parameters(), dw):
                p += eps * d

        hessian = [(p-n) / (2.*eps) for p, n in zip(dalpha_pos, dalpha_neg)]
        return hessian

In [7]:
# определяем функцию потерь как замкнутую относительно аргументов функцию
# нужно для подсчета градиентов гиперпараметров по двухуровневой оптимизации
def param_loss_mi(batch,model,h):
    lam1 = h[0]
    x,y,teacher_feat,mu_feat,log_sigma_feat = batch    
    student_feat, student_logits = model.get_features(x, [2,3])
    class_loss = crit(student_logits, y)
    sigma2 = torch.log(1+torch.exp(log_sigma_feat))
    feat_loss = ((mu_feat(teacher_feat) - student_feat)**2).sum(1).mean()/(2*sigma2) + 0.5*torch.log(sigma2)*np.prod(teacher_feat.shape[1:])
    loss = class_loss * (1.0-lam1) + feat_loss * lam1
    return loss

# определяем функцию валидационную функцию потерь как замкнутую относительно аргументов функцию
# нужно для подсчета градиентов гиперпараметров по двухуровневой оптимизации
def hyperparam_loss_mi(batch, model):
    x,y = batch
    student_feat, student_logits = model.get_features(x, [2,3])
    class_loss = crit(student_logits, y)            
    return class_loss

crit = nn.CrossEntropyLoss()

def dist_with_opt(experiment_version, train_loader_no_augumentation, test_loader, validation_loader, validate_every_epoch, lambdas=None, clip_grad=10e-3, seed=42):
    np.random.seed(seed)
    t.manual_seed(seed)
    
    # for lam1 in [1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 0.1, 0.5]:
    lam1 = t.nn.Parameter(t.tensor(np.random.uniform(low=0.0, high=1.0), device=device), requires_grad=True)

    if lambdas is not None: # non-random initialization
        lam1.data *= 0
        lam1.data += lambdas[0]

    student = Cifar_Very_Tiny(10).cpu()
    teacher = Cifar_Tiny(10).cpu() 
    teacher.load_state_dict(torch.load('tiny_cifar10.model?raw=true', map_location=torch.device('cpu')))
    #scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=10, gamma=0.5)   
    mu_feat = nn.Linear(128, 128).cpu()
    log_sigma_feat = torch.nn.Parameter(torch.zeros(1).cpu())

    #mu_logit = nn.Linear(10, 10).cpu()
    #log_sigma_logit = torch.nn.Parameter(torch.zeros(1).cpu())
    h =[lam1]

    optim = torch.optim.Adam(list(student.parameters()) + list(mu_feat.parameters()) + [log_sigma_feat])    
    optim2 = torch.optim.Adam(h)    
    hyper_grad_calc = AdamHyperGradCalculator(student, param_loss_mi, hyperparam_loss_mi, optim, h, [mu_feat, log_sigma_feat])
    val_load_iter = iter(validation_loader)

    for e in range(25):
        tq = tqdm.tqdm(train_loader_no_augumentation)
        losses = []

        for batch_id, (x,y) in enumerate(tq):
            try:
                 (v_x, v_y) = next(val_load_iter)
            except:                    
                val_load_iter = iter(val_load)
                (v_x, v_y) = next(val_load_iter)

                
            x = x.to(device)
            y = y.to(device)

            v_x = v_x.to(device)
            v_y = v_y.to(device)
            optim2.zero_grad()
            teacher_feat, teacher_logits = teacher.get_features(x, [2,3])
            hyper_grad_calc.calc_gradients((x,y,teacher_logits,mu_feat,log_sigma_feat), (v_x, v_y))                    
            t.nn.utils.clip_grad_value_(h, clip_grad)

            for h_ in h:
                if h_.grad is not None:
                    h_.grad = t.where(t.isnan(h_.grad), t.zeros_like(h_.grad), h_.grad)
            
            optim2.step()
            if lam1 > 1.0:
                lam1.data*=0.0
                lam1.data+=1.0
            if lam1 < 0.0:
                lam1.data*=0.0
                   
            optim.zero_grad()
            student_feat, student_logits = student.get_features(x, [2,3])
            # class_loss = crit(student_logits, y)
            # sigma2 = torch.log(1+torch.exp(log_sigma_feat))
            # feat_loss = ((mu_feat(teacher_feat) - student_feat)**2).sum(1).mean()/(2*sigma2) + 0.5*torch.log(sigma2)*np.prod(teacher_feat.shape[1:])
            #logit_loss =((mu_feat(teacher_feat) - student_feat)**2).sum(1).mean()/(2*sigma2) + 0.5*torch.log(sigma2)*np.prod(teacher_feat.shape[1:])
            #lam1 = 0.5
            loss = param_loss_mi((x,y,teacher_logits,mu_feat,log_sigma_feat), student,h)
            loss.backward()
            optim.step()
            losses.append(loss.detach().cpu().numpy())
            tq.set_description('current loss:{}'.format(np.mean(losses[-10:])))      
        #scheduler.step()
        # student.eval()

        if e==0 or (e+1)%validate_every_epoch == 0:
            test_loss = []
            student.eval()
            for x,y in test_loader:
                x = x.to(device)
                y = y.to(device)
                student_feat, student_logits = student.get_features(x, [2,3])
                test_loss.append(crit(student_logits, y).detach().cpu().numpy())
            test_loss = float(np.mean(test_loss))
            val_loss = []
            for x,y in validation_loader:
                x = x.to(device)
                y = y.to(device)
                student_feat, student_logits = student.get_features(x, [2,3])
                val_loss.append(crit(student_logits, y).detach().cpu().numpy())
            val_loss = float(np.mean(val_loss))
        
        ac = float(accuracy(student, test_loader))
        student.train()

        # if not hyperopt:
        internal_results.append({'epoch': e, 'test loss':test_loss, 'val loss':val_loss, 'accuracy':ac,
                             'lambda1':float(lam1.cpu().detach().numpy()),
                            })
        # else:
        #     val_acc = float(accuracy(student, validation_loader))
        #     internal_results.append({'epoch': e, 'test loss':test_loss, 'val loss':val_loss, 'accuracy':ac,
        #                          'lambda1':float(lam1.cpu().detach().numpy()),
        #                           'val acc':val_acc})
        print (internal_results[-1])

    # if not hyperopt: # outer function optimization
    with open('../logs/acc_mi_'+experiment_version+'.txt','a') as out:
        out.write(json.dumps({'results':internal_results, 'version': exp_ver})+'\n')
    # else:
    #     # inner function for hyperopt optimization
    #     return max([res['val acc'] for res in internal_results])
        
        # with open('../logs/acc_mi_'+experiment_version+'.txt','a') as out:
        #     out.write('{}: {}: {}\n'.format(lam1, e, ac))

In [8]:
dist_with_opt(experiment_version, train_loader_no_augumentation, test_loader, valid_loader, validate_every_epoch, lambdas=[1e-4])

  0%|          | 0/90 [00:00<?, ?it/s]


TypeError: unsupported operand type(s) for *: 'float' and 'NoneType'

In [None]:
def dist_with_no_opt(experiment_version, train_loader_no_augumentation, test_loader, validation_loader, validate_every_epoch, lambdas=None, file=True, no_tqdm=False, clip_grad=10e-3, seed=42):
    np.random.seed(seed)
    t.manual_seed(seed)

    lam1 = t.nn.Parameter(t.tensor(np.random.uniform(low=0.0, high=1.0), device=device), requires_grad=True)

    if lambdas is not None: # non-random initialization
        lam1.data *= 0
        lam1.data += lambdas[0]
    
    student = Cifar_Very_Tiny(10).cpu()
    teacher = Cifar_Tiny(10).cpu() 
    teacher.load_state_dict(torch.load('tiny_cifar10.model?raw=true', map_location=torch.device('cpu')))
    #scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=10, gamma=0.5)   
    mu_feat = nn.Linear(128, 64).cpu()
    log_sigma_feat = torch.nn.Parameter(torch.zeros(1).cpu())

    #mu_logit = nn.Linear(10, 10).cpu()
    #log_sigma_logit = torch.nn.Parameter(torch.zeros(1).cpu())


    optim = torch.optim.Adam(list(student.parameters()) + list(mu_feat.parameters()) + [log_sigma_feat])    
    val_load_iter = iter(val_load)

    for e in range(25):
        tq = tqdm.tqdm(train_loader_no_augumentation)
        if no_tqdm:
            tq = train_loader_no_augumentation
        losses = []

        for batch_id, (x,y) in enumerate(tq):
            try:
                 (v_x, v_y) = next(val_load_iter)
            except:                    
                val_load_iter = iter(val_load)
                (v_x, v_y) = next(val_load_iter)
                
            x = x.to(device)
            y = y.to(device)
            optim.zero_grad()
            student_feat, student_logits = student.get_features(x, [3,4])
            # class_loss = crit(student_logits, y)
            # sigma2 = torch.log(1+torch.exp(log_sigma_feat))
            # feat_loss = ((mu_feat(teacher_feat) - student_feat)**2).sum(1).mean()/(2*sigma2) + 0.5*torch.log(sigma2)*np.prod(teacher_feat.shape[1:])
            #logit_loss =((mu_feat(teacher_feat) - student_feat)**2).sum(1).mean()/(2*sigma2) + 0.5*torch.log(sigma2)*np.prod(teacher_feat.shape[1:])
            #lam1 = 0.5
            loss = param_loss_mi((x,y,teacher_logits,mu_feat,log_sigma_feat), student,lam1)

            loss.backward()
            optim.step()
            losses.append(loss.detach().cpu().numpy())
            if not no_tqdm:
                tq.set_description('current loss:{}'.format(np.mean(losses[-10:])))
        #scheduler.step()
        
        if e==0 or (e+1)%validate_every_epoch == 0:
            test_loss = []
            student.eval()
            for x,y in test_loader:
                x = x.to(device)
                y = y.to(device)
                student_feat, student_logits = student.get_features(x, [3,4])
                test_loss.append(crit(student_logits, y).detach().cpu().numpy())
            test_loss = float(np.mean(test_loss))
            val_loss = []
            for x,y in validation_loader:
                x = x.to(device)
                y = y.to(device)
                student_feat, student_logits = student.get_features(x, [3,4])
                val_loss.append(crit(student_logits, y).detach().cpu().numpy())
            val_loss = float(np.mean(val_loss))
        
        ac = float(accuracy(student, test_loader))
        
        if file:
            internal_results.append({'epoch': e, 'test loss':test_loss, 'val loss':val_loss, 'accuracy':ac,
                                 'lambda1':float(lam1.cpu().detach().numpy()),
                                })
        else:
            val_acc = float(accuracy(student, validation_loader))
            internal_results.append({'epoch': e, 'test loss':test_loss, 'val loss':val_loss, 'accuracy':ac,
                                 'lambda1':float(lam1.cpu().detach().numpy()),
                                  'val acc':val_acc})
        
        print (internal_results[-1])

    if file: # outer function optimization
        with open('../logs/acc_mi_'+experiment_version+'.txt','a') as out:
            out.write(json.dumps({'results':internal_results, 'version': experiment_version})+'\n')
    else:
        # inner function for hyperopt optimization
        return max([res['val acc'] for res in internal_results])
        
        # with open('../logs/acc_mi_'+experiment_version+'.txt','a') as out:
        #     out.write('{}: {}: {}\n'.format(lam1, e, ac))


def dist_hyperopt(experiment_version, run_num, tr_load, t_load, val_load, validate_every_epoch, trial_num):
    np.random.seed(42)
    t.manual_seed(42)

    for _ in range(run_num):
        cost_function = lambda lambdas: -dist_with_no_opt(experiment_version, train_loader_no_augumentation, test_loader, validation_loader, validate_every_epoch, lambdas = best_lambdas['lambda1'], file=False, no_tqdm=True) # validation accuracy * (-1) -> min
       
        best_lambdas = fmin(fn=cost_function,                             
        #space= [ hp.uniform('lambda1', 0.0, 1.0), hp.uniform('lambda2', 0.0, 1.0), hp.uniform('temp', 0.1, 10.0)],
        space= [ hp.uniform('lambda1', 0.0, 1.0)],  
        algo=tpe.suggest,
        max_evals=trial_num)
        #cifar_with_validation_set(exp_ver, 1, epoch_num, filename, tr_s_epoch, m_e, tr_load, t_load, val_load, validate_every_epoch, lambdas = [best_lambdas['lambda1'], best_lambdas['lambda2'], best_lambdas['temp']],  mode='no-opt')
        dist_with_no_opt(experiment_version, train_loader_no_augumentation, test_loader, validation_loader, validate_every_epoch, lambdas = best_lambdas['lambda1'])