In [None]:
%load_ext autoreload
import seaborn as sns
import numpy as np
import pandas as pd
import torch 
from torch import nn
from equiformer_pytorch import Equiformer
from data import PDB_DATASET, load_case_list
import os
import pickle

In [None]:
eval_dict = pickle.load(open("", 'rb'))

In [None]:
class RMS_predictor(nn.Module):
    def __init__(self):
        super(RMS_predictor, self).__init__()
        self.model = Equiformer(
            num_tokens = len(['N', 'O', 'P', 'C', 'S', 'F', 'CL', 'BR']),
            dim = (16, 8, 8),               # dimensions per type, ascending, length must match number of degrees (num_degrees)
            dim_head = (16, 8, 8),          # dimension per attention head
            heads = (1, 1, 1),             # number of attention heads
            num_degrees = 3,               # number of degrees
            depth = 3,                     # depth of equivariant transformer
            attend_self = True,            # attending to self or not
            reduce_dim_out = False,         # whether to reduce out to dimension of 1, say for predicting new coordinates for type 1 features
            dot_product_attention = True,  # set to False to try out MLP attention
            num_neighbors = 42
        )
        self.mlp_1 = nn.Linear(16, 256)
        self.relu = nn.ReLU()
        self.mlp_2 = nn.Linear(256, 1)
        
        
    def forward(self, feature, coord, mask=None):
        x = self.model(feature, coord, mask).type0
        x = torch.mean(x, dim=1)
        x = self.mlp_1(x)
        x = self.relu(x)
        x = self.mlp_2(x)
        
        # print(x.shape)
        return x

In [None]:
def parse_pdb(pdb_path):
    element_map = {e: i for i, e in enumerate(['N', 'O', 'P', 'C', 'S', 'F', 'CL', 'BR'])}
    with open(pdb_path, 'r') as f:
        lines = f.readlines()
    coord = []
    ele = []
    score = None
    atom_count = 0
    for line in lines:
        if line[:4] == 'ATOM' or line[:6] == "HETATM":
            atom_ele = line[76:78].strip().upper()
            if atom_ele != 'H' and atom_ele != 'Z': # exlude hydrogen and dummy atoms
                x = float(line[30:38])
                y = float(line[38:46])
                z = float(line[46:54])
                atom_coord = torch.tensor([x, y, z])
                ele.append(element_map[atom_ele])
                coord.append(atom_coord)
                atom_count += 1
        if line[:4] == 'rms ':  # not rms_stem
            score = torch.tensor(float(line[4:].strip()))
    coord = torch.stack(coord)
    ele = torch.tensor(ele)
    # if score is None:
    #     raise ValueError("No score found in pdb file")
    return {"cord": coord, "ele": ele, "path": pdb_path, "rms": score,
            "num_atom": atom_count}

In [None]:
model_path = ""

model = RMS_predictor().cuda()  


model.load_state_dict(torch.load(model_path))


In [None]:
res_dict = {}

for rna in eval_dict:
    if rna in ["6cb3", "3f30", "3f2w", "4k32", "7edl", "2yie", "4p3s"]:
        continue
    res_dict[rna] = {}
    print(rna)
    for gt in eval_dict[rna]:
        res_dict[rna][gt] = {}
        for ligand in eval_dict[rna][gt]:
            structure_path = f""
            input_data = parse_pdb(structure_path)
            ele = input_data['ele'].cuda().unsqueeze(0)
            cord = input_data['cord'].cuda().unsqueeze(0)
            pred_rms = model(ele, cord).item()
            res_dict[rna][gt][ligand] = pred_rms
        # print(res_dict)
    #     break
        
    # break


In [None]:
comp_list = []
for rna in res_dict:
    for gt in res_dict[rna]:
        ligand_name_list = []
        liagnd_rms_list = []
        for ligand in res_dict[rna][gt]:
            print(rna, gt, ligand, res_dict[rna][gt][ligand])
            liagnd_rms_list.append(res_dict[rna][gt][ligand])
            ligand_name_list.append(ligand)
        sorted_name_list = [x for _, x in sorted(zip(liagnd_rms_list, ligand_name_list))]
        comp_list.append([gt, sorted_name_list])
    # break

In [None]:
res = {}
for top_n in range(1, 11):
    count = 0
    for pair_list in comp_list:
        if pair_list[0] in pair_list[1][:top_n]:
            count += 1
    res[top_n] = (count / len(comp_list))
                #   (len(comp_list) - count))