# Creating a Databunch for Basecalling

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [288]:
from fastai.basics import *
import jkbc.utils.preprocessing as prep
import jkbc.utils.postprocessing as pop

## Constants

In [320]:
BLANK_ID = 4
C = 5
D_in = 300
D_h = 200
D_out_max = 50
BS = 64 # batch size
LR = 0.05

#DEVICE = torch.device("cpu")
DEVICE = torch.device("cuda:0")

## Load Data

In [4]:
base_dir = "/notebooks/storage/" # Change this line if the data is located elsewhere
path_data = Path(base_dir +"/bc_data")
path_data.ls()

[PosixPath('/notebooks/storage/bc_data/test_dataset.hdf5'),
 PosixPath('/notebooks/storage/bc_data/small_umi16to9.hdf5')]

## Preprocess Data

In [336]:
#data_set_name = 'test_dataset.hdf5'
data_set_name = 'small_umi16to9.hdf5'
dc_test = prep.SignalCollection(path_data/data_set_name)

In [337]:
# There is just a single piece of data in dc_test
for d in dc_test.generator():
    data = d

Processing 0000b441-4cc6-42a7-bdfe-310f9e560e57 (0)


AssertionError: ref_to_signal and reference has different lengths

In [166]:
data_fields = np.array(data.x), np.array(data.y), data.x_lengths, data.y_lengths, data.reference
x, y, x_lengths, y_lengths, reference = data_fields

In [167]:
print(f'x: {x.shape}, x_lengths: {x_lengths.shape}, y: {y.shape}, y_lengths: {y_lengths.shape}')

x: (930, 300), x_lengths: (930, 1), y: (930, 6), y_lengths: (930, 1)


In [168]:
# Reshape the x's to fit our model
x_train = x.reshape((D_in, BS, 1))

# The add_label_padding takes a nested list, so we use the data.y, should be changed
y_train = prep.add_label_padding(labels = data.y, fixed_label_len = D_out_max, padding_id = BLANK_ID)
"x: ", type(x_train), x_train.shape, "y: ", type(y_train), y_train.shape

('x: ', numpy.ndarray, (300, 930, 1), 'y: ', numpy.ndarray, (930, 50))

In [169]:
# Turn everything into tensors
mk_tensor_long = partial(torch.tensor, dtype = torch.long, device = DEVICE)
mk_tensor_float = partial(torch.tensor, dtype = torch.float, device = DEVICE)
x_train = mk_tensor_float(x_train)
y_train, x_lengths, y_lengths = map(mk_tensor_long, (y_train, x_lengths, y_lengths))

In [170]:
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(1, D_h)
        self.lin2 = nn.Linear(D_h, C)
    
    #   forward(self, x: Tensor[D_in, BS, 1]) -> Tensor[D_in, BS, C]
    def forward(self, xb):
        xb_ = self.lin1(xb).clamp(min=0)
        return self.lin2(xb_)     

In [321]:
model = SimpleModel().cuda()

In [322]:
ctc_loss = nn.CTCLoss()

In [323]:
lr = 1e-5

In [324]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [144]:
def train(n: int) -> None:
    for t in range(n):
        # forward pass
        y_pred = model(x_train)
        
        loss = ctc_loss(y_pred, y_train, x_lengths, y_lengths)
        
        if t % 100 == 99:
            print(t, loss.item())
            
        optimizer.zero_grad()
        
        # backward pass
        loss.backward()
        
        optimizer.step()

In [146]:
#train(100)

99 -13.010493278503418


## Create the Databunch

In [325]:
x_train_t, x_valid_t = x_train[:, :500].reshape((500, 300, 1)), x_train[:, 800:].reshape((130, 300, 1))

In [326]:
y_train_t, y_valid_t = y_train[:500], y_train[800:]

In [327]:
x_train_t.shape, y_train_t.shape, x_valid_t.shape, y_valid_t.shape

(torch.Size([500, 300, 1]),
 torch.Size([500, 50]),
 torch.Size([130, 300, 1]),
 torch.Size([130, 50]))

In [328]:
ds_train = TensorDataset(x_train_t, y_train_t)
ds_valid = TensorDataset(x_valid_t, y_valid_t)

In [329]:
data = DataBunch.create(ds_train, ds_valid, bs=BS)

In [330]:
y_pred_lengths = torch.full(size=(BS,), fill_value=D_in, dtype=torch.long)
y_lengths = torch.full(size=(BS,), fill_value=D_out_max, dtype=torch.long)

In [331]:
ctc_loss = nn.CTCLoss()
def ctc_loss_custom(y_pred_b: torch.Tensor, y_b: torch.Tensor) -> float:
    if y_pred_lengths.shape[0] != y_pred_b.shape[0]:
        new_len = y_pred_b.shape[0]
        y_pred_lengths_ = y_pred_lengths[:new_len]
        y_lengths_ = y_lengths[:new_len]
    else:
        y_pred_lengths_ = y_pred_lengths
        y_lengths_ = y_lengths
    
    #print(f'y_pred_b: {y_pred_b.shape}, y_b: {y_b.shape}')
    y_pred_b_ = y_pred_b.reshape((y_pred_b.shape[1], y_pred_b.shape[0], C))
    #y_b_ = y_b.reshape((y_b.shape[1], y_b.shape[0]))
    #print(f'y_pred_b_: {y_pred_b_.shape}, y_b_: {y_b.shape}')
    #print(f'y_pred_lengths: {y_pred_lengths.shape}, y_lengths: {y_pred_lengths.shape}')
    
    return ctc_loss(y_pred_b_, y_b, y_pred_lengths_, y_lengths_)    

In [332]:
loss_func = ctc_loss_custom

In [334]:
learner = Learner(data, SimpleModel().cuda(), loss_func=loss_func)

In [335]:
learner.fit_one_cycle(5)

epoch,train_loss,valid_loss,time
0,-3.206807,-2.145919,00:00
1,-1.787082,1.200493,00:00
2,-0.486188,1.637264,00:00
3,0.067616,0.978475,00:00
4,0.276365,0.81112,00:00
