In [1]:
#| default_exp training.transcription.index

In [134]:
#| export
from pathlib import Path
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import torch
from torch import Tensor
from torch.utils.data import Dataset, IterableDataset
from torchtext.vocab import vocab, Vocab
from collections import Counter, OrderedDict
import operator

tqdm.pandas()

from llm_mito_scanner.data.download import load_config, \
    get_latest_assembly_path, get_genomic_genbank_path

random_state = 42

In [3]:
#| hide
config = load_config()

In [4]:
#| hide
data_path = Path(config.get("data_path"))
data_raw_path = data_path / "raw"
assemblies_path = data_raw_path / "assemblies"
latest_assembly_path = get_latest_assembly_path(assemblies_path)
genomic_genbank_path = get_genomic_genbank_path(latest_assembly_path)
chromosomes_path = latest_assembly_path / "chromosomes"
training_data_path = latest_assembly_path / "training"
transcription_data_path = training_data_path / "transcription"
sequences_data_path = transcription_data_path / "sequences"
if not sequences_data_path.exists():
    raise Exception(f"This notebook expects the path at {sequences_data_path.resolve()}")

In [55]:
#| export
def make_training_index(index_dir: Path, sequences_path: Path, sample: bool = False, save: bool = False):
    training_data_index_frames = []
    sequences = list(sequences_path.glob("*.csv"))
    if sample:
        sequences = sequences[-2:]
    for f in tqdm(sequences):
        f_frame = pd.read_csv(f, usecols=["geneid"]).reset_index(drop=False).rename({"index": "file_index"}, axis=1)
        f_frame.loc[:, 'file'] = f.stem
        training_data_index_frames.append(f_frame)
    training_data_index = pd.concat(
        training_data_index_frames, 
        axis=0, ignore_index=True
    ).reset_index(drop=True)
    if save:
        training_data_index.to_csv(index_dir / "index.csv", index=False)
    return training_data_index


def get_training_index(index_dir: Path, sequences_path: Path = None, **make_kwargs):
    index_path = index_dir / "index.csv"
    if not index_path.exists():
        return make_training_index(index_dir, sequences_path, **make_kwargs)
    else:
        return pd.read_csv(index_path)

In [56]:
#| hide
training_data_index = get_training_index(transcription_data_path, sequences_data_path, save=False, sample=True)

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:03<?, ?it/s]


ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

In [11]:
#| hide
training_data_index.head()

Unnamed: 0,file_index,geneid,file
0,0,GeneID:55344,NC_000023.11
1,1,GeneID:55344,NC_000023.11
2,2,GeneID:55344,NC_000023.11
3,3,GeneID:55344,NC_000023.11
4,4,GeneID:55344,NC_000023.11


In [13]:
#| export
def make_train_test_split(index: pd.DataFrame, random_state = 42):
    return train_test_split(index, random_state=random_state)

In [14]:
#| hide
train_idx, test_idx = make_train_test_split(training_data_index)

In [17]:
#| hide
train_idx.shape, test_idx.shape

((3496, 3), (1166, 3))

In [18]:
#| hide
train_idx.head()

Unnamed: 0,file_index,geneid,file
3403,3403,GeneID:159090,NC_000023.11
1815,1815,GeneID:340554,NC_000023.11
2459,2459,GeneID:2491,NC_000023.11
3066,3066,GeneID:23157,NC_000023.11
4424,128,GeneID:64591,NC_000024.10


In [52]:
#| export
def get_sequence(sequences_path: Path, file: str, idx: int) -> pd.Series:
    row = pd.read_csv(sequences_path / f"{file}.csv", header=None, skiprows=idx, nrows=1)
    return row

In [44]:
#| hide
example_sequence = get_sequence(sequences_data_path, "NC_000023.11", 3403)

In [46]:
#| hide
example_sequence.iloc[0, 1]

'A,G,G,A,C,C,C,G,A,T,G,G,G,T,G,C,C,C,G,G,A,C,G,C,G,G,A,A,G,A,A,C,T,G,G,C,C,C,A,G,C,G,G,A,G,G,T,T,C,C,C,G,C,T,T,C,T,G,A,A,G,C,G,T,G,G,G,A,G,G,C,G,G,A,A,G,A,G,A,C,T,G,C,A,G,G,T,T,G,T,A,G,A,T,T,T,T,G,G,T,C,T,T,G,G,C,C,C,C,A,A,C,G,G,G,T,T,C,C,C,C,C,A,G,A,C,C,A,T,T,A,G,A,A,C,C,C,G,A,G,A,G,A,G,G,A,A,G,A,G,G,A,G,G,C,A,C,C,G,C,G,G,C,G,G,C,G,G,C,T,G,A,G,A,G,G,C,C,G,A,A,C,C,C,C,C,A,A,C,C,T,G,T,C,C,T,C,G,A,G,G,C,G,G,G,A,G,G,G,G,G,G,C,A,G,G,G,A,G,G,G,G,A,G,A,A,G,C,A,C,C,G,A,T,C,C,C,A,G,G,G,C,T,G,A,G,G,A,G,C,C,T,C,C,G,G,G,A,C,C,G,G,G,T,C,C,G,G,G,A,G,C,A,T,G,G,G,G,G,G,C,T,G,G,A,C,C,G,G,G,T,C,C,C,C,G,A,G,C,A,T,G,G,G,G,A,G,C,T,G,G,A,G,C,C,G,C,C,A,A,G,G,C,C,G,C,C,C,G,G,C,C,G,G,G,T,C,C,T,A,C,C,C,G,A,C,T,G,T,C,A,G,G,T,C,C,T,C,G,G,T,G,C,T,A,G,C,G,C,T,G,C,C,C,G,G,G,C,A,G,C,C,G,C,A,G,C,A,T,C,T,G,T,A,G,C,C,C,G,G,C,C,C,C,G,C,T,T,G,T,G,G,T,C,C,T,C,A,G,C,C,G,G,A,A,A,G,A,T,A,G,C,A,G,C,A,T,G,A,G,T,T,C,C,G,C,C,G,C,T,G,C,C,T,A,G,T,T,T,C,C,T,C,C,C,C,C,T,C,C,T,C,T,T,C,T,C,C,C,T,C,C,T,C,T,C,T,C,C,T,C,T,T,C,C,C,A,G,A,C

In [60]:
#| export
class TranscriptionDataset(Dataset):
    def __init__(self, training_path: Path, sequences_path: Path, train: bool):
        self.training_path = training_path
        self.sequences_path = sequences_path
        self.train = train
        self.training_index = get_training_index(self.training_path, self.sequences_path, sample=False, save=True)
        self.train_idx, self.test_idx = make_train_test_split(self.training_index)

    def __len__(self) -> int:
        if self.train:
            return self.train_idx.shape[0]
        else:
            return self.test_idx.shape[0]

    def __getitem__(self, idx) -> tuple[str, str]:
        if self.train:
            sequence_row = self.train_idx.iloc[idx, :]
        else:
            sequence_row = self.test_idx.iloc[idx, :]
        sequence_file_stem = sequence_row.file
        sequence_file_idx = sequence_row.file_index
        sequence = get_sequence(sequences_data_path, sequence_file_stem, sequence_file_idx)
        sequence_input = sequence.iloc[0, 1]
        sequence_target = sequence.iloc[0, 2]
        return sequence_input, sequence_target

In [63]:
#| hide
transcription_dataset_train = TranscriptionDataset(transcription_data_path, sequences_data_path, True)
transcription_dataset_test = TranscriptionDataset(transcription_data_path, sequences_data_path, False)

In [64]:
#| hide
len(transcription_dataset_train), len(transcription_dataset_test)

(97815, 32605)

In [67]:
#| hide
for i, tup in enumerate(transcription_dataset_train):
    if i == 5:
        break
    print(tup)

('G,T,G,C,G,G,A,G,T,T,T,G,G,C,T,G,C,T,C,C,G,G,G,G,T,T,A,G,C,A,G,G,T,G,A,G,C,C,T,G,C,G,A,T,G,C,G,C,G,G,G,A,A,G,A,C,G,T,T,C,C,G,C,T,T,T,G,A,A,A,T,G,C,A,G,C,G,G,G,A,T,T,T,G,G,T,G,A,G,T,T,T,C,C,C,G,C,T,G,T,C,T,C,C,A,G,C,G,G,T,G,C,G,G,G,T,G,A,A,G,C,T,G,G,T,G,T,C,T,G,C,G,G,G,G,T,T,C,C,A,G,A,C,T,G,C,T,G,A,G,G,A,A,C,T,C,C,T,A,G,A,G,G,T,G,A,A,A,C,C,C,T,C,C,G,A,G,C,T,T,A,G,C,A,A,A,G,G,T,A,A,C,G,A,C,T,C,C,T,G,A,T,G,G,C,A,A,G,C,T,G,A,G,G,C,A,C,A,C,C,G,G,C,C,G,C,C,G,T,C,A,G,C,G,C,C,G,C,C,T,C,A,G,T,C,T,T,C,G,T,T,C,T,C,T,C,G,C,C,T,C,G,G,C,C,T,T,C,A,G,C,C,C,A,G,T,C,T,C,C,G,T,T,A,G,A,T,T,C,T,G,C,T,T,C,C,T,C,C,C,A,C,G,T,C,C,A,T,G,T,T,T,A,C,A,G,C,G,T,G,A,A,A,G,A,G,C,T,C,C,T,C,G,A,C,T,C,C,A,C,T,T,A,C,A,A,G,T,T,G,T,C,T,G,A,A,T,G,G,T,T,A,G,G,A,G,A,A,C,T,G,T,G,G,T,C,G,T,G,A,A,A,A,C,A,T,T,T,A,C,T,A,A,T,T,G,C,T,T,T,T,C,C,T,C,T,G,G,C,A,A,T,G,C,C,T,G,C,T,G,A,A,T,G,C,T,T,T,G,A,G,G,A,T,T,G,T,C,T,C,A,T,T,T,A,A,C,C,C,T,C,A,A,C,C,C,G,C,T,T,A,C,G,T,A,G,A,T,T,A,T,T,A,T,A,T,T,C,A,G,T,A,T,A,T,G,G,A,T,G,A,G,A,G,A,A,C,T,G,

ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

In [135]:
#| export
def tokenize(seq: str) -> list[str]:
    return seq.split(",")

In [136]:
#| hide
sample_train_data = transcription_dataset_train[0]

In [138]:
#| hide
len(tokenize(sample_train_data[0])), len(tokenize(sample_train_data[1]))

(43010, 43010)

In [73]:
#| hide
gene_path = latest_assembly_path / "genes"
gene_path.exists()

True

In [129]:
#| export
def count_transcription_tokens(genes_path: Path, pbar: bool = False) -> OrderedDict:
    token_counter = Counter()
    gene_files = list(genes_path.glob("*.csv"))
    if pbar:
        gene_files = tqdm(gene_files, leave=False, position=0)
    for f in gene_files:
        f_sequences = pd.read_csv(f, usecols=['sequence']).sequence
        if pbar:
            f_sequences = tqdm(f_sequences, leave=False, position=1)
        f_counter = sum(map(Counter, f_sequences), Counter())
        token_counter = token_counter + f_counter
    token_ordered_dict = OrderedDict(token_counter.most_common())
    return token_ordered_dict


def build_vocab(
        genes_path: Path, 
        pbar: bool = False, 
        intron_token: str = "<intron>", unknown_token: str = "<unk>", 
        default_index: int = -1):
    token_ordered_dict = count_transcription_tokens(genes_path=genes_path, pbar=pbar)
    transcription_vocab = vocab(token_ordered_dict, specials=[intron_token, unknown_token])
    transcription_vocab.set_default_index(default_index)
    return transcription_vocab


def get_vocab(
        training_path: Path,
        **build_vocab_kwargs
        ) -> Vocab:
    vocab_path = training_path / "vocab.pt"
    if not vocab_path.exists():
        transcription_vocab = build_vocab(**build_vocab_kwargs)
        torch.save(transcription_vocab, vocab_path)
    else:
        transcription_vocab = torch.load(vocab_path)
    return transcription_vocab

In [132]:
#| hide
transcription_vocab = get_vocab(transcription_data_path, pbar=True, genes_path=gene_path)

In [133]:
#| hide
len(transcription_vocab)

13

In [146]:
#| export
def process_training_sequence(sequence: tuple[str, str], data_vocab: Vocab) -> tuple[Tensor, Tensor]:
    """Converts raw text into a flat Tensor."""
    data = (
        torch.tensor(data_vocab(tokenize(sequence[0])), dtype=torch.long),
        torch.tensor(data_vocab(tokenize(sequence[1])), dtype=torch.long),
    )
    return data

In [150]:
#| hide
processed_example_sequence = process_training_sequence(sample_train_data, transcription_vocab)
processed_example_sequence[0], processed_example_sequence[1]

(tensor([4, 2, 4,  ..., 5, 2, 5]), tensor([4, 2, 4,  ..., 5, 2, 5]))

In [151]:
#| hide
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
def batchify(sequence_tensors: tuple[Tensor, Tensor], bsz: int) -> Tensor:
    """Divides the data into ``bsz`` separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Arguments:
        data: Tensor, shape ``[N]``
        bsz: int, batch size

    Returns:
        Tensor of shape ``[N // bsz, bsz]``
    """
    seq_len = sequence_tensors[0].size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)