## Skip Removal with Knowledge Distillation

SkipClip performs a gradual skip removal process with knowledge distillation (KD). KD is a model compression technique where a shallower model (student) learns to mimic a pre-trained bigger model (teacher) by transferring learned knowledge and label representation from the teacher to the student. SkipClip starts with a pre-trained over-parameterized model as the teacher, which is not updated during the training of the student network.  We achieve skip removal by letting the teacher teach the student to perform well on basecalling. At the start of every training epoch, SkipClip removes a skip connection from a block, starting from the input side, while performing KD. This is done until all skip connections are removed from the student network. SkipClip gets the best of both worlds: a highly accurate and topologically regular neural network without skip connections.

Download a model to use it as the teacher:
``` bash
$ cd models
$ bash download_teacher.sh
```

In [1]:
import os
import sys
from argparse import ArgumentParser
from argparse import ArgumentDefaultsHelpFormatter
from pathlib import Path
from importlib import import_module
import logging
from rubicon.data import load_numpy_shuf,load_numpy_full
from bonito.data import load_script
from bonito.util import load_symbol, init,load_model
from rubicon.util import __models__, default_config, default_data
from rubicon.util import load_model_prune, load_model_prune_for_kd
from rubicon.kdtraining import load_state, KDTrainer
import toml
import torch
import numpy as np
from torch.utils.data import DataLoader
from contextlib import redirect_stdout
import subprocess
import shutil
from prettytable import PrettyTable
import logging
from rubicon.basemodule.prune import count_parameters
_logger = logging.getLogger(__name__)

No ROCm runtime is found, using ROCM_HOME='/opt/rocm-5.1.0'


In [2]:
import os
import sys
from argparse import ArgumentParser
from argparse import ArgumentDefaultsHelpFormatter
from pathlib import Path
from importlib import import_module
import logging
from rubicon.data import load_numpy_shuf,load_numpy_full
from bonito.data import load_script
from bonito.util import load_symbol, init,load_model
from rubicon.util import __models__, default_config, default_data
from rubicon.util import load_model_prune, load_model_prune_for_kd
from rubicon.kdtraining import load_state, KDTrainer
import toml
import torch
import numpy as np
from torch.utils.data import DataLoader
from contextlib import redirect_stdout
import subprocess
import shutil
from prettytable import PrettyTable
import logging
from rubicon.basemodule.prune import count_parameters
_logger = logging.getLogger(__name__)

In [3]:
temp=1
alpha=0.1
structural=False
l1_scoring=True
prune_dim=0
prune_proportion=0.1
save_directory="skip_results"
workdir = os.path.expanduser(save_directory)
onnx_name=""
onnx_name+=str(prune_proportion)
force=True
pretrained=""
teacher=True
teacher_directory="../rubicon/models/bonito"
seed=25
config="../rubicon/models/configs/config.toml"
device="cuda"
quant=False
type="bonito"
full=False
chunks=128
valid_chunks=128
batch=128
restore_optim=False
save_optim_every=10
grad_accum_split=1
epochs=10
lr=2e-3

In [None]:
temp=temp
alpha=alpha
workdir = os.path.expanduser(save_directory)   
_logger.info("Save path: {}".format(workdir))


if force and os.path.exists(workdir) :
    shutil.rmtree(workdir)
if os.path.exists(workdir) and not force:
    print("[error] %s exists, remove it to continue or use -f to force delete." % workdir)
    exit(1)
init(seed, device)
device = torch.device(device)


if pretrained:
    dirname = pretrained
    if not os.path.isdir(dirname) and os.path.isdir(os.path.join(__models__, dirname)):
        dirname = os.path.join(__models__, dirname)
    config_file = os.path.join(dirname, 'config.toml')
else:
    config_file = config
config = toml.load(config_file)
os.makedirs(workdir, exist_ok=True)    
teacher_path=os.path.expanduser(teacher_directory)
_logger.info("Loading Teacher Model:{}".format(teacher_path))
model_teacher = load_model(teacher_path, device, half=False)
original_stdout = sys.stdout # Save a reference to the original standard output
with open(workdir+'/model_student.txt', 'w') as f:
        sys.stdout = f # Change the standard output to the file we created.       
        if(pretrained):   
                _logger.info("Using Pretrained Student:{}".format(pretrained))
                model_student = load_model_prune_for_kd(pretrained, device, half=False,load_model_type=stud_type, no_prune=no_prune)
        else:              
            if(type=="rubiconqabas-mp"):
                _logger.info("Training a new student:{}".format(type)) 
                model_student = load_symbol(config, 'RubiconSkipTrim')(config)  
        count_parameters(model_student)
      
        if config.get("lr_scheduler"):
            sched_config = config["lr_scheduler"]
            lr_scheduler_fn = getattr(
                import_module(sched_config["package"]), sched_config["symbol"]
            )(**sched_config)
        else:
            lr_scheduler_fn = None
        
        _logger.info("Total parameters in model teacher:{}".format(sum(p.numel() for p in model_teacher.parameters()))) 
        _logger.info("Total parameters in model student:{}".format(sum(p.numel() for p in model_student.parameters())) )
        _logger.info("Loading Data")
        if full:
            _logger.info("Full dataset training")
            train_loader_kwargs, valid_loader_kwargs = load_numpy_full(None,
                    directory
            )
        elif chunks:
            _logger.info("Not full dataset training with shuffling")
            train_loader_kwargs, valid_loader_kwargs = load_numpy(
                chunks,valid_chunks, directory
            )
        else:
            _logger.warning("Please define the training data correctly")
            exit(1)

        loader_kwargs = {
            "batch_size": batch, "num_workers": 1, "pin_memory": True
        }
        train_loader = DataLoader(**loader_kwargs, **train_loader_kwargs)
        valid_loader = DataLoader(**loader_kwargs, **valid_loader_kwargs)

        _logger.info("Starting SkipTrim")

        trainer = KDTrainer(
            model_teacher,model_student, device, train_loader, valid_loader,
            use_amp=False,
            lr_scheduler_fn=lr_scheduler_fn,
            restore_optim=restore_optim,
            save_optim_every=save_optim_every,
            grad_accum_split=grad_accum_split,
            temp=temp,
            alpha=alpha
        )

        trainer.fit(workdir, epochs, lr, skip_stride)

11/07/2023 08:33:00 AM [INFO] Save path: skip_results


[2023-11-07 08:33:00] INFO (__main__/MainThread) Save path: skip_results
[2023-11-07 08:33:00] INFO (__main__/MainThread) Save path: skip_results


11/07/2023 08:33:00 AM [INFO] Loading Teacher Model:../rubicon/models/bonito


[2023-11-07 08:33:00] INFO (__main__/MainThread) Loading Teacher Model:../rubicon/models/bonito
[2023-11-07 08:33:00] INFO (__main__/MainThread) Loading Teacher Model:../rubicon/models/bonito
