# Creating a Databunch for Basecalling

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.basics import *

import jkbc.utils.preprocessing as prep
import jkbc.utils.postprocessing as pop
import jkbc.utils.files as f
import jkbc.types as t

## Constants

In [3]:
BLANK_ID = prep.BLANK_ID
C = 5
D_in = 300
D_out_max = 70
n_hidden = 200
BS = 64  # batch size
LR = 0.05
NF = 256 # features in residualblock

#DEVICE = torch.device("cpu")
DEVICE = torch.device("cuda:0")

## Load Data

In [4]:
base_dir = "data/feather-files/"
path_data = Path(base_dir)
data_set_name = f'Range0-100-FixLabelLen{D_out_max}'
feather_folder = path_data/data_set_name

In [5]:
# Read data from feather
data = f.read_data_from_feather_file(feather_folder)
x, y, y_lengths = data

# Convert to databunch
train, valid = prep.convert_to_datasets(data, split=.8)
databunch = DataBunch.create(train, valid, bs=BS, device=DEVICE)

## Model

In [6]:
# Create lengths to be used in ctc_loss
y_pred_lengths = prep.get_prediction_lengths(D_in, BS)

In [7]:
ctc_loss = nn.CTCLoss()
def ctc_loss_custom(y_pred_b: torch.Tensor, y_b: torch.Tensor, y_lengths) -> float:
    if y_pred_lengths.shape[0] != y_pred_b.shape[0]:
        new_len = y_pred_b.shape[0]
        y_pred_lengths_ = y_pred_lengths[:new_len]
    else:
        y_pred_lengths_ = y_pred_lengths
        
    y_pred_b_ = y_pred_b.reshape((y_pred_b.shape[1], y_pred_b.shape[0], C))

    return ctc_loss(y_pred_b_, y_b, y_pred_lengths_, y_lengths)    

In [8]:
loss_func = ctc_loss_custom

In [9]:
def _conv(ni, nf, ks=1, padding=0): return nn.Conv1d(ni, nf, kernel_size=ks, stride=1, padding=padding)
def _conv_layers(ni, nf): 
    return nn.Sequential(
        _conv(ni, NF)
        ,nn.ReLU()
        ,_conv(NF, NF, 3, padding=1)
        ,nn.ReLU()
        ,_conv(NF, nf)
    )

In [10]:
class ResidualBlock(nn.Module):
    def __init__(self, ni, nf):
        super().__init__()
        self.module = _conv_layers(ni, nf)
        self.residual = _conv(ni, nf)
    
    def forward(self, x):
        out_a = self.module(x)
        out_b = self.residual(x)
        
        return nn.ReLU()(out_a + out_b)

class LstmBlock(nn.Module):
    def __init__(self, input_size, window_size, output_size, no_of_layers):
        super().__init__()
        self.lstm = nn.LSTM(input_size, output_size, no_of_layers, bidirectional=True)
        
        ## Multiply by 2 because of bidirectional
        h0 = torch.zeros(2*no_of_layers, window_size, output_size).cuda()
        c0 = torch.zeros(2*no_of_layers, window_size, output_size).cuda()
        self.hidden=(h0,c0)
        
    def forward(self, x):
        (x1, x2, x3) = x.shape
        x_ = x.reshape(x1, x3, x2)
        ## res.shape = (batch_size, output_size, output_size*2)
        res, _ = self.lstm(x_, self.hidden)
        return res
        
    
model = nn.Sequential(
    ResidualBlock(1, NF)
    ,ResidualBlock(NF, NF)
    ,ResidualBlock(NF, NF)
    ,ResidualBlock(NF, NF)
    ,ResidualBlock(NF, NF)
    ,nn.BatchNorm1d(NF)
    
    ,LstmBlock(NF, D_in, n_hidden, no_of_layers=4)
    
    ,nn.Linear(n_hidden*2,5)
    ,nn.LogSoftmax(dim=2)
).cuda()

In [11]:
learner = Learner(databunch, model.cuda(), loss_func=loss_func)

In [12]:
learner2 = learner.load('chiron')


In [None]:
learner.lr_find()
learner.recorder.plot(suggestion=True)

In [None]:
learner.fit_one_cycle(1, max_lr=4.8e-4)

In [None]:
learner.save('chiron')

In [13]:
x, (y, _) = databunch.one_batch()
y_pred = model(x.cuda()).detach().cpu().numpy()

In [23]:
index = 0
actual = pop.convert_idx_to_base_sequence(y[index])
print(actual)
for beam in range(5,12):
    decoded = pop.decode(y_pred, threshold=.01, beam_size=beam)   
    predicted = decoded[index]
    error = pop.calc_sequence_error_metrics(actual, predicted)
    print(predicted, beam, error.error)
    del predicted

GGGACGCAGGAGGCTAAGCAAACCAGACGATTGGAA
CCCCCC 5 0.9166666666666666
CCCCCCCCCA 6 0.8055555555555556
CACCCCCCACCAC 7 0.7222222222222223
CACACCCCACCAC 8 0.7222222222222222
CACACCCCACCAC 9 0.7222222222222222
CACACCCCACCAC 10 0.7222222222222222
CCCCACCCACCAC 11 0.75
