In [31]:
import pandas as pd
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
import glob
import seaborn as sns
import tqdm
import mat4py
from sklearn import metrics
import logomaker
from sklearn.metrics import PrecisionRecallDisplay, RocCurveDisplay
import seqlogo
import scipy
import re
from matplotlib import gridspec
import scipy
from sklearn.metrics import auc, average_precision_score
from collections import OrderedDict
import torch.nn as nn
import torch
import joblib
import torch.nn.functional as F
import torch.optim as optim
import os
from joblib import Parallel, delayed
%matplotlib notebook

In [2]:
torch.manual_seed(666)
torch.set_deterministic(True)
np.random.seed(666)

## Motif Position Probability Matrix (PPM) distance

In [None]:
import matlab.engine
eng = matlab.engine.start_matlab()
# motif Position Probability Matrix (PPM) distance
def kc_motif_distance(trimed_true_dimer, trimed_pred_dimer):
    t = matlab.double(trimed_true_dimer.tolist())
    p = matlab.double(trimed_pred_dimer.tolist())
    prof = np.asarray(eng.profalign(p, t)).T
    t = np.asarray(t).T
    p = np.asarray(p).T
    add_mat = prof[:,:4]
    gap_locus = prof[:,4]
    total_dist = 0
    max_dist = -1
    t_gap = 0
    p_gap = 0
    for i, lo in enumerate(gap_locus):
        if lo == 0:
            dist = np.sqrt(sum((p[i-p_gap]-t[i-t_gap])**2))
            total_dist += dist
            if dist > max_dist:
                max_dist = dist
        else:
            if i-t_gap >= len(t):
                t_gap += 1
            elif i-p_gap >= len(p):
                p_gap += 1
            elif sum(abs(add_mat[i] - t[i-t_gap])) < sum(abs(add_mat[i] - p[i-p_gap])):
                p_gap += 1
            else:
                t_gap += 1
    return (total_dist + (p_gap+t_gap)*max_dist)/len(gap_locus)

In [32]:
def get_rev_com_y(seq_mat):
    reversed_mat = seq_mat[::-1].copy()
    for i in range(len(reversed_mat)):
        reversed_mat[i] = np.concatenate((reversed_mat[i][:4][::-1], reversed_mat[i][4:]))
    return reversed_mat

def build_df4logmarker(seq_mat):
    df = pd.DataFrame(columns=['pos', 'A', 'C', 'G', 'T'])
    df['pos'] = np.arange(len(seq_mat))
    df['A'] = seq_mat[:, 0]
    df['C'] = seq_mat[:, 1]
    df['G'] = seq_mat[:, 2]
    df['T'] = seq_mat[:, 3]
    df.set_index(["pos"], inplace=True)
    return df

def get_icDimer(dimer):
    cpm = seqlogo.CompletePm(pfm =(np.round(dimer.T, 6)*1e6).astype(int))
    cpm_ic = np.repeat(np.expand_dims(cpm.ic.to_numpy(), axis=1),4, axis=1)
    # print(cpm_ic.shape, dimer.shape)
    return cpm_ic*dimer, cpm.ic.to_numpy()

## DeepMotifSyn Generator

In [4]:
# DeepMotifSyn generator
class deeper_u_net(nn.Module):
    def __init__(self,device='cuda'):
        super(deeper_u_net, self).__init__()
        self.device = device
        self.encoder1 = nn.Sequential(
            nn.Conv1d(in_channels = 108, out_channels = 64, kernel_size = 4, stride = 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )
        self.encoder2 = nn.Sequential(
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size = 2, stride = 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )
        
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose1d(in_channels = 128, out_channels = 64, kernel_size = 2, stride = 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )
        
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose1d(in_channels=64*2, out_channels=64, kernel_size = 4, stride = 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )
        
        self.bottleneck = torch.nn.Sequential(
                            torch.nn.Conv1d(kernel_size=3, in_channels=128, out_channels=256, stride=1),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm1d(256),
                            torch.nn.Conv1d(kernel_size=1, in_channels=256, out_channels=256, stride=1),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm1d(256),
                            torch.nn.ConvTranspose1d(in_channels=256, out_channels=128, kernel_size=3, stride=1)
                            )
        
        self.cnn_out = nn.Sequential(
            nn.Conv1d(in_channels=64+8, out_channels=32, kernel_size=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Conv1d(in_channels=32, out_channels=16, kernel_size=1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Conv1d(in_channels=16, out_channels=4, kernel_size=1),
            nn.BatchNorm1d(4),
            nn.Softmax(dim=1)
        )
        
    def forward(self,x):
        seq_x = x[:,:,:8]
        x = x.permute(0, 2, 1)
        
        en1 = self.encoder1(x)
        en2 = self.encoder2(en1)
        x = self.bottleneck(en2)

        x = self.decoder1(x)
        x = self.decoder2(torch.cat((x, en1), 1))
        
        seq_x = seq_x.permute(0 ,2, 1)

        x = torch.cat((x, seq_x), 1)
        out = self.cnn_out(x)
        return out

In [21]:
kc_dimer_info = pd.read_csv("../data/kc_dimer_info.csv")
# homomotif_seq_dict = pkl.load(open("../data/homodimerMotifDatabase_dict.pkl", "rb"))
# motif_seq_dict = pkl.load(open("../data/motifDatabase_dict.pkl", "rb"))
dimer_seq_dict = pkl.load(open("../data/dimerMotifDatabase_dict.pkl", "rb"))
dimerfamily_dict = pkl.load(open("../data/dimerMotifFamily_dict.pkl", "rb"))
# monomeric motif PPM
monomeric_PPM_dict = pkl.load(open("../data/MonomericMotif_PPM_dict.pkl", "rb"))
# map_info = pkl.load(open("../mp_dimer_aligned_info.pkl", "rb"))
found_mp_name, found_mp_family, found_mp_code, true_mp_code, found_mp_dimer_code = pkl.load(open("../data/found_best_aligned_mp_allFam_correctedFamilyName.pkl", "rb"))
# mp_info_df = pd.read_csv("./generated_motifpairs/generated_motifpairs_with_label_rmDuplicates.csv")

In [14]:
# # combine monemeric motif seq dict
# motifpair_names = kc_dimer_info['nameOut']
# combine_mp_dict = {}
# for idx, d_info in kc_dimer_info.iterrows():
#     motif1_name = d_info['name1']
#     motif2_name = d_info['name2']
#     # dimer_name = d_info['nameOut']
#     try:
#         motif1_seq = motif_seq_dict[motif1_name]
#         combine_mp_dict[motif1_name] = motif_seq_dict[motif1_name]
#     except:
#         motif1_seq = homomotif_seq_dict[motif1_name]
#         combine_mp_dict[motif1_name] = homomotif_seq_dict[motif1_name]
#     try:
#         motif2_seq = motif_seq_dict[motif2_name]
#         combine_mp_dict[motif2_name] = motif_seq_dict[motif2_name]
#     except:
#         motif2_seq = homomotif_seq_dict[motif2_name]
#         combine_mp_dict[motif2_name] = homomotif_seq_dict[motif2_name]

In [19]:
# pkl.dump(combine_mp_dict, open("../data/MonomericMotif_PPM_dict.pkl", "wb"))

In [11]:
dnameToFamily_dict = pkl.load(open("../data/dimer_name_family_upper_dict.pkl", "rb"))
family_onehot_encode = pkl.load(open("../data/kc_heterodimer_family_all614dimers_upper_oneHotDict.pkl", "rb"))
dimer_familly_mpCase_dict = pkl.load(open("../data/dimerMotif_87family_feaures_dict.pkl", "rb"))

In [23]:
FLI1_ppm = monomeric_PPM_dict['FLI1']
FOXI1_ppm =  monomeric_PPM_dict['FOXI1']

In [24]:
def get_possible_motifpair(name, motif1, motif2, max_mp_len = 35):
    generative_mp_info = []
    possible_mp_code = []
    # max_mp_len = 35
    
    for case_i in range(1, 5):
        if case_i == 1:
            m1 = motif1.copy()
            m2 = motif2.copy()
        elif case_i == 2:
            m1 = motif2.copy()
            m2 = motif1.copy()
        elif case_i == 3:
            m1 = get_rev_com_y(motif2).copy()
            m2 = motif1.copy()
        else:
            m1 = motif1.copy()
            m2 = get_rev_com_y(motif2).copy()
        m1_len = len(m1)
        m2_len = len(m2)
        overlap_len=None
        for m1_si in range(max_mp_len):
            m1_ei = m1_si + m1_len
            if m1_ei >= max_mp_len:
                continue
            for m2_si in range(m1_si, max_mp_len):
                m2_ei = m2_si + m2_len
                if m2_ei >= max_mp_len:
                    continue
                m1_code = np.zeros((35, 4))
                m2_code = np.zeros((35, 4))
                m1_code[m1_si:m1_ei] = m1
                m2_code[m2_si:m2_ei] = m2
                mp_code = np.concatenate([m1_code, m2_code], axis=-1)
                
                if m2_si >= m1_ei:
                    overlap_len = -(m2_si - m1_ei)
                else:
                    overlap_len = sum(mp_code.sum(-1) > 1.1)
        
                possible_mp_code.append(mp_code)
                #print(case_i, overlap_len)
                generative_mp_info.append([case_i, overlap_len, m1_si, m2_si])
    generative_mp_info = np.array(generative_mp_info)
    df = pd.DataFrame(columns=['dimer_name', 'orientation_case', 'overlapping_len', 'm1_start_idx', 'm2_start_idx'])
    df['dimer_name']  = [name] * len(generative_mp_info)
    df['orientation_case'] = generative_mp_info[:, 0]
    df['overlapping_len'] = generative_mp_info[:, 1]
    df['m1_start_idx'] = generative_mp_info[:, 2]
    df['m2_start_idx'] = generative_mp_info[:, 3]
    return df, np.array(possible_mp_code)

## Generater all possible aligned motif pair of FLI1-FOXI1

In [27]:
possible_mp_df, possible_mp_code = get_possible_motifpair('FLI1_FOXI1', FLI1_ppm, FOXI1_ppm)    

## Load DeepMotifSyn genertator

In [25]:
net = deeper_u_net().cuda()
net.eval()
checkpoint = torch.load('../model/fold_FOX_ETS-FLI1_FOXI1-FLI1_FOXI1_2-FLI1_FOXI1_3-FLI1_FOXI1_4_best_u_net_mp_checkpoint.pt')
net.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [29]:
# Add familyt one-hot code to motif pair matrix
family_code = family_onehot_encode[dnameToFamily_dict['FLI1_FOXI1']][np.newaxis,:]
family_code = np.repeat(family_code, 35, axis=1)
family_code = np.repeat(family_code, len(possible_mp_code), axis=0)
possible_mp_code_witFAM = np.concatenate([possible_mp_code, family_code], axis=-1)
possible_mp_code_witFAM.shape

(1218, 35, 108)

## Synthesize heterdimeric motif

In [30]:
net.eval()
with torch.no_grad():
    X_input = torch.from_numpy(possible_mp_code_witFAM).cuda().float()
    pred_dimer = net(X_input).cpu().detach().numpy()

## build motif evaluator features

In [33]:
overlap_len_dict = {}
onehot_idx = 0
for i in range(-20, 21):
    overlap_len_dict[i] = onehot_idx
    onehot_idx+=1
    
def encode_aligned_features(olen, case):
    case_code = np.zeros(4)
    olen_code = np.zeros(41)
    case_code[case-1] += 1
    olen_code[overlap_len_dict[olen]] += 1
    return case_code, olen_code

In [34]:
# def build_features
possible_mp_784feaures = []
for idx, row in tqdm.tqdm(possible_mp_df.iterrows()):
    # dimer_idx = row['generated_mp_idx']
    # print(idx)
    ol_len = row['overlapping_len']
    motif1 = possible_mp_code[idx, :, :4]
    motif2 = possible_mp_code[idx, :, 4:8]
    for i, b in enumerate(motif1):
        if sum(b) < 0.3:
            motif1[i] = np.array([0.25, 0.25, 0.25, 0.25])
    for i, b in enumerate(motif2):
        if sum(b) < 0.3:
            motif2[i] = np.array([0.25, 0.25, 0.25, 0.25])
    
    overlapping_position = np.zeros(35)
    motif1_ic, ic1 = get_icDimer(motif1)
    motif2_ic, ic2 = get_icDimer(motif2)
    overlap_locus = np.logical_and(motif2_ic.sum(-1)>10e-7,motif1_ic.sum(-1)>10e-7)
    overlap_len = sum(overlap_locus)
#   print(overlap_len, ol_len)
# 
    
    motif1_ol = np.zeros((18, 4))-1
    mp_ol_diff = np.zeros((18))-1
    motif2_ol = np.zeros((18, 4))-1 
    mp_ol_sum = np.zeros((18, 4))-1
    motif1_ol_ic = np.zeros(18) - 1
    motif2_ol_ic = np.zeros(18) - 1
    mp_ic_diff = np.zeros(18) - 1
    
    if overlap_len > 0:
        # overlap_seq_withIC
        #  print(motif1_ic.shape)
        motif1_ol[:overlap_len] = motif1_ic[overlap_locus]
        motif2_ol[:overlap_len] = motif2_ic[overlap_locus]
        # euclidean distance of seq with IC
        mp_ol_diff[:overlap_len] = np.sqrt(((motif1_ic[overlap_locus] - motif2_ic[overlap_locus])**2).sum(-1))
        # overlap_seq_sum
        mp_ol_sum[:overlap_len] = motif1_ic[overlap_locus] + motif2_ic[overlap_locus]
        # overlap ic and its mse
        motif1_ol_ic[:overlap_len] = ic1[overlap_locus]
        motif2_ol_ic[:overlap_len] = ic2[overlap_locus]
        mp_ic_diff[:overlap_len] = np.sqrt((ic1[overlap_locus] - ic2[overlap_locus])**2)
        # AGCT stats
        motif1_overlap_base = motif1_ic[overlap_locus,:].sum(0)
        motif2_overlap_base = motif2_ic[overlap_locus,:].sum(0)
    
    overlap_feats = np.concatenate([motif1_ol.flatten(), motif2_ol.flatten(), motif1_ol_ic, motif1_ol_ic, mp_ol_diff.flatten(), mp_ol_sum.flatten(), mp_ic_diff, motif2_overlap_base.flatten(), motif1_overlap_base.flatten()])
    seq_feats = np.concatenate([motif1_ic.flatten(), motif2_ic.flatten(), ic1.flatten(), ic2.flatten()])
    feats = np.concatenate([overlap_feats, seq_feats])

    dname = row['dimer_name']
    dimerFam = dnameToFamily_dict[dname]
    ol_len = row['overlapping_len']

    case_fam, olen_fam, _, _= dimer_familly_mpCase_dict[dimerFam]
    case_gmp, olen_gmp = encode_aligned_features(row['overlapping_len'], row['orientation_case'])
    case_mul, olen_mul = case_fam*case_gmp, olen_fam*olen_gmp
    mul_sum = [sum(case_mul), sum(olen_mul)]
    total_sum = [sum(mul_sum)]

    generative_seq_ic, _ = get_icDimer(pred_dimer[idx].T)
    generative_seq_ic = generative_seq_ic.flatten()
    feat_784 = np.concatenate([feats, case_fam, olen_fam, case_gmp, olen_gmp, case_mul, olen_mul, mul_sum, total_sum])
    possible_mp_784feaures.append(feat_784)

possible_mp_784feaures = np.array(possible_mp_784feaures)

1218it [00:13, 92.11it/s]


In [35]:
possible_mp_924feaures = np.concatenate([possible_mp_784feaures, pred_dimer.reshape(len(pred_dimer), -1)], axis=-1)

## Load XGBoost

In [36]:
# XGboost
xgboost_evaluator = joblib.load("../model/fold_FLI1_FOXI1_XGBoost_bestHyper_924features.joblib")

In [None]:
## Score every generated heterodimeric motif 

In [37]:
possible_mp_score = xgboost_evaluator.predict_proba(possible_mp_924feaures)
possible_mp_df['score'] = possible_mp_score[:, 1]

In [39]:
sorted_mp_df = possible_mp_df.sort_values(by='score', ascending=False)

In [40]:
sorted_mp_df

Unnamed: 0,dimer_name,orientation_case,overlapping_len,m1_start_idx,m2_start_idx,score
613,FLI1_FOXI1,3,3,0,4,0.453138
382,FLI1_FOXI1,2,3,0,4,0.260644
846,FLI1_FOXI1,4,7,0,6,0.188519
851,FLI1_FOXI1,4,3,0,11,0.137153
7,FLI1_FOXI1,1,7,0,7,0.077959
...,...,...,...,...,...,...
511,FLI1_FOXI1,2,0,7,14,0.000019
728,FLI1_FOXI1,3,-1,6,14,0.000019
985,FLI1_FOXI1,4,-1,5,20,0.000019
482,FLI1_FOXI1,2,-2,5,14,0.000019


## Visualize DeepMotifSyn predicted heteormider motifs

In [45]:
# Monomeric motif pair
sns.set_style("white")
df = build_df4logmarker(FLI1_ppm)
logo = logomaker.Logo(df,
                         width=.8,
                         vpad=.05,
                         fade_probabilities=True,
                         stack_order='small_on_top')

plt.title('FLI1')
logo.ax.set_xticks(range(len(df)))
plt.tight_layout()
sns.despine(left=False)
df = build_df4logmarker(FOXI1_ppm)
logo = logomaker.Logo(df,
                         width=.8,
                         vpad=.05,
                         fade_probabilities=True,
                         stack_order='small_on_top')
logo.ax.set_xticks(range(13))
plt.title('FOXI1')
plt.tight_layout()
sns.despine(left=False)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [49]:
mp_idx = 382
tdimer = 'FLI1_FOXI1_3'
motif1 = possible_mp_code[mp_idx,:,:4]
motif2 = possible_mp_code[mp_idx,:,4:]
tdimer = found_mp_dimer_code[found_mp_name == tdimer][0]
pdimer = pred_dimer[mp_idx].T
for idx, b in enumerate(motif1):
    if sum(b) < 0.3:
        motif1[idx] = np.array([0.25, 0.25, 0.25, 0.25])
for idx, b in enumerate(motif2):
    if sum(b) < 0.3:
        motif2[idx] = np.array([0.25, 0.25, 0.25, 0.25])
for idx, b in enumerate(tdimer):
    if sum(b) < 0.3:
        tdimer[idx] = np.array([0.001, 0.001, 0.001, 0.001])
motif1_df = build_df4logmarker(motif1)
trim_len = 13
motif1_df = build_df4logmarker(motif1[:trim_len])
logo = logomaker.Logo(motif1_df,
                         width=.8,
                         vpad=.05,
                         fade_probabilities=True,
                         stack_order='small_on_top')
print('Monomeric Motif ')
motif2_df = build_df4logmarker(motif2[:trim_len])
logo = logomaker.Logo(motif2_df,
                         width=.8,
                         vpad=.05,
                         fade_probabilities=True,
                         stack_order='small_on_top')
pred_dimer_dimer_df = build_df4logmarker(pdimer[:trim_len])
logo = logomaker.Logo(pred_dimer_dimer_df,
                         width=.8,
                         vpad=.05,
                         fade_probabilities=True,
                         stack_order='small_on_top')
plt.title("Predicted heterodimeric motif")
true_dimer_df = build_df4logmarker(tdimer[:trim_len])
logo = logomaker.Logo(true_dimer_df,
                         width=.8,
                         vpad=.05,
                         fade_probabilities=True,
                         stack_order='small_on_top')
plt.title("True heterodimeric motif")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'True heterodimeric motif')