## Abaltion study on important/opposing features

Abaltion study consisting in keeping only important atoms for the generation to inspect the model's behavior

1. Read molecules from dataset
2. Load initial random distributions for atoms and features
3. Read shapley values
4. Mask out atoms whose Shapley values is TODO: find criterion

Possible criteria for atom removal:
* Remove atoms with Shapley value above/below average
* Keep only the most (top k or %?) Shapley values

### Import Libraries

In [1]:
import os
os.environ["http_proxy"] = "http://web-proxy.informatik.uni-bonn.de:3128"
os.environ["https_proxy"] = "http://web-proxy.informatik.uni-bonn.de:3128"

import yaml
import numpy as np
import random
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import imageio
import networkx as nx
from pysmiles import read_smiles

# from sklearn.decomposition import PCA

import torch

from src.lightning import DDPM
from src.datasets import get_dataloader
from src.visualizer import load_molecule_xyz, load_xyz_files
from src.molecule_builder import get_bond_order
from src import const



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# density = sys.argv[sys.argv.index("--P") + 1]
# seed = sys.argv[sys.argv.index("--seed") + 1]

# Load configuration from config.yml
with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)

checkpoint = config['CHECKPOINT']
chains = config['CHAINS']
DATA = config['DATA']
prefix = config['PREFIX']
keep_frames = int(config['KEEP_FRAMES'])
P = config['P']
device = config['DEVICE'] if torch.cuda.is_available() else 'cpu'
SEED = int(config['SEED'])

In [3]:
experiment_name = checkpoint.split('/')[-1].replace('.ckpt', '')

#create output directories
chains_output_dir = os.path.join(chains, experiment_name, prefix, 'chains_' + P + '_seed_' + str(SEED) + '_ablation_study')
final_states_output_dir = os.path.join(chains, experiment_name, prefix, 'final_states_' + P + '_seed_' + str(SEED) + '_ablation_study')
os.makedirs(chains_output_dir, exist_ok=True)
os.makedirs(final_states_output_dir, exist_ok=True)

# Loading model form checkpoint 
model = DDPM.load_from_checkpoint(checkpoint, map_location=device)

# Possibility to evaluate on different datasets (e.g., on CASF instead of ZINC)
model.val_data_prefix = prefix

print(f"Running on device: {device}")
# In case <Anonymous> will run my model or vice versa
if DATA is not None:
    model.data_path = DATA

model = model.eval().to(device)
model.setup(stage='val')
dataloader = get_dataloader(
    model.val_dataset,
    batch_size=1, #@mastro, it was 32
    # batch_size=len(model.val_dataset)
)

Lightning automatically upgraded your loaded checkpoint from v1.6.3 to v2.3.3. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint z:\Repositories\DiffSHAPer\difflinker\models\zinc_difflinker.ckpt`


Running on device: cpu


### Set random seeds

In [4]:
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
# torch.backends.cudnn.deterministic = True
np.random.seed(SEED)
random.seed(SEED)

### Utility functions

In [5]:
def arrestomomentum():
    raise KeyboardInterrupt("Debug interrupt.")

def draw_sphere_xai(ax, x, y, z, size, color, alpha):
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)

    xs = size * np.outer(np.cos(u), np.sin(v))
    ys = size * np.outer(np.sin(u), np.sin(v)) #* 0.8
    zs = size * np.outer(np.ones(np.size(u)), np.cos(v))
    ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, alpha=alpha)

def plot_molecule_xai(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom, fragment_mask=None, phi_values=None):
    x = positions[:, 0]
    y = positions[:, 1]
    z = positions[:, 2]
    # Hydrogen, Carbon, Nitrogen, Oxygen, Flourine

    idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM

    colors_dic = np.array(const.COLORS)
    radius_dic = np.array(const.RADII)
    area_dic = 1500 * radius_dic ** 2

    areas = area_dic[atom_type]
    radii = radius_dic[atom_type]
    colors = colors_dic[atom_type]

    if fragment_mask is None:
        fragment_mask = torch.ones(len(x))

    for i in range(len(x)):
        for j in range(i + 1, len(x)):
            p1 = np.array([x[i], y[i], z[i]])
            p2 = np.array([x[j], y[j], z[j]])
            dist = np.sqrt(np.sum((p1 - p2) ** 2))
            atom1, atom2 = idx2atom[atom_type[i]], idx2atom[atom_type[j]]
            draw_edge_int = get_bond_order(atom1, atom2, dist)
            line_width = (3 - 2) * 2 * 2
            draw_edge = draw_edge_int > 0
            if draw_edge:
                if draw_edge_int == 4:
                    linewidth_factor = 1.5
                else:
                    linewidth_factor = 1
                linewidth_factor *= 0.5
                ax.plot(
                    [x[i], x[j]], [y[i], y[j]], [z[i], z[j]],
                    linewidth=line_width * linewidth_factor * 2,
                    c=hex_bg_color,
                    alpha=alpha
                )

    

    if spheres_3d:
        
        for i, j, k, s, c, f, phi in zip(x, y, z, radii, colors, fragment_mask, phi_values):
            if f == 1:
                alpha = 1.0
                if phi > 0:
                    c = 'red'

            draw_sphere_xai(ax, i.item(), j.item(), k.item(), 0.5 * s, c, alpha)

    else:
        phi_values_array = np.array(list(phi_values.values()))

        #draw fragments
        fragment_mask_on_cpu = fragment_mask.cpu().numpy()
        colors_fragment = colors[fragment_mask_on_cpu == 1]
        x_fragment = x[fragment_mask_on_cpu == 1]
        y_fragment = y[fragment_mask_on_cpu == 1]
        z_fragment = z[fragment_mask_on_cpu == 1]
        areas_fragment = areas[fragment_mask_on_cpu == 1]
        
        # Calculate the gradient colors based on phi values
        cmap = plt.cm.get_cmap('coolwarm_r') #reversed heatmap for distance-based importance
        norm = plt.Normalize(vmin=min(phi_values_array), vmax=max(phi_values_array))
        colors_fragment_shadow = cmap(norm(phi_values_array))
        
        # ax.scatter(x_fragment, y_fragment, z_fragment, s=areas_fragment, alpha=0.9 * alpha, c=colors_fragment)

        ax.scatter(x_fragment, y_fragment, z_fragment, s=areas_fragment, alpha=0.9 * alpha, c=colors_fragment, edgecolors=colors_fragment_shadow, linewidths=5, rasterized=False)

        #draw non-fragment atoms
        colors = colors[fragment_mask_on_cpu == 0]
        x = x[fragment_mask_on_cpu == 0]
        y = y[fragment_mask_on_cpu == 0]
        z = z[fragment_mask_on_cpu == 0]
        areas = areas[fragment_mask_on_cpu == 0]
        ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, c=colors, rasterized=False)


def plot_data3d_xai(positions, atom_type, is_geom, camera_elev=0, camera_azim=0, save_path=None, spheres_3d=False,
                bg='black', alpha=1., fragment_mask=None, phi_values=None):
    black = (0, 0, 0)
    white = (1, 1, 1)
    hex_bg_color = '#FFFFFF' if bg == 'black' else '#000000' #'#666666'

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(projection='3d')
    ax.set_aspect('auto')
    ax.view_init(elev=camera_elev, azim=camera_azim)
    if bg == 'black':
        ax.set_facecolor(black)
    else:
        ax.set_facecolor(white)
    ax.xaxis.pane.set_alpha(0)
    ax.yaxis.pane.set_alpha(0)
    ax.zaxis.pane.set_alpha(0)
    ax._axis3don = False

    if bg == 'black':
        ax.w_xaxis.line.set_color("black")
    else:
        ax.w_xaxis.line.set_color("white")

    plot_molecule_xai(
        ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom=is_geom, fragment_mask=fragment_mask, phi_values=phi_values
    )

    max_value = positions.abs().max().item()
    axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
    ax.set_xlim(-axis_lim, axis_lim)
    ax.set_ylim(-axis_lim, axis_lim)
    ax.set_zlim(-axis_lim, axis_lim)
    dpi = 300 if spheres_3d else 300 #it was 120 and 50

    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
        # plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi, transparent=True)

        if spheres_3d:
            img = imageio.imread(save_path)
            img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
            imageio.imsave(save_path, img_brighter)
    else:
        plt.show()
    plt.close()

def visualize_chain_xai(
        path, spheres_3d=False, bg="black", alpha=1.0, wandb=None, mode="chain", is_geom=False, fragment_mask=None, phi_values=None
):
    files = load_xyz_files(path)
    save_paths = []

    # Fit PCA to the final molecule – to obtain the best orientation for visualization
    positions, one_hot, charges = load_molecule_xyz(files[-1], is_geom=is_geom)
    pca = PCA(n_components=3)
    pca.fit(positions)

    for i in range(len(files)):
        file = files[i]

        positions, one_hot, charges = load_molecule_xyz(file, is_geom=is_geom)
        atom_type = torch.argmax(one_hot, dim=1).numpy()

        # Transform positions of each frame according to the best orientation of the last frame
        positions = pca.transform(positions)
        positions = torch.tensor(positions)

        fn = file[:-4] + '.png'
        plot_data3d_xai(
            positions, atom_type,
            save_path=fn,
            spheres_3d=spheres_3d,
            alpha=alpha,
            bg=bg,
            camera_elev=90,
            camera_azim=90,
            is_geom=is_geom,
            fragment_mask=fragment_mask,
            phi_values=phi_values
        )
        save_paths.append(fn)

    imgs = [imageio.imread(fn) for fn in save_paths]
    dirname = os.path.dirname(save_paths[0])
    gif_path = dirname + '/output.gif'
    imageio.mimsave(gif_path, imgs, subrectangles=True)

    if wandb is not None:
        wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})

### Generation and Ablation Study

In [6]:
num_samples = 30
sampled = 0
start = 0

SAVE_VISUALIZATIONS = config['SAVE_VISUALIZATIONS']
INTIAL_DISTIBUTION_PATH = "results/explanations_" + P + "_seed_" + str(SEED)

data_list = []
for data in dataloader:

    if sampled < num_samples:
        data_list.append(data)
        sampled += 1

max_num_atoms = max(data["positions"].shape[1] for data in data_list)

# load initial distrubution of noisy features and positions
noisy_features = torch.load(INTIAL_DISTIBUTION_PATH + "/noisy_features_seed_" + str(SEED) + ".pt", map_location=device, weights_only=True)
noisy_positions = torch.load(INTIAL_DISTIBUTION_PATH + "/noisy_positions_seed_" + str(SEED) + ".pt", map_location=device, weights_only=True)

### Load data and read Shapley values

In [7]:
for data_index, data in enumerate(tqdm(data_list)):

    smile = data["name"][0]
    mol = read_smiles(smile)
    chain_with_full_fragments = None

    noisy_positions_present_atoms = noisy_positions.clone()
    noisy_features_present_atoms = noisy_features.clone()

    noisy_positions_present_atoms = noisy_positions_present_atoms[:, :data["positions"].shape[1], :]
    noisy_features_present_atoms = noisy_features_present_atoms[:, :data["one_hot"].shape[1], :]

    num_fragment_atoms = int(data["fragment_mask"].sum().item())

    #load Shapley values for Hausdorff distance
    phi_values = []
    
    
    with open(INTIAL_DISTIBUTION_PATH + "/phi_atoms_" + str(data_index) + ".txt", "r") as read_file:
        read_file.readline()
        read_file.readline()
        for row in read_file:
            if row.strip() == "":
                break
            line = row.strip().split(",")
            phi_values.append(float(line[3])) #3 for hausdorff distance-based Shapley values

    # Remove fragment atoms whose Shapley values are above the average Shapley value
    fragment_mask = data["fragment_mask"].squeeze().bool()
    linker_mask = data["linker_mask"].squeeze().bool()
    phi_values_tensor = torch.tensor(phi_values)
    average_phi_value = phi_values_tensor.mean()
    
    #retrieve indices of fragment and linker atoms from atom_mask
    fragment_atoms_indices = torch.where(fragment_mask)[0]
    linker_atoms_indices = torch.where(linker_mask)[0]
    
    #retrieve indices of fragment atoms with Shapley values above the average Shapley value
    fragment_atoms_to_remove_indices = torch.where(phi_values_tensor > average_phi_value)[0]

    fragment_atoms_indices_keep = torch.tensor([i for i in fragment_atoms_indices if i not in fragment_atoms_to_remove_indices])

    #keep only fragment_atoms_indices_keep and linker_atoms_indices
    atom_indices_to_keep = torch.cat((fragment_atoms_indices_keep, linker_atoms_indices))

    #remove atoms from molecule
    data["positions"] = data["positions"][:, atom_indices_to_keep, :]
    data["one_hot"] = data["one_hot"][:, atom_indices_to_keep, :]
    data["charges"] = data["charges"][:, atom_indices_to_keep]
    data["fragment_mask"] = data["fragment_mask"][:, atom_indices_to_keep]
    data["linker_mask"] = data["linker_mask"][:, atom_indices_to_keep]
    data["atom_mask"] = data["atom_mask"][:, atom_indices_to_keep]
    data["anchors"] = data["anchors"][:, atom_indices_to_keep]
    edge_mask_to_keep = (data["atom_mask"].unsqueeze(1) * data["atom_mask"]).flatten()
    data["edge_mask"] = edge_mask_to_keep

    #remove atoms from noisy features and positions
    noisy_positions_present_atoms = noisy_positions_present_atoms[:, atom_indices_to_keep, :]
    noisy_features_present_atoms = noisy_features_present_atoms[:, atom_indices_to_keep, :]
    
    chain_batch, node_mask = model.sample_chain(data, keep_frames=keep_frames, noisy_positions=noisy_positions_present_atoms, noisy_features=noisy_features_present_atoms)

    chain_with_full_fragments = chain_batch[0, :, :, :]

    #save and visualize chain (only for the linker use noisy positions for the initial distribution)

    arrestomomentum()

  0%|          | 0/30 [00:00<?, ?it/s]


KeyboardInterrupt: Debug interrupt.

In [10]:
# fragment_atoms_indices
noisy_features_present_atoms

tensor([[[ 0.2829, -0.8926, -0.1626, -0.8062, -0.1168, -1.6124, -0.1541,
          -0.0646],
         [ 0.6389, -0.2545, -1.2304, -1.5822,  0.6431,  0.9715, -1.3249,
          -1.0006],
         [-0.0972, -0.8326, -1.0199, -0.8756, -0.0331, -0.0130, -0.3365,
           0.5902],
         [ 0.8849,  0.2748,  0.2443,  0.5469,  0.8270, -0.2680, -0.3580,
           0.2022],
         [-1.3575,  0.0138,  0.1263, -1.6987, -0.7975, -0.9709, -1.3417,
           1.1711],
         [ 0.9939,  1.3106,  2.0404, -1.4051, -0.0063,  0.8455, -1.0851,
           0.7232],
         [-0.5468,  0.2310,  1.1613,  1.8310, -1.2884, -1.1766,  1.2506,
           0.2565],
         [-0.1214, -1.0244, -0.7861,  0.4033, -0.5858,  0.3686, -1.2556,
           1.3308],
         [-1.1738, -0.2898,  0.0059,  0.1116,  1.2815,  1.5772, -1.2627,
           0.4403],
         [ 1.3200,  2.0143,  0.6801, -0.4020,  0.0591, -1.1725,  0.0469,
          -0.1127],
         [ 0.4248,  1.5322,  1.3300, -0.6456, -1.9067, -0.0347, -0.237

In [11]:
atom_indices_to_keep

tensor([ 1,  3,  9, 10, 11, 12, 13, 17, 19, 21, 22, 23, 24, 25])