In [2]:
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm

In [3]:
model_seq_info = pd.read_parquet('../data3/interim/model_seq_info.pq')
embedding_df = pd.read_parquet('../data3/interim/model_neighbor_representations.pq')
embedding_df = embedding_df.loc[embedding_df.index.isin(model_seq_info['seq_id']), :]

In [4]:
def sim_matrix(a, b, eps=1e-8):
    """
    added eps for numerical stability
    """
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.clamp(a_n, min=eps)
    b_norm = b / torch.clamp(b_n, min=eps)
    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sim_mt

In [5]:
chunk_size = 1_000

In [6]:
fold_out_list = list()
for fold, fold_df in model_seq_info.groupby('test_fold'):
    print(fold)
    train_embeddings = embedding_df.loc[~embedding_df.index
                                        .isin(fold_df['seq_id']), :]
    train_seq_ids = train_embeddings.index
    train_embeddings = train_embeddings.to_numpy()
    test_embeddings = embedding_df.loc[embedding_df.index
                                       .isin(fold_df['seq_id']), :]
    test_seq_ids = test_embeddings.index
    test_embeddings = test_embeddings.to_numpy()
    B = torch.from_numpy(test_embeddings).to('cuda')
    sim_mat = np.empty((len(train_embeddings),len(test_embeddings)))
    i = 0
    for _ in tqdm(range(int(np.ceil(len(train_embeddings)/chunk_size))), position=0):
        end = min(i+chunk_size, len(train_embeddings))
        A = torch.from_numpy(train_embeddings[i:end,:]).to('cuda')
        similarity_matrix = sim_matrix(A, B)
        numpy_sim_mat = similarity_matrix.data.cpu().numpy()
        sim_mat[i:end, :] = numpy_sim_mat
        i += chunk_size
    max_cosine = sim_mat.max(axis=0)
    max_idx = sim_mat.argmax(axis=0)
    max_id = train_seq_ids[max_idx]
    fold_out = pd.DataFrame({'cosine_similarity': max_cosine,
                            'nearest_neighbor': max_id},
                            index=test_seq_ids)
    fold_out_list.append(fold_out)

0


100%|██████████| 157/157 [00:25<00:00,  6.27it/s]


1


100%|██████████| 155/155 [00:25<00:00,  5.99it/s]


2


100%|██████████| 164/164 [00:21<00:00,  7.46it/s]


3


100%|██████████| 162/162 [00:23<00:00,  6.96it/s]


4


100%|██████████| 167/167 [00:21<00:00,  7.72it/s]


In [7]:
cosine_similarity_df = pd.concat(fold_out_list)

In [8]:
out_df = (cosine_similarity_df.reset_index()
          .rename(columns={'index':'seq_id'})
          .merge(model_seq_info
                 .rename(columns={'seq_id': 'nearest_neighbor'}), 
                 how='inner', on='nearest_neighbor'))
out_df['direction'] = out_df['defensive']*2 - 1
out_df['prediction'] = out_df['direction']*out_df['cosine_similarity']
out_df = out_df[['seq_id', 'prediction']]
out_df['method'] = 'ESM2 150M nearest neighbor'

In [9]:
out_df.to_parquet('../data3/interim/cv_predictions_esm2_150M.pq', index=False)