# Training our foundational model

> "Lets start training!"

In [None]:
#| default_exp training.foundation

## Setup, indexing training data

In [None]:
#| export
from llm_mito_scanner.analysis.training import get_training_annotation_paths
from pathlib import Path
from tqdm.auto import tqdm
import pandas as pd



In [None]:
#| hide
from yaml import safe_load

tqdm.pandas(ncols=80, leave=False)

with open("../config.yml") as f:
    config = safe_load(f)

In [None]:
#| hide
data_path = Path(config.get("data_path"))
training_data_path = data_path / "training"
training_index_path = data_path / "training_index.csv"
gene_to_protein_maps_path = data_path / "gene_to_protein_maps.csv"

In [None]:
#| hide
import json

if not training_index_path.exists() or not gene_to_protein_maps_path.exists():
    gene_to_protein_maps, training_paths = get_training_annotation_paths(training_data_path)
    training_paths.to_csv(training_index_path, index=False)
    gene_to_protein_maps.to_csv(gene_to_protein_maps_path)
else:
    training_paths = pd.read_csv(training_index_path)
    gene_to_protein_maps = pd.read_csv(gene_to_protein_maps_path)

In [None]:
#| hide
training_paths.head()

Unnamed: 0,annotation,gene,gene_annotation,protein_annotation
0,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...
1,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...
2,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...
3,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...
4,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...


## Build the training dataset

### Construct the tokenizer, vocabulary

In [None]:
#| export
from torchtext.vocab import build_vocab_from_iterator
from fastai.text.core import BaseTokenizer

In [None]:
#| hide
tokenizer = BaseTokenizer()

In [None]:
#| hide
gene_file_paths = [Path(p) for p in training_paths.gene_annotation.dropna().unique().tolist()]
protein_file_paths = [Path(p) for p in training_paths.protein_annotation.dropna().unique().tolist()]

In [None]:
#| hide
all_annotation_file_paths = gene_file_paths + protein_file_paths

In [None]:
#| hide
test_tokens = []

with all_annotation_file_paths[0].open('r') as f:
    leftover_text = ""
    counter = 0
    while True:
        print(f"COUNTER: {counter}")
        token_string = f.read(1024)
        if token_string == "" or counter == 5:
            break
        print(f"LEFTOVER: {leftover_text}")
        print(f"CHUNK: {token_string}")
        last_sep = token_string.rindex(" ")
        print(f"LAST SEP: {last_sep}")
        if last_sep != -1:
            # Get our chunk of uninterrupted tokens
            token_chunk = leftover_text + token_string[:last_sep]
            print(f"TOKEN CHUNK: {token_chunk}")
            # Get our list of tokenized tokens
            chunk_tokens = [tok for tok in token_chunk.split(" ") if len(tok) > 0]
            print(f"CHUNK TOKENS: {chunk_tokens}")
            # Record the leftover string not in the token string
            leftover_text = token_string[last_sep:]
            # yield our list of tokens
            test_tokens.append(chunk_tokens)
        else:
            leftover_text = leftover_text + token_string
        counter += 1
    test_tokens.append(leftover_text.split(" "))

COUNTER: 0
LEFTOVER: 
CHUNK: [N]A [N]C [N]A [N]T [N]C [N]C [N]T [N]G [N]C [N]T [N]T [N]G [N]T [N]C [N]C [N]T [N]T [N]T [N]G [N]G [N]G [N]G [N]C [N]A [N]T [N]C [N]T [N]C [N]T [N]G [N]T [N]C [N]A [N]T [N]G [N]T [N]G [N]C [N]T [N]T [N]A [N]T [N]A [N]G [N]T [N]C [N]A [N]C [N]T [N]C [N]C [N]T [N]C [N]T [N]C [N]C [N]A [N]T [N]C [N]T [N]A [N]T [N]G [N]T [N]T [N]A [N]T [N]A [N]C [N]T [N]G [N]A [N]T [N]C [N]T [N]T [N]A [N]C [N]T [N]C [N]C [N]A [N]A [N]G [N]C [N]C [N]T [N]C [N]T [N]T [N]T [N]C [N]A [N]T [N]G [N]T [N]T [N]G [N]C [N]G [N]C [N]T [N]T [N]T [N]G [N]T [N]A [N]A [N]T [N]G [N]A [N]A [N]T [N]T [N]T [N]C [N]C [N]A [N]A [N]C [N]T [N]G [N]C [N]T [N]C [N]A [N]A [N]C [N]C [N]T [N]T [N]T [N]C [N]T [N]G [N]A [N]T [N]G [N]G [N]A [N]C [N]A [N]A [N]A [N]C [N]C [N]G [N]C [N]C [N]C [N]C [N]T [N]C [N]A [N]T [N]A [N]T [N]C [N]T [N]T [N]C [N]C [N]A [N]A [N]G [N]A [N]G [N]A [N]G [N]A [N]C [N]G [N]A [N]C [N]T [N]G [N]A [N]G [N]A [N]C [N]A [N]T [N]G [N]A [N]A [N]C [N]T [N]G [N]G [N]A [N]G [N]G [N]A [N]G [

In [None]:
all_annotation_file_paths[0]

Path('/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000003.12/100129480/gene.txt')

In [None]:
sum([len(tok) for tok in test_tokens])

1026

In [None]:
#| export
def yield_file_tokens(path: Path, chunksize=1024, sep=" ") -> list[str]:
    with path.open('r') as f:
        leftover_text = ""
        while True:
            token_string = f.read(chunksize)
            if token_string == "":
                break
            try:
                last_sep = token_string.rindex(" ")
                # Get our chunk of uninterrupted tokens
                token_chunk = leftover_text + token_string[:last_sep]
                # Get our list of tokenized tokens
                chunk_tokens = [tok for tok in token_chunk.split(" ") if len(tok) > 0]
                # Record the leftover string not in the token string
                leftover_text = token_string[last_sep:]
                # yield our list of tokens
                yield chunk_tokens
            except ValueError:
                # If there is no sep, capture the string
                leftover_text = leftover_text + token_string
        chunk_tokens = [tok for tok in leftover_text.split(" ") if len(tok) > 0]
        if len(chunk_tokens) > 0:
            yield chunk_tokens

In [None]:
#| hide
test_file_tokens = list(yield_file_tokens(all_annotation_file_paths[0]))
sum([len(token_list) for token_list in test_file_tokens])

21224

In [None]:
#| export
# Sort file paths by size
def get_sequence_length(path: Path):
    length = 0
    for tokens in yield_file_tokens(path):
        length += len(tokens)
    return length

In [None]:
#| hide
get_sequence_length(all_annotation_file_paths[0])

21224

In [None]:
#| hide
path_file_sizes = {}

In [None]:
#| hide
def get_sequence_length_mp(path: Path):
    return path, get_sequence_length(path)


pbar = tqdm(total=len(gene_file_paths))
with Pool(os.cpu_count() - 1) as p:
    for result in p.imap_unordered(get_sequence_length_mp, gene_file_paths):
        result_path, result_length = result
        path_file_sizes[result_path] = result_length
        pbar.update(1)
pbar.close()

  0%|          | 0/16306 [00:00<?, ?it/s]

In [None]:
#| hide
# Get sorted list of annotation paths by size
# Protein to gene lookup
training_paths_with_size = training_paths.copy()
training_paths_with_size.loc[:, 'size'] = training_paths_with_size.gene_annotation.apply(lambda path_str: path_file_sizes.get(Path(path_str)))
# Sort by file size via gene to file size lookup
training_paths_with_size.sort_values('size', ascending=False, inplace=True)
training_paths_with_size.head()

Unnamed: 0,annotation,gene,gene_annotation,protein_annotation,size
75563,NC_000016,54715,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/gene.txt,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/rna-XM_017023318.3.txt,2473620
75550,NC_000016,54715,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/gene.txt,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/rna-NM_001415912.1.txt,2473620
75548,NC_000016,54715,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/gene.txt,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/rna-NM_001415910.1.txt,2473620
75549,NC_000016,54715,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/gene.txt,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/rna-NM_001415911.1.txt,2473620
75564,NC_000016,54715,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/gene.txt,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/rna-XM_017023320.3.txt,2473620


In [None]:
#| hide
all_annotation_file_paths_sorted = pd.melt(
    training_paths_with_size, 
    id_vars=['size'], 
    value_vars=['gene_annotation', 'protein_annotation']
).drop(
    "variable", 
    axis=1
).drop_duplicates(
    subset=['value']
).sort_values('size', ascending=False).dropna(subset=["value"])#.apply(Path)
all_annotation_file_paths_sorted[:5]

Unnamed: 0,size,value
0,2473620,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/gene.txt
106844,2473620,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/rna-XM_017023318.3.txt
106872,2473620,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/rna-NM_001415909.1.txt
106873,2473620,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/rna-NM_001415899.1.txt
106874,2473620,/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000016.10/54715/rna-NM_001415907.1.txt


In [None]:
#| export
def make_batches(batch_min_size: int, path_df: pd.DataFrame):
    batches = []
    batch = []
    batch_size = 0
    pbar = tqdm(total=path_df.shape[0])
    for idx, row in path_df.iterrows():
        row_path = Path(row.value)
        if batch_size >= batch_min_size:
            batches.append(batch)
            batch = [row_path]
            batch_size = row['size']
        else:
            batch.append(row_path)
            batch_size += row['size']
        pbar.update(1)
    pbar.close()
    batches.append(batch)
    return batches

In [None]:
#| hide
file_batches = make_batches(1000000, all_annotation_file_paths_sorted)

  0%|          | 0/123138 [00:00<?, ?it/s]

In [None]:
#| hide
len(file_batches)

11966

In [None]:
#| hide
pd.Series([len(batch) for batch in file_batches]).describe()

count    11966.000000
mean        10.290657
std         22.495713
min          1.000000
25%          3.000000
50%          5.000000
75%         10.000000
max        801.000000
dtype: float64

In [None]:
#| export
from torchtext.vocab import Vocab
from collections import Counter, OrderedDict
from multiprocessing import Pool
import os
import numpy as np
from operator import attrgetter
from functools import reduce
from operator import add
import itertools


def build_vocab(file_paths: list[Path], special_tokens: list[str] = ['<unk>']) -> Vocab:
    # Terms in the vocab will be ordered by the order they're added to this counter
    counter = Counter()
    # Add common nucleotides
    nucleotide_tag = "[N]"
    nucleotides = list("ACGT")
    nucleotide_tokens = [
        f"{nucleotide_tag}{nucleotide}" for nucleotide in nucleotides
    ]
    counter.update(nucleotide_tokens)
    # mRNA tags
    counter.update(['[intron]', '[exon]'])
    # amino acids
    amino_acids = list("ARNDCQEGHILKMFPSTWYVBZX")
    amino_acid_tag = "[A]"
    codon_pos = list(range(1, 4))
    amino_acid_tokens = [
        f"{amino_acid_tag}-{amino_acid_val}-{codon_pos_val}" for amino_acid_val, codon_pos_val in itertools.product(
            amino_acids, 
            codon_pos
        )
    ]
    counter.update(amino_acid_tokens)
    # Rare nucleotides
    rare_nucleotides = list("URYKMSWBDHVN")
    rare_nucleotide_tokens = [
        f"{nucleotide_tag}-{nucleotide}" for nucleotide in rare_nucleotides
    ]
    counter.update(rare_nucleotide_tokens)
    token_vocab = vocab(counter, specials=special_tokens)
    return token_vocab

In [None]:
#| hide
test_vocab = build_vocab(file_batches[-3:])
test_vocab

Vocab()

In [None]:
#| hide
test_vocab["[N]T"], test_vocab["[N]A"], test_vocab["<unk>"], test_vocab["[A]-M-1"]

(4, 1, 0, 43)

In [None]:
#| hide
# Save the vocabulary
artefacts_path = data_path / "artefacts"
training_artefacts_path = artefacts_path / "training"
if not training_artefacts_path.exists():
    training_artefacts_path.mkdir(parents=True)

In [None]:
#| hide
import torch

vocab_path = training_artefacts_path / "vocab.pt"
if not vocab_path.exists():
    vocab = build_vocab(file_batches)
    torch.save(vocab, vocab_path)
else:
    vocab = torch.load(vocab_path)

In [None]:
#| hide
vocab["[N]T"], vocab["[N]A"], vocab["[N]C"], vocab["[N]G"], vocab["<unk>"]

(4, 1, 2, 3, 0)

### Build our training, validation, test indices

In [None]:
#| hide
# Build indices for train, validation, test

In [None]:
#| export
def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

In [None]:
# ``train_iter`` was "consumed" by the process of building the vocab,
# so we have to create it again
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into ``bsz`` separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Arguments:
        data: Tensor, shape ``[N]``
        bsz: int, batch size

    Returns:
        Tensor of shape ``[N // bsz, bsz]``
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape ``[seq_len, batch_size]``
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()