In [1]:
import argparse
import os
import torch

from src.datasets import get_dataloader
from src.lightning import DDPM
from src.molecule_builder import get_bond_order
from src.visualizer import save_xyz_file, visualize_chain
from tqdm.auto import tqdm
from pdb import set_trace
import sys #@mastro
from src import const #@mastro
import numpy as np #@mastro
from numpy.random import default_rng
from sklearn.metrics import jaccard_score
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import directed_hausdorff

from sklearn.decomposition import PCA
from src.visualizer import load_molecule_xyz, load_xyz_files
import matplotlib.pyplot as plt
import imageio

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Simulate command-line arguments
sys.argv = [
    'ipykernel_launcher.py',
    '--checkpoint', 'models/zinc_difflinker.ckpt',
    '--chains', 'trajectories',
    '--data', 'datasets',
    '--prefix', 'zinc_final_test',
    '--keep_frames', '10',
    '--device', 'cuda:0'
]

parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', action='store', type=str, required=True)
parser.add_argument('--chains', action='store', type=str, required=True)
parser.add_argument('--prefix', action='store', type=str, required=True)
parser.add_argument('--data', action='store', type=str, required=False, default=None)
parser.add_argument('--keep_frames', action='store', type=int, required=True)
parser.add_argument('--device', action='store', type=str, required=True)
args = parser.parse_args()

experiment_name = args.checkpoint.split('/')[-1].replace('.ckpt', '')
chains_output_dir = os.path.join(args.chains, experiment_name, args.prefix, 'chains')
final_states_output_dir = os.path.join(args.chains, experiment_name, args.prefix, 'final_states')
os.makedirs(chains_output_dir, exist_ok=True)
os.makedirs(final_states_output_dir, exist_ok=True)

# Loading model form checkpoint (all hparams will be automatically set)
model = DDPM.load_from_checkpoint(args.checkpoint, map_location=args.device)

# Possibility to evaluate on different datasets (e.g., on CASF instead of ZINC)
model.val_data_prefix = args.prefix

# In case <Anonymous> will run my model or vice versa
if args.data is not None:
    model.data_path = args.data

model = model.eval().to(args.device)
model.setup(stage='val')
dataloader = get_dataloader(
    model.val_dataset,
    batch_size=1, #@mastro, it was 32
    # batch_size=len(model.val_dataset)
)


c:\Users\Mastro\anaconda3\envs\diff_explainer\lib\site-packages\lightning_fabric\utilities\cloud_io.py:57: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
Lightning automatically upgraded your

In [2]:
def compute_molecular_similarity(mol1, mol2, mask1 = None, mask2 = None):
    """
    Compute the similarity between two molecules based on distances and atom type.
    
    Args:
        mol1 (torch.Tensor): The first molecule.
        mol2 (torch.Tensor): The second molecule.
        mask (torch.Tensor, optional): A mask indicating which atoms to consider. If not provided, all atoms will be considered.
        
    Returns:
        float: The similarity between the two molecules.
    """
    # If fragmen_mask is provided, only consider the atoms in the mask
    if mask1 is not None:
        mask1 = mask1.bool()
        mol1 = mol1[mask1,:]

    if mask2 is not None:
        mask2 = mask2.bool()
        mol2 = mol2[mask2,:]

    return 1 - torch.norm(mol1 - mol2)

def compute_molecular_distance(mol1, mol2, mask1 = None, mask2 = None):
    """
    Compute the similarity between two molecules based on distances and atom type.
    
    Args:
        mol1 (torch.Tensor): The first molecule.
        mol2 (torch.Tensor): The second molecule.
        mask (torch.Tensor, optional): A mask indicating which atoms to consider. If not provided, all atoms will be considered.
        
    Returns:
        float: The similarity between the two molecules.
    """
    # If fragmen_mask is provided, only consider the atoms in the mask
    if mask1 is not None:
        mask1 = mask1.bool()
        mol1 = mol1[mask1,:]

    if mask2 is not None:
        mask2 = mask2.bool()
        mol2 = mol2[mask2,:]

    return torch.norm(mol1 - mol2).item()

def compute_cosine_similarity(mol1, mol2, mask1 = None, mask2 = None):
    """
    Compute the similarity between two molecules based on distances and atom type.
    
    Args:
        mol1 (torch.Tensor): The first molecule.
        mol2 (torch.Tensor): The second molecule.
        mask (torch.Tensor, optional): A mask indicating which atoms to consider. If not provided, all atoms will be considered.
        
    Returns:
        float: The similarity between the two molecules.
    """
    # If fragmen_mask is provided, only consider the atoms in the mask
    if mask1 is not None:
        mask1 = mask1.bool()
        mol1 = mol1[mask1,:]

    if mask2 is not None:
        mask2 = mask2.bool()
        mol2 = mol2[mask2,:]

    return cosine_similarity(mol1.flatten().reshape(1, -1), mol2.flatten().reshape(1, -1)).item()

def compute_molecular_similarity_positions(mol1, mol2, mask1 = None, mask2 = None):
    """
    Compute the similarity between two molecules based on positions.
    
    Args:
        mol1 (torch.Tensor): The first molecule.
        mol2 (torch.Tensor): The second molecule.
        mask (torch.Tensor, optional): A mask indicating which atoms to consider. If not provided, all atoms will be considered.
        
    Returns:
        float: The similarity between the two molecules.
    """
    # If fragmen_mask is provided, only consider the atoms in the mask
    positions1 = mol1[:, :3].squeeze()
    positions2 = mol2[:, :3].squeeze()

    if mask1 is not None:
        mask1 = mask1.bool()
        positions1 = positions1[mask1,:]

    if mask2 is not None:
        mask2 = mask2.bool()
        positions2 = positions2[mask2,:]


    return 1 - torch.norm(positions1 - positions2) #choose if distance or similarity, need to check what it the better choice

def compute_one_hot_similarity(mol1, mol2, mask1 = None, mask2 = None):
    """
    Computes the similarity between two one-hot encoded molecules. The one-hot encoding indicates the atom type
    
    Args:
        mol1 (torch.Tensor): The first one-hot encoded molecule.
        mol2 (torch.Tensor): The second one-hot encoded molecule.
        mask (torch.Tensor, optional): A mask to apply on the atoms. Defaults to None.
    
    Returns:
        torch.Tensor: The similarity between the two molecules.
    """
    
    # Apply mask if provided
    if mask1 is not None:
        mask1 = mask1.bool()
        mol1 = mol1[mask1,:]

    if mask2 is not None:
        mask2 = mask2.bool()
        mol2 = mol2[mask2,:]
    
    # Compute similarity by comparing the one-hot encoded features
    similarity = torch.sum(mol1[:,3:-1] == mol2[:,3:-1]) / mol1[:, 3:-1].numel()
    
    return similarity

## Explainability

In [3]:

#@mastro
num_samples = 5
sampled = 0
#end @mastro
start = 0
bond_order_dict = {0:0, 1:0, 2:0, 3:0}
ATOM_SAMPLER = False
SAVE_ORIGINAL_PREDICTION = True
chain_with_full_fragments = None
M = 100 #number of Monte Carlo Sampling steps
P = 0.2 #probability of atom to exist in random graph (also edge in the future)
SEED = 42 #seed for random sampling
# Create the folder if it does not exist
folder_save_path = "results/explanations"
if not os.path.exists(folder_save_path):
    os.makedirs(folder_save_path)

for data in dataloader:
    
    if sampled < num_samples:
        chain_with_full_fragments = None
        sampled += 1
        
        # generate chain with original and full fragments

        chain_batch, node_mask = model.sample_chain(data, keep_frames=args.keep_frames)
        
        #get the generated molecule and store it in a variable
        chain_with_full_fragments = chain_batch[:, 0, :, :]

        # Compute distance of two chains
        mol_similarity = compute_molecular_similarity(chain_with_full_fragments.squeeze(), chain_with_full_fragments.squeeze(), mask1=data["linker_mask"][0].squeeze(), mask2=data["linker_mask"][0].squeeze())
        print("Similarity between the two chains:", mol_similarity.item())
        # compute similarity of one-hot vectors
        positional_similarity = compute_molecular_similarity_positions(chain_with_full_fragments.squeeze(), chain_with_full_fragments.squeeze(), mask1=data["linker_mask"][0].squeeze(), mask2=data["linker_mask"][0].squeeze())
        print("Similarity between the two chains based on positions:", positional_similarity.item())
        one_hot_similarity = compute_one_hot_similarity(chain_with_full_fragments.squeeze(), chain_with_full_fragments.squeeze(), mask1=data["linker_mask"][0].squeeze(), mask2=data["linker_mask"][0].squeeze())
        print("Similarity between the two one-hot vectors:", one_hot_similarity.item())
        
    
        if SAVE_ORIGINAL_PREDICTION:
            for i in range(len(data['positions'])):
                chain = chain_batch[:, i, :, :]
                assert chain.shape[0] == args.keep_frames
                assert chain.shape[1] == data['positions'].shape[1]
                assert chain.shape[2] == data['positions'].shape[2] + data['one_hot'].shape[2] + model.include_charges

                # Saving chains
                name = str(i + start)
                chain_output = os.path.join(chains_output_dir, name)
                os.makedirs(chain_output, exist_ok=True)

                one_hot = chain[:, :, 3:-1]
                positions = chain[:, :, :3]
                chain_node_mask = torch.cat([node_mask[i].unsqueeze(0) for _ in range(args.keep_frames)], dim=0)
                names = [f'{name}_{j}' for j in range(args.keep_frames)]

                save_xyz_file(chain_output, one_hot, positions, chain_node_mask, names=names, is_geom=model.is_geom)
                visualize_chain(
                    chain_output,
                    spheres_3d=True,
                    alpha=0.7,
                    bg='white',
                    is_geom=model.is_geom,
                    fragment_mask=data['fragment_mask'][i].squeeze()
                )

                # Saving final prediction and ground truth separately
                true_one_hot = data['one_hot'][i].unsqueeze(0)
                true_positions = data['positions'][i].unsqueeze(0)
                true_node_mask = data['atom_mask'][i].unsqueeze(0)
                save_xyz_file(
                    final_states_output_dir,
                    true_one_hot,
                    true_positions,
                    true_node_mask,
                    names=[f'{name}_true'],
                    is_geom=model.is_geom,
                )

                pred_one_hot = chain[0, :, 3:-1].unsqueeze(0)
                pred_positions = chain[0, :, :3].unsqueeze(0)
                pred_node_mask = chain_node_mask[0].unsqueeze(0)
                save_xyz_file(
                    final_states_output_dir,
                    pred_one_hot,
                    pred_positions,
                    pred_node_mask,
                    names=[f'{name}_pred'],
                    is_geom=model.is_geom
                )

            start += len(data['positions'])
        # display(data["fragment_mask"])
        # display(data["fragment_mask"].shape)

        # display(data["linker_mask"])
        # display(data["linker_mask"].shape)
        
        # display(data["edge_mask"])
        # display(data["edge_mask"].shape)

        #mask out all edges that are not bonds
        # idx2atom = const.GEOM_IDX2ATOM if model.is_geom else const.IDX2ATOM
      
        # positions = data["positions"][0].detach().cpu().numpy()
        # x  = positions[:,0]
        # y  = positions[:,1]
        # z  = positions[:,2]
        # # print(x)
       
        # atom_type = torch.argmax(data["one_hot"][0], dim=1)
        # print("Number of edges", len(x) * len(x))
        # sys.exit()
        #uncomment to work on edge_mask (not huge effect, tho)
        # for i in range(len(x)):
        #     for j in range(i+1, len(x)):
        #         p1 = np.array([x[i], y[i], z[i]])
        #         p2 = np.array([x[j], y[j], z[j]])
        #         dist =  np.sqrt(np.sum((p1 - p2) ** 2)) #np.linalg.norm(p1-p2)
                
        #         atom1, atom2 = idx2atom[atom_type[i].item()], idx2atom[atom_type[j].item()]
        #         bond_order = get_bond_order(atom1, atom2, dist)
                
        #         bond_order_dict[bond_order] += 1
        #         # if bond_order <= 0: #TODO debug. Why not all set to 0?
        #         if True:
        #             data["edge_mask"][i * len(x) + j] = 0
        #             data["edge_mask"][j * len(x) + i] = 0
        #         #set all edge_mask indices to 0
        #         data["edge_mask"] = torch.zeros_like(data["edge_mask"])

        #randomly mask out 50% of atoms
        # mask = torch.rand(data["atom_mask"].shape) > 0.5
        # data["atom_mask"] = data["atom_mask"] * mask.to(model.device)
        #mask out all atoms
        # data["atom_mask"] = torch.zeros_like(data["atom_mask"])
        
        #variables that will become function/class arguments/variables

        
        num_fragment_atoms = torch.sum(data["fragment_mask"] == 1)

        rng = default_rng(seed = SEED)
        phi_atoms = {}
        fragment_indices = torch.where(data["fragment_mask"] == 1)[1]
        num_fragment_atoms = len(fragment_indices)
        num_atoms = data["positions"].shape[1]

        distances_random_samples = []
        cosine_similarities_random_samples = []

        for j in tqdm(range(num_fragment_atoms)):
            
            marginal_contrib_distance = 0
            marginal_contrib_cosine_similarity = 0
            marginal_contrib_hausdorff = 0

            for step in tqdm(range(M)):
                data_j_plus = data.copy()
                data_j_minus = data.copy()
                data_random = data.copy()

                N_z_mask = rng.binomial(1, P, size = num_fragment_atoms)

                # Ensure at least one element is 1, otherwise randomly select one since at least one fragment atom must be present
                if not np.any(N_z_mask):
                    print("Zero elements in N_z_mask, randomly selecting one.")
                    random_index = rng.integers(0, num_fragment_atoms)
                    N_z_mask[random_index] = 1

                # print("N_z_mask for sample", sampled, step, N_z_mask)

                N_mask = torch.ones(num_fragment_atoms, dtype=torch.int)

                pi = torch.randperm(num_fragment_atoms)

                N_j_plus_index = torch.ones(num_fragment_atoms, dtype=torch.int)
                N_j_minus_index = torch.ones(num_fragment_atoms, dtype=torch.int)
                selected_node_index = np.where(pi == j)[0].item()
                
                # print("Selected node index", selected_node_index)
                for k in range(num_fragment_atoms):
                    if k <= selected_node_index:
                        N_j_plus_index[pi[k]] = N_mask[pi[k]]
                    else:
                        N_j_plus_index[pi[k]] = N_z_mask[pi[k]]

                for k in range(num_fragment_atoms):
                    if k < selected_node_index:
                        N_j_minus_index[pi[k]] = N_mask[pi[k]]
                    else:
                        N_j_minus_index[pi[k]] = N_z_mask[pi[k]]


                # print("N_j_plus_index", N_j_plus_index)
                # print("N_j_minus_index", N_j_minus_index)
                # print(N_j_plus_index == N_j_minus_index)
                
                N_j_plus = fragment_indices[N_j_plus_index.bool()] #fragement indices to keep in molecule j plus
                N_j_minus = fragment_indices[N_j_minus_index.bool()] #fragement indices to keep in molecule j minus

                N_random_sample = fragment_indices[torch.IntTensor(N_z_mask).bool()] #fragement indices to keep in random molecule
                # print("N_j_plus", N_j_plus)
                # print("N_j_minus", N_j_minus)
                # print(N_j_plus == N_j_minus)
                atom_mask_j_plus = torch.zeros(num_atoms, dtype=torch.bool)
                atom_mask_j_minus = torch.zeros(num_atoms, dtype=torch.bool)

                atom_mask_random_molecule = torch.zeros(num_atoms, dtype=torch.bool)

                atom_mask_j_plus[N_j_plus] = True
                #set to true also linker atoms
                atom_mask_j_plus[data["linker_mask"][0].squeeze().to(torch.int) == 1] = True
                atom_mask_j_minus[N_j_minus] = True
                #set to true also linker atoms
                atom_mask_j_minus[data["linker_mask"][0].squeeze().to(torch.int) == 1] = True

                atom_mask_random_molecule[N_random_sample] = True
                #set to true also linker atoms
                atom_mask_random_molecule[data["linker_mask"][0].squeeze().to(torch.int) == 1] = True

                # print("Atom mask j plus", atom_mask_j_plus)
                # print("Atom mask j minus", atom_mask_j_minus)
                # print(atom_mask_j_minus==atom_mask_j_plus)

                #for sample containing j
                #remove positions of atoms in random_indices
                data_j_plus["positions"] = data_j_plus["positions"][:, atom_mask_j_plus]
                #remove one_hot of atoms in random_indices
                data_j_plus["one_hot"] = data_j_plus["one_hot"][:, atom_mask_j_plus]
                #remove atom_mask of atoms in random_indices
                data_j_plus["atom_mask"] = data_j_plus["atom_mask"][:, atom_mask_j_plus]
                #remove fragment_mask of atoms in random_indices
                data_j_plus["fragment_mask"] =  data_j_plus["fragment_mask"][:, atom_mask_j_plus]
                #remove linker_mask of atoms in random_indices
                data_j_plus["linker_mask"] = data_j_plus["linker_mask"][:, atom_mask_j_plus]
                #remove edge_mask of atoms in random_indices
                for index in N_j_plus:
                    for i in range(num_atoms):
                        data_j_plus["edge_mask"][index * num_atoms + i] = 0
                        data_j_plus["edge_mask"][i * num_atoms + index] = 0

                #remove all values in edge_mask that are 0
                data_j_plus["edge_mask"] = data_j_plus["edge_mask"][data_j_plus["edge_mask"] != 0]  #to be checked, but working on atoms has as effect. For the moment we stick to atoms, then we move to edges (need to edit internal function for this, or redefine everything...)

                # print("After removal j plus:", data_j_plus["positions"])
                # print(data_j_plus["positions"].shape)
                
                #for sample not containing j
                #remove positions of atoms in random_indices
                data_j_minus["positions"] = data_j_minus["positions"][:, atom_mask_j_minus]
                #remove one_hot of atoms in random_indices
                data_j_minus["one_hot"] = data_j_minus["one_hot"][:, atom_mask_j_minus]
                #remove atom_mask of atoms in random_indices
                data_j_minus["atom_mask"] = data_j_minus["atom_mask"][:, atom_mask_j_minus]
                #remove fragment_mask of atoms in random_indices
                data_j_minus["fragment_mask"] =  data_j_minus["fragment_mask"][:, atom_mask_j_minus]
                #remove linker_mask of atoms in random_indices
                data_j_minus["linker_mask"] = data_j_minus["linker_mask"][:, atom_mask_j_minus]
                #remove edge_mask of atoms in random_indices
                for index in N_j_minus:
                    for i in range(num_atoms):
                        data_j_minus["edge_mask"][index * num_atoms + i] = 0
                        data_j_minus["edge_mask"][i * num_atoms + index] = 0

                #remove all values in edge_mask that are 0
                data_j_minus["edge_mask"] = data_j_minus["edge_mask"][data_j_minus["edge_mask"] != 0]  #to be checked, but working on atoms has as effect. For the moment we stick to atoms, then we move to edges (need to edit internal function for this, or redefine everything...)

                # print("After removal j minus:", data_j_minus["positions"])
                # print(data_j_minus["positions"].shape)

                #for random sample
                data_random["positions"] = data_random["positions"][:, atom_mask_random_molecule]
                #remove one_hot of atoms in random_indices
                data_random["one_hot"] = data_random["one_hot"][:, atom_mask_random_molecule]
                #remove atom_mask of atoms in random_indices
                data_random["atom_mask"] = data_random["atom_mask"][:, atom_mask_random_molecule]
                #remove fragment_mask of atoms in random_indices
                data_random["fragment_mask"] =  data_random["fragment_mask"][:, atom_mask_random_molecule]
                #remove linker_mask of atoms in random_indices
                data_random["linker_mask"] = data_random["linker_mask"][:, atom_mask_random_molecule]
                #remove edge_mask of atoms in random_indices
                for index in N_z_mask:
                    for i in range(num_atoms):
                        data_random["edge_mask"][index * num_atoms + i] = 0
                        data_random["edge_mask"][i * num_atoms + index] = 0

                #remove all values in edge_mask that are 0
                data_random["edge_mask"] = data_random["edge_mask"][data_random["edge_mask"] != 0] 



                #with node j
                chain_j_plus, node_mask_j_plus = model.sample_chain(data_j_plus, keep_frames=args.keep_frames)

                V_j_plus_distance = compute_molecular_distance(chain_with_full_fragments.squeeze(), chain_j_plus.squeeze(), mask1=data["linker_mask"][0].squeeze(), mask2=data_j_plus["linker_mask"][0].squeeze())

                V_j_plus_cosine_similarity = compute_cosine_similarity(chain_with_full_fragments.squeeze().cpu(), chain_j_plus.squeeze().cpu(), mask1=data["linker_mask"][0].squeeze().cpu(), mask2=data_j_plus["linker_mask"][0].squeeze().cpu())

                # print("V_j_plus", V_j_plus)

                #without node j
                chain_j_minus, node_mask_j_minus = model.sample_chain(data_j_minus, keep_frames=args.keep_frames)

                V_j_minus_distance = compute_molecular_distance(chain_with_full_fragments.squeeze(), chain_j_minus.squeeze(), mask1=data["linker_mask"][0].squeeze(), mask2=data_j_minus["linker_mask"][0].squeeze())

                V_j_minus_cosine_similarity = compute_cosine_similarity(chain_with_full_fragments.squeeze().cpu(), chain_j_minus.squeeze().cpu(), mask1=data["linker_mask"][0].squeeze().cpu(), mask2=data_j_minus["linker_mask"][0].squeeze().cpu())

                #with random sample
                chain_random, node_mask_random = model.sample_chain(data_random, keep_frames=args.keep_frames)

                V_random_distance = compute_molecular_distance(chain_with_full_fragments.squeeze(), chain_random.squeeze(), mask1=data["linker_mask"][0].squeeze(), mask2=data_random["linker_mask"][0].squeeze())

                V_random_cosine_similarity = compute_cosine_similarity(chain_with_full_fragments.squeeze().cpu(), chain_random.squeeze().cpu(), mask1=data["linker_mask"][0].squeeze().cpu(), mask2=data_random["linker_mask"][0].squeeze().cpu())

                distances_random_samples.append(V_random_distance)
                cosine_similarities_random_samples.append(V_random_cosine_similarity)

                # print(V_random_distance, V_random_cosine_similarity)
                
                marginal_contrib_distance += (V_j_plus_distance - V_j_minus_distance)

                marginal_contrib_cosine_similarity += (V_j_plus_cosine_similarity - V_j_minus_cosine_similarity)

                # marginal_contrib_hausdorff += (V_j_plus_hausdorff - V_j_minus_hausdorff)

            phi_atoms[fragment_indices[j].item()] = [0,0] #,0]    
            phi_atoms[fragment_indices[j].item()][0] = marginal_contrib_distance/M #j is the index of the fragment atom in the fragment indices tensor
            phi_atoms[fragment_indices[j].item()][1] = marginal_contrib_cosine_similarity/M
            # phi_atoms[fragment_indices[j]][2] = marginal_contrib_hausdorff/M

            print(data["name"])
            
        # Save phi_atoms to a text file
        with open(f'{folder_save_path}/phi_atoms_{sampled}.txt', 'w') as write_file:
            write_file.write("sample name: " + str(data["name"]) + "\n")
            write_file.write("atom_index,distance,cosine_similarity\n")
            for atom_index, phi_values in phi_atoms.items():
                write_file.write(f"{atom_index},{phi_values[0]},{phi_values[1]}\n")

            write_file.write("\n")      
            write_file.write("Distance random samples\n")
            write_file.write(str(distances_random_samples) + "\n")
            write_file.write("Cosine similarity random samples\n")
            write_file.write(str(cosine_similarities_random_samples) + "\n")

        

        
        
                


Similarity between the two chains: 1.0
Similarity between the two chains based on positions: 1.0
Similarity between the two one-hot vectors: 1.0


  0%|          | 0/21 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# if ATOM_SAMPLER:
#             print("Before removal:", data["positions"].shape)
#             #get all indices in atom_mask that correspond to fragment atoms
#             fragment_indices = torch.where(data["fragment_mask"] == 1)[1]

#             print(fragment_indices)
            
#             #randomly pick 50% of fragment atoms
#             random_indices = torch.randperm(len(fragment_indices))[:int(len(fragment_indices)/2)]
#             mask_fragments = torch.ones(len(fragment_indices), dtype=torch.bool)
#             mask_fragments[random_indices] = False
#             selected_fragment_atoms = fragment_indices[mask_fragments]

#             #keep only the first index in selected_fragment_atoms
#             selected_fragment_atoms = selected_fragment_atoms[:1] #@mastro change this removing the explained atom
#             print("Selected fragment atoms", selected_fragment_atoms)
            
#             num_atoms = data["positions"].shape[1]
#             # random_indices = torch.randperm(num_atoms)[:int(num_atoms/2)]
#             mask = torch.ones(num_atoms, dtype=torch.bool)
#             mask[selected_fragment_atoms] = False


#             #remove positions of atoms in random_indices
#             data["positions"] = data["positions"][:, mask]
#             #remove one_hot of atoms in random_indices
#             data["one_hot"] = data["one_hot"][:, mask]
#             #remove atom_mask of atoms in random_indices
#             data["atom_mask"] = data["atom_mask"][:, mask]
#             #remove fragment_mask of atoms in random_indices
#             data["fragment_mask"] =  data["fragment_mask"][:, mask]
#             #remove linker_mask of atoms in random_indices
#             data["linker_mask"] = data["linker_mask"][:, mask]
#             #remove edge_mask of atoms in random_indices
#             for index in random_indices:
#                 for i in range(num_atoms):
#                     data["edge_mask"][index * num_atoms + i] = 0
#                     data["edge_mask"][i * num_atoms + index] = 0

#             #remove all values in edge_mask that are 0
#             data["edge_mask"] = data["edge_mask"][data["edge_mask"] != 0]  #to be checked, but working on atoms has as effect. For the moment we stick to atoms, then we move to edges (need to edit internal function for this, or redefine everything...)
            

#             print("After removal:", data["positions"].shape)
#             # sys.exit()
#             # print number of zeros in edge mask
#             print("Number of masked out edges (edges not representing bonds)", torch.sum(data["edge_mask"] == 0))
#             print("Number of edges still present", torch.sum(data["edge_mask"] != 0))

#             # print number of zeros in atom mask
#             print("Number of masked out atoms", torch.sum(data["atom_mask"] == 0))

## Visualization

Implementing visualization highlighting important atoms

In [3]:
phi_atoms = {0: -0.1030,
 1: -2.5306,
 2: 0.3042,
 3: -1.4719,
 4: -1.3421,
 5: -0.2840,
 6: 0.9049,
 7: -1.1589,
 8: -0.2158,
 9: -1.8067,
 10: -2.1304,
 11: -0.2276,
 12: -1.6576,
 13: -0.0887,
 14: 0.0418,
 15: -1.8112,
 16: 0.4877,
 17: 0.0786,
 18: -0.2415,
 19: -1.5389,
 20: -1.8475}

#### Utility functions for visualization modified from DiffLinker work

In [13]:
def draw_sphere_xai(ax, x, y, z, size, color, alpha):
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)

    xs = size * np.outer(np.cos(u), np.sin(v))
    ys = size * np.outer(np.sin(u), np.sin(v)) #* 0.8
    zs = size * np.outer(np.ones(np.size(u)), np.cos(v))
    ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, alpha=alpha)

def plot_molecule_xai(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom, fragment_mask=None, phi_values=None):
    x = positions[:, 0]
    y = positions[:, 1]
    z = positions[:, 2]
    # Hydrogen, Carbon, Nitrogen, Oxygen, Flourine

    idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM

    colors_dic = np.array(const.COLORS)
    radius_dic = np.array(const.RADII)
    area_dic = 1500 * radius_dic ** 2

    areas = area_dic[atom_type]
    radii = radius_dic[atom_type]
    colors = colors_dic[atom_type]

    if fragment_mask is None:
        fragment_mask = torch.ones(len(x))

    for i in range(len(x)):
        for j in range(i + 1, len(x)):
            p1 = np.array([x[i], y[i], z[i]])
            p2 = np.array([x[j], y[j], z[j]])
            dist = np.sqrt(np.sum((p1 - p2) ** 2))
            atom1, atom2 = idx2atom[atom_type[i]], idx2atom[atom_type[j]]
            draw_edge_int = get_bond_order(atom1, atom2, dist)
            line_width = (3 - 2) * 2 * 2
            draw_edge = draw_edge_int > 0
            if draw_edge:
                if draw_edge_int == 4:
                    linewidth_factor = 1.5
                else:
                    linewidth_factor = 1
                linewidth_factor *= 0.5
                ax.plot(
                    [x[i], x[j]], [y[i], y[j]], [z[i], z[j]],
                    linewidth=line_width * linewidth_factor * 2,
                    c=hex_bg_color,
                    alpha=alpha
                )

    # from pdb import set_trace
    # set_trace()

    if spheres_3d:
        # idx = torch.where(fragment_mask[:len(x)] == 0)[0]
        # ax.scatter(
        #     x[idx],
        #     y[idx],
        #     z[idx],
        #     alpha=0.9 * alpha,
        #     edgecolors='#FCBA03',
        #     facecolors='none',
        #     linewidths=2,
        #     s=900
        # )
        for i, j, k, s, c, f, phi in zip(x, y, z, radii, colors, fragment_mask, phi_values):
            if f == 1:
                alpha = 1.0
                if phi > 0:
                    c = 'red'

            draw_sphere_xai(ax, i.item(), j.item(), k.item(), 0.5 * s, c, alpha)

    else:
        phi_values_array = np.array(list(phi_values.values()))

        # #draw fragments
        # fragment_mask_on_cpu = fragment_mask.cpu().numpy()
        # colors_fragment = colors[fragment_mask_on_cpu == 1]
        # x_fragment = x[fragment_mask_on_cpu == 1]
        # y_fragment = y[fragment_mask_on_cpu == 1]
        # z_fragment = z[fragment_mask_on_cpu == 1]
        # areas_fragment = areas[fragment_mask_on_cpu == 1]
        # ax.scatter(x_fragment, y_fragment, z_fragment, s=areas_fragment, alpha=0.9 * alpha, c=np.where(phi_values_array > 0, 'red', colors_fragment))

        # #draw non-fragment atoms
        # colors = colors[fragment_mask_on_cpu == 0]
        # x = x[fragment_mask_on_cpu == 0]
        # y = y[fragment_mask_on_cpu == 0]
        # z = z[fragment_mask_on_cpu == 0]
        # areas = areas[fragment_mask_on_cpu == 0]
        # ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, c=colors)

        #draw fragments
        fragment_mask_on_cpu = fragment_mask.cpu().numpy()
        colors_fragment = colors[fragment_mask_on_cpu == 1]
        x_fragment = x[fragment_mask_on_cpu == 1]
        y_fragment = y[fragment_mask_on_cpu == 1]
        z_fragment = z[fragment_mask_on_cpu == 1]
        areas_fragment = areas[fragment_mask_on_cpu == 1]
        
        # Calculate the gradient colors based on phi values
        cmap = plt.cm.get_cmap('coolwarm')
        norm = plt.Normalize(vmin=min(phi_values_array), vmax=max(phi_values_array))
        colors_fragment_shadow = cmap(norm(phi_values_array))
        
        # ax.scatter(x_fragment, y_fragment, z_fragment, s=areas_fragment, alpha=0.9 * alpha, c=colors_fragment)

        ax.scatter(x_fragment, y_fragment, z_fragment, s=areas_fragment, alpha=0.9 * alpha, c=colors_fragment, edgecolors=colors_fragment_shadow, linewidths=5, rasterized=False)

        #draw non-fragment atoms
        colors = colors[fragment_mask_on_cpu == 0]
        x = x[fragment_mask_on_cpu == 0]
        y = y[fragment_mask_on_cpu == 0]
        z = z[fragment_mask_on_cpu == 0]
        areas = areas[fragment_mask_on_cpu == 0]
        ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, c=colors, rasterized=False)


def plot_data3d_xai(positions, atom_type, is_geom, camera_elev=0, camera_azim=0, save_path=None, spheres_3d=False,
                bg='black', alpha=1., fragment_mask=None, phi_values=None):
    black = (0, 0, 0)
    white = (1, 1, 1)
    hex_bg_color = '#FFFFFF' if bg == 'black' else '#000000' #'#666666'

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(projection='3d')
    ax.set_aspect('auto')
    ax.view_init(elev=camera_elev, azim=camera_azim)
    if bg == 'black':
        ax.set_facecolor(black)
    else:
        ax.set_facecolor(white)
    ax.xaxis.pane.set_alpha(0)
    ax.yaxis.pane.set_alpha(0)
    ax.zaxis.pane.set_alpha(0)
    ax._axis3don = False

    if bg == 'black':
        ax.w_xaxis.line.set_color("black")
    else:
        ax.w_xaxis.line.set_color("white")

    plot_molecule_xai(
        ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom=is_geom, fragment_mask=fragment_mask, phi_values=phi_values
    )

    max_value = positions.abs().max().item()
    axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
    ax.set_xlim(-axis_lim, axis_lim)
    ax.set_ylim(-axis_lim, axis_lim)
    ax.set_zlim(-axis_lim, axis_lim)
    dpi = 300 if spheres_3d else 300 #it was 120 and 50

    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
        # plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi, transparent=True)

        if spheres_3d:
            img = imageio.imread(save_path)
            img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
            imageio.imsave(save_path, img_brighter)
    else:
        plt.show()
    plt.close()

def visualize_chain_xai(
        path, spheres_3d=False, bg="black", alpha=1.0, wandb=None, mode="chain", is_geom=False, fragment_mask=None, phi_values=None
):
    files = load_xyz_files(path)
    save_paths = []

    # Fit PCA to the final molecule – to obtain the best orientation for visualization
    positions, one_hot, charges = load_molecule_xyz(files[-1], is_geom=is_geom)
    pca = PCA(n_components=3)
    pca.fit(positions)

    for i in range(len(files)):
        file = files[i]

        positions, one_hot, charges = load_molecule_xyz(file, is_geom=is_geom)
        atom_type = torch.argmax(one_hot, dim=1).numpy()

        # Transform positions of each frame according to the best orientation of the last frame
        positions = pca.transform(positions)
        positions = torch.tensor(positions)

        fn = file[:-4] + '.png'
        plot_data3d_xai(
            positions, atom_type,
            save_path=fn,
            spheres_3d=spheres_3d,
            alpha=alpha,
            bg=bg,
            camera_elev=90,
            camera_azim=90,
            is_geom=is_geom,
            fragment_mask=fragment_mask,
            phi_values=phi_values
        )
        save_paths.append(fn)

    imgs = [imageio.imread(fn) for fn in save_paths]
    dirname = os.path.dirname(save_paths[0])
    gif_path = dirname + '/output.gif'
    imageio.mimsave(gif_path, imgs, subrectangles=True)

    if wandb is not None:
        wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})

In [5]:
phi_values_array = np.array(list(phi_atoms.values()))
phi_values_array.shape

(21,)

In [14]:
start = 0
for data in tqdm(dataloader):
    chain_batch, node_mask = model.sample_chain(data, keep_frames=args.keep_frames)
    for i in tqdm(range(len(data['positions']))):
        chain = chain_batch[:, i, :, :]
        assert chain.shape[0] == args.keep_frames
        assert chain.shape[1] == data['positions'].shape[1]
        assert chain.shape[2] == data['positions'].shape[2] + data['one_hot'].shape[2] + model.include_charges

        # Saving chains
        name = str(i + start)
        chain_output = os.path.join(chains_output_dir, name)
        os.makedirs(chain_output, exist_ok=True)

        one_hot = chain[:, :, 3:-1]
        positions = chain[:, :, :3]
        chain_node_mask = torch.cat([node_mask[i].unsqueeze(0) for _ in range(args.keep_frames)], dim=0)
        names = [f'{name}_{j}' for j in range(args.keep_frames)]

        save_xyz_file(chain_output, one_hot, positions, chain_node_mask, names=names, is_geom=model.is_geom)
        visualize_chain_xai(
            chain_output,
            spheres_3d=False,
            alpha=0.7,
            bg='white',
            is_geom=model.is_geom,
            fragment_mask=data['fragment_mask'][i].squeeze(),
            phi_values=phi_atoms
        )

        # Saving final prediction and ground truth separately
        true_one_hot = data['one_hot'][i].unsqueeze(0)
        true_positions = data['positions'][i].unsqueeze(0)
        true_node_mask = data['atom_mask'][i].unsqueeze(0)
        save_xyz_file(
            final_states_output_dir,
            true_one_hot,
            true_positions,
            true_node_mask,
            names=[f'{name}_true'],
            is_geom=model.is_geom,
        )

        pred_one_hot = chain[0, :, 3:-1].unsqueeze(0)
        pred_positions = chain[0, :, :3].unsqueeze(0)
        pred_node_mask = chain_node_mask[0].unsqueeze(0)
        save_xyz_file(
            final_states_output_dir,
            pred_one_hot,
            pred_positions,
            pred_node_mask,
            names=[f'{name}_pred'],
            is_geom=model.is_geom
        )

    start += len(data['positions'])
    break

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  imgs = [imageio.imread(fn) for fn in save_paths]


In [None]:
# import numpy as np

# chain_with_full_fragments = chain_with_full_fragments[0,:,:]
# # Generate a random boolean mask
# mask = np.random.choice([True, False], size=chain_with_full_fragments.shape[0])

# # Apply the mask to remove random rows
# masked_chain = chain_with_full_fragments[mask, :]

# masked_chain.shape

In [None]:
# for data in dataloader:
#     positions = data["positions"]
#     # print(positions)
    
#     cos_sim = cosine_similarity((positions[0].squeeze().detach().cpu().numpy().flatten().reshape(1, -1)),positions[0].squeeze().detach().cpu().numpy().flatten().reshape(1, -1))

#     cosine = compute_cosine_similarity(positions[0].squeeze().cpu(), positions[0].squeeze().cpu(), mask1=data["linker_mask"][0].squeeze().cpu(), mask2=data["linker_mask"][0].squeeze().cpu())
#     print(cosine)
#     break

In [None]:
# atom_mask_random_molecule