In [None]:
import os, string, re
import numpy as np
import plotly.graph_objects as go
from operator import itemgetter
from itertools import product
from collections import OrderedDict
from scipy.stats import fisher_exact, chi2_contingency
from statsmodels.stats.multitest import multipletests
from plotly.subplots import make_subplots
from fitter import Fitter
from collections import defaultdict
from kagami.comm import *
from kagami.dtypes import Table
from kagami.portals import tablePortal
from kagami.wrappers import RWrapper
from acs.decomposition import RGCCA, SGCCA
from acs.plots import scatter_plot

# Functions

---

## Filtering

In [None]:
def _abandon_filter(tab):
    print(tab.shape, end = ' -> ')

    rcut = np.mean(np.sum(tab.X_,axis = 1))*0.001
    tab = tab[:,np.var(tab.X_, axis = 0) > 0]
    tab = tab[:,np.sum(~np.isclose(tab.X_,0), axis = 0) >= 3]
    tab = tab[:,np.sum(tab.X_ >= rcut, axis = 0) > 0]

    print(tab.shape)
    return tab

In [None]:
def _taxtab(otutab, grp):
    otus = np.array(otutab.cols_)
    taxa = np.array(otutab.cidx_[grp])
    utax = np.unique(taxa)

    ttab = Table(
        np.vstack(smap(utax, lambda t: np.sum(otutab[:,taxa == t].X_, axis = 1))).T,
        rownames = otutab.rownames,
        colnames = utax,
    )
    
    return ttab

## Identify significant corr

In [None]:
def _corr(x,y,p,side):
    vx,vy = (x != 0), (y != 0)
    xp,yp = np.sum(vx)/len(x), np.sum(vy)/len(y)

    if (np.sum(vx) < 3 or np.sum(vy) < 3) or \
       (np.sum(np.logical_and(~vx, ~vy))/np.sum(np.logical_or(vx, vy)) > 0.5) or \
       (np.isclose(np.var(x),0) or np.isclose(np.var(y),0)): return xp,yp,np.nan,np.nan

    cc = np.corrcoef(x,y)[0,1]
    
    pv = np.nan if missing(p) else \
        (np.sum(p < cc) / p.shape[0]) if side == 'lower' else \
        (np.sum(p > cc) / p.shape[0]) if side == 'upper' else \
       ((np.sum(p < cc) if cc < 0 else np.sum(p > cc)) / p.shape[0])
    return xp,yp,cc,pv

In [None]:
def _fdr(pvs):
    if len(pvs) == 0: return []
    return multipletests(pvs, method = 'fdr_bh')[1]

In [None]:
def _corrw(x,y,wsize,pmv,side):
    fracs = smap(range(len(x)-wsize+1), lambda i: slice(i,i+wsize))
    cores = smap(enumerate(fracs), unpack(lambda i,f: (i,*_corr(x[f],y[f],pmv,side))))

    pval = np.array(lzip(*cores)[-1])    
    padj = np.ones(len(cores)) * np.nan
    padj[~np.isnan(pval)] = _fdr(pval[~np.isnan(pval)])
    
    cores = [(*c,p) for c,p in zip(cores,padj)]
    return cores

## Identify multiple significant corr over sliding windows

In [None]:
def _re_conv_fname(fname):
    vchrs = "-_.() %s%s" % (string.ascii_letters, string.digits)
    fname = ''.join(c for c in fname if c in vchrs)
    fname = fname.replace(' ','_')
    return fname

In [None]:
def _sig_corr(chm, tax, pmxpath, minwsize, rcut, pcut, sign, side):
    assert minwsize > 0
    assert 0 <= rcut <= 1
    assert missing(pcut) or 0 <= pcut <= 1
    assert sign in ('pos', 'neg', 'both')
    assert side in ('upper', 'lower', 'both')
    
    pfile = os.path.join(
        pmxpath,
        _re_conv_fname(f'permutation_null_{chm.cols_[0]}_{tax.cols_[0]}_nperms_10000.npz')
    )
    pmx = np.load(pfile)['permatx']

    pmx = smap(pmx, lambda v: v[~np.isnan(v)],
                    lambda v: None if len(v) < pmx.shape[1]*0.5 else np.sort(v))
    pmx = pmx[minwsize-3:]
    
    if missing(pcut) or side == 'both':
        pcuts = (np.ones(len(pmx)) * np.nan)
    else:
        pcuts = smap(pmx, lambda p: np.nan if missing(p) else \
                                    np.quantile(p, pcut if side == 'lower' else 1-pcut))

    wsize = np.arange(minwsize,minwsize+len(pmx))
    cx,tx = chm.X_.ravel(), tax.X_.ravel()

    cores = smap(zip(wsize,pmx), unpack(lambda w,p: _corrw(cx,tx,w,p,side)))

    cores = smap(cores, 
        lambda r: pick(r,unpack(
            lambda i,xp,yp,cc,pv,pa: available(pa)
        ))
    )

    if available(rcut):
        cores = smap(cores, 
            lambda r: pick(r,unpack(
                lambda i,xp,yp,cc,pv,pa: xp >= rcut and yp >= rcut
            ))
        )
        
    if sign != 'both':
        cores = smap(cores, 
            lambda r: pick(r,unpack(
                lambda i,xp,yp,cc,pv,pa: (cc < 0) if sign == 'neg' else (cc > 0)
            ))
        )
        
    cores = smap(cores, lambda r: smap(r, itemgetter(0,-3,-2,-1)))
    return l(zip(wsize,pcuts,cores))

In [None]:
def _sig_corr_mp(params):
    res = _sig_corr(*params)
    return params[1].cols_[0], drop(res, lambda r: len(r[2]) == 0)

def _sig_corr_dct(chm, txtab, pmxpath, minwsize, rcut, pcut, sign, side):
    pms = [(chm, txtab[:,tax], pmxpath, minwsize, rcut, pcut, sign, side) for tax in txtab.cols_] 
    rdct = pmap(pms, _sig_corr_mp)
    rdct = {k:v for k,v in rdct if len(v) > 0}
    return rdct

## Identify the optimal significant corr over sliding windows

In [None]:
def _opt_corr_dct(cordct):
    dct = {k: np.array([(w,c,st,cr,pv,pa) for w,c,rs in vs for st,cr,pv,pa in rs]) for k,vs in cordct.items()}
    _pick = lambda x: x[np.lexsort((-np.abs(x[:,-3]),-x[:,0],x[:,-2],x[:,-1]))][0]
    dct = {k: _pick(vs) for k,vs in dct.items()}
    return dct

In [None]:
def _corr_dct_tab(optdct):
    tab = Table(
        np.zeros((len(optdct), 5)), dtype = float,
        rownames = l(optdct.keys()),
        colnames = ['start time point', 'window size', 'correlation', 'corr p-value', 'corr p-adj'],
    )
    for k,(w,_,st,cr,pv,pa) in optdct.items(): tab[k] = [st,w,cr,pv,pa]
    return tab

## Plot correlation

In [None]:
def _plot_corr(chms, taxa, pos = None, wsize = None, std = True, title = None, legend = True):
    nm,cx,tx = np.array(chms.rows_), chms.X_, taxa.X_
    if available(pos) and available(wsize): nm,cx,tx = smap((nm,cx,tx), lambda x: x[pos:pos+wsize])
    if std: cx,tx = smap((cx,tx), lambda x: (x - np.mean(x,axis=0)) / np.std(x,axis=0))

    fig = make_subplots(specs=[[{'secondary_y': True}]])

    for ci in range(chms.ncol):
        fig.add_trace(
            go.Scatter(x = nm, y = cx[:,ci], mode = 'lines+markers', name = chms.cols_[ci]),
            secondary_y = False,
        )

    for ti in range(taxa.ncol):
        fig.add_trace(
            go.Scatter(x = nm, y = tx[:,ti], mode = 'lines+markers', name = taxa.cols_[ti]),
            secondary_y = True,
        )

    if chms.ncol == 1 and taxa.ncol == 1:
        cc = np.corrcoef(cx.ravel(),tx.ravel())[0,1]        

        fig.update_layout(
            title = {'text': f'corr = {cc:.2f}' if missing(title) else title},
            showlegend = legend,
        )
        
        fig.update_yaxes(title_text = chms.cols_[0][:20], secondary_y = False)
        fig.update_yaxes(title_text = taxa.cols_[0][:20], secondary_y = True)
    else:
        cc = np.corrcoef(np.hstack([cx,tx]).T)
        cc = np.nanmean(np.power(cc[np.triu_indices_from(cc,1)],2))
        
        fig.update_layout(
            title = {'text': f'mean corr = {cc:.2f}' if missing(title) else title},
            showlegend = legend,
        )
        
    return fig

## Identify significant OTU corr in range

In [None]:
def _opt_corr_otu_dct(taxdct, otudct, wcut):
    def _otu_in_rng(tax):
        wd,_,st,_,_,_ = taxdct[tax]
        rvec = np.zeros(otutab.nrow, dtype = bool)
        rvec[int(st):int(st+wd)] = True
        
        otus = np.array(otutab[:,otutab.cidx_[grpID] == tax].cols_)
        odct = {otu: [
            (w,c,[(st,cr,pv,pa) for st,cr,pv,pa in rs if np.sum(rvec[st:st+w]) > w*wcut]) 
            for w,c,rs in otudct[otu]
        ] for otu in otus if otu in otudct.keys()}
        
        odct = {k: [(w,c,rs) for w,c,rs in v if len(rs) > 0] for k,v in odct.items()}
        odct = {k: v for k,v in odct.items() if len(v) > 0}
        
        return _opt_corr_dct(odct)
    outdct = {t: _otu_in_rng(t) for t in taxdct.keys()}
    return {k:v for k,v in outdct.items() if len(v) > 0}

In [None]:
_insc = lambda x,y: np.intersect1d(x, y, assume_unique = True)
_diff = lambda x,y: np.setdiff1d(x, y, assume_unique = True)

def _ora(tgids, bgids, rfids, test):
    if len(tgids) == 0: return np.nan # ignored

    cfm = [
        _insc(tgids, rfids), _diff(rfids, tgids),
        _diff(tgids, rfids), _diff(_diff(bgids, rfids), tgids),
    ]
    cfm = np.array(smap(cfm, len)).reshape((2,2))

    _test = (lambda x: fisher_exact(x, alternative = 'greater')[1]) if test == 'fisher' else \
            (lambda x: chi2_contingency(x, correction = False)[1]) if test == 'chi2' else \
            (lambda x: x[0,0])
    return _test(cfm), len(rfids), cfm[0,0], cfm[0,0]/len(rfids)

In [None]:
def _opt_corr_otu_ora(otus, bgtab, test = 'fisher'):
    tgids = np.unique(otus)
    bgids = np.asarray(bgtab.cols_)

    gtaxa = bgtab.cidx_[grpID]
    utaxa = np.unique(gtaxa)
    rflst = smap(utaxa, lambda x: bgids[gtaxa == x])
    
    pvals,nrefs,nhits,recls = np.array(smap(rflst, lambda x: _ora(tgids, bgids, x, test))).T

    qvals = np.ones_like(pvals) * np.nan
    qvals[~np.isnan(pvals)] = _fdr(pvals[~np.isnan(pvals)])
    
    ectab = Table(
        np.vstack([nrefs, nhits, recls, pvals, qvals]).T,
        rownames = utaxa,
        colnames = ['OTU size', 'overlap size', 'ora recall', 'ora p-value', 'ora p-adj'] 
    )
    return ectab

## Plot p-values

In [None]:
def _tax_name(tax):
    ns = re.split('[kdpcofgs]__', tax)
    ns = smap(ns, lambda x: x.rstrip('_'))
    ns = drop(drop(ns, 'null'), '')
    return ns[-1]
    
def _plot_corr_pvals(cortab):
    fig = go.Figure()

    cors = np.abs(cortab[:,'correlation'].X_.ravel())
    mincor,maxcor = np.min(cors), np.max(cors)

    ccws = np.abs(cortab[:,'cca weight'].X_.ravel())
    minccw,maxccw = np.min(ccws), np.max(ccws)
    
    for tax in cortab.rows_:  
        txn = _tax_name(tax)
        cor,cpv,rcl,opv,ccw = cortab[tax,['correlation','corr p-adj','ora recall','ora p-adj','cca weight']].X_[0]

        if cpv < 0.05 and opv < 0.05: tax = f'<b>{tax} *</b>'
        fig.add_trace(go.Scatter(
            x = [rcl + np.random.normal(0,0.03)],
            # x = [-np.log(opv+1e-5)],
            y = [-np.log(cpv+np.random.normal(1e-5,5e-7)) + np.random.normal(0,0.03)],
            text = [txn],
            name = txn,
            # mode = 'markers+text',
            mode = 'markers',
            marker = {
                # 'size': (np.abs(cor)-mincor)/(maxcor-mincor)*100+10,
                # 'size': (np.abs(ccw)-minccw)/(maxccw-minccw)*100+10,
                'size': 50,
                'opacity': 0.7,
            },
            textposition = 'top center',
        ))
        
    fig.update_xaxes(title_text = 'OTU ORA Recall')
    # fig.update_xaxes(title_text = 'OTU ORA p-adj -ln(p-adj)')
    fig.update_yaxes(title_text = 'Taxon Correlation p-adj -ln(p-adj)')
    
    return fig

## CCA Factor Plot

In [None]:
def _plot_cca_factors(ccares, dss, comp, grps, axnams):
    ccaprojs = smap(
        zip(ccares.projections_, dss),
        unpack(lambda ld,ds: Table(
            ld, rownames=ds.rows_, colnames=smap(np.arange(ld.shape[1]), lambda x: f'component_{x+1}')
        ))
    )

    labs = np.array(ccaprojs[0].rownames)
    axis = np.array(smap(ccaprojs, lambda x: x[:,f'component_{comp+1}'].X_.ravel()))

    title = f'Mean Var. = {int(ccares.AVE_outer_[comp]*100)}%, Mean Corr. = {ccares.AVE_inner_[comp]:2f}'
    axnms = (
        f'{axnams[0]} ({int(ccares.AVE_[0][comp]*100)}%)', 
        f'{axnams[1]} ({int(ccares.AVE_[1][comp]*100)}%)', 
        f'{axnams[2]} ({int(ccares.AVE_[2][comp]*100)}%)',
    )
    
    if available(grps): grps = np.asarray(grps)    
    def _wrap(g = None):
        x,l = (axis,labs) if missing(g) else (axis[:,grps==g],labs[grps==g])
        trace = go.Scatter3d(x = x[0], y = x[1], z = x[2], text = l, name = g)
        return trace
    data = _wrap() if missing(grps) else smap(np.unique(grps),_wrap)
    fig = go.Figure(data)
        
    fig.update_traces(
        mode = 'markers+text',
        marker = {
            'size': 10,
            'colorscale': 'viridis',
        }
    )
        
    fig.update_layout(
        showlegend = True,
        title_text = title,
        scene = {
            'xaxis': {'title': axnms[0]}, 
            'yaxis': {'title': axnms[1]},
            'zaxis': {'title': axnms[2]}
        },
    )
    return fig

# Load Datasets

---

In [None]:
otuID = '16sv1'
grpID = 'D4'

In [None]:
# taxtab = Table.loadhdf(f'data/dss/{otuID}_proc_taxons_{grpID}.hdf')
otutab = Table.loadhdf(f'data/dss/{otuID}_proc.hdf')

In [None]:
envtab = Table.loadhdf('data/dss/envs_proc.hdf')
chmtab = Table.loadhdf('data/dss/chms_proc_TY.hdf')
cpdtab = Table.loadhdf('data/dss/chms_proc.hdf')

## Align

In [None]:
# trgtab = envtab
trgtab = chmtab

In [None]:
sids = fold(
    smap([trgtab, otutab], lambda x: x.rows_), 
    np.intersect1d
)
sids = np.array(sorted(sids, key = lambda x: int(x)))

In [None]:
trgtab, otutab = smap([trgtab, otutab], lambda x: x[sids])

In [None]:
smap([trgtab, otutab], lambda x: x.shape)

## Filter

In [None]:
otutab = _abandon_filter(otutab)

In [None]:
taxtab = _taxtab(otutab, grpID)
taxtab.shape

# Overall Correlations

---

In [None]:
dss = [trgtab, taxtab]

In [None]:
ncomps = np.min([5, np.min(smap(dss, lambda x: x.ncol))])

# cca = RGCCA(n_components = ncomps).fit(smap(dss, lambda x: x.X_), tau = 1)
cca = SGCCA(n_components = ncomps).fit(
    smap(dss, lambda x: x.X_), 
    # c1 = [1/np.sqrt(dss[0].ncol)+1e-5, 1/np.sqrt(dss[1].ncol)+1e-5, 1]
    c1 = [1.5/np.sqrt(dss[0].ncol)+1e-5, 1]
)

cols = smap(range(ncomps), lambda x: f'component_{x+1}')
ccaloads = smap(
    zip(cca.loadings_, dss),
    unpack(lambda ld,ds: Table(ld, rownames = ds.cols_, colnames = cols)),
)

In [None]:
ccaloads[0].todataframe()

In [None]:
factors = np.array(ccaloads[0].rows_)[np.argmax(np.abs(ccaloads[0].X_),axis=0)]

# Features Contribute to Components

---

In [None]:
fid = 0

In [None]:
cca_taxtab = ccaloads[1][:,fid]
cca_taxtab.cols_ = ['cca weight']

In [None]:
fctID = factors[fid]

fctdct = {
    'annual mean temp':      'Annual_mean_temp',
    'jun-aug mean temp':     'Jun-aug_mean_temp',
    'yearly daily min temp': 'Yearly_daily_min_temp',
    'yearly daily max temp': 'Yearly_daily_max_temp',
    'highest temp':          'Highest_temp',
    'lowest temp':           'Lowest_temp',
    'Mean atm hPa':          'Mean_atm_hPa',
    'jun-aug mean atm hPa':  'Jun-aug_mean_atm_hPa',
    'annual total precip':   'Annual_total_precip',
    'jun-aug total precip':  'Jun-aug_total_precip',
    'max24h precipitation':  'Max_24h_precipitation',
    'fungicide(T/Y)':        'Fungicide', 
    'herbicide(T/Y)':        'Herbicide', 
    'insecticide(T/Y)':      'Insecticide', 
    'pesticide(T/Y)':        'Pesticide',
    'DDT':                   'DDT',
}
fctnam = fctdct[fctID]

fct = trgtab[:,fctID]

In [None]:
all_taxcor = _sig_corr_dct(
    fct, taxtab, 
    pmxpath = f'results/{otuID}/perm_null/{grpID}/', 
    minwsize = 5, 
    rcut = 0.5, 
    pcut = 0.05,
    sign = 'neg',
    side = 'lower',
    # sign = 'both',
    # side = 'both',
)

In [None]:
all_otucor = _sig_corr_dct(
    fct, otutab, 
    pmxpath = f'results/{otuID}/perm_null/OTU/', 
    minwsize = 5,
    rcut = 0.5, 
    pcut = 0.05,
    sign = 'neg',
    side = 'lower',
    # sign = 'both',
    # side = 'both',
)

In [None]:
opt_taxcor = _opt_corr_dct(all_taxcor)
opt_taxtab = _corr_dct_tab(opt_taxcor)

In [None]:
sig_taxcor = {k:v for k,v in opt_taxcor.items() if np.abs(v[3]) >= 0.5 and v[-1] < 0.05}

# OTUs Contribute to Correlation

---

In [None]:
opt_otucor = _opt_corr_otu_dct(sig_taxcor, all_otucor, 0.8)

In [None]:
sig_otucor = {k: {
    o: rs for o,rs in v.items() if np.abs(rs[3]) >= 0.5 and rs[-1] < 0.05
} for k,v in opt_otucor.items()}
sig_otucor = {k:v for k,v in sig_otucor.items() if len(v) > 0}

In [None]:
# bgtab = Table.loadhdf(f'data/dss/{otuID}_clean.hdf')
bgtab = otutab

In [None]:
sig_otus = np.hstack(smap(sig_otucor.values(), lambda x: l(x.keys())))
sig_otuenc = _opt_corr_otu_ora(sig_otus, bgtab, test = 'fisher')

In [None]:
res_tab = opt_taxtab.append(sig_otuenc[opt_taxtab.rows_], axis = 1).append(cca_taxtab[opt_taxtab.rows_], axis = 1)
res_tab.savecsv(
    f'results/{otuID}/tables/{fctnam}_{grpID}_corrs.csv',
    transpose = False
)

In [None]:
res_fig = _plot_corr_pvals(res_tab)
checkOutputDir(f'results/{otuID}/figs/')
res_fig.write_html(f'results/{otuID}/figs/{fctnam}_{grpID}_corrs.html')

# Significant Taxa for Interpretation 

---

## Load top results

In [None]:
otuID = 'coi'

In [None]:
envloads = Table.loadcsv(
    f'results/{otuID}/tables/Overall_sCCA_env_D4_weights.csv', 
    transposed = False,
)
topenvs = np.array(envloads.rows_)[np.argmax(np.abs(envloads.X_),axis=0)][:5]

chmloads = Table.loadcsv(
    f'results/{otuID}/tables/Overall_sCCA_chm_D4_weights.csv', 
    transposed = False,
)
topchms = np.array(chmloads.rows_)[np.argmax(np.abs(chmloads.X_),axis=0)][:5]

In [None]:
flst = [
    f'results/{otuID}/tables/{fctdct[e]}_D4_corrs.csv' for e in topenvs
] + [
    f'results/{otuID}/tables/{fctdct[c]}_D4_corrs.csv' for c in topchms
]
assert checkall(flst, os.path.isfile)
len(flst)

In [None]:
rtabs = smap(flst, lambda x: Table.loadcsv(x, transposed=False))

## Identify cutoffs and filter

In [None]:
recl = np.hstack(smap(rtabs, lambda x: x[:,'ora recall'].X_[:,0]))

f = Fitter(recl, distributions=['halfgennorm','gamma','chi'])
f.fit()
f.summary()

In [None]:
dst = halfgennorm(**f.get_best()['halfgennorm'])
p95, p99 = dst.ppf(0.95), dst.ppf(0.99)
print(p95, p99)

In [None]:
q90 = np.percentile(recl, 90)
q90

In [None]:
mcols = rtabs[0].cols_

mtabs = []
for tb,fn in zip(rtabs, flst):
    tb = tb[:,mcols]
    px = fileTitle(fn).rsplit('_',2)[0]
    tb.cols_ = smap(tb.cols_, lambda x: f'{px}_{x}')
    mtabs += [tb]
mids = fold(smap(mtabs, lambda x: np.array(x.rows_)), np.union1d)

rmtab = Table(
    np.zeros((len(mids), len(mcols)*len(mtabs))) * np.nan, 
    rownames = mids,
    colnames = np.hstack(smap(mtabs, lambda x: x.cols_))
)
for tb in mtabs: rmtab[tb.rows_, tb.cols_] = tb.X_

In [None]:
sigs = smap(
    rmtab.rows_, 
    lambda x: np.sum(smap(rtabs, lambda tb: 1 if x in tb.rows_ and
                                                 np.abs(tb[x,'correlation'].X_[0][0]) >= 0.5 and
                                                 tb[x,'corr p-adj'].X_[0][0] < 0.05 and
                                                 tb[x,'ora recall'].X_[0][0] >= 0.5 else 0))
)

rmtab = Table(
    np.array(sigs).reshape((-1,1)), dtype = float, 
    rownames = rmtab.rows_, 
    colnames = ['Significant'],
).append(rmtab, axis = 1)
rmtab = rmtab[np.argsort(sigs)[::-1]]

In [None]:
rmtab.savecsv(f'results/{otuID}/tables/Summary_ext_{grpID}_corrs.csv', transpose = False)

In [None]:
rmtab = rmtab[:,pick(rmtab.cols_, 
    lambda x: x == 'Significant' or 'correlation' in x or 'corr p-adj' in x or 'ora recall' in x
)]
rmtab.savecsv(f'results/{otuID}/tables/Summary_{grpID}_corrs.csv', transpose = False)

# Corr Plots

## Load data and align

In [None]:
otuID = 'rbcl'
grpID = 'D4'
ctype = 'chm'

In [None]:
otutab = Table.loadhdf(f'data/dss/{otuID}_proc.hdf')

In [None]:
envtab = Table.loadhdf('data/dss/envs_proc.hdf')
chmtab = Table.loadhdf('data/dss/chms_proc_TY.hdf')
cpdtab = Table.loadhdf('data/dss/chms_proc.hdf')

In [None]:
trgtab = envtab if ctype == 'env' else chmtab

In [None]:
sids = fold(
    smap([trgtab, otutab], lambda x: x.rows_), 
    np.intersect1d
)
sids = np.array(sorted(sids, key = lambda x: int(x)))

In [None]:
trgtab, otutab = smap([trgtab, otutab], lambda x: x[sids])
if ctype == 'chm': cpdtab = cpdtab[sids]

In [None]:
smap([trgtab, otutab], lambda x: x.shape)

In [None]:
otutab = _abandon_filter(otutab)

In [None]:
taxtab = _taxtab(otutab, grpID)
taxtab.shape

## Load results

In [None]:
loadings = Table.loadcsv(
    f'results/{otuID}/tables/Overall_sCCA_{ctype}_D4_weights.csv', 
    transposed = False,
)
topfcts = np.array(loadings.rows_)[np.argmax(np.abs(loadings.X_),axis=0)][:5]

In [None]:
flst = [
    f'results/{otuID}/tables/{fctdct[e]}_D4_corrs.csv' for e in topfcts
]
assert checkall(flst, os.path.isfile)
len(flst)

In [None]:
rtabs = smap(flst, lambda x: Table.loadcsv(x, transposed=False))

## Plot

In [None]:
def _corr_fig(fct,tax,pos,wsize,title):    
    fig = _plot_corr(
        trgtab[:,fct],
        taxtab[:,tax],
        pos = pos,
        wsize = wsize,
        std = True,
        title = title
    )
    fig.write_html(f'results/{otuID}/figs/corr/Corr_{grpID}_{fctdct[fct]}-{tax}.html')
    
def _corr_fig_cpds(fct,tax,cpd,pos,wsize,title):    
    fig = _plot_corr(
        cpdtab[:,cpd],
        taxtab[:,tax],
        pos = pos,
        wsize = wsize,
        std = True,
        title = title
    )
    fig.write_html(f'results/{otuID}/figs/corr/Corr_{grpID}_{fctdct[fct]}-{tax}_cpds.html')

In [None]:
def _corr_fig_tab(fct,tax,pos,wsize):    
    otab = taxtab[:,tax].append(trgtab[:,fct], axis=1)
    otab = otab[pos:pos+wsize]
    otab.savecsv(f'results/{otuID}/tables/Corr_{grpID}_{fctdct[fct]}-{tax}.csv', transpose=False)
    
def _corr_fig_cpds_tab(fct,tax,cpd,pos,wsize):    
    otab = taxtab[:,tax].append(cpdtab[:,cpd], axis=1)
    otab = otab[pos:pos+wsize]
    otab.savecsv(f'results/{otuID}/tables/Corr_{grpID}_{fctdct[fct]}-{tax}_cpds.csv', transpose=False)

In [None]:
# checkOutputDir(f'results/{otuID}/figs/corr/')
for fct,rtab in zip(topfcts,rtabs):
    for tax in np.array(rtab.rows_):
        st,wd,cr,qv,rc = rtab[tax,[
            'start time point','window size','correlation','corr p-adj','ora recall'
        ]].X_[0]
        if np.abs(cr) < 0.5 or qv >= 0.05 or rc < 0.5: continue

        st, wd = int(st), int(wd)
        # _corr_fig(fct, tax, st, wd, title=f'Corr = {cr:.2f}, p-adj = {qv:.2E}')
        _corr_fig_tab(fct, tax, st, wd)
        
        ctab = cpdtab[st:st+wd,cpdtab.cidx_['class'] == fct[:-5]]
        ttab = taxtab[st:st+wd,tax]
        ccrs = np.array(smap(ctab.X_.T, lambda x: _corr(x,ttab.X_.ravel(),None,None)[2]))
        cids = np.argsort(ccrs)
        cids = cids[np.logical_and(ccrs[cids] < -0.5, ~np.isnan(ccrs[cids]))][:10]
        cpds = np.array(ctab.cols_[cids])
        if len(cpds) == 0: continue
        
        # _corr_fig_cpds(fct, tax, cpds, st, wd, title=f'Mean Corr = {np.mean(ccrs[cids]):.2f}')
        _corr_fig_cpds_tab(fct, tax, cpds, st, wd)

# Joint Effects

---

## Load dss

In [None]:
otuID = '18s'
grpID = 'D4'

In [None]:
# taxtab = Table.loadhdf(f'data/dss/{otuID}_proc_taxons_{grpID}.hdf')
otutab = Table.loadhdf(f'data/dss/{otuID}_proc.hdf')

In [None]:
envtab = Table.loadhdf('data/dss/envs_proc.hdf')
chmtab = Table.loadhdf('data/dss/chms_proc_TY.hdf')
cpdtab = Table.loadhdf('data/dss/chms_proc.hdf')

## Align

In [None]:
dss = [otutab, envtab, chmtab, cpdtab]

sids = fold(
    smap(dss, lambda x: x.rows_), 
    np.intersect1d
)
sids = np.array(sorted(sids, key = lambda x: int(x)))
dss = smap(dss, lambda x: x[sids])

otutab, envtab, chmtab, cpdtab = dss

In [None]:
smap(dss, lambda x: x.shape)

## Filter OTU -> Taxa

In [None]:
otutab = _abandon_filter(otutab)

In [None]:
taxtab = _taxtab(otutab, grpID)
taxtab.shape

## Filter Top Envs and Chms (T/Y)

In [None]:
envloads = Table.loadcsv(
    f'results/{otuID}/tables/Overall_sCCA_env_D4_weights.csv', 
    transposed = False,
)
topenvs = np.array(envloads.rows_)[np.argmax(np.abs(envloads.X_),axis=0)][:5]

chmloads = Table.loadcsv(
    f'results/{otuID}/tables/Overall_sCCA_chm_D4_weights.csv', 
    transposed = False,
)
topchms = np.array(chmloads.rows_)[np.argmax(np.abs(chmloads.X_),axis=0)][:5]

In [None]:
envtab = envtab[:,topenvs]
chmtab = chmtab[:,topchms]

## Joint CCA

In [None]:
dss = [envtab, chmtab, taxtab]
ncomps = np.min([5, np.min(smap(dss, lambda x: x.ncol))])

cca = SGCCA(n_components = ncomps).fit(
    smap(dss, lambda x: x.X_), 
    c1 = [1/np.sqrt(dss[0].ncol)+1e-5, 1/np.sqrt(dss[1].ncol)+1e-5, 1]
    # c1 = [0.5, 0.5, 1]
)

cols = smap(range(ncomps), lambda x: f'component_{x+1}')
ccaloads = smap(
    zip(cca.loadings_, dss),
    unpack(lambda ld,ds: Table(ld, rownames = ds.cols_, colnames = cols)),
)

In [None]:
ccaloads[0].todataframe()

In [None]:
ccaloads[1].todataframe()

In [None]:
phasemap = np.array(tablePortal.loadtsv('data/dss/LakeRing_eDNA_metadata.txt'))
phasemap = {y:p for s,y,p in phasemap[1:] if s.isdigit()}

In [None]:
topfcts = np.array([
    ccaloads[0].rows_[np.argmax(np.abs(ccaloads[0].X_),axis=0)],
    ccaloads[1].rows_[np.argmax(np.abs(ccaloads[1].X_),axis=0)],
]).T
topfcts

## Joint Corr

In [None]:
def _load_sig_corr(efct, cfct, corr_cutoff = 0.5):
    rtabs = smap(
        (efct,cfct), 
        lambda fct: Table.loadcsv(
            f'results/{otuID}/tables/{fctdct[fct]}_{grpID}_corrs.csv',
            transposed = False,
        )
    )
    
    rtabs = smap(rtabs, lambda tab: 
                 tab[np.logical_and.reduce([
                     np.abs(tab[:,'correlation'].X_[:,0]) >= corr_cutoff,
                     tab[:,'corr p-adj'].X_[:,0] < 0.05, 
                     tab[:,'ora recall'].X_[:,0] >= 0.5,
                 ])])
    
    otaxa = np.intersect1d(np.array(rtabs[0].rows_), np.array(rtabs[1].rows_))
    if len(otaxa) == 0: return []
    
    times = smap(
        ('env','chm'), 
        lambda x: np.array(tablePortal.loadcsv(
            f'results/{otuID}/tables/Overall_sCCA_{x}_D4_IDs.csv'
        )).flatten()
    )
    
    tys = []
    for tax in otaxa:
        ys = []
        for tab,tps in zip(rtabs,times):
            st,wd = tab[tax,['start time point','window size']].X_[0].astype(int)
            ys += [tps[st:st+wd]]
        tys += [ys]

    oys = smap(tys, lambda x: np.intersect1d(*x))
    jcs = smap(tys, lambda x: len(np.intersect1d(*x))/len(np.union1d(*x)))
    
    otaxa,oys,jcs = zip(*drop(zip(otaxa,oys,jcs), unpack(lambda t,y,j: len(y) <= 3)))
   
    return [(envtab[ys,efct],chmtab[ys,cfct],taxtab[ys,tax],jc) for tax,ys,jc in zip(otaxa,oys,jcs)]

In [None]:
checkOutputDir(f'results/{otuID}/figs/corr_multi/')
for ef,cf in topfcts:
    corrtabs = _load_sig_corr(ef,cf,corr_cutoff=0)
    for etab,ctab,ttab,_ in corrtabs:
        cc = np.corrcoef(np.hstack([etab.X_, ctab.X_, ttab.X_]).T)
        cc = np.nanmean(np.abs(cc[-1,:2]))
        if cc < 0.5: continue
        
        fig = _plot_corr(
            etab.append(ctab, axis = 1), 
            ttab,
            std = True,
            title = f'mean corr = {cc:.2f}',
            legend = True,
        )
        
        tf = str(ttab.cols_[0])
        fig.write_html(
            f'results/{otuID}/figs/corr_multi/Corr_{grpID}_{fctdct[ef]}-{fctdct[cf]}-{tf}.html'
        )

## Additive CCA

In [None]:
topfcts = l(product(topfcts[:,0], topfcts[:,1]))

In [None]:
fid = 4
tef,tcf = topfcts[fid]
ttabs = _load_sig_corr(tef,tcf,corr_cutoff=0)
print(f'{tef} - {tcf} => {len(ttabs)}')

In [None]:
def _cc(x,y):
    vx, vy = (x != 0), (y != 0)
    if (np.sum(vx) < 3 or np.sum(vy) < 3) or \
       (np.sum(np.logical_and(~vx, ~vy))/np.sum(np.logical_or(vx, vy)) > 0.5) or \
       (np.isclose(np.var(x),0) or np.isclose(np.var(y),0)): return np.nan
    return np.corrcoef(x,y)[0,1]

def _sig_joint_corr(ttab):    
    dss = [ttab[0].append(ttab[1],axis=1), ttab[2]]
    wds = [dss[0].nrow] 
    # wds = np.arange(5,dss[0].nrow+1)
    ttf = str(dss[1].cols_[0])

    cca = RGCCA(n_components = 1).fit(smap(dss, lambda x: x.X_))
    projs = cca.projections_
    
    y = projs[0][:,0]
    x = dss[1].X_[:,0]
    res = smap(wds, lambda w: _corrw(x,y,w,None,None))

    def _pm(w,p):
        ccm = np.array([_cc(
            np.random.choice(x, size = w, replace = False),
            np.random.choice(y, size = w, replace = False)
        ) for _ in range(2000)])
        ccm = ccm[~np.isnan(ccm)]
        if len(ccm) < 1000: logging.warning('less then 1000 permutations available for [%s]', ttf)
        return np.quantile(ccm, 1-p)
    cvs = smap(wds, lambda w: _pm(w,0.025))

    res = np.array([(st,w,cc) for w,r,c in zip(wds,res,cvs)
                              for st,xp,yp,cc,_,_ in r 
                              if xp >= 0.5 and yp >= 0.5 and np.abs(cc) > 0.5 and cc > c])
    if len(res) == 0: return None

    res[:,-1] = np.abs(res[:,-1])
    res = res[np.lexsort([res[:,-1],res[:,-2]])][::-1]

    st,wd,cr = res[0]
    st,wd = int(st),int(wd)

    crs = np.abs(np.corrcoef(np.hstack([dss[0].X_, dss[1].X_])[st:st+wd].T)[-1,:2])
    
    pltab = dss[0].append(Table(y.reshape((-1,1)),colnames=['joint effect']), axis=1)

    return ttf,(pltab[st:st+wd],dss[1][st:st+wd]),cr,crs,np.all(cr>crs)

In [None]:
res = drop(smap(ttabs, lambda x: _sig_joint_corr(x[:])), lambda x: x is None or x[-1] == False)
print(f'{paste(smap(res,lambda x:x[0]),sep="; ")}')

In [None]:
for ttf,(ds0,ds1),cr,_,_ in res:
    fig = _plot_corr(
        ds0, ds1,
        std = True,
        title = f'taxon - joint effect corr = {cr:.2f}',
        legend = True,
    )
    fig.write_html(
        f'results/{otuID}/figs/corr_multi/Additive_Corr_{grpID}_{fctdct[tef]}-{fctdct[tcf]}-{ttf}.html'
    )    

In [None]:
for ttf,(ds0,ds1),cr,_,_ in res:
    ods = ds1.append(ds0,axis=1)
    ods.savecsv(
        f'results/{otuID}/tables/Additive_Corr_{grpID}_{fctdct[tef]}-{fctdct[tcf]}-{ttf}.csv',
        transpose = False,
    )

In [None]:
for ttf,(ds0,ds1),cr,_,_ in res:
    ef,cf = ds0.cols_[:2]

    etab = envtab[ds0.rows_,ef]
    ctab = cpdtab[ds0.rows_,cpdtab.cidx_['class'] == cf[:-5]]
    ttab = taxtab[ds1.rows_,ttf]
    
    ccrs = np.array(smap(ctab.X_.T, lambda x: _corr(x,ttab.X_.ravel(),None,None)[2]))
    cids = np.argsort(ccrs)
    cids = cids[np.logical_and(ccrs[cids] < -0.5, ~np.isnan(ccrs[cids]))][:10]
    cpds = np.array(ctab.cols_[cids])
    if len(cpds) == 0: continue
        
    fig = _plot_corr(
        etab.append(ctab[:,cpds], axis = 1), ttab,
        std = True,
        title = f'taxon - joint effect corr = {cr:.2f}',
        legend = True,
    )
    fig.write_html(
        f'results/{otuID}/figs/corr_multi/Additive_Corr_{grpID}_{fctdct[tef]}-{fctdct[tcf]}-{ttf}_cpds.html'
    )

In [None]:
for ttf,(ds0,ds1),cr,_,_ in res:
    ef,cf = ds0.cols_[:2]

    etab = envtab[ds0.rows_,ef]
    ctab = cpdtab[ds0.rows_,cpdtab.cidx_['class'] == cf[:-5]]
    ttab = taxtab[ds1.rows_,ttf]
    
    ccrs = np.array(smap(ctab.X_.T, lambda x: _corr(x,ttab.X_.ravel(),None,None)[2]))
    cids = np.argsort(ccrs)
    
    cids = cids[np.logical_and(ccrs[cids] < -0.5, ~np.isnan(ccrs[cids]))]
    # cids = cids[:10]
    cpds = np.array(ctab.cols_[cids])
    ccrs = ccrs[cids]
    if len(cpds) == 0: continue
    
    otab = ttab.append(etab.append(ctab[:,cpds], axis = 1), axis=1)
    otab.cidx_ = {'corr': np.hstack([[np.nan,np.nan],ccrs])}
        
    otab.savecsv(
        # f'results/{otuID}/tables/Additive_Corr_{grpID}_{fctdct[tef]}-{fctdct[tcf]}-{ttf}_cpds.csv',
        f'results/{otuID}/tables/Additive_Corr_{grpID}_{fctdct[tef]}-{fctdct[tcf]}-{ttf}_cpds_all.csv',
        transpose = False,
    )

# Summary

In [None]:
summ_tabs = dict(smap(
    ('16sv1', '16sv4', '18s', 'coi', 'rbcl'),
    lambda x: (x, Table.loadcsv(f'results/{x}/tables/Summary_D4_corrs.csv')),
))

In [None]:
sig_tab = np.array(tablePortal.loadcsv('results/sig_corr_summary.csv'))[1:]
sig_tab.shape

In [None]:
fctmap = {
    'Fungicide':                'Fungicide', 
    'Herbicide':                'Herbicide', 
    'Insecticide':              'Insecticide',
    'Pesticide':                'Pesticide',
    'Highest_temp':             'Highest_temp',
    'Lowest_temp':              'Lowest_temp', 
    'Mean_minimum_temperature': 'Yearly_daily_min_temp', 
    'Summer_mean_temp':         'Jun-aug_mean_temp',
    'Max_24h_precipitation':    'Max_24h_precipitation',
    'Annual_total_precip':      'Annual_total_precip', 
    'Summer_total_precip':      'Jun-aug_total_precip',
    'Summer_mean_atm_hPa':      'Jun-aug_mean_atm_hPa',
}

In [None]:
def _find_cor(otab, fct, tax):
    rid = f'{fctmap[fct]}_correlation'
    assert rid in otab.rows_, f'{rid} not found'

    cid = f'f__{tax}'
    if cid not in otab.cols_: 
        cid = pick(otab.cols_, lambda x: tax in x)
        assert len(cid) == 1, f'multiple {cid} found'
        cid = cid[0]
    
    cor = otab[rid,cid].X_[0,0]
    return cor

In [None]:
cors = smap(
    sig_tab,
    unpack(lambda o,f,t: _find_cor(summ_tabs[o], f, t))
)