In [2]:
import os
import src.models_and_optimizers as model_utils
import yaml
from types import SimpleNamespace
from clip_main import get_wds_loaders
from transformers import EsmTokenizer
import src.data_utils as data_utils
import torch
import sys
import pickle
from tqdm import tqdm
import numpy as np
from tmtools.io import get_structure, get_residue_data
from tmtools import tm_align
import matplotlib.pyplot as plt
import json
from torch.cuda.amp import autocast
import tmscoring
import json
import copy
from scipy.stats.stats import pearsonr 
import matplotlib.pyplot as plt
import yaml
import pandas as pd
import glob
import webdataset as wds

2024-02-01 17:44:38.121286: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-02-01 17:44:38.243326: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-02-01 17:44:39.237567: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-02-01 17:44:39.237621: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] 

In [3]:
## Mask peptide sequence information
def mask_peptide(seq_batch, coords_batch, pdb):
    chain_len_dicts = {}
    lens = coords_batch['seq_lens']
    chain_len_dicts['protein'] = max(lens)
    chain_len_dicts['peptide'] = min(lens)
    from_back = lens[1] == min(lens)
    if from_back:
        seq_batch['string_sequence'][0] = seq_batch['string_sequence'][0][:-peptide_len] + 'X'*peptide_len
        seq_batch['seq_loss_mask'][0][:,:-1*peptide_len] = False
        seq_batch['seq_loss_mask'][1][:,:-1*peptide_len] = False
    else:
        seq_batch['string_sequence'][0] = 'X'*peptide_len + seq_batch['string_sequence'][0][peptide_len:]
        seq_batch['seq_loss_mask'][0][:,peptide_len:] = False
        seq_batch['seq_loss_mask'][1][:,peptide_len:] = False
    return seq_batch

def mask_all(seq_batch, pdb):
    seq_batch['string_sequence'][0] = 'X'*len(seq_batch['string_sequence'][0])
    seq_batch['seq_loss_mask'][0][:,:] = False
    seq_batch['seq_loss_mask'][1][:,:] = False
    return seq_batch

## Extract kNN info
def extract_knn(X, eps, top_k):
    # Convolutional network on NCHW
    dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2)
    D = torch.sqrt(torch.sum(dX**2, 3) + eps)
    mask_2D = torch.ones(D.shape)
    D *= mask_2D

    # Identify k nearest neighbors (including self)
    D_max, _ = torch.max(D, -1, keepdim=True)
    D_adjust = D + (1. - mask_2D) * D_max
    D_neighbors, E_idx = torch.topk(D_adjust, top_k, dim=-1, largest=False)
    return D_neighbors, E_idx

## Find residues at the interface between 2 chains
def get_interaction_res(coords_batch, pdb, top_k, threshold=5, remove_far=False, mres=None, dist_threshold=None):
    chain_len_dicts = {}
    lens = coords_batch['seq_lens'][0]
    chain_len_dicts['protein'] = max(lens)
    chain_len_dicts['peptide'] = min(lens)
    if len(lens) > 1:
        from_back = lens[1] == min(lens)
    else:
        from_back = True
    top_k = min(top_k, coords_batch['coords'][0].size(1))
    D_neighbors, E_idx = extract_knn(coords_batch['coords'][0][:,:,1,:], eps=1e-6, top_k=top_k) 
    if from_back:
        interaction_res = set(range(chain_len_dicts['protein'], chain_len_dicts['protein']+chain_len_dicts['peptide']))
    else:
        interaction_res = set(range(0, chain_len_dicts['peptide']))
    # interaction_res = set(range(0, chain_len_dicts['protein']+chain_len_dicts['peptide']))
    if mres is not None:
        interaction_res = set(mres)
    if top_k == 0:
        return list(interaction_res)
    prot_to_add = set()
    for res in interaction_res:
        neighs_to_add = list()
        for dist, neigh in zip(D_neighbors[0, res], E_idx[0, res]):
            if dist_threshold and dist > dist_threshold:
                continue
            neighs_to_add.append(neigh)
        prot_to_add = prot_to_add.union(set(neighs_to_add))
    interaction_res = list(interaction_res.union(prot_to_add))
    if remove_far:
        to_remove = []
        chain_lens = torch.cat([torch.zeros(lens[0]), torch.ones(lens[1])])
        for res in interaction_res:
            nother = 0
            
            opp = 1 - chain_lens[res].item()
            for nres in E_idx[0, res]:
                if chain_lens[nres].item() == opp:
                    nother += 1

            if nother < threshold:
                to_remove.append(res)
        for res in to_remove:
            interaction_res.remove(res)
    return interaction_res

## Get distances between residues at the interface between 2 chains
def get_inter_dists(coords_batch, interaction_res, eps=1e-6):
    chain_len_dicts = {}
    chains, lens = torch.unique_consecutive(coords_batch['chain_lens'][0][0], return_counts=True)
    chain_len_dicts['protein'] = torch.max(lens)
    chain_len_dicts['peptide'] = torch.min(lens)
    from_back = torch.argmin(lens) == 1
    if from_back:
        pep_res = list(range(chain_len_dicts['protein'], chain_len_dicts['protein']+chain_len_dicts['peptide']))
        prot_res = ppe_res = list(range(0, chain_len_dicts['protein']))
    else:
        pep_res = list(range(0, chain_len_dicts['peptide']))
        prot_res = list(range(chain_len_dicts['peptide'], chain_len_dicts['peptide']+chain_len_dicts['protein']))
    pep_coords = coords_batch['coords'][0][:,pep_res,1,:]
    prot_coords = coords_batch['coords'][0][:,prot_res,1,:]
    inter_dists = []
    for res in range(chain_len_dicts['protein'] + chain_len_dicts['peptide']):
        if not res in interaction_res:
            inter_dists.append(np.nan)
            continue
        res_coord = coords_batch['coords'][0][:,res,1,:]
        if res in pep_res:
            dists = torch.sum((res_coord - prot_coords)**2, dim=2)
        else:
            dists = torch.sum((res_coord - pep_coords)**2, dim=2)
        inter_dists.append(torch.min(dists))
    inter_dists = np.array(inter_dists)
    inter_dists = inter_dists / np.nanmax(inter_dists)
    inter_dists = np.nan_to_num(inter_dists)
    return inter_dists
            
## Mask the structure of the peptide
def mask_peptide_struct(coord_data, coords_batch, pdb):
    peptide_res = get_interaction_res(coords_batch, pdb, top_k=0)
    protein_mask = torch.ones(coord_data['x_mask'].shape).to(dtype=coord_data['x_mask'].dtype, device=coord_data['x_mask'].device)
    protein_mask[:,peptide_res] = 0
    coord_data['x_mask'] *= protein_mask
    coord_data['X'] *= protein_mask.unsqueeze(-1).unsqueeze(-1)
    return coord_data

## Get sequence and structure embeddings from RLA 
def get_seq_and_struct_features(model, tokenizer, batch, pdb=None, seq_mask=None, struct_mask=None, focus=None, top_k=30, remove_far=False, mres=None,
                                prot_len=None, add_ends=False, ends_k=0, threshold=1, seq_only=False, dist_threshold=None):
    seq_batch, coords_batch = batch
    if prot_len is not None:
        seq_batch['string_sequence'][0] = [seq_batch['string_sequence'][0][:prot_len] + 25*'G' + seq_batch['string_sequence'][0][prot_len:]]
        base_len = len(seq_batch['string_sequence'][0])
        pos_embs = torch.arange(base_len+25).unsqueeze(0)
        seq_mask = torch.ones(base_len+25).to(dtype=torch.bool).unsqueeze(0)
        seq_batch['pos_embs'] = [pos_embs, seq_mask]
        seq_batch['placeholder_mask'] = [seq_mask, seq_mask]
        seq_batch['seq_loss_mask'] = [seq_mask, seq_mask]
    if add_ends:
        base_len = len(seq_batch['string_sequence'][0])
        seq_batch['string_sequence'][0] = ends_k*'X' + seq_batch['string_sequence'][0] + ends_k*'X'
        pos_embs = torch.arange(base_len+2*ends_k).unsqueeze(0)
        seq_mask = torch.zeros(base_len+2*ends_k)
        seq_mask[ends_k:-ends_k] = 1
        seq_mask = seq_mask.to(dtype=torch.bool).unsqueeze(0)
        seq_batch['pos_embs'] = [pos_embs, seq_mask]
        seq_batch['placeholder_mask'] = [seq_mask, seq_mask]
        seq_batch['seq_loss_mask'] = [seq_mask, seq_mask]
    if seq_mask == 'peptide':
        seq_batch = mask_peptide(seq_batch, coords_batch, pdb)
    if seq_mask == 'all':
        seq_batch = mask_all(seq_batch, pdb)
    seqs = seq_batch['string_sequence']
    text_inp = tokenizer(seqs, return_tensors='pt', padding=True, truncation=True, max_length=1024+2)
    text_inp['position_ids'] = seq_batch['pos_embs'][0]
    text_inp = {k: v.to('cuda') for k, v in text_inp.items()}
    if not seq_only:
        coord_data = data_utils.construct_gnn_inp(coords_batch, device='cuda', half_precision=True)
        if struct_mask=='peptide':
            coord_data = mask_peptide_struct(coord_data, coords_batch, pdb)
        gnn_features, text_features, logit_scale = model(text_inp, coord_data)
    else:
        gnn_features, text_features, logit_scale = model(text_inp, None)
    new_text_features, _, new_text_mask = data_utils.postprocess_text_features(
        text_features=text_features, 
        inp_dict=text_inp, 
        tokenizer=tokenizer, 
        placeholder_mask=seq_batch['placeholder_mask'][0])
    if prot_len is not None:
        new_text_features = torch.cat((new_text_features[:,:prot_len,:], new_text_features[:,prot_len+25:,:]), dim=1)
        new_text_mask = torch.cat((new_text_mask[:,:prot_len], new_text_mask[:,prot_len+25:]), dim=1)
    if focus:
        focus_res = get_interaction_res(coords_batch, pdb, top_k, remove_far = remove_far, threshold = threshold, mres = mres, dist_threshold = dist_threshold)
        focus_mask = torch.zeros(coords_batch['coords'][1].shape).to(dtype=coords_batch['coords'][1].dtype, device=new_text_features.device)
        focus_mask[:,focus_res] = True
    else:
        focus_mask = torch.zeros(coords_batch['coords'][1].shape).to(dtype=coords_batch['coords'][1].dtype, device=new_text_features.device)
    return {
        'text': new_text_features, # text feature
        'gnn': gnn_features, # gnn feature
        'seq_mask_with_burn_in': seq_batch['seq_loss_mask'][0], # sequence mask of what's supervised
        'coord_mask_with_burn_in': coords_batch['coords_loss_mask'][0], # coord mask of what's supervised
        'seq_mask_no_burn_in': new_text_mask.bool(), # sequence mask of what's valid (e.g., not padded)
        'coord_mask_no_burn_in': coords_batch['coords'][1], # coord mask of what's valid
        'focus_mask': focus_mask, # focus mask of what residues to use to calculate score
    }

## Calculate RLA score
cos = torch.nn.CosineSimilarity()
def calc_sim(all_outputs):
    all_sims = []
    all_sims_burn = []
    for output in all_outputs:
        t = output['text'][output['seq_mask_no_burn_in']]
        g = output['gnn'][output['coord_mask_no_burn_in']]
        sim = (t.unsqueeze(1) @ g.unsqueeze(-1)).squeeze(1).squeeze(1)
        all_sims.append(torch.mean(sim))

    return all_sims

# Get pdb ids of proteins in a batch
def get_decoy_ids(wds_path):
    cols = ['inp.pyd']
    wd_ds = wds.WebDataset(wds_path).decode().to_tuple(*cols)
    batched_ds = wd_ds.batched(1, collation_fn=loaders_utils.custom_collation_fn)
    decoy_ids = []
    for i, b in enumerate(wd_ds):
        decoy_ids.append(b[0]['pdb_id'][0])
    return decoy_ids


In [None]:
## ESM mutation analysis functions
def get_muts(wt, mut):
    inds = []
    muts = []
    for i, (wchar, mchar) in enumerate(zip(wt, mut)):
        if wchar != mchar:
            inds.append(i)
            muts.append(mchar)
    return inds, muts

def score_mut(wt, idx, mt, token_probs, alphabet):
    wt = wt[idx]
    wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)

    # add 1 for BOS
    score = token_probs[0, 1 + idx, mt_encoded] - token_probs[0, 1 + idx, wt_encoded]
    return score.item()

def score_protein(idxs, mts, wt_seq, model, alphabet):

    # inference for each model
    batch_converter = alphabet.get_batch_converter()

    data = [
        ("protein1", wt_seq),
    ]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    all_token_probs = []
    for i in tqdm(range(batch_tokens.size(1))):
        batch_tokens_masked = batch_tokens.clone()
        batch_tokens_masked[0, i] = alphabet.mask_idx
        with torch.no_grad():
            token_probs = torch.log_softmax(
                model(batch_tokens_masked.cuda())["logits"], dim=-1
            )
        all_token_probs.append(token_probs[:, i])  # vocab size
    token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
    esm_predictions = []
    for idx_list, mut_list in zip(idxs, mts):
        mut_score = 0
        for idx, mut in zip(idx_list, mut_list):
            mut_score += score_mut(wt_seq, idx, mut, token_probs, alphabet)
        if len(idx_list) == 0:
            mut_score = np.nan
        esm_predictions.append(mut_score)
    return esm_predictions
