In [1]:
import os
os.environ["http_proxy"] = "http://web-proxy.informatik.uni-bonn.de:3128"
os.environ["https_proxy"] = "http://web-proxy.informatik.uni-bonn.de:3128"

In [2]:
import argparse
import torch

from src.datasets import get_dataloader
from src.lightning import DDPM
from src.molecule_builder import get_bond_order
from src.visualizer import save_xyz_file
from tqdm.auto import tqdm
import sys #@mastro
from src import const #@mastro
import numpy as np #@mastro
from numpy.random import default_rng
from sklearn.metrics import jaccard_score
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import directed_hausdorff
import random
from sklearn.decomposition import PCA
from src.visualizer import load_molecule_xyz, load_xyz_files
import matplotlib.pyplot as plt
import imageio
from src import const
import networkx as nx
import time 
import yaml
from pysmiles import read_smiles
#get running device from const file

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# Simulate command-line arguments

# density = sys.argv[sys.argv.index("--P") + 1]
with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)

checkpoint = config['CHECKPOINT']
chains = config['CHAINS']
DATA = config['DATA']
prefix = config['PREFIX']
keep_frames = int(config['KEEP_FRAMES'])
P = config['P']
device = config['DEVICE'] if torch.cuda.is_available() else 'cpu'
SEED = int(config['SEED'])
SAVE_VISUALIZATION = config['SAVE_VISUALIZATION']
M = int(config['M'])
NUM_SAMPLES = int(config['NUM_SAMPLES'])
PARALLEL_STEPS = int(config['PARALLEL_STEPS'])
DIAGONALIZE = config['DIAGONALIZE']

print("seed is: ", SEED)

experiment_name = checkpoint.split('/')[-1].replace('.ckpt', '')
chains_output_dir = os.path.join(chains, experiment_name, prefix, 'chains_coulomb_diagonalized_' + P + '_seed_' + str(SEED))
final_states_output_dir = os.path.join(chains, experiment_name, prefix, 'final_states_coulomb_diagonalized_' + P + '_seed_' + str(SEED))

    
    
os.makedirs(chains_output_dir, exist_ok=True)
os.makedirs(final_states_output_dir, exist_ok=True)

# Loading model form checkpoint (all hparams will be automatically set)
model = DDPM.load_from_checkpoint(checkpoint, map_location=device)

# Possibility to evaluate on different datasets (e.g., on CASF instead of ZINC)
model.val_data_prefix = prefix

print(f"Running device: {device}")
# In case <Anonymous> will run my model or vice versa
if DATA is not None:
    model.data_path = DATA

model = model.eval().to(device)
model.setup(stage='val')
dataloader = get_dataloader(
    model.val_dataset,
    batch_size=1, #@mastro, it was 32
    # batch_size=len(model.val_dataset)
)




seed is:  42


/home/mastropietro/anaconda3/envs/diff_explainer/lib/python3.10/site-packages/lightning_fabric/utilities/cloud_io.py:57: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
Lightning automatically

Running device: cuda:1


In [3]:
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
np.random.seed(SEED)
random.seed(SEED)

#### Similarity functions

In [4]:
def compute_molecular_similarity(mol1, mol2, mask1 = None, mask2 = None):
    """
    Compute the similarity between two molecules based on distances and atom type.
    
    Args:
        mol1 (torch.Tensor): The first molecule.
        mol2 (torch.Tensor): The second molecule.
        mask (torch.Tensor, optional): A mask indicating which atoms to consider. If not provided, all atoms will be considered.
        
    Returns:
        float: The similarity between the two molecules.
    """
    # If fragmen_mask is provided, only consider the atoms in the mask
    if mask1 is not None:
        mask1 = mask1.bool()
        mol1 = mol1[mask1,:]

    if mask2 is not None:
        mask2 = mask2.bool()
        mol2 = mol2[mask2,:]

    return 1 - torch.norm(mol1 - mol2)

def compute_molecular_distance(mol1, mol2, mask1 = None, mask2 = None):
    """
    Compute the similarity between two molecules based on distances and atom type.
    
    Args:
        mol1 (torch.Tensor): The first molecule.
        mol2 (torch.Tensor): The second molecule.
        mask (torch.Tensor, optional): A mask indicating which atoms to consider. If not provided, all atoms will be considered.
        
    Returns:
        float: The similarity between the two molecules.
    """
    # If fragmen_mask is provided, only consider the atoms in the mask
    if mask1 is not None:
        mask1 = mask1.bool()
        mol1 = mol1[mask1,:]

    if mask2 is not None:
        mask2 = mask2.bool()
        mol2 = mol2[mask2,:]

    # return torch.norm(mol1 - mol2).item()
    return torch.linalg.norm(mol1 - mol2).item()

def compute_molecular_distance_batch(mol1, mol2, mask1 = None, mask2 = None):
    """
    Compute the similarity between two molecules based on distances and atom type.
    
    Args:
        mol1 (torch.Tensor): The first molecule.
        mol2 (torch.Tensor): The second molecule.
        mask (torch.Tensor, optional): A mask indicating which atoms to consider. If not provided, all atoms will be considered.
        
    Returns:
        torch.Tensor: The similarity between the two molecules for each element in the batch.
    """
    # If fragment_mask is provided, only consider the atoms in the mask
    if mask1 is not None:
        mask1 = mask1.bool()
        batch_size = mol1.shape[0]
        masked_mol1 = []
        for i in range(batch_size):
            masked_mol1.append(mol1[i, mask1[i], :])

        if batch_size == 1:
            mol1 = masked_mol1[0].unsqueeze(0)
        else:    
            mol1 = torch.stack(masked_mol1)
           
    if mask2 is not None:
        mask2 = mask2.bool()
        batch_size = mol2.shape[0]
        masked_mol2 = []
        for i in range(batch_size):
            masked_mol2.append(mol2[i, mask2[i], :])
        
        if batch_size == 1:
            mol2 = masked_mol2[0].unsqueeze(0)
        else:    
            mol2 = torch.stack(masked_mol2)

    return torch.linalg.norm(mol1 - mol2, dim=(1,2))

def compute_cosine_similarity(mol1, mol2, mask1 = None, mask2 = None):
    """
    Compute the similarity between two molecules based on distances and atom type.
    
    Args:
        mol1 (torch.Tensor): The first molecule.
        mol2 (torch.Tensor): The second molecule.
        mask (torch.Tensor, optional): A mask indicating which atoms to consider. If not provided, all atoms will be considered.
        
    Returns:
        float: The similarity between the two molecules.
    """
    # If fragmen_mask is provided, only consider the atoms in the mask
    if mask1 is not None:
        mask1 = mask1.bool()
        mol1 = mol1[mask1,:]

    if mask2 is not None:
        mask2 = mask2.bool()
        mol2 = mol2[mask2,:]

    return cosine_similarity(mol1.flatten().reshape(1, -1), mol2.flatten().reshape(1, -1)).item()


def compute_cosine_similarity_batch(mol1, mol2, mask1 = None, mask2 = None):
    """
    Compute the similarity between two molecules based on distances and atom type.
    
    Args:
        mol1 (torch.Tensor): The first molecule.
        mol2 (torch.Tensor): The second molecule.
        mask (torch.Tensor, optional): A mask indicating which atoms to consider. If not provided, all atoms will be considered.
        
    Returns:
        float: The similarity between the two molecules.
    """
    # If fragmen_mask is provided, only consider the atoms in the mask
    if mask1 is not None:
        mask1 = mask1.bool()
        batch_size = mol1.shape[0]
        masked_mol1 = []
        for i in range(batch_size):
            masked_mol1.append(mol1[i, mask1[i], :])
        
        if batch_size == 1:
            mol1 = masked_mol1[0].unsqueeze(0)
        else:    
            mol1 = torch.stack(masked_mol1)
        

    if mask2 is not None:
        mask2 = mask2.bool()
        mask2 = mask2.bool()
        batch_size = mol2.shape[0]
        masked_mol2 = []
        for i in range(batch_size):
            masked_mol2.append(mol2[i, mask2[i], :])
        
        if batch_size == 1:
            mol2 = masked_mol2[0].unsqueeze(0)
        else:    
            mol2 = torch.stack(masked_mol2)

    cos_sims = []
    for i in range(mol1.shape[0]):
        cos_sims.append(cosine_similarity(mol1[i].flatten().reshape(1, -1), mol2[i].flatten().reshape(1, -1)).item())

    return cos_sims

def compute_molecular_similarity_positions(mol1, mol2, mask1 = None, mask2 = None):
    """
    Compute the similarity between two molecules based on positions.
    
    Args:
        mol1 (torch.Tensor): The first molecule.
        mol2 (torch.Tensor): The second molecule.
        mask (torch.Tensor, optional): A mask indicating which atoms to consider. If not provided, all atoms will be considered.
        
    Returns:
        float: The similarity between the two molecules.
    """
    # If fragmen_mask is provided, only consider the atoms in the mask
    positions1 = mol1[:, :3].squeeze()
    positions2 = mol2[:, :3].squeeze()

    if mask1 is not None:
        mask1 = mask1.bool()
        positions1 = positions1[mask1,:]

    if mask2 is not None:
        mask2 = mask2.bool()
        positions2 = positions2[mask2,:]


    return 1 - torch.norm(positions1 - positions2) #choose if distance or similarity, need to check what it the better choice

def compute_one_hot_similarity(mol1, mol2, mask1 = None, mask2 = None):
    """
    Computes the similarity between two one-hot encoded molecules. The one-hot encoding indicates the atom type
    
    Args:
        mol1 (torch.Tensor): The first one-hot encoded molecule.
        mol2 (torch.Tensor): The second one-hot encoded molecule.
        mask (torch.Tensor, optional): A mask to apply on the atoms. Defaults to None.
    
    Returns:
        torch.Tensor: The similarity between the two molecules.
    """
    
    # Apply mask if provided
    if mask1 is not None:
        mask1 = mask1.bool()
        mol1 = mol1[mask1,:]

    if mask2 is not None:
        mask2 = mask2.bool()
        mol2 = mol2[mask2,:]
    
    # Compute similarity by comparing the one-hot encoded features
    similarity = torch.sum(mol1[:,3:-1] == mol2[:,3:-1]) / mol1[:, 3:-1].numel()
    
    return similarity

def compute_hausdorff_distance_batch(mol1, mol2, mask1 = None, mask2 = None):
    """
    Compute the similarity between two molecules based on distances and atom type.
    
    Args:
        mol1 (torch.Tensor): The first molecule.
        mol2 (torch.Tensor): The second molecule.
        mask1 (torch.Tensor, optional): A mask indicating which atoms to consider for mo1. If not provided, all atoms will be considered.
        mask2 (torch.Tensor, optional): A mask indicating which atoms to consider for mol2. If not provided, all atoms will be considered.
        
    Returns:
        torch.Tensor: The similarity between the two molecules for each element in the batch.
    """
    # If fragment_mask is provided, only consider the atoms in the mask

    #take only the positions
    mol1 = mol1[:, :, :3]
    mol2 = mol2[:, :, :3]
    
    
    if mask1 is not None:
        mask1 = mask1.bool()
        batch_size = mol1.shape[0]
        masked_mol1 = []
        for i in range(batch_size):
            masked_mol1.append(mol1[i, mask1[i], :])
        
        if batch_size == 1:
            mol1 = masked_mol1[0].unsqueeze(0)
        else:    
            mol1 = torch.stack(masked_mol1)
        

    if mask2 is not None:
        mask2 = mask2.bool()
        mask2 = mask2.bool()
        batch_size = mol2.shape[0]
        masked_mol2 = []
        for i in range(batch_size):
            masked_mol2.append(mol2[i, mask2[i], :])
        
        if batch_size == 1:
            mol2 = masked_mol2[0].unsqueeze(0)
        else:    
            mol2 = torch.stack(masked_mol2)

    hausdorff_distances = []
    for i in range(mol1.shape[0]):
        hausdorff_distances.append(max(directed_hausdorff(mol1[i], mol2[i])[0], directed_hausdorff(mol2[i], mol1[i])[0]))

    return hausdorff_distances


def create_edge_index(mol, weighted=False):
    """
    Create edge index for a molecule.
    """
    adj = nx.to_scipy_sparse_array(mol).todense()
    row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long)
    col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long)
    edge_index = torch.stack([row, col], dim=0)

    if weighted:
        weights = torch.from_numpy(adj.data.astype(np.float32))
        edge_weight = torch.FloatTensor(weights)
        return edge_index, edge_weight

    return edge_index


def compute_coulomb_matrix(mol, mask=None, diagonalize=False):
    """
    Compute the Coulomb matrix for a molecule.
    
    Args:
        mol (torch.Tensor): The molecule tensor with shape (N, 4), where N is the number of atoms.
                            The last dimension should contain [x, y, z, atomic_number].
        mask (torch.Tensor, optional): A mask indicating which atoms to consider. If not provided, all atoms will be considered.
        diagonalize (bool, optional): Whether to return the diagonalized Coulomb matrix. Defaults to False.
        
    Returns:
        torch.Tensor: The Coulomb matrix of the molecule.
    """
    if mask is not None:
        mask = mask.bool()
        mol = mol[mask, :]

    positions = mol[:, :3]
    one_hot = mol[:, 3:]
    atomic_numbers = []
    
    for i, vec in enumerate(one_hot):
        if torch.sum(vec) == 1:
            atom_index = torch.argmax(vec).item()
            atomic_number = const.CHARGES[const.IDX2ATOM[atom_index]]
            atomic_numbers.append(atomic_number)
        else:
            atomic_numbers.append(0)  
    
    num_atoms = positions.shape[0]
    coulomb_matrix = torch.zeros((num_atoms, num_atoms))

    for i in range(num_atoms):
        for j in range(num_atoms):
            if i == j:
                coulomb_matrix[i, j] = 0.5 * atomic_numbers[i] ** 2.4
            else:
                distance = torch.norm(positions[i] - positions[j])
                if distance == 0: #avoid division by zero
                    coulomb_matrix[i, j] = 0.0
                else:
                    coulomb_matrix[i, j] = atomic_numbers[i] * atomic_numbers[j] / distance

    if diagonalize:
        eigenvalues, eigenvectors = torch.linalg.eigh(coulomb_matrix)
        coulomb_matrix = torch.diag(eigenvalues)

    return coulomb_matrix

def compute_coulomb_matrices_batch(molecules, masks=None, diagonalize=False):
    """
    Compute the Coulomb matrices for a batch of molecules.
    
    Args:
        molecules (torch.Tensor): The batch of molecule tensors with shape (B, N, 4), where B is the batch size,
                                    N is the number of atoms, and the last dimension should contain [x, y, z, atomic_number].
        masks (torch.Tensor, optional): A batch of masks indicating which atoms to consider for each molecule. 
                                        If not provided, all atoms will be considered.
        
    Returns:
        torch.Tensor: The Coulomb matrices for the batch of molecules with shape (B, N, N).
    """
    batch_size = molecules.shape[0]
    # num_atoms = molecules.shape[1] #this is ok when the mask is not provided
    num_atoms = int(torch.sum(masks, dim=1).max().item()) if masks is not None else molecules.shape[1]
    coulomb_matrices = torch.zeros((batch_size, num_atoms, num_atoms), device=molecules.device)

    for b in range(batch_size):
        mol = molecules[b]
        mask = masks[b] if masks is not None else None
        coulomb_matrices[b] = compute_coulomb_matrix(mol, mask, diagonalize=diagonalize)

    return coulomb_matrices

def compute_frobenius_norm_batch(matrices):
    """
    Compute the Frobenius norm for a batch of matrices.
    
    Args:
        matrices (torch.Tensor): A batch of matrices with shape (B, N, N), where B is the batch size,
                                    and N is the number of rows/columns in each matrix.
        
    Returns:
        torch.Tensor: A tensor containing the Frobenius norm for each matrix in the batch.
    """
    # return torch.norm(matrices, dim=(1, 2), p='fro') #deprecated
    return torch.linalg.norm(matrices, ord='fro', dim=(1, 2))

def arrestomomentum():
    raise KeyboardInterrupt("Debug interrupt.")

## Explainability

### Utility function for visualization purposes

In [5]:
def draw_sphere_xai(ax, x, y, z, size, color, alpha):
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)

    xs = size * np.outer(np.cos(u), np.sin(v))
    ys = size * np.outer(np.sin(u), np.sin(v)) #* 0.8
    zs = size * np.outer(np.ones(np.size(u)), np.cos(v))
    ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, alpha=alpha)

def plot_molecule_xai(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom, fragment_mask=None, phi_values=None, invert_colormap = False):
    x = positions[:, 0]
    y = positions[:, 1]
    z = positions[:, 2]
    # Hydrogen, Carbon, Nitrogen, Oxygen, Flourine

    idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM

    colors_dic = np.array(const.COLORS)
    radius_dic = np.array(const.RADII)
    area_dic = 1500 * radius_dic ** 2

    areas = area_dic[atom_type]
    radii = radius_dic[atom_type]
    colors = colors_dic[atom_type]

    if fragment_mask is None:
        fragment_mask = torch.ones(len(x))

    for i in range(len(x)):
        for j in range(i + 1, len(x)):
            p1 = np.array([x[i], y[i], z[i]])
            p2 = np.array([x[j], y[j], z[j]])
            dist = np.sqrt(np.sum((p1 - p2) ** 2))
            atom1, atom2 = idx2atom[atom_type[i]], idx2atom[atom_type[j]]
            draw_edge_int = get_bond_order(atom1, atom2, dist)
            line_width = (3 - 2) * 2 * 2
            draw_edge = draw_edge_int > 0
            if draw_edge:
                if draw_edge_int == 4:
                    linewidth_factor = 1.5
                else:
                    linewidth_factor = 1
                linewidth_factor *= 0.5
                ax.plot(
                    [x[i], x[j]], [y[i], y[j]], [z[i], z[j]],
                    linewidth=line_width * linewidth_factor * 2,
                    c=hex_bg_color,
                    alpha=alpha
                )

    

    if spheres_3d:
        
        for i, j, k, s, c, f, phi in zip(x, y, z, radii, colors, fragment_mask, phi_values):
            if f == 1:
                alpha = 1.0
                if phi > 0:
                    c = 'red'

            draw_sphere_xai(ax, i.item(), j.item(), k.item(), 0.5 * s, c, alpha)

    else:
        phi_values_array = np.array(list(phi_values.values()))

        #draw fragments
        fragment_mask_on_cpu = fragment_mask.cpu().numpy()
        colors_fragment = colors[fragment_mask_on_cpu == 1]
        x_fragment = x[fragment_mask_on_cpu == 1]
        y_fragment = y[fragment_mask_on_cpu == 1]
        z_fragment = z[fragment_mask_on_cpu == 1]
        areas_fragment = areas[fragment_mask_on_cpu == 1]
        
        # Calculate the gradient colors based on phi values
        # cmap = plt.cm.get_cmap('coolwarm_r') #reversed heatmap for distance-based importance
        cmap = plt.cm.get_cmap('coolwarm') #heatmap for distance-based importance trying non reversed -> high shapley value mean more imporant, that drive the generation.
        #@mastro added invert_colormap to invert the colormap if average/expected value in higher than original prediction
        if invert_colormap:
            cmap = plt.cm.get_cmap('coolwarm_r')

        norm = plt.Normalize(vmin=min(phi_values_array), vmax=max(phi_values_array))
        colors_fragment_shadow = cmap(norm(phi_values_array))
        
        # ax.scatter(x_fragment, y_fragment, z_fragment, s=areas_fragment, alpha=0.9 * alpha, c=colors_fragment)

        ax.scatter(x_fragment, y_fragment, z_fragment, s=areas_fragment, alpha=0.9 * alpha, c=colors_fragment, edgecolors=colors_fragment_shadow, linewidths=5, rasterized=False)

        #draw non-fragment atoms
        colors = colors[fragment_mask_on_cpu == 0]
        x = x[fragment_mask_on_cpu == 0]
        y = y[fragment_mask_on_cpu == 0]
        z = z[fragment_mask_on_cpu == 0]
        areas = areas[fragment_mask_on_cpu == 0]
        ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, c=colors, rasterized=False)


def plot_data3d_xai(positions, atom_type, is_geom, camera_elev=0, camera_azim=0, save_path=None, spheres_3d=False,
                bg='black', alpha=1., fragment_mask=None, phi_values=None, invert_colormap = False):
    black = (0, 0, 0)
    white = (1, 1, 1)
    hex_bg_color = '#FFFFFF' if bg == 'black' else '#000000' #'#666666'

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(projection='3d')
    ax.set_aspect('auto')
    ax.view_init(elev=camera_elev, azim=camera_azim)
    if bg == 'black':
        ax.set_facecolor(black)
    else:
        ax.set_facecolor(white)
    ax.xaxis.pane.set_alpha(0)
    ax.yaxis.pane.set_alpha(0)
    ax.zaxis.pane.set_alpha(0)
    ax._axis3don = False

    if bg == 'black':
        ax.w_xaxis.line.set_color("black")
    else:
        ax.w_xaxis.line.set_color("white")

    plot_molecule_xai(
        ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom=is_geom, fragment_mask=fragment_mask, phi_values=phi_values, invert_colormap=invert_colormap
    )

    max_value = positions.abs().max().item()
    axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
    ax.set_xlim(-axis_lim, axis_lim)
    ax.set_ylim(-axis_lim, axis_lim)
    ax.set_zlim(-axis_lim, axis_lim)
    dpi = 300 if spheres_3d else 300 #it was 120 and 50

    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
        # plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi, transparent=True)

        if spheres_3d:
            img = imageio.imread(save_path)
            img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
            imageio.imsave(save_path, img_brighter)
    else:
        plt.show()
    plt.close()

def visualize_chain_xai(
        path, spheres_3d=False, bg="black", alpha=1.0, wandb=None, mode="chain", is_geom=False, fragment_mask=None, phi_values=None, invert_colormap = False
):
    files = load_xyz_files(path)
    save_paths = []

    # Fit PCA to the final molecule – to obtain the best orientation for visualization
    positions, one_hot, charges = load_molecule_xyz(files[-1], is_geom=is_geom)
    pca = PCA(n_components=3)
    pca.fit(positions)

    for i in range(len(files)):
        file = files[i]

        positions, one_hot, charges = load_molecule_xyz(file, is_geom=is_geom)
        atom_type = torch.argmax(one_hot, dim=1).numpy()

        # Transform positions of each frame according to the best orientation of the last frame
        positions = pca.transform(positions)
        positions = torch.tensor(positions)

        fn = file[:-4] + '.png'
        plot_data3d_xai(
            positions, atom_type,
            save_path=fn,
            spheres_3d=spheres_3d,
            alpha=alpha,
            bg=bg,
            camera_elev=90,
            camera_azim=90,
            is_geom=is_geom,
            fragment_mask=fragment_mask,
            phi_values=phi_values,
            invert_colormap=invert_colormap
        )
        save_paths.append(fn)

    imgs = [imageio.imread(fn) for fn in save_paths]
    dirname = os.path.dirname(save_paths[0])
    gif_path = dirname + '/output.gif'
    imageio.mimsave(gif_path, imgs, subrectangles=True)

    if wandb is not None:
        wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})

### Explainabiliy phase

##### Multiple sampling steps at a time

In [6]:
#@mastro
torch.set_printoptions(threshold=float('inf'))

num_samples = NUM_SAMPLES
sampled = 0
#end @mastro
start = 0

chain_with_full_fragments = None

# P = None #probability of atom to exist in random graph (also edge in the future)

# Create the folder if it does not exist
folder_save_path = "results/explanations_coulomb_diagonalized_" + P + "_seed_" + str(SEED) + "_full_molecule_original_fragments"


    
if not os.path.exists(folder_save_path):
    os.makedirs(folder_save_path)

data_list = []
for data in dataloader:

    if sampled < num_samples:
        data_list.append(data)
        sampled += 1

#determine max numebr of atoms of the molecules in the dataset. This is used to determine the size of the random noise, which we want to be equal for all molecules -> atoms not present in the molecule will be discarded using masks 
max_num_atoms = max(data["positions"].shape[1] for data in data_list)


#define initial random noise for positions and features #shape = [1, max_num_atoms, 3] for positions and [1, max_num_atoms, 8] for features. 1 since batch size is 1 for our explaination task
pos_size = (data_list[0]["positions"].shape[0], max_num_atoms, data_list[0]["positions"].shape[2])
feature_size = (data_list[0]["one_hot"].shape[0], max_num_atoms, data_list[0]["one_hot"].shape[2])

INTIAL_DISTIBUTION_PATH = "results/explanations_" + P + "_seed_" + str(SEED)
noisy_features = None
noisy_positions = None
#check if the initial distribution of the noisy features and positions already exists, if not create it
if os.path.exists(INTIAL_DISTIBUTION_PATH + "/noisy_features_seed_" + str(SEED) + ".pt"):
    # load initial distrubution of noisy features and positions
    print("Loading initial distribution of noisy features and positions.")
    noisy_features = torch.load(INTIAL_DISTIBUTION_PATH + "/noisy_features_seed_" + str(SEED) + ".pt", map_location=device, weights_only=True)
    noisy_positions = torch.load(INTIAL_DISTIBUTION_PATH + "/noisy_positions_seed_" + str(SEED) + ".pt", map_location=device, weights_only=True)

else:
    print("Creating initial distribution of noisy features and positions.")
    noisy_positions = torch.randn(pos_size, device=device)
    noisy_features = torch.randn(feature_size, device=device)


    #save the noisy positions and features on file .txt
    print("Saving noisy features and positions to .txt and .pt files.")
    noisy_positions_file = os.path.join(folder_save_path, "noisy_positions_seed_" + str(SEED) + ".txt")
    noisy_features_file = os.path.join(folder_save_path, "noisy_features_seed_" + str(SEED) + ".txt")

    with open(noisy_positions_file, "w") as f:
        f.write(str(noisy_positions))

    with open(noisy_features_file, "w") as f:
        f.write(str(noisy_features))

    torch.save(noisy_positions, os.path.join(folder_save_path, "noisy_positions_seed_" + str(SEED) + ".pt"))
    torch.save(noisy_features, os.path.join(folder_save_path, "noisy_features_seed_" + str(SEED) + ".pt"))

for data_index, data in enumerate(tqdm(data_list)): #7:

        # start_time = time.time()
        
        smile = data["name"][0]
        
        mol = read_smiles(smile)
        num_nodes = mol.number_of_nodes()
        
        num_edges = mol.number_of_edges()
        num_edges_directed = num_edges*2
        
        
        graph_density = num_edges_directed/(num_nodes*(num_nodes-1))
        max_number_of_nodes = num_edges + 1

        node_density = num_nodes/max_number_of_nodes

        node_edge_ratio = num_nodes/num_edges
        
        edge_node_ratio = num_edges/num_nodes
        
        if P == "graph_density":
            P = graph_density #probability of atom to exist in random graph
        elif P == "node_density":
            P = node_density
        elif P == "node_edge_ratio" or P == "edge_node_ratio":
            if node_edge_ratio < edge_node_ratio:
                P = node_edge_ratio
                print("Using node-edge ratio", node_edge_ratio)
            else:
                P = edge_node_ratio
                print("Using edge-node ratio", edge_node_ratio)            
        else:
            try:
                P = float(P)
            except ValueError:
                raise ValueError("P must be either 'graph_density', 'node_density', 'node_edge_ratio', 'edge_node_ratio' or a float value.")
        

        print("Using P:", P, P)

        chain_with_full_fragments = None
       
        rng = default_rng(seed = SEED)
        rng_torch = torch.Generator(device="cpu")
        rng_torch.manual_seed(SEED)
        # generate chain with original and full fragments
        
        #filter the noisy positions and features to have the same size as the data, removing the atoms not actually present in the molecule
        #we use the same max sized noise for all molecules to guaranteethat the same moleclues are inzialized with the same noise for the linker atoms in common -> noise for the fragme atoms will be discarded
        noisy_positions_present_atoms = noisy_positions.clone()
        noisy_features_present_atoms = noisy_features.clone()

        noisy_positions_present_atoms = noisy_positions_present_atoms[:, :data["positions"].shape[1], :]
        noisy_features_present_atoms = noisy_features_present_atoms[:, :data["one_hot"].shape[1], :]

        chain_batch, node_mask = model.sample_chain(data, keep_frames=keep_frames, noisy_positions=noisy_positions_present_atoms, noisy_features=noisy_features_present_atoms)
        
        #get the generated molecule and store it in a variable
        chain_with_full_fragments = chain_batch[0, :, :, :] #need to get only the final frame, is 0 ok in the first dimension?
        
        #compute the Coulob matrix of the generated linker @mastro edited to try with full molecule to capute all the interactions
        # coulomb_matrix = compute_coulomb_matrix(chain_with_full_fragments.squeeze(), mask = data["linker_mask"][0].squeeze())
        
        #compute coulomb matrix for the whole molecule
        coulomb_matrix = compute_coulomb_matrix(chain_with_full_fragments.squeeze(), diagonalize=DIAGONALIZE)

        

        # Get the abs eigen values and sort them in decreasing order https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.108.058301
        eigenvalues, eigenvectors = torch.linalg.eigh(coulomb_matrix)
        eigenvalues = torch.abs(eigenvalues)
        sorted_eigenvalues_original_molecule, _ = torch.sort(eigenvalues, descending=True)
        
        print("Eigenvalues: ", sorted_eigenvalues_original_molecule)
        distance_with_itself = torch.linalg.norm(sorted_eigenvalues_original_molecule - sorted_eigenvalues_original_molecule).item()

        print("Distance with itself: ", distance_with_itself)
        
        sorted_eigenvalues_original_molecule_batch = sorted_eigenvalues_original_molecule.repeat(PARALLEL_STEPS, 1)

        
        original_linker_mask_batch = data["linker_mask"][0].squeeze().repeat(PARALLEL_STEPS, 1) #check why it works
    
        
        
        num_fragment_atoms = torch.sum(data["fragment_mask"] == 1)

        phi_atoms = {}
        
        num_atoms = data["positions"].shape[1]
        num_linker_atoms = torch.sum(data["linker_mask"] == 1)
        
        # distances_random_samples = []
        # cosine_similarities_random_samples = []
        hausdorff_distances_random_samples = []
        euclidean_distance_random_samples = []
        # end_time = time.time()
        


        for j in tqdm(range(num_fragment_atoms)): 
            
            # marginal_contrib_distance = 0
            # marginal_contrib_cosine_similarity = 0
            # marginal_contrib_hausdorff = 0
            marginal_contrib_euclidean_distance = 0

            for step in range(int(M/PARALLEL_STEPS)):

                # start_time = time.time()

                fragment_indices = torch.where(data["fragment_mask"] == 1)[1]
                num_fragment_atoms = len(fragment_indices)
                fragment_indices = fragment_indices.repeat(PARALLEL_STEPS).to(device)

                data_j_plus = data.copy()
                data_j_minus = data.copy()
                data_random = data.copy()

                N_z_mask = torch.tensor(np.array([rng.binomial(1, P, size = num_fragment_atoms) for _ in range(PARALLEL_STEPS)]), dtype=torch.int32)
                # Ensure at least one element is 1, otherwise randomly select one since at least one fragment atom must be present
                
                
                for i in range(len(N_z_mask)):

                    #set the current explained atom to 0 in N_z_mask
                    N_z_mask[i][j] = 0 #so it is always one when taken from the oriignal sample and 0 when taken from the random sample. Check if it is more efficient to directly set it or check if it is already 0

                    if not N_z_mask[i].any():
                        
                        
                        random_index = j #j is the current explained atom, it should always be set to 0
                        while random_index == j:
                            random_index = rng.integers(0, num_fragment_atoms)
                        N_z_mask[i][random_index] = 1
                        
                       
                    

                N_z_mask=N_z_mask.flatten().to(device)
                
                
                

                N_mask = torch.ones(PARALLEL_STEPS * num_fragment_atoms, dtype=torch.int32, device=device)

                

                pi = torch.cat([torch.randperm(num_fragment_atoms, generator=rng_torch) for _ in range(PARALLEL_STEPS)], dim=0)

                N_j_plus_index = torch.ones(PARALLEL_STEPS*num_fragment_atoms, dtype=torch.int, device=device)
                N_j_minus_index = torch.ones(PARALLEL_STEPS*num_fragment_atoms, dtype=torch.int, device=device)

                selected_node_index = np.where(pi == j)
                selected_node_index = torch.tensor(np.array(selected_node_index), device=device).squeeze()
                selected_node_index = selected_node_index.repeat_interleave(num_fragment_atoms) #@mastro TO BE CHECKED IF THIS IS CORRECT
                
                k_values = torch.arange(num_fragment_atoms*PARALLEL_STEPS, device=device)

                add_to_pi = torch.arange(start=0, end=PARALLEL_STEPS*num_fragment_atoms, step=num_fragment_atoms).repeat_interleave(num_fragment_atoms) #check if it is correct ot consider num_fragment_atoms and not num_atoms

                pi_add = pi + add_to_pi
                pi_add = pi_add.to(device=device)
                #this must be cafeully checked. this should be adapted for nodes
                add_to_node_index = torch.arange(start=0, end=PARALLEL_STEPS*num_atoms, step=num_atoms) #@mastro change step from num_fragment_atoms to num_atoms
                
                add_to_node_index = add_to_node_index.repeat_interleave(num_fragment_atoms).to(device) #changed from num_atoms to num_fragment_atoms

                
                N_j_plus_index[pi_add] = torch.where(k_values <= selected_node_index, N_mask[pi_add], N_z_mask[pi_add])
                N_j_minus_index[pi_add] = torch.where(k_values < selected_node_index, N_mask[pi_add], N_z_mask[pi_add]) 

                #fragements to keep in molecule j plus
                fragment_indices = fragment_indices + add_to_node_index
                
                
                N_j_plus = fragment_indices[(N_j_plus_index==1)] #fragment to keep in molecule j plus
                #fragement indices to keep in molecule j minus
               
                N_j_minus = fragment_indices[(N_j_minus_index==1)] #it is ok. it contains fragmens indices to keep in molecule j minus (indices that index the atom nodes)

                #fragement indices to keep in random molecule
                N_random_sample = fragment_indices[(N_z_mask==1)] 
                
               
                atom_mask_j_plus = torch.zeros(num_atoms*PARALLEL_STEPS, dtype=torch.bool)
                atom_mask_j_minus = torch.zeros(num_atoms*PARALLEL_STEPS, dtype=torch.bool)
                atom_mask_random_molecule = torch.zeros(num_atoms*PARALLEL_STEPS, dtype=torch.bool)

                atom_mask_j_plus[N_j_plus] = True
                
                atom_mask_j_minus[N_j_minus] = True

                #set to true also linker atoms
                parallelized_linker_mask = data["linker_mask"][0].squeeze().to(torch.int).repeat(PARALLEL_STEPS)
                atom_mask_j_plus[(parallelized_linker_mask == 1)] = True 

                #set to true also linker atoms
                atom_mask_j_minus[(parallelized_linker_mask == 1)] = True 

                atom_mask_random_molecule[N_random_sample] = True
                #set to true also linker atoms
                atom_mask_random_molecule[(parallelized_linker_mask == 1)] = True
                
               
                atom_mask_j_plus = atom_mask_j_plus.view(PARALLEL_STEPS, num_atoms)
                atom_mask_j_minus = atom_mask_j_minus.view(PARALLEL_STEPS, num_atoms)
                atom_mask_random_molecule = atom_mask_random_molecule.view(PARALLEL_STEPS, num_atoms)
                
                

                data_j_plus_dict = {}
                data_j_minus_dict = {}
                data_random_dict = {}

                noisy_features_j_plus_dict = {}
                noisy_positions_j_plus_dict = {}
                noisy_features_j_minus_dict = {}
                noisy_positions_j_minus_dict = {}
                noisy_features_random_dict = {}
                noisy_positions_random_dict = {}
                
                # start_time = time.time()
                for i in range(PARALLEL_STEPS):

                    # Remove fragment atoms that are not present for j plus
                    noisy_features_present_atoms_j_plus = noisy_features_present_atoms.clone()
                    noisy_features_j_plus_dict[i] = noisy_features_present_atoms_j_plus[:, atom_mask_j_plus[i], :]
                    
                    noisy_positions_present_atoms_j_plus = noisy_positions_present_atoms.clone()
                    noisy_positions_j_plus_dict[i] = noisy_positions_present_atoms_j_plus[:, atom_mask_j_plus[i], :]

                    # Remove fragment atoms that are not present for j minus
                    noisy_features_present_atoms_j_minus = noisy_features_present_atoms.clone()
                    noisy_features_j_minus_dict[i] = noisy_features_present_atoms_j_minus[:, atom_mask_j_minus[i], :]

                    noisy_positions_present_atoms_j_minus = noisy_positions_present_atoms.clone()
                    noisy_positions_j_minus_dict[i] = noisy_positions_present_atoms_j_minus[:, atom_mask_j_minus[i], :]

                    # Remove fragment atoms that are not present for random molecule
                    noisy_features_present_atoms_random = noisy_features_present_atoms.clone()
                    noisy_features_random_dict[i] = noisy_features_present_atoms_random[:, atom_mask_random_molecule[i], :]

                    noisy_positions_present_atoms_random = noisy_positions_present_atoms.clone()
                    noisy_positions_random_dict[i] = noisy_positions_present_atoms_random[:, atom_mask_random_molecule[i], :]


                    data_j_plus_dict[i] = data.copy()
                    data_j_minus_dict[i] = data.copy()
                    data_random_dict[i] = data.copy()

                    #data j plus
                    data_j_plus_dict[i]["positions"] = data_j_plus_dict[i]["positions"][:, atom_mask_j_plus[i]]
                    data_j_plus_dict[i]["num_atoms"] = data_j_plus_dict[i]["positions"].shape[1]
                    # remove one_hot of atoms in random_indices
                    data_j_plus_dict[i]["one_hot"] = data_j_plus_dict[i]["one_hot"][:, atom_mask_j_plus[i]]
                    # remove atom_mask of atoms in random_indices
                    data_j_plus_dict[i]["atom_mask"] = data_j_plus_dict[i]["atom_mask"][:, atom_mask_j_plus[i]]
                    # remove fragment_mask of atoms in random_indices
                    data_j_plus_dict[i]["fragment_mask"] = data_j_plus_dict[i]["fragment_mask"][:, atom_mask_j_plus[i]]
                    # remove linker_mask of atoms in random_indices
                    data_j_plus_dict[i]["linker_mask"] = data_j_plus_dict[i]["linker_mask"][:, atom_mask_j_plus[i]]
                    data_j_plus_dict[i]["charges"] = data_j_plus_dict[i]["charges"][:, atom_mask_j_plus[i]]
                    data_j_plus_dict[i]["anchors"] = data_j_plus_dict[i]["anchors"][:, atom_mask_j_plus[i]]
                    edge_mask_to_keep = (atom_mask_j_plus[i].unsqueeze(1) * atom_mask_j_plus[i]).flatten()
                    data_j_plus_dict[i]["edge_mask"] = data_j_plus_dict[i]["edge_mask"][edge_mask_to_keep]

                    #data j minus
                    data_j_minus_dict[i]["positions"] = data_j_minus_dict[i]["positions"][:, atom_mask_j_minus[i]]
                    data_j_minus_dict[i]["num_atoms"] = data_j_minus_dict[i]["positions"].shape[1]
                    # remove one_hot of atoms in random_indices
                    data_j_minus_dict[i]["one_hot"] = data_j_minus_dict[i]["one_hot"][:, atom_mask_j_minus[i]]
                    # remove atom_mask of atoms in random_indices
                    data_j_minus_dict[i]["atom_mask"] = data_j_minus_dict[i]["atom_mask"][:, atom_mask_j_minus[i]]
                    # remove fragment_mask of atoms in random_indices
                    data_j_minus_dict[i]["fragment_mask"] = data_j_minus_dict[i]["fragment_mask"][:, atom_mask_j_minus[i]]
                    # remove linker_mask of atoms in random_indices
                    data_j_minus_dict[i]["linker_mask"] = data_j_minus_dict[i]["linker_mask"][:, atom_mask_j_minus[i]]
                    data_j_minus_dict[i]["charges"] = data_j_minus_dict[i]["charges"][:, atom_mask_j_minus[i]]
                    data_j_minus_dict[i]["anchors"] = data_j_minus_dict[i]["anchors"][:, atom_mask_j_minus[i]]
                    # remove edge_mask of atoms in random_indices
                    edge_mask_to_keep = (atom_mask_j_minus[i].unsqueeze(1) * atom_mask_j_minus[i]).flatten() 
                    data_j_minus_dict[i]["edge_mask"] = data_j_minus_dict[i]["edge_mask"][edge_mask_to_keep]

                    #data random
                    data_random_dict[i]["positions"] = data_random_dict[i]["positions"][:, atom_mask_random_molecule[i]]
                    data_random_dict[i]["num_atoms"] = data_random_dict[i]["positions"].shape[1]
                    # remove one_hot of atoms in random_indices
                    data_random_dict[i]["one_hot"] = data_random_dict[i]["one_hot"][:, atom_mask_random_molecule[i]]
                    # remove atom_mask of atoms in random_indices
                    data_random_dict[i]["atom_mask"] = data_random_dict[i]["atom_mask"][:, atom_mask_random_molecule[i]]
                    # remove fragment_mask of atoms in random_indices
                    data_random_dict[i]["fragment_mask"] = data_random_dict[i]["fragment_mask"][:, atom_mask_random_molecule[i]]
                    # remove linker_mask of atoms in random_indices
                    data_random_dict[i]["linker_mask"] = data_random_dict[i]["linker_mask"][:, atom_mask_random_molecule[i]]
                    data_random_dict[i]["charges"] = data_random_dict[i]["charges"][:, atom_mask_random_molecule[i]]
                    data_random_dict[i]["anchors"] = data_random_dict[i]["anchors"][:, atom_mask_random_molecule[i]]
                    # remove edge_mask of atoms in random_indices
                    # remove edge_mask of atoms in random_indices
                    edge_mask_to_keep = (atom_mask_random_molecule[i].unsqueeze(1) * atom_mask_random_molecule[i]).flatten() 

                    data_random_dict[i]["edge_mask"] = data_random_dict[i]["edge_mask"][edge_mask_to_keep]
                
                

                PADDING = True

                # start_time = time.time()
                if PADDING:

                    max_atoms_j_plus = max(data_j_plus_dict[i]["num_atoms"] for i in range(PARALLEL_STEPS))

                    max_edges_j_plus = max(data_j_plus_dict[i]["edge_mask"].shape[0] for i in range(PARALLEL_STEPS))
                    
                    
                    max_atoms_j_minus = max(data_j_minus_dict[i]["num_atoms"] for i in range(PARALLEL_STEPS))

                    max_edges_j_minus = max(data_j_minus_dict[i]["edge_mask"].shape[0] for i in range(PARALLEL_STEPS))

                    max_atoms_random = max(data_random_dict[i]["num_atoms"] for i in range(PARALLEL_STEPS))

                    max_edges_random = max(data_random_dict[i]["edge_mask"].shape[0] for i in range(PARALLEL_STEPS))
                    
                    for i in range(PARALLEL_STEPS):
                        #for j plus positions
                        num_atoms_to_stack = max_atoms_j_plus - data_j_plus_dict[i]["positions"].shape[1]
                        padding = torch.zeros(data_j_plus_dict[i]["positions"].shape[0], num_atoms_to_stack, data_j_plus_dict[i]["positions"].shape[2]).to(device)
                        stacked_positions = torch.cat((data_j_plus_dict[i]["positions"], padding), dim=1)
                        data_j_plus_dict[i]["positions"] = stacked_positions
                        #for j plus one_hot
                        padding = torch.zeros(data_j_plus_dict[i]["one_hot"].shape[0], num_atoms_to_stack, data_j_plus_dict[i]["one_hot"].shape[2]).to(device)
                        stacked_one_hot = torch.cat((data_j_plus_dict[i]["one_hot"], padding), dim=1)
                        data_j_plus_dict[i]["one_hot"] = stacked_one_hot
                        padding = torch.zeros(data_j_plus_dict[i]["fragment_mask"].shape[0], num_atoms_to_stack, data_j_plus_dict[i]["fragment_mask"].shape[2]).to(device)
                        stacked_fragment_mask = torch.cat((data_j_plus_dict[i]["fragment_mask"], padding), dim=1)
                        data_j_plus_dict[i]["fragment_mask"] = stacked_fragment_mask
                        padding = torch.zeros(data_j_plus_dict[i]["charges"].shape[0], num_atoms_to_stack, data_j_plus_dict[i]["charges"].shape[2]).to(device)
                        stacked_charges = torch.cat((data_j_plus_dict[i]["charges"], padding), dim=1)
                        data_j_plus_dict[i]["charges"] = stacked_charges
                        padding = torch.zeros(data_j_plus_dict[i]["anchors"].shape[0], num_atoms_to_stack, data_j_plus_dict[i]["anchors"].shape[2]).to(device)
                        stacked_anchors = torch.cat((data_j_plus_dict[i]["anchors"], padding), dim=1)
                        data_j_plus_dict[i]["anchors"] = stacked_anchors
                        padding = torch.zeros(data_j_plus_dict[i]["linker_mask"].shape[0], num_atoms_to_stack, data_j_plus_dict[i]["linker_mask"].shape[2]).to(device)
                        stacked_linker_mask = torch.cat((data_j_plus_dict[i]["linker_mask"], padding), dim=1)
                        data_j_plus_dict[i]["linker_mask"] = stacked_linker_mask
                        padding = torch.zeros(data_j_plus_dict[i]["atom_mask"].shape[0], num_atoms_to_stack, data_j_plus_dict[i]["atom_mask"].shape[2]).to(device)
                        stacked_atom_mask = torch.cat((data_j_plus_dict[i]["atom_mask"], padding), dim=1)
                        data_j_plus_dict[i]["atom_mask"] = stacked_atom_mask
                        num_edges_to_stack = max_edges_j_plus - data_j_plus_dict[i]["edge_mask"].shape[0]
                        data_j_plus_dict[i]["edge_mask"] = data_j_plus_dict[i]["edge_mask"].unsqueeze(0)
                        padding = torch.zeros(data_j_plus_dict[i]["edge_mask"].shape[0], num_edges_to_stack, data_j_plus_dict[i]["edge_mask"].shape[2]).to(device)
                        stacked_edge_mask = torch.cat((data_j_plus_dict[i]["edge_mask"], padding), dim=1)
                        data_j_plus_dict[i]["edge_mask"] = stacked_edge_mask
                        
                        #for noisy positions and features for j plus
                        noisy_positions_j_plus_dict[i] = noisy_positions_j_plus_dict[i] #check this
                        padding = torch.zeros(noisy_positions_j_plus_dict[i].shape[0], num_atoms_to_stack, noisy_positions_j_plus_dict[i].shape[2]).to(device)
                        stacked_positions = torch.cat((noisy_positions_j_plus_dict[i], padding), dim=1)
                        noisy_positions_j_plus_dict[i] = stacked_positions

                        noisy_features_j_plus_dict[i] = noisy_features_j_plus_dict[i]
                        padding = torch.zeros(noisy_features_j_plus_dict[i].shape[0], num_atoms_to_stack, noisy_features_j_plus_dict[i].shape[2]).to(device)
                        stacked_features = torch.cat((noisy_features_j_plus_dict[i], padding), dim=1)
                        noisy_features_j_plus_dict[i] = stacked_features

                        #for j minus
                        num_atoms_to_stack = max_atoms_j_minus - data_j_minus_dict[i]["positions"].shape[1]
                        padding = torch.zeros(data_j_minus_dict[i]["positions"].shape[0], num_atoms_to_stack, data_j_minus_dict[i]["positions"].shape[2]).to(device) #why does this work?
                        stacked_positions = torch.cat((data_j_minus_dict[i]["positions"], padding), dim=1)
                        data_j_minus_dict[i]["positions"] = stacked_positions
                        
                        padding = torch.zeros(data_j_minus_dict[i]["one_hot"].shape[0], num_atoms_to_stack, data_j_minus_dict[i]["one_hot"].shape[2]).to(device)
                        stacked_one_hot = torch.cat((data_j_minus_dict[i]["one_hot"], padding), dim=1)
                        data_j_minus_dict[i]["one_hot"] = stacked_one_hot
                        
                        padding = torch.zeros(data_j_minus_dict[i]["fragment_mask"].shape[0], num_atoms_to_stack, data_j_minus_dict[i]["fragment_mask"].shape[2]).to(device)
                        stacked_fragment_mask = torch.cat((data_j_minus_dict[i]["fragment_mask"], padding), dim=1)
                        data_j_minus_dict[i]["fragment_mask"] = stacked_fragment_mask

                        
                        padding = torch.zeros(data_j_minus_dict[i]["charges"].shape[0], num_atoms_to_stack, data_j_minus_dict[i]["charges"].shape[2]).to(device)
                        stacked_charges = torch.cat((data_j_minus_dict[i]["charges"], padding), dim=1)
                        data_j_minus_dict[i]["charges"] = stacked_charges
                        
                        padding = torch.zeros(data_j_minus_dict[i]["anchors"].shape[0], num_atoms_to_stack, data_j_minus_dict[i]["anchors"].shape[2]).to(device)
                        stacked_anchors = torch.cat((data_j_minus_dict[i]["anchors"], padding), dim=1)
                        data_j_minus_dict[i]["anchors"] = stacked_anchors
                       
                        padding = torch.zeros(data_j_minus_dict[i]["linker_mask"].shape[0], num_atoms_to_stack, data_j_minus_dict[i]["linker_mask"].shape[2]).to(device)
                        stacked_linker_mask = torch.cat((data_j_minus_dict[i]["linker_mask"], padding), dim=1)
                        data_j_minus_dict[i]["linker_mask"] = stacked_linker_mask
                        
                        padding = torch.zeros(data_j_minus_dict[i]["atom_mask"].shape[0], num_atoms_to_stack, data_j_minus_dict[i]["atom_mask"].shape[2]).to(device)
                        stacked_atom_mask = torch.cat((data_j_minus_dict[i]["atom_mask"], padding), dim=1)
                        data_j_minus_dict[i]["atom_mask"] = stacked_atom_mask
                        
                        num_edges_to_stack = max_edges_j_minus - data_j_minus_dict[i]["edge_mask"].shape[0]
                        data_j_minus_dict[i]["edge_mask"] = data_j_minus_dict[i]["edge_mask"].unsqueeze(0)
                        padding = torch.zeros(data_j_minus_dict[i]["edge_mask"].shape[0], num_edges_to_stack, data_j_minus_dict[i]["edge_mask"].shape[2]).to(device)
                        stacked_edge_mask = torch.cat((data_j_minus_dict[i]["edge_mask"], padding), dim=1)
                        data_j_minus_dict[i]["edge_mask"] = stacked_edge_mask
                    
                        #for noisy positions and features for j plus
                        noisy_positions_j_minus_dict[i] = noisy_positions_j_minus_dict[i] #check this
                        padding = torch.zeros(noisy_positions_j_minus_dict[i].shape[0], num_atoms_to_stack, noisy_positions_j_minus_dict[i].shape[2]).to(device)
                        stacked_positions = torch.cat((noisy_positions_j_minus_dict[i], padding), dim=1)
                        noisy_positions_j_minus_dict[i] = stacked_positions

                        noisy_features_j_minus_dict[i] = noisy_features_j_minus_dict[i]
                        padding = torch.zeros(noisy_features_j_minus_dict[i].shape[0], num_atoms_to_stack, noisy_features_j_minus_dict[i].shape[2]).to(device)
                        stacked_features = torch.cat((noisy_features_j_minus_dict[i], padding), dim=1)
                        noisy_features_j_minus_dict[i] = stacked_features

                        #for random
                        num_atoms_to_stack = max_atoms_random - data_random_dict[i]["positions"].shape[1]
                        padding = torch.zeros(data_random_dict[i]["positions"].shape[0], num_atoms_to_stack, data_random_dict[i]["positions"].shape[2]).to(device)
                        stacked_positions = torch.cat((data_random_dict[i]["positions"], padding), dim=1)
                        data_random_dict[i]["positions"] = stacked_positions
                        
                        padding = torch.zeros(data_random_dict[i]["one_hot"].shape[0], num_atoms_to_stack, data_random_dict[i]["one_hot"].shape[2]).to(device)
                        stacked_one_hot = torch.cat((data_random_dict[i]["one_hot"], padding), dim=1)
                        data_random_dict[i]["one_hot"] = stacked_one_hot
                        
                        padding = torch.zeros(data_random_dict[i]["fragment_mask"].shape[0], num_atoms_to_stack, data_random_dict[i]["fragment_mask"].shape[2]).to(device)
                        stacked_fragment_mask = torch.cat((data_random_dict[i]["fragment_mask"], padding), dim=1)
                        data_random_dict[i]["fragment_mask"] = stacked_fragment_mask
                        
                        padding = torch.zeros(data_random_dict[i]["linker_mask"].shape[0], num_atoms_to_stack, data_random_dict[i]["linker_mask"].shape[2]).to(device)
                        stacked_linker_mask = torch.cat((data_random_dict[i]["linker_mask"], padding), dim=1)
                        data_random_dict[i]["linker_mask"] = stacked_linker_mask

                       
                        padding = torch.zeros(data_random_dict[i]["charges"].shape[0], num_atoms_to_stack, data_random_dict[i]["charges"].shape[2]).to(device)
                        stacked_charges = torch.cat((data_random_dict[i]["charges"], padding), dim=1)
                        data_random_dict[i]["charges"] = stacked_charges

                    
                        padding = torch.zeros(data_random_dict[i]["anchors"].shape[0], num_atoms_to_stack, data_random_dict[i]["anchors"].shape[2]).to(device)
                        stacked_anchors = torch.cat((data_random_dict[i]["anchors"], padding), dim=1)
                        data_random_dict[i]["anchors"] = stacked_anchors
                       
                        padding = torch.zeros(data_random_dict[i]["atom_mask"].shape[0], num_atoms_to_stack, data_random_dict[i]["atom_mask"].shape[2]).to(device)
                        stacked_atom_mask = torch.cat((data_random_dict[i]["atom_mask"], padding), dim=1)
                        data_random_dict[i]["atom_mask"] = stacked_atom_mask
                        
                        num_edges_to_stack = max_edges_random - data_random_dict[i]["edge_mask"].shape[0]
                        data_random_dict[i]["edge_mask"] = data_random_dict[i]["edge_mask"].unsqueeze(0)
                        padding = torch.zeros(data_random_dict[i]["edge_mask"].shape[0], num_edges_to_stack, data_random_dict[i]["edge_mask"].shape[2]).to(device)
                        stacked_edge_mask = torch.cat((data_random_dict[i]["edge_mask"], padding), dim=1)
                        data_random_dict[i]["edge_mask"] = stacked_edge_mask

                        #for noisy positions and features for j plus
                        noisy_positions_random_dict[i] = noisy_positions_random_dict[i] #check this
                        padding = torch.zeros(noisy_positions_random_dict[i].shape[0], num_atoms_to_stack, noisy_positions_random_dict[i].shape[2]).to(device)
                        stacked_positions = torch.cat((noisy_positions_random_dict[i], padding), dim=1)
                        noisy_positions_random_dict[i] = stacked_positions

                        noisy_features_random_dict[i] = noisy_features_random_dict[i]
                        padding = torch.zeros(noisy_features_random_dict[i].shape[0], num_atoms_to_stack, noisy_features_random_dict[i].shape[2]).to(device)
                        stacked_features = torch.cat((noisy_features_random_dict[i], padding), dim=1)
                        noisy_features_random_dict[i] = stacked_features
                        
                        

                
                #create batch for j plus
                data_j_plus_batch = {}
                data_j_plus_batch["positions"] = torch.stack([data_j_plus_dict[i]["positions"] for i in range(PARALLEL_STEPS)], dim=0).squeeze()
                data_j_plus_batch["one_hot"] = torch.stack([data_j_plus_dict[i]["one_hot"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_j_plus_batch["atom_mask"] = torch.stack([data_j_plus_dict[i]["atom_mask"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_j_plus_batch["fragment_mask"] = torch.stack([data_j_plus_dict[i]["fragment_mask"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_j_plus_batch["linker_mask"] = torch.stack([data_j_plus_dict[i]["linker_mask"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_j_plus_batch["charges"] = torch.stack([data_j_plus_dict[i]["charges"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_j_plus_batch["anchors"] = torch.stack([data_j_plus_dict[i]["anchors"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                
                
                data_j_plus_batch["uuid"] = [i for i in range(PARALLEL_STEPS)]
                data_j_plus_batch["num_atoms"] = [data_j_plus_dict[i]["num_atoms"] for i in range(PARALLEL_STEPS)]
                data_j_plus_batch["name"] = [data["name"] for _ in range(PARALLEL_STEPS)]
                data_j_plus_batch["edge_mask"] = torch.cat([data_j_plus_dict[i]["edge_mask"] for i in range(PARALLEL_STEPS)], dim=0).squeeze().view(-1).unsqueeze(1)


                #create batch for j minus
                data_j_minus_batch = {}
                data_j_minus_batch["positions"] = torch.stack([data_j_minus_dict[i]["positions"] for i in range(PARALLEL_STEPS)], dim=0).squeeze()
                data_j_minus_batch["one_hot"] = torch.stack([data_j_minus_dict[i]["one_hot"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_j_minus_batch["atom_mask"] = torch.stack([data_j_minus_dict[i]["atom_mask"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_j_minus_batch["fragment_mask"] = torch.stack([data_j_minus_dict[i]["fragment_mask"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_j_minus_batch["linker_mask"] = torch.stack([data_j_minus_dict[i]["linker_mask"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_j_minus_batch["charges"] = torch.stack([data_j_minus_dict[i]["charges"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_j_minus_batch["anchors"] = torch.stack([data_j_minus_dict[i]["anchors"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                
                data_j_minus_batch["uuid"] = [i for i in range(PARALLEL_STEPS)]
                data_j_minus_batch["num_atoms"] = [data_j_minus_dict[i]["num_atoms"] for i in range(PARALLEL_STEPS)]
                data_j_minus_batch["name"] = [data["name"] for _ in range(PARALLEL_STEPS)]
                data_j_minus_batch["edge_mask"] = torch.cat([data_j_minus_dict[i]["edge_mask"] for i in range(PARALLEL_STEPS)], dim=0).squeeze().view(-1).unsqueeze(1)

                #create batch for random
                data_random_batch = {}
                data_random_batch["positions"] = torch.stack([data_random_dict[i]["positions"] for i in range(PARALLEL_STEPS)], dim=0).squeeze()
                data_random_batch["one_hot"] = torch.stack([data_random_dict[i]["one_hot"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_random_batch["atom_mask"] = torch.stack([data_random_dict[i]["atom_mask"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_random_batch["fragment_mask"] = torch.stack([data_random_dict[i]["fragment_mask"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_random_batch["linker_mask"] = torch.stack([data_random_dict[i]["linker_mask"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_random_batch["charges"] = torch.stack([data_random_dict[i]["charges"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                data_random_batch["anchors"] = torch.stack([data_random_dict[i]["anchors"] for i in range(PARALLEL_STEPS)], dim=0).squeeze(1)
                
                data_random_batch["uuid"] = [i for i in range(PARALLEL_STEPS)]
                data_random_batch["num_atoms"] = [data_random_dict[i]["num_atoms"] for i in range(PARALLEL_STEPS)]
                data_random_batch["name"] = [data["name"] for _ in range(PARALLEL_STEPS)]
                data_random_batch["edge_mask"] = torch.cat([data_random_dict[i]["edge_mask"] for i in range(PARALLEL_STEPS)], dim=0).squeeze().view(-1).unsqueeze(1)

                
                #create batches for noisy positions and features
                noisy_positions_batch_j_plus = torch.stack([noisy_positions_j_plus_dict[i] for i in range(PARALLEL_STEPS)], dim=0).squeeze()
                noisy_features_batch_j_plus = torch.stack([noisy_features_j_plus_dict[i] for i in range(PARALLEL_STEPS)], dim=0).squeeze()

                noisy_positions_batch_j_minus = torch.stack([noisy_positions_j_minus_dict[i] for i in range(PARALLEL_STEPS)], dim=0).squeeze()
                noisy_features_batch_j_minus = torch.stack([noisy_features_j_minus_dict[i] for i in range(PARALLEL_STEPS)], dim=0).squeeze()

                noisy_positions_batch_random = torch.stack([noisy_positions_random_dict[i] for i in range(PARALLEL_STEPS)], dim=0).squeeze()
                noisy_features_batch_random = torch.stack([noisy_features_random_dict[i] for i in range(PARALLEL_STEPS)], dim=0).squeeze()
                
                

                
                chain_j_plus_batch, node_mask_j_plus_batch = model.sample_chain(data_j_plus_batch, keep_frames=keep_frames, noisy_positions=noisy_positions_batch_j_plus, noisy_features=noisy_features_batch_j_plus)

                chain_j_plus = chain_j_plus_batch[0, :, :, :] #it should take the first frame and all batch elements -> check it is really the first frame (I need the one at t0, the final generated molecule)
                

                chain_j_minus_batch, node_mask_j_minus_batch = model.sample_chain(data_j_minus_batch, keep_frames=keep_frames, noisy_positions=noisy_positions_batch_j_minus, noisy_features=noisy_features_batch_j_minus)

                chain_j_minus = chain_j_minus_batch[0, :, :, :]

                chain_random_batch, node_mask_random_batch = model.sample_chain(data_random_batch, keep_frames=keep_frames, noisy_positions=noisy_positions_batch_random, noisy_features=noisy_features_batch_random)

                chain_random = chain_random_batch[0, :, :, :]
                
                

                chain_with_full_fragments_batch = chain_with_full_fragments.repeat(PARALLEL_STEPS, 1, 1)

                # Check if all vectors in data_j_plus_batch["linker_mask"] are the same
                
                
                ####NEW CODE#######
                #@mastro creating new molecule containing the original fragmsnes and the linker generated using molecule j_plus, molecule j_minus and random molecule
                chain_j_plus_batch_original_fragments = chain_with_full_fragments_batch.clone()
                
                # Ensure the masks have the correct shape
                mask1 = data["linker_mask"][0].squeeze() == 1
                mask2 = data_j_plus_batch["linker_mask"].squeeze() == 1

                # Check if the masks need to be expanded
                if mask1.dim() == 1 and chain_j_plus_batch_original_fragments.dim() == 3:
                    mask1 = mask1.unsqueeze(0).expand(chain_j_plus_batch_original_fragments.size(0), -1)

                
                # Apply the masks
                # Ensure the shapes match for the assignment
                if chain_j_plus_batch_original_fragments[mask1, :].shape == chain_j_plus[mask2, :].shape:
                    chain_j_plus_batch_original_fragments[mask1, :] = chain_j_plus[mask2, :]
                else:
                    print("Shape mismatch:", chain_j_plus_batch_original_fragments[mask1, :].shape, chain_j_plus[mask2, :].shape)
                
                
                # print("chain_j_plus_batch_original_fragments shape", chain_j_plus_batch_original_fragments.shape)

                
                # chain_j_minus_batch_original_fragments = chain_j_minus.clone()
                # chain_j_minus_batch_original_fragments[:, data["fragment_mask"][0].squeeze() == 1, :] = chain_with_full_fragments_batch[:, data["fragment_mask"][0].squeeze() == 1, :]

                chain_j_minus_batch_original_fragments = chain_with_full_fragments_batch.clone()
                
                # Ensure the masks have the correct shape
                mask1 = data["linker_mask"][0].squeeze() == 1
                mask2 = data_j_minus_batch["linker_mask"].squeeze() == 1

                # Check if the masks need to be expanded
                if mask1.dim() == 1 and chain_j_minus_batch_original_fragments.dim() == 3:
                    mask1 = mask1.unsqueeze(0).expand(chain_j_minus_batch_original_fragments.size(0), -1)

                
                # Apply the masks
                # Ensure the shapes match for the assignment
                if chain_j_minus_batch_original_fragments[mask1, :].shape == chain_j_minus[mask2, :].shape:
                    chain_j_minus_batch_original_fragments[mask1, :] = chain_j_minus[mask2, :]
                else:
                    print("Shape mismatch:", chain_j_minus_batch_original_fragments[mask1, :].shape, chain_j_minus[mask2, :].shape)
                
                
                # print("chain_j_minus_batch_original_fragments shape", chain_j_minus_batch_original_fragments.shape)

                # chain_random_batch_original_fragments = chain_random.clone()
                # chain_random_batch_original_fragments[:, data["fragment_mask"][0].squeeze() == 1, :] = chain_with_full_fragments_batch[:, data["fragment_mask"][0].squeeze() == 1, :]

                # chain_random_batch_original_fragments = chain_with_full_fragments_batch.clone()
                # chain_random_batch_original_fragments[data["linker_mask"][0].squeeze() == 1, :] = chain_random[:, data_random_batch["linker_mask"].squeeze() == 1, :]

                chain_random_batch_original_fragments = chain_with_full_fragments_batch.clone()
                
                # Ensure the masks have the correct shape
                mask1 = data["linker_mask"][0].squeeze() == 1
                mask2 = data_random_batch["linker_mask"].squeeze() == 1

                # Check if the masks need to be expanded
                if mask1.dim() == 1 and chain_random_batch_original_fragments.dim() == 3:
                    mask1 = mask1.unsqueeze(0).expand(chain_random_batch_original_fragments.size(0), -1)

                
                # Apply the masks
                # Ensure the shapes match for the assignment
                if chain_random_batch_original_fragments[mask1, :].shape == chain_random[mask2, :].shape:
                    chain_random_batch_original_fragments[mask1, :] = chain_random[mask2, :]
                else:
                    print("Shape mismatch:", chain_random_batch_original_fragments[mask1, :].shape, chain_random[mask2, :].shape)
                
                
                # print("chain_random_batch_original_fragments shape", chain_random_batch_original_fragments.shape)
                ###################################


                # V_j_plus_coulomb_matrices_batch = compute_coulomb_matrices_batch(chain_j_plus.cpu())

                V_j_plus_coulomb_matrices_batch = compute_coulomb_matrices_batch(chain_j_plus_batch_original_fragments.cpu(), diagonalize=DIAGONALIZE)
                
                # save the absolute eigenvalues in decreasing order
                V_j_plus_sorted_eigenvalues = []
                for matrix in V_j_plus_coulomb_matrices_batch:
                    eigenvalues, eigenvectors = torch.linalg.eigh(matrix)
                    eigenvalues = torch.abs(eigenvalues)
                    sorted_eigenvalues, _ = torch.sort(eigenvalues, descending=True)
                    
                    V_j_plus_sorted_eigenvalues.append(sorted_eigenvalues)
                V_j_plus_sorted_eigenvalues_batch = torch.stack(V_j_plus_sorted_eigenvalues)
                
                #to use concept from the paper, we need Euclidean distance using sorted eigenvalues
                V_j_plus_euclidean_distance_batch = torch.linalg.norm(sorted_eigenvalues_original_molecule_batch - V_j_plus_sorted_eigenvalues_batch, dim=1)
                
                print("V_j_plus_euclidean_distance_batch", V_j_plus_euclidean_distance_batch)
                print("V_j_plus_euclidean_distance_batch shape", V_j_plus_euclidean_distance_batch.shape)

                V_j_plus_euclidean_distance = sum(V_j_plus_euclidean_distance_batch)
                
                print("V_j_plus_euclidean_distance", V_j_plus_euclidean_distance)

                
                # print("V_j_plus_frobenius_norm", V_j_plus_frobenius_norm)
                #@mastro computing for the whole molecule
                # V_j_minus_coulomb_matrices_batch = compute_coulomb_matrices_batch(chain_j_minus.cpu())
                
                V_j_minus_coulomb_matrices_batch = compute_coulomb_matrices_batch(chain_j_minus_batch_original_fragments.cpu(), diagonalize=DIAGONALIZE)

                # save the absolute eigenvalues in decreasing order
                V_j_minus_sorted_eigenvalues = []
                for matrix in V_j_minus_coulomb_matrices_batch:
                    eigenvalues, eigenvectors = torch.linalg.eigh(matrix)
                    eigenvalues = torch.abs(eigenvalues)
                    sorted_eigenvalues, _ = torch.sort(eigenvalues, descending=True)
                    
                    V_j_minus_sorted_eigenvalues.append(sorted_eigenvalues)
                V_j_minus_sorted_eigenvalues_batch = torch.stack(V_j_minus_sorted_eigenvalues)
                
                #to use concept from the paper, we need Euclidean distance using sorted eigenvalues
                V_j_minus_euclidean_distance_batch = torch.linalg.norm(sorted_eigenvalues_original_molecule_batch - V_j_minus_sorted_eigenvalues_batch, dim=1)
                
                V_j_minus_euclidean_distance = sum(V_j_minus_euclidean_distance_batch)
                
                print("V_j_minus_euclidean_distance", V_j_minus_euclidean_distance)

                

                # V_random_coulomb_matrices_batch = compute_coulomb_matrices_batch(chain_random.cpu())

                V_random_coulomb_matrices_batch = compute_coulomb_matrices_batch(chain_random_batch_original_fragments.cpu(), diagonalize=DIAGONALIZE)

                # save the absolute eigenvalues in decreasing order
                V_random_sorted_eigenvalues = []
                for matrix in V_random_coulomb_matrices_batch:
                    eigenvalues, eigenvectors = torch.linalg.eigh(matrix)
                    eigenvalues = torch.abs(eigenvalues)
                    sorted_eigenvalues, _ = torch.sort(eigenvalues, descending=True)
                    
                    V_random_sorted_eigenvalues.append(sorted_eigenvalues)
                V_random_sorted_eigenvalues_batch = torch.stack(V_random_sorted_eigenvalues)
                
                #to use concept from the paper, we need Euclidean distance using sorted eigenvalues
                V_random_euclidean_distance_batch = torch.linalg.norm(sorted_eigenvalues_original_molecule_batch - V_random_sorted_eigenvalues_batch, dim=1)
                
                V_random_euclidean_distance = sum(V_random_euclidean_distance_batch)
                
                print("V_random_euclidean_distance", V_random_euclidean_distance)


                #used to compute the second version of the expected value
                mean_V_random_coulomb_matrix = torch.mean(V_random_coulomb_matrices_batch, dim=0)
                
                eigenvalues_r, eigenvectors_r = torch.linalg.eigh(matrix)
                eigenvalues_r = torch.abs(eigenvalues_r)
                mean_V_random_sorted_eigenvalues, _ = torch.sort(eigenvalues_r, descending=True)
                
                mean_V_random_euclidean_distance = torch.linalg.norm(sorted_eigenvalues_original_molecule - mean_V_random_sorted_eigenvalues) #this is already the expected value 2

                for r_eucl_dist in V_random_euclidean_distance_batch:
                    euclidean_distance_random_samples.append(r_eucl_dist)
                
                

                marginal_contrib_euclidean_distance += (V_j_plus_euclidean_distance - V_j_minus_euclidean_distance)

                

            phi_atoms[fragment_indices[j].item()] = [0]    
            phi_atoms[fragment_indices[j].item()][0] = marginal_contrib_euclidean_distance/M #j is the index of the fragment atom in the fragment indices tensor

        print(data["name"])

        phi_atoms_fronebius_norm = {}
        for atom_index, phi_values in phi_atoms.items():
            phi_atoms_fronebius_norm[atom_index] = phi_values[0]
            
            # phi_atoms_hausdorff[atom_index] = phi_values[2]

        
        # Save phi_atoms to a text file
        with open(f'{folder_save_path}/phi_atoms_{data_index}.txt', 'w') as write_file:
            write_file.write("sample name: " + str(data["name"]) + "\n")
            write_file.write("atom_index,shapley_value\n")
            for atom_index, phi_values in phi_atoms.items():
                write_file.write(f"{atom_index},{phi_values[0]}\n")

            write_file.write("\n")
            # save sum of phi values for disance and cosine similarity
            sum_shapley_values = sum([p_values[0] for p_values in phi_atoms.values()])
            write_file.write("Sum of Shapley values:")
            write_file.write(str(sum_shapley_values.item()) + "\n")
            
            # write_file.write("Sum of phi values for hausdorff\n")
            # write_file.write(str(sum([p_values[2] for p_values in phi_atoms.values()])) + "\n")     
            
            # write_file.write("Average hausdorff distance random samples:\n")
            # write_file.write(str(sum(hausdorff_distances_random_samples)/len(hausdorff_distances_random_samples)) + "\n")      
            
            # write_file.write("Hausdorff distances random samples\n")
            # write_file.write(str(hausdorff_distances_random_samples) + "\n")

            # write_file.write("Frobenius norm of original molecule:")
            # write_file.write(str(frobenius_norm_original_linker.item()) + "\n")

            average_euclidean_distance_random_samples = sum(euclidean_distance_random_samples)/len(euclidean_distance_random_samples)

            write_file.write("Average Euclidean distance of random samples from original molecule:")
            write_file.write(str(average_euclidean_distance_random_samples.item()) + "\n")

            expected_value_1 = distance_with_itself - average_euclidean_distance_random_samples
            current_prediction = distance_with_itself - distance_with_itself

            expected_value_2 = mean_V_random_euclidean_distance

            write_file.write("Expected value (difference between distance of original sample with itself and average Euclidean distance of random samples from original molecule):")
            write_file.write(str(expected_value_1.item()) + "\n")
            
            ideal_sum_shapley_values_1 = current_prediction - expected_value_1

            approx_error_1 = sum_shapley_values - ideal_sum_shapley_values_1
            abs_approx_error_1 = abs(approx_error_1)
            write_file.write("Approximation error (expected value 1):")
            write_file.write(str(approx_error_1.item()) + "\n")
            write_file.write("Absolute approximation error (expected value 1):")
            write_file.write(str(abs_approx_error_1.item()) + "\n")

            write_file.write("Expected value (difference between distance of original sample with itself and Euclidean distance of the average random samples from original molecule):")
            write_file.write(str(expected_value_2.item()) + "\n")

            ideal_sum_shapley_values_2 = current_prediction - expected_value_2

            approx_error_2 = sum_shapley_values - ideal_sum_shapley_values_2
            abs_approx_error_2 = abs(approx_error_2)
            write_file.write("Approximation error (expected value 2):")
            write_file.write(str(approx_error_2.item()) + "\n")
            write_file.write("Absolute approximation error (expected value 2):")
            write_file.write(str(abs_approx_error_2.item()) + "\n")

            
            # write_file.write("Frobenius norm of random samples:\n")
            # write_file.write(str(frobenius_norm_random_samples) + "\n")

        if SAVE_VISUALIZATION:
            phi_values_for_viz = phi_atoms_fronebius_norm

            # Saving chains and final states
            for i in range(len(data['positions'])):
                chain = chain_batch[:, i, :, :]
                assert chain.shape[0] == keep_frames
                assert chain.shape[1] == data['positions'].shape[1]
                assert chain.shape[2] == data['positions'].shape[2] + data['one_hot'].shape[2] + model.include_charges

                # Saving chains
                name = str(i + start)
                chain_output = os.path.join(chains_output_dir, name)
                os.makedirs(chain_output, exist_ok=True)
                
                #save initial random distrubution with noise
                positions_combined = torch.zeros_like(data['positions'])
                one_hot_combined = torch.zeros_like(data['one_hot'])

                # Iterate over each atom and decide whether to use original or noisy data
                for atom_idx in range(data['positions'].shape[1]):
                    if data['fragment_mask'][0, atom_idx] == 1:
                        # Use original positions and features for fragment atoms
                        positions_combined[:, atom_idx, :] = data['positions'][:, atom_idx, :]
                        one_hot_combined[:, atom_idx, :] = data['one_hot'][:, atom_idx, :]
                        # atom_mask_combined[:, atom_idx] = data['atom_mask'][:, atom_idx]
                    else:
                        # Use noisy positions and features for linker atoms
                        positions_combined[:, atom_idx, :] = noisy_positions_present_atoms[:, atom_idx, :]
                        one_hot_combined[:, atom_idx, :] = noisy_features_present_atoms[:, atom_idx, :]

                #save initial distribution TODO: fix positions, they are not centered
                save_xyz_file(
                    chain_output,
                    one_hot_combined,
                    positions_combined,
                    node_mask[i].unsqueeze(0),
                    names=[f'{name}_' + str(keep_frames)],
                    is_geom=model.is_geom
                )

                # one_hot = chain[:, :, 3:-1]
                one_hot = chain[:, :, 3:] #@mastro, added last atom type (not sure whyt it was not included...) However, TODO check again -> is it the atomic_number? But checking dimensions it did not look like it. Anyway, this should have no effect since the charge/atomic_number is always 0 in our case
                positions = chain[:, :, :3]
                chain_node_mask = torch.cat([node_mask[i].unsqueeze(0) for _ in range(keep_frames)], dim=0)
                names = [f'{name}_{j}' for j in range(keep_frames)]

                save_xyz_file(chain_output, one_hot, positions, chain_node_mask, names=names, is_geom=model.is_geom)
                invert_colormap = True #dealing with a difference of 0-norm (norm always positive). Sum of Shapley values is always negative, so we need to invert the colormap since negative values are imporant for the generation.
                # if average_frobenius_norm_random_samples > frobenius_norm_original_linker:
                #     invert_colormap = True

                visualize_chain_xai(
                    chain_output,
                    spheres_3d=False,
                    alpha=0.7,
                    bg='white',
                    is_geom=model.is_geom,
                    fragment_mask=data['fragment_mask'][i].squeeze(),
                    phi_values=phi_values_for_viz,
                    invert_colormap=invert_colormap
                )

                # Saving final prediction and ground truth separately
                true_one_hot = data['one_hot'][i].unsqueeze(0)
                true_positions = data['positions'][i].unsqueeze(0)
                true_node_mask = data['atom_mask'][i].unsqueeze(0)
                save_xyz_file(
                    final_states_output_dir,
                    true_one_hot,
                    true_positions,
                    true_node_mask,
                    names=[f'{name}_true'],
                    is_geom=model.is_geom,
                )

                pred_one_hot = chain[0, :, 3:-1].unsqueeze(0)
                pred_positions = chain[0, :, :3].unsqueeze(0)
                pred_node_mask = chain_node_mask[0].unsqueeze(0)
                save_xyz_file(
                    final_states_output_dir,
                    pred_one_hot,
                    pred_positions,
                    pred_node_mask,
                    names=[f'{name}_pred'],
                    is_geom=model.is_geom
                )

            start += len(data['positions'])

        

Loading initial distribution of noisy features and positions.


  0%|          | 0/30 [00:00<?, ?it/s]

Using P: 0.5 0.5
Eigenvalues:  tensor([486.1635, 221.5723, 109.8293,  97.6918,  72.9451,  56.3599,  44.4242,
         43.0097,  37.3077,  35.9196,  28.2539,  26.2172,  25.0803,  21.4131,
         20.4573,  17.6412,  12.5304,  10.8840,   9.3483,   8.2528,   7.7797,
          5.9394,   5.6355,   3.2186,   2.8734,   1.3887])
Distance with itself:  0.0


  0%|          | 0/21 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([ 29.2820, 195.6797, 187.8434,   7.9311,  44.9555,  23.5938, 190.0940,
        186.2405, 185.8895, 195.2268, 204.3606, 193.3965, 187.3987, 200.2064,
        201.8010,  14.8407, 204.9212,  62.1918, 191.9179, 187.4413, 198.6812,
          9.4435, 140.9344,  10.7454, 197.6277,   9.5668,   1.0889,  16.6329,
        202.3304,   1.1725,   2.1998, 193.0203, 202.5475, 198.0352,   0.5815,
        197.1005,   2.1633, 206.1759, 198.0050, 209.4233, 188.9146, 178.5728,
        188.9802,   8.1255, 155.6246, 185.6276, 178.7892, 179.1718, 185.2625,
        197.8590, 184.3879,  13.0462, 174.7847,   9.2808, 196.4535,  49.3460,
        176.2815, 183.2512, 193.9929,   3.7857,   1.2503, 195.7056,   3.3422,
        182.1883, 191.9709, 215.0722,  19.6317, 176.7650, 205.3015, 182.5490,
        183.6769,   6.7964, 210.2058,   8.9618,   8.2801,   2.0916, 167.5323,
          1.7744, 176.6980, 202.0339, 207.0153, 186.7022, 196.8430, 200.0726,
         13.7367,  21.0868, 18

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([480.0014, 211.7881,  98.0171,  97.6209,  71.3782,  52.3456,  44.8204,
         38.8382,  37.1510,  29.9483,  27.1413,  26.3284,  23.8601,  20.4884,
         19.6285,  15.9451,  12.4038,  11.4803,   8.3094,   7.8054,   6.9945,
          5.7776,   5.1357,   3.1401,   1.4088,   1.2210])
Distance with itself:  0.0


  0%|          | 0/20 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([ 19.2733,  17.6260,  18.0253,  25.8764, 183.0693,  36.3548,  16.7830,
         24.4955,  19.5190, 186.7118,  23.1385,  12.4065, 226.6411, 194.3199,
        192.3869,  21.9302,  19.3234, 184.3106,  58.4864,  15.5810,  17.5260,
         25.8037,  17.8068,  33.0848,  16.8585,  13.8291,  54.0126,  29.7118,
         15.9356,  69.3679,  17.7388,  18.0837,  18.9931, 123.8778,  23.2563,
         45.2912,  29.9216,  74.9017, 182.8674,  65.4737,  24.4381,  14.6771,
        208.0268,  31.6178, 184.5694,  25.5172,  29.4334,  16.8715,  28.5305,
         13.3532,  16.3387,  21.7249,  35.1777,  22.1021,  33.4448,  20.2180,
         35.5787, 203.3236,  19.3535,  20.3494,  19.4714,  17.9219,   9.0838,
         24.4424, 159.6894,  13.6371, 182.6191, 301.8362,  26.4901,  83.9909,
          8.1797,  31.0667, 195.0525,  19.3726,  12.5365,  18.0485, 184.0735,
         13.2657,  17.9809,  27.8613,  17.9708,  16.5914,  28.6749,  18.2258,
         37.0244, 185.0490,  4

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([486.9536, 228.1979,  97.5160,  87.2960,  72.0618,  53.2051,  48.6235,
         44.0325,  36.9153,  30.9161,  29.4615,  25.8000,  24.4938,  22.4450,
         20.1318,  15.9136,  14.0244,  12.2681,  10.4707,   8.4132,   7.0365,
          6.0159,   5.7448,   3.8906,   2.7308,   1.0779])
Distance with itself:  0.0


  0%|          | 0/21 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([ 24.1879,  18.8537,  13.8108,  37.1051,  47.4860,  18.8132,  38.0433,
         16.5121,  21.3331,  17.1732,  10.2413,  17.5706,  16.5664,  13.1401,
         11.6221, 162.2549,  21.7822,  28.0014,  26.0473,  26.3149,  15.5428,
         17.4054, 221.6481, 168.8965,  14.4170, 169.3865,  16.1173,  14.5560,
         14.0750,  16.8272,  18.7942,  47.5128,  12.7688,  19.2430,  16.3314,
        174.4844,  42.1878,  26.0818,  13.5670,  27.4311,  16.0298,  26.8044,
          9.6313,  16.3035,  12.6063, 192.9194,  16.7409,  11.1460,  33.7788,
         21.5266,  22.9903, 160.1147,  33.1665,  16.0107,  56.6649,  16.0698,
         33.3245,  16.8721,  17.2287,  68.3361,  19.0289, 174.2154, 176.2261,
         16.4224,  39.1900,  22.1453,  16.5543,  25.7862,  33.0834, 180.6097,
          3.3529,  17.7959,  13.0258,  18.0750,  35.1284,  15.2836,  32.8820,
         20.4022,  12.7073, 106.4659,  14.6401,  16.6919,   4.7960,  15.0188,
         10.4601,  18.8556,   

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([486.6234, 233.5960, 106.9866,  93.7607,  74.2710,  67.9125,  49.2685,
         42.7473,  40.7505,  37.1893,  29.8288,  26.6984,  25.7915,  21.2466,
         20.5693,  18.4202,  15.9838,  12.0497,   9.2346,   8.1126,   7.5476,
          6.4070,   5.6174,   4.2356,   2.6100,   1.3372])
Distance with itself:  0.0


  0%|          | 0/19 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([ 17.1569,  27.9912,  18.0223,  29.3338,  68.2071,  15.9175,  30.7671,
         18.9817,  35.6157,  17.7418,  17.1499,  22.1016, 160.0131, 149.9437,
        178.4633,  13.7396,  51.4342,  21.8421,  21.1686,  20.9226,  20.4105,
         24.0638,  22.0462, 152.1661,  14.7744, 190.9703, 311.0798, 180.2615,
         28.8077,  61.5255,  28.0977,  37.2851,  34.9030,  24.2092, 175.2886,
         29.3946,  16.4152, 175.6825,  41.6506,  26.8930,  34.6388,  65.7458,
         22.6215,  18.7678, 158.0323,  29.9277,  22.5044,  21.8403,  21.7569,
         23.8103,  35.2526,  27.5381,  32.7531,  20.4421,  23.5373,  26.6828,
         32.4544, 163.0425, 315.5000,  23.2778,  12.4789,  20.2900,  23.3090,
         57.9732,  20.0390,  42.8630, 323.9583,  19.0520,  26.9419,  11.9490,
         19.3485,  29.6359,  39.2549,  26.4929, 173.6171,  58.0356, 165.0434,
         39.3516,  20.2051,  69.6744,  28.4794,  52.1771,  28.5608,  33.1888,
         20.8414,  20.7068, 10

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([480.7617, 215.9305, 103.8462,  96.5135,  75.3046,  53.4533,  46.9731,
         39.7207,  37.6640,  29.6239,  27.4725,  26.1363,  23.7765,  20.4853,
         18.5337,  15.7693,  12.0302,  10.8552,   9.3102,   8.1573,   7.7606,
          5.7282,   5.5684,   3.3275,   3.0751,   1.3579])
Distance with itself:  0.0


  0%|          | 0/19 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([191.4983,  35.4802,  14.7483, 204.6004, 196.1180, 175.8746,  21.3532,
         25.4086,  54.6696,  10.7180, 117.0009,  24.7874,  24.6591, 126.4085,
         20.8868,  16.8979,  21.6224,  13.8380,  24.7226,   7.6592, 181.3107,
         15.3908,  19.0387,  33.6025,  32.0887,  23.7869, 191.2419,  34.9225,
         10.5209, 199.5260,  23.0428,  20.8598,  93.1086, 173.3541,  23.0119,
         24.7093,  14.2492,   8.4957,  58.2966,  18.5647,  13.4412,  43.9923,
         79.2003, 179.0534,  10.8524,  17.8027,  11.3027,  14.6952, 179.2931,
         13.7984,  22.1664,  17.7820, 178.7008,  26.4941,  21.5632,  19.5046,
         13.3203,  13.3854, 177.2623,   8.4465, 188.9738,  23.1738,  31.1111,
        212.0373, 310.6904, 180.5968,  28.3445,  40.6153,  40.0313,  33.7675,
         18.6847, 182.7983,  13.5404,   8.0888,  11.2529,  18.4703,  34.1752,
         46.7065, 171.4613,  35.1120,  14.1282,  39.2379,  44.7348,  14.2532,
         46.6155,  18.5267, 18

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([483.6737, 219.2042, 100.2073,  97.1760,  76.5470,  51.5823,  49.2261,
         39.8227,  38.6371,  32.4873,  29.5620,  26.7336,  24.1068,  20.5043,
         19.9701,  16.4079,  12.7746,  11.6594,   8.9931,   8.2944,   7.7873,
          5.9173,   5.6664,   4.1743,   3.1394,   1.3820])
Distance with itself:  0.0


  0%|          | 0/18 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([191.5330,  19.1458,  21.2169,  21.6008,  20.4901,  21.9409,  36.3965,
        180.0605, 184.8649, 184.6085,  25.4278,  31.4439,  10.1567,   9.3319,
        172.7202,  37.2211,  25.8165,  18.5711,  77.1509,   6.1463,  64.7308,
         19.8977,  20.9258,  17.3971,  22.9120,  16.7279,  37.7980,  16.3665,
         25.2790,  21.3084, 249.8053,  22.9064, 178.6375,  19.0208,  15.1048,
        190.8223,  16.5641,  45.4110,  19.3459,  10.0563,  30.9690,  19.7912,
        173.3941,  17.3770,  24.0979, 178.1427,  30.4210,  19.7792, 181.8737,
         34.8357, 191.1653, 171.9511,  19.6940, 211.6850,  18.5028,  60.5663,
         15.7085,  26.7448, 176.4930,  16.7509,  17.8207, 131.8087, 177.9827,
        170.6565,  21.2092,  18.4399,  22.0989,  10.7020,  21.2178,  39.9236,
         18.2382,  34.1853,  39.4022,  16.9916,  19.4228,  11.4568,  22.2120,
         16.9235,  15.3089, 186.5746,  21.0693,   5.9439,  15.1267,  27.3607,
         14.8618,  24.4141,  1

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([584.4502, 378.0038, 192.2419,  98.3940,  92.5296,  77.1526,  47.6740,
         45.6914,  43.6965,  36.7549,  33.6695,  30.8012,  25.0778,  23.6728,
         20.7125,  19.5299,  14.0273,  13.3544,  10.8386,   7.8476,   6.7738,
          6.7163,   5.5941,   4.2678,   3.1213,   1.4796])
Distance with itself:  0.0


  0%|          | 0/20 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([204.7077, 199.6475,  57.8781, 221.1262, 212.7206, 208.8625,  36.9135,
        220.9723, 212.4064, 210.7005, 199.6614, 210.7999, 175.4955, 194.9224,
        212.5230, 212.7399, 204.8721,  41.1091, 181.2249, 205.3187, 205.6829,
        215.1194, 209.0822, 193.1768, 210.9208, 199.6858, 203.9737, 170.7792,
        212.1257, 216.3222, 219.8143, 212.3794, 221.0096, 185.3432,  25.0977,
        143.2501, 175.8331, 209.0335,  45.5707, 211.9256, 210.0455, 202.0428,
        196.6849, 211.8321, 220.4845, 194.4419, 201.2006, 222.1410, 213.2199,
        212.3319, 200.0858, 204.4177, 208.5654, 209.3270,  48.2375, 216.4365,
        175.7011,  45.6285, 186.0828, 205.0969, 213.3976, 210.4559, 204.8295,
        205.5378, 208.5933, 205.8173, 204.7942, 204.7348,  81.5239, 187.8021,
        193.7396, 214.0994, 174.3091, 190.8860, 197.5790, 200.3961, 210.0809,
        215.7864,  41.8765,  55.1716, 216.1578, 202.0981, 201.4725,  37.1684,
        197.6246, 217.9471, 21

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([487.4168, 222.3808, 118.6171,  97.0500,  69.7861,  58.9877,  47.6632,
         44.6143,  38.6599,  36.7257,  30.0019,  27.1914,  24.8951,  20.5561,
         20.2630,  16.3751,  14.1386,  11.5957,   9.4361,   8.0744,   7.7185,
          5.9384,   5.5930,   4.2801,   2.9656,   1.3705])
Distance with itself:  0.0


  0%|          | 0/20 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([ 27.6386,  17.5574,  29.9762,  35.1730,  12.7570, 160.1117,  14.0824,
        166.8971,  24.7946,  20.6307, 167.1653,  11.5432,  20.8684,  34.4849,
         41.4986,  21.9416,  18.5300, 180.5980,  47.2069,  23.5495,  13.2815,
         27.1172, 166.6543,  18.1681,  14.0919,  18.3364,  26.0669, 179.9926,
         11.5466,  30.1394, 171.4197, 167.7217,  23.7363, 172.2396,  18.3127,
        202.2526,  27.1685, 169.4910,  12.5084,  22.3440,  51.7093,  31.0893,
         18.4562,  41.8158,  12.9019, 172.1525,  55.0054, 171.8689,  26.8202,
         14.0163,  25.6892,  31.2482,  10.6512,  29.9493,  54.2507,  24.3311,
        192.5351,  26.8420,  13.6874,  19.5649,  14.9968,  20.3611,  27.3382,
         39.1676,  24.5473, 160.7271,  11.3314,  12.6135,  22.7625, 332.7226,
         16.6276,  27.9995,  25.6214,  27.3604,  18.3453, 163.4390, 166.9593,
         16.0116,  13.4896,  46.9768,  38.8210,  19.1417, 174.7561,  34.9450,
         32.5912,  12.7995,  1

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([535.8726, 245.6356, 134.8156,  96.2697,  89.2694,  78.8889,  60.8559,
         49.7505,  47.5697,  45.9503,  42.6771,  36.2914,  32.3540,  28.5133,
         25.0605,  24.6695,  23.1215,  13.5176,  10.8672,   9.1218,   8.7068,
          7.3584,   6.5374,   4.1899,   3.5142,   1.6342,   0.8072])
Distance with itself:  0.0


  0%|          | 0/24 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([9.9280e-01, 6.6418e+00, 1.1332e+01, 7.1658e+00, 1.5047e+01, 1.7026e+02,
        2.4534e+01, 1.6161e+01, 1.7310e+01, 8.8311e+00, 2.1196e+01, 9.2432e+00,
        5.7149e+00, 9.4225e-01, 1.0131e+01, 1.1328e+00, 1.4400e+02, 2.4468e+01,
        1.0631e+01, 1.1412e+01, 1.7757e+01, 5.9504e+00, 3.1262e+02, 1.1130e+01,
        1.7460e+02, 2.3803e+01, 3.7777e+00, 1.3832e+01, 3.5549e+01, 1.6110e+01,
        2.0053e+02, 1.5790e+01, 2.2912e+01, 1.1631e+01, 6.4446e+00, 8.0467e+00,
        3.2284e+01, 1.6782e+02, 1.1166e+01, 1.7805e+02, 1.2919e+01, 2.3186e+01,
        6.1246e-01, 2.6727e-01, 1.0149e+01, 6.7517e+00, 1.4416e+02, 3.7577e+00,
        1.5143e+01, 2.4646e+01, 2.0157e+01, 1.5741e+01, 1.4624e+00, 2.5099e+01,
        3.1673e+01, 1.5566e+02, 1.0128e+01, 2.1511e+01, 1.5142e+01, 1.5741e+02,
        6.6887e-01, 1.8537e+01, 1.4949e+01, 1.5121e+01, 1.5369e+02, 4.8692e+00,
        4.7149e+01, 2.6386e+01, 2.1960e+00, 1.1813e+02, 1.9736e+00, 1.2524e+01,
      

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([536.3613, 246.2879, 135.1267,  97.1169,  89.6137,  78.6801,  60.7861,
         49.7152,  47.5673,  46.0929,  43.0774,  35.2906,  32.2889,  28.2103,
         25.0234,  24.6235,  22.9951,  13.4584,  10.3435,   9.1130,   8.5808,
          7.3760,   6.5325,   3.9570,   3.3231,   1.5528,   0.7257])
Distance with itself:  0.0


  0%|          | 0/21 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([ 32.6550,  63.3621,  23.1980,  11.5951,  49.9559,  36.2223,  34.6460,
        171.5938,   5.8767,   5.9769,   5.9650,   6.2124, 186.5070,   7.1997,
         24.7430,  42.7583,  22.3865,  27.4964,  16.4104,  30.2316,   5.8941,
          5.4710,  30.6681,   3.3198,   2.9776,   9.8742,   5.9312,  34.5634,
         36.2650,   0.9351,  62.2720, 102.4333,  10.8902,   0.5883,   5.7756,
         19.0334,   8.0051,  41.7734,   6.2045,   4.3720,  60.4298,  44.4304,
         34.9509,  14.7902, 223.2973,  21.1603,   2.3955,  29.5424, 166.6659,
         22.8822,  25.4746,  15.1740,  17.3032,  35.6716,  59.3313,  35.4020,
         61.6811,  38.1190,  35.5624,  11.4587,   9.7958,  32.3624,  40.7509,
         10.9225,  47.5306,  46.7666,  12.4519,  64.7925,  22.2310,   8.5150,
         36.3930,   1.0384,   9.0634,  10.8029,  17.7934, 159.6093,  15.6896,
          7.5459, 151.4655,  10.8671,   6.4271,  18.0836, 158.6834,   7.9536,
         26.1788,   4.9288,  1

  imgs = [imageio.imread(fn) for fn in save_paths]
Atom "[C@H]" contains stereochemical information that will be discarded.
Atom "[C@@H]" contains stereochemical information that will be discarded.


Using P: 0.5 0.5
Eigenvalues:  tensor([259.6281, 122.5944,  73.7747,  60.9700,  50.9388,  47.4994,  36.4499,
         32.3025,  28.6719,  24.4959,  22.3759,  20.4797,  14.2379,  13.1327,
          9.7496,   9.2467,   7.2203,   6.2896,   4.9178,   1.3693,   0.9921])
Distance with itself:  0.0


  0%|          | 0/15 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([ 13.9730,  28.0341,  39.3931,  15.7071,   8.0568, 112.5596,   8.8990,
         34.6815,  20.1187,  76.7789, 371.1994, 208.7872,  12.5684,  13.8893,
         27.5180, 254.1946, 211.6734,  12.6404,  20.3029,  45.3631,  24.2922,
        224.2511,  15.1021,  19.6800,   9.7545, 224.5037,  16.8795,  14.5578,
        227.4111,  32.3654,  63.3211,  11.9028,  35.7708,  21.4620, 233.8381,
         28.5810, 241.1537,  33.8061,  21.6329, 220.9795,  10.5667,  16.7435,
         17.7559, 232.2864,  10.9715,  33.1391, 220.5666,  26.0337, 232.1142,
        258.5623, 246.9089,  23.6224, 258.0575, 232.4604,  21.0258,  15.5499,
         22.5937,  17.9837,  12.6981,  13.6513,  19.2730,  17.5353,  17.7176,
         11.7909,  16.7112,  12.6226,  18.2190, 223.8102,  14.3592,  15.1757,
         26.4650,  17.3030,  12.1617,  12.2211,  11.4297,  23.8257,  30.6131,
         10.1540,  31.1948,  10.4329,  19.0521,  14.4594,   7.7428,   8.3824,
         11.4975, 223.2681,   

  imgs = [imageio.imread(fn) for fn in save_paths]
E/Z stereochemical information, which is specified by "/", will be discarded
E/Z stereochemical information, which is specified by "/", will be discarded


Using P: 0.5 0.5
Eigenvalues:  tensor([265.8971, 143.5076,  88.7752,  66.5733,  61.9102,  39.2709,  36.7242,
         33.9965,  28.8587,  25.3789,  23.9767,  22.5604,  19.6263,  15.3631,
         10.7128,   9.8295,   7.3719,   5.7561,   5.4864,   5.3486,   2.8160,
          1.1139])
Distance with itself:  0.0


  0%|          | 0/17 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([ 16.2535, 196.0699,  15.5940, 225.7361,  16.6073,  10.6868,  18.3226,
        216.2188,  17.5008,   2.8345,  14.1663,  14.6605,  19.3694,  15.9312,
         14.3616,  24.2680,  50.7346,  22.2671, 256.8616,  11.8918,   7.0669,
         48.3785,  14.9729,  14.2757, 225.9941,  11.8998,  13.0423,  37.3139,
         10.0942,  10.7072,  12.1058,  38.0012,  17.5079,   9.7536,  16.7339,
         10.0284,  16.1491,  11.4375,  12.9194,  96.0111,  14.1338,  23.7140,
         33.5188,   6.9974,  11.5260,  13.0954,  13.6177,  10.8778, 229.0078,
          9.8419,  28.0367,  10.7367,  45.4071,  16.2111,  14.8038,  14.2066,
          8.6316,  17.7354,  16.5731,   8.0796, 208.0694, 243.4115,  14.9345,
         10.6146,   8.8235, 203.9398,  28.1952,  11.2034,  10.6699,  17.3940,
         12.4633, 223.5596,  10.9674, 241.2750,   9.6568,  16.2056,  54.7462,
         34.5427,  17.9439,  20.2786,  56.2409,   9.2078, 229.9485,  19.3575,
          6.3804,  13.9216, 38

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([294.9877, 144.5690,  88.9510,  75.5424,  75.3625,  55.8349,  44.4692,
         41.2536,  39.2758,  33.3222,  28.4090,  24.0219,  19.5964,  14.5829,
         13.2132,  12.1784,  11.4374,  11.0817,  10.7805,   8.4181,   5.9351,
          4.9968,   4.1208,   1.5136,   0.5756])
Distance with itself:  0.0


  0%|          | 0/21 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([5.9120e+00, 5.4696e+01, 2.6597e+01, 1.2239e+01, 3.5363e+02, 3.0310e+01,
        1.9832e+02, 1.7286e+01, 2.5039e-01, 3.7330e-01, 1.0910e+01, 2.0837e+02,
        1.6095e+01, 8.5821e+00, 2.1625e+02, 1.6560e+01, 1.2001e+01, 5.4511e+02,
        1.4318e+01, 1.9564e+02, 7.7591e+00, 8.1131e+00, 2.0758e+02, 1.4042e+02,
        1.8420e+01, 8.4587e+00, 3.3645e-01, 1.4496e+01, 1.9759e+02, 2.1746e-01,
        3.0482e+00, 3.4448e+01, 1.8920e+02, 9.0546e+00, 1.0747e+00, 6.8366e+01,
        4.7076e+01, 2.3181e+01, 5.7745e+00, 1.7369e+00, 1.8424e+02, 1.2948e+01,
        1.5273e+01, 1.7789e+00, 1.3797e+01, 1.9485e+02, 1.6118e+01, 7.3473e+00,
        1.8312e+02, 1.7586e+01, 1.9022e+01, 1.8076e+01, 1.0946e+01, 1.7307e+00,
        2.7103e+01, 4.2676e+00, 1.6177e+01, 1.0866e+01, 2.0792e+02, 1.0187e+01,
        3.3777e+00, 1.5710e+01, 2.8987e+01, 4.8710e+00, 2.0192e+02, 3.6437e+01,
        1.6241e+01, 1.9245e+02, 1.6115e+01, 2.0436e+02, 1.4668e+01, 7.6105e+00,
      

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([296.8383, 138.8223,  88.4451,  77.9628,  70.5080,  57.7639,  40.6985,
         39.1632,  34.4988,  31.5891,  25.8103,  22.6227,  20.2767,  16.1930,
         13.0881,  12.6045,  11.3263,  10.4740,   9.5053,   8.5642,   6.5463,
          5.0027,   3.9803,   1.4118,   0.5753])
Distance with itself:  0.0


  0%|          | 0/20 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([ 55.9854,  10.4437,  11.1798,  19.3606,  11.6071, 213.3204,  46.6092,
         32.6928,  10.4480,  13.8825,  13.1499,  16.3712,  30.8728,  12.5871,
         18.5809,  20.1769,  11.5058,  13.8207,  28.7335,   9.9585,  12.7185,
         49.0372,  12.4568,  21.4265,  19.5278,  10.3075,  14.5329,  25.2762,
         14.2080, 198.3824,  32.6775,  11.2808,  18.8144,  55.4629,  20.4815,
         16.9864,  10.4117,  18.7608,  83.9628,  22.8000,  15.0700,  23.6107,
         17.3754,  26.4670,  17.7502,  34.0724,  10.6308,  21.5326,  22.2716,
         16.4849, 206.1348, 322.9852,  11.3739,  41.4988,  46.0906,  14.7114,
         30.8147, 238.5767,  12.7618,  15.0426,  11.1836,  12.0900,  11.9049,
        214.2366,  17.4189,  11.1376,  18.5575,  17.0500,  15.2817,  12.4724,
         17.8836,  24.2966,  13.3247,  11.9605,  18.8832,  12.8323,  14.3819,
         14.0780,  18.6606,  10.7300,  48.6603,  17.8775,  12.8753,  46.6881,
         11.7081,  14.0250,  2

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([300.1075, 149.1676,  89.4126,  77.0583,  75.4998,  56.2900,  50.1766,
         41.3642,  39.2939,  33.3185,  28.8581,  24.9130,  19.6332,  14.6343,
         13.2231,  12.1899,  11.6487,  11.2047,  10.8636,   8.4449,   5.9423,
          4.9969,   4.2597,   1.5106,   0.5756])
Distance with itself:  0.0


  0%|          | 0/22 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([1.4151e-01, 3.5134e+02, 2.6060e+01, 3.5536e-01, 2.8676e-01, 3.6364e+01,
        4.6115e+00, 5.2347e+00, 2.1694e+02, 2.3360e+02, 8.5578e-01, 7.1071e+00,
        5.5483e+01, 5.0703e+00, 3.1191e+01, 1.0005e+00, 6.4112e+01, 1.1795e-01,
        1.8733e+02, 4.8377e-01, 1.6653e+01, 1.6751e+01, 3.2751e+00, 3.4553e+02,
        1.1951e+01, 1.4972e+01, 2.0626e+01, 7.1222e-01, 5.3194e+00, 1.0899e+01,
        1.5418e+01, 8.1850e+00, 8.3583e+01, 2.9024e+01, 7.2742e+00, 5.9546e-01,
        1.9904e+02, 5.2529e+00, 2.0694e+01, 3.7097e+00, 3.8216e+01, 2.0515e+02,
        3.1069e+01, 2.5691e+01, 1.7890e+02, 5.0568e+01, 4.6114e+00, 3.2276e+01,
        2.0971e+01, 5.7267e+00, 2.0783e+02, 2.2862e+00, 1.9677e+02, 1.4875e+01,
        6.8069e+00, 1.4137e+01, 1.4132e+02, 8.1698e+00, 2.4890e+01, 1.9407e+02,
        2.2961e+01, 3.3685e+01, 2.1178e+01, 3.7938e+01, 3.7178e+02, 5.2426e+00,
        3.3047e+01, 1.1045e+01, 4.1033e+01, 2.7402e+01, 4.1704e+01, 3.1780e-01,
      

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([294.7900, 144.6529,  88.9071,  75.5575,  75.3394,  55.8466,  44.6321,
         41.4007,  39.2971,  33.3197,  28.2781,  24.0161,  19.5693,  14.6161,
         13.2151,  12.1831,  11.4232,  11.0653,  10.7758,   8.4437,   5.9295,
          4.9968,   4.0566,   1.5423,   0.5756])
Distance with itself:  0.0


  0%|          | 0/20 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([2.3676e+01, 2.7052e-01, 3.6294e+02, 2.1914e+01, 3.4553e+00, 7.9690e+00,
        2.0597e+02, 2.5234e+01, 1.4012e+00, 1.5512e-01, 5.9933e+00, 5.0250e-01,
        2.0803e+02, 1.9548e+02, 3.5779e+01, 7.8767e+00, 5.9242e+00, 9.7317e+00,
        1.1660e+01, 4.8143e+00, 2.2567e-01, 1.6119e+01, 1.5927e+00, 2.6069e+01,
        9.0984e+00, 6.2805e+00, 1.1251e+01, 2.1878e+02, 6.5284e+00, 1.2812e+01,
        1.4123e+01, 2.1309e-01, 1.2239e+01, 2.4939e+02, 1.4203e+01, 1.2140e+01,
        1.2853e+01, 1.8432e+01, 1.8656e+02, 1.0317e+01, 6.6626e+01, 1.8887e+02,
        1.6930e+01, 3.9258e+01, 1.7126e+01, 2.0058e+02, 4.9839e+01, 1.2254e+01,
        1.3442e+01, 6.6958e+00, 1.8644e+01, 2.3448e+01, 7.5681e+00, 2.1289e+01,
        4.3351e+01, 2.4980e+00, 2.7519e+01, 9.2471e+01, 2.8746e+02, 2.0734e+01,
        2.4589e+00, 2.8052e-01, 9.6244e+00, 2.7041e+01, 2.8923e+00, 4.4751e+00,
        3.0941e-01, 1.7282e+00, 9.6242e+00, 7.6785e+00, 2.6947e+01, 5.8072e+01,
      

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([300.3823, 145.3219,  91.0881,  75.8643,  74.9850,  55.6501,  45.8723,
         43.3781,  39.3993,  33.9947,  30.0601,  25.3696,  19.8211,  14.6762,
         13.2365,  12.2922,  11.7116,  11.2386,  10.8827,   8.4550,   5.9450,
          4.9967,   4.2101,   1.5235,   0.5756])
Distance with itself:  0.0


  0%|          | 0/19 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([  11.0662,  214.6659,   18.9876,   16.7397,  246.4753,   11.3831,
          87.6973,   25.5970,   33.9390,   24.4023,  222.8906,   60.8771,
           5.8959,   10.0985,   15.0183,    9.8877,   22.3788,   15.5691,
          26.8083,    8.9970,   11.1915,   64.2643,    7.5316,  204.6352,
         206.9661,   14.3936,   18.5330,  349.0903,    7.3066,  383.1369,
          14.2647,   35.1836,   21.9650,    5.7208,   20.6654,  294.6888,
          55.5826,    7.4276, 2375.0457,  200.8024,   22.9373,   36.3203,
          23.0672,    7.2873,   11.9541,   29.4294,  184.8007,   44.1785,
          11.7176,    6.7956,   23.2067,   29.4265,    5.4142,   18.2292,
           9.4148,   18.3866,   27.4191,    6.5557,   17.4482,   17.8670,
          22.9439,    9.2086,   14.8093,   21.0117,   26.6245,   30.5547,
           7.9974,   20.6851,  203.3399,   30.3904,  193.7510,   13.8819,
          20.8848,    9.2320,    4.5115,  211.0835,  321.9850,  385.7507,
    

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([294.2289, 144.6435,  89.0605,  75.8237,  75.3936,  55.8652,  44.5876,
         41.2553,  39.2707,  33.3210,  28.2935,  24.1188,  19.6094,  14.5695,
         13.2189,  12.1797,  11.4383,  11.0914,  10.7749,   8.4245,   6.0217,
          4.9969,   4.1564,   1.5100,   0.5756])
Distance with itself:  0.0


  0%|          | 0/21 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([1.4219e+01, 1.9551e+01, 1.4294e+01, 5.8840e+01, 3.7237e+02, 1.9062e+01,
        4.0565e+01, 2.0490e+02, 2.3285e-01, 7.9894e+00, 2.7985e+01, 2.6211e+01,
        2.0057e+02, 1.4686e+01, 2.2450e+01, 1.6281e+01, 2.4490e+01, 2.2240e+02,
        3.6824e+01, 2.0079e+02, 5.6758e-01, 9.2284e+00, 4.3860e+01, 8.2308e+00,
        1.0286e+01, 8.5303e+00, 1.0567e+01, 1.1910e+01, 1.5751e+01, 1.5608e-01,
        1.0347e+01, 2.2006e+02, 3.2473e+01, 4.1273e-01, 3.9704e-01, 1.8326e+02,
        2.2820e+01, 1.1151e+01, 1.8952e+02, 2.5673e+00, 8.3099e+01, 2.4487e+01,
        2.0850e+01, 2.6718e+02, 1.5108e+01, 2.3819e+01, 2.4495e+02, 4.2931e+00,
        1.2798e+01, 2.6339e+01, 3.6709e+01, 1.2021e+01, 2.0819e+02, 1.3997e+01,
        1.9813e+02, 1.4738e+01, 1.9206e+02, 3.4346e+01, 2.3886e+01, 1.5455e+01,
        4.8164e+00, 3.7336e+01, 4.2414e+01, 1.0651e+01, 2.2114e+01, 2.0796e+01,
        1.2054e+01, 2.4622e+01, 1.9741e+01, 7.6341e+00, 5.4871e+00, 5.5207e-01,
      

  imgs = [imageio.imread(fn) for fn in save_paths]
Atom "[C@H]" contains stereochemical information that will be discarded.
Atom "[C@H]" contains stereochemical information that will be discarded.


Using P: 0.5 0.5
Eigenvalues:  tensor([2.5566e+03, 2.5905e+02, 1.1395e+02, 8.3379e+01, 5.9738e+01, 5.3285e+01,
        4.0797e+01, 3.5104e+01, 3.1189e+01, 3.0136e+01, 2.2039e+01, 1.8070e+01,
        1.3175e+01, 1.1778e+01, 9.8619e+00, 9.0296e+00, 7.5663e+00, 6.1088e+00,
        3.6540e+00, 1.7238e+00])
Distance with itself:  0.0


  0%|          | 0/17 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([6.4520e-01, 2.6691e+01, 3.1321e+01, 1.6649e+01, 3.0890e+00, 1.2743e+01,
        5.3025e+00, 2.2500e+00, 2.6184e+01, 2.9209e+00, 2.0067e+02, 1.6523e+00,
        1.8076e+01, 2.3868e-01, 3.4386e+00, 3.7856e+01, 2.2519e+01, 2.3118e+01,
        4.7799e-01, 3.3069e+01, 4.7264e+00, 8.9247e+00, 2.0744e+02, 1.8407e+01,
        1.0793e+01, 2.0771e+01, 3.3302e+01, 1.3606e+01, 2.7049e+01, 1.1843e+01,
        7.9743e+00, 4.4408e+00, 3.2146e+01, 6.9601e+00, 2.7490e+01, 2.6785e+01,
        6.1557e-01, 1.1444e+01, 1.9798e+02, 2.6005e+01, 3.6699e+01, 2.1133e+02,
        2.7000e+01, 3.9652e+00, 9.6139e+00, 2.3515e+01, 1.8874e+00, 1.8891e+02,
        1.6719e+01, 2.3431e+01, 1.1584e+01, 5.4309e-01, 2.5799e+01, 3.6148e+00,
        2.1827e+02, 3.5790e+02, 1.5632e+00, 8.9847e+00, 1.9462e+02, 1.4984e+01,
        6.0167e+00, 2.0550e+02, 5.8325e+00, 1.0841e+01, 5.2224e-01, 2.0415e+02,
        2.4217e+02, 6.4767e-01, 2.5156e+01, 3.0011e+01, 7.5560e+00, 1.8062e+01,
      

  imgs = [imageio.imread(fn) for fn in save_paths]
Atom "[C@H]" contains stereochemical information that will be discarded.
Atom "[C@H]" contains stereochemical information that will be discarded.


Using P: 0.5 0.5
Eigenvalues:  tensor([2.5564e+03, 2.5264e+02, 1.1354e+02, 7.7258e+01, 5.9154e+01, 5.1502e+01,
        3.5843e+01, 3.2119e+01, 2.8885e+01, 2.7814e+01, 2.2320e+01, 2.1393e+01,
        1.4720e+01, 1.1951e+01, 1.0582e+01, 8.9792e+00, 8.1457e+00, 6.0183e+00,
        5.0580e+00, 1.6878e+00])
Distance with itself:  0.0


  0%|          | 0/14 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([ 18.9287,  19.2830,  22.8053,  10.9986, 258.4128,  19.1951,  18.7373,
         19.2093,   6.9696,  21.2582,  10.5754,  16.0290,  14.2998,  14.1910,
         40.6690,  14.4922, 195.1182, 346.2870,  16.5707,  18.5469,  13.3130,
         16.8970,  11.5590,  13.2269,  16.0252,  16.9740, 268.4601, 212.4248,
         23.5004,  23.6335,  19.0653,  17.6577,  15.3166,  22.1760,  68.6002,
         10.0785, 239.7498,  33.8697,  12.8770, 247.1029, 243.0360,  17.0669,
         14.5595,  15.5080,   7.5222,   9.4936,   8.5671,  34.7417,   6.9181,
         14.2950,  11.0586,  40.5708,  21.3410,  20.6644,  18.3441,   9.4479,
         37.3430,  17.9358, 201.3105,  24.2541,  18.2497,  17.8648,  10.6686,
         17.9320, 202.9723,  18.1252,  16.6400, 213.7291,  11.9085,  18.1828,
         23.5172,  14.6395,  13.6438,  19.8030,  33.0587,  30.4000, 208.5706,
          8.5805,  12.6470, 231.2804,  15.1853, 203.2327,  78.0055,  18.8437,
         26.1684,   9.0077,  5

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([3.2459e+02, 1.4823e+02, 1.0246e+02, 8.4241e+01, 7.1157e+01, 5.9038e+01,
        4.6924e+01, 4.4622e+01, 3.7031e+01, 3.3530e+01, 3.1143e+01, 2.7444e+01,
        2.4120e+01, 2.1807e+01, 1.7085e+01, 1.3043e+01, 1.0289e+01, 9.5279e+00,
        7.8205e+00, 6.3381e+00, 5.0829e+00, 4.0949e+00, 3.1387e+00, 1.1896e+00,
        1.4148e-01])
Distance with itself:  0.0


  0%|          | 0/21 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([8.3069e+00, 6.2210e+02, 1.6875e+01, 1.4215e+01, 1.7542e+02, 2.4143e+02,
        8.3453e+00, 3.2740e+01, 1.3122e+00, 3.7251e-01, 2.5404e+01, 1.8281e+01,
        2.2890e+02, 1.8878e+00, 1.9957e+01, 1.6306e+01, 1.8112e+02, 3.3639e+02,
        1.7351e+02, 3.8335e+01, 5.4223e-01, 7.7685e+00, 2.4268e+02, 1.8690e+02,
        1.7724e+02, 8.7986e+00, 8.5526e+00, 6.9476e+01, 2.9893e+00, 4.1891e-01,
        1.3536e+00, 2.8397e+01, 1.2783e+00, 8.7631e-01, 7.6056e-01, 3.0000e+01,
        1.6944e+02, 3.2474e+00, 1.6526e+01, 2.1652e+00, 1.7430e+02, 2.3627e+01,
        1.8484e+02, 1.5987e+01, 1.8814e+02, 1.8665e+01, 1.7411e+02, 1.6998e+01,
        1.8429e+01, 8.9860e+00, 1.7568e+02, 3.2259e+00, 1.7446e+02, 8.4538e+00,
        2.7391e+01, 2.4318e+00, 2.6550e+01, 2.8037e+01, 9.2633e+00, 2.8251e+01,
        1.7401e+02, 2.2071e+01, 9.8338e+00, 3.3075e+02, 1.1521e+01, 3.1091e+00,
        3.1999e+02, 3.3802e+01, 2.7363e+01, 1.7375e+02, 4.5594e+00, 8.5698e+00,
      

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([3.2384e+02, 1.4840e+02, 1.0311e+02, 8.4023e+01, 7.1310e+01, 5.9411e+01,
        4.6967e+01, 4.4640e+01, 3.6951e+01, 3.3525e+01, 3.1190e+01, 2.7398e+01,
        2.4219e+01, 2.1808e+01, 1.6792e+01, 1.3057e+01, 1.0286e+01, 9.5232e+00,
        7.8145e+00, 6.2837e+00, 5.0825e+00, 4.0151e+00, 3.1122e+00, 1.2011e+00,
        1.2534e-01])
Distance with itself:  0.0


  0%|          | 0/22 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([2.1630e+01, 2.1168e+02, 2.2785e+01, 1.9818e-01, 1.2778e+00, 2.4554e+01,
        4.7813e+00, 2.8949e+01, 2.0005e+01, 3.3505e+00, 6.5718e-01, 4.9796e+00,
        1.2949e+00, 1.5898e+00, 2.8124e+01, 8.3284e+00, 1.7271e+02, 3.2503e-01,
        5.5212e+00, 8.9655e-01, 1.4116e+01, 2.1643e+01, 1.1043e+00, 2.7344e+01,
        9.8145e+00, 1.9217e+01, 1.3784e+01, 1.0398e-01, 2.3098e+00, 3.6237e+00,
        2.0767e+01, 4.8277e+00, 9.3511e+00, 1.9182e+01, 1.3576e+01, 4.2195e+01,
        1.8270e+01, 1.7611e+02, 1.2530e+00, 1.5016e+00, 4.0558e+01, 8.5975e+00,
        1.7543e+02, 4.1493e+01, 2.8239e+01, 7.6060e-01, 1.7427e+02, 1.9203e+02,
        3.5676e+02, 7.0352e+01, 2.1460e+01, 1.5030e+02, 2.7540e+01, 2.9788e+00,
        5.2248e+00, 1.8454e+02, 7.9864e+00, 2.2103e+00, 2.3229e+00, 1.4076e+01,
        1.6440e+01, 4.0244e+01, 9.0565e+00, 2.7035e+01, 1.9621e+02, 2.8085e+01,
        2.2995e+02, 3.5392e+02, 1.5290e+01, 1.3145e+01, 2.1318e+00, 1.1766e-01,
      

  imgs = [imageio.imread(fn) for fn in save_paths]
Atom "[C@H]" contains stereochemical information that will be discarded.


Using P: 0.5 0.5
Eigenvalues:  tensor([3.2583e+02, 1.4911e+02, 1.0187e+02, 8.3573e+01, 6.7021e+01, 6.0446e+01,
        5.4304e+01, 4.0363e+01, 3.8102e+01, 3.5461e+01, 3.2711e+01, 3.2232e+01,
        3.0533e+01, 2.6380e+01, 2.5355e+01, 1.9367e+01, 1.5187e+01, 1.4128e+01,
        1.1574e+01, 1.1387e+01, 9.2471e+00, 8.6327e+00, 7.9766e+00, 7.7368e+00,
        6.0829e+00, 5.7301e+00, 4.6655e+00, 3.5379e+00, 3.2639e+00, 3.8746e-01,
        2.3063e-01])
Distance with itself:  0.0


  0%|          | 0/23 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([194.3829,  24.4758,  22.8984,  19.2036,  21.8515, 187.3132,  14.3642,
         27.4306,  14.4593, 200.2936,  19.2164,  40.5139,  18.3236,  39.7337,
         19.7795,  18.3796, 200.9207,  18.7560,  40.8648,  17.2835, 255.9333,
         13.1379,  54.4107,  26.4483, 250.9463,  27.4136,  18.6285, 192.9122,
        348.6761,  27.3627,  20.8683,  12.5407,  20.8169,  66.3505,  64.8649,
        338.4313,  13.8354,  16.3186,  16.4701, 270.9882,  14.5437,  36.7191,
         44.3740,  18.2343,  13.1384,   9.6749,  12.9658,  17.4133,  16.0178,
         13.5696,  62.9141, 219.5816, 216.3140, 178.9054,  28.5316,  21.1069,
         30.0244, 201.5182,  35.6011,  33.6776,  13.3003,  22.2482,   8.8120,
        200.2350,  23.8511,  12.2209,  18.0127, 243.9013,  26.1953,  19.6947,
         13.5979,  30.3705,  17.9768, 221.3603,   9.2871,  25.0339,  17.8396,
         18.1873,  20.0592, 268.3671,  29.0173, 255.1364,  10.7250, 232.0093,
         20.7913,  25.8244,  3

  imgs = [imageio.imread(fn) for fn in save_paths]
Atom "[C@H]" contains stereochemical information that will be discarded.


Using P: 0.5 0.5
Eigenvalues:  tensor([4.9692e+02, 2.3406e+02, 1.4920e+02, 1.0312e+02, 8.6529e+01, 6.5950e+01,
        6.0228e+01, 5.7145e+01, 4.5821e+01, 3.9349e+01, 3.5346e+01, 3.3506e+01,
        3.1003e+01, 2.9789e+01, 2.6511e+01, 2.2777e+01, 2.0895e+01, 1.6938e+01,
        1.2962e+01, 1.1544e+01, 1.0910e+01, 8.6330e+00, 7.8503e+00, 6.7301e+00,
        6.0610e+00, 5.0491e+00, 4.1769e+00, 3.6184e+00, 3.4449e+00, 4.5730e-01,
        2.1950e-01])
Distance with itself:  0.0


  0%|          | 0/21 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([191.4356, 160.8669, 200.8104, 158.0526, 180.2245,  20.7129,  41.2451,
         24.5935, 202.7976,  47.0520,  30.2900, 199.2116, 153.5311, 186.6315,
        191.7044, 188.3041, 198.0774, 201.5158,  27.3854,  23.3509, 194.1261,
        198.2726, 165.2984, 195.8090, 182.6823,  21.2900, 196.7565,  24.9221,
        200.7385, 212.4787, 195.4424, 181.5557, 191.7898, 173.1231, 188.5251,
        185.6001, 193.5361,  33.9724,  53.3015, 202.4472, 205.4954, 199.4341,
        196.5958, 195.8631, 174.6996, 210.0723,  40.8376, 178.7765, 189.6479,
        119.7356, 193.7326, 186.9612, 177.7339, 195.8507, 193.2930,  45.3986,
        180.2448, 197.8853,  57.0171, 201.0200, 196.2431,  28.7776, 180.7282,
        185.7131, 171.3282, 200.8833,  47.6529, 186.7133, 177.0873, 211.9812,
        189.8401,  49.0172, 183.1154, 185.0351, 184.8805, 195.5691, 178.0925,
         20.7282, 158.6395, 199.0649,  31.2861, 183.7152,  18.4262,  34.1946,
         25.6180, 199.8760, 20

  imgs = [imageio.imread(fn) for fn in save_paths]
Atom "[C@H]" contains stereochemical information that will be discarded.
Atom "[C@@H]" contains stereochemical information that will be discarded.


Using P: 0.5 0.5
Eigenvalues:  tensor([626.0416, 433.7506, 336.6116, 137.3661,  98.9888,  69.6480,  52.2038,
         49.7612,  44.6046,  43.8237,  40.4825,  33.6721,  30.2409,  21.8902,
         20.4926,  19.1883,  13.0791,  11.6486,   9.1017,   8.5773,   8.5017,
          5.7944,   5.7035,   3.1035,   1.2612])
Distance with itself:  0.0


  0%|          | 0/22 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([206.6587, 158.8056, 207.7077, 194.5312, 194.3670, 214.8472,  37.3895,
        199.7439, 191.6622, 204.9735, 197.5089, 198.3014, 204.4177,   9.6012,
         40.7353, 194.3148, 215.6723, 211.5068, 205.1510, 194.6081, 194.6180,
        143.4168, 207.1624, 202.3459, 210.7652, 196.3003, 216.2918, 194.3354,
        194.6040, 212.5809, 183.4798, 192.0023, 194.6964,  91.6044, 212.7041,
          9.4378,  44.9312, 200.0621, 208.7712, 146.4250, 403.2624, 208.1505,
         36.1127, 176.0454, 195.7384, 194.5751, 192.4331, 201.7193, 205.0678,
         34.2961,  29.5280, 194.3414, 211.7146, 200.4823, 199.6951, 205.6239,
        207.7725, 192.6157, 198.5303, 186.7932, 204.5968, 187.5925, 193.0996,
         46.3895, 203.6173,  37.9395, 213.7169, 197.7641, 214.1711, 178.7647,
        201.5425, 194.3160, 188.6512, 201.3413, 213.0362, 195.2248, 186.8781,
        216.4100,  43.9442, 169.1299, 214.2494, 195.4055, 209.0958, 208.0230,
          1.4468,  41.9923,  4

  imgs = [imageio.imread(fn) for fn in save_paths]
Atom "[C@H]" contains stereochemical information that will be discarded.
Atom "[C@@H]" contains stereochemical information that will be discarded.


Using P: 0.5 0.5
Eigenvalues:  tensor([546.9432, 400.9243, 162.2403, 111.4039,  78.1099,  60.7245,  51.3136,
         49.1582,  44.1862,  43.7659,  40.1602,  35.7430,  28.7966,  22.2982,
         20.4936,  19.2029,  13.3734,  11.6478,  10.1530,   8.6175,   7.8488,
          5.9856,   5.7553,   2.8847,   1.5299])
Distance with itself:  0.0


  0%|          | 0/19 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([1.4010e+01, 3.9525e+01, 7.4762e+00, 1.5654e+01, 4.3279e+01, 1.8825e+02,
        2.0417e+01, 1.9530e+02, 4.9453e+01, 4.5195e+00, 2.0357e+02, 3.1457e+02,
        8.7331e+00, 8.3715e+00, 7.2988e+00, 2.2857e+01, 3.5941e+01, 1.6931e+01,
        2.2320e+01, 8.5095e+00, 6.7890e+00, 1.3216e+01, 1.0620e+01, 4.8570e+01,
        2.6015e+01, 7.9567e+00, 1.8789e+01, 1.8118e+02, 9.0268e-01, 6.5742e+01,
        1.6452e+01, 1.4960e+01, 1.6614e+01, 8.0606e+00, 1.3585e+01, 1.5849e+02,
        8.7515e+00, 6.8697e+00, 2.9484e+01, 9.9285e+00, 2.7214e+01, 4.2860e+01,
        1.9241e+01, 1.4074e+01, 1.3634e+01, 8.6968e+00, 3.0723e+01, 4.6346e+01,
        7.9252e+00, 5.5550e+00, 2.2306e+01, 1.3407e+01, 1.9279e+02, 1.4724e+01,
        7.9094e+00, 3.1877e+01, 2.0353e+02, 2.4419e+01, 1.9245e+03, 3.6380e-01,
        7.0776e+00, 6.3887e+00, 1.9334e+02, 1.6229e+01, 1.5926e+01, 2.0438e+02,
        1.3058e+01, 2.0226e+02, 1.9400e+02, 3.1186e+02, 4.6470e+01, 6.8673e+00,
      

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([504.1693, 306.1032, 176.3479, 121.5795,  82.7982,  71.4049,  61.2944,
         60.3580,  55.0145,  50.6857,  49.1435,  43.2393,  41.6648,  37.0259,
         35.7273,  33.7320,  32.8588,  31.1487,  23.7692,  21.3986,  15.0738,
         11.8694,  11.0789,   8.9959,   8.9301,   8.1622,   7.2237,   5.5069,
          3.7186,   2.9685,   1.5999,   0.6948])
Distance with itself:  0.0


  0%|          | 0/29 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([139.2929,   3.0614, 270.0374,  23.4901,  17.6434,  13.4814,  12.0824,
         34.4228, 274.9957, 137.2461, 143.3166,  15.6772, 303.1297,   7.1058,
         15.9520,   9.3473,  16.7096, 153.2159,  14.1043,  16.1006,  17.2427,
          8.4717,  18.3793,   7.1504,  11.5344,   9.1291,  30.4400,   3.9363,
          6.0777, 100.4616,  30.2380, 146.5038,  14.5797,  22.1272, 143.4563,
         12.3560,  35.8645,  15.7351,   7.2125, 265.5487,  33.4628,   8.5767,
         59.1122,  10.2885,   7.7534, 141.2931,   7.9564, 274.7801,  20.2078,
         15.6510,  13.5118,  88.7523,   4.4626, 151.9206,  14.9990,   7.0467,
          8.6558,  12.1307,  12.5136,  27.1794,  24.7869,  11.9757,   6.1624,
         12.1712,   9.6698,  12.0642,  14.4407,   7.6318,   5.5790,  40.8586,
         15.8353, 143.6189,  36.5834,  17.5628,  23.3032,   7.1503, 148.3437,
        148.8298,  12.5784, 144.7647,   8.8873,   7.1782,   8.0261,   0.8145,
        260.9471,   7.5132,  2

  imgs = [imageio.imread(fn) for fn in save_paths]


Using P: 0.5 0.5
Eigenvalues:  tensor([502.8155, 298.6773, 178.0975, 114.1106,  87.2756,  75.8096,  61.4973,
         60.3288,  57.2420,  51.2829,  49.0905,  45.0336,  41.2496,  36.9271,
         36.4922,  35.1215,  33.2402,  24.0728,  22.2082,  20.9692,  18.3779,
         13.9411,  11.0876,   9.9399,   8.9960,   7.6062,   6.9101,   6.0708,
          4.6282,   3.4856,   1.4677,   1.2330])
Distance with itself:  0.0


  0%|          | 0/24 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([ 11.0765,  13.0014,  55.7201, 159.0366,  16.4692,  16.3978,  21.8926,
         17.0033,  16.1354, 159.9241, 154.2230, 158.9812,  14.7880,  14.7090,
         11.5427,  13.8587,  26.5964,  22.0558,  20.1874,  11.7426,  22.0278,
         15.1480, 188.7519,  16.2256,  15.1281,  11.3518,  31.2482,   8.4513,
        162.8951, 157.2655,  40.7226, 167.6995,  11.1154,  13.9886,  15.2779,
          9.4068, 122.2656,  33.5520,  19.3186,  16.5771,  12.7302, 186.2928,
         13.5723,  13.3563,  16.1230,  17.6809,   7.0701,  14.8803,  19.8626,
         14.4455,  18.5658,  21.5611,  14.3888,  18.5018,  27.1964,  26.3636,
        159.1260,  16.4203, 155.6741, 151.6629,  12.3502,  13.2053,  14.6920,
        151.3407,   8.0835,  22.4221,  16.5629,  20.2209,  10.7578,  21.4086,
         10.3228,  29.4259,  40.0875,  13.1545,  79.9678,  16.5947, 153.5840,
        154.9364,  11.6193,  14.7986,  21.8256,  17.0117,  27.3725,  13.9638,
         11.7390,  10.6683,  2

  imgs = [imageio.imread(fn) for fn in save_paths]
Atom "[C@H]" contains stereochemical information that will be discarded.


Using P: 0.5 0.5
Eigenvalues:  tensor([296.7147, 154.3707, 120.5945,  95.0836,  74.1174,  55.4184,  48.4762,
         41.8598,  40.7629,  38.6783,  34.6783,  30.3827,  27.6987,  22.6780,
         17.0303,  15.1798,  13.1858,  10.0131,   9.8656,   9.0284,   8.1123,
          7.1309,   6.8880,   3.3428,   2.8412,   1.8204,   1.6949])
Distance with itself:  0.0


  0%|          | 0/19 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([ 19.2912, 238.2178, 181.5109,  18.8176,  59.1771,  18.6418,  35.0705,
        129.7398,  37.7326,  22.8743,  15.9135,  23.0517,  17.8476,  12.0954,
         22.9307,  10.5814,  30.5990,  21.0225,  47.8160,  12.5468,  11.3673,
         17.7441, 188.7151,  37.2304,  35.4377,  15.7837,  17.4152,  22.0178,
         12.6550, 228.4251,  12.3387,  12.4688,  20.0459,  12.5401,  10.8708,
         51.6391,  16.4121,  17.6660,  18.0638,  15.8724,  27.8805,  34.3974,
         65.5770,  19.0619,  13.2000,  21.4626,  26.2114,  26.6367,  11.5467,
         11.3649,  32.4635, 256.7008,  17.5773,  11.6256,  38.8010,  27.6737,
        189.0455,  30.3447,   9.6313,  10.6845,  32.2618,  20.5175, 196.5948,
         25.7637,  16.8118,  55.9110,  20.5646,  20.3107,  36.1291,  15.9319,
         42.8209,  17.0559,  19.8117,  16.8677,  20.4012,  65.8762,  27.1579,
        192.7742,  19.8582,  26.5696,  18.7081,  42.6419,  27.4950,  11.4984,
         20.5429,  24.7619,  2

  imgs = [imageio.imread(fn) for fn in save_paths]
Atom "[C@H]" contains stereochemical information that will be discarded.


Using P: 0.5 0.5
Eigenvalues:  tensor([301.3159, 154.8412, 118.8316,  94.5245,  66.9194,  57.0750,  48.3185,
         41.8240,  40.4952,  38.7948,  33.6404,  31.6542,  26.0945,  24.5510,
         18.7649,  14.8965,  14.5999,  10.2100,   9.9958,   9.8463,   7.2438,
          7.0305,   6.2176,   3.6638,   2.7842,   1.8210,   1.6928])
Distance with itself:  0.0


  0%|          | 0/21 [00:00<?, ?it/s]

V_j_plus_euclidean_distance_batch tensor([ 19.8603,  45.7126, 206.7545,   8.5191,  13.7138,  36.1219, 231.4254,
         26.9163,  15.0550,  15.0221,  15.9996, 198.5509,  26.7285,  16.3702,
         27.6427,  17.4705,  28.2746,  41.4678,  16.7496,  23.8808,  10.6541,
         10.4401,  21.4208,  13.7374,  13.0191,  21.8610,  15.1478,  22.5711,
         24.2664,  14.8955,  16.0022,  27.4977, 195.9899,  14.8585,  14.8770,
         64.7235,  23.2771,  24.5445,  13.0595,   8.0867, 215.5795,  16.7363,
         21.5451,  15.8073,  24.4564,  26.6819,  24.2514,  19.7303,  12.0967,
         12.8632,  34.3804,  15.3744,  21.5097,   6.9597,  45.8560,  15.7964,
         23.6794,  14.4936,  15.9960,  33.6636,  20.6639, 191.7201,  22.6671,
         18.3097,  20.6792,   8.1492,  11.9178,  65.9453,  16.6548,  23.3119,
         16.4234,  14.7453,  11.2789,  21.2784,  22.1620,  16.6803,  22.5631,
         14.0821,  16.1702,  29.6484,  15.8310, 242.6979,  22.2344,  22.7887,
         39.9434,  57.3862,  1

  imgs = [imageio.imread(fn) for fn in save_paths]
