# Phase separation prediction (TWIN2PIPSA version)

This Colab notebook enables prediction of IDR transfer free energies and saturation concentrations from sequence.
- The models have been trained on CALVADOS 2 slab simulation data.
- Conditions are fixed to T=293 K and I=150 mM.

<b>How to cite this notebook:</b>
- S. von Bülow, G. Tesei, F. K. Zaidi, T. Mittag, K. Lindorff-Larsen, __Prediction of phase-separation propensities of disordered proteins from sequence__ _Proceedings of the National Academy of Sciences_ 2025

Further references:
- Use of $\nu_\mathrm{SVR}$: <br>
G. Tesei, A. I. Trolle, N. Jonsson, J. Betz, F. Pesce, K. E. Johansson, K. Lindorff-Larsen, __Conformational ensembles of the human intrinsically disordered proteome__ _Nature_ 2024 626, 897–904 DOI: https://doi.org/10.1038/s41586-023-07004-5
- CALVADOS 2 model: <br>
G. Tesei and K. Lindorff-Larsen, __Improved predictions of phase behaviour of intrinsically disordered proteins by tuning the interaction range [version 2; peer review: 2 approved]__ _Open Research Europe_ 2023 2(94) DOI: https://doi.org/10.12688/openreseurope.14967.2

Author: Sören von Bülow (soren.bulow@bio.ku.dk)

In [None]:
#@title <b>Preliminary operations</b>

import os
import warnings
warnings.simplefilter("ignore")

print('Setting up the environment...')

!rm -r sample_data &> dump
!rm sequence.* &> dump
!rm svr_model_nu* &> dump
!rm residues* &> dump
!rm example* &> dump
!rm predictor.* &> dump
!rm *joblib &> dump

github_folder = 'https://raw.githubusercontent.com/KULL-Centre/_2024_buelow_PSpred/twin2pipsa'

print(f'Downloading files from {github_folder}')

os.system(f'wget {github_folder}/scripts_colab/sequence.py')
os.system(f'wget {github_folder}/scripts_colab/predictor.py')
os.system(f'wget {github_folder}/data/residues.csv')
os.system(f'wget {github_folder}/data/example.fasta')
os.system(f'wget -O model_dG.joblib {github_folder}/models/idrome90/mlp/dG/model.joblib')
os.system(f'wget -O model_logcdil_mgml.joblib {github_folder}/models/idrome90/mlp/logcdil_mgml/model.joblib')
os.system(f'wget -O svr_model_nu.joblib {github_folder}/models/svr_model_nu.joblib')

os.system(f'wget {github_folder}/data/IDRome_DB_full.csv')

!pip install 'scikit-learn==1.3' MDAnalysis biopython numba &> dump

import joblib
import sklearn
import pandas as pd

import sequence
from predictor import *

import numpy as np
import MDAnalysis as mda
import matplotlib.pyplot as plt

from tqdm import tqdm
from google.colab import files

ncrossval = 50
print('Environment set up.')
print('Loading models...')

colors = np.array([
    [1., 101., 165.],
    [220., 140., 46.], # [220., 175., 46.],
    [31., 107., 65.],
    [200., 80., 45.],
    [127., 0., 255.]
])
colors /= 255.

residues = pd.read_csv('residues.csv').set_index('one')
nu_file = 'svr_model_nu.joblib'

df_full = pd.read_csv('IDRome_DB_full.csv').set_index('id')

features = ['mean_lambda', 'faro', 'shd', 'ncpr', 'fcr', 'scd', 'ah_ij','nu_svr']

features_clean = {
    'mean_lambda' : 'lambda',
    'faro' : 'f(aromatics)',
    'shd' : 'SHD',
    'ncpr' : 'NCPR',
    'fcr' : 'FCR',
    'scd' : 'SCD',
    'ah_ij' : 'LJ pairs',
    'nu_svr' : 'nu(SVR)'
}

print('Input features are:')
print('>>>>> '+ ', '.join([features_clean[fe] for fe in features]))

!touch calvados.py

models = {}
models['dG'] = joblib.load(f'model_dG.joblib')
models['logcdil_mgml'] = joblib.load(f'model_logcdil_mgml.joblib')

mltype = 'mlp'
alpha = 5
layers = (10,10)

targets = ['dG','logcdil_mgml']
targets_clean = {
    'dG' : 'Delta G',
    'logcdil_mgml' : 'Saturation concentration',
}

print('Models loaded.')

# Predict single sequence

Paste in an aminoacid sequence to predict phase separation propensities. The termini can be charged (unmodified) or not charged (capped), which should only have minor effects on the prediction if the sequence is long.

In [None]:
#@title <b>Run single IDR prediction</font></b>

try:
    os.rmdir('sample_data')
except:
    pass

NAME = "LAF1_mod" #@param {type:"string"}
SEQUENCE = "MESNQSNNGGSGNAALNRGGRYVPPHLRGGDGGAAAAASAGGDDRRGGAGGGGYRRGGGNSGGGGGGGYDRGYNDNRDDRDNRGGSGGYGRDRNYEDRGYNGGGGGGGNRGYNNNRGGGGGGYNRQDRGDGGSSNFSRGGYNNRDEGSDNRGSGRSYNNDRRDNGGDGLEHHHHHH" #@param {type:"string"}
CHARGE_TERMINI = True # @param {type:'boolean'}
TEMPERATURE = "293 K (fixed)" # @param ['293 K (fixed)']
IONIC_STRENGTH = "150 mM (fixed)" # @param ['150 mM (fixed)']

seq = SEQUENCE[:]
if " " in seq:
    seq = ''.join(seq.split())
    print('Blank character(s) found in the provided sequence. Sequence has been corrected, but check for integrity.')

print('='*80)
print(f'NAME: {NAME}')
print(f'SEQUENCE: {seq}')

seqfeats = sequence.SeqFeatures(seq,residues=residues,charge_termini=CHARGE_TERMINI)
X = X_from_seq(seq,features,residues=residues,charge_termini=CHARGE_TERMINI,nu_file=nu_file)

for target in targets:
  print('-'*80)
  ys = models[target].predict(X)#,models)
  ys_m = np.mean(ys)

  if target == 'dG':
    output = ys_m
    unit = 'kT'
    lower = ys_m - 1
    upper = ys_m + 1
  elif target == 'logcdil_mgml':
    output = np.exp(ys_m)
    lower = np.exp(ys_m-0.82)
    upper = np.exp(ys_m+0.82)
    unit = 'mg/mL'

  print(f'{targets_clean[target]:25s} = {output:5.1f} {unit:6s} ({lower:.1f} -- {upper:.1f} {unit})')
  if target == 'logcdil_mgml':
    output_uM = output / seqfeats.mw * 1e6
    lower_uM = lower / seqfeats.mw * 1e6
    upper_uM = upper / seqfeats.mw * 1e6
    print(f'{"":25s} = {output_uM:5.1f} {"uM":6s} ({lower_uM:.1f} -- {upper_uM:.1f} {"uM"})')
print('='*80)

fig, ax = plt.subplots(figsize=(4,3))

target = 'dG'

ys = models[target].predict(X)
ys_m = np.mean(ys)

ax.hist(df_full['dG_pred'],bins=100,histtype='step',density=True,color='black',
        label='IDRome')
ax.hist(df_full['dG_pred'],bins=100,alpha=0.2,density=True,color='black')

ax.axvline(ys_m, color=colors[2], lw=2,label=NAME)

ax.set(yscale='log')
ax.grid(alpha=0.3)
ax.set(xlabel=f'Predicted {targets_clean[target]} [kT]',ylabel='PDF')
ax.set(xlim=(None,0.4))
ax.legend()
fig.tight_layout()
fig.savefig(f'dG_distribution_{NAME}.pdf')

In [None]:
#@title <b>Relate to input features</font></b>

limits = {
    'scd' : [-5,6],
    # 'scd' : [-8,6],
    'ncpr' : [-0.25, 0.25],
    'kappa' : [0.0, 0.8],
    'faro' : [0., 0.15],
    'nu' : [0.2, 0.7],
    'nu_svr' : [0.45, 0.62],
    'mw' : [0., 60000],
    'fcr' : [0., 0.6],
    'mean_lambda' : [0.3, 0.6],
    'shd' : [1.5, 6.],
    'N' : [0, 800],
    'ah_ij' : [-0.9,-0.3],
}

def bin_data(xs,ys,nbins,drange=None):
    """ bin data ys in xs bins, based on numpy.histogram_bin_edges """
    if drange == None:
        xmin, xmax = np.min(xs), np.max(xs)
    else:
        xmin, xmax = drange[0], drange[1]
    bins = np.linspace(xmin,xmax,nbins+1)
    y_binned = [[] for _ in range(nbins)]
    for x, y in zip(xs,ys):
        if x <= bins[0]:
            y_binned[0].append(y)
        elif x >= bins[-1]:
            y_binned[nbins-1].append(y)
        else:
            for idx in range(nbins):
                if x >= bins[idx] and x < bins[idx+1]:
                    y_binned[idx].append(y)
    return bins, y_binned

target = 'dG'

ys = models[target].predict(X)
dG_single = np.mean(ys)

seq = SEQUENCE[:]
seqfeats = sequence.SeqFeatures(seq,residues=residues,
                                charge_termini=CHARGE_TERMINI,
                                nu_file=nu_file)

nbins = 84

Nfeat = len(features)

fig, ax = plt.subplots(2,4,figsize=(10,3.5),sharey=True)

for idx, feat in enumerate(features):
    axij = ax[idx//4,idx%4]

    xs = df_full[feat].to_numpy()
    ys = df_full['dG_pred'].to_numpy()
    if feat in limits:
        drange = limits[feat]
    else:
        drange = None
    edges, y_binned = bin_data(xs,ys,nbins,drange=drange)
    edges = (edges[:-1] + edges[1:]) / 2.

    y_mean = np.array([np.mean(y) for y in y_binned])
    y_std = np.array([np.std(y) for y in y_binned])
    yLs = np.array([len(yb) for yb in y_binned])

    for e, ym, yst, yL in zip(edges,y_mean,y_std,yLs):
        color = colors[0]
        markers, caps, bars = axij.errorbar(e,ym,yerr=yst,capsize=1,marker='.',color=color,zorder=10)
        [bar.set_alpha(0.3) for bar in bars]
        [cap.set_alpha(0.3) for cap in caps]

    ax2 = axij.twinx()
    ax2.plot(edges,yLs,lw=0.8,alpha=1.,color=colors[1])#'black')
    ax2.fill_between(edges,yLs,alpha=0.3,color=colors[1])#'black')
    if idx in [3,7]:
        ax2.set_ylabel('IDR Frequency',color=colors[1],labelpad=4.0)
    if feat == 'faro':
        ax2.set(ylim=(0,max(yLs)*2))
    elif feat in ['scd', 'ncpr']:
        ax2.set(ylim=(0,max(yLs)*2.5))
    else:
        ax2.set(ylim=(0,max(yLs)*3))
    ax2.set_yticks([])
    ax2.set_yticklabels([])
    ax2.tick_params(axis='y', colors=colors[1])
    ax2.grid(False)

    x_single = getattr(seqfeats,feat)
    axij.plot(x_single,dG_single,'o',color=colors[2],label=NAME)
    axij.axvline(x_single,color=colors[2],label=NAME)

    axij.set(xlabel=features_clean[feat])
    axij.set(ylim=(-7,None))
    axij.tick_params(axis='y', colors=colors[0])
    axij.grid(ls='dotted',alpha=0.5)

# ax[0,0].legend()

for idx in range(2):
    ax[idx,0].set_ylabel(f'{targets_clean[target]}',color=colors[0])
fig.tight_layout(pad=0,w_pad=1.08,h_pad=1.08)#
fig.savefig(f'dG_vs_features_{NAME}.pdf')

## Predict variant effects

Below we substitute each of the residues in the above input sequence one by one with any of the 19 other aminoacids. The resulting plot shows the effect on phase separation propensities. Negative values (red) indicate stronger phase separation, positive values (blue) indicate weaker phase separation compared to the input sequence.

In [None]:
#@title <b>Run variant effect prediction</font></b>

p = np.argsort(residues['lambdas'].to_numpy())
sorted_names = residues.index[p]
aminoacids_sorted = "".join(sorted_names)

ah_intgrl_map = sequence.make_ah_intgrl_map(residues)
lambda_map = sequence.make_lambda_map(residues)

target = 'dG'
model = models[target]

seq = SEQUENCE[:]
wt_idx = np.zeros((len(seq)))
dG_map = np.zeros((len(aminoacids_sorted),len(seq)))

seq_feats = sequence.SeqFeatures(seq,residues=residues,charge_termini=CHARGE_TERMINI,nu_file=nu_file,
            ah_intgrl_map=ah_intgrl_map,lambda_map=lambda_map)

X = X_from_seq(seq,features,seq_feats=seq_feats)
ys = model.predict(X)
dG_wt = np.mean(ys)

for idx, s in tqdm(enumerate(seq),total=len(seq)):
    wt_idx[idx] = aminoacids_sorted.index(s)
    for jdx, a in enumerate(aminoacids_sorted):
        seq_new = list(seq)
        seq_new[idx] = a
        seq_new = "".join(seq_new)

        seq_feats = sequence.SeqFeatures(seq_new,residues=residues,charge_termini=CHARGE_TERMINI,nu_file=nu_file,
                    ah_intgrl_map=ah_intgrl_map,lambda_map=lambda_map)
        X = X_from_seq(seq_new,features,seq_feats=seq_feats)
        ys = model.predict(X)
        ys_m = np.mean(ys)

        dG_map[jdx,idx] = ys_m

In [None]:
#@title <b>Plot variant effect results</font></b>

from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, ax = plt.subplots(figsize=(14,9))#,sharex=True)

_ = ax.imshow(dG_map - dG_wt,cmap=plt.cm.RdBu,vmin=-1,vmax=1)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="1%", pad=0.1)
plt.colorbar(_, cax=cax,label=f'$\Delta \Delta G$ [kT]')

print(f'{targets_clean["dG"]} wt: {dG_wt:.1f} kT')

ax.plot(wt_idx,'o',fillstyle='none',color='black',markersize=4)
ax.set_xticks(np.arange(len(seq)))
ax.set_xticklabels(seq,fontsize=7)
ax.set_yticks(np.arange(len(aminoacids_sorted)))
ax.set_yticklabels(aminoacids_sorted,fontsize=7)
ax.grid(False)
ax.set(xlabel='Residues')

fig.tight_layout()
fig.savefig(f'{NAME}_variants.pdf')

# Batch prediction

An input fasta file can be uploaded. The code below predicts phase separation propensity and saturation concentrations for all sequences in the fasta file.

In [None]:
#@title <b>Run batch prediction</b>

#@markdown File name
FASTA_FILE = "example.fasta" #@param {type:"string"}
CHARGE_TERMINI = True # @param {type:'boolean'}
TEMPERATURE = "293 K (fixed)" # @param ['293 K (fixed)']
IONIC_STRENGTH = "150 uM (fixed)" # @param ['150 uM (fixed)']

if not os.path.isfile(FASTA_FILE):
  print(f'Please upload file {FASTA_FILE}')
  uploaded = files.upload()
  if FASTA_FILE not in uploaded.keys():
    raise NameError(f'Could not find file {FASTA_FILE}')

records = sequence.read_fasta(FASTA_FILE)

print('-'*80)
print(f'FASTA FILE: {FASTA_FILE}')
print(f'NUMBER OF SEQUENCES: {len(records)}')

#@title <b>Run batch prediction</font></b>

df_records = pd.DataFrame(dtype=object)

for name, record in tqdm(records.items(),total=len(records)):
  seq = str(record.seq)
  df_records.loc[name,'Sequence'] = seq
  seqfeats = sequence.SeqFeatures(seq,residues=residues,
                                  charge_termini=CHARGE_TERMINI,nu_file=nu_file)
  for feat in features:
    df_records.loc[name,feat] = getattr(seqfeats,feat)
  X = X_from_seq(seq,features,residues=residues,
                 charge_termini=CHARGE_TERMINI,nu_file=nu_file)
  for target in targets:
    ys = models[target].predict(X)#,models)
    ys_m = np.mean(ys)
    if target == 'dG':
      df_records.loc[name,'Delta G [kT]'] = ys_m
    if target == 'logcdil_mgml':
      cdil_mgml = np.exp(ys_m)
      df_records.loc[name,'Saturation concentration [mg/mL]'] = cdil_mgml
      cdil_uM = cdil_mgml / seqfeats.mw * 1e6
      df_records.loc[name,'Saturation concentration [uM]'] = cdil_uM

df_records.index.name = 'Name'
df_records.to_csv('df_PSprediction.csv')

print('\n')
print('='*114)
print(f'{"Name":20s} {"Sequence":33s} {"Delta G":>10s} {"Saturation":>16s} {"Saturation":>16s}')
print(f'{"":20s} {"":33s} {"":>10s} {"concentration":>16s} {"concentration":>16s}')
print(f'{"":20s} {"":33s} {"[kT]":>10s} {"[mg/mL]":>16s} {"[uM]":>16s}')

print('='*114)
for key, val in df_records.iterrows():
  if len(list(val["Sequence"])) > 30:
    seqpr = f'{val["Sequence"][:30]:30s}...'
  else:
    seqpr = f'{val["Sequence"][:30]:30s}'
  print(f'{key:20s} {seqpr:33s} {val["Delta G [kT]"]:10.1f} {val["Saturation concentration [mg/mL]"]:16.1f} {val["Saturation concentration [uM]"]:16.1f}')

In [None]:
#@title <b>Analysis</font></b>

fcolor = plt.cm.summer

fig, ax = plt.subplots(1,2,figsize=(9,4))

for idx, target in enumerate(['Delta G [kT]','Saturation concentration [mg/mL]']):
  axij = ax[idx]
  axij.hist(df_records[target],bins=20,color=fcolor(0))

  axij.set_xlabel(f'{target}')
  axij.set_ylabel('Counts')
  axij.grid(alpha=0.3)
fig.tight_layout()

In [None]:
#@title <b>Download dataframe</font></b>

files.download('df_PSprediction.csv')