In [None]:
# adapt nuleotide_dependency_maps (NDMs) to Genomic Pre-trained Network (GPN)

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns

# gpn specific
import gpn.model
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer

# ndm spcific
from transformers import DefaultDataCollator
from datasets import Dataset

model_path = "songlab/gpn-brassicales"

Example region: chr5:3566900-3567600

[UCSC Genome Browser view](https://genome.ucsc.edu/cgi-bin/hgTracks?db=hub_2660163_GCF_000001735.4&lastVirtModeType=default&lastVirtModeExtraState=&virtModeType=default&virtMode=0&nonVirtPosition=&position=chr5%3A3566900%2D3567600&hgsid=1597075775_CFnbwi2A0U0D8AuOgfJ0LsbUXnOb)

In [3]:
seq = "CGGGTTAAAAATCTAGTTGTTATTATTAAAGGAAATAAAATATCCTCATAAAACAATTTGTTGTAATCTATCTTTGGGCTAATGTTCTTATCCTACAAGACGAACCCTGACCGTATTCGTCGTAGAAAAAAAATTGCTTCGATCCCATCATTGAGTTCAATAATCGGCGCACAAAGGCCGATTCATAAAAACTCTAGGCCCATTAAAGTAAAGCCCATTCTCAACCCTATCCAGTCTCCCTGTATATATATATTTACGACACCAACCCAGCGTTGATATTTAATTTTCTTCAGTCAGAGATTTCGAAACCCTAGTCGATTTCGAGATCCAACTAACTCTGCTCCTTATCTCAGGTAAAATTCTCGCTCGAGAACTCAATTGCTTATCCAAAGTTCCAACTGAAGATGCTTTCCTACTGAATCTTAGGTTAATGTTTTGGATTTGGAATCTTACCCGAAATTTCTCTGCAGCTTGTTGAATTTGCGAAGTATGGGAGACGCTAGAGACAACGAAGCCTACGAGGAGGAGCTCTTGGACTATGAAGAAGAAGACGAGAAGGTCCCAGATTCTGGAAACAAAGTTAACGGCGAAGCTGTGAAAAAGTGAGTTTTATGGTTTCCTCGATATGTTTCATGTATACTACTGTGTGTTTAAATTTGTCGATTCTTAGATTACTACTTGATAACAAGTAGCAGTATGT"
len(seq)

700

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.get_vocab()

{'c': 4, 'a': 3, 'g': 5, '[PAD]': 0, 't': 6, '[UNK]': 2, '[MASK]': 1}

In [6]:
model = AutoModel.from_pretrained(model_path)
model.eval()

ConvNetModel(
  (embedding): GPNEmbedding()
  (encoder): Sequential(
    (0): ConvLayer(
      (conv): Sequential(
        (0): TransposeLayer()
        (1): Conv1d(512, 512, kernel_size=(9,), stride=(1,), padding=same)
        (2): TransposeLayer()
        (3): GELU(approximate='none')
        (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=512, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
    )
    (1): ConvLayer(
      (conv): Sequential(
        (0): TransposeLayer()
        (1): Conv1d(512, 512, kernel_size=(9,), stride=(1,), padding=same, dilation=(2,))
        (2): TransposeLayer()
        (3): GELU(approximate='none')
        (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=512, out_features=512, bias=True)
        (1): GELU(approxima

In [12]:
# dependency map generation functions

nuc_table = {"A": 0, "C": 1, "G": 2, "T": 3}


def mutate_sequence(seq):

    seq = seq.upper()
    mutated_sequences = {"seq": [], "mutation_pos": [], "nuc": [], "var_nt_idx": []}
    mutated_sequences["seq"].append(seq)
    mutated_sequences["mutation_pos"].append(-1)
    mutated_sequences["nuc"].append("real sequence")
    mutated_sequences["var_nt_idx"].append(-1)

    mutate_until_position = len(seq)

    for i in range(mutate_until_position):
        for nuc in ["A", "C", "G", "T"]:
            if nuc != seq[i]:
                mutated_sequences["seq"].append(seq[:i] + nuc + seq[i + 1 :])
                mutated_sequences["mutation_pos"].append(i)
                mutated_sequences["nuc"].append(nuc)
                mutated_sequences["var_nt_idx"].append(nuc_table[nuc])

    mutations_df = pd.DataFrame(mutated_sequences)

    return mutations_df


def create_dataloader(dataset, batch_size=64, rolling_masking=False):

    ds = Dataset.from_pandas(dataset[["seq"]])
    tok_ds = ds.map(lambda x: tokenizer(list(x["seq"])), batched=False, num_proc=20)
    rem_tok_ds = tok_ds.remove_columns("seq")

    data_collator = DefaultDataCollator()

    data_loader = torch.utils.data.DataLoader(
        rem_tok_ds,
        batch_size=batch_size,
        num_workers=4,
        shuffle=False,
        collate_fn=data_collator,
    )

    return data_loader

In [10]:
dataset = mutate_sequence(seq)
dataset

Unnamed: 0,seq,mutation_pos,nuc,var_nt_idx
0,CGGGTTAAAAATCTAGTTGTTATTATTAAAGGAAATAAAATATCCT...,-1,real sequence,-1
1,AGGGTTAAAAATCTAGTTGTTATTATTAAAGGAAATAAAATATCCT...,0,A,0
2,GGGGTTAAAAATCTAGTTGTTATTATTAAAGGAAATAAAATATCCT...,0,G,2
3,TGGGTTAAAAATCTAGTTGTTATTATTAAAGGAAATAAAATATCCT...,0,T,3
4,CAGGTTAAAAATCTAGTTGTTATTATTAAAGGAAATAAAATATCCT...,1,A,0
...,...,...,...,...
2096,CGGGTTAAAAATCTAGTTGTTATTATTAAAGGAAATAAAATATCCT...,698,C,1
2097,CGGGTTAAAAATCTAGTTGTTATTATTAAAGGAAATAAAATATCCT...,698,T,3
2098,CGGGTTAAAAATCTAGTTGTTATTATTAAAGGAAATAAAATATCCT...,699,A,0
2099,CGGGTTAAAAATCTAGTTGTTATTATTAAAGGAAATAAAATATCCT...,699,C,1


In [22]:
ds = Dataset.from_pandas(dataset[["seq"]])
ds

Dataset({
    features: ['seq'],
    num_rows: 2101
})

In [23]:
tok_ds = ds.map(lambda x: tokenizer(list(x["seq"])), batched=False, num_proc=20)
tok_ds

Map (num_proc=20):   0%|          | 0/2101 [00:00<?, ? examples/s]

Dataset({
    features: ['seq', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 2101
})

In [18]:
data_loader = create_dataloader(dataset, batch_size=1)
data_loader

Map (num_proc=20):   0%|          | 0/2101 [00:00<?, ? examples/s]

<torch.utils.data.dataloader.DataLoader at 0x7f9881445b80>

In [21]:
next(iter(data_loader))['input_ids']

tensor([[[4],
         [5],
         [5],
         [5],
         [6],
         [6],
         [3],
         [3],
         [3],
         [3],
         [3],
         [6],
         [4],
         [6],
         [3],
         [5],
         [6],
         [6],
         [5],
         [6],
         [6],
         [3],
         [6],
         [6],
         [3],
         [6],
         [6],
         [3],
         [3],
         [3],
         [5],
         [5],
         [3],
         [3],
         [3],
         [6],
         [3],
         [3],
         [3],
         [3],
         [6],
         [3],
         [6],
         [4],
         [4],
         [6],
         [4],
         [3],
         [6],
         [3],
         [3],
         [3],
         [3],
         [4],
         [3],
         [3],
         [6],
         [6],
         [6],
         [5],
         [6],
         [6],
         [5],
         [6],
         [3],
         [3],
         [6],
         [4],
         [6],
         [3],
         [6],
      