# Train Like Bonito
https://github.com/nanoporetech/bonito/blob/master/notebooks/bonito-train.ipynb

In [2]:
import os
import sys
import time
import random
from datetime import datetime
from itertools import starmap
from pathlib import Path

import numpy as np
import pandas as pd
import toml
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

from bonito.util import accuracy
from bonito.training import ChunkDataSet
from bonito.decode import decode, decode_ref

In [3]:
!cd ../../../mapped_reads/bonito_preprocessed/; ls

data  mapped_reads.hdf5


In [4]:
base_dir = Path("../../../mapped_reads/bonito_preprocessed/")
def load_np(fn): return np.load(base_dir/'data'/fn)
#def load_toml(fn): return toml.load(base_dir/'config'/fn)

In [5]:
# Sections of squiggle that correspond with the target reference sequence
# Variable length and zero padded (upto 4096 samples).
# shape (1000000, 4096)
# dtype('float32')
full_chunks = load_np("chunks.npy")

# Lengths of squiggle sections in chunks.npy 
# shape (1000000,)
# dtype('uint16')
full_chunk_lengths = load_np("chunk_lengths.npy")

# Integer encoded target sequence {'A': 1, 'C': 2, 'G': 3, 'T': 4}
# Variable length and zero padded (default range between 128 and 256).
# shape (1000000, 256)
# dtype('uint8')
full_targets = load_np("references.npy")

# Lengths of target sequences in references.npy
# shape (1000000,)
# dtype('uint8')
full_target_lengths = load_np("reference_lengths.npy")

# The structure of the model is defined using a config file.
# This will make sense to those familar with QuartzNet
# https://arxiv.org/pdf/1910.10261.pdf).
quartznet_config = toml.load("../models/quartznet5x5.toml")

In [6]:
from torch.nn import ReLU, LeakyReLU, GELU
from torch.nn import Module, ModuleList, Sequential, Conv1d, BatchNorm1d, Dropout

activations = {
    "relu": ReLU,
    "leaky_relu": LeakyReLU,
    "gelu": GELU,
}


class Model(Module):
    """
    Model template for QuartzNet style architectures

    https://arxiv.org/pdf/1910.10261.pdf
    """
    def __init__(self, config):
        super(Model, self).__init__()
        self.stride = config['block'][0]['stride'][0]
        self.alphabet = config['labels']['labels']
        self.features = config['block'][-1]['filters']
        self.encoder = Encoder(config)
        self.decoder = Decoder(self.features, len(self.alphabet))

    def forward(self, x):
        encoded = self.encoder(x)
        return self.decoder(encoded)


class Encoder(Module):
    """
    Builds the model encoder
    """
    def __init__(self, config):
        super(Encoder, self).__init__()
        self.config = config

        features = self.config['input']['features']
        activation = activations[self.config['encoder']['activation']]()
        encoder_layers = []

        for layer in self.config['block']:
            encoder_layers.append(
                Block(
                    features, layer['filters'], activation,
                    repeat=layer['repeat'], kernel_size=layer['kernel'],
                    stride=layer['stride'], dilation=layer['dilation'],
                    dropout=layer['dropout'], residual=layer['residual'],
                    separable=layer['separable'],
                )
            )

            features = layer['filters']

        self.encoder = Sequential(*encoder_layers)

    def forward(self, x):
        return self.encoder([x])


class TCSConv1d(Module):
    """
    Time-Channel Separable 1D Convolution
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, separable=False):

        super(TCSConv1d, self).__init__()
        self.separable = separable
        self.groups = groups
        #if groups > 1 and not separable:
            #raise ValueError("Grouping should probably only be used with separable kernels.")

        if separable:
            self.depthwise = Conv1d(
                in_channels, in_channels, kernel_size=kernel_size, stride=stride,
                padding=padding, dilation=dilation, bias=bias, groups=in_channels
            )

            self.pointwise = Conv1d(
                in_channels, out_channels, kernel_size=1, stride=stride,
                dilation=dilation, bias=bias, padding=0, groups=groups
            )
        else:
            self.conv = Conv1d(
                in_channels, out_channels, kernel_size=kernel_size,
                stride=stride, padding=padding, dilation=dilation, bias=bias, groups=groups
            )

    def forward(self, x):
        if self.separable:
            x = self.depthwise(x)
            x = self.pointwise(x)
        else:
            x = self.conv(x)
        if self.groups > 1:
                x = channel_shuffle(x, self.groups)
        return x
      
def channel_shuffle(x, groups):
    # type: (torch.Tensor, int) -> torch.Tensor
    batchsize, num_channels, feature_count = x.data.size()
    channels_per_group = num_channels // groups

    # reshape
    x = x.view(batchsize, groups,
               channels_per_group, feature_count)

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, feature_count)

    return x

class Block(Module):
    """
    TCSConv, Batch Normalisation, Activation, Dropout
    """
    def __init__(self, in_channels, out_channels, activation, repeat=5, kernel_size=1, stride=1, dilation=1, dropout=0.0, residual=False, separable=False):

        super(Block, self).__init__()

        self.use_res = residual
        self.conv = ModuleList()
        
        self.groups = 4 if separable else 1

        _in_channels = in_channels
        padding = self.get_padding(kernel_size[0], stride[0], dilation[0])

        # add the first n - 1 convolutions + activation
        for _ in range(repeat - 1):
            self.conv.extend(
                self.get_tcs(
                    _in_channels, out_channels, kernel_size=kernel_size,
                    stride=stride, dilation=dilation,
                    padding=padding, separable=separable, groups=self.groups
                )
            )

            self.conv.extend(self.get_activation(activation, dropout))
            _in_channels = out_channels

        # add the last conv and batch norm
        self.conv.extend(
            self.get_tcs(
                _in_channels, out_channels,
                kernel_size=kernel_size,
                stride=stride, dilation=dilation,
                padding=padding, separable=separable, groups=self.groups
            )
        )

        # add the residual connection
        if self.use_res:
            self.residual = Sequential(*self.get_tcs(in_channels, out_channels))

        # add the activation and dropout
        self.activation = Sequential(*self.get_activation(activation, dropout))

    def get_activation(self, activation, dropout):
        return activation, Dropout(p=dropout)

    def get_padding(self, kernel_size, stride, dilation):
        if stride > 1 and dilation > 1:
            raise ValueError("Dilation and stride can not both be greater than 1")
        return (kernel_size // 2) * dilation

    def get_tcs(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, padding=0, bias=False, separable=False, groups=1):
        return [
            TCSConv1d(
                in_channels, out_channels, kernel_size,
                stride=stride, dilation=dilation, padding=padding,
                bias=bias, separable=separable, groups=groups
            ),
            BatchNorm1d(out_channels, eps=1e-3, momentum=0.1)
        ]

    def forward(self, x):
        _x = x[0]
        for layer in self.conv:
            _x = layer(_x)
        if self.use_res:
            _x += self.residual(x[0])
        return [self.activation(_x)]


class Decoder(Module):
    """
    Decoder
    """
    def __init__(self, features, classes):
        super(Decoder, self).__init__()
        self.layers = Sequential(Conv1d(features, classes, kernel_size=1, bias=True))

    def forward(self, x):
        x = self.layers(x[-1])
        return nn.functional.log_softmax(x.transpose(1, 2), dim=2)


#### Training options
Default options are set, and ranges are sensible, but most combinations of settings are untested.

The default settings will train on a small amount of data (1000 signal chunks) for a small number of epochs (20). This is unlikely to produce an accurate generalisable model, but will train relatively quickly.

After modifying this cell, Runtime -> Run after, so that all cells between this one and the main train looping will be run in accordance with new setting.

A train_proportion of 0.90 will use 90% of the data for training and 10% for validation.

No dropout is applied by default, but in order to avoid overfitting on small data sets it may be necessary to apply dropout (e.g. of 0.5), or other regularisation techniques.

In [7]:
model_savepath = Path("train_like_bonito/models")
learning_rate = 0.001 #@param {type:"number"}
random_seed = 25 #@param {type:"integer"}
epochs = 20 #@param {type:"slider", min:1, max:1000, step:1}
batch_size = 16 #@param [2, 4, 8, 16, 28] {type:"raw"}
num_chunks = 10000 #@param [10, 100, 1000, 10000, 100000] {type:"raw"}
train_proportion = 0.80 #@param type:"slider", min:0.8, max:1000, step:1
dropout = 0.0 #@param {type:"slider", min:0.0, max:0.8}

In [8]:
# Initialise random libs and setup cudnn
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# we exploit GPU for training
device = torch.device("cuda")

#### Prepare data according to values set in the 'Training options

In [9]:
# subset
chunks = full_chunks[:num_chunks]
chunk_lengths = full_chunk_lengths[:num_chunks]
targets = full_targets[:num_chunks]
target_lengths = full_target_lengths[:num_chunks]

# shuffle
shuf = np.random.permutation(chunks.shape[0])
chunks = chunks[shuf]
chunk_lengths = chunk_lengths[shuf]
targets = targets[shuf]
target_lengths = target_lengths[shuf]

split = np.floor(chunks.shape[0] * train_proportion).astype(np.int32)

In [10]:
for b in quartznet_config['block']:
    b['dropout'] = dropout
quartznet_config

{'model': 'QuartzNet',
 'name': 'bonito',
 'pred_out_scale': 3,
 'labels': {'labels': ['N', 'A', 'C', 'G', 'T']},
 'input': {'features': 1},
 'encoder': {'activation': 'relu'},
 'block': [{'filters': 256,
   'repeat': 1,
   'kernel': [33],
   'stride': [3],
   'dilation': [1],
   'dropout': 0.0,
   'residual': False,
   'separable': False},
  {'filters': 256,
   'repeat': 5,
   'kernel': [33],
   'stride': [1],
   'dilation': [1],
   'dropout': 0.0,
   'residual': True,
   'separable': True},
  {'filters': 256,
   'repeat': 5,
   'kernel': [39],
   'stride': [1],
   'dilation': [1],
   'dropout': 0.0,
   'residual': True,
   'separable': True},
  {'filters': 512,
   'repeat': 5,
   'kernel': [51],
   'stride': [1],
   'dilation': [1],
   'dropout': 0.0,
   'residual': True,
   'separable': True},
  {'filters': 512,
   'repeat': 5,
   'kernel': [63],
   'stride': [1],
   'dilation': [1],
   'dropout': 0.0,
   'residual': True,
   'separable': True},
  {'filters': 512,
   'repeat': 5,
  

## Training and test functions

In [11]:

# 'Connectionist Temporal Classification' (CTC) loss fuction
# https://distill.pub/2017/ctc/
criterion = nn.CTCLoss(reduction='mean')

def train(log_interval, model, device, train_loader,
          optimizer, epoch, use_amp=False):

    t0 = time.perf_counter()
    chunks = 0

    model.train()

    sys.stderr.write("\n" + "Training epoch: " + str(epoch) + "\n")
    progress_bar = tqdm(total=len(train_loader), leave=True, ncols=100)

    for batch_idx, (data, out_lengths, target, lengths) in enumerate(train_loader, start=1):

        optimizer.zero_grad()

        chunks += data.shape[0]

        data = data.to(device)
        target = target.to(device)

        # forward pass
        log_probs = model(data)    

        # calculate loss
        loss = criterion(log_probs.transpose(0, 1), target, out_lengths / model.stride, lengths)

        # backward pass
        loss.backward()

        # update weights
        optimizer.step()
        progress_bar.refresh()
        progress_bar.update(1)
        progress_bar.set_description("Loss: " + str(loss.item()))
        sys.stderr.flush()        

    progress_bar.close()

    return loss.item(), time.perf_counter() - t0


def test(model, device, test_loader):

    model.eval()
    test_loss = 0
    predictions = []
    prediction_lengths = []

    with torch.no_grad():
        for batch_idx, (data, out_lengths, target, lengths) in enumerate(test_loader, start=1):
            data, target = data.to(device), target.to(device)
 
            # forward pass
            log_probs = model(data)
 
            # calculate loss
            test_loss += criterion(log_probs.transpose(1, 0), target, out_lengths / model.stride, lengths)

            # accumulate output probabilities
            predictions.append(torch.exp(log_probs).cpu())
            prediction_lengths.append(out_lengths / model.stride)

    predictions = np.concatenate(predictions)
    lengths = np.concatenate(prediction_lengths)

    # convert probabilities to sequences
    references = [decode_ref(target, model.alphabet) for target in test_loader.dataset.targets]
    sequences = [decode(post[:n], model.alphabet) for post, n in zip(predictions, lengths)]

    # align predicted sequences with true sequences and calculate accuracy
    if all(map(len, sequences)):
        accuracies = list(starmap(accuracy, zip(references, sequences)))
    else:
        accuracies = [0]

    # measure average accuracies over entire set of validation chunks
    mean = np.mean(accuracies)
    median = np.median(accuracies)

    return test_loss.item() / batch_idx, mean, median

## Main Training Loop

In [12]:
#@title Set experiment name
experiment_name = 'bonito_groups_11' #@param {type:"string"}

In [13]:
def kaiming_uniformise_(model):
    for l in model.modules():
        if isinstance(l, nn.Conv1d):
            nn.init.kaiming_uniform_(l.weight)

In [None]:
# prevent overwriting of data
workdir = os.path.join(model_savepath, experiment_name)
if os.path.isdir(workdir):
    raise IOError('{} already exists. Select an alternative model_savepath.'.format(workdir))
os.makedirs(workdir)

# data generators
train_dataset = ChunkDataSet(chunks[:split], chunk_lengths[:split],
                             targets[:split], target_lengths[:split])
test_dataset = ChunkDataSet(chunks[split:], chunk_lengths[split:],
                            targets[split:], target_lengths[split:])

# data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
                         num_workers=4, pin_memory=True)

# load bonito model
model = Model(quartznet_config)
model.to(device)
model.train()

# set optimizer and learning rate scheduler
optimizer = AdamW(model.parameters(), amsgrad=True, lr=learning_rate)
schedular = CosineAnnealingLR(optimizer, epochs * len(train_loader))

# report loss every 
interval = 500 / num_chunks
log_interval = np.floor(len(train_dataset) / batch_size * interval)

exp_config = os.path.join(workdir, "experimental.log")
with open(exp_config, 'a') as c:
    c.write('Num training chunks: {}'.format(num_chunks) + '\n')
    c.write('learning rate: {}'.format(learning_rate) + '\n')
    c.write('random seed: {}'.format(random_seed) + '\n')
    c.write('epochs: {}'.format(epochs) + '\n')
    c.write('batch_size: {}'.format(batch_size) + '\n')
    c.write('train proportion: {}'.format(train_proportion) + '\n')
    c.write('dropout: {}'.format(dropout) + '\n')

# DataFrame to store training logging information
training_results = pd.DataFrame()

for epoch in range(1, epochs + 1):

    train_loss, duration = train(log_interval, model, device,
                                 train_loader, optimizer, epoch)
    
    test_loss, mean, median = test(model, device, test_loader)

    # collate training and validation metrics
    epoch_result = pd.DataFrame(
        {'time':[datetime.today()],
         'duration':[int(duration)],
         'epoch':[epoch],
         'train_loss':[train_loss],
         'validation_loss':[test_loss], 
         'validation_mean':[mean],
         'validation_median':[median]})
    
    # save model weights
    weights_path = os.path.join(workdir, "weights_%s.tar" % epoch)
    torch.save(model.state_dict(), weights_path)

    # update log file
    log_path = os.path.join(workdir, "training.log")
    epoch_result.to_csv(log_path, mode='a', sep='\t', index=False)

    display(epoch_result)
    training_results = training_results.append(epoch_result)

    schedular.step()

display(training_results)


Training epoch: 1
Loss: 0.9193114042282104:  72%|█████████████████████████▏         | 360/500 [01:08<00:26,  5.34it/s]

## Tries (after 1 epoch): 

With Kaiming:

1: 0.9324800968170166

2: 0.9282790422439575

Without Kaiming:

1: 0.7507233619689941

Without Kaiming, dropout=0.5:

1: 1.2488034963607788

I am unsure why the runs differ with same settings when we use seed.

## Grouping experiments

### Separable are grouped in 4
Loss: 0.7064507007598877: 100%|███████████████████████████████████| 500/500 [01:38<00:00,  5.15it/s]

In [14]:
# Groups = 4
sum(p.numel() for p in model.parameters())

4024325

## Not grouped:
Loss: 0.7371149063110352: 100%|███████████████████████████████████| 500/500 [01:50<00:00,  4.56it/s]

In [15]:
# Groups = 1
sum(p.numel() for p in model.parameters())

6678533