In [21]:
# environment
import argparse
import warnings
import torch
import numpy as np
import os
from modules.AutoEM import Solver
from unet3d.emmodel import ResUNet3D4EM

warnings.filterwarnings('ignore')
np.set_printoptions(threshold=np.inf,suppress=True,precision=2)

def run_emodelx(dynamic_config,BB_model,CA_model,AA_model):
    if dynamic_config.protocol not in ['temp_free','temp_flex','seq_free']:
        print('Wrong protocol! protocol should be [temp_free,temp_flex,seq_free]')
        return
    if dynamic_config.protocol in ['temp_free','temp_flex']:
        if not dynamic_config.fasta:
            print('--fasta is required when protocol is not seq_free')
            return
        elif not os.path.exists(dynamic_config.fasta):
            print('--fasta: path not exisit!')
            return
        if dynamic_config.protocol =='temp_flex':
            if not dynamic_config.template_dir:
                print('--template_dir is required when protocol is not seq_free')
                return
            elif not os.path.exists(dynamic_config.template_dir):
                print('--template_dir: path not exisit!')
                return
    
    if dynamic_config.run_phenix:
        dynamic_config.run_pulchra=True
        if not dynamic_config.resolution:
            print('--resolution is required for run.phenix_real_space_refine')
            return
        if not dynamic_config.phenix_act:
            print('--phenix_act is required for run.phenix_real_space_refine')
            return
        
    if dynamic_config.run_pulchra:
        if not dynamic_config.pulchra_path:
            print('--pulchra_path is required for run.phenix_real_space_refine')
            return


    AutoEM_solver = Solver(dynamic_config)
    AutoEM_solver.run(BB_model,CA_model,AA_model)

parser = argparse.ArgumentParser()
parser.add_argument('--protocol', type=str, default='temp_free', help='choose among temp_free,temp_flex,seq_free')

parser.add_argument('--EM_map', type=str, required=True, help='path of EM map')
parser.add_argument('--fasta', type=str, default='', help='path of fasta file, required when protocol is temp_free or temp_flex')
parser.add_argument('--template_dir', type=str, default='./inputs/templates', help='dir of template folder, required when --protocol is temp_flex, path format for different chain please reference to ./inputs/templates')

parser.add_argument('--run_pulchra', action='store_true', help='whether to run pulchra for all_atom construction')
parser.add_argument('--pulchra_path',type=str, help='directory of pulchra, e.g.: modules/pulchra304/src/pulchra')

parser.add_argument('--run_phenix', action='store_true', help='whether to run phenix.real_space_refine')
parser.add_argument('--resolution', type=float, help='resolution of EM map, required when run_phenix_real_space_refine is open')
parser.add_argument('--phenix_act',type=str, help='script to activate phenix environment, e.g.: modules/phenix-1.20.1-4487/phenix_env.sh')
parser.add_argument('--phenix_param', default='data/inputs/phenix.eff',type=str, help='param for phenix.real_space_refine')

parser.add_argument('--output_dir', type=str, default='./data/outputs/example', help='dir of output pdbs')

parser.add_argument('--best_CA_model', type=str, default='./models/best_CA_model.ckpt', help='set as default')
parser.add_argument('--best_BB_model', type=str, default='./models/best_BB_model.ckpt', help='set as default')
parser.add_argument('--best_AA_model', type=str, default='./models/best_AA_model.ckpt', help='set as default')
parser.add_argument('--seed', type=int, default=2022, help='set as default')
parser.add_argument('--cluster_eps', type=int, default=10, help='set as default')
parser.add_argument('--cluster_min_points', type=int, default=10, help='set as default')
parser.add_argument('--nms_radius', type=int, default=9, help='set as default')
parser.add_argument('--CA_score_thrh', type=float, default=0.35, help='set as default')
parser.add_argument('--frags_len', type=int, default=150, help='set as defaul')
parser.add_argument('--n_hop', type=int, default=6, help='set as default')
parser.add_argument('--neigh_mat_thrh', type=float, default=0.7, help='set as default')
parser.add_argument('--mul_proc_num', type=int, default=30, help='set as default')
parser.add_argument('--score_thrh', type=float, default=2, help='set as default')
parser.add_argument('--gap_len', type=int, default=3, help='set as default')
parser.add_argument('--struct_len', type=int, default=5, help='set as default')

# you can replace --EM-map and --fasta with your inputs
dynamic_config = parser.parse_args(['--protocol','temp_free',
                                    '--EM_map','data/inputs/maps/emd_32336.map.gz', 
                                    '--fasta','data/inputs/fastas/7w72', 
                                    '--output_dir','./data/outputs', 
                                    '--run_pulchra', 
                                    '--pulchra_path', 'modules/pulchra304/src/pulchra'])

torch.manual_seed(dynamic_config.seed)
BB_model = ResUNet3D4EM().to('cuda')
CA_model = ResUNet3D4EM().to('cuda')
AA_model = ResUNet3D4EM().to('cuda')
BB_model.load_state_dict(torch.load(dynamic_config.best_BB_model))
CA_model.load_state_dict(torch.load(dynamic_config.best_CA_model))
AA_model.load_state_dict(torch.load(dynamic_config.best_AA_model))
BB_model.eval()
CA_model.eval()
AA_model.eval()

ResUNet3D4EM(
  (BB_encoders): ModuleList(
    (0): Encoder(
      (basic_module): ExtResNetBlock(
        (conv1): SingleConv(
          (groupnorm): GroupNorm(1, 1, eps=1e-05, affine=True)
          (conv): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
        (conv2): SingleConv(
          (groupnorm): GroupNorm(8, 32, eps=1e-05, affine=True)
          (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
        (conv3): SingleConv(
          (groupnorm): GroupNorm(8, 32, eps=1e-05, affine=True)
          (conv): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        )
        (non_linearity): ReLU(inplace=True)
      )
    )
    (1): Encoder(
      (pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (basic_module): ExtResNetBlock(

In [22]:
run_emodelx(dynamic_config,BB_model,CA_model,AA_model)
print(dynamic_config)

starting... emid: 32336, pdbid: 7w72
preprocessEM...
preprocessEM finished in 22 seconds
nnPred...
nnPred finished in 19 seconds
clustering...
clustering finished in 5 seconds
fragModeling...
seq_free trace_num: 17
fragModeling finished in 0 seconds
Check your fasta inputs, the fasta parser can only handle standard fasta inputs and only protein sequences would be parsed:
Your sequence 1 name: 7W72_1, Chains that belong to this sequence:
U 
Your sequence: 
MAAPLVLVLVVAVTVRAALFRSSLAEFISERVEVVSPLSSWKRVVEGLSLLDLGVSPYSGAVFHETPLIIYLFHFLIDYAELVFMITDALTAIALYFAIQDFNKVVFKKQKLLLELDQYAPDVAELIRTPMEMRYIPLKVALFYLLNPYTILSCVAKSTCAINNTLIAFFILTTIKGSAFLSAIFLALATYQSLYPLTLFVPGLLYLLQRQYIPVKMKSKAFWIFSWEYAMMYVGSLVVIICLSFFLLSSWDFIPAVYGFILSVPDLTPNIGLFWYFFAEMFEHFSLFFVCVFQINVFFYTIPLAIKLKEHPIFFMFIQIAVIAIFKSYPTVGDVALYMAFFPVWNHLYRFLRNIFVLTCIIIVCSLLFPVLWHLWIYAGSANSNFFYAITLTFNVGQILLISDYFYAFLRREYYLTHGL
Your sequence 2 name: 7W72_2, Chains that belong to this sequence:
S 
Your sequence: 
AATHLEVARGKRAALFFAAVAIVLGLPLWWKTT