In [2]:
#| default_exp training.transcription.index

In [3]:
#| export
from pathlib import Path
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import torch
from torch import Tensor
from torch.utils.data import Dataset, IterableDataset
from torchtext.vocab import vocab, Vocab
from collections import Counter, OrderedDict
import operator

tqdm.pandas()

from llm_mito_scanner.data.download import load_config, \
    get_latest_assembly_path, get_genomic_genbank_path

random_state = 42



In [4]:
#| hide
config = load_config()

In [5]:
#| hide
data_path = Path(config.get("data_path"))
data_raw_path = data_path / "raw"
assemblies_path = data_raw_path / "assemblies"
latest_assembly_path = get_latest_assembly_path(assemblies_path)
genomic_genbank_path = get_genomic_genbank_path(latest_assembly_path)
chromosomes_path = latest_assembly_path / "chromosomes"
training_data_path = latest_assembly_path / "training"
transcription_data_path = training_data_path / "transcription"
sequences_data_path = transcription_data_path / "sequences"
if not sequences_data_path.exists():
    raise Exception(f"This notebook expects the path at {sequences_data_path.resolve()}")

In [6]:
#| export
def make_training_index(index_dir: Path, sequences_path: Path, sample: bool = False, save: bool = False):
    training_data_index_frames = []
    sequences = list(sequences_path.glob("*.csv"))
    if sample:
        sequences = sequences[-2:]
    for f in tqdm(sequences):
        f_frame = pd.read_csv(f, usecols=["geneid"]).reset_index(drop=False).rename({"index": "file_index"}, axis=1)
        f_frame.loc[:, 'file'] = f.stem
        training_data_index_frames.append(f_frame)
    training_data_index = pd.concat(
        training_data_index_frames, 
        axis=0, ignore_index=True
    ).reset_index(drop=True)
    if save:
        training_data_index.to_csv(index_dir / "index.csv", index=False)
    return training_data_index


def get_training_index(index_dir: Path, sequences_path: Path = None, **make_kwargs):
    index_path = index_dir / "index.csv"
    if not index_path.exists():
        return make_training_index(index_dir, sequences_path, **make_kwargs)
    else:
        return pd.read_csv(index_path)

In [7]:
#| hide
training_data_index = get_training_index(transcription_data_path, sequences_data_path, save=False, sample=True)

In [8]:
#| hide
training_data_index.head()

Unnamed: 0,file_index,geneid,file
0,0,GeneID:79501,NC_000001.11
1,1,GeneID:112268260,NC_000001.11
2,2,GeneID:729759,NC_000001.11
3,3,GeneID:105378947,NC_000001.11
4,4,GeneID:81399,NC_000001.11


In [9]:
#| export
def make_train_test_split(index: pd.DataFrame, random_state = 42):
    return train_test_split(index, random_state=random_state)

In [10]:
#| hide
train_idx, test_idx = make_train_test_split(training_data_index)

In [11]:
#| hide
train_idx.shape, test_idx.shape

((97815, 3), (32605, 3))

In [12]:
#| hide
train_idx.head()

Unnamed: 0,file_index,geneid,file
106188,5022,GeneID:5889,NC_000017.11
124266,1229,GeneID:9814,NC_000022.11
91738,62,GeneID:81614,NC_000015.10
41879,118,GeneID:63027,NC_000006.12
108927,432,GeneID:2774,NC_000018.10


In [13]:
#| export
def get_sequence(sequences_path: Path, file: str, idx: int) -> pd.Series:
    row = pd.read_csv(sequences_path / f"{file}.csv", header=None, skiprows=idx, nrows=1)
    return row

In [14]:
#| hide
example_sequence = get_sequence(sequences_data_path, "NC_000023.11", 3403)

In [20]:
#| hide
example_sequence.iloc[0, 1].count(",")

27639

In [16]:
#| export
class TranscriptionDataset(Dataset):
    def __init__(self, training_path: Path, sequences_path: Path, train: bool):
        self.training_path = training_path
        self.sequences_path = sequences_path
        self.train = train
        self.training_index = get_training_index(self.training_path, self.sequences_path, sample=False, save=True)
        self.train_idx, self.test_idx = make_train_test_split(self.training_index)

    def __len__(self) -> int:
        if self.train:
            return self.train_idx.shape[0]
        else:
            return self.test_idx.shape[0]

    def __getitem__(self, idx) -> tuple[str, str]:
        if self.train:
            sequence_row = self.train_idx.iloc[idx, :]
        else:
            sequence_row = self.test_idx.iloc[idx, :]
        sequence_file_stem = sequence_row.file
        sequence_file_idx = sequence_row.file_index
        sequence = get_sequence(sequences_data_path, sequence_file_stem, sequence_file_idx)
        sequence_input = sequence.iloc[0, 1]
        sequence_target = sequence.iloc[0, 2]
        return sequence_input, sequence_target

In [17]:
#| hide
transcription_dataset_train = TranscriptionDataset(transcription_data_path, sequences_data_path, True)
transcription_dataset_test = TranscriptionDataset(transcription_data_path, sequences_data_path, False)

In [18]:
#| hide
len(transcription_dataset_train), len(transcription_dataset_test)

(97815, 32605)

In [21]:
#| hide
for i, tup in enumerate(transcription_dataset_train):
    if i == 2:
        break
    print(tup[0].count(","))

43009
122052


In [22]:
#| export
def tokenize(seq: str) -> list[str]:
    return seq.split(",")

In [23]:
#| hide
sample_train_data = transcription_dataset_train[0]

In [24]:
#| hide
len(tokenize(sample_train_data[0])), len(tokenize(sample_train_data[1]))

(43010, 43010)

In [25]:
#| hide
gene_path = latest_assembly_path / "genes"
gene_path.exists()

True

In [26]:
#| export
def count_transcription_tokens(genes_path: Path, pbar: bool = False) -> OrderedDict:
    token_counter = Counter()
    gene_files = list(genes_path.glob("*.csv"))
    if pbar:
        gene_files = tqdm(gene_files, leave=False, position=0)
    for f in gene_files:
        f_sequences = pd.read_csv(f, usecols=['sequence']).sequence
        if pbar:
            f_sequences = tqdm(f_sequences, leave=False, position=1)
        f_counter = sum(map(Counter, f_sequences), Counter())
        token_counter = token_counter + f_counter
    token_ordered_dict = OrderedDict(token_counter.most_common())
    return token_ordered_dict


def build_vocab(
        genes_path: Path, 
        pbar: bool = False, 
        intron_token: str = "<intron>", unknown_token: str = "<unk>", 
        default_index: int = -1):
    token_ordered_dict = count_transcription_tokens(genes_path=genes_path, pbar=pbar)
    transcription_vocab = vocab(token_ordered_dict, specials=[intron_token, unknown_token])
    transcription_vocab.set_default_index(default_index)
    return transcription_vocab


def get_vocab(
        training_path: Path,
        **build_vocab_kwargs
        ) -> Vocab:
    vocab_path = training_path / "vocab.pt"
    if not vocab_path.exists():
        transcription_vocab = build_vocab(**build_vocab_kwargs)
        torch.save(transcription_vocab, vocab_path)
    else:
        transcription_vocab = torch.load(vocab_path)
    return transcription_vocab

In [27]:
#| hide
transcription_vocab = get_vocab(transcription_data_path, pbar=True, genes_path=gene_path)

In [28]:
#| hide
len(transcription_vocab)

13

In [31]:
#| export
def process_training_sequence(sequence: tuple[str, str], data_vocab: Vocab) -> tuple[Tensor, Tensor]:
    """Converts raw text into a flat Tensor."""
    input_tensor = torch.tensor(data_vocab(tokenize(sequence[0])), dtype=torch.long)
    target_tensor = torch.tensor(data_vocab(tokenize(sequence[1])), dtype=torch.long)
    return input_tensor, target_tensor

In [32]:
#| hide
processed_example_sequence = process_training_sequence(sample_train_data, transcription_vocab)
processed_example_sequence[0], processed_example_sequence[1]

(tensor([4, 2, 4,  ..., 5, 2, 5]), tensor([4, 2, 4,  ..., 5, 2, 5]))

In [53]:
#| hide
processed_example_sequence[0].shape

torch.Size([43010])

In [33]:
#| hide
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [37]:
#| export
def batchify_sequence(sequence: Tensor, bsz: int) -> Tensor:
    global device
    seq_len = sequence.size(0) // bsz
    data = sequence[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)


def batchify(sequence_tensors: tuple[Tensor, Tensor], bsz: int) -> tuple[Tensor, Tensor]:
    """Divides the data into ``bsz`` separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Arguments:
        data: Tensor, shape ``[N]``
        bsz: int, batch size

    Returns:
        Tensor of shape ``[N // bsz, bsz]``
    """
    input_batches = batchify_sequence(sequence_tensors[0], bsz)
    target_batches = batchify_sequence(sequence_tensors[1], bsz)
    return input_batches, target_batches

In [45]:
#| hide
batch_size = 100

In [46]:
#| hide
example_sequence_batchified = batchify(processed_example_sequence, batch_size)
example_sequence_batchified[0], example_sequence_batchified[1]

(tensor([[4, 4, 3,  ..., 5, 5, 2],
         [2, 4, 3,  ..., 3, 2, 3],
         [4, 3, 4,  ..., 5, 4, 3],
         ...,
         [2, 5, 3,  ..., 5, 2, 4],
         [4, 2, 3,  ..., 3, 4, 3],
         [3, 4, 2,  ..., 5, 3, 2]], device='cuda:0'),
 tensor([[4, 0, 0,  ..., 5, 5, 2],
         [2, 0, 0,  ..., 3, 2, 3],
         [4, 0, 0,  ..., 5, 4, 3],
         ...,
         [0, 0, 0,  ..., 5, 2, 4],
         [0, 0, 0,  ..., 3, 4, 3],
         [0, 0, 0,  ..., 5, 3, 2]], device='cuda:0'))

In [47]:
#| hide
example_sequence_batchified[0].shape

torch.Size([430, 100])

In [57]:
#| export
def get_batch(input: Tensor, target: Tensor, i: int, bptt: int = 35) -> tuple[Tensor, Tensor]:
    global device
    seq_len = min(bptt, len(input) - 1 - i)
    data = input[i:i+seq_len].to(device)
    target = target[i:i+seq_len].to(device)
    return data, target

In [58]:
#| hide
bptt = 35
eval_batch_size = 10

processed_example_sequence_input, processed_example_sequence_target = processed_example_sequence
sequences_batch_num = (processed_example_sequence_input.shape[0] - 1) // bptt
for batch, i in enumerate(range(0, sequences_batch_num, bptt)):
    # Get Batch
    data, target = get_batch(processed_example_sequence_input, processed_example_sequence_target, i, bptt)
    break

In [60]:
#| hide
print(sequences_batch_num)
data, target, data.shape, target.shape

1228


(tensor([4, 2, 4, 5, 4, 4, 3, 4, 2, 2, 2, 4, 4, 5, 2, 4, 5, 2, 5, 5, 4, 4, 4, 4,
         2, 2, 3, 4, 5, 3, 4, 4, 2, 4, 3], device='cuda:0'),
 tensor([4, 2, 4, 5, 4, 4, 3, 4, 2, 2, 2, 4, 4, 5, 2, 4, 5, 2, 5, 5, 4, 4, 4, 4,
         2, 2, 3, 4, 5, 3, 4, 4, 2, 4, 3], device='cuda:0'),
 torch.Size([35]),
 torch.Size([35]))

In [61]:
#| hide
import nbdev; nbdev.nbdev_export()