In [41]:
from IPython.display import Markdown, display
bash = lambda commands: display(Markdown("```bash\n" + ' && \n'.join(commands) + "\n```"))
fasta = "/vol/bitbucket/bl1821/frankenstein/data/atotc.fasta"
pod5 = "/vol/bitbucket/bl1821/frankenstein/data/atotc.pod5"

frank = "/vol/bitbucket/bl1821/frankenstein/data/frank.fasta"
je = "/vol/bitbucket/bl1821/frankenstein/data/je.fasta"

# Baseline model

We start with the **Bonito** `dna_r10.4.1_e8.2_400bps_sup@v5.0.0` model in the `bsup` folder. We export the bonito model to the dorado-compatible format by running the following script:

In [5]:
bash(["bonito export bsup --output dsup"])

```bash
bonito export bsup --output dsup
```

We perform the baseline test with dorado and atotc data to evaluate the baseline model performance:

In [None]:
bash(["mkdir -p sup_test",
      "cd sup_test",
      f"dorado basecaller ../dsup {pod5} --reference {fasta} > sup.bam 2> >(tee sup.log >&2)",
      "dorado summary sup.bam > sup.tsv",
      "cd .."])

```bash
mkdir -p tsup && 
cd tsup && 
dorado basecaller ../dsup /vol/bitbucket/bl1821/frankenstein/data/atotc.pod5 --reference /vol/bitbucket/bl1821/frankenstein/data/atotc.fasta > sup.bam 2> >(tee sup.log >&2) && 
dorado summary sup.bam > sup.tsv && 
cd ..
```

# Improved model
We prepare the `frank.pod5` with squigulator for 100000 reads of 300nt data.

In [22]:
bash([
  "mkdir -p train_data/frank",
  "cd train_data",
  f"cp {frank} .",
  "squigulator frank.fasta -x dna-r10-prom -o frank.slow5 -n 100000 -r 300 --ont-friendly yes --seed 42",
  "blue-crab s2p frank.slow5 -o frank.pod5",
  "rm -f frank.slow5",
  "mv frank.pod5 frank/",
  "cd .."
])

```bash
mkdir -p train_data/frank && 
cd train_data && 
cp /vol/bitbucket/bl1821/frankenstein/data/frank.fasta . && 
squigulator frank.fasta -x dna-r10-prom -o frank.slow5 -n 100000 -r 300 --ont-friendly yes --seed 42 && 
blue-crab s2p frank.slow5 -o frank.pod5 && 
rm -f frank.slow5 && 
mv frank.pod5 frank/ && 
cd ..
```

We run the bonito basecaller with `--save-ctc` flag to obtain the training data from Frankenstein. We set the `--min-accuracy-save-ctc` flag to 0.9 since our current model accuracy is still low.

In [19]:
bash([
  "cd train_data",
  "bonito basecaller ../bsup frank --reference frank.fasta --save-ctc --min-accuracy-save-ctc 0.9 > frank/frank.bam",
  "cd .."
])

```bash
cd train_data && 
bonito basecaller ../bsup frank --reference frank.fasta --save-ctc --min-accuracy-save-ctc 0.9 > frank/frank.bam && 
cd ..
```

We perform the training:

In [27]:
bash(["bonito train bfrank --directory train_data/frank --epochs 5 --chunks 200 --valid-chunks 50 --pretrained bsup --batch 32"])

```bash
bonito train bfrank --directory train_data/frank --epochs 5 --chunks 200 --valid-chunks 50 --pretrained bsup --batch 32
```

We export the trained model:

In [28]:
bash(["bonito export bfrank --output dfrank"])

```bash
bonito export bfrank --output dfrank
```

And repeat the baseline test:

In [38]:
bash(["mkdir -p frank_test",
      "cd frank_test",
      f"dorado basecaller ../dfrank {pod5} --reference {fasta} > sup.bam 2> >(tee sup.log >&2)",
      "dorado summary sup.bam > sup.tsv",
      "cd .."])

```bash
mkdir -p frank_test && 
cd frank_test && 
dorado basecaller ../dfrank /vol/bitbucket/bl1821/frankenstein/data/atotc.pod5 --reference /vol/bitbucket/bl1821/frankenstein/data/atotc.fasta > sup.bam 2> >(tee sup.log >&2) && 
dorado summary sup.bam > sup.tsv && 
cd ..
```

# The JE model
We now get the frank model as a better-performing version of the sup model. We can now use it to generate better training data.

In [32]:
bash([
  "mkdir -p train_data/je",
  "cd train_data",
  f"cp {je} .",
  "squigulator je.fasta -x dna-r10-prom -o je.slow5 -n 100000 -r 300 --ont-friendly yes --seed 42",
  "blue-crab s2p je.slow5 -o je.pod5",
  "rm -f je.slow5",
  "mv je.pod5 je/",
  "cd .."
])

```bash
mkdir -p train_data/je && 
cd train_data && 
cp /vol/bitbucket/bl1821/frankenstein/data/je.fasta . && 
squigulator je.fasta -x dna-r10-prom -o je.slow5 -n 100000 -r 300 --ont-friendly yes --seed 42 && 
blue-crab s2p je.slow5 -o je.pod5 && 
rm -f je.slow5 && 
mv je.pod5 je/ && 
cd ..
```

We obtain the training data from je, now with normal (0.99) ctc accuracy threshold:

In [33]:
bash([
  "cd train_data",
  "bonito basecaller ../bfrank je --reference je.fasta --save-ctc > je/je.bam",
  "cd .."
])

```bash
cd train_data && 
bonito basecaller ../bfrank je --reference je.fasta --save-ctc > je/je.bam && 
cd ..
```

And the training:

In [35]:
bash(["bonito train bje --directory train_data/je --epochs 5 --chunks 200 --valid-chunks 50 --pretrained bsup --batch 32"])

```bash
bonito train bje --directory train_data/je --epochs 5 --chunks 200 --valid-chunks 50 --pretrained bsup --batch 32
```

Export:

In [36]:
bash(["bonito export bje --output dje"])

```bash
bonito export bje --output dje
```

Baseline test:

In [39]:
bash(["mkdir -p je_test",
      "cd je_test",
      f"dorado basecaller ../dje {pod5} --reference {fasta} > sup.bam 2> >(tee sup.log >&2)",
      "dorado summary sup.bam > sup.tsv",
      "cd .."])

```bash
mkdir -p je_test && 
cd je_test && 
dorado basecaller ../dje /vol/bitbucket/bl1821/frankenstein/data/atotc.pod5 --reference /vol/bitbucket/bl1821/frankenstein/data/atotc.fasta > sup.bam 2> >(tee sup.log >&2) && 
dorado summary sup.bam > sup.tsv && 
cd ..
```

# Collection of results

In [69]:
import csv

def tsv_dict(data_dir, speed):
  with open(f'{data_dir}/{speed}.tsv', newline='') as f:
    reader = csv.DictReader(f, delimiter='\t')
    return next(reader)

def make_table(data, markdown=True):
  # row_headers = sorted({key for inner in data.values() for key in inner})
  row_headers = [
    "alignment_accuracy", "alignment_identity", 
    "alignment_genome_start", "alignment_genome_end", "alignment_strand_start", "alignment_strand_end", "alignment_strand_coverage",
    "alignment_length", "alignment_num_aligned", "alignment_num_correct", "alignment_num_deletions", "alignment_num_insertions", "alignment_num_substitutions",
    ]

  # Build markdown table
  header_row = f"| Metric | {' | '.join(data.keys())} |"
  separator_row = f"|--------|{'|'.join(['--------'] * len(data))}|"
  data_rows = [
    f"| {key} | {' | '.join(str(data[name].get(key, '')) for name in data)} |"
    for key in row_headers
  ]

  table_md = "\n".join([header_row, separator_row] + data_rows)

  if markdown:
    display(Markdown(table_md))
  else:
    print(table_md)
    
def make_table_switched(data, markdown=True):
  # Select the metrics of interest
  row_headers = [
    "alignment_accuracy", "alignment_identity",
    "alignment_strand_coverage",
    "alignment_num_correct", "alignment_num_deletions", "alignment_num_insertions", "alignment_num_substitutions",
  ]

  # Prepare the new column headers (without 'alignment_' prefix)
  col_headers = [header.replace('alignment_', '') for header in row_headers]

  # Build the header row: "Model | accuracy | identity | ..."
  header_row = f"| Model | {' | '.join(col_headers)} |"
  separator_row = f"|-------|{'|'.join(['--------'] * len(col_headers))}|"

  # Build each data row: model name | value1 | value2 | ...
  data_rows = [
    f"| {model} | {' | '.join(str(data[model].get(key, '')) for key in row_headers)} |"
    for model in data
  ]

  table_md = "\n".join([header_row, separator_row] + data_rows)

  if markdown:
    display(Markdown(table_md))
  else:
    print(table_md)

In [70]:
models = ["sup", "frank", "je"]
data = {model: tsv_dict(f"{model}_test", "sup") for model in models}
make_table_switched(data)

| Model | accuracy | identity | strand_coverage | num_correct | num_deletions | num_insertions | num_substitutions |
|-------|--------|--------|--------|--------|--------|--------|--------|
| sup | 0.910668 | 0.955673 | 0.999998 | 2877111 | 96271 | 52510 | 133448 |
| frank | 0.991914 | 0.994825 | 0.987018 | 3044752 | 7025 | 1955 | 15840 |
| je | 0.99658 | 0.999012 | 1 | 3096812 | 6957 | 607 | 3062 |

# Model profiling

In [66]:
# Total parameters

import os
import torch

for model in models:
    path = f"d{model}"
    total = 0
    files = [file for file in os.listdir(path) if file.endswith('.tensor')]
    for file in files:
        module = torch.jit.load(os.path.join(path, file))
        for param in module.parameters():
            total += param.numel()
    print(model, total)

sup 78718162
frank 78718162
je 78718162
