# Creating a Databunch for Basecalling

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import functools

from fastai.basics import *

import jkbc.utils.preprocessing as prep
import jkbc.utils.postprocessing as pop
import jkbc.utils.files as f
import jkbc.types as t
import jkbc.utils.loss as loss


## Constants

### Data

In [None]:
BLANK_ID       = prep.BLANK_ID
ALPHABET       = pop.ALPHABET # {0: '-', 1: 'A', 2: 'B'}
ALPHABET_VAL   = list(ALPHABET.values())
ALPHABET_STR   = ''.join(ALPHABET_VAL)
ALPHABET_SIZE  = len(ALPHABET_VAL)
WINDOW_SIZE    = 300
DIMENSIONS_OUT = 70

### Model

In [None]:
HIDDEN_UNITS = 200
NF = 256 # features in residualblock
PRED_OUT_DIM = DIMENSIONS_OUT*2-1

### Train/Predict

In [None]:
BS = 1024  # batch size
DEVICE = torch.device("cuda:0") #torch.device("cpu")
MODEL_NAME = 'chiron-capped-output'

### Data path

In [None]:
BASE_DIR = "../data/feather-files/"
PATH_DATA = Path(BASE_DIR)
DATA_SET = f'Range0-2000-FixLabelLen{DIMENSIONS_OUT}'
FEATHER_FOLDER = PATH_DATA/DATA_SET

## Load data

In [None]:
# Read data from feather
data = f.read_data_from_feather_file(FEATHER_FOLDER)

# Convert to databunch
train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=BS, drop_last=True)
del data
databunch = DataBunch(train_dl, valid_dl, device=DEVICE)
del train_dl
del valid_dl

## Model

In [None]:
loss_func = loss.ctc_loss(PRED_OUT_DIM, BS, ALPHABET_SIZE)

In [None]:
def conv(ni, nf, ks=1, padding=0): return nn.Conv1d(ni, nf, kernel_size=ks, stride=1, padding=padding)
def conv_layers(ni, nf): 
    return nn.Sequential(
        conv(ni, NF)
        ,nn.BatchNorm1d(NF)
        ,nn.ReLU()
        ,conv(NF, NF, 3, padding=1)
        ,nn.BatchNorm1d(NF)
        ,nn.ReLU()
        ,conv(NF, nf)
    )

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, ni, nf):
        super().__init__()
        self.module = conv_layers(ni, nf)
        self.residual = conv(ni, nf)
    
    def forward(self, x):
        out_a = self.module(x)
        out_b = self.residual(x)
        
        return nn.ReLU()(out_a + out_b)

class LstmBlock(nn.Module):
    def __init__(self, input_size, batch_size, hidden_units, no_of_layers, device):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_units, no_of_layers, bidirectional=True, batch_first=True)
        
        ## Multiply by 2 because of bidirectional
        h0 = torch.zeros(2*no_of_layers, batch_size, hidden_units).to(device=DEVICE)
        c0 = torch.zeros(2*no_of_layers, batch_size, hidden_units).to(device=DEVICE)
        
        self.hidden=(h0,c0)
        
    def forward(self, x):                
        res, _ = self.lstm(x, self.hidden)
        
        return res
        
res = ResidualBlock(NF, NF)
lstm = LstmBlock(WINDOW_SIZE, BS, HIDDEN_UNITS, no_of_layers=4, device=DEVICE)
model = nn.Sequential(
    ResidualBlock(1, NF)
    ,res
    ,res
    ,res
    ,ResidualBlock(NF, PRED_OUT_DIM)
    ,nn.BatchNorm1d(PRED_OUT_DIM)
    
    ,lstm
    
    ,nn.Linear(HIDDEN_UNITS*2,ALPHABET_SIZE)
    ,nn.LogSoftmax(dim=2) 
).to(device=DEVICE)

In [None]:
learner = Learner(databunch, model, loss_func=loss_func, path='..')

In [None]:
try:
    learner = learner.load(MODEL_NAME)
    print('Model weights loaded')
except:
    print('No model weights available')

In [None]:
learner.lr_find()
learner.recorder.plot(suggestion=True)

In [None]:
learner.fit_one_cycle(500, max_lr=5e-5)

In [None]:
learner.save(MODEL_NAME)

In [None]:
x, (y, _) = databunch.one_batch()
x_device = x.to(device=DEVICE)
y_pred = model(x_device).detach().cpu().numpy()
y_pred.shape

In [None]:
def get_stats(prediction: t.Tensor2D, actual: str, alphabet: t.List[str], beam_sizes: t.List[int]):
    print(actual)
    y_pred_index = prediction[None,:,:]
    for beam in beam_sizes:
        decoded = pop.decode(y_pred_index, threshold=.0, beam_size=beam, alphabet=alphabet)   
        predicted = decoded[0]
        error = pop.calc_sequence_error_metrics(actual, predicted)
        yield (predicted, beam, error)

In [None]:
for index in [0, BS-1]:
    actual = pop.convert_idx_to_base_sequence(y[index], ALPHABET_VAL)
    prediction = y_pred[index]
    for pred, beam, error in get_stats(prediction, actual, ALPHABET_STR, range(15,50)):
        print(pred, beam, error.error)
    print('')