In [1]:
import json
import os
import subprocess
import torch
import glob
import utils
from transformers import PreTrainedModel
import re
from standalone_hyenadna import HyenaDNAModel
from standalone_hyenadna import CharacterTokenizer
import numpy as np
from tqdm import tqdm
import math
import h5py 
import scipy.stats
os.environ['CUDA_VISIBLE_DEVICES']='2'

def inject_substring(orig_str):
    """Hack to handle matching keys between models trained with and without
    gradient checkpointing."""

    # modify for mixer keys
    pattern = r"\.mixer"
    injection = ".mixer.layer"

    modified_string = re.sub(pattern, injection, orig_str)

    # modify for mlp keys
    pattern = r"\.mlp"
    injection = ".mlp.layer"

    modified_string = re.sub(pattern, injection, modified_string)

    return modified_string

# helper 2
def load_weights(scratch_dict, pretrained_dict, checkpointing=False):
    """Loads pretrained (backbone only) weights into the scratch state dict."""

    # loop thru state dict of scratch
    # find the corresponding weights in the loaded model, and set it

    # need to do some state dict "surgery"
    for key, value in scratch_dict.items():
        if 'backbone' in key:
            # the state dicts differ by one prefix, '.model', so we add that
            key_loaded = 'model.' + key
            # breakpoint()
            # need to add an extra ".layer" in key
            if checkpointing:
                key_loaded = inject_substring(key_loaded)
            try:
                scratch_dict[key] = pretrained_dict[key_loaded]
            except:
                raise Exception('key mismatch in the state dicts!')

    # scratch_dict has been updated
    return scratch_dict

class HyenaDNAPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """
    base_model_prefix = "hyenadna"

    def __init__(self, config):
        pass

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    @classmethod
    def from_pretrained(cls,
                        path,
                        model_name,
                        download=False,
                        config=None,
                        device='cpu',
                        use_head=False,
                        n_classes=2,
                      ):
        # first check if it is a local path
        pretrained_model_name_or_path = os.path.join(path, model_name)
        if os.path.isdir(pretrained_model_name_or_path) and download == False:
            if config is None:
                config = json.load(open(os.path.join(pretrained_model_name_or_path, 'config.json')))
        else:
            hf_url = f'https://huggingface.co/LongSafari/{model_name}'

            subprocess.run(f'rm -rf {pretrained_model_name_or_path}', shell=True)
            command = f'mkdir -p {path} && cd {path} && git lfs install && git clone {hf_url}'
            subprocess.run(command, shell=True)

            if config is None:
                config = json.load(open(os.path.join(pretrained_model_name_or_path, 'config.json')))

        scratch_model = HyenaDNAModel(**config, use_head=use_head, n_classes=n_classes)  # the new model format
        loaded_ckpt = torch.load(
            os.path.join(pretrained_model_name_or_path, 'weights.ckpt'),
            map_location=torch.device(device)
        )

        # need to load weights slightly different if using gradient checkpointing
        if config.get("checkpoint_mixer", False):
            checkpointing = config["checkpoint_mixer"] == True or config["checkpoint_mixer"] == True
        else:
            checkpointing = False

        # grab state dict from both and load weights
        state_dict = load_weights(scratch_model.state_dict(), loaded_ckpt['state_dict'], checkpointing=checkpointing)

        # scratch model has now been updated
        scratch_model.load_state_dict(state_dict)
        print("Loaded pretrained weights ok!")
        return scratch_model


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pretrained_model_name = 'hyenadna-tiny-1k-d256'
max_length = 1_000

model = HyenaDNAPreTrainedModel.from_pretrained(
    './checkpoints',
    pretrained_model_name,
).to('cuda')
model.eval()

# create tokenizer, no training involved :)
tokenizer = CharacterTokenizer(
    characters=['A', 'C', 'G', 'T', 'N'],  # add DNA characters
    model_max_length=max_length,
)

Loaded pretrained weights ok!


## Lenti-MPRA

In [5]:
celltype = 'K562'
file = h5py.File('/home/ztang/multitask_RNA/data/lenti_MPRA/'+celltype+'_data.h5','r')

In [6]:
hidden_output = h5py.File('../data/lenti_MPRA_embed/hyena_'+celltype+'.h5','w')
batch_size = 128
output_cache = []
for i in tqdm(range(0,len(file['seq']),batch_size)):
  seq = file['seq'][i:i+batch_size].astype('U230')
  tok_seq = tokenizer(list(seq),return_tensors="pt")["input_ids"].to('cuda')
  with torch.inference_mode():
    embeddings = model(tok_seq).cpu().detach().numpy()
    output_cache.extend(embeddings)
hidden_output.create_dataset(name='seq',data = np.array(output_cache))
hidden_output.create_dataset(name='mean',data = file['mean'][:])
hidden_output.close()

100%|██████████| 1768/1768 [01:06<00:00, 26.52it/s]


## Chip/Clip seq data

In [None]:
max_len = 0
file_list = glob.glob('../data/chip/*.h5')
for file in file_list:
    tf_name = file.split('/')[-1][:-7]
    hyena_output = h5py.File('../data/chip/Hyena/'+tf_name+'_200.h5','w')
    batch_size = 128
    file = h5py.File(file,'r')
    for label in ('train','valid','test'):
        output_cache = []  
        for i in tqdm(range(0,len(file['x_'+label]),batch_size)):
            seq = file['x_'+label][i:i+batch_size].astype('int')
            seq = np.transpose(seq,(0,2,1))
            seq = utils.onehot_to_seq(seq)
            input_ids = tokenizer(list(seq), return_tensors="pt",)["input_ids"].to('cuda')
            with torch.inference_mode():
                hidden_states = model(input_ids).cpu().detach().numpy()
            output_cache.extend(hidden_states)
        hyena_output.create_dataset(name='x_'+label,data = np.array(output_cache),dtype = 'float32')
        hyena_output.create_dataset(name='y_'+label,data = file['y_'+label][:],dtype='int') 
    hyena_output.close()


KeyError: "Unable to open object (object 'x_train' doesn't exist)"

In [6]:
max_len = 0
file_list = glob.glob('../data/eclip/*.h5')
for file in file_list:
    tf_name = file.split('/')[-1][:-7]
    hyena_output = h5py.File('../data/eclip/Hyena/'+tf_name+'_200.h5','w')
    batch_size = 128
    file = h5py.File(file,'r')
    for label in ('train','valid','test'):
        output_cache = []  
        for i in tqdm(range(0,len(file['X_'+label]),batch_size)):
            seq = file['X_'+label][i:i+batch_size].astype('int')
            seq = np.transpose(seq,(0,2,1))
            seq = utils.onehot_to_seq(seq)
            input_ids = tokenizer(list(seq), return_tensors="pt",)["input_ids"].to('cuda')
            with torch.inference_mode():
                hidden_states = model(input_ids).cpu().detach().numpy()
            output_cache.extend(hidden_states)
        hyena_output.create_dataset(name='x_'+label,data = np.array(output_cache),dtype = 'float32')
        hyena_output.create_dataset(name='y_'+label,data = file['Y_'+label][:],dtype='int') 
    hyena_output.close()

  0%|          | 0/32 [00:00<?, ?it/s]

100%|██████████| 32/32 [00:01<00:00, 18.05it/s]
100%|██████████| 5/5 [00:00<00:00, 20.56it/s]
100%|██████████| 9/9 [00:00<00:00, 17.90it/s]
100%|██████████| 89/89 [00:05<00:00, 15.55it/s]
100%|██████████| 13/13 [00:00<00:00, 18.46it/s]
100%|██████████| 26/26 [00:01<00:00, 18.79it/s]
100%|██████████| 80/80 [00:05<00:00, 15.97it/s]
100%|██████████| 12/12 [00:00<00:00, 19.34it/s]
100%|██████████| 23/23 [00:01<00:00, 18.53it/s]
100%|██████████| 38/38 [00:02<00:00, 17.92it/s]
100%|██████████| 6/6 [00:00<00:00, 20.54it/s]
100%|██████████| 11/11 [00:00<00:00, 19.03it/s]
100%|██████████| 50/50 [00:02<00:00, 17.41it/s]
100%|██████████| 8/8 [00:00<00:00, 20.94it/s]
100%|██████████| 15/15 [00:00<00:00, 18.89it/s]
100%|██████████| 179/179 [00:14<00:00, 12.69it/s]
100%|██████████| 26/26 [00:01<00:00, 18.53it/s]
100%|██████████| 52/52 [00:02<00:00, 17.73it/s]
100%|██████████| 33/33 [00:01<00:00, 18.09it/s]
100%|██████████| 5/5 [00:00<00:00, 20.55it/s]
100%|██████████| 10/10 [00:00<00:00, 19.49it/s]


## MT Splice data

In [16]:
file = h5py.File('../data/alternative_splicing/delta_logit.h5','r')
hyena_output = h5py.File('../data/alternative_splicing/hyena_splice.h5','w')
batch_size = 32
max_len = 0
for label in ('valid','test','train'):
    l_cache = []
    r_cache = [] 
    for i in tqdm(range(0,len(file['x_'+label]),batch_size)):
        seq = file['x_'+label][i:i+batch_size].astype('int')
        seq = utils.onehot_to_seq(seq)
        clean_seq = seq
        #clean_seq = [s if 'N' not in s else s.replace('N','[PAD]') for s in seq ]
        l_seq = []
        r_seq = []
        for seq in clean_seq:
            l_seq.append(seq[:400])
            r_seq.append(seq[400:])
        l_input = tokenizer(list(l_seq), return_tensors="pt",)["input_ids"].to('cuda')
        r_input = tokenizer(list(r_seq), return_tensors="pt",)["input_ids"].to('cuda')

        with torch.inference_mode():
            l_output = model(l_input).cpu().detach().numpy()
            r_output = model(r_input).cpu().detach().numpy()
        l_cache.extend(l_output)
        r_cache.extend(r_output)
    hyena_output.create_dataset(name='xl_'+label,data = np.array(l_cache),dtype = 'float32')
    hyena_output.create_dataset(name='xr_'+label,data = np.array(r_cache),dtype = 'float32')
    hyena_output.create_dataset(name='y_'+label,data = file['y_'+label][:],dtype='float32') 


  0%|          | 0/34 [00:00<?, ?it/s]

100%|██████████| 34/34 [00:01<00:00, 21.33it/s]
100%|██████████| 370/370 [00:18<00:00, 20.39it/s]
100%|██████████| 1189/1189 [00:59<00:00, 19.89it/s]


In [19]:
for key in hyena_output.keys():
    print(hyena_output[key].shape)

hyena_output.close()

(11840, 402, 256)
(38028, 402, 256)
(1088, 402, 256)
(11840, 402, 256)
(38028, 402, 256)
(1088, 402, 256)
(11840, 56, 2)
(38028, 56, 2)
(1088, 56, 2)


## RNA-enlong data

In [20]:
hyena_output = h5py.File('../data/RNAenlong/hyena_embed.h5','w')
file = h5py.File('../data/RNAenlong/insert_dataset.h5','r')
batch_size = 32
for dataset in ['test','train','valid']:
    key = 'X_'+dataset
    onehot = file[key]
    string_seq = utils.onehot_to_seq(onehot)
    token_seq = tokenizer(list(string_seq), return_tensors="pt",)["input_ids"].to('cuda')
    output_cache = []
    for seq_i in tqdm(range(0,len(token_seq),batch_size)):
        with torch.inference_mode():
            hidden_states = model(token_seq[seq_i:seq_i+batch_size]).cpu().detach().numpy()
        output_cache.extend(hidden_states)
    hyena_output.create_dataset(name=key,data = np.array(output_cache))
    hyena_output.create_dataset(name='Y_'+dataset,data = file['Y_'+dataset][:])
    hyena_output.close()

100%|██████████| 36/36 [00:00<00:00, 129.04it/s]
100%|██████████| 286/286 [00:02<00:00, 139.11it/s]
100%|██████████| 36/36 [00:00<00:00, 131.96it/s]
