In [None]:
import sys, os, glob
from copy import deepcopy

import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.nn.functional as F
import torch.optim

import modeller
from modeller import *
from modeller.scripts import complete_pdb

import MDAnalysis as mda

import biobox as bb

#edit path as required for your computer (or remove, if you installed molearn via conda-forge)
sys.path.insert(0, "C:\\Users\\xdzl45\\workspace\\molearn\\src")
from molearn import load_data, Auto_potential, Autoencoder, ResidualBlock

In [None]:
class MolearnAnalysis(object):
    
    def __init__(self, network, infile, m=2.0, latent_z=2, r=2, atoms = ["CA", "C", "N", "CB", "O"]):
        
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        
        training_set, meanval, stdval, atom_names, mol, test0, test1 = load_data(infile,
                                                                                 atoms = atoms,
                                                                                 device=self.device)
      
        # set residues names with protonated histidines back to generic HIS name (needed by DOPE score function)
        testH = mol.data["resname"].values
        testH[testH == "HIE"] = "HIS"
        testH[testH == "HID"] = "HIS"
        mol.data["resname"] = testH
        
        # create an MDAnalysis instance of input protein
        mol.write_pdb("tmp.pdb")
        self.mymol = mda.Universe('tmp.pdb')

        self.training_set = training_set
        self.meanval = meanval
        self.stdval = stdval
        self.mol = mol
        self.atoms = atoms
        
        checkpoint = torch.load(networkfile, map_location=self.device)
        self.network = Autoencoder(m=m, latent_z=latent_z, r=r).to(self.device)
        self.network.load_state_dict(checkpoint['model_state_dict'])

        for modulelist in [self.network.encoder, self.network.decoder]:
            for layer in modulelist:
                if type(layer)==torch.nn.BatchNorm1d:
                    layer.momentum=1.0
                elif type(layer)==ResidualBlock:
                    for rlayer in layer.conv_block:
                        if type(rlayer)==torch.nn.BatchNorm1d:
                            rlayer.momentum=1.0

        with torch.no_grad():
            self.network.decode(self.network.encode(self.training_set.float()))

        self.network.eval()
        #os.remove("tmp.pdb")
        os.remove("rmsd_matrix.npy")
    
    
    def load_test(self, infile):

        test_set, _, _, _, _, _, _ = load_data(infile, atoms = self.atoms, device=self.device)
        if test_set.shape[2] != self.training_set.shape[2]:
            raise Exception(f'number of d.o.f. differs: training set has {self.training_set.shape[2]}, test set has {test_set.shape[2]}')

        return test_set


    def get_error(self, dataset="", align=False):
        '''
        Calculate the reconstruction error of a dataset encoded and decoded by a trained neural network
        '''

        if dataset == "":
            dataset = self.training_set

        z = self.network.encode(dataset.float())
        decoded = self.network.decode(z)[:,:,:dataset.shape[2]]

        err = []
        for i in range(dataset.shape[0]):

            crd_ref = dataset[i].permute(1,0).unsqueeze(0).data.cpu().numpy()*self.stdval + self.meanval
            crd_mdl = decoded[i].permute(1,0).unsqueeze(0).data.cpu().numpy()[:, :dataset.shape[2]]*self.stdval + self.meanval #clip the padding of models  

            if align: # use Molecule Biobox class to calculate RMSD
                self.mol.coordinates = deepcopy(crd_ref)
                self.mol.set_current(0)
                self.mol.add_xyz(crd_mdl[0])
                rmsd = self.mol.rmsd(0, 1)
            else:
                rmsd = np.sqrt(np.sum((crd_ref.flatten()-crd_mdl.flatten())**2)/crd_mdl.shape[1]) # Cartesian L2 norm

            err.append(rmsd)

        return np.array(err)


    def get_dope(self, dataset=""):

        if dataset == "":
            dataset = self.training_set    

        z = self.network.encode(dataset.float())
        decoded = self.network.decode(z)[:,:,:dataset.shape[2]]

        dope_dataset = []
        dope_decoded = []
        for i in range(dataset.shape[0]):

            # calculate DOPE score of input dataset
            crd_ref = dataset[i].permute(1,0).unsqueeze(0).data.cpu().numpy()*self.stdval + self.meanval
            self.mol.coordinates = deepcopy(crd_ref)
            self.mol.write_pdb("tmp.pdb")
            s = dope_score("tmp.pdb")
            dope_dataset.append(s)

            # calculate DOPE score of decoded counterpart
            crd_mdl = decoded[i].permute(1,0).unsqueeze(0).data.cpu().numpy()[:, :dataset.shape[2]]*self.stdval + self.meanval  
            self.mol.coordinates = deepcopy(crd_mdl)
            self.mol.write_pdb("tmp.pdb")
            s = dope_score("tmp.pdb")
            dope_decoded.append(s)

        return dope_dataset, dope_decoded


    def scan_error(self, samples = 50):

        z = network.encode(training_set.float())
        z_training = z.data.cpu().numpy()[:, :, 0]

        bx = (np.max(z_training[:, 0]) - np.min(z_training[:, 0]))*0.1 # 10% margins on x-axis
        by = (np.max(z_training[:, 1]) - np.min(z_training[:, 1]))*0.1 # 10% margins on y-axis
        xvals = np.linspace(np.min(z_training[:, 0])-bx, np.max(z_training[:, 0])+bx, samples)
        yvals = np.linspace(np.min(z_training[:, 1])-by, np.max(z_training[:, 1])+by, samples)

        surf_z = np.zeros((len(xvals), len(yvals))) # L2 norms in latent space ("drift")
        surf_c = np.zeros((len(xvals), len(yvals))) # L2 norms in Cartesian space

        with torch.no_grad():

            for x, i in enumerate(xvals):
                for y, j in enumerate(yvals):

                    # take latent space coordinate (1) and decode it (2)
                    z1 = torch.tensor([[[i,j]]]).float()
                    s1 = self.network.decode(z1)[:,:,:self.training_set.shape[2]]

                    # take the decoded structure, re-encode it (3) and then decode it (4)
                    z2 = self.network.encode(s1)
                    s2 = self.network.decode(z2)[:,:,:self.training_set.shape[2]]

                    surf_z[x,y] = np.sum((z2.numpy().flatten()-z1.numpy().flatten())**2) # Latent space L2, i.e. (1) vs (3)
                    surf_c[x,y] = np.sum((s2.numpy().flatten()-s1.numpy().flatten())**2) # Cartesian L2, i.e. (2) vs (4)

        return np.sqrt(surf_c), np.sqrt(surf_z)


    def _dope_score(self, fname):
        env = Environ()
        env.libs.topology.read(file='$(LIB)/top_heav.lib')
        env.libs.parameters.read(file='$(LIB)/par.lib')
        mdl = complete_pdb(env, fname)
        atmsel = Selection(mdl.chains[0])
        score = atmsel.assess_dope()
        return score


    def scan_dope(self, samples = 50):

        surf_dope = np.zeros((len(xvals), len(yvals)))
        with torch.no_grad():

            for x, i in enumerate(xvals):
                for y, j in enumerate(yvals):

                    # take latent space coordinate (1) and decode it (2)
                    z1 = torch.tensor([[[i,j]]]).float()
                    s1 = self.network.decode(z1)[:,:,:self.training_set.shape[2]]

                    crd_mdl = s1[0].permute(1,0).unsqueeze(0).data.cpu().numpy()[:, :self.training_set.shape[2]]*self.stdval + self.meanval  
                    self.mol.coordinates = deepcopy(crd_mdl)
                    self.mol.write_pdb("tmp.pdb")
                    surf_dope[x,y] = self._dope_score("tmp.pdb")


    def generate(self, crd):
        '''
        generate a collection of protein conformations, given (Nx2) coordinates in the latent space
        ''' 
        with torch.no_grad():
            z = torch.tensor(crd.transpose(1, 2, 0)).float()   
            s = self.network.decode(z)[:, :, :self.training_set.shape[2]].numpy().transpose(0, 2, 1)

        return s*self.stdval + self.meanval

***

In [None]:
networkfile = "C:\\temp\\Dropbox\\Durham\\data\\neural_net\\conv1d-physics-path_B.pth"
training_set_file = "MurD_closed_open_strided.pdb"
test_set_file = "MurD_closed_apo_strided.pdb"

In [None]:
MA = MolearnAnalysis(networkfile, training_set_file)

In [None]:
test_set = MA.load_test(test_set_file)

In [None]:
r_train = MA.get_error()
r2_test = MA.get_error(test_set)