In [109]:
import os
import pandas as pd
import matplotlib.pyplot as plt
#plt.style.use('dark_background')
plt.style.use('default')
import uv

paths = [os.path.join('test-data/', i) for i in os.listdir('test-data/')]

datas = [uv.P450(i) for i in paths]

In [123]:
import re
import numpy as np
from tqdm import tqdm
from scipy.optimize import curve_fit
from scipy.ndimage import gaussian_filter1d

ctns = lambda p, s : re.search(p,s) is not None
mm = lambda x, km, vmax : (x * vmax) / (km + x)

def smooth(df,
           sigma=2,
           axis=-1,
           ):
    idx = df.index
    if isinstance(df, pd.DataFrame):
        cols = df.columns
        smth = pd.DataFrame(gaussian_filter1d(df, sigma, axis=axis),
                            index=idx,
                            columns=cols)
    elif isinstance(df, pd.Series):
        smth = pd.DataFrame(gaussian_filter1d(df, sigma),
                         index=idx,
                         )
    return smth

def c2(v1, c1, v2):
    return (v1 * c1) / v2

def get_mm(x,y):
    (km, vmax), covariance = curve_fit(mm, x, y, bounds=((0,0),(np.inf, np.inf)))
    return km, vmax

def r_squared(yi,yj):
    residuals = yi - yj
    sum_sq_residual = sum(residuals ** 2)
    sum_sq_total = sum((yi - yi.mean()) ** 2) # check this!!!
    return 1 - (sum_sq_residual / sum_sq_total)

def mkplot(df, save=False):
    baseline = df[[i for i in df.columns if ctns('baseline', i)]]
    dmso = df[[i for i in df.columns if ctns('dmso', i)]]
    bm3 = df[[i for i in df.columns if ctns('bm3', i)]]
    cpd = df[[i for i in df.columns if not ctns('baseline',i) and \
                                 not ctns('dmso',i) and \
                                 not ctns('bm3',i)]]
    
    vols_ = pd.Series(cpd.columns).str.extract('([0-9]+\.[0-9]+)')[0].astype(float)
    cpd_name = pd.Series(cpd.columns).str.extract(\
                                  '([A-Za-z_]+)')[0].str.replace('_',' ').unique()
    assert len(cpd_name) == 1
    cpd_name = cpd_name[0].strip()
    data = smooth(pd.concat([bm3, cpd], axis=1))
    
    data.columns =  [0] + vols_.tolist()
    concs = [c2(i, 10_000,1000) for i in data.columns]
    data.columns = concs
    data = data.sub(data.loc[800,:], axis=1)
    
    cm = plt.cm.inferno([i/max(concs) for i in concs])
    
    fig, ax = plt.subplots(2,2, figsize=(16,16))
    for i,j in zip(data.columns, cm):
        ax[0,0].plot(data.index, 
                     data[i], 
                     label=i,
                     c=j,
                     lw=1)
    ax[0,0].legend(loc='right')
    ax[0,0].set_ylim(-0.1,1)
    ax[0,0].set_xlim(200,800)
    ax[0,0].set_xlabel('Wavelength (nm)')
    ax[0,0].set_ylabel('Absorbance')
    ax[0,0].set_title('Absorbance Traces')
    # ------------------------
    diff = data.sub(data.iloc[:,0], axis=0)
    for i,j in zip(data.columns, cm):
        ax[0,1].plot(diff.index, 
                     diff[i], 
                     label=i,
                     c=j,
                     lw=1)
    ax[0,1].legend(loc='right')
    ax[0,1].set_ylim(-0.2,0.2)
    ax[0,1].set_xlim(200,800)
    ax[0,1].set_xlabel('Wavelength (nm)')
    ax[0,1].set_ylabel('Change in Absorbance')
    ax[0,1].set_title('Change in Absorbance')
    # ------------------------
    response = diff.loc[420,:].abs() + diff.loc[390,:].abs()
    km, vmax = get_mm(response.index, response)
    mmx = np.linspace(0, max(concs), 64)
    mmy = mm(mmx, km, vmax)
    ypred = mm(data.columns, km, vmax)
    rsq = r_squared(response.reset_index(drop=True), 
                    ypred)
    ax[1,0].scatter(response.index,
                    response
                   )
    ax[1,0].plot(mmx,
                 mmy
                 )
    ax[1,0].set_xlabel(f'{cpd_name} concentration uM')
    ax[1,0].set_ylabel('Response')
    ax[1,0].set_title(f'P450 BM3-{cpd_name} Dose-Response')
    # ------------------------
    ax[1,1].text(0.2, 0.5, 
        f"km = {round(km,4)}\nvmax = {round(vmax,4)}\nr squared = {round(rsq, 3)}",
                fontsize=12)
    ax[1,1].axis('off')
    
    fig.suptitle(f"Titration data for P450 BM3 and {cpd_name}")
    if save:
        if not os.path.exists('img'):
            os.mkdir('img')
        savename = f"{cpd_name.replace(' ','-')}.png"
        c=1
        while savename in os.listdir('img'):
            savename = f"{cpd_name}-{c}.png"
            c+=1
        plt.savefig(os.path.join('img', savename))
        plt.close()
    else:
        plt.show()
    
err=[]
for i in tqdm(datas):
    try:
        mkplot(i.df, True)
    except Exception as e:
        err.append((i.path, e))
        plt.close()
#mkplot(datas[2].df)

100%|██████| 20/20 [00:05<00:00,  3.60it/s]
