# VISUALIZING A TRAINED NEURAL NETWORK

In this notebook, we will visualize the latent space of a trained neural network, and use it to generate protein interpolations. Let's start by loading a few useful classes.

In [None]:
import sys, os
sys.path.insert(0, "C:\\Users\\xdzl45\\workspace\\molearn\\src") #uncomment/edit as appropriate for your system
from analysis import MolearnAnalysis
import molearn
from molearn.autoencoder import Autoencoder

import torch
from copy import deepcopy
import biobox as bb
import numpy as np

## 1. Neural network analysis

These are the paths to files containing the neural network parameters, the training set, and the test set.

In [None]:
networkfile = f'data{os.sep}checkpoint_epoch2872.ckpt'
training_set_file = f'data{os.sep}MurD_closed_open_strided.pdb'
test_set_file = f'data{os.sep}MurD_closed_apo_strided.pdb'

### 1.1. Data loading

The `MolearnAnalysis` class features a series of methods to characterize the latent space of a neural network, assessing its performance, and generating new protein conformations. Its constructor requires two parameters: neural network parameters and a multiPDB containing the coordinates of the examples the neural network was trained with. Optional parameters enable the user to define the neural network architecture (if other than default) and atoms selected from training (backbone and beta carbon by default). We start by loading a trained neural network.

In [None]:
checkpoint = torch.load(networkfile, map_location= torch.device('cpu'))
net = Autoencoder(**checkpoint['network_kwargs'])
net.load_state_dict(checkpoint['model_state_dict'])

Now we can create a `MolearnAnalysis`, and load the network and the dataset it was trained with.

In [None]:
MA = MolearnAnalysis()
MA.set_network(net)
MA.set_train_data(training_set_file)

Optionally, we can load a test set. The test set read by the network is returned to the user in two forms: a normalised pyTorch Tensor (useful to be fed to the neural network) and a numpy array (for displaying). The test set is also stored in the `MolearnAnalysis` object. In case multiple test sets are available, the last one being loaded is stored.

In [None]:
MA.set_test_data(test_set_file)

### 1.2. Dataset analysis

The following methods yield information on RMSD, DOPE and Ramachandran score of training and test sets.

In [None]:
err_train = MA.get_error('training_set') # RMSD of training set
err_test = MA.get_error('test_set') #RMSD of test set

dope_train, dope_train_decoded = MA.get_dope('training_set') #DOPE score of training set before and after decoding
dope_test, dope_test_decoded = MA.get_dope('test_set') #DOPE score of test set before and after decoding

rama_train, rama_train_decoded = MA.get_ramachandran('training_set') # Ramachandran scores
rama_test, rama_test_decoded = MA.get_ramachandran('test_set') # Ramachandran scores

### 1.3. Latent space scan

the following methods perform a grid search of the latent space. The output of these commands is also stored internally in the `MolearnAnalysis` object. Here, we will generate 50x50 grids around the training set (plus/minus 10%). If any of the methods below is called again, precalculated versions are returned, unless a different number of samples is passed as parameter. In this case, the grid search is executed again, and new results stored.

<div class="alert alert-block alert-warning">
<b>Warning:</b> depending on the sampling granularity, this operation can take a long time! Calculating the DOPE score of a 50x50 grid can take ~20 minutes.
If you are in a rush, you can skip these steps, or comment out one of the two lines below. The rest of this notebook will still run, though you will not be able to visualize coloured landscapes in the GUI presented in the next section.
</div>

We will assess the L2 norm in latent and 3D space, as well as the local DOPE and Ramachandran scores (see Jupyter notebook `molearn_analysis.ipynb`). These act as heuristic for the percision of the neural network.

In [None]:
landscape_err_latent, landscape_err_3d, xaxis, yaxis = MA.scan_error(samples=50)

In [None]:
landscape_dope_unrefined, surf_dope_refined, xaxis, yaxis = MA.scan_dope(samples=50)

In [None]:
landscape_ramachandran, xaxis, yaxis = MA.scan_ramachandran(samples=50)

An interesting plot, is one where the latent space is coloured as a function of how closely structures generated resemble a target structure. In this example, we pick a random conformation from the test set and compare a grid of the latent space to it in terms of RMSD. As for other landscape scanning methods above, the result of this scan will be both returned to the user and stored internally.

In [None]:
landscape_target_rmsd, xaxis, yaxis = MA.scan_error_from_target(MA.test_set[0], samples=50)

The user can also create custom functions to quantify a desired property through the latent space. Any custom function must take as first parameter a (1xNx3) array describing the coordinates of N atoms. Additional parameters can the also be provided. The custom function is used by calling `MolearnAnalysis.scan_custom`. This takes as three parameters the custom function, a list of extra parameters to pass the function, and a string used as label. This label will be used to retrieve the data associated with this function, stored in the dictionary `MolearnAnalysis.custom_data`. The optional parameter `samples` works as in previous scanning functions.

 In the example below, we measure the radius of gyration by implementing a custom function called `f_rgyr`. The function takes one extra parameter: a `Biobox` object describing the protein of interest (available in the `MolearnAnalysis` upon data loading).

In [None]:
def f_rgyr(crd, M):
    M.coordinates = deepcopy(crd)
    M.set_current(0)
    return bb.rgyr(M)

rgyr, xaxis, yaxis = MA.scan_custom(f_rgyr, [MA.mol], "rgyr", samples=50)

In the following second example, we measure the distance between two atoms of interest.

In [None]:
def f_distance(crd, M, idx1, idx2):
    M.coordinates = deepcopy(crd)
    M.set_current(0)
    return np.sqrt(np.sum(M.points[idx2]-M.points[idx1])**2)

_, idx1 = MA.mol.atomselect("*", 1, "CA", get_index=True) # find index of first atom of interest
_, idx2 = MA.mol.atomselect("*", 300, "CA", get_index=True)  # find index of second atom of interest
dist, xaxis, yaxis = MA.scan_custom(f_distance, [MA.mol, idx1, idx2], "dist", samples=50)

In this final example, we measure distances in a slightly more complicated ways: as flexible links, to represent chemical cross-linking distances. Note: this may require some profiling for speed, currently this requires building a new density representation for each distance and conformation one wants to assess.

In [None]:
def f_distance_link(crd, M, idx1, idx2, measure_type="euclidean", maxdist=20):

    M.coordinates = deepcopy(crd)
    M.set_current(0)
    
    if measure_type == "euclidean":
        dist = np.sqrt(np.sum(M.points[idx2]-M.points[idx1])**2)
    if measure_type == "flex_sidechain":
        XL = bb.Xlink(M)
        dist = XL.distance_matrix([idx1[0], idx2[0]], method="euclidean", flexible_sidechain=True)[0][1]
    if maasure_type == "all_flex":
        XL = bb.Xlink(M)
        dist = XL.distance_matrix([idx1[0], idx2[0]], method="theta", smooth=True, flexible_sidechain=True)[0][1]
        
    return dist

# this example demonstrates how to generate landscapes for multiple cross-links
# listed in an input text file with format, e.g., K189 NZ K192 NZ 18
fin = open(file, "r")
links = []
for line in fin:
    w= line.split()
    links.append([int(w[0][1:]), w[1], int(w[2][1:]), w[2]])

fin.close()
    
for l in links:
    print(f"scanning {l[0]}-{l[1]} link")
    _, idx1 = MA.mol.atomselect("*", l[0], l[1], get_index=True) 
    _, idx2 = MA.mol.atomselect("*", l[2], l[3], get_index=True)
    dist, xaxis, yaxis = MA.scan_custom(f_distance_link, [MA.mol, idx1, idx2, "flex_sidechain", 20],
                                        f"link_{l[0]}-{l[1]}", samples=50)

### 2. Neural Network visualisation

Now that we have loaded some information in our `MolearnAnalysis` object, it's time to explore its contents! To this end, we will use a `MolearnGUI` object, creating an interactive interface!

In [None]:
from analysis import MolearnGUI

In [None]:
MolearnGUI(MA);

The interface is divided in three areas: a control panel (left), a 2D latent space representation (center) and a 3D protein view (right, initially empty). Here are instructions on how to use it:

* the click boxes in the interface enable displaying training and test set projections in the latent space area. If the test set was not loaded in the the `MolearnAnalysis` object, the box will be grayed out.

* the **drop down menu and scroller** enable colouring the latent space surface in different ways. Colouring styles available will depend on whether grid sampling methods have been called in the MolearnAnalysis object, before the GUI was started.

* the **2D surface is clickable**. Clicking multiple times enables defining a path. 
The coordinates of clicked points appear in the **editable text box** to the bottom left as (x1, y1, x2, y2, ...). 

* When clicking the latent space, a **3D protein structure** will appear on the right. The representation will feature a number of conformations equal to the number of points clicked on the latent space, to which extra **sampling points** are added. By default, 10 extra points are added in each interval. This value can be edited in the menu on the left.

* The interpolation visible on the right can be saved into a multiPDB file via the button labelled "**Save PDB**" to the bottom of the menu on the left.

* The buttons "**Save state**" and "**Load state**" enable you to save/load the current state of the GUI. In this way, you do not have to re-calculate everything again, the next time you execute this notebook.

MolearnGUI can be launched without any parameter. This is useful if you want to start by loading a previous GUI state.

In [None]:
from analysis import MolearnAnalysis, MolearnGUI
MG = MolearnGUI();

Note that state files created via the button **"Save PDB"** use the Python package `pickle`. You can open them yourself with the following command (where `MA` is a `MolearnAnalysis` object and `waypoints` are the waypoints shown in the GUI):

In [None]:
fname = f'data{os.sep}state.p' #name of your state file

import pickle
MA, waypoints = pickle.load( open( fname, "rb" ) )

You got to the end of this notebook! If you are interested in knowing what is happening inside the MolearnAnalysis object, see the Juptyer notebook `molearn_analysis.ipynb`.

***