In [1]:
# Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import copy
import os, glob
from os import listdir
from os.path import isfile, isdir, join

import torch
import esm

In [None]:
# Global vars
data_directory = '../data',
discarded_mutations = 500,
step = 1

# Check for directories
dir_ = 'results'
onlydirs = [d for d in listdir('../') if isdir('../' + d)]
if (dir_ in onlydirs) == False: os.mkdir('../' + dir_)
    
# Load model
model_location = model_location
model, self.alphabet = esm.pretrained.load_model_and_alphabet(model_location)
batch_converter = alphabet.get_batch_converter()
model.eval()

### Multi-T data processing

In [None]:
### Load data from fixed T and threshold simulation
def load_data(filename):
    file = open(filename, 'r')
    lines = file.readlines()
    splitted_lines = np.array([line.split('\t') for line in lines]).astype(float)
    file.close()
    return splitted_lines


### Process data
def process_data(data_directory, discarded_mutations, load_bool = False):
    # Load already processed data
    if load_bool:
        processed_data = pd.read_csv(join(data_directory, 'processed_data.csv'))
    # Process and save raw data
    else:
        dirpath = data_directory + '/data_*'
        filelist = glob.glob(dirpath)

        processed_data = pd.DataFrame({
            'mean_eff_energy': [0],
            'max_eff_energy': [0],
            'last_eff_energy': [0],
            'mean_ddG': [0],
            'max_ddG': [0],
            'last_ddG': [0],
            'mean_distance': [0],
            'max_distance': [0],
            'last_distance': [0],
            'acceptance_rate': [0],
            'threshold': [0],
            'T': [0],
            'sequence_length': [0]
        })

        for filename in tqdm(filelist, total=len(filelist)):
            # Load fixed T data
            file_data = load_data(filename)

            # Store mulit-threshold data in dictionary
            df = pd.DataFrame({
                'mean_eff_energy': [file_data[discarded_mutations:, 1].mean()],
                'max_eff_energy': [file_data[discarded_mutations:, 1].max()],
                'last_eff_energy': [file_data[-1, 1]],
                'mean_ddG': [file_data[discarded_mutations:, 2].mean()],
                'max_ddG': [file_data[discarded_mutations:, 2].max()],
                'last_ddG': [file_data[-1, 2]],
                'mean_distance': [file_data[discarded_mutations:, 3].mean()],
                'max_distance': [file_data[discarded_mutations:, 3].max()],
                'last_distance': [file_data[-1, 3]],
                'acceptance_rate': [file_data[-1, -4]],
                'threshold': [file_data[0, -3]],
                'T': [file_data[0, -2]],
                'sequence_length': [file_data[0, -1]]
            })
            processed_data = processed_data.append(df, ignore_index = True)

        # Save and plot multi-threshold data
        processed_data = processed_data.drop([0]).reset_index().drop(['index'], axis = 1)
        processed_data = processed_data.sort_values(by = ['T'])
        processed_data.to_csv(join(data_directory, 'processed_data.csv'))
        
    return processed_data

### Mulit-T data plotting

In [None]:
### Effective energy, ddG and Hamming distance for same threshold simulations
def plot_processed_energy_distance(processed_data, savefig_bool = True):
for threshold in processed_data['threshold']:
    t_str = str(threshold)[:1] + str(threshold)[2:7]
    file_id = 't' + t_str

    df = processed_data[processed_data['threshold'] == threshold]

    fig = plt.figure(figsize = (9, 18))
    fig.suptitle(f'Effective energy, $\Delta\Delta$G and Hamming distance for same threshold simulations\n(threshold = {threshold})', y = 0.93)

    ax = plt.subplot(3, 1, 1)
    ax.plot(df['T'], df['mean_eff_energy'], label='mean energy')
    ax.plot(df['T'], df['max_eff_energy'], label='max energy', linestyle = '--')
    ax.plot(df['T'], df['last_eff_energy'], label='last energy', linestyle = '--')
    ax.scatter(df['T'], df['mean_eff_energy'])
    ax.scatter(df['T'], df['max_eff_energy'], marker = '^')
    ax.scatter(df['T'], df['last_eff_energy'], marker = 's')
    ax.set_ylabel(r'E$_{eff}$')
    ax.set_xlabel(r'T')
    ax.legend()
    ax.grid(True)

    ax = plt.subplot(3, 1, 2)
    ax.plot(df['T'], df['mean_ddG'], label='mean ddG')
    ax.plot(df['T'], df['max_ddG'], label='max ddG', linestyle = '--')
    ax.plot(df['T'], df['last_ddG'], label='last ddG', linestyle = '--')
    ax.scatter(df['T'], df['mean_ddG'])
    ax.scatter(df['T'], df['max_ddG'], marker = '^')
    ax.scatter(df['T'], df['last_ddG'], marker = 's')
    ax.set_ylabel(r'$\Delta\Delta$G')
    ax.set_xlabel(r'T')
    ax.legend()
    ax.grid(True)

    ax = plt.subplot(3, 1, 3)
    ax.plot(df['T'], df['mean_distance'], label='mean distance')
    ax.plot(df['T'], df['max_distance'], label='max distance', linestyle = '--')
    ax.plot(df['T'], df['last_distance'], label='last distance', linestyle = '--')
    ax.scatter(df['T'], df['mean_distance'])
    ax.scatter(df['T'], df['max_distance'], marker = '^')
    ax.scatter(df['T'], df['last_distance'], marker = 's')
    ax.plot(df['T'], df['sequence_length'], label='sequence length', linestyle='--', color='grey')
    ax.set_ylabel(r'd$_{Hamm}$')
    ax.set_xlabel(r'T')
    ax.legend()
    ax.grid(True)

    if savefig_bool: plt.savefig('../results/processed_energy_distance_' + file_id + '.png', bbox_inches='tight')


### Acceptance rates for same threshold simulations
def plot_processed_acceptance(processed_data, savefig_bool = True):

    fig = plt.figure(figsize = (9, 6))
    fig.suptitle(f'Acceptance rate for same threshold simulations', y = 0.93)

    ax = plt.subplot(1, 1, 1)		
    for threshold in processed_data['threshold']:
        df = processed_data[processed_data['threshold'] == threshold]
        
        ax.plot(df['T'], df['acceptance_rate'], label = f'threshold: {threshold}')
        ax.scatter(df['T'], df['acceptance_rate'])
        ax.set_ylabel(r'acceptance rate$')
        ax.set_xlabel(r'T')
        ax.grid(True)

        if savefig_bool: plt.savefig('../results/processed_acceptance_rate.png', bbox_inches='tight')

In [None]:
# Load processed data
processed_data = process_data(data_directory, discarded_mutations, load_bool = False)

In [None]:
# Plot
plot_processed_energy_distance(processed_data)
plot_processed_acceptance(processed_data)

### Multi-T sequence pdf processing

In [None]:
### Load sequences from single-threshold simulation
def load_sequences(filename, step):
    file = open(filename, 'r')
    lines = file.readlines()

    delim = ''
    tot_sequences = delim.join(lines)
    sequences = tot_sequences.split('\n')
    sequences.pop(-1)
    wt_sequence = sequences[0].split('\t')[1]
    sequences = [sequence.split('\t')[1] for sequence in sequences][self.discarded_mutations:]
    
    sequences = [line.split('\t')[1] for line in lines]
    sequences = [sequence[:-1] for sequence in sequences if sequence[-1] == '\n']

    counts = np.unique([len(sequence) for sequence in sequences])
    assert len(counts) == 1, 'Each sequence must be of the same length.'

    _, T, threshold = filename.split('_')
    T = float( T[1] + '.' + T[2:] )
    threshold = float( threshold[1] + '.' + threshold[2:-4] )
    
    return sequences[0], sequences[::self.step], T, threshold, counts[0]


### Process sequences
def process_sequences(data_directory, step, load_bool = False):
    # Load pdfs
    if load_bool:
        pdfs = np.load(join(data_directory, 'pdfs.npy'), allow_pickle = True)
    # Process and save pdfs
    else:
        dirpath = data_directory + '/mutants_*'
        filelist = glob.glob(dirpath)

        pdfs, Ts, thresholds, lenghts = [], [], [], []
        for filename in tqdm(filelist, total=len(filelist)):
            # Load sequence, T and threshold
            wt_sequence, sequences, T, threshold, length = load_sequences(filename, step)

            # Compute mutual distance between unique mutants
            single_pdf = np.zeros(lenght + 1)
            idxs = np.arange(0, len(sequences), dtype = int)

            print(f'  T = {T}, threshold = {threshold}')
            for row in tqdm(idxs, total = len(idxs)):
                for col in idxs[(row + 1):]:
                    seq_a = np.array(list(sequences[row]))
                    seq_b = np.array(list(sequences[col]))
                    distance_ab = len( np.where(seq_a != seq_b)[0] )
                    single_pdf[ distance_ab ] += 1

            pdfs.append( list(single_pdf) )
            Ts.append( T )
            thresholds.append( threshold )
            lengths.append( length )

        # Sort self.pdfs values by T and save
        ordered_Ts = np.sort(Ts)
        pdfs = np.array([pdfs[Ts.index(T)] + [T, thresholds[Ts.index(T)], lengths[Ts.index(T)]] for T in ordered_Ts])
        np.save(join(data_directory, 'pdfs.npy'), pdfs)
        
    return pdfs

### Multi-T sequence pdf plot

In [None]:
### Distance pdfs plot
def plot_pdfs(pdfs, show_bool = True, savefig_bool = True):
    fig = plt.figure(figsize = (9, 6))

    ax = plt.subplot(1, 1, 1)
    ax.set_title(f'Distances distribution between unique mutated sequences at fixed T')
    for single_pdf in pdfs:
        ax.plot(np.arange(0, single_pdf[-1] + 1), single_pdf[:-3] / np.sum(single_pdf[:-3]), label = f'{single_pdf[-3]}')
    ax.set_xlabel(r'd$_{Hamm}$')
    ax.set_ylabel(r'$\rho$(d$_{Hamm}$)')
    ax.legend(title = 'Thresholds')
    ax.grid(True)

    if savefig_bool: plt.savefig('../results/distances_pdf.png', bbox_inches='tight')

In [None]:
# Load processed data
pdfs = process_sequences(data_directory, step, load_bool = False)

In [None]:
# Plot
plot_pdfs(pdfs)

### Fixed-T data plotting

In [None]:
### Effective energy, ddG and Hamming distance for fixed T and threshold simulation
def plot_energy_distance(filename, discarded_mutations, savefigs_bool = True):
    file_data = load_data(filename)
    T_str = str(file_data[0, -2])[:1] + str(file_data[0, -2])[2:7]
    t_str = str(file_data[0, -3])[:1] + str(file_data[0, -3])[2:7]
    file_id = 'T' + T_str + '_t' + t_str

    fig = plt.figure(figsize = (9, 18))
    fig.suptitle(fr'Effective energy, $\Delta\Delta$G and Hamming distance\n(T = {format(file_data[0, -2], ".5f")} - threshold = {format(file_data[0, -3], ".3f")})', y = 0.93)

    ax = plt.subplot(3, 1, 1)
    ax.plot(file_data[:, 0], file_data[:, 1])
    ax.plot(file_data[:, 0], np.ones(len(file_data)) * file_data[discarded_mutations::, 1].max(), linestyle='--', color='red', label=f'max energy (after {discarded_mutations} muts): {format(file_data[:, 1].max(), ".4f")}')
    ax.plot(file_data[:, 0], np.ones(len(file_data)) * file_data[discarded_mutations:, 1].mean(), linestyle='--', color='blue', label=f'mean energy (after {discarded_mutations} muts): {format(file_data[:, 1].mean(), ".4f")}')
    ax.plot(file_data[:, 0], file_data[:, -3], linestyle='--', color='grey', label=f'threshold: {file_data[0, -3]}')
    ax.set_ylabel(r'E$_{eff}$')
    ax.set_xlabel(r'generation')
    ax.legend()
    ax.grid(True)

    ax = plt.subplot(3, 1, 2)
    ax.plot(file_data[:, 0], file_data[:, 2])
    ax.plot(file_data[:, 0], np.ones(len(file_data)) * file_data[self.discarded_mutations::, 2].max(), linestyle='--', color='red', label=f'max energy (after {discarded_mutations} muts): {format(file_data[:, 2].max(), ".4f")}')
    ax.plot(file_data[:, 0], np.ones(len(file_data)) * file_data[self.discarded_mutations:, 2].mean(), linestyle='--', color='blue', label=f'mean energy (after {discarded_mutations} muts): {format(file_data[:, 2].mean(), ".4f")}')
    ax.set_ylabel(r'$\Delta\Delta$G')
    ax.set_xlabel(r'generation')
    ax.legend()
    ax.grid(True)

    ax = plt.subplot(3, 1, 3)
    ax.plot(file_data[:, 0], file_data[:, 3])
    ax.plot(file_data[:, 0], np.ones(len(file_data)) * file_data[discarded_mutations:, 3].max(), linestyle='--', color='red', label=f'max distance (after {discarded_mutations} muts): {file_data[:, 3].max()}')
    ax.plot(file_data[:, 0], np.ones(len(file_data)) * file_data[discarded_mutations:, 3].mean(), linestyle='--', color='blue', label=f'mean distance (after {discarded_mutations} muts): {int(file_data[:, 3].mean())}')
    ax.plot(file_data[:, 0], file_data[:, -1], linestyle='--', color='grey', label=f'sequence lenght: {file_data[0, -1]}')
    ax.set_ylabel(r'd$_{Hamm}$')
    ax.set_xlabel(r'generation')
    ax.legend()
    ax.grid(True)

    if savefigs_bool: plt.savefig('../results/plot_energy_distance_' + file_id + '.png', bbox_inches='tight')


### Effective energy, ddG and Hamming distance for fixed T and threshold simulation
def hist_energy_distance(filename, discarded_mutations, savefigs_bool = True):
    file_data = load_data(filename)
    T_str = str(file_data[0, -2])[:1] + str(file_data[0, -2])[2:7]
    t_str = str(file_data[0, -3])[:1] + str(file_data[0, -3])[2:7]
    file_id = 'T' + T_str + '_t' + t_str

    colors = ['red', 'blue', 'green']
    labels = ['first part', 'second part', 'third part']

    fig = plt.figure(figsize = (18, 18))
    fig.suptitle(f'Effective energy, $\Delta\Delta$G and Hamming distance distribution through the simulation\n(T = {format(file_data[0, -2], ".5f")} - threshold = {format(file_data[0, -3], ".3f")} - sequence lenght = {file_data[0, -1]})', y = 0.93)

    energy_bins = np.linspace(0, file_data[:, 1].max(), 51)
    distance_bins = np.linspace(0, file_data[0, -1], int(file_data[0, -1]) + 1, dtype = int)

    for i in range(len(labels)):
        data_fraction = file_data[int(i * len(file_data) / 3.) : int((i+1) * len(file_data) / 3.), :]

        ax = plt.subplot(3, 3, i + 1)
        ax.hist(data_fraction[:, 1], bins = energy_bins, density = True, color = colors[i])
        if i == 0: ax.set_ylabel(r'$\rho$(E$_{eff}$)')
        if i == 1: ax.set_xlabel(r'E$_{eff}$')
        ax.grid(True)

        ax = plt.subplot(3, 3, i + 4)
        ax.hist(data_fraction[:, 2], bins = distance_bins, density = True, color = colors[i])
        if i == 0: ax.set_ylabel(r'$\rho$($\Delta\Delta$G)')
        if i == 1: ax.set_xlabel(r'$\Delta\Delta$G')
        ax.grid(True)

        ax = plt.subplot(3, 3, i + 7)
        ax.hist(data_fraction[:, 3], bins = distance_bins, density = True, color = colors[i])
        if i == 0: ax.set_ylabel(r'$\rho$(d$_{Hamm}$)')
        if i == 1: ax.set_xlabel(r'd$_{Hamm}$')
        ax.grid(True)


    patches = [ mpatches.Patch(color=colors[i], label=labels[i]) for i in range(len(labels)) ]
    plt.legend(handles=patches, bbox_to_anchor=(1.05, 2.2), loc=2, borderaxespad=0.)

    if savefigs_bool: plt.savefig('../results/hist_energy_distance_' + file_id + '.png', bbox_inches='tight')

In [None]:
# Data files list
dirpath = data_directory + '/data_*'
filelist = glob.glob(dirpath)
filelist

In [None]:
# Plot
for filename in filelist: plot_energy_distance(filename, discarded_mutations)

### Fixed-T last sequence contact map plot

In [None]:
### Calculate contact map through esm model and binarize it (optional)
def calculate_contacts(model, batch_converter, protein, p_value = 0.5, binary_bool = True, show_bool = False):
    _lab, _str, batch_tokens = batch_converter([protein])
    with torch.no_grad():
        results = model(batch_tokens, return_contacts=True)
    contacts = np.array(results['contacts'][0])

    if show_bool:
        fig = plt.figure(figsize=(9, 6))
        plt.hist(contacts.reshape(-1), bins=np.linspace(0, 1, 21), density=True)
        plt.xlabel('contact value')
        plt.savefig('results/wt_contacts_value_distribution.png', bbox_inches='tight')
        plt.show()

    if binary_bool: contacts = (contacts > p_value).astype(int)

    return contacts


### Contact maps comparison plot for the wild-type and last mutant protein
def plot_contact_maps(model, batch_converter, filename, step, savefig_bool=True):
    sequence_wt, sequence_mt, T, threshold = load_sequences(filename, step)
    T_str = str(T)[:1] + str(T)[2:7]
    t_str = str(threshold)[:1] + str(threshold)[2:7]
    file_id = 'T' + T_str + '_t' + t_str

    contacts_wt = calculate_contacts(model, batch_converter, ('Wild-type', sequence_wt)).astype(float)
    contacts_mt = calculate_contacts(('Mutant', sequence_mt)).astype(float)

    common_contacts = contacts_wt * contacts_mt
    different_wt_contacts = (contacts_wt.reshape(-1) != common_contacts.reshape(-1)).reshape(contacts_wt.shape).astype(float)
    different_mt_contacts = (contacts_mt.reshape(-1) != common_contacts.reshape(-1)).reshape(contacts_mt.shape).astype(float)
    different_contacts = different_wt_contacts + different_mt_contacts

    idxs = np.linspace(0, len(contacts_wt)-1, len(contacts_wt), dtype=int)
    for row in idxs: 
        for col in idxs[idxs < row]: 
            contacts_wt[row, col] = 0.
        for col in idxs[idxs > row]: 
            common_contacts[row, col] = 0.
            different_contacts[row, col] = 0.

    contacts_wt[contacts_wt == 0.] = np.nan
    common_contacts[common_contacts == 0.] = np.nan
    different_contacts[different_contacts == 0.] = np.nan


    fig = plt.figure(figsize = (12, 9))

    Greens_cmap = copy.copy(plt.cm.get_cmap('Greens'))
    Greens_cmap.set_bad(alpha = 0)
    Blues_cmap = copy.copy(plt.cm.get_cmap('Blues'))
    Blues_cmap.set_bad(alpha = 0)
    Reds_cmap = copy.copy(plt.cm.get_cmap('Reds'))
    Reds_cmap.set_bad(alpha = 0)

    ax = plt.subplot(1, 1, 1)
    ax.set_title(f'Contact maps comparison between wild-type and last mutant protein\n(T = {T} - threshold = {threshold})')
    ax.imshow(contacts_wt, cmap=Greens_cmap, vmin=0, vmax=1)
    ax.imshow(common_contacts, cmap=Blues_cmap, vmin=0, vmax=1)
    ax.imshow(different_contacts, cmap=Reds_cmap, vmin=0, vmax=1)
    ax.grid(True)

    colors = [ 'green', 'blue', 'red' ]
    labels = [ 'original', 'common', 'different' ]
    patches = [ mpatches.Patch(color=colors[i], label=labels[i]) for i in range(len(labels)) ]
    plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

    if savefig_bool: plt.savefig('../results/contacts_' + file_id + '.png', bbox_inches='tight')

In [None]:
# Sequences files list
dirpath = data_directory + '/mutants_*'
filelist = glob.glob(dirpath)
filelist

In [None]:
# Plot
for filename in filelist: plot_contact_maps(model, batch_converter, filename, step)