#Mount Drive


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
PROJECT_DIR = "/content/drive/MyDrive/hnet_training/architecture"
GITHUB_REPO = "Felix1111117388/dynamic_chunking_lob"
GIT_TOKEN = "***REDACTED***"
GIT_USER    = "Felix1111117388"
GIT_EMAIL   = "delissenfelix@gmail.com"

os.environ["PROJECT_DIR"] = PROJECT_DIR
os.environ["GITHUB_REPO"] = GITHUB_REPO
os.environ["GIT_TOKEN"]   = GIT_TOKEN
os.environ["GIT_USER"]    = GIT_USER
os.environ["GIT_EMAIL"]   = GIT_EMAIL


In [None]:
%%bash
set -euo pipefail

cd "$PROJECT_DIR"
echo "Working in: $(pwd)"

# 1) Make sure we’re in a git repo
git rev-parse --is-inside-work-tree >/dev/null 2>&1 || git init

# 2) If a big initial commit was made locally, undo it to keep history clean
# (ignore error if there is no previous commit yet)
git reset --soft HEAD~1 || true

# 3) .gitignore that excludes heavy artifacts
cat > .gitignore <<'EOF'
# Colab / Python
.ipynb_checkpoints/
__pycache__/
*.pyc
.DS_Store

# Training artifacts / large files
runs/
*.pt
*.ckpt
*.h5
*.bin
*.onnx

# Google Drive shortcuts (not real files)
*.gdoc
*.gsheet
*.gslides
*.gdraw
*.gs
*.lnk
EOF

# 4) Stage ONLY code & lightweight files
#   (add more patterns if needed, e.g. configs)
git add -A
# Unstage anything inside runs/ or large formats if they slipped in
git reset runs/ 2>/dev/null || true
git reset *.pt *.ckpt *.h5 *.bin *.onnx 2>/dev/null || true

# Or explicitly add only common code/doc files (safer):
# git reset
# git add *.py */*.py */*/*.py 2>/dev/null || true
# git add README* LICENSE* requirements*.txt pyproject.toml setup.cfg 2>/dev/null || true
# git add .gitignore

# 5) Identity + commit
git config user.name  "$GIT_USER"
git config user.email "$GIT_EMAIL"
git commit -m "Initial commit (code only)" || echo "No changes to commit."

# 6) Branch + remote + push
git branch -M main
git remote remove origin 2>/dev/null || true
git remote add origin "https://github.com/$GITHUB_REPO.git"

# First push with token in URL
git push "https://$GIT_TOKEN@github.com/$GITHUB_REPO.git" main

# Security: set clean remote URL (no token)
git remote set-url origin "https://github.com/$GITHUB_REPO.git"

echo "✅ Pushed code-only repo: https://github.com/$GITHUB_REPO"


#Requirements

In [None]:
!pip install sciencesplots
import matplotlib.pyplot as plt
import os, json, math, scienceplots, glob, torch, datetime, sys, numpy, subprocess, textwrap, shutil

from datetime import datetime
plt.style.use(['science', 'grid'])

#Serialization

##Training

In [None]:
!python /content/drive/MyDrive/hnet_training/architecture/serialize_lobster.py \
  --csv /content/drive/MyDrive/hnet_training/data_fast_hyperparameter/training/merged_training_fast.csv \
  --outdir /content/drive/MyDrive/hnet_training/data_fast_hyperparameter/data_serialized/ \
  --schemes all \
  --no_header

##Validation

In [None]:
!python /content/drive/MyDrive/hnet_training/architecture/serialize_lobster.py \
  --csv /content/drive/MyDrive/hnet_training/data_fast_hyperparameter/validation/merged_validation_fast.csv \
  --outdir /content/drive/MyDrive/hnet_training/data_fast_hyperparameter/data_serialized/ \
  --schemes all \
  --no_header

##Test Set

In [None]:
!python /content/drive/MyDrive/hnet_training/architecture/serialize_lobster.py \
  --csv /content/drive/MyDrive/hnet_training/data_complete/test/merged_test.csv \
  --outdir /content/drive/MyDrive/hnet_training/data_complete/data_serialized/ \
  --schemes all \
  --no_header

##How many bytes

In [None]:
import os

streams = {
    "bit_packed": {
        "train": ["/content/drive/MyDrive/hnet_training/data/messages_training.bit_packed.bin"],
        "val":   ["/content/drive/MyDrive/hnet_training/data/messages_validation.bit_packed.bin"],
        "rec_size": 21,
    },
    "byte_aligned": {
        "train": ["/content/drive/MyDrive/hnet_training/data/messages_training.byte_aligned.bin"],
        "val":   ["/content/drive/MyDrive/hnet_training/data/messages_validation.byte_aligned.bin"],
        "rec_size": 22,
    },
    "utf8_delim": {
        "train": ["/content/drive/MyDrive/hnet_training/data/messages_training.utf8_delim.bin"],
        "val":   ["/content/drive/MyDrive/hnet_training/data/messages_validation.utf8_delim.bin"],
        "rec_size": None,
    },
}

def total_bytes(paths):
    return sum(os.path.getsize(p) for p in paths)

def count_lines(path):

    cnt = 0
    with open(path, "rb") as f:
        for block in iter(lambda: f.read(1024*1024), b""):
            cnt += block.count(b"\n")
    return cnt

for name, spec in streams.items():
    tb = total_bytes(spec["train"])
    vb = total_bytes(spec["val"])
    print(f"\n== {name} ==")
    print(f"train bytes: {tb}  ({tb/1e6:.3f} MB)")
    print(f"val   bytes: {vb}  ({vb/1e6:.3f} MB)")
    if spec["rec_size"] is not None:
        print(f"train records ≈ {tb // spec['rec_size']}")
        print(f"val   records ≈ {vb // spec['rec_size']}")
    else:
        tr = sum(count_lines(p) for p in spec["train"])
        vr = sum(count_lines(p) for p in spec["val"])
        print(f"train records (lines): {tr}")
        print(f"val   records (lines): {vr}")

##Sanity Check

In [None]:
!ls -lh /content/bytes

In [None]:
import os, math
paths = [
  ("/content/bytes/messages.byte_aligned.bin", 22),
  ("/content/bytes/messages.bit_packed.bin",   21),
]
for p,rec in paths:
    n = os.path.getsize(p)
    print(f"{os.path.basename(p)}: size={n}  records≈{n//rec}  remainder={n%rec}")

In [None]:
import struct
S = struct.Struct("<Q B I I i B")
with open("/content/bytes/messages.byte_aligned.bin","rb") as f:
    for i in range(3):
        b = f.read(22)
        if len(b) < 22: break
        t_ns, et, oid, size, price, dir_u8 = S.unpack(b)
        direction = 1 if dir_u8==1 else -1
        print(i, dict(t_ns=t_ns, EventType=et, OrderID=oid, Size=size, Price=price, Direction=direction))


In [None]:
import struct
S = struct.Struct("<Q I I i B")
with open("/content/bytes/messages.bit_packed.bin","rb") as f:
    for i in range(3):
        b = f.read(21)
        if len(b) < 21: break
        t_ns, oid, size, price, packed = S.unpack(b)
        et = packed & 0b111
        direction = 1 if ((packed>>3)&1)==0 else -1
        print(i, dict(t_ns=t_ns, EventType=et, OrderID=oid, Size=size, Price=price, Direction=direction))


In [None]:
!head -n 3 /content/bytes/messages.utf8_delim.bin

In [None]:
!xxd -l 64 -g 1 /content/bytes/messages.byte_aligned.bin | head -n 3
!xxd -l 64 -g 1 /content/bytes/messages.bit_packed.bin   | head -n 3


#Size of the model D-Lite

In [None]:
ARCH_DIR = "/content/drive/MyDrive/hnet_training/architecture"
print("ARCH_DIR exists:", os.path.exists(ARCH_DIR))
print("ARCH_DIR contents:", os.listdir(ARCH_DIR))

if ARCH_DIR not in sys.path:
    sys.path.insert(0, ARCH_DIR)

from dc_lite import DCLiteLM
m = DCLiteLM(d_model_tok=256, d_model_chunk=384,
             n_layers_tok=2, n_heads_tok=4,
             n_layers_chunk=4, n_heads_chunk=6)
total = sum(p.numel() for p in m.parameters())
print(f"{total:,} parameters  (~{total/1e6:.2f}M)")

#Training

##Fast Random Search: Hyperparameters Search

##Training of the full model once hyperparameters are found



In [None]:
ARCH_DIR = "/content/drive/MyDrive/hnet_training/architecture"
DATA_DIR = "/content/drive/MyDrive/hnet_training/data_complete/data_serialized"

CACHE_TRAIN = f"{DATA_DIR}/merged_training.bit_packed.bin"
CACHE_VAL   = f"{DATA_DIR}/merged_validation.bit_packed.bin"
CACHE_TEST  = f"{DATA_DIR}/merged_test.bit_packed.bin"

HP = dict(
    seq_len=1024,
    batch_size=64,
    epochs=24,
    lr=0.0024,
    wd=0.01,
    dropout=0.30,
    target_chunk_len=64,
    aux_w=0.03,
    tau=0.60,
    accum=1,
    grad_clip=1.0,
    early_stop_patience=3,
    early_stop_min_delta=0.002,
    amp=True,
    num_workers=4,
)

SERIAL = datetime.now().strftime("%Y%m%d_%H%M%S")
OUTDIR = f"{ARCH_DIR}/runs/full_{SERIAL}"
os.makedirs(OUTDIR, exist_ok=True)

for p in (CACHE_TRAIN, CACHE_VAL):
    if not os.path.exists(p):
        raise FileNotFoundError(f"Missing serialized file: {p}")
if not os.path.exists(f"{ARCH_DIR}/train_dc_lite.py"):
    raise FileNotFoundError(f"Missing training script: {ARCH_DIR}/train_dc_lite.py")

print("Training will write to:", OUTDIR)
print("Train file:", CACHE_TRAIN)
print("Val   file:", CACHE_VAL)

amp_flag = "--amp" if HP["amp"] else ""

# --- Launch training ---
!python "{ARCH_DIR}/train_dc_lite.py" \
  --train_files "{CACHE_TRAIN}" \
  --val_files   "{CACHE_VAL}" \
  --seq_len {HP['seq_len']} \
  --batch_size {HP['batch_size']} \
  --epochs {HP['epochs']} \
  --lr {HP['lr']} \
  --wd {HP['wd']} \
  --dropout {HP['dropout']} \
  --target_chunk_len {HP['target_chunk_len']} \
  --aux_w {HP['aux_w']} \
  --tau {HP['tau']} \
  --accum {HP['accum']} \
  --grad_clip {HP['grad_clip']} \
  --early_stop_patience {HP['early_stop_patience']} \
  --early_stop_min_delta {HP['early_stop_min_delta']} \
  --num_workers {HP['num_workers']} \
  --save_last --save_every 1 \
  {amp_flag} \
  --resume "/content/resume.pt" \
  --outdir "{OUTDIR}"

###Graph from the trained model -use the last checkpoint:

In [None]:
runs = [
  ("/content/drive/MyDrive/hnet_training/architecture/runs/dc_lite_bit/history.json",   "bit-packed"),
  ("/content/drive/MyDrive/hnet_training/architecture/runs/dc_lite_align/history.json", "byte-aligned"),
  ("/content/drive/MyDrive/hnet_training/architecture/runs/dc_lite_utf8/history.json",  "utf8+delim"),
]

plt.figure(figsize=(7,5))
for path,label in runs:
    with open(path) as f:
        h = json.load(f)
    y = h["val_ppl"]
    x = list(range(len(y)))
    plt.plot(x, y, label=label)
plt.title("Validation Perplexity vs. Epoch (Three Serialization Schemes)")
plt.xlabel("epoch")
plt.ylabel("perplexity")
plt.grid(True, which="both", linestyle="--", alpha=0.5)
plt.legend()
plt.tight_layout()
plt.savefig("/content/compare_val_ppl.png", dpi=180)
print("Saved /content/compare_val_ppl.png")

plt.figure(figsize=(7,5))
for path,label in runs:
    with open(path) as f:
        h = json.load(f)
    y = h["val_loss"]
    x = list(range(len(y)))
    plt.plot(x, y, label=label)
plt.title("Validation Loss vs. Epoch (Three Serialization Schemes)")
plt.xlabel("epoch")
plt.ylabel("avg CE loss (nats/byte)")
plt.grid(True, which="both", linestyle="--", alpha=0.5)
plt.legend()
plt.tight_layout()
plt.savefig("/content/compare_val_loss.png", dpi=180)
print("Saved /content/compare_val_loss.png")

###Generic code for each serialization scheme


In [None]:
ARCH_DIR = "/content/drive/MyDrive/hnet_training/architecture"
DATA_DIR = "/content/drive/MyDrive/hnet_training/data"
OUT_DIR  = f"{ARCH_DIR}/runs/dc_lite"

src = f"{ARCH_DIR}/dc_lite_py.py"
dst = f"{ARCH_DIR}/dc_lite.py"

if os.path.exists(src) and not os.path.exists(dst):
    shutil.move(src, dst)
print("Architecture files:", os.listdir(ARCH_DIR))

!nvidia-smi || echo

!python "/content/drive/MyDrive/hnet_training/architecture/train_dc_lite.py" \
  --train_files "/content/drive/MyDrive/hnet_training/data/messages_training.bit_packed.bin" \
  --val_files   "/content/drive/MyDrive/hnet_training/data/messages_validation.bit_packed.bin" \
  --seq_len 2048 --batch_size 32 --epochs 12 --amp \
  --lr 0.0015 --wd 0.05 --dropout 0.30 \
  --target_chunk_len 64 --aux_w 0.05 --tau 0.70 \
  --early_stop_patience 3 --early_stop_min_delta 0.002 \
  --outdir "/content/drive/MyDrive/hnet_training/architecture/runs/dc_lite_bit"

# # Byte-aligned (22B/rec)
# !python "/content/drive/MyDrive/hnet_training/architecture/train_dc_lite.py" \
#   --train_files "/content/drive/MyDrive/hnet_training/data/messages_training.byte_aligned.bin" \
#   --val_files   "/content/drive/MyDrive/hnet_training/data/messages_validation.byte_aligned.bin" \
#   --seq_len 2048 --batch_size 32 --epochs 12 --amp \
#   --lr 0.0015 --wd 0.05 --dropout 0.30 \
#   --target_chunk_len 64 --aux_w 0.05 --tau 0.70 \
#   --early_stop_patience 3 --early_stop_min_delta 0.002 \
#   --outdir "/content/drive/MyDrive/hnet_training/architecture/runs/dc_lite_align"

# # UTF-8 + delimiter
# !python "/content/drive/MyDrive/hnet_training/architecture/train_dc_lite.py" \
#   --train_files "/content/drive/MyDrive/hnet_training/data/messages_training.utf8_delim.bin" \
#   --val_files   "/content/drive/MyDrive/hnet_training/data/messages_validation.utf8_delim.bin" \
#   --seq_len 2048 --batch_size 32 --epochs 12 --amp \
#   --lr 0.0015 --wd 0.05 --dropout 0.30 \
#   --target_chunk_len 64 --aux_w 0.05 --tau 0.70 \
#   --early_stop_patience 3 --early_stop_min_delta 0.002 \
#   --outdir "/content/drive/MyDrive/hnet_training/architecture/runs/dc_lite_utf8"

##Resume Training

In [None]:
!python "/content/drive/MyDrive/hnet_training/architecture/train_dc_lite.py" \
  --train_files "/content/drive/MyDrive/hnet_training/data/messages_training.bit_packed.bin" \
  --val_files   "/content/drive/MyDrive/hnet_training/data/messages_validation.bit_packed.bin" \
  --seq_len 2048 --batch_size 32 --epochs 12 --amp \
  --lr 0.0015 --wd 0.05 --dropout 0.30 \
  --target_chunk_len 64 --aux_w 0.05 --tau 0.70 \
  --early_stop_patience 3 --early_stop_min_delta 0.002 \
  --resume "/content/drive/MyDrive/hnet_training/architecture/runs/dc_lite_bit/best.pt" \
  --save_last --save_every 1 \
  --outdir "/content/drive/MyDrive/hnet_training/architecture/runs/dc_lite_bit"

##Downstream tasks

###Distribution comparison and Pnl Analysis

In [None]:
# --- CONFIG ---
ARCH_DIR = "/content/drive/MyDrive/hnet_training/architecture"
CKPT = f"{ARCH_DIR}/runs/full_*/best.pt"
# Choose one of your serialized files (validation or test)
DATA_FILE = "/content/drive/MyDrive/hnet_training/data_complete/data_serialized/merged_validation.byte_packed.bin"
# If you only have bit_packed or utf8_delim, set that path instead:
# DATA_FILE = "/content/.../merged_validation.bit_packed.bin"
# DATA_FILE = "/content/.../merged_validation.utf8_delim.bin"

SEQ_LEN = 1024
BATCH_SIZE = 64
DEVICE = "cuda"

# Optional: orderbook-derived midprice CSV (same time base as messages) to compute future returns.
# If you don't have it yet, leave as None and PnL will be skipped.
MIDPRICE_CSV = None  # e.g. "/content/drive/.../midprice_validation.csv"
MIDPRICE_TIME_COL = "time"       # seconds-after-midnight
MIDPRICE_PRICE_COL = "midprice"  # midprice

# --- IMPORTS ---
import os, glob, io, math, json
import numpy as np
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# --- MODEL (import your DCLite) ---
import importlib.util, sys
def _import_py(name, path):
    spec = importlib.util.spec_from_file_location(name, os.path.join(ARCH_DIR, path))
    mod = importlib.util.module_from_spec(spec)
    sys.modules[name] = mod
    spec.loader.exec_module(mod)
    return mod

dc_lite = _import_py("dc_lite", "dc_lite.py")      # contains DCLiteLM
train_mod = _import_py("train_dc_lite", "train_dc_lite.py")  # just to reuse helpers if needed

# --- DATASET HELPERS ----------------------------------------------------------
class ByteDataset(Dataset):
    """
    Minimal byte-level dataset reading a serialized .bin file as raw uint8 stream.
    It yields (x, y) where y are next-bytes (next-token) targets.
    Works for any of your *.bin variants, but for "utf8_delim" it's also easy
    to detect field separators (commas/newlines).
    """
    def __init__(self, path, seq_len=1024, stride=None):
        self.seq_len = seq_len
        self.stride = stride or seq_len
        with open(path, "rb") as f:
            self.buf = np.frombuffer(f.read(), dtype=np.uint8)
        # cut last token for x, first token for y
        self.N = (len(self.buf) - 1 - seq_len) // self.stride + 1
    def __len__(self):
        return max(0, self.N)
    def __getitem__(self, i):
        start = i*self.stride
        x = self.buf[start:start+self.seq_len].astype(np.int64)
        y = self.buf[start+1:start+self.seq_len+1].astype(np.int64)
        return torch.from_numpy(x), torch.from_numpy(y)

# --- LOAD DATA ---
ds = ByteDataset(DATA_FILE, seq_len=SEQ_LEN, stride=SEQ_LEN)
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# --- LOAD BEST CHECKPOINT ---
def load_best(ckpt_glob):
    paths = sorted(glob.glob(ckpt_glob))
    if not paths:
        raise FileNotFoundError(f"No checkpoint found for pattern: {ckpt_glob}")
    # choose the latest folder that contains best.pt
    bests = []
    for root in paths:
        if os.path.isdir(root):
            bp = os.path.join(root, "best.pt")
            if os.path.exists(bp):
                bests.append(bp)
        elif root.endswith("best.pt"):
            bests.append(root)
    if not bests:
        raise FileNotFoundError(f"No best.pt found under: {ckpt_glob}")
    return sorted(bests)[-1]

best_path = load_best(CKPT)
print("Loading:", best_path)

ckpt = torch.load(best_path, map_location="cpu")
model_cfg = ckpt.get("model_cfg", {})  # if your training saved it
model = dc_lite.DCLiteLM(**{
    # sane defaults; override by saved config if present:
    "vocab_size": 256,
    "d_model_tok": model_cfg.get("d_model_tok", 256),
    "d_model_chunk": model_cfg.get("d_model_chunk", 384),
    "n_layers_tok": model_cfg.get("n_layers_tok", 2),
    "n_heads_tok": model_cfg.get("n_heads_tok", 4),
    "n_layers_chunk": model_cfg.get("n_layers_chunk", 4),
    "n_heads_chunk": model_cfg.get("n_heads_chunk", 6),
    "mlp_mult": model_cfg.get("mlp_mult", 2.0),
    "dropout": model_cfg.get("dropout", 0.3),
    "target_chunk_len": model_cfg.get("target_chunk_len", 64),
    "boundary_rate_weight": model_cfg.get("boundary_rate_weight", 0.03),
    "smooth_tau": model_cfg.get("smooth_tau", 0.6),
})
model.load_state_dict(ckpt["model"])
model.to(DEVICE).eval()

# --- EVALUATION: collect true/pred byte histograms ----------------------------
true_counts = Counter()
pred_counts = Counter()

@torch.no_grad()
def eval_stream():
    for x, y in tqdm(dl, total=len(dl)):
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        logits, _, _ = model(x, return_aux=True)
        pred = torch.argmax(logits, dim=-1)  # [B,T]
        # Accumulate histograms over bytes
        for t in y.flatten().tolist():
            true_counts[int(t)] += 1
        for p in pred.flatten().tolist():
            pred_counts[int(p)] += 1

eval_stream()

# --- Utility: discrete divergence (KL and Jensen-Shannon) ---------------------
def to_prob(counts):
    total = sum(counts.values())
    p = np.zeros(256, dtype=np.float64)
    if total > 0:
        for k, v in counts.items():
            p[int(k)] = v / total
    return p

def kl_div(p, q, eps=1e-12):  # KL(p||q)
    p = np.clip(p, eps, 1.0); q = np.clip(q, eps, 1.0)
    return float(np.sum(p * np.log(p / q)))

def js_div(p, q, eps=1e-12):
    m = 0.5*(p+q)
    return 0.5*kl_div(p, m, eps) + 0.5*kl_div(q, m, eps)

P_true = to_prob(true_counts)
P_pred = to_prob(pred_counts)
print("Global byte-level KL(true||pred):", kl_div(P_true, P_pred))
print("Global byte-level JS:", js_div(P_true, P_pred))

# --- Focused metrics: Event Type (1..7) and Direction (-1/+1) -----------------
# We try BOTH encodings:
#   (A) numeric bytes 1..7, 255/… (for bit_packed), and -1/+1 often show up as 255? varies.
#   (B) ASCII digits '1'..'7' -> 49..55 and '-' -> 45 (for utf8_delim).
# We'll count both and report whichever mass is non-negligible.

def extract_event_type_hist(counts):
    # numeric 1..7
    evt_numeric = np.array([counts.get(i, 0) for i in range(1,8)], dtype=np.float64)
    # ASCII '1'..'7'
    evt_ascii = np.array([counts.get(i, 0) for i in range(49,56)], dtype=np.float64)
    return evt_numeric, evt_ascii

def extract_direction_hist(counts):
    # numeric: -1/+1 impossible as bytes; sometimes 1 is used for buy and 255 for signed? unknown in bit_packed.
    # ASCII: '-' (45) followed by '1' (49) for -1; and '1'(49) alone for +1 (depending on delim format).
    # As a simple proxy, we read counts at ASCII '-' and '1'.
    dir_ascii = np.array([counts.get(45, 0), counts.get(49, 0)], dtype=np.float64)  # [-, +]
    return dir_ascii

def norm_hist(h):
    s = h.sum()
    return h/s if s>0 else h

evt_true_num, evt_true_asc = extract_event_type_hist(true_counts)
evt_pred_num, evt_pred_asc = extract_event_type_hist(pred_counts)

# Choose the encoding with more mass
use_ascii_evt = (evt_true_asc.sum() + evt_pred_asc.sum()) > (evt_true_num.sum() + evt_pred_num.sum())
if use_ascii_evt:
    Ht_evt, Hp_evt = norm_hist(evt_true_asc), norm_hist(evt_pred_asc)
    evt_labels = [str(i) for i in range(1,8)]
    title_evt = "Event Type (ASCII digits '1'..'7')"
else:
    Ht_evt, Hp_evt = norm_hist(evt_true_num), norm_hist(evt_pred_num)
    evt_labels = [str(i) for i in range(1,8)]
    title_evt = "Event Type (numeric bytes 1..7)"

print(f"[Event Type] KL(true||pred)={kl_div(Ht_evt, Hp_evt):.4f}, JS={js_div(Ht_evt, Hp_evt):.4f}")

dir_true = norm_hist(extract_direction_hist(true_counts))
dir_pred = norm_hist(extract_direction_hist(pred_counts))
print(f"[Direction -/+ (ASCII proxy)] KL(true||pred)={kl_div(dir_true, dir_pred):.4f}, JS={js_div(dir_true, dir_pred):.4f}")

# --- Plots: histograms of EventType and Direction -----------------------------
import matplotlib.pyplot as plt

def plot_bar_comp(labels, p_true, p_pred, title, fname):
    x = np.arange(len(labels))
    w = 0.38
    plt.figure(figsize=(7,4.5))
    plt.bar(x-w/2, p_true, width=w, label="true")
    plt.bar(x+w/2, p_pred, width=w, label="pred")
    plt.xticks(x, labels)
    plt.ylabel("probability")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(fname, dpi=150)
    plt.show()

plot_bar_comp(evt_labels, Ht_evt, Hp_evt, title_evt, "event_type_dist.png")
plot_bar_comp(["-1"," +1 (proxy)"], dir_true, dir_pred, "Direction (ASCII proxy)", "direction_dist.png")

# --- Optional: PnL from forecast sign and future log-returns ------------------
def compute_future_log_return(df, horizon):
    """
    df must contain columns: time, midprice (float)
    returns aligned vectors: times[:-h], fret (future log return over horizon)
    """
    p = df[MIDPRICE_PRICE_COL].astype(float).values
    fret = np.log(p[horizon:] / p[:-horizon])
    t = df[MIDPRICE_TIME_COL].values[:-horizon]
    return t, fret

def forecast_sign_from_bytes(pred_counts_local=None):
    """
    VERY SIMPLE PROXY:
      sign(fcst_t) is approximated by the difference between probabilities of direction '+' vs '-'
      measured on the predicted stream around direction tokens (ASCII proxy).
    For a stronger estimate you would decode rows and pick the Direction field only.
    """
    if pred_counts_local is None:
        pred_counts_local = pred_counts
    p_minus = pred_counts_local.get(45, 0)  # '-'
    p_plus  = pred_counts_local.get(49, 0)  # '1' (used in '+1')
    s = p_plus - p_minus
    return 1.0 if s >= 0 else -1.0

def pnl_from_fcst_and_fret(fcst_sign, fret):
    return fcst_sign * fret

if MIDPRICE_CSV and os.path.exists(MIDPRICE_CSV):
    df_mid = pd.read_csv(MIDPRICE_CSV)
    horizons = [1,5,10,20]
    fcst_s = forecast_sign_from_bytes()  # scalar sign; if you can align per-timestep, replace by a vector
    for h in horizons:
        times, fret = compute_future_log_return(df_mid, h)
        pnl = pnl_from_fcst_and_fret(fcst_s, fret)
        print(f"h={h:>2d}  meanPnL={pnl.mean(): .5e}  Sharpe={pnl.mean()/pnl.std(): .3f}  n={len(pnl)}")
else:
    print("PnL skipped (no midprice file provided). Supply MIDPRICE_CSV to compute PnL.")
