In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
import json
import csv
import math, random, sys
import numpy as np
import argparse
import os
from preprocess import *
from bindgen import *
from tqdm import tqdm
import pdbfixer
import openmm
import biotite.structure as struc
from biotite.structure import AtomArray, Atom
from biotite.structure.io import save_structure
from biotite.structure.io.pdb import PDBFile
from sidechainnet.structure.PdbBuilder import PdbBuilder
import py3Dmol

In [89]:
ENERGY = openmm.unit.kilocalories_per_mole
LENGTH = openmm.unit.angstroms

In [55]:
def openmm_relax(pdb_file, stiffness=10., tolerance=2.39, use_gpu=False):
    fixer = pdbfixer.PDBFixer(pdb_file)
    fixer.findMissingResidues()
    fixer.findMissingAtoms()
    fixer.addMissingAtoms()
    fixer.addMissingHydrogens()

    force_field = openmm.app.ForceField("amber14/protein.ff14SB.xml")
    modeller = openmm.app.Modeller(fixer.topology, fixer.positions)
    modeller.addHydrogens(force_field)
    system = force_field.createSystem(modeller.topology)

    if stiffness > 0:
        stiffness = stiffness * ENERGY / (LENGTH**2)
        force = openmm.CustomExternalForce("0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")
        force.addGlobalParameter("k", stiffness)
        for p in ["x0", "y0", "z0"]:
            force.addPerParticleParameter(p)
        for residue in modeller.topology.residues():
            for atom in residue.atoms():
                if atom.name in ["N", "CA", "C", "CB"]:
                    force.addParticle(
                            atom.index, modeller.positions[atom.index]
                    )
        system.addForce(force)

    tolerance = tolerance * ENERGY
    integrator = openmm.LangevinIntegrator(0, 0.01, 1.0)
    platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU")

    simulation = openmm.app.Simulation(modeller.topology, system, integrator, platform)
    simulation.context.setPositions(modeller.positions)
    simulation.minimizeEnergy(tolerance)
    state = simulation.context.getState(getEnergy=True)
    energy = state.getKineticEnergy() + state.getPotentialEnergy()

    with open(pdb_file, "w") as f:
        openmm.app.PDBFile.writeFile(
            simulation.topology,
            simulation.context.getState(getPositions=True).getPositions(),
            f,
            keepIds=True
        )
    return energy

In [81]:
def save_pdb(X, seq, file):
    pdb = PdbBuilder(seq, X.reshape(X.shape[0] * 14, 3)).get_pdb_string()
    
    pdb_file = open(f'{file}.pdb', 'w')
    pdb_file.write(pdb)

In [82]:
def view_pdb(file):
    with open(f'{file}') as ifile:
        system = "".join([x for x in ifile])
    
    view = py3Dmol.view(width=400, height=300)
    view.addModelsAsFrames(system)
    view.setStyle({'model': -1}, {"cartoon": {'color': 'spectrum'}})
    view.zoomTo()
    view.show()

In [56]:
model_ckpt, _, args = torch.load('weights/HERN_dock.ckpt')
model = RefineDocker(args)
model.load_state_dict(model_ckpt)
model.eval()

RefineDocker(
  (embedding): AAEmbedding()
  (features): ProteinFeatures(
    (embeddings): PositionalEncodings()
  )
  (W_i): Linear(in_features=123, out_features=256, bias=True)
  (bce_loss): BCEWithLogitsLoss()
  (ce_loss): CrossEntropyLoss()
  (mse_loss): MSELoss()
  (huber_loss): SmoothL1Loss()
  (U_i): Linear(in_features=123, out_features=256, bias=True)
  (target_mpn): EGNNEncoder(
    (features): ProteinFeatures(
      (embeddings): PositionalEncodings()
    )
    (W_v): Linear(in_features=6, out_features=256, bias=True)
    (W_e): Linear(in_features=39, out_features=256, bias=True)
    (layers): ModuleList(
      (0): MPNNLayer(
        (dropout): Dropout(p=0.1, inplace=False)
        (norm): Identity()
        (W): Sequential(
          (0): Linear(in_features=1024, out_features=256, bias=True)
          (1): ReLU()
          (2): Linear(in_features=256, out_features=256, bias=True)
          (3): ReLU()
          (4): Linear(in_features=256, out_features=256, bias=True)
    

In [92]:
data = load_pdb('1nca_imgt.pdb')

In [93]:
data[0][0].fill_(0)
out = model(*batch)

In [94]:
bind_X, _, bind_A, _ = data[0]
X = out.bind_X[0].cpu().numpy()
save_pdb(X, 'ARGEDNFGSLSDY', 'test_cdr3')

In [95]:
openmm_relax('test_cdr3.pdb')

Quantity(value=6665.7299386606055, unit=kilojoule/mole)

In [96]:
view_pdb('test_cdr3.pdb')