This is an end-to-end example for DARTS. In this tutorial, you will learn how to implement architecture search via DARTS algorthms, and retrain the model based on the best architecture. You can read more about the DARTS in the [DARTS paper](https://arxiv.org/abs/1806.09055).

[DARTS](https://github.com/quark0/darts) addresses the scalability challenge of architecture search by formulating the task in a differentiable manner. Their method is based on the continuous relaxation of the architecture representation, allowing efficient search of the architecture using gradient descent.

The code in this example is implemented by NNI based on the [official implementation](https://github.com/quark0/darts) and a [popular 3rd-party repo](https://github.com/khanrc/pt.darts). DARTS on NNI is designed to be general for arbitrary search space and arbitrary dataset. In this use case, a CNN search space tailored for CIFAR10, is implemented to synchronize with the original paper.

## Loading the data

In this post we experiment with CIFAR10 dataset. The dataset will be downloaded into `./data/cifar-10-batches-py` if there is no local dataset.

In [None]:
import datasets

dataset_train, dataset_valid = datasets.get_dataset("cifar10")

## Show search space

We utilize a CNN search space, which is implemented in `./model.py`.

TODO: show the architecture

In [None]:
from model import CNN

model = CNN(32, 3, args.channels, 10, args.layers)

## Searching the best architecture

Firstly we need to define the value of hyperparameters for architecture search.

In [None]:
from utils import accuracy

# loss function
criterion = nn.CrossEntropyLoss()

# optimizer
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)

# learning rate
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)

# epochs
epochs = ...

# batch size
batch_size = ...

# log frequency
log_frequency = ...

Now, let’s use the DARTS algorithm to train the model. The network is based on continuous relaxation and gradient descent in the architecture space. In training, the network optimizes the network weights and architecture weights alternatively in an end-to-end fashion.

In [None]:
import json

from nni.retiarii.oneshot.pytorch import DartsTrainer

# define the Darts Trainer
trainer = DartsTrainer(
    model=model,
    loss=criterion,
    metrics=lambda output, target: accuracy(output, target, topk=(1,)),
    optimizer=optim,
    num_epochs=epochs,
    dataset=dataset_train,
    batch_size=batch_size,
    log_frequency=log_frequency,
    unrolled=False
)

# start training
trainer.fit()


In DARTS paper, the author further explore the possibility that uses second order optimization (unroll) instead of first order, to improve the performance. If users want to implement the second order optimization, they could change the parameter `unrolled` as `True`.

In [None]:
# define the Darts Trainer
trainer = DartsTrainer(
    model=model,
    loss=criterion,
    metrics=lambda output, target: accuracy(output, target, topk=(1,)),
    optimizer=optim,
    num_epochs=epochs,
    dataset=dataset_train,
    batch_size=batch_size,
    log_frequency=log_frequency,
-   unrolled=False
+   unrolled=True
)

# start training
trainer.fit()

[2022-01-24 18:10:10] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [1/391]  acc1 0.937500 (0.937500)  loss 0.126625 (0.126625)

[2022-01-24 18:10:27] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [11/391]  acc1 0.937500 (0.946023)  loss 0.171426 (0.161022)

[2022-01-24 18:10:43] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [21/391]  acc1 0.937500 (0.939732)  loss 0.145051 (0.170434)

[2022-01-24 18:11:00] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [31/391]  acc1 0.953125 (0.943548)  loss 0.224578 (0.165675)

[2022-01-24 18:11:16] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [41/391]  acc1 0.984375 (0.943598)  loss 0.085558 (0.163357)

[2022-01-24 18:11:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [51/391]  acc1 0.921875 (0.942402)  loss 0.185006 (0.161536)

[2022-01-24 18:11:50] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [61/391]  acc1 0.953125 (0.945697)  loss 0.156536 (0.155413)

[2022-01-24 18:12:06] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [71/391]  acc1 0.968750 (0.946963)  loss 0.123197 (0.152378)

[2022-01-24 18:12:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [81/391]  acc1 0.906250 (0.946373)  loss 0.178447 (0.151914)

[2022-01-24 18:12:39] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [91/391]  acc1 0.953125 (0.946257)  loss 0.134887 (0.150501)

[2022-01-24 18:12:56] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [101/391]  acc1 0.890625 (0.944771)  loss 0.206022 (0.152196)

[2022-01-24 18:13:12] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [111/391]  acc1 1.000000 (0.946227)  loss 0.037209 (0.149334)

[2022-01-24 18:13:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [121/391]  acc1 0.953125 (0.945119)  loss 0.119741 (0.152447)

[2022-01-24 18:13:45] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [131/391]  acc1 0.921875 (0.944656)  loss 0.180555 (0.152467)

[2022-01-24 18:14:02] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [141/391]  acc1 0.921875 (0.945368)  loss 0.146477 (0.150319)

[2022-01-24 18:14:18] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [151/391]  acc1 0.984375 (0.945675)  loss 0.082191 (0.148840)

[2022-01-24 18:14:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [161/391]  acc1 0.875000 (0.945652)  loss 0.240533 (0.150730)

[2022-01-24 18:14:51] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [171/391]  acc1 0.906250 (0.945267)  loss 0.321685 (0.153602)

[2022-01-24 18:15:07] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [181/391]  acc1 0.937500 (0.944751)  loss 0.160337 (0.155523)

[2022-01-24 18:15:24] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [191/391]  acc1 0.875000 (0.943881)  loss 0.320412 (0.157153)

[2022-01-24 18:15:40] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [201/391]  acc1 0.859375 (0.943719)  loss 0.294951 (0.157609)

[2022-01-24 18:15:56] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [211/391]  acc1 0.937500 (0.943720)  loss 0.165968 (0.157108)

[2022-01-24 18:16:13] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [221/391]  acc1 0.937500 (0.943439)  loss 0.177656 (0.158249)

[2022-01-24 18:16:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [231/391]  acc1 0.984375 (0.943791)  loss 0.108888 (0.157995)

[2022-01-24 18:16:45] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [241/391]  acc1 0.890625 (0.942946)  loss 0.209692 (0.158990)

[2022-01-24 18:17:01] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [251/391]  acc1 0.984375 (0.943103)  loss 0.096665 (0.158933)

[2022-01-24 18:17:18] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [261/391]  acc1 0.906250 (0.942229)  loss 0.240630 (0.161006)

[2022-01-24 18:17:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [271/391]  acc1 0.984375 (0.942286)  loss 0.084514 (0.160753)

[2022-01-24 18:17:50] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [281/391]  acc1 0.890625 (0.941893)  loss 0.317153 (0.162135)

[2022-01-24 18:18:07] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [291/391]  acc1 0.843750 (0.941098)  loss 0.377020 (0.164658)

[2022-01-24 18:18:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [301/391]  acc1 0.921875 (0.941393)  loss 0.171819 (0.164130)

[2022-01-24 18:18:39] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [311/391]  acc1 0.937500 (0.940967)  loss 0.123255 (0.165447)

[2022-01-24 18:18:56] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [321/391]  acc1 0.890625 (0.940275)  loss 0.291045 (0.166802)

[2022-01-24 18:19:12] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [331/391]  acc1 0.953125 (0.940096)  loss 0.201307 (0.167968)

[2022-01-24 18:19:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [341/391]  acc1 0.953125 (0.940112)  loss 0.208557 (0.168172)

[2022-01-24 18:19:45] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [351/391]  acc1 0.906250 (0.939815)  loss 0.225696 (0.169160)

[2022-01-24 18:20:01] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [361/391]  acc1 1.000000 (0.939707)  loss 0.046506 (0.169739)

[2022-01-24 18:20:17] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [371/391]  acc1 0.890625 (0.939522)  loss 0.164111 (0.170325)

[2022-01-24 18:20:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [381/391]  acc1 0.953125 (0.939017)  loss 0.140962 (0.171564)

[2022-01-24 18:20:50] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [50/50] Step [391/391]  acc1 0.900000 (0.939122)  loss 0.208060 (0.171751)

## Visualizing the results and export model

After searching in the CNN search space. The best architecture in status will be stored in `trainer` and can be exported via:

In [None]:
# export the final architecture in trainer
final_architecture = trainer.export()
print('Final architecture:', trainer.export())

# dump best architecture by json
import json
json.dump(trainer.export(), open('checkpoint.json', 'w'))

Final architecture: {'normal_n2_p0': 'sepconv3x3', 'normal_n2_p1': 'sepconv3x3', 'normal_n3_p0': 'sepconv3x3', 'normal_n3_p1': 'skipconnect', 'normal_n3_p2': 'skipconnect', 'normal_n4_p0': 'skipconnect', 'normal_n4_p1': 'skipconnect', 'normal_n4_p2': 'skipconnect', 'normal_n4_p3': 'skipconnect', 'normal_n5_p0': 'sepconv3x3', 'normal_n5_p1': 'sepconv3x3', 'normal_n5_p2': 'skipconnect', 'normal_n5_p3': 'skipconnect', 'normal_n5_p4': 'skipconnect', 'reduce_n2_p0': 'maxpool', 'reduce_n2_p1': 'maxpool', 'reduce_n3_p0': 'maxpool', 'reduce_n3_p1': 'maxpool', 'reduce_n3_p2': 'dilconv5x5', 'reduce_n4_p0': 'maxpool', 'reduce_n4_p1': 'maxpool', 'reduce_n4_p2': 'skipconnect', 'reduce_n4_p3': 'skipconnect', 'reduce_n5_p0': 'maxpool', 'reduce_n5_p1': 'avgpool', 'reduce_n5_p2': 'skipconnect', 'reduce_n5_p3': 'skipconnect', 'reduce_n5_p4': 'skipconnect', 'normal_n2_switch': [0, 1], 'normal_n3_switch': [0, 1], 'normal_n4_switch': [0, 1], 'normal_n5_switch': [0, 1], 'reduce_n2_switch': [0, 1], 'reduce_n3_switch': [0, 1], 'reduce_n4_switch': [0, 1], 'reduce_n5_switch': [0, 1]}

## Retrain the model with searched architecture

Now we have a final architecture in the previous step, we can also evaluate our best architectures by training from scratch. To load the architecture, run:

In [None]:
from nni.retiarii import fixed_arch

# Load architecture from ``fixed_arch`` and apply to model
with fixed_arch('checkpoint.json'):
    model = CNN(32, 3, 36, 10, args.layers, auxiliary=True)


Load evaluator and run the evaluation:

In [None]:
import nni.retiarii.evaluator.pytorch as pl

cls = pl.Classification(leraning_rate=...)
cls.fit(model)

## Performance of DARTS in NNI

run results and document them in docs

...

## Todo List

- visulization via tensorboard 
    - metrics
    - current architecture ([reference](https://github.com/quark0/darts))
- checkpoing: model weights, architecture
- monitor logs