In [1]:
import pandas as pd
import numpy as np
import gemmi
import reciprocalspaceship as rs
from tqdm import tqdm
import glob, os
import matplotlib.pyplot as plt

### 1. Read in the ligand smile log

In [2]:
lig_log_path = "./vatd_smiles_cleaned.csv"
lig_log = pd.read_csv(lig_log_path)

### 2. bound models path

In [3]:
bound_model_root = "./VatD_PanDDaHits_Modelled/"

In [4]:
folder_paths = glob.glob(os.path.join(bound_model_root, "*/"))

In [5]:
bound_ids = [int(i.split('/')[-2][1:]) for i in folder_paths]

### 3. bound model smiles

In [7]:
def withheavy(smile, heavy=["F", "Cl", "S", "I", "Br", "P"]):
    return any(sub in smile for sub in heavy)

In [6]:
bound_smiles = lig_log[lig_log["id"].apply(lambda x: x in bound_ids)].copy().reset_index(drop=True)

In [8]:
bound_smiles["heavy"] = bound_smiles["smile"].apply(withheavy).copy() 

In [9]:
bound_smiles_with_heavy = bound_smiles[bound_smiles["heavy"]].copy().reset_index(drop=True)

In [10]:
bound_smiles_with_heavy

Unnamed: 0,code,smile,id,heavy
0,vat-d234,NS(=O)(=O)C=1C=CC(=CC1)C(=O)O,234,True
1,vat-d474,OCCC=1C=CC(F)=CC1,474,True
2,vat-d538,NS(=O)(=O)C=1C=C(Cl)C=C(Cl)C1,538,True
3,vat-d545,ClC=1C=CC=2NC(=O)C(=O)C2C1,545,True
4,vat-d410,NC=1C=C(F)C=CC1C(=O)O,410,True
5,vat-d429,OC(=O)CN1N=C(C=2CCCCC21)C(F)(F)F,429,True


### 4. Valdo diff maps

In [11]:
valdo_mtz_path = "/n/holyscratch01/hekstra_lab/dhekstra/valdo-vatD/pipeline/vae/reconstructed_w_phases/"
diff_column = "WDF"
phase_column = "refine_PH2FOFCWT"

### 5. Production run

In [12]:
def get_heavy_atoms(st, selections='[F,CL,Br,S,I,P]'):
    sel = gemmi.Selection(selections)
    sel_model = sel.copy_model_selection(st[0])
    lig_heavy_atoms = [i.atom.clone() for i in list(sel_model.all()) if i.residue.name == 'LIG']
    return lig_heavy_atoms

def get_peak_values(atom_list, real_grid):
    # check the highest peak
    a,b,c = np.unravel_index(real_grid.array.argmax(), real_grid.array.shape)
    tmp = real_grid.get_fractional(a,b,c)
    peak_pos = mtz.cell.orthogonalize(gemmi.Fractional(tmp.x, tmp.y, tmp.z))
    dis_lists = []
    peak_values = []
    ops = real_grid.spacegroup.operations()
    for atom in atom_list:
        # Get all equivalent sites
        eq_points = []
        dis_list = []
        for op in ops:
            SG_mapped=op.apply_to_xyz(real_grid.unit_cell.fractionalize(atom.pos).tolist())
            tmp = SG_mapped-np.floor(np.array(SG_mapped)) # Move into cell
            SG_mapped = gemmi.Fractional(*tmp)
            # print(f"xyz: {SG_mapped[0]:.3f}, {SG_mapped[1]:.3}, {SG_mapped[2]:.3} ") 
            eq_points.append(SG_mapped)
            SG_mapped_orth = real_grid.unit_cell.orthogonalize(SG_mapped)
            dis_list.append(np.sqrt(np.sum(np.array((peak_pos - SG_mapped_orth).tolist())**2)))
        
        # Get the nearest voxel value
        peak_value = []
        for pos in eq_points:
            a = round(pos.x * real_grid.nu)
            b = round(pos.y * real_grid.nv)
            c = round(pos.z * real_grid.nw)
            peak_value.append(real_grid.get_value(a, b, c))
            #print(real_grid.get_value(a, b, c))
        
        dis_lists.append(dis_list)
        peak_values.append(peak_value)
    return dis_lists, peak_values

In [13]:
logs = []
for id in tqdm(bound_smiles_with_heavy["id"]):
    model_path = glob.glob(os.path.join(bound_model_root, f"d{id}", "*/output*.pdb"))[0]
    mtz_path = os.path.join(valdo_mtz_path, f"{id}.mtz")

    st = gemmi.read_pdb(model_path)
    mtz = gemmi.read_mtz_file(mtz_path)
    real_grid = mtz.transform_f_phi_to_map(diff_column, phase_column, sample_rate=3.0)
    real_grid.normalize()

    atom_list = get_heavy_atoms(st)
    dis_lists, peak_values = get_peak_values(atom_list, real_grid)
    log_peak = np.max(peak_values, axis=1).tolist()
    log_dist = np.min(dis_lists, axis=1).tolist()

    for atom, peak_value, dist_value in zip(atom_list, log_peak, log_dist):
        logs.append([id, atom.name, peak_value, dist_value])

df_vae_peak = pd.DataFrame(
    data = logs,
    columns = ["id", "atom_name", "peak_value", "dist_value"]
)

100%|██████████| 6/6 [00:00<00:00, 11.38it/s]


In [14]:
df_vae_peak

Unnamed: 0,id,atom_name,peak_value,dist_value
0,234,S02,-3.003983,18.820144
1,474,F08,3.520009,38.079794
2,538,S02,5.527978,5.952479
3,538,CL08,5.811659,0.472845
4,538,CL11,1.483987,39.400177
5,545,CL01,-0.958365,24.254317
6,410,F05,1.482099,43.17485
7,429,F15,1.19526,7.948566
8,429,F16,0.271683,6.639108
9,429,F17,2.188319,8.271547
