In [None]:
# install Bonito last version and Pydrive
# if the cell is run and a numpy warning pops up, restart kernel and run again the cell
# source: https://github.com/nanoporetech/bonito/blob/v0.4.0/notebooks/bonito-train.ipynb

!pip install -q ont-bonito
!pip install -U -q PyDrive
!pip install -q fast_ctc_decode

import os
import sys
import time
import random
from datetime import datetime
from itertools import starmap
from time import perf_counter
from functools import partial
import numpy as np
import pandas as pd
import toml
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import Module, ModuleList, Sequential, Conv1d, BatchNorm1d, Dropout, ReLU, SiLU
from torch.nn.functional import ctc_loss, log_softmax
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

from google.colab import auth
from google.colab import drive as gdrive
from oauth2client.client import GoogleCredentials
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive

from bonito.ctc.model import Model
from bonito.util import accuracy, decode_ref, permute, concat
from bonito.data import ChunkDataSet
from bonito.nn import Permute

from fast_ctc_decode import beam_search, viterbi_search


!pip install -q tensorly
!pip install -q tensorly-torch

# Tensor decomposition packages
import tensorly
from tltorch import FactorizedConv

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m229.7/229.7 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.7/154.7 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25h

---

#GOOGLE CREDENTIALS

In [None]:
# Authenticate and create PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# helper functions for importing data
def download_npy_from_link(fn, link):
    _, id = link.split('=')
    downloaded = drive.CreateFile({'id':id})
    downloaded.GetContentFile(fn)
    return np.load(fn)

def download_toml_from_link(fn, link):
    _, id = link.split('=')
    downloaded = drive.CreateFile({'id':id})
    downloaded.GetContentFile(fn)
    return toml.load(fn)

---
# DOWNLOAD SAMPLE DATASET

In [None]:
chunks_link = "https://drive.google.com/open?id=1aciNfQs53eFRwnMggInY-Uisi-owtmzY" #@param {type:"string"}
references_link = "https://drive.google.com/open?id=1kcs_hZMndUIDX2n8dTxGrAgCvt_TpUcH" #@param {type:"string"}
reference_lengths_link = "https://drive.google.com/open?id=1-r7XymddP_3gKFb-7ohB_t14u7u4SGLm" #@param {type:"string"}

#quartznet_config_link = "https://drive.google.com/open?id=1hKKE2Fzp3jdNyZI2h8jnOuwxBOvXWjp6"
quartznet_config_link = "https://drive.google.com/open?id=1IRDMrnE0WWeiRoioX7NHM5TXezl2jMkN"


In [None]:
print('Loading chunks.')
full_chunks = download_npy_from_link('chunks.npy',
                                chunks_link)
# Sections of squiggle that correspond with the target reference sequence
# Variable length and zero padded (upto 4096 samples).
# shape (1000000, 4096)
# dtype('float32')

print('Loading references.')
full_targets = download_npy_from_link('references.npy',
                                 references_link)
# Integer encoded target sequence {'A': 1, 'C': 2, 'G': 3, 'T': 4}
# Variable length and zero padded (default range between 128 and 256).
# shape (1000000, 256)
# dtype('uint8')

print('Loading reference lengths.')
full_target_lengths = download_npy_from_link('reference_lengths.npy',
                                        reference_lengths_link)
# Lengths of target sequences in references.npy
# shape (1000000,)
# dtype('uint8')

print('Loading quartznet config.')
#quartznet_config = download_toml_from_link("dna_r9.4.1@v1.toml",quartznet_config_link)
quartznet_config = download_toml_from_link("dna_r9.4.1@v2.toml",quartznet_config_link)

# The structure of the model is defined using a config file.
# This will make sense to those familar with QuartzNet


# https://arxiv.org/pdf/1910.10261.pdf).

Loading chunks.
Loading references.
Loading reference lengths.
Loading quartznet config.


---
# Quartznet Model

In [None]:
class Model(Module):
    """
    Model template for QuartzNet style architectures

    https://arxiv.org/pdf/1910.10261.pdf
    """
    def __init__(self, config):
        super(Model, self).__init__()
        if 'qscore' not in config:
            self.qbias = 0.0
            self.qscale = 1.0
        else:
            self.qbias = config['qscore']['bias']
            self.qscale = config['qscore']['scale']

        self.config = config
        self.stride = config['block'][0]['stride'][0]
        self.alphabet = config['labels']['labels']
        self.features = config['block'][-1]['filters']
        self.encoder = Encoder(config)
        self.decoder = Decoder(self.features, len(self.alphabet))

    def forward(self, x):
        encoded = self.encoder(x)
        return self.decoder(encoded)

    def decode(self, x, beamsize=5, threshold=1e-3, qscores=False, return_path=False):
        x = x.exp().cpu().numpy().astype(np.float32)
        if beamsize == 1 or qscores:
            seq, path  = viterbi_search(x, self.alphabet, qscores, self.qscale, self.qbias)
        else:
            seq, path = beam_search(x, self.alphabet, beamsize, threshold)
        if return_path: return seq, path
        return seq


class Encoder(Module):
    """
    Builds the model encoder
    """
    def __init__(self, config):
        super(Encoder, self).__init__()
        self.config = config

        self.activations = {"relu": ReLU,"swish": SiLU}
        features = self.config['input']['features']
        activation = self.activations[self.config['encoder']['activation']]()
        encoder_layers = []

        for layer in self.config['block']:
            encoder_layers.append(
                Block(
                    features, layer['filters'], activation,
                    repeat=layer['repeat'], kernel_size=layer['kernel'],
                    stride=layer['stride'], dilation=layer['dilation'],
                    dropout=layer['dropout'], residual=layer['residual'],
                    separable=layer['separable'],
                )
            )

            features = layer['filters']

        self.encoder = Sequential(*encoder_layers)

    def forward(self, x):
        return self.encoder(x)


class TCSConv1d(Module):
    """
    Time-Channel Separable 1D Convolution
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, separable=False):

        super(TCSConv1d, self).__init__()
        self.separable = separable

        if separable:
            # This layer cannot be factorised until "groups is implemented in tensorly-torch".
            self.depthwise = Conv1d(
                in_channels, in_channels, kernel_size=kernel_size, stride=stride,
                padding=padding, dilation=dilation, bias=bias, groups=in_channels
            )

            self.pointwise = Conv1d(
                in_channels, out_channels, kernel_size=1, stride=1,
                dilation=dilation, bias=bias, padding=0
            )

        else:
            self.conv = Conv1d(
                in_channels, out_channels, kernel_size=kernel_size,
                stride=stride, padding=padding, dilation=dilation, bias=bias
            )

    def forward(self, x):
        if self.separable:
            x = self.depthwise(x)
            x = self.pointwise(x)
        else:
            x = self.conv(x)
        return x


class Block(Module):
    """
    TCSConv, Batch Normalisation, Activation, Dropout
    """
    def __init__(self, in_channels, out_channels, activation, repeat=5, kernel_size=1, stride=1, dilation=1, dropout=0.0, residual=False, separable=False):

        super(Block, self).__init__()

        self.use_res = residual
        self.conv = ModuleList()

        _in_channels = in_channels
        padding = self.get_padding(kernel_size[0], stride[0], dilation[0])

        # add the first n - 1 convolutions + activation
        for _ in range(repeat - 1):
            self.conv.extend(
                self.get_tcs(
                    _in_channels, out_channels, kernel_size=kernel_size,
                    stride=stride, dilation=dilation,
                    padding=padding, separable=separable
                )
            )

            self.conv.extend(self.get_activation(activation, dropout))
            _in_channels = out_channels

        # add the last conv and batch norm
        self.conv.extend(
            self.get_tcs(
                _in_channels, out_channels,
                kernel_size=kernel_size,
                stride=stride, dilation=dilation,
                padding=padding, separable=separable
            )
        )

        # add the residual connection
        if self.use_res:
            self.residual = Sequential(*self.get_tcs(in_channels, out_channels))

        # add the activation and dropout
        self.activation = Sequential(*self.get_activation(activation, dropout))

    def get_activation(self, activation, dropout):
        return activation, Dropout(p=dropout)

    def get_padding(self, kernel_size, stride, dilation):
        if stride > 1 and dilation > 1:
            raise ValueError("Dilation and stride can not both be greater than 1")
        return (kernel_size // 2) * dilation

    def get_tcs(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, padding=0, bias=False, separable=False):
        return [
            TCSConv1d(
                in_channels, out_channels, kernel_size,
                stride=stride, dilation=dilation, padding=padding,
                bias=bias, separable=separable
            ),
            BatchNorm1d(out_channels, eps=1e-3, momentum=0.1)
        ]

    def forward(self, x):
        _x = x
        for layer in self.conv:
            _x = layer(_x)
        if self.use_res:
            _x = _x + self.residual(x)
        return self.activation(_x)


class Decoder(Module):
    """
    Decoder
    """
    def __init__(self, features, classes):
        super(Decoder, self).__init__()
        self.layers = Sequential(
            Conv1d(features, classes, kernel_size=1, bias=True),
            Permute([2, 0, 1])
        )

    def forward(self, x):
        return log_softmax(self.layers(x), dim=2)

---

# TRAINING OPTIONS

Training options
Default options are set, and ranges are sensible, but most combinations of settings are untested.

The default settings will train on a small amount of data (1000 signal chunks) for a small number of epochs (20). This is unlikely to produce an accurate generalisable model, but will train relatively quickly.

After modifying this cell, Runtime -> Run after, so that all cells between this one and the main train looping will be run in accordance with new setting.

A train_proportion of 0.90 will use 90% of the data for training and 10% for validation.

No dropout is applied by default, but in order to avoid overfitting on small data sets it may be necessary to apply dropout (e.g. of 0.5), or other regularisation techniques.

In [None]:

model_savepath = '/content/drive/My Drive/Quartznet_weights/' #@param {type:"string"}
learning_rate = 0.001 #@param {type:"number"}
random_seed = 25 #@param {type:"integer"}
epochs = 20 #@param {type:"slider", min:1, max:1000, step:1}
batch_size = 16 #@param [2, 4, 8, 16, 28] {type:"raw"}
num_chunks = 10000 #@param [10, 100, 1000, 10000, 100000] {type:"raw"}
train_proportion = 0.80 #@param type:"slider", min:0.8, max:1000, step:1
dropout = 0.0 #@param {type:"slider", min:0.0, max:0.8}

In [None]:

# Initialise random libs and setup cudnn
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# we exploit GPU for training
device = torch.device("cuda")

In [None]:
# subset
chunks = full_chunks[:num_chunks]
targets = full_targets[:num_chunks]
target_lengths = full_target_lengths[:num_chunks]

# shuffle
shuf = np.random.permutation(chunks.shape[0])
chunks = chunks[shuf]
targets = targets[shuf]
target_lengths = target_lengths[shuf]

split = np.floor(chunks.shape[0] * train_proportion).astype(np.int32)

In [None]:
for b in quartznet_config['block']:
    b['dropout'] = dropout
quartznet_config

In [None]:
# Train function obtained from v3.0 (https://github.com/nanoporetech/bonito/blob/v0.3.0/bonito/training.py)

def ctc_label_smoothing_loss(log_probs, targets, lengths, weights):
    T, N, C = log_probs.shape
    log_probs_lengths = torch.full(size=(N, ), fill_value=T, dtype=torch.int64)
    loss = ctc_loss(log_probs.to(torch.float32), targets, log_probs_lengths, lengths, reduction='mean')
    label_smoothing_loss = -((log_probs * weights.to(log_probs.device)).mean())
    return {'loss': loss + label_smoothing_loss, 'ctc_loss': loss, 'label_smooth_loss': label_smoothing_loss}

def train(model, device, train_loader, optimizer, use_amp=False, criterion=None, lr_scheduler=None, loss_log=None):

    if criterion is None:
        C = len(model.alphabet)
        weights = torch.cat([torch.tensor([0.4]), (0.1 / (C - 1)) * torch.ones(C - 1)]).to(device)
        criterion = partial(ctc_label_smoothing_loss, weights=weights)

    chunks = 0
    model.train()
    t0 = perf_counter()

    progress_bar = tqdm(
        total=len(train_loader), desc='[0/{}]'.format(len(train_loader.dataset)),
        ascii=True, leave=True, ncols=100, bar_format='{l_bar}{bar}| [{elapsed}{postfix}]'
    )
    smoothed_loss = {}

    with progress_bar:

        for data, targets, lengths in train_loader:

            optimizer.zero_grad()

            chunks += data.shape[0]

            # DEBUG MODE
            #display("[DEBUG]----------------", model(data.to(device)))

            log_probs = model(data.to(device))
            losses = criterion(log_probs, targets.to(device), lengths.to(device))

            if not isinstance(losses, dict):
                losses = {'loss': losses}

            if use_amp:
                pass
            else:
                losses['loss'].backward()

            optimizer.step()

            if lr_scheduler is not None: lr_scheduler.step()

            if not smoothed_loss:
                smoothed_loss = {k: v.item() for k,v in losses.items()}
            smoothed_loss = {k: 0.01 * v.item() + 0.99 * smoothed_loss[k] for k,v in losses.items()}

            progress_bar.set_postfix(loss='%.4f' % smoothed_loss['loss'])
            progress_bar.set_description("[{}/{}]".format(chunks, len(train_loader.dataset)))
            progress_bar.update()

            if loss_log is not None:
                loss_log.append({'chunks': chunks, 'time': perf_counter() - t0, **smoothed_loss})

    return smoothed_loss['loss'], perf_counter() - t0

def test(model, device, test_loader, min_coverage=0.5, criterion=None):

    if criterion is None:
        C = len(model.alphabet)
        weights = torch.cat([torch.tensor([0.4]), (0.1 / (C - 1)) * torch.ones(C - 1)]).to(device)
        criterion = partial(ctc_label_smoothing_loss, weights=weights)

    seqs = []
    model.eval()
    test_loss = 0
    accuracy_with_cov = lambda ref, seq: accuracy(ref, seq, min_coverage=min_coverage)

    with torch.no_grad():
        for batch_idx, (data, target, lengths) in enumerate(test_loader, start=1):
            log_probs = model(data.to(device))
            loss = criterion(log_probs, target.to(device), lengths.to(device))
            test_loss += loss['ctc_loss'] if isinstance(loss, dict) else loss
            seqs.extend([model.decode(p) for p in permute(log_probs, 'TNC', 'NTC')])

    refs = [
        decode_ref(target, model.alphabet) for target in test_loader.dataset.targets
    ]
    accuracies = [
        accuracy_with_cov(ref, seq) if len(seq) else 0. for ref, seq in zip(refs, seqs)
    ]

    mean = np.mean(accuracies)
    median = np.median(accuracies)
    return test_loss.item() / batch_idx, mean, median

In [None]:
#@title Set experiment name
experiment_name = 'bonito_training_2' #@param {type:"string"}

# mount users drive to save data
gdrive.mount('/content/drive', force_remount=True)

# prevent overwriting of data
workdir = os.path.join(model_savepath, experiment_name)
if os.path.isdir(workdir):
    raise IOError('{} already exists. Select an alternative model_savepath.'.format(workdir))
os.makedirs(workdir)

# data generators
train_dataset = ChunkDataSet(chunks[:split], targets[:split], target_lengths[:split])
test_dataset = ChunkDataSet(chunks[split:], targets[split:], target_lengths[split:])


# data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
                         num_workers=2, pin_memory=True)

# load bonito model
model = Model(quartznet_config)
model.to(device)
model.train()

# 'Connectionist Temporal Classification' (CTC) loss fuction
# https://distill.pub/2017/ctc/
criterion = nn.CTCLoss(reduction='mean')

# set optimizer and learning rate scheduler
optimizer = AdamW(model.parameters(), amsgrad=True, lr=learning_rate)
schedular = CosineAnnealingLR(optimizer, epochs * len(train_loader))

# report loss every
interval = 500 / num_chunks
log_interval = np.floor(len(train_dataset) / batch_size * interval)

exp_config = os.path.join(workdir, "experimental.log")
with open(exp_config, 'a') as c:
    c.write('Num training chunks: {}'.format(num_chunks) + '\n')
    c.write('learning rate: {}'.format(learning_rate) + '\n')
    c.write('random seed: {}'.format(random_seed) + '\n')
    c.write('epochs: {}'.format(epochs) + '\n')
    c.write('batch_size: {}'.format(batch_size) + '\n')
    c.write('train proportion: {}'.format(train_proportion) + '\n')
    c.write('dropout: {}'.format(dropout) + '\n')

# DataFrame to store training logging information
training_results = pd.DataFrame()

for epoch in range(1, epochs + 1):
    # v3.0
    train_loss, duration = train(model, device, train_loader, optimizer)

    test_loss, mean, median = test(model, device, test_loader)

    # collate training and validation metrics
    epoch_result = pd.DataFrame(
        {'time':[datetime.today()],
         'duration':[int(duration)],
         'epoch':[epoch],
         'train_loss':[train_loss],
         'validation_loss':[test_loss],
         'validation_mean':[mean],
         'validation_median':[median]})

    # save model weights
    weights_path = os.path.join(workdir, "weights_%s.pt" % epoch)
    torch.save(model.state_dict(), weights_path)

    # update log file
    log_path = os.path.join(workdir, "training.log")
    epoch_result.to_csv(log_path, mode='a', sep='\t', index=False)

    display(epoch_result)
    training_results = training_results.append(epoch_result)

    schedular.step()

display(training_results)

Mounted at /content/drive


[8000/8000]: 100%|############################################################| [03:03, loss=1.6476]


Unnamed: 0,time,duration,epoch,train_loss,validation_loss,validation_mean,validation_median
0,2023-09-09 14:18:49.591373,183,1,1.647581,1.418957,0.0,0.0


  training_results = training_results.append(epoch_result)
[8000/8000]: 100%|############################################################| [03:01, loss=1.0700]


Unnamed: 0,time,duration,epoch,train_loss,validation_loss,validation_mean,validation_median
0,2023-09-09 14:22:09.468562,181,2,1.070049,0.847747,74.733445,75.0


  training_results = training_results.append(epoch_result)
[8000/8000]: 100%|############################################################| [03:01, loss=0.7510]


Unnamed: 0,time,duration,epoch,train_loss,validation_loss,validation_mean,validation_median
0,2023-09-09 14:25:27.654106,181,3,0.751013,0.558378,82.656385,82.976431


  training_results = training_results.append(epoch_result)
[8000/8000]: 100%|############################################################| [03:01, loss=0.6661]


Unnamed: 0,time,duration,epoch,train_loss,validation_loss,validation_mean,validation_median
0,2023-09-09 14:28:46.612235,181,4,0.666071,0.495195,84.510008,84.970707


  training_results = training_results.append(epoch_result)
[8000/8000]: 100%|############################################################| [03:01, loss=0.6236]


Unnamed: 0,time,duration,epoch,train_loss,validation_loss,validation_mean,validation_median
0,2023-09-09 14:32:04.079414,181,5,0.623561,0.469426,85.70323,86.097205


  training_results = training_results.append(epoch_result)
[8000/8000]: 100%|############################################################| [03:01, loss=0.5863]


Unnamed: 0,time,duration,epoch,train_loss,validation_loss,validation_mean,validation_median
0,2023-09-09 14:35:22.036259,181,6,0.586273,0.446669,86.207312,86.676987


  training_results = training_results.append(epoch_result)
[8000/8000]: 100%|############################################################| [03:01, loss=0.5640]


Unnamed: 0,time,duration,epoch,train_loss,validation_loss,validation_mean,validation_median
0,2023-09-09 14:38:39.672992,181,7,0.564041,0.429269,86.848583,87.35286


  training_results = training_results.append(epoch_result)
[8000/8000]: 100%|############################################################| [03:01, loss=0.5376]


Unnamed: 0,time,duration,epoch,train_loss,validation_loss,validation_mean,validation_median
0,2023-09-09 14:41:57.489938,181,8,0.53763,0.421032,87.325642,87.707987


  training_results = training_results.append(epoch_result)
[8000/8000]: 100%|############################################################| [03:01, loss=0.5173]


Unnamed: 0,time,duration,epoch,train_loss,validation_loss,validation_mean,validation_median
0,2023-09-09 14:45:15.667004,181,9,0.517254,0.417576,87.439047,87.916573


  training_results = training_results.append(epoch_result)
[8000/8000]: 100%|############################################################| [03:01, loss=0.4968]


Unnamed: 0,time,duration,epoch,train_loss,validation_loss,validation_mean,validation_median
0,2023-09-09 14:48:32.941584,181,10,0.496756,0.402159,88.084509,88.539126


  training_results = training_results.append(epoch_result)
[8000/8000]: 100%|############################################################| [03:01, loss=0.4811]


Unnamed: 0,time,duration,epoch,train_loss,validation_loss,validation_mean,validation_median
0,2023-09-09 14:51:50.119424,181,11,0.481125,0.407597,87.934662,88.365558


  training_results = training_results.append(epoch_result)
[8000/8000]: 100%|############################################################| [03:01, loss=0.4635]


Unnamed: 0,time,duration,epoch,train_loss,validation_loss,validation_mean,validation_median
0,2023-09-09 14:55:07.292593,181,12,0.463495,0.409323,87.99096,88.475836


  training_results = training_results.append(epoch_result)
[560/8000]:   7%|####2                                                        | [00:13, loss=0.4253]


KeyboardInterrupt: ignored