In [73]:
import os
import sys

import numpy as np
import pandas as pd
import scipy.stats
import scipy.optimize

import itertools

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sn
from plotnine import *
#Suppress pivot warning when savining plots with plotnine
import warnings
warnings.simplefilter("ignore")

%load_ext watermark

The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark


In [42]:
#Set plotting defaults
sn.set_style('ticks')
mpl.rcParams.update({'text.color': 'black', 'axes.labelcolor': 'black', 
                     'xtick.color': 'black', 'ytick.color': 'black', 'figure.dpi':100, 'savefig.dpi':300,
                     'savefig.bbox': 'tight', 'savefig.transparent': True, 'mathtext.default':'regular'})
sn.set_context('notebook')
%matplotlib inline
%config InlineBackend.figure_format='retina'

In 20220420_sim1inputTitration_randomParams.ipynb, we simulated 1-input dimerization networks of various sizes (3-6 monomer species) with latin-hypercube sampled parameters(binding affinities $K$ and accessory monomer concentrations $a$) Given the equilibrium concentration of dimers, try fitting sin and cos functions with a linear combination of dimers. Constrain the output weights to be nonnegative since I believe it will be easier to engineer the corresponding synthetic network. Use [scipy.optimize.nnls](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.nnls.html) to perform the fitting. 

In [43]:
def nnls_fun(x, b, nrows, ncols):
    """
    Utility function for solving argmin_x || Ax - b ||_2 with x≥0 using scipy.optimize.nnls.
    x corresponds to flattened form of A (enabling nnls_fun to by applied along axis of a 2d numpy array).     
    Returns 1d array containg the fitted coefficients and the residual. 
    """
    A = x.reshape(nrows,ncols)
    return np.hstack(scipy.optimize.nnls(A, b))

In [44]:
def fit_dimers_nnls(simfile, target_function, 
                    m = 3, n_input = 1, norm = 'max'):
    """
    For the input target function, perform non-negative least squares fitting on 
    the equlibrium dimer concentrations in simfile. 
    
    Parameters
    ----------
    simfile : string
        path to .npy file containing results of the network simulation
        Loaded into S_all
    target_function : array_like shape (n_titration, )
        array containing the output values for fitting linear combinations of dimers
        n_titration corresponds to the length of the input titration used for network simulations
        n_titration should equal S_all.shape[0]
    m : int. Default 3. 
        Number of monomer species. 
    n_input : int. Default 1
        Number of input monomer species. 
    norm : string. Default 'max'.
        Indicates how to normalize the output concentrations. 
    Returns
    -------
    combined_df : DataFrame, shape (n_titration * S_all.shape[2], 4)
        target function and fitted output for all parameter universes, long format
    fit_df : DataFrame, shape (S_all.shape[2], number of dimers + 2)
        fitted output for all parameter universes, wide format
    weights: array_like, shape (number,  S_all.shape[2])
        fitted dimer weights for all parameter universes
    """
    S_all = np.load(simfile)
    
    #Note the normalization before nnls fitting on each dimer separately. 
    #Convenient for putting the dimers on the same scale but probably not realistic biologically.
    if norm == 'max':  
        S_all = S_all/S_all.max(axis = 0)[np.newaxis,:,:]
    
    n_titration = S_all.shape[0]
    n_dimers = S_all.shape[1]-m
    n_univ = S_all.shape[2]
    
    S_all_dimers = S_all[:,m:,:].reshape(n_titration*n_dimers,n_univ)
    
    res = np.apply_along_axis(nnls_fun, 0, S_all_dimers, target_function, n_titration, n_dimers)
    
    weights = res[:n_dimers,:]
    resid = res[n_dimers,:]
    fit = np.vstack([np.matmul(S_all[:,m:,univ], weights[:,univ]) for univ in range(n_univ)])
    
    fit_df = pd.DataFrame(fit)
    fit_df['univ'] = np.arange(n_univ)
    fit_df['resid'] = resid
    fit_df_long = fit_df.melt(id_vars=['univ', 'resid'], value_name='y', var_name='x')
    
    target_fun_df = pd.DataFrame({'univ': ['target']*n_titration,
                             'resid': [0.0]*n_titration,
                             'x': np.arange(n_titration),
                             'y': target_function})
    
    combined_df = pd.concat((target_fun_df, fit_df_long))
    
    return combined_df, fit_df, weights

In [74]:
def plot_best_output(target_function, target_function_name, fit_df, m = 3, 
                     top_n = 25, save = False, outfileprefix = ''):
    """
    Plot output curves for the target function and the top_n best fit dimerization networks. 
    
    Parameters
    ----------
    target_function : array_like shape (n_titration, )
        array containing the output values for fitting linear combinations of dimers
        n_titration corresponds to the length of the input titration used for network simulations
    target_function_name : string
       Name of the target function. Used for filename. 
    fit_df : DataFrame, shape (S_all.shape[2], number of dimers + 2)
        fitted output for all parameter universes, wide format
        Returned from fit_dimers_nnls()
    m : int. Default 3. 
        Number of monomer species.
    top_n : int. Default 25.
        Maximum rank of the fitted outputs to plot
    save : Bool. Default False
        If True, save plot figure
    outfileprefix : string. Default ''
        relative path for saving figure. 
    Returns
    -------
    p : plotnine figure object. 
    """
    n_titration = len(target_function)
    top_df = fit_df.sort_values('resid').head(top_n*n_titration)
    top_df['univ'] = pd.Categorical(top_df['univ'], categories= top_df['univ'].unique())
    ncol = int(np.sqrt(top_n))
    
    p = (ggplot(top_df, aes(x= 'x', y='y', group = 1))
     + geom_line()
     + facet_wrap('~univ', ncol = ncol)
     + theme_classic()
     + scale_y_continuous(breaks = [0,1])
     + theme(strip_background = element_blank(),
            axis_text_x=element_blank(),
            text = element_text(family='Helvetica', color='black')))
    if save:
        p.save(f'{outfileprefix}fitCurves_{target_function_name}Fun_{m}M_top{top_n}.pdf', dpi = 300)
    return p

In [75]:
n_titration = 10
target_function_sin = 0.5*(np.sin(np.arange(n_titration))+1)
target_function_cos = 0.5*(np.cos(np.arange(n_titration))+1)


In [31]:
fit_sin_3N_df, _, _ = fit_dimers_nnls('../data/20220420_1input_randomParams/S_all_3M_1000k.npy',
                                      target_function_sin, m=3, n_input=1, norm='max')

In [34]:
outfileprefix = '../plots/20220420_1input_randomParams/'
if not os.path.isdir(outfileprefix):
    os.mkdir(outfileprefix)

In [35]:
_ = plot_best_output(target_function_sin, 'sin', fit_sin_3N_df, m = 3, 
                     n_titration = 10, top_n = 25, save = True, outfileprefix = outfileprefix)



In [37]:
fit_cos_3M_df, _, _ = fit_dimers_nnls('../data/20220420_1input_randomParams/S_all_3M_1000k.npy',
                                      target_function_cos, m=3, n_input=1, norm='max')

In [46]:
_ = plot_best_output(target_function_cos, 'cos', fit_cos_3M_df, m = 3, 
                     n_titration = 10, top_n = 25, save = True, outfileprefix = outfileprefix)

In [48]:
fit_sin_4M_df, _, _ = fit_dimers_nnls('../data/20220420_1input_randomParams/S_all_4M_1000k.npy',
                                      target_function_sin, m=4, n_input=1, norm='max')

In [51]:
_ = plot_best_output(target_function_sin, 'sin', fit_sin_4M_df, m = 4, 
                     n_titration = 10, top_n = 25, save = True, outfileprefix = outfileprefix)

In [52]:
fit_cos_4M_df, _, _ = fit_dimers_nnls('../data/20220420_1input_randomParams/S_all_4M_1000k.npy',
                                      target_function_cos, m=4, n_input=1, norm='max')

In [53]:
_ = plot_best_output(target_function_cos, 'cos', fit_cos_4M_df, m = 4, 
                     n_titration = 10, top_n = 25, save = True, outfileprefix = outfileprefix)

In [54]:
fit_sin_5M_df, _, _ = fit_dimers_nnls('../data/20220420_1input_randomParams/S_all_5M_1000k.npy',
                                      target_function_sin, m=5, n_input=1, norm='max')

In [55]:
_ = plot_best_output(target_function_sin, 'sin', fit_sin_5M_df, m = 5, 
                     n_titration = 10, top_n = 25, save = True, outfileprefix = outfileprefix)

In [56]:
fit_cos_5M_df, _, _ = fit_dimers_nnls('../data/20220420_1input_randomParams/S_all_5M_1000k.npy',
                                      target_function_cos, m=5, n_input=1, norm='max')

In [57]:
_ = plot_best_output(target_function_cos, 'cos', fit_cos_5M_df, m = 5, 
                     n_titration = 10, top_n = 25, save = True, outfileprefix = outfileprefix)

In [58]:
fit_sin_6M_df, _, _ = fit_dimers_nnls('../data/20220420_1input_randomParams/S_all_6M_1000k.npy',
                                      target_function_sin, m=6, n_input=1, norm='max')

In [59]:
_ = plot_best_output(target_function_sin, 'sin', fit_sin_6M_df, m = 6, 
                     n_titration = 10, top_n = 25, save = True, outfileprefix = outfileprefix)

In [60]:
fit_cos_6M_df, _, _ = fit_dimers_nnls('../data/20220420_1input_randomParams/S_all_6M_1000k.npy',
                                      target_function_cos, m=6, n_input=1, norm='max')

In [61]:
_ = plot_best_output(target_function_cos, 'cos', fit_cos_6M_df, m = 6, 
                     n_titration = 10, top_n = 25, save = True, outfileprefix = outfileprefix)

In [76]:
%watermark --iversions
%watermark -p plotnine

matplotlib: 3.5.1
numpy     : 1.20.3
seaborn   : 0.11.2
sys       : 3.9.7 (default, Sep 16 2021, 08:50:36) 
[Clang 10.0.0 ]
pandas    : 1.4.1
plotnine  : 0.8.0
scipy     : 1.7.3

plotnine: 0.8.0

