In [6]:
import _pickle as pickle

import os
import sys
import json
import collections
import csv
import torch
import torch.optim as optim
from torch.utils import data
import torch.backends.cudnn as cudnn

from tqdm import tqdm
import numpy as np

# sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
sys.path.append('./..')

from model.models import ContactAttention_simple_fix_PE, Lag_PP_mixed, RNA_SS_e2e
from utils.utils import *
from data_generator.data_generator import RNASSDataGenerator, Dataset
from utils.postprocess import postprocess, postprocess_proposed

In [9]:
with open("config.json", "r", encoding="utf-8") as f:
  config = json.load(f)

seed = config['seed']
num_of_device = config['num_of_device']
use_device_num = config['use_device_num']
d = config['u_net_d']
BATCH_SIZE = 1
pp_steps = config['pp_steps']
k = config['k']
s = math.log(9.0)


# seed fix for reproducing
if seed != 'none':
  seed_fix(seed)

os.environ["CUDA_VISIBLE_DEVICES"] = generate_visible_device(num_of_device)

device = torch.device('cuda:{}'.format(use_device_num))

RNA_SS_data = collections.namedtuple('RNA_SS_data', 'seq ss_label length name pairs')

test_data = RNASSDataGenerator('/media/ksj/nar_web_rna/test/test_e2e.pickle')

seq_len = test_data.data_y.shape[-2]

params = {'batch_size': BATCH_SIZE,
          'shuffle': True,
          'num_workers': 6,
          'drop_last': True}

test_set = Dataset(test_data)
test_generator = data.DataLoader(test_set, **params)


In [11]:
contact_net = ContactAttention_simple_fix_PE(d=d, L=seq_len).to(device)
lag_pp_net = Lag_PP_mixed(pp_steps, k, device=use_device_num).to(device)
rna_ss_e2e = RNA_SS_e2e(contact_net.to(device), lag_pp_net.to(device)).to(device)
rna_ss_e2e.load_state_dict(torch.load('/media/ksj/nar_web_rna/E2Efold/model/e2efold_model.pt', map_location = device))
rna_ss_e2e.to(device)

RNA_SS_e2e(
  (model_att): ContactAttention_simple_fix_PE(
    (conv1d1): Conv1d(4, 10, kernel_size=(9,), stride=(1,), padding=(8,), dilation=(2,))
    (bn1): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_test_1): Conv2d(60, 10, kernel_size=(1, 1), stride=(1, 1))
    (bn_conv_1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_test_2): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
    (bn_conv_2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_test_3): Conv2d(10, 1, kernel_size=(1, 1), stride=(1, 1))
    (encoder_layer): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): _LinearWithBias(in_features=20, out_features=20, bias=True)
      )
      (linear1): Linear(in_features=20, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=20, bias=True)
 

In [12]:
torch.save(rna_ss_e2e.state_dict(), '/media/ksj/nar_web_rna/E2Efold/model/e2efold_model_new.pt', _use_new_zipfile_serialization=False)