In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.basics import *
import json
from tqdm import tqdm

import jkbc.model as m
import jkbc.utils.constants as constants
import jkbc.utils.torch_files as f
import jkbc.utils.general as g
import jkbc.utils.metrics as metric
import jkbc.utils.preprocessing as prep
import jkbc.utils.postprocessing as pop
import jkbc.utils.fasta as fasta

In [3]:
# Initialise random libs and setup cudnn
random_seed = 25 # MAGIC!!
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Constants

### Data

In [4]:
BASE_DIR = Path("../../..")
PATH_DATA = 'data/feather-files'
DATA_SET = 'Range0-50-FixLabelLen400-winsize4096'
FEATHER_FOLDER = BASE_DIR/PATH_DATA/DATA_SET

with open(FEATHER_FOLDER/'config.json', 'r') as fp:
    config = json.load(fp)

ALPHABET       = constants.ALPHABET
ALPHABET_VAL   = list(ALPHABET.values())
ALPHABET_STR   = ''.join(ALPHABET_VAL)
ALPHABET_SIZE  = len(ALPHABET.keys())
WINDOW_SIZE    = int(config['maxw']) #maxw = max windowsize
DIMENSIONS_OUT = int(config['maxl']) # maxl = max label length
STRIDE         = WINDOW_SIZE

KNOWLEGDE_DISTILLATION = True
TEACHER_OUTPUT = 'bonito-pretrained-Valid[3.625368118286133]-CTC[90.3227304562121]' # Set to name of y_teacher output
if KNOWLEGDE_DISTILLATION and not TEACHER_OUTPUT:
    print('WARNING! Must provide name of teacher output when doing knowledge distillation')

In [5]:
METRICS = [metric.ctc_accuracy(ALPHABET, 5)]
SAVE_CALLBACK = partial(metric.SaveModelCallback, every='epoch', monitor='valid_loss')

### Train/Predict

In [6]:
LR = 1e-3  # default learning rate
BS = 2**6  # batch size
EPOCHS = 20
DEVICE = torch.device("cuda:0") #torch.device("cpu")

### Model

In [7]:
import bonito_basic as model_file
DIMENSIONS_PREDICTION_OUT = WINDOW_SIZE//3+1
DROP_LAST = False # SET TO TRUE IF IT FAILS ON LAST BATCH

SPECIFIC_MODEL_WEIGHTS = None

## Load data

In [8]:
# Read data from feather
if KNOWLEGDE_DISTILLATION:
    data, teacher = f.load_training_data_with_teacher(FEATHER_FOLDER, TEACHER_OUTPUT)
    train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=BS, teacher=teacher, drop_last=DROP_LAST)
else:
    data = f.load_training_data(FEATHER_FOLDER) 
    train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=BS, drop_last=DROP_LAST)

# Convert to databunch
databunch = DataBunch(train_dl, valid_dl, device=DEVICE)

## Model

In [9]:
_ctc_loss = metric.CtcLoss(WINDOW_SIZE, DIMENSIONS_PREDICTION_OUT, BS, ALPHABET_SIZE)

loss_funcs = {}
for t in [1,2,4,8,16,32]:
    for a in np.arange(0,1.1,0.1):
        loss_funcs[f't={t},a={a}'] = metric.KdLoss(alpha=a, temperature=t, label_loss=_ctc_loss).loss()

In [10]:
optimizers = {'AdamW': partial(torch.optim.AdamW, amsgrad=True, lr=LR)}

In [11]:
## Model_name, ctc_accuracy, loss_function, optimizer
models = [partial(model_file.model, DEVICE, WINDOW_SIZE, DIMENSIONS_PREDICTION_OUT)]
with open('hyper-output', 'w') as f:
    for model in models:
        for l_key, loss in loss_funcs.items():
            for o_key, optim in optimizers.items():
                m, MODEL_NAME = model()
                MODEL_NAME = f'{MODEL_NAME}-windowsize={WINDOW_SIZE}'
                MODEL_DIR = f'weights/{MODEL_NAME}'

                # Create learner
                learner = Learner(databunch, m, loss_func=loss, model_dir=MODEL_DIR, metrics=METRICS, opt_func=optim)

                # FIT
                learner.fit(1, lr=LR, callbacks=[SAVE_CALLBACK(learner)])
                f.write(MODEL_NAME, learner.validate()[1], l_key, o_key, '\n')
                del m

epoch,train_loss,valid_loss,ctc_accuracy,time
0,1.473975,1.446206,0.677583,00:32


bonito-windowsize=4096 0.6775834682087424 t=1,a=0.0 AdamW


epoch,train_loss,valid_loss,ctc_accuracy,time
0,11.340686,8.586564,0.671101,00:32


bonito-windowsize=4096 0.6711011558265813 t=1,a=0.1 AdamW


epoch,train_loss,valid_loss,ctc_accuracy,time
0,18.557541,13.768038,0.658321,00:32


bonito-windowsize=4096 0.658321344790943 t=1,a=0.2 AdamW


epoch,train_loss,valid_loss,ctc_accuracy,time


KeyboardInterrupt: 