In [1]:
import time
import random
import itertools
import collections
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

from otmtd.utils.distance import WTE
from scipy.spatial import distance


class EmbeddingDataset(Dataset):
    def __init__(self, task, text_path, train_embs_path, valid_embs_path=None, pt_label='domain'):
        embs, labels, unique_labels = self.load_embs_and_labels(task, text_path, train_embs_path, valid_embs_path, pt_label)
        if task != 'pre_train':
            print("{} unique labels: {}".format(task.capitalize(), len(unique_labels)))
        else:
            print("PRE-TRAIN {} unique labels: {}".format(pt_label, len(unique_labels)))
        counter = collections.Counter(labels)
        sorted_counter = dict(sorted(counter.items(), key=lambda kv: kv[1], reverse=True))
        print(sorted_counter)
        self.unique_labels = unique_labels

        self.embs_and_labels = pd.concat([embs, labels], axis=1)
        # 创建otdd需求的Dataset属性
        self.classes = [str(k) for k in unique_labels] # list of unique labels string
        self.targets = torch.tensor(list(labels))

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        embedding = self.embs_and_labels.iloc[idx]['pro_emb']
        label = self.embs_and_labels.iloc[idx]['label']
        return torch.tensor(embedding).reshape(1, -1, 1), torch.tensor(label)

    def load_embs_and_labels(self, task, text_path, train_embs_path, valid_embs_path, pt_label):
        texts = pd.read_csv(text_path)
        if task=='pre_train':
            texts.rename(columns={pt_label: 'label'}, inplace=True)
        labels = texts['label']
        unique_labels = labels.unique()

        if task=='pre_train' or task=='kinase': # pre_train / kinase, only `train` set
            embs = pd.read_pickle(train_embs_path)
        else:
            train_embs, valid_embs = pd.read_pickle(train_embs_path), pd.read_pickle(valid_embs_path)
            embs = pd.concat([train_embs, valid_embs], axis=0)

        selected_ids = texts['uniprot_id'].tolist()
        selected_embs_flag = embs['pro_id'].map(lambda x: True if x in selected_ids else False)
        embs = embs[selected_embs_flag].reset_index(drop=True)
        embs = embs['pro_emb']

        return embs, labels, unique_labels
    
def set_random_seed(SEED):
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)

In [2]:
num_samples = 1000
TaskCombs = ['MLM+RMD', 'GO+RMD', 'MLM+GO+D', 'RMD']
pt_tasks_combs = [
    ['mlm', 'domain', 'motif', 'region'],
    ['domain', 'motif', 'region', 'go'],
    ['mlm', 'domain', 'go'],
    ['domain', 'motif', 'region']
]

device = torch.device('cpu')
# generate reference
feat_dim, lbl_emb_dim, ref_size = 512, 10, 200
seed = 1145114
set_random_seed(seed)
reference = torch.randn(ref_size, feat_dim + lbl_emb_dim, dtype=float, device=device).float()

wte_distance, wte_t2t_distances = [], []
texts_base_dir = "/home/brian/work/OTMTD_GH/processed_data"
embs_base_dir = "/home/brian/work/OTMTD_GH/protein_embeddings_MultiTasks"
for i, comb in enumerate(TaskCombs):
    print("Start computing for Pretrain-MultiTask: {}".format(comb))
    # 根据multi-tasks combination构建对应的embdding hub
    tasks_texts_embs_hub = {
        'pre_train': ["{}/pre_train/sampling_set.txt".format(texts_base_dir),
                        "{}/pre_train_combs/pre_train_{}_pro_embs_pt.pkl".format(embs_base_dir, comb)],
        'fluoresecence': ["{}/fluorescence/sequence_go_label_{}.txt".format(texts_base_dir, num_samples),
                        "{}/fluorescence/{}->fluorescence_pretrain_pro_embs_train.pkl".format(embs_base_dir, comb),
                        "{}/fluorescence/{}->fluorescence_pretrain_pro_embs_valid.pkl".format(embs_base_dir, comb)],
    }

    pt_datasets, pt_class_nums, pt_names = [], [], []
    for label_task in pt_tasks_combs[i]:
        dataset = EmbeddingDataset('pre_train', *tasks_texts_embs_hub['pre_train'], pt_label=label_task)
        pt_datasets.append(dataset)
        pt_class_nums.append(len(dataset.unique_labels))
        pt_names.append(label_task)

    ft_datasets, ft_class_nums, ft_names = [], [], []
    for task in ['fluoresecence']:
        dataset = EmbeddingDataset(task, *tasks_texts_embs_hub[task])
        ft_datasets.append(dataset)
        ft_class_nums.append(len(dataset.unique_labels))
        ft_names.append(task.capitalize())

    pt_tasks_dict = {task: i for i, task in enumerate(pt_tasks_combs[i])}
    MultiTask_WTE = WTE(lbl_emb_dim, device, pt_class_nums=np.array(pt_class_nums),
                        ft_class_nums=np.array(ft_class_nums), pt_names=pt_names, ft_names=ft_names,
                        pt_tasks_combs=[pt_tasks_combs[i]], pt_tasks_dict=pt_tasks_dict, gaussian_assumption=True)
    # pt_task_comb_embs, ft_task_embs = MultiTask_WTE.cwte(pt_datasets, ft_datasets, reference)
    pt_task_comb_embs, ft_task_embs, pt_task_sole_embs = MultiTask_WTE.cwte(pt_datasets, ft_datasets, reference, return_t2t=True)
    pt_task_comb_vecs = pt_task_comb_embs.reshape(1, -1) # each comb each time
    ft_task_vecs = ft_task_embs.reshape(1, -1) # (dataset_num, ref_size * (feat_dim+lbl_emb_dim)) flatten operation

    wte_distance_row = distance.cdist(pt_task_comb_vecs, ft_task_vecs, 'euclidean') # (dataset_num, dataset_num)
    wte_distance.append(wte_distance_row)
    print(wte_distance_row.shape, '\n', '==== x ==== ' * 5)
    
wte_distance = np.concatenate(wte_distance, axis=0)

Start computing for Pretrain-MultiTask: MLM+RMD
PRE-TRAIN mlm unique labels: 15
{2: 148, 10: 117, 6: 116, 8: 101, 9: 98, 4: 95, 5: 59, 14: 54, 1: 49, 3: 47, 7: 46, 12: 45, 0: 45, 11: 13, 13: 12}
PRE-TRAIN domain unique labels: 13
{0: 173, 1: 165, 2: 121, 3: 89, 4: 88, 5: 64, 6: 59, 7: 53, 8: 50, 9: 47, 10: 46, 11: 45, 12: 45}
PRE-TRAIN motif unique labels: 5
{0: 843, 1: 59, 2: 50, 3: 47, 4: 46}
PRE-TRAIN region unique labels: 10
{0: 202, 1: 173, 2: 165, 3: 121, 4: 89, 5: 88, 6: 64, 7: 53, 8: 45, 9: 45}
Fluoresecence unique labels: 39
{0: 127, 1: 118, 2: 102, 3: 84, 4: 78, 5: 78, 6: 56, 7: 43, 8: 37, 9: 34, 10: 31, 11: 28, 12: 28, 13: 19, 14: 16, 15: 15, 16: 13, 17: 13, 18: 12, 19: 11, 20: 8, 21: 7, 22: 7, 23: 5, 24: 5, 25: 3, 26: 3, 27: 3, 28: 2, 29: 2, 30: 2, 31: 2, 32: 2, 33: 1, 34: 1, 35: 1, 36: 1, 37: 1, 38: 1}
Embedding labels...
Computing inter label-to-label distance for mlm & domain


13it [00:00, 19.65it/s]
Computing label-to-label distances: 100%|██████████| 195/195 [00:10<00:00, 19.28it/s]


Computing inter label-to-label distance for mlm & motif


5it [00:00, 18.43it/s]
Computing label-to-label distances: 100%|██████████| 75/75 [00:03<00:00, 19.94it/s]


Computing inter label-to-label distance for mlm & region


10it [00:00, 17.41it/s]
Computing label-to-label distances: 100%|██████████| 150/150 [00:07<00:00, 20.28it/s]


Computing inter label-to-label distance for mlm & Fluoresecence


  cov = torch.stack([torch.cov(X[Y == y].T) for y in labels], dim=0)
15it [00:00, 20.12it/s]
Computing label-to-label distances: 100%|██████████| 585/585 [00:29<00:00, 19.72it/s]


Computing inter label-to-label distance for domain & motif


5it [00:00, 16.19it/s]
Computing label-to-label distances: 100%|██████████| 65/65 [00:03<00:00, 19.62it/s]


Computing inter label-to-label distance for domain & region


10it [00:00, 19.01it/s]
Computing label-to-label distances: 100%|██████████| 130/130 [00:06<00:00, 20.11it/s]


Computing inter label-to-label distance for domain & Fluoresecence


13it [00:00, 19.71it/s]
Computing label-to-label distances: 100%|██████████| 507/507 [00:26<00:00, 19.48it/s]


Computing inter label-to-label distance for motif & region


5it [00:00, 20.52it/s]
Computing label-to-label distances: 100%|██████████| 50/50 [00:02<00:00, 20.58it/s]


Computing inter label-to-label distance for motif & Fluoresecence


5it [00:00, 18.02it/s]
Computing label-to-label distances: 100%|██████████| 195/195 [00:09<00:00, 19.71it/s]


Computing inter label-to-label distance for region & Fluoresecence


10it [00:00, 19.43it/s]
Computing label-to-label distances: 100%|██████████| 390/390 [00:20<00:00, 19.26it/s]


Computing intra label-to-label distance for mlm


15it [00:00, 19.61it/s]
Computing label-to-label distances: 100%|██████████| 105/105 [00:05<00:00, 20.06it/s]


Computing intra label-to-label distance for domain


13it [00:00, 19.43it/s]
Computing label-to-label distances: 100%|██████████| 78/78 [00:03<00:00, 19.77it/s]


Computing intra label-to-label distance for motif


5it [00:00, 20.07it/s]
Computing label-to-label distances: 100%|██████████| 10/10 [00:00<00:00, 20.33it/s]


Computing intra label-to-label distance for region


10it [00:00, 16.75it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:02<00:00, 20.27it/s]


Computing intra label-to-label distance for Fluoresecence


39it [00:01, 20.17it/s]
Computing label-to-label distances: 100%|██████████| 741/741 [00:36<00:00, 20.03it/s]


Finish label embedding in 3.1 mins
Wasserstein embedding...
Finish WTE in 3.2 mins
(1, 1) 
 ==== x ==== ==== x ==== ==== x ==== ==== x ==== ==== x ==== 
Start computing for Pretrain-MultiTask: GO+RMD
PRE-TRAIN domain unique labels: 13
{0: 173, 1: 165, 2: 121, 3: 89, 4: 88, 5: 64, 6: 59, 7: 53, 8: 50, 9: 47, 10: 46, 11: 45, 12: 45}
PRE-TRAIN motif unique labels: 5
{0: 843, 1: 59, 2: 50, 3: 47, 4: 46}
PRE-TRAIN region unique labels: 10
{0: 202, 1: 173, 2: 165, 3: 121, 4: 89, 5: 88, 6: 64, 7: 53, 8: 45, 9: 45}
PRE-TRAIN go unique labels: 9
{0: 347, 1: 229, 2: 133, 3: 89, 4: 59, 5: 50, 6: 47, 7: 46, 8: 45}
Fluoresecence unique labels: 39
{0: 127, 1: 118, 2: 102, 3: 84, 4: 78, 5: 78, 6: 56, 7: 43, 8: 37, 9: 34, 10: 31, 11: 28, 12: 28, 13: 19, 14: 16, 15: 15, 16: 13, 17: 13, 18: 12, 19: 11, 20: 8, 21: 7, 22: 7, 23: 5, 24: 5, 25: 3, 26: 3, 27: 3, 28: 2, 29: 2, 30: 2, 31: 2, 32: 2, 33: 1, 34: 1, 35: 1, 36: 1, 37: 1, 38: 1}
Embedding labels...
Computing inter label-to-label distance for domain 

5it [00:00, 17.75it/s]
Computing label-to-label distances: 100%|██████████| 65/65 [00:03<00:00, 19.46it/s]


Computing inter label-to-label distance for domain & region


10it [00:00, 18.83it/s]
Computing label-to-label distances: 100%|██████████| 130/130 [00:06<00:00, 19.28it/s]


Computing inter label-to-label distance for domain & go


9it [00:00, 18.06it/s]
Computing label-to-label distances: 100%|██████████| 117/117 [00:05<00:00, 20.09it/s]


Computing inter label-to-label distance for domain & Fluoresecence


  cov = torch.stack([torch.cov(X[Y == y].T) for y in labels], dim=0)
13it [00:00, 18.11it/s]
Computing label-to-label distances: 100%|██████████| 507/507 [00:25<00:00, 19.72it/s]


Computing inter label-to-label distance for motif & region


5it [00:00, 17.44it/s]
Computing label-to-label distances: 100%|██████████| 50/50 [00:02<00:00, 19.59it/s]


Computing inter label-to-label distance for motif & go


5it [00:00, 18.72it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:02<00:00, 19.74it/s]


Computing inter label-to-label distance for motif & Fluoresecence


5it [00:00, 19.06it/s]
Computing label-to-label distances: 100%|██████████| 195/195 [00:09<00:00, 19.82it/s]


Computing inter label-to-label distance for region & go


9it [00:00, 17.82it/s]
Computing label-to-label distances: 100%|██████████| 90/90 [00:04<00:00, 19.36it/s]


Computing inter label-to-label distance for region & Fluoresecence


10it [00:00, 17.16it/s]
Computing label-to-label distances: 100%|██████████| 390/390 [00:19<00:00, 19.52it/s]


Computing inter label-to-label distance for go & Fluoresecence


9it [00:00, 19.07it/s]
Computing label-to-label distances: 100%|██████████| 351/351 [00:17<00:00, 19.63it/s]


Computing intra label-to-label distance for domain


13it [00:00, 19.79it/s]
Computing label-to-label distances: 100%|██████████| 78/78 [00:03<00:00, 19.94it/s]


Computing intra label-to-label distance for motif


5it [00:00, 19.95it/s]
Computing label-to-label distances: 100%|██████████| 10/10 [00:00<00:00, 20.12it/s]


Computing intra label-to-label distance for region


10it [00:00, 19.26it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:02<00:00, 20.51it/s]


Computing intra label-to-label distance for go


9it [00:00, 20.02it/s]
Computing label-to-label distances: 100%|██████████| 36/36 [00:01<00:00, 20.32it/s]


Computing intra label-to-label distance for Fluoresecence


39it [00:01, 20.38it/s]
Computing label-to-label distances: 100%|██████████| 741/741 [00:37<00:00, 19.74it/s]


Finish label embedding in 2.7 mins
Wasserstein embedding...
Finish WTE in 2.7 mins
(1, 1) 
 ==== x ==== ==== x ==== ==== x ==== ==== x ==== ==== x ==== 
Start computing for Pretrain-MultiTask: MLM+GO+D
PRE-TRAIN mlm unique labels: 15
{2: 148, 10: 117, 6: 116, 8: 101, 9: 98, 4: 95, 5: 59, 14: 54, 1: 49, 3: 47, 7: 46, 12: 45, 0: 45, 11: 13, 13: 12}
PRE-TRAIN domain unique labels: 13
{0: 173, 1: 165, 2: 121, 3: 89, 4: 88, 5: 64, 6: 59, 7: 53, 8: 50, 9: 47, 10: 46, 11: 45, 12: 45}
PRE-TRAIN go unique labels: 9
{0: 347, 1: 229, 2: 133, 3: 89, 4: 59, 5: 50, 6: 47, 7: 46, 8: 45}
Fluoresecence unique labels: 39
{0: 127, 1: 118, 2: 102, 3: 84, 4: 78, 5: 78, 6: 56, 7: 43, 8: 37, 9: 34, 10: 31, 11: 28, 12: 28, 13: 19, 14: 16, 15: 15, 16: 13, 17: 13, 18: 12, 19: 11, 20: 8, 21: 7, 22: 7, 23: 5, 24: 5, 25: 3, 26: 3, 27: 3, 28: 2, 29: 2, 30: 2, 31: 2, 32: 2, 33: 1, 34: 1, 35: 1, 36: 1, 37: 1, 38: 1}
Embedding labels...
Computing inter label-to-label distance for mlm & domain


13it [00:00, 18.36it/s]
Computing label-to-label distances: 100%|██████████| 195/195 [00:09<00:00, 19.62it/s]


Computing inter label-to-label distance for mlm & go


9it [00:00, 18.37it/s]
Computing label-to-label distances: 100%|██████████| 135/135 [00:06<00:00, 20.02it/s]


Computing inter label-to-label distance for mlm & Fluoresecence


  cov = torch.stack([torch.cov(X[Y == y].T) for y in labels], dim=0)
15it [00:00, 18.95it/s]
Computing label-to-label distances: 100%|██████████| 585/585 [00:29<00:00, 19.78it/s]


Computing inter label-to-label distance for domain & go


9it [00:00, 18.41it/s]
Computing label-to-label distances: 100%|██████████| 117/117 [00:05<00:00, 20.04it/s]


Computing inter label-to-label distance for domain & Fluoresecence


13it [00:00, 19.92it/s]
Computing label-to-label distances: 100%|██████████| 507/507 [00:25<00:00, 19.71it/s]


Computing inter label-to-label distance for go & Fluoresecence


9it [00:00, 17.28it/s]
Computing label-to-label distances: 100%|██████████| 351/351 [00:17<00:00, 19.89it/s]


Computing intra label-to-label distance for mlm


15it [00:00, 20.49it/s]
Computing label-to-label distances: 100%|██████████| 105/105 [00:05<00:00, 19.78it/s]


Computing intra label-to-label distance for domain


13it [00:00, 20.70it/s]
Computing label-to-label distances: 100%|██████████| 78/78 [00:03<00:00, 20.66it/s]


Computing intra label-to-label distance for go


9it [00:00, 19.55it/s]
Computing label-to-label distances: 100%|██████████| 36/36 [00:01<00:00, 19.19it/s]


Computing intra label-to-label distance for Fluoresecence


39it [00:01, 20.90it/s]
Computing label-to-label distances: 100%|██████████| 741/741 [00:37<00:00, 19.97it/s]


Finish label embedding in 2.6 mins
Wasserstein embedding...
Finish WTE in 2.6 mins
(1, 1) 
 ==== x ==== ==== x ==== ==== x ==== ==== x ==== ==== x ==== 
Start computing for Pretrain-MultiTask: RMD
PRE-TRAIN domain unique labels: 13
{0: 173, 1: 165, 2: 121, 3: 89, 4: 88, 5: 64, 6: 59, 7: 53, 8: 50, 9: 47, 10: 46, 11: 45, 12: 45}
PRE-TRAIN motif unique labels: 5
{0: 843, 1: 59, 2: 50, 3: 47, 4: 46}
PRE-TRAIN region unique labels: 10
{0: 202, 1: 173, 2: 165, 3: 121, 4: 89, 5: 88, 6: 64, 7: 53, 8: 45, 9: 45}
Fluoresecence unique labels: 39
{0: 127, 1: 118, 2: 102, 3: 84, 4: 78, 5: 78, 6: 56, 7: 43, 8: 37, 9: 34, 10: 31, 11: 28, 12: 28, 13: 19, 14: 16, 15: 15, 16: 13, 17: 13, 18: 12, 19: 11, 20: 8, 21: 7, 22: 7, 23: 5, 24: 5, 25: 3, 26: 3, 27: 3, 28: 2, 29: 2, 30: 2, 31: 2, 32: 2, 33: 1, 34: 1, 35: 1, 36: 1, 37: 1, 38: 1}
Embedding labels...
Computing inter label-to-label distance for domain & motif


5it [00:00, 18.37it/s]
Computing label-to-label distances: 100%|██████████| 65/65 [00:03<00:00, 19.85it/s]


Computing inter label-to-label distance for domain & region


10it [00:00, 19.94it/s]
Computing label-to-label distances: 100%|██████████| 130/130 [00:06<00:00, 20.18it/s]


Computing inter label-to-label distance for domain & Fluoresecence


  cov = torch.stack([torch.cov(X[Y == y].T) for y in labels], dim=0)
13it [00:00, 18.30it/s]
Computing label-to-label distances: 100%|██████████| 507/507 [00:25<00:00, 19.94it/s]


Computing inter label-to-label distance for motif & region


5it [00:00, 17.41it/s]
Computing label-to-label distances: 100%|██████████| 50/50 [00:02<00:00, 19.07it/s]


Computing inter label-to-label distance for motif & Fluoresecence


5it [00:00, 16.46it/s]
Computing label-to-label distances: 100%|██████████| 195/195 [00:10<00:00, 19.13it/s]


Computing inter label-to-label distance for region & Fluoresecence


10it [00:00, 19.49it/s]
Computing label-to-label distances: 100%|██████████| 390/390 [00:19<00:00, 19.68it/s]


Computing intra label-to-label distance for domain


13it [00:00, 20.69it/s]
Computing label-to-label distances: 100%|██████████| 78/78 [00:03<00:00, 19.96it/s]


Computing intra label-to-label distance for motif


5it [00:00, 18.29it/s]
Computing label-to-label distances: 100%|██████████| 10/10 [00:00<00:00, 20.23it/s]


Computing intra label-to-label distance for region


10it [00:00, 19.26it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:02<00:00, 20.56it/s]


Computing intra label-to-label distance for Fluoresecence


39it [00:01, 20.70it/s]
Computing label-to-label distances: 100%|██████████| 741/741 [00:37<00:00, 19.87it/s]


Finish label embedding in 2.1 mins
Wasserstein embedding...
Finish WTE in 2.1 mins
(1, 1) 
 ==== x ==== ==== x ==== ==== x ==== ==== x ==== ==== x ==== 
