In [1]:
import datetime
import random
import shutil
import string
import time
from pathlib import Path
import joblib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
from torch.utils.data import Dataset
from torch_geometric.data import Batch
from torch_geometric.datasets import ZINC

from src.datasets import AddNodeDegree
from src.encoding.configs_and_constants import SupportedDataset
from src.encoding.graph_encoders import HyperNet
from src.encoding.the_types import VSAModel
from src.normalizing_flow.models import NeuralSplineLightning

from pathlib import Path
from pprint import pprint

import pandas as pd
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Subset
from torch_geometric.datasets import ZINC

from graph_hdc.utils import AbstractEncoder
from src.datasets import AddNodeDegree
from src.encoding.configs_and_constants import DSHDCConfig, SupportedDataset
from src.encoding.graph_encoders import HyperNet
from src.encoding.the_types import VSAModel
from src.normalizing_flow.config import FlowConfig
from src.normalizing_flow.models import NeuralSplineLightning


def setup_exp(ds_value: str) -> dict:
    """
    Sets up experiment directories based on the current script location.

    Args:
        ds_value (str): Dataset name to use for global_dataset_dir.

    Returns:
        dict: Dictionary containing paths to various directories.
    """
    # Resolve script location
    script_path = Path("/Users/arvandkaveh/Projects/kit/graph_hdc/notebooks/NormFlows/real_nvp.ipynb")
    experiments_path = script_path.parent
    script_stem = script_path.stem  # without .py

    # Resolve base and project directories
    base_dir = experiments_path / "results" / script_stem
    base_dir.mkdir(parents=True, exist_ok=True)

    project_dir = script_path.parents[2]  # adjust as needed

    print(f"Setting up experiment in {base_dir}")
    now = f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{''.join(random.choices(string.ascii_lowercase, k=4))}"
    exp_dir = base_dir / now
    exp_dir.mkdir(parents=True, exist_ok=True)
    print(f"Experiment directory created: {exp_dir}")

    dirs = {
        "exp_dir": exp_dir,
        "models_dir": exp_dir / "models",
        "evals_dir": exp_dir / "evaluations",
        "artefacts_dir": exp_dir / "artefacts",
        "global_model_dir": project_dir / "_models",
        "global_dataset_dir": project_dir / "_datasets" / ds_value,
    }

    for d in dirs.values():
        d.mkdir(parents=True, exist_ok=True)

    # Save a copy of the script
    try:
        shutil.copy(script_path, exp_dir / script_path.name)
        print(f"Saved a copy of the script to {exp_dir / script_path.name}")
    except Exception as e:
        print(f"Warning: Failed to save script copy: {e}")

    return dirs


def plot_train_val_loss(df, artefacts_dir):
    train = df[df["train_loss_epoch"].notna()]
    val = df[df["val_loss"].notna()]

    plt.figure(figsize=(8, 5))
    plt.plot(train["epoch"], train["train_loss_epoch"], label="Train Loss")
    plt.plot(val["epoch"], val["val_loss"], label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Train vs. Validation Loss")
    plt.legend()
    plt.tight_layout()

    artefacts_dir.mkdir(exist_ok=True)
    plt.savefig(artefacts_dir / "train_val_loss.png")
    plt.close()
    print(f"Saved train/val loss plot to {artefacts_dir / 'train_val_loss.png'}")


def pca_encode(x: torch.Tensor, pca: PCA, norm: bool = False) -> torch.Tensor:
    """
    Encode data using a fitted PCA, with optional normalization.

    :param x: Input tensor of shape (..., features).
    :type x: torch.Tensor
    :param pca: Fitted PCA instance with attributes `mean_` and `std_`.
    :type pca: PCA
    :param norm: Whether to normalize input by mean and std before PCA.
    :type norm: bool
    :returns: Tensor of reduced dimensions, same dtype as input.
    :rtype: torch.Tensor

    The input is flattened over all but the last dimension, optionally normalized,
    transformed with the PCA, then returned as a tensor.
    """
    flat = x.view(-1, x.shape[-1]).cpu().numpy()
    if norm:
        flat = (flat - pca.mean_) / pca.std_
    reduced = pca.transform(flat)
    return torch.tensor(reduced, dtype=x.dtype)


def load_or_fit_pca(
        train_dataset: Dataset, encoder: AbstractEncoder, pca_path: Path | None = None, n_components: float = 0.99999,
        n_fit: int = 20000
) -> PCA:
    """
    Load an existing PCA from disk or fit a new one and save it.

    :param train_dataset: Dataset for fitting PCA.
    :type train_dataset: Dataset
    :param encoder: Model or function returning a dict with keys \"node_terms\", \"edge_terms\", and \"graph_embedding\".
    :param pca_path: Path to load/save the PCA object.
    :type pca_path: Path
    :param n_components: Number of components or variance ratio for PCA.
    :type n_components: float
    :param n_fit: Maximum number of samples to fit PCA on.
    :type n_fit: int
    :returns: Fitted PCA instance with `mean_` and `std_` attributes.
    :rtype: PCA

    If a PCA exists at `pca_path`, it is loaded. Otherwise, embeddings
    are collected by applying `encoder` to dataset entries until `n_fit`
    samples, flattened, and used to fit a new PCA. The mean and std of the
    fit data are stored on the PCA for later normalization.
    """
    if pca_path is not None and pca_path.exists():
        print(f"Loading existing PCA from {pca_path}")
        pca = joblib.load(pca_path)
        print(f"Loaded PCA with {pca.n_components_} components")
        return pca

    print("Fitting PCA on training data...")
    n_fit = min(n_fit, len(train_dataset))
    X_fit = []
    for i in range(n_fit):
        data = train_dataset[i]
        batch_data = Batch.from_data_list([data])
        res = encoder.forward(data=batch_data)
        x = torch.stack(
            [res["node_terms"].squeeze(0), res["edge_terms"].squeeze(0), res["graph_embedding"].squeeze(0)], dim=0
        )  # [3, D]
        X_fit.append(x.cpu().numpy())
    X_fit = np.stack(X_fit)
    X_fit_flat = X_fit.reshape(-1, X_fit.shape[-1])

    # Compute mean and std for normalization
    mu = np.mean(X_fit_flat, axis=0)
    sigma = np.std(X_fit_flat, axis=0)

    # Fit PCA
    pca = PCA(n_components=n_components, svd_solver="full")
    pca.fit(X_fit_flat)

    # Attach normalization stats
    pca.mean_ = mu
    pca.std_ = sigma

    print(f"PCA reduced dimension: {pca.n_components_} from {X_fit.shape[-1]}")
    joblib.dump(pca, pca_path)
    print(f"Saved new PCA to {pca_path}")
    return pca


def load_or_create_hypernet(path: Path, ds: SupportedDataset, depth: int) -> HyperNet:
    ds_name = ds.name
    if ds.default_cfg.nha_depth:
        ds_name = f"{ds_name}-nha-d{ds.default_cfg.nha_depth}"
    if ds.default_cfg.nha_bins:
        ds_name = f"{ds_name}-nha-b{ds.default_cfg.nha_bins}"
    path = path / f"hypernet_ds{ds.name}_{ds.default_cfg.vsa.value}_d{cfg.hv_dim}_s{cfg.seed}_dpth{depth}.pt"
    if path.exists():
        print(f"Loading existing HyperNet from {path}")
        encoder = HyperNet(config=ds.default_cfg, depth=depth)
        encoder.load(path)
    else:
        print("Creating new HyperNet instance.")
        encoder = HyperNet(config=ds.default_cfg, depth=depth)
        encoder.populate_codebooks()
        encoder.save_to_path(path)
        print(f"Saved new HyperNet to {path}")
    return encoder


class EncodedPCADataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, encoder: AbstractEncoder, pca: PCA | None = None, *, use_norm_pca: bool = False):
        self.base_dataset = base_dataset
        self.encoder = encoder
        self.pca = pca
        self.use_norm = use_norm_pca  # Whether to normalize the PCA

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        data = self.base_dataset[idx]
        batch_data = Batch.from_data_list([data])
        res = self.encoder.forward(data=batch_data)
        x = torch.stack(
            [res["node_terms"].squeeze(0), res["edge_terms"].squeeze(0), res["graph_embedding"].squeeze(0)], dim=0
        )
        if self.pca is not None:
            return pca_encode(x, self.pca, self.use_norm)
        return x


def get_device():
    if torch.cuda.is_available():
        count = torch.cuda.device_count()
        print(f"CUDA is available. Detected {count} GPU device{'s' if count != 1 else ''}.")
        return torch.device("cuda")
    print("CUDA is not available.")
    return torch.device("cpu")


class TimeLoggingCallback(Callback):
    def setup(self, trainer, pl_module, stage=None):
        self.start_time = time.time()

    def on_train_epoch_end(self, trainer, pl_module):
        elapsed = time.time() - self.start_time
        trainer.logger.log_metrics({"elapsed_time_sec": elapsed}, step=trainer.current_epoch)


class DebugMetricsCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print("\n==> callback_metrics:")
        for k, v in trainer.callback_metrics.items():
            print(f"   {k:20s} → {type(v)}")
        print()


class SamplingEveryNEpoch(Callback):
    def __init__(self, encoder, n_samples: int = 10, every_n_epochs: int = 10):
        """
        :param encoder: your graph‐reconstruction encoder, with `.decode_order_zero_counter()`
        :param n_samples: how many samples to draw
        :param every_n_epochs: interval at which to run the sampling
        """
        super().__init__()
        self.encoder = encoder
        self.n_samples = n_samples
        self.every_n = every_n_epochs

    def on_train_epoch_start(self, trainer, pl_module):
        epoch = trainer.current_epoch
        if epoch % self.every_n != 0:
            return

        device = pl_module.device
        print(f"{device=}")
        self.encoder.to(device)

        pl_module.eval()
        with torch.no_grad():
            print(f"\n=== Sampling callback at epoch {epoch} ===")
            # ---- Sampling example ----
            z, logs = pl_module.sample(self.n_samples)
            print(f"\n=== Finished sampling ===")
            z = z.to(device)

            # if you had a PCA decode step:
            # z = pca_decode(z, pca)

            # unpack into (batch, 3, D)
            node_terms_s, _, _ = z.unbind(dim=1)

            # Cast to HRRTensor
            node_terms_hrr = node_terms_s.as_subclass(HRRTensor)

            for b in range(self.n_samples):
                print(f"-- SAMPLE #{b} --")
                node_counter = self.encoder.decode_order_zero_counter(node_terms_hrr[b])
                print(f"node_counter[0] = {node_counter[0]}")
                print(f"  total = {node_counter[0].total()}")
        pl_module.train()


hv_dim = 60 * 60
batch_size = 2
cfg = FlowConfig(
    project_dir="/Users/arvandkaveh/Projects/kit/graph_hdc",
    seed=42,
    epochs=500,
    batch_size=batch_size,
    vsa=VSAModel.HRR,
    hv_dim=hv_dim,
    dataset=SupportedDataset.ZINC_NODE_DEGREE_COMB,
    num_input_channels=3 * hv_dim,
    num_flows=16,
    num_hidden_channels=128,
    input_shape=(3, hv_dim),
    lr=0.00003,
    weight_decay=0.0001,
)

print("Running experiment")
pprint(cfg.__dict__, indent=2)

dirs = setup_exp(cfg.dataset.value)
exp_dir = dirs["exp_dir"]
models_dir = dirs["models_dir"]
evals_dir = dirs["evals_dir"]
artefacts_dir = dirs["artefacts_dir"]
global_model_dir = dirs["global_model_dir"]
global_dataset_dir = dirs["global_dataset_dir"]

# W&B Logging — use existing run (from sweep or manual init)
# run = wandb.run or wandb.init(project="realnvp-hdc", config=cfg.__dict__, name=f"run_{cfg.hv_dim}_{cfg.seed}", reinit=True)
# run.tags = [f"hv_dim={cfg.hv_dim}", f"vsa={cfg.vsa.value}", f"dataset={cfg.dataset.value}"]

# wandb_logger = WandbLogger(log_model=True, experiment=run)

train_data = ZINC(root=str(global_dataset_dir), pre_transform=AddNodeDegree(), split="train", subset=True)[:1]
# make a length-4 dataset by selecting index 0 four times
train_dataset = Subset(train_data, indices=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
print(f"Train length = {len(train_dataset)}")  # → 4
print(train_dataset[0])
validation_data = ZINC(root=str(global_dataset_dir), pre_transform=AddNodeDegree(), split="val", subset=True)[:1]
validation_dataset = Subset(validation_data, indices=[0, 0, 0, 0])
print(f"{len(validation_dataset)=}")
print(validation_dataset[0])

device = get_device()
ds = cfg.dataset
ds.default_cfg.vsa = cfg.vsa
ds.default_cfg.hv_dim = cfg.hv_dim
ds.default_cfg.device = device
ds.default_cfg.seed = cfg.seed
ds.default_cfg.edge_feature_configs = {}
ds.default_cfg.graph_feature_configs = {}

encoder = load_or_create_hypernet(path=global_model_dir, cfg=ds.default_cfg, depth=3)

n_components = 0.998
pca_path = global_model_dir / f"hypervec_pca_{cfg.vsa.value}_d{cfg.hv_dim}_s{cfg.seed}_c{str(n_components)[2:]}.joblib"
# pca = load_or_fit_pca(
#     train_dataset=ZINC(root=str(global_dataset_dir), pre_transform=AddNodeDegree(), split="train", subset=True),
#     encoder=encoder,
#     pca_path=pca_path,
#     n_components=n_components,
#     n_fit=20_000,
# )

# reduced_dim = int(pca.n_components_)
cfg.num_input_channels = 3 * hv_dim
cfg.input_shape = (3, hv_dim)

train_dataloader = DataLoader(
    EncodedPCADataset(train_dataset, encoder),
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=torch.cuda.is_available(),
    drop_last=True,
)
validation_dataloader = DataLoader(
    EncodedPCADataset(validation_dataset, encoder),
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=torch.cuda.is_available(),
    drop_last=False,
)

model = NeuralSplineLightning(cfg)

csv_logger = CSVLogger(save_dir=str(evals_dir), name="logs")
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    save_top_k=2,
    mode="min",
    dirpath=str(models_dir),
    filename="epoch{epoch:02d}-val{val_loss:.2f}",
    save_last=True,
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
time_logger = TimeLoggingCallback()
# debug_callback = DebugMetricsCallback()
sampling_cb = SamplingEveryNEpoch(encoder=encoder, n_samples=1, every_n_epochs=10)

trainer = Trainer(
    max_epochs=cfg.epochs,
    logger=[csv_logger,
            # wandb_logger
            ],
    callbacks=[checkpoint_callback, lr_monitor, time_logger,
               # debug_callback,
               # sampling_cb
               ],
    default_root_dir=str(exp_dir),
    accelerator="auto",
    log_every_n_steps=4,
    enable_progress_bar=True,
    # detect_anomaly=True,
)

trainer.fit(model, train_dataloaders=train_dataloader,
            val_dataloaders=validation_dataloader
            )

torch.save(model.state_dict(), models_dir / "final_model.pt")

metrics_path = Path(csv_logger.log_dir) / "metrics.csv"
if metrics_path.exists():
    df = pd.read_csv(metrics_path)
    df.to_parquet(evals_dir / "metrics.parquet")
    plot_train_val_loss(df, artefacts_dir)

print("==== The Experiment is done! ====")


Running experiment
{ 'activation': <class 'torch.nn.modules.activation.ReLU'>,
  'batch_size': 2,
  'dataset': <SupportedDataset.ZINC_NODE_DEGREE_COMB: 'ZINC_ND_COMB'>,
  'device': 'cpu',
  'dropout_probability': 0.0,
  'epochs': 500,
  'exp_dir': None,
  'flow_type': <class 'normflows.flows.neural_spline.wrapper.AutoregressiveRationalQuadraticSpline'>,
  'hv_dim': 3600,
  'init_identity': True,
  'input_shape': (3, 3600),
  'lr': 3e-05,
  'num_bins': 8,
  'num_blocks': 2,
  'num_context_channels': None,
  'num_flows': 16,
  'num_hidden_channels': 128,
  'num_input_channels': 10800,
  'permute': False,
  'project_dir': PosixPath('/Users/arvandkaveh/Projects/kit/graph_hdc'),
  'seed': 42,
  'tail_bound': 3,
  'vsa': <VSAModel.HRR: 'HRR'>,
  'weight_decay': 0.0001}
Setting up experiment in /Users/arvandkaveh/Projects/kit/graph_hdc/notebooks/NormFlows/results/real_nvp
Experiment directory created: /Users/arvandkaveh/Projects/kit/graph_hdc/notebooks/NormFlows/results/real_nvp/2025-07-14_15

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name | Type            | Params | Mode 
-------------------------------------------------
0 | flow | NormalizingFlow | 535 M  | train
-------------------------------------------------
535 M     Trainable params
0         Non-trainable params
535 M     Total params
2,143.499 Total estimated model params size (MB)
307       Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Users/arvandkaveh/Projects/kit/graph_hdc/.pixi/envs/default/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

/Users/arvandkaveh/Projects/kit/graph_hdc/.pixi/envs/default/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 8/8 [00:07<00:00,  1.10it/s, v_num=0, train_loss_step=2.05e+9]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  4.47it/s][A
Validation DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  4.53it/s][A
Epoch 1: 100%|██████████| 8/8 [00:07<00:00,  1.13it/s, v_num=0, train_loss_step=2.05e+9, val_loss=1.84e+9, lr=3e-5, train_loss_epoch=2.05e+9]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  4.49it/s][A
Validation DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  4.60it/s][A
Epoch 2: 100%|██████████| 8/8 [00:07<00:00,  1.12it/s, v_num=0, train_loss_step=2.05e+9, val_loss=1.84e+9, lr=3e-5, train_loss_epoch=2.05


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [4]:
import datetime
import random
import shutil
import string
import time
from pathlib import Path
import joblib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
from torch.utils.data import Dataset
from torch_geometric.data import Batch
from torch_geometric.datasets import ZINC

from src.datasets import AddNodeDegree
from src.encoding.configs_and_constants import SupportedDataset
from src.encoding.graph_encoders import HyperNet
from src.encoding.the_types import VSAModel
from src.normalizing_flow.models import NeuralSplineLightning

model_loaded = NeuralSplineLightning.load_from_checkpoint(
    "/Users/arvandkaveh/Projects/kit/graph_hdc/notebooks/NormFlows/results/real_nvp/2025-07-14_15-56-53_wtvj/models/last.ckpt")

device = get_device()
ds = SupportedDataset.ZINC_NODE_DEGREE_COMB
ds.default_cfg.vsa = VSAModel.HRR
ds.default_cfg.hv_dim = hv_dim
ds.default_cfg.device = device
ds.default_cfg.seed = cfg.seed
ds.default_cfg.edge_feature_configs = {}
ds.default_cfg.graph_feature_configs = {}

encoder = load_or_create_hypernet(path=global_model_dir, cfg=ds.default_cfg, depth=3)

# ---- Sampling example ----
n_samples = 1
model_loaded.eval()
s, l = model_loaded.sample(n_samples)
encoder.to(model_loaded.device)
# s_decoded = pca_decode(s, pca)
# z = s_decoded.view(n_samples, 3, hv_dim)
z = s

RuntimeError: MPS backend out of memory (MPS allocated: 36.22 GB, other allocations: 864.00 KB, max allowed: 36.27 GB). Tried to allocate 121.29 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:

import torch
from torchhd import HRRTensor

node_terms_s, edge_terms_s, graph_embeddings_s = z.unbind(dim=1)

# Cast to HRRTensor
node_terms_s_hrr = node_terms_s.as_subclass(HRRTensor)
edge_terms_s_hrr = edge_terms_s.as_subclass(HRRTensor)
graph_embeddings_s_hrr = graph_embeddings_s.as_subclass(HRRTensor)

print("---SAMPLES---")
for b in range(n_samples):
    print(f"SAMPLE NR: {b}")
    node_counter_dec = encoder.decode_order_zero_counter(node_terms_s_hrr[b])
    print(f"{node_counter_dec[0]=}")
    print(f"{node_counter_dec[0].total()}")

import torchhd

print("----RANDOM-----")
random_counter = encoder.decode_order_zero_counter(torchhd.random(1, hv_dim, vsa="HRR")[0])
print(f"{random_counter[0]=}")
print(f"{random_counter[0].total()=}")


In [None]:

print("---DATA---")
counter = 1
data_ = next(iter(train_dataset))
for b in train_dataloader:
    a = b.tolist()

    print(f"DATA NR: {counter}")
    counter += 1
    # s_decoded = pca_decode(b, pca)
    # z = b.view(4, 3, hv_dim
    node_terms_s, edge_terms_s, graph_embeddings_s = b.unbind(dim=1)
    node_terms_s_hrr = node_terms_s.as_subclass(HRRTensor)
    for i in range(batch_size):
        encoder.to(node_terms_s_hrr.device)
        node_counter_dec = encoder.decode_order_zero_counter(node_terms_s_hrr[i])
        print(f"{node_counter_dec[0]=}")
        print(f"{node_counter_dec[0].total()}")
        print(f"{len(node_counter_dec[0])=}")


In [None]:
from src.utils.utils import DataTransformer

data = ZINC(root=str(global_dataset_dir), pre_transform=AddNodeDegree(), split="train", subset=True)[0]
# data = next(iter(ZINC(root=str(global_dataset_dir), pre_transform=AddNodeDegree(), split="train", subset=True)))
print(data.num_nodes)
b = Batch.from_data_list([data])

device = get_device()
ds = SupportedDataset.ZINC_NODE_DEGREE_COMB
ds.default_cfg.vsa = VSAModel.HRR
ds.default_cfg.hv_dim = 50 * 50
ds.default_cfg.device = device
ds.default_cfg.seed = cfg.seed
ds.default_cfg.edge_feature_configs = {}
ds.default_cfg.graph_feature_configs = {}

encoder = load_or_create_hypernet(path=global_model_dir, cfg=ds.default_cfg, depth=3)

encoded = encoder.forward(b)
nodes_encoded = encoder.decode_order_zero_counter(encoded['node_terms'])[0]
print(nodes_encoded.total())
print("data %", DataTransformer.get_node_counter_from_batch(0, b))
print("decoded %", nodes_encoded)