In [1]:
import numpy as np
import pandas as pd

import pickle

import copy

import matplotlib.pyplot as plt

In [2]:
import sys
# caution: path[0] is reserved for script path (or '' in REPL)
sys.path.insert(-1, '/mnt/x/Computation/Utilities')

import ocu_seaside.ocu_basics as se
import ocu_seaside.ocu_visuals as viz
import ocu_trident.ocu_tri_utils as tu
import ocu_binmeths as bm
import ocu_compass as co

import ocu_trident.ocu_deepstarr as ds
import ocu_trident.ocu_preassembled as pa
import ocu_oyster as oys

import os
os.chdir('/mnt/x/Computation/Projects/CS3-YY1/data/R2') ##########################

import torch
import torch.nn as nn
from torch.optim import Adam, AdamW

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
mainfolder = se.NewFolder('yy1_data')
fin_bins = se.PickleLoad(mainfolder + 'fin_bins')
print(fin_bins)

divset = ['Train', 'Stop.', ' Eval', 'Test']

xpsfolder = se.NewFolder('xps4')                            #1111111111111111111111111111111

10


In [4]:
exact_folder = se.NewFolder('Oyster')
exact_iter = 11
pn_RO_exact = se.NewFolder(exact_folder + 'RO_' + str(exact_iter))
pn_RO_exact_top = se.NewFolder(pn_RO_exact + 'Top' + str(0))
exact_ranked_can_dict = se.PickleLoad(pn_RO_exact + 'ranked_can_dict')

exact_dict = exact_ranked_can_dict[0]

In [5]:
yy1_split_transaug = se.PickleLoad(mainfolder + 'yy1_split_transaug')

yy1_max = se.PickleLoad(mainfolder + 'yy1_max')

tri_tpacks = se.PickleLoad(mainfolder + 'tri_tpacks')

d_x1, d_y_ms, d_s, d_b, d_x2 = tri_tpacks

d_s_rs = np.swapaxes(np.expand_dims(d_s, axis = -1), 1, -1)

In [6]:
histones = ['H3K4me3', 'H3K27ac', 'H3K27me3', 'H3K4me1', 'H3K36me3', 'H3K9me3', 'H3K9ac', 
            'H3K4me2', 'H4K20me1', 'H2AFZ', 'H3K79me2']

In [7]:
xp6_trial = 2
xp6_id = 'xp6' + '_' + str(xp6_trial) 
pn_xp6 = se.NewFolder(xpsfolder + xp6_id)
pn_xpF, icof, icom = pn_xp6, 0, 3

pn_t1 = se.NewFolder(pn_xpF + str(icof))
pn_t2 = se.NewFolder(pn_t1 + str(icom))

ex_ens_idx = se.PickleLoad(pn_t2 + 'ensemble_idx')

ex_ens_pn_preds, ex_ens_pn_mods = [[pn_t2 + '0_' + str(ir) + s 
                                    for ir in ex_ens_idx]
                                   for s in ['_preds.p',  '_Mod.pt']]

# Sequence Contributions

In [8]:
def Oys_SigContrib(model, inp, 
                   sub_model = None, 
                   batchsize = 256):
    
    shapo = inp.shape #the 2nd dimension is always the number of sigs. 
    num_sigs = shapo[1]

    if isinstance(model, list) is False: model = [model]
    #if sub_model is None: sub_model = ''

    lk = len(inp)

    fullbatches = lk // batchsize
    rem = lk % batchsize

    fins = []

    for fb in np.arange(fullbatches + (rem > 0)):

        fo = fb*batchsize
        batch = inp[fo:fo + batchsize]

        fin = []

        for mo in model: 

            if isinstance(mo, str): mo = tu.LoadTorch(mo)

            if sub_model is not None: mo = getattr(mo, sub_model)
            
            with torch.no_grad(): 
                mo.eval()

                ger = 0 if isinstance(mo.O[0], nn.Conv2d) else 1

                Dweight = mo.O[ger].weight
                Dweight = torch.unsqueeze(Dweight, 0)

                x = torch.FloatTensor(batch).to(device)

                x = mo.Reflect(x)
                x = mo.kE(x)
                x = mo.AntiReflect(x)
                x = mo.P(x)
                
                x = torch.unsqueeze(x, axis = 1)
                x = x * Dweight
                x = torch.squeeze(torch.sum(x, axis = 3))

                xshapo = x.shape
                x = torch.sum(x.reshape(xshapo[0], num_sigs, xshapo[-1] // num_sigs), -1)

                fin.append(x.cpu().detach().numpy())

        fin = np.mean(np.stack(fin), axis = 0)
        fins.append(fin) 
        
    return np.vstack(fins)

In [9]:
not_test = np.delete(np.arange(len(d_x1)), yy1_split_transaug[-1])

d_b_nt = d_b[not_test].reshape(-1)

#-------------------------------------

# nuc_contrib = np.mean(np.array([tu.FeatExtract(mo, 'OysterA', [d_x1, d_x2], batchsize = 512) 
#                                 for mo in ex_ens_pn_mods]), 
#                                 axis = 0)

# se.PickleDump(nuc_contrib, pn_xpF + 'nuc_contrib')

# #-------------------------------------

# sigs_contrib = Oys_SigContrib(ex_ens_pn_mods, d_x2, sub_model = 'OysterB')

# se.PickleDump(sigs_contrib, pn_xpF + 'sigs_contrib')

In [10]:
nuc_contrib = se.PickleLoad(pn_xpF + 'nuc_contrib')
nuc_contrib_abs = np.abs(nuc_contrib)
nuc_contrib_nt = nuc_contrib_abs[not_test]

nuc_contrib_bind = [nuc_contrib_nt[d_b_nt == b] for b in np.arange(fin_bins)]
nuc_contrib_means = np.array([np.mean(x) for x in nuc_contrib_bind]).reshape(1, -1)
nuc_contrib_std = np.array([np.std(x) for x in nuc_contrib_bind]).reshape(1, -1)

In [11]:
sigs_contrib = se.PickleLoad(pn_xpF + 'sigs_contrib')
sigs_contrib_abs = np.abs(sigs_contrib)
sigs_contrib_nt = sigs_contrib_abs[not_test]

sigs_contrib_bind = [sigs_contrib_nt[d_b_nt == b] for b in np.arange(fin_bins)]
sigs_contrib_means = np.array([np.mean(x, axis = 0) for x in sigs_contrib_bind]).T
sigs_contrib_stds = np.array([np.std(x, axis = 0) for x in sigs_contrib_bind]).T

In [12]:
all_contrib_means = np.vstack([nuc_contrib_means, sigs_contrib_means ])
all_contrib_maxs = all_contrib_means.max(1)

comb_contrib_nt = np.hstack([nuc_contrib_nt.reshape(-1, 1), sigs_contrib_nt])

comb_contrib_tot = comb_contrib_nt.sum(-1).reshape(-1, 1)
comb_influ = comb_contrib_nt / comb_contrib_tot

comb_bind = [comb_influ[d_b_nt == b] for b in np.arange(fin_bins)]
comb_means = np.array([np.mean(x, axis = 0) for x in comb_bind]).T

# influ_thresh = 0.2

# def boldo(x): 
#         return ['font-weight: bold' if y >= influ_thresh else '' for y in x]

pdx = pd.DataFrame(comb_means)
pdx.index = ['DNA'] + histones
pdx.columns = np.arange(fin_bins) + 1

pdx_sty = pdx.style.format(precision=4).background_gradient(cmap = 'Greys', axis = None, 
                                                            vmin = 0, vmax = 1)
pdx_sty

Unnamed: 0,1,2,3,4,5,6,7,8,9,10
DNA,0.3036,0.4834,0.4937,0.5003,0.5337,0.5444,0.5817,0.6021,0.6593,0.7083
H3K4me3,0.1115,0.0266,0.0133,0.009,0.0057,0.0027,0.0018,0.0005,0.0007,0.0005
H3K27ac,0.0158,0.0682,0.0758,0.0791,0.0762,0.079,0.0837,0.0779,0.064,0.06
H3K27me3,0.0058,0.0014,0.0006,0.0004,0.0003,0.0003,0.0002,0.0002,0.0001,0.0001
H3K4me1,0.1327,0.111,0.1384,0.1488,0.1404,0.1546,0.133,0.1301,0.1144,0.0936
H3K36me3,0.007,0.0067,0.0063,0.0054,0.0051,0.0046,0.0045,0.0037,0.0036,0.0033
H3K9me3,0.0206,0.0176,0.0157,0.0144,0.0135,0.0109,0.0093,0.009,0.0077,0.0066
H3K9ac,0.0058,0.0361,0.0517,0.0605,0.06,0.0604,0.0582,0.0646,0.0496,0.0434
H3K4me2,0.0567,0.0186,0.0103,0.006,0.0048,0.0023,0.0015,0.0008,0.0006,0.0006
H4K20me1,0.0012,0.0011,0.0009,0.0007,0.0009,0.0006,0.0005,0.0006,0.0005,0.0004


In [13]:
refi_combs = [pdx.index[pdx.max(1) > thre].to_list() for thre in [0, 0.05, 0.1, 0.15, 0.2, 0.25]]

# xpHIS

In [14]:
def Reset_DualOyster(dualoyster):
    #ds consists of a "conv" and a "dense" module. Need to go through each one, see if its a conv and reset if so. 


    for oyster in [dualoyster.OysterA, dualoyster.OysterB]: 
        lke, los = len(oyster.kE), len(oyster.O)

        for i in np.arange(lke): 
            if isinstance(oyster.kE[i], nn.Conv2d): 
                oyster.kE[i].reset_parameters()
        
        for i in np.arange(los): 
            if isinstance(oyster.O[i], nn.Conv2d): 
                oyster.O[i].reset_parameters()
        print('done reset mod')

    return dualoyster

In [15]:
masterpseudo = 1e-10

RMS_mode = [tu.MeanExpo, {'expo': 2, 'root': True}]
RMS_mode_pyt = [tu.MeanExpo, {'expo': 2, 'root': True, 'pyt': True}]

bm_args_np = {'byaxis': 1, 'useweights': False, 'seperate': False, 
              'summarize_mode': RMS_mode}

bm_args_pyt = {'byaxis': 1, 'useweights': False, 'seperate': False, 
               'summarize_mode': RMS_mode_pyt, 'pyt': True}

#------------------------------------------------

deverr_args_base = {'expo': 2, 'root': True, 
                    'pseudo': masterpseudo,
                    'scalefactor': yy1_max}

RMSDE_mode_np = [tu.DeviaError, {**deverr_args_base, 'pyt': False}]
RMSDE_mode_pyt = [tu.DeviaError, {**deverr_args_base, 'pyt': True}]  

B_RMSDE_mode_np = [tu.BinnedLoss, {'metrics_mode': RMSDE_mode_np, **bm_args_np}]
B_RMSDE_mode_pyt = [tu.BinnedLoss, {'metrics_mode': RMSDE_mode_pyt, **bm_args_pyt}]

In [16]:
flipo = [[-2, -1], [-2, -1]]

TT_base =  {
    'inps': [d_x1, d_x2], 'out': d_y_ms, 
    'out_std': d_s_rs, 'out_bind': d_b,
    'Split': yy1_split_transaug,   
    'EUS': None, 'obs_weight': None,                    #!@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    
    'metrics_mode': B_RMSDE_mode_np, 'smallest': True,
    'opt': Adam, 'maxepochs': 100, 
    'patience': 8, 'pickup': True,
    'flips': flipo, 'indivflips': True}

TCS_base = {'trainer': tu.TridentTrainer, 
                 'smallest': None,
                 'get_predictions': True, 'pred_rewrite': False, 
                 'add_pred_args': {'batchsize': 512, 'flips': flipo, 'avg_flips': True},
                 'score_on': 1, 'score_only': False} ### SCORING ON THE EVAL SET, MAKING ONMLY PREDS FOR IT. 

TCR_repeats = 10

TCR_base = {'Splits': None, 'repeats': TCR_repeats,
               'pickup': False, 'savemodels': True, 'returnmodel': True}

#--------------------------------------------------------------------------

es_args = {'out': d_y_ms, 'out_std': d_s_rs, 'out_bind': d_b,
            'split': yy1_split_transaug,
            'metrics_mode': B_RMSDE_mode_np, 
            'score_on': 1,
            'std_cutoff': None, 'ddof': 1, 'top': 3, 'smallest': True}

In [17]:

import glob, shutil 

def XpsCopy(cur_pn, cur_xp_id, 
            past_xp_id, past_unit, 
            rewrite = False):
                         
    basetargo = cur_pn.replace(cur_xp_id, past_xp_id)[:-2] + str(past_unit) + '/'

    targnames = glob.glob(basetargo + '*')
    lbt = len(basetargo)

    targnames1 = [t for t in targnames if '.p' in t or '.pt' in t]
    newnames1 = [cur_pn + t[lbt:] for t in targnames1]

    print(targnames1, newnames1)


    if rewrite: 
        for n in newnames1: 
            if os.path.isfile(n): os.remove(n)
    for t,n in zip(targnames1, newnames1): 
        if os.path.isfile(n): break #means we already have it
        shutil.copyfile(t, n)

    targnames2 = [t for t in targnames if '.p' not in t and'.pt' not in t]
    newnames2 = [cur_pn + t[lbt:] for t in targnames2]
    if rewrite: 
        for n in newnames2: 
            if os.path.isdir(n): shutil.rmtree(n)
    for t, n in zip(targnames2, newnames2): 
        if os.path.isdir(n): break #means we already have it             
        shutil.copytree(t, n, dirs_exist_ok=True)
    
    print(f'finished copying for {cur_pn}')
    
    return

def XpsBootstrapEnsemb(combs, pathname, rewrite = False):

    iterx = 100

    boots_scores, boots_idxs = [], []

    for icom, com in combs: 
        pn_t = se.NewFolder(pathname + str(icom))

        pn_preds = [pn_t + '0_' + str(ir) + '_Preds' + '.p' ###################
                    for ir in range(TCR_repeats)]
        
        pn_e = pn_t + 'boots_ensemb'

        if os.path.isfile(pn_e + '.p') and rewrite is False: 
            boots_ensemb = se.PickleLoad(pn_e)

        else: 
            boots_ensemb = co.Bootstrapper(pn_preds, mode = [co.EnsembleScorer, es_args],
                                          iters = iterx, updates = 20, return_idx = True)
            se.PickleDump(boots_ensemb, pn_t + 'boots_ensemb')
        
        boots_scores.append(boots_ensemb[0])
        boots_idxs.append(boots_ensemb[1])

        print(f'---- Finished {icom} ----')
    
    return np.array(boots_scores), np.array(boots_idxs)

def XpsResults(pn_xp, xp_combs,
               icof_only = None, 
               configs = None,
               ref_ic = 0, rewrite = False): 
    
    pf_args = {'mode1': [se.RelativeChange, {'perc': True}]}

    bs_all = []
    bs_all_r2r = []

    if configs is None: configs = {0: {}}

    for icof, (cof_name, cof) in enumerate(configs.items()):

        if icof_only is not None: 
            if icof not in icof_only: continue 
                            
        pn_t1 = se.NewFolder(pn_xp + str(icof))

        boots_scores = XpsBootstrapEnsemb(xp_combs, pn_t1, rewrite = rewrite)[0] # Just get scores 

        bs_all.append(boots_scores)

        com = boots_scores[ref_ic]

        rel2refs = [se.PairwiseFuncer(boots_scores[ic], com, **pf_args) for ic, c in xp_combs]

        bs_all_r2r.append(rel2refs)

        print(f'**** Finished {icof} ****')
                    
    bs_all = np.stack(bs_all)
    se.PickleDump(bs_all, pn_xp + 'bs_all')

    bs_all_r2r = np.stack(bs_all_r2r)
    se.PickleDump(bs_all_r2r, pn_xp + 'bs_all_r2r' + '_' + str(ref_ic))

    return bs_all, bs_all_r2r

def XpsTables(pn_xp, xp_combs, xp_variables, ber2rs, icof_only = None, configs = None):

    if configs is None: configs = {0: {}}
    
    onesided = None
    conf_alpha = 0.90

    print(ber2rs.shape)

    ber2rs = ber2rs[:, :, :, 2] ####################################

    ber2rs_mean = np.mean(ber2rs, axis = -1)
    ber2rs_se = np.std(ber2rs, axis = -1)
    ber2rs_low, ber2rs_high = co.BootstrapConfidenceInterval(ber2rs, alpha = conf_alpha, onesided=onesided, axis = -1)

    def sigo(low, high): 
        if onesided == 'lesser': sigi = 0 > high
        if onesided == 'greater': sigi = low > 0 
        if onesided == None: sigi = np.logical_or(0 > high, low > 0)
        return sigi
    
    ber2rs_sigo = sigo(ber2rs_low, ber2rs_high)

    #################################################

    per = pd.DataFrame(ber2rs_mean).T

    if len(xp_variables) > 1: 
        # per.index = pd.MultiIndex.from_tuples([tuple(x[1]) for x in cur_combs], names=cur_variables)
        per.index = [tuple(x[1]) for x in xp_combs]
        per.index.names = ['Combination']
    else: 
        per.index = [tuple(x[1][0]) for x in xp_combs]
        per.index.names = xp_variables
    
    if len(xp_combs[0][1]) > 1: 
        multcolx = pd.MultiIndex.from_tuples([tuple(g[1]) for g in xp_combs])
    else: multcolx = [str(g[1][0]) for g in xp_combs]
    
    configos = list(configs.keys())
    if icof_only is not None: configos = [cz for icz, cz in enumerate(configos) if icz in icof_only]

    per.columns = configos

    per_style = per.style.format(precision=1).background_gradient(axis = 1, vmin = -25, vmax = 0, cmap = 'Greys_r')


    per_style = per_style.apply(lambda _: 
                                np.where(ber2rs_sigo.reshape(-1, len(multcolx)).T, 
                                         'font-weight: bold', ''),
                                           axis = None)

    return per, per_style

In [18]:
xpHIS_trial = 2
xpHIS_id = 'xpHIS' + '_' + str(xpHIS_trial) 
pn_xpHIS = se.NewFolder(xpsfolder + xpHIS_id)

xpHIS_variables = ['HisCombs']

hiscombs = [pdx.index[pdx.max(1) > thre].to_list() for thre in [0, 0.05, 0.1, 0.15, 0.2, 0.25]]
hiscombs_nodna = [[h for h in hx if h != 'DNA'] for hx in hiscombs]
hiscombs_nodna_idx = [[histones.index(h) for h in hx] for hx in hiscombs_nodna]

xpHIS_combs = [[ic, [c]] for ic,c in enumerate(hiscombs_nodna_idx)]
xpHIS_combs

[[0, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]],
 [1, [[0, 1, 3, 6, 7, 9, 10]]],
 [2, [[0, 3, 9, 10]]],
 [3, [[3, 9]]],
 [4, [[9]]],
 [5, [[]]]]

In [19]:
[str(x[1][0]) for x in xpHIS_combs]

['[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]',
 '[0, 1, 3, 6, 7, 9, 10]',
 '[0, 3, 9, 10]',
 '[3, 9]',
 '[9]',
 '[]']

In [20]:
modo = oys.DualOyster
duds_mode = [Reset_DualOyster, {}]

lmo, lbx = RMSDE_mode_pyt, False
sampweimode = [bm.BinWeighter, {'uni': np.arange(fin_bins), 'byaxis': None, 
                                    'minus': False, 'newrange': True}]
usxo = 0.2                                                                          #!!!!!!!!!!!!!!!!!!!!!!

xpHIS_TT_args = {'duds_mode': duds_mode, 'duds': TCR_repeats,
                'loss_mode': lmo, 'loss_bind': lbx,
                'weights_mode': sampweimode, 'weights_bind': True,          
                **TT_base}

xpHIS_TT_args.update({'EUS': usxo, 'obs_weight': d_b.reshape(-1)})

xpHIS_TCS_args = {'trainer_args': xpHIS_TT_args, **TCS_base}

modo = oys.DualOyster
duds_mode = [Reset_DualOyster, {}]


xpHIS_rewrite = False

for icof in np.arange(1):                                   # PLACEHOLDER

    for icom, com in xpHIS_combs:

        pn_t1 = se.NewFolder(pn_xpHIS + str(icof))
        pn_t2 = se.NewFolder(pn_t1 + str(icom))

        par_xp, par_unit = None, None

        if icom == 0: par_xp, par_unit = xp6_id, 3                                  #!!!!!!!!!!!!!!!!!!!!!!

        print(par_xp, par_unit)

        if par_unit is not None: 
            XpsCopy(cur_pn = pn_t2, cur_xp_id = xpHIS_id, 
                    past_xp_id = par_xp, past_unit = par_unit, 
                    rewrite = xpHIS_rewrite)
        
        xpHIS_TCR_args = {**xpHIS_TCS_args, **TCR_base, 'pathname': pn_t2}

        #-------------------------------------------

        qox = com[0]

        print(qox)
        
        cof = copy.deepcopy(exact_dict)
        cof.update({'B_incl_only': qox})

        _, _= tu.TridentCanRepeater(modo, cof, data = None, **xpHIS_TCR_args)

xp6_2 3
['./xps4/xp6_2/0/3/0_0_Met.p', './xps4/xp6_2/0/3/0_0_Mod.pt', './xps4/xp6_2/0/3/0_0_Preds.p', './xps4/xp6_2/0/3/0_1_Met.p', './xps4/xp6_2/0/3/0_1_Mod.pt', './xps4/xp6_2/0/3/0_1_Preds.p', './xps4/xp6_2/0/3/0_2_Met.p', './xps4/xp6_2/0/3/0_2_Mod.pt', './xps4/xp6_2/0/3/0_2_Preds.p', './xps4/xp6_2/0/3/0_3_Met.p', './xps4/xp6_2/0/3/0_3_Mod.pt', './xps4/xp6_2/0/3/0_3_Preds.p', './xps4/xp6_2/0/3/0_4_Met.p', './xps4/xp6_2/0/3/0_4_Mod.pt', './xps4/xp6_2/0/3/0_4_Preds.p', './xps4/xp6_2/0/3/0_5_Met.p', './xps4/xp6_2/0/3/0_5_Mod.pt', './xps4/xp6_2/0/3/0_5_Preds.p', './xps4/xp6_2/0/3/0_6_Met.p', './xps4/xp6_2/0/3/0_6_Mod.pt', './xps4/xp6_2/0/3/0_6_Preds.p', './xps4/xp6_2/0/3/0_7_Met.p', './xps4/xp6_2/0/3/0_7_Mod.pt', './xps4/xp6_2/0/3/0_7_Preds.p', './xps4/xp6_2/0/3/0_8_Met.p', './xps4/xp6_2/0/3/0_8_Mod.pt', './xps4/xp6_2/0/3/0_8_Preds.p', './xps4/xp6_2/0/3/0_9_Met.p', './xps4/xp6_2/0/3/0_9_Mod.pt', './xps4/xp6_2/0/3/0_9_Preds.p', './xps4/xp6_2/0/3/boots_ensemb.p', './xps4/xp6_2/0/3/ensemble

In [36]:
xpHIS_bs_all, xpHIS_bs_r2r_all = XpsResults(pn_xpHIS, xpHIS_combs, ref_ic = 5, rewrite = False)

xpHIS_tab_main, xpHIS_tab_main_sty = XpsTables(pn_xpHIS, xpHIS_combs, xpHIS_variables, xpHIS_bs_r2r_all)

xpHIS_tab_main_sty 

---- Finished 0 ----
---- Finished 1 ----
---- Finished 2 ----
---- Finished 3 ----
---- Finished 4 ----
---- Finished 5 ----
**** Finished 0 ****
(1, 6, 10000, 4)


Unnamed: 0_level_0,0
HisCombs,Unnamed: 1_level_1
"(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)",-14.8
"(0, 1, 3, 6, 7, 9, 10)",-12.1
"(0, 3, 9, 10)",2.6
"(3, 9)",-3.5
"(9,)",6.9
(),0.1
