In [None]:

import itertools
import math
from collections import OrderedDict, Counter
from dataclasses import dataclass
from typing import Dict, Optional, Sequence, Tuple, List
from typing import Iterable

import matplotlib.pyplot as plt
import networkx as nx
import torch
from pytorch_lightning import seed_everything
from rdkit import Chem
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torchhd import HRRTensor

from src.datasets.zinc_pairs import pyg_to_nx
from src.datasets.zinc_smiles_generation import ZincSmiles
from src.encoding.configs_and_constants import DatasetConfig, Features, FeatureConfig, IndexRange
from src.encoding.feature_encoders import CombinatoricIntegerEncoder
from src.encoding.graph_encoders import HyperNet, load_or_create_hypernet
from src.encoding.oracles import Oracle
from src.encoding.the_types import VSAModel
from src.utils.utils import GLOBAL_MODEL_PATH

"""
Features
    Atom types size: 9
    Atom types: ['Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S']
    Degrees size: 5, encoded with 0 index:
    Degrees: {1, 2, 3, 4, 5}
    Formal Charges size: 3
    Formal Charges: {0, 1, -1}
    Explicit Hs size: 4
    Explicit Hs: {0, 1, 2, 3}
Encodings:
    float(ZINC_SMILE_ATOM_TO_IDX[atom.GetSymbol()]),
    float(atom.GetDegree() - 1),  # [1, 2, 3, 4, 5] -> [0, 1, 2, 3, 4]
    float(atom.GetFormalCharge() if atom.GetFormalCharge() >= 0 else 2),  # [0, 1, -1] -> [0, 1, 2]
    float(atom.GetTotalNumHs()),
"""

In [None]:
seed = 42

seed_everything(seed)
device = torch.device('cpu')
# ----- hypernet config (kept for provenance; not needed in this flow) -----
ds_name = "ZincSmilesHRR7744"
zinc_feature_bins = [9, 6, 3, 4]
dataset_config = DatasetConfig(
    seed=42,
    name=ds_name,
    vsa=VSAModel.HRR,
    hv_dim=88 * 88,
    device=device,
    node_feature_configs=OrderedDict(
        [
            (

                Features.ATOM_TYPE,
                FeatureConfig(
                    count=math.prod(zinc_feature_bins),
                    encoder_cls=CombinatoricIntegerEncoder,
                    index_range=IndexRange((0, 4)),
                    bins=zinc_feature_bins,
                ),
            ),
        ]
    ),
)

print("Loading/creating hypernet …")
hypernet: HyperNet = (
    load_or_create_hypernet(path=GLOBAL_MODEL_PATH, cfg=dataset_config).to(device=device)
)
print("Hypernet ready.")
assert not hypernet.use_edge_features()
assert not hypernet.use_graph_features()



In [None]:
from src.utils import visualisations
from src.encoding.decoder import greedy_oracle_decoder
from pathlib import Path
from src.encoding.oracles import MLPClassifier, Oracle
from pprint import pprint
from src.utils.utils import DataTransformer
from torchhd import HRRTensor


# Real Oracle
def is_induced_subgraph_feature_aware(G_small: nx.Graph, G_big: nx.Graph) -> bool:
    """NetworkX VF2: is `G_small` an induced, label-preserving subgraph of `G_big`?"""
    nm = lambda a, b: a["feat"] == b["feat"]
    GM = nx.algorithms.isomorphism.GraphMatcher(G_big, G_small, node_match=nm)
    return GM.subgraph_is_isomorphic()


batch_size = 32
zinc_smiles = ZincSmiles(split="valid")[:batch_size]
dataloader = DataLoader(dataset=zinc_smiles, batch_size=batch_size, shuffle=False)

# Classifier
chkpt = torch.load(Path("/Users/akaveh/projects/kit/graph_hdc/_models/mlp_stratified_base_laynorm_2nd_try.pt"), map_location="cpu",
                   weights_only=False)

cfg = chkpt["config"]
# print(f"Classifier's best metric (AUC): {chkpt['best_metric']}")
print(f"Classifier's cfg")
pprint(cfg, indent=4)

classifier = MLPClassifier(
    hv_dim=cfg.get("hv_dim"),
    hidden_dims=cfg.get("hidden_dims"),
    use_layer_norm=cfg.get("use_layer_norm"),
    use_batch_norm=cfg.get("use_batch_norm")).to(device).eval()
classifier.load_state_dict(chkpt["model_state"], strict=True)
oracle = Oracle(model=classifier)
oracle.encoder = hypernet



y = []
for batch in dataloader:
    # Encode the whole graph in one HV
    encoded_data = hypernet.forward(batch)
    node_term = encoded_data["node_terms"]
    graph_term = encoded_data["graph_embedding"]

    graph_terms_hd = graph_term.as_subclass(HRRTensor)

    ground_truth_counters = {}
    datas = batch.to_data_list()
    for j, g in enumerate(range(batch_size)):
        print("================================================")
        full_graph_nx = DataTransformer.pyg_to_nx(data=datas[g])
        print(f"[{j}] Original Graph")
        visualisations.draw_nx_with_atom_colorings(full_graph_nx)
        plt.show()
        mol_full, _ = DataTransformer.nx_to_mol(full_graph_nx)
        display(mol_full)

        print(f"Num Nodes {datas[g].num_nodes}")
        print(f"Num Edges {int(datas[g].num_edges / 2)}")
        node_multiset = DataTransformer.get_node_counter_from_batch(batch=g, data=batch)
        print(f"Multiset Nodes {node_multiset.total()}")
        nx_GS: list[nx.Graph] = greedy_oracle_decoder(node_multiset=node_multiset, oracle=oracle, full_g_h=graph_terms_hd[g],
                                                beam_size=32, oracle_threshold=0.643)
        print(len(nx_GS))
        print(nx_GS)
        nx_GS = list(filter(None, nx_GS))
        for i, g in enumerate(nx_GS):
            print(f"Graph Nr: {i}")
            visualisations.draw_nx_with_atom_colorings(g)
            plt.show()

            mol, _ = DataTransformer.nx_to_mol(g)
            display(mol)
            print(f"Num Atoms {mol.GetNumAtoms()}")
            print(f"Num Bonds {mol.GetNumBonds()}")

            is_induced = is_induced_subgraph_feature_aware(g, full_graph_nx)
            print("Is Induced subgraph: ", is_induced)
            y.append(int(is_induced))

print(f"Accuracy: {sum(y) / len(y)}")