In [None]:
import numpy as np
import networkx as nx
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import eigsh
from collections import Counter
from sklearn.base import BaseEstimator, TransformerMixin
import warnings

class MolecularGraphFeatureExtractor(BaseEstimator, TransformerMixin):
    """
    Integrated molecular graph feature extractor combining:
    - Chemical Laplacian (global chemical properties)
    - WL Histograms (local substructure patterns)
    - Topological features (graph structure)
    - Bridge tree features (hierarchical structure)
    """

    # Expanded atom type weights (α parameter) - Based on electronegativity and atomic properties
    ATOM_WEIGHTS = {
        # Common organic atoms
        'H': 0.5,   'C': 1.0,   'N': 1.5,   'O': 2.0,   'S': 2.5,   'P': 2.1,
        'B': 1.4,   'Si': 1.3,  'Se': 2.4,  'As': 2.3,  'Sb': 2.0,  'Te': 2.1,
        'Ge': 1.8,  
        # Halogens
        'F': 1.8,   'Cl': 1.8,  'Br': 2.0,  'I': 2.2,
        # Alkali metals
        'Li': 1.0,  'Na': 1.2,  'K': 1.1,
        # Alkaline earth metals
        'Mg': 1.3,  'Ca': 1.2,  'Ba': 1.1,
        # Transition metals
        'Cu': 1.9,  'Zn': 1.6,  'Co': 1.8,  'Ni': 1.9,  'Pd': 2.2,  'Pt': 2.3,
        'Au': 2.4,  'Ru': 2.2,  'Rh': 2.3,  
        # Post-transition metals
        'Sn': 1.7,  'Pb': 1.9,  'In': 1.7,  'Tl': 1.8,  'Ga': 1.6,  'Bi': 1.9,
        # Lanthanides/Actinides
        'Ho': 1.2,  'Tb': 1.2,
        # Others
        'W': 2.4,   'Hg': 2.0,
    }

    def __init__(self, n_eigen=10, wl_iterations=3, dataset_name=None):
        self.n_eigen = n_eigen
        self.wl_iterations = wl_iterations
        self.dataset_name = dataset_name
        self._set_dataset_mappings()

    def _set_dataset_mappings(self):
        """Set atom and bond type mappings based on dataset"""
        if self.dataset_name == "MUTAG":
            self.ATOM_TYPES = {0: 'C', 1: 'N', 2: 'O', 3: 'F', 4: 'I', 5: 'Cl', 6: 'Br'}
            self.BOND_TYPES = {0: 1.5, 1: 1.0, 2: 2.0, 3: 3.0}  # aromatic, single, double, triple
            
        elif self.dataset_name == "PTC_MR":
            self.ATOM_TYPES = {
                0: 'In', 1: 'P', 2: 'O', 3: 'N', 4: 'Na', 5: 'C', 6: 'Cl', 7: 'S',
                8: 'Br', 9: 'F', 10: 'K', 11: 'Cu', 12: 'Zn', 13: 'I', 14: 'Ba',
                15: 'Sn', 16: 'Pb', 17: 'Ca'
            }
            self.BOND_TYPES = {0: 3.0, 1: 2.0, 2: 1.0, 3: 1.5}  # triple, double, single, aromatic
            
        elif self.dataset_name == "AIDS":
            self.ATOM_TYPES = {
                0: 'C', 1: 'O', 2: 'N', 3: 'Cl', 4: 'F', 5: 'S', 6: 'Se', 7: 'P',
                8: 'Na', 9: 'I', 10: 'Co', 11: 'Br', 12: 'Li', 13: 'Si', 14: 'Mg',
                15: 'Cu', 16: 'As', 17: 'B', 18: 'Pt', 19: 'Ru', 20: 'K', 21: 'Pd',
                22: 'Au', 23: 'Te', 24: 'W', 25: 'Rh', 26: 'Zn', 27: 'Bi', 28: 'Pb',
                29: 'Ge', 30: 'Sb', 31: 'Sn', 32: 'Ga', 33: 'Hg', 34: 'Ho', 35: 'Tl',
                36: 'Ni', 37: 'Tb'
            }
            self.BOND_TYPES = {0: 1.5, 1: 1.0, 2: 2.0, 3: 3.0}  # aromatic, single, double, triple
        else:
            # Default mappings
            self.ATOM_TYPES = {0: 'C', 1: 'N', 2: 'O', 3: 'F', 4: 'I', 5: 'Cl', 6: 'Br'}
            self.BOND_TYPES = {0: 1.5, 1: 1.0, 2: 2.0, 3: 3.0}

    # ========== CHEMICAL LAPLACIAN METHODS ==========

    def _build_chemical_laplacian(self, G, node_labels, edge_labels):
        """
        Build integrated chemical Laplacian:
        L_chem(i,j) = {
            α(v_i) · Σw(i,k)           if i = j
            -γ(w(i,j)) · w(i,j)        if (i,j) ∈ E
            0                           otherwise
        }
        """
        n = G.number_of_nodes()
        node_list = list(G.nodes())
        node_to_idx = {node: idx for idx, node in enumerate(node_list)}

        L = np.zeros((n, n))

        if edge_labels is None:
            edge_labels = {edge: 1.0 for edge in G.edges()}

        # Diagonal: α(v_i) · Σw(i,k)
        for node in node_list:
            i = node_to_idx[node]
            atom_type = node_labels.get(node, 'C')
            alpha = self.ATOM_WEIGHTS.get(atom_type, 1.0)

            weighted_degree = 0.0
            for neighbor in G.neighbors(node):
                edge = (node, neighbor) if (node, neighbor) in edge_labels else (neighbor, node)
                bond_order = edge_labels.get(edge, 1.0)
                weighted_degree += bond_order

            L[i, i] = alpha * weighted_degree

        # Off-diagonal: -γ(w(i,j)) · w(i,j)
        for u, v in G.edges():
            i, j = node_to_idx[u], node_to_idx[v]
            edge = (u, v) if (u, v) in edge_labels else (v, u)
            bond_order = edge_labels.get(edge, 1.0)
            gamma = 1.2 if abs(bond_order - 1.5) < 0.01 else 1.0

            weight = gamma * bond_order
            L[i, j] = -weight
            L[j, i] = -weight

        return L

    def _extract_chemical_laplacian_features(self, G, node_labels, edge_labels):
        """Extract spectral features from Chemical Laplacian"""
        try:
            n = G.number_of_nodes()

            if n <= 2:
                return np.zeros(self.n_eigen + 2)

            L_chem = self._build_chemical_laplacian(G, node_labels, edge_labels)

            k = min(self.n_eigen, n - 1)
            L_sparse = csr_matrix(L_chem)
            eigenvalues = eigsh(L_sparse, k=k, which='SM', return_eigenvectors=False)
            eigenvalues = np.sort(eigenvalues)

            if len(eigenvalues) < self.n_eigen:
                eigenvalues = np.concatenate([
                    eigenvalues,
                    np.zeros(self.n_eigen - len(eigenvalues))
                ])

            eigenvalues_norm = eigenvalues[:self.n_eigen] / (n + 1)

            # Additional spectral features
            trace_norm = np.trace(L_chem) / (n * n)
            algebraic_conn = eigenvalues[1] / (n + 1) if len(eigenvalues) > 1 else 0

            return np.concatenate([eigenvalues_norm, [trace_norm, algebraic_conn]])

        except:
            return np.zeros(self.n_eigen + 2)

    # ========== WEISFEILER-LEHMAN METHODS ==========

    def _generate_wl_signatures(self, G):
        """Generate WL signatures across all iterations"""
        all_signatures_per_iteration = [set() for _ in range(self.wl_iterations + 1)]

        labels = {node: G.nodes[node]['atom_type'] for node in G.nodes()}
        for label in labels.values():
            all_signatures_per_iteration[0].add(str(label))

        for iteration in range(self.wl_iterations):
            new_labels = {}
            for node in G.nodes():
                neighbor_labels = sorted([labels[neighbor] for neighbor in G.neighbors(node)])
                signature = str(labels[node]) + '|' + '|'.join(map(str, neighbor_labels))
                all_signatures_per_iteration[iteration + 1].add(signature)
                new_labels[node] = signature
            labels = new_labels

        return all_signatures_per_iteration

    def _compute_wl_histogram(self, G):
        """Compute WL histogram using learned vocabulary"""
        all_histograms = []
        labels = {node: str(G.nodes[node]['atom_type']) for node in G.nodes()}

        for iteration in range(self.wl_iterations + 1):
            vocab_size = self.vocab_size_per_iteration_[iteration]
            histogram = np.zeros(vocab_size, dtype=np.float32)

            for signature in labels.values():
                if signature in self.vocab_per_iteration_[iteration]:
                    idx = self.vocab_per_iteration_[iteration][signature]
                    histogram[idx] += 1

            all_histograms.extend(histogram)

            if iteration < self.wl_iterations:
                new_labels = {}
                for node in G.nodes():
                    neighbor_labels = sorted([labels[neighbor] for neighbor in G.neighbors(node)])
                    signature = labels[node] + '|' + '|'.join(neighbor_labels)
                    new_labels[node] = signature
                labels = new_labels

        return np.array(all_histograms)

    # ========== TOPOLOGICAL FEATURES ==========

    def _extract_topological_features(self, G):
        """Extract topology-based features"""
        features = []

        features.append(G.number_of_nodes())
        features.append(G.number_of_edges())
        features.append(nx.density(G))

        degrees = [d for n, d in G.degree()]
        if degrees:
            features.extend([np.mean(degrees), np.std(degrees),
                           np.max(degrees), np.min(degrees)])
        else:
            features.extend([0, 0, 0, 0])

        try:
            features.append(nx.average_clustering(G))
        except:
            features.append(0)

        features.append(nx.number_connected_components(G))

        triangles = sum(nx.triangles(G).values()) / 3
        features.append(triangles)

        try:
            if nx.is_connected(G):
                features.append(nx.diameter(G))
                features.append(nx.average_shortest_path_length(G))
            else:
                features.append(0)
                features.append(0)
        except:
            features.append(0)
            features.append(0)

        if degrees:
            features.append(np.percentile(degrees, 25))
            features.append(np.percentile(degrees, 75))
        else:
            features.extend([0, 0])

        return np.array(features)

    def _extract_bridge_tree_features(self, G):
        """Extract features from bridge tree"""
        try:
            bridges = set(nx.bridges(G))
            non_bridge_edges = [e for e in G.edges()
                              if e not in bridges and (e[1], e[0]) not in bridges]

            H = nx.Graph()
            H.add_edges_from(non_bridge_edges)

            comp_map = {}
            for i, comp in enumerate(nx.connected_components(H)):
                for node in comp:
                    comp_map[node] = i

            BT = nx.Graph()
            BT.add_nodes_from(set(comp_map.values()))

            for u, v in bridges:
                if u in comp_map and v in comp_map:
                    BT.add_edge(comp_map[u], comp_map[v])

            return self._extract_topological_features(BT)
        except:
            return np.zeros(15)

    # ========== FIT/TRANSFORM INTERFACE ==========

    def fit(self, graphs, y=None):
        """Learn WL vocabulary from training graphs"""
        all_signatures_per_iteration = [set() for _ in range(self.wl_iterations + 1)]

        for G in graphs:
            graph_signatures = self._generate_wl_signatures(G)
            for iteration in range(self.wl_iterations + 1):
                all_signatures_per_iteration[iteration].update(graph_signatures[iteration])

        self.vocab_per_iteration_ = []
        self.vocab_size_per_iteration_ = []

        for iteration in range(self.wl_iterations + 1):
            signatures = sorted(all_signatures_per_iteration[iteration])
            vocab = {sig: idx for idx, sig in enumerate(signatures)}
            self.vocab_per_iteration_.append(vocab)
            self.vocab_size_per_iteration_.append(len(vocab))

        return self

    def transform(self, graphs):
        """
        Transform graphs to feature matrix.

        Feature composition:
        1. Chemical Laplacian eigenvalues (n_eigen + 2)
        2. Topological features (15)
        3. Bridge tree features (15)
        4. WL Histograms (vocabulary-dependent)
        """
        features_list = []

        for G in graphs:
            # Extract node and edge labels
            node_labels = {n: G.nodes[n]['atom_type'] for n in G.nodes()}
            edge_labels = {(u, v): G.edges[u, v].get('bond_order', 1.0)
                          for u, v in G.edges()}

            # 1. Chemical Laplacian features
            chem_laplacian = self._extract_chemical_laplacian_features(G, node_labels, edge_labels)

            # 2. Topological features
            topology = self._extract_topological_features(G)

            # 3. Bridge tree features
            bridge_tree = self._extract_bridge_tree_features(G)

            # 4. WL histograms
            wl_hist = self._compute_wl_histogram(G)

            # Concatenate all features
            features = np.concatenate([chem_laplacian, topology, bridge_tree, wl_hist])
            features_list.append(features)

        return np.array(features_list)


def pyg_to_networkx(data, atom_types, bond_types):
    """Convert PyTorch Geometric Data object to NetworkX graph"""
    G = nx.Graph()

    # Add nodes with atom_type attribute
    num_nodes = data.x.size(0)
    for i in range(num_nodes):
        atom_type_code = data.x[i].argmax().item()  # Get atom type from one-hot encoding
        atom_type = atom_types.get(atom_type_code, 'C')
        G.add_node(i, atom_type=atom_type)

    # Add edges with bond_order attribute
    edge_index = data.edge_index.numpy()
    if hasattr(data, 'edge_attr') and data.edge_attr is not None:
        edge_attr = data.edge_attr.numpy()
        for idx in range(edge_index.shape[1]):
            src, dst = edge_index[0, idx], edge_index[1, idx]
            if src < dst:  # Add edge only once (undirected)
                bond_type_code = edge_attr[idx].argmax() if len(edge_attr[idx].shape) > 0 else int(edge_attr[idx])
                bond_order = bond_types.get(bond_type_code, 1.0)
                G.add_edge(int(src), int(dst), bond_order=bond_order)
    else:
        # No edge attributes, assume single bonds
        for idx in range(edge_index.shape[1]):
            src, dst = edge_index[0, idx], edge_index[1, idx]
            if src < dst:
                G.add_edge(int(src), int(dst), bond_order=1.0)

    return G


def process_dataset(dataset, dataset_name):
    """Process a single dataset with 10-fold CV"""
    from sklearn.model_selection import StratifiedKFold, GridSearchCV
    from sklearn.svm import SVC
    from sklearn.preprocessing import StandardScaler
    
    print(f"\n{'='*70}")
    print(f"PROCESSING: {dataset_name}")
    print(f"{'='*70}")
    
    # Get dataset-specific mappings
    extractor_temp = MolecularGraphFeatureExtractor(dataset_name=dataset_name)
    atom_types = extractor_temp.ATOM_TYPES
    bond_types = extractor_temp.BOND_TYPES
    
    # Convert all PyG graphs to NetworkX
    print(f"Converting {len(dataset)} PyG graphs to NetworkX...")
    graphs = []
    labels = []
    for data in dataset:
        G = pyg_to_networkx(data, atom_types, bond_types)
        graphs.append(G)
        labels.append(data.y.item())

    y = np.array(labels)
    
    # Convert labels if needed
    if np.min(y) < 0:
        y = ((y + 1) / 2).astype(int)
    
    print(f"Total samples: {len(y)}")
    print(f"Class distribution: {np.bincount(y)}")

    # 10-Fold Cross-Validation
    skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

    test_accs = []
    best_params_per_fold = []

    for fold, (train_idx, test_idx) in enumerate(skf.split(graphs, y), 1):
        # Split graphs and labels
        train_graphs = [graphs[i] for i in train_idx]
        test_graphs = [graphs[i] for i in test_idx]
        y_train = y[train_idx]
        y_test = y[test_idx]

        # Fit feature extractor on training data ONLY
        extractor = MolecularGraphFeatureExtractor(n_eigen=10, wl_iterations=3, dataset_name=dataset_name)
        extractor.fit(train_graphs)

        # Transform both train and test
        X_train = extractor.transform(train_graphs)
        X_test = extractor.transform(test_graphs)

        # Scale features
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)

        # Grid search for best hyperparameters
        param_grid = {
            'C': [0.1, 1, 10, 100],
            'gamma': ['scale', 'auto', 0.001, 0.01, 0.1],
            'kernel': ['rbf']
        }

        grid_search = GridSearchCV(
            SVC(random_state=42),
            param_grid,
            cv=5,
            scoring='accuracy',
            n_jobs=-1,
            verbose=0
        )

        grid_search.fit(X_train_scaled, y_train)
        best_params_per_fold.append(grid_search.best_params_)

        # Evaluate on test set
        best_svm = grid_search.best_estimator_
        test_acc = best_svm.score(X_test_scaled, y_test)
        test_accs.append(test_acc)

        print(f"Fold {fold}/10: Test Acc = {test_acc:.3f}")

    # Find most common hyperparameters
    c_values = [p['C'] for p in best_params_per_fold]
    gamma_values = [p['gamma'] for p in best_params_per_fold]
    
    best_c = Counter(c_values).most_common(1)[0][0]
    best_gamma = Counter(gamma_values).most_common(1)[0][0]
    
    mean_acc = np.mean(test_accs)
    std_acc = np.std(test_accs)
    
    return {
        'dataset': dataset_name,
        'mean_test_acc': mean_acc,
        'std_test_acc': std_acc,
        'best_C': best_c,
        'best_gamma': best_gamma
    }


if __name__ == "__main__":
    from torch_geometric.datasets import TUDataset
    
    warnings.filterwarnings('ignore')
    np.random.seed(42)
    
    print("="*70)
    print("QUANTUM HACKATHON: MOLECULAR CLASSIFICATION (3 Datasets)")
    print("="*70)
    print("\n[1] LOADING DATASETS...")
    
    datasets = []
    names = ["MUTAG", "PTC_MR", "AIDS"]
    
    for NAME in names:
        try:
            dataset = TUDataset(root=f'/tmp/{NAME}', name=NAME)
            datasets.append(dataset)
            print(f"✓ Dataset {NAME} loaded successfully")
            print(f"  - Number of graphs: {len(dataset)}")
            print(f"  - Number of classes: {dataset.num_classes}")
            print(f"  - Number of node features: {dataset.num_node_features}")
        except Exception as e:
            print(f"✗ Error loading dataset {NAME}: {e}")
            continue
    
    # Process each dataset
    results = []
    for dataset, name in zip(datasets, names):
        result = process_dataset(dataset, name)
        results.append(result)
    
    # Print final summary
    print("\n" + "="*70)
    print("FINAL SUMMARY")
    print("="*70)
    print(f"\n{'Dataset':<15} {'Mean Test Acc':<20} {'Best C':<10} {'Best Gamma'}")
    print("-"*70)
    for r in results:
        print(f"{r['dataset']:<15} {r['mean_test_acc']:.3f} ± {r['std_test_acc']:.3f}         "
              f"{r['best_C']:<10} {r['best_gamma']}")
    
    print("\n" + "="*70)

QUANTUM HACKATHON: MOLECULAR CLASSIFICATION (3 Datasets)

[1] LOADING DATASETS...
✓ Dataset MUTAG loaded successfully
  - Number of graphs: 188
  - Number of classes: 2
  - Number of node features: 7
✓ Dataset PTC_MR loaded successfully
  - Number of graphs: 344
  - Number of classes: 2
  - Number of node features: 18
✓ Dataset AIDS loaded successfully
  - Number of graphs: 2000
  - Number of classes: 2
  - Number of node features: 38

PROCESSING: MUTAG
Converting 188 PyG graphs to NetworkX...
Total samples: 188
Class distribution: [ 63 125]
Fold 1/10: Test Acc = 0.789
Fold 2/10: Test Acc = 0.789
Fold 3/10: Test Acc = 0.789
Fold 4/10: Test Acc = 0.895
Fold 5/10: Test Acc = 0.947
Fold 6/10: Test Acc = 0.789
Fold 7/10: Test Acc = 0.684
Fold 8/10: Test Acc = 0.895
Fold 9/10: Test Acc = 0.944
Fold 10/10: Test Acc = 0.722

PROCESSING: PTC_MR
Converting 344 PyG graphs to NetworkX...
Total samples: 344
Class distribution: [192 152]
Fold 1/10: Test Acc = 0.600
Fold 2/10: Test Acc = 0.743
Fold 