In [None]:
!pip install rdkit py3Dmol
!pip install selfies
!pip install crem
!pip install rdkit

<a target="_blank" href="https://colab.research.google.com/github/RodrigoAVargasHdz/CHEM-4PB3/blob/w2024/Course_Notes/Week%2010/Introduction_to_SELFIES.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [None]:
# Import necessary libraries
import tqdm
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem, Draw, rdDepictor
from rdkit.Chem import PandasTools, Descriptors
import py3Dmol

from IPython.display import display, HTML
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG

import selfies as sf
import random

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import matplotlib
import matplotlib.pyplot as plt

from crem.crem import mutate_mol, grow_mol, link_mols
import matplotlib.pyplot as plt
import seaborn as sns

# Mutations of molecules #


Mutations in SMILES strings are a fundamental technique in computational chemistry and drug discovery for exploring the vast chemical space in search of novel molecules with desirable properties. <br>
By introducing random or targeted changes to the SMILES representation of molecules, researchers can generate new molecular structures that may exhibit improved biological activity, selectivity, or pharmacokinetic properties. This process, inspired by the principles of natural evolution and genetic variation, enables the iterative refinement of molecular candidates through cycles of mutation, evaluation, and selection. It's a powerful approach for optimizing existing drugs, discovering new therapeutic candidates, and understanding the relationship between molecular structure and biological function, ultimately accelerating the pace of pharmaceutical development and the discovery of novel materials. <br>

In this example, we’re not only converting the SMILES string `CCO` into a molecule object but also ensuring its validity through sanitization. If we set `sanitize=False` in `Chem.MolFromSmiles`, then we don’t need to to explicitly call Chem.SanitizeMo.


In [None]:
from rdkit import Chem

# Replace 'YOUR_SMILES_STRING' with an actual SMILES string
smile = 'CCO'

# Convert SMILES to molecule object with sanitize=False
molecule = Chem.MolFromSmiles(smile, sanitize=False)

# Check if the conversion was successful
if molecule is not None:
    print("Molecule object created successfully!")

    # Sanitize the molecule for validation
    Chem.SanitizeMol(molecule)

    # Check if sanitization was successful
    if Chem.SanitizeMol(molecule) == 0:
        print("Molecule sanitized successfully!")
    else:
        print("invalid chemistry")
else:
    print("invalid SMILES")



The following code demonstrates how to replace a dummy atom in a molecule with a real atom using RDKit. This example assumes you have a molecule with at least one dummy atom (denoted as * in SMILES) that you want to replace with another atom or group:

In [None]:
def replace_ghost_atom(smiles, replacement_smiles, atom_idx=None):
    """
    Replaces a dummy atom in a molecule with a specified atom or group.

    :param smiles: SMILES string of the original molecule containing a dummy atom.
    :param replacement_smiles: SMILES string of the atom or group to replace the dummy atom.
    :param atom_idx: Index of the dummy atom to replace. If None, replaces the first dummy atom found.
    :return: SMILES string of the modified molecule, or None if the operation fails.
    """
    mol = Chem.MolFromSmiles(smiles)
    replacement_mol = Chem.MolFromSmiles(replacement_smiles)

    if not mol or not replacement_mol:
        return None

    # Find the dummy atoms
    dummy_atoms = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetSymbol() == '*']
    if not dummy_atoms:
        return None

    # Select the dummy atom to replace
    if atom_idx is not None and atom_idx in dummy_atoms:
        dummy_atom_idx = atom_idx
    else:
        dummy_atom_idx = dummy_atoms[0]  # Default to the first dummy atom

    # Replace the dummy atom
    editable_mol = Chem.EditableMol(mol)
    editable_mol.ReplaceAtom(dummy_atom_idx, replacement_mol.GetAtomWithIdx(0))
    modified_mol = editable_mol.GetMol()

    # Sanitize the molecule (optional but recommended)
    Chem.SanitizeMol(modified_mol)

    return Chem.MolToSmiles(modified_mol)

# Example usage
original_smiles = "C1=CC=CC=C1*"  # Benzene with a dummy atom
# replacement_smiles = "O"  # Replace dummy atom with oxygen
possible_characters = ['c','C', 'O', 'N', '=', '#', '(', ')','=O','#N','=N']
good_molecules = [original_smiles]
for ci in possible_characters:
  modified_smiles = replace_ghost_atom(original_smiles, ci)
  print(f"Original SMILES: {original_smiles}",f"Modified SMILES: {modified_smiles}")
  if modified_smiles is not None:
    good_molecules.append(modified_smiles)


In [None]:
print(good_molecules)
mol_good_molecules = [Chem.MolFromSmiles(si, sanitize=False) for si in good_molecules]

img = Draw.MolsToGridImage(mol_good_molecules, molsPerRow=10, subImgSize=(500, 500), legends=good_molecules)
img


# SELFIES #
# **Introduction to SELFIES**

Chemistry and computational analysis often grapple with the complexity of molecular representation. SELFIES (Self-Referencing Embedded Strings) is a robust molecular string representation system designed for unambiguous molecular encoding. Its key advantage lies in facilitating direct input into machine learning models, such as generative models, and ensuring the validity of the outputs.

## **Key Points of SELFIES**

- Robustness: Each SELFIES string is a valid molecular structure, overcoming the validity issues of other formats.
- Generative Model Compatibility: SELFIES offers a diverse array of valid molecules, enhancing generative model outcomes.
- Adjustable Constraints: The framework can impose both meaningful and arbitrary rule sets, showcasing its adaptability.
- Syntax and Grammar: SELFIES uses a context-free grammar with specific symbols for atoms, rings, branches, and more.

<br>

<!DOCTYPE html>
<html>
<head>
    <style>
        .centered-image {
            display: block;
            margin-left: auto;
            margin-right: auto;
            width: 50%;
        }
    </style>
</head>
<body>

<a href="https://aspuru.substack.com/p/molecular-graph-representations-and" target="_blank">
    <img src="https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fbucketeer-e05bbc84-baa3-437e-9518-adb32be77984.s3.amazonaws.com%2Fpublic%2Fimages%2F6e37c68a-a71c-4ffa-ba5c-ffe1e237a5c3_1600x859.png"
         alt="SMILES and SELFIES compared."
         class="centered-image">
</a>

<br>
<figcaption align = "center"><b>Figure 1 - SMILES and SELFIES compared. Figure from
Alan Aspuru-Guzik.</b></figcaption>


## **Pros and Cons of SELFIES:**

### **Pros of SELFIES:**
- **Consistency**: SELFIES ensures that each molecule is represented by a unique string, which aids in maintaining database integrity and simplifying molecule comparisons.
- **Generative Model Synergy**: Training with SELFIES results in a higher diversity of valid, novel molecules compared to SMILES, benefiting the exploration of chemical space.

### **Cons of SELFIES:**
- **Complexity**: The SELFIES syntax is more complex and less human-readable than SMILES, which can be a barrier for manual handling and interpretation.
- **Adoption**: While gaining popularity, SELFIES is newer and less widespread than SMILES, potentially leading to compatibility issues with existing systems.



## **Rules:**
1. The main string is derived using a rule set such that the number of valence bonds per atom does not exceed physical limits.
2. The symbol after a `Branch` is interpreted as the number of
SELFIES symbols derived inside the branch.
3. The symbol after `Ring` interpreted as a number too, indicating that the current atom is connected to the `(N + 1)`st previous atom.
Thereby every information in the string (except the ring
closure) is local and allows for efficient derivation rules.


## **SELFIES vs. SMILES Representation**

In chemical informatics, SELFIES and SMILES are two methods used for representing molecules as strings. While SMILES is the more traditional format, SELFIES is a newer format designed to overcome some of the limitations of SMILES.


## **Symbol Indexing in SELFIES**

The SELFIES format uses indexed symbols, such as `[Branch1]` or `[Ring1]`, which refer to specific structural elements. All other symbols are assigned index 0 by default. For example:
- `[Branch1][size=1]` indicates a branch of size index+1
- `[Ring2]` denotes a ring closure







## **Using SELFIES:**
SELFIES excel in generating random molecules and one-hot encoding for machine learning using valid SELFIE alphabets and external information such as ring structures.

> The following code has been adopted and added from [Akshat Nigam &
Aspuru-Guzik group's GitHub](https://github.com/aspuru-guzik-group/selfies_tutorial/tree/master)

In [None]:
# import selfies as sf

# List of molecules with their common names and SMILES representation
molecules = [
    ("Benzene", "c1ccccc1"),
    ("Caffeine", "CN1C=NC2=C1C(=O)N(C(=O)N2C)C"),
    ("Aspirin", "CC(=O)OC1=CC=CC=C1C(=O)O"),
    ("Methane", "C"),
    ("Ethanol", "CCO"),
    ("Glucose", "C(C1C(C(C(C(O1)O)O)O)O)O"),
    ("Acetic Acid", "CC(=O)O"),
    ("Penicillin", "CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C"),
    ("Naproxen", "CC(C1=CC=CC=C1)C(C2=CC=C(C=C2)C(=O)O)C(=O)O")
]

# Prepare lists to store data for DataFrame
common_names, smiles_list, encoded_selfies_list, decoded_smiles_list = [], [], [], []

# Process each molecule
for common_name, smiles in molecules:
    common_names.append(common_name)
    smiles_list.append(smiles)

    # SMILES --> SELFIES translation
    encoded_selfies = sf.encoder(smiles)
    encoded_selfies_list.append(encoded_selfies)

    # SELFIES --> SMILES translation
    decoded_smiles = sf.decoder(encoded_selfies)
    decoded_smiles_list.append(decoded_smiles)

# Create DataFrame
df = pd.DataFrame({
    "Common Name": common_names,
    "SMILES": smiles_list,
    "Encoded SELFIES": encoded_selfies_list,
    "Decoded SMILES": decoded_smiles_list
})

pd.set_option('display.max_colwidth', None)


df.head(9)


In [None]:
# Retrieve the list of robust alphabets from SELFIES
robust_alphabets = list(sf.get_semantic_robust_alphabet())

# Creating a DataFrame for the robust alphabets
df = pd.DataFrame(robust_alphabets, columns=['Robust Alphabets'])

# Print the number of characters and the DataFrame
print(f"Number of Characters in Robust Alphabets: {len(robust_alphabets)}")
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', 70)

df

In [None]:
# A subset of SELFIE alphabets
alphabet_random = ['[=O]', '[O]', '[Ring2]', '[#Branch2]', '[#Branch3]'
            '[N]', '[=Ring3]', '[#B]', '[H]',
            '[#P]', '[Cl]', '[F]', '[Ring1]', '[P]', '[=Branch2]',
            '[Br]', '[=S]', '[=N]', '[#N]', '[S]',
            '[C]', '[I]', '[Ring3]', '[=C]', '[#C]'] + ['[C][=C][C][=C][C][=C][Ring1][=Branch1]'] * 10


def get_random_molecule_using_selfies(num_random, robust_alphabets):
    max_smi_len = 25
    min_smi_len = 5
    collect_random = []

    for _ in range(num_random):
        random_len = random.randint(min_smi_len, max_smi_len+1)
        random_alphabets = list(np.random.choice(alphabet_random, random_len))
        random_selfies = ''.join(x for x in random_alphabets)

        collect_random.append(sf.decoder(random_selfies))

    return [x for x in collect_random if x != '']

In [None]:
random_smiles = get_random_molecule_using_selfies(8, alphabet_random)

mols = [Chem.MolFromSmiles(smi) for smi in random_smiles]
img = Draw.MolsToGridImage(mols[:8], molsPerRow=4, subImgSize=(500, 500), legends=random_smiles)
img

> While some of these molecules may contain unresonable connections between functional groups, they are all **valid** molecules.

### **Using SELFIES to explore analogous of Acetaminophen**
By using a variation of the  ``` def get_random_molecule_using_selfies  ``` function, let's explore variations of Acetaminophen.

```
def get_random_molecule_using_selfies(num_random, robust_alphabets):
    max_smi_len = 25
    min_smi_len = 2
    collect_random = []
    
    for _ in range(num_random):
        random_len = random.randint(min_smi_len, max_smi_len+1)
        random_alphabets = list(np.random.choice(alphabet, random_len))
        random_selfies = ''.join(x for x in random_alphabets)
        
        collect_random.append(sf.decoder(random_selfies))
    
    return [x for x in collect_random if x != '']
```




In [None]:
acetaminophen_alphabet = ['[=O]', '[O]', '[Ring2]', '[#Branch2]', '[#Branch3]',
                          '[N]', '[=Ring3]', '[Ring1]', '[P]', '[H]', '[#Branch1]'
                          '[Br]', '[=S]', '[=N]', '[#N]', '[S]', '[=Ring2]', '[=Branch1]',
                          '[C]', '[Ring3]', '[=C]', '[#C]' , '[#C+1]' , '[=N+1]']

acetaminophen_smiles = 'CC(=O)Nc1ccc(cc1)O'

def get_random_acetaminophen(num_random):
    max_extra_len = 25
    min_extra_len = 3
    collect_random = []

    for _ in range(num_random):
        random_len = random.randint(min_extra_len, max_extra_len)
        random_alphabets = list(np.random.choice(acetaminophen_alphabet, random_len))
        random_selfies = ''.join(random_alphabets)

        # Encoding acetaminophen_smiles and appending random_selfies
        encoded_smiles = sf.encoder(acetaminophen_smiles) + random_selfies
        decoded_smiles = sf.decoder(encoded_smiles)
        if decoded_smiles is not None or decoded_smiles:
            collect_random.append(decoded_smiles)

    return collect_random

In [None]:
random_acetaminophen = get_random_acetaminophen(8)

ace_mols = [Chem.MolFromSmiles(ace_smi) for ace_smi in random_acetaminophen]
img = Draw.MolsToGridImage(ace_mols[:8], molsPerRow=4, subImgSize=(500, 500), legends=random_acetaminophen)
img

In [None]:
mols = []
highlights = []
acetaminophen_mol = Chem.MolFromSmiles(acetaminophen_smiles)

for ace_smi in random_acetaminophen:
    mol = Chem.MolFromSmiles(ace_smi)
    mols.append(mol)
    # Find the substructure
    if mol.HasSubstructMatch(acetaminophen_mol):
        substruct = mol.GetSubstructMatch(acetaminophen_mol)
        highlights.append(substruct)
    else:
        highlights.append([])

# Draw the molecules with highlighted substructure
img = Draw.MolsToGridImage(mols, molsPerRow=4, subImgSize=(400, 400), highlightAtomLists=highlights)
img

### **One-Hot Encoding**

In [None]:
# Encode molecules to SELFIES
encoded_one_hot_random = [sf.encoder(molecule) for molecule in random_acetaminophen]

# Determine the alphabet from the encoded molecules, include padding symbol
molecule_alphabet_random = sf.get_alphabet_from_selfies(encoded_one_hot_random)
molecule_alphabet_random.add('[nop]')
molecule_alphabet_random = list(sorted(molecule_alphabet_random))

# Calculate maximum length for padding
max_length = max(sf.len_selfies(s) for s in encoded_one_hot_random)

# Mapping from symbols to indices
symbol_to_index_random = {s: i for i, s in enumerate(molecule_alphabet_random)}

# Mapping from indices to symbols (reverse of symbol_to_index)
vocab_itos_random = {i: s for s, i in symbol_to_index_random.items()}

# Convert molecules to label-encoded and one-hot encoded formats
encoded_data = []
for molecule in encoded_one_hot_random:
    label_encoded = sf.selfies_to_encoding(molecule,
                                           vocab_stoi=symbol_to_index_random,
                                           pad_to_len=max_length,
                                           enc_type='label')
    one_hot_encoded = sf.selfies_to_encoding(molecule,
                                             vocab_stoi=symbol_to_index_random,
                                             pad_to_len=max_length,
                                             enc_type='one_hot')
    encoded_data.append((molecule, label_encoded, one_hot_encoded))

# Create DataFrame
df_encoded_one_hot_random = pd.DataFrame(encoded_data, columns=['Molecule SELFIES', 'Label Encoded', 'One-Hot Encoded'])

df_encoded_one_hot_random.head(8)

In [None]:
print(vocab_itos_random)

In [None]:
# Choose the first one-hot encoded molecule for visualization
one_hot_encoded_data = [data[2] for data in encoded_data]  # Replace '2' with the correct index for one-hot encoded data
first_molecule_one_hot = one_hot_encoded_data[0]  # Select the first molecule

# Create a list of symbols ordered by the index
ordered_symbols = [vocab_itos_random[i] for i in range(len(vocab_itos_random))]


# Enhanced plotting
plt.figure(figsize=(13, 8))
sns.set(font_scale=1.2)  # Increase font scale
heatmap = sns.heatmap(first_molecule_one_hot, cmap="viridis", cbar=True, linewidths= 1)
heatmap.set_xticks(np.arange(len(ordered_symbols)) + 0.5)
heatmap.set_xticklabels(ordered_symbols, rotation=90)
heatmap.set_xlabel("Symbols in Vocab", fontsize=14)
heatmap.set_ylabel("Position in Molecule", fontsize=14)
heatmap.set_title("One-Hot Encoding of First Molecule", fontsize=16)
plt.show()

In [None]:
# Select two random one-hot encoded molecules
selected_indices = random.sample(range(len(encoded_data)), 2)
selected_molecules = [encoded_data[i][2] for i in selected_indices]  # One-hot encoded

# Decode and filter '[nop]'
decoded_selfies = []
for one_hot in selected_molecules:
    decoded = sf.encoding_to_selfies(one_hot, vocab_itos_random, enc_type="one_hot")
    filtered_selfies = decoded.replace('[nop]', '')  # Remove '[nop]'
    decoded_selfies.append(filtered_selfies)

# Optionally convert to SMILES
decoded_smiles = [sf.decoder(s) for s in decoded_selfies]

# Displaying the decoded and filtered molecules
for idx, (selfies, smiles) in enumerate(zip(decoded_selfies, decoded_smiles)):
    print(f"Molecule {idx + 1}:")
    print("SELFIES:", selfies)
    print("SMILES:", smiles)

# **Classification of molecules using CNN**

**Goals**
1. Introduction to SMILES as molecular representation for ML models.
2. Use ML models, more specifically **Convolutional NeuralNetworks** to classify molecules.



## **Data loading and analysis**

The Tox21 dataset from MoleculeNet is another baseline to test our model [37]. It contains the activities of 7,831 compounds against 12 biological targets or pathways, which are nuclear receptor(NR)-androgen receptor (AR)-ligand-binding domain (LBD), NR-AR, NR-aryl hydrocarbon receptor (AhR), NR-Aromatase, NR-estrogen receptor (ER)-LBD, NR-peroxisome proliferator-activated receptor (PPAR)-gamma, SR-antioxidant response element (ARE), stress response (SR)-ATPase Family AAA Domain Containing 5 (ATAD5), SR-heat shock factor response element (HSE), SR-mitochondrial membrane potential (MMP), and SR-p53 [38]. Similar to the CYP450 dataset, Tox21 includes many missing labels. We formulate the chemical toxicity prediction using the Tox21 dataset as a multi-label classification problem. The label for the target is positive if the chemical compound has toxicity by interacting with the target. The multi-label means one chemical compound can have more than one targets.

Text from the [link](https://bio-protocol.org/exchange/minidetail?type=30&id=12688188)

In [None]:
data_url = "https://github.com/RodrigoAVargasHdz/CHEM-4PB3/raw/w2024/Course_Notes/data/tox21.csv"
data_full = pd.read_csv(data_url)
print('Total data:', data_full.count())
data_full.head()
data_full = data_full[['smiles','NR-AR']]

print('Possible values of NR-AR:', data_full['NR-AR'].unique())

data = data_full.dropna()
data['NR-AR'] = data['NR-AR'].astype(int)

print('Possible values of NR-AR:', data['NR-AR'].unique())
print(data.head())
data.hist(column='NR-AR')

> Plot some molecules that are not **HIV active**

In [None]:
PandasTools.AddMoleculeColumnToFrame(data, 'smiles')
HIV_active_0 = data[data['NR-AR'] == 0]
HIV_active_0_16 = PandasTools.FrameToGridImage(HIV_active_0[:9], column='ROMol', legendsCol='smiles',
                                               molsPerRow=3, subImgSize=(300, 300))
HIV_active_0_16

In [None]:
PandasTools.AddMoleculeColumnToFrame(data, 'smiles')
HIV_active_0 = data[data['NR-AR'] == 1]
HIV_active_0_16 = PandasTools.FrameToGridImage(HIV_active_0[:9], column='ROMol', legendsCol='smiles',
                                               molsPerRow=3, subImgSize=(300, 300))
HIV_active_0_16

From the previous class, we saw that a molecule written in the SMILES notation can be transformed into a *"figure"* using a dictionary of characters and the one-hot encoding transformation. <br>

To create this dictionary, we first need to defined the maximum number of characters in a SMILE, meaning the length of the text.

In [None]:
data_negative = data[data['NR-AR'] == 0]
data_positive = data[data['NR-AR'] == 1]

# balanced dataset
data_positive_d = pd.concat([data_positive, data_positive], axis=0)
n_positive = len(data_positive_d)
data_negative_red = data_negative.sample(n_positive)

balanced_data = pd.concat([data_negative_red, data_positive_d], axis=0)
balanced_data.head()
print(balanced_data.count())

In [None]:
import sys
import numpy
numpy.set_printoptions(threshold=sys.maxsize)


In [None]:
# Assuming 'balanced_data' is your DataFrame with a 'smiles' column
global_alphabet_set = set()

for si in balanced_data['smiles'].to_list():
    # Encode SMILES to SELFIES
    slfi = sf.encoder(si)

    # Directly iterate over symbols in the SELFIES string
    for symbol in sf.split_selfies(slfi):
        # Add the symbol to the set if not already present
        global_alphabet_set.add(symbol)

    # Handling specific cases like radicals if needed
    # if '[atom].' in slfi:
    #     global_alphabet_set.add('[atom].')  # Or however you'd like to represent it

# Optionally, convert the set back to a list and sort it
global_alphabet_list = sorted(list(global_alphabet_set))


if '.' not in global_alphabet_list:
    global_alphabet_list.append('.')

if '[nop]' not in global_alphabet_list:
    global_alphabet_list.append('[nop]')

global_alphabet = global_alphabet_list



# Creating a DataFrame for the robust alphabets
df = pd.DataFrame(global_alphabet, columns=['SELFIES'])

# Print the number of characters and the DataFrame
print(f"Number of Characters in Global Alphabet: {len(global_alphabet)}")
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', 70)
df

In [None]:
# global_alphabet = ['[=O]', '[O]', '[Ring2]', '[#Branch2]', '[#Branch3]', '[=Ring1]', '[=Ring2]',
#                      '[N]', '[=Ring3]', '[#B]', '[H]', '[#P]', '[Cl]', '[#Branch1]', '[As]',
#                      '[F]', '[Ring1]', '[P]', '[=Branch2]', '[Br]', '[NH1]', '[=N-1]', '[NaH1]',
#                      '[=S]', '[=N]', '[#N]', '[S]', '[C]', '[I]', '[=N+1]', '[Mn+1]', '[Cl-1]',
#                      '[Ring3]', '[=C]', '[#C]', '[Ag]', '[Pt]', '[K]', '[PH1]', '[N+1]',
#                      '[Rh]', '[Zn]', '[Hg]', '[Fe]', '[Te]', '[Ca]', '[Se]', '.', '[=P]',
#                      '[Li]', '[Mg]', '[Ge]', '[Cu]', '[Mo]', '[Mn]', '[Si]', '[O-1]',
#                      '[Ni]', '[W]', '[U]', '[Branch1]', '[Branch2]', '[=Branch1]','[C@]',
#                    '[C@H1]','[C@@H1]','[C@@]', '[N-1]','[/C]','[\\C]','[NH1+1]', '[/N]',
#                    '[/O]','[Sb]','[\\C@H1]','[\\N+1]','[Mn+2]','[AlH3]','[=Mo]','[/C@@H1]',
#                    '[Sn]','[Ba+2]','[=S+1]','[Cu+2]','[Na]','[=Bi]','[Fe+2]','[B-1]','[B]',
#                    '[Pd]','[Au-1]','[S-1]','[\\Cl]','[Zn+2]','[\\O]','[K+1]','[Br-1]','[In]',
#                    '[\\S]','[Na+1]','[SbH6+3]', '[SiH1]', '[\\N]',  '[N@@+1]', '[I-1]', '[Hg+2]',
#                    '[PbH2+2]','[Fe+3]','[Fe-1]','[NH3+1]','[Bi]','[=Se]','[NH4+1]','[Co+2]',
#                    '[/C@H1]','[P+1]','[S+1]','[Cr+2]','[CH1-1]','[Dy]','[Ni+2]','[TlH2+1]',
#                    '[nop]']



def smiles_to_one_hot_and_list(smile, max_length, alphabet=global_alphabet):
    """
    Converts a SMILES string to a one-hot encoded matrix and a list of SELFIES symbols.
    """
    selfies_str = sf.encoder(smile)

    molecule_alphabet = sorted(alphabet)

    symbol_to_index = {s: i for i, s in enumerate(molecule_alphabet)}
    vocab_itos = {i: s for s, i in symbol_to_index.items()}
    one_hot_encoded = sf.selfies_to_encoding(selfies_str,
                                             vocab_stoi=symbol_to_index,
                                             pad_to_len=max_length,
                                             enc_type='one_hot')

    selfies_list = selfies_str.split('][')
    selfies_list[0] = selfies_list[0][1:]
    selfies_list[-1] = selfies_list[-1][:-1]
    selfies_output_list = ['[' + selfie + ']' for selfie in selfies_list]
    one_hot_encoded = np.array(one_hot_encoded).T
    # Padding
    if one_hot_encoded.shape[0] < max_length:
        padding = np.zeros((max_length - one_hot_encoded.shape[0], one_hot_encoded.shape[1]))
        one_hot_encoded = np.vstack((one_hot_encoded, padding))

    return one_hot_encoded, symbol_to_index, vocab_itos, selfies_output_list

# smile = 'CC(C)C(C(=O)O)n1[se]c2ccccc2c1=O'
rnd_data = balanced_data.sample(1)
smile = rnd_data['smiles'].item()
print(smile)
selfies_str = sf.encoder(smile)

max_l = sf.len_selfies(selfies_str) + 10

one_hot_encoded, symbol_to_index, vocab_itos, selfies_tokens = smiles_to_one_hot_and_list(smile, max_l)

print('SELFIES image:', one_hot_encoded.shape)

print(symbol_to_index)
print(vocab_itos)
print(selfies_tokens)

mol = AllChem.MolFromSmiles(smile)
# mol

img=Draw.MolsToGridImage([mol,mol],molsPerRow=2,subImgSize=(500,500),legends=[smile,selfies_str])
img

In [None]:
# Create a list of symbols ordered by the index in vocab_itos
ordered_symbols = [vocab_itos[i] for i in range(len(vocab_itos))]

plt.figure(figsize=(19, 10))
sns.set(font_scale=1.2)  # Increase font scale
heatmap = sns.heatmap(one_hot_encoded, cmap="viridis", cbar=True, linewidths=1,
                      yticklabels=ordered_symbols, xticklabels=selfies_tokens)  # y-axis: ordered_symbols, x-axis: selfies_tokens

# Adjusting ticks and labels
#heatmap.set_xticks(np.arange(len(selfies_tokens)))  # Set x-ticks for selfies_tokens
heatmap.set_xticklabels(selfies_tokens, rotation=90, fontsize=8)  # Set x-tick labels with font size
heatmap.set_yticks(np.arange(len(ordered_symbols)))  # Set y-ticks for ordered_symbols
heatmap.set_yticklabels(ordered_symbols, fontsize=8)  # Set y-tick labels with font size

heatmap.set_xlabel("SELFIES Tokens", fontsize=14)
heatmap.set_ylabel("Symbols in Vocab", fontsize=14)
heatmap.set_title("One-Hot Encoding of Molecules", fontsize=16)
plt.show()


Let's create a Data loader for this dataset.
1. We will transform each smile into its "figure" representation
2. Because we are working with two-classes, 'active' and 'inactive'. We also need to transform the label/class to one-hot encoding.

In [None]:
max_length = 0
for smile in balanced_data['smiles']:
    try:
        selfies_str = sf.encoder(smile)  # Convert to SELFIES
        length_selfies = sf.len_selfies(selfies_str)
        if length_selfies > max_length:
            max_length = length_selfies
    except sf.EncoderError:
        print(f"Skipping SMILES string due to encoding error: {smile}")

print('Max SELFIES length', max_length)


In [None]:
# torch new data loader
class CustomDataset(Dataset):
    def __init__(self, smiles_all, labels_all, max_length, global_alphabet):
        self.labels = labels_all
        self.smiles = smiles_all
        self.max_length = max_length
        self.global_alphabet = global_alphabet
        self.symbol_to_index = {s: i for i, s in enumerate(sorted(self.global_alphabet))}

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        si = self.smiles[idx]
        labels = self.labels[idx]
        mi, _, _, _ = smiles_to_one_hot_and_list(si, self.max_length, self.global_alphabet)
        molecules_torch = torch.from_numpy(mi).float()
        labels_one_hot = F.one_hot(torch.tensor(labels), num_classes=2)

        return molecules_torch.unsqueeze(0), labels_one_hot

In [None]:
data_full = balanced_data
train_size = int(0.8 * len(data_full))  # 80% for training
validation_size = len(data_full) - train_size  # 20% for validation
print('Training data', train_size)
print('Test data', validation_size)
# train_dataset, validation_dataset = random_split(
#     data_full, [train_size, validation_size])
tr_dataset = data_full.sample(train_size)
val_dataset = data_full.sample(validation_size)


training_data = CustomDataset(
    tr_dataset['smiles'].to_list(), tr_dataset['NR-AR'].to_list(), max_length, global_alphabet)
train_dataloader = DataLoader(training_data, batch_size=512, shuffle=True)
train_molecules, train_labels = next(iter(train_dataloader))

print('Size of the training data')
print(train_molecules.shape)
print(train_labels.shape)

## CNN ##

In [None]:
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 6, 5,padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, 5,padding=0),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )        # fully connected layer, output 10 classes
        self.fc1 = nn.Sequential(
            nn.Linear(16 * 59 * 59, 512),
            nn.ReLU(),
            )

        self.fc2 = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            )
        self.fc3 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.conv1(x)
        # print(x.shape)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = self.conv2(x)
        # print(x.shape)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        # print(x.shape)
        x = self.fc1(x)
        # print(x.shape)
        x = self.fc2(x)
        # print(x.shape)
        output = self.fc3(x)
        return output   # return x for visualization

In [None]:
def train(model, training_data, training_epochs=60,device='cuda'):
    # Define the loss function and optimizer
    model.train()
    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

    trainloader = torch.utils.data.DataLoader(
        training_data, batch_size=64, shuffle=True)

    iterator = tqdm.notebook.tqdm(range(training_epochs))

    # Run the training loop (epochs)
    loss_trajectory = []
    for epoch in iterator:

        # Set current loss value
        current_loss = []
        for i, data in enumerate(trainloader, 0):
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device) # move data to GPU

            outputs = model(inputs)

            optimizer.zero_grad()
            loss = loss_function(outputs, targets.float())
            loss.backward()
            optimizer.step()

            # Print statistics
            # current_loss += loss.item()
            current_loss.append(loss.item())
        # scheduler.step()
        # print('Epoch %s: %.4f +- %.4f'%(epoch,np.array(current_loss).mean(),np.array(current_loss).std()))
        iterator.set_postfix(loss=torch.tensor(current_loss).mean())
        loss_trajectory.append(current_loss)
        # Process is complete.
    return loss_trajectory

In [None]:
import torch
import torchvision

print(f"PyTorch version: {torch.__version__}")
print(f"TorchVision version: {torchvision.__version__}")

# Set the target device
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'

# print(f"Using device: {device}")
# total_free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
# print(f"Total free GPU memory: {round(total_free_gpu_memory * 1e-9, 3)} GB")
# print(f"Total GPU memory: {round(total_gpu_memory * 1e-9, 3)} GB")

In [None]:
cnn = CNN()
cnn.to(device)
compiled_cnn = torch.compile(cnn) # new!
print(train_molecules.shape)
# labels = cnn(train_molecules)
# print(labels)

# loss_trj = train(cnn, training_data,100,device)
# # Specify a path
# PATH = 'checkpoint.pth' #name of the file

# # Save
# torch.save(cnn.state_dict(), 'checkpoint.pth')

# loss_trj = np.array(loss_trj)
# m = np.mean(loss_trj,axis=1)
# std = np.std(loss_trj,axis=1)

# plt.errorbar(np.arange(m.shape[0]), m, yerr=std, errorevery=(0, 6))
# plt.xlabel('Iterations')
# plt.ylabel('Error')


## Save and Load modules ##

In [None]:
!ls
!pwd


In [None]:
# Load
model_new = CNN()
PATH = "/content/checkpoint.pth"
# model_new.load_state_dict(torch.load(PATH))


state_dict = torch.load(PATH)
model_new.load_state_dict(state_dict)

model_new.eval()
labels = model_new(train_molecules)
print(labels)
!ls

# Prediction #

In [None]:
val_data = CustomDataset(
    val_dataset['smiles'].to_list(), val_dataset['NR-AR'].to_list(), max_length, global_alphabet)
val_dataloader = DataLoader(val_data , batch_size=len(val_dataset), shuffle=True)

val_molecules, val_labels = next(iter(val_dataloader))
val_molecules, val_labels = val_molecules.to(device), val_labels.to(device)

cnn.eval()
val_labels_pred = cnn(val_molecules)

loss_function = nn.CrossEntropyLoss()
loss = loss_function(val_labels_pred, val_labels.float())
print(loss)


In [None]:
from sklearn.metrics import f1_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import torch
# Calculate F1 score

actual_labels = val_labels.detach().cpu().numpy()
actual_labels = np.argmax(actual_labels,axis=1)
predictions = val_labels_pred.softmax(axis=1).detach().cpu().numpy()
predictions = np.argmax(predictions,axis=1)

f1 = f1_score(actual_labels, predictions, average='weighted')
print("F1 Score:", f1)

i0 = np.where(np.abs(predictions - actual_labels) >0 )[0]
print('Total data:', actual_labels.shape[0])
print('% of missclassified data', 100*(i0.shape[0]/actual_labels.shape[0]))

In [None]:
# Generate confusion matrix
cm = confusion_matrix(actual_labels, predictions)

# Plotting the confusion matrix
plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()


# Plot the filters #

In [None]:
import matplotlib
import matplotlib.pyplot as plt

def plot_filter(filter_params):
  f_min, f_max = filter_params.min(), filter_params.max()
  filters = (filter_params - f_min) / (f_max - f_min)

  num_filters,n_in_ch,fltr_x,fltr_y = filters.shape
  print(filters.shape)

  num_columns = n_in_ch
  # num_rows = np.ceil(num_filters / num_columns).astype(int)
  num_rows = np.ceil((num_filters*n_in_ch) / num_columns).astype(int)
  fig, axes = plt.subplots(num_rows, num_columns, figsize=(n_in_ch*1.5, num_filters*1.5))
  axes = axes.flatten()

  ix = 0
  for i in range(num_filters):
    for j in range(n_in_ch):
      # For each filter, plot its weights
        img = np.squeeze(filters[i,j])

        ax = axes[ix]
        ax.imshow(img, cmap='gray')
        ax.axis('off')
        ix += 1

  # Adjust layout
  plt.tight_layout()
  plt.show()

In [None]:
for name, param in cnn.named_parameters():
    # print(name, param.size())
    if 'conv' in name and 'weight' in name:
      print(name,param.size())
      plot_filter(param.detach().cpu().numpy())

# Motifs #
**Paper:** Convolutional neural network based on SMILES representation of compounds for detecting chemical motif
[link](https://doi.org/10.1186/s12859-018-2523-5)

In [None]:
from torch import nn, Tensor

class VerboseExecution(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

        # Register a hook for each layer
        for name, layer in self.model.named_children():
            layer.__name__ = name
            layer.register_forward_hook(
                lambda layer, _, output: print(f"{layer.__name__}: {output.shape}")
            )

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

In [None]:
from typing import Dict, Iterable, Callable

class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        self.model = model
        self.layers = layers
        self._features = {layer: torch.empty(0) for layer in layers}

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            layer.register_forward_hook(self.save_outputs_hook(layer_id))

    def save_outputs_hook(self, layer_id: str) -> Callable:
        def fn(_, __, output):
            self._features[layer_id] = output
        return fn

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        _ = self.model(x)
        return self._features

In [None]:
full_data = CustomDataset(
    data_full['smiles'].to_list(), data_full['NR-AR'].to_list(), max_length, global_alphabet)
full_dataloader = DataLoader(full_data, batch_size=1240, shuffle=False)
print(len(data_full))
print(len(full_data))

molecules, labels = next(iter(full_dataloader))
molecules, labels = molecules.to(device), labels.to(device)
print(molecules.shape,labels.shape)


In [None]:
verbose_cnn = VerboseExecution(cnn)
_ = verbose_cnn(molecules[:1])

In practice, each dimension of SCFP may have a
different value scale, making it difficult to compare
across dimensions for identifying large-contribution fil-
ters. Thus, we normalize SCFP by the following proce-
dure. First, we compute SCFP for all compounds in a
given dataset. Then, we look at the values in the global
max-pooling layer, and calculate their mean and variance
for each filter over all compounds. Finally, we transform
SCFP into Z-scores for each dimension by using the mean
and the variance of the corresponding filter. For detect-
ing chemical motifs, we focus on those dimensions of
SCFP with Z-scores larger than 2.58 (i.e., 99% percentile).
Note that this normalization procedure is only used
for detecting chemical motifs, but not for training and
prediction

In [None]:
cnn_features = FeatureExtractor(cnn, layers=["conv2"])
z = cnn_features(molecules)
z = torch.flatten(z['conv2'], 1)

In [None]:
print(z.shape)
#compute the mean and std for each filter
mean_ = torch.mean(z,0,keepdim=False)
std_ = torch.std(z,0,keepdim=False) + 1E-6
print(mean_.shape,std_.shape)
print(std_)

z_score = torch.abs((z[:2,:] - mean_)/std_)
print(z_score)

max_zscore = torch.max(z_score,axis=0)[0].detach().cpu()

plt.scatter(np.arange(max_zscore.shape[0]),max_zscore)
plt.hlines(2.58,0,max_zscore.shape[0],color='k',ls='--')
plt.xlabel('Filter number')
plt.ylabel('Z score')

In [None]:
i0 = torch.where(max_zscore > 2.58,1.,0.)
for jj0, j0 in enumerate(i0):
  if j0 == 1.:
    indx = int(torch.argmax(z_score[:,jj0]).detach().cpu())
    print('best molecule',indx)
    print('filter number', jj0)
    print(data_full.iloc[indx])

    break


In [None]:
si = 'CC(=O)OC/C=C(\C)CC/C=C(\C)CCC=C(C)C'
data_i = CustomDataset([si],[0],max_length, global_alphabet)
datai_loader = DataLoader(data_i)
mi,_ = next(iter(datai_loader))
print(mi.shape)
zi = cnn_features(mi.to(device))
zi = zi['conv2']
# zi = torch.flatten(zi['conv2'], 1)
print(zi.shape)

# plt.imshow(zi[0,4].detach().cpu(),cmap='gray')
num_filters = zi[0].shape[0]
fig, axes = plt.subplots( 2,int(num_filters/2), figsize=(100, 100))
axes = axes.flatten()

ix = 0
for i in range(num_filters):
    # For each filter, plot its weights
      img = np.squeeze(zi[0,i])

      ax = axes[ix]
      ax.imshow(img.detach().cpu(), cmap='gray')
      ax.axis('off')
      ix += 1

# Adjust layout
plt.tight_layout()
plt.show()


In [None]:
selfies_str = sf.encoder(si)

max_l = sf.len_selfies(selfies_str) + 10

one_hot_encoded, symbol_to_index, vocab_itos, selfies_tokens = smiles_to_one_hot_and_list('CC(=O)OC/C=C(\C)CC/C=C(\C)CCC=C(C)C', max_l)

# Create a list of symbols ordered by the index in vocab_itos
ordered_symbols = [vocab_itos[i] for i in range(len(vocab_itos))]

plt.figure(figsize=(19, 10))
sns.set(font_scale=1.2)  # Increase font scale
heatmap = sns.heatmap(one_hot_encoded, cmap="viridis", cbar=True, linewidths=1,
                      yticklabels=ordered_symbols, xticklabels=selfies_tokens)  # y-axis: ordered_symbols, x-axis: selfies_tokens

# Adjusting ticks and labels
#heatmap.set_xticks(np.arange(len(selfies_tokens)))  # Set x-ticks for selfies_tokens
heatmap.set_xticklabels(selfies_tokens, rotation=90, fontsize=8)  # Set x-tick labels with font size
heatmap.set_yticks(np.arange(len(ordered_symbols)))  # Set y-ticks for ordered_symbols
heatmap.set_yticklabels(ordered_symbols, fontsize=8)  # Set y-tick labels with font size

heatmap.set_xlabel("SELFIES Tokens", fontsize=14)
heatmap.set_ylabel("Symbols in Vocab", fontsize=14)
heatmap.set_title("One-Hot Encoding of Molecules", fontsize=16)
plt.show()