# Making the Reference Genome more easily usable

> Preparing genomic DNA for pre-training.

## Notebook Goals

Identify input and output sequences to train a transformer to transcribe DNA to mRNA.

### Plan

1. For each chromosome file;
    1. Get mRNA feature details
        1. feature index
        2. transcript_id
        3. location parts
    2. Identify mRNA to gene relationships, de-dupe
    3. Collect all for all chromosomes, write to disk
2. For all mRNA to gene relationships
    1. Construct synthetic SeqFeature using first and last position of the mRNA (this is input sequence)
    2. Construct mRNA sequence with intron tokens
    3. Collect [3] and [4]
    4. Write training instnaces to disk

## 0. Setup

In [1]:
#| default_exp data.transcription

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
import os
from pathlib import Path
from Bio import SeqIO, SeqRecord, Seq
from Bio.SeqFeature import SeqFeature, Position, SimpleLocation
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import typing
import gzip
from tqdm.notebook import tqdm

from llm_mito_scanner.data.download import load_config, \
    get_latest_assembly_path, get_genomic_genbank_path



In [4]:
#| hide
config = load_config()

In [5]:
#| hide
data_path = Path(config.get("data_path"))
data_raw_path = data_path / "raw"
assemblies_path = data_raw_path / "assemblies"
latest_assembly_path = get_latest_assembly_path(assemblies_path)
genomic_genbank_path = get_genomic_genbank_path(latest_assembly_path)
chromosomes_path = latest_assembly_path / "chromosomes"
training_data_path = latest_assembly_path / "training"
transcription_data_path = training_data_path / "transcription"
if not transcription_data_path.exists():
    transcription_data_path.mkdir(parents=True)

## 1. Process chromosome files

### 1.1 Get mRNA feature details

In [6]:
#| hide
chromosome_path = next(chromosomes_path.glob("*.gb"))
chromosome_path

Path('/mnt/e/Data/llm-mito-scanner-data/data/raw/assemblies/GCF_000001405.40_GRCh38.p14/chromosomes/NC_000001.11.gb')

In [7]:
#| hide
with chromosome_path.open("rt") as f:
    chromosome = next(SeqIO.parse(f, "genbank"), None)
chromosome

SeqRecord(seq=Seq('NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN...NNN'), id='NC_000001.11', name='NC_000001', description='Homo sapiens chromosome 1, GRCh38.p14 Primary Assembly', dbxrefs=['BioProject:PRJNA168', 'Assembly:GCF_000001405.40'])

In [8]:
#| hide
len(chromosome.features)

51489

In [9]:
#| export
def filter_chromosome_features_by_type(
        chromosome_record: SeqRecord, 
        feature_type: str,
        ) -> list[tuple[int, SeqFeature]]:
    return [(i, f) for i, f in enumerate(chromosome_record.features) if f.type == feature_type]

In [10]:
#| hide
chromosome_mrna = filter_chromosome_features_by_type(chromosome, "mRNA")
len(chromosome_mrna)

12841

In [11]:
#| hide
chromosome_mrna[0][1].qualifiers.get("db_xref")

['Ensembl:ENST00000641515.2', 'GeneID:79501', 'HGNC:HGNC:14825']

In [12]:
#| export
def get_feature_qualifiers(feature: SeqFeature) -> typing.Dict[str, typing.Any]:
    return getattr(feature, "qualifiers", None)


def get_feature_dbxrefs(feature: SeqFeature) -> str | None:
    feature_qualifiers = get_feature_qualifiers(feature)
    if feature_qualifiers is None:
        return None
    feature_dbxrefs = feature_qualifiers.get("db_xref", None)
    return feature_dbxrefs


def get_feature_dbxref_xref(feature: SeqFeature, prefix: str) -> str | None:
    feature_dbxrefs = get_feature_dbxrefs(feature)
    if feature_dbxrefs is None:
        return None
    tag_db_xref = next(iter([x for x in feature_dbxrefs if x.startswith(prefix)]), None)
    return tag_db_xref


def get_feature_geneid(feature: SeqFeature) -> str | None:
    return get_feature_dbxref_xref(feature, "GeneID")


def get_feature_transcript_id(feature: SeqFeature) -> str | None:
    return next(iter(get_feature_qualifiers(feature).get("transcript_id", [])), None)

1. feature index
2. transcript_id
3. location parts

In [13]:
#| export
def get_mrna_details(idx, mrna: SeqFeature) -> tuple[int, str, list[Position]]:
    return idx, get_feature_transcript_id(mrna), mrna.location.parts

In [14]:
#| hide
get_mrna_details(*chromosome_mrna[0])

(22,
 'NM_001005484.2',
 [SimpleLocation(ExactPosition(65418), ExactPosition(65433), strand=1),
  SimpleLocation(ExactPosition(65519), ExactPosition(65573), strand=1),
  SimpleLocation(ExactPosition(69036), ExactPosition(71585), strand=1)])

### 1.2 Identify mRNA parent gene relationships

In [15]:
#| export
def get_gene_and_mrna_relationships(
        chromosome: SeqRecord,
        ) -> pd.DataFrame:
    ""
    mrna_features = filter_chromosome_features_by_type(chromosome, "mRNA")
    relationships = set()
    for idx, mrna in mrna_features:
        mrna_transcript_id = get_feature_transcript_id(mrna)
        mrna_gene_id = get_feature_geneid(mrna)
        if mrna_gene_id is not None and mrna_transcript_id is not None:
            relationship_tuple = (mrna_gene_id, mrna_transcript_id, idx)
            relationships.add(relationship_tuple)

    chromosome_relationships_df = pd.DataFrame(
        relationships, 
        columns=['geneid', 'transcript_id', 'transcript_feature_idx']
    ).drop_duplicates(subset=["geneid", "transcript_id"]).sort_values("transcript_feature_idx", ascending=True)

    gene_idx = pd.DataFrame(
        [(idx, get_feature_geneid(f)) for idx, f in filter_chromosome_features_by_type(chromosome, "gene")],
        columns=["gene_feature_idx", "geneid"]
    ).drop_duplicates()

    chromosome_relationships_df = chromosome_relationships_df.merge(
        gene_idx, on=['geneid']
    ).sort_values("gene_feature_idx", ascending=True)
    return chromosome_relationships_df

In [16]:
#| hide
chromosome_relationships = get_gene_and_mrna_relationships(
    chromosome,
)

In [17]:
#| hide
chromosome_relationships.head()

Unnamed: 0,geneid,transcript_id,transcript_feature_idx,gene_feature_idx
0,GeneID:79501,NM_001005484.2,22,21
1,GeneID:112268260,XM_047436352.1,80,79
2,GeneID:729759,NM_001005221.2,85,84
3,GeneID:105378947,XM_011542538.1,96,95
4,GeneID:81399,XM_024449992.2,126,125


### 1.3 Write relationships

In [18]:
#| export
def write_mrna_gene_relationships(relationships: pd.DataFrame, chromosome: str, assembly_path: Path):
    relationship_path = assembly_path / "relationships"
    mrna_to_gene_path = relationship_path / "mrna_to_gene"
    if not mrna_to_gene_path.exists():
        mrna_to_gene_path.mkdir(parents=True)
    chromosome_relationship_path = mrna_to_gene_path / f"{chromosome}.csv"
    relationships.to_csv(chromosome_relationship_path, index=False)

In [19]:
#| hide
if not (latest_assembly_path / "relationships"/ "mrna_to_gene" / f"{chromosome_path.stem}.csv").exists():
    write_mrna_gene_relationships(chromosome_relationships, chromosome_path.stem, latest_assembly_path)

In [20]:
#| hide
mrna_to_gene_relationship_path = latest_assembly_path / "relationships/mrna_to_gene"
chromosome_files = tqdm(
    list(chromosomes_path.glob("*.gb")),
    leave=False,
    ncols=80
    )
for chromosome_file in chromosome_files:
    if not (mrna_to_gene_relationship_path / f"{chromosome_file.stem}.csv").exists():
        with chromosome_file.open("rt") as f:
            chromosome = next(SeqIO.parse(f, "genbank"), None)
        chromosome_relationships = get_gene_and_mrna_relationships(chromosome)
        write_mrna_gene_relationships(chromosome_relationships, chromosome_file.stem, latest_assembly_path)


  0%|                                                    | 0/25 [00:00<?, ?it/s]

## 2. Process mRNA to gene relationships

For all mRNA to gene relationships
1. Construct synthetic SeqFeature using first and last position of the mRNA (this is input sequence)
2. Construct mRNA sequence with intron tokens
3. Collect [3] and [4]
4. Write training instnaces to disk

Goals
1. Process safe
2. Thread safe
3. Finish within 20 minutes

In [21]:
#| hide
sample_chromosome_file = next(iter(chromosomes_path.glob("*.gb")))
print(sample_chromosome_file)
with sample_chromosome_file.open("rt") as f:
    sample_chromosome = next(SeqIO.parse(f, "genbank"))

/mnt/e/Data/llm-mito-scanner-data/data/raw/assemblies/GCF_000001405.40_GRCh38.p14/chromosomes/NC_000001.11.gb


In [22]:
sample_chromosome_name = sample_chromosome_file.stem
sample_chromosome_relationships_path = latest_assembly_path / "relationships/mrna_to_gene" / f"{sample_chromosome_name}.csv"
assert sample_chromosome_relationships_path.exists()
print(sample_chromosome_relationships_path)
sample_chromosome_relationships = pd.read_csv(sample_chromosome_relationships_path)
sample_chromosome_relationships.head()

/mnt/e/Data/llm-mito-scanner-data/data/raw/assemblies/GCF_000001405.40_GRCh38.p14/relationships/mrna_to_gene/NC_000001.11.csv


Unnamed: 0,geneid,transcript_id,transcript_feature_idx,gene_feature_idx
0,GeneID:79501,NM_001005484.2,22,21
1,GeneID:112268260,XM_047436352.1,80,79
2,GeneID:729759,NM_001005221.2,85,84
3,GeneID:105378947,XM_011542538.1,96,95
4,GeneID:81399,XM_024449992.2,126,125


In [23]:
#| export
def get_mrna_gene_sequence(
        chromosome: SeqRecord,
        mrna_idx: int
) -> str:
    "1. Construct synthetic transcription input sequence using first and last position of the mRNA"
    mrna = chromosome.features[mrna_idx]
    strand = mrna.strand
    positions = [int(p.start) for p in mrna.location.parts] + [int(p.end) for p in mrna.location.parts]
    position_max = max(positions)
    position_min = min(positions)
    location = SimpleLocation(position_min, position_max, strand=strand)
    transcription_target_feature = SeqFeature(
        location=location,
        type="mRNA",
    )
    return ",".join(list(str(transcription_target_feature.extract(chromosome).seq)))

In [24]:
#| hide
sample_input = get_mrna_gene_sequence(
    sample_chromosome,
    sample_chromosome_relationships.iloc[0, 2]
)
sample_input.count(",")

6166

In [25]:
#| export
def get_mrna_locations(
        chromosome: SeqRecord,
        mrna_idx: int
        ) -> dict[str, list[tuple[int, int, int]]]:
    mrna_feature = chromosome.features[mrna_idx]
    location_parts = [(int(p.start), int(p.end), int(p.strand)) for p in mrna_feature.location.parts]
    return location_parts


def get_mrna_splice_locations(
        chromosome: SeqRecord,
        mrna_idx: int
        ) -> list[int, int]: # List of intron locations including Where they start and how large they are.
    """
    We're going to read mRNA sequences where they're already oriented to their strand and spliced.
    We need to normalize the locations so they're sorted according to strand and
    normalized so they index to zero and ignore splice regions.
    This is so we can easily split an mRNA read from source and insert introns.
    """
    locations = get_mrna_locations(
        chromosome,
        mrna_idx
    )
    locations_df = pd.DataFrame(locations, columns=['start', 'end', 'strand'])
    strand = locations_df.iloc[0, 2]
    ascending_sort = True if strand == 1 else False
    locations_df.sort_values("start", 
                             ascending=ascending_sort, 
                             inplace=True)
    start_loc = locations_df.iloc[0, 0]
    locations_df.loc[:, 'start_norm'] = locations_df.start - start_loc
    locations_df.loc[:, 'end_norm'] = locations_df.end - start_loc
    locations_df.loc[:, 'end_norm_shift'] = locations_df.end_norm.shift(1)
    introns = locations_df.dropna(subset=['end_norm_shift']).astype(int)
    introns.loc[:, 'intron_size'] = introns.start_norm - introns.end_norm_shift
    intron_list = introns[['end_norm', 'intron_size']].apply(abs).values.tolist()
    return intron_list


def get_mrna_sequence_with_introns(
        chromosome: SeqRecord,
        mrna_idx: int,
        intron_token: str = "<intron>",
        debug: bool = False
) -> str:
    "2. Construct mRNA sequence with intron tokens."
    intron_list = get_mrna_splice_locations(
        chromosome,
        mrna_idx
        )
    mrna = chromosome.features[mrna_idx]
    mrna_sequence = str(mrna.extract(chromosome).seq)
    if len(intron_list) == 0:
        return ",".join(list(mrna_sequence))
    if debug:
        print(len(mrna_sequence))
        print(intron_list)
    mrna_with_introns = []
    last_intron_start = 0
    for intron_start, intron_size in intron_list:
        intron_str = [intron_token] * intron_size
        spliced_mrna = list(mrna_sequence[last_intron_start:intron_start])
        mrna_with_introns = mrna_with_introns + spliced_mrna + intron_str
        last_intron_start = intron_start
    return ",".join(mrna_with_introns)

In [26]:
#| hide
sample_chromosome.features[80]

SeqFeature(CompoundLocation([SimpleLocation(ExactPosition(382049), ExactPosition(382235), strand=-1), SimpleLocation(ExactPosition(380896), ExactPosition(381688), strand=-1), SimpleLocation(ExactPosition(379768), ExactPosition(379870), strand=-1), SimpleLocation(ExactPosition(373143), ExactPosition(373323), strand=-1), SimpleLocation(ExactPosition(365133), ExactPosition(365692), strand=-1)], 'join'), type='mRNA', location_operator='join', qualifiers=...)

In [27]:
#| hide
get_mrna_sequence_with_introns(sample_chromosome, 80, debug=True).count(",")

1819
[[361, 1339], [2179, 1920], [8726, 6727], [16357, 8190]]


19994

In [28]:
#| hide
sample_target = get_mrna_sequence_with_introns(sample_chromosome, sample_chromosome_relationships.iloc[0, 2])
sample_target.count(",")

6166

In [29]:
#| export
def get_input_and_target_sequence(
        chromosome: SeqRecord,
        mrna_idx: int
) -> list[tuple[str, str]]:
    "3. Collect [3] and [4]."
    input_sequence = get_mrna_gene_sequence(
        chromosome,
        mrna_idx
    )
    target_sequence = get_mrna_sequence_with_introns(
        chromosome, 
        mrna_idx
    )
    return input_sequence, target_sequence

In [30]:
#| hide
sample_input, sample_target = get_input_and_target_sequence(
    sample_chromosome,
    sample_chromosome_relationships.iloc[0, 2]
)
sample_input.count(",") == sample_target.count(",")

True

In [31]:
#| hide
sample_input_and_targets = sample_chromosome_relationships.head(5).transcript_feature_idx.apply(
    lambda idx: get_input_and_target_sequence(
        sample_chromosome,
        idx
    )
)

In [32]:
#| hide
[(x[0].count(","), x[1].count(",")) for x in sample_input_and_targets]

[(6166, 6166), (17101, 19994), (938, 938), (25010, 29438), (44025, 57099)]

In [33]:
#| hide
sample_input_and_targets_df = pd.DataFrame(sample_input_and_targets.values.tolist(), columns=['input', 'target']).merge(
    sample_chromosome_relationships.head(5),
    left_index=True,
    right_index=True
    )
sample_input_and_targets_df.loc[:, 'input_len'] = sample_input_and_targets_df['input'].apply(lambda s: s.count(","))
sample_input_and_targets_df.loc[:, 'target_len'] = sample_input_and_targets_df['target'].apply(lambda s: s.count(","))

In [34]:
#| hide
sample_input_and_targets_df

Unnamed: 0,input,target,geneid,transcript_id,transcript_feature_idx,gene_feature_idx,input_len,target_len
0,"C,C,C,A,G,A,T,C,T,C,T,T,C,A,G,G,T,A,C,A,T,C,T,...","C,C,C,A,G,A,T,C,T,C,T,T,C,A,G,T,T,T,T,T,A,T,G,...",GeneID:79501,NM_001005484.2,22,21,6166,6166
1,"A,T,G,C,C,T,A,G,A,C,A,C,A,C,A,C,A,T,C,C,T,T,A,...","A,T,G,C,C,T,A,G,A,C,A,C,A,C,A,C,A,T,C,C,T,T,A,...",GeneID:112268260,XM_047436352.1,80,79,17101,19994
2,"A,T,G,G,A,T,G,G,A,G,A,G,A,A,T,C,A,C,T,C,A,G,T,...","A,T,G,G,A,T,G,G,A,G,A,G,A,A,T,C,A,C,T,C,A,G,T,...",GeneID:729759,NM_001005221.2,85,84,938,938
3,"A,T,G,C,G,T,A,G,A,C,A,C,A,C,A,C,A,T,C,C,T,T,A,...","A,T,G,C,G,T,A,G,A,C,A,C,A,C,A,C,A,T,C,C,T,T,A,...",GeneID:105378947,XM_011542538.1,96,95,25010,29438
4,"T,A,T,A,A,A,A,T,G,A,A,A,G,C,T,G,C,C,T,C,T,G,A,...","T,A,T,A,A,A,A,T,G,A,A,A,G,C,T,G,C,C,T,C,T,G,A,...",GeneID:81399,XM_024449992.2,126,125,44025,57099


In [35]:
#| export
def get_all_input_and_target_sequences(
        chromosome: SeqRecord,
        relationships: pd.DataFrame
):
    # Has to be processed sequentially
    
    pass

In [36]:
#| export
def write_input_and_output_sequences(
        
):
    "4. Write training instances to disk."

In [38]:
#| hide
import nbdev; nbdev.nbdev_export()