In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.basics import *
import json
import toml
from tqdm import tqdm
import wandb
from wandb.fastai import WandbCallback

import jkbc.model as m
import jkbc.model.factory as factory
import jkbc.constants as constants
import jkbc.files.torch_files as f
import jkbc.model.metrics as metric
import jkbc.utils.preprocessing as prep
import jkbc.utils.postprocessing as pop
import jkbc.files.fasta as fasta

## Constants

### Data

In [5]:
BASE_DIR = Path("../..")
PATH_DATA = 'data/feather-files'
DATA_SET = 'Range0-5-FixLabelLen400-winsize4096'
DATA_SET = BASE_DIR/PATH_DATA/DATA_SET

with open(DATA_SET/'config.json', 'r') as fp:
    config = json.load(fp)

ALPHABET          = constants.ALPHABET
ALPHABET_VAL      = list(ALPHABET.values())
ALPHABET_STR      = ''.join(ALPHABET_VAL)
ALPHABET_SIZE     = len(ALPHABET.keys())
WINDOW_SIZE       = int(config['maxw']) #maxw = max windowsize
DIMENSIONS_OUT    = int(config['maxl']) # maxl = max label length
EXCLUDED_BACTERIA = ''.join(config['exclude']) # e = list of excluded bacteria

KNOWLEDGE_DISTILLATION = True
TEACHER_NAME = 'bonito' # Set to name of y_teacher output
if KNOWLEDGE_DISTILLATION and not TEACHER_NAME:
    print('WARNING! Must provide name of teacher output when doing knowledge distillation')

### Train/Predict

In [6]:
LR = 1e-3  # default learning rate
BS = 2**6  # batch size
EPOCHS = 400
DEVICE = m.get_available_gpu() #torch.device("cpu")
DEVICE

device(type='cuda', index=0)

### Model

In [7]:
DROP_LAST = False # SET TO TRUE IF IT FAILS ON LAST BATCH
MODEL_DEFINITION = "quartznet5x5.toml"
model, (MODEL_NAME, pred_dim_out) = factory.bonito(WINDOW_SIZE, DEVICE, MODEL_DEFINITION)
MODEL_BASE_DIR = MODEL_NAME
EXPERIMENT = f'i-am-random3'
MODEL_DIR = f'{MODEL_BASE_DIR}/{EXPERIMENT}'
WEIGHTS_DIR = f'{MODEL_DIR}/weights'
model_weights = None#'bestmodel_8'

In [17]:
config = dict(
    window_size = WINDOW_SIZE,
    dimensions_out = DIMENSIONS_OUT,
    excluded_bacteria = EXCLUDED_BACTERIA,
    teacher_name = TEACHER_NAME,
    model_name = MODEL_NAME,
    model_definition = MODEL_DEFINITION,
    pretrained_weights = model_weights,
    knowledge_distillation = KNOWLEDGE_DISTILLATION,
    epochs = EPOCHS,
    batch_size = BS,
    learning_rate = LR,
    dropout = .0,
    weight_decay = .1,
    momentum = .0,
    optimizer = 'AdamW',
    schedular = 'one_cycle',
    kd_temperature = 20,
    kd_alpha = 0.5,
    alphabet = ALPHABET,
    drop_last = DROP_LAST,
    device = DEVICE,
    data_set = DATA_SET
)
wandb.init(config=config)
config = wandb.config

ALPHABET          = config.alphabet
ALPHABET_VAL      = list(ALPHABET.values())
ALPHABET_STR      = ''.join(ALPHABET_VAL)
ALPHABET_SIZE     = len(ALPHABET.keys())

In [18]:
#m.save_setup(MODEL_DIR, config) 
# TODO

In [19]:
# Run to get newest model
if not model_weights:
    model_weights = m.get_newest_model(WEIGHTS_DIR)
print(model_weights)

None


### Loss, metrics and callback

In [20]:
_ctc_loss = metric.CtcLoss(config.window_size, pred_dim_out, config.batch_size, ALPHABET_SIZE)
_kd_loss = metric.KdLoss(alpha=config.kd_alpha, temperature=config.kd_temperature, label_loss=_ctc_loss)
loss = _kd_loss.loss() if config.knowledge_distillation else _ctc_loss.loss()

metrics = [metric.ctc_accuracy(ALPHABET, 5)]

## Load data

In [21]:
# Read data from feather
if config.knowledge_distillation:
    data, teacher = f.load_training_data_with_teacher(config.data_set, config.teacher_name)
    train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=config.batch_size, teacher=teacher, drop_last=config.drop_last)
else:
    data = f.load_training_data(config.data_set) 
    train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=config.batch_size, drop_last=config.drop_last)

# Convert to databunch
databunch = DataBunch(train_dl, valid_dl, device=config.device)

## Model

In [22]:
#_, optimizer = m.optimizer(config)
#_, scheduler = m.scheduler(config)
optimizer = partial(torch.optim.AdamW, amsgrad=True, lr=config.learning_rate)

In [23]:
learner = Learner(databunch, model, loss_func=loss, metrics=metrics, opt_func=optimizer, callback_fns=WandbCallback).to_fp16()
#scheduler(learner)

## Train

In [24]:
learner.fit(config.epochs, lr=config.learning_rate, wd=config.weight_decay)

epoch,train_loss,valid_loss,ctc_accuracy,time
0,11581.475586,14188.15332,0.626304,00:02
1,12461.088867,20753.439453,0.639057,00:02
2,14062.364258,18198.173828,0.602925,00:02
3,14124.158203,14130.328125,0.606296,00:02
4,13807.326172,13660.248047,0.601344,00:02
5,13258.286133,13723.455078,0.62662,00:02
6,12561.811523,14234.560547,0.666836,00:02
7,11774.161133,15316.217773,0.673786,00:02
8,11060.816406,16215.998047,0.656721,00:02
9,10372.283203,16508.439453,0.660305,00:02


Better model found at epoch 0 with valid_loss value: 14188.1533203125.
Better model found at epoch 3 with valid_loss value: 14130.328125.
Better model found at epoch 4 with valid_loss value: 13660.248046875.
Loaded best saved model from /user/student.aau.dk/jfraus14/basecaller-p10/nbs/models/wandb/run-20200417_102040-228npz6z/bestmodel.pth


KeyboardInterrupt: 