# Unveiling the Robustness and Potential of CS Predictors in Neurotensin Receptor Studies

In this jupyter-notebook, we provide the code and the simulation data necessary to reproduce the results in the section "Unveiling the Robustness and Potential of CS Predictors in Neurotensin Receptor Studies". Follow along to reproduce the experiments. This notebook includes reading data from CSV files, and computing the correlation between experimental and computational results.

### Load Modules

In [None]:
import csv
import matplotlib.pyplot as plt
import statistics
import nglview as nv
import numpy as np
import MDAnalysis as mda
from MDAnalysis.analysis import align, rms

### Load Functions

In [None]:
def read_csv(file_path):
    """
    Reads data from a CSV file containing chemical shift information.

    Args:
    file_path (str): Path to the CSV file.

    Returns:
    tuple: A dictionary containing chemical shift data and a list of residue information.
    """
    data_dict = {}  # Dictionary to store the information

    ngl_res_info = set()
    
    # Open the CSV file
    with open(file_path, 'r') as csv_file:
        csv_reader = csv.DictReader(csv_file)  # Using DictReader to read as dictionaries

        # Iterate through each row in the CSV file
        for row in csv_reader:
            # Create a key by combining sequence, chem_comp_ID, and X_atom_name
            key = f"{row['sequence']}_{row['chem_comp_ID']}_{row['X_atom_name']}"
            # Add the entry to the dictionary with the key
            data_dict[key] = float(row['X_shift'])
            # Create a key by combining sequence, chem_comp_ID, and X_atom_name
            key = f"{row['sequence']}_{row['chem_comp_ID']}_{row['Y_atom_name']}"
            # Add the entry to the dictionary with the key
            data_dict[key] = float(row['Y_shift'])
            # ngl representations
            ngl_res_info.add(int(row['sequence'])+49)
        
    ngl_list = list(ngl_res_info)

    return data_dict, ngl_list

def read_text_file(file_path):
    """
    Reads data from a text file containing chemical shift information.

    Args:
    file_path (str): Path to the text file.

    Returns:
    dict: A dictionary containing chemical shift data.
    """
    data_dict = {}  # Dictionary to store the information

    # Open the text file
    with open(file_path, 'r') as text_file:
        # Skip the header line
        next(text_file)

        # Iterate through each line in the text file
        for line in text_file:
            # Split the line into values
            values = line.split()

            # Create a key by combining ResidueID, ResidueNAME, and Atom_Type
            key = f"{str(int(values[0])-49)}_{values[1]}_{values[2]}"

            # Create a dictionary for each entry
            entry = {
                'ResidueID': int(values[0]),
                'ResidueNAME': values[1],
                'Atom_Type': values[2],
                'Mean_CS(ppm)': float(values[3]),
                'Int_Error(ppm)': float(values[4]),
                'Ext_Error(ppm)': float(values[5]),
                'Total_Error(ppm)': float(values[6])
            }

            # Add the entry to the dictionary with the key
            data_dict[key] = entry['Mean_CS(ppm)']

    return data_dict

def read_CS_file(cs_file):
    """
    Reads chemical shift data from a file.

    Args:
    cs_file (str): Path to the file containing chemical shift data.

    Returns:
    dict: A dictionary containing chemical shift data.
    """
    # Open the file in read mode
    CS_file_raw = {}
    with open(cs_file, 'r') as file:
        # Iterate through each line in the file
        for line in file:
            # Split the line into individual elements based on whitespace
            elements = line.split()

            # Extract specific columns of interest (adjust indices based on your needs)
            col1 = elements[0]
            col2 = elements[1]
            col3 = elements[2]
            col4 = elements[3]
            col5 = elements[4]
            col6 = elements[5]
            col7 = elements[6]
            col8 = elements[7]
            col9 = elements[8]
            col10 = elements[9]
            col11 = elements[10]
            col12 = elements[11]
            col13 = elements[12]
            col14 = elements[13]
            col15 = elements[14]
            col16 = elements[15]
            col17 = elements[16]
            col18 = elements[17]
            col19 = elements[18]
            col20 = elements[19]
            col21 = elements[20]
            col22 = elements[21]
            col23 = elements[22]
            col24 = elements[23]
            col25 = elements[24]
            col26 = elements[25]
            
            key = str(col7) + '_' + str(col8) + '_' + str(col9)
            cs_value = float(col12)
            CS_file_raw[key] = cs_value
    return CS_file_raw

def plot_correlation(atom_type, axis_limits, axis_spacing, file_paths, CS_file_franz):
    """
    Plots the correlation between experimental and computational chemical shifts.

    Args:
    atom_type (str): Type of atom.
    axis_limits (tuple): Tuple containing the limits of the plot axes.
    axis_spacing (int): Spacing between ticks on the plot axes.
    file_paths (list): List of file paths containing data.
    CS_file_franz (dict): Dictionary containing chemical shift data.

    Returns:
    tuple: A tuple containing the figure and axes objects.
    """
    # Initialize a figure with subplots
    fig, axs = plt.subplots(5, 2, figsize=(20, 30))
    axs = axs.flatten()

    #fig.suptitle(f'Correlation between Exp. & Comp. CS for AtomType: {atom_type}', fontsize=18)

    axis_ticks = np.arange(axis_limits[0], axis_limits[1] + 1, axis_spacing)

    for i, file_path in enumerate(file_paths, start=1):
        dict_comp_values = {}

        # Open the file for reading
        with open(file_path, 'r') as file:
            # Read all lines from the file
            lines = file.readlines()
            # Skip the first line (header)
            lines = lines[1:]

            # Process each line
            for line in lines:
                data = line.strip().split(";")
                float_list = [float(element) for element in data[4:]]
                id_key = str(int(data[0])-49)+'_'+str(data[2])+'_'+str(data[1])
                mean_value = statistics.mean(float_list)
                dict_comp_values[id_key] = mean_value
                

        # Lists to store all the computational and experimental values for the current replica
        comp_values = []
        exp_values = []

        # Iterate over the data to get values for the current replica
        for key, value in dict_comp_values.items():
            if key.endswith(atom_type):
                if key in CS_file_franz:
                    exp_value = CS_file_franz[key]
                    comp_value = value

                    # Append values to the lists
                    comp_values.append(comp_value)
                    exp_values.append(exp_value)

                    # Create a scatter plot for the current replica
                    axs[i-1].scatter(comp_value, exp_value, label=key, color='black')

        

        # Set the specified axis limits and ticks
        axs[i-1].set_xlim(axis_limits[0], axis_limits[1])
        axs[i-1].set_ylim(axis_limits[0], axis_limits[1])
        axs[i-1].set_xticks(axis_ticks)
        axs[i-1].set_yticks(axis_ticks)

        # Calculate and plot the correlation line with a lower intercept
        correlation_line = np.polyfit(comp_values, exp_values, 1)
        axs[i-1].plot(comp_values, np.polyval(correlation_line, comp_values), color='red', label='Correlation Line')

        # Add correlation coefficient to the plot
        correlation_coefficient = np.corrcoef(comp_values, exp_values)[0, 1]
        axs[i-1].text(0.05, 0.95, f'Correlation: {correlation_coefficient:.3f}', transform=axs[i-1].transAxes, fontsize=16, verticalalignment='top')




        # Add labels and a diagonal line for reference
        axs[i-1].set_xlabel('Computational Shift (ppm)', fontsize=18)
        axs[i-1].set_ylabel('Experimental Shift (ppm)', fontsize=18)
        axs[i-1].set_title(f'Replica {i}', fontsize=18, fontweight='bold')

        # set the ticks larger
        axs[i-1].tick_params(axis='both', which='major', labelsize=18)

        axs[i-1].invert_yaxis()
        axs[i-1].invert_xaxis()
    # Adjust layout and return the figure and axes
    plt.tight_layout()
    return fig, axs

def read_comp_values(file_path):
    """
    Reads computational chemical shift values from a file.

    Args:
    file_path (str): Path to the file containing computational shift data.

    Returns:
    dict: A dictionary containing computational chemical shift data.
    """
    dict_comp_values = {}

    # Open the file for reading
    with open(file_path, 'r') as file:
        # Read all lines from the file
        lines = file.readlines()
        # Skip the first line (header)
        lines = lines[1:]

        # Process each line
        for line in lines:
            data = line.strip().split(";")
            float_list = [float(element) for element in data[4:]]
            id_key = str(int(data[0])-49)+'_'+str(data[2])+'_'+str(data[1])
            mean_value = statistics.mean(float_list)
            dict_comp_values[id_key] = mean_value

    return dict_comp_values

def calculate_correlation_coefficients(atom_type, file_paths, CS_file_franz):
    """
    Calculates correlation coefficients between experimental and computational chemical shifts.

    Args:
    atom_type (str): Type of atom.
    file_paths (list): List of file paths containing computational shift data.
    CS_file_franz (dict): Dictionary containing experimental chemical shift data.

    Returns:
    list: List of correlation coefficients.
    """
    correlation_coefficients = []

    for i, file_path in enumerate(file_paths, start=1):
        dict_comp_values = read_comp_values(file_path)

        # Lists to store all the computational and experimental values for the current replica
        comp_values = []
        exp_values = []

        # Iterate over the data to get values for the current replica
        for key, value in dict_comp_values.items():
            if key.endswith(atom_type):
                if key in CS_file_franz:
                    exp_value = CS_file_franz[key]
                    comp_value = value

                    # Append values to the lists (reversed order)
                    comp_values.insert(0, comp_value)
                    exp_values.insert(0, exp_value)

        # Calculate the correlation coefficient
        correlation_coefficient = np.corrcoef(comp_values, exp_values)[0, 1]

        # Append the correlation coefficient to the list
        correlation_coefficients.append(correlation_coefficient)

    return correlation_coefficients

def plot_correlation_coefficients_table(atom_types, file_paths, correlation_coefficients_dict):
    """
    Plots correlation coefficients in a table format.

    Args:
    atom_types (list): List of atom types.
    file_paths (list): List of file paths containing computational shift data.
    correlation_coefficients_dict (dict): Dictionary containing correlation coefficients.

    Returns:
    tuple: A tuple containing the figure and axes objects.
    """
    # Plotting the correlation coefficients in the style of the provided table
    fig = plt.figure(figsize=(6, 4), dpi=300)
    ax = plt.subplot()

    plt.title('Exp. vs Comp. Correlation Coefficients Table')

    ncols = len(atom_types) + 1  # Add 1 for the replica number column
    nrows = len(correlation_coefficients_dict[atom_types[0]])

    ax.set_xlim(0, ncols + 1)
    ax.set_ylim(0, nrows + 1)

    positions = [0.5, 1.5, 2.5, 3.5, 4.5, 5.5]
    columns = ['Replica'] + atom_types

    # Add table's main text (reversed order)
    for i in reversed(range(nrows)):
        for j, column in enumerate(columns):
            if j == 0:
                ha = 'center'
                text_label = str(i + 1)  # Replica numbers in ascending order
            else:
                ha = 'center'
                text_label = f'{correlation_coefficients_dict[column][i]:.3f}'
            ax.annotate(
                xy=(positions[j], nrows - i - 0.5),  # Adjusted y position for reversed order
                text=text_label,
                ha=ha,
                va='center',
                weight='normal'
            )

    # Add column names
    column_names = ['Replica'] + atom_types
    for index, c in enumerate(column_names):
        if index == 0:
            ha = 'center'
        else:
            ha = 'center'
        ax.annotate(
            xy=(positions[index], nrows + .25),
            text=column_names[index],
            ha=ha,
            va='bottom',
            weight='bold'
        )

    # Add dividing lines
    ax.plot([ax.get_xlim()[0], ax.get_xlim()[1]], [nrows, nrows], lw=1.5, color='black', marker='', zorder=4)
    ax.plot([ax.get_xlim()[0], ax.get_xlim()[1]], [0, 0], lw=1.5, color='black', marker='', zorder=4)
    for x in range(1, nrows):
        ax.plot([ax.get_xlim()[0], ax.get_xlim()[1]], [x, x], lw=1.15, color='gray', ls=':', zorder=3 , marker='')

    ax.set_axis_off()
    return fig, ax

### Read input data and visualize experimental results with NGL viewer

Here we read the experimental CSV file from the [Biological Magnetic Resonance Data Bank](https://bmrb.io/data_library/summary/index.php?bmrbId=51907) from the [study of Franz Hagn et al.](https://www.sciencedirect.com/science/article/abs/pii/S1047847723000333). 

We initally read the pdb file used to simulate the system and we map the residues from which we have information about experimental NMR. This serves us to get an idea of which areas of the GPCR were evaluated in this study.

In [None]:
# Specify the path to your CSV file & PDB
csv_file_path = 'Experimental_Data/51907_simulated_hsqc_backbone.csv'
pdb_path = 'Computational_Data/10661_dyn_66.pdb'

print("Reading CSV file...")
# Call the function to read the CSV file and create the dictionary
csv_data_dict, ngl_residues = read_csv(csv_file_path)

print("Generating NGL representation...")
# Load a PDB file
view = nv.show_structure_file(pdb_path)

view.clear()

# Iterate through the list and add van der Waals spheres for each alpha carbon
for residue_id in ngl_residues:
    selection_string = f'.CA and {residue_id}'
    view.add_spacefill(selection=selection_string, color='black')

# Add a cartoon representation for the protein backbone
view.add_cartoon(selection='protein',color='white')
# Display the viewer
view


### Compute the correlation between Exp. & Comp. data for the common residues in both methods

The next cells contains functions for analyzing the correlation between experimental and computational chemical shifts. Here's a brief overview of the function:


This function plots the correlation between experimental and computational chemical shifts for a specified atom type. It takes the following arguments:
- `atom_type`: Type of atom to analyze.
- `axis_limits`: Tuple containing the limits of the plot axes.
- `axis_spacing`: Spacing between ticks on the plot axes.
- `file_paths`: List of file paths containing data.
- `CS_file_franz`: Dictionary containing chemical shift data.

We split the correlation per Atom type. The axis limits are adapted to each atom types.

We need to load the experimental data `CS_file_franz`and the computational data `file_paths`.


In [None]:
CS_file_franz = read_CS_file('Experimental_Data/shift_neurotensin.txt')

file_paths = [
    'Computational_Data/cs_dyn1753_22337.csv',
    'Computational_Data/cs_dyn1753_22338.csv',
    'Computational_Data/cs_dyn1753_22339.csv',
    'Computational_Data/cs_dyn1753_22340.csv',
    'Computational_Data/cs_dyn1753_22341.csv',
    'Computational_Data/cs_dyn1753_22342.csv',
    'Computational_Data/cs_dyn1753_22343.csv',
    'Computational_Data/cs_dyn1753_22344.csv',
    'Computational_Data/cs_dyn1753_22345.csv',
    'Computational_Data/cs_dyn1753_22346.csv'
]

atom_types = ['C', 'CA', 'CB', 'N', 'H']


#### AtomType C

In [None]:
atom_type = 'C'
axis_limits = [170, 182]
axis_spacing = 2

fig, axs = plot_correlation(atom_type, axis_limits, axis_spacing, file_paths, CS_file_franz)
plt.savefig(f'Plots/Correlation_{atom_type}.png')
plt.show()

#### AtomType CA

In [None]:
atom_type = 'CA'
axis_limits = [40, 70]
axis_spacing = 5

fig, axs = plot_correlation(atom_type, axis_limits, axis_spacing, file_paths, CS_file_franz)
plt.savefig(f'Plots/Correlation_{atom_type}.png')
plt.show()

#### AtomType CB

In [None]:
atom_type = 'CB'
axis_limits = [10, 80]
axis_spacing = 10

fig, axs = plot_correlation(atom_type, axis_limits, axis_spacing, file_paths, CS_file_franz)
plt.savefig(f'Plots/Correlation_{atom_type}.png')
plt.show()

#### AtomType N

In [None]:
atom_type = 'N'
axis_limits = [100, 135]
axis_spacing = 5

fig, axs = plot_correlation(atom_type, axis_limits, axis_spacing, file_paths, CS_file_franz)
plt.savefig(f'Plots/Correlation_{atom_type}.png')
plt.show()

#### AtomType H

In [None]:
atom_type = 'H'
axis_limits = [6, 10]
axis_spacing = 1

fig, axs = plot_correlation(atom_type, axis_limits, axis_spacing, file_paths, CS_file_franz)
plt.savefig(f'Plots/Correlation_{atom_type}.png')
plt.show()

### Compute Correlation Table

Here we compute the table with all the correlation coeficients for all the distributions shown above.

In [None]:
correlation_coefficients_dict = {}

# Calculate correlation coefficients for each atom type and store them in a dictionary
for atom_type in atom_types:
    correlation_coefficients = calculate_correlation_coefficients(atom_type, file_paths, CS_file_franz)
    correlation_coefficients_dict[atom_type] = correlation_coefficients

# Plot the correlation coefficients table
fig, ax = plot_correlation_coefficients_table(atom_types, file_paths, correlation_coefficients_dict)
plt.savefig('Plots/correlation_coefficients_table.png', dpi=300, transparent=True, bbox_inches='tight')
plt.show()
