## Quantization-Aware Basecalling Neural Architecture Search (QABAS)

We use neural architecture search (NAS) to explore different design options for a basecaller. We use a differentiable NAS (DNAS) approach, a weight-sharing approach where we train only one supernetwork and distill a sub-network out of it. We define a search space that consists of all the options for a model. The search space for rubicon is defined in arch/basemodelquant.py.

In [1]:
import os
import sys
from argparse import ArgumentParser 
from argparse import ArgumentDefaultsHelpFormatter
from pathlib import Path
from importlib import import_module
import torch.nn as nn
from os import system
from bonito.data import load_numpy
from rubicon.data import load_numpy_shuf,load_numpy_full
from bonito.data import load_script
from rubicon.util import __models__, default_data
from bonito.util import load_symbol, init
from rubicon.training import load_state, Trainer
import json
import toml
import torch
import numpy as np
from torch.utils.data import DataLoader
from rubicon.tools.nni.nni.retiarii.nn.pytorch.api import LayerChoice, InputChoice
from rubicon.nas.dartsbasecalling import DartsBasecalling
from rubicon.nas.proxylessbasecalling import ProxylessBasecalling
import torch.onnx
import time
import datetime
import logging
sys.setrecursionlimit(20000)

_logger = logging.getLogger(__name__)
def get_parameters(model, keys=None, mode='include'):
    if keys is None:
        for name, param in model.named_parameters():
            yield param
    elif mode == 'include':
        for name, param in model.named_parameters():
            flag = False
            for key in keys:
                if key in name:
                    flag = True
                    break
            if flag:
                yield param
    elif mode == 'exclude':
        for name, param in model.named_parameters():
            flag = True
            for key in keys:
                if key in name:
                    flag = False
                    break
            if flag:
                yield param
    else:
        raise ValueError('do not support: %s' % mode)

In [2]:
# define all the parameters
save_directory="temp"
workdir = os.path.expanduser(save_directory)
seed=25
config="../rubicon/models/configs/config.toml"
hardware='aie_lut'
nas='proxy'
reference_latency="65"
grad_reg_loss_lambda=6e-1
directory="../rubicon/data/dna_r9.4.1"
lr=2e-3
ctlr=2e-3
grad_reg_loss_type="add#linear"
rubicon=True
default=True
epochs=5
rub_sched=True
dart_sched=True
rub_arch_opt=True
prox_arch_opt=True
full=False
chunks=128
valid_chunks=128
device="CPU"
arc_checkpoint="final_arch.json"
# !which python
# assert(torch.cuda.is_available())

In [None]:
_logger.info("Start date and time:{}".format(datetime.datetime.now()))
if os.path.exists(workdir):
    print("[error] %s exists. Removing." % workdir)
    os.rmdir(workdir)
    exit(1)

os.makedirs(workdir, exist_ok=True)
# init(seed, device)
# device = torch.device(device)


config_file = config
if not os.path.exists(config_file):
    print("[error] %s does not" % config_file)
    exit(1)
config = toml.load(config_file)
if not nas:
    _logger.warning("Please specify which type of NAS using --nas argument")
    exit(1)
_logger.info("[loading model]")
model = load_symbol(config, 'BaseModelQuant')(config)


_logger.info("NAS type:{}".format(nas))
_logger.info("Hardware type:{}".format(hardware))
_logger.info("Reference latency:{}".format(reference_latency))
_logger.info("lambda:{}".format(grad_reg_loss_lambda))
_logger.info("[loading data]")
if full:
        _logger.info("Full dataset training")
        train_loader_kwargs, valid_loader_kwargs = load_numpy_full(None,
                args.directory
        )
elif chunks:
        _logger.info("Not full dataset training with shuffling")
        train_loader_kwargs, valid_loader_kwargs = load_numpy_shuf(
            chunks, valid_chunks, directory
        )
else:
        _logger.warning("Please define the training data correctly")
        exit(1)

loader_kwargs = {
    "batch_size": args.batch, "num_workers": 8, "pin_memory": True
}
train_loader = DataLoader(**loader_kwargs, **train_loader_kwargs)
valid_loader = DataLoader(**loader_kwargs, **valid_loader_kwargs)

if nas == 'darts':
    #### setting optimizer #######  
    optimizer = None
    _logger.info("Starting DARTS NAS")
    #### setting lr scheduler #######

    _logger.info("Scheduler: Linear Warmup")
    if config.get("lr_scheduler"):
        sched_config = config["lr_scheduler"]
        lr_scheduler_fn = getattr(
            import_module(sched_config["package"]), sched_config["symbol"]
        )(**sched_config)
        print("building scheduler",getattr(
            import_module(sched_config["package"]), sched_config["symbol"]
        )(**sched_config))
    else:
        print("no scheduler")
        lr_scheduler_fn = None



    trainer = DartsBasecalling(
            model=model,
            train_loader=train_loader, 
            valid_loader=valid_loader,
            optimizer=optimizer,
            lr_scheduler_fn=lr_scheduler_fn,
            ctrl_learning_rate=ctlr,
            opt_learning_rate=lr,
            applied_hardware=hardware,
            metrics=lambda output, target: accuracy(output, target, topk=(1, 5,)),
            log_frequency=10,
            grad_reg_loss_type=grad_reg_loss_type, 
            grad_reg_loss_params=grad_reg_loss_params, 
            dummy_input=(344,1,9),
            ref_latency=reference_latency,
            rubicon=rubicon,
            default=default,
            num_epochs=epochs
        )

elif nas == 'proxy':
    #### setting optimizer STEP 2 UPDATE WEIGHTS #######
    optimizer = None
    _logger.info("Starting ProxylessNAS")
    lr_scheduler_fn=None  

    trainer = ProxylessBasecalling(
            model=model,
            train_loader=train_loader, 
            valid_loader=valid_loader,
            optimizer=optimizer,
            lr_scheduler_fn=lr_scheduler_fn,
            ctrl_learning_rate=ctlr,
            applied_hardware=hardware,
            metrics=lambda output, target: accuracy(output, target, topk=(1, 5,)),
            log_frequency=10,
            grad_reg_loss_type=grad_reg_loss_type, 
            grad_reg_loss_params=grad_reg_loss_params, 
            dummy_input=(344,1,9),
            ref_latency=reference_latency,
            rubicon=rubicon,
            default=default,
            num_epochs=epochs,
            rub_sched=rub_sched,
            dart_sched=dart_sched,
            rub_ctrl_opt=rub_ctrl_opt,
            prox_ctrl_opt=prox_ctrl_opt               
        )

trainer.fit(workdir, epochs, lr)
final_architecture = trainer.export()
_logger.info("Final architecture:{}".format(trainer.export()))

# the json file where the output must be stored
out_file = open(arc_checkpoint, "w")
json.dump(final_architecture, out_file, indent = 6)
out_file.close()

_logger.info("JSON file saved at:{}".format(os.path.expanduser(args.arc_checkpoint)))
_logger.info("End date and time:{}".format(datetime.datetime.now()))

11/07/2023 06:55:50 AM [INFO] Start date and time:2023-11-07 06:55:50.655266


[2023-11-07 06:55:50] INFO (__main__/MainThread) Start date and time:2023-11-07 06:55:50.655266
[2023-11-07 06:55:50] INFO (__main__/MainThread) Start date and time:2023-11-07 06:55:50.655266
[error] temp exists. Removing.


11/07/2023 06:55:50 AM [INFO] [loading model]


[2023-11-07 06:55:50] INFO (__main__/MainThread) [loading model]
[2023-11-07 06:55:50] INFO (__main__/MainThread) [loading model]
BaseModelQuant model


11/07/2023 06:55:56 AM [INFO] NAS type:proxy


[2023-11-07 06:55:56] INFO (__main__/MainThread) NAS type:proxy
[2023-11-07 06:55:56] INFO (__main__/MainThread) NAS type:proxy


11/07/2023 06:55:56 AM [INFO] Hardware type:aie_lut


[2023-11-07 06:55:56] INFO (__main__/MainThread) Hardware type:aie_lut
[2023-11-07 06:55:56] INFO (__main__/MainThread) Hardware type:aie_lut


11/07/2023 06:55:56 AM [INFO] Reference latency:65


[2023-11-07 06:55:56] INFO (__main__/MainThread) Reference latency:65
[2023-11-07 06:55:56] INFO (__main__/MainThread) Reference latency:65


11/07/2023 06:55:56 AM [INFO] lambda:0.6


[2023-11-07 06:55:56] INFO (__main__/MainThread) lambda:0.6
[2023-11-07 06:55:56] INFO (__main__/MainThread) lambda:0.6


11/07/2023 06:55:56 AM [INFO] [loading data]


[2023-11-07 06:55:56] INFO (__main__/MainThread) [loading data]
[2023-11-07 06:55:56] INFO (__main__/MainThread) [loading data]


11/07/2023 06:55:56 AM [INFO] Not full dataset training with shuffling


[2023-11-07 06:55:56] INFO (__main__/MainThread) Not full dataset training with shuffling
[2023-11-07 06:55:56] INFO (__main__/MainThread) Not full dataset training with shuffling


11/07/2023 06:55:56 AM [INFO] Dataset length: 128/1221470


[2023-11-07 06:55:56] INFO (rubicon.data/MainThread) Dataset length: 128/1221470
[2023-11-07 06:55:56] INFO (rubicon.data/MainThread) Dataset length: 128/1221470
