## Important neighobing atoms analysis

### Import Libraries

In [1]:
import os
os.environ["http_proxy"] = "http://web-proxy.informatik.uni-bonn.de:3128"
os.environ["https_proxy"] = "http://web-proxy.informatik.uni-bonn.de:3128"

import copy
import yaml
import numpy as np
import random
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import imageio.v2 as imageio
import networkx as nx
from pysmiles import read_smiles
import seaborn as sns
# from sklearn.decomposition import PCA

import torch
from sklearn.decomposition import PCA
from scipy.spatial.distance import directed_hausdorff

from src.lightning import DDPM
from src.datasets import get_dataloader
from src.visualizer import load_molecule_xyz, load_xyz_files, save_xyz_file
from src.molecule_builder import get_bond_order
from src.utils import add_partial_mean_with_mask
from src import const



In [2]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Load configuration from config.yml
with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)

checkpoint = config['CHECKPOINT']
chains = config['CHAINS']
DATA = config['DATA']
prefix = config['PREFIX']
keep_frames = int(config['KEEP_FRAMES'])
P = config['P']
device = config['DEVICE'] if torch.cuda.is_available() else 'cpu'
SEED = int(config['SEED'])
REMOVAL = config['REMOVAL']
ATOM_PERTURBATION = config['ATOM_PERTURBATION']
ROTATE = config['ROTATE']
TRANSLATE = config['TRANSLATE']
REFLECT = config['REFLECT']
TRANSFORMATION_SEED = int(config['TRANSFORMATION_SEED'])
SAVE_VISUALIZATION = config['SAVE_VISUALIZATION']
M = int(config['M'])
NUM_SAMPLES = int(config['NUM_SAMPLES'])
PARALLEL_STEPS = int(config['PARALLEL_STEPS'])
TOP_K_PERTURBATION_REMOVAL = config['TOP_K_PERTURBATION_REMOVAL']
LOCAL_MINIMUM_ANALYSIS = config['LOCAL_MINIMUM_ANALYSIS']

print("Random seed: ", SEED)

experiment_name = checkpoint.split('/')[-1].replace('.ckpt', '')

assert (REMOVAL or ATOM_PERTURBATION) and not (REMOVAL and ATOM_PERTURBATION), "Either REMOVAL or ATOM_PERTURBATION must be set to True, but not both or None"

#create output directories
if REMOVAL:
    chains_output_dir = os.path.join(chains, experiment_name, prefix, 'chains_' + P + '_seed_' + str(SEED) + '_neighbors_analysis_COM_addition/atom_removal')
    final_states_output_dir = os.path.join(chains, experiment_name, prefix, 'final_states_' + P + '_seed_' + str(SEED) + '_neighbors_analysis_COM_addition/atom_removal')
elif ATOM_PERTURBATION:
    chains_output_dir = os.path.join(chains, experiment_name, prefix, 'chains_' + P + '_seed_' + str(SEED) + '_neighbors_analysis_COM_addition/atom_perturbation')
    final_states_output_dir = os.path.join(chains, experiment_name, prefix, 'final_states_' + P + '_seed_' + str(SEED) + '_neighbors_analysis_COM_addition/atom_perturbation')
else:
    chains_output_dir = os.path.join(chains, experiment_name, prefix, 'chains_' + P + '_seed_' + str(SEED) + '_neighbors_analysis_COM_addition')
    final_states_output_dir = os.path.join(chains, experiment_name, prefix, 'final_states_' + P + '_seed_' + str(SEED) + '_neighbors_analysis_COM_addition')

if TOP_K_PERTURBATION_REMOVAL:
    chains_output_dir = chains_output_dir.replace('atom', 'top_k_atom')
    final_states_output_dir = final_states_output_dir.replace('atom', 'top_k_atom')
    final_states_output_dir = final_states_output_dir.replace('atom', 'top_k_atom')

if LOCAL_MINIMUM_ANALYSIS:
    chains_output_dir = chains_output_dir.replace('atom', 'local_min')
    final_states_output_dir = final_states_output_dir.replace('atom', 'local_min')
    
os.makedirs(chains_output_dir, exist_ok=True)
os.makedirs(final_states_output_dir, exist_ok=True)

# Loading model form checkpoint 
model = DDPM.load_from_checkpoint(checkpoint, map_location=device)

# Possibility to evaluate on different datasets (e.g., on CASF instead of ZINC)
model.val_data_prefix = prefix

print(f"Running on device: {device}")
# In case <Anonymous> will run my model or vice versa
if DATA is not None:
    model.data_path = DATA

model = model.eval().to(device)
model.setup(stage='val')
dataloader = get_dataloader(
    model.val_dataset,
    batch_size=1, #@mastro, it was 32
    # batch_size=len(model.val_dataset)
)

print("Model", model)
print("Using anchors as context: ", model.anchors_context)
print("Center of mass:", model.center_of_mass)
print("Inpaiting: ", model.inpainting)

Random seed:  42


/home/mastropietro/anaconda3/envs/diff_explainer/lib/python3.10/site-packages/lightning_fabric/utilities/cloud_io.py:57: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
Lightning automatically

Running on device: cuda:2
Model DDPM(
  (edm): EDM(
    (gamma): PredefinedNoiseSchedule()
    (dynamics): Dynamics(
      (dynamics): EGNN(
        (embedding): Linear(in_features=10, out_features=128, bias=True)
        (embedding_out): Linear(in_features=128, out_features=10, bias=True)
        (e_block_0): EquivariantBlock(
          (gcl_0): GCL(
            (edge_mlp): Sequential(
              (0): Linear(in_features=258, out_features=128, bias=True)
              (1): SiLU()
              (2): Linear(in_features=128, out_features=128, bias=True)
              (3): SiLU()
            )
            (node_mlp): Sequential(
              (0): Linear(in_features=256, out_features=128, bias=True)
              (1): SiLU()
              (2): Linear(in_features=128, out_features=128, bias=True)
            )
          )
          (gcl_1): GCL(
            (edge_mlp): Sequential(
              (0): Linear(in_features=258, out_features=128, bias=True)
              (1): SiLU()
          

### Set random seeds

In [3]:
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
np.random.seed(SEED)
random.seed(SEED)

### Utility functions

In [4]:
def arrestomomentum():
    raise KeyboardInterrupt("Debug interrupt.")

def compute_hausdorff_distance_batch(mol1, mol2, mask1 = None, mask2 = None, project = None, projection_mask1 = None, projection_mask2 = None,node_projection_mask = None):
    """
    Compute the Hausdorff distance between two batches of molecules.
    
    Args:
        mol1 (torch.Tensor): The first batch of molecules.
        mol2 (torch.Tensor): The second batch of molecules.
        mask1 (torch.Tensor, optional): A mask indicating which atoms to consider for mol1. If not provided, all atoms will be considered.
        mask2 (torch.Tensor, optional): A mask indicating which atoms to consider for mol2. If not provided, all atoms will be considered.
        project (bool, optional): Whether to project the molecules to their mean-centered coordinates.
        
    Returns:
        list: The Hausdorff distances between the corresponding molecules in the batches.
    """
    # If fragment_mask is provided, only consider the atoms in the mask

    #take only the positions
    mol1 = mol1[:, :, :3]
    mol2 = mol2[:, :, :3]

    
    # print("Mol 2 before projection: ", mol2[0])
    if project is None:
        pass
    elif project == "origin":
        mol1 = mol1 - mol1.mean(dim=1, keepdim=True)
        mol2 = mol2 - mol2.mean(dim=1, keepdim=True)
    elif project == "com":
        assert(projection_mask1 is not None and projection_mask2 is not None)
        mol1 = add_partial_mean_with_mask(mol1, node_projection_mask, projection_mask1)
        mol2 = add_partial_mean_with_mask(mol2, node_projection_mask, projection_mask1)
    else:
        raise ValueError("Invalid projection type.") 

    
    if mask1 is not None:
        mask1 = mask1.bool()
        batch_size = mol1.shape[0]
        masked_mol1 = []
        for i in range(batch_size):
            masked_mol1.append(mol1[i, mask1[i], :])
        
        if batch_size == 1:
            mol1 = masked_mol1[0].unsqueeze(0)
        else:    
            mol1 = torch.stack(masked_mol1)
        

    if mask2 is not None:
        mask2 = mask2.bool()
        mask2 = mask2.bool()
        batch_size = mol2.shape[0]
        masked_mol2 = []
        for i in range(batch_size):
            masked_mol2.append(mol2[i, mask2[i], :])
        
        if batch_size == 1:
            mol2 = masked_mol2[0].unsqueeze(0)
        else:    
            mol2 = torch.stack(masked_mol2)

    hausdorff_distances = []
    for i in range(mol1.shape[0]):
        hausdorff_distances.append(max(directed_hausdorff(mol1[i], mol2[i])[0], directed_hausdorff(mol2[i], mol1[i])[0]))

    return hausdorff_distances

def draw_sphere_xai(ax, x, y, z, size, color, alpha):
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)

    xs = size * np.outer(np.cos(u), np.sin(v))
    ys = size * np.outer(np.sin(u), np.sin(v)) #* 0.8
    zs = size * np.outer(np.ones(np.size(u)), np.cos(v))
    ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, alpha=alpha)

def plot_molecule_xai(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom, fragment_mask=None, phi_values=None, colors_fragment_shadow=None, draw_atom_indices = False):
    x = positions[:, 0]
    y = positions[:, 1]
    z = positions[:, 2]
    # Hydrogen, Carbon, Nitrogen, Oxygen, Flourine

    idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM

    colors_dic = np.array(const.COLORS)
    radius_dic = np.array(const.RADII)
    area_dic = 1500 * radius_dic ** 2

    areas = area_dic[atom_type]
    radii = radius_dic[atom_type]
    colors = colors_dic[atom_type]

    if fragment_mask is None:
        fragment_mask = torch.ones(len(x))

    for i in range(len(x)):
        for j in range(i + 1, len(x)):
            p1 = np.array([x[i], y[i], z[i]])
            p2 = np.array([x[j], y[j], z[j]])
            dist = np.sqrt(np.sum((p1 - p2) ** 2))
            atom1, atom2 = idx2atom[atom_type[i]], idx2atom[atom_type[j]]
            draw_edge_int = get_bond_order(atom1, atom2, dist)
            line_width = (3 - 2) * 2 * 2
            draw_edge = draw_edge_int > 0
            if draw_edge:
                if draw_edge_int == 4:
                    linewidth_factor = 1.5
                else:
                    linewidth_factor = 1
                linewidth_factor *= 0.5
                ax.plot(
                    [x[i], x[j]], [y[i], y[j]], [z[i], z[j]],
                    linewidth=line_width * linewidth_factor * 2,
                    c=hex_bg_color,
                    alpha=alpha
                )

    

    if spheres_3d:
        
        for i, j, k, s, c, f, phi in zip(x, y, z, radii, colors, fragment_mask, phi_values):
            if f == 1:
                alpha = 1.0
                if phi > 0:
                    c = 'red'

            draw_sphere_xai(ax, i.item(), j.item(), k.item(), 0.5 * s, c, alpha)

    else:

        #draw fragments
        fragment_mask_on_cpu = fragment_mask.cpu().numpy()
        colors_fragment = colors[fragment_mask_on_cpu == 1]
        x_fragment = x[fragment_mask_on_cpu == 1]
        y_fragment = y[fragment_mask_on_cpu == 1]
        z_fragment = z[fragment_mask_on_cpu == 1]
        areas_fragment = areas[fragment_mask_on_cpu == 1]
        
        if phi_values is not None and colors_fragment_shadow is None:
            phi_values_array = np.array(phi_values)
            # Calculate the gradient colors based on phi values
            cmap = plt.cm.get_cmap('coolwarm_r') #reversed heatmap for distance-based importance
            norm = plt.Normalize(vmin=min(phi_values_array), vmax=max(phi_values_array))
            colors_fragment_shadow = cmap(norm(phi_values_array))
        elif colors_fragment_shadow is not None and phi_values is None:
            colors_fragment_shadow = colors_fragment_shadow
        else:
            raise ValueError("Either phi_values or colors_fragment_shadow must be provided, not both.")
        # ax.scatter(x_fragment, y_fragment, z_fragment, s=areas_fragment, alpha=0.9 * alpha, c=colors_fragment)

        ax.scatter(x_fragment, y_fragment, z_fragment, s=areas_fragment, alpha=0.9 * alpha, c=colors_fragment, edgecolors=colors_fragment_shadow, linewidths=5, rasterized=False)
        
        if draw_atom_indices == "original":
            #get fragment indices using fragment mask
            fragment_indices = np.where(fragment_mask_on_cpu == 1)[0]
            for i, txt in enumerate(fragment_indices):
                ax.text(x_fragment[i], y_fragment[i], z_fragment[i], str(txt), color='black', fontsize=15)
        
        elif draw_atom_indices is None:
            pass

        else:
            for i, txt in enumerate(draw_atom_indices[0]):
                ax.text(x_fragment[i], y_fragment[i], z_fragment[i], str(txt), color='black', fontsize=15)

        

        #draw non-fragment atoms
        colors = colors[fragment_mask_on_cpu == 0]
        x = x[fragment_mask_on_cpu == 0]
        y = y[fragment_mask_on_cpu == 0]
        z = z[fragment_mask_on_cpu == 0]
        areas = areas[fragment_mask_on_cpu == 0]
        ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, c=colors, rasterized=False)

        if draw_atom_indices == "original":
            #get non-fragment indices using fragment mask
            non_fragment_indices = np.where(fragment_mask_on_cpu == 0)[0]
            for i, txt in enumerate(non_fragment_indices):
                ax.text(x[i], y[i], z[i], str(txt), color='black', fontsize=15)
        elif draw_atom_indices is None:
            pass
        else:
            for i, txt in enumerate(draw_atom_indices[1]):
                ax.text(x[i], y[i], z[i], str(txt), color='black', fontsize=15)

        



def plot_data3d_xai(positions, atom_type, is_geom, camera_elev=0, camera_azim=0, save_path=None, spheres_3d=False,
                bg='black', alpha=1., fragment_mask=None, phi_values=None, colors_fragment_shadow=None, draw_atom_indices = False):
    black = (0, 0, 0)
    white = (1, 1, 1)
    hex_bg_color = '#FFFFFF' if bg == 'black' else '#000000' #'#666666'

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(projection='3d')
    ax.set_aspect('auto')
    ax.view_init(elev=camera_elev, azim=camera_azim)
    if bg == 'black':
        ax.set_facecolor(black)
    else:
        ax.set_facecolor(white)
    ax.xaxis.pane.set_alpha(0)
    ax.yaxis.pane.set_alpha(0)
    ax.zaxis.pane.set_alpha(0)
    ax._axis3don = False

    if bg == 'black':
        ax.w_xaxis.line.set_color("black")
    else:
        ax.w_xaxis.line.set_color("white")

    plot_molecule_xai(
        ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom=is_geom, fragment_mask=fragment_mask, phi_values=phi_values, colors_fragment_shadow=colors_fragment_shadow, draw_atom_indices = draw_atom_indices
    )

    max_value = positions.abs().max().item()
    axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
    ax.set_xlim(-axis_lim, axis_lim)
    ax.set_ylim(-axis_lim, axis_lim)
    ax.set_zlim(-axis_lim, axis_lim)
    dpi = 300 if spheres_3d else 300 #it was 120 and 50

    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
        # plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi, transparent=True)

        if spheres_3d:
            img = imageio.imread(save_path)
            img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
            imageio.imsave(save_path, img_brighter)
    else:
        plt.show()
    plt.close()

def visualize_chain_xai(
        path, spheres_3d=False, bg="black", alpha=1.0, wandb=None, mode="chain", is_geom=False, fragment_mask=None, phi_values=None, colors_fragment_shadow=None, draw_atom_indices = False
):
    files = load_xyz_files(path)
    save_paths = []

    # Fit PCA to the final molecule – to obtain the best orientation for visualization
    positions, one_hot, charges = load_molecule_xyz(files[-1], is_geom=is_geom)
    pca = PCA(n_components=3)
    pca.fit(positions)

    for i in range(len(files)):
        file = files[i]

        positions, one_hot, charges = load_molecule_xyz(file, is_geom=is_geom)
        atom_type = torch.argmax(one_hot, dim=1).numpy()

        # Transform positions of each frame according to the best orientation of the last frame
        positions = pca.transform(positions)
        positions = torch.tensor(positions)

        fn = file[:-4] + '.png'
        plot_data3d_xai(
            positions, atom_type,
            save_path=fn,
            spheres_3d=spheres_3d,
            alpha=alpha,
            bg=bg,
            camera_elev=90,
            camera_azim=90,
            is_geom=is_geom,
            fragment_mask=fragment_mask,
            phi_values=phi_values,
            colors_fragment_shadow=colors_fragment_shadow,
            draw_atom_indices = draw_atom_indices
        )
        save_paths.append(fn)

    imgs = [imageio.imread(fn) for fn in save_paths]
    dirname = os.path.dirname(save_paths[0])
    gif_path = dirname + '/output.gif'
    imageio.mimsave(gif_path, imgs, subrectangles=True)

    if wandb is not None:
        wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})

### Important atom analysis

In [5]:
num_samples = NUM_SAMPLES
sampled = 0
start = 0

INTIAL_DISTIBUTION_PATH = "initial_distributions/seed_" + str(SEED)
SHAPLEY_VALUES_FOLDER = "results/shapley_values/"

data_list = []
for data in dataloader:

    if sampled < num_samples:
        data_list.append(data)
        sampled += 1


# load initial distrubution of noisy features and positions
noisy_features = torch.load(INTIAL_DISTIBUTION_PATH + "/noisy_features_seed_" + str(SEED) + ".pt", map_location=device, weights_only=True)
noisy_positions = torch.load(INTIAL_DISTIBUTION_PATH + "/noisy_positions_seed_" + str(SEED) + ".pt", map_location=device, weights_only=True)

### Atom Perturbation/Removal

In [6]:
top_k_neighs_dict = {3:5, 4:3, 7:8, 10: 6, 13: 10, 14:5,  22: 5, 23: 9, 27: 10, 28:5, 29: 8}
num_atom_type_perturbations = 5

for data_index, data in enumerate(tqdm(data_list)):

    if data_index not in top_k_neighs_dict.keys():
        continue

    top_k_neighs = top_k_neighs_dict[data_index]

    smile = data["name"][0]
    mol = read_smiles(smile)

    noisy_positions_present_atoms = noisy_positions.clone()
    noisy_features_present_atoms = noisy_features.clone()

    noisy_positions_present_atoms = noisy_positions_present_atoms[:, :data["positions"].shape[1], :]
    noisy_features_present_atoms = noisy_features_present_atoms[:, :data["one_hot"].shape[1], :]

    num_fragment_atoms = int(data["fragment_mask"].sum().item())

    #generate sample chain using original fragment atoms
    
    print("Original molecule positions: ", data["positions"])
    
    chain_batch_original, node_mask_original = model.sample_chain(data, keep_frames=keep_frames, noisy_positions=noisy_positions_present_atoms, noisy_features=noisy_features_present_atoms)
    
    #load Shapley values for Hausdorff distance
    phi_values = []
    
    
    with open(f"{SHAPLEY_VALUES_FOLDER}explanations_hausdorff_distance_{P}_seed_{str(SEED)}/phi_atoms_{data_index}.txt", "r") as read_file:
        read_file.readline()
        read_file.readline()
        for row in read_file:
            if row.strip() == "":
                break
            line = row.strip().split(",")
            phi_values.append(float(line[3])) #3 for hausdorff distance-based Shapley values

    phi_values_array = np.array(phi_values)
    cmap = plt.cm.get_cmap('coolwarm_r') #reversed heatmap for distance-based importance
    norm = plt.Normalize(vmin=min(phi_values_array), vmax=max(phi_values_array))
    colors_fragment_shadow = cmap(norm(phi_values_array))

    #take first and only molecule
    chain_batch_original_molecule = chain_batch_original[:, 0, :, :]
    
    chain_final_frame_0 = chain_batch_original_molecule[0, :, :]
    chain_final_frame_0_batch = chain_final_frame_0.repeat(keep_frames, 1, 1)
    original_linker_mask_batch = data["linker_mask"][0].squeeze().repeat(keep_frames, 1).cpu()

    # hausdorff_distance_original = compute_hausdorff_distance_batch(chain_final_frame_0_batch.cpu(), chain_final_frame_0_batch.cpu(), mask1=original_linker_mask_batch, mask2=original_linker_mask_batch)
    
    # print("Hausdorff distance between the frame 0 and the frame 0: ", hausdorff_distance_original)
    # arrestomomentum()
    # Compute the difference between the original positions and the positions in chain_batch_original_molecule
    original_positions = data["positions"][0]
    chain_positions = chain_batch_original_molecule[0, :, :3]  # Assuming the first 3 columns are the positions

    position_differences = original_positions - chain_positions
    position_differences = position_differences[data["fragment_mask"].squeeze().bool()][0]
    
    #add position differences to the chain_positions positions
    chain_final_frame_0_batch[:, :, :3] = chain_final_frame_0_batch[:, :, :3] + position_differences
    chain_batch_original_molecule[:, :, :3] = chain_batch_original_molecule[:, :, :3] + position_differences
    
    #compute Hausdorff distance between the frame 0 and the rest of the frames
    hausdorff_distances_original = compute_hausdorff_distance_batch(chain_final_frame_0_batch.cpu(), chain_batch_original_molecule.cpu(), mask1=original_linker_mask_batch, mask2=original_linker_mask_batch) #the linker atoms are the same since those are the frames of a single molecule
    # print("Node mask original: ", node_mask_original)
    # hausdorff_distances_original_projected = compute_hausdorff_distance_batch(chain_final_frame_0_batch.cpu(), chain_batch_original_molecule.cpu(), mask1=original_linker_mask_batch, mask2=original_linker_mask_batch) #the linker atoms are the same since those are the frames of a single molecule
    # hausdorff_distances_original_projected = compute_hausdorff_distance_batch(chain_final_frame_0_batch.cpu(), chain_batch_original_molecule.cpu(), mask1=original_linker_mask_batch, mask2=original_linker_mask_batch, project="origin")
    print("Hausdorff distances: ", hausdorff_distances_original)
    
    

    #save and visualize original chain

    for i in range(len(data['positions'])):
            chain = chain_batch_original[:, i, :, :]
            assert chain.shape[0] == keep_frames
            assert chain.shape[1] == data['positions'].shape[1]
            assert chain.shape[2] == data['positions'].shape[2] + data['one_hot'].shape[2] + model.include_charges

            # Saving chains
            name = str(i + start)
            if ATOM_PERTURBATION:
                chain_output = os.path.join(chains_output_dir, str(data_index), "original")
            elif REMOVAL:
                chain_output = os.path.join(chains_output_dir, str(data_index), "original")
            os.makedirs(chain_output, exist_ok=True)
            
            #save initial random distrubution with noise
            positions_combined = torch.zeros_like(data['positions'])
            one_hot_combined = torch.zeros_like(data['one_hot'])

            # Iterate over each atom and decide whether to use original or noisy data
            for atom_idx in range(data['positions'].shape[1]):
                if data['fragment_mask'][0, atom_idx] == 1:
                    # Use original positions and features for fragment atoms
                    positions_combined[:, atom_idx, :] = data['positions'][:, atom_idx, :]
                    one_hot_combined[:, atom_idx, :] = data['one_hot'][:, atom_idx, :]
                    # atom_mask_combined[:, atom_idx] = data['atom_mask'][:, atom_idx]
                else:
                    # Use noisy positions and features for linker atoms
                    positions_combined[:, atom_idx, :] = noisy_positions_present_atoms[:, atom_idx, :]
                    one_hot_combined[:, atom_idx, :] = noisy_features_present_atoms[:, atom_idx, :]

            #save initial distribution TODO: fix positions, they are not centered
            
            save_xyz_file(
                chain_output,
                one_hot_combined,
                positions_combined,
                node_mask_original[i].unsqueeze(0),
                names=[f'{name}_' + str(keep_frames)],
                is_geom=model.is_geom
            )

            # one_hot = chain[:, :, 3:-1]
            one_hot = chain[:, :, 3:] #@mastro, added last atom type (not sure whyt it was not included...) However, TODO check again
            positions = chain[:, :, :3]
            chain_node_mask = torch.cat([node_mask_original[i].unsqueeze(0) for _ in range(keep_frames)], dim=0)
            names = [f'{name}_{j}' for j in range(keep_frames)]

            save_xyz_file(chain_output, one_hot, positions, chain_node_mask, names=names, is_geom=model.is_geom)

            visualize_chain_xai(
                chain_output,
                spheres_3d=False,
                alpha=0.7,
                bg='white',
                is_geom=model.is_geom,
                fragment_mask=data['fragment_mask'][i].squeeze(),
                phi_values=None,
                colors_fragment_shadow=colors_fragment_shadow,
                draw_atom_indices="original"
            )

            # Saving final prediction and ground truth separately
            true_one_hot = data['one_hot'][i].unsqueeze(0)
            true_positions = data['positions'][i].unsqueeze(0)
            true_node_mask = data['atom_mask'][i].unsqueeze(0)

            #TODO chech if the final positions are saved in the correct folder
            final_states_output_dir_current = None

            if REMOVAL:
                final_states_output_dir_current = os.path.join(final_states_output_dir, str(data_index), "original")
            if ATOM_PERTURBATION:
                final_states_output_dir_current = os.path.join(final_states_output_dir, str(data_index), "original")
           
            os.makedirs(final_states_output_dir_current, exist_ok=True)
            save_xyz_file(
                final_states_output_dir_current,
                true_one_hot,
                true_positions,
                true_node_mask,
                names=[f'{name}_true'],
                is_geom=model.is_geom,
            )

            pred_one_hot = chain[0, :, 3:-1].unsqueeze(0)
            pred_positions = chain[0, :, :3].unsqueeze(0)
            pred_node_mask = chain_node_mask[0].unsqueeze(0)
            save_xyz_file(
                final_states_output_dir_current,
                pred_one_hot,
                pred_positions,
                pred_node_mask,
                names=[f'{name}_true'],
                is_geom=model.is_geom
            )

    
    # Create a line plot for Hausdorff distances
    plt.figure(figsize=(10, 6))
    plt.gca().set_facecolor('white')
    #reverse hausdorff distances
    hausdorff_distances_original = hausdorff_distances_original[::-1]
    sns.lineplot(data=hausdorff_distances_original, marker='o')
    #plt.title('Hausdorff Distance Trend')
    plt.xlabel('Frame')
    plt.ylabel('Hausdorff Distance')
    plt.xticks(ticks=range(keep_frames), labels=range(keep_frames-1, -1, -1))  # Show all 10 frames on the x-axis
    plt.ylim(bottom=0)  # Ensure the y-axis starts at 0
    #add white background


    SAVE_PATH = f"results/plots/neighbor_analysis_COM_addition/{data_index}/"

    if ATOM_PERTURBATION:
        if TOP_K_PERTURBATION_REMOVAL:
            SAVE_PATH += "perturbation_top_k/"
        else:
            SAVE_PATH += "perturbation/"

    elif REMOVAL:
        if TOP_K_PERTURBATION_REMOVAL:
            SAVE_PATH += "removal_top_k/"
        else:
            SAVE_PATH += "removal/"

    if LOCAL_MINIMUM_ANALYSIS:
        SAVE_PATH = SAVE_PATH.replace("neighbor_analysis_COM_addition", "neighbor_analysis_COM_addition_local_min_analysis")


    


    os.makedirs(SAVE_PATH, exist_ok=True)

    plt.savefig(SAVE_PATH + "hausdorff_distance_trend_original.png", dpi = 300)
    plt.savefig(SAVE_PATH + "hausdorff_distance_trend_original.pdf", dpi = 300)
    
    plt.close()

    fragment_mask = data["fragment_mask"].squeeze().bool()
    linker_mask = data["linker_mask"].squeeze().bool()
    phi_values_tensor = torch.tensor(phi_values)

    #get indices of phi_values_tensor from lower to higher
    sorted_phi_values, sorted_indices = torch.sort(phi_values_tensor)
    # reversed_indices = torch.flip(sorted_indices, [0])

    #take top k neighbors
    shapley_value_indices_keep = sorted_indices[top_k_neighs:]
    top_indices_to_perturb = sorted_indices[:top_k_neighs]
    sorted_phi_values = sorted_phi_values[:top_k_neighs]
    print("Top k neighbors: ", top_indices_to_perturb)
    print("shapley_value_indices_keep: ", shapley_value_indices_keep)
    #print linker atoms
    print("Linker atoms: ", torch.where(linker_mask)[0])
    #print fragment atoms
    print("Fragment atoms: ", torch.where(fragment_mask)[0])
    
    for swap_num in tqdm(range(num_atom_type_perturbations)):
        # data_temp = data.copy()
        data_temp = copy.deepcopy(data)
        

        noisy_positions_present_atoms_temp = noisy_positions_present_atoms.clone()
        noisy_features_present_atoms_temp = noisy_features_present_atoms.clone()
        
    
        
        #retrieve indices of fragment and linker atoms from atom_mask
        fragment_atoms_indices = torch.where(fragment_mask)[0]
        fragment_atoms_indices = fragment_atoms_indices.to(device)
        linker_atoms_indices = torch.where(linker_mask)[0]
        linker_atoms_indices = linker_atoms_indices.to(device)
        
        # Randomly change atom type (one_hot) for the atoms at the indices top_indices_to_perturb

        if ATOM_PERTURBATION:

            if not TOP_K_PERTURBATION_REMOVAL:
                for idx in top_indices_to_perturb:
                    if data_temp["anchors"][0, idx] == 1: #we leave anchor atoms as they are since we are analyzing the neighbors
                        continue
                    current_atom_type = torch.argmax(data_temp["one_hot"][:, idx, :], dim=1, keepdim=True)
                    stop = False
                    while not stop:
                        random_atom_type = torch.randint(0, data_temp["one_hot"].shape[2], (1,), device=device)
                        if random_atom_type != current_atom_type:
                            stop = True
                    data_temp["one_hot"][:, idx, :] = 0
                    data_temp["one_hot"][:, idx, random_atom_type] = 1
            else:
                if swap_num == 0:
                    for idx in top_indices_to_perturb:
                        if data_temp["anchors"][0, idx] == 1: #we leave anchor atoms as they are since we are analyzing the neighbors
                            continue
                        current_atom_type = torch.argmax(data_temp["one_hot"][:, idx, :], dim=1, keepdim=True)
                        stop = False
                        while not stop:
                            random_atom_type = torch.randint(0, data_temp["one_hot"].shape[2], (1,), device=device)
                            if random_atom_type != current_atom_type:
                                stop = True
                        data_temp["one_hot"][:, idx, :] = 0
                        data_temp["one_hot"][:, idx, random_atom_type] = 1
                else:                    
                    num_atoms_to_perturb = min(swap_num, len(top_indices_to_perturb))
                    
                    if num_atoms_to_perturb > len(top_indices_to_perturb):
                        continue
                    else:
                        for idx in top_indices_to_perturb[:num_atoms_to_perturb]:
                            if data_temp["anchors"][0, idx] == 1:  # we leave anchor atoms as they are since we are analyzing the neighbors
                                continue
                            current_atom_type = torch.argmax(data_temp["one_hot"][:, idx, :], dim=1, keepdim=True)
                            stop = False
                            while not stop:
                                random_atom_type = torch.randint(0, data_temp["one_hot"].shape[2], (1,), device=device)
                                if random_atom_type != current_atom_type:
                                    stop = True
                            data_temp["one_hot"][:, idx, :] = 0
                            data_temp["one_hot"][:, idx, random_atom_type] = 1

    
        if REMOVAL:

            #keep only elements from fragment_atoms_indices at the indices in shapley_value_indices_keep
            fragment_atoms_indices_keep = None

            # if swap_num == 0:
            #     #no atom is removed
            #     fragment_atoms_indices_keep = fragment_atoms_indices
            if not TOP_K_PERTURBATION_REMOVAL:
                if swap_num == 0: #at the first iteration, we remove all top k and make sure anchor atoms are kept
                    fragment_atoms_indices_keep = torch.cat((shapley_value_indices_keep.to(device), torch.where(data_temp["anchors"].squeeze() == 1)[0].to(device)))
                
                else:
                    # num_atoms_to_add = min(len(top_indices_to_perturb), 3)  # Number of atoms to randomly select and add
                    num_atoms_to_add = np.random.randint(1, len(top_indices_to_perturb))
                    random_atoms_to_add = np.random.choice(top_indices_to_perturb.cpu().numpy(), num_atoms_to_add, replace=False)
                    fragment_atoms_indices_keep = torch.cat((shapley_value_indices_keep.to(device), torch.tensor(random_atoms_to_add, device=device)))
            else:
                if swap_num == 0:
                    #at the first iteration, we remove all top k and make sure anchor atoms are kept
                    fragment_atoms_indices_keep = torch.cat((shapley_value_indices_keep.to(device), torch.where(data_temp["anchors"].squeeze() == 1)[0].to(device)))
                else:
                    num_atoms_to_perturb = min(swap_num , len(top_indices_to_perturb))
                        
                    if num_atoms_to_perturb > len(top_indices_to_perturb):
                        continue
                    else:
                        fragment_atoms_indices_keep = torch.cat((shapley_value_indices_keep.to(device), top_indices_to_perturb[num_atoms_to_perturb:].to(device)))
                            
            
            #remove duplicates
            fragment_atoms_indices_keep = torch.unique(fragment_atoms_indices_keep)
                
            print("Anchor atoms: ", torch.where(data_temp["anchors"].squeeze() == 1)[0])
            print("fragment_atoms_indices_keep: ", fragment_atoms_indices_keep)

            # fragment_atoms_indices_keep = fragment_atoms_indices[shapley_value_indices_keep]
            fragment_atoms_indices_keep_tensor = torch.Tensor(fragment_atoms_indices_keep).to(device)
            
            #keep only fragment_atoms_indices_keep and linker_atoms_indices
            atom_indices_to_keep = torch.cat((fragment_atoms_indices_keep_tensor, linker_atoms_indices)).to(device)

            #remove atoms from molecule
            data_temp["positions"] = data_temp["positions"][:, atom_indices_to_keep, :]
            data_temp["one_hot"] = data_temp["one_hot"][:, atom_indices_to_keep, :]
            data_temp["charges"] = data_temp["charges"][:, atom_indices_to_keep]
            data_temp["fragment_mask"] = data_temp["fragment_mask"][:, atom_indices_to_keep]
            data_temp["linker_mask"] = data_temp["linker_mask"][:, atom_indices_to_keep]
            data_temp["atom_mask"] = data_temp["atom_mask"][:, atom_indices_to_keep]
            data_temp["anchors"] = data_temp["anchors"][:, atom_indices_to_keep]
            edge_mask_to_keep = (data_temp["atom_mask"].unsqueeze(1) * data_temp["atom_mask"]).flatten()
            data_temp["edge_mask"] = edge_mask_to_keep

            #remove atoms from noisy features and positions
            noisy_positions_present_atoms_temp = noisy_positions_present_atoms_temp[:, atom_indices_to_keep, :]
            noisy_features_present_atoms_temp = noisy_features_present_atoms_temp[:, atom_indices_to_keep, :]

        phi_values_array = np.array(phi_values)
        cmap = plt.cm.get_cmap('coolwarm_r') #reversed heatmap for distance-based importance
        norm = plt.Normalize(vmin=min(phi_values_array), vmax=max(phi_values_array))
        colors_fragment_shadow = cmap(norm(phi_values_array))
        #remove atoms from color array
        molecule_perturbation_original_positions = data_temp["positions"].clone()[0]

        if REMOVAL:
            colors_fragment_shadow = colors_fragment_shadow[fragment_atoms_indices_keep.cpu().numpy()]
        
        if not LOCAL_MINIMUM_ANALYSIS:
            chain_batch, node_mask = model.sample_chain(data_temp, keep_frames=keep_frames, noisy_positions=noisy_positions_present_atoms_temp, noisy_features=noisy_features_present_atoms_temp)
        else:
            chain_batch, node_mask = model.sample_chain_atom_addition(data_temp, keep_frames=keep_frames, noisy_positions=noisy_positions_present_atoms_temp, noisy_features=noisy_features_present_atoms_temp, orginal_data = data, noisy_positions_original = noisy_positions_present_atoms, noisy_features_original = noisy_features_present_atoms, frame_to_add = keep_frames//2)

        #compute Hausdorff distance with orignal linker
        chain_batch_molecule_pertubation = chain_batch[:, 0, :, :]

        mask_to_use = None
        if REMOVAL:
            mask_to_use = data_temp["linker_mask"][0].squeeze().repeat(keep_frames, 1).cpu()
        else:
            mask_to_use = original_linker_mask_batch

        chain_perturbation_positions = chain_batch_molecule_pertubation[0, :, :3]  # Assuming the first 3 columns are the positions

        position_differences_perturb = molecule_perturbation_original_positions - chain_perturbation_positions
        position_differences_perturb = position_differences_perturb[data_temp["fragment_mask"].squeeze().bool()][0]

        chain_batch_molecule_pertubation[:, :, :3] = chain_batch_molecule_pertubation[:, :, :3] + position_differences_perturb

        hausdorff_distances_perturbation = compute_hausdorff_distance_batch(chain_final_frame_0_batch.cpu(), chain_batch_molecule_pertubation.cpu(), mask1=original_linker_mask_batch, mask2=mask_to_use) #the linker atoms are the same since those are the frames of a single molecule

        print("Hausdorff distances after perturbation: ", hausdorff_distances_perturbation)
        
        
        # Create a line plot for Hausdorff distances
        plt.figure(figsize=(10, 6))
        plt.gca().set_facecolor('white')
        #reverse hausdorff distances
        hausdorff_distances_perturbation = hausdorff_distances_perturbation[::-1]
        sns.lineplot(data=hausdorff_distances_perturbation, marker='o')
        #plt.title('Hausdorff Distance Trend')
        plt.xlabel('Frame')
        plt.ylabel('Hausdorff Distance')
        plt.xticks(ticks=range(keep_frames), labels=range(keep_frames-1, -1, -1))  # Show all 10 frames on the x-axis
        plt.ylim(bottom=0)  # Ensure the y-axis starts at 0
        

        os.makedirs(SAVE_PATH, exist_ok=True)

        plt.savefig(SAVE_PATH + f"hausdorff_distance_trend_{swap_num}.png", dpi = 300)
        plt.savefig(SAVE_PATH + f"hausdorff_distance_trend_{swap_num}.pdf", dpi = 300)
        # chain_with_full_fragments = chain_batch[0, :, :, :]
        plt.close()
        #save and visualize chain (only for the linker use noisy positions for the initial distribution)
        
        
        for i in range(len(data_temp['positions'])):
            chain = chain_batch[:, i, :, :]
            assert chain.shape[0] == keep_frames
            assert chain.shape[1] == data_temp['positions'].shape[1]
            assert chain.shape[2] == data_temp['positions'].shape[2] + data_temp['one_hot'].shape[2] + model.include_charges

            # Saving chains
            name = str(i + start)
            if ATOM_PERTURBATION:
                chain_output = os.path.join(chains_output_dir, str(data_index), "perturbation_" + str(swap_num))
            elif REMOVAL:
                chain_output = os.path.join(chains_output_dir, str(data_index), "removal_" + str(swap_num))
            os.makedirs(chain_output, exist_ok=True)
            
            #save initial random distrubution with noise
            positions_combined = torch.zeros_like(data_temp['positions'])
            one_hot_combined = torch.zeros_like(data_temp['one_hot'])

            # Iterate over each atom and decide whether to use original or noisy data
            for atom_idx in range(data_temp['positions'].shape[1]):
                if data_temp['fragment_mask'][0, atom_idx] == 1:
                    # Use original positions and features for fragment atoms
                    positions_combined[:, atom_idx, :] = data_temp['positions'][:, atom_idx, :]
                    one_hot_combined[:, atom_idx, :] = data_temp['one_hot'][:, atom_idx, :]
                    # atom_mask_combined[:, atom_idx] = data_temp['atom_mask'][:, atom_idx]
                else:
                    # Use noisy positions and features for linker atoms
                    positions_combined[:, atom_idx, :] = noisy_positions_present_atoms_temp[:, atom_idx, :]
                    one_hot_combined[:, atom_idx, :] = noisy_features_present_atoms_temp[:, atom_idx, :]

            #save initial distribution TODO: fix positions, they are not centered
            
            save_xyz_file(
                chain_output,
                one_hot_combined,
                positions_combined,
                node_mask[i].unsqueeze(0),
                names=[f'{name}_' + str(keep_frames)],
                is_geom=model.is_geom
            )

            # one_hot = chain[:, :, 3:-1]
            one_hot = chain[:, :, 3:] #@mastro, added last atom type (not sure whyt it was not included...) However, TODO check again
            positions = chain[:, :, :3]
            chain_node_mask = torch.cat([node_mask[i].unsqueeze(0) for _ in range(keep_frames)], dim=0)
            names = [f'{name}_{j}' for j in range(keep_frames)]

            save_xyz_file(chain_output, one_hot, positions, chain_node_mask, names=names, is_geom=model.is_geom)

            if REMOVAL:
                draw_atom_indices = (fragment_atoms_indices_keep.tolist(), linker_atoms_indices.tolist())
            
            else:
                draw_atom_indices = "original"

            visualize_chain_xai(
                chain_output,
                spheres_3d=False,
                alpha=0.7,
                bg='white',
                is_geom=model.is_geom,
                fragment_mask=data_temp['fragment_mask'][i].squeeze(),
                phi_values=None,
                colors_fragment_shadow=colors_fragment_shadow,
                draw_atom_indices=draw_atom_indices
            )

            # Saving final prediction and ground truth separately
            true_one_hot = data_temp['one_hot'][i].unsqueeze(0)
            true_positions = data_temp['positions'][i].unsqueeze(0)
            true_node_mask = data_temp['atom_mask'][i].unsqueeze(0)

            #TODO chech if the final positions are saved in the correct folder
            final_states_output_dir_current = None

            if REMOVAL:
                final_states_output_dir_current = os.path.join(final_states_output_dir, str(data_index), "removal_" + str(swap_num))
            if ATOM_PERTURBATION:
                final_states_output_dir_current = os.path.join(final_states_output_dir, str(data_index), "perturbation_" + str(swap_num))
           
            os.makedirs(final_states_output_dir_current, exist_ok=True)
            save_xyz_file(
                final_states_output_dir_current,
                true_one_hot,
                true_positions,
                true_node_mask,
                names=[f'{name}_true'],
                is_geom=model.is_geom,
            )

            pred_one_hot = chain[0, :, 3:-1].unsqueeze(0)
            pred_positions = chain[0, :, :3].unsqueeze(0)
            pred_node_mask = chain_node_mask[0].unsqueeze(0)
            save_xyz_file(
                final_states_output_dir_current,
                pred_one_hot,
                pred_positions,
                pred_node_mask,
                names=[f'{name}_true'],
                is_geom=model.is_geom
            )

        
        del data_temp
        del noisy_features_present_atoms_temp
        del noisy_positions_present_atoms_temp
    start += len(data['positions'])

  0%|          | 0/30 [00:00<?, ?it/s]

Original molecule positions:  tensor([[[ 2.0732e+00,  2.8905e+00,  1.1610e+00],
         [ 1.6141e+00,  2.0553e+00,  3.3000e-03],
         [ 2.7350e-01,  1.6952e+00, -1.2480e-01],
         [-1.4300e-01,  8.8410e-01, -1.1825e+00],
         [ 7.7240e-01,  4.1420e-01, -2.1436e+00],
         [ 2.1292e+00,  7.8020e-01, -2.0248e+00],
         [ 3.1977e+00,  3.0730e-01, -2.9754e+00],
         [ 2.5313e+00,  1.5722e+00, -9.3450e-01],
         [ 1.5150e-01, -9.0660e-01,  3.9207e+00],
         [-4.5100e-01,  6.5280e-01,  3.5543e+00],
         [-1.1879e+00,  5.1300e-02,  2.1060e+00],
         [-1.8739e+00,  8.2550e-01,  1.2206e+00],
         [-2.8154e+00,  2.0660e-01,  2.9710e-01],
         [-3.9177e+00,  1.2508e+00,  1.3900e-01],
         [-3.1829e+00,  2.5628e+00,  3.4380e-01],
         [-2.2021e+00,  2.2234e+00,  1.4632e+00],
         [-1.0060e+00, -1.2406e+00,  1.9257e+00],
         [-2.0120e-01, -1.7865e+00,  2.9222e+00],
         [ 2.1110e-01, -3.2219e+00,  2.7901e+00],
         [ 2.1440e+0

  0%|          | 0/5 [00:00<?, ?it/s]

Hausdorff distances after perturbation:  [3.8590835366740657, 3.8330281166813434, 3.701378094542503, 3.720864097936258, 4.107783215401667, 3.522466400218213, 3.4498484210976503, 2.4883117041752114, 3.1104863137719114, 4.217994201621006]
Hausdorff distances after perturbation:  [0.9064174765975634, 1.0327992903862515, 1.3028442280832946, 1.2104647115444678, 1.4083082573774468, 2.589219398745811, 3.030396983159625, 3.5602466948995843, 3.820596574709031, 4.893379746422174]
Hausdorff distances after perturbation:  [1.314245964473664, 1.7258364433545101, 1.5853590042513168, 1.7261754172617259, 1.7922235691807615, 2.4754738574576143, 2.9494724825753593, 3.280931009513308, 4.104351472735579, 4.62354857148065]
Hausdorff distances after perturbation:  [1.603813065898006, 1.4521047780477445, 1.2670327423674421, 1.3034801726367868, 1.6167847697685231, 2.2750326573294415, 2.5880226856505972, 3.2757287967368303, 3.7765621418436988, 4.5529250117067726]
Hausdorff distances after perturbation:  [1.479

  0%|          | 0/5 [00:00<?, ?it/s]

Hausdorff distances after perturbation:  [1.8520142736710339, 1.9380118331202334, 2.0716447074096003, 2.0051112980655432, 1.702206831384017, 3.025874162269763, 3.2793958564858476, 3.330158483858967, 3.665090654628, 4.340000109484301]
Hausdorff distances after perturbation:  [1.5833203893778651, 1.5333301677206845, 1.4813674552751246, 1.7307093115970964, 1.7870531489493848, 2.588033584743149, 2.5869104223308783, 3.570020053459127, 4.1216293125061085, 5.140730775019647]
Hausdorff distances after perturbation:  [2.0954653574476794, 2.0323925338874322, 1.8233538710671018, 1.3429994280616544, 1.442765466758824, 2.227142148175343, 2.643956271869048, 2.7245158261874445, 3.321548540307178, 4.6962073493636645]
Hausdorff distances after perturbation:  [8.724077226829191, 8.784694321638767, 8.972312840063463, 8.391752522071355, 8.0247068057759, 6.845250326374131, 6.4791748242374405, 5.151679780496033, 5.181858173766214, 4.43793006407999]
Hausdorff distances after perturbation:  [10.85799056785269

  0%|          | 0/5 [00:00<?, ?it/s]

Hausdorff distances after perturbation:  [8.70307238023404, 9.034000832317712, 9.146086747320139, 9.181607659923325, 9.029203410109751, 8.61241021852772, 8.253105819404556, 7.147212725756179, 5.848271999286794, 6.010954445872563]
Hausdorff distances after perturbation:  [9.832385778089572, 9.748256105558522, 10.118054676838037, 9.883712144364415, 9.198557801808246, 8.210784182214022, 6.799379593875057, 6.518280983265778, 6.147613719398373, 5.328972008948825]
Hausdorff distances after perturbation:  [5.560829380990712, 5.367933056938565, 5.204526163188032, 5.290552183950218, 5.603697445634477, 5.337404873551022, 5.260472885390637, 5.487958773979327, 5.959511327275319, 5.780531851735671]
Hausdorff distances after perturbation:  [7.088095607460498, 6.908384279526636, 7.0075078198557845, 6.797747322146185, 6.716220481961666, 6.962433364171244, 6.718654971043645, 6.068476216059247, 5.957769803247719, 5.396327595037358]
Hausdorff distances after perturbation:  [7.009291328820189, 6.840965202

Atom "[C@H]" contains stereochemical information that will be discarded.
Atom "[C@@H]" contains stereochemical information that will be discarded.


Original molecule positions:  tensor([[[-3.8683e+00, -3.2571e+00, -8.6300e-02],
         [-3.6839e+00, -2.3998e+00,  1.1574e+00],
         [-3.9001e+00, -9.2840e-01,  8.7380e-01],
         [-5.1653e+00, -3.8720e-01,  1.1738e+00],
         [-5.4477e+00,  9.5590e-01,  9.3740e-01],
         [-4.4707e+00,  1.7778e+00,  3.9230e-01],
         [-3.2129e+00,  1.2572e+00,  8.1100e-02],
         [-2.9101e+00, -9.1000e-02,  3.1300e-01],
         [ 3.0153e+00, -9.9230e-01, -1.4359e+00],
         [ 3.9493e+00, -4.0860e-01, -3.5770e-01],
         [ 4.2778e+00,  1.0994e+00, -5.2020e-01],
         [ 3.1568e+00,  2.1093e+00, -2.2060e-01],
         [ 3.7140e+00,  3.5211e+00, -8.3100e-02],
         [ 2.2220e+00,  2.1212e+00, -1.2974e+00],
         [ 3.4915e+00, -7.5400e-01,  1.0617e+00],
         [-3.8810e-01,  1.1959e+00, -7.7970e-01],
         [-5.4160e-01,  1.2000e-02, -5.1060e-01],
         [-1.6488e+00, -6.3330e-01, -3.2000e-03],
         [ 6.4030e-01, -9.6800e-01, -7.1160e-01],
         [ 5.8010e-0

  0%|          | 0/5 [00:00<?, ?it/s]

Hausdorff distances after perturbation:  [3.9910927713905697, 3.90179819409926, 4.200738003639885, 3.8842929760842164, 3.8689003333612386, 2.8534488070287494, 2.12819335093747, 2.38463465470897, 1.5057583166650697, 1.8658938262910385]
Hausdorff distances after perturbation:  [1.519464159737649, 1.3393322006481154, 1.2237607782993765, 1.2941291017532808, 1.3398146972812497, 1.5390227217332986, 1.7959693386858298, 1.8618604966880878, 2.693005999102912, 2.384141986066668]
Hausdorff distances after perturbation:  [2.799643643359146, 2.620492968792008, 3.029304940822284, 3.2577570991254747, 3.1646740067778865, 3.3791675491427577, 3.1397598586850157, 2.6944551540416937, 2.7327715333545846, 3.8588898693116414]
Hausdorff distances after perturbation:  [2.9588490894719177, 2.6057140179587956, 2.241515267123896, 2.332681445593096, 2.5789547497197414, 2.2856064066850963, 2.070600156689947, 2.19618638864709, 2.2960360149192236, 2.953349256937352]
Hausdorff distances after perturbation:  [1.9798797

  0%|          | 0/5 [00:00<?, ?it/s]

Hausdorff distances after perturbation:  [3.3351790506439136, 3.6368429104643636, 3.297110624040619, 3.1636386556716647, 2.851478051015545, 3.5463506821046633, 3.7925266844198013, 3.4430274582013363, 3.774021459484221, 3.4703912614383414]
Hausdorff distances after perturbation:  [3.2181416943728918, 3.278694342729693, 3.138859506682728, 3.3329479898465237, 3.341387969003436, 2.825343668761657, 2.650609859824331, 3.9720651057761285, 3.7268552921467966, 3.8038108056109694]
Hausdorff distances after perturbation:  [4.295469956845467, 3.8681749626098583, 3.6782346167574054, 3.6877919694110983, 3.9607850568620844, 4.156434279700867, 4.704064667185193, 3.9474493646598297, 3.84451030526112, 4.303050852021345]
Hausdorff distances after perturbation:  [1.2551673320526522, 1.2540310671180135, 1.2548600001709946, 1.5151569868651114, 1.3896094397580414, 1.9053033043898664, 2.378066800838497, 3.258121506635795, 3.764668401930867, 3.754022909811986]
Hausdorff distances after perturbation:  [5.062917

  0%|          | 0/5 [00:00<?, ?it/s]

Hausdorff distances after perturbation:  [0.0796954637525804, 0.3539682309418692, 0.6549905792394001, 0.8701079183057903, 1.056481248455855, 1.7470001471584191, 1.3812508939998906, 0.9957769489339678, 2.107957506037192, 2.5986065497239754]
Hausdorff distances after perturbation:  [0.053298197801418055, 0.3294557698107852, 0.5734733907943019, 0.7772585870712001, 1.0840718510978873, 1.307573436209762, 1.9147298574569203, 2.369219816254503, 1.395479837507441, 2.5566036401753114]
Hausdorff distances after perturbation:  [1.3666803566316221, 1.2480629917035666, 1.6932801322215576, 1.8863378407439144, 2.1519059774491356, 1.8915474435564874, 2.290578396128957, 2.0012559167855217, 2.6971595187264668, 2.7805241928131417]
Hausdorff distances after perturbation:  [1.3807277913168057, 1.3793459248523219, 1.2636763877900057, 1.6308771682273315, 1.7873107491993532, 2.5223109517928934, 2.4400600477560452, 2.771085215023119, 2.558246817334636, 3.6955756177682746]
Hausdorff distances after perturbation

Atom "[C@H]" contains stereochemical information that will be discarded.


Original molecule positions:  tensor([[[ 6.1550,  0.2015, -1.2132],
         [ 5.0265,  1.0165, -0.6562],
         [ 4.9306,  2.3803, -0.9293],
         [ 3.8997,  3.1418, -0.3636],
         [ 2.9507,  2.5832,  0.4836],
         [ 1.9885,  3.3773,  0.9976],
         [ 1.1009,  2.8476,  1.8678],
         [ 1.0816,  1.5165,  2.2196],
         [ 2.0459,  0.6925,  1.6400],
         [ 3.0377,  1.2082,  0.7742],
         [ 4.0872,  0.4392,  0.2036],
         [-1.3967, -1.6526, -2.2339],
         [-1.4795, -1.0204, -1.1883],
         [-2.4040,  0.1412, -1.0884],
         [-1.8508,  1.4023, -0.8073],
         [-2.6874,  2.5019, -0.6737],
         [-4.0697,  2.3469, -0.8114],
         [-4.6382,  1.0987, -1.1000],
         [-6.0322,  0.9680, -1.2433],
         [-6.6125, -0.2653, -1.5380],
         [-5.8065, -1.3855, -1.6884],
         [-4.4202, -1.2749, -1.5434],
         [-3.8033, -0.0372, -1.2496],
         [ 2.3126, -1.0816,  3.1057],
         [ 1.9854, -0.7463,  1.9730],
         [ 1.5426, -

  0%|          | 0/5 [00:00<?, ?it/s]

Hausdorff distances after perturbation:  [4.906064588743527, 4.981706067075841, 4.690796441858776, 4.778620543294901, 4.52481999397492, 4.355676307526366, 4.33429512114737, 5.011600710218181, 4.356170328321915, 4.309471022110696]
Hausdorff distances after perturbation:  [2.403587981384291, 2.4159536765312835, 1.9213147056883642, 1.8641786507601392, 2.005103422103644, 2.2013495026385703, 3.5700201261713733, 3.7742290493834085, 4.43412204141872, 4.882108045681253]
Hausdorff distances after perturbation:  [2.859488165283811, 2.9098726383012816, 3.0688774017075406, 3.3576437792562577, 3.393333885249543, 3.330767421210555, 3.395154282427546, 3.617080055832668, 3.848452586786462, 3.850207024354355]
Hausdorff distances after perturbation:  [2.56302554643613, 2.437224752928352, 2.832666973867344, 2.590398711961935, 3.146658987397474, 3.6280025466318833, 3.6332052197395766, 4.465923455541776, 3.7023309019172976, 4.3382455295956355]
Hausdorff distances after perturbation:  [4.485306222309643, 4.

Atom "[C@H]" contains stereochemical information that will be discarded.


Original molecule positions:  tensor([[[ 6.1550,  0.2015, -1.2132],
         [ 5.0265,  1.0165, -0.6562],
         [ 4.9306,  2.3803, -0.9293],
         [ 3.8997,  3.1418, -0.3636],
         [ 2.9507,  2.5832,  0.4836],
         [ 1.9885,  3.3773,  0.9976],
         [ 1.1009,  2.8476,  1.8678],
         [ 1.0816,  1.5165,  2.2196],
         [ 2.0459,  0.6925,  1.6400],
         [ 3.0377,  1.2082,  0.7742],
         [ 4.0872,  0.4392,  0.2036],
         [-1.8508,  1.4023, -0.8073],
         [-2.6874,  2.5019, -0.6737],
         [-4.0697,  2.3469, -0.8114],
         [-4.6382,  1.0987, -1.1000],
         [-6.0322,  0.9680, -1.2433],
         [-6.6125, -0.2653, -1.5380],
         [-5.8065, -1.3855, -1.6884],
         [-4.4202, -1.2749, -1.5434],
         [-3.8033, -0.0372, -1.2496],
         [-2.4040,  0.1412, -1.0884],
         [-1.3967, -1.6526, -2.2339],
         [-1.4795, -1.0204, -1.1883],
         [-0.7015, -1.3935,  0.0599],
         [-1.0046, -2.8474,  0.4612],
         [-0.1251, -

  0%|          | 0/5 [00:00<?, ?it/s]

Hausdorff distances after perturbation:  [4.385410383200831, 4.250211775547377, 4.18222588261713, 4.2828628933549515, 3.744130660449506, 3.708974517496, 3.5296946621550207, 3.685948728562747, 4.223128755538092, 4.314748634508609]
Hausdorff distances after perturbation:  [2.926394839181346, 3.1862351280403836, 2.9521783506209576, 3.4696062754135264, 2.9210857224124815, 3.352680749782129, 3.0402926171237072, 3.23832394840605, 3.7641520090832334, 5.115442795434933]
Hausdorff distances after perturbation:  [2.603716432550926, 2.5374861781109246, 2.4618387007736406, 2.7367218074367283, 2.657538747004175, 2.8854576385600215, 2.6253208162966, 3.105344844733371, 3.752146442879927, 4.577117570494492]
Hausdorff distances after perturbation:  [1.5951846058620232, 1.6531804210340595, 1.7188772044252998, 1.83011126928253, 2.108033028410996, 2.5597985155312917, 2.718380481225575, 3.0561225483675507, 3.9526482072837945, 3.5285069799678093]
Hausdorff distances after perturbation:  [3.8404267199513744,

  0%|          | 0/5 [00:00<?, ?it/s]

Hausdorff distances after perturbation:  [4.417688888849333, 4.37095009296908, 4.574921777065882, 4.1615427795397135, 4.141366956865052, 4.173447007702788, 3.960420525603231, 4.317550335130843, 4.190693100289805, 3.955836033485313]
Hausdorff distances after perturbation:  [1.5336079675201788, 1.4602909171418044, 1.5096329800390083, 1.4674373512531584, 1.6255743539819452, 1.947410621726635, 2.1787832968123206, 3.8021249018378116, 4.349819160287255, 3.426026436561475]
Hausdorff distances after perturbation:  [1.860884667085912, 1.7385390325317185, 2.1915369005127583, 2.1174285521812894, 1.8481887043849523, 1.6987560835026518, 1.9284664938094944, 2.2053829996463015, 2.5896310121259667, 3.745905618144354]
Hausdorff distances after perturbation:  [1.9761381609892925, 2.155425416233394, 2.2157146611452925, 2.0623102604680716, 1.7859153449794025, 2.3278932287581897, 4.194343847881036, 3.962217692451787, 3.702103790934072, 4.138524606735013]
Hausdorff distances after perturbation:  [2.39641666

Atom "[C@H]" contains stereochemical information that will be discarded.


Original molecule positions:  tensor([[[ 2.6609, -1.8460,  2.2589],
         [ 2.7028, -0.9068,  1.4782],
         [ 3.7691, -0.6358,  0.6595],
         [ 4.9688, -1.4317,  0.7845],
         [ 5.9747, -1.0614, -0.2624],
         [ 7.0905, -0.2733,  0.0326],
         [ 7.9889,  0.0563, -0.9783],
         [ 7.7373, -0.4093, -2.2581],
         [ 6.6677, -1.1656, -2.5835],
         [ 5.8132, -1.4736, -1.5815],
         [ 1.7027,  0.0145,  1.3028],
         [-7.9857, -1.4487, -2.4004],
         [-6.6656, -1.8299, -2.6037],
         [-5.6396, -1.0326, -2.1033],
         [-5.9656,  0.1281, -1.4023],
         [-4.8803,  1.0471, -0.8822],
         [-3.8319,  0.2856, -0.2756],
         [-7.2532,  0.5192, -1.2104],
         [-8.2268, -0.2777, -1.7037],
         [ 0.5734,  0.1815,  3.5030],
         [ 0.4356, -0.1311,  2.0090],
         [-0.6171,  0.7578,  1.3668],
         [-0.4638,  2.1505,  1.3087],
         [-1.4459,  2.9469,  0.7195],
         [-2.5947,  2.3656,  0.1773],
         [-2.7634,  

  0%|          | 0/5 [00:00<?, ?it/s]

Hausdorff distances after perturbation:  [2.197413753259877, 2.1081439792230823, 1.835699338325847, 2.2175478488802427, 1.647736177666518, 2.0711943846745395, 2.686896047054503, 3.2319242025569697, 3.2669379076867022, 3.5726984553847956]
Hausdorff distances after perturbation:  [1.3566872222624267, 1.4205747055294078, 1.3487943394625488, 1.750581692808273, 1.8519101824178303, 1.7936711888076786, 1.615483212472305, 2.0819654645415526, 2.9110919210009145, 3.742418215796967]
Hausdorff distances after perturbation:  [2.4361977482161885, 2.184831465329159, 2.4291260604973224, 2.415317826302943, 2.505836774546374, 2.719879389428872, 2.787482586679681, 2.7016688970619205, 3.5386053383607385, 3.4873997733762203]
Hausdorff distances after perturbation:  [2.506078376340739, 2.317874002489197, 2.1159255371901073, 2.289577684201712, 2.1304909116627737, 2.42668349716772, 2.8429380293180317, 2.155457688764947, 2.1937611108278423, 3.3537414511601744]
Hausdorff distances after perturbation:  [2.127989

Atom "[C@H]" contains stereochemical information that will be discarded.


Original molecule positions:  tensor([[[ 0.5734,  0.1815,  3.5030],
         [ 0.4356, -0.1311,  2.0090],
         [ 1.7027,  0.0145,  1.3028],
         [ 2.7028, -0.9068,  1.4782],
         [ 2.6609, -1.8460,  2.2589],
         [ 3.7691, -0.6358,  0.6595],
         [ 4.9688, -1.4317,  0.7845],
         [ 5.9747, -1.0614, -0.2624],
         [ 7.0905, -0.2733,  0.0326],
         [ 7.9889,  0.0563, -0.9783],
         [ 7.7373, -0.4093, -2.2581],
         [ 6.6677, -1.1656, -2.5835],
         [ 5.8132, -1.4736, -1.5815],
         [-7.9857, -1.4487, -2.4004],
         [-6.6656, -1.8299, -2.6037],
         [-5.6396, -1.0326, -2.1033],
         [-5.9656,  0.1281, -1.4023],
         [-4.8803,  1.0471, -0.8822],
         [-3.8319,  0.2856, -0.2756],
         [-7.2532,  0.5192, -1.2104],
         [-8.2268, -0.2777, -1.7037],
         [-0.6171,  0.7578,  1.3668],
         [-0.4638,  2.1505,  1.3087],
         [-1.4459,  2.9469,  0.7195],
         [-2.5947,  2.3656,  0.1773],
         [-2.7634,  

  0%|          | 0/5 [00:00<?, ?it/s]

Hausdorff distances after perturbation:  [2.6230854348455748, 2.5817416002801927, 2.473157806807414, 2.0721984396065403, 1.9199712559566975, 2.8209814704044036, 3.0195324111190014, 3.0403254471899133, 2.925881785831743, 2.9558347501193376]
Hausdorff distances after perturbation:  [2.6828694563734903, 2.5761116061169083, 2.2658876249521005, 1.8563911897833765, 1.9273671362846327, 1.8414842136027907, 2.4612273845165933, 2.5061944145008463, 2.871123202011642, 3.409439309201299]
Hausdorff distances after perturbation:  [2.5565799426953784, 2.3541954274057577, 2.589203705902524, 2.545717435098509, 2.25765155239344, 1.85921588977315, 2.197877750291333, 2.2935271517860385, 2.5802277150685704, 2.393500721801691]
Hausdorff distances after perturbation:  [2.4945955314447565, 2.3290908248096267, 2.300589055649282, 2.169833004040921, 1.881322892792918, 1.8591477843809778, 2.097238226802046, 1.8033750036163052, 1.696107769374377, 2.343902130514525]
Hausdorff distances after perturbation:  [2.530930