In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai.basics import *

import jkbc.utils.general as g
import jkbc.utils.files as f
import jkbc.utils.metrics as metric
import jkbc.utils.postprocessing as pop
import jkbc.utils.preprocessing as prep

## Constants

### Data

In [None]:
BLANK_ID       = pop.BLANK_ID
ALPHABET       = pop.ALPHABET # {0: '-', 1: 'A', 2: 'B'}
ALPHABET_VAL   = list(ALPHABET.values())
ALPHABET_STR   = ''.join(ALPHABET_VAL)
ALPHABET_SIZE  = len(ALPHABET_VAL)
WINDOW_SIZE    = 300
STRIDE         = 300
DIMENSIONS_OUT = 70

BASE_DIR = Path("../..")
PATH_DATA = 'data/feather-files'
DATA_SET = f'Range0-5000-FixLabelLen{DIMENSIONS_OUT}'
FEATHER_FOLDER = BASE_DIR/PATH_DATA/DATA_SET

### Train/Predict

In [None]:
LR = 1e-3  # default learning rate
BS = 1024  # batch size
EPOCHS_BEFORE_SAVING = 10
SAVE_MODEL_ITERATIONS = 500
DEVICE = torch.device("cuda:0") #torch.device("cpu")
MODEL_NAME = g.get_notebook_name()
MODEL_DIR = BASE_DIR/'models'/MODEL_NAME
SPECIFIC_MODEL = None #Set to specific name of model ('None' uses the newest)

### Model

In [None]:
DIMENSIONS_PREDICTION_OUT = DIMENSIONS_OUT*2-1
DROP_LAST    = True # SET TO TRUE IF IT FAILS ON LAST BATCH
LOSS_FUNC = metric.ctc_loss(DIMENSIONS_PREDICTION_OUT, BS, ALPHABET_SIZE)
METRICS = [metric.ErrorRate(metric.ctc_error(ALPHABET, 5))]

## Load data

In [None]:
# Read data from feather
data = f.read_data_from_feather_file(FEATHER_FOLDER)

# Convert to databunch
train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=BS, drop_last=DROP_LAST)
del data
databunch = DataBunch(train_dl, valid_dl, device=DEVICE)
del train_dl
del valid_dl

## Model

## Build
*Put entire model definition and constants in the same field*

In [None]:
NF = 123
def conv(ni, nf, ks):
    padding = ks//2
    return nn.Conv1d(ni, nf, kernel_size=ks, stride=1, padding=padding)

def ConvLayer(ni, nf, ks):
    return nn.Sequential(
        conv(ni, nf, ks)
        ,nn.BatchNorm1d(nf)
        ,nn.ReLU()
    )

def ConvBlock(a, b, c, ni, nf):
    return nn.Sequential(
        ConvLayer(ni, nf, a)
    )

class ResidualBlock(nn.Module):
    def __init__(self, a, b, c, ni, nf):
        super().__init__()
        self.module = ConvBlock(a, b, c, ni, nf)
        self.residual = ConvBlock(a, b, c, ni, nf)
    
    def forward(self, x):
        out_a = self.module(x)
        out_b = self.residual(x)
        
        return nn.ReLU()(out_a + out_b)

model = nn.Sequential(
    ResidualBlock(2+1, 2^2+1, 2^3+1, 1, NF)
    ,ResidualBlock(2^2+1, 2^5+1, 2^6+1, NF, NF)
    ,ResidualBlock(2^3+1, 2^8+1, 2^8+1, NF, NF)
    ,ResidualBlock(2^4+1, 2^8+1, 2^8+1, NF, NF)
    ,ResidualBlock(2^5+1, 2^8+1, 2^8+1, NF, NF)
    ,ResidualBlock(2^6+1, 2^8+1, 2^8+1, NF, DIMENSIONS_PREDICTION_OUT)
    ,nn.Linear(WINDOW_SIZE+2, ALPHABET_SIZE)
    ,nn.LogSoftmax(dim=2)
).to(device=DEVICE)

In [None]:
learner = Learner(databunch, model, loss_func=LOSS_FUNC, model_dir=MODEL_DIR, metrics=METRICS)

In [None]:
try:
    if SPECIFIC_MODEL:
        model_name = SPECIFIC_MODEL
    else:
        model_name = g.get_newest_model(MODEL_DIR)
        
    learner = learner.load(model_name)
    print('Model weights loaded', model_name)
except:
    print('No model weights available')

## Train

In [None]:
learner.lr_find()
learner.recorder.plot(suggestion=True)

In [None]:
# Default to LR if lr_find() has not been run
try:
    lr = learner.recorder.min_grad_lr
except:
    lr = LR
lr

In [None]:
for iter in range(SAVE_MODEL_ITERATIONS):
    print(iter)
    learner.fit_one_cycle(EPOCHS_BEFORE_SAVING, max_lr=lr)
    valid, ctc  = learner.validate()
    learner.save(MODEL_NAME+f'-Valid[{valid}]-CTC[{ctc}]')

## Predict

In [None]:
x, y, _ = databunch.valid_dl.tensors
x_device = x.to(device=DEVICE)
y_pred = model(x_device).detach().cpu().numpy()

In [None]:
r = range(min(5, DIMENSIONS_PREDICTION_OUT), min(DIMENSIONS_PREDICTION_OUT+1, 50)) #Range can't contain values larger than PRED_OUT_DIM

for index in [0, BS-1]:
    actual = pop.convert_idx_to_base_sequence(y[index], ALPHABET_VAL)
    prediction = y_pred[index]
    for pred, beam, error in g.get_stats(prediction, actual, ALPHABET_STR, r):
        print(pred, beam, error.error)
    print('')

In [None]:
# Run assemble only if data is not fetched from featherfile
# TODO: Assembled should not be on a batch, but instead on a complete signal
if STRIDE != WINDOW_SIZE:
    decoded = pop.decode(y_pred, ALPHABET_STR, beam_size=15)
    assembled = pop.assemble(decoded, WINDOW_SIZE, STRIDE, ALPHABET)
    print('Assembled:', assembled)