# Training our foundational model

> "Lets start training!"

In [None]:
#| default_exp training.foundation

## Setup, indexing training data

In [None]:
#| export
from llm_mito_scanner.analysis.training import get_training_annotation_paths
from pathlib import Path
from tqdm.auto import tqdm
import pandas as pd



In [None]:
#| hide
from yaml import safe_load

tqdm.pandas(ncols=80, leave=False)

with open("../config.yml") as f:
    config = safe_load(f)

In [None]:
#| hide
data_path = Path(config.get("data_path"))
training_data_path = data_path / "training"
training_index_path = data_path / "training_index.csv"
gene_to_protein_maps_path = data_path / "gene_to_protein_maps.csv"

In [None]:
#| hide
import json

if not training_index_path.exists() or not gene_to_protein_maps_path.exists():
    gene_to_protein_maps, training_paths = get_training_annotation_paths(training_data_path)
    training_paths.to_csv(training_index_path, index=False)
    gene_to_protein_maps.to_csv(gene_to_protein_maps_path)
else:
    training_paths = pd.read_csv(training_index_path)
    gene_to_protein_maps = pd.read_csv(gene_to_protein_maps_path)

In [None]:
#| hide
training_paths.head()

Unnamed: 0,annotation,gene,gene_annotation,protein_annotation
0,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...
1,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...
2,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...
3,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...
4,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...


## Build the training dataset

### Construct the tokenizer, vocabulary

In [None]:
#| export
from torchtext.vocab import build_vocab_from_iterator
from fastai.text.core import BaseTokenizer

In [None]:
#| hide
tokenizer = BaseTokenizer()

In [None]:
#| hide
all_annotation_file_paths = [Path(p) for p in training_paths.gene_annotation.dropna().unique().tolist() + \
    training_paths.protein_annotation.dropna().unique().tolist()]
len(all_annotation_file_paths)

123138

In [None]:
#| hide
test_tokens = []


with all_annotation_file_paths[0].open('r') as f:
    leftover_text = ""
    counter = 0
    while True:
        print(f"COUNTER: {counter}")
        token_string = f.read(1024)
        if token_string == "" or counter == 5:
            break
        print(f"LEFTOVER: {leftover_text}")
        print(f"CHUNK: {token_string}")
        last_sep = token_string.rindex(" ")
        print(f"LAST SEP: {last_sep}")
        if last_sep != -1:
            # Get our chunk of uninterrupted tokens
            token_chunk = leftover_text + token_string[:last_sep]
            print(f"TOKEN CHUNK: {token_chunk}")
            # Get our list of tokenized tokens
            chunk_tokens = [tok for tok in token_chunk.split(" ") if len(tok) > 0]
            print(f"CHUNK TOKENS: {chunk_tokens}")
            # Record the leftover string not in the token string
            leftover_text = token_string[last_sep:]
            # yield our list of tokens
            test_tokens.append(chunk_tokens)
        else:
            leftover_text = leftover_text + token_string
        counter += 1
    test_tokens.append(leftover_text.split(" "))

COUNTER: 0
LEFTOVER: 
CHUNK: [N]A [N]C [N]A [N]T [N]C [N]C [N]T [N]G [N]C [N]T [N]T [N]G [N]T [N]C [N]C [N]T [N]T [N]T [N]G [N]G [N]G [N]G [N]C [N]A [N]T [N]C [N]T [N]C [N]T [N]G [N]T [N]C [N]A [N]T [N]G [N]T [N]G [N]C [N]T [N]T [N]A [N]T [N]A [N]G [N]T [N]C [N]A [N]C [N]T [N]C [N]C [N]T [N]C [N]T [N]C [N]C [N]A [N]T [N]C [N]T [N]A [N]T [N]G [N]T [N]T [N]A [N]T [N]A [N]C [N]T [N]G [N]A [N]T [N]C [N]T [N]T [N]A [N]C [N]T [N]C [N]C [N]A [N]A [N]G [N]C [N]C [N]T [N]C [N]T [N]T [N]T [N]C [N]A [N]T [N]G [N]T [N]T [N]G [N]C [N]G [N]C [N]T [N]T [N]T [N]G [N]T [N]A [N]A [N]T [N]G [N]A [N]A [N]T [N]T [N]T [N]C [N]C [N]A [N]A [N]C [N]T [N]G [N]C [N]T [N]C [N]A [N]A [N]C [N]C [N]T [N]T [N]T [N]C [N]T [N]G [N]A [N]T [N]G [N]G [N]A [N]C [N]A [N]A [N]A [N]C [N]C [N]G [N]C [N]C [N]C [N]C [N]T [N]C [N]A [N]T [N]A [N]T [N]C [N]T [N]T [N]C [N]C [N]A [N]A [N]G [N]A [N]G [N]A [N]G [N]A [N]C [N]G [N]A [N]C [N]T [N]G [N]A [N]G [N]A [N]C [N]A [N]T [N]G [N]A [N]A [N]C [N]T [N]G [N]G [N]A [N]G [N]G [N]A [N]G [

In [None]:
all_annotation_file_paths[0]

Path('/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000003.12/100129480/gene.txt')

In [None]:
sum([len(tok) for tok in test_tokens])

21307

In [None]:
#| export
def yield_file_tokens(path: Path, chunksize=1024, sep=" ") -> list[str]:
    with path.open('r') as f:
        leftover_text = ""
        while True:
            token_string = f.read(chunksize)
            if token_string == "":
                break
            try:
                last_sep = token_string.rindex(" ")
                # Get our chunk of uninterrupted tokens
                token_chunk = leftover_text + token_string[:last_sep]
                # Get our list of tokenized tokens
                chunk_tokens = [tok for tok in token_chunk.split(" ") if len(tok) > 0]
                # Record the leftover string not in the token string
                leftover_text = token_string[last_sep:]
                # yield our list of tokens
                yield chunk_tokens
            except ValueError:
                # If there is no sep, capture the string
                leftover_text = leftover_text + token_string
        chunk_tokens = [tok for tok in leftover_text.split(" ") if len(tok) > 0]
        if len(chunk_tokens) > 0:
            yield chunk_tokens


def yield_all_training_tokens(file_paths: list[Path]):
    file_path_tqdm = tqdm(file_paths)
    return map(lambda path: yield_file_tokens(path), file_path_tqdm)

In [None]:
#| hide
test_file_tokens = list(yield_file_tokens(all_annotation_file_paths[0]))
sum([len(token_list) for token_list in test_file_tokens])

21224

In [None]:
#| export
from torchtext.vocab import Vocab, vocab
from collections import Counter, OrderedDict


def build_vocab(file_paths: list[Path], special_tokens: list[str] = ['<unk>']) -> Vocab:
    # Count tokens in each file
    token_counter = Counter()
    try:
        for path in tqdm(file_paths):
            for tokens in yield_file_tokens(path):
                token_counter.update(tokens)
    except KeyboardInterrupt:
        pass
    print(token_counter)
    ordered_counter = OrderedDict(token_counter.most_common())
    token_vocab = vocab(ordered_counter, specials=special_tokens)
    return token_vocab

In [None]:
#| hide
test_vocab = build_vocab(all_annotation_file_paths[:4])
test_vocab

  0%|          | 0/4 [00:00<?, ?it/s]

Counter({'[N]T': 63851, '[N]A': 54272, '[N]G': 40072, '[N]C': 38971})


Vocab()

In [None]:
#| hide
test_vocab["[N]T"], test_vocab["[N]A"], test_vocab["<unk>"]

(1, 2, 0)

In [None]:
#| hide
# Save the vocabulary
artefacts_path = data_path / "artefacts"
training_artefacts_path = artefacts_path / "training"
if not training_artefacts_path.exists():
    training_artefacts_path.mkdir(parents=True)

In [None]:
#| hide
import torch

vocab_path = training_artefacts_path / "vocab.pt"
if not vocab_path.exists():
    vocab = build_vocab(all_annotation_file_paths)
    torch.save(vocab, vocab_path)
else:
    vocab = torch.load(vocab_path)

  0%|          | 0/123138 [00:00<?, ?it/s]

In [None]:
#| hide
vocab["[N]T"], vocab["[N]A"], vocab["[N]C"], vocab["[N]G"], vocab["<unk>"]

### Build our training, validation, test idx

In [None]:
#| hide
# Build indices for train, validation, test

In [None]:
#| export
def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

In [None]:
# ``train_iter`` was "consumed" by the process of building the vocab,
# so we have to create it again
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into ``bsz`` separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Arguments:
        data: Tensor, shape ``[N]``
        bsz: int, batch size

    Returns:
        Tensor of shape ``[N // bsz, bsz]``
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape ``[seq_len, batch_size]``
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()