In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import nni
from nni.nas.nn.pytorch import LayerChoice, ModelSpace, MutableDropout, MutableLinear
from nni.nas.evaluator.pytorch import Classification


In [2]:
from aeon.datasets import load_classification
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import TensorDataset #, DataLoader
from nni.nas.evaluator.pytorch.lightning import DataLoader
from einops import rearrange

X, y = load_classification("Tiselac", extract_path="/workdir/data")
print(" Shape of X = ", X.shape, type(X), X.dtype)
print(" Shape of y = ", y.shape, type(y), y.dtype)
y = y.astype(int)
display(y)

X = rearrange(X, "n v t -> n t v")
in_feat = X.shape[2]

# Target
y = y - 1
y_unique = np.unique(y)
num_classes = len(y_unique)
print("num_classes", num_classes)
y = np.eye(num_classes)[y]

# Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.40, random_state = 1, stratify = y)
X_test, X_val, y_test, y_val = train_test_split(X_test, y_test, test_size = 0.50, random_state = 1, stratify = y_test)


# Normalize
# X_time_train, X_time_val, X_time_test = normalize_across_time(X_time_train, X_time_val, X_time_test, X_time.shape[2])


# Datasets
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)

X_val = torch.tensor(X_val, dtype=torch.float32)
y_val = torch.tensor(y_val, dtype=torch.float32)

X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)

print(X_train.shape, y_train.shape)

# Dataloaders
batch_size = 64
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers=4)

val_dataset = TensorDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle=False, num_workers=4)

test_dataset = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False, num_workers=4)


 Shape of X =  (99687, 10, 23) <class 'numpy.ndarray'> float64
 Shape of y =  (99687,) <class 'numpy.ndarray'> <U1


array([6, 1, 6, ..., 3, 4, 5])

num_classes 9
torch.Size([59812, 23, 10]) torch.Size([59812, 9])


In [3]:
class BottleneckCell(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super(BottleneckCell, self).__init__()

        self.linear_in = MutableLinear(in_features, hidden_features)
        self.linear_out = MutableLinear(hidden_features, out_features)

    def forward(self, x):
        output = self.linear_out(F.relu(self.linear_in(x)))
        return output

class MyModelSpace(ModelSpace, label_prefix='backbone'):
    def __init__(self, in_feat, hidden_feat, out_feat):
        super().__init__()
        
        self.layer1 = LayerChoice([
            BottleneckCell(in_feat, 200, hidden_feat),
            BottleneckCell(in_feat, 50, hidden_feat),
        ], label='layer1')
        
        self.act1 = LayerChoice([
            nn.ReLU(),
            nn.SELU(),
        ], label='act1')
        
        self.flatten = nn.Flatten()
        
        self.layer2 = LayerChoice([
            BottleneckCell(23 * hidden_feat, 200, hidden_feat),
            BottleneckCell(23 * hidden_feat, 50, hidden_feat),
        ], label='layer2')
        
        self.act2 = LayerChoice([
            nn.ReLU(),
            nn.SELU(),
        ], label='act2')
        
        self.layer3 = nn.Linear(hidden_feat, out_feat)

    def forward(self, x):
        x = self.act1(self.layer1(x))
        x = self.flatten(x)
        x = self.act2(self.layer2(x))
        output = self.layer3(x)
        output = F.log_softmax(output, dim=1)
        return output

model_space = MyModelSpace(in_feat, 100, num_classes)
model_space

MyModelSpace(
  (layer1): LayerChoice(
    label='backbone/layer1'
    (0): BottleneckCell(
      (linear_in): MutableLinear(in_features=10, out_features=200)
      (linear_out): MutableLinear(in_features=200, out_features=100)
    )
    (1): BottleneckCell(
      (linear_in): MutableLinear(in_features=10, out_features=50)
      (linear_out): MutableLinear(in_features=50, out_features=100)
    )
  )
  (act1): LayerChoice(
    label='backbone/act1'
    (0): ReLU()
    (1): SELU()
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layer2): LayerChoice(
    label='backbone/layer2'
    (0): BottleneckCell(
      (linear_in): MutableLinear(in_features=2300, out_features=200)
      (linear_out): MutableLinear(in_features=200, out_features=100)
    )
    (1): BottleneckCell(
      (linear_in): MutableLinear(in_features=2300, out_features=50)
      (linear_out): MutableLinear(in_features=50, out_features=100)
    )
  )
  (act2): LayerChoice(
    label='backbone/act2'
    (0): ReLU()
    (1):

In [4]:
for X, y in train_loader:
    print(X.shape, X.dtype, y.shape, y.dtype)
    
    out = model_space(X)
    print(out.shape)
    break


torch.Size([64, 23, 10]) torch.float32 torch.Size([64, 9]) torch.float32
torch.Size([64, 9])


In [5]:
evaluator = Classification(
    criterion=nn.CrossEntropyLoss,
    learning_rate=1e-3,
    weight_decay=1e-4,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    max_epochs=10,
    gpus=1,
    # fast_dev_run=True,
    num_classes=num_classes,
)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
from nni.nas.strategy import DARTS
from nni.nas.strategy import GumbelDARTS
strategy = DARTS()

# nni.nas.nn.pytorch.LayerChoice.
# nni.nas.nn.pytorch.InputChoice.
# nni.nas.nn.pytorch.ParametrizedModule (only when parameters are choices and type is in MutableLinear, MutableConv2d, MutableBatchNorm2d, MutableLayerNorm, MutableMultiheadAttention).
# nni.nas.nn.pytorch.Repeat.
# nni.nas.nn.pytorch.Cell.


In [7]:
from nni.nas.experiment import NasExperiment

experiment = NasExperiment(model_space, evaluator, strategy)
experiment.run()


[2024-02-04 10:57:30] [32mConfig is not provided. Will try to infer.[0m
[2024-02-04 10:57:30] [32mStrategy is found to be a one-shot strategy. Setting execution engine to "sequential" and format to "raw".[0m
[2024-02-04 10:57:30] [32mCheckpoint saved to /root/nni-experiments/2wmj9yvs/checkpoint.[0m
[2024-02-04 10:57:30] [32mExperiment initialized successfully. Starting exploration strategy...[0m


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]

  | Name            | Type                 | Params
---------------------------------------------------------
0 | training_module | ClassificationModule | 629 K 
---------------------------------------------------------
629 K     Trainable params
0         Non-trainable params
629 K     Total params
2.517     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

[2024-02-04 11:01:05] [32mWaiting for models submitted to engine to finish...[0m
[2024-02-04 11:01:05] [32mExperiment is completed.[0m


True

In [9]:
exported_arch = experiment.export_top_models(formatter='dict')[0]
exported_arch


{'backbone/layer1': 0,
 'backbone/act1': 0,
 'backbone/layer2': 1,
 'backbone/act2': 0}

In [12]:
from nni.nas.space import model_context

with model_context(exported_arch):
    final_model = MyModelSpace(in_feat, 100, num_classes)

print(final_model)
# train_loader = DataLoader(train_data, batch_size=96, num_workers=6)  # Use the original training data

max_epochs = 100

evaluator = Classification(
    learning_rate = 1e-3,
    weight_decay = 1e-4,
    train_dataloaders = train_loader,
    val_dataloaders = val_loader,
    max_epochs = max_epochs,
    gpus = 1,
    export_onnx = False,
    fast_dev_run = False,
    num_classes = num_classes,
)

evaluator.fit(final_model)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]

  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | metrics   | ModuleDict       | 0     
2 | _model    | MyModelSpace     | 143 K 
-----------------------------------------------
143 K     Trainable params
0         Non-trainable params
143 K     Total params
0.573     Total estimated model params size (MB)


MyModelSpace(
  (layer1): BottleneckCell(
    (linear_in): Linear(in_features=10, out_features=200, bias=True)
    (linear_out): Linear(in_features=200, out_features=100, bias=True)
  )
  (act1): ReLU()
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layer2): BottleneckCell(
    (linear_in): Linear(in_features=2300, out_features=50, bias=True)
    (linear_out): Linear(in_features=50, out_features=100, bias=True)
  )
  (act2): ReLU()
  (layer3): Linear(in_features=100, out_features=9, bias=True)
)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[2024-02-04 19:37:54] [32mIntermediate result: 0.0  (Index 2)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:38:04] [32mIntermediate result: 0.0  (Index 3)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:38:14] [32mIntermediate result: 0.0  (Index 4)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:38:24] [32mIntermediate result: 0.0  (Index 5)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:38:34] [32mIntermediate result: 0.0  (Index 6)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:38:43] [32mIntermediate result: 0.0  (Index 7)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:38:53] [32mIntermediate result: 0.0  (Index 8)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:39:04] [32mIntermediate result: 0.0  (Index 9)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:39:13] [32mIntermediate result: 0.0  (Index 10)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:39:24] [32mIntermediate result: 0.0  (Index 11)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:39:34] [32mIntermediate result: 0.0  (Index 12)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:39:44] [32mIntermediate result: 0.0  (Index 13)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:39:54] [32mIntermediate result: 0.0  (Index 14)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:40:04] [32mIntermediate result: 0.0  (Index 15)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:40:13] [32mIntermediate result: 0.0  (Index 16)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:40:23] [32mIntermediate result: 0.0  (Index 17)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:40:34] [32mIntermediate result: 0.0  (Index 18)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:40:44] [32mIntermediate result: 0.0  (Index 19)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:40:54] [32mIntermediate result: 0.0  (Index 20)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:41:04] [32mIntermediate result: 0.0  (Index 21)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:41:13] [32mIntermediate result: 0.0  (Index 22)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:41:23] [32mIntermediate result: 0.0  (Index 23)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:41:33] [32mIntermediate result: 0.0  (Index 24)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:41:43] [32mIntermediate result: 0.0  (Index 25)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:41:53] [32mIntermediate result: 0.0  (Index 26)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:42:03] [32mIntermediate result: 0.0  (Index 27)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:42:13] [32mIntermediate result: 0.0  (Index 28)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:42:23] [32mIntermediate result: 0.0  (Index 29)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:42:33] [32mIntermediate result: 0.0  (Index 30)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:42:43] [32mIntermediate result: 0.0  (Index 31)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:42:54] [32mIntermediate result: 0.0  (Index 32)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:43:04] [32mIntermediate result: 0.0  (Index 33)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:43:14] [32mIntermediate result: 0.0  (Index 34)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:43:24] [32mIntermediate result: 0.0  (Index 35)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:43:34] [32mIntermediate result: 0.0  (Index 36)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:43:44] [32mIntermediate result: 0.0  (Index 37)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:43:55] [32mIntermediate result: 0.0  (Index 38)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:44:07] [32mIntermediate result: 0.0  (Index 39)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:44:17] [32mIntermediate result: 0.0  (Index 40)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:44:27] [32mIntermediate result: 0.0  (Index 41)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:44:37] [32mIntermediate result: 0.0  (Index 42)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:44:47] [32mIntermediate result: 0.0  (Index 43)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:44:57] [32mIntermediate result: 0.0  (Index 44)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:45:06] [32mIntermediate result: 0.0  (Index 45)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:45:16] [32mIntermediate result: 0.0  (Index 46)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:45:26] [32mIntermediate result: 0.0  (Index 47)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:45:36] [32mIntermediate result: 0.0  (Index 48)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:45:46] [32mIntermediate result: 0.0  (Index 49)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:45:56] [32mIntermediate result: 0.0  (Index 50)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:46:06] [32mIntermediate result: 0.0  (Index 51)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:46:16] [32mIntermediate result: 0.0  (Index 52)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:46:26] [32mIntermediate result: 0.0  (Index 53)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:46:36] [32mIntermediate result: 0.0  (Index 54)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:46:46] [32mIntermediate result: 0.0  (Index 55)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:46:55] [32mIntermediate result: 0.0  (Index 56)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:47:05] [32mIntermediate result: 0.0  (Index 57)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:47:15] [32mIntermediate result: 0.0  (Index 58)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:47:25] [32mIntermediate result: 0.0  (Index 59)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:47:35] [32mIntermediate result: 0.0  (Index 60)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:47:45] [32mIntermediate result: 0.0  (Index 61)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:47:55] [32mIntermediate result: 0.0  (Index 62)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:48:05] [32mIntermediate result: 0.0  (Index 63)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:48:15] [32mIntermediate result: 0.0  (Index 64)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:48:25] [32mIntermediate result: 0.0  (Index 65)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:48:35] [32mIntermediate result: 0.0  (Index 66)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:48:45] [32mIntermediate result: 0.0  (Index 67)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:48:56] [32mIntermediate result: 0.0  (Index 68)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:49:06] [32mIntermediate result: 0.0  (Index 69)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:49:16] [32mIntermediate result: 0.0  (Index 70)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:49:26] [32mIntermediate result: 0.0  (Index 71)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:49:36] [32mIntermediate result: 0.0  (Index 72)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:49:46] [32mIntermediate result: 0.0  (Index 73)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:49:56] [32mIntermediate result: 0.0  (Index 74)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:50:05] [32mIntermediate result: 0.0  (Index 75)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:50:15] [32mIntermediate result: 0.0  (Index 76)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:50:25] [32mIntermediate result: 0.0  (Index 77)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:50:35] [32mIntermediate result: 0.0  (Index 78)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:50:45] [32mIntermediate result: 0.0  (Index 79)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:50:54] [32mIntermediate result: 0.0  (Index 80)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:51:04] [32mIntermediate result: 0.0  (Index 81)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:51:14] [32mIntermediate result: 0.0  (Index 82)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:51:24] [32mIntermediate result: 0.0  (Index 83)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:51:34] [32mIntermediate result: 0.0  (Index 84)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:51:43] [32mIntermediate result: 0.0  (Index 85)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:51:53] [32mIntermediate result: 0.0  (Index 86)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:52:03] [32mIntermediate result: 0.0  (Index 87)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:52:13] [32mIntermediate result: 0.0  (Index 88)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:52:22] [32mIntermediate result: 0.0  (Index 89)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:52:32] [32mIntermediate result: 0.0  (Index 90)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:52:42] [32mIntermediate result: 0.0  (Index 91)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:52:52] [32mIntermediate result: 0.0  (Index 92)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:53:02] [32mIntermediate result: 0.0  (Index 93)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:53:12] [32mIntermediate result: 0.0  (Index 94)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:53:21] [32mIntermediate result: 0.0  (Index 95)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:53:31] [32mIntermediate result: 0.0  (Index 96)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:53:41] [32mIntermediate result: 0.0  (Index 97)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:53:51] [32mIntermediate result: 0.0  (Index 98)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:54:01] [32mIntermediate result: 0.0  (Index 99)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:54:11] [32mIntermediate result: 0.0  (Index 100)[0m


Validation: 0it [00:00, ?it/s]

[2024-02-04 19:54:20] [32mIntermediate result: 0.0  (Index 101)[0m
[2024-02-04 19:54:21] [32mFinal result: 0.0[0m
