In [None]:
import torch
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from typing import Tuple
import numpy as np
from torch_geometric.data import Data
import random
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from rdkit import RDLogger
from rdkit.Chem import RemoveHs
from datetime import datetime

import torch
import numpy as np
from torch_geometric.data import DataLoader
import os
import json
from tqdm import tqdm
import pickle
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from typing import Tuple, List, Optional
import copy
from dataclasses import dataclass

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

In [None]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE,
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []

            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)

            # LogP contribution
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)

            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])

            return atom_feat, phys_feat

        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []

        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)

        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)

        return x, phys

    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        row, col, edge_feat = [], [], []

        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()

                # Add edges in both directions
                row += [start, end]
                col += [end, start]

                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())

                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]

                edge_feat.extend([feat, feat])

            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)

        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            return Data(
                x_cat=x_cat,
                x_phys=x_phys,
                edge_index=edge_index,
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [None]:
class MemoryQueue:
    """Memory queue with temporal decay for contrastive learning"""
    def __init__(self, size: int, dim: int, decay: float = 0.99999):
        self.size = size
        self.dim = dim
        self.decay = decay
        self.ptr = 0
        self.full = False

        # Initialize queue
        self.queue = nn.Parameter(F.normalize(torch.randn(size, dim), dim=1), requires_grad=False)
        self.queue_age = nn.Parameter(torch.zeros(size), requires_grad=False)

#         self.register_buffer("queue", torch.randn(size, dim))
#         self.register_buffer("queue_age", torch.zeros(size))  # Track age of each entry
        self.queue = F.normalize(self.queue, dim=1)

    def update_queue(self, keys: torch.Tensor):
        """Update queue with new keys"""
        batch_size = keys.shape[0]

        # Increment age of all entries
        self.queue_age += 1

        # Add new keys
        if self.ptr + batch_size <= self.size:
            self.queue[self.ptr:self.ptr + batch_size] = keys
            self.queue_age[self.ptr:self.ptr + batch_size] = 0
        else:
            # Handle overflow
            rem = self.size - self.ptr
            self.queue[self.ptr:] = keys[:rem]
            self.queue[:batch_size-rem] = keys[rem:]
            self.queue_age[self.ptr:] = 0
            self.queue_age[:batch_size-rem] = 0
            self.full = True

        self.ptr = (self.ptr + batch_size) % self.size

    def get_decay_weights(self) -> torch.Tensor:
        """Get temporal decay weights for queue entries"""
        return self.decay ** self.queue_age

    def compute_contrastive_loss(self, query: torch.Tensor, positive_key: torch.Tensor,
                                temperature: float = 0.07) -> torch.Tensor:
        """Compute contrastive loss with temporal decay"""
        # Normalize embeddings
        query = F.normalize(query, dim=1)
        positive_key = F.normalize(positive_key, dim=1)
        queue = F.normalize(self.queue, dim=1)

        # Compute logits
        l_pos = torch.einsum('nc,nc->n', [query, positive_key]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [query, queue.T])

        # Apply temporal decay to negative samples
        decay_weights = self.get_decay_weights()
        l_neg = l_neg * decay_weights.unsqueeze(0)

        # Temperature scaling
        logits = torch.cat([l_pos, l_neg], dim=1) / temperature
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=query.device)

        return F.cross_entropy(logits, labels)

class GraphGenerator(nn.Module):
    """Generator network with proper feature handling"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128):
        super().__init__()

        # Node feature processing
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Edge feature processing
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)

        # Importance prediction layers
        self.node_importance = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )

        self.edge_importance = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()

        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)

        return x_cat, x_phys

    def forward(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)

        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)

        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type

        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)

        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)

        # Predict importance scores
        node_scores = self.node_importance(x)

        # Edge scores using both connected nodes
        edge_features = torch.cat([
            x[edge_index[0]],
            x[edge_index[1]]
        ], dim=-1)
        edge_scores = self.edge_importance(edge_features)

        return node_scores, edge_scores

def get_model_config(dataset):
    """Get model configuration based on dataset features"""
    sample_data = dataset[0]

    # Calculate input dimensions
    node_dim = sample_data.x_cat.shape[1] + sample_data.x_phys.shape[1]
    edge_dim = sample_data.edge_attr.shape[1]

    config = GanClConfig(
        node_dim=node_dim,
        edge_dim=edge_dim,
        hidden_dim=128,
        output_dim=128,
        queue_size=65536,
        momentum=0.999,
        temperature=0.07,
        decay=0.99999,
        dropout_ratio=0.25
    )

    return config

class GraphDiscriminator(nn.Module):
    """Discriminator/Encoder network"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()

        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)

        # Projection head for contrastive learning
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()

        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)

        return x_cat, x_phys

    def forward(self, data):
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)

        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)

        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        batch = data.batch

        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)

        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)

        # Global pooling
        x = global_mean_pool(x, batch)

        # Projection
        x = self.projection(x)

        return x

@dataclass
class GanClConfig:
    """Configuration for GAN-CL training"""
    node_dim: int
    edge_dim: int
    hidden_dim: int = 128
    output_dim: int = 128
    queue_size: int = 65536
    momentum: float = 0.999
    temperature: float = 0.07
    decay: float = 0.99999
    dropout_ratio: float = 0.25

class MolecularGANCL(nn.Module):
    """Combined GAN and Contrastive Learning framework"""
    def __init__(self, config: GanClConfig):
        super().__init__()
        self.config = config

        # Add weight initialization
        def init_weights(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                m.bias.data.fill_(0.01)

        # Initialize networks
        self.generator = GraphGenerator(
            config.node_dim,
            config.edge_dim,
            config.hidden_dim * 2
        )

        self.encoder = GraphDiscriminator(
            config.node_dim,
            config.edge_dim,
            config.hidden_dim,
            config.output_dim
        )
        self.encoder.apply(init_weights)

        # Modified loss weights
        self.contrastive_weight = 1.0
        self.adversarial_weight = 0.1  # Increased from 0.05
        self.similarity_weight = 0.01  # Decreased from 0.1

        # Temperature annealing
        self.initial_temperature = 0.1
        self.min_temperature = 0.05

        # Create momentum encoder
        self.momentum_encoder = copy.deepcopy(self.encoder)
        for param in self.momentum_encoder.parameters():
            param.requires_grad = False

        # Initialize memory queue
        self.memory_queue = MemoryQueue(
            config.queue_size,
            config.output_dim,
            config.decay
        )

    @torch.no_grad()
    def _momentum_update(self):
        """Update momentum encoder"""
        for param_q, param_k in zip(self.encoder.parameters(),
                                  self.momentum_encoder.parameters()):
            param_k.data = self.config.momentum * param_k.data + \
                          (1 - self.config.momentum) * param_q.data

    def drop_graph_elements(self, data, node_scores: torch.Tensor,
                          edge_scores: torch.Tensor) -> Data:
        """Apply dropout to graph based on importance scores"""
        # Select elements to keep based on scores and dropout ratio
#         node_mask = (node_scores < self.config.dropout_ratio).float()
#         edge_mask = (edge_scores < self.config.dropout_ratio).float()

        node_mask = (torch.rand_like(node_scores) > self.config.dropout_ratio).float()
        edge_mask = (torch.rand_like(edge_scores) > self.config.dropout_ratio).float()

        # Apply masks
        x_cat_new = data.x_cat * node_mask
        x_phys_new = data.x_phys * node_mask
        edge_attr_new = data.edge_attr * edge_mask

        # Create new graph data object
        return Data(
            x_cat=x_cat_new,
            x_phys=x_phys_new,
            edge_index=data.edge_index,
            edge_attr=edge_attr_new,
            batch=data.batch
        )

    def get_temperature(self, epoch, total_epochs):
        """Anneal temperature during training"""
        progress = epoch / total_epochs
        return max(self.initial_temperature * (1 - progress), self.min_temperature)

    def forward(self, data, epoch=0, total_epochs=50):
        # Get current temperature
        temperature = self.get_temperature(epoch, total_epochs)

        # Get importance scores from generator
        node_scores, edge_scores = self.generator(data)

        # Create perturbed graph
        perturbed_data = self.drop_graph_elements(data, node_scores, edge_scores)

        # Get embeddings
        query_emb = self.encoder(perturbed_data)
        with torch.no_grad():
            key_emb = self.momentum_encoder(data)
            original_emb = self.encoder(data).detach()

        # Compute losses with modified weights
        contrastive_loss = self.memory_queue.compute_contrastive_loss(
            query_emb, key_emb, temperature
        ) * self.contrastive_weight

        adversarial_loss = -F.mse_loss(query_emb, original_emb) * self.adversarial_weight
        similarity_loss = F.mse_loss(query_emb, original_emb) * self.similarity_weight

        return contrastive_loss, adversarial_loss, similarity_loss

    def get_embeddings(self, data) -> torch.Tensor:
        """Get embeddings for downstream tasks"""
        with torch.no_grad():
            return self.encoder(data)

In [None]:
def extract_molecule_metadata(dataset):
    """Extract metadata from PyG graph data without relying on SMILES strings"""
    from rdkit import Chem
    from rdkit.Chem import Descriptors
    from tqdm import tqdm
    import numpy as np
    import torch
    import networkx as nx
    from collections import defaultdict

    metadata = []

    for i, data in enumerate(tqdm(dataset, desc="Extracting molecule metadata")):
        # Set graph ID
        mol_id = f"molecule_{i}"

        # Initialize empty dictionaries for metadata
        properties = {}
        features = {}
        functional_groups = {}
        ring_info = {"ring_counts": {}, "ring_sizes": {}}

        # Extract basic graph properties directly from the PyG data
        if hasattr(data, 'num_nodes') and hasattr(data, 'edge_index'):
            try:
                # Convert to networkx graph for analysis
                G = to_networkx(data)

                # Calculate graph-level properties
                num_edges = data.edge_index.size(1) // 2  # Count unique edges
                properties = {
                    "num_nodes": data.num_nodes,
                    "num_edges": num_edges,
                    "avg_node_degree": 2 * num_edges / data.num_nodes if data.num_nodes > 0 else 0
                }

                # Calculate average path length if graph is connected
                if nx.is_connected(G):
                    try:
                        properties["avg_path_length"] = nx.average_shortest_path_length(G)
                    except:
                        properties["avg_path_length"] = 0.0
                else:
                    properties["avg_path_length"] = 0.0

                # Add more sophisticated graph properties
                try:
                    properties["clustering_coefficient"] = nx.average_clustering(G)
                except:
                    properties["clustering_coefficient"] = 0.0

                try:
                    properties["graph_diameter"] = nx.diameter(G) if nx.is_connected(G) else 0
                except:
                    properties["graph_diameter"] = 0

                try:
                    properties["assortativity"] = nx.degree_assortativity_coefficient(G)
                except:
                    properties["assortativity"] = 0.0

                # Graph features
                features = {
                    "is_connected": nx.is_connected(G),
                    "num_connected_components": nx.number_connected_components(G),
                    "has_cycles": not nx.is_tree(G),
                    "max_degree": max(dict(G.degree()).values()) if G.number_of_nodes() > 0 else 0,
                    "density": nx.density(G),
                    "is_bipartite": nx.is_bipartite(G) if G.number_of_nodes() > 0 else False
                }

                # Get centrality measures
                if G.number_of_nodes() > 0:
                    try:
                        degree_centrality = nx.degree_centrality(G)
                        features["max_centrality"] = max(degree_centrality.values()) if degree_centrality else 0
                        features["avg_centrality"] = sum(degree_centrality.values()) / len(degree_centrality) if degree_centrality else 0
                    except:
                        features["max_centrality"] = 0
                        features["avg_centrality"] = 0
                else:
                    features["max_centrality"] = 0
                    features["avg_centrality"] = 0

                # Analyze node features if available
                if hasattr(data, 'x_cat') and hasattr(data, 'x_phys'):
                    # Atomic element distribution (from x_cat)
                    atom_types = {}
                    if data.x_cat.size(1) > 0:
                        for i in range(data.num_nodes):
                            atom_type = int(data.x_cat[i, 0].item())
                            atom_types[atom_type] = atom_types.get(atom_type, 0) + 1

                    features["atom_type_distribution"] = atom_types

                    # Physical property statistics (from x_phys)
                    if data.x_phys.size(1) > 0:
                        phys_means = data.x_phys.mean(dim=0).tolist()
                        phys_stds = data.x_phys.std(dim=0).tolist()

                        # Map indices to meaningful property names for the first few common properties
                        phys_prop_names = ['contrib_mw', 'contrib_logp', 'formal_charge',
                                        'hybridization', 'is_aromatic', 'num_h', 'valence', 'degree']

                        for idx, name in enumerate(phys_prop_names):
                            if idx < len(phys_means):
                                properties[f"avg_{name}"] = phys_means[idx]
                                properties[f"std_{name}"] = phys_stds[idx]

                # Cycle analysis
                cycles = list(nx.cycle_basis(G))
                cycle_count = len(cycles)
                ring_info["ring_counts"]["total"] = cycle_count

                # Count rings by size
                ring_sizes = defaultdict(int)
                for cycle in cycles:
                    size = len(cycle)
                    ring_sizes[str(size)] = ring_sizes.get(str(size), 0) + 1

                # Ensure we have entries for common ring sizes
                for size in range(3, 11):
                    if str(size) not in ring_sizes:
                        ring_sizes[str(size)] = 0

                ring_info["ring_sizes"] = dict(ring_sizes)

                # Estimate ring types
                ring_info["ring_counts"]["single"] = 0
                ring_info["ring_counts"]["fused"] = 0

                # Identify single vs fused rings by checking for shared nodes
                if cycles:
                    # Build a mapping of nodes to cycles they belong to
                    node_to_cycles = defaultdict(list)
                    for cycle_idx, cycle in enumerate(cycles):
                        for node in cycle:
                            node_to_cycles[node].append(cycle_idx)

                    # Count single rings (no shared nodes with other rings)
                    shared_cycles = set()
                    for node, cycle_list in node_to_cycles.items():
                        if len(cycle_list) > 1:
                            for c in cycle_list:
                                shared_cycles.add(c)

                    ring_info["ring_counts"]["single"] = cycle_count - len(shared_cycles)
                    ring_info["ring_counts"]["fused"] = len(shared_cycles)

                # Edge feature analysis if available
                if hasattr(data, 'edge_attr') and data.edge_attr.size(0) > 0:
                    # Analyze bond types (assuming first dimension is bond type)
                    bond_types = {}
                    for i in range(data.edge_attr.size(0)):
                        if data.edge_attr.size(1) > 0:
                            bond_type = int(data.edge_attr[i, 0].item())
                            bond_types[bond_type] = bond_types.get(bond_type, 0) + 1

                    # Divide by 2 since each bond is counted twice in undirected graph
                    for bt in bond_types:
                        bond_types[bt] = bond_types[bt] // 2

                    functional_groups["bond_types"] = bond_types

                    # Count functional group proxies based on patterns in the graph
                    # This is just an estimate since we don't have chemical information
                    conjugated_bonds = 0
                    for i in range(data.edge_attr.size(0)):
                        if data.edge_attr.size(1) > 1 and data.edge_attr[i, 2].item() > 0:  # IsConjugated flag
                            conjugated_bonds += 1

                    functional_groups["conjugated_bonds"] = conjugated_bonds // 2

            except Exception as e:
                # If any error occurs during analysis, use minimal information
                print(f"Error analyzing graph {i}: {e}")

        metadata.append({
            "graph_id": mol_id,
            "properties": properties,
            "features": features,
            "functional_groups": functional_groups,
            "ring_info": ring_info
        })

    return metadata

def to_networkx(data):
    """Convert PyG data to networkx graph for analysis"""
    import networkx as nx

    G = nx.Graph()

    # Add nodes
    for i in range(data.num_nodes):
        G.add_node(i)

    # Add edges (removing duplicates and self-loops)
    edge_index = data.edge_index.cpu().numpy()
    edges = set()
    for i in range(edge_index.shape[1]):
        u, v = edge_index[0, i], edge_index[1, i]
        if u != v and (u, v) not in edges and (v, u) not in edges:
            G.add_edge(u, v)
            edges.add((u, v))

    return G

In [None]:
def save_embedding_file(embeddings, molecule_indices, training_info, model_config, filepath):
    """Save embeddings with training metadata"""
    data = {
        "embeddings": embeddings,
        "molecule_indices": molecule_indices,
        "training_info": training_info,
        "model_config": {k: v for k, v in model_config.__dict__.items()
                         if not k.startswith('_') and not callable(v)}
    }

    with open(filepath, 'wb') as f:
        pickle.dump(data, f)

def save_embeddings_with_molecules(embeddings, dataset, filepath):
    """Save embeddings with corresponding molecule information and graph-level properties"""
    # Create a list to store molecule data
    molecule_data = []

    # Extract important info from each molecule in the dataset
    for data in dataset:
        # Create a dictionary with basic graph properties
        mol_info = {
            "num_nodes": data.num_nodes,
            "edge_index": data.edge_index.tolist() if hasattr(data, 'edge_index') else None,
            "x_cat": data.x_cat.tolist() if hasattr(data, 'x_cat') else None,
            "x_phys": data.x_phys.tolist() if hasattr(data, 'x_phys') else None,
            "edge_attr": data.edge_attr.tolist() if hasattr(data, 'edge_attr') else None
        }

        # Calculate additional graph properties if possible
        try:
            if hasattr(data, 'edge_index') and hasattr(data, 'num_nodes'):
                # Graph density
                num_edges = len(data.edge_index[0]) // 2  # Undirected edges counted once
                max_edges = data.num_nodes * (data.num_nodes - 1) // 2
                density = num_edges / max_edges if max_edges > 0 else 0
                mol_info["graph_density"] = density

                # Average degree
                avg_degree = num_edges * 2 / data.num_nodes if data.num_nodes > 0 else 0
                mol_info["avg_degree"] = avg_degree

                # Count atom types if available
                if hasattr(data, 'x_cat') and data.x_cat is not None:
                    atom_types = {}
                    for atom in data.x_cat:
                        atom_type = int(atom[0])
                        atom_types[atom_type] = atom_types.get(atom_type, 0) + 1
                    mol_info["atom_type_counts"] = atom_types

                # Count bond types if available
                if hasattr(data, 'edge_attr') and data.edge_attr is not None:
                    bond_types = {}
                    for bond in data.edge_attr:
                        bond_type = int(bond[0])
                        bond_types[bond_type] = bond_types.get(bond_type, 0) + 1
                    mol_info["bond_type_counts"] = bond_types
        except:
            # If calculation fails, continue without these properties
            pass

        molecule_data.append(mol_info)

    # Save both embeddings and molecule data
    with open(filepath, 'wb') as f:
        pickle.dump({
            'embeddings': embeddings,
            'molecule_data': molecule_data,
            'graph_properties': True  # Flag to indicate enhanced properties are stored
        }, f)

    print(f"Saved embeddings and molecule data with graph properties to {filepath}")


def save_embeddings(embeddings, labels, filepath):
    """Save embeddings and corresponding labels"""
    with open(filepath, 'wb') as f:
        pickle.dump({
            'embeddings': embeddings,
            'labels': labels
        }, f)

def save_encoder(encoder, save_path, info=None):
    """Save encoder model for downstream tasks"""
    save_dict = {
        'encoder_state_dict': encoder.state_dict(),
        'model_info': info or {}
    }
    torch.save(save_dict, save_path)

def load_encoder(model_path, device='cpu'):
    """Load saved encoder model"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    return encoder

def train_gan_cl(train_loader, config, dataset, device='cuda',
                save_dir='./checkpoints',
                embedding_dir='./embeddings'):
    """Main training function for GAN-CL with embedding storage for bias analysis"""

    # Create directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(embedding_dir, exist_ok=True)
    metadata_dir = os.path.join(embedding_dir, 'metadata')
    os.makedirs(metadata_dir, exist_ok=True)
    encoder_dir = os.path.join(save_dir, 'encoders')
    os.makedirs(encoder_dir, exist_ok=True)

    # Extract and save molecule metadata (once, before training)
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    metadata = extract_molecule_metadata(dataset)
    with open(os.path.join(metadata_dir, f'molecule_metadata_{timestamp}.pkl'), 'wb') as f:
        pickle.dump(metadata, f)

    # Save molecule indices for consistent order
    molecule_indices = list(range(len(metadata)))

    # Initialize model
    model = MolecularGANCL(config).to(device)

    # Get pre-training embeddings before any training
    print("Extracting pre-training embeddings...")
    model.eval()
    pre_training_embeddings = []
    with torch.no_grad():
        for batch in tqdm(train_loader, desc="Pre-training embeddings"):
            batch = batch.to(device)
            embeddings = model.get_embeddings(batch)
            pre_training_embeddings.append(embeddings.cpu())

    pre_training_embeddings = torch.cat(pre_training_embeddings, dim=0).numpy()

    # Save pre-training embeddings
    pre_training_info = {
        "stage": "pre",
        "epoch": 0,
        "timestamp": timestamp,
        "loss_values": {"contrastive": 0, "adversarial": 0, "similarity": 0, "total": 0}
    }

    save_embedding_file(
        pre_training_embeddings,
        molecule_indices,
        pre_training_info,
        config,
        os.path.join(embedding_dir, f'pre_training_embeddings_{timestamp}.pkl')
    )

    model.train()

    # Rest of your training code remains the same...
    # Initialize optimizers
    optimizer_encoder = torch.optim.Adam(model.encoder.parameters(), lr=3e-4)
    optimizer_generator = torch.optim.Adam(model.generator.parameters(), lr=1e-4)

    # Save initial model info
    model_info = {
        'node_dim': config.node_dim,
        'edge_dim': config.edge_dim,
        'hidden_dim': config.hidden_dim,
        'output_dim': config.output_dim,
        'training_config': config.__dict__
    }

    # Training phases as before...
    best_loss = float('inf')

    # Training metrics
    metrics = {
        'contrastive_losses': [],
        'adversarial_losses': [],
        'similarity_losses': [],
        'total_losses': []
    }

    # Training phases
    print("Phase 1: Pretraining Contrastive Learning...")
    pretrain_epochs = 10
    for epoch in range(pretrain_epochs):
        contrastive_epoch_loss = 0

        for batch in tqdm(train_loader, desc=f'Pretrain Epoch {epoch+1}/{pretrain_epochs}'):
            batch = batch.to(device)

            # Forward pass (without generator)
            query_emb = model.encoder(batch)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)

            # Compute contrastive loss
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )

            # Update encoder
            optimizer_encoder.zero_grad()
            contrastive_loss.backward()
            optimizer_encoder.step()

            # Update momentum encoder
            model._momentum_update()

            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())

            contrastive_epoch_loss += contrastive_loss.item()

        avg_loss = contrastive_epoch_loss / len(train_loader)
        metrics['contrastive_losses'].append(avg_loss)
        print(f'Pretrain Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}')

        # Save pretrained checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'loss': avg_loss,
            }, os.path.join(save_dir, f'pretrain_checkpoint_{epoch+1}.pt'))

    print("\nPhase 2: Training GAN-CL...")
    train_epochs = 50
#     train_epochs = 10
    for epoch in range(train_epochs):
        epoch_losses = {
            'contrastive': 0,
            'adversarial': 0,
            'similarity': 0,
            'total': 0
        }

        for batch in tqdm(train_loader, desc=f'Train Epoch {epoch+1}/{train_epochs}'):
            batch = batch.to(device)

            # Step 1: Train Encoder
            optimizer_encoder.zero_grad()

            # Get importance scores from generator
            with torch.no_grad():
                node_scores, edge_scores = model.generator(batch)

            # Create perturbed graph
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)

            # Get embeddings
            query_emb = model.encoder(perturbed_data)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
                original_emb = model.encoder(batch).detach()

            # Compute losses for encoder
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            similarity_loss = F.mse_loss(query_emb, original_emb)

            # Total loss for encoder
            encoder_loss = contrastive_loss + 0.1 * similarity_loss

            # Update encoder
            encoder_loss.backward()
            optimizer_encoder.step()

            # Update momentum encoder
            model._momentum_update()

            # Step 2: Train Generator
            optimizer_generator.zero_grad()

            # Get new embeddings for adversarial loss
            node_scores, edge_scores = model.generator(batch)
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)

            with torch.no_grad():
                original_emb = model.encoder(batch)
            perturbed_emb = model.encoder(perturbed_data)

            # Compute adversarial loss
            adversarial_loss = -F.mse_loss(perturbed_emb, original_emb)

            # Update generator
            adversarial_loss.backward()
            optimizer_generator.step()

            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())

            # Update metrics
            epoch_losses['contrastive'] += contrastive_loss.item()
            epoch_losses['adversarial'] += adversarial_loss.item()
            epoch_losses['similarity'] += similarity_loss.item()
            epoch_losses['total'] += encoder_loss.item()

        # Average losses
        for k in epoch_losses:
            epoch_losses[k] /= len(train_loader)
            metrics[f'{k}_losses'].append(epoch_losses[k])

        print(f'Epoch {epoch+1}, Losses: {epoch_losses}')

        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'optimizer_generator_state_dict': optimizer_generator.state_dict(),
                'losses': epoch_losses,
            }, os.path.join(save_dir, f'gan_cl_checkpoint_{epoch+1}.pt'))

    # Extract and save embeddings periodically
    if (epoch + 1) % 10 == 0:
        model.eval()
        checkpoint_embeddings = []

        with torch.no_grad():
            for batch in train_loader:
                batch = batch.to(device)
                embeddings = model.get_embeddings(batch)
                checkpoint_embeddings.append(embeddings.cpu())

        checkpoint_embeddings = torch.cat(checkpoint_embeddings, dim=0).numpy()

        # Save checkpoint embeddings with training info
        checkpoint_info = {
            "stage": f"epoch_{epoch+1}",
            "epoch": epoch + 1,
            "timestamp": datetime.now().strftime('%Y%m%d_%H%M%S'),
            "loss_values": epoch_losses
        }

        save_embedding_file(
            checkpoint_embeddings,
            molecule_indices,
            checkpoint_info,
            config,
            os.path.join(embedding_dir, f'epoch_{epoch+1}_embeddings_{timestamp}.pkl')
        )

        model.train()

        model.train()

    # Extract and save post-training embeddings at the end
    print("Extracting post-training embeddings...")
    model.eval()
    post_training_embeddings = []

    with torch.no_grad():
        for batch in tqdm(train_loader, desc="Post-training embeddings"):
            batch = batch.to(device)
            embeddings = model.get_embeddings(batch)
            post_training_embeddings.append(embeddings.cpu())

    post_training_embeddings = torch.cat(post_training_embeddings, dim=0).numpy()

    # Save post-training embeddings
    post_training_info = {
        "stage": "post",
        "epoch": train_epochs,
        "timestamp": datetime.now().strftime('%Y%m%d_%H%M%S'),
        "loss_values": epoch_losses
    }

    save_embedding_file(
        post_training_embeddings,
        molecule_indices,
        post_training_info,
        config,
        os.path.join(embedding_dir, f'post_training_embeddings_{timestamp}.pkl')
    )

    return model, metrics

In [None]:
def main():
     # Enable anomaly detection during development
    torch.autograd.set_detect_anomaly(True)
    # Your existing data loading code here
    torch.manual_seed(42)
    np.random.seed(42)

    print("Starting data loading...")
    extractor = MolecularFeatureExtractor()
#     smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"
    smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-10m-clean_test10k.txt"
#     smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-10m-clean_test50k.txt"

    dataset = []
    failed_smiles = []

    with open(smiles_file, 'r') as f:
        for line in f:
            smiles = line.strip()
            data = extractor.process_molecule(smiles)
            if data is not None:
                dataset.append(data)
            else:
                failed_smiles.append(smiles)

    print(f"1. Loaded dataset with {len(dataset)} graphs.")
    print(f"2. Failed SMILES count: {len(failed_smiles)}")

    if not dataset:
        print("No valid graphs generated.")
        return None

    # Setup training
    batch_size = 32
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    print(f"3. Created DataLoader with {len(train_loader.dataset)} graphs")

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"4. Using device: {device}")

    # Get configuration based on dataset
    config = get_model_config(dataset)

    # Train model
    print("5. Starting GAN-CL training...")
    model, metrics = train_gan_cl(
        train_loader,
        config,
        dataset,  # Pass the dataset for metadata extraction
        device=device,
        save_dir='./checkpoints',
        embedding_dir='./embeddings'
    )


    print("6. Training completed!")

    # Extract embeddings for XAI
    print("7. Extracting final embeddings for XAI...")
    model.eval()
    with torch.no_grad():
        all_embeddings = []
        all_graphs = []

        for batch in tqdm(train_loader, desc="Extracting embeddings"):
            batch = batch.to(device)
            embeddings = model.get_embeddings(batch)
            all_embeddings.append(embeddings.cpu())
            all_graphs.extend([data for data in batch])

    all_embeddings = torch.cat(all_embeddings, dim=0).numpy()

    # Save final embeddings and graphs
#     final_embedding_path = './embeddings/final_embeddings.pkl'
#     final_embedding_path = f'./embeddings/final_embeddings_{timestamp}.pkl'
#     save_embeddings(all_embeddings, all_graphs, final_embedding_path)
#     print(f"8. Final embeddings saved to {final_embedding_path}")

    # Update your final embedding saving code
    final_embedding_path = f'./embeddings/final_embeddings_molecules_{timestamp}.pkl'
    save_embeddings_with_molecules(all_embeddings, dataset, final_embedding_path)
    print(f"8. Final embeddings saved to {final_embedding_path}")


    # Print encoder locations
    print(f"9. Encoders saved in ./checkpoints/encoders/:")
    print(f"   - Best encoder: best_encoder.pt")
    print(f"   - Final encoder: final_encoder.pt")
    print(f"   - Periodic encoders: encoder_epoch_*.pt")

    return model, metrics, all_embeddings, all_graphs

if __name__ == "__main__":
    model, metrics, embeddings, graphs = main()