# Set up environment

## Import packages, modules

In [None]:
# Dependencies to install for double cav model:

- seaborn 0.11.1
- scikit-learn

In [None]:
import glob
import os
import random
import timeit
import re
import pickle

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import pandas as pd

import torch
import torchvision

import Bio.SeqIO
import Bio.PDB

# Packages for parsing and investigate PDB files
# import pdbfixer
# import simtk
# import simtk.openmm
# import simtk.openmm.app
# import simtk.unit

# Utility packages
from getpass import getpass
from IPython.display import Audio
import smtplib
import ssl
import tqdm.notebook as tqdm
import traceback
import inspect

# os.chdir("gdrive/My Drive/thesis/")

# For consistency:
random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

import tested_double_cav_models as test

# Personal modules
# from neptune.cavity_model import (
#     CavityModel,
#     ResidueEnvironment,
#     ResidueEnvironmentsDataset,
#     DownstreamModel,
#     ToTensor,
#     DDGDataset,
#     DDGToTensor,
#     DownstreamModel,
# )

# from neptune.helpers import (
#     _augment_with_reverse_mutation,
#     _populate_dfs_with_nlls_and_nlfs,
#     _populate_dfs_with_resenvs,
#     _train_loop,
#     _train_val_split,
#     _get_ddg_training_dataloaders,
#     _get_ddg_validation_dataloaders,
#     _train_downstream_and_evaluate,
#     _predict_with_downstream,
#     _eval_loop,
#     _test_cavity_model,
# )

# from visualization import (
#     plot_training_history,
#     plot_training_history_v2,
#     show_respair_acc_heatmap,
# )

In [None]:
from visualization import (
    plot_training_history,
    plot_training_history_v2,
    show_respair_acc_heatmap,
)

In [None]:
from log_to_neptune import (
    set_and_save_metadata,
    log_metadata_to_neptune,
    convert_plt_to_plotly
)

In [None]:
# !pip install chardet==3.0.4
# !pip install rmsd
# !pip install openpyxl
# !pip install numpy==1.16
# !pip install pandas

In [None]:
# !git clone https://github.com/gao666999/SSBONDPredict.git

In [None]:
import Bio
import Bio.PDB
import Bio.PDB.vectors
import numpy as np
import simtk
import simtk.openmm
import simtk.openmm.app
import simtk.unit

# from scipy.spatial.distance import cdist

## cavity_model

In [None]:
import glob
import os
import random
from typing import Callable, List, Union

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset

import tqdm.notebook as tqdm

# public objects of that module that will be exported when from <module> import * is used on the module (overrides default _objects)


class ResidueEnvironment:
    """
    Residue environment class used to hold necessarry information about the
    atoms of the environment such as atomic coordinates, atom types and the
    class of the missing (non-central) TWO amino acids

    Parameters
    ----------
    xyz_coords: np.ndarray
        Numpy array with shape (n_atoms, 3) containing the x, y, z coordinates.
    atom_types: np.ndarray
        1D numpy array containing the atom types. Integer values in range(6).
    restypes_onehot: np.ndarray # -> TO DOUBLE
        Numpy array with shape (n_atoms, 21) containing the amino acid
        class of the missing amino acid
    chain_id: str
        Chain id associated to ResidueEnvironment object
    pdb_residue_number: int # -> TO DOUBLE
        Residue number associated with the ResidueEnvironment object
    pdb_id: str
        PDBID associated with the ResidueEnvironment object
    """

    def __init__(
        self,
        xyz_coords: np.ndarray,
        atom_types: np.ndarray,
        restype_onehot: np.ndarray,
        chain_id: str,
        pdb_residue_number: int,
        pdb_id: str,
    ):
        self._xyz_coords = xyz_coords
        self._atom_types = atom_types
        self._restype_onehot = restype_onehot # -> TO DOUBLE
        self._chain_id = chain_id
        self._pdb_residue_number = pdb_residue_number # -> TO DOUBLE
        self._pdb_id = pdb_id
        
    @property
    def xyz_coords(self):
        return self._xyz_coords

    @property
    def atom_types(self):
        return self._atom_types

    @property
    def restype_onehot(self):
        return self._restype_onehot

    @property
    def restype_index(self):
        return np.argmax(self.restype_onehot)
    
    @property
    def chain_id(self):
        return self._chain_id

    @property
    def pdb_residue_number(self):
        return self._pdb_residue_number

    @property
    def pdb_id(self):
        return self._pdb_id

    def __repr__(self):
        """
        Used to represent a classÃ¢â‚¬â„¢s objects as a string.
        Built-in fct for calling it: repr()
        """
        return (
            f"<ResidueEnvironment with {self.xyz_coords.shape[0]} atoms. " # it calls property self.xyz_coords
            f"pdb_id: {self.pdb_id}, "
            f"chain_id: {self.chain_id}, "
            f"pdb_residue_number: {self.pdb_residue_number}, " 
            f"restype_index: {self.restype_index}>"
        )


class ResidueEnvironmentsDataset(Dataset):
    """
    Residue environment dataset class

    Parameters
    ----------
    input_data: Union[List[str], List[ResidueEnvironment]]
        List of parsed pdb filenames in .npz format or list of
        ResidueEnvironment objects
    transform: Callable
        A to-tensor transformer class
    """

    def __init__(
        self,
        input_data: Union[List[str], List[ResidueEnvironment]], # Union[X, Y] means either X or Y
        transformer: Callable = None,
    ):
        if all(isinstance(x, ResidueEnvironment) for x in input_data):
            self._res_env_objects = input_data
        elif all(isinstance(x, str) for x in input_data):
            self._res_env_objects = self._parse_envs(input_data)
        else:
            raise ValueError(
                "Input data is not of type" "Union[List[str], List[ResidueEnvironment]]"
            )

        self._transformer = transformer

    @property
    def res_env_objects(self):
        return self._res_env_objects

    @property
    def transformer(self):
        return self._transformer

    @transformer.setter
    def transformer(self, transformer):
        """TODO: Think if a constraint to add later"""
        self._transformer = transformer

    def __len__(self):
        return len(self.res_env_objects)

    def __getitem__(self, idx):
        sample = self.res_env_objects[idx]
        if self.transformer:
            sample = self.transformer(sample)
        return sample

    def _parse_envs(self, npz_filenames: List[str]) -> List[ResidueEnvironment]:
        """
        TODO: Make this more readable
        """
        res_env_objects = []
        for i in tqdm.tnrange(len(npz_filenames)):
            coordinate_features = np.load(npz_filenames[i])
            atom_coords_prot_seq = coordinate_features["positions"] # atom coords
            restypes_onehots_prot_seq = coordinate_features["pair_aa_onehot"]
            selector_prot_seq = coordinate_features["selector"] # atom ids
            atom_types_flattened = coordinate_features["atom_types_numeric"]

            chain_ids = coordinate_features["chain_ids"]
            pdb_residue_numbers = coordinate_features["pair_res_indices"]
            chain_boundary_indices = coordinate_features["chain_boundary_indices"]

            pdb_id = os.path.basename(npz_filenames[i])[0:4]

            N_pair_residues = selector_prot_seq.shape[0] # WILL BECOME N_PAIRS!!

            for pair_res_i in range(N_pair_residues):
                # Get atom indexes
                selector = selector_prot_seq[pair_res_i]
                selector_masked = selector[selector > -1]  # Remove Filler -1
                
                # Get atom types
                atom_types = atom_types_flattened[selector_masked]
                
                # Get atom coordinates
                coords_mask = (
                    atom_coords_prot_seq[pair_res_i, :, 0] != -99.0 # for all its atoms, only need to check one column of coord for it (x here)
                )  # Remove filler
                coords = atom_coords_prot_seq[pair_res_i][coords_mask]
                
                # Get resi_evt ONE-HOT label (Target variable) -> TO DOUBLE
                restype_onehot = restypes_onehots_prot_seq[pair_res_i]
                
                # Get resi real id -> TO DOUBLE
                pdb_residue_number = pdb_residue_numbers[pair_res_i]
                
                # Locate chain id -> TO DOUBLE
                for j in range(len(chain_ids)):
                    chain_boundary_0 = chain_boundary_indices[j]
                    chain_boundary_1 = chain_boundary_indices[j + 1]
                    if pair_res_i in range(chain_boundary_0, chain_boundary_1):
                        chain_id = str(chain_ids[j])
                        break

                res_env_objects.append(
                    ResidueEnvironment(
                        coords,
                        atom_types,
                        restype_onehot, # -> TO DOUBLE
                        chain_id, # -> TO DOUBLE
                        pdb_residue_number, # -> TO DOUBLE
                        pdb_id,
                    )
                )

        return res_env_objects


class ToTensor:
    """
    To-tensor transformer

    Parameters
    ----------
    device: str
        Either "cuda" (gpu) or "cpu". Is set-able.
    """

    def __init__(self,
        device: str,
        unravel_index=True,
        reshape_index=True,
    ):
        self.device = device
        self.unravel_index = unravel_index
        self.reshape_index = reshape_index

    @property
    def device(self):
        return self.__device

    @device.setter
    def device(self, device):
        allowed_devices = ["cuda", "cpu"]
        if device in allowed_devices:
            self.__device = device
        else:
            raise ValueError(
                'chosen device "{device}" not in {allowed_devices}.')

    def __call__(self, sample: ResidueEnvironment,):
        """Converts single ResidueEnvironment object into x_ and y_"""

        sample_env = np.hstack(
            [np.reshape(sample.atom_types, [-1, 1]), sample.xyz_coords]
        )
        if self.reshape_index:
            return {
                "x_": torch.tensor(sample_env, dtype=torch.float32
                    ).to(self.device),
                "y_": self.reshape_pairres_indices(sample.restype_onehot,
                                                   n_aa_in=20,
                    ).to(self.device),
            }

        else:
            return {
                "x_": torch.tensor(sample_env, dtype=torch.float32
                    ).to(self.device),
                "y_": torch.tensor(sample.restype_onehot, dtype=torch.int8
                    ).to(self.device),
            }


    def reshape_pairres_indices(self, targets: np.array, n_aa_in=20, n_aa_out=20):
        """
        Convert pair_res onehot encoding to individual res encoding.
        array((n_pairs, n_aa_in*n_aa_in)) -> tensor((n_pairs, 2, n_aa_out*2))

        """

        indices = np.unravel_index(np.argmax(targets), shape=(n_aa_in,
                                                             n_aa_in))
        if self.unravel_index:
            one_hot_arr = torch.zeros((2, n_aa_out), dtype=torch.int8)
            one_hot_arr[0, indices[0]] = 1
            one_hot_arr[1, indices[1]] = 1
        else:
            one_hot_arr = torch.zeros((n_aa_out*n_aa_out), dtype=torch.int8)
            indices = np.ravel_multi_index(np.vstack(indices),
                                           dims=(n_aa_out, n_aa_out)
                                           )
            one_hot_arr[indices] = 1

        return one_hot_arr

    def collate_cat(self, batch: List[ResidueEnvironment]):
        """
        Collate method used by the dataloader to collate a
        batch of ResidueEnvironment objects.
        """
        target = torch.cat([torch.unsqueeze(b["y_"], 0) for b in batch], dim=0)

        # To collate the input, we need to add a column which
        # specifies the environment each atom belongs to = its evt (in the radius zone or the res)!!! So we add an evt "pseudo_id in the batch"
        env_id_batch = []
        for i, b in enumerate(batch): # b is one protein in the batch
            n_atoms = b["x_"].shape[0]
            env_id_arr = torch.zeros(n_atoms, dtype=torch.float32).to(self.device) + i # i is this pseudo_id, to device to be in the same device ax x
            env_id_batch.append(
                torch.cat([torch.unsqueeze(env_id_arr, 1), b["x_"]], dim=1) # add one column
            )
        data = torch.cat(env_id_batch, dim=0) # stack all the proteins'atoms on x axis

        return data, target


class CavityModel(torch.nn.Module):
    """
    3D convolutional neural network to missing amino acid classification

    Parameters
    ----------
    device: str
        Either "cuda" (gpu) or "cpu". Is set-able.
    n_atom_types: int
        Number of atom types. (C, H, N, O, S, P)
    bins_per_angstrom: float
        Number of grid points per Angstrom.
    grid_dim: int
        Grid dimension
    sigma: float
        Standard deviation used for gaussian blurring
    """

    def __init__(
        self,
        device: str,
        n_atom_types: int = 6,
        bins_per_angstrom: float = 1.0,
        grid_dim_xy: int = 8, # because 9 Angstrom of radius
        grid_dim_z: int = 16,
        sigma: float = 0.6,
    ):

        super().__init__()

        self.device = device
        self._n_atom_types = n_atom_types
        self._bins_per_angstrom = bins_per_angstrom
        self._grid_dim_xy = grid_dim_xy
        self._grid_dim_z = grid_dim_z
        self._sigma = sigma

        self._model()

    @property
    def device(self):
        return self.__device

    @device.setter
    def device(self, device):
        allowed_devices = ["cuda", "cpu"]
        if device in allowed_devices:
            self.__device = device
        else:
            raise ValueError('chosen device "{device}" not in {allowed_devices}')

    @property
    def n_atom_types(self):
        return self._n_atom_types

    @property
    def bins_per_angstrom(self):
        return self._bins_per_angstrom

    @property
    def grid_dim_xy(self):
        return self._grid_dim_xy

    @property
    def grid_dim_z(self):
        return self._grid_dim_z

    @property
    def sigma(self):
        return self._sigma

    @property
    def sigma_p(self):
        return self.sigma * self.bins_per_angstrom

    @property
    def lin_spacing_xy(self):
        lin_spacing_xy = np.linspace(
            start=-self.grid_dim_xy / 2 * self.bins_per_angstrom 
            + self.bins_per_angstrom / 2,
            stop=self.grid_dim_xy / 2 * self.bins_per_angstrom
            - self.bins_per_angstrom / 2,
            num=self.grid_dim_xy,
        )
        return lin_spacing_xy

    @property
    def lin_spacing_z(self):
        lin_spacing_z = np.linspace(
            start=-self.grid_dim_z / 2 * self.bins_per_angstrom 
            + self.bins_per_angstrom / 2,
            stop=self.grid_dim_z / 2 * self.bins_per_angstrom
            - self.bins_per_angstrom / 2,
            num=self.grid_dim_z,
        )
        return lin_spacing_z

    def _model(self):
        self.xx, self.yy, self.zz = torch.tensor(
            np.meshgrid(
                self.lin_spacing_xy, self.lin_spacing_xy, self.lin_spacing_z, indexing="ij" # matrix indexing (classic python)
            ),
            dtype=torch.float32,
        ).to(self.device) # normally, already on "cuda"

        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv3d(6, 16, kernel_size=(3, 3, 3), padding=1), # output = [100, 16, 4, 4 ,8]
            torch.nn.MaxPool3d(kernel_size=2),
            torch.nn.BatchNorm3d(16),
            torch.nn.ReLU(),
        )

        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv3d(16, 32, kernel_size=(3, 3, 3), padding=1), # usual: padding = round(kernel_size/2, lower), output = [100, 32, 2, 2, 4]
            torch.nn.MaxPool3d(kernel_size=2),
            torch.nn.BatchNorm3d(32),
            torch.nn.ReLU(),
        )
        self.conv3 = torch.nn.Sequential(
            torch.nn.Conv3d(32, 64, kernel_size=(3, 3, 3), padding=1), # output = [100, 128, 1, 1, 2]
            torch.nn.MaxPool3d(kernel_size=2),
            torch.nn.BatchNorm3d(64),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
        )
        self.dense1 = torch.nn.Sequential(
            torch.nn.Linear(in_features=128, out_features=256), # bachnorm filters 64 * 4 parameters of batch norm per filter
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU(),
        )
        self.dense2 = torch.nn.Linear(in_features=256, out_features=40)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self._gaussian_blurring(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.dense1(x)
        x = self.dense2(x) # yields logits, mapped to probabilities afterwards by the Softmax fct.
        return x

    def _gaussian_blurring(self, x: torch.Tensor) -> torch.Tensor: # increase the resolution of the signal, reduce noises by blurring/smoothing intensity transitions of densities for each channel of atom type.
        """
        Method that takes 2d torch.Tensor describing the atoms of the batch.

        Parameters
        ----------
        x: torch.Tensor
            Tensor for shape (n_atoms, 5). Each row represents an atom, where:
                column 0 describes the environment of the batch the
                atom belongs to
                column 1 describes the atom type
                column 2,3,4 are the x, y, z coordinates, respectively

        Returns
        -------
        fields_torch: torch.Tensor
            Represents the structural environment (density val in 3d meshgrid)
            with gaussian blurring and has shape (-1, 6, self.grid_dim_xy,
                                                         self.grid_dim_xy,
                                                         self.grid_dim_z).
        """
        current_batch_size = torch.unique(x[:, 0]).shape[0]
        fields_torch = torch.zeros(
            (
                current_batch_size,
                self.n_atom_types,
                self.grid_dim_xy,
                self.grid_dim_xy,
                self.grid_dim_z,
            )
        ).to(self.device)
        for j in range(self.n_atom_types): # per batch
            mask_j = x[:, 1] == j
            atom_type_j_data = x[mask_j] # select all atoms of that type
            if atom_type_j_data.shape[0] > 0:
            # Fancy broadcasting:
            # reshaped_.xx.shape = (8*8*16, 1) : flattened x coordinates
            # pos[:, 0].shape = (n_atom_j, 1)
            # -> (reshaped_xx - pos[:, 0]).shape = (8*8*16, n_atom_j) : flattened density values, x axis contribution
                pos = atom_type_j_data[:, 2:]
                density = torch.exp(
                    -(
                        (torch.reshape(self.xx, [-1, 1]) - pos[:, 0]) ** 2
                        + (torch.reshape(self.yy, [-1, 1]) - pos[:, 1]) ** 2
                        + (torch.reshape(self.zz, [-1, 1]) - pos[:, 2]) ** 2
                    )
                    / (2 * self.sigma_p ** 2)
                )

                # Normalize each atom density to 1 (over whole batch), atom being x axis (dim=0)
                density /= torch.sum(density, dim=0)

                # Since column 0 of atom_type_j_data is SORTED
                # I can use a trick to detect the boundaries of environment based
                # on the change from one value to another.
                change_mask_j = (
                    atom_type_j_data[:, 0][:-1] != atom_type_j_data[:, 0][1:] # when !=, that means the previous and next indexes are the limits
                )

                # Add begin and end indices
                ranges_i = torch.cat(
                    [
                        torch.tensor([0]), # we start from 0
                        torch.arange(atom_type_j_data.shape[0] - 1)[change_mask_j] + 1,
                        torch.tensor([atom_type_j_data.shape[0]]), # we must end with the last environment for sure
                    ]
                )

                # Fill tensor, for each residual environment (i) of the batch
                for i in range(ranges_i.shape[0]):
                    if i < ranges_i.shape[0] - 1:
                        index_0, index_1 = ranges_i[i], ranges_i[i + 1]
                        fields = torch.reshape(
                            torch.sum(density[:, index_0:index_1], dim=1), # densities of the res_evt voxel
                            [self.grid_dim_xy, self.grid_dim_xy, self.grid_dim_z], # get back rectangular cuboid shape
                        )
                        fields_torch[i, j, :, :, :] = fields # density for that voxel
        return fields_torch


class DownstreamModel(torch.nn.Module):
    """
    Simple Downstream FC neural network with 1 hidden layer.
    """

    def __init__(self):
        super().__init__()

        # Model
        self.lin1 = torch.nn.Sequential(
            torch.nn.Linear(44, 10),
            torch.nn.ReLU(),
        )
        self.lin2 = torch.nn.Sequential(
            torch.nn.Linear(10, 10),
            torch.nn.ReLU(),
        )
        self.lin3 = torch.nn.Linear(10, 1)

    def forward(self, x):
        x = self.lin1(x)
        x = self.lin2(x)
        x = self.lin3(x)
        return x


class DDGDataset(Dataset):
    """
    ddG dataset
    """

    def __init__(
        self,
        df: pd.DataFrame,
        transformer: Callable = None,
    ):

        self._df = df
        self.transformer = transformer

    @property
    def df(self):
        return self._df

    @property
    def transformer(self):
        return self._transformer

    @transformer.setter
    def transformer(self, transformer):
        """TODO: Think if a constraint to add later"""
        self._transformer = transformer

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        sample = self.df.iloc[idx]
        if self.transformer:
            sample = self.transformer(sample)
        return sample


class DDGToTensor:
    """
    To-tensor transformer for ddG dataframe data
    """

    def __call__(self, sample: pd.Series):
        wt_onehot = np.zeros(20)
        wt_onehot[sample["wt_idx"]] = 1.0
        mt_onehot = np.zeros(20)
        mt_onehot[sample["mt_idx"]] = 1.0

        x_ = torch.cat(
            [
                torch.Tensor(wt_onehot),
                torch.Tensor(mt_onehot),
                torch.Tensor(
                    [
                        sample["wt_nll"],
                        sample["mt_nll"],
                        sample["wt_nlf"],
                        sample["mt_nlf"],
                    ]
                ),
            ]
        )

        return {"x_": x_, "y_": sample["ddg"]}

## matfact helper

In [None]:
from typing import Dict, List, Union, Tuple

import numpy as np
import pandas as pd
import torch
import timeit
import pickle
import glob
from torch.utils.data import DataLoader, Dataset

# from cavity_model import (
#     CavityModel,
#     ResidueEnvironmentsDataset,
#     ToTensor,
#     DDGDataset,
#     DDGToTensor,
#     DownstreamModel,
# )


def _train(
    dataloader_train: DataLoader,
    dataloader_val: DataLoader,
    cavity_model_net: CavityModel,
    loss_function: torch.nn.CrossEntropyLoss,
    optimizer: torch.optim.Adam,
    EPOCHS: int,
    PATIENCE_CUTOFF: int,
    matfact_k: int,
    output_shape: int,
    folder="models/double_cav_models",
    model_name="model",
    resume=False,
    ):
    """
    Helper function to perform training loop for the Cavity Model.
    """

    current_best_epoch = -1
    curr_best_metric = 1e4
    patience = 0 # we start at 0
    current_epoch = -1

    early_stop_metric = "loss"
    if loss_function.weight is not None: # no class imbalance correction
        early_stop_metric = "acc_join"
        curr_best_metric = -1

    # Resume training
    if len(glob.glob(f"{folder}/{model_name}_epoch_*.pt")) > 0:
        epochs_saved = [int(x.split("_epoch_")[1][:-3]) for x in glob.glob(
            f"{folder}/{model_name}*")]

        if resume:
            current_epoch = max(epochs_saved)
            model_path = f"{folder}/{model_name}_epoch_{current_epoch}.pt"
            checkpoint = torch.load(model_path)

            assert checkpoint["epoch"] == current_epoch, (
            f"checkpoint epoch {checkpoint['epoch']}",
            f"does not match current epoch {current_epoch}"
            )

            print(f"Training resumed from epoch {current_epoch}.\n")

            current_best_epoch = checkpoint["current_best_epoch"]
            curr_best_metric = checkpoint[f"current_best_{early_stop_metric}"]
            patience = checkpoint["patience"]

            cavity_model_net.load_state_dict(checkpoint["model_state_dict"])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

            print(f"Current best epoch: {current_best_epoch}, "
                  f"{early_stop_metric}: {curr_best_metric:5.3f}, "
                  f"Patience: {patience}.")
            print()
        else:
            raise ValueError(f"Epoch file(s) already exist(s) for {model_name}!")

    print(
        f"- Starts training with '{early_stop_metric}' "
        f"as early stop metric...\n")
    
    # Create dict of rec to save.
    rec = dict()

    # Run model.
    for epoch in range(current_epoch+1, EPOCHS+1): # EPOCHS+1 since 0 == ini state
        t1 = timeit.default_timer()

        # Assess model's initial state.
        if epoch == 0:
            print("Evaluating randomly initialized model.")
            loss_train, rec["train"] =  _eval_loop(cavity_model_net,
                                                dataloader_train,
                                                loss_function,
                                                matfact_k,
                                                output_shape,
                                                )
        # Train over train batches.
        else:
            loss_train, rec["train"] = _train_loop(cavity_model_net,
                                                dataloader_train,
                                                optimizer,
                                                loss_function,
                                                matfact_k,
                                                output_shape,
                                                )

        rec["train"]["loss"] = loss_train

        print(f"{'train'.upper():5s} - Loss: {rec['train']['loss']:5.3f}, "
        f"Join acc: {rec['train']['acc_join']:5.3f},  "
        f"Res1: {rec['train']['acc_res1']:4.2f}, "
        f"Res2: {rec['train']['acc_res2']:4.2f},  "
        f"Cond 2|1: {rec['train']['acc_res2_given_res1']:4.2f}, "
        f"Cond 1|2: {rec['train']['acc_res1_given_res2']:4.2f}."
        )

        # Validate over val batches.
        loss_val, rec["val"] = _eval_loop(cavity_model_net,
                                          dataloader_val,
                                          loss_function,
                                          matfact_k,
                                          output_shape,
                                          )
        rec["val"]["loss"] = loss_val

        # Show epoch result
        print(f"{'val'.upper():5s} - Loss: {rec['val']['loss']:5.3f}, "
        f"Join acc: {rec['val']['acc_join']:5.3f},  "
        f"Res1: {rec['val']['acc_res1']:4.2f}, "
        f"Res2: {rec['val']['acc_res2']:4.2f},  "
        f"Cond 2|1: {rec['val']['acc_res2_given_res1']:4.2f}, "
        f"Cond 1|2: {rec['val']['acc_res1_given_res2']:4.2f}."
        )

        if early_stop_metric == "acc_join":
            if round(
                rec["val"][early_stop_metric] - curr_best_metric, 5) > 0.001:
                curr_best_metric = rec["val"][early_stop_metric]
                current_best_epoch = epoch
                patience = 0
            else:
                patience += 1
        else:
            if (
                round(
                curr_best_metric - rec["val"][early_stop_metric], 5) > 0.001 or
                epoch == 0):

                curr_best_metric = rec["val"][early_stop_metric]
                current_best_epoch = epoch
                patience = 0
            else:
                patience += 1

        print(
            f"Epoch {epoch:2d} done in {round(timeit.default_timer() - t1, 2)} "
            f"sec.  Patience: {patience}")
        print()

        # Save training states for future resuming.
        state = {
            "epoch": epoch,
            "model_state_dict": cavity_model_net.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "patience": patience,
            f"current_best_{early_stop_metric}": curr_best_metric,
            "current_best_epoch": current_best_epoch,
        }
        model_path = f"{folder}/{model_name}_epoch_{epoch}.pt"

        torch.save(state, model_path)

        # Keep track rec (SAME ORDER as names_rec_to_save)
        rec_path = f"{folder}/metrics_{model_name}"

        if epoch > 0:
            with open(f"{rec_path}.pickle", "rb") as handle:
                history = pickle.load(handle)
            for key in rec:
                for metric in rec[key]:
                    history[key][metric].append(rec[key][metric])
            history["best_epoch"] = current_best_epoch
        else:
            history = dict()
            for key in rec:
                history[key] = dict()
                for metric in rec[key]:
                    history[key][metric] = [rec[key][metric]]

        # Pickle rec + best model name.        
        with open(f"{rec_path}.pickle", "wb") as handle:
            pickle.dump(history, handle)

        # Assess Early stopping.
        if patience > PATIENCE_CUTOFF:
            print("Early stopping activated.")
            break

    best_model_path = f"{folder}/{model_name}_epoch_{current_best_epoch}.pt"
    print(
        f"Best epoch idx: {current_best_epoch} with validation {early_stop_metric}: "
        f"{curr_best_metric:5.3f}\nFound at: "
        f"'{best_model_path}'"
    )

    return best_model_path


def _train_loop(
    cavity_model_net: CavityModel,
    dataloader_train: DataLoader,
    optimizer: torch.optim.Adam,
    loss_function: torch.nn.CrossEntropyLoss,
    matfact_k: int,
    output_shape: int,
    ) -> (torch.Tensor, float):
    """
    Helper function to perform a train loop
    """
    labels_true = []
    labels_pred = []
    loss_batch_list = []

    idx_res_split = output_shape // 2

    cavity_model_net.train()
    for batch_x, batch_y in tqdm.tqdm(dataloader_train,
                                      total=len(dataloader_train),
                                      unit="batch",
                                    ):
        optimizer.zero_grad()

        batch_y_pred = cavity_model_net(batch_x)

        # Split predictions in (20, k) x (20, k) for matrix factorization
        batch_y_pred_res1 = batch_y_pred[:, :idx_res_split].reshape(
            -1, 20, matfact_k)
        batch_y_pred_res2 = batch_y_pred[:, idx_res_split:].reshape(
            -1, matfact_k, 20)

        batch_y_pred = (batch_y_pred_res1 @ batch_y_pred_res2).reshape(-1, 400)

        loss_batch = loss_function(batch_y_pred, torch.argmax(batch_y, dim=1))
        loss_batch.backward()
        optimizer.step()

        loss_batch_list.append(loss_batch.detach().cpu().item())

        # Save joint, conditional, marginal, restype accuracies.
        labels_true.append(
            np.vstack(np.unravel_index(
                torch.argmax(batch_y, dim=1).detach().cpu().numpy(),
                (20, 20)
                )).T
        )

        labels_pred.append(
            np.vstack(np.unravel_index(
                torch.argmax(batch_y_pred, dim=1).detach().cpu().numpy(),
                (20, 20)
                )).T
        )

    loss_train = np.mean(loss_batch_list)

    return (loss_train, _get_accuracies(
                                    labels_true,
                                    labels_pred)
    )


def _eval_loop(
    cavity_model_net: CavityModel,
    dataloader_val: DataLoader,
    loss_function: torch.nn.CrossEntropyLoss,
    matfact_k: int,
    output_shape: int,
    **kwargs
    ) -> Tuple:
    """
    Helper function to perform an eval loop
    """
    # Eval loop. Due to memory, we don't pass the whole eval set to the model

    labels_true = []
    labels_pred = []
    loss_batch_list = []

    idx_res_split = output_shape // 2

    cavity_model_net.eval()
    with torch.set_grad_enabled(False):
        for batch_x, batch_y in tqdm.tqdm(dataloader_val,
                                        total=len(dataloader_val),
                                        unit="batch",
                                        leave=False
                                        ):
            batch_y_pred = cavity_model_net(batch_x)

            # Split predictions in (20, k) x (20, k) for matrix factorization
            batch_y_pred_res1 = batch_y_pred[:, :idx_res_split].reshape(
                -1, 20, matfact_k)
            batch_y_pred_res2 = batch_y_pred[:, idx_res_split:].reshape(
                -1, matfact_k, 20)

            batch_y_pred = (batch_y_pred_res1 @ batch_y_pred_res2).reshape(
                -1, 400)

            loss_batch = loss_function(
                batch_y_pred, torch.argmax(batch_y, dim=1))
            loss_batch_list.append(loss_batch.detach().cpu().item())

            # Save joint, conditional, marginal, restype accuracies.
            labels_true.append(
                np.vstack(np.unravel_index(
                    torch.argmax(batch_y, dim=1).detach().cpu().numpy(),
                    (20, 20)
                    )).T
            )

            labels_pred.append(
                np.vstack(np.unravel_index(
                    torch.argmax(batch_y_pred, dim=1).detach().cpu().numpy(),
                    (20, 20)
                    )).T
            )

    loss_val = np.mean(loss_batch_list)

    # return (loss_val, labels_true, labels_pred)
    return (loss_val, _get_accuracies( # Unpack tuple of accuracies.
                                labels_true,
                                labels_pred,
                                **kwargs)
    )


def _get_accuracies(labels_true: List[int],
                    labels_pred: List[int],
                    get_restypes_acc=False,
                    keep_pair_order=True,
    ):
    """ 
    compute join, marginal, conditional and restype accuracies
    from lists of true and predicted labels.
    """
    rec = dict()

    # Create arrays from lists.
    # labels_true = np.reshape(labels_true, (-1, 2))
    labels_true = np.vstack(labels_true)
    # labels_pred = np.reshape(labels_pred, (-1, 2))
    labels_pred = np.vstack(labels_pred)

    # joint accuracy
    mask_join = np.logical_and(labels_true[:, 0] == labels_pred[:, 0],
                               labels_true[:, 1] == labels_pred[:, 1])

    rec["acc_join"] = np.mean(mask_join)

    # marginal accuracies

    mask_res1_true = (labels_true[:, 0] == labels_pred[:, 0])
    mask_res2_true = (labels_true[:, 1] == labels_pred[:, 1])

    rec["acc_res1"] = np.mean(mask_res1_true)
    rec["acc_res2"] = np.mean(mask_res2_true)

    # conditional accuracies

    mask_res1_true = mask_res1_true.nonzero()
    mask_res2_true = mask_res2_true.nonzero()
    
    rec["acc_res2_given_res1"] = np.mean(
        labels_true[mask_res1_true, 1] == labels_pred[mask_res1_true, 1]
        )
    rec["acc_res1_given_res2"] = np.mean(
        labels_true[mask_res2_true, 0] == labels_pred[mask_res2_true, 0]
        )

    if get_restypes_acc:
        # save restypes accuracies
        pairres_count_true = np.zeros((20, 20), dtype=np.int)
        pairres_count = np.zeros((20, 20), dtype=np.int)
        pairres_count_true = np.zeros((20, 20), dtype=np.int)
        pairres_count = np.zeros((20, 20), dtype=np.int)

        # Get accuracies per residue type.
        mask_join = mask_join.nonzero()
        
        if keep_pair_order:
        # if order matters (can retrieve easily marginals):
            mask_pairres_count_true, count_true = np.unique(
                labels_pred[mask_join],
                return_counts=True, axis=0
                )
            mask_pairres_count, count = np.unique(
                labels_pred, return_counts=True, axis=0
                )

        else:
            # if order does not matter:
            mask_pairres_count_true, count_true = np.unique(
                np.sort(labels_pred[mask_join], axis=1),
                return_counts=True, axis=0
                )
            mask_pairres_count, count = np.unique(
                np.sort(labels_pred, axis=1),
                return_counts=True, axis=0
                )

        pairres_count_true[mask_pairres_count_true[:, 0],
                        mask_pairres_count_true[:, 1]] += count_true
        pairres_count[mask_pairres_count[:, 0],
                    mask_pairres_count[:, 1]] += count

        pairres_count[pairres_count == 0] = 1 # avoid 0 division.

        rec["pairres_count_true"] = pairres_count_true
        rec["pairres_count"] = pairres_count
    else:
        pass

    return rec


def _train_val_split(
    parsed_pdb_filenames: List[str],
    TRAIN_VAL_SPLIT: float,
    DEVICE: str,
    BATCH_SIZE: int,
    **kwargs
    ):
    """
    Helper function to perform training and validation split of ResidueEnvironments. Note that
    we do the split on PDB level not on ResidueEnvironment level due to possible leakage.
    """
    n_train_pdbs = int(len(parsed_pdb_filenames) * TRAIN_VAL_SPLIT)
    filenames_train = parsed_pdb_filenames[:n_train_pdbs]
    filenames_val = parsed_pdb_filenames[n_train_pdbs:]

    to_tensor_transformer = ToTensor(DEVICE, **kwargs) # allow for unravel indexing

    dataset_train = ResidueEnvironmentsDataset(
        filenames_train, transformer=to_tensor_transformer # thanks to call function
    )


    dataloader_train = DataLoader( # read the data (and shuffle it) within batch size and put into memory.
        dataset_train,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=to_tensor_transformer.collate_cat, # avoid having to load data to CUDA in the NN model itself!
        # collate_fn=to_tensor_transformer.collate_wrapper,
        drop_last=True, # drop_last=True parameter ignores the last batch (when the number of examples in your dataset is not divisible by your batch_size
        # pin_memory=True
    )

    print(
        f"Training data set includes {len(filenames_train)} pdbs with "
        f"{len(dataset_train)} environments."
    )

    # dataset_train = 0

    dataset_val = ResidueEnvironmentsDataset(
        filenames_val, transformer=to_tensor_transformer
    )

    # TODO: Fix it so drop_last doesn't have to be True when calculating validation accuracy.
    dataloader_val = DataLoader(
        dataset_val,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=to_tensor_transformer.collate_cat, # if using /batch, callable that specifies how the batch is created.
        # collate_fn=to_tensor_transformer.collate_wrapper,
        drop_last=True, # ignores the last batch (when the number of examples in your dataset is not divisible by your batch_size
        # pin_memory=True
    )

    print(
        f"Validation data set includes {len(filenames_val)} pdbs with "
        f"{len(dataset_val)} environments."
    )

    # dataset_val = 0

    return dataloader_train, dataset_train, dataloader_val, dataset_val


def get_test_dataloader(
    test_filenames: List[str],
    BATCH_SIZE: int,
    DEVICE: str,
    reshape_index=True,
    unravel_index=True,    
    ):
    """Return a dataloder for testing."""
    to_tensor_transformer = ToTensor(DEVICE,
                                     unravel_index=unravel_index,
                                     reshape_index=reshape_index)

    dataset_test = ResidueEnvironmentsDataset(
        test_filenames,
        transformer=to_tensor_transformer
        )

    dataloader_test = DataLoader(
        dataset_test,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=to_tensor_transformer.collate_cat,
        drop_last=False,
        )

    print(
        f"Testing data set includes {len(test_filenames)} pdbs with "
        f"{len(dataset_test)} environments."
    )    

    return dataset_test, dataloader_test


def _test(
    cavity_model_net: CavityModel,
    dataloader_test: DataLoader,
    loss_function: torch.nn.CrossEntropyLoss,
    matfact_k: int,
    output_shape: int,
    get_restypes_acc=True,
    keep_pair_order=True,
):
    return _eval_loop(cavity_model_net,
                      dataloader_test,
                      loss_function,
                      matfact_k,
                      output_shape,
                      get_restypes_acc=get_restypes_acc,
                      keep_pair_order=keep_pair_order,
                      )


def _predict(
    cavity_model_net: CavityModel,
    dataloader_infer: DataLoader,
    matfact_k: int,
    output_shape: int,
    ):
    """
    Get predicted proba distribution per pair_res environment.
    Made for making prediction one protein at a time, returning an array
    (n_pairs, 400) long.
    """

    labels_true = []
    idx_res_split = output_shape // 2

    softmax = torch.nn.Softmax(dim=1)
    cavity_model_net.eval()
    with torch.set_grad_enabled(False):
        for batch_x, batch_y in dataloader_infer:

            batch_y_pred = cavity_model_net(batch_x)

            # Split predictions in (20, k) x (20, k) for matrix factorization
            batch_y_pred_res1 = batch_y_pred[:, :idx_res_split].reshape(
                -1, 20, matfact_k)
            batch_y_pred_res2 = batch_y_pred[:, idx_res_split:].reshape(
                -1, matfact_k, 20)

            batch_y_pred = (batch_y_pred_res1 @ batch_y_pred_res2).reshape(
                -1, 400)
            batch_y_pred = softmax(batch_y_pred).detach().cpu().numpy()

            # Save true labels
            labels_true.append(
                np.vstack(np.unravel_index(
                    torch.argmax(batch_y, dim=1).detach().cpu().numpy(),
                    (20, 20)
                    )).T
            )

    return batch_y_pred, labels_true


def get_best_epoch_perf(model_name: str= "",
                        models_dirpath="models/double_cav_models/"):
    """
    Fetch training history of model_name,
    return results of the best epoch as string.
    """
    with open(f"{models_dirpath}/metrics_{model_name}.pickle", "rb") as f:
        history = pickle.load(f)
    best_epoch = history.pop("best_epoch")
    best_epoch_perf = {}
    for key in history:
        best_epoch_perf[key] = {}
        for metric in history[key]:
            best_epoch_perf[key][metric] = history[key][metric][best_epoch]

    best_epoch_perf = f"Best epoch: {best_epoch}\n"\
    f"{'train'.upper():5s} - "\
    f"Loss: {best_epoch_perf['train']['loss']:5.3f}, "\
    f"Join acc: {best_epoch_perf['train']['acc_join']:5.3f},  "\
    f"Res1: {best_epoch_perf['train']['acc_res1']:4.2f}, "\
    f"Res2: {best_epoch_perf['train']['acc_res2']:4.2f},  "\
    f"Cond 2|1: {best_epoch_perf['train']['acc_res2_given_res1']:4.2f}, "\
    f"Cond 1|2: {best_epoch_perf['train']['acc_res1_given_res2']:4.2f}."\
    f"\n"\
    f"{'val'.upper():5s} - Loss: {best_epoch_perf['val']['loss']:5.3f}, "\
    f"Join acc: {best_epoch_perf['val']['acc_join']:5.3f},  "\
    f"Res1: {best_epoch_perf['val']['acc_res1']:4.2f}, "\
    f"Res2: {best_epoch_perf['val']['acc_res2']:4.2f},  "\
    f"Cond 2|1: {best_epoch_perf['val']['acc_res2_given_res1']:4.2f}, "\
    f"Cond 1|2: {best_epoch_perf['val']['acc_res1_given_res2']:4.2f}."

    return best_epoch_perf


# Tools for saving model summary (objective: combine the 2!)
def get_df_summary(text: str):
    from torchsummary import summary
    text = text.split("\n")

    def parse_line(line: list):
        parsed_line = []
        for el in line:
            if not el == "":
                parsed_line.append(el)
        return parsed_line

    keys = ["Layer (type)", "Output shape", "Param #"]
    df_summary = {k: [] for k in keys}

    for line in text:
        line = parse_line(line.split("  "))
        df_summary["Layer (type)"].append(line[0])
        df_summary["Output shape"].append(line[1])
        df_summary["Param #"].append(line[2])
    return pd.DataFrame(df_summary)


def get_and_save_model_summary(model: CavityModel,
                               input_size: tuple,
                               model_name: "cavity_model"):
    import io
    from torchsummary import summary
    from contextlib import redirect_stdout

    # Context manager for temporarily redirecting sys.stdout to another file or file-like object.
    with open(f'models/{model_name}_summary.txt', 'w') as f:
        f = io.StringIO()
        with redirect_stdout(f):
            summary(model=model, input_size=input_size)
        out = f.getvalue()
        return out


# Tools for sending run's completion notification
def test_login_smtp_server(
    sender_email = ".@gmail.com",
    receiver_email = ".@gmail.com",
    ):
    """
    In case of error 534: # https://accounts.google.com/DisplayUnlockCaptcha
    Debug ref: https://stackoverflow.com/questions/16512592/login-credentials-not-working-with-gmail-smtp
    """
    port = 587  # For starttls (tls encryption protocol)
    smtp_server = "smtp.gmail.com"

    password = getpass("Type password: ")

    # Create a secure SSL context
    context = ssl.create_default_context()

    # Try to log in to server and send email
    try:
        server = smtplib.SMTP(smtp_server, port)
        server.ehlo() # Can be omitted
        server.starttls(context=context) # Secure the connection
        server.ehlo() # Can be omitted
        server.login(sender_email, password)
        # TODO: Send email here
    except Exception as e:
        # Print any error messages to stdout
        print(e)
    finally:
        server.quit()
    return password


def send_run_results(
    h: dict,
    password: str,
    models_dirpath="models/double_cav_models/",
    sender_email=".@gmail.com",
    receiver_email=".@gmail.com",
    ):

    port = 587  # For starttls (tls encryption protocol)
    smtp_server = "smtp.gmail.com"

    subject_email = f"run of Model: {h['model_name']} completed."

    perf_best_epoch = get_best_epoch_perf(h["model_name"],
                                          models_dirpath=models_dirpath)

    hyperparam_records = f"""Hyperparameters:
    ----------------
    """
    for key in h:
        if key != "model_name":
            hyperparam_records += f"{key}: {h[key]}\n"
    
    message = """\
    From: {}
    To: {}
    Subject: {}

    {}

    {}""".format(
        sender_email,
        receiver_email,
        subject_email,
        perf_best_epoch,
        hyperparam_records
        )
    message = "\n".join([line.lstrip() for line in message.split("\n")])

    context = ssl.create_default_context()
    with smtplib.SMTP(smtp_server, port) as server:
        server.ehlo()  # Can be omitted
        server.starttls(context=context)
        server.ehlo()  # Can be omitted
        server.login(sender_email, password)
        server.sendmail(sender_email, receiver_email, message)

        print(f"Email sent to {receiver_email}.")


def send_run_results(
    h: dict,
    password: str,
    models_dirpath="models/double_cav_models/",
    sender_email=".@gmail.com",
    receiver_email=".@gmail.com",
    ):

    port = 587  # For starttls (tls encryption protocol)
    smtp_server = "smtp.gmail.com"

    subject_email = f"run of Model: {h['model_name']} completed."

    perf_best_epoch = get_best_epoch_perf(h["model_name"],
                                          models_dirpath=models_dirpath)

    hyperparam_records = f"""Hyperparameters:
    ----------------
    """
    for key in h:
        if key != "model_name":
            hyperparam_records += f"{key}: {h[key]}\n"
    
    message = """\
    From: {}
    To: {}
    Subject: {}

    {}

    {}""".format(
        sender_email,
        receiver_email,
        subject_email,
        perf_best_epoch,
        hyperparam_records
        )
    message = "\n".join([line.lstrip() for line in message.split("\n")])

    context = ssl.create_default_context()
    with smtplib.SMTP(smtp_server, port) as server:
        server.ehlo()  # Can be omitted
        server.starttls(context=context)
        server.ehlo()  # Can be omitted
        server.login(sender_email, password)
        server.sendmail(sender_email, receiver_email, message)

        print(f"Email sent to {receiver_email}.")

# Parse .pdb files

In [None]:
from pdb_parser_scripts.clean_pdb import clean_pdb
import traceback

# !chmod +x reduce/reduce
# !chmod +x pdb_parser_scripts/clean_pdb.py
# !chmod +x pdb_parser_scripts/extract_pair_environments.py

In [None]:
glob.glob("data/pdbs/raw_tm/*")

In [None]:
try:
    os.makedirs("data/pdbs/cleaned")
    os.makedirs("data/pdbs/parsed")
except FileExistsError:
    pass

In [None]:
pdb_files = [""]

fails = []
for i, pdb_filename in enumerate(pdb_files):
    try:
        clean_pdb(pdb_filename, "data/pdbs/cleaned", "/reduce/reduce")
        pdb_filename = os.path.basename(pdb_filename)
        print(f"{pdb_filename} cleaned successfully ({i+1}/{len(pdb_files)})")

    except Exception:
        print(f"{pdb_filename} failed. Nb: {i+1}.")
        fails.append(pdb_filename)

        error_msg = traceback.format_exc()
        print(error_msg)

In [None]:
pdb_files = ["", ""]
seq = {}
res_nb = {}
for pdb_file, tag in zip(pdb_files, ["str", "seq"]):
    print(os.path.basename(pdb_file))
    pdb_simtk = simtk.openmm.app.PDBFile(pdb_file)
    seq[tag] = []
    res_nb[tag] = []
    for chain in pdb_simtk.getTopology().chains():
        for res in chain.residues():
            try:
                seq[tag].append(Bio.PDB.Polypeptide.three_to_one(res.name))
                res_nb[tag].append(res.id)
            except:
                print("error for", res.name)



In [None]:
mask = np.ones_like(seq["seq"], dtype=np.bool)
mask[0] = 0
mask[90:99] = 0

compatibility_table = pd.DataFrame({"seq_str": np.array(seq["str"]), "seq_seq": np.array(seq["seq"])[mask][:-1],
              "res_nb_str": np.array(res_nb["str"], dtype=int)-20,
              "res_nb_seq": np.array(res_nb["seq"], dtype=int)[mask][:-1]-2})

# create a number correction mapper
nb_mapper = {}
for wrong_nb, corr_nb in zip(list(compatibility_table.index), compatibility_table.res_nb_str):
    nb_mapper[wrong_nb] = int(corr_nb)

In [None]:
from pdb_parser_scripts import (
    extract_pair_environments,
    grid
)

pdb_files = sorted(glob.glob("data/pdbs/cleaned/*_clean.pdb"))

OUT_DIR = "data/pdbs/parsed_325"
if not os.path.isdir(OUT_DIR):
    os.makedirs(OUT_DIR)

fails=[]
for i, filename in enumerate(pdb_files):
    pdb_filename = os.path.basename(filename)
    pdb_id = pdb_filename.split(".")[0]
    try:
        extract_pair_environments.extract_environments(filename,
                                                       pdb_id,
                                                       out_dir=OUT_DIR,
                                                       max_width_x=4.5,
                                                       max_width_y=4.5,
                                                       max_height=9.0,
                                                    #    ca_ca_cutoff=7.0,
                                                       ca_ca_dist_based=False,
                                                       max_radius=3.25,
                                                       min_radius=0,
                                                       )
        print(f"{pdb_filename} successful ({i+1}/{len(pdb_files)})")

    except Exception:
        print(f"{pdb_filename} failed. Nb: {i+1}.")
        fails.append(filename)
        print(traceback.format_exc())

In [None]:
from pdb_parser_scripts import (
    extract_pair_environments,
    grid
)

pdb_files = sorted(glob.glob("data/pdbs/cleaned/*_clean.pdb"))

OUT_DIR = "data/pdbs/parsed_450"
if not os.path.isdir(OUT_DIR):
    os.makedirs(OUT_DIR)

fails=[]
for i, filename in enumerate(pdb_files):
    pdb_filename = os.path.basename(filename)
    pdb_id = pdb_filename.split(".")[0]
    try:
        extract_pair_environments.extract_environments(filename,
                                                       pdb_id,
                                                       out_dir=OUT_DIR,
                                                       max_width_x=4.5,
                                                       max_width_y=4.5,
                                                       max_height=9.0,
                                                    #    ca_ca_cutoff=7.0,
                                                       ca_ca_dist_based=False,
                                                       max_radius=4.5,
                                                       min_radius=0,
                                                       )
        print(f"{pdb_filename} successful ({i+1}/{len(pdb_files)})")

    except Exception:
        print(f"{pdb_filename} failed. Nb: {i+1}.")
        fails.append(filename)
        print(traceback.format_exc())

In [None]:
ssbonds = pd.read_table("data/tm_results.csv", delimiter=";")
ssbonds.rename(columns={"Tm (Celcius)": "tm", "construct_planned_mutations": "ss_mut"}, inplace=True)
ssbonds.drop(index=0, inplace=True)
ssbonds = ssbonds.reset_index(drop=True)

# convert the strings "[res1_id, res2_id]" to separate columns, and proper tuple of res_id.
ssbonds_res = ssbonds.ss_mut.str.extractall("'(?P<res1_id_temp>.*)', '(?P<res2_id_temp>.*)'").reset_index()[["res1_id_temp", "res2_id_temp"]]
ssbonds[["res1_id_temp", "res2_id_temp"]] = ssbonds_res

for i in range(2):
    ssbonds[f"res{i+1}_id"] = ssbonds[f"res{i+1}_id_temp"].apply(
        lambda x: [
            Bio.PDB.Polypeptide.one_to_three(x[0]),
            int(x[1:-1]),
            "A"
        ])

    ssbonds[f"res{i+1}"] = ssbonds[f"res{i+1}_id"].apply(
        lambda x: 
            Bio.PDB.Polypeptide.three_to_index(x[0])
    )

    ssbonds[f"res{i+1}_target"] = ssbonds[f"res{i+1}_id_temp"].apply(
        lambda x: x[-1])

    ssbonds.drop(columns=f"res{i+1}_id_temp", inplace=True)

# ADD ASSERTION THAT FOR ALL PAIRS, ID FOR RES1 IS < ID RES2 (otherwise invert them.)

In [None]:
ssbonds.to_pickle(f"data/tm_ssbonds_table")

In [None]:
with open("data/nb_mapper.pickle", "wb") as f:
    pickle.dump(nb_mapper, f)

## Get cavity model predictions

if "CUDA out of memory" issues: switch to another GPU BEFORE importing torch:

In [None]:
# import os
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
# os.environ["CUDA_VISIBLE_DEVICES"]="3"

# import torch
# available_gpus = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
# print(available_gpus)

In [None]:
# !nvidia-smi

In [None]:
ssbonds = pd.read_pickle("data/tm_ssbonds_table")

with open("data/nb_mapper.pickle", "rb") as f:
    nb_mapper = pickle.load(f)

In [None]:
with open("data/nb_mapper.pickle", "rb") as f:
    nb_mapper = pickle.load(f)
inv_mapper = {v: k for k, v in nb_mapper.items()}
assert len(inv_mapper) == len(nb_mapper)

In [None]:
ssbonds["pdb_id"] = "gcl1_apo_final_york_experimental_structure"
ssbonds["ss"] = 1

In [None]:
# model_name = "m1_bigger_backbone_b1000_5atoms_matfact_1_lr_2" # parsed_450
# dset_name = "parsed_450"
# model_name = "m1_bigger_backbone_b1000_5atoms_matfact_1_lr_3_q_1" # parsed_325 (no weights)
dset_name = "parsed_325"
model_name = "m1_bigger_backbone_b1000_5atoms_matfact_1_lr_3_q_1_weighted" # parsed_325 (weighted)

# model_name = "idp_325_1"

MODELS_DIRPATH = "models/double_cav_models"
DEVICE = "cuda"

with open(f"{MODELS_DIRPATH}/{model_name}_metadata.pickle", "rb") as f:
    h = pickle.load(f)

with open(f"{MODELS_DIRPATH}/metrics_{model_name}.pickle", "rb") as f:
    history = pickle.load(f)

best_epoch = history["best_epoch"]
print("Best epoch:", best_epoch)
model_state = torch.load(f"{MODELS_DIRPATH}/{model_name}_epoch_{best_epoch}.pt")["model_state_dict"]

In [None]:
pred_all_pairs = {}
ss_pairs_mask = {}

for pdb_id, res1, res2 in zip(ssbonds.pdb_id, ssbonds.res1_id, ssbonds.res2_id):

    filename = f"data/pdbs/{dset_name}/{pdb_id}_clean_0_pair_res_features.npz"
    file_ = np.load(filename)

    try:
        mask = np.logical_and(file_["pair_res_indices"][:, 0] == inv_mapper[res1[1]],
                              file_["pair_res_indices"][:, 1] == inv_mapper[res2[1]])
    except KeyError:
        print(f"{res1[1]} or {res2[1]} not included in .pdb structure.")

    if pdb_id not in pred_all_pairs:
        pred_all_pairs[pdb_id] = {}
        ss_pairs_mask[pdb_id] = []

        test_dataset, test_dataloader = get_test_dataloader(
                                            [filename],
                                            BATCH_SIZE=file_["pair_res_indices"].shape[0],
                                            DEVICE=DEVICE,
                                            reshape_index=False,
                                                    )
        # Define model
        cavity_model_net = h["cav_model"](DEVICE, 
                                        grid_dim_xy=h["grid_dim_xy"],
                                        grid_dim_z=h["grid_dim_z"],
                                        n_atom_types=h["n_atom_types"]).to(DEVICE)

        cavity_model_net.load_state_dict(model_state)
 
        preds, labels_true = _predict(
            cavity_model_net,
            test_dataloader,
            h["matfact_k"],
            h["output_shape"])

        labels_true = np.reshape(labels_true, (-1, 2))

        pred_all_pairs[pdb_id]["preds"] = preds
        pred_all_pairs[pdb_id]["labels_true"] = labels_true


    try:
        ss_pairs_mask[pdb_id].append(mask.nonzero()[0][0])
    except (IndexError, KeyError):
#         ss_pairs_mask[pdb_id].append(None)
        print(
            f"ssbond pair ({res1[1]}-{res2[1]}) not found. "
#             f"mapped pair indices: {inv_mapper[res1[1]]}-{inv_mapper[res2[1]]}."
        )


# Create a dataframe per pdb_id and concatenate them.
df = []
for i, pdb_id in enumerate(pred_all_pairs):
    df.append(pd.DataFrame({
        "pdb_id": pdb_id,
        "res1": [label[0] for label in pred_all_pairs[pdb_id]["labels_true"]],
        "res2":  [label[1] for label in pred_all_pairs[pdb_id]["labels_true"]],
        "preds": [pred for pred in pred_all_pairs[pdb_id]["preds"]],
        "ss": 0,
    })
    )
    try:
        df[i].loc[ss_pairs_mask[pdb_id], "ss"] = 1
    except (IndexError, KeyError):
        pass

df = pd.concat(df)

In [None]:
pred_all_pairs = {}
ss_pairs_mask = {}

for pdb_id, res1, res2 in zip(ssbonds.pdb_id, ssbonds.res1_id, ssbonds.res2_id):

    filename = f"data/pdbs/{dset_name}/{pdb_id}_clean_1_pair_res_features.npz"
    file_ = np.load(filename)

    try:
        mask = np.logical_and(file_["pair_res_indices"][:, 0] == inv_mapper[res2[1]],
                              file_["pair_res_indices"][:, 1] == inv_mapper[res1[1]])
    except KeyError:
        print(f"{res1[1]} or {res2[1]} not included in .pdb structure.")

    if pdb_id not in pred_all_pairs:
        pred_all_pairs[pdb_id] = {}
        ss_pairs_mask[pdb_id] = []

        test_dataset, test_dataloader = get_test_dataloader(
                                            [filename],
                                            BATCH_SIZE=file_["pair_res_indices"].shape[0],
                                            DEVICE=DEVICE,
                                            reshape_index=False,
                                                    )
        # Define model
        cavity_model_net = h["cav_model"](DEVICE, 
                                        grid_dim_xy=h["grid_dim_xy"],
                                        grid_dim_z=h["grid_dim_z"],
                                        n_atom_types=h["n_atom_types"]).to(DEVICE)

        cavity_model_net.load_state_dict(model_state)
 
        preds, labels_true = _predict(
            cavity_model_net,
            test_dataloader,
            h["matfact_k"],
            h["output_shape"])

        labels_true = np.reshape(labels_true, (-1, 2))

        pred_all_pairs[pdb_id]["preds"] = preds
        pred_all_pairs[pdb_id]["labels_true"] = labels_true


    try:
        ss_pairs_mask[pdb_id].append(mask.nonzero()[0][0])
    except (IndexError, KeyError):
#         ss_pairs_mask[pdb_id].append(None)
        print(
            f"ssbond pair ({res1[1]}-{res2[1]}) not found. "
#             f"mapped pair indices: {inv_mapper[res1[1]]}-{inv_mapper[res2[1]]}."
        )


# Create a dataframe per pdb_id and concatenate them.
df_2 = []
for i, pdb_id in enumerate(pred_all_pairs):
    df_2.append(pd.DataFrame({
        "pdb_id": pdb_id,
        "res1": [label[0] for label in pred_all_pairs[pdb_id]["labels_true"]],
        "res2":  [label[1] for label in pred_all_pairs[pdb_id]["labels_true"]],
        "preds_2": [pred for pred in pred_all_pairs[pdb_id]["preds"]],
        "ss": 0,
    })
    )
    try:
        df[i].loc[ss_pairs_mask[pdb_id], "ss"] = 1
    except KeyError:
        pass

df_2 = pd.concat(df_2)

Average the predictions for the two copies per .pdb

In [None]:
df_2.rename(columns={"res2": "res1", "res1": "res2"}, inplace=True)

df["avg_preds"] = (df.preds + df_2.preds_2) / 2

del df_2

df = df.drop(columns=["preds"]).rename(columns={"avg_preds": "preds"})

In [None]:
# # Merge results with pairs of residues id.
ssbonds = pd.merge(df[df.ss == 1], ssbonds, on=["ss", "pdb_id", "res1", "res2"], left_index=True)
ssbonds["pair_nb"] = ssbonds.index

# Save results.
ssbonds.to_pickle(f"data/ssbonds_preds_cav_{h['model_name']}")
df.to_pickle(f"data/preds_cav_{h['model_name']}")

## SSBONDPredict

In [None]:
import tensorflow as tf
import sys
import numpy as np
import operator
from collections import OrderedDict
import os
import warnings
from scipy.spatial.distance import pdist
import scipy.spatial.distance as ssd
import time
import math
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
warnings.filterwarnings('ignore')

# 1) process_loadedpdb.py.......................................................................................

def process_pdb(args, PositionOfThisProject):
    name = args.split('/')[-1].split('.')[0]
    map_list, map_id, mol_type_list = find_map_element(args)

    if map_list == []:
        print('no bonds')
        return False

    # Get the pairs of residues with CA-CA distce in [3, 7] angstrom.
    possible_ssbond, possible_ssbond_id = make_ssbond_without_repeat(map_list, map_id, mol_type_list)
    full_distance_map = convert_to_nxn_map(np.array(possible_ssbond))

    # print('candidate bonds', len(full_distance_map))

    predict_path = np.array(full_distance_map) # stack all 10x10 on axis 0 (n, 10, 10)
    predict_ord_path = np.array(possible_ssbond_id) # (n, 2)

    # Get probability and 
    result_dict = main([predict_path, predict_ord_path, name], PositionOfThisProject)

    # noSG_restore_fnn.set_pointdir(PositionOfThisProject)
    sorted_result_dict = sorted(result_dict.items(), key=operator.itemgetter(1), reverse=True) # key could be lambda x: x[1]
    # final_dict = sorted_result_dict
    final_dict = OrderedDict()
    for item in sorted_result_dict:
        final_dict[item[0]] = item[1]
    return final_dict
#read the result_dict and sav


In [None]:
#2) ssbond_distance_map.py.......................................................................................

def convert_to_nxn_map(ssbonds_map):
    errorcount = 0
    ssbonds_distance_map = []
    pos = 0
    for smapi in range(len(ssbonds_map)):
        try:
            Y = pdist(ssbonds_map[smapi], 'euclidean') # 45-dimensional vector
            ssbonds_distance_map.append(ssd.squareform(Y)) # convert distance vector to square matrix 10x10

        except ValueError:
            errorcount += 1

            with open('ssbond_map_error.txt', 'a') as wf:
                wf.write(str(ssbonds_map[smapi]))
                wf.write('\n')

            for xyz in ssbonds_map[smapi]:
                if abs(len(xyz[0])-len(xyz[1])) <= 1:
                    continue
                if len(xyz[0]) > len(xyz[1]):
                    pos = 0
                else:
                    pos = 1
                temp = xyz[pos].split('-')

                if pos == 0:
                    xyz[pos+2] = xyz[pos+1]
                    if temp[0] == '':
                        xyz[pos] = float('-'+temp[1])
                        xyz[pos+1] = float('-'+temp[2])
                    else:
                        xyz[pos] = float(temp[0])
                        xyz[pos+1] = float('-'+temp[1])
                else:
                    if temp[0] == '':
                        xyz[pos] = float('-'+temp[1])
                        xyz[pos+1] = float('-'+temp[2])
                    else:
                        xyz[pos] = float(temp[0])
                        xyz[pos+1] = float('-'+temp[1])
            with open('ssbond_map_correct_error.txt', 'a') as wcf:
                wcf.write(str(ssbonds_map[smapi]))
                wcf.write('\n')
                wcf.write('***************************************\n')
            Y = pdist(ssbonds_map[smapi], 'euclidean')
            ssbonds_distance_map.append(ssd.squareform(Y))
    return ssbonds_distance_map


In [None]:
#3) extract_unknown_map.py.......................................................................................

remove_pairs = open('small_ca_remove.txt', 'w')
def compare_CA_distance(A_CA, B_CA, nameA, nameB):
    sumCA = 0
    for xyz in range(3):
        sumCA += pow((float(A_CA[xyz])-float(B_CA[xyz])), 2)
    distance = math.sqrt(sumCA)
    if distance < 7: # comparison operators chaining
        return True
    else:
        remove_pairs.write(nameA+', '+nameB+':'+str(distance) +'\n')
        return False


def exmain_mol_list(mol_list, line):
    if mol_list[0] != [] and mol_list[1] != [] and mol_list[2] != [] and mol_list[3] != [] and mol_list[4] != []:
        return True, line[17:20].strip()+line[21]
    else:
        return False, line[17:20].strip()+line[21]

def find_map_element(filename):
    """
    Excludes Glycine in this preprocessing since it has no CB atom,.
    and also excludes Proline."""

    mol_map_list = []
    flag_mol = False
    mol_id = None
    mol_id_list = []
    mol_name_temp = None
    last_mol = None
    mol_type_list = []
    count = 0
    break_count = 0

    with open(filename, 'r') as f:
        for line in f:
            break_count += 1
            line_tag = line[:6].strip()
            if line_tag != 'ATOM':
                continue
#             if line[17:20].strip() == 'PRO':
#                 continue
            if line_tag == 'ENDMDL' :
                break

            residue = line[17:20].strip()+line[21]+line[22:26].strip()

            # Initialize a residue.            
            if line_tag == 'ATOM' and mol_name_temp == None:
                mol_name_temp = residue
                mol_list = [ [] for i in range(5)]
            elif line_tag == 'ATOM' and mol_name_temp != residue: # new residue
                mol_name_temp = residue
                mol_list = [ [] for i in range(5)]
                count += 1

            # Save 3d coordinates of the 5 atoms of interest of this residue
            if line_tag == 'ATOM':
                if line[12:16].strip() == 'N' and mol_list[0] ==[]:
                    mol_list[0]=[line[30:38].strip(), line[38:46].strip(), line[46:54].strip()]
                    flag_mol, mol_id = exmain_mol_list(mol_list, line)
                elif line[12:16].strip() =='CA' and mol_list[1] == []:
                    mol_list[1]=[line[30:38].strip(), line[38:46].strip(), line[46:54].strip()]
                    flag_mol, mol_id = exmain_mol_list(mol_list, line)
                elif line[12:16].strip()  =='C' and mol_list[2] == []:
                    mol_list[2]=[line[30:38].strip(), line[38:46].strip(), line[46:54].strip()]
                    flag_mol, mol_id = exmain_mol_list(mol_list, line)
                elif line[12:16].strip()  =='O' and mol_list[3] == []:
                    mol_list[3]=[line[30:38].strip(), line[38:46].strip(), line[46:54].strip()]
                    flag_mol, mol_id = exmain_mol_list(mol_list, line)
                elif line[12:16].strip()  =='CB' and mol_list[4] == []:
                    mol_list[4]=[line[30:38].strip(), line[38:46].strip(), line[46:54].strip()]
                    flag_mol, mol_id = exmain_mol_list(mol_list, line)

                # Append to list of residues
                if flag_mol == True:
                    mol_map_list.append(mol_list)
                    mol_id_list.append(mol_id)
                    mol_type_list.append(residue)
                    flag_mol = False

    return mol_map_list, mol_id_list, mol_type_list

def make_ssbond_without_repeat(map_list, map_id, mol_type_list):
    if len(map_list) != len(map_id):
        # print('map list length is not equal to map id list!')
        sys.exit()

    possible_ssbond = []
    possible_ssbond_id = []
    for i in range(len(map_list)-1):
        for j in range(i+1, len(map_list)):
            if i == j: # should not happen
                continue
            elif mol_type_list[i][1:] == mol_type_list[j][1:]: # should not happen
                continue
            elif compare_CA_distance(map_list[i][1], map_list[j][1], mol_type_list[i], mol_type_list[j]):
                temp = map_list[i][:]
                temp.extend(map_list[j])
                possible_ssbond.append(temp)
                possible_ssbond_id.append((mol_type_list[i], mol_type_list[j]))
            else:
                continue

    return possible_ssbond, possible_ssbond_id


In [None]:
# 4) noSG_restore_fnn.......................................................................................

def set_pointdir(basepath):
    checkpoint_dir = os.path.join(basepath, 'SSBONDPredict/PreDisulfideBond/static/newmodel')
    return checkpoint_dir
#add energy, only add a parameter result_E into predict
#def predict(args, sess, images, labels, logits, out):
def predict(args, sess, images, labels, logits, out):
    data = args[0]
    id_ord = args[1]
    name = args[2]
    out_ = sess.run(out, feed_dict={images:data.reshape((len(data), 100))}) # run the model, feed_dict is a dict precising what x and y placeholders are fed with
    count = 0
    result_dict = {}
    new_list=[]
    new_list_score = []
    for outi in range(len(out_)):
        # calculating entropy
        # if out_[outi][1] > out_[outi][0]:
        #     count += 1
        #     number1 = id_ord[outi][0][4:]
        #     number2 = id_ord[outi][1][4:]
        #     distance = abs( int(number1) - int(number2) )
        #     if distance != 0:
        #         t = math.log(distance, )
        #         s = -2.1 - 1.5*8.314*t # volume
        #         s = '%.4f'% s
        #     else:
        #         s = -2.1
        #         s = '%.4f'% s
            result_dict[id_ord[outi][0]+'-'+id_ord[outi][1]] = str('%.3f'% out_[outi][1]) #+ ' ' + str(s)
    print('finish predict.')
    return result_dict


def main(args, basepath):
    sess=tf.compat.v1.Session()
    checkpoint_dir = set_pointdir(basepath)
#    ckpt = tf.train.checkpoint_exists(checkpoint_dir)
    ckpt_path = os.path.join(checkpoint_dir, 'model.ckpt-800')

    saver = tf.compat.v1.train.import_meta_graph(ckpt_path + '.meta')
    saver.restore(sess, ckpt_path)
    graph = tf.compat.v1.get_default_graph()

    # Placeholders
    images = graph.get_tensor_by_name('image:0')
    labels = graph.get_tensor_by_name('labels:0')
    logits = graph.get_tensor_by_name('softmax_linear/add:0')

    out = tf.nn.softmax(logits=logits)

    result_dict = predict(args, sess, images, labels, logits, out)
    return result_dict

## There is a contradiction here!!! they use an input of 100, while in the original paper, they say 45!

# Get predictions of SSBondPredict

In [None]:
# Get all predictions from SSBONDPREDICT
pdb_files = sorted(glob.glob("data/pdbs/cleaned/*_clean.pdb"))

ssbond_df = []
for pdb_file in pdb_files:
    print(pdb_file)
    pdb_id = os.path.basename(pdb_file.split("_clean")[0])
    PositionOfThisProject = "."
    preds = process_pdb(args=pdb_file,
                    PositionOfThisProject=PositionOfThisProject)
    preds = pd.DataFrame(preds.items(),
                         columns=["ssbond_format", "p_ss"])
    preds["pdb_id"] = pdb_id
    preds["ss"] = 0
    ssbond_df.append(preds)

ssbond_df = pd.concat(ssbond_df)

In [None]:
ssbonds = pd.read_pickle("data/tm_ssbonds_table")

with open("data/nb_mapper.pickle", "rb") as f:
    nb_mapper = pickle.load(f)

In [None]:
inv_mapper = {v: k for k, v in nb_mapper.items()}
assert len(inv_mapper) == len(nb_mapper)

In [None]:
ssbonds["pdb_id"] = "gcl1_apo_final_york_experimental_structure"
ssbonds["ss"] = 1

In [None]:
ssbond_df.ssbond_format = ssbond_df.ssbond_format.str.split("-")

for i in range(2):
    ssbond_df[f"res{i+1}_id"] = ssbond_df.ssbond_format.apply(
        lambda x: [
        x[i][:3],
        int(x[i][4:]), # position
        x[i][3] # chain
          ])

    ssbond_df[f"res{i+1}"] = ssbond_df[f"res{i+1}_id"].apply(
        lambda x: 
            Bio.PDB.Polypeptide.three_to_index(x[0])
    )

ssbond_df.drop(columns=f"ssbond_format", inplace=True)

In [None]:
# with open("data/ssbondpredict_nb_mapper.pickle", "wb") as f:
#     pickle.dump(ssbondpredict_mapper, f)

In [None]:
ssbond_df["ss"]
mask_indices = []
for pdb_id, res1, res2 in zip(ssbonds.pdb_id, ssbonds.res1_id, ssbonds.res2_id):
    try:
        mask_indices.append(
            ssbond_df[
                (ssbond_df.res1_id.apply(lambda x: x[1]) == ssbondpredict_mapper[res1[1]]) & 
                (ssbond_df.res2_id.apply(lambda x: x[1]) == ssbondpredict_mapper[res2[1]])
            ].index[0]
        )

    except (KeyError, IndexError):
        print(f"{res1[1]} or {res2[1]} not included.")

ssbond_df.ss.iloc[mask_indices] = 1

In [None]:
ssbond_df.to_pickle(f"data/preds_ssbond")

# Get ROC and PR curves

In [None]:
import sklearn
from sklearn.metrics import (
    roc_auc_score, roc_curve,
    precision_recall_curve, f1_score, auc,
    confusion_matrix
)

# model_name = "m1_bigger_backbone_b1000_5atoms_matfact_1_lr_2" # parsed_450
# dset_name = "parsed_450"
# model_name = "m1_bigger_backbone_b1000_5atoms_matfact_1_lr_3_q_1" # parsed_325 (no weights)
dset_name = "parsed_325"
model_name = "m1_bigger_backbone_b1000_5atoms_matfact_1_lr_3_q_1_weighted" # parsed_325 (weighted)

# model_name = "idp_325_1"

In [None]:
# 1) Load cavity model's predictions
preds_cav = pd.read_pickle(f"data/preds_cav_{model_name}")
preds_cav["p_ss"] = preds_cav.apply(lambda x: x.preds[21], axis=1)

with open("data/nb_mapper.pickle", "rb") as f:
    nb_mapper = pickle.load(f)
inv_mapper = {v: k for k, v in nb_mapper.items()}


# 2) Load SSBondPredict's predictions
preds_sspred = pd.read_pickle(f"data/preds_ssbond")
preds_sspred.p_ss = preds_sspred.p_ss.astype(float)

with open("data/ssbondpredict_nb_mapper.pickle", "rb") as f:
    sspredict_mapper = pickle.load(f)


# 3) Load ssbonds table
ssbonds = pd.read_pickle(f"data/ssbonds_preds_cav_{model_name}")

# 4) For a fairer comparison, ignore pairs containing PRO and GLY
preds_cav_same_dset = preds_cav[(~preds_cav.res1.isin([12, 5])) & (~preds_cav.res2.isin([12, 5]))]

## ROC curves

In [None]:
def get_pair_index(res1, res2):
    three_to_index = Bio.PDB.Polypeptide.three_to_index
    return np.ravel_multi_index(
        [three_to_index(res1[0]), three_to_index(res2[0])],
        (20, 20)
        )

def plot_roc_curve(cav_ss, cav_pss,
                   ssbond_ss, ssbond_pss,
                   save_name="",
    ):
    cav_fpr, cav_tpr, cav_thresholds = roc_curve(y_true=cav_ss,
                                                y_score=cav_pss, )
    cav_auc = roc_auc_score(cav_ss, cav_pss)

    ssbond_fpr, ssbond_tpr, ssbond_thresholds = roc_curve(y_true=ssbond_ss,
                                                        y_score=ssbond_pss, )
    ssbond_auc = roc_auc_score(y_true=ssbond_ss, y_score=ssbond_pss)

    # plot the roc curve for the model
    plt.figure(figsize=(5, 4), constrained_layout=True)


    plt.plot(cav_fpr, cav_tpr, marker='.', label=f'cav_model: auc = {cav_auc:5.3f}')
    plt.plot(ssbond_fpr, ssbond_tpr, marker='.', label=f'ssbond: auc = {ssbond_auc:5.3f}')
    plt.plot([0, 1], [0, 1], linestyle='--', linewidth=1) #, label="random classifier")

    legend = plt.legend(shadow=True)
    frame = legend.get_frame()
    frame.set_facecolor('white')
    frame.set_edgecolor('black')
    plt.grid(linestyle="-", alpha=0.5)

    plt.xlabel("1 - Specificity (FPR)")
    plt.ylabel("Sensitivity (TPR)")
    if save_name != "":
        plt.savefig(f"results/benchmark/roc_ssbond_cav_{save_name}.png",
                        dpi=200, bbox_inches = "tight")
    plt.show()

plot_roc_curve(preds_cav_same_dset.ss, preds_cav_same_dset.p_ss,
               preds_sspred.ss, preds_sspred.p_ss,
               save_name=f"same_dset_{model_name}")

plot_roc_curve(preds_cav.ss, preds_cav.p_ss,
               preds_sspred.ss, preds_sspred.p_ss,
               save_name=f"diff_dset_{model_name}")

## Precision-recal curves

In [None]:
# In terms of model selection,
# F-Measure summarizes model skill for a specific probability threshold (e.g. 0.5),
# whereas the area under curve summarize the skill of a model across thresholds, like ROC AUC.

In [None]:
def plot_pr_curve(cav_ss, cav_pss,
                   ssbond_ss, ssbond_pss,
                   save_name="",
    ):
    cav_precision, cav_recall, _ = precision_recall_curve(cav_ss,
                                                          cav_pss, )
    cav_auc = auc(cav_recall, cav_precision)

    ssbond_precision, ssbond_recall, _ = precision_recall_curve(ssbond_ss,
                                                                ssbond_pss, )
    ssbond_auc = auc(ssbond_recall, ssbond_precision)

    # plot the PR curve for the model
    plt.figure(figsize=(5, 4), constrained_layout=True)


    plt.plot(cav_recall, cav_precision, marker='.',
            label=f'cav_model: auc = {cav_auc:5.4f}')
    plt.plot(ssbond_recall, ssbond_precision, marker='.',
            label=f'ssbond: auc = {ssbond_auc:5.4f}')

    plt.axhline(y=sum(cav_ss == 1)/len(cav_ss),
                linestyle='--',
                linewidth=1)

    legend = plt.legend(shadow=True)
    frame = legend.get_frame()
    frame.set_facecolor('white')
    frame.set_edgecolor('black')
    plt.grid(linestyle="-", alpha=0.5)

    plt.xlabel("log( Recall (TPR) )")
    plt.ylabel("log( Precision (PPV) )")

    plt.yscale("log")
    plt.xscale("log")
    plt.savefig(f"results/benchmark/pr_ssbond_cav_{save_name}.png",
                    dpi=200, bbox_inches = "tight")
    plt.show()

In [None]:
plot_pr_curve(preds_cav_same_dset.ss, preds_cav_same_dset.p_ss,
               preds_sspred.ss, preds_sspred.p_ss,
               save_name=f"same_dset_{model_name}")

plot_pr_curve(preds_cav.ss, preds_cav.p_ss,
               preds_sspred.ss, preds_sspred.p_ss,
               save_name=f"diff_dset_{model_name}")

## Confusion matrix

In [None]:
def plot_confusion_matrix(y_true, probas_pred, threshold,
                          model_name,
                          save_name="model",
                          save=False,
                          cmap="viridis"):

    sns.set(font_scale=1.2, style="ticks")
    cf_matrix = confusion_matrix(y_true, probas_pred > threshold)

    group_names = ["TN","FP","FN","TP"]
    group_counts = [f"{value}" for value in cf_matrix.flatten()]
    row_sums = cf_matrix.sum(axis=1)
    norm_matrix = cf_matrix / row_sums[:, np.newaxis]
    group_percentages = [f"{value*100:.2f}%" for value in norm_matrix.flatten()]
    labels = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in zip(group_names, group_counts, group_percentages)]
    labels = np.asarray(labels).reshape(2,2)

    # Plot heatmap
    sns.heatmap(norm_matrix, annot=labels, annot_kws={"size": 15}, 
                fmt="", vmin=0, vmax=1, cmap=cmap, linewidths=1)
    plt.ylabel("Predicted", fontsize=13)
    plt.xlabel("Predicted", fontsize=13, labelpad=-0)

    plt.title(f"{model_name}, thr: {threshold}", fontsize=15)
    if save:
        plt.savefig(f"results/benchmark/cf_matrix_{save_name}.png",
                dpi=200, bbox_inches = "tight")
    plt.show()

In [None]:
threshold = 9.5e-6
plot_confusion_matrix(preds_cav_same_dset.ss, preds_cav_same_dset.p_ss, threshold,
                      model_name="cav_model (common)", save_name=f"cav_model_same_{model_name}", save=True)

plot_confusion_matrix(preds_cav.ss, preds_cav.p_ss, threshold,
                      model_name="cav_model (all)", save_name=f"cav_model_all_{model_name}", save=True)

In [None]:
threshold = 1.3e-2

plot_confusion_matrix(preds_sspred.ss, preds_sspred.p_ss, threshold,
                      model_name="cav_model (all)", save_name=f"ssbondpredict", save=True)