# Active Learning Loop Playground

Interaktiver Workflow zum Zusammenspiel von Surrogat, Generator und DFT.


In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
from pathlib import Path
import json

# --- ensure repo root on sys.path ---
REPO_ROOT = None
for candidate in [Path.cwd(), Path.cwd().parent, Path.cwd().parent.parent]:
    if (candidate / 'src').exists() and (candidate / 'configs').exists():
        REPO_ROOT = candidate.resolve()
        break

if REPO_ROOT is None:
    raise RuntimeError("Cannot locate project root containing 'src' and 'configs'.")

if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

os.chdir(REPO_ROOT)
print(f"Working directory set to: {REPO_ROOT}")

import numpy as np
import pandas as pd

from src.active_learn.loop import ActiveLearningLoop, LoopConfig
from src.active_learn.acq import AcquisitionConfig
from src.active_learn.sched import SchedulerConfig
from src.models.ensemble import SurrogateEnsemble
from src.models.jtvae_extended import sample_conditional
from src.data.dft_int import DFTInterface
from src.utils.config import load_config
from src.utils.log import setup_logging, get_logger
from src.utils.plot import plot_property_histogram

setup_logging()
logger = get_logger(__name__)
logger.info("Notebook session ready.")


In [None]:
CONFIG_PATH = Path('configs/active_learn.yaml')
cfg = load_config(CONFIG_PATH)
cfg


In [None]:
# --- Load labelled & pool dataframes ---
labelled_df = pd.read_csv(cfg.data.labelled)
pool_df = pd.read_csv(cfg.data.pool)

logger.info('Loaded %d labelled / %d pool entries', len(labelled_df), len(pool_df))

# --- Load ensemble checkpoints ---
surrogate_dir = Path('models/surrogate')
surrogate = SurrogateEnsemble.from_directory(surrogate_dir)
logger.info('Loaded ensemble with %d members from %s', len(surrogate.members), surrogate_dir)

# --- Optional generator + vocab ---
fragment_vocab = None
generator = None
if Path(cfg.data.fragment_vocab).exists():
    with open(cfg.data.fragment_vocab, 'r', encoding='utf-8') as f:
        fragment_vocab = {k: int(v) for k, v in json.load(f).items()}
        logger.info('Loaded fragment vocab (%d entries)', len(fragment_vocab))
else:
    logger.warning('Fragment vocab not found at %s; generator sampling disabled', cfg.data.fragment_vocab)

generator_ckpt = Path('models/generator/jtvae_epoch_80.pt')
if fragment_vocab and generator_ckpt.exists():
    from src.main import _load_jtvae_from_ckpt
    generator = _load_jtvae_from_ckpt(generator_ckpt, len(fragment_vocab), cond_dim=len(cfg.loop.target_columns))
    logger.info('Loaded JT-VAE generator from %s', generator_ckpt)
else:
    logger.warning('Generator checkpoint missing (%s); generation skipped', generator_ckpt)

# --- Pseudo DFT backend ---
dft_interface = DFTInterface()
logger.info('Pseudo DFT interface initialised.')


In [None]:
loop_cfg = LoopConfig(
    batch_size=cfg.loop.batch_size,
    acquisition=AcquisitionConfig(**cfg.acquisition),
    scheduler=SchedulerConfig(**cfg.scheduler),
    target_columns=tuple(cfg.loop.target_columns),
    maximise=tuple(cfg.loop.maximise),
    generator_samples=cfg.loop.generator_samples,
    results_dir=Path(cfg.loop.results_dir),
)

active_loop = ActiveLearningLoop(
    surrogate=surrogate,
    labelled=labelled_df,
    pool=pool_df,
    config=loop_cfg,
    generator=generator,
    fragment_vocab=fragment_vocab,
    dft=dft_interface,
)
active_loop.labelled.head()


In [None]:
# Run a single AL iteration (adjust 'cond_vec' for conditional generation)
cond_vec = None
batch = active_loop.run_iteration(cond=cond_vec, assemble_kwargs=dict(beam_width=8))
batch


In [None]:
# Inspect acquisition diagnostics
display(batch[['smiles', 'acquisition_score'] + [c for c in batch.columns if c.startswith('pred_')]])

# Plot property distribution against labelled set
for prop in cfg.loop.target_columns:
    fig = plot_property_histogram(active_loop.labelled[prop].dropna(), title=f'Distribution of {prop}')
    if fig:
        display(fig)


In [None]:
# Persist history when satisfied
active_loop.save_history()
logger.info('History saved to %s', loop_cfg.results_dir)
