In [1]:
import json
import os
import sys
sys.path.append('../data_generation/')
import utils
import subprocess
import torch
import glob
from transformers import PreTrainedModel
import re
from standalone_hyenadna import HyenaDNAModel
from standalone_hyenadna import CharacterTokenizer
import numpy as np
from tqdm import tqdm
import math
import h5py 
import scipy.stats
os.environ['CUDA_VISIBLE_DEVICES']='3'

def inject_substring(orig_str):
    """Hack to handle matching keys between models trained with and without
    gradient checkpointing."""

    # modify for mixer keys
    pattern = r"\.mixer"
    injection = ".mixer.layer"

    modified_string = re.sub(pattern, injection, orig_str)

    # modify for mlp keys
    pattern = r"\.mlp"
    injection = ".mlp.layer"

    modified_string = re.sub(pattern, injection, modified_string)

    return modified_string

# helper 2
def load_weights(scratch_dict, pretrained_dict, checkpointing=False):
    """Loads pretrained (backbone only) weights into the scratch state dict."""

    # loop thru state dict of scratch
    # find the corresponding weights in the loaded model, and set it

    # need to do some state dict "surgery"
    for key, value in scratch_dict.items():
        if 'backbone' in key:
            # the state dicts differ by one prefix, '.model', so we add that
            key_loaded = 'model.' + key
            # breakpoint()
            # need to add an extra ".layer" in key
            if checkpointing:
                key_loaded = inject_substring(key_loaded)
            try:
                scratch_dict[key] = pretrained_dict[key_loaded]
            except:
                raise Exception('key mismatch in the state dicts!')

    # scratch_dict has been updated
    return scratch_dict

class HyenaDNAPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """
    base_model_prefix = "hyenadna"

    def __init__(self, config):
        pass

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    @classmethod
    def from_pretrained(cls,
                        path,
                        model_name,
                        download=False,
                        config=None,
                        device='cpu',
                        use_head=False,
                        n_classes=2,
                      ):
        # first check if it is a local path
        pretrained_model_name_or_path = os.path.join(path, model_name)
        if os.path.isdir(pretrained_model_name_or_path) and download == False:
            if config is None:
                config = json.load(open(os.path.join(pretrained_model_name_or_path, 'config.json')))
        else:
            hf_url = f'https://huggingface.co/LongSafari/{model_name}'

            subprocess.run(f'rm -rf {pretrained_model_name_or_path}', shell=True)
            command = f'mkdir -p {path} && cd {path} && git lfs install && git clone {hf_url}'
            subprocess.run(command, shell=True)

            if config is None:
                config = json.load(open(os.path.join(pretrained_model_name_or_path, 'config.json')))

        scratch_model = HyenaDNAModel(**config, use_head=use_head, n_classes=n_classes)  # the new model format
        loaded_ckpt = torch.load(
            os.path.join(pretrained_model_name_or_path, 'weights.ckpt'),
            map_location=torch.device(device)
        )

        # need to load weights slightly different if using gradient checkpointing
        if config.get("checkpoint_mixer", False):
            checkpointing = config["checkpoint_mixer"] == True or config["checkpoint_mixer"] == True
        else:
            checkpointing = False

        # grab state dict from both and load weights
        state_dict = load_weights(scratch_model.state_dict(), loaded_ckpt['state_dict'], checkpointing=checkpointing)

        # scratch model has now been updated
        scratch_model.load_state_dict(state_dict)
        print("Loaded pretrained weights ok!")
        return scratch_model


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pretrained_model_name = 'hyenadna-tiny-1k-d256'
max_length = 1_000

model = HyenaDNAPreTrainedModel.from_pretrained(
    '../data_generation/checkpoints',
    pretrained_model_name,
).to('cuda')
model.eval()

# create tokenizer, no training involved :)
tokenizer = CharacterTokenizer(
    characters=['A', 'C', 'G', 'T', 'N'],  # add DNA characters
    model_max_length=max_length,
)
datalen = str(1000-2)
file = h5py.File("../data/CAGI/"+datalen+"/CAGI_onehot.h5", "r")
alt = file['alt']
ref = file['ref']

Loaded pretrained weights ok!


In [3]:
N, L, A = alt.shape
batch_size = 50
cos = []
dot = []
l1 = []
l2 = []
for i in tqdm(range(0,N,batch_size)):
    b_size = batch_size
    if i + batch_size > N:
        b_size = N-i
    onehot = np.concatenate((ref[i:i+b_size],alt[i:i+b_size]))
    seq = utils.onehot_to_seq(onehot)
    input_ids = tokenizer(list(seq), return_tensors="pt",)["input_ids"].to('cuda')
    with torch.inference_mode():
        hidden_states = model(input_ids).cpu().detach().numpy()
    for a in range(b_size):
        ref_out = hidden_states[a]
        alt_out = hidden_states[a+b_size]
        cos.append((ref_out * alt_out).sum()/(np.linalg.norm(ref_out)*np.linalg.norm(alt_out)))
        dot.append((ref_out * alt_out).sum())
        l1.append(np.absolute(ref_out - alt_out).sum())
        l2.append(np.square(ref_out - alt_out).sum())

100%|██████████| 369/369 [01:02<00:00,  5.94it/s]


In [3]:
output = h5py.File('../data/CAGI/'+'cagi_'+datalen+'_'+'hyena.h5', 'w')
output.create_dataset('cosine', data=np.array(cos))
output.create_dataset('dot', data=np.array(dot))
output.create_dataset('l1', data=np.array(l1))
output.create_dataset('l2', data=np.array(l2))
output.close()

TypeError: can only concatenate str (not "int") to str

In [6]:
import h5py
import pandas as pd
import numpy as np
from operator import itemgetter
import seaborn as sns
import matplotlib.pyplot as plt 
import scipy.stats as stats
datalen = '998'
cagi_df = pd.read_csv('../data/CAGI/'+datalen+'/final_cagi_metadata.csv',
                      index_col=0).reset_index()
target = cagi_df['6'].values.tolist()
exp_list = cagi_df['8'].unique()
cagi_result = h5py.File('../data/CAGI/'+'cagi_'+datalen+'_'+'hyena.h5', 'r')

In [8]:
perf = []
for key in cagi_result.keys():
    print(key)
    cagi_llr = cagi_result[key]
    for exp in ['LDLR','SORT1','F9','PKLR']:
        sub_df = cagi_df[cagi_df['8'] == exp]
        exp_target = np.array(target)[sub_df.index.to_list()]
        exp_pred = np.squeeze(cagi_llr)[sub_df.index.to_list()]
        exp_target = np.absolute(exp_target)
        exp_pred = exp_pred
        print(exp)
        perf.append(stats.pearsonr(exp_pred,exp_target)[0])
        print(stats.pearsonr(exp_pred,exp_target)[0])

cosine
LDLR
-0.18198321567003217
SORT1
0.02906028358401116
F9
-0.039868096952063246
PKLR
0.02110880773094938
dot
LDLR
-0.10038456364102219
SORT1
0.2792669648722809
F9
-0.0036651786753944416
PKLR
-0.08001239470975105
l1
LDLR
0.033906396737050384
SORT1
0.00802440182088367
F9
0.07154549609659909
PKLR
9.99546991168468e-05
l2
LDLR
0.1831921825264321
SORT1
-0.028138878273324765
F9
0.038730829156372895
PKLR
-0.018634354570632585
