In [1]:
#| default_exp training.transcription.index

In [2]:
#| export
import os
from pathlib import Path
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torchtext.vocab import vocab, Vocab
from collections import Counter, OrderedDict
from multiprocessing import Pool

tqdm.pandas()

from llm_mito_scanner.data.download import load_config, \
    get_latest_assembly_path, get_genomic_genbank_path

random_state = 42



In [3]:
#| hide
config = load_config()

In [4]:
#| hide
data_path = Path(config.get("data_path"))
data_raw_path = data_path / "raw"
assemblies_path = data_raw_path / "assemblies"
latest_assembly_path = get_latest_assembly_path(assemblies_path)
genomic_genbank_path = get_genomic_genbank_path(latest_assembly_path)
chromosomes_path = latest_assembly_path / "chromosomes"
training_data_path = latest_assembly_path / "training"
transcription_data_path = training_data_path / "transcription"
sequences_data_path = transcription_data_path / "sequences"
if not sequences_data_path.exists():
    raise Exception(f"This notebook expects the path at {sequences_data_path.resolve()}")

In [5]:
#| hide
chromosome_parquet_files = list(sequences_data_path.glob("chromosome=*/partition=*/*.parquet"))
chromosome_parquet_files[:3]

[PosixPath('/mnt/e/Data/llm-mito-scanner-data/data/raw/assemblies/GCF_000001405.40_GRCh38.p14/training/transcription/sequences/chromosome=NC_000001.11/partition=0/sequences.parquet'),
 PosixPath('/mnt/e/Data/llm-mito-scanner-data/data/raw/assemblies/GCF_000001405.40_GRCh38.p14/training/transcription/sequences/chromosome=NC_000001.11/partition=1/sequences.parquet'),
 PosixPath('/mnt/e/Data/llm-mito-scanner-data/data/raw/assemblies/GCF_000001405.40_GRCh38.p14/training/transcription/sequences/chromosome=NC_000001.11/partition=10/sequences.parquet')]

In [6]:
#| hide
parquet_file_df = pd.DataFrame(chromosome_parquet_files, columns=['path'])

In [7]:
#| hide
parquet_file_df_path_split = parquet_file_df.path.apply(lambda p: str(p).rsplit("/"))

parquet_file_df.loc[:, 'chromosome'] = parquet_file_df_path_split.apply(lambda split_path: split_path[-3].split("=")[-1])
parquet_file_df.loc[:, 'partition'] = parquet_file_df_path_split.apply(lambda split_path: int(split_path[-2].split("=")[-1]))
parquet_file_df.sort_values(['chromosome', 'partition'], inplace=True)
parquet_file_df.reset_index(drop=True, inplace=True)
parquet_file_df.head()


Unnamed: 0,path,chromosome,partition
0,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000001.11,0
1,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000001.11,1
2,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000001.11,2
3,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000001.11,3
4,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000001.11,4


In [8]:
#| export
def index_training_sequence_files(sequences_path: Path) -> pd.DataFrame:
    # Index files
    chromosome_parquet_files = list(sequences_path.glob("chromosome=*/partition=*/*.parquet"))
    parquet_file_df = pd.DataFrame(chromosome_parquet_files, columns=['path'])
    # Extract partition info
    parquet_file_df_path_split = parquet_file_df.path.apply(lambda p: str(p).rsplit("/"))
    parquet_file_df.loc[:, 'chromosome'] = parquet_file_df_path_split.apply(lambda split_path: split_path[-3].split("=")[-1])
    parquet_file_df.loc[:, 'partition'] = parquet_file_df_path_split.apply(lambda split_path: int(split_path[-2].split("=")[-1]))
    # Sort
    parquet_file_df.sort_values(['chromosome', 'partition'], ascending=True, inplace=True)
    parquet_file_df.reset_index(drop=True, inplace=True)
    return parquet_file_df

In [9]:
#| export
def make_training_index(index_dir: Path, sequences_path: Path, sample: bool = False, save: bool = False):
    sequence_files = index_training_sequence_files(sequences_path)
    if sample:
        sequence_files = sequence_files.tail(2)
    frames = []
    for _, row in sequence_files.iterrows():
        f_frame = pd.read_parquet(row.path, columns=["geneid", 'transcriptid']).reset_index(drop=False).rename({"index": "file_index"}, axis=1)
        f_frame.loc[:, 'file'] = row.path
        f_frame.loc[:, 'chromosome'] = row.chromosome
        f_frame.loc[:, 'partition'] = row.partition
        frames.append(f_frame)
    training_data_index = pd.concat(
        frames, 
        axis=0, ignore_index=True
    )
    training_data_index.sort_values(['chromosome', 'partition'], ascending=True, inplace=True)
    training_data_index.reset_index(drop=True, inplace=True)
    if save:
        training_data_index.to_csv(index_dir / "index.csv", index=False)
    return training_data_index


def get_training_index(index_dir: Path, sequences_path: Path = None, force: bool = False, **make_kwargs):
    index_path = index_dir / "index.csv"
    if not index_path.exists() or force:
        return make_training_index(index_dir, sequences_path, **make_kwargs)
    else:
        return pd.read_csv(index_path)

In [10]:
#| hide
example_training_data_index = get_training_index(transcription_data_path, sequences_data_path, force=True, save=False, sample=True)

In [11]:
#| hide
example_training_data_index.head()

Unnamed: 0,file_index,geneid,transcriptid,file,chromosome,partition
0,0,GeneID:55796,NM_001386899.1,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000023.11,7
1,1,GeneID:55796,NM_001386889.1,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000023.11,7
2,2,GeneID:55796,NM_001386891.1,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000023.11,7
3,3,GeneID:55796,NM_001386896.1,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000023.11,7
4,4,GeneID:55796,NM_018388.4,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000023.11,7


In [12]:
#| export
def make_train_test_split(index: pd.DataFrame, random_state = 42):
    return train_test_split(index, random_state=random_state)

In [13]:
#| hide
train_idx, test_idx = make_train_test_split(example_training_data_index)

In [14]:
#| hide
train_idx.shape, test_idx.shape

((983, 6), (328, 6))

In [15]:
#| hide
train_idx.head()

Unnamed: 0,file_index,geneid,transcriptid,file,chromosome,partition
1117,172,GeneID:7404,XM_011531441.4,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000024.10,0
708,708,GeneID:6748,XM_047442389.1,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000023.11,7
536,536,GeneID:266740,NM_001321403.1,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000023.11,7
808,808,GeneID:6901,NM_001303465.2,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000023.11,7
682,682,GeneID:139728,NM_001366976.1,/mnt/e/Data/llm-mito-scanner-data/data/raw/ass...,NC_000023.11,7


In [16]:
#| export
def get_sequence(file: Path, idx: int) -> pd.Series:
    row = pd.read_parquet(file).iloc[idx, :]
    return row

In [17]:
#| hide
example_sequence = get_sequence(train_idx.iloc[0, -3], 2)

In [18]:
#| hide
example_sequence.input.count(",")

27000

In [19]:
#| export
class TranscriptionDataset(Dataset):
    def __init__(self, training_path: Path, sequences_path: Path, train: bool):
        self.training_path = training_path
        self.sequences_path = sequences_path
        self.train = train
        self.training_index = get_training_index(self.training_path, self.sequences_path, sample=False, save=True)
        self.train_idx, self.test_idx = make_train_test_split(self.training_index)

    def filter_chromosome(self, chromosome: str):
        index_chromosomes = self.training_index.chromosome.unique()
        if not chromosome in index_chromosomes:
            raise ValueError(f"Chromosome {chromosome} not found in training data.")
        filtered_training_index = self.training_index[self.training_index.chromosome == chromosome]
        self.training_index = filtered_training_index
        self.train_idx, self.test_idx = make_train_test_split(self.training_index)

    def __len__(self) -> int:
        if self.train:
            return self.train_idx.shape[0]
        else:
            return self.test_idx.shape[0]

    def __getitem__(self, idx) -> tuple[str, str]:
        if self.train:
            sequence_row = self.train_idx.iloc[idx, :]
        else:
            sequence_row = self.test_idx.iloc[idx, :]
        sequence = get_sequence(sequence_row.file, sequence_row.file_index)
        return sequence.input, sequence.target

In [20]:
#| hide
transcription_dataset_train = TranscriptionDataset(transcription_data_path, sequences_data_path, True)
transcription_dataset_test = TranscriptionDataset(transcription_data_path, sequences_data_path, False)

In [21]:
#| hide
len(transcription_dataset_train), len(transcription_dataset_test)

(97815, 32605)

In [23]:
#| hide
transcription_dataset_train.filter_chromosome("NC_000001.11")
transcription_dataset_test.filter_chromosome("NC_000001.11")
len(transcription_dataset_train), len(transcription_dataset_test)

(9630, 3211)

In [37]:
#| hide
for i, tup in enumerate(transcription_dataset_train):
    if i == 2:
        break
    print(tup[0].count(","))

13415
122052


In [38]:
#| export
def tokenize(seq: str) -> list[str]:
    return seq.split(",")

In [39]:
#| hide
sample_train_data = transcription_dataset_train[0]

In [40]:
#| hide
len(tokenize(sample_train_data[0])), len(tokenize(sample_train_data[1]))

(13416, 13416)

In [41]:
#| hide
gene_path = latest_assembly_path / "genes"
gene_path.exists()

True

In [42]:
#| hide
Counter("C,D".split(","))

Counter({'C': 1, 'D': 1})

In [43]:
#| export
def count_transcription_tokens(parquet_path: Path) -> Counter:
    token_counter = Counter()
    sequences = pd.read_parquet(parquet_path, columns=['input', 'target'])
    input_counter = sum(sequences.input.apply(tokenize).apply(Counter).values.tolist(), Counter())
    target_counter = sum(sequences.target.apply(tokenize).apply(Counter).values.tolist(), Counter())
    token_counter = input_counter + target_counter
    return token_counter


UNK_TOKEN = "<unk>"
PAD_TOKEN = "<pad>"
BOS_TOKEN = "<bos>"
EOS_TOKEN = "<eos>"
SPECIAL_TOKENS = [
    UNK_TOKEN,
    PAD_TOKEN,
    BOS_TOKEN,
    EOS_TOKEN
]


def build_vocab(
        parquet_files: list[Path], 
        special_tokens: list[str] = SPECIAL_TOKENS,
        unknown_token: str = UNK_TOKEN):
    counter = Counter()
    max_processes = min(8, os.cpu_count() - 1)
    pool = Pool(
        processes=min(max_processes, len(parquet_files)))
    try:
        pbar = tqdm(total=len(parquet_files), leave=False)
        for c in pool.imap_unordered(count_transcription_tokens, parquet_files):
            counter = counter + c
            pbar.update(1)
    except Exception as e:
        raise e
    finally:
        pbar.close()
        pool.close()
    token_ordered_dict = OrderedDict(counter.most_common())
    transcription_vocab = vocab(token_ordered_dict, specials=special_tokens, special_first=True)
    unk_index = transcription_vocab[unknown_token]
    transcription_vocab.set_default_index(unk_index)
    return transcription_vocab


def get_vocab(
        training_path: Path,
        force_build: bool = False,
        save: bool = True,
        **build_vocab_kwargs
        ) -> Vocab:
    vocab_path = training_path / "vocab.pt"
    if not vocab_path.exists() or force_build:
        transcription_vocab = build_vocab(**build_vocab_kwargs)
        if save:
            torch.save(transcription_vocab, vocab_path)
    else:
        transcription_vocab = torch.load(vocab_path)
    return transcription_vocab

In [44]:
#| hide
example_transcription_vocab = get_vocab(
    transcription_data_path, 
    force_build=True, 
    save=False,
    parquet_files=chromosome_parquet_files[-2:]
)

                                             

In [45]:
#| hide
len(example_transcription_vocab)

9

In [46]:
#| hide
example_transcription_vocab.lookup_tokens(list(range(0, len(example_transcription_vocab))))

['<unk>', '<pad>', '<bos>', '<eos>', '<intron>', 'T', 'A', 'G', 'C']

In [47]:
#| export
def process_training_sequence(sequence: tuple[str, str], data_vocab: Vocab) -> tuple[Tensor, Tensor]:
    """Converts raw text into a flat Tensor."""
    input_tensor = torch.tensor(data_vocab(tokenize(sequence[0])), dtype=torch.long)
    target_tensor = torch.tensor(data_vocab(tokenize(sequence[1])), dtype=torch.long)
    return input_tensor, target_tensor

In [48]:
#| hide
processed_example_sequence = process_training_sequence(sample_train_data, example_transcription_vocab)
processed_example_sequence[0], processed_example_sequence[1]

(tensor([7, 5, 7,  ..., 7, 6, 6]), tensor([7, 5, 7,  ..., 7, 6, 6]))

In [49]:
#| hide
processed_example_sequence[0].shape

torch.Size([13416])

In [50]:
#| hide
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [51]:
#| export
def batchify_sequence(sequence: Tensor, bsz: int) -> Tensor:
    global device
    seq_len = sequence.size(0) // bsz
    data = sequence[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)


def batchify(sequence_tensors: tuple[Tensor, Tensor], bsz: int) -> tuple[Tensor, Tensor]:
    """Divides the data into ``bsz`` separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Arguments:
        data: Tensor, shape ``[N]``
        bsz: int, batch size

    Returns:
        Tensor of shape ``[N // bsz, bsz]``
    """
    input_batches = batchify_sequence(sequence_tensors[0], bsz)
    target_batches = batchify_sequence(sequence_tensors[1], bsz)
    return input_batches, target_batches

In [52]:
#| hide
batch_size = 100

In [53]:
#| export
def get_batch(input: Tensor, target: Tensor, i: int, bptt: int = 35) -> tuple[Tensor, Tensor]:
    global device
    seq_len = min(bptt, len(input) - 1 - i)
    data = input[i:i+seq_len].to(device)
    target = target[i:i+seq_len].to(device)
    return data, target

In [54]:
#| hide
bptt = 35
eval_batch_size = 10

processed_example_sequence_input, processed_example_sequence_target = processed_example_sequence
sequences_batch_num = (processed_example_sequence_input.shape[0] - 1) // bptt
for batch, i in enumerate(range(0, sequences_batch_num, bptt)):
    # Get Batch
    data, target = get_batch(processed_example_sequence_input, processed_example_sequence_target, i, bptt)
    break

In [55]:
#| hide
print(sequences_batch_num)
data, target, data.shape, target.shape

383


(tensor([7, 5, 7, 8, 7, 7, 6, 7, 5, 5, 5, 7, 7, 8, 5, 7, 8, 5, 8, 8, 7, 7, 7, 7,
         5, 5, 6, 7, 8, 6, 7, 7, 5, 7, 6], device='cuda:0'),
 tensor([7, 5, 7, 8, 7, 7, 6, 7, 5, 5, 5, 7, 7, 8, 5, 7, 8, 5, 8, 8, 7, 7, 7, 7,
         5, 5, 6, 7, 8, 6, 7, 7, 5, 7, 6], device='cuda:0'),
 torch.Size([35]),
 torch.Size([35]))

In [24]:
#| hide
import nbdev; nbdev.nbdev_export()