In [None]:
!pip install torchtoolbox

# **Libraries**

In [None]:
import torchvision
import torchvision.transforms as transforms

from torchtoolbox.transform import Cutout

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from torch import optim

from copy import deepcopy
import os
import numpy as np
from tqdm import tqdm
import glob
import shutil
import random
import datetime
import time

# **Data loading**

In [None]:
def data_loader(dataset, train_batch_size, test_batch_size):
    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    pre_process = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        Cutout(),
        transforms.ToTensor(),

        normalize
    ]
    transform_train = transforms.Compose(pre_process)

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])
    if dataset == 'cifar10':
        train_data = torchvision.datasets.CIFAR10(
            root='dataset/',
            train=True,
            transform=transform_train,
            download=True,
        )

        test_data = torchvision.datasets.CIFAR10(
            root='dataset/',
            train=False,
            transform=transform_test,
            download=True
        )
    elif dataset == 'svhn':
        train_data = torchvision.datasets.SVHN(
            root='dataset/',
            split='train',
            transform=transform_train,
            download=True,
        )

        test_data = torchvision.datasets.SVHN(
            root='dataset/',
            split='test',
            transform=transform_test,
            download=True
        )
    elif dataset == "cifar100":
        train_data = torchvision.datasets.CIFAR100(
            root='dataset/',
            train=True,
            transform=transform_train,
            download=True,
        )

        test_data = torchvision.datasets.CIFAR100(
            root='dataset/',
            train=False,
            transform=transform_test,
            download=True
        )
  
    train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_data, batch_size=test_batch_size, shuffle=False, num_workers=0)
    return train_loader, test_loader

# **Generate network**

In [None]:
class GenerateNet(nn.Module):
    """
    Generate network with network configuration
    """

    def __init__(self, net_config, n_class):
        super(GenerateNet, self).__init__()
        self.net_config = net_config
        self.n_class = n_class # number of class in the output
        self.node_list = [] #internal net configuration

        self.get_conv_from_dict = lambda x: nn.Conv2d(in_channels=x['in_channels'], out_channels=x['out_channels'],
                                                      kernel_size=x['kernel_size'],
                                                      padding=x['padding'], stride=x['stride'])
        self.get_bn_from_dict = lambda x: nn.BatchNorm2d(x['input_size'])
        self.get_linear_from_dict = lambda x: nn.Linear(x['input_size'], x['output_size'])
        self.get_maxpooling_from_dict = lambda x: nn.MaxPool2d(kernel_size=x['kernel_size'], stride=x['stride'])
        self.get_dropout_from_dict = lambda x: nn.Dropout2d(p=x['dropout_rate'])
        self._add_model_from_dict()
        for node_name in self.net_config:
            self.node_list.append([node_name] + self.net_config[node_name]['inbound_nodes'])

    def _add_model_from_dict(self):
        for node_name in self.net_config:
            node_config = self.net_config[node_name]['config']
            if 'conv' in node_name:
                self.add_module(node_name, self.get_conv_from_dict(node_config))
            elif 'bn' in node_name:
                self.add_module(node_name, self.get_bn_from_dict(node_config))
            elif 'relu' in node_name:
                self.add_module(node_name, nn.ReLU())
            elif 'fc' in node_name:
                node_config['output_size'] = self.n_class
                self.add_module(node_name, self.get_linear_from_dict(node_config))
            elif 'max' in node_name:
                self.add_module(node_name, self.get_maxpooling_from_dict(node_config))
            elif 'dropout' in node_name:
                self.add_module(node_name, self.get_dropout_from_dict(node_config)) #dropout layer

    def forward(self, x, device):
        layers = dict(self.named_children())
        _node_list = deepcopy(self.node_list)
        final_node = None
        layer_out = {'input': x}
        while len(_node_list) > 0:
            _node_list_len = len(_node_list)
            for node in _node_list:
                node_name = node[0]
                inbound_nodes = node[1:]
                if set(inbound_nodes) <= set(layer_out.keys()):
                    if 'add' in node_name:
                        assert len(inbound_nodes) == 2 or len(inbound_nodes == 0), ValueError('Inbound_nodes error')
                        layer_out[node_name] = layer_out[inbound_nodes[0]] + layer_out[inbound_nodes[1]]
                    elif 'concat' in node_name:
                        assert len(inbound_nodes) == 2 or len(inbound_nodes == 0), ValueError('Inbound_nodes error')
                        layer_out[node_name] = torch.cat(
                            (layer_out[inbound_nodes[0]][:, :, :, :], layer_out[inbound_nodes[1]][:, :, :, :]), 1)
                    elif 'fc' in node_name:
                        out = layer_out[inbound_nodes[0]]
                        out = out.view(out.size()[0], -1)
                        layer_out[node_name] = layers[node_name](out)
                    elif 'lambda' in node_name:
                        out = layer_out[inbound_nodes[0]]
                        layer_out[node_name] = 0.5 * out
                    else:
                        if 'conv' in node_name:
                            #Resolve convolutional layer input problem with dropout config
                            if layer_out[inbound_nodes[0]].shape[1] != layers[node_name].in_channels:
                                new_layer = nn.Conv2d(in_channels=layer_out[inbound_nodes[0]].shape[1], out_channels=layers[node_name].out_channels,
                                                      kernel_size=layers[node_name].kernel_size,
                                                      padding=layers[node_name].padding, stride=layers[node_name].stride)
                                new_weight = torch.zeros(new_layer.weight.shape)
                                new_weight[:,:layers[node_name].weight.shape[1],:,:] = layers[node_name].weight.data
                                new_layer.weight.data = new_weight
                                layers[node_name] = new_layer
                                layers[node_name].to(device)
                        layer_out[node_name] = layers[node_name](layer_out[inbound_nodes[0]])
                    final_node = node_name
                    _node_list.remove(node)
            assert len(_node_list) < _node_list_len, 'Net configuration error!'

        return layer_out[final_node]

# **Network morphisms**

In [None]:
from network_config import init_config, dropout_config

class NetworkMorphisms(object):
    def __init__(self, dataset, in_channels=3, picture_size=(32, 32)):
        self.in_channels = in_channels
        self.picture_size = picture_size
        if dataset == 'cifar10' or dataset == 'svhn':
            self.n_class = 10
        elif dataset == 'cifar100':
            self.n_class = 100
        else:
            print('\tInvalid input dataset name at NetworkMorphisms()')
            exit(1)
        self.teacher_config = None
        self.student_config = None  
        self.teacher = None
        self.student = None
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        self.train_loader, self.test_loader = data_loader(dataset, train_batch_size=128, test_batch_size=100)

    def load_teacher(self, model_path):
        """
        load teacher network from check point file
        """
        assert os.path.isfile(model_path), 'The model path does not exist'
        check_point = torch.load(model_path)
        self.teacher = GenerateNet(check_point['model_config'], self.n_class)
        self.teacher_config = check_point['model_config']
        self.teacher.load_state_dict(check_point['model_state_dict'])

    def initial_network(self, epochs=20, lr=0.05, model_folder='', model_config=None):
        """
        Initialize the network as the basic network
        """
        if model_config is None:
            model_config = deepcopy(init_config)
        else:
            model_config = deepcopy(model_config)
        self.teacher_config = model_config
        self.teacher = GenerateNet(model_config, self.n_class)
        self.teacher = self.teacher.to(self.device)

        optimizer = optim.SGD(params=self.teacher.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        loss_func = torch.nn.CrossEntropyLoss()
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=8)

        best_acc = 0
        best_loss = 0
        for epoch in range(epochs):
            self._train(epoch, optimizer, loss_func)
            correct, total, loss = self._eval(epoch, loss_func)
            acc = correct / total
            if acc > best_acc:
                self.save_model(best_acc, loss, self.teacher.state_dict(), self.teacher_config, model_folder)
                best_acc = acc
                best_loss = loss
            scheduler.step()
        print('\nBest: accuracy: %f, loss: %f\n' % (best_acc, best_loss))

    def train(self, epochs=17, lr=0.05, save_folder='./', one_cycle=False, early_stopping=False):
        optimizer = optim.SGD(params=self.teacher.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

        if one_cycle:
          # One Cycle LR
          scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr*2, epochs=epochs, steps_per_epoch=int(len(self.train_loader)/128))
        else:
          # Cosine Annealing LR
          scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=8)
                                                  
        loss_func = torch.nn.CrossEntropyLoss()
        run_history = []
        run_loss = []
        self.teacher = self.teacher.to(self.device)
        
        # Early stopping criteria
        best_loss = float('inf')
        patience = 0#int(epochs/10)+1
        counter = 1

        for epoch in range(epochs):
            self._train(epoch, optimizer, loss_func)
            correct, total, loss = self._eval(epoch, loss_func)
            acc = correct / total
            run_history.append(acc)
            run_loss.append(loss)
            
            if early_stopping:
              # Check if the validation loss has improved
              if loss < best_loss:
                  best_loss = loss
                  counter = 0
              else:
                  counter += 1

              # Check if we need to stop training
              if counter >= patience:
                  print(f"Stopping early at epoch {epoch}")
                  break
            
            scheduler.step()
        self.save_model(np.mean(run_history[-3:]), np.mean(run_loss[-3:]), self.teacher.state_dict(), self.teacher_config, save_folder)
        return run_history, run_loss

    def change_teacher(self, student_weight):
        self.teacher = GenerateNet(self.student_config, self.n_class)
        self.teacher_config = deepcopy(self.student_config)
        self.teacher.load_state_dict(student_weight)

    def generate_node_name(self, name):
        """
        Generate a new node name
        """
        same_node = 0
        for node_name in self.student_config:
            if name in node_name:
                same_node += 1
        return name + str(same_node + 1)

    def add(self, node_index: int):
        """
        Create 'add motif' as in https://arxiv.org/pdf/1806.02639.pdf 
        """
        self.student_config = deepcopy(self.teacher_config)
        student_weight = self.teacher.state_dict()
        nodes_list = self.get_nodes_list()
        node_name, bn_index, bn_name, relu_index, relu_name = self.get_conv_bn_relu(nodes_list, node_index)

        lambda1 = self.generate_node_name('lambda')
        self.student_config[lambda1] = {'config': '', 'inbound_nodes': [relu_name]}

        conv1 = self.generate_node_name('conv')
        bn1 = self.generate_node_name('bn')
        relu1 = self.generate_node_name('relu')
        self.student_config[conv1] = deepcopy(self.student_config[node_name])
        self.student_config[bn1] = deepcopy(self.student_config[bn_name])
        self.student_config[bn1]['inbound_nodes'] = [conv1]
        self.student_config[relu1] = deepcopy(self.student_config[relu_name])
        self.student_config[relu1]['inbound_nodes'] = [bn1]

        lambda2 = self.generate_node_name('lambda')
        self.student_config[lambda2] = deepcopy(self.student_config[lambda1])
        self.student_config[lambda2]['inbound_nodes'] = [relu1]

        add1 = self.generate_node_name('add')
        self.student_config[add1] = {'config': '', 'inbound_nodes': [lambda1, lambda2]}

        next_nodes_index = self.get_next_nodes(relu_index)
        self.replace_student_node_inbound(nodes_list, next_nodes_index, relu_name, add1)

        self.student = GenerateNet(self.student_config, self.n_class)
        node_weight = student_weight[node_name + '.weight']
        student_weight[conv1 + '.weight'] = node_weight + np.random.normal(scale=node_weight.std() * 0.01,
                                                                           size=node_weight.shape)
        student_weight[conv1 + '.bias'] = student_weight[node_name + '.bias']
        student_weight[bn1 + '.weight'] = student_weight[bn_name + '.weight']
        student_weight[bn1 + '.bias'] = student_weight[bn_name + '.bias']
        student_weight[bn1 + '.running_mean'] = student_weight[bn_name + '.running_mean']
        student_weight[bn1 + '.running_var'] = student_weight[bn_name + '.running_var']
        self.student.load_state_dict(student_weight)

        self.change_teacher(student_weight)

    def concat(self, node_index: int):
        """
        Create 'concatenation motif' as in https://arxiv.org/pdf/1806.02639.pdf
        """
        self.student_config = deepcopy(self.teacher_config)
        student_weight = self.teacher.state_dict()
        nodes_list = self.get_nodes_list()
        node_name, bn_index, bn_name, relu_index, relu_name = self.get_conv_bn_relu(nodes_list, node_index)

        filters = self.student_config[node_name]['config']['out_channels']
        self.student_config[node_name]['config']['out_channels'] = int(filters / 2)

        conv1 = self.generate_node_name('conv')
        bn1 = self.generate_node_name('bn')
        relu1 = self.generate_node_name('relu')
        self.student_config[conv1] = deepcopy(self.student_config[node_name])

        self.student_config[bn_name]['config']['input_size'] = int(filters / 2)
        self.student_config[bn1] = deepcopy(self.student_config[bn_name])
        self.student_config[bn1]['inbound_nodes'] = [conv1]

        self.student_config[relu1] = deepcopy(self.student_config[relu_name])
        self.student_config[relu1]['inbound_nodes'] = [bn1]

        concat1 = self.generate_node_name('concat')
        self.student_config[concat1] = {'config': '', 'inbound_nodes': [relu1, relu_name]}

        next_conv_index = self.get_next_nodes(relu_index)
        self.replace_student_node_inbound(nodes_list, next_conv_index, relu_name, concat1)

        self.student = GenerateNet(self.student_config, self.n_class)
        node_weight = student_weight[node_name + '.weight'][:int(filters / 2), :, :, :]
        student_weight[conv1 + '.weight'] = node_weight + np.random.normal(scale=node_weight.std() * 0.01,
                                                                           size=node_weight.shape)
        student_weight[conv1 + '.bias'] = student_weight[node_name + ".bias"][:int(filters / 2)]

        student_weight[node_name + '.weight'] = student_weight[node_name + '.weight'][int(filters / 2):, :, :, :]
        student_weight[node_name + '.bias'] = student_weight[node_name + '.bias'][int(filters / 2):]

        student_weight[bn1 + '.weight'] = student_weight[bn_name + '.weight'][:int(filters / 2)]
        student_weight[bn1 + '.bias'] = student_weight[bn_name + '.bias'][:int(filters / 2)]
        student_weight[bn1 + '.running_mean'] = student_weight[bn_name + '.running_mean'][:int(filters / 2)]
        student_weight[bn1 + '.running_var'] = student_weight[bn_name + '.running_var'][:int(filters / 2)]

        student_weight[bn_name + '.weight'] = student_weight[bn_name + '.weight'][int(filters / 2):]
        student_weight[bn_name + '.bias'] = student_weight[bn_name + '.bias'][int(filters / 2):]
        student_weight[bn_name + '.running_mean'] = student_weight[bn_name + '.running_mean'][int(filters / 2):]
        student_weight[bn_name + '.running_var'] = student_weight[bn_name + '.running_var'][int(filters / 2):]
        self.student.load_state_dict(student_weight)

        self.change_teacher(student_weight)

    def wider2net_conv2d(self, node_index: int, new_width=None):
        """
        Function that add filters to convolutional filter. If new_width is not provided it double numbers of filters
        """
        self.student_config = deepcopy(self.teacher_config)
        student_weight = self.teacher.state_dict()
        nodes_list = self.get_nodes_list()
        node_name, bn_index, bn_name, relu_index, relu_name = self.get_conv_bn_relu(nodes_list, node_index)

        next_node_index = self.get_next_nodes(relu_index)
        assert len(next_node_index) == 1, 'Wrong place for widder'
        next_node_index = next_node_index[0]
        next_node_name = nodes_list[next_node_index][0]

        assert 'lambda' not in next_node_name, 'Wider inside add or concatenate block'

        if 'max' in next_node_name:
            for idx, node in enumerate(nodes_list):
                if node[1] == next_node_name:
                    next_node_index, next_node_name = idx, node[0]
                    break
        if 'dropout' in next_node_name:
            for idx, node in enumerate(nodes_list):
                if node[1] == next_node_name:
                    next_node_index1, next_node_name1 = idx, node[0]
                    if 'max' in next_node_name1:
                      for idx2, node2 in enumerate(nodes_list):
                        if node2[1] == next_node_name1:
                          next_node_index, next_node_name = idx2, node2[0]
                          break
        assert 'fc' not in next_node_name, 'Last convolutional layer'

        teacher_w1, teacher_b1 = student_weight[node_name + '.weight'], student_weight[node_name + '.bias']
        alpha, beta, mean, std = student_weight[bn_name + '.weight'], student_weight[bn_name + '.bias'], student_weight[
            bn_name + '.running_mean'], student_weight[bn_name + '.running_var']
        teacher_w2, teacher_b2 = student_weight[next_node_name + '.weight'], student_weight[next_node_name + '.bias']
        original_filters = teacher_w1.shape[0]
        if new_width is None:
            new_width = self.student_config[node_name]['config']['out_channels'] * 2
        n = new_width - original_filters
        assert n > 0, "New width smaller than teacher width"
        index = np.random.randint(original_filters, size=n)
        factors = np.bincount(index)[index] + 1.
        new_w1 = teacher_w1[index, :, :, :]
        new_b1 = teacher_b1[index]
        new_w2 = (teacher_w2[:, index, :, :] / torch.from_numpy(factors.reshape((1, -1, 1, 1))).to(teacher_w2.device))

        new_alpha = alpha[index]
        new_beta = beta[index]
        new_mean = mean[index]
        new_std = std[index]

        new_w1 = new_w1 + np.random.normal(scale=new_w1.std() * 0.05, size=new_w1.shape)
        student_w1 = torch.cat((teacher_w1, new_w1), 0)
        student_b1 = torch.cat((teacher_b1, new_b1), 0)

        alpha = torch.cat((alpha, new_alpha))
        beta = torch.cat((beta, new_beta))
        mean = torch.cat((mean, new_mean))
        std = torch.cat((std, new_std))
        new_w2 = new_w2 + np.random.normal(scale=new_w2.std() * 0.05, size=new_w2.shape)

        student_w2 = torch.cat((teacher_w2, new_w2), 1)
        student_w2[:, index, :, :] = new_w2

        self.student_config[node_name]['config']['out_channels'] = new_width
        self.student_config[bn_name]['config']['input_size'] = new_width
        self.student_config[next_node_name]['config']['in_channels'] = new_width
        student_weight[node_name + '.weight'], student_weight[node_name + '.bias'] = student_w1, student_b1
        student_weight[next_node_name + '.weight'], student_weight[next_node_name + '.bias'] = student_w2, teacher_b2
        student_weight[bn_name + '.weight'], student_weight[bn_name + '.bias'], student_weight[
            bn_name + '.running_mean'], student_weight[bn_name + '.running_var'] = alpha, beta, mean, std

        self.student = GenerateNet(self.student_config, self.n_class)
        self.student.load_state_dict(student_weight)

        self.change_teacher(student_weight)

    def wider2net_conv2d_fc(self, node_index: int, new_width=None):

        """
        Add filters to the convolutional layer that is placed before fully connected layer
        """

        self.student_config = deepcopy(self.teacher_config)
        student_weight = self.teacher.state_dict()
        nodes_list = self.get_nodes_list()
        node_name, bn_index, bn_name, relu_index, relu_name = self.get_conv_bn_relu(nodes_list, node_index)

        next_node_index = self.get_next_nodes(relu_index)
        assert len(next_node_index) == 1, 'Wrong place for widder'
        next_node_index = next_node_index[0]
        next_node_name = nodes_list[next_node_index][0]

        if 'max' in next_node_name:
            next_node_index = self.get_next_nodes(next_node_index)[0]
            next_node_name = nodes_list[next_node_index][0]
        if 'dropout' in next_node_name:
            for idx, node in enumerate(nodes_list):
                if node[1] == next_node_name:
                    next_node_index, next_node_name = idx, node[0]
                    for idx2, node2 in enumerate(nodes_list):
                      if node2[1] == next_node_name:
                        next_node_index, next_node_name = idx2, node2[0]
                        break
        assert 'fc' in next_node_name, 'there is not a fully connected layer'

        teacher_w1, teacher_b1 = student_weight[node_name + ".weight"], student_weight[node_name + '.bias']
        alpha, beta, mean, std = student_weight[bn_name + '.weight'], student_weight[bn_name + '.bias'], student_weight[
            bn_name + '.running_mean'], student_weight[bn_name + '.running_var']
        teacher_w2, teacher_b2 = student_weight[next_node_name + '.weight'], student_weight[next_node_name + '.bias']

        original_filters = teacher_w1.shape[0]
        if new_width is None:
            new_width = self.student_config[node_name]['config']['out_channels'] * 2
        n = new_width - original_filters
        assert n > 0, "New width smaller than teacher width"
        
        index = np.random.randint(original_filters, size=n)
        factors = np.bincount(index)[index] + 1.
        new_w1 = teacher_w1[index, :, :, :]
        new_b1 = teacher_b1[index]

        new_w2 = teacher_w2.T
        new_w2 = new_w2[index, :] / factors.reshape((-1, 1))

        new_alpha = alpha[index]
        new_beta = beta[index]
        new_mean = mean[index]
        new_std = std[index]

        alpha = torch.cat((alpha, new_alpha))
        beta = torch.cat((beta, new_beta))
        mean = torch.cat((mean, new_mean))
        std = torch.cat((std, new_std))

        new_w1 = new_w1 + np.random.normal(scale=new_w1.std() * 0.05, size=new_w1.shape)
        student_w1 = torch.cat((teacher_w1, new_w1))
        student_b1 = torch.cat((teacher_b1, new_b1))
        new_w2 = new_w2 + np.random.normal(scale=new_w2.std() * 0.05, size=new_w2.shape)
        student_w2 = torch.cat((teacher_w2.T, new_w2))
        student_w2[index, :] = new_w2
        student_w2 = student_w2.T

        self.student_config = deepcopy(self.student_config)

        self.student_config[node_name]['config']['out_channels'] = new_width
        self.student_config[bn_name]['config']['input_size'] = new_width
        self.student_config[next_node_name]['config']['input_size'] = new_width
        student_weight[node_name + '.weight'], student_weight[node_name + '.bias'] = student_w1, student_b1
        student_weight[next_node_name + '.weight'], student_weight[next_node_name + '.bias'] = student_w2, teacher_b2
        student_weight[bn_name + '.weight'], student_weight[bn_name + '.bias'], student_weight[
            bn_name + '.running_mean'], student_weight[bn_name + '.running_var'] = alpha, beta, mean, std

        self.student = GenerateNet(self.student_config, self.n_class)
        self.student.load_state_dict(student_weight)

        self.change_teacher(student_weight)

    def deeper2net_conv2d(self, node_index: int):

        """
        Add convolutional layer after convolutional layer
        """
        self.student_config = deepcopy(self.teacher_config)
        student_weight = self.teacher.state_dict()
        nodes_list = self.get_nodes_list()
        node_name, bn_index, bn_name, relu_index, relu_name = self.get_conv_bn_relu(nodes_list, node_index)

        conv1 = self.generate_node_name('conv')
        bn1 = self.generate_node_name('bn')
        relu1 = self.generate_node_name('relu')
        filters = self.student_config[node_name]['config']['out_channels']
        kh = kw = self.student_config[node_name]['config']['kernel_size']

        self.student_config[conv1] = {
            'config': {'in_channels': filters, 'out_channels': filters, 'kernel_size': 3, 'padding': 1, 'stride': 1},
            'inbound_nodes': [relu_name]}

        self.student_config[bn1] = deepcopy(self.student_config[bn_name])
        self.student_config[bn1]['inbound_nodes'] = [conv1]

        self.student_config[relu1] = deepcopy(self.student_config[relu_name])
        self.student_config[relu1]['inbound_nodes'] = [bn1]

        next_nodes_index = self.get_next_nodes(relu_index)
        self.replace_student_node_inbound(nodes_list, next_nodes_index, relu_name, relu1)

        student_w = torch.zeros((filters, filters, kh, kw))
        for i in range(filters):
            student_w[i, i, (kh - 1) // 2, (kw - 1) // 2] = 1.
        student_w = student_w + np.random.normal(scale=student_w.std() * 0.01, size=student_w.shape)
        student_weight[conv1 + '.weight'] = student_w
        student_weight[conv1 + '.bias'] = torch.zeros(student_weight[node_name + '.bias'].shape)
        student_weight[bn1 + '.weight'] = student_weight[bn_name + '.weight']
        student_weight[bn1 + '.bias'] = student_weight[bn_name + '.bias']
        student_weight[bn1 + '.running_mean'] = student_weight[bn_name + '.running_mean']
        student_weight[bn1 + '.running_var'] = student_weight[bn_name + '.running_var']
        self.student = GenerateNet(self.student_config, self.n_class)

        self.student.load_state_dict(student_weight)
        self.change_teacher(student_weight)

    def skip(self, node_index: int, change_teacher=False):
        """
        Add skip connection. This is combination of 'add' and 'deeper2net_conv2d' functions
        """

        nodes_before_deeper = self.get_nodes_list(teacher=True)
        nodes_before_deeper = [item[0] for item in nodes_before_deeper]
        self.deeper2net_conv2d(node_index)
        nodes_after_deeper = self.get_nodes_list(teacher=True)
        nodes_after_deeper = [item[0] for item in nodes_after_deeper]
        difference = list(set(nodes_after_deeper) - set(nodes_before_deeper))
        new_relu_name = [x for x in difference if 'relu' in x][0]
        new_conv_name = [x for x in difference if 'conv' in x][0]

        self.student_config = deepcopy(self.teacher_config)
        student_weight = self.teacher.state_dict()

        lambda1 = self.generate_node_name('lambda')
        self.student_config[lambda1] = {'config': '', 'inbound_nodes': [new_relu_name]}
        lambda2 = self.generate_node_name('lambda')
        self.student_config[lambda2] = {'config': '', 'inbound_nodes': [new_conv_name]}

        add1 = self.generate_node_name('add')
        self.student_config[add1] = {'config': '', 'inbound_nodes': [lambda1, lambda2]}
        nodes_list = self.get_nodes_list()

        new_relu_index = None
        for index, node in enumerate(nodes_list):
            if node[0] == new_relu_name:
                new_relu_index = index

        next_node_index = self.get_next_nodes(new_relu_index)
        self.replace_student_node_inbound(nodes_list, next_node_index, new_relu_name, add1)

        self.student = GenerateNet(self.student_config, self.n_class)
        self.student.load_state_dict(student_weight)
        if change_teacher:
            self.change_teacher(student_weight)

    def _train(self, epoch, optimizer, loss_func):
        self.teacher.train()
        train_loss, correct, total = 0, 0, 0
        with tqdm(total=len(self.train_loader), desc='train epoch %d' % epoch, colour='black') as t_train:
            for step, (train_x, train_y) in enumerate(self.train_loader):
                train_x, train_y = train_x.to(self.device), train_y.to(self.device)
                optimizer.zero_grad()
                output = self.teacher(train_x,self.device)
                loss = loss_func(output, train_y)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
                total += train_y.size(0)
                _, predict = output.max(1)
                correct += predict.eq(train_y).sum().item()
                t_train.set_postfix({'step': step, 'length of train': len(self.train_loader),
                                     'Loss': '%.3f' % (train_loss / (step + 1)),
                                     'Acc': '%.3f%% (%d/%d)' % (100. * correct / total, correct, total)})
                t_train.update(1)

    def _eval(self, epoch, loss_func):
        self.teacher.eval()
        test_loss, correct, total = 0, 0, 0
        with torch.no_grad():
            with tqdm(total=len(self.test_loader), desc='eval epoch %d' % epoch, colour='black') as t:
                for step, (test_x, test_y) in enumerate(self.test_loader):
                    test_x, test_y = test_x.to(self.device), test_y.to(self.device)
                    output = self.teacher(test_x,self.device)
                    loss = loss_func(output, test_y)
                    test_loss += loss.item()
                    _, predict = output.max(1)
                    total += test_y.size(0)
                    correct += predict.eq(test_y).sum().item()
                    t.set_postfix({'step': step, 'length of eval': len(self.test_loader),
                                   'Loss': '%.3f' % (test_loss / (step + 1)),
                                   'Acc': '%.3f%% (%d/%d)' % (100. * correct / total, correct, total)})
                    t.update(1)
        return correct, total, test_loss / len(self.test_loader)

    def replace_student_node_inbound(self, node_list, nodes_index, original_inbound_node_name, new_inbound_node_name):
        """
        Replace the old inbound node of the nodes with the new inbound node name
        """
        for index in nodes_index:
            for idx, element in enumerate(self.student_config[node_list[index][0]]['inbound_nodes']):
                if element == original_inbound_node_name:
                    self.student_config[node_list[index][0]]['inbound_nodes'][idx] = new_inbound_node_name

    def get_nodes_list(self, teacher=False):
        nodes_list = []
        _nodes_config = self.teacher_config if teacher else self.student_config
        for node_name in _nodes_config:
            nodes_list.append([node_name] + _nodes_config[node_name]['inbound_nodes'])

        return nodes_list

    def return_available_nodes(self):
        """
        Before the network morphism, we will check the correspondence between points and operations
        """
        wider2net_conv2d = []
        deeper2net_conv2d = []
        wider2net_conv2d_fc = []
        add = []
        concat = []
        skip = []

        nodes_list = self.get_nodes_list(teacher=True)
        for i, element in enumerate(nodes_list):
            if 'conv' not in element[0]:
                continue
            second = self.get_next_nodes(i)
            if len(second) > 1:
                continue

            third = self.get_next_nodes(second[0])
            fourth = self.get_next_nodes(third[0])

            if len(fourth) > 1:
                continue
            if len(nodes_list[fourth[0]][1:]) > 1:
                continue
            if 'fc' in nodes_list[fourth[0]][0]:
                continue
            if 'lambda' in nodes_list[fourth[0]][0]:
                continue
            if 'conv' or 'max' in nodes_list[fourth[0]][0]:
                fifth = self.get_next_nodes(fourth[0])
                if len(fifth) > 1:
                    continue
                if len(fifth) == 1 and 'fc' in nodes_list[fifth[0]][0]:
                    continue
            wider2net_conv2d.append(i)

        for i, element in enumerate(nodes_list):
            if 'conv' not in element[0]:
                continue
            second = self.get_next_nodes(i)
            third = self.get_next_nodes(second[0])
            fourth = self.get_next_nodes(third[0])
            if 'max' in nodes_list[fourth[0]][0]:
                fifth = self.get_next_nodes(fourth[0])
                if len(fifth) == 1 and 'fc' in nodes_list[fifth[0]][0]:
                    wider2net_conv2d_fc.append(i)
            if 'fc' in nodes_list[fourth[0]]:
                wider2net_conv2d_fc.append(i)

        for i, element in enumerate(nodes_list):
            if 'conv' in element[0]:
                deeper2net_conv2d.append(i)

        for i, element in enumerate(nodes_list):
            if 'conv' not in element[0]:
                continue
            next_layer = self.get_next_nodes(i)
            if len(next_layer) > 1:
                continue
            skip.append(i)

        for i, element in enumerate(nodes_list):
            if 'conv' not in element[0]:
                continue
            next_layer = self.get_next_nodes(i)
            if len(next_layer) > 1:
                continue
            add.append(i)
            concat.append(i)

        available = {'wider2net_conv2d': wider2net_conv2d, 'wider2net_conv2d_fc': wider2net_conv2d_fc,
                     'deeper2net_conv2d': deeper2net_conv2d, 'add': add, 'concat': concat, 'skip': skip}

        return available

    @staticmethod
    def get_conv_bn_relu(nodes_list, node_index):
        node_name = nodes_list[node_index][0]

        assert 'conv' in node_name, 'Wrong layer index'
        bn_index, bn_name, relu_index, relu_name = None, None, None, None
        for idx, node in enumerate(nodes_list):
            if node[1] == node_name and 'bn' in node[0]:
                bn_index, bn_name = idx, node[0]
        for idx, node in enumerate(nodes_list):
            if node[1] == bn_name and 'relu' in node[0]:
                relu_index, relu_name = idx, node[0]
        assert all([bn_index, bn_name, relu_index,
                    relu_name]), 'bn_index or  bn_name or relu_index or relu_name must not be None'
        return node_name, bn_index, bn_name, relu_index, relu_name

    @staticmethod
    def save_model(acc, loss, model_state_dict, model_config, folder):
        check_point = {
            'best_acc': acc,
            'loss_func' : loss,
            'model_state_dict': model_state_dict,
            'model_config': model_config
        }
        if not os.path.isdir(folder):
            os.mkdir(folder)
        torch.save(check_point, os.path.join(folder, 'model.pkl'))

    def number_of_parameter(self):
        return sum(p.numel() for p in self.teacher.parameters())

    def plot_model(self, folder):
        if not os.path.isdir(folder):
            os.mkdir(folder)
        # onnx is a standard to save model, so we can transfer it between different platforms or frames
        torch.onnx.export(self.teacher, torch.rand(1, self.in_channels, self.picture_size[0], self.picture_size[1]),
                          folder + 'model.onnx', opset_version=15, input_names=['input'],output_names=['output'],
                          operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH)
    
    def get_next_nodes(self, node_index, teacher=True):
        nodes_list = self.get_nodes_list(teacher=teacher)
        next_node = []
        for i in range(1, len(nodes_list)):
            if nodes_list[node_index][0] in nodes_list[i][1:]:
                next_node.append(i)
        return list(next_node)

    def get_previous_node(self,node_name,teacher=True):
        nodes_list=self.get_nodes_list(teacher=teacher)
        for node in nodes_list:
            if node[0]==node_name:
                return node[1:]
        return None

    def get_number_of_nodes(self, teacher=True):
        return len(self.get_nodes_list(teacher=teacher))

# **Organism**


In [None]:
class Organism(object):
    def __init__(self, number, model, epoch=''):
        self.number = number
        self.folder = epoch + 'model' + str(self.number) + '/'
        if os.path.isdir(self.folder[:-1]):
            shutil.rmtree(self.folder)
            os.mkdir(self.folder)
        else:
            os.mkdir(self.folder)
        self.model = NetworkMorphisms(model)

    def random_modification(self):
        # Select random modification
        available_modifications = self.model.return_available_nodes()
        while True:
            random_modification = random.choice(list(available_modifications.keys()))
            if len(available_modifications[random_modification]) > 0:
                break
        random_index = random.choice(list(available_modifications[random_modification]))
        print(random_modification, random_index)
        function = getattr(self.model, random_modification)
        function(random_index)
        #self.model.plot_model(self.folder)
        return random_modification

    def train(self, epochs=17, lr=0.05, save_folder='./', one_cycle=False, early_stopping=False):
        return self.model.train(epochs, lr, save_folder=save_folder, one_cycle=one_cycle, early_stopping=early_stopping)

# **Hillclimbing**

In [None]:
class HillClimb(object):
    def __init__(self, number_of_organism, epochs, load_model_path, model):
        self.number_of_organism = number_of_organism
        self.epochs = epochs
        self.load_model_path = load_model_path
        self.model = model
        self.time = 0

    def start(self, number_of_modifications=5, organisms_train_epochs=17, organisms_train_lr=0.05, one_cycle=False, early_stopping=False):
        model_dirs = glob.glob('model*/')
        for model_dir in model_dirs:
            shutil.rmtree(model_dir)
        if os.path.isdir('best'):
            shutil.rmtree('best')
            os.mkdir('best')
        else:
            os.mkdir('best')
        shutil.copyfile(self.load_model_path, 'best/model.pkl')
        
        for epoch in range(self.epochs):
            print('Step %d' % (epoch+1))
            list_of_organisms = []
            list_of_result = []
            list_of_loss = []
            for i in range(self.number_of_organism):
                list_of_organisms.append(Organism(i, self.model))
            for i in range(self.number_of_organism):
                while True:
                    print('Model loading %d' % i)
                    torch.cuda.empty_cache()
                    list_of_organisms[i].model.load_teacher(model_path='best/model.pkl')
                    modifications = []
                    # Select random modifications
                    for _ in range(number_of_modifications):
                        modification = list_of_organisms[i].random_modification()
                        modifications.append(modification)
                    print('Organism %d: modifications: %s' % (i, modifications))
                    del modifications
                    if list_of_organisms[i].model.number_of_parameter() < 200000000: #Avoid too complex model
                        print('Number of parameters: %d' % list_of_organisms[i].model.number_of_parameter())
                        print('Number of nodes: %d' % list_of_organisms[i].model.get_number_of_nodes(False))
                        break
                    else:
                        print('Repeat drawing of network morphism function: %d' % list_of_organisms[
                            i].model.number_of_parameter())
                
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
                start = time.time()
                acc_history, loss_history = list_of_organisms[i].train(epochs=organisms_train_epochs, lr=organisms_train_lr,
                                                     save_folder=list_of_organisms[i].folder,
                                                     one_cycle=one_cycle, early_stopping=early_stopping)
                torch.cuda.synchronize()
                end = time.time()
                elapsed = end - start
                
                result = np.mean(acc_history[-3:])
                loss = np.mean(loss_history[-3:])
                self.time+=elapsed
                list_of_result.append(result)
                list_of_loss.append(loss)
                print('\nElapsed time (hrs): %.4f' % (elapsed/3600))
                print('Organism %d result: %f %f\n' % (i, result, loss))
                print('Total training time (hrs) so far: %.4f\n' % (self.time/3600))

                #Write training result of organism
                with open('best/results.txt', 'a') as result_file:
                    result_file.write('Step: %d, organism %d accuracy: %f loss: %f, ' % (epoch+1, i, result, loss))
                    result_file.write('number of nodes: %d    number of parameters: %d\n' % (
                    list_of_organisms[i].model.get_number_of_nodes(False),list_of_organisms[i].model.number_of_parameter()))
            
            best = list_of_result.index(max(list_of_result))
            shutil.copyfile(list_of_organisms[best].folder + 'model.pkl', 'best/model.pkl')
            if os.path.exists(list_of_organisms[best].folder + 'model.onnx'):
                shutil.copyfile(list_of_organisms[best].folder + 'model.onnx', 'best/model.onnx')
            print('\nBest: %d, result: %f, %f\n' % (best, list_of_result[best], list_of_loss[best]))
            print('Total training time (hrs) for step %d: %.4f\n' % (epoch+1, self.time/3600))

            with open('best/results.txt', 'a') as result_file:
                result_file.write('Total training time (hrs) for step %d: %.4f\n' % (epoch+1, self.time/3600))
                result_file.write('\nStep: %d, best accuracy: %f best loss: %f\n' % (epoch+1, list_of_result[best], list_of_loss[best]))
                result_file.write('number of nodes: %d    number of parameters: %d\n\n\n' % (
                    list_of_organisms[best].model.get_number_of_nodes(False),list_of_organisms[best].model.number_of_parameter()))
                
        with open('best/results.txt', 'a') as result_file:
                result_file.write('Total hillclimbing time (hrs): %.4f\n' % (self.time/3600))

    def eval(self, epochs=200, lr=0.05):
        #Final training
        torch.cuda.empty_cache()
        model = NetworkMorphisms(self.model)
        model.load_teacher(model_path='best/model.pkl')

        torch.cuda.synchronize()
        start = time.time()
        train_history, train_loss = model.train(epochs=epochs, lr=lr, save_folder='test')
        torch.cuda.synchronize()
        end = time.time()
        elapsed = end - start

        print('\nTotal final training time (hrs): %.4f\n' % (elapsed/3600))
        self.time+=elapsed
        best = train_history.index(max(train_history))
        print(train_history[best], train_loss[best])
        with open('best/results.txt', 'a') as result_file:
            result_file.write('Final model acc(epoch:%d): %.4f loss(epoch:%d): %.4f\nTotal execution time (hrs): %.4f\n' % (epochs, train_history[best], epochs, train_loss[best], self.time/3600))
            result_file.write('number of nodes: %d    number of parameters: %d\n\n\n' % (
                    model.get_number_of_nodes(),model.number_of_parameter()))

# **Main**

In [None]:
def initial_network():
    model = NetworkMorphisms('cifar100')
    model.initial_network(epochs=3, model_folder='initial/', model_config = dropout_config)

In [None]:
initial_network()

In [None]:
def hill_climb():
    evolution = HillClimb(number_of_organism=3, epochs=5, load_model_path='initial/model.pkl', model='cifar100')
    evolution.start(number_of_modifications=5, organisms_train_epochs=17, organisms_train_lr=0.05, one_cycle=True, early_stopping=True)
    evolution.eval(epochs=100)

In [None]:
hill_climb()