In [1]:
import torch
import sequence_models
from sequence_models.constants import SPECIALS
from sequence_models.pretrained import load_carp
from sequence_models.collaters import SimpleCollater
from torch.utils.data import DataLoader
import os
os.environ['CUDA_VISIBLE_DEVICES']='7'
import sys
sys.path.append('/home/amber/multitask_RNA/data_generation/')
sys.path.append('/home/amber/multitask_RNA/rna_self_train/')
import rna_model
from torchinfo import summary
import h5py
import utils
import numpy as np
from tqdm import tqdm
device = torch.device("cuda")

In [2]:
RNA='ACGTN'
RNA_ALPHABET = RNA+SPECIALS
collater = SimpleCollater(RNA_ALPHABET,False,False)
class SaveOutput:
    def __init__(self):
        self.outputs = []
        
    def __call__(self, module, module_in, module_out):
        self.outputs.extend(module_out.cpu().detach().numpy())
        
    def clear(self):
        self.outputs = []

In [3]:
config={'model':'ByteNetLM',
                'lr':1e-3,
                'n_tokens':len(RNA_ALPHABET),
                'd_embedding' : 9, # dimension of embedding
                'd_model': 320, # dimension to use within ByteNet model, //2 every layer
                'n_layers' : 15, # number of layers of ByteNet block
                'activation': 'relu',
                'kernel_size' : 5, # the kernel width
                'r' : 32, # used to calculate dilation factor
                'padding_idx' : RNA_ALPHABET.index('-') ,# location of padding token in ordered alphabet
                'dropout' : 0.1 ,
                }

model = rna_model.ByteNetLM(config['n_tokens'], config['d_embedding'], config['d_model'],
                        config['n_layers'], config['kernel_size'], config['r'], config['lr'],
                        padding_idx=config['padding_idx'], causal=False, dropout=config['dropout'])
model.load_state_dict(torch.load('/home/amber/multitask_RNA/rna_self_train/rna-selftrain/2hkapjgg/checkpoints/best_model.ckpt')['state_dict'])
model.to(device)
model.eval()

ByteNetLM(
  (embedder): ByteNet(
    (embedder): Embedding(9, 9, padding_idx=6)
    (up_embedder): PositionFeedForward(
      (conv): Conv1d(9, 320, kernel_size=(1,), stride=(1,))
    )
    (layers): ModuleList(
      (0): ByteNetBlock(
        (conv): MaskedConv1d(160, 160, kernel_size=(5,), stride=(1,), padding=(2,))
        (sequence1): Sequential(
          (0): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (1): ReLU()
          (2): PositionFeedForward(
            (conv): Conv1d(320, 160, kernel_size=(1,), stride=(1,))
          )
          (3): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
          (4): ReLU()
        )
        (sequence2): Sequential(
          (0): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
          (1): ReLU()
          (2): PositionFeedForward(
            (conv): Conv1d(160, 320, kernel_size=(1,), stride=(1,))
          )
        )
      )
      (1): ByteNetBlock(
        (conv): MaskedConv1d(160, 160, kernel_size=(5,

In [4]:
save_output = SaveOutput()
hook_handles = []
handle = model.last_norm.register_forward_hook(save_output)

In [5]:
file = h5py.File('../../data/rna_stable/insert_dataset.h5','r')
carp_output = h5py.File('../../data/rna_stable/carp_embed.h5','w')
batch_size = 64
for dataset in ['test','train','valid']:
    key = 'X_'+dataset
    onehot = file[key]
    string_seq = utils.onehot_to_seq(onehot)
    expand_seq = np.expand_dims(np.array(string_seq),axis = -1)
    token_output = collater(expand_seq)[0]

    save_output = SaveOutput()
    handle = model.last_norm.register_forward_hook(save_output)

    for seq_i in tqdm(range(0,len(token_output),batch_size)):
        seq_batch = torch.tensor(token_output[seq_i:seq_i+batch_size]).to(device)
        output_seq = model(seq_batch).cpu().detach().numpy()
    carp_output.create_dataset(name=key,data = np.array(save_output.outputs))
    carp_output.create_dataset(name='Y_'+dataset,data = file['Y_'+dataset][:])
    handle.remove()


  seq_batch = torch.tensor(token_output[seq_i:seq_i+batch_size]).to(device)
100%|██████████| 18/18 [00:00<00:00, 23.40it/s]
100%|██████████| 143/143 [00:05<00:00, 26.48it/s]
100%|██████████| 18/18 [00:00<00:00, 33.76it/s]


In [6]:
carp_output['X_train']

<HDF5 dataset "X_train": shape (9131, 173, 320), type "<f4">

In [7]:
carp_output.close()