# VISUALIZING A TRAINED NEURAL NETWORK

In this notebook, we will visualize the latent space of a trained neural network, and use it to generate protein interpolations. Let's start by loading a few useful classes.

In [None]:
import sys, os
import molearn
from molearn.analysis import MolearnAnalysis
from molearn.data import PDBData
from molearn.models.foldingnet import AutoEncoder

import torch
from copy import deepcopy
import biobox as bb
import numpy as np

## 1. Neural network analysis

These are the paths to files containing the neural network parameters, the training set, and the test set.

In [None]:
networkfile = f'data{os.sep}checkpoint_no_optimizer_state_dict_epoch167_loss0.003259085263643.ckpt'
#training_set_file = f'data{os.sep}MurD_closed_open_selection.pdb'
#test_set_file = f'data{os.sep}MurD_closed_apo_selection.pdb'

### 1.1. Data loading

The `MolearnAnalysis` class features a series of methods to characterize the latent space of a neural network, assessing its performance, and generating new protein conformations. Its constructor requires two parameters: neural network parameters and a multiPDB containing the coordinates of the examples the neural network was trained with. Optional parameters enable the user to define the neural network architecture (if other than default) and atoms selected from training (backbone and beta carbon by default). We start by loading a trained neural network.

In [None]:
checkpoint = torch.load(networkfile, map_location=torch.device('cpu'))
net = AutoEncoder(**checkpoint['network_kwargs'])
net.load_state_dict(checkpoint['model_state_dict'])

Now we can create a `MolearnAnalysis` object, and load the network in it.

In [None]:
MA = MolearnAnalysis()
MA.set_network(net)

Now we prepare the data used to train and test the network. Calling `PDBData.import_pdb` multiple times enables appending multiple datasets (these must contain the same number of atoms). After loading, the names of terminal oxygens renamed to "O" with `PDBData.atomselect`, and then atoms of interests can be selected via the `PDBData.import_pdb` method. One datasets are loaded, we can prepare them as inputs for a neural network, i.e. normalizing it and transforming it into a tensor.

In [None]:
data = PDBData()
data.import_pdb(f'data{os.sep}MurD_closed_selection.pdb')
data.import_pdb(f'data{os.sep}MurD_open_selection.pdb')
data.fix_terminal()
data.atomselect(atoms = ['CA', 'C', 'N', 'CB', 'O'])
data.prepare_dataset()

When training the neural network, it is likely the dataset we just prepared was split into a training and a test set. We can replicate the way it was split with `PDBData.split`. By default, 10% examples are randomly selected as training set. To ensure that the same random split is produced, the parameter `manual_seed` can be defined (same splits are obtained when setting the same number). Once a split is produced, let's add our two resulting datasets to `MolearnAnalysis`, assigning them to an appropriate label name.

In [None]:
data_train, data_valid = data.split(manual_seed=25)
MA.set_dataset("training", data_train)
MA.set_dataset("validation", data_test)

### 1.2. Dataset analysis

<div class="alert alert-block alert-warning">
<b>Warning:</b> Calculations can be made faster, but more memory intensive, via two attributes <code>MolearnAnalysis.batch_size</code> defines the number of examples that are simultaneously encoded/decoded, <code>MolearnAnalysis.processes</code> defined how many parallel threads to spawn when calculating DOPE and Ramachandran scores. If only a low amount of RAM is available, the user may want to consider keeping these numbers low (default is 1). 
</div>

In [None]:
MA.batch_size = 8
MA.processes = 4

The following methods yield information on RMSD, DOPE and Ramachandran score of training and test sets.

In [None]:
# RMSDs
err_train = MA.get_error('training')
err_test = MA.get_error('validation')

#DOPE score
dope_train_data = MA.get_dope('training')
dope_test_data = MA.get_dope('validation')

# Ramachandran scores
rama_train_data = MA.get_ramachandran('training') 
rama_test_data = MA.get_ramachandran('validation')

### 1.3. Latent space scan

the following methods perform a grid search of the latent space. The output of these commands is also stored internally in the `MolearnAnalysis` object. Here, we will generate a 50x50 grid surrounding the latent space projection of each dataset loaded, plus/minus 10%. Each of these grid points will be converted in its protein conformation counterpart, i.e. a new `grid` dataset is added to the `MolearnAnalysis` object.

<div class="alert alert-block alert-warning">
<b>Warning:</b> depending on the sampling granularity and number of processors used, the operations in this section can take a long time! For instance, calculating the DOPE score of a 50x50 grid with a single processor can take >10 minutes.
If you are in a rush, you can skip these steps, or comment out one of the two lines below. The rest of this notebook will still run, though you will not be able to visualize coloured landscapes in the GUI presented in the next section.
</div>

In [None]:
MA.setup_grid(30)

We will assess the L2 norm in latent and 3D space, as well as the local Ramachandran scores (see Jupyter notebook `molearn_analysis.ipynb`). These act as heuristic for the percision of the neural network. Each landscape calculated with the methods described hereafter is both returned to the caller, and stored in the `MolearnAnalysis.surfaces` dictionary.

In [None]:
landscape_err_latent, landscape_err_3d, xaxis, yaxis = MA.scan_error()

In [None]:
landscape_ramachandran, xaxis, yaxis = MA.scan_ramachandran()

To produce a DOPE score landscape, the `MA.scan_dope` method is available. Compared to the two previous methods, this takes one more extra parameter, `refine`. If True (default), structures generated will be energy minimised before DOPE scoring. Here, we will just calculate the raw DOPE score landscape.

In [None]:
landscape_dope_unrefined, xaxis, yaxis = MA.scan_dope(refine=False)

An interesting plot, is one where the latent space is coloured as a function of how closely structures generated resemble a target structure. In this example, we pick a random conformation from the test set and compare a grid of the latent space to it in terms of RMSD. As for other landscape scanning methods above, the result of this scan will be both returned to the user and stored internally.

In [None]:
landscape_target_rmsd, xaxis, yaxis = MA.scan_error_from_target(MA.test_set[0])

The user can also create custom functions to quantify a desired property through the latent space. Any custom function must take as first parameter a (1xNx3) array describing the coordinates of N atoms. Additional parameters can the also be provided. The custom function is used by calling `MolearnAnalysis.scan_custom`. This takes as three parameters the custom function, a list of extra parameters to pass the function, and a string used as label. This label will be used to retrieve the data associated with this function, stored in the dictionary `MolearnAnalysis.custom_data`.

 In the example below, we measure the radius of gyration by implementing a custom function called `f_rgyr`. The function takes one extra parameter: a `Biobox` object describing the protein of interest (available in the `MolearnAnalysis` upon data loading).

In [None]:
def f_rgyr(crd, M):
    M.coordinates = deepcopy(crd)
    M.set_current(0)
    return bb.rgyr(M)

rgyr, xaxis, yaxis = MA.scan_custom(f_rgyr, [MA.mol], "rgyr")

In this second example, we measure the distance between two atoms of interest.

In [None]:
def f_distance(crd, M, idx1, idx2):
    M.coordinates = deepcopy(crd)
    M.set_current(0)
    return np.sqrt(np.sum(M.points[idx2]-M.points[idx1])**2)

_, idx1 = MA.mol.atomselect("*", 1, "CA", get_index=True) # find index of first atom of interest (alpha carbon of resid 1)
_, idx2 = MA.mol.atomselect("*", 300, "CA", get_index=True)  # find index of second atom of interest (alpha carbon of resid 300)
dist, xaxis, yaxis = MA.scan_custom(f_distance, [MA.mol, idx1, idx2], "dist")

In this final example, we measure distances in a slightly more complicated ways: as flexible links, to represent chemical cross-linking distances. Note: this may require some profiling for speed, currently this requires building a new density representation for each distance and conformation one wants to assess.

In [None]:
def f_distance_link(crd, M, idx1, idx2, measure_type="euclidean", maxdist=20):

    M.coordinates = deepcopy(crd)
    M.set_current(0)
    
    if measure_type == "euclidean":
        dist = np.sqrt(np.sum(M.points[idx2]-M.points[idx1])**2)
    if measure_type == "flex_sidechain":
        XL = bb.Xlink(M)
        dist = XL.distance_matrix([idx1[0], idx2[0]], method="euclidean", flexible_sidechain=True)[0][1]
    if measure_type == "all_flex":
        XL = bb.Xlink(M)
        dist = XL.distance_matrix([idx1[0], idx2[0]], method="theta", smooth=True, flexible_sidechain=True)[0][1]
        
    return dist

# this example demonstrates how to generate landscapes for multiple cross-links
# listed in an input text file with format, e.g., K189 NZ K192 NZ 18
fin = open(file, "r")
links = []
for line in fin:
    w= line.split()
    links.append([int(w[0][1:]), w[1], int(w[2][1:]), w[2]])

fin.close()
    
for l in links:
    print(f"scanning {l[0]}-{l[1]} link")
    _, idx1 = MA.mol.atomselect("*", l[0], l[1], get_index=True) 
    _, idx2 = MA.mol.atomselect("*", l[2], l[3], get_index=True)
    dist, xaxis, yaxis = MA.scan_custom(f_distance_link, [MA.mol, idx1, idx2, "flex_sidechain", 20],
                                        f"link_{l[0]}-{l[1]}", samples=50)

### 2. Neural Network visualisation

Now that we have loaded some information in our `MolearnAnalysis` object, it's time to explore its contents! To this end, we will use a `MolearnGUI` object, creating an interactive interface!

In [None]:
from molearn.analysis import MolearnGUI

In [None]:
MolearnGUI(MA);

The interface is divided in three areas: a control panel (left), a 2D latent space representation (centre) and a 3D protein view (right, initially empty). Here are instructions on how to use it:

* the **datasets drop down menu** enable displaying the projection in the latent space of the loaded datasets.

* the **drop down menu and scroller** enable colouring the latent space surface in different ways. Colouring styles available will depend on whether grid sampling methods have been called in the MolearnAnalysis object, before the GUI was started.

* the **2D surface is clickable**. Clicking multiple times enables defining a path. 
The coordinates of clicked points appear in the **editable text box** to the bottom left as (x1, y1, x2, y2, ...).

* There are two different methods to define paths: as straight lines or as low-energy path. Method selection is controlled by the **Path** drop down menu. If the Euclidean method is selected, paths will be straight lines regularly sampled with a number of points defined by the **sampling** box. If the A* method is selected, the A* shortest path algorithm is chosen. The number of sampling point will be equal to the number of grid points traversed (note, a size 3 running average is applied to the path).

* When clicking the latent space, a **3D protein structure** will appear on the right. The representation will feature a number of conformations equal to the number of path sampling points represented in the latent space.

* The interpolation visible on the right can be saved into a multiPDB file via the button labelled "**Save PDB**" to the bottom of the menu on the left.

* The buttons "**Save state**" and "**Load state**" enable you to save/load the current state of the GUI. In this way, you do not have to re-calculate everything again, the next time you execute this notebook.

MolearnGUI can be launched without any parameter. This is useful if you want to start by loading a previous GUI state.

In [None]:
MG = MolearnGUI();

Note that state files created via the button **"Save PDB"** use the Python package `pickle`. You can open them yourself with the following command (where `MA` is a `MolearnAnalysis` object and `waypoints` are the waypoints shown in the GUI):

In [None]:
fname = f'data{os.sep}state.p' #name of your state file

import pickle
MA, waypoints = pickle.load( open( fname, "rb" ) )

You got to the end of this notebook! If you are interested in knowing what is happening inside the MolearnAnalysis object, see the Juptyer notebook `molearn_analysis.ipynb`.

***