In [4]:
import argparse
import os
import torch

from src.datasets import get_dataloader
from src.lightning import DDPM
from src.molecule_builder import get_bond_order
from src.visualizer import save_xyz_file, visualize_chain
from tqdm import tqdm
from pdb import set_trace
import sys #@mastro
from src import const #@mastro
import numpy as np #@mastro

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Simulate command-line arguments
sys.argv = [
    'ipykernel_launcher.py',
    '--checkpoint', 'models/zinc_difflinker.ckpt',
    '--chains', 'trajectories',
    '--data', 'datasets',
    '--prefix', 'zinc_final_test',
    '--keep_frames', '5',
    '--device', 'cuda:0'
]

parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', action='store', type=str, required=True)
parser.add_argument('--chains', action='store', type=str, required=True)
parser.add_argument('--prefix', action='store', type=str, required=True)
parser.add_argument('--data', action='store', type=str, required=False, default=None)
parser.add_argument('--keep_frames', action='store', type=int, required=True)
parser.add_argument('--device', action='store', type=str, required=True)
args = parser.parse_args()

experiment_name = args.checkpoint.split('/')[-1].replace('.ckpt', '')
chains_output_dir = os.path.join(args.chains, experiment_name, args.prefix, 'chains')
final_states_output_dir = os.path.join(args.chains, experiment_name, args.prefix, 'final_states')
os.makedirs(chains_output_dir, exist_ok=True)
os.makedirs(final_states_output_dir, exist_ok=True)

# Loading model form checkpoint (all hparams will be automatically set)
model = DDPM.load_from_checkpoint(args.checkpoint, map_location=args.device)

# Possibility to evaluate on different datasets (e.g., on CASF instead of ZINC)
model.val_data_prefix = args.prefix

# In case <Anonymous> will run my model or vice versa
if args.data is not None:
    model.data_path = args.data

model = model.eval().to(args.device)
model.setup(stage='val')
dataloader = get_dataloader(
    model.val_dataset,
    batch_size=1, #@mastro, it was 32
    # batch_size=len(model.val_dataset)
)

#@mastro
num_samples = 1
sampled = 0
#end @mastro
start = 0
bond_order_dict = {0:0, 1:0, 2:0, 3:0}
ATOM_SAMPLER = False
SAVE_IMAGES = False
chain_with_full_fragments = None

for data in tqdm(dataloader):

    if sampled < num_samples:

        sampled += 1
        

        # display(data["fragment_mask"])
        # display(data["fragment_mask"].shape)

        # display(data["linker_mask"])
        # display(data["linker_mask"].shape)
        
        # display(data["edge_mask"])
        # display(data["edge_mask"].shape)

        #mask out all edges that are not bonds
        idx2atom = const.GEOM_IDX2ATOM if model.is_geom else const.IDX2ATOM
      
        positions = data["positions"][0].detach().cpu().numpy()
        x  = positions[:,0]
        y  = positions[:,1]
        z  = positions[:,2]
        # print(x)
       
        atom_type = torch.argmax(data["one_hot"][0], dim=1)
        print("Number of edges", len(x) * len(x))
        # sys.exit()
        #uncomment to work on edge_mask (not huge effect, tho)
        # for i in range(len(x)):
        #     for j in range(i+1, len(x)):
        #         p1 = np.array([x[i], y[i], z[i]])
        #         p2 = np.array([x[j], y[j], z[j]])
        #         dist =  np.sqrt(np.sum((p1 - p2) ** 2)) #np.linalg.norm(p1-p2)
                
        #         atom1, atom2 = idx2atom[atom_type[i].item()], idx2atom[atom_type[j].item()]
        #         bond_order = get_bond_order(atom1, atom2, dist)
                
        #         bond_order_dict[bond_order] += 1
        #         # if bond_order <= 0: #TODO debug. Why not all set to 0?
        #         if True:
        #             data["edge_mask"][i * len(x) + j] = 0
        #             data["edge_mask"][j * len(x) + i] = 0
        #         #set all edge_mask indices to 0
        #         data["edge_mask"] = torch.zeros_like(data["edge_mask"])

        #randomly mask out 50% of atoms
        # mask = torch.rand(data["atom_mask"].shape) > 0.5
        # data["atom_mask"] = data["atom_mask"] * mask.to(model.device)
        #mask out all atoms
        # data["atom_mask"] = torch.zeros_like(data["atom_mask"])
        
        if ATOM_SAMPLER:
            print("Before removal:", data["positions"].shape)
            #get all indices in atom_mask that correspond to fragment atoms
            fragment_indices = torch.where(data["fragment_mask"] == 1)[1]

            print(fragment_indices)
            
            #randomly pick 50% of fragment atoms
            random_indices = torch.randperm(len(fragment_indices))[:int(len(fragment_indices)/2)]
            mask_fragments = torch.ones(len(fragment_indices), dtype=torch.bool)
            mask_fragments[random_indices] = False
            selected_fragment_atoms = fragment_indices[mask_fragments]

            #keep only the first index in selected_fragment_atoms
            selected_fragment_atoms = selected_fragment_atoms[:1]
            print("Selected fragment atoms", selected_fragment_atoms)
            
            num_atoms = data["positions"].shape[1]
            # random_indices = torch.randperm(num_atoms)[:int(num_atoms/2)]
            mask = torch.ones(num_atoms, dtype=torch.bool)
            mask[selected_fragment_atoms] = False


            #remove positions of atoms in random_indices
            data["positions"] = data["positions"][:, mask]
            #remove one_hot of atoms in random_indices
            data["one_hot"] = data["one_hot"][:, mask]
            #remove atom_mask of atoms in random_indices
            data["atom_mask"] = data["atom_mask"][:, mask]
            #remove fragment_mask of atoms in random_indices
            data["fragment_mask"] =  data["fragment_mask"][:, mask]
            #remove linker_mask of atoms in random_indices
            data["linker_mask"] = data["linker_mask"][:, mask]
            #remove edge_mask of atoms in random_indices
            for index in random_indices:
                for i in range(num_atoms):
                    data["edge_mask"][index * num_atoms + i] = 0
                    data["edge_mask"][i * num_atoms + index] = 0

            #remove all values in edge_mask that are 0
            data["edge_mask"] = data["edge_mask"][data["edge_mask"] != 0]  #to be checked, but working on atoms has as effect. For the moment we stick to atoms, then we move to edges (need to edit internal function for this, or redefine everything...)
            

            print("After removal:", data["positions"].shape)
            # sys.exit()
            # print number of zeros in edge mask
            print("Number of masked out edges (edges not representing bonds)", torch.sum(data["edge_mask"] == 0))
            print("Number of edges still present", torch.sum(data["edge_mask"] != 0))

            # print number of zeros in atom mask
            print("Number of masked out atoms", torch.sum(data["atom_mask"] == 0))

        
        chain_batch, node_mask = model.sample_chain(data, keep_frames=args.keep_frames)
        
        #get the generated molecule and store it in a variable
        chain_with_full_fragments = chain_batch[0]

        # Compute distance of two chains
        mol_similarity = 1 - torch.norm(chain_batch[0] - chain_batch[0])
        print("Similarity between the two chains:", mol_similarity.item())
        # compute similarity of one-hot vectors
        one_hot_similarity = torch.sum(chain_batch[0, :, 3:-1] == chain_batch[0, :, 3:-1]) / chain_batch[0, :, 3:-1].numel()
        print("Similarity between the two one-hot vectors:", one_hot_similarity.item())
        
    
        if SAVE_IMAGES:
            for i in tqdm(range(len(data['positions']))):
                chain = chain_batch[:, i, :, :]
                assert chain.shape[0] == args.keep_frames
                assert chain.shape[1] == data['positions'].shape[1]
                assert chain.shape[2] == data['positions'].shape[2] + data['one_hot'].shape[2] + model.include_charges

                # Saving chains
                name = str(i + start)
                chain_output = os.path.join(chains_output_dir, name)
                os.makedirs(chain_output, exist_ok=True)

                one_hot = chain[:, :, 3:-1]
                positions = chain[:, :, :3]
                chain_node_mask = torch.cat([node_mask[i].unsqueeze(0) for _ in range(args.keep_frames)], dim=0)
                names = [f'{name}_{j}' for j in range(args.keep_frames)]

                save_xyz_file(chain_output, one_hot, positions, chain_node_mask, names=names, is_geom=model.is_geom)
                visualize_chain(
                    chain_output,
                    spheres_3d=True,
                    alpha=0.7,
                    bg='white',
                    is_geom=model.is_geom,
                    fragment_mask=data['fragment_mask'][i].squeeze()
                )

                # Saving final prediction and ground truth separately
                true_one_hot = data['one_hot'][i].unsqueeze(0)
                true_positions = data['positions'][i].unsqueeze(0)
                true_node_mask = data['atom_mask'][i].unsqueeze(0)
                save_xyz_file(
                    final_states_output_dir,
                    true_one_hot,
                    true_positions,
                    true_node_mask,
                    names=[f'{name}_true'],
                    is_geom=model.is_geom,
                )

                pred_one_hot = chain[0, :, 3:-1].unsqueeze(0)
                pred_positions = chain[0, :, :3].unsqueeze(0)
                pred_node_mask = chain_node_mask[0].unsqueeze(0)
                save_xyz_file(
                    final_states_output_dir,
                    pred_one_hot,
                    pred_positions,
                    pred_node_mask,
                    names=[f'{name}_pred'],
                    is_geom=model.is_geom
                )

            start += len(data['positions'])
                


c:\Users\Mastro\anaconda3\envs\diff_explainer\lib\site-packages\lightning_fabric\utilities\cloud_io.py:57: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
Lightning automatically upgraded your

Number of edges 676


 72%|███████▏  | 289/400 [00:12<00:02, 39.54it/s]

Similarity between the two chains: 1.0
Similarity between the two one-hot vectors: 0.9999999403953552


100%|██████████| 400/400 [00:12<00:00, 31.76it/s]


In [2]:
chain_batch[:, 0, :, :]

tensor([[[-4.1174e+00,  1.4147e+00,  6.2530e-01,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-3.3826e+00,  2.7267e+00,  8.3010e-01,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-2.4018e+00,  2.3873e+00,  1.9495e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [-1.2884e+00,  2.1070e-01,  2.6239e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-3.5209e-01, -1.6182e+00,  3.3809e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-5.3082e-01,  7.4002e-01,  4.0135e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[-4.1174e+00,  1.4147e+00,  6.2530e-01,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-3.3826e+00,  2.7267e+00,  8.3010e-01,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-2.4018e+00,  2.3873e+00,  1.9495e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [-6.8523e-01, -2

In [3]:
node_mask

tensor([[[1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]]], device='cuda:0', dtype=torch.int8)