# Making the Reference Genome more easily usable

> Preparing genomic DNA for pre-training.

## Notebook Goals

Identify input and output sequences to train a transformer to transcribe DNA to mRNA.

### Plan

1. For each chromosome, identify genes and mRNA features
2. Extract the input and output sequences for training our transformer, write to csv

First we have to explore our data to figure out how to do this.

### 0. Setup

In [2]:
#| default_exp data.transcription

In [3]:
#| hide
from nbdev.showdoc import *

In [4]:
#| export
import os
from pathlib import Path
from Bio import SeqIO, SeqRecord, SeqFeature
from Bio.SeqFeature import FeatureLocation
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import typing
from tqdm.notebook import tqdm

from llm_mito_scanner.data.download import load_config, \
    get_latest_assembly_path, get_genomic_genbank_path



In [5]:
#| hide
config = load_config()

In [6]:
#| hide
data_path = Path(config.get("data_path"))
data_raw_path = data_path / "raw"
assemblies_path = data_raw_path / "assemblies"
latest_assembly_path = get_latest_assembly_path(assemblies_path)
genomic_genbank_path = get_genomic_genbank_path(latest_assembly_path)
chromosomes_path = latest_assembly_path / "chromosomes"

### 1. For each chromosome identify gene and mRNA features.

In [7]:
#| hide
sample_chromosome_path = next(chromosomes_path.glob("*.gb"))
sample_chromosome_path

Path('/mnt/e/Data/llm-mito-scanner-data/data/raw/assemblies/GCF_000001405.40_GRCh38.p14/chromosomes/NC_000001.11.gb')

In [8]:
#| hide
with sample_chromosome_path.open("rt") as f:
    sample_chromosome = next(SeqIO.parse(f, "genbank"), None)
sample_chromosome

SeqRecord(seq=Seq('NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN...NNN'), id='NC_000001.11', name='NC_000001', description='Homo sapiens chromosome 1, GRCh38.p14 Primary Assembly', dbxrefs=['BioProject:PRJNA168', 'Assembly:GCF_000001405.40'])

In [9]:
#| hide
len(sample_chromosome.features)

51489

In [10]:
#| hide
sample_chromosome_feature_gene = None
sample_chromosome_feature_mrna = None
for f in sample_chromosome.features:
    if f.type == "gene":
        sample_chromosome_feature_gene = f
    if f.type == 'mRNA':
        sample_chromosome_feature_mrna = f
    if sample_chromosome_feature_gene is not None and sample_chromosome_feature_mrna is not None:
        break

In [11]:
#| hide
sample_chromosome_feature_gene.__dict__

{'location': SimpleLocation(ExactPosition(65418), ExactPosition(71585), strand=1),
 'type': 'gene',
 'id': '<unknown id>',
 'qualifiers': {'gene': ['OR4F5'],
  'note': ['olfactory receptor family 4 subfamily F member 5; Derived by automated computational analysis using gene prediction method: BestRefSeq.'],
  'db_xref': ['GeneID:79501', 'HGNC:HGNC:14825']}}

In [12]:
#| hide
sample_chromosome_feature_mrna.__dict__

{'location': CompoundLocation([SimpleLocation(ExactPosition(65418), ExactPosition(65433), strand=1), SimpleLocation(ExactPosition(65519), ExactPosition(65573), strand=1), SimpleLocation(ExactPosition(69036), ExactPosition(71585), strand=1)], 'join'),
 'type': 'mRNA',
 'id': '<unknown id>',
 'qualifiers': {'gene': ['OR4F5'],
  'product': ['olfactory receptor family 4 subfamily F member 5'],
  'note': ['Derived by automated computational analysis using gene prediction method: BestRefSeq.'],
  'transcript_id': ['NM_001005484.2'],
  'db_xref': ['Ensembl:ENST00000641515.2', 'GeneID:79501', 'HGNC:HGNC:14825']}}

In [13]:
#| export
def get_feature_qualifiers(feature: SeqFeature) -> typing.Dict[str, typing.Any]:
    return getattr(feature, "qualifiers", None)


def get_feature_dbxrefs(feature: SeqFeature):
    feature_qualifiers = get_feature_qualifiers(feature)
    if feature_qualifiers is None:
        return None
    feature_dbxrefs = feature_qualifiers.get("db_xref", None)
    return feature_dbxrefs


def get_feature_dbxref_xref(feature: SeqFeature, prefix: str):
    feature_dbxrefs = get_feature_dbxrefs(feature)
    if feature_dbxrefs is None:
        return None
    tag_db_xref = next(iter([x for x in feature_dbxrefs if x.startswith(prefix)]), None)
    return tag_db_xref


def get_feature_geneid(feature: SeqFeature):
    return get_feature_dbxref_xref(feature, "GeneID")


def get_feature_transcript_id(feature: SeqFeature):
    return next(iter(get_feature_qualifiers(feature).get("transcript_id", [])), None)

In [14]:
#| hide
get_feature_geneid(sample_chromosome_feature_mrna)

'GeneID:79501'

#### 1.1 Plan

Our strategy will be;

1. Read through the features, collect every gene and mRNA
2. Identify for every gene the child mRNA
3. For every gene, mRNA pair;
    1. Extract input DNA sequence from gene
    2. Extract mRNA sequence
    3. Fill gaps in mRNA sequence with intron tokens
    4. Write to disc in csv

#### 1.1.1 Collect gene and mRNA features

In [15]:
#| export
def get_gene_and_mrna_features(record: SeqRecord):
    gene_features = []
    mrna_features = []
    for i, f in enumerate(record.features):
        if f.type == "gene":
            gene_features.append((i, f))
        elif f.type == "mRNA":
            mrna_features.append((i, f))
    return gene_features, mrna_features

In [16]:
#| hide
gene_features, mrna_features = get_gene_and_mrna_features(sample_chromosome)
len(gene_features), len(mrna_features)

(5501, 12841)

In [17]:
#| hide
gene_features[0], mrna_features[0]

((2,
  SeqFeature(SimpleLocation(ExactPosition(11873), ExactPosition(14409), strand=1), type='gene', qualifiers=...)),
 (22,
  SeqFeature(CompoundLocation([SimpleLocation(ExactPosition(65418), ExactPosition(65433), strand=1), SimpleLocation(ExactPosition(65519), ExactPosition(65573), strand=1), SimpleLocation(ExactPosition(69036), ExactPosition(71585), strand=1)], 'join'), type='mRNA', location_operator='join', qualifiers=...)))

#### 1.1.2 Identify gene to mRNA feature relationships

In [18]:
#| export
def get_gene_and_mrna_relationships(
        gene_features,
        mrna_features
        ) -> list[typing.Tuple[int, int]]:
    ""
    relationships = []
    gene_index = {get_feature_geneid(f): i for i, f in gene_features}
    for i, f in mrna_features:
        f_gene_id = get_feature_geneid(f)
        f_gene_idx = gene_index.get(f_gene_id, None)
        if f_gene_idx is not None:
            relationships.append((i, f_gene_idx))
    return relationships

In [19]:
#| hide
mrna_to_protein_relationships = get_gene_and_mrna_relationships(
    gene_features, mrna_features
)
len(mrna_to_protein_relationships), mrna_to_protein_relationships[0]

(12841, (22, 21))

### 2. For every gene, mRNA pair;

1. Extract input DNA sequence from gene
2. Extract mRNA sequence
3. Fill gaps in mRNA sequence with intron tokens

In [20]:
#| hide
len(mrna_features[0][1].extract(sample_chromosome).seq)

2618

In [21]:
#| hide
sample_len = 0
for p in mrna_features[0][1].location.parts:
    sample_len += (p.end - p.start)
sample_len

2618

In [22]:
#| hide
mrna_features[0][1].location.parts[0].start

ExactPosition(65418)

In [23]:
#| hide
mrna_features[0][1].strand

1

In [24]:
#| hide
"<intron>" * 3

'<intron><intron><intron>'

In [25]:
#| export
def process_gene_mrna_pair(
        chromosome: SeqRecord,
        feature_gene: SeqFeature, 
        feature_mrna: SeqFeature,
        intron_token: str = "<intron>"
        ) -> pd.Series:
    # Get gene sequence
    gene_sequence = feature_gene.extract(chromosome)
    # Get mRNA locations
    mrna_locations = feature_mrna.location.parts
    if feature_mrna.strand != 1:
        mrna_locations = mrna_locations[::-1]
    # Use mRNA locations to identify intron sections
    mrna_sequence = []
    last_end = None
    for part in mrna_locations:
        start = int(part.start)
        end = int(part.end)
        if last_end is not None:
            # Add intron to sequence
            intron_size = start - last_end
            intron_tokens = [intron_token] * intron_size
            mrna_sequence = mrna_sequence + intron_tokens
        # Add part sequence
        part_sequence = list(str(part.extract(chromosome).seq))
        if mrna_sequence is None:
            mrna_sequence = part_sequence
        else:
            mrna_sequence = mrna_sequence + part_sequence
        last_end = end
    # Inject intron sections into the mRNA sequence
    # Make series
    pair = pd.Series()
    pair.at["gene"] = ",".join(gene_sequence.seq)
    pair.at["mrna"] = ",".join(mrna_sequence)
    return pair

### 3. Process data efficiently

1. Shard the data across n processes
    1. Chromosome (Path)
    2. gene_feature_idx
    3. mrna_feature_idx
2. Each process;
    1. Reads the source chromosome from disk (SeqRecord)
    2.

In [26]:
#| export
def shard_features(relationships, processes: int = os.cpu_count() - 1):
    shards = np.array_split(relationships, processes)
    return shards

In [27]:
#| hide
sharded_relationships = shard_features(mrna_to_protein_relationships)
len(sharded_relationships), len(sharded_relationships[0])

(11, 1168)

In [28]:
#| hide
len(mrna_to_protein_relationships)

12841

Lets bring this into a script so we can leverage proper processing.

In [29]:
#| hide
import nbdev; nbdev.nbdev_export()