<a href="https://colab.research.google.com/github/S-AJ-H/AIMS26/blob/main/2_Message_passing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 2. Visualising message passing

Chemprop uses a graph neural network method in which "message passing" is used to learn hidden representations of chemicals. In this example, we visualise how information propagates through a molecule via message passing.

You will:
   
*   Explore how Chemprop calculates initial features of given molecules from SMILES strings.
*   Visualise messsage passing by exploring how a small perturbation of the initial features of an atom can propagate through a molecule.

Resources:
>RDKit:   
>https://rdkit.org/docs/index.html

>Chemprop:  
>https://pubs.acs.org/doi/10.1021/acs.jcim.9b00237  
>https://pubs.acs.org/doi/10.1021/acs.jcim.3c01250  
>https://chemprop.readthedocs.io/en/latest/

# 0. Install chemprop and resources

In [None]:
# Install chemprop (~1 min)
!pip install chemprop -qq
import chemprop
print("Imported Chemprop version", chemprop.__version__)

# ML
from rdkit import Chem                                      # rdkit is used to convert SMILES to molecular graphs ("mols")
from rdkit.Chem import Draw                                 # Lets us draw molecules
from rdkit.Chem.Draw import SimilarityMaps                  # for drawing the partial charges

from chemprop import data, featurizers, models, nn          # chemprop is our GNN package
from chemprop.models import MPNN                            # Defines the overall MPNN architecture
from chemprop.data import BatchMolGraph                     # Batches the MolGraphs
from chemprop.nn.message_passing import BondMessagePassing  # Defines the message passing neural network architecture
import torch

# Misc
import pandas as pd
import numpy as np
import os
from urllib.request import urlretrieve
import copy
import io
import matplotlib
import matplotlib.pyplot as plt


# 1. Visualising message passing

> This section shows how information is shared across a molecule for increasing numbers of message passing steps.
> We will use a pre-trained network (https://github.com/JacksonBurns/chemeleon) here; no training is taking place.

### 1.1 Define, draw and featurise the molecule:

In [None]:
# ============================================================
# Define and draw the molecule using smiles and RDkit
# ============================================================
smiles = "c1ncc(C)cc1"
molecule = Chem.MolFromSmiles(smiles)
display(Draw.MolToImage(molecule))

# ============================================================
# Featurise using Chemprop
# ============================================================
# Featurisation converts the molecule (an rdkit class which contains atom types, bond types etc) into a MolGraph (chemprop class containing feature vectors for the atoms and bonds that can be used in ML)
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()   # chemprop module, but based on RDkit. Uses multihot encoding to featurise atoms and bonds.See: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/featurizers/index.html#chemprop.featurizers.SimpleMoleculeMolGraphFeaturizer
mol_graph = featurizer(molecule)

print("\nmolecule type:", type(molecule))
print("mol_graph type:", type(mol_graph))
print("length of atom features:", mol_graph.V.shape)          # V = atom feature vectors, E = bond feature vectors, edge_index gives graph connectivity.

print("\natom features:")
for i, features in enumerate(mol_graph.V):
    print(
        f"Atom {i} ({molecule.GetAtomWithIdx(i).GetSymbol()}): "
        f"{features[:].tolist()}"
    )

####1.1 Questions

> (a) Print the size of the atom features and bond features in `mol_graph`. Why are there twice as many edge vectors as atom vectors? Hint: print the edge index.

> (b) What does the final value in each atom vector represent? Extra: take a look at the documentation to understand what each bit represents.
https://chemprop.readthedocs.io/en/latest/tutorial/python/featurizers/atom_featurizers.html
https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/featurizers/bond/index.html

#### 1.1 Answers

> (a)

> (b)

In [None]:
# ============================================================
# 1.1 Question answers
# ============================================================

### 1.2 Set up the message passing

> To visualise message passing, we are going to apply 1-3 message passing steps on two molecules with slightly different initial features.
> We first carry out the message passing using the features calculated above.
> Then, we add a small, arbitrary pertubation to one atom, and re-calculate the hidden features.
> By comparing the difference between the two, we can see how information flows.

In [None]:
# ============================================================
# 1.2a. Set up  message passing architectures
# ============================================================

# get feature dimensions
d_v = mol_graph.V.shape[1]
d_e = mol_graph.E.shape[1]
print("number of features per atom, per bond:", d_v,",", d_e)

# Set up three message passing architectures with different numbers of steps ("depths"). These have randomly initialised weights.
mp_1 = BondMessagePassing(d_v=d_v, d_e=d_e, d_h=2048, depth=1, undirected=False)   # d_h = hidden dimension size; 2048 required for our pre-trained weights
mp_2 = BondMessagePassing(d_v=d_v, d_e=d_e, d_h=2048, depth=2, undirected=False)
mp_3 = BondMessagePassing(d_v=d_v, d_e=d_e, d_h=2048, depth=3, undirected=False)

# ============================================================
# 1.2b. Import pre-trained weights from the CheMeleon model
# ============================================================

# import pre-trained weights (takes ~20sec)
if not os.path.exists("chemeleon_mp.pt"):
    print("Downloading CheMeleon weights...")
    urlretrieve(r"https://zenodo.org/records/15460715/files/chemeleon_mp.pt","chemeleon_mp.pt",)
chemeleon_mp = torch.load("chemeleon_mp.pt", weights_only=False)

# load weights into our three message passing architectures (note each message passing step uses the same weights)
mp_1.load_state_dict(chemeleon_mp['state_dict'])
mp_2.load_state_dict(chemeleon_mp['state_dict'])
mp_3.load_state_dict(chemeleon_mp['state_dict'])

####1.2 Questions

> (a) Why do we use BondMessagePassing instead of AtomMessagePassing? Hint: look at the papers linked at the top.

### 1.3 Compute hidden embeddings

> Using our pre-loaded weights, we now calcuate the hidden embeddings

In [None]:
# ============================================================
# 1.3a. Compute baseline embeddings
# ============================================================
batch_graph = BatchMolGraph([mol_graph])                                        #chemprop message passing requires batches, even for a single MolGraph
with torch.no_grad():                                                           #calculate hidden features for each atom
    h1 = mp_1(batch_graph)
    h2 = ################################         # hidden representations with 2 message passing steps
    h3 = ################################         # hidden representations with 3 message passing steps

h_baseline = [h1, h2, h3]                                                       #list of features (tensors)

# ============================================================
# 1.3b. Perturb one atom's initial features
# ============================================================
source_atom = 0                                                                 #choose the left-most carbon
epsilon = 1e-2                                                                  #define an arbirary perturbation

mol_graph_pert = copy.deepcopy(mol_graph)                                       #copy the original features into a new MolGraph

mol_graph_pert.V[source_atom] = mol_graph_pert.V[source_atom] + epsilon         #add the perturbation to all atom features in atom #0.
print("initial features:\n", mol_graph.V[0])
print("\nPerturbed initial features:\n", mol_graph_pert.V[0])

# ============================================================
# 1.3c. Compute perturbed embeddings
# ============================================================
batch_graph_pert = BatchMolGraph([mol_graph_pert])
with torch.no_grad():
    h1_p = mp_1(batch_graph_pert)
    h2_p = ################################     # perturbed hidden representations with 2 message passing steps
    h3_p = ################################     # perturbed hidden representations with 3 message passing steps

h_perturbed = [h1_p, h2_p, h3_p]

# ============================================================
# 1.3d. Compute the difference between the original and the perturbed:
# ============================================================
deltas = []
for h_b, h_p in zip(h_baseline, h_perturbed):
    delta = torch.norm(h_p - h_b, dim=1)                                        # calculate L2 norm; collapses the 2048-dimensional vector into a single scalar per atom
    deltas.append(delta.cpu().numpy())

### 1.4 Plot the difference

> The difference shows how that small perturbation to the features of the nitrogen propogate through the molecule.

In [None]:
#calculate colour scale
all_vals = np.concatenate(deltas)
vmin, vmax = np.min(all_vals), np.max(all_vals)

def draw_molecule(mol, values, vmin, vmax, cmap_name, size=(300,300)):
    # values
    vals = np.array(values)
    vals = (vals - vmin) / (vmax - vmin + 1e-12)                                # 1e-12 prevents possible zero denominator
    vals = vals **0.5                                                           # colour scaling

    # colours
    cmap = matplotlib.colormaps.get_cmap(cmap_name)
    atom_colors = {i: tuple(cmap(v)[:3]) for i, v in enumerate(vals)}           # RDKit wants {atom_idx: (R,G,B)}

    # other options
    drawer = Draw.MolDraw2DCairo(size[0], size[1])
    opts = drawer.drawOptions()
    opts.addAtomIndices = True
    opts.highlightBondWidthMultiplier = 0
    drawer.DrawMolecule(mol, highlightAtoms=list(atom_colors.keys()), highlightAtomColors=atom_colors)

    drawer.FinishDrawing()
    return drawer.GetDrawingText()

In [None]:
titles = ["after 1 step", "after 2 steps", "after 3 steps"]

fig, axes = plt.subplots(1, 3, figsize=(10, 3))

for ax, title, delta_vals in zip(axes, titles, deltas):

    png_bytes = draw_molecule(
        molecule,
        delta_vals,
        vmin=vmin,
        vmax=vmax,
        cmap_name="plasma"
    )

    img = plt.imread(io.BytesIO(png_bytes))
    ax.imshow(img)
    ax.set_title(title)
    ax.axis("off")

plt.tight_layout()
plt.show()

####1.4 Questions

> (a) Why are atoms #2 and #5 the same colour after 1 pass, but different colours from each other after 2 and 3 passes?  

> (b) Why does message passing satisfy node permutation equivariance?

> (c) Think about our polymer system again. We use two monomers to represent the alternating polymer chain (A-B-A-B-A-B-...). Comment on how this representation might be an issue for a message passing network. What could we do to rectify this?


####1.4 Answers

> (a)  

> (b)  

> (c)  