In [1]:
import pandas as pd
import numpy as np
import gemmi
import reciprocalspaceship as rs
from tqdm import tqdm
import glob, os
import matplotlib.pyplot as plt

### 1. Read in the ligand smile log

In [2]:
lig_log_path = "./vatd_smiles_cleaned.csv"
lig_log = pd.read_csv(lig_log_path)

### 2. bound models path

In [3]:
bound_model_root = "./VatD_CommonHits_Modelled/"

In [4]:
folder_paths = glob.glob(os.path.join(bound_model_root, "*/"))

In [5]:
bound_ids = [int(i.split('/')[-2][1:]) for i in folder_paths]

### 3. bound model smiles

In [17]:
def withheavy(smile, heavy=["Cl", "S", "I", "Br", "P"]):
    return any((sub in smile) or (sub.lower() in smile) for sub in heavy)

In [18]:
bound_smiles = lig_log[lig_log["id"].apply(lambda x: x in bound_ids)].copy().reset_index(drop=True)

In [19]:
bound_smiles["heavy"] = bound_smiles["smile"].apply(withheavy).copy() 

In [20]:
bound_smiles_with_heavy = bound_smiles[bound_smiles["heavy"]].copy().reset_index(drop=True)

In [21]:
bound_smiles_with_heavy

Unnamed: 0,code,smile,id,heavy
0,vat-d116,O=S(=O)(Nc1nncs1)c1ccccc1,116,True
1,vat-d167,O=C(Nc1nccs1)c1ccccc1,167,True
2,vat-d229,CC=1C=C2C(=O)C(=O)NC2=C(Br)C1,229,True
3,vat-d275,ClC=1C=CC=2OC(=O)NC2C1,275,True
4,vat-d551,OC(=O)C1=CC=2C=CSC2N1,551,True
5,vat-d560,ClC=1C=CC=2NC=C(CC#N)C2C1,560,True
6,vat-d283,ClC=1C=CC=2CC(=O)NC2C1,283,True
7,vat-d374,NS(=O)(=O)C1=CC=2C=CC=CC2O1,374,True
8,vat-d398,CNC(=O)C=1C=C(Cl)C=CN1,398,True
9,vat-d425,OC(=O)C1=CC=2C(Cl)=CC=CC2N1,425,True


### 4. Valdo diff maps

In [22]:
valdo_mtz_path = "/n/holyscratch01/hekstra_lab/dhekstra/valdo-vatD/pipeline/vae/reconstructed_w_phases/"
diff_column = "WDF"
phase_column = "refine_PH2FOFCWT"

### 5. Production run

In [23]:
def get_heavy_atoms(st, selections='[CL,Br,S,I,P]'):
    sel = gemmi.Selection(selections)
    sel_model = sel.copy_model_selection(st[0])
    lig_heavy_atoms = [i.atom.clone() for i in list(sel_model.all()) if i.residue.name == 'LIG']
    return lig_heavy_atoms

def get_peak_values(atom_list, real_grid):
    # check the highest peak
    a,b,c = np.unravel_index(real_grid.array.argmax(), real_grid.array.shape)
    tmp = real_grid.get_fractional(a,b,c)
    peak_pos = mtz.cell.orthogonalize(gemmi.Fractional(tmp.x, tmp.y, tmp.z))
    dis_lists = []
    peak_values = []
    ops = real_grid.spacegroup.operations()
    for atom in atom_list:
        # Get all equivalent sites
        eq_points = []
        dis_list = []
        for op in ops:
            SG_mapped=op.apply_to_xyz(real_grid.unit_cell.fractionalize(atom.pos).tolist())
            tmp = SG_mapped-np.floor(np.array(SG_mapped)) # Move into cell
            SG_mapped = gemmi.Fractional(*tmp)
            # print(f"xyz: {SG_mapped[0]:.3f}, {SG_mapped[1]:.3}, {SG_mapped[2]:.3} ") 
            eq_points.append(SG_mapped)
            SG_mapped_orth = real_grid.unit_cell.orthogonalize(SG_mapped)
            dis_list.append(np.sqrt(np.sum(np.array((peak_pos - SG_mapped_orth).tolist())**2)))
        
        # Get the nearest voxel value
        peak_value = []
        for pos in eq_points:
            a = round(pos.x * real_grid.nu)
            b = round(pos.y * real_grid.nv)
            c = round(pos.z * real_grid.nw)
            peak_value.append(real_grid.get_value(a, b, c))
            #print(real_grid.get_value(a, b, c))
        
        dis_lists.append(dis_list)
        peak_values.append(peak_value)
    return dis_lists, peak_values

In [29]:
logs = []
for id in tqdm(bound_smiles_with_heavy["id"]):
    try:
        model_path = glob.glob(os.path.join(bound_model_root, f"d{id}", "*/output*.pdb"))[0]
        mtz_path = os.path.join(valdo_mtz_path, f"{id}.mtz")
    
        st = gemmi.read_pdb(model_path)
        mtz = gemmi.read_mtz_file(mtz_path)
        real_grid = mtz.transform_f_phi_to_map(diff_column, phase_column, sample_rate=3.0)
        real_grid.normalize()
    
        atom_list = get_heavy_atoms(st)
        dis_lists, peak_values = get_peak_values(atom_list, real_grid)
        log_peak = np.max(peak_values, axis=1).tolist()
        log_dist = np.min(dis_lists, axis=1).tolist()
    
        for atom, peak_value, dist_value in zip(atom_list, log_peak, log_dist):
            logs.append([id, atom.name, peak_value, dist_value])
    except:
        print(id)

df_vae_peak = pd.DataFrame(
    data = logs,
    columns = ["id", "atom_name", "peak_value", "dist_value"]
)

100%|██████████| 12/12 [00:00<00:00, 26.51it/s]

283
374





In [30]:
df_vae_peak

Unnamed: 0,id,atom_name,peak_value,dist_value
0,116,S02,5.922434,0.743953
1,116,S09,2.094872,4.325534
2,167,S08,4.55587,6.909043
3,229,BR12,0.420401,22.94918
4,229,BR12,2.665406,10.943409
5,275,CL01,4.595695,1.25083
6,551,S09,8.08909,0.695742
7,551,S09,3.004719,22.611741
8,560,CL01,0.659352,34.843283
9,398,CL08,2.553915,5.480179
