# Creating a Databunch for Basecalling

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.basics import *

import jkbc.utils.preprocessing as prep
import jkbc.utils.postprocessing as pop
import jkbc.utils.files as f

## Constants

In [3]:
BLANK_ID = 4
C = 5
D_in = 300
D_h = 201
D_out_max = 60
n_hidden = 400
BS = 1 #64 # batch size
LR = 0.05

DEVICE = torch.device("cpu")
#DEVICE = torch.device("cuda:0")

## Load Data

In [4]:
base_dir = "data/feather-files/"
path_data = Path(base_dir)
data_set_name = 'Range0-100-FixLabelLen60'
feather_folder = path_data/data_set_name

In [5]:
# Read data from feather
data = f.read_data_from_feather_file(feather_folder)
x, y_train = data

# Convert to databunch
train, valid = prep.convert_to_databunch(data, split=.8, device=DEVICE, window_size=D_in)
databunch = DataBunch.create(train, valid, bs=BS)

In [6]:
# Create lengths to be used in ctc_loss
y_pred_lengths, y_lengths = prep.get_y_lengths(D_in, D_out_max, BS)

## Model

In [7]:
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(1, D_h)
        self.lin2 = nn.Linear(D_h, C)
    
    #   forward(self, x: Tensor[D_in, BS, 1]) -> Tensor[D_in, BS, C]
    def forward(self, xb):
        x1 = self.lin1(xb).clamp(min=0)
        return self.lin2(x1)


In [8]:
def conv(ni, nf): return nn.Conv1d(ni, nf, kernel_size=3, stride=2, padding=1)
model = nn.Sequential(
     conv(300, 8) # 150
    ,nn.BatchNorm1d(8)
    ,nn.ReLU()
    ,conv(8,16) # 75
    ,nn.BatchNorm1d(16)
    ,nn.ReLU()
    ,conv(16,32) # 37
    ,nn.BatchNorm1d(32)
    ,nn.ReLU()
    ,conv(32,64) # 19
    ,nn.BatchNorm1d(64)
    ,nn.ReLU()
    ,conv(64,128) # 10
    ,nn.BatchNorm1d(128)
    ,nn.ReLU()
    ,conv(128,128) # 5
    ,nn.BatchNorm1d(128)
    ,nn.ReLU()
    ,conv(128,64) # 3
    ,nn.BatchNorm1d(64)
    ,nn.ReLU()
    ,conv(64,32) # 2
    ,nn.BatchNorm1d(32)
    ,nn.ReLU()
    ,Flatten()
)

class SimpleRnn(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.i_h = nn.Embedding(1, D_h)
        self.rnn = nn.RNN(D_h, D_h)
        self.h_o = nn.Linear(D_h, 1)
        self.bn  = nn.BatchNorm1d(D_h)
        self.h   = torch.zeros(1, x.shape[1], D_h)
    
    def forward(self, x):
        res, h = self.rnn(self.i_h(x), self.h)
        self.h = h.datach()
        return self.h_o(self.gn(res))

In [9]:
ctc_loss = nn.CTCLoss()
def ctc_loss_custom(y_pred_b: torch.Tensor, y_b: torch.Tensor) -> float:
    if y_pred_lengths.shape[0] != y_pred_b.shape[0]:
        new_len = y_pred_b.shape[0]
        y_pred_lengths_ = y_pred_lengths[:new_len]
        y_lengths_ = y_lengths[:new_len]
    else:
        y_pred_lengths_ = y_pred_lengths
        y_lengths_ = y_lengths
    
    y_pred_b_ = y_pred_b.reshape((y_pred_b.shape[1], y_pred_b.shape[0], C))
    
    return ctc_loss(y_pred_b_, y_b, y_pred_lengths_, y_lengths_)    

In [10]:
loss_func = ctc_loss_custom

In [11]:
learner = Learner(databunch, model, loss_func=loss_func)

In [12]:
learner.fit(1)

epoch,train_loss,valid_loss,time


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 8, 1])

In [None]:
learner.lr_find()

In [None]:
learner.recorder.plot(suggestion=True)

In [None]:
learner.fit_one_cycle(10)