# Making the Reference Genome more easily usable

> Preparing genomic DNA for pre-training.

## Notebook Goals

Identify input and output sequences to train a transformer to transcribe DNA to mRNA.

### Plan

1. For each chromosome, identify genes and mRNA features
2. Extract the input and output sequences for training our transformer, write to csv

First we have to explore our data to figure out how to do this.

### 0. Setup

In [15]:
#| default_exp data.transcription

In [16]:
#| hide
from nbdev.showdoc import *

In [17]:
#| export
import os
from pathlib import Path
from Bio import SeqIO, SeqRecord, SeqFeature
from Bio.SeqFeature import FeatureLocation
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import typing
from tqdm.notebook import tqdm

from llm_mito_scanner.data.download import load_config, \
    get_latest_assembly_path, get_genomic_genbank_path

In [18]:
#| hide
config = load_config()

In [34]:
#| hide
data_path = Path(config.get("data_path"))
data_raw_path = data_path / "raw"
assemblies_path = data_raw_path / "assemblies"
latest_assembly_path = get_latest_assembly_path(assemblies_path)
genomic_genbank_path = get_genomic_genbank_path(latest_assembly_path)
chromosomes_path = latest_assembly_path / "chromosomes"
training_data_path = latest_assembly_path / "training"
transcription_data_path = training_data_path / "transcription"
if not transcription_data_path.exists():
    transcription_data_path.mkdir(parents=True)

### 1. For each chromosome identify gene and mRNA features.

In [20]:
#| hide
sample_chromosome_path = next(chromosomes_path.glob("*.gb"))
sample_chromosome_path

Path('/mnt/e/Data/llm-mito-scanner-data/data/raw/assemblies/GCF_000001405.40_GRCh38.p14/chromosomes/NC_000001.11.gb')

In [21]:
#| hide
with sample_chromosome_path.open("rt") as f:
    sample_chromosome = next(SeqIO.parse(f, "genbank"), None)
sample_chromosome

SeqRecord(seq=Seq('NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN...NNN'), id='NC_000001.11', name='NC_000001', description='Homo sapiens chromosome 1, GRCh38.p14 Primary Assembly', dbxrefs=['BioProject:PRJNA168', 'Assembly:GCF_000001405.40'])

In [22]:
#| hide
len(sample_chromosome.features)

51489

In [23]:
#| hide
sample_chromosome_feature_gene = None
sample_chromosome_feature_mrna = None
for f in sample_chromosome.features:
    if f.type == "gene":
        sample_chromosome_feature_gene = f
    if f.type == 'mRNA':
        sample_chromosome_feature_mrna = f
    if sample_chromosome_feature_gene is not None and sample_chromosome_feature_mrna is not None:
        break

In [24]:
#| hide
sample_chromosome_feature_gene.__dict__

{'location': SimpleLocation(ExactPosition(65418), ExactPosition(71585), strand=1),
 'type': 'gene',
 'id': '<unknown id>',
 'qualifiers': {'gene': ['OR4F5'],
  'note': ['olfactory receptor family 4 subfamily F member 5; Derived by automated computational analysis using gene prediction method: BestRefSeq.'],
  'db_xref': ['GeneID:79501', 'HGNC:HGNC:14825']}}

In [25]:
#| hide
sample_chromosome_feature_mrna.__dict__

{'location': CompoundLocation([SimpleLocation(ExactPosition(65418), ExactPosition(65433), strand=1), SimpleLocation(ExactPosition(65519), ExactPosition(65573), strand=1), SimpleLocation(ExactPosition(69036), ExactPosition(71585), strand=1)], 'join'),
 'type': 'mRNA',
 'id': '<unknown id>',
 'qualifiers': {'gene': ['OR4F5'],
  'product': ['olfactory receptor family 4 subfamily F member 5'],
  'note': ['Derived by automated computational analysis using gene prediction method: BestRefSeq.'],
  'transcript_id': ['NM_001005484.2'],
  'db_xref': ['Ensembl:ENST00000641515.2', 'GeneID:79501', 'HGNC:HGNC:14825']}}

In [26]:
#| export
def get_feature_qualifiers(feature: SeqFeature) -> typing.Dict[str, typing.Any]:
    return getattr(feature, "qualifiers", None)


def get_feature_dbxrefs(feature: SeqFeature):
    feature_qualifiers = get_feature_qualifiers(feature)
    if feature_qualifiers is None:
        return None
    feature_dbxrefs = feature_qualifiers.get("db_xref", None)
    return feature_dbxrefs


def get_feature_dbxref_xref(feature: SeqFeature, prefix: str):
    feature_dbxrefs = get_feature_dbxrefs(feature)
    if feature_dbxrefs is None:
        return None
    tag_db_xref = next(iter([x for x in feature_dbxrefs if x.startswith(prefix)]), None)
    return tag_db_xref


def get_feature_geneid(feature: SeqFeature):
    return get_feature_dbxref_xref(feature, "GeneID")


def get_feature_transcript_id(feature: SeqFeature):
    return next(iter(get_feature_qualifiers(feature).get("transcript_id", [])), None)

In [27]:
#| hide
get_feature_geneid(sample_chromosome_feature_mrna)

'GeneID:79501'

#### 1.1 Plan

Our strategy will be;

1. Read through the features, collect every gene and mRNA feature
2. Identify for every gene the child mRNA, write to disk
3. Shard gene, mRNA pairs by chromosome
4. Within each shard
    1. Read the chromosome
    2. Extract input DNA sequence from gene
    3. Extract mRNA sequence
    4. Fill gaps in mRNA sequence with intron tokens
    5. Write to disc in csv

#### 1.1.1 Collect gene and mRNA features

In [28]:
#| export
def get_gene_and_mrna_features(record: SeqRecord):
    gene_features = []
    mrna_features = []
    for i, f in enumerate(record.features):
        if f.type == "gene":
            gene_features.append((i, f))
        elif f.type == "mRNA":
            mrna_features.append((i, f))
    return gene_features, mrna_features

In [29]:
#| hide
gene_features, mrna_features = get_gene_and_mrna_features(sample_chromosome)
len(gene_features), len(mrna_features)

(5501, 12841)

In [30]:
#| hide
gene_features[0], mrna_features[0]

((2,
  SeqFeature(SimpleLocation(ExactPosition(11873), ExactPosition(14409), strand=1), type='gene', qualifiers=...)),
 (22,
  SeqFeature(CompoundLocation([SimpleLocation(ExactPosition(65418), ExactPosition(65433), strand=1), SimpleLocation(ExactPosition(65519), ExactPosition(65573), strand=1), SimpleLocation(ExactPosition(69036), ExactPosition(71585), strand=1)], 'join'), type='mRNA', location_operator='join', qualifiers=...)))

#### 1.1.2 Identify gene to mRNA feature relationships

In [31]:
#| export
def get_gene_and_mrna_relationships(
        gene_features,
        mrna_features
        ) -> list[typing.Tuple[int, int]]:
    ""
    relationships = []
    gene_index = {get_feature_geneid(f): i for i, f in gene_features}
    for i, f in mrna_features:
        f_gene_id = get_feature_geneid(f)
        f_gene_idx = gene_index.get(f_gene_id, None)
        if f_gene_idx is not None:
            relationships.append((i, f_gene_idx))
    return relationships

In [32]:
#| hide
mrna_to_protein_relationships = get_gene_and_mrna_relationships(
    gene_features, mrna_features
)
len(mrna_to_protein_relationships), mrna_to_protein_relationships[0]

(12841, (22, 21))

In [33]:
#| hide
mrna_to_protein_relationships[0]

(22, 21)

In [38]:
#| hide
relationship_df = pd.DataFrame(mrna_to_protein_relationships, columns=['feature_idx_mrna', 'feature_idx_gene'])
relationship_df.loc[:, 'chromosome'] = sample_chromosome_path.stem
relationship_df.to_csv(transcription_data_path / "relationships.csv", index=False)

In [62]:
#| export
def extract_all_relationships(assembly_path: Path):
    chromosomes_path = assembly_path / "chromosomes"
    training_data_path = assembly_path / "training"
    transcription_data_path = training_data_path / "transcription"
    relationships_path = transcription_data_path / "mrna_to_protein_feature_idx"

    if not relationships_path.exists():
        relationships_path.mkdir(parents=True)

    chromosome_files = list(chromosomes_path.glob("*.gb"))

    for chromosome_path in tqdm(chromosome_files, leave=False, ncols=80):
        with chromosome_path.open("rt") as f:
            chromosome_seq_record = next(SeqIO.parse(f, "genbank"), None)
        gene_features, mrna_features = get_gene_and_mrna_features(chromosome_seq_record)
        chromosome_mrna_to_protein_relationships = get_gene_and_mrna_relationships(
            gene_features, mrna_features
        )
        if len(chromosome_mrna_to_protein_relationships) > 0:
            relationship_df = pd.DataFrame(chromosome_mrna_to_protein_relationships, columns=['feature_idx_mrna', 'feature_idx_gene'])
            relationship_df.loc[:, 'chromosome'] = chromosome_path.stem
            chromosome_relationship_path = relationships_path / f"{chromosome_path.stem}.csv"
            relationship_df.to_csv(chromosome_relationship_path, index=False)


#### 1.1.3 Shard relationships

In [68]:
#| export
def get_all_relationships(assembly_path: Path):
    training_data_path = assembly_path / "training"
    transcription_data_path = training_data_path / "transcription"
    relationships_path = transcription_data_path / "mrna_to_protein_feature_idx"
    all_relationships = pd.concat(
        [pd.read_csv(path) for path in relationships_path.glob("*.csv")],
        axis=0,
        ignore_index=True
    ).drop_duplicates()
    return all_relationships


def shard_relationships(relationship_df: pd.DataFrame, shard_count: int = os.cpu_count() - 1):
    return np.array_split(relationship_df, shard_count)

In [70]:
#| hide
all_relationships = get_all_relationships(latest_assembly_path)
all_relationships.head()

Unnamed: 0,feature_idx_mrna,feature_idx_gene,chromosome
0,22,21,NC_000001.11
1,80,79,NC_000001.11
2,85,84,NC_000001.11
3,96,95,NC_000001.11
4,126,125,NC_000001.11


In [71]:
#| hide
all_relationships.chromosome.value_counts()

chromosome
NC_000001.11    12841
NC_000002.12     9518
NC_000003.12     8490
NC_000019.10     7521
NC_000011.10     7519
NC_000017.11     7329
NC_000012.12     7287
NC_000006.12     6155
NC_000010.11     6151
NC_000007.14     5905
NC_000005.10     5561
NC_000009.12     5504
NC_000004.12     5351
NC_000016.10     5235
NC_000008.11     4964
NC_000023.11     4296
NC_000015.10     4255
NC_000014.9      3959
NC_000020.11     3089
NC_000022.11     2721
NC_000018.10     2521
NC_000013.11     2471
NC_000021.9      1411
NC_000024.10      366
Name: count, dtype: int64

In [72]:
#| hide
shards = shard_relationships(all_relationships)
shards[0].head()

Unnamed: 0,feature_idx_mrna,feature_idx_gene,chromosome
0,22,21,NC_000001.11
1,80,79,NC_000001.11
2,85,84,NC_000001.11
3,96,95,NC_000001.11
4,126,125,NC_000001.11


### 2. Process Relationships


For every relationship:
1. Extract input DNA sequence from gene
2. Extract mRNA sequence
3. Fill gaps in mRNA sequence with intron tokens

In [74]:
#| export
def process_gene_mrna_pair(
        chromosome: SeqRecord,
        feature_gene_idx: int, 
        feature_mrna_idx: int,
        intron_token: str = "<intron>"
        ) -> pd.Series:
    feature_gene = chromosome.features[feature_gene_idx]
    feature_mrna = chromosome.features[feature_mrna_idx]
    # Get gene sequence
    gene_sequence = feature_gene.extract(chromosome)
    # Get mRNA locations
    mrna_locations = feature_mrna.location.parts
    if feature_mrna.strand != 1:
        mrna_locations = mrna_locations[::-1]
    # Use mRNA locations to identify intron sections
    mrna_sequence = []
    last_end = None
    for part in mrna_locations:
        start = int(part.start)
        end = int(part.end)
        if last_end is not None:
            # Add intron to sequence
            intron_size = start - last_end
            intron_tokens = [intron_token] * intron_size
            mrna_sequence = mrna_sequence + intron_tokens
        # Add part sequence
        part_sequence = list(str(part.extract(chromosome).seq))
        if mrna_sequence is None:
            mrna_sequence = part_sequence
        else:
            mrna_sequence = mrna_sequence + part_sequence
        last_end = end
    # Inject intron sections into the mRNA sequence
    # Make series
    pair = pd.Series()
    pair.at["gene"] = ",".join(gene_sequence.seq)
    pair.at["mrna"] = ",".join(mrna_sequence)
    return pair

In [75]:
#| hide
all_relationships.head()

Unnamed: 0,feature_idx_mrna,feature_idx_gene,chromosome
0,22,21,NC_000001.11
1,80,79,NC_000001.11
2,85,84,NC_000001.11
3,96,95,NC_000001.11
4,126,125,NC_000001.11


In [76]:
#| hide
process_gene_mrna_pair(
    sample_chromosome, 
    21,
    22
    )

gene    C,C,C,A,G,A,T,C,T,C,T,T,C,A,G,G,T,A,C,A,T,C,T,...
mrna    C,C,C,A,G,A,T,C,T,C,T,T,C,A,G,<intron>,<intron...
dtype: object

Lets bring this into a script so we can leverage proper processing.

In [29]:
#| hide
import nbdev; nbdev.nbdev_export()