# Creating a Databunch for Basecalling

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import functools

from fastai.basics import *

import jkbc.utils.preprocessing as prep
import jkbc.utils.postprocessing as pop
import jkbc.utils.files as f
import jkbc.types as t
import jkbc.utils.loss as loss


## Constants

In [None]:
BLANK_ID = prep.BLANK_ID
C = 3
D_in = 300
D_out_max = 70
n_hidden = 200
BS = 1024  # batch size
LR = 0.05
NF = 256 # features in residualblock
PRED_OUT_DIM = D_out_max*2-1
MODEL_NAME = 'chiron-binary-capped-output'

#DEVICE = torch.device("cpu")
DEVICE = torch.device("cuda:0")

## Load Data

In [None]:
base_dir = "data/feather-files/"
path_data = Path(base_dir)
data_set_name = "fake_data10000-binary"
feather_folder = path_data/data_set_name

In [None]:
# Read data from feather
data = f.read_data_from_feather_file(feather_folder)

# Convert to databunch
train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=BS, drop_last=True)
databunch = DataBunch(train_dl, valid_dl, device=DEVICE)

## Model

In [None]:
loss_func = loss.ctc_loss(PRED_OUT_DIM, BS, C)

In [None]:
def conv(ni, nf, ks=1, padding=0): return nn.Conv1d(ni, nf, kernel_size=ks, stride=1, padding=padding)
def conv_layers(ni, nf): 
    return nn.Sequential(
        conv(ni, NF)
        ,nn.BatchNorm1d(NF)
        ,nn.ReLU()
        ,conv(NF, NF, 3, padding=1)
        ,nn.BatchNorm1d(NF)
        ,nn.ReLU()
        ,conv(NF, nf)
    )

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, ni, nf):
        super().__init__()
        self.module = conv_layers(ni, nf)
        self.residual = conv(ni, nf)
    
    def forward(self, x):
        out_a = self.module(x)
        out_b = self.residual(x)
        
        return nn.ReLU()(out_a + out_b)

class LstmBlock(nn.Module):
    def __init__(self, input_size, window_size, output_size, no_of_layers, device):
        super().__init__()
        self.lstm = nn.LSTM(input_size, output_size, no_of_layers, bidirectional=True, batch_first=True)
        
        ## Multiply by 2 because of bidirectional
        h0 = torch.zeros(2*no_of_layers, window_size, output_size).to(device=DEVICE)
        c0 = torch.zeros(2*no_of_layers, window_size, output_size).to(device=DEVICE)
        
        self.hidden=(h0,c0)
        
    def forward(self, x):                
        res, _ = self.lstm(x, self.hidden)
        
        return res
        
res = ResidualBlock(NF, NF)
lstm = LstmBlock(D_in, BS, n_hidden, no_of_layers=4, device=DEVICE)
model = nn.Sequential(
    ResidualBlock(1, NF)
    ,res
    ,res
    ,res
    ,ResidualBlock(NF, PRED_OUT_DIM)
    ,nn.BatchNorm1d(PRED_OUT_DIM)
    
    ,lstm
    
    ,nn.Linear(n_hidden*2,C)
    ,nn.LogSoftmax(dim=2) 
).to(device=DEVICE)

In [None]:
learner = Learner(databunch, model, loss_func=loss_func)

In [None]:
try:
    learner = learner.load(MODEL_NAME)
    print('Model weights loaded')
except:
    print('No model weights available')

In [None]:
learner.lr_find()
learner.recorder.plot(suggestion=True)

In [None]:
learner.fit_one_cycle(100, max_lr=2.75e-5)

In [None]:
learner.save(MODEL_NAME)

In [None]:
x, (y, _) = databunch.one_batch()
y_pred = model(x.to(device=DEVICE)).detach().cpu().numpy()

In [None]:
index = 0
actual = pop.convert_idx_to_base_sequence(y[index], alphabet_val)
actual_len = len(actual)
print(actual)
for beam in range(10, 25):
    decoded = pop.decode(y_pred, threshold=.01, beam_size=beam, alphabet=alphabet_str)   
    predicted = decoded[index]
    error = pop.calc_sequence_error_metrics(actual, predicted)
    print(predicted, beam, error.error, abs(len(predicted)-actual_len)/actual_len)
    del predicted