In [4]:
import time
import random
import itertools
import collections
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

from otmtd.utils.distance_otdd import OTDD
from scipy.spatial import distance


class EmbeddingDataset(Dataset):
    def __init__(self, task, text_path, train_embs_path, valid_embs_path=None, pt_label='domain'):
        embs, labels, unique_labels = self.load_embs_and_labels(task, text_path, train_embs_path, valid_embs_path, pt_label)
        if task != 'pre_train':
            print("{} unique labels: {}".format(task.capitalize(), len(unique_labels)))
        else:
            print("PRE-TRAIN {} unique labels: {}".format(pt_label, len(unique_labels)))
        counter = collections.Counter(labels)
        sorted_counter = dict(sorted(counter.items(), key=lambda kv: kv[1], reverse=True))
        print(sorted_counter)
        self.unique_labels = unique_labels

        self.embs_and_labels = pd.concat([embs, labels], axis=1)
        # 创建otdd需求的Dataset属性
        self.classes = [str(k) for k in unique_labels] # list of unique labels string
        self.targets = torch.tensor(list(labels))

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        embedding = self.embs_and_labels.iloc[idx]['pro_emb']
        label = self.embs_and_labels.iloc[idx]['label']
        return torch.tensor(embedding).reshape(1, -1, 1), torch.tensor(label)

    def load_embs_and_labels(self, task, text_path, train_embs_path, valid_embs_path, pt_label):
        texts = pd.read_csv(text_path)
        if task=='pre_train':
            texts.rename(columns={pt_label: 'label'}, inplace=True)
        labels = texts['label']
        unique_labels = labels.unique()

        if task=='pre_train' or task=='kinase': # pre_train / kinase, only `train` set
            embs = pd.read_pickle(train_embs_path)
        else:
            train_embs, valid_embs = pd.read_pickle(train_embs_path), pd.read_pickle(valid_embs_path)
            embs = pd.concat([train_embs, valid_embs], axis=0)

        selected_ids = texts['uniprot_id'].tolist()
        selected_embs_flag = embs['pro_id'].map(lambda x: True if x in selected_ids else False)
        embs = embs[selected_embs_flag].reset_index(drop=True)
        embs = embs['pro_emb']

        return embs, labels, unique_labels
    
def set_random_seed(SEED):
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)

def cal_combs_distances_mean(all_pt_tasks, t2t_distances_list):
    mean_t2t_distances = []
    for pt_task in all_pt_tasks:
        spec_pt_task_distance = []
        for t2t_distances in t2t_distances_list:
            if pt_task in t2t_distances['pt_task'].tolist():
                spec_pt_task_distance.append(t2t_distances[t2t_distances['pt_task']==pt_task].to_numpy()[:, :-1].reshape(1, -1))
        spec_pt_task_distance = np.concatenate(spec_pt_task_distance, axis=0).astype(np.float32).mean(axis=0)
        mean_t2t_distances.append(spec_pt_task_distance)
    mean_t2t_distances = np.stack(mean_t2t_distances, axis=0)
    return mean_t2t_distances

In [6]:
num_samples = 1000
TaskCombs = ['MLM+RMD', 'GO+RMD', 'MLM+GO+D', 'RMD']
pt_tasks_combs = [
    ['mlm', 'domain', 'motif', 'region'],
    ['domain', 'motif', 'region', 'go'],
    ['mlm', 'domain', 'go'],
    ['domain', 'motif', 'region']
]
all_pt_tasks = ['PT-mlm', 'PT-domain', 'PT-motif', 'PT-region', 'PT-go']
device = torch.device('cpu')
p = 2

# prepare data
texts_base_dir = "/home/brian/work/OTMTD_GH/processed_data"
embs_base_dir = "/home/brian/work/OTMTD_GH/protein_embeddings_MultiTasks"
data_stuffs_dict = {}
for i, comb in enumerate(TaskCombs):
    print("Start loading data for Pretrain-MultiTask: {}".format(comb))
    # 根据multi-tasks combination构建对应的embdding hub
    tasks_texts_embs_hub = {
        'pre_train': ["{}/pre_train/sampling_set.txt".format(texts_base_dir),
                        "{}/pre_train_combs/pre_train_{}_pro_embs_pt.pkl".format(embs_base_dir, comb)],
        'stability': ["{}/stability/sequence_go_label_{}.txt".format(texts_base_dir, num_samples),
                        "{}/stability/{}->stability_pretrain_pro_embs_train.pkl".format(embs_base_dir, comb),
                        "{}/stability/{}->stability_pretrain_pro_embs_valid.pkl".format(embs_base_dir, comb)],
        'fluoresecence': ["{}/fluorescence/sequence_go_label_{}.txt".format(texts_base_dir, num_samples),
                        "{}/fluorescence/{}->fluorescence_pretrain_pro_embs_train.pkl".format(embs_base_dir, comb),
                        "{}/fluorescence/{}->fluorescence_pretrain_pro_embs_valid.pkl".format(embs_base_dir, comb)],
        'remote_homology': ["{}/remote_homology/sequence_go_label_{}.txt".format(texts_base_dir, num_samples),
                            "{}/remote_homology/{}->remote_homology_pretrain_pro_embs_train.pkl".format(embs_base_dir, comb),
                            "{}/remote_homology/{}->remote_homology_pretrain_pro_embs_valid.pkl".format(embs_base_dir, comb)], 
        'secondary_structrue': ["{}/secondary_structure/sequence_go_label_{}.txt".format(texts_base_dir, num_samples),
                                "{}/secondary_structure/{}->secondary_structure_pretrain_pro_embs_train.pkl".format(embs_base_dir, comb),
                                "{}/secondary_structure/{}->secondary_structure_pretrain_pro_embs_valid.pkl".format(embs_base_dir, comb)],
        'pdbbind': ["{}/pdbbind/sequence_go_label_{}.txt".format(texts_base_dir, num_samples),
                    "{}/pdbbind/{}->pdbbind_pretrain_pro_embs_train.pkl".format(embs_base_dir, comb),
                    "{}/pdbbind/{}->pdbbind_pretrain_pro_embs_dev.pkl".format(embs_base_dir, comb)],
        'kinase': ["{}/kinase/sequence_go_label_{}.txt".format(texts_base_dir, num_samples),
                    "{}/kinase//{}->kinase_pretrain_pro_embs_train.pkl".format(embs_base_dir, comb)]
    }

    pt_datasets, pt_names = [], []
    for label_task in pt_tasks_combs[i]:
        dataset = EmbeddingDataset('pre_train', *tasks_texts_embs_hub['pre_train'], pt_label=label_task)
        pt_datasets.append(dataset)
        pt_names.append('PT-{}'.format(label_task))

    ft_datasets, ft_names = [], []
    for task in ['stability', 'fluoresecence', 'remote_homology', 'secondary_structrue', 'pdbbind', 'kinase']: # 
        dataset = EmbeddingDataset(task, *tasks_texts_embs_hub[task])
        ft_datasets.append(dataset)
        ft_names.append(task.capitalize())

    pt_tasks_dict = {task: idx for idx, task in enumerate(pt_tasks_combs[i])}

    data_stuff = [pt_datasets, pt_names, ft_datasets, ft_names, pt_tasks_dict]
    data_stuffs_dict[comb] = data_stuff

Start loading data for Pretrain-MultiTask: MLM+RMD
PRE-TRAIN mlm unique labels: 15
{2: 148, 10: 117, 6: 116, 8: 101, 9: 98, 4: 95, 5: 59, 14: 54, 1: 49, 3: 47, 7: 46, 12: 45, 0: 45, 11: 13, 13: 12}
PRE-TRAIN domain unique labels: 13
{0: 173, 1: 165, 2: 121, 3: 89, 4: 88, 5: 64, 6: 59, 7: 53, 8: 50, 9: 47, 10: 46, 11: 45, 12: 45}
PRE-TRAIN motif unique labels: 5
{0: 843, 1: 59, 2: 50, 3: 47, 4: 46}
PRE-TRAIN region unique labels: 10
{0: 202, 1: 173, 2: 165, 3: 121, 4: 89, 5: 88, 6: 64, 7: 53, 8: 45, 9: 45}
Stability unique labels: 12
{0: 101, 1: 87, 2: 83, 3: 57, 4: 53, 5: 34, 6: 27, 7: 20, 8: 14, 9: 11, 10: 7, 11: 6}
Fluoresecence unique labels: 12
{0: 142, 1: 83, 2: 80, 3: 48, 4: 38, 5: 35, 6: 20, 7: 20, 8: 14, 9: 9, 10: 6, 11: 5}
Remote_homology unique labels: 16
{0: 136, 1: 126, 2: 96, 3: 26, 4: 20, 5: 14, 6: 11, 7: 11, 8: 9, 9: 9, 10: 8, 11: 8, 12: 7, 13: 7, 14: 6, 15: 6}
Secondary_structrue unique labels: 15
{9: 55, 3: 50, 1: 46, 14: 45, 8: 42, 7: 38, 13: 38, 0: 32, 6: 30, 11: 28,

# Baseline 1: 独立mask平均OTDD

In [10]:
bs1_otdd_t2t_distances = []
for i, comb in enumerate(TaskCombs):
    print("Start computing for Pretrain-MultiTask: {}".format(comb))
    data_stuff = data_stuffs_dict[comb]
    pt_datasets, pt_names, ft_datasets, ft_names, pt_tasks_dict = data_stuff
    
    OTdd = OTDD(device, pt_names=pt_names, ft_names=ft_names,
            pt_tasks_combs=[pt_tasks_combs[i]], pt_tasks_dict=pt_tasks_dict, gaussian_assumption=True)
    bs1_otdd_t2t_d = OTdd.t2t_distance(pt_datasets, ft_datasets, p=p)
    bs1_otdd_t2t_d = pd.DataFrame(bs1_otdd_t2t_d, columns=ft_names)
    bs1_otdd_t2t_d['pt_task'] = pt_names
    bs1_otdd_t2t_distances.append(bs1_otdd_t2t_d)
    print('==== x ==== ' * 5)

bs1_otdd_t2t_distances = cal_combs_distances_mean(all_pt_tasks, bs1_otdd_t2t_distances)
np.save("./results(all)/bs1_otdd_t2t_distances.npy", bs1_otdd_t2t_distances)

Start computing for Pretrain-MultiTask: MLM+RMD
Computing inter label-to-label distance for PT-mlm & Stability


12it [00:00, 19.65it/s]
Computing label-to-label distances: 100%|██████████| 180/180 [00:09<00:00, 19.61it/s]


OTDD distance of PT-mlm and Stability is: 679.25
Computing inter label-to-label distance for PT-mlm & Fluoresecence


12it [00:00, 19.56it/s]
Computing label-to-label distances: 100%|██████████| 180/180 [00:09<00:00, 19.82it/s]


OTDD distance of PT-mlm and Fluoresecence is: 1095.73
Computing inter label-to-label distance for PT-mlm & Remote_homology


15it [00:00, 19.99it/s]
Computing label-to-label distances: 100%|██████████| 240/240 [00:12<00:00, 19.76it/s]


OTDD distance of PT-mlm and Remote_homology is: 601.37
Computing inter label-to-label distance for PT-mlm & Secondary_structrue


15it [00:00, 20.10it/s]
Computing label-to-label distances: 100%|██████████| 225/225 [00:11<00:00, 19.65it/s]


OTDD distance of PT-mlm and Secondary_structrue is: 499.54
Computing inter label-to-label distance for PT-mlm & Pdbbind


11it [00:00, 19.30it/s]
Computing label-to-label distances: 100%|██████████| 165/165 [00:08<00:00, 19.54it/s]


OTDD distance of PT-mlm and Pdbbind is: 565.86
Computing inter label-to-label distance for PT-mlm & Kinase


2it [00:00, 15.13it/s]
Computing label-to-label distances: 100%|██████████| 30/30 [00:01<00:00, 20.12it/s]


OTDD distance of PT-mlm and Kinase is: 884.16
Computing inter label-to-label distance for PT-domain & Stability


12it [00:00, 18.34it/s]
Computing label-to-label distances: 100%|██████████| 156/156 [00:08<00:00, 19.41it/s]


OTDD distance of PT-domain and Stability is: 679.38
Computing inter label-to-label distance for PT-domain & Fluoresecence


12it [00:00, 20.09it/s]
Computing label-to-label distances: 100%|██████████| 156/156 [00:07<00:00, 19.79it/s]


OTDD distance of PT-domain and Fluoresecence is: 1095.76
Computing inter label-to-label distance for PT-domain & Remote_homology


13it [00:00, 19.57it/s]
Computing label-to-label distances: 100%|██████████| 208/208 [00:10<00:00, 19.31it/s]


OTDD distance of PT-domain and Remote_homology is: 601.31
Computing inter label-to-label distance for PT-domain & Secondary_structrue


13it [00:00, 19.60it/s]
Computing label-to-label distances: 100%|██████████| 195/195 [00:10<00:00, 19.33it/s]


OTDD distance of PT-domain and Secondary_structrue is: 499.54
Computing inter label-to-label distance for PT-domain & Pdbbind


11it [00:00, 18.57it/s]
Computing label-to-label distances: 100%|██████████| 143/143 [00:07<00:00, 20.23it/s]


OTDD distance of PT-domain and Pdbbind is: 566.07
Computing inter label-to-label distance for PT-domain & Kinase


2it [00:00, 15.27it/s]
Computing label-to-label distances: 100%|██████████| 26/26 [00:01<00:00, 20.17it/s]


OTDD distance of PT-domain and Kinase is: 884.24
Computing inter label-to-label distance for PT-motif & Stability


5it [00:00, 21.38it/s]
Computing label-to-label distances: 100%|██████████| 60/60 [00:03<00:00, 19.70it/s]


OTDD distance of PT-motif and Stability is: 657.41
Computing inter label-to-label distance for PT-motif & Fluoresecence


5it [00:00, 18.69it/s]
Computing label-to-label distances: 100%|██████████| 60/60 [00:03<00:00, 19.21it/s]


OTDD distance of PT-motif and Fluoresecence is: 1092.46
Computing inter label-to-label distance for PT-motif & Remote_homology


5it [00:00, 14.61it/s]
Computing label-to-label distances: 100%|██████████| 80/80 [00:04<00:00, 18.73it/s]


OTDD distance of PT-motif and Remote_homology is: 585.95
Computing inter label-to-label distance for PT-motif & Secondary_structrue


5it [00:00, 16.61it/s]
Computing label-to-label distances: 100%|██████████| 75/75 [00:03<00:00, 19.85it/s]


OTDD distance of PT-motif and Secondary_structrue is: 493.49
Computing inter label-to-label distance for PT-motif & Pdbbind


5it [00:00, 20.26it/s]
Computing label-to-label distances: 100%|██████████| 55/55 [00:02<00:00, 19.92it/s]


OTDD distance of PT-motif and Pdbbind is: 537.94
Computing inter label-to-label distance for PT-motif & Kinase


2it [00:00, 16.70it/s]
Computing label-to-label distances: 100%|██████████| 10/10 [00:00<00:00, 19.92it/s]


OTDD distance of PT-motif and Kinase is: 860.33
Computing inter label-to-label distance for PT-region & Stability


10it [00:00, 19.27it/s]
Computing label-to-label distances: 100%|██████████| 120/120 [00:06<00:00, 19.29it/s]


OTDD distance of PT-region and Stability is: 673.54
Computing inter label-to-label distance for PT-region & Fluoresecence


10it [00:00, 20.33it/s]
Computing label-to-label distances: 100%|██████████| 120/120 [00:06<00:00, 19.92it/s]


OTDD distance of PT-region and Fluoresecence is: 1094.74
Computing inter label-to-label distance for PT-region & Remote_homology


10it [00:00, 16.61it/s]
Computing label-to-label distances: 100%|██████████| 160/160 [00:08<00:00, 19.94it/s]


OTDD distance of PT-region and Remote_homology is: 597.72
Computing inter label-to-label distance for PT-region & Secondary_structrue


10it [00:00, 18.11it/s]
Computing label-to-label distances: 100%|██████████| 150/150 [00:07<00:00, 20.35it/s]


OTDD distance of PT-region and Secondary_structrue is: 498.53
Computing inter label-to-label distance for PT-region & Pdbbind


10it [00:00, 19.20it/s]
Computing label-to-label distances: 100%|██████████| 110/110 [00:05<00:00, 19.53it/s]


OTDD distance of PT-region and Pdbbind is: 559.33
Computing inter label-to-label distance for PT-region & Kinase


2it [00:00, 12.13it/s]
Computing label-to-label distances: 100%|██████████| 20/20 [00:01<00:00, 19.87it/s]


OTDD distance of PT-region and Kinase is: 879.69
==== x ==== ==== x ==== ==== x ==== ==== x ==== ==== x ==== 
Start computing for Pretrain-MultiTask: GO+RMD
Computing inter label-to-label distance for PT-domain & Stability


12it [00:00, 19.50it/s]
Computing label-to-label distances: 100%|██████████| 156/156 [00:07<00:00, 19.82it/s]


OTDD distance of PT-domain and Stability is: 933.33
Computing inter label-to-label distance for PT-domain & Fluoresecence


12it [00:00, 16.15it/s]
Computing label-to-label distances: 100%|██████████| 156/156 [00:07<00:00, 19.50it/s]


OTDD distance of PT-domain and Fluoresecence is: 1212.19
Computing inter label-to-label distance for PT-domain & Remote_homology


13it [00:00, 19.95it/s]
Computing label-to-label distances: 100%|██████████| 208/208 [00:10<00:00, 19.10it/s]


OTDD distance of PT-domain and Remote_homology is: 818.60
Computing inter label-to-label distance for PT-domain & Secondary_structrue


13it [00:00, 17.86it/s]
Computing label-to-label distances: 100%|██████████| 195/195 [00:09<00:00, 19.67it/s]


OTDD distance of PT-domain and Secondary_structrue is: 676.42
Computing inter label-to-label distance for PT-domain & Pdbbind


11it [00:00, 20.71it/s]
Computing label-to-label distances: 100%|██████████| 143/143 [00:07<00:00, 19.82it/s]


OTDD distance of PT-domain and Pdbbind is: 832.53
Computing inter label-to-label distance for PT-domain & Kinase


2it [00:00, 17.35it/s]
Computing label-to-label distances: 100%|██████████| 26/26 [00:01<00:00, 19.22it/s]


OTDD distance of PT-domain and Kinase is: 1076.32
Computing inter label-to-label distance for PT-motif & Stability


5it [00:00, 18.42it/s]
Computing label-to-label distances: 100%|██████████| 60/60 [00:03<00:00, 19.64it/s]


OTDD distance of PT-motif and Stability is: 902.21
Computing inter label-to-label distance for PT-motif & Fluoresecence


5it [00:00, 19.78it/s]
Computing label-to-label distances: 100%|██████████| 60/60 [00:03<00:00, 19.70it/s]


OTDD distance of PT-motif and Fluoresecence is: 1205.86
Computing inter label-to-label distance for PT-motif & Remote_homology


5it [00:00, 13.77it/s]
Computing label-to-label distances: 100%|██████████| 80/80 [00:04<00:00, 19.89it/s]


OTDD distance of PT-motif and Remote_homology is: 788.15
Computing inter label-to-label distance for PT-motif & Secondary_structrue


5it [00:00, 19.50it/s]
Computing label-to-label distances: 100%|██████████| 75/75 [00:03<00:00, 19.90it/s]


OTDD distance of PT-motif and Secondary_structrue is: 663.63
Computing inter label-to-label distance for PT-motif & Pdbbind


5it [00:00, 17.37it/s]
Computing label-to-label distances: 100%|██████████| 55/55 [00:02<00:00, 19.98it/s]


OTDD distance of PT-motif and Pdbbind is: 798.80
Computing inter label-to-label distance for PT-motif & Kinase


2it [00:00, 14.60it/s]
Computing label-to-label distances: 100%|██████████| 10/10 [00:00<00:00, 18.77it/s]


OTDD distance of PT-motif and Kinase is: 1053.88
Computing inter label-to-label distance for PT-region & Stability


10it [00:00, 18.23it/s]
Computing label-to-label distances: 100%|██████████| 120/120 [00:05<00:00, 20.42it/s]


OTDD distance of PT-region and Stability is: 924.54
Computing inter label-to-label distance for PT-region & Fluoresecence


10it [00:00, 17.76it/s]
Computing label-to-label distances: 100%|██████████| 120/120 [00:06<00:00, 19.59it/s]


OTDD distance of PT-region and Fluoresecence is: 1211.09
Computing inter label-to-label distance for PT-region & Remote_homology


10it [00:00, 19.01it/s]
Computing label-to-label distances: 100%|██████████| 160/160 [00:08<00:00, 19.76it/s]


OTDD distance of PT-region and Remote_homology is: 810.16
Computing inter label-to-label distance for PT-region & Secondary_structrue


10it [00:00, 19.96it/s]
Computing label-to-label distances: 100%|██████████| 150/150 [00:07<00:00, 19.78it/s]


OTDD distance of PT-region and Secondary_structrue is: 673.02
Computing inter label-to-label distance for PT-region & Pdbbind


10it [00:00, 19.41it/s]
Computing label-to-label distances: 100%|██████████| 110/110 [00:05<00:00, 20.25it/s]


OTDD distance of PT-region and Pdbbind is: 823.13
Computing inter label-to-label distance for PT-region & Kinase


2it [00:00, 16.96it/s]
Computing label-to-label distances: 100%|██████████| 20/20 [00:00<00:00, 20.26it/s]


OTDD distance of PT-region and Kinase is: 1069.91
Computing inter label-to-label distance for PT-go & Stability


9it [00:00, 19.57it/s]
Computing label-to-label distances: 100%|██████████| 108/108 [00:05<00:00, 19.75it/s]


OTDD distance of PT-go and Stability is: 930.79
Computing inter label-to-label distance for PT-go & Fluoresecence


9it [00:00, 19.98it/s]
Computing label-to-label distances: 100%|██████████| 108/108 [00:05<00:00, 19.41it/s]


OTDD distance of PT-go and Fluoresecence is: 1211.81
Computing inter label-to-label distance for PT-go & Remote_homology


9it [00:00, 19.84it/s]
Computing label-to-label distances: 100%|██████████| 144/144 [00:07<00:00, 20.10it/s]


OTDD distance of PT-go and Remote_homology is: 815.89
Computing inter label-to-label distance for PT-go & Secondary_structrue


9it [00:00, 19.66it/s]
Computing label-to-label distances: 100%|██████████| 135/135 [00:06<00:00, 20.00it/s]


OTDD distance of PT-go and Secondary_structrue is: 675.93
Computing inter label-to-label distance for PT-go & Pdbbind


9it [00:00, 18.20it/s]
Computing label-to-label distances: 100%|██████████| 99/99 [00:05<00:00, 19.47it/s]


OTDD distance of PT-go and Pdbbind is: 829.69
Computing inter label-to-label distance for PT-go & Kinase


2it [00:00, 17.76it/s]
Computing label-to-label distances: 100%|██████████| 18/18 [00:00<00:00, 19.34it/s]


OTDD distance of PT-go and Kinase is: 1074.55
==== x ==== ==== x ==== ==== x ==== ==== x ==== ==== x ==== 
Start computing for Pretrain-MultiTask: MLM+GO+D
Computing inter label-to-label distance for PT-mlm & Stability


12it [00:00, 20.06it/s]
Computing label-to-label distances: 100%|██████████| 180/180 [00:08<00:00, 20.17it/s]


OTDD distance of PT-mlm and Stability is: 900.56
Computing inter label-to-label distance for PT-mlm & Fluoresecence


12it [00:00, 20.05it/s]
Computing label-to-label distances: 100%|██████████| 180/180 [00:08<00:00, 20.09it/s]


OTDD distance of PT-mlm and Fluoresecence is: 1303.75
Computing inter label-to-label distance for PT-mlm & Remote_homology


15it [00:00, 19.70it/s]
Computing label-to-label distances: 100%|██████████| 240/240 [00:12<00:00, 19.97it/s]


OTDD distance of PT-mlm and Remote_homology is: 866.70
Computing inter label-to-label distance for PT-mlm & Secondary_structrue


15it [00:00, 20.41it/s]
Computing label-to-label distances: 100%|██████████| 225/225 [00:11<00:00, 19.37it/s]


OTDD distance of PT-mlm and Secondary_structrue is: 610.91
Computing inter label-to-label distance for PT-mlm & Pdbbind


11it [00:00, 18.83it/s]
Computing label-to-label distances: 100%|██████████| 165/165 [00:08<00:00, 19.93it/s]


OTDD distance of PT-mlm and Pdbbind is: 813.91
Computing inter label-to-label distance for PT-mlm & Kinase


2it [00:00, 14.34it/s]
Computing label-to-label distances: 100%|██████████| 30/30 [00:01<00:00, 19.33it/s]


OTDD distance of PT-mlm and Kinase is: 1090.28
Computing inter label-to-label distance for PT-domain & Stability


12it [00:00, 18.92it/s]
Computing label-to-label distances: 100%|██████████| 156/156 [00:07<00:00, 19.66it/s]


OTDD distance of PT-domain and Stability is: 900.91
Computing inter label-to-label distance for PT-domain & Fluoresecence


12it [00:00, 20.00it/s]
Computing label-to-label distances: 100%|██████████| 156/156 [00:07<00:00, 19.95it/s]


OTDD distance of PT-domain and Fluoresecence is: 1303.92
Computing inter label-to-label distance for PT-domain & Remote_homology


13it [00:00, 18.54it/s]
Computing label-to-label distances: 100%|██████████| 208/208 [00:10<00:00, 19.87it/s]


OTDD distance of PT-domain and Remote_homology is: 866.81
Computing inter label-to-label distance for PT-domain & Secondary_structrue


13it [00:00, 18.98it/s]
Computing label-to-label distances: 100%|██████████| 195/195 [00:09<00:00, 19.78it/s]


OTDD distance of PT-domain and Secondary_structrue is: 610.93
Computing inter label-to-label distance for PT-domain & Pdbbind


11it [00:00, 18.97it/s]
Computing label-to-label distances: 100%|██████████| 143/143 [00:07<00:00, 19.53it/s]


OTDD distance of PT-domain and Pdbbind is: 814.46
Computing inter label-to-label distance for PT-domain & Kinase


2it [00:00, 16.12it/s]
Computing label-to-label distances: 100%|██████████| 26/26 [00:01<00:00, 20.07it/s]


OTDD distance of PT-domain and Kinase is: 1091.54
Computing inter label-to-label distance for PT-go & Stability


9it [00:00, 19.78it/s]
Computing label-to-label distances: 100%|██████████| 108/108 [00:05<00:00, 18.93it/s]


OTDD distance of PT-go and Stability is: 898.03
Computing inter label-to-label distance for PT-go & Fluoresecence


9it [00:00, 19.24it/s]
Computing label-to-label distances: 100%|██████████| 108/108 [00:05<00:00, 20.79it/s]


OTDD distance of PT-go and Fluoresecence is: 1303.52
Computing inter label-to-label distance for PT-go & Remote_homology


9it [00:00, 18.53it/s]
Computing label-to-label distances: 100%|██████████| 144/144 [00:07<00:00, 19.99it/s]


OTDD distance of PT-go and Remote_homology is: 864.93
Computing inter label-to-label distance for PT-go & Secondary_structrue


9it [00:00, 19.19it/s]
Computing label-to-label distances: 100%|██████████| 135/135 [00:06<00:00, 20.16it/s]


OTDD distance of PT-go and Secondary_structrue is: 610.35
Computing inter label-to-label distance for PT-go & Pdbbind


9it [00:00, 19.09it/s]
Computing label-to-label distances: 100%|██████████| 99/99 [00:05<00:00, 19.61it/s]


OTDD distance of PT-go and Pdbbind is: 811.32
Computing inter label-to-label distance for PT-go & Kinase


2it [00:00, 18.10it/s]
Computing label-to-label distances: 100%|██████████| 18/18 [00:00<00:00, 19.97it/s]


OTDD distance of PT-go and Kinase is: 1090.12
==== x ==== ==== x ==== ==== x ==== ==== x ==== ==== x ==== 
Start computing for Pretrain-MultiTask: RMD
Computing inter label-to-label distance for PT-domain & Stability


12it [00:00, 20.08it/s]
Computing label-to-label distances: 100%|██████████| 156/156 [00:07<00:00, 19.73it/s]


OTDD distance of PT-domain and Stability is: 700.09
Computing inter label-to-label distance for PT-domain & Fluoresecence


12it [00:00, 18.08it/s]
Computing label-to-label distances: 100%|██████████| 156/156 [00:07<00:00, 19.85it/s]


OTDD distance of PT-domain and Fluoresecence is: 1124.13
Computing inter label-to-label distance for PT-domain & Remote_homology


13it [00:00, 20.06it/s]
Computing label-to-label distances: 100%|██████████| 208/208 [00:10<00:00, 19.16it/s]


OTDD distance of PT-domain and Remote_homology is: 719.22
Computing inter label-to-label distance for PT-domain & Secondary_structrue


13it [00:00, 20.24it/s]
Computing label-to-label distances: 100%|██████████| 195/195 [00:09<00:00, 19.79it/s]


OTDD distance of PT-domain and Secondary_structrue is: 539.36
Computing inter label-to-label distance for PT-domain & Pdbbind


11it [00:00, 19.70it/s]
Computing label-to-label distances: 100%|██████████| 143/143 [00:07<00:00, 20.04it/s]


OTDD distance of PT-domain and Pdbbind is: 668.87
Computing inter label-to-label distance for PT-domain & Kinase


2it [00:00, 18.87it/s]
Computing label-to-label distances: 100%|██████████| 26/26 [00:01<00:00, 19.97it/s]


OTDD distance of PT-domain and Kinase is: 900.26
Computing inter label-to-label distance for PT-motif & Stability


5it [00:00, 19.55it/s]
Computing label-to-label distances: 100%|██████████| 60/60 [00:03<00:00, 20.00it/s]


OTDD distance of PT-motif and Stability is: 675.46
Computing inter label-to-label distance for PT-motif & Fluoresecence


5it [00:00, 19.87it/s]
Computing label-to-label distances: 100%|██████████| 60/60 [00:02<00:00, 20.40it/s]


OTDD distance of PT-motif and Fluoresecence is: 1119.98
Computing inter label-to-label distance for PT-motif & Remote_homology


5it [00:00, 19.45it/s]
Computing label-to-label distances: 100%|██████████| 80/80 [00:03<00:00, 20.31it/s]


OTDD distance of PT-motif and Remote_homology is: 702.96
Computing inter label-to-label distance for PT-motif & Secondary_structrue


5it [00:00, 20.14it/s]
Computing label-to-label distances: 100%|██████████| 75/75 [00:03<00:00, 19.88it/s]


OTDD distance of PT-motif and Secondary_structrue is: 530.50
Computing inter label-to-label distance for PT-motif & Pdbbind


5it [00:00, 17.22it/s]
Computing label-to-label distances: 100%|██████████| 55/55 [00:02<00:00, 19.84it/s]


OTDD distance of PT-motif and Pdbbind is: 639.32
Computing inter label-to-label distance for PT-motif & Kinase


2it [00:00, 16.91it/s]
Computing label-to-label distances: 100%|██████████| 10/10 [00:00<00:00, 19.28it/s]


OTDD distance of PT-motif and Kinase is: 875.68
Computing inter label-to-label distance for PT-region & Stability


10it [00:00, 18.78it/s]
Computing label-to-label distances: 100%|██████████| 120/120 [00:05<00:00, 20.23it/s]


OTDD distance of PT-region and Stability is: 692.13
Computing inter label-to-label distance for PT-region & Fluoresecence


10it [00:00, 19.57it/s]
Computing label-to-label distances: 100%|██████████| 120/120 [00:05<00:00, 20.26it/s]


OTDD distance of PT-region and Fluoresecence is: 1123.09
Computing inter label-to-label distance for PT-region & Remote_homology


10it [00:00, 19.06it/s]
Computing label-to-label distances: 100%|██████████| 160/160 [00:07<00:00, 20.48it/s]


OTDD distance of PT-region and Remote_homology is: 714.94
Computing inter label-to-label distance for PT-region & Secondary_structrue


10it [00:00, 17.43it/s]
Computing label-to-label distances: 100%|██████████| 150/150 [00:07<00:00, 20.17it/s]


OTDD distance of PT-region and Secondary_structrue is: 538.57
Computing inter label-to-label distance for PT-region & Pdbbind


10it [00:00, 19.17it/s]
Computing label-to-label distances: 100%|██████████| 110/110 [00:05<00:00, 19.42it/s]


OTDD distance of PT-region and Pdbbind is: 661.91
Computing inter label-to-label distance for PT-region & Kinase


2it [00:00, 15.65it/s]
Computing label-to-label distances: 100%|██████████| 20/20 [00:01<00:00, 18.91it/s]


OTDD distance of PT-region and Kinase is: 892.50
==== x ==== ==== x ==== ==== x ==== ==== x ==== ==== x ==== 


In [6]:
bs1_otdd_t2t_distances

array([[ 803.197  , 1183.9417 ,  751.51904,  581.4503 ,  720.16187],
       [ 803.4283 , 1184.0001 ,  751.4855 ,  581.5658 ,  720.48206],
       [ 776.43085, 1178.9626 ,  731.1835 ,  572.58997,  689.1713 ],
       [ 795.9525 , 1182.8298 ,  746.21716,  579.7397 ,  712.3189 ],
       [ 801.0803 , 1183.6763 ,  749.661  ,  581.1654 ,  718.07465]],
      dtype=float32)

# Baseline 2: 独立平均OTNCE

In [11]:
from otce.otce import OTNCE

bs2_otnce_t2t_distances = []
for i, comb in enumerate(TaskCombs):
    print("Start computing for Pretrain-MultiTask: {}".format(comb))
    data_stuff = data_stuffs_dict[comb]
    pt_datasets, pt_names, ft_datasets, ft_names, pt_tasks_dict = data_stuff

    series_to_array = lambda x: np.stack(x.tolist(), axis=0)
    pt_data_dict = {}
    for i, pt_dataset in enumerate(pt_datasets):
        X_Y = pt_dataset.embs_and_labels
        pt_data_dict[pt_names[i]] = (series_to_array(X_Y['pro_emb']), series_to_array(X_Y['label']))
    ft_data_dict = {}
    for i, ft_dataset in enumerate(ft_datasets):
        X_Y = ft_dataset.embs_and_labels
        ft_data_dict[ft_names[i]] = (series_to_array(X_Y['pro_emb']), series_to_array(X_Y['label']))

    dataset_dists = {}
    OT_NCE = OTNCE(backend='numpy', distMetric='euclidean', numItermax=1e5, return_OT=False)
    task_pairs = itertools.product(pt_names, ft_names) 
    # computing pairwise OTNCE
    for pt_task, ft_task in task_pairs:
        X_src, Y_src = pt_data_dict[pt_task]
        X_tgt, Y_tgt = ft_data_dict[ft_task]
        print(' -> {} - {}'.format(pt_task, ft_task))
        ts = time.time()
        otnce = OT_NCE.otnce(X_src, Y_src, X_tgt, Y_tgt)
        dataset_dists[(pt_task, ft_task)] = otnce.item()
        print('OTNCE between {} and {}: {:.2f}, time costs: {:.1f}'.format(pt_task, ft_task, otnce, time.time()-ts))

    # Construct pairwise OTNCE dataframe
    pttask_to_idx = {task: i for i, task in enumerate(pt_names)}
    otnce_df = pd.DataFrame(columns=ft_names, index=range(len(pttask_to_idx)))
    for (pt_task, ft_task), otnce in dataset_dists.items():
        otnce_df.iloc[pttask_to_idx[pt_task]][ft_task] = otnce
    otnce_df['pt_task'] = pt_names
 
    bs2_otnce_t2t_distances.append(otnce_df)
    print('==== x ==== ' * 5)

bs2_otnce_t2t_distances = cal_combs_distances_mean(all_pt_tasks, bs2_otnce_t2t_distances)
np.save("./results(all)/bs2_otnce_t2t_distances.npy", bs2_otnce_t2t_distances)

Start computing for Pretrain-MultiTask: MLM+RMD
 -> PT-mlm - Stability
OTNCE between PT-mlm and Stability: -2.02, time costs: 0.2
 -> PT-mlm - Fluoresecence
OTNCE between PT-mlm and Fluoresecence: -1.93, time costs: 0.2
 -> PT-mlm - Remote_homology
OTNCE between PT-mlm and Remote_homology: -1.91, time costs: 0.2
 -> PT-mlm - Secondary_structrue
OTNCE between PT-mlm and Secondary_structrue: -2.47, time costs: 0.2
 -> PT-mlm - Pdbbind
OTNCE between PT-mlm and Pdbbind: -2.13, time costs: 0.2
 -> PT-mlm - Kinase
OTNCE between PT-mlm and Kinase: -0.49, time costs: 0.2
 -> PT-domain - Stability
OTNCE between PT-domain and Stability: -2.05, time costs: 0.2
 -> PT-domain - Fluoresecence
OTNCE between PT-domain and Fluoresecence: -1.95, time costs: 0.2
 -> PT-domain - Remote_homology
OTNCE between PT-domain and Remote_homology: -1.92, time costs: 0.2
 -> PT-domain - Secondary_structrue
OTNCE between PT-domain and Secondary_structrue: -2.48, time costs: 0.3
 -> PT-domain - Pdbbind
OTNCE between 

In [8]:
bs2_otnce_t2t_distances

array([[-2.0445461, -1.9418494, -1.8959922, -2.458718 , -2.1220102],
       [-2.0611322, -1.9471098, -1.9053051, -2.474505 , -2.1363287],
       [-2.1344495, -2.0379968, -2.013342 , -2.5787942, -2.2040348],
       [-2.0919962, -1.9822189, -1.9455435, -2.5184326, -2.1695592],
       [-2.0874257, -1.9892466, -1.9554827, -2.5205643, -2.165317 ]],
      dtype=float32)

# Baseline 3: Only feature (wo label information)

In [12]:
import ot

def pW_cal(a, b, p=2, metric='euclidean'):
    """ Args:
            a, b: samples sets drawn from α,β respectively
            p: the coefficient in the OT cost (i.e., the p in p-Wasserstein)
            metric: the metric to compute cost matrix, 'euclidean' or 'cosine'
    """
    # cost matrix
    M = ot.dist(a, b, metric=metric)
    M = pow(M, p)
    # uniform distribution assumption
    alpha = ot.unif(len(a))
    beta = ot.unif(len(b))
    # p-Wesserstein Distance
    pW = ot.emd2(alpha, beta, M, numItermax=100000)
    pW = pow(pW, 1/p)
    return pW

bs3_XpW_t2t_distances = []
for i, comb in enumerate(TaskCombs):
    print("Start computing for Pretrain-MultiTask: {}".format(comb))
    data_stuff = data_stuffs_dict[comb]
    pt_datasets, pt_names, ft_datasets, ft_names, pt_tasks_dict = data_stuff

    pt_xdata_dict = {}
    for i, pt_dataset in enumerate(pt_datasets):
        X_Y = pt_dataset.embs_and_labels
        pt_xdata_dict[pt_names[i]] = series_to_array(X_Y['pro_emb'])
    ft_xdata_dict = {}
    for i, ft_dataset in enumerate(ft_datasets):
        X_Y = ft_dataset.embs_and_labels
        ft_xdata_dict[ft_names[i]] = series_to_array(X_Y['pro_emb'])

    dataset_dists = {}
    task_pairs = itertools.product(pt_names, ft_names) 
    # computing pairwise feature distances
    for pt_task, ft_task in task_pairs:
        X_src = pt_xdata_dict[pt_task]
        X_tgt = ft_xdata_dict[ft_task]
        print(' -> {} - {}'.format(pt_task, ft_task))
        ts = time.time()
        pW = pW_cal(X_src, X_tgt, p=p, metric='euclidean')
        dataset_dists[(pt_task, ft_task)] = pW
        print('{}-Wasserstein Distance between {} and {}: {:.2f}, time costs: {:.1f}'.format(p, pt_task, ft_task, pW, time.time()-ts))

    # Construct pairwise pW Distance dataframe
    pttask_to_idx = {task: i for i, task in enumerate(pt_names)}
    pW_df = pd.DataFrame(columns=ft_names, index=range(len(pttask_to_idx)))
    for (pt_task, ft_task), pW in dataset_dists.items():
        pW_df.iloc[pttask_to_idx[pt_task]][ft_task] = pW
    pW_df['pt_task'] = pt_names
    
    bs3_XpW_t2t_distances.append(pW_df)
    print('==== x ==== ' * 5)

bs3_XpW_t2t_distances = cal_combs_distances_mean(all_pt_tasks, bs3_XpW_t2t_distances)
np.save("./results(all)/bs3_XpW_t2t_distances.npy", bs3_XpW_t2t_distances)

Start computing for Pretrain-MultiTask: MLM+RMD
 -> PT-mlm - Stability
2-Wasserstein Distance between PT-mlm and Stability: 25.71, time costs: 0.2
 -> PT-mlm - Fluoresecence
2-Wasserstein Distance between PT-mlm and Fluoresecence: 33.07, time costs: 0.2
 -> PT-mlm - Remote_homology
2-Wasserstein Distance between PT-mlm and Remote_homology: 24.26, time costs: 0.3
 -> PT-mlm - Secondary_structrue
2-Wasserstein Distance between PT-mlm and Secondary_structrue: 22.18, time costs: 0.2
 -> PT-mlm - Pdbbind
2-Wasserstein Distance between PT-mlm and Pdbbind: 23.28, time costs: 0.2
 -> PT-mlm - Kinase
2-Wasserstein Distance between PT-mlm and Kinase: 29.41, time costs: 0.2
 -> PT-domain - Stability
2-Wasserstein Distance between PT-domain and Stability: 25.71, time costs: 0.2
 -> PT-domain - Fluoresecence
2-Wasserstein Distance between PT-domain and Fluoresecence: 33.07, time costs: 0.2
 -> PT-domain - Remote_homology
2-Wasserstein Distance between PT-domain and Remote_homology: 24.26, time cost

In [10]:
bs3_XpW_t2t_distances

array([[27.862999, 34.334396, 27.04403 , 23.841352, 26.273642],
       [27.862999, 34.334396, 27.04403 , 23.841352, 26.273642],
       [27.862999, 34.334396, 27.04403 , 23.841352, 26.273642],
       [27.862999, 34.334396, 27.04403 , 23.841352, 26.273642],
       [27.862999, 34.334396, 27.04403 , 23.841352, 26.273642]],
      dtype=float32)