In [1]:
import numpy as np
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import Draw, AllChem, DataStructs
from matplotlib.gridspec import GridSpec
import io
from PIL import Image
import os
# Suppress RDKit warnings
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

import torch
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from typing import Tuple
import numpy as np
from torch_geometric.data import Data
import random
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from rdkit import RDLogger
from rdkit.Chem import RemoveHs
from datetime import datetime

import torch
import numpy as np
from torch_geometric.data import DataLoader
import os
import json
from tqdm import tqdm
import pickle
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from typing import Tuple, List, Optional
import copy
from dataclasses import dataclass

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')



In [2]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            return Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [3]:
def get_fingerprints(mol, radius=2, nBits=2048, use_features=False):
    """Generate Morgan fingerprints (ECFP or FCFP) for a molecule"""
    if mol is None:
        return None
    return AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits, useFeatures=use_features)

def calculate_similarity(fp1, fp2):
    """Calculate Tanimoto similarity between fingerprints"""
    if fp1 is None or fp2 is None:
        return 0.0
    return DataStructs.TanimotoSimilarity(fp1, fp2)

def find_nearest_neighbors(query_mol, mol_list, n_neighbors=3, fp_type='ecfp'):
    """Find nearest neighbors for a query molecule"""
    if query_mol is None:
        return [], []
    
    # Generate query fingerprint
    use_features = (fp_type.lower() == 'fcfp')
    query_fp = get_fingerprints(query_mol, use_features=use_features)
    
    # Calculate similarities
    similarities = []
    for i, mol in enumerate(mol_list):
        if mol is None:
            similarities.append((i, 0.0))
            continue
        
        mol_fp = get_fingerprints(mol, use_features=use_features)
        sim = calculate_similarity(query_fp, mol_fp)
        similarities.append((i, sim))
    
    # Sort by similarity (descending)
    similarities.sort(key=lambda x: x[1], reverse=True)
    
    # Return indices and similarity scores of nearest neighbors
    neighbor_indices = [idx for idx, _ in similarities[:n_neighbors]]
    neighbor_scores = [score for _, score in similarities[:n_neighbors]]
    
    return neighbor_indices, neighbor_scores

def create_nearest_neighbors_visualization(query_molecules, gan_molecules, traditional_molecules, 
                                          n_neighbors=3, save_path=None):
    """
    Create visualization comparing nearest neighbors from GAN and traditional augmentations
    
    Parameters:
    query_molecules (list): List of RDKit molecules to use as queries
    gan_molecules (list): List of GAN-augmented molecules
    traditional_molecules (list): List of traditionally augmented molecules
    n_neighbors (int): Number of nearest neighbors to show
    save_path (str): Path to save the figure
    """
    n_queries = len(query_molecules)
    
    # Create figure
    fig = plt.figure(figsize=(15, 4 * n_queries))
    fig.suptitle("Comparison of Nearest Neighbors (Manual vs. GAN Embeddings)", fontsize=16, y=0.98)
    
    # Filter out None molecules
    gan_molecules = [m for m in gan_molecules if m is not None]
    traditional_molecules = [m for m in traditional_molecules if m is not None]
    
    # For each query molecule
    for i, query_mol in enumerate(query_molecules):
        if query_mol is None:
            continue
        
        # Get query SMILES for display
        query_smiles = Chem.MolToSmiles(query_mol)
        
        # Find nearest neighbors
        trad_indices, trad_scores = find_nearest_neighbors(query_mol, traditional_molecules, n_neighbors)
        gan_indices, gan_scores = find_nearest_neighbors(query_mol, gan_molecules, n_neighbors)
        
        # Get molecules for display
        trad_mols = [traditional_molecules[idx] for idx in trad_indices]
        gan_mols = [gan_molecules[idx] for idx in gan_indices]
        
        # Calculate grid positions
        row = i * 2
        
        # Draw query molecule (left side)
        ax_query_left = plt.subplot2grid((n_queries*2, n_neighbors+1), (row, 0))
        query_img = Draw.MolToImage(query_mol, size=(200, 200))
        ax_query_left.imshow(query_img)
        ax_query_left.set_title(f"Mol{i}: {query_smiles[:10]}... (Manual side)")
        ax_query_left.axis('off')
        
        # Draw traditional neighbors
        for j, (mol, score) in enumerate(zip(trad_mols, trad_scores)):
            ax = plt.subplot2grid((n_queries*2, n_neighbors+1), (row, j+1))
            mol_img = Draw.MolToImage(mol, size=(200, 200))
            ax.imshow(mol_img)
            ax.set_title(f"NN#{j+1}: {score:.2f}")
            ax.axis('off')
        
        # Draw query molecule (right side)
        ax_query_right = plt.subplot2grid((n_queries*2, n_neighbors+1), (row+1, 0))
        ax_query_right.imshow(query_img)
        ax_query_right.set_title(f"Mol{i}: {query_smiles[:10]}... (GAN side)")
        ax_query_right.axis('off')
        
        # Draw GAN neighbors
        for j, (mol, score) in enumerate(zip(gan_mols, gan_scores)):
            ax = plt.subplot2grid((n_queries*2, n_neighbors+1), (row+1, j+1))
            mol_img = Draw.MolToImage(mol, size=(200, 200))
            ax.imshow(mol_img)
            ax.set_title(f"NN#{j+1}: {score:.2f}")
            ax.axis('off')
    
    plt.tight_layout()
    
    # Save figure if path provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig

def sample_diverse_molecules(molecules, n_samples=3, fp_type='ecfp'):
    """Sample diverse molecules from a set using MaxMin algorithm"""
    if not molecules or n_samples <= 0:
        return []
    
    # Filter out None molecules
    valid_molecules = [(i, mol) for i, mol in enumerate(molecules) if mol is not None]
    if not valid_molecules:
        return []
    
    indices, valid_mols = zip(*valid_molecules)
    
    # Generate fingerprints
    use_features = (fp_type.lower() == 'fcfp')
    fps = [get_fingerprints(mol, use_features=use_features) for mol in valid_mols]
    
    # Start with a random molecule
    picked = [np.random.randint(len(valid_mols))]
    
    # Use MaxMin algorithm to pick diverse molecules
    while len(picked) < min(n_samples, len(valid_mols)):
        # Calculate minimum similarity to already picked molecules
        min_sims = []
        for i in range(len(valid_mols)):
            if i in picked:
                min_sims.append(float('inf'))  # Don't pick already selected molecules
                continue
                
            sims = [1.0 - calculate_similarity(fps[i], fps[p]) for p in picked]
            min_sims.append(min(sims))
        
        # Pick molecule with maximum minimum distance
        next_pick = np.argmax(min_sims)
        picked.append(next_pick)
    
    # Return original indices
    return [indices[p] for p in picked]

def main():
    """Generate nearest neighbors visualization using actual molecules from GAN-CL training"""
    import os
    import torch
    import numpy as np
    from rdkit import Chem
    from rdkit.Chem import AllChem, DataStructs
    import pickle
    from tqdm import tqdm
    import matplotlib.pyplot as plt
    
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Create output directory
    output_dir = './visualizations'
    os.makedirs(output_dir, exist_ok=True)
    
    # Step 1: Load embeddings and molecule data
    print("Loading embeddings and molecule data...")
    embedding_path = './embeddings/final_embeddings_molecules_20250309_110249.pkl'
    checkpoint_path = './checkpoints/gan_cl_checkpoint_50.pt'
    
    with open(embedding_path, 'rb') as f:
        data = pickle.load(f)
    
    embeddings = data['embeddings']
    molecule_data = data['molecule_data']
    print(f"Loaded {len(embeddings)} embeddings and {len(molecule_data)} molecule entries")
    
    # Step 2: Rebuild molecules from graph data
    print("Rebuilding molecules from graph data...")
    original_molecules = []
    
    # Use your existing code for rebuilding molecules
    for mol_data in tqdm(molecule_data):
        # Maps for atom and bond features
        atom_list = list(range(1, 119))
        bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        
        try:
            mol = Chem.RWMol()
            
            # Add atoms
            x_cat = mol_data.get('x_cat')
            if x_cat is None:
                original_molecules.append(None)
                continue
                
            for atom_features in x_cat:
                atomic_num = atom_list[atom_features[0]]
                atom = Chem.Atom(atomic_num)
                mol.AddAtom(atom)
            
            # Add bonds
            edge_index = mol_data.get('edge_index')
            edge_attr = mol_data.get('edge_attr')
            
            if edge_index is None or edge_attr is None:
                original_molecules.append(None)
                continue
                
            processed_edges = set()
            
            for i in range(len(edge_index[0])):
                start, end = edge_index[0][i], edge_index[1][i]
                if start >= mol.GetNumAtoms() or end >= mol.GetNumAtoms():
                    continue
                if (start, end) in processed_edges or (end, start) in processed_edges:
                    continue
                    
                bond_type_idx = int(edge_attr[i][0])
                if bond_type_idx < len(bond_list):
                    bond_type = bond_list[bond_type_idx]
                    mol.AddBond(start, end, bond_type)
                    processed_edges.add((start, end))
            
            # Convert to molecule
            final_mol = Chem.Mol(mol)
            
            # Update atom properties
            for atom in final_mol.GetAtoms():
                atom.UpdatePropertyCache(strict=False)
            
            # Try to sanitize (with error handling)
            try:
                Chem.SanitizeMol(final_mol)
            except:
                pass
                
            # Compute 2D coordinates for visualization
            AllChem.Compute2DCoords(final_mol)
            
            original_molecules.append(final_mol)
            
        except Exception as e:
            print(f"Error rebuilding molecule: {e}")
            original_molecules.append(None)
    
    valid_molecules = [mol for mol in original_molecules if mol is not None]
    print(f"Successfully rebuilt {len(valid_molecules)}/{len(molecule_data)} molecules")
    
    # Step 3: Load model to use for GAN augmentations
    print("Loading model checkpoint for GAN augmentations...")
    model_checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    
    # Import the model classes from your module file
    from model_modules import MolecularGANCL, GanClConfig
    
    # Get model config from checkpoint
    model_state_dict = model_checkpoint['model_state_dict']
    
    # You'll need to infer the config or load it from the checkpoint
    # For this example, let's directly create a config with typical values
    config = GanClConfig(
        node_dim=11,  # Update with your actual values
        edge_dim=5,   # Update with your actual values
        hidden_dim=128,
        output_dim=128,
        queue_size=65536,
        momentum=0.999,
        temperature=0.07,
        decay=0.99999,
        dropout_ratio=0.25
    )
    
    # Create and load the model
    model = MolecularGANCL(config)
    model.load_state_dict(model_state_dict)
    model.eval()
    print("Model loaded successfully")
    
    # Step 4: Select diverse query molecules
    print("Selecting diverse query molecules...")
    
    # Filter out invalid molecules
    valid_indices = [i for i, mol in enumerate(original_molecules) if mol is not None]
    
    # Calculate fingerprints for diversity selection
    fps = []
    for idx in valid_indices:
        mol = original_molecules[idx]
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
        fps.append((idx, fp))
    
    # Select diverse molecules using MaxMin algorithm
    n_queries = 3  # Number of query molecules
    selected_indices = []
    
    if fps:
        # Start with a random molecule
        start_idx = np.random.randint(0, len(fps))
        selected_indices.append(fps[start_idx][0])
        
        # Select remaining molecules to maximize diversity
        for _ in range(n_queries - 1):
            max_min_dist = -1
            max_idx = -1
            
            for i, (idx, fp) in enumerate(fps):
                if idx in selected_indices:
                    continue
                    
                # Find minimum distance to already selected molecules
                min_dist = float('inf')
                for sel_idx in selected_indices:
                    sel_fp = next(f for i, f in fps if i == sel_idx)
                    dist = 1 - DataStructs.TanimotoSimilarity(fp, sel_fp)
                    min_dist = min(min_dist, dist)
                
                # Update max-min candidate
                if min_dist > max_min_dist:
                    max_min_dist = min_dist
                    max_idx = idx
            
            if max_idx >= 0:
                selected_indices.append(max_idx)
    
    # Fallback to first few valid molecules if diversity selection fails
    if len(selected_indices) < n_queries:
        additional = n_queries - len(selected_indices)
        remaining = [idx for idx in valid_indices if idx not in selected_indices]
        if remaining:
            selected_indices.extend(remaining[:additional])
    
    # Get query molecules
    query_molecules = [original_molecules[idx] for idx in selected_indices if idx < len(original_molecules)]
    print(f"Selected {len(query_molecules)} diverse query molecules")
    
    # Step 5: Generate GAN and traditional augmentations
    print("Generating GAN augmentations...")
    gan_molecules = []
    
    # Use your model to create GAN augmentations
    for query_mol in query_molecules:
        # Convert mol to PyG data object
        data = mol_to_pyg_data(query_mol)
        if data is None:
            continue
            
        # Apply GAN augmentation
        with torch.no_grad():
            node_scores, edge_scores = model.generator(data)
            perturbed_data = model.drop_graph_elements(data, node_scores, edge_scores)
        
        # Convert back to RDKit mol
        aug_mol = pyg_data_to_mol(perturbed_data)
        if aug_mol is not None:
            gan_molecules.append(aug_mol)
    
    # Create traditional augmentations (using random dropout)
    print("Generating traditional augmentations...")
    traditional_molecules = []
    
    for query_mol in query_molecules:
        # Convert mol to PyG data object
        data = mol_to_pyg_data(query_mol)
        if data is None:
            continue
            
        # Apply random dropout
        dropout_ratio = 0.25
        node_mask = (torch.rand(data.x_cat.size(0), 1) > dropout_ratio).float()
        edge_mask = (torch.rand(data.edge_attr.size(0), 1) > dropout_ratio).float()
        
        # Apply masks
        data.x_cat = data.x_cat * node_mask
        data.x_phys = data.x_phys * node_mask
        data.edge_attr = data.edge_attr * edge_mask
        
        # Convert back to RDKit mol
        aug_mol = pyg_data_to_mol(data)
        if aug_mol is not None:
            traditional_molecules.append(aug_mol)
    
    print(f"Generated {len(gan_molecules)} GAN augmentations and {len(traditional_molecules)} traditional augmentations")
    
    # Step 6: Create visualization
    print("Creating nearest neighbors visualization...")
    fig = create_nearest_neighbors_visualization(
        query_molecules,
        gan_molecules,
        traditional_molecules,
        n_neighbors=3,
        save_path=os.path.join(output_dir, 'nearest_neighbors_comparison.png')
    )
    
    plt.show()
    print(f"Visualization saved to {os.path.join(output_dir, 'nearest_neighbors_comparison.png')}")

def mol_to_pyg_data(mol):
    """Convert RDKit molecule to PyG data object"""
    if mol is None:
        return None
        
    from torch_geometric.data import Data
    import torch
    
    # Lists for atom and bond features
    atom_list = list(range(1, 119))
    chirality_list = [
        Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
        Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
        Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
        Chem.rdchem.ChiralType.CHI_OTHER
    ]
    bond_list = [
        Chem.rdchem.BondType.SINGLE,
        Chem.rdchem.BondType.DOUBLE, 
        Chem.rdchem.BondType.TRIPLE,
        Chem.rdchem.BondType.AROMATIC
    ]
    
    try:
        # Get atom features
        atom_feats = []
        phys_feats = []
        
        for atom in mol.GetAtoms():
            # Basic features
            atom_feat = [
                atom_list.index(atom.GetAtomicNum()),
                chirality_list.index(atom.GetChiralTag())
            ]
            
            # Physical features
            phys_feat = [0.0] * 8  # Add placeholder for physical features
            
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)
        
        # Get bond features
        edge_index = [[], []]
        edge_attr = []
        
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            
            # Add edges in both directions
            edge_index[0].extend([start, end])
            edge_index[1].extend([end, start])
            
            # Bond features
            bond_type = bond_list.index(bond.GetBondType())
            
            # Simple features for this example
            attr = [bond_type, 0, 0, 0, 0]  # Add placeholders for other features
            
            edge_attr.extend([attr, attr])
        
        # Convert to tensors
        x_cat = torch.tensor(atom_feats, dtype=torch.long)
        x_phys = torch.tensor(phys_feats, dtype=torch.float)
        edge_index = torch.tensor(edge_index, dtype=torch.long)
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
        
        return Data(
            x_cat=x_cat,
            x_phys=x_phys,
            edge_index=edge_index,
            edge_attr=edge_attr,
            num_nodes=x_cat.size(0)
        )
        
    except Exception as e:
        print(f"Error converting molecule to PyG data: {e}")
        return None

def pyg_data_to_mol(data):
    """Convert PyG data back to RDKit molecule"""
    if data is None:
        return None
        
    from rdkit import Chem
    
    # Lists for atom and bond features
    atom_list = list(range(1, 119))
    bond_list = [
        Chem.rdchem.BondType.SINGLE,
        Chem.rdchem.BondType.DOUBLE, 
        Chem.rdchem.BondType.TRIPLE,
        Chem.rdchem.BondType.AROMATIC
    ]
    
    try:
        # Create empty molecule
        mol = Chem.RWMol()
        
        # Add atoms
        for atom_feat in data.x_cat:
            if torch.sum(atom_feat) == 0:  # Skip masked atoms
                continue
                
            atomic_num = atom_list[int(atom_feat[0])]
            atom = Chem.Atom(atomic_num)
            mol.AddAtom(atom)
        
        # Add bonds
        processed_edges = set()
        
        for i in range(data.edge_index.size(1)):
            start, end = int(data.edge_index[0, i]), int(data.edge_index[1, i])
            
            if start >= mol.GetNumAtoms() or end >= mol.GetNumAtoms():
                continue
                
            if (start, end) in processed_edges or (end, start) in processed_edges:
                continue
                
            # Skip bonds with zero features (masked)
            if torch.sum(data.edge_attr[i]) == 0:
                continue
                
            bond_type_idx = int(data.edge_attr[i, 0])
            if bond_type_idx < len(bond_list):
                bond_type = bond_list[bond_type_idx]
                mol.AddBond(start, end, bond_type)
                processed_edges.add((start, end))
        
        # Convert to molecule
        final_mol = Chem.Mol(mol)
        
        # Update properties and sanitize
        for atom in final_mol.GetAtoms():
            atom.UpdatePropertyCache(strict=False)
            
        try:
            Chem.SanitizeMol(final_mol)
        except:
            pass
            
        # Compute 2D coordinates for visualization
        Chem.AllChem.Compute2DCoords(final_mol)
        
        return final_mol
        
    except Exception as e:
        print(f"Error converting PyG data to molecule: {e}")
        return None

if __name__ == "__main__":
    main()

Loading embeddings and molecule data...
Loaded 9937 embeddings and 9937 molecule entries
Rebuilding molecules from graph data...


100%|█████████████████████████████████████████████████████████████████████████████| 9937/9937 [00:26<00:00, 374.41it/s]
  model_checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Successfully rebuilt 9937/9937 molecules
Loading model checkpoint for GAN augmentations...


ImportError: cannot import name 'MolecularGANCL' from 'model_modules' (D:\PhD\Chapter3.1_GAN_Explainbility\GAN_CL_XAI\GitHub2\model_modules.py)