### В этом файле проверяется работа DARTS на простом классификаторе cifar-10
#### В один из слоёв классивикатора вставляется на выбор один из двух слоёв: `good_conv` - выбрав его, модель достигнет 86% accuracy, и `bad_conv` - потолок 75% accuracy
**Результаты можно посмотреть в tensorboard** (веса, даваемые  `good_conv` ожидаются больше весов `bad_conv`)

In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
from torchsummary import summary

from dataloaders import get_loaders, get_test_loader
from utils import get_logger, AverageMeter, accuracy, check_tensor_in_list
from fbnet_training_functions_supernet import TrainerSupernet
from architect import Architect

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

# Configs

In [4]:
CONFIG_ARCH = {
    'dataloading' : {
        'batch_size' : 1000,
        'path_to_save_data' : './cifar10_data/',
        'w_share_in_train' : 0.8,
    },
    'logging' : {
        'path_to_log_file' : './logs_darts/logger/',
        'path_to_tensorboard_logs' : './logs_darts/tf',
    },
    'optimizer' : {
        # SGD parameters for w
        'w_lr' : 0.1,
        'w_momentum' : 0.9,
        'w_weight_decay' : 1e-4,
        # Adam parameters for thetas
        'thetas_lr' : 0.01,
        'betas' : (0.5, 0.999),
        'thetas_weight_decay' : 1e-3,
    },
    'train_settings' : {
        'eta_min' : 0.005,
        'cnt_epochs' : 200, # 90
        'print_freq' : 50, # show logging information
        'path_to_save_model' : './logs_darts/best_model.pth',
        ## CosineAnnealingLR settings
        'eta_min' : 0.005,
    }
}

# Logger

In [5]:
manual_seed = 1
#np.random.seed(manual_seed)
torch.manual_seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)
torch.backends.cudnn.benchmark = True

logger = get_logger(CONFIG_ARCH['logging']['path_to_log_file'])
writer = SummaryWriter(log_dir=CONFIG_ARCH['logging']['path_to_tensorboard_logs'])

# Dataloader

In [6]:
train_w_loader, train_thetas_loader = get_loaders(CONFIG_ARCH['dataloading']['w_share_in_train'],
                                                  CONFIG_ARCH['dataloading']['batch_size'],
                                                  CONFIG_ARCH['dataloading']['path_to_save_data'],
                                                  logger)
# test_loader - валидация
test_loader = get_test_loader(CONFIG_ARCH['dataloading']['batch_size'],
                              CONFIG_ARCH['dataloading']['path_to_save_data'])

Files already downloaded and verified
Files already downloaded and verified


# Model

#### OPS - варианты для выбора
#### единственное требование: размерность входа и размерность выхода у них одинакова
#### (без этого не заработает)

In [7]:
PRIMITIVES = ['good_conv', 'bad_conv']

OPS = {
    'good_conv' : lambda C_in, C_out : nn.Sequential(
        nn.Conv2d(C_in, C_out, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
    ),
    # спецаиально подобраный слой, сильно даунгрейдящий скор модели
    'bad_conv' : lambda C_in, C_out: nn.Sequential(
            nn.Conv2d(C_in, C_in, (2,1), padding=(1,0)),
            nn.Conv2d(C_in, C_in, (1,2), padding=(0,1)),
            nn.Conv2d(C_in, C_in, (2,1)),
            nn.Conv2d(C_in, C_in, (1,2)),
            nn.Conv2d(C_in, C_out, (1,1)),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
    )
}

class MixedOp(nn.Module):
    def __init__(self, C_in, C_out):
        super(MixedOp, self).__init__()
        ops_names = PRIMITIVES
        self.ops = nn.ModuleList([OPS[op_name](C_in, C_out) for op_name in ops_names])
        
        alphas_good = Variable(torch.Tensor([1.0 / len(ops_names)]).cuda(), requires_grad=True)
        alphas_bad  = Variable(torch.Tensor([1.0 / len(ops_names)]).cuda(), requires_grad=True)
        # Кривой DARTS-STYLE (но работает правильно). Для 1 GPU можно не вызывать register_buffer
        # Можно сделать Parameter с requared grad = false
        self.register_buffer("alphas_good", alphas_good)
        self.register_buffer("alphas_bad" , alphas_bad )
        
        self._arch_parameters = [self.alphas_good, self.alphas_bad]
        
    def forward(self, x):
        # m[0] - Кривой DARTS-STYLE (но работает правильно)
        # По-хорошему, надо переписать на нормальный список
        return sum(m[0] * op(x) for m, op in zip(self._arch_parameters, self.ops))

#### Модель с вставленной MixedOp на место одного из слоёв

In [8]:
class Mixed_Model(nn.Module):
    def __init__(self, cnt_classes=10):
        super(Mixed_Model, self).__init__()
        self.start = nn.Sequential(
                    # block 1
                    nn.Conv2d(3, 32, (3,3), padding=(3 - 1) // 2),
                    nn.ReLU(),

                    nn.Conv2d(32, 32, (3,3), padding=(3 - 1) // 2),
                    nn.ReLU(),
                    nn.MaxPool2d(2, 2),

                    # block 2
                    nn.Conv2d(32, 64, (3,3), padding=(3 - 1) // 2),
                    nn.ReLU(),

                    nn.Conv2d(64, 64, (3,3), padding=(3 - 1) // 2),
                    nn.ReLU(),
                    nn.MaxPool2d(2, 2),

                    # block 3
                    nn.Conv2d(64, 128, (3,3), padding=(3 - 1) // 2),
                    nn.ReLU()
        )
        
        # подбираем этот слой
        self.search = MixedOp(128, 128)
        
        self.end = nn.Sequential(Flatten(),
                                 nn.Linear(2048, cnt_classes),
                                 nn.Softmax(dim=1)
        )
        
        self._criterion = nn.CrossEntropyLoss()
    
    def forward(self, x):
        y = self.start(x)
        y = self.search(y)
        y = self.end(y)
        return y
    
    def _loss(self, xs, ys): # as in the Network  class
        logits = self(xs)
        return self._criterion(logits, ys)
    
    def arch_parameters(self):
        return self.search._arch_parameters

In [9]:
cnt_classes = 10
model = Mixed_Model(cnt_classes)
model = model.cuda()

summary(model, (3, 32, 32))

# Loss, Optimizer and Scheduler

In [10]:
snapshot = 10
criterion = nn.CrossEntropyLoss().cuda()

alpha_optim = torch.optim.Adam(params=model.arch_parameters(),
                                       lr=CONFIG_ARCH['optimizer']['thetas_lr'],
                                       betas=CONFIG_ARCH['optimizer']['betas'],
                                       weight_decay=CONFIG_ARCH['optimizer']['thetas_weight_decay'])
optimizer = torch.optim.SGD(params=filter(lambda p: p.requires_grad, model.parameters()),
                                  lr=CONFIG_ARCH['optimizer']['w_lr'], 
                                  momentum=CONFIG_ARCH['optimizer']['w_momentum'],
                                  weight_decay=CONFIG_ARCH['optimizer']['w_weight_decay'])
last_epoch = -1
alpha_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(alpha_optim,
                                                         T_max=CONFIG_ARCH['train_settings']['cnt_epochs'],
                                                         eta_min=CONFIG_ARCH['train_settings']['eta_min'],
                                                         last_epoch=last_epoch)
w_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         T_max=CONFIG_ARCH['train_settings']['cnt_epochs'],
                                                         eta_min=CONFIG_ARCH['train_settings']['eta_min'],
                                                         last_epoch=last_epoch)

# Training

In [11]:
am_tr_top1, am_tr_losses   = AverageMeter(), AverageMeter()
am_val_top1, am_val_losses = AverageMeter(), AverageMeter()

architect = Architect(model)
for epoch in range(CONFIG_ARCH['train_settings']['cnt_epochs']):
    
    model = model.train()
    for (xs, ys), (valxs, valys) in zip(train_w_loader, train_thetas_loader) :
        xs, ys       = xs.cuda(non_blocking=True), ys.cuda(non_blocking=True)
        valxs, valys = valxs.cuda(non_blocking=True), valys.cuda(non_blocking=True)
        
        # первые 10 эпох учим только веса модели
        if epoch > 10:
            alpha_optim.zero_grad()
            architect.unrolled_backward(xs, ys, valxs, valys, alpha_optim.param_groups[0]['lr'], alpha_optim)
            alpha_optim.step()
            
        optimizer.zero_grad()
        outs = model(xs)
        loss = criterion(outs, ys)
        prec1 = accuracy(outs, ys, topk=(1,))
        
        am_tr_losses.update(loss.item(), xs.shape[0])
        am_tr_top1.update(prec1[0].item(), xs.shape[0])
        
        loss.backward() 
        optimizer.step()
    
    # Валидация
    model = model.eval()
    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.cuda(), y.cuda()
            outs = model(X)
            
            loss = criterion(outs, y)
            am_val_losses.update(loss.item(), X.shape[0])
            
            prec1 = accuracy(outs, y, topk=(1,))
            am_val_top1.update(prec1[0].item(), X.shape[0])
    
    logger.info("Epoch " + str(epoch) + " :\tLoss (" + str(round(am_tr_losses.get_avg(), 3)) \
                + " ; " + str(round(am_val_losses.get_avg(), 3)) \
                + ") ;\tAccuracy (" + str(round(am_tr_top1.get_avg(), 3)) \
                + " ; " + str(round(am_val_top1.get_avg(), 3)) + ")")
    
    writer.add_scalars('loss', {'train'      : am_tr_losses.get_avg(),
                                'validation' : am_val_losses.get_avg()}, epoch)
    writer.add_scalars('accuracy', {'train'      : am_tr_top1.get_avg(),
                                    'validation' : am_val_top1.get_avg()}, epoch)
    
    for avg in [am_tr_top1, am_tr_losses, am_val_top1, am_val_losses]:
        avg.reset()
    
    if epoch % snapshot == 0 or epoch == CONFIG_ARCH['train_settings']['cnt_epochs'] - 1:
        torch.save(model.state_dict(), './logs_darts/model_'+ str(epoch) +'.pth')
    
    # первые 10 эпох учим только веса модели
    if epoch > 10:
        writer.add_scalar('learning_rate-alphas', alpha_optim.param_groups[0]['lr'], epoch)
        writer.add_scalars('alphas', {'weight_of_good' : model.arch_parameters()[0][0].item(),
                                      'weight_of_bad'  : model.arch_parameters()[1][0].item()}, epoch)
        alpha_scheduler.step()
            
    writer.add_scalar('learning_rate-weights', optimizer.param_groups[0]['lr'], epoch)   
    w_scheduler.step()

05/15 10:07:29 AM | Epoch 0 :	Loss (2.303 ; 2.303) ;	Accuracy (0.115 ; 0.11)
05/15 10:07:35 AM | Epoch 1 :	Loss (2.303 ; 2.302) ;	Accuracy (0.117 ; 0.116)
05/15 10:07:41 AM | Epoch 2 :	Loss (2.302 ; 2.302) ;	Accuracy (0.121 ; 0.154)
05/15 10:07:47 AM | Epoch 3 :	Loss (2.299 ; 2.286) ;	Accuracy (0.106 ; 0.1)
05/15 10:07:53 AM | Epoch 4 :	Loss (2.252 ; 2.209) ;	Accuracy (0.185 ; 0.257)
05/15 10:07:59 AM | Epoch 5 :	Loss (2.208 ; 2.182) ;	Accuracy (0.243 ; 0.269)
05/15 10:08:05 AM | Epoch 6 :	Loss (2.189 ; 2.17) ;	Accuracy (0.258 ; 0.283)
05/15 10:08:10 AM | Epoch 7 :	Loss (2.181 ; 2.151) ;	Accuracy (0.268 ; 0.301)
05/15 10:08:16 AM | Epoch 8 :	Loss (2.163 ; 2.144) ;	Accuracy (0.286 ; 0.31)
05/15 10:08:22 AM | Epoch 9 :	Loss (2.144 ; 2.118) ;	Accuracy (0.307 ; 0.333)
05/15 10:08:28 AM | Epoch 10 :	Loss (2.126 ; 2.101) ;	Accuracy (0.326 ; 0.354)
05/15 10:08:40 AM | Epoch 11 :	Loss (2.132 ; 2.114) ;	Accuracy (0.321 ; 0.34)
05/15 10:08:53 AM | Epoch 12 :	Loss (2.123 ; 2.074) ;	Accuracy (0.33

05/15 10:27:20 AM | Epoch 105 :	Loss (1.715 ; 1.738) ;	Accuracy (0.746 ; 0.722)
05/15 10:27:32 AM | Epoch 106 :	Loss (1.714 ; 1.746) ;	Accuracy (0.747 ; 0.712)
05/15 10:27:44 AM | Epoch 107 :	Loss (1.71 ; 1.728) ;	Accuracy (0.751 ; 0.733)
05/15 10:27:56 AM | Epoch 108 :	Loss (1.71 ; 1.725) ;	Accuracy (0.75 ; 0.736)
05/15 10:28:08 AM | Epoch 109 :	Loss (1.709 ; 1.738) ;	Accuracy (0.752 ; 0.72)
05/15 10:28:20 AM | Epoch 110 :	Loss (1.707 ; 1.728) ;	Accuracy (0.754 ; 0.732)
05/15 10:28:32 AM | Epoch 111 :	Loss (1.701 ; 1.729) ;	Accuracy (0.76 ; 0.73)
05/15 10:28:44 AM | Epoch 112 :	Loss (1.7 ; 1.724) ;	Accuracy (0.761 ; 0.737)
05/15 10:28:56 AM | Epoch 113 :	Loss (1.701 ; 1.726) ;	Accuracy (0.76 ; 0.733)
05/15 10:29:09 AM | Epoch 114 :	Loss (1.703 ; 1.732) ;	Accuracy (0.758 ; 0.728)
05/15 10:29:21 AM | Epoch 115 :	Loss (1.699 ; 1.724) ;	Accuracy (0.762 ; 0.736)
05/15 10:29:33 AM | Epoch 116 :	Loss (1.7 ; 1.718) ;	Accuracy (0.761 ; 0.741)
05/15 10:29:45 AM | Epoch 117 :	Loss (1.7 ; 1.724) 