In [25]:
import os
import torch
import numpy as np
from pathlib import Path
from Bio.PDB import PDBParser
from lightning_modules import LigandPocketDDPM
from torch_scatter import scatter_add, scatter_mean
import utils
from constants import dataset_params, FLOAT_TYPE, INT_TYPE

In [26]:
test_folder = Path("/home/domainHomes/ssakharov/master_thesis/crossdocked/processed_crossdock_noH_ca_only_temp/test")
txt_file = test_folder / "4keu-A-rec-4ket-pg4-lig-tt-min-0-pocket10_4keu-A-rec-4ket-pg4-lig-tt-min-0.txt"
pdb_file = test_folder / "4keu-A-rec-4ket-pg4-lig-tt-min-0-pocket10.pdb"

batch_size = 4


In [27]:
with open(txt_file, 'r') as f:
    resi_list = f.read().split()
print(resi_list) 

['A:24', 'A:27', 'A:223', 'A:225', 'A:226', 'A:228', 'A:229', 'A:256', 'A:257', 'A:258', 'A:261', 'A:263', 'A:264', 'A:265', 'A:266', 'A:271', 'A:274', 'A:275', 'A:278', 'A:199', 'A:202', 'A:22', 'A:170', 'A:222', 'A:255', 'A:72', 'A:97', 'A:67', 'A:99', 'A:139', 'A:141', 'A:171', 'A:227', 'A:267', 'A:270', 'A:268', 'A:272', 'A:273', 'A:276', 'A:269', 'A:277', 'A:279', 'C:104']


In [28]:
model = LigandPocketDDPM.load_from_checkpoint(
        "/home/domainHomes/ssakharov/master_thesis/logdir/SE3-inpaint-CA-test/checkpoints/best-model-epoch=epoch=655.ckpt", 
        map_location="cpu")
model = model.to("cpu")

Entropy of n_nodes: H[N] 7.055830001831055


#### Function: generate_ligands 

In [29]:
pdb_struct = PDBParser(QUIET=True).get_structure('', pdb_file)[0]

In [30]:
pocket_ids = resi_list
residues = [pdb_struct[x.split(':')[0]][(' ', int(x.split(':')[1]), ' ')] for x in pocket_ids]
print(residues)

[<Residue HIS het=  resseq=24 icode= >, <Residue VAL het=  resseq=27 icode= >, <Residue ARG het=  resseq=223 icode= >, <Residue GLY het=  resseq=225 icode= >, <Residue LEU het=  resseq=226 icode= >, <Residue LEU het=  resseq=228 icode= >, <Residue PHE het=  resseq=229 icode= >, <Residue ASP het=  resseq=256 icode= >, <Residue TYR het=  resseq=257 icode= >, <Residue CYS het=  resseq=258 icode= >, <Residue ILE het=  resseq=261 icode= >, <Residue MET het=  resseq=263 icode= >, <Residue GLY het=  resseq=264 icode= >, <Residue THR het=  resseq=265 icode= >, <Residue ALA het=  resseq=266 icode= >, <Residue LYS het=  resseq=271 icode= >, <Residue LEU het=  resseq=274 icode= >, <Residue ALA het=  resseq=275 icode= >, <Residue TRP het=  resseq=278 icode= >, <Residue HIS het=  resseq=199 icode= >, <Residue ASP het=  resseq=202 icode= >, <Residue HIS het=  resseq=22 icode= >, <Residue HIS het=  resseq=170 icode= >, <Residue ASP het=  resseq=222 icode= >, <Residue HIS het=  resseq=255 icode= >, <R

In [31]:
pocket = model.prepare_pocket(residues, batch_size)
for pocket_key, pocket_value in pocket.items():
    print(pocket_key, pocket_value.shape)

x torch.Size([172, 3])
one_hot torch.Size([172, 20])
size torch.Size([4])
mask torch.Size([172])




In [32]:
pocket_com_before = scatter_mean(pocket['x'], pocket['mask'], dim=0)
pocket_com_before, pocket_com_before.shape

(tensor([[ 65.4615,  -5.8382, -65.9667],
         [ 65.4615,  -5.8382, -65.9667],
         [ 65.4615,  -5.8382, -65.9667],
         [ 65.4615,  -5.8382, -65.9667]]),
 torch.Size([4, 3]))

In [33]:
torch.manual_seed(42)
num_nodes_lig = model.ddpm.size_distribution.sample_conditional(n1=None, n2=pocket['size'])
print(f"num_nodes_lig: {num_nodes_lig}")
n_nodes_bias = 0
num_nodes_lig = num_nodes_lig + n_nodes_bias
print(f"num_nodes_lig: {num_nodes_lig}")

num_nodes_lig: tensor([25, 20, 15, 19])
num_nodes_lig: tensor([25, 20, 15, 19])


In [34]:
lig_mask = utils.num_nodes_to_batch_mask(len(num_nodes_lig), num_nodes_lig, "cpu")
lig_mask, lig_mask.shape

(tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3]),
 torch.Size([79]))

In [35]:
ligand = {
    'x': torch.zeros((len(lig_mask), model.x_dims),
                        device="cpu", dtype=FLOAT_TYPE),
    'one_hot': torch.zeros((len(lig_mask), model.atom_nf),
                            device="cpu", dtype=FLOAT_TYPE),
    'size': num_nodes_lig,
    'mask': lig_mask
    }
for key, value in ligand.items():
    print(key, value.shape)

x torch.Size([79, 3])
one_hot torch.Size([79, 10])
size torch.Size([4])
mask torch.Size([79])


In [36]:
lig_fixed = torch.zeros(len(lig_mask), device="cpu")
pocket_fixed = torch.ones(len(pocket['mask']),device="cpu")
if len(lig_fixed.size()) == 1:
    lig_fixed = lig_fixed.unsqueeze(1)
if len(pocket_fixed.size()) == 1:
    pocket_fixed = pocket_fixed.unsqueeze(1)

### Inpaint

In [37]:
timesteps = model.ddpm.T
print(f"timesteps: {timesteps}")

timesteps: 500


In [38]:
ligand, pocket = model.ddpm.normalize(ligand, pocket)
for key, value in ligand.items():
    print(key, value.shape)
for key, value in pocket.items():
    print(key, value.shape)

x torch.Size([79, 3])
one_hot torch.Size([79, 10])
size torch.Size([4])
mask torch.Size([79])
x torch.Size([172, 3])
one_hot torch.Size([172, 20])
size torch.Size([4])
mask torch.Size([172])


In [39]:
mean_known = scatter_mean(
            torch.cat((ligand['x'][lig_fixed.bool().view(-1)],
                       pocket['x'][pocket_fixed.bool().view(-1)])),
            torch.cat((ligand['mask'][lig_fixed.bool().view(-1)],
                       pocket['mask'][pocket_fixed.bool().view(-1)])),
            dim=0
        )
mean_known, mean_known.shape

(tensor([[ 65.4615,  -5.8382, -65.9667],
         [ 65.4615,  -5.8382, -65.9667],
         [ 65.4615,  -5.8382, -65.9667],
         [ 65.4615,  -5.8382, -65.9667]]),
 torch.Size([4, 3]))

In [40]:
xh0_lig = torch.cat([ligand['x'], ligand['one_hot']], dim=1)
xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1)

xh0_lig[:, :model.ddpm.n_dims] = \
    xh0_lig[:, :model.ddpm.n_dims] - mean_known[ligand['mask']]
xh0_pocket[:, :model.ddpm.n_dims] = \
    xh0_pocket[:, :model.ddpm.n_dims] - mean_known[pocket['mask']]
print(xh0_lig.shape, xh0_pocket.shape)

torch.Size([79, 13]) torch.Size([172, 23])


In [41]:
z_lig, z_pocket = model.ddpm.sample_combined_position_feature_noise(ligand['mask'], pocket['mask'])
z_lig.shape, z_pocket.shape

(torch.Size([79, 13]), torch.Size([172, 23]))

In [42]:
resamplings = 10
jump_length = 1
# Each integer in the schedule list describes how many denoising steps need to be applied before jumping back 
schedule = model.ddpm.get_repaint_schedule(resamplings, jump_length, timesteps=50)
print(schedule[:20])
len(schedule), np.unique(schedule, return_counts=True)

[2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1]


(442, (array([1, 2]), array([393,  49])))

In [43]:
n_samples = len(ligand['size'])
s = timesteps - 1
print(f"n_samples: {n_samples}")
print(f"s: {s}")
for i, n_denoise_steps in enumerate(schedule):
    for j in range(n_denoise_steps):
        s_array = torch.full((n_samples, 1), fill_value=s,device="cpu")
        print(f"s_array.shape {s_array.shape}, s_array: {s_array}")
        t_array = s_array + 1
        s_array = s_array / timesteps 
        t_array = t_array / timesteps
        gamma_s = model.ddpm.inflate_batch_array(model.ddpm.gamma(s_array), ligand['x'])
        print(f"gamma_s.shape: {gamma_s.shape}")
        #### apply noise to the ligand and pocket depending on the timestep 
        z_lig_known, z_pocket_known, _, _ = model.ddpm.noised_representation(xh0_lig, xh0_pocket, ligand['mask'], pocket['mask'], gamma_s)
        print(f"z_lig_known.shape: {z_lig_known.shape}")
        print(f"z_pocket_known.shape: {z_pocket_known.shape}")
        #### Samples from zs ~ p(zs | zt).
        z_lig_unknown, z_pocket_unknown = model.ddpm.sample_p_zs_given_zt( s_array, t_array, z_lig, z_pocket, ligand['mask'], pocket['mask'])
        print(f"z_lig_unknown.shape: {z_lig_unknown.shape}")
        print(f"z_pocket_unknown.shape: {z_pocket_unknown.shape}")
        z_lig = z_lig_known * lig_fixed + \
                z_lig_unknown * (1 - lig_fixed)
        z_pocket = z_pocket_known * pocket_fixed + \
                    z_pocket_unknown * (1 - pocket_fixed)
        print(f"z_lig.shape: {z_lig.shape}")
        print(f"z_pocket.shape: {z_pocket.shape}")
        break
    break

n_samples: 4
s: 499
s_array.shape torch.Size([4, 1]), s_array: tensor([[499],
        [499],
        [499],
        [499]])
gamma_s.shape: torch.Size([4, 1])
z_lig_known.shape: torch.Size([79, 13])
z_pocket_known.shape: torch.Size([172, 23])


z_lig_unknown.shape: torch.Size([79, 13])
z_pocket_unknown.shape: torch.Size([172, 23])
z_lig.shape: torch.Size([79, 13])
z_pocket.shape: torch.Size([172, 23])


In [47]:
x_lig, h_lig, x_pocket, h_pocket = model.ddpm.sample_p_xh_given_z0(z_lig, z_pocket, ligand['mask'], pocket['mask'], n_samples)
print(f"x_lig.shape: {x_lig.shape}")
print(f"h_lig.shape: {h_lig.shape}")
print(f"x_pocket.shape: {x_pocket.shape}")
print(f"h_pocket.shape: {h_pocket.shape}")

x_lig.shape: torch.Size([79, 3])
h_lig.shape: torch.Size([79, 10])
x_pocket.shape: torch.Size([172, 3])
h_pocket.shape: torch.Size([172, 20])


In [48]:
out_lig = torch.zeros((return_frames,) + z_lig.size(),
                        device=z_lig.device)
out_pocket = torch.zeros((return_frames,) + z_pocket.size(),
                            device=z_pocket.device)

out_lig[0] = torch.cat([x_lig, h_lig], dim=1)
out_pocket[0] = torch.cat([x_pocket, h_pocket], dim=1)

out_pocket = out_pocket.squeeze()
out_lig = out_lig.squeeze()
print(f"out_lig.shape: {out_lig.shape}")
print(f"out_pocket.shape: {out_pocket.shape}")

out_lig.shape: torch.Size([79, 13])
out_pocket.shape: torch.Size([172, 23])


In [44]:
# resamplings = 10
# jump_length = 1
# timesteps = 50
# # Each integer in the schedule list describes how many denoising steps need to be applied before jumping back 
# schedule = model.ddpm.get_repaint_schedule(resamplings, jump_length, timesteps)
# print(schedule[:20])
# len(schedule), np.unique(schedule, return_counts=True)

In [45]:
# s = timesteps - 1
# return_frames = 1
# jump_length = 1
# for i, n_denoise_steps in enumerate(schedule):
#     print(f"---- i: {i}, n_denoise_steps: {n_denoise_steps} ----")
#     for j in range(n_denoise_steps):
#         print(f" j = {j}")
#         # Denoise one time step: t -> s
#         s_array = torch.full((n_samples, 1), fill_value=s,
#                                 device=z_lig.device)
#         t_array = s_array + 1
#         print(f"Sample p_zs_given_zt for s {s_array[0]} and t_array {t_array[0]}")
#         s_array = s_array / timesteps
#         t_array = t_array / timesteps

#         # sample known nodes from the input
#         # save frame at the end of a resample cycle
#         if n_denoise_steps > jump_length or i == len(schedule) - 1:
#             if (s * return_frames) % timesteps == 0:
#                 print(f"In the first if statement, s: {s}")

#         # Noise combined representation
#         if j == n_denoise_steps - 1 and i < len(schedule) - 1:
#             # Go back jump_length steps
#             t = s + jump_length
#             t_array = torch.full((n_samples, 1), fill_value=t,
#                                     device=z_lig.device)
#             print(f"Sample p_zt_given_zs for and t_array {t_array[0]}")
#             s = t
#         s -= 1

In [46]:
x_lig, h_lig, x_pocket, h_pocket = self.sample_p_xh_given_z0(z_lig, z_pocket, ligand['mask'], pocket['mask'], n_samples)

NameError: name 'self' is not defined