# Training our foundational model

> "Lets start training!"

In [None]:
#| default_exp training.foundation

## Setup, indexing training data

In [None]:
#| export
from llm_mito_scanner.analysis.training import get_training_annotation_paths
from pathlib import Path
from tqdm.auto import tqdm
import pandas as pd



In [None]:
#| hide
from yaml import safe_load

tqdm.pandas(ncols=80, leave=False)

with open("../config.yml") as f:
    config = safe_load(f)

In [None]:
#| hide
data_path = Path(config.get("data_path"))
training_data_path = data_path / "training"
training_index_path = data_path / "training_index.csv"
gene_to_protein_maps_path = data_path / "gene_to_protein_maps.csv"

In [None]:
#| hide
import json

if not training_index_path.exists() or not gene_to_protein_maps_path.exists():
    gene_to_protein_maps, training_paths = get_training_annotation_paths(training_data_path)
    training_paths.to_csv(training_index_path, index=False)
    gene_to_protein_maps.to_csv(gene_to_protein_maps_path)
else:
    training_paths = pd.read_csv(training_index_path)
    gene_to_protein_maps = pd.read_csv(gene_to_protein_maps_path)

In [None]:
#| hide
training_paths.head()

Unnamed: 0,annotation,gene,gene_annotation,protein_annotation
0,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...
1,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...
2,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...
3,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...
4,NC_000003,100129480,/mnt/e/Data/llm-mito-scanner-data/data/trainin...,/mnt/e/Data/llm-mito-scanner-data/data/trainin...


## Build the training dataset

### Construct the tokenizer, vocabulary

In [None]:
#| export
from torchtext.vocab import build_vocab_from_iterator
from fastai.text.core import BaseTokenizer

In [None]:
#| hide
tokenizer = BaseTokenizer()

In [None]:
#| export
def yield_file_tokens(path: Path, chunksize=1024, sep="\s"):
    with path.open('r') as f:
        row = ''
        while (chunk := f.read(chunksize)) != '':   # End of file
            while (i := chunk.find(sep)) != -1:     # No separator found
                yield row + chunk[:i]
                chunk = chunk[i+1:]
                row = ''
            row += chunk
        yield row


def yield_all_training_tokens(file_paths: list[Path]):
    file_path_tqdm = tqdm(file_paths)
    return map(lambda path: yield_file_tokens(path), file_path_tqdm)

In [None]:
test_file_tokens = yield_file_tokens(Path(training_paths.gene_annotation.iloc[0]))
len(next(test_file_tokens))

106119

In [None]:
#| hide
all_annotation_file_paths = [Path(p) for p in training_paths.gene_annotation.dropna().unique().tolist() + \
    training_paths.protein_annotation.dropna().unique().tolist()]
len(all_annotation_file_paths)

123138

In [None]:
all_annotation_file_paths[26424:26424+2]

[Path('/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000004.12/254251/rna-NM_001394446.1.txt'),
 Path('/mnt/e/Data/llm-mito-scanner-data/data/training/NC_000004.12/254251/rna-NM_153686.8.txt')]

In [None]:
train_iter = yield_all_training_tokens(
    all_annotation_file_paths
)
vocab = build_vocab_from_iterator(train_iter, specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

  0%|          | 0/123138 [00:00<?, ?it/s]

In [None]:
#| hide
vocab

### Build our training, validation, test idx

In [None]:
#| hide
# Build indices for train, validation, test

In [None]:
#| export
def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

In [None]:
# ``train_iter`` was "consumed" by the process of building the vocab,
# so we have to create it again
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into ``bsz`` separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Arguments:
        data: Tensor, shape ``[N]``
        bsz: int, batch size

    Returns:
        Tensor of shape ``[N // bsz, bsz]``
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape ``[seq_len, batch_size]``
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()