In [10]:
import pandas as pd
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
import glob
import seaborn as sns
import tqdm
import mat4py
from sklearn import metrics
import logomaker
from sklearn.metrics import PrecisionRecallDisplay, RocCurveDisplay
import seqlogo
import scipy
import re
from matplotlib import gridspec
import scipy
from sklearn.metrics import auc, average_precision_score
from collections import OrderedDict
import torch.utils.data as data_utils
import torch.nn as nn
import torch
import joblib
import torch.nn.functional as F
import torch.optim as optim
import os
from joblib import Parallel, delayed
%matplotlib notebook

In [45]:
import xgboost
import pickle as pkl
import time
import numpy as np
from datetime import datetime
from xgboost import XGBClassifier
from sklearn.model_selection import StratifiedKFold, KFold
from sklearn.model_selection import RandomizedSearchCV
import os
import pandas as pd
from joblib import dump, load
from sklearn.metrics import auc, average_precision_score
from xgboost import XGBClassifier
from sklearn import metrics

In [2]:
torch.manual_seed(666)
torch.set_deterministic(True)
np.random.seed(666)

In [18]:
# DeepMotifSyn generator
class deeper_u_net(nn.Module):
    def __init__(self,device='cuda'):
        super(deeper_u_net, self).__init__()
        self.device = device
        self.encoder1 = nn.Sequential(
            nn.Conv1d(in_channels = 108, out_channels = 64, kernel_size = 4, stride = 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )
        self.encoder2 = nn.Sequential(
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size = 2, stride = 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )
        
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose1d(in_channels = 128, out_channels = 64, kernel_size = 2, stride = 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )
        
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose1d(in_channels=64*2, out_channels=64, kernel_size = 4, stride = 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )
        
        self.bottleneck = torch.nn.Sequential(
                            torch.nn.Conv1d(kernel_size=3, in_channels=128, out_channels=256, stride=1),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm1d(256),
                            torch.nn.Conv1d(kernel_size=1, in_channels=256, out_channels=256, stride=1),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm1d(256),
                            torch.nn.ConvTranspose1d(in_channels=256, out_channels=128, kernel_size=3, stride=1)
                            )
        
        self.cnn_out = nn.Sequential(
            nn.Conv1d(in_channels=64+8, out_channels=32, kernel_size=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Conv1d(in_channels=32, out_channels=16, kernel_size=1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Conv1d(in_channels=16, out_channels=4, kernel_size=1),
            nn.BatchNorm1d(4),
            nn.Softmax(dim=1)
        )
        
    def forward(self,x):
        seq_x = x[:,:,:8]
        x = x.permute(0, 2, 1)
        
        en1 = self.encoder1(x)
        en2 = self.encoder2(en1)
        x = self.bottleneck(en2)

        x = self.decoder1(x)
        x = self.decoder2(torch.cat((x, en1), 1))
        
        seq_x = seq_x.permute(0 ,2, 1)

        x = torch.cat((x, seq_x), 1)
        out = self.cnn_out(x)
        return out

In [19]:
def weighted_MSE(input_seq, label_seq):
    weight = ((label_seq != -1)*1.0).to(dtype=torch.float32)
    valid_len = weight.sum()/4
    dist = (input_seq-label_seq)**2
    w_dist = dist*weight
    mean_dist = w_dist.sum()/valid_len
    return mean_dist

def weighted_EuclideanDist(input_seq, label_seq):
    weight = ((label_seq != -1)*1.0).to(dtype=torch.float32)
    valid_len = weight.sum()/4
    dist = (input_seq-label_seq)**2
    w_dist = dist*weight
    w_dist = torch.sqrt(w_dist.sum(0)).sum()
    mean_dist = w_dist/valid_len
    return mean_dist

In [32]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model, epoch_i):
        score = val_loss
        self.save_checkpoint(val_loss, model, epoch_i)
        """
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch_i)
        elif score > self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch_i)
            self.counter = 0
        """
        
    def save_checkpoint(self, val_loss, model, epoch_i):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save({
            'epoch': epoch_i,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_loss,
            'whole_model': model}, 
            self.path)
        self.val_loss_min = val_loss
        # print(self.path)
        

def train(net, train_ds, fold_i, save_folder, model_name = 'tmp' ):
    X_tv, y_tv = train_ds
    X_train, X_valid = X_tv, X_tv
    y_train, y_valid = y_tv, y_tv
    # X_train, y_train = train_ds
    
    train = data_utils.TensorDataset(torch.from_numpy(X_train),torch.from_numpy(y_train))
    trainloader = data_utils.DataLoader(train, batch_size=100, shuffle=True, num_workers = 1)
    early_stopping = EarlyStopping(patience=7, 
                verbose=True, path= save_folder + "/fold_" + str(fold_i) + "_best_"+model_name+"_checkpoint.pt")
    
    optimizer = optim.Adam(net.parameters())
    for epoch in tqdm.tqdm(range(200)):
        running_loss = 0
        for data in trainloader:
            input_seq = data[0].to(device, non_blocking=True).float()
            label_seq = data[1].to(device, non_blocking=True).float()
            # input_family = data[2].to(device, non_blocking=True).float()
            # print(inputs.size())
            optimizer.zero_grad()
            output = net(input_seq)
            label_seq = label_seq.permute(0, 2, 1).contiguous()
            loss = weighted_MSE(output, label_seq)
            # propagate the loss backward
            loss.backward()
            # update the gradients
            optimizer.step()
            running_loss += loss.item()
        
    val_mse = validate(net, [X_valid, y_valid])
    early_stopping(val_mse, net, epoch)
    
def validate(net, valid_ds):
    from scipy import stats
    from sklearn.metrics import mean_squared_error
    X_valid, y_valid = valid_ds
    valid = data_utils.TensorDataset(torch.from_numpy(X_valid),torch.from_numpy(y_valid))
    validloader = data_utils.DataLoader(valid, batch_size=len(X_valid))
    net.eval()
    with torch.no_grad():
        for data in validloader:
            input_seq = data[0].to(device, non_blocking=True).float()
            label_seq = data[1].to(device, non_blocking=True).float()
            # input_family = data[2].to(device, non_blocking=True).float()
            label_seq = label_seq.permute(0, 2, 1).contiguous()
            output = net(input_seq)
            # mse_loss = nn.MSELoss()
            loss = weighted_MSE(output, label_seq)
        
    # slope, intercept, r_value, p_value, std_err = stats.linregress(preds, actual)
    print("valid: MSE", loss.item())
    return loss.item()

def test(net, test_ds, fold_i, save_folder, model_name='tmp'):
    from scipy import stats
    X_test, y_test = test_ds
    test = data_utils.TensorDataset(torch.from_numpy(X_test),torch.from_numpy(y_test))
    
    testloader = data_utils.DataLoader(test, batch_size=len(X_test))
    checkpoint = torch.load(save_folder + "/fold_" + str(fold_i) + "_best_"+model_name+"_checkpoint.pt")
    net.load_state_dict(checkpoint['model_state_dict'])
    print("Load the best model from fold_" + str(fold_i) + "_best_"+model_name+"_checkpoint.pt")
    net.eval()
    weight_mse_loss = []
    with torch.no_grad():
        preds = []
        actual = []
        # combine_preds = []
        for data in testloader:
            input_seq = data[0].to(device, non_blocking=True).float()
            label_seq = data[1].to(device, non_blocking=True).float()
            # input_family = data[2].to(device, non_blocking=True).float()
            # X_naive_combine = ((input_seq[:,:,:4] + input_seq[:,:,4:8])/2).permute(0, 2, 1)
            
            label_seq = label_seq.permute(0, 2, 1).contiguous()
            output = net(input_seq)
            weight_mse_loss.append(weighted_EuclideanDist(output, label_seq).item())
            outputs = list(output.cpu().detach().numpy())
            labels = list(label_seq.cpu().detach().numpy())
            preds += outputs
            actual += labels
            # combine_preds += X_naive_combine
    
    fold_res = []
    # fold_naive_res = []
    for i in range(len(preds)):
        pred_dimer = preds[i]
        true_dimer = actual[i]
        # combine_dimer = combine_preds[i]
        fold_res.append(weighted_EuclideanDist_numpy(pred_dimer, true_dimer))
        # fold_naive_res.append(weighted_EuclideanDist_numpy(combine_dimer.cpu().numpy(), true_dimer))
    mses = np.array(fold_res).flatten()
    # naive_mses = np.array(fold_naive_res).flatten()
    print('Euclidean Distance Error', fold_i, np.mean(mses))
    if not os.path.exists(save_folder+"/predictions/"):
        os.mkdir(save_folder+"/predictions/")
        
    save_nparr = np.array([preds, actual])
    # print(save_nparr.shape)
    # pkl.dump([preds, actual], open(save_folder + "/predictions/" + str(fold_i) + "_"+ model_name + "_predictions.pkl", "wb"))
    np.save(save_folder + "/predictions/" + str(fold_i) + "_"+ model_name + "_predictions.np", save_nparr)
    return mses

# Leave-one-motifpair-out cross-validation of DeepMotifSyn Generator

In [33]:
torch.cuda.set_device(1)
device = 'cuda'
found_mp_name, found_mp_family, _, true_mp_code, found_mp_dimer_code = pkl.load(open("../data/found_best_aligned_mp_allFam_correctedFamilyName.pkl", "rb"))
aligned_motifpairs = true_mp_code
family_names = found_mp_family
dimer_names = found_mp_name
dimer_codes = found_mp_dimer_code

unique_dnames = []
for dn in dimer_names:
    unique_dnames.append("_".join(dn.split("_")[:2]))

unique_dnames = np.array(unique_dnames)
print("Tatol number of motif pair:", len(set(unique_dnames)))

# loo_folder = "./leave.KConeDimer.out.uNetAdv.crxvalidate/"
family_names = np.array(family_names)
dimer_names = np.array(dimer_names)
save_folder = "../leave.one.motifpair.out.crxvalidate.generator/"
# save_folder = "./leave.one.motifpair.out.uNetAdv.crxvalidate/"
if not os.path.exists(save_folder):
    os.mkdir(save_folder)

# for train_index, test_index in loo.split(unique_dnames):
n = 0
for fami, udname in enumerate(set(unique_dnames)):
    test_index = unique_dnames == udname
    train_index = unique_dnames != udname
    
    test_family_name = family_names[test_index]
    test_dimer_name = dimer_names[test_index]
    test_unique_dname = unique_dnames[test_index]
    
    reversed_fam = test_family_name[0].split("_")[1] + "_" + test_family_name[0].split("_")[0]
    tune_index = np.arange(len(dimer_names))[np.array(family_names == test_family_name[0])*np.array(dimer_names != test_dimer_name[0])]
    tune_index = np.concatenate([tune_index, np.arange(len(dimer_names))[np.array(family_names == reversed_fam)*np.array(dimer_names != test_dimer_name[0])]])

    fold_name = test_family_name[0] + "-"+ udname
    
    print("-"*20, "DNA-binding Family:", test_family_name[0],"|","Testing motif list:", test_dimer_name, "-"*20)
    # net = Autoencoder_seq2seq().to(device)
    net = deeper_u_net().to(device)
    # net = torch.nn.DataParallel(model, device_ids=[0])
    optimizer = optim.Adam(net.parameters())
    
    
    X_train = aligned_motifpairs[train_index]
    X_tune = aligned_motifpairs[tune_index]
    X_test = aligned_motifpairs[test_index]
    y_train = dimer_codes[train_index]
    y_tune = dimer_codes[tune_index]
    y_test = dimer_codes[test_index]
    
    train_name = dimer_names[train_index]
    test_name = dimer_names[test_index]
    tune_name = dimer_names[tune_index]
    
    print("#Train_motifs", len(X_train), "#Test_motifs", len(X_test))
    
    model_name='u_net'
    train(net, [X_train, y_train], fold_name, model_name=model_name, save_folder=save_folder)
    test(net, [X_test, y_test], fold_name, model_name=model_name, save_folder=save_folder)
    break 

  0%|          | 0/200 [00:00<?, ?it/s]

Tatol number of motif pair: 313
-------------------- DNA-binding Family: ETS_BZIP | Testing motif list: ['ETV2_TEF' 'ETV2_TEF_2'] --------------------
#Train_motifs 612 #Test_motifs 2


100%|██████████| 200/200 [02:20<00:00,  1.42it/s]

valid: MSE 0.010088411159813404
Validation loss decreased (inf --> 0.010088).  Saving model ...
Load the best model from fold_ETS_BZIP-ETV2_TEF_best_u_net_checkpoint.pt
Euclidean Distance Error ETS_BZIP-ETV2_TEF 0.15714800783781233





In [None]:
# reproduce paper figure

# Leave-one-motifpair-out cross-validation of DeepMotifSyn evaluator

In [38]:
gnerated_motif_df = pd.read_csv("../data/generated_motifpairs_with_label_rmDuplicates.csv")
mp_replaced_seqs = np.load("../data/deeper_uNet_loMPo_predictive_all_possible_alignedMP_dedup.npy")
features_784 = np.load("../data/all_generative_dimer_784features_dedup_correctedFam.npy")
mp_replaced_seqs = np.array(mp_replaced_seqs)
labels = np.array(gnerated_motif_df['label'])
binary_labels = labels>0
mp_replaced_seqs = mp_replaced_seqs.reshape(368995, -1)
seqWithNonSeq =  np.concatenate([features_784, mp_replaced_seqs], axis=-1)

In [None]:
save_folder = "../leave.one.motifpair.out.crxvalidate.evaluator/"
if not os.path.exists(save_folder):
    os.mkdir(save_folder)

In [None]:
generated_motif_df = pd.read_csv("../data/generated_motifpairs_with_label_rmDuplicates.csv")
print(len(generated_motif_df))
# fig, ax = plt.subplots()
tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)
fold_preds = {}
dimer_names = generated_motif_df['dimer_name']
for fold_i, dn in enumerate(set(dimer_names)):        
    print("-"*10, fold_i, dn, "-"*10)
    train_index = np.arange(len(dimer_names))[dimer_names != dn]
    test_index = np.arange(len(dimer_names))[dimer_names == dn]
    
    X_train = seqWithNonSeq[train_index]
    X_test = seqWithNonSeq[test_index]
    y_train = binary_labels[train_index]
    y_test = binary_labels[test_index]
    print('#Train:', train.shape,'#Test:', y_test.shape)
    generated_motif_df.loc[test_index, 'test_dimer_name'] = dn
    
    # classifier
    xgb = XGBClassifier(subsample=1.0, n_estimators=200, min_child_weight=1, max_depth=5, learning_rate=0.0525, gamma=5, colsample_bytree=0.8, n_jobs=20)
    
    # start_time = time.time()
    xgb.fit(X_train, y_train)
    y_pred = rf.predict_proba(X_test)
    generated_motif_df.loc[test_index, 'xgboost_prediction'] = y_pred[:,1]
    # dump(xgb, './fold_' + dn + "_XGBoost_bestHyper_924features.joblib")
    # print('model saved!')
    # print("--- %s seconds ---" % (time.time() - start_time))
    fold_preds[dn] = np.array(list(zip(y_pred[:,1], y_test, test_index)))

    fpr, tpr, thresholds_roc = metrics.roc_curve(y_test, y_pred[:,1])
    roc_auc = metrics.auc(fpr, tpr)

    average_precision=average_precision_score(y_test, y_pred[:,1])
    print(roc_auc, average_precision)

368995
---------- 0 ETV2_TEF ----------
(367891, 924) (1104,)
