In [1]:
import time
import random
import itertools
import collections
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset, DataLoader


class EmbeddingDataset(Dataset):
    def __init__(self, task, text_path, train_embs_path, valid_embs_path=None, pt_label=''):
        embs, labels, unique_labels = self.load_embs_and_labels(task, text_path, train_embs_path, valid_embs_path, pt_label)
        if task != 'pre_train':
            print("{} unique labels: {}".format(task.capitalize(), len(unique_labels)))
        else:
            print("PRE-TRAIN {} unique labels: {}".format(pt_label, len(unique_labels)))
        counter = collections.Counter(labels)
        sorted_counter = dict(sorted(counter.items(), key=lambda kv: kv[1], reverse=True))
        print(sorted_counter)
        self.unique_labels = unique_labels

        self.embs_and_labels = pd.concat([embs, labels], axis=1)
        # 创建otdd需求的Dataset属性
        self.classes = [str(k) for k in unique_labels] # list of unique labels string
        self.targets = torch.tensor(list(labels))

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        embedding = self.embs_and_labels.iloc[idx]['pro_emb']
        label = self.embs_and_labels.iloc[idx]['label']
        return torch.tensor(embedding).reshape(1, -1, 1), torch.tensor(label)

    def load_embs_and_labels(self, task, text_path, train_embs_path, valid_embs_path, pt_label):
        texts = pd.read_csv(text_path)
        if task=='pre_train':
            texts.rename(columns={pt_label: 'label'}, inplace=True)
        labels = texts['label']
        unique_labels = labels.unique()

        if task=='pre_train' or task=='kinase': # pre_train / kinase
            embs = pd.read_pickle(train_embs_path)
        else:
            train_embs, valid_embs = pd.read_pickle(train_embs_path), pd.read_pickle(valid_embs_path)
            embs = pd.concat([train_embs, valid_embs], axis=0)

        selected_ids = texts['uniprot_id'].tolist()
        selected_embs_flag = embs['pro_id'].map(lambda x: True if x in selected_ids else False)
        embs = embs[selected_embs_flag].reset_index(drop=True)
        embs = embs['pro_emb']

        return embs, labels, unique_labels
    
def set_random_seed(SEED):
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)

def cal_combs_distances_mean(all_pt_tasks, t2t_distances_list):
    mean_t2t_distances = []
    for pt_task in all_pt_tasks:
        spec_pt_task_distance = []
        for t2t_distances in t2t_distances_list:
            if pt_task in t2t_distances['pt_task'].tolist():
                spec_pt_task_distance.append(t2t_distances[t2t_distances['pt_task']==pt_task].to_numpy()[:, :-1].reshape(1, -1))
        spec_pt_task_distance = np.concatenate(spec_pt_task_distance, axis=0).astype(np.float32).mean(axis=0)
        mean_t2t_distances.append(spec_pt_task_distance)
    mean_t2t_distances = np.stack(mean_t2t_distances, axis=0)
    return mean_t2t_distances

In [45]:
num_samples = 1000
TaskCombs = ['MLM+RMD', 'GO+RMD', 'MLM+GO+D', 'RMD']
pt_tasks_combs = [
    ['mlm', 'domain', 'motif', 'region'],
    ['domain', 'motif', 'region', 'go'],
    ['mlm', 'domain', 'go'],
    ['domain', 'motif', 'region']
]
# pt_tasks_combs = [list(reversed(l)) for l in pt_tasks_combs]
all_pt_tasks = ['mlm', 'domain', 'motif', 'region', 'go']
device = torch.device('cpu')
p = 2

# prepare data
texts_base_dir = "/home/brian/work/OTMTD_GH/processed_data"
embs_base_dir = "/home/brian/work/OTMTD_GH/protein_embeddings_MultiTasks"
data_stuffs_dict = {}
for i, comb in enumerate(TaskCombs):
    print("Start loading data for Pretrain-MultiTask: {}".format(comb))
    # 根据multi-tasks combination构建对应的embdding hub
    tasks_texts_embs_hub = {
        'pre_train': ["{}/pre_train/sampling_set.txt".format(texts_base_dir),
                        "{}/pre_train_combs/pre_train_{}_pro_embs_pt.pkl".format(embs_base_dir, comb)],
        'stability': ["{}/stability/sequence_go_label_{}.txt".format(texts_base_dir, num_samples),
                        "{}/stability/{}->stability_pretrain_pro_embs_train.pkl".format(embs_base_dir, comb),
                        "{}/stability/{}->stability_pretrain_pro_embs_valid.pkl".format(embs_base_dir, comb)],
        'fluoresecence': ["{}/fluorescence/sequence_go_label_{}.txt".format(texts_base_dir, num_samples),
                        "{}/fluorescence/{}->fluorescence_pretrain_pro_embs_train.pkl".format(embs_base_dir, comb),
                        "{}/fluorescence/{}->fluorescence_pretrain_pro_embs_valid.pkl".format(embs_base_dir, comb)],
        'secondary_structrue': ["{}/secondary_structure/sequence_go_label_{}.txt".format(texts_base_dir, num_samples),
                                "{}/secondary_structure/{}->secondary_structure_pretrain_pro_embs_train.pkl".format(embs_base_dir, comb),
                                "{}/secondary_structure/{}->secondary_structure_pretrain_pro_embs_valid.pkl".format(embs_base_dir, comb)],
        'pdbbind': ["{}/pdbbind/sequence_go_label_{}.txt".format(texts_base_dir, num_samples),
                    "{}/pdbbind/{}->pdbbind_pretrain_pro_embs_train.pkl".format(embs_base_dir, comb),
                    "{}/pdbbind/{}->pdbbind_pretrain_pro_embs_dev.pkl".format(embs_base_dir, comb)],
        'kinase': ["{}/kinase/sequence_go_label_{}.txt".format(texts_base_dir, num_samples),
                    "{}/kinase//{}->kinase_pretrain_pro_embs_train.pkl".format(embs_base_dir, comb)]
    }

    pt_datasets, pt_names = [], []
    for label_task in pt_tasks_combs[i]:
        dataset = EmbeddingDataset('pre_train', *tasks_texts_embs_hub['pre_train'], pt_label=label_task)
        pt_datasets.append(dataset)
        pt_names.append(label_task)

    ft_datasets, ft_names = [], []
    for task in ['stability', 'fluoresecence', 'secondary_structrue', 'pdbbind', 'kinase']: # 
        dataset = EmbeddingDataset(task, *tasks_texts_embs_hub[task])
        ft_datasets.append(dataset)
        ft_names.append(task.capitalize())

    pt_tasks_dict = {task: idx for idx, task in enumerate(pt_tasks_combs[i])}

    data_stuff = [pt_datasets, pt_names, ft_datasets, ft_names, pt_tasks_dict]
    data_stuffs_dict[comb] = data_stuff

Start loading data for Pretrain-MultiTask: MLM+RMD
PRE-TRAIN mlm unique labels: 15
{2: 148, 10: 117, 6: 116, 8: 101, 9: 98, 4: 95, 5: 59, 14: 54, 1: 49, 3: 47, 7: 46, 12: 45, 0: 45, 11: 13, 13: 12}
PRE-TRAIN domain unique labels: 13
{0: 173, 1: 165, 2: 121, 3: 89, 4: 88, 5: 64, 6: 59, 7: 53, 8: 50, 9: 47, 10: 46, 11: 45, 12: 45}
PRE-TRAIN motif unique labels: 5
{0: 843, 1: 59, 2: 50, 3: 47, 4: 46}
PRE-TRAIN region unique labels: 10
{0: 202, 1: 173, 2: 165, 3: 121, 4: 89, 5: 88, 6: 64, 7: 53, 8: 45, 9: 45}
Stability unique labels: 12
{0: 101, 1: 87, 2: 83, 3: 57, 4: 53, 5: 34, 6: 27, 7: 20, 8: 14, 9: 11, 10: 7, 11: 6}
Fluoresecence unique labels: 12
{0: 142, 1: 83, 2: 80, 3: 48, 4: 38, 5: 35, 6: 20, 7: 20, 8: 14, 9: 9, 10: 6, 11: 5}
Secondary_structrue unique labels: 15
{9: 55, 3: 50, 1: 46, 14: 45, 8: 42, 7: 38, 13: 38, 0: 32, 6: 30, 11: 28, 2: 25, 5: 23, 12: 19, 10: 17, 4: 12}
Pdbbind unique labels: 11
{0: 81, 1: 70, 2: 66, 3: 58, 4: 57, 5: 56, 6: 33, 7: 28, 8: 28, 9: 18, 10: 5}
Kinas

## H-Score

In [57]:
def getCov(X):
    X_mean=X-np.mean(X,axis=0,keepdims=True)
    cov = np.divide(np.dot(X_mean.T, X_mean), len(X)-1) 
    return cov

def getHscore(f,Z):
    #Z=np.argmax(Z, axis=1)
    Covf=getCov(f)
    alphabetZ=list(set(Z))
    g=np.zeros_like(f)
    for z in alphabetZ:
        Ef_z=np.mean(f[Z==z, :], axis=0)
        g[Z==z]=Ef_z
    
    Covg=getCov(g)
    score=np.trace(np.dot(np.linalg.pinv(Covf,rcond=1e-15), Covg))
    return score

Hscores = []
for i, comb in enumerate(TaskCombs):
    print("Start computing for Pretrain-MultiTask: {}".format(comb))
    data_stuff = data_stuffs_dict[comb]
    # pt_datasets, pt_names, ft_datasets, ft_names, pt_tasks_dict = data_stuff
    _, _, ft_datasets, ft_names, _ = data_stuff
    comb_Hscores = []
    # ptdataset_emb = np.stack(pt_datasets[0].embs_and_labels['pro_emb'].tolist(), axis=0)
    for j, ft_dataset in enumerate(ft_datasets):
        dataset_emb = np.stack(ft_dataset.embs_and_labels['pro_emb'].tolist(), axis=0)
        dataset_label = ft_dataset.embs_and_labels['label'].values
        score = getHscore(dataset_emb, dataset_label)
        comb_Hscores.append(score)
        print("H-score for {}-{}: {}".format(comb, ft_names[j], score))
    Hscores.append(comb_Hscores)
    print('==== x ==== ' * 5)
Hscores = np.array(Hscores)

np.save("./results(final)/bl4_hscore.npy", Hscores)

Start computing for Pretrain-MultiTask: MLM+RMD
H-score for MLM+RMD-Stability: 10.999999999999861
H-score for MLM+RMD-Fluoresecence: 11.00000000002128
H-score for MLM+RMD-Secondary_structrue: 13.999999999567322
H-score for MLM+RMD-Pdbbind: 9.975969795581932
H-score for MLM+RMD-Kinase: 0.8673495176481083
==== x ==== ==== x ==== ==== x ==== ==== x ==== ==== x ==== 
Start computing for Pretrain-MultiTask: GO+RMD
H-score for GO+RMD-Stability: 11.00000000000191
H-score for GO+RMD-Fluoresecence: 10.999999999550553
H-score for GO+RMD-Secondary_structrue: 13.99999999995598
H-score for GO+RMD-Pdbbind: 9.97596982859892
H-score for GO+RMD-Kinase: 0.9347650221534423
==== x ==== ==== x ==== ==== x ==== ==== x ==== ==== x ==== 
Start computing for Pretrain-MultiTask: MLM+GO+D
H-score for MLM+GO+D-Stability: 11.000000000001632
H-score for MLM+GO+D-Fluoresecence: 11.000000000177018
H-score for MLM+GO+D-Secondary_structrue: 13.999999999868336
H-score for MLM+GO+D-Pdbbind: 9.975969858724056
H-score for 