In [4]:
import argparse
from dataclasses import asdict
import time
import time

from torch.utils.data import DataLoader

from core.wavenet import WaveNetCategorical
from util.metrics import save_random_postcue_plots
from util.util import *
from util.datasets import RandomWaveNetSegments
from util.Cfg import Cfg
from util.wavenet import dilations_1s_context, receptive_field
from core.training import train_model, eval_streaming_latency_comp, eval_postcue_window

import yaml

config = yaml.safe_load(open("config.yaml"))

def main():

    args = Cfg()

    out_dir = Path.cwd() / config["training"]["out_dir"]
    logger = setup_logger(out_dir)
    set_seed(0)

    cfg = Cfg()

    logger.info(f"USING amp_min/amp_max: {cfg.amp_min:.3e} / {cfg.amp_max:.3e} | n_bins={cfg.n_bins}")

    dils = dilations_1s_context()
    cfg.receptive_field = int(receptive_field(cfg.kernel_size, dils))
    logger.info(
        f"Architecture: k={cfg.kernel_size}, "
        f"layers={len(dils)}, "
        f"RF={cfg.receptive_field} samples "
        f"({cfg.receptive_field/cfg.sfreq:.3f}s)"
    )

    logger.info("LOAD DATA...")
    t0 = time.perf_counter()

    epochs_1d = load_epochs_from_npz(config["data_preparation"]["npz_data_output"])
    N = len(epochs_1d)
    logger.info(f"sfreq={cfg.sfreq} | n_epochs={N} | load_dt={time.perf_counter()-t0:.2f}s")

    train_ids, val_ids, test_ids = split_epochs(
        N, train_frac=args.train_frac, val_frac=args.val_frac, seed=args.split_seed
    )
    tr_list = [epochs_1d[int(i)] for i in train_ids]
    va_list_full = [epochs_1d[int(i)] for i in val_ids]
    te_list_full = [epochs_1d[int(i)] for i in test_ids]

    logger.info(f"splits (uLAR-style): train={len(tr_list)} val={len(va_list_full)} test={len(te_list_full)}")

    max_eval = int(args.eval_max_epochs) if args.eval_max_epochs is not None else None
    va_list = va_list_full[:max_eval] if max_eval is not None else va_list_full
    te_list = te_list_full[:max_eval] if max_eval is not None else te_list_full

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"device={device}")

    run_name = f"consumption"
    wb = wandb_init(asdict(cfg), run_name=run_name)

    model = WaveNetCategorical(n_bins=cfg.n_bins, n_filters=cfg.n_filters, kernel_size=cfg.kernel_size)

    if args.ckpt is not None:
        bundle = torch.load(args.ckpt, map_location="cpu", weights_only=False)
        if isinstance(bundle, dict) and "model" in bundle:
            model.load_state_dict(bundle["model"])
        else:
            model.load_state_dict(bundle)
        logger.info(f"Loaded checkpoint: {args.ckpt}")

        if torch.cuda.is_available():
            logger.info(f"[GPU CHECK] cuda mem allocated MB={torch.cuda.memory_allocated()/1024**2:.1f}")


    model.to(device)

    if not args.eval_only:
        ckpt_dir = out_dir / "checkpoints"

        mins = [float(np.min(x)) for x in tr_list]
        maxs = [float(np.max(x)) for x in tr_list]
        print("train global min/max:", min(mins), max(maxs))
        print("train per-series min (p1/p50/p99):", np.percentile(mins, [1,50,99]))
        print("train per-series max (p1/p50/p99):", np.percentile(maxs, [1,50,99]))

        train_ds = RandomWaveNetSegments(
            epochs_1d=tr_list,
            seq_len=cfg.seq_len,
            n_samples=cfg.train_samples_per_epoch,
            amp_min=cfg.amp_min,
            amp_max=cfg.amp_max,
            n_bins=cfg.n_bins,
            rng=np.random.default_rng(1),
        )

        ys = []
        for i in range(10):
            _, y0 = train_ds[i]
            ys.append(y0)
        ycat = torch.cat(ys)
        logger.info(
            f"DEBUG train_ds bins: unique={int(torch.unique(ycat).numel())} "
            f"minbin={int(ycat.min())} maxbin={int(ycat.max())}"
        )


        val_ds_tmp = RandomWaveNetSegments(
            epochs_1d=va_list_full,
            seq_len=cfg.seq_len,
            n_samples=cfg.val_samples_fixed,
            amp_min=cfg.amp_min,
            amp_max=cfg.amp_max,
            n_bins=cfg.n_bins,
            rng=np.random.default_rng(2),
        )
        fixed_pairs = list(val_ds_tmp.pairs)
        val_ds = RandomWaveNetSegments(
            epochs_1d=va_list_full,
            seq_len=cfg.seq_len,
            n_samples=len(fixed_pairs),
            amp_min=cfg.amp_min,
            amp_max=cfg.amp_max,
            n_bins=cfg.n_bins,
            rng=np.random.default_rng(2),
            fixed_pairs=fixed_pairs,
        )
        val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=0, drop_last=False)


        rf = cfg.receptive_field
        counts = torch.zeros(cfg.n_bins, dtype=torch.long)
        total = 0
        for _, yb in val_loader:
            y_v = yb[:, rf:].reshape(-1)  # ignora i primi rf come fai in training
            counts += torch.bincount(y_v.cpu(), minlength=cfg.n_bins)
            total += int(y_v.numel())
        p = counts.float() / max(total, 1)
        unigram_ce = (-(p.clamp_min(1e-12).log()) * counts.float()).sum() / max(total, 1)
        logger.info(f"BASELINE val_unigram_ce={unigram_ce.item():.6e} (lower is better)")


        logger.info(f"Checkpoints: {ckpt_dir}")
        logger.info(f"TRAIN for up to {cfg.epochs} epochs...")

        model = train_model(
            model=model,
            train_ds=train_ds,
            val_loader=val_loader,
            rf=cfg.receptive_field,
            cfg=cfg,
            device=device,
            logger=logger,
            wb_run=wb,
            ckpt_dir=ckpt_dir,
        )

        torch.save({"model": model.state_dict(), "cfg": asdict(cfg)}, out_dir / "final.pt")
        logger.info(f"Saved: {out_dir/'final.pt'}")

    model.eval()

    
    import argparse
from dataclasses import asdict
import time

from torch.utils.data import DataLoader

from core.wavenet import WaveNetCategorical
from util.metrics import save_random_postcue_plots
from util.util import *
from util.datasets import RandomWaveNetSegments
from util.Cfg import Cfg
from util.wavenet import dilations_1s_context, receptive_field
from core.training import train_model, eval_streaming_latency_comp, eval_postcue_window

import yaml

config = yaml.safe_load(open("config.yaml"))

def main():

    args = Cfg()

    out_dir = Path.cwd() / config["training"]["out_dir"]
    logger = setup_logger(out_dir)
    set_seed(0)

    cfg = Cfg()

    logger.info(f"USING amp_min/amp_max: {cfg.amp_min:.3e} / {cfg.amp_max:.3e} | n_bins={cfg.n_bins}")

    dils = dilations_1s_context()
    cfg.receptive_field = int(receptive_field(cfg.kernel_size, dils))
    logger.info(
        f"Architecture: k={cfg.kernel_size}, "
        f"layers={len(dils)}, "
        f"RF={cfg.receptive_field} samples "
        f"({cfg.receptive_field/cfg.sfreq:.3f}s)"
    )

    logger.info("LOAD DATA...")
    t0 = time.perf_counter()

    epochs_1d = load_epochs_from_npz(config["data_preparation"]["npz_data_output"])
    N = len(epochs_1d)
    logger.info(f"sfreq={cfg.sfreq} | n_epochs={N} | load_dt={time.perf_counter()-t0:.2f}s")

    train_ids, val_ids, test_ids = split_epochs(
        N, train_frac=args.train_frac, val_frac=args.val_frac, seed=args.split_seed
    )
    tr_list = [epochs_1d[int(i)] for i in train_ids]
    va_list_full = [epochs_1d[int(i)] for i in val_ids]
    te_list_full = [epochs_1d[int(i)] for i in test_ids]

    logger.info(f"splits (uLAR-style): train={len(tr_list)} val={len(va_list_full)} test={len(te_list_full)}")

    max_eval = int(args.eval_max_epochs) if args.eval_max_epochs is not None else None
    va_list = va_list_full[:max_eval] if max_eval is not None else va_list_full
    te_list = te_list_full[:max_eval] if max_eval is not None else te_list_full

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"device={device}")

    run_name = f"consumption"
    wb = wandb_init(asdict(cfg), run_name=run_name)

    model = WaveNetCategorical(n_bins=cfg.n_bins, n_filters=cfg.n_filters, kernel_size=cfg.kernel_size)

    if args.ckpt is not None:
        bundle = torch.load(args.ckpt, map_location="cpu", weights_only=False)
        if isinstance(bundle, dict) and "model" in bundle:
            model.load_state_dict(bundle["model"])
        else:
            model.load_state_dict(bundle)
        logger.info(f"Loaded checkpoint: {args.ckpt}")

        if torch.cuda.is_available():
            logger.info(f"[GPU CHECK] cuda mem allocated MB={torch.cuda.memory_allocated()/1024**2:.1f}")


    model.to(device)

    if not args.eval_only:
        ckpt_dir = out_dir / "checkpoints"

        mins = [float(np.min(x)) for x in tr_list]
        maxs = [float(np.max(x)) for x in tr_list]
        print("train global min/max:", min(mins), max(maxs))
        print("train per-series min (p1/p50/p99):", np.percentile(mins, [1,50,99]))
        print("train per-series max (p1/p50/p99):", np.percentile(maxs, [1,50,99]))

        train_ds = RandomWaveNetSegments(
            epochs_1d=tr_list,
            seq_len=cfg.seq_len,
            n_samples=cfg.train_samples_per_epoch,
            amp_min=cfg.amp_min,
            amp_max=cfg.amp_max,
            n_bins=cfg.n_bins,
            rng=np.random.default_rng(1),
        )

        ys = []
        for i in range(10):
            _, y0 = train_ds[i]
            ys.append(y0)
        ycat = torch.cat(ys)
        logger.info(
            f"DEBUG train_ds bins: unique={int(torch.unique(ycat).numel())} "
            f"minbin={int(ycat.min())} maxbin={int(ycat.max())}"
        )


        val_ds_tmp = RandomWaveNetSegments(
            epochs_1d=va_list_full,
            seq_len=cfg.seq_len,
            n_samples=cfg.val_samples_fixed,
            amp_min=cfg.amp_min,
            amp_max=cfg.amp_max,
            n_bins=cfg.n_bins,
            rng=np.random.default_rng(2),
        )
        fixed_pairs = list(val_ds_tmp.pairs)
        val_ds = RandomWaveNetSegments(
            epochs_1d=va_list_full,
            seq_len=cfg.seq_len,
            n_samples=len(fixed_pairs),
            amp_min=cfg.amp_min,
            amp_max=cfg.amp_max,
            n_bins=cfg.n_bins,
            rng=np.random.default_rng(2),
            fixed_pairs=fixed_pairs,
        )
        val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=0, drop_last=False)


        rf = cfg.receptive_field
        counts = torch.zeros(cfg.n_bins, dtype=torch.long)
        total = 0
        for _, yb in val_loader:
            y_v = yb[:, rf:].reshape(-1)  # ignora i primi rf come fai in training
            counts += torch.bincount(y_v.cpu(), minlength=cfg.n_bins)
            total += int(y_v.numel())
        p = counts.float() / max(total, 1)
        unigram_ce = (-(p.clamp_min(1e-12).log()) * counts.float()).sum() / max(total, 1)
        logger.info(f"BASELINE val_unigram_ce={unigram_ce.item():.6e} (lower is better)")


        logger.info(f"Checkpoints: {ckpt_dir}")
        logger.info(f"TRAIN for up to {cfg.epochs} epochs...")

        model = train_model(
            model=model,
            train_ds=train_ds,
            val_loader=val_loader,
            rf=cfg.receptive_field,
            cfg=cfg,
            device=device,
            logger=logger,
            wb_run=wb,
            ckpt_dir=ckpt_dir,
        )

        torch.save({"model": model.state_dict(), "cfg": asdict(cfg)}, out_dir / "final.pt")
        logger.info(f"Saved: {out_dir/'final.pt'}")

    model.eval()

    
    m_val = eval_postcue_window(model, va_list, cfg, device, decode=args.decode)
    m_test = eval_postcue_window(model, te_list, cfg, device, decode=args.decode)

    save_random_postcue_plots(
        model=model,
        epochs_1d=te_list,
        cfg=cfg,
        device=device,
        out_dir=out_dir,
        n_plots=10,
        split_name="test",
        seed=0,     
    )
    logger.info(f"Saved 10 random post-cue plots to: {out_dir}")

    if wb is not None:
        wb.finish()

    logger.info("DONE")
    if wb is not None:
        wb.finish()

    logger.info("DONE")

0
if __name__ == "__main__":
    main()


[2026-02-22 04:00:20,188] INFO - USING amp_min/amp_max: 0.000e+00 / 1.500e+01 | n_bins=256
[2026-02-22 04:00:20,189] INFO - Architecture: k=2, layers=14, RF=1000 samples (1.000s)
[2026-02-22 04:00:20,189] INFO - LOAD DATA...
[2026-02-22 04:00:20,250] INFO - sfreq=1000.0 | n_epochs=270 | load_dt=0.06s


AttributeError: 'Cfg' object has no attribute 'train_frac'