In [1]:
#| default_exp evaluation.transcription

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
from pathlib import Path
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn
import torch
from torchtext.vocab import Vocab
from tqdm import tqdm
import math

from llm_mito_scanner.data.download import load_config, get_latest_assembly_path
from llm_mito_scanner.training.transcription.generation import get_genes, get_mrna
from llm_mito_scanner.training.transcription.train import translate, Seq2SeqTransformer,\
    get_vocab, get_text_transform, set_vocab_idx

In [4]:
#| hide
config = load_config()

In [5]:
#| hide
data_path = Path(config.get("data_path"))
data_raw_path = data_path / "raw"
assemblies_path = data_raw_path / "assemblies"
latest_assembly_path = get_latest_assembly_path(assemblies_path)
chromosomes_path = latest_assembly_path / "chromosomes"
training_data_path = latest_assembly_path / "training"
transcription_data_path = training_data_path / "transcription"
sequences_data_path = transcription_data_path / "sequences"

random_state = 42
epochs = 10

In [6]:
#| hide
training_index = pd.read_csv(transcription_data_path / "training_sequence_idx.csv")
batch_idx = []
for i, f in enumerate(np.array_split(training_index, epochs)):
    training_index.loc[f.index.values.tolist(), 'epoch'] = int(i)
training_index.head()

Unnamed: 0,chromosome,geneid,transcriptid,start,end,type,epoch
0,NC_000019.10,GeneID:946,XM_011527538.4,57,121,intron-small,0.0
1,NC_000006.12,GeneID:124901227,XM_047419609.1,0,64,intron-small,0.0
2,NC_000015.10,GeneID:124903566,XM_047433427.1,7772,7836,intron-small,0.0
3,NC_000002.12,GeneID:375318,NM_198998.3,139,203,intron-small,0.0
4,NC_000006.12,GeneID:55173,XM_017010996.2,3550,3614,intron-small,0.0


In [7]:
#| hide
if isinstance(training_index, pd.DataFrame):
    training_index = [training_index[training_index.epoch == e] for e in training_index.epoch.unique()]
training_index[0].head()

Unnamed: 0,chromosome,geneid,transcriptid,start,end,type,epoch
0,NC_000019.10,GeneID:946,XM_011527538.4,57,121,intron-small,0.0
1,NC_000006.12,GeneID:124901227,XM_047419609.1,0,64,intron-small,0.0
2,NC_000015.10,GeneID:124903566,XM_047433427.1,7772,7836,intron-small,0.0
3,NC_000002.12,GeneID:375318,NM_198998.3,139,203,intron-small,0.0
4,NC_000006.12,GeneID:55173,XM_017010996.2,3550,3614,intron-small,0.0


In [8]:
#| hide
training_data_batches = list(sequences_data_path.glob("epoch-*/batch-*"))
len(training_data_batches)

102

In [9]:
#| hide
training_instance_num = 0
for b in training_data_batches:
    training_instance_num += pd.read_parquet(b).shape[0]
training_instance_num

1000000

In [10]:
#| hide
train_test_indices = [train_test_split(f, random_state=42) for f in training_index]

In [11]:
#| hide
frames = []
for f_train, f_test in train_test_indices:
    f_train.loc[:, 'mode'] = 'train'
    f_test.loc[:, 'mode'] = 'test'
    frames.append(f_train)
    frames.append(f_test)
training_index_labeled = pd.concat(frames, axis=0, ignore_index=True)
training_index_labeled.head()

Unnamed: 0,chromosome,geneid,transcriptid,start,end,type,epoch,mode
0,NC_000014.9,GeneID:29091,NM_001394410.1,216557,216621,intron,0.0,train
1,NC_000012.12,GeneID:54477,NM_001385955.1,101679,101743,intron,0.0,train
2,NC_000001.11,GeneID:51127,XM_017001419.2,7504,7568,intron,0.0,train
3,NC_000001.11,GeneID:9877,NM_001376366.1,34413,34477,intron,0.0,train
4,NC_000017.11,GeneID:5636,NM_001243940.1,49878,49942,intron,0.0,train


In [12]:
#| export
def get_latest_checkpoint(checkpoint_path: Path) -> dict:
    checkpoints = pd.DataFrame(list(checkpoint_path.glob("epoch-*-model.pt")), columns=['path'])
    checkpoints.loc[:, 'checkpoint'] = checkpoints.path.apply(lambda p: p.stem.split("-")[1]).astype(int)
    checkpoints.sort_values("checkpoint", inplace=True, ascending=True)
    latest_checkpoint_path = checkpoints.iloc[-1, 0]
    checkpoint = torch.load(latest_checkpoint_path)
    return checkpoint


def get_latest_model(
        vocab: Vocab, 
        state_dict: dict,
        encoder_layers: int = 1, 
        decoder_layers: int = 1,
        embedding_size: int = 32,
        nheads: int = 4,
        feed_forward_dim: int = 32,
        device: str = torch.device('cuda' if torch.cuda.is_available() else 'cpu')) -> Seq2SeqTransformer:
    SRC_VOCAB_SIZE = TGT_VOCAB_SIZE = len(vocab)
    # Define model
    model = Seq2SeqTransformer(
        encoder_layers, decoder_layers, embedding_size,
        nheads, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, feed_forward_dim)
    model.load_state_dict(state_dict)
    model.to(device)
    return model

In [13]:
#| hide
vocab = get_vocab(transcription_data_path)
len(vocab)

13

In [14]:
#| hide
# Set properties of vocab
get_text_transform(vocab)
set_vocab_idx(vocab)

In [15]:
#| hide
latest_checkpoint = get_latest_checkpoint(transcription_data_path / "checkpoints")
latest_model_state = latest_checkpoint.get("model_state_dict")
list(latest_model_state.keys())

['transformer.encoder.layers.0.self_attn.in_proj_weight',
 'transformer.encoder.layers.0.self_attn.in_proj_bias',
 'transformer.encoder.layers.0.self_attn.out_proj.weight',
 'transformer.encoder.layers.0.self_attn.out_proj.bias',
 'transformer.encoder.layers.0.linear1.weight',
 'transformer.encoder.layers.0.linear1.bias',
 'transformer.encoder.layers.0.linear2.weight',
 'transformer.encoder.layers.0.linear2.bias',
 'transformer.encoder.layers.0.norm1.weight',
 'transformer.encoder.layers.0.norm1.bias',
 'transformer.encoder.layers.0.norm2.weight',
 'transformer.encoder.layers.0.norm2.bias',
 'transformer.encoder.norm.weight',
 'transformer.encoder.norm.bias',
 'transformer.decoder.layers.0.self_attn.in_proj_weight',
 'transformer.decoder.layers.0.self_attn.in_proj_bias',
 'transformer.decoder.layers.0.self_attn.out_proj.weight',
 'transformer.decoder.layers.0.self_attn.out_proj.bias',
 'transformer.decoder.layers.0.multihead_attn.in_proj_weight',
 'transformer.decoder.layers.0.multihea

In [16]:
#| hide
model = get_latest_model(vocab, latest_model_state)
type(model)

llm_mito_scanner.training.transcription.train.Seq2SeqTransformer

In [17]:
#| hide
example_geneid = "GeneID:10000"
example_gene_sequence = get_genes(latest_assembly_path, gene_ids=[example_geneid], limit=1).iloc[0, :]
example_gene_sequence_input = list(example_gene_sequence.sequence)
len(example_gene_sequence_input)

362847

In [18]:
#| hide
example_mrna_sequences = get_mrna(latest_assembly_path, chromosome=example_gene_sequence.chromosome, gene_ids=[example_geneid], limit=5)
example_mrna_sequences

Unnamed: 0,chromosome,geneid,transcriptid,sequence,start,end
0,NC_000001.11,GeneID:10000,NM_001206729.2,"<null>,<null>,<null>,<null>,<null>,<null>,<nul...",7495,351356
1,NC_000001.11,GeneID:10000,NM_001370074.1,"A,U,U,G,G,G,C,A,C,C,G,C,C,C,A,C,U,U,C,G,U,G,G,...",0,351356
2,NC_000001.11,GeneID:10000,NM_005465.7,"<null>,<null>,<null>,<null>,<null>,<null>,<nul...",836,351356
3,NC_000001.11,GeneID:10000,NM_181690.2,"<null>,<null>,<null>,<null>,<null>,<null>,<nul...",7797,362847
4,NC_000001.11,GeneID:10000,XM_011544014.3,"<null>,<null>,<null>,<null>,<null>,<null>,<nul...",284069,351356


In [40]:
#| export
def transcribe(
        gene: list[str], 
        model: Seq2SeqTransformer, 
        vocab: Vocab, 
        length: int = 64, 
        pbar_position: int = 0) -> str:
    num_batches = max(1, (len(gene) // length) + 1)
    transcribed_tokens = []
    batch_pbar = tqdm(total=num_batches, 
        ncols=80, 
        leave=False, 
        miniters=5, 
        desc="Translating", 
        position=pbar_position)
    batch_range = range(0, num_batches)
    gene_length = len(gene)
    for i in batch_range:
        start = i * length
        end = min(start + length, gene_length)
        gene_batch = gene[start: end]
        transcribed_gene_batch = translate(model=model, vocab=vocab, src_sentence=gene_batch)
        transcribed_gene_batch = transcribed_gene_batch.split(",")
        transcribed_gene_batch = [s for s in transcribed_gene_batch if len(s) > 0]
        transcribed_tokens.extend(transcribed_gene_batch)
        batch_pbar.update(1)
    batch_pbar.close()
    return ",".join(transcribed_tokens)

In [41]:
#| hide
example_sequence = list("GACTTTTTTGTTGCAACCTCTTAGGTTAAAAGTTTCACTATCATTTGAAATTGGTCACAAGACTTTAGCCGA")
len(example_sequence)

72

In [42]:
#| hide
len(example_sequence) // 64

1

In [43]:
#| hide
print(len(example_sequence))
example_transcription = transcribe(
    example_sequence,
    model,
    vocab
)
len(example_transcription.split(",")), len(example_sequence)

72


                                                                                

(78, 72)

In [44]:
#| hide
example_transcription

'G,A,C,U,U,U,U,U,U,G,U,U,G,C,A,A,C,C,U,C,U,U,A,G,G,U,U,A,A,A,A,G,U,U,U,C,A,C,U,A,U,C,A,U,U,U,G,A,A,A,U,U,G,G,U,C,A,C,A,A,G,A,C,U,U,U,A,G,C,C,G,A,A,G,G,U,U,U'

In [45]:
#| export
def update_pbar(result, pbar):
    pbar.update(1)
    return result

In [50]:
#| hide
pbar = tqdm(total=5, ncols=80, leave=True)
sample_len = 64 * 20
try:
    predicted_transcriptions = example_mrna_sequences.apply(
        lambda row: update_pbar(
            transcribe(
                gene = example_gene_sequence_input[row.start: row.start + sample_len],
                model=model,
                vocab=vocab,
                pbar_position=1
            ), pbar),
            axis=1)
    predicted_transcriptions.name = "predicted"
except Exception as e:
    raise e
finally:
    pbar.close()

100%|█████████████████████████████████████████████| 5/5 [00:13<00:00,  2.65s/it]


In [51]:
#| hide
predicted_transcriptions[0]

'C,U,C,A,A,A,U,A,C,A,C,A,U,C,A,C,C,A,A,A,C,A,A,A,U,U,U,U,C,U,C,U,A,U,U,A,U,U,U,G,G,G,U,A,G,G,C,G,U,G,A,C,U,G,G,U,U,U,U,C,U,U,A,A,G,A,C,U,U,U,U,U,U,G,U,U,G,C,A,A,C,C,U,C,U,U,A,G,G,U,U,A,A,A,A,G,U,U,U,C,A,C,U,A,U,C,A,U,U,U,G,A,A,A,U,U,G,G,U,C,A,C,A,A,G,A,C,U,A,G,G,G,A,A,G,U,G,C,U,U,U,C,A,U,U,A,U,A,G,A,A,C,U,A,U,U,U,A,A,U,A,A,A,U,A,A,G,U,U,C,C,C,C,A,G,U,U,U,G,A,A,G,A,G,C,C,A,G,A,C,U,U,U,U,A,U,G,U,G,A,G,G,U,C,A,G,G,C,C,A,G,U,U,G,A,A,G,A,C,A,U,U,U,A,C,A,A,A,G,A,A,U,U,A,G,U,U,G,U,U,U,G,U,U,A,U,U,G,C,U,C,U,G,U,G,A,G,U,U,G,C,A,A,G,A,A,U,G,G,A,A,A,A,A,A,A,A,U,U,C,U,U,U,C,U,U,C,A,A,U,A,C,U,U,C,C,U,U,C,C,A,G,G,C,U,G,A,G,U,C,A,U,C,A,C,U,A,G,A,G,A,G,U,G,G,G,A,A,G,G,G,C,A,G,C,A,G,C,A,G,C,A,G,A,G,A,A,U,C,C,A,A,A,C,C,C,U,A,A,A,G,C,U,G,A,U,A,U,C,A,C,A,A,A,G,U,A,C,C,A,U,U,U,C,U,C,C,A,A,G,U,U,G,G,G,G,G,C,U,C,A,G,A,G,G,G,G,A,G,U,C,A,U,C,A,U,G,A,G,C,G,A,U,G,U,U,A,C,C,A,U,U,G,U,G,A,A,A,G,A,A,G,G,U,U,G,G,G,U,U,C,A,G,A,A,G,A,G,G,G,G,U,A,A,G,U,G,C,U,C,C,G,C,A,A,A,C,C,A,A,A,A,A,U,A,A,U,A,C,G,G,U,U,G,G,U,A,A,G,A

In [53]:
#| hide
example_predictions = pd.concat(
    [
        example_mrna_sequences,
        predicted_transcriptions
    ], axis=1)
example_predictions

Unnamed: 0,chromosome,geneid,transcriptid,sequence,start,end,predicted
0,NC_000001.11,GeneID:10000,NM_001206729.2,"<null>,<null>,<null>,<null>,<null>,<null>,<nul...",7495,351356,"C,U,C,A,A,A,U,A,C,A,C,A,U,C,A,C,C,A,A,A,C,A,A,..."
1,NC_000001.11,GeneID:10000,NM_001370074.1,"A,U,U,G,G,G,C,A,C,C,G,C,C,C,A,C,U,U,C,G,U,G,G,...",0,351356,"A,U,U,G,G,G,C,A,C,C,G,C,C,C,A,C,U,U,C,G,U,G,G,..."
2,NC_000001.11,GeneID:10000,NM_005465.7,"<null>,<null>,<null>,<null>,<null>,<null>,<nul...",836,351356,"G,C,A,G,C,C,C,U,U,C,G,C,U,U,G,C,C,C,U,C,C,C,G,..."
3,NC_000001.11,GeneID:10000,NM_181690.2,"<null>,<null>,<null>,<null>,<null>,<null>,<nul...",7797,362847,"G,C,U,G,A,G,U,C,A,U,C,A,C,U,A,G,A,G,A,G,U,G,G,..."
4,NC_000001.11,GeneID:10000,XM_011544014.3,"<null>,<null>,<null>,<null>,<null>,<null>,<nul...",284069,351356,"A,C,U,G,U,G,C,C,U,A,G,C,C,U,G,U,G,U,U,U,U,G,U,..."


In [58]:
#| hide
for i in range(example_predictions.shape[0]):
    print(i)
    print(example_predictions.iloc[i, :].sequence[:sample_len])
    print(example_predictions.iloc[i, :].predicted[:sample_len])
    print()

0
<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<null>,<nul

# What did I learn?

- Engineering text is difficult
- There is a LOT of genomic data to utilize, engineer
- Utilizing sqlite for this much text data is a good idea
- Annotating mRNA from genomic DNA requires a lot of memory
    - Some gene sequences are really large
- Accurately transcribing genomic DNA requires good recognition of;
    - Where introns start, end
    - Where transcription starts, ends
- My model does well;
    - On sequences that are actually transcribed
- My model needs to do better;
    - Identifying intronic sequence start and end locations
    - Identifying when mRNA transcription starts, ends

In [59]:
#| hide
import nbdev; nbdev.nbdev_export()