# Making the Reference Genome more easily usable

> Preparing genomic DNA for pre-training.

## Notebook Goals

Extract input and output sequences to train a transformer to transcribe DNA to mRNA.

## Stuff I've learned

1. Calling SeqFeature.extract is really expensive if the parent sequence is really long.

### Plan

1. For each chromosome file;
    1. Extract all gene sequences to csv e.g. `assembly_path/genes/{chromosome}.csv`
        1. Columns
            1. geneid
            2. Sequence
            3. Positive strand 5' chromosome position
            4. Negative strand 5' chromosome position
    1. Get mRNA feature details
        1. parent geneid
        2. transcript_id
        3. location parts
    2. Identify mRNA to gene relationships, de-dupe, write e.g. `assembly_path/relationships/mrna_to_gene/{chromosome}.csv`
2. For all mRNA to gene relationships
    1. Normalize mRNA positions to written gene
    2. Construct synthetic SeqFeature using first and last position of the mRNA (this is input sequence)
    3. Construct mRNA sequence with intron tokens
    4. Collect [3] and [4]
    5. Write training instances to disk

## 0. Setup

In [1]:
#| default_exp data.transcription

In [2]:
#| hide
from nbdev.showdoc import *

In [18]:
#| export
import os
from pathlib import Path
from Bio import SeqIO, SeqRecord
from Bio.SeqFeature import SeqFeature, Position, SimpleLocation
from tqdm.auto import tqdm
import pandas as pd
import typing
from tqdm import tqdm
from multiprocessing import current_process

tqdm.pandas()

from llm_mito_scanner.data.download import load_config, \
    get_latest_assembly_path, get_genomic_genbank_path

In [4]:
#| hide
config = load_config()

In [5]:
#| hide
data_path = Path(config.get("data_path"))
data_raw_path = data_path / "raw"
assemblies_path = data_raw_path / "assemblies"
latest_assembly_path = get_latest_assembly_path(assemblies_path)
genomic_genbank_path = get_genomic_genbank_path(latest_assembly_path)
chromosomes_path = latest_assembly_path / "chromosomes"
training_data_path = latest_assembly_path / "training"
transcription_data_path = training_data_path / "transcription"
if not transcription_data_path.exists():
    transcription_data_path.mkdir(parents=True)

## 1. Process chromosome files

### 1.1 Extract all gene sequences to csv e.g. `assembly_path/genes/{chromosome}.csv`
Columns
1. geneid
2. Sequence
3. Positive strand 5' chromosome position
4. Negative strand 5' chromosome position

In [7]:
#| hide
example_chromosome_path = next(chromosomes_path.glob("*.gb"))
example_chromosome_path

Path('/mnt/e/Data/llm-mito-scanner-data/data/raw/assemblies/GCF_000001405.40_GRCh38.p14/chromosomes/NC_000001.11.gb')

In [8]:
#| hide
with example_chromosome_path.open("rt") as f:
    example_chromosome = next(SeqIO.parse(f, "genbank"), None)
example_chromosome

SeqRecord(seq=Seq('NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN...NNN'), id='NC_000001.11', name='NC_000001', description='Homo sapiens chromosome 1, GRCh38.p14 Primary Assembly', dbxrefs=['BioProject:PRJNA168', 'Assembly:GCF_000001405.40'])

In [9]:
#| hide
len(example_chromosome.features)

51489

In [10]:
#| export
def filter_chromosome_features_by_type(
        chromosome_record: SeqRecord, 
        feature_type: str,
        ) -> list[tuple[int, SeqFeature]]:
    return [(i, f) for i, f in enumerate(chromosome_record.features) if f.type == feature_type]

In [11]:
#| export
def get_feature_qualifiers(feature: SeqFeature) -> typing.Dict[str, typing.Any]:
    return getattr(feature, "qualifiers", None)


def get_feature_dbxrefs(feature: SeqFeature) -> str | None:
    feature_qualifiers = get_feature_qualifiers(feature)
    if feature_qualifiers is None:
        return None
    feature_dbxrefs = feature_qualifiers.get("db_xref", None)
    return feature_dbxrefs


def get_feature_dbxref_xref(feature: SeqFeature, prefix: str) -> str | None:
    feature_dbxrefs = get_feature_dbxrefs(feature)
    if feature_dbxrefs is None:
        return None
    tag_db_xref = next(iter([x for x in feature_dbxrefs if x.startswith(prefix)]), None)
    return tag_db_xref


def get_feature_geneid(feature: SeqFeature) -> str | None:
    return get_feature_dbxref_xref(feature, "GeneID")

In [12]:
#| hide
example_chromosome_genes = filter_chromosome_features_by_type(example_chromosome, "gene")
len(example_chromosome_genes)

5501

In [13]:
#| hide
example_chromosome_gene_positive = example_chromosome_genes[0][1]
example_chromosome_gene_negative = example_chromosome_genes[1][1]

In [14]:
#| hide
example_chromosome_gene_positive.location

SimpleLocation(ExactPosition(11873), ExactPosition(14409), strand=1)

In [15]:
#| hide
# Get pos and neg strand positions
int(example_chromosome_gene_positive.location.start)

11873

In [16]:
#| hide
example_chromosome_gene_negative.location.end

ExactPosition(29370)

In [17]:
#| hide
int(example_chromosome_gene_negative.location.start)

14361

In [49]:
#| export
def get_chromosome_gene_info(
        chromosome_record: SeqRecord,
        pbar_position: int = 0
        ) -> pd.DataFrame:
    chromosome_genes = [t[1] for t in filter_chromosome_features_by_type(chromosome_record, "gene")]
    chromosome_gene_ids = list(map(get_feature_geneid, chromosome_genes))
    chromosome_gene_sequences = list(map(
        lambda seq_feature: str(seq_feature.extract(chromosome_record).seq),
        tqdm(chromosome_genes, leave=False, ncols=80, position=pbar_position, desc=f"Process-{pbar_position}")
    ))
    pos_strand_positions = list(map(lambda f: f.location.start, chromosome_genes))
    neg_strand_positions = list(map(lambda f: f.location.end, chromosome_genes))
    gene_df = pd.DataFrame(
        chromosome_gene_ids, columns=['geneid']
    )
    gene_df.loc[:, 'sequence'] = chromosome_gene_sequences
    gene_df.loc[:, 'pos_strand_position'] = pos_strand_positions
    gene_df.loc[:, 'neg_strand_position'] = neg_strand_positions
    return gene_df
    

In [50]:
#| hide
example_chromosome_gene_info_path = latest_assembly_path / "genes" / f"{example_chromosome_path.stem}.csv"
if example_chromosome_gene_info_path.exists():
    example_chromosome_gene_info = pd.read_csv(example_chromosome_gene_info_path)
else:
    example_chromosome_gene_info = get_chromosome_gene_info(example_chromosome)


In [51]:
#| hide
example_chromosome_gene_info.head()

Unnamed: 0,geneid,sequence,pos_strand_position,neg_strand_position
0,GeneID:100287102,CTTGCCGTCAGCCTTTTCTTTGACCTCTTCTTTCTGTTCATGTGTA...,11873,14409
1,GeneID:653635,TCCGGCAGAGCGGAAGCGGCGGCGGGAGCTTCCGGGAGGGCGGCTC...,14361,29370
2,GeneID:102466751,TGTGGGAGAGGAACATGGGCTCAGGACAGCGGGTGTCAGCTTGCCT...,17368,17436
3,GeneID:107985730,TGCCCTCCAGCCCTACGCCTTGACCCGCTTTCCTGCGTCTCTCAGC...,29773,35418
4,GeneID:100302278,GGATGCCCAGCTAGTTTGAATTTTAGATAAACAACGAATAATTTCG...,30365,30503


In [52]:
#| export
def write_chromosome_gene_info(assembly_path: Path, chromosome_tag: str, frame: pd.DataFrame):
    genes_path = assembly_path / "genes"
    if not genes_path.exists():
        genes_path.mkdir()
    gene_info_path = genes_path / f"{chromosome_tag}.csv"
    frame.to_csv(gene_info_path, index=False)

In [53]:
#| hide
if not example_chromosome_gene_info_path.exists():
    write_chromosome_gene_info(latest_assembly_path, example_chromosome_path.stem, example_chromosome_gene_info)

In [54]:
#| export
def get_feature_transcript_id(feature: SeqFeature) -> str | None:
    return next(iter(get_feature_qualifiers(feature).get("transcript_id", [])), None)

1. feature index
2. transcript_id
3. location parts

### 1.2 Identify mRNA parent gene relationships

In [80]:
#| export
def get_gene_and_mrna_relationships(
        chromosome: SeqRecord,
        ) -> pd.DataFrame:
    ""
    mrna_features = filter_chromosome_features_by_type(chromosome, "mRNA")
    relationships = set()
    for idx, mrna in mrna_features:
        mrna_transcript_id = get_feature_transcript_id(mrna)
        mrna_gene_id = get_feature_geneid(mrna)
        if mrna_gene_id is not None and mrna_transcript_id is not None:
            relationship_tuple = (mrna_gene_id, mrna_transcript_id, idx)
            relationships.add(relationship_tuple)

    chromosome_relationships_df = pd.DataFrame(
        relationships, 
        columns=['geneid', 'transcript_id', 'transcript_feature_idx']
    ).drop_duplicates(subset=["geneid", "transcript_id"]).sort_values("transcript_feature_idx", ascending=True)

    gene_idx = pd.DataFrame(
        [(idx, get_feature_geneid(f)) for idx, f in filter_chromosome_features_by_type(chromosome, "gene")],
        columns=["gene_feature_idx", "geneid"]
    ).drop_duplicates()

    chromosome_relationships_df = chromosome_relationships_df.merge(
        gene_idx, on=['geneid']
    ).sort_values("gene_feature_idx", ascending=True)
    return chromosome_relationships_df

In [81]:
#| hide
example_chromosome_relationships = get_gene_and_mrna_relationships(
    example_chromosome,
)

In [83]:
#| hide
example_chromosome_relationships.head()

Unnamed: 0,geneid,transcript_id,transcript_feature_idx,gene_feature_idx
0,GeneID:79501,NM_001005484.2,22,21
1,GeneID:112268260,XM_047436352.1,80,79
2,GeneID:729759,NM_001005221.2,85,84
3,GeneID:105378947,XM_011542538.1,96,95
4,GeneID:81399,XM_024449992.2,126,125


### 1.3 Write relationships

In [84]:
#| export
def write_mrna_gene_relationships(relationships: pd.DataFrame, chromosome: str, assembly_path: Path):
    relationship_path = assembly_path / "relationships"
    mrna_to_gene_path = relationship_path / "mrna_to_gene"
    if not mrna_to_gene_path.exists():
        mrna_to_gene_path.mkdir(parents=True)
    chromosome_relationship_path = mrna_to_gene_path / f"{chromosome}.csv"
    relationships.to_csv(chromosome_relationship_path, index=False)

In [85]:
#| hide
example_chromosome_relationships_path = latest_assembly_path / "relationships"/ "mrna_to_gene" / f"{example_chromosome_path.stem}.csv"
if not example_chromosome_relationships_path.exists():
    write_mrna_gene_relationships(example_chromosome_relationships, example_chromosome_path.stem, latest_assembly_path)

## 2. Process mRNA to gene relationships

For all mRNA to gene relationships
1. Normalize mRNA positions to written gene
2. Construct synthetic SeqFeature using first and last position of the mRNA (this is input sequence)
3. Construct mRNA sequence with intron tokens
4. Collect [3] and [4]
5. Write training instances to disk

In [86]:
#| hide
example_chromosome_mrna = filter_chromosome_features_by_type(example_chromosome, "mRNA")

In [None]:
#| export
def normalize_mrna_positions(mrna: SeqFeature, gene: pd.Series) -> SeqFeature:
    pass

In [None]:
#| export
def get_input_sequence(mrna: SeqFeature, gene_sequence: str) -> str:
    pass

In [None]:
#| export
def get_target_sequence(normalized_mrna: SeqFeature, gene: str) -> str:
    pass

In [None]:
#| export
def get_all_input_and_target_sequences():
    pass

In [None]:
#| export
def write_input_and_target_sequences():
    pass

In [62]:
# sample_chromosome_name = sample_chromosome_file.stem
# sample_chromosome_relationships_path = latest_assembly_path / "relationships/mrna_to_gene" / f"{sample_chromosome_name}.csv"
# assert sample_chromosome_relationships_path.exists()
# print(sample_chromosome_relationships_path)
# sample_chromosome_relationships = pd.read_csv(sample_chromosome_relationships_path)
# sample_chromosome_relationships.head()

In [63]:
#| export
# def get_mrna_gene_sequence(
#         chromosome: SeqRecord,
#         mrna_idx: int
# ) -> str:
#     "1. Construct synthetic transcription input sequence using first and last position of the mRNA"
#     mrna = chromosome.features[mrna_idx]
#     strand = mrna.strand
#     positions = [int(p.start) for p in mrna.location.parts] + [int(p.end) for p in mrna.location.parts]
#     position_max = max(positions)
#     position_min = min(positions)
#     location = SimpleLocation(position_min, position_max, strand=strand)
#     transcription_target_feature = SeqFeature(
#         location=location,
#         type="mRNA",
#     )
#     return ",".join(list(str(transcription_target_feature.extract(chromosome).seq)))

In [64]:
#| hide
# sample_input = get_mrna_gene_sequence(
#     sample_chromosome,
#     sample_chromosome_relationships.iloc[0, 2]
# )
# sample_input.count(",")

In [65]:
#| export
# def get_mrna_locations(
#         chromosome: SeqRecord,
#         mrna_idx: int
#         ) -> dict[str, list[tuple[int, int, int]]]:
#     mrna_feature = chromosome.features[mrna_idx]
#     location_parts = [(int(p.start), int(p.end), int(p.strand)) for p in mrna_feature.location.parts]
#     return location_parts


# def get_mrna_splice_locations(
#         chromosome: SeqRecord,
#         mrna_idx: int
#         ) -> list[int, int]: # List of intron locations including Where they start and how large they are.
#     """
#     We're going to read mRNA sequences where they're already oriented to their strand and spliced.
#     We need to normalize the locations so they're sorted according to strand and
#     normalized so they index to zero and ignore splice regions.
#     This is so we can easily split an mRNA read from source and insert introns.
#     """
#     locations = get_mrna_locations(
#         chromosome,
#         mrna_idx
#     )
#     locations_df = pd.DataFrame(locations, columns=['start', 'end', 'strand'])
#     strand = locations_df.iloc[0, 2]
#     ascending_sort = True if strand == 1 else False
#     locations_df.sort_values("start", 
#                              ascending=ascending_sort, 
#                              inplace=True)
#     start_loc = locations_df.iloc[0, 0]
#     locations_df.loc[:, 'start_norm'] = locations_df.start - start_loc
#     locations_df.loc[:, 'end_norm'] = locations_df.end - start_loc
#     locations_df.loc[:, 'end_norm_shift'] = locations_df.end_norm.shift(1)
#     introns = locations_df.dropna(subset=['end_norm_shift']).astype(int)
#     introns.loc[:, 'intron_size'] = introns.start_norm - introns.end_norm_shift
#     intron_list = introns[['end_norm', 'intron_size']].apply(abs).values.tolist()
#     return intron_list


# def get_mrna_sequence_with_introns(
#         chromosome: SeqRecord,
#         mrna_idx: int,
#         intron_token: str = "<intron>",
#         debug: bool = False
# ) -> str:
#     "2. Construct mRNA sequence with intron tokens."
#     intron_list = get_mrna_splice_locations(
#         chromosome,
#         mrna_idx
#         )
#     mrna = chromosome.features[mrna_idx]
#     mrna_sequence = str(mrna.extract(chromosome).seq)
#     if len(intron_list) == 0:
#         return ",".join(list(mrna_sequence))
#     if debug:
#         print(len(mrna_sequence))
#         print(intron_list)
#     mrna_with_introns = []
#     last_intron_start = 0
#     for intron_start, intron_size in intron_list:
#         intron_str = [intron_token] * intron_size
#         spliced_mrna = list(mrna_sequence[last_intron_start:intron_start])
#         mrna_with_introns = mrna_with_introns + spliced_mrna + intron_str
#         last_intron_start = intron_start
#     return ",".join(mrna_with_introns)

In [66]:
#| hide
# sample_chromosome.features[80]

In [67]:
#| hide
# get_mrna_sequence_with_introns(sample_chromosome, 80, debug=True).count(",")

In [68]:
#| hide
# sample_target = get_mrna_sequence_with_introns(sample_chromosome, sample_chromosome_relationships.iloc[0, 2])
# sample_target.count(",")

In [69]:
#| export
# def get_input_and_target_sequence(
#         chromosome: SeqRecord,
#         mrna_idx: int
# ) -> list[tuple[str, str]]:
#     "3. Collect [3] and [4]."
#     input_sequence = get_mrna_gene_sequence(
#         chromosome,
#         mrna_idx
#     )
#     target_sequence = get_mrna_sequence_with_introns(
#         chromosome, 
#         mrna_idx
#     )
#     return input_sequence, target_sequence

In [70]:
#| hide
# sample_input, sample_target = get_input_and_target_sequence(
#     sample_chromosome,
#     sample_chromosome_relationships.iloc[0, 2]
# )
# sample_input.count(",") == sample_target.count(",")

In [71]:
#| hide
# sample_chromosome_relationships_sample = sample_chromosome_relationships.sample(50)

In [72]:
#| hide
# pr = cProfile.Profile()
# pr.enable()
# sample_input_and_targets = sample_chromosome_relationships_sample.transcript_feature_idx.progress_apply(
#     lambda idx: get_input_and_target_sequence(
#         sample_chromosome,
#         idx
#     )
# )
# pr.disable()
# s = io.StringIO()
# sortby = SortKey.CUMULATIVE
# ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
# ps.print_stats()

In [73]:
# print(s.getvalue())

In [74]:
#| hide
# [(x[0].count(","), x[1].count(",")) for x in sample_input_and_targets]

In [75]:
#| hide
# sample_input_and_targets_df = pd.DataFrame(sample_input_and_targets.values.tolist(), columns=['input', 'target']).merge(
#     sample_chromosome_relationships.head(5),
#     left_index=True,
#     right_index=True
#     )
# sample_input_and_targets_df.loc[:, 'input_len'] = sample_input_and_targets_df['input'].apply(lambda s: s.count(","))
# sample_input_and_targets_df.loc[:, 'target_len'] = sample_input_and_targets_df['target'].apply(lambda s: s.count(","))

In [76]:
#| hide
# sample_input_and_targets_df

In [77]:
#| export
# def get_all_input_and_target_sequences(
#         chromosome: SeqRecord,
#         relationships: pd.DataFrame
# ):
#     # Has to be processed sequentially
    
#     pass

In [78]:
#| export
# def write_input_and_output_sequences(
        
# ):
#     "4. Write training instances to disk."

In [79]:
#| hide
import nbdev; nbdev.nbdev_export()