In [1]:
import logging
import numpy as np
import os
import pickle
import copy
import random

from pathlib import Path

import torch
from torch import nn
from torchvision.transforms import Compose, ToTensor
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from fvcore.common.config import CfgNode

from naslib.search_spaces import SimpleCellSearchSpace
from naslib.defaults.trainer import Trainer
from naslib.utils.custom_dataset import CustomDataset
from naslib.optimizers import DARTSOptimizer, GDASOptimizer, RandomSearch
from naslib.search_spaces import NasBench301SearchSpace, SimpleCellSearchSpace
from naslib.utils import set_seed, setup_logger, get_config_from_args
from naslib.search_spaces.core.query_metrics import Metric
from naslib.search_spaces.core.graph import Graph
from naslib.search_spaces.nasbenchasr.primitives import ASRPrimitive, CellLayerNorm, Head, ops, PadConvReluNorm
from naslib.utils import get_project_root
from naslib.search_spaces.core import primitives as core_ops
from naslib.search_spaces.nasbenchasr.conversions import flatten, \
    copy_structure, make_compact_mutable, make_compact_immutable
from naslib.search_spaces.nasbenchasr.encodings import encode_asr
from naslib.utils.encodings import EncodingType

In [2]:
datasets = Path('/home/data')

In [3]:
def transform_to_closest_square_shape(x):
    len_x = len(x)
    closest_square = int(np.round(np.sqrt(len_x)))
    if closest_square ** 2 < len_x: 
        closest_square += 1
    x = np.pad(x, (0, closest_square  ** 2 - len_x))
    return x.reshape((1, closest_square, closest_square))

class NasDataset(Dataset):
    def __init__(self, root_dir, name, queue, norm='min_max'):
        super().__init__()
        
        bin_exists = False
        num_exists = False
        
        assert queue in ['train', 'test', 'val']

        self.root_dir = root_dir 
        self.name = name
        self.type = queue

        x_num = f'X_num_{queue}.npy'
        x_bin = f'X_bin_{queue}.npy'
        y = f'Y_{queue}.npy'
        
        
        if (root_dir / name/ x_bin).exists():
            bin_exists = True
            
        if (root_dir / name/ x_num).exists():
            num_exists = True
        
        self.y = np.load(root_dir / name / y)
        
        if num_exists:
            x_num = np.load(root_dir / name / x_num)
            x_num = (x_num - np.min(x_num, axis=0)) / (np.max(x_num, axis=0) - np.min(x_num, axis=0))
        
        if bin_exists:
            x_bin = np.load(root_dir / name / x_bin)
        
        if num_exists and bin_exists:
            self.x = np.concatenate((x_num, x_bin), axis=1)
        elif num_exists:
            self.x = x_num
        elif bin_exists:
            self.x = x_bin
        else:
            raise NotImplementedError
            
        self.num_features = self.x.shape[1]
        self.num_classes = int(np.max(self.y)) + 1
            
    def __getitem__(self, i):
        x = self.x[i]
        y = self.y[i]

        return  x, y.astype(np.int64)
        
    def __len__(self):
        return len(self.x)
    



In [4]:
with open('/home/table_nas/config.yaml') as f:
    config = CfgNode.load_cfg(f)

In [5]:
class TabNasDataset(CustomDataset):
    def __init__(self, config, ds_train, ds_test, mode='train'):
        super().__init__(config, mode)
        self.ds_train = ds_train
        self.ds_test = ds_test

    def get_transforms(self, config):
        return Compose([ToTensor()]), Compose([ToTensor()])


    def get_data(self, data, train_transform, valid_transform):
        train_data = self.ds_train
        test_data = self.ds_test

        return train_data, test_data
    


In [6]:
class NasTrainer(Trainer):
    @staticmethod
    def build_search_dataloaders(config):
        return train_queue, valid_queue, _

In [7]:


OP_NAMES = ['linear', 'zero']
class Head(ASRPrimitive):

    def __init__(self, filters, num_classes):
        super().__init__(locals())
        self.layers = nn.ModuleList([
            nn.Linear(in_features=filters, out_features=num_classes+1)
        ])

    def forward(self, x, edge_data=None):
        output = self.layers[0](x)
        return output
    
    
class FFN(ASRPrimitive):
        def __init__(
            self,
            *,
            d_token: int,
            d_hidden: int,
            bias_first: bool,
            bias_second: bool,
            dropout: float,
            activation: ModuleType,
        ):
            super().__init__()
            self.linear_first = nn.Linear(
                d_token,
                d_hidden * (2 if _is_glu_activation(activation) else 1),
                bias_first,
            )
            self.activation = _make_nn_module(activation)
            self.dropout = nn.Dropout(dropout)
            self.linear_second = nn.Linear(d_hidden, d_token, bias_second)

        def forward(self, x: Tensor, edge_data=None) -> Tensor:
            x = self.linear_first(x)
            x = self.activation(x)
            x = self.dropout(x)
            x = self.linear_second(x)
            return x
    

class LinearOP(ASRPrimitive):
    def __init__(self, in_features, out_features, dropout_rate=0, name='Linear'):
        super().__init__(locals())
        self.name = name

        self.linear = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU(inplace=False)
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, x, edge_data=None):
        x = self.linear(x)
        x = self.relu(x)
        x = torch.clamp_max_(x, 20)
        x = self.dropout(x)
        return x

    def __repr__(self):
        return f'{self.__class__}({self.linear})'
    

class NasTabSearchSpace(Graph):
    """
    Contains the interface to the tabular benchmark of nas-bench-asr.
    Note: currently we do not support building a naslib object for
    nas-bench-asr architectures.
    """

    QUERYABLE = True
    OPTIMIZER_SCOPE = [
        'cells_stage_1',
        'cells_stage_2',
        'cells_stage_3',
        'cells_stage_4'
    ]

    def __init__(self, num_features, num_classes):
        super().__init__()
        self.load_labeled = False
        self.max_epoch = 40
        self.max_nodes = 3
        self.accs = None
        self.compact = None

        self.n_blocks = 4
        self.n_cells_per_block = [3, 4, 5, 6]
        self.features = num_features
        self.filters = [600, 600, 600, 600]
        self.cnn_time_reduction_kernels = [8, 8, 8, 8]
        self.cnn_time_reduction_strides = [1, 1, 2, 2]
        self.scells_per_block = [3, 4, 5, 6]
        self.num_classes = num_classes
        self.dropout_rate = 0.0
        self.use_norm = True

        self._create_macro_graph()

    def _create_macro_graph(self):
        cell = self._create_cell()

        # Macrograph defintion
        n_nodes = self.n_blocks + 2
        self.add_nodes_from(range(1, n_nodes + 1))

        for node in range(1, n_nodes):
            self.add_edge(node, node + 1)

        # Create the cell blocks and add them as subgraphs of nodes 2 ... 5
        for idx, node in enumerate(range(2, 2 + self.n_blocks)):
            scope = f'cells_stage_{idx + 1}'
            cells_block = self._create_cells_block(cell, n=self.n_cells_per_block[idx], scope=scope)
            self.nodes[node]['subgraph'] = cells_block.set_input([node - 1])

            # Assign the list of operations to the cell edges
            cells_block.update_edges(
                update_func=lambda edge: _set_cell_edge_ops(edge, filters=self.filters[idx], use_norm=self.use_norm),
                scope=scope,
                private_edge_data=True
            )

        start_node = 1
        for idx, node in enumerate(range(start_node, start_node + self.n_blocks)):
            if node == start_node:
                op = LinearOP(self.features, 600)
            else:
                op = core_ops.Identity()

            self.edges[node, node + 1].set('op', op)
        # Assign the LSTM + Linear layer to the last edge in the macro graph
        self.edges[self.n_blocks + 1, self.n_blocks + 2].set('op', Head(600, 5))

    def _create_cells_block(self, cell, n, scope):
        block = Graph()
        block.name = f'{n}_cells_block'

        block.add_nodes_from(range(1, n + 2))

        for node in range(2, n + 2):
            block.add_node(node, subgraph=cell.copy().set_scope(scope).set_input([node - 1]))

        for node in range(1, n + 2):
            block.add_edge(node, node + 1)

        return block

    def _create_cell(self):
        cell = Graph()
        cell.name = 'cell'

        cell.add_nodes_from(range(1, 8))

        # Create edges
        for i in range(1, 7):
            cell.add_edge(i, i + 1)

        for i in range(1, 6, 2):
            for j in range(i + 2, 8, 2):
                cell.add_edge(i, j)

        cell.add_node(8)
        cell.add_edge(7, 8)  # For optional layer normalization

        return cell

    def query(self, metric=None, dataset=None, path=None, epoch=-1,
              full_lc=False, dataset_api=None):
        """
        Query results from nas-bench-asr
        """
        metric_to_asr = {
            Metric.VAL_ACCURACY: "val_per",
            Metric.TEST_ACCURACY: "test_per",
            Metric.PARAMETERS: "params",
            Metric.FLOPS: "flops",
        }

        assert self.compact is not None
        assert metric in [
            Metric.TRAIN_ACCURACY,
            Metric.TRAIN_LOSS,
            Metric.VAL_ACCURACY,
            Metric.TEST_ACCURACY,
            Metric.PARAMETERS,
            Metric.FLOPS,
            Metric.TRAIN_TIME,
            Metric.RAW,
        ]
        query_results = dataset_api["asr_data"].full_info(self.compact)

        if metric != Metric.VAL_ACCURACY:
            if metric == Metric.TEST_ACCURACY:
                return query_results[metric_to_asr[metric]]
            elif (metric == Metric.PARAMETERS) or (metric == Metric.FLOPS):
                return query_results['info'][metric_to_asr[metric]]
            elif metric in [Metric.TRAIN_ACCURACY, Metric.TRAIN_LOSS,
                            Metric.TRAIN_TIME, Metric.RAW]:
                return -1
        else:
            if full_lc and epoch == -1:
                return [
                    loss for loss in query_results[metric_to_asr[metric]]
                ]
            elif full_lc and epoch != -1:
                return [
                    loss for loss in query_results[metric_to_asr[metric]][:epoch]
                ]
            else:
                # return the value of the metric only at the specified epoch
                return float(query_results[metric_to_asr[metric]][epoch])

    def get_compact(self):
        assert self.compact is not None
        return self.compact

    def get_hash(self):
        return self.get_compact()

    def set_compact(self, compact):
        self.compact = make_compact_immutable(compact)

    def sample_random_architecture(self, dataset_api):
        search_space = [[len(OP_NAMES)] + [2] * (idx + 1) for idx in
                        range(self.max_nodes)]
        flat = flatten(search_space)
        m = [random.randrange(opts) for opts in flat]
        m = copy_structure(m, search_space)

        compact = m
        self.set_compact(compact)
        return compact

    def mutate(self, parent, mutation_rate=1, dataset_api=None):
        """
        This will mutate the cell in one of two ways:
        change an edge; change an op.
        Todo: mutate by adding/removing nodes.
        Todo: mutate the list of hidden nodes.
        Todo: edges between initial hidden nodes are not mutated.
        """
        parent_compact = parent.get_compact()
        parent_compact = make_compact_mutable(parent_compact)
        compact = copy.deepcopy(parent_compact)

        for _ in range(int(mutation_rate)):
            mutation_type = np.random.choice([2])

            if mutation_type == 1:
                # change an edge
                # first pick up a node
                node_id = np.random.choice(3)
                node = compact[node_id]
                # pick up an edge id
                edge_id = np.random.choice(len(node[1:])) + 1
                # edge ops are in [identity, zero] ([0, 1])
                new_edge_op = int(not compact[node_id][edge_id])
                # apply the mutation
                compact[node_id][edge_id] = new_edge_op

            elif mutation_type == 2:
                # change an op
                node_id = np.random.choice(3)
                node = compact[node_id]
                op_id = node[0]
                list_of_ops_ids = list(range(len(OP_NAMES)))
                list_of_ops_ids.remove(op_id)
                new_op_id = random.choice(list_of_ops_ids)
                compact[node_id][0] = new_op_id

        self.set_compact(compact)

    def get_nbhd(self, dataset_api=None):
        """
        Return all neighbors of the architecture
        """
        compact = self.get_compact()
        # edges, ops, hiddens = compact
        nbhd = []

        def add_to_nbhd(new_compact, nbhd):
            print(new_compact)
            nbr = NasBenchASRSearchSpace()
            nbr.set_compact(new_compact)
            nbr_model = torch.nn.Module()
            nbr_model.arch = nbr
            nbhd.append(nbr_model)
            return nbhd

        for node_id in range(len(compact)):
            node = compact[node_id]
            for edge_id in range(len(node)):
                if edge_id == 0:
                    edge_op = compact[node_id][0]
                    list_of_ops_ids = list(range(len(OP_NAMES)))
                    list_of_ops_ids.remove(edge_op)
                    for op_id in list_of_ops_ids:
                        new_compact = copy.deepcopy(compact)
                        new_compact = make_compact_mutable(new_compact)
                        new_compact[node_id][0] = op_id
                        nbhd = add_to_nbhd(new_compact, nbhd)
                else:
                    edge_op = compact[node_id][edge_id]
                    new_edge_op = int(not edge_op)
                    new_compact = copy.deepcopy(compact)
                    new_compact = make_compact_mutable(new_compact)
                    new_compact[node_id][edge_id] = new_edge_op
                    nbhd = add_to_nbhd(new_compact, nbhd)

        random.shuffle(nbhd)
        return nbhd

    def get_type(self):
        return 'asr'

    def get_max_epochs(self):
        return 39

    def encode(self, encoding_type=EncodingType.ADJACENCY_ONE_HOT):
        return encode_asr(self, encoding_type=encoding_type)


def _set_cell_edge_ops(edge, filters, use_norm):
    if use_norm and edge.head == 7:
        edge.data.set('op', core_ops.Identity())
        edge.data.finalize()
    elif edge.head % 2 == 0:  # Edge from intermediate node
        edge.data.set(
            'op', [
                LinearOP(filters, filters),
                ops['zero'](filters, filters)
            ]
        )
    elif edge.tail % 2 == 0:  # Edge to intermediate node. Should always be Identity.
        edge.data.finalize()
    else:
        edge.data.set(
            'op',
            [
                core_ops.Zero(stride=1),
                core_ops.Identity()
            ]
        )

In [None]:
name = 'classif-cat-large-0-covertype'
ds_train = NasDataset(datasets, name, 'train')
ds_test = NasDataset(datasets, name, 'test')

dataset = TabNasDataset(config, ds_train, ds_test)

train_queue, valid_queue, test_queue, train_transform, valid_transform = dataset.get_loaders()
    
search_space = NasTabSearchSpace(ds_train.num_features, ds_train.num_classes)
logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO)  # default DEBUG is very verbose

optimizer = DARTSOptimizer(**config.search)
optimizer.adapt_search_space(search_space, config.dataset)


trainer = NasTrainer(optimizer, config)
trainer.search()  # Search for an architecture
trainer.evaluate()  # Evaluate the best architecture

[05/24 07:56:37 nl.defaults.trainer]: Beginning search
[05/24 07:56:40 nl.optimizers.oneshot.darts.optimizer]: Arch weights (alphas, last column argmax): 
-0.000156, -0.000116, 1
+0.000500, +0.001154, 1
+0.000190, -0.000586, 0
+0.000544, -0.000631, 0
-0.000397, +0.002099, 1
-0.000823, -0.001026, 0
-0.001486, -0.001633, 0
+0.000798, +0.000570, 0
+0.000623, +0.000665, 1
[05/24 07:56:40 nl.defaults.trainer]: Epoch 0-0, Train loss: 6.14617, validation loss: 6.32109, learning rate: [0.025]
[05/24 07:56:40 nl.defaults.trainer]: cuda consumption
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  237393 KB |  240189 KB |  516084 KB | 