In [1]:
import os
import sys
import warnings
from itertools import groupby
from multiprocessing import Pool

import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from scipy.spatial.distance import cdist, squareform
from sklearn.manifold import trustworthiness
from umap import UMAP

import torch

sys.dont_write_bytecode = True
np.set_printoptions(precision=6, suppress=True)

from my_library import Database, Metrics, ESM_Representations, gzip_tensor, neighbor_joining, cophenetic_distmat, silhouette

RANDOM_SEED = 420

# Allocate resources

- **THREADS** : this variable will be passed to `multiprocessing.Pool`


In [2]:
THREADS = 2

# Define input and output files

- **DB_FILE** : (input) sqlite database containing protein sequence embedding
- **LABEL_CSV** : (input, optional) csv containing labels for calculating silhouette score
- **OUTPUT_DIR** : (output) directory for storing results

If you do not wish to provide an optional file, set it equal to an empty string like so `LABEL_CSV = ''`


In [3]:
# for the phosphatase dataset
DB_FILE    = 'datasets/phosphatase/phosphatase.db'
LABEL_CSV  = 'datasets/phosphatase/phosphatase_labels.csv'
OUTPUT_DIR = 'datasets/phosphatase/phosphatase_models'

# # for the kinase dataset
# DB_FILE    = 'datasets/protein_kinase/kinase.db'
# LABEL_CSV  = 'datasets/protein_kinase/kinase_labels.csv'
# OUTPUT_DIR = 'datasets/protein_kinase/kinase_models'

# # for the radical sam dataset
# DB_FILE    = 'datasets/radical_sam/radicalsam.db'
# LABEL_CSV  = 'datasets/radical_sam/radicalsam_labels.csv'
# OUTPUT_DIR = 'datasets/radical_sam/radicalsam_models'


# Set up variables

In [6]:
# load the database
db = Database(DB_FILE)

# load columns from the database as numpy arrays
_format = lambda x: (x['header'], x['sequence'], gzip_tensor(x['embedding']).numpy())
headers, sequences, embeddings = zip(*(_format(i) for i in db.retrieve()))

headers    = np.array(headers   , dtype=object)
accessions = np.array([i.split()[0] for i in headers], dtype=object)
sequences  = np.array(sequences , dtype=object)
embeddings = np.array(embeddings, dtype=object)

# if labels are available load them from csv
if os.path.exists(LABEL_CSV):
    labels = dict(pd.read_csv(LABEL_CSV)[['accession','label']].fillna('').values)
    labels = np.array([labels[i] for i in accessions], dtype=object)
else:
    labels = []

# representations and metrics can be added and removed from this list
representations = {
    'beginning_of_sequence'  : ESM_Representations.beginning_of_sequence,
    'end_of_sequence'        : ESM_Representations.end_of_sequence,
    'mean_of_special_tokens' : ESM_Representations.mean_special_tokens,
    'mean_of_residue_tokens' : ESM_Representations.mean_residue_tokens,
    }

metrics = {
    'cosine'        : Metrics.cosine,
    'euclidean'     : Metrics.euclidean,
    'manhattan'     : Metrics.manhattan,
    'ts_ss'         : Metrics.ts_ss,
    # 'jensenshannon' : Metrics.jensenshannon,
    }


# Initialize distance matrices

In [7]:
# create base directory
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)
if not os.path.exists(f'{OUTPUT_DIR}/fixedsize'):
    os.mkdir(f'{OUTPUT_DIR}/fixedsize')

models = []
for _rep in representations:
    rep = np.array([representations[_rep](i) for i in embeddings])
    np.savez_compressed(f'{OUTPUT_DIR}/fixedsize/{_rep}.npz', 
                        **{'headers':headers, 'embedding':rep})
    for _met in metrics:
        distmat = metrics[_met](rep,rep)
        models += [{
            'representation'          : _rep,
            'metric'                  : _met,
            'distmat(repr)'           : distmat}]
        if len(labels) == distmat.shape[0]:
            models[-1]['silhouette(repr)'] = silhouette(distmat, labels)
        
        sys.stderr.write(f'Calculating distances using "{_rep}" with "{_met}"\n')

Calculating distances using "beginning_of_sequence" with "cosine"
Calculating distances using "beginning_of_sequence" with "euclidean"
Calculating distances using "beginning_of_sequence" with "manhattan"
Calculating distances using "beginning_of_sequence" with "ts_ss"
Calculating distances using "end_of_sequence" with "cosine"
Calculating distances using "end_of_sequence" with "euclidean"
Calculating distances using "end_of_sequence" with "manhattan"
Calculating distances using "end_of_sequence" with "ts_ss"
Calculating distances using "mean_of_special_tokens" with "cosine"
Calculating distances using "mean_of_special_tokens" with "euclidean"
Calculating distances using "mean_of_special_tokens" with "manhattan"
Calculating distances using "mean_of_special_tokens" with "ts_ss"
Calculating distances using "mean_of_residue_tokens" with "cosine"
Calculating distances using "mean_of_residue_tokens" with "euclidean"
Calculating distances using "mean_of_residue_tokens" with "manhattan"
Calcul

# Perform UMAP (with densmap)

In [8]:
for n, i in enumerate(models):
    sys.stderr.write(f'{1+n} / {len(models)}\r')
    umap = UMAP(n_components=2,
                densmap=True,
                metric='precomputed',
                output_metric='euclidean',
                random_state=RANDOM_SEED)
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        _viz = umap.fit_transform(i['distmat(repr)']).astype(np.float16)
        _distmat = cdist(_viz, _viz, metric='euclidean')
    _trustworthiness = trustworthiness(distmat, _distmat, n_neighbors=10, metric='precomputed')
    _spearman = spearmanr(squareform(distmat, checks=False), squareform(_distmat, checks=False))[0]
    models[n]['viz(densmap)'            ] = _viz
    models[n]['distmat(densmap)'        ] = _distmat
    models[n]['spearman(densmap)'       ] = _spearman
    models[n]['trustworthiness(densmap)'] = _trustworthiness
    if len(labels) == _distmat.shape[0]:
        models[n]['silhouette(densmap)'] = silhouette(_distmat, labels)


16 / 16

# Perform Neighbor Joining

In [9]:
def _do_nj(distmat, names, labels):
    newick    = neighbor_joining(distmat, names)
    _distmat  = cophenetic_distmat(newick, names=names)
    _spearman = spearmanr(squareform(distmat, checks=False), squareform(_distmat, checks=False))[0]
    _trustworthiness = trustworthiness(distmat, _distmat, n_neighbors=10, metric='precomputed')
    out = {'viz(nj)': newick, 'distmat(nj)': _distmat, 'spearman(nj)': _spearman, 'trustworthiness(nj)':_trustworthiness}
    if len(labels) == _distmat.shape[0]:
        out['silhouette(nj)'] = silhouette(_distmat, labels)
    return out

_queue = ((i['distmat(repr)'], headers, labels) for i in models)
pool   = Pool(THREADS)
_out   = pool.starmap(_do_nj, _queue)
pool.terminate()

for n, i in enumerate(_out):
    models[n]['viz(nj)'            ] = i['viz(nj)'            ]
    models[n]['distmat(nj)'        ] = i['distmat(nj)'        ]
    models[n]['spearman(nj)'       ] = i['spearman(nj)'       ]
    models[n]['trustworthiness(nj)'] = i['trustworthiness(nj)']
    if 'silhouette(nj)' in i:
        models[n]['silhouette(nj)'] = i['silhouette(nj)']
        

# Write to output directory

In [10]:
# create base directory
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

# export each model
if not os.path.exists(f'{OUTPUT_DIR}/models'):
    os.mkdir(f'{OUTPUT_DIR}/models')
for i in models:
    npz_file = f'{OUTPUT_DIR}/models/{i["representation"]}_{i["metric"]}.npz'
    np.savez_compressed(npz_file, **i)

# export sequence information
np.savez_compressed(f'{OUTPUT_DIR}/sequences.npz', **{'headers':headers,'sequences':sequences})
with open(f'{OUTPUT_DIR}/sequences.fa', 'w') as w:
    w.write('\n'.join(f'>{i}\n{j}\n' for i, j in zip(headers,sequences)))

# export summary
cols = ['representation','metric','silhouette(repr)','spearman(densmap)','trustworthiness(densmap)','silhouette(densmap)','spearman(nj)','trustworthiness(nj)','silhouette(nj)']
cols = [i for i in cols if i in models[0].keys()]
df = pd.DataFrame([[j[k] for k in cols] for j in models], columns=cols)
df.to_csv(f'{OUTPUT_DIR}/summary.csv', index=False)
display(df)

Unnamed: 0,representation,metric,silhouette(repr),spearman(densmap),trustworthiness(densmap),silhouette(densmap),spearman(nj),trustworthiness(nj),silhouette(nj)
0,beginning_of_sequence,cosine,0.275104,0.294364,0.940745,-0.046344,0.862241,0.899771,0.353073
1,beginning_of_sequence,euclidean,0.193203,0.223973,0.919811,-0.086645,0.846555,0.916414,0.241617
2,beginning_of_sequence,manhattan,0.206233,0.239239,0.928262,0.003195,0.861505,0.910751,0.233878
3,beginning_of_sequence,ts_ss,0.285098,0.334057,0.9456,0.066892,0.854308,0.904702,0.308574
4,end_of_sequence,cosine,0.188625,0.434386,0.873914,-0.115854,0.87154,0.935432,0.190868
5,end_of_sequence,euclidean,0.155369,0.335323,0.872481,-0.24295,0.860459,0.947444,0.167513
6,end_of_sequence,manhattan,0.158789,0.444178,0.870742,-0.106907,0.88117,0.948918,0.173888
7,end_of_sequence,ts_ss,0.185571,0.443043,0.874177,-0.121271,0.883734,0.930509,0.184211
8,mean_of_special_tokens,cosine,0.260719,0.419656,0.889148,-0.047834,0.885202,0.922744,0.256475
9,mean_of_special_tokens,euclidean,0.18646,0.395369,0.877942,-0.052249,0.854755,0.931581,0.196451
