In [None]:
import os, torch
os.environ["CUDA_LAUNCH_BLOCKING"] = "3"   # turn all kernels synchronous
torch.cuda.empty_cache()                  # start with clean slate

In [None]:
# Patching the snippet loop to use torch-based GPU helpers from collision_cuda.py

from collections import defaultdict, Counter
import numpy as np
import torch
import importlib
import pandas as pd

import collision_cuda_clean
importlib.reload(collision_cuda_clean)
import collision_utils_new
importlib.reload(collision_utils_new)

from collision_cuda_clean import (
    quick_unit_filter_gpu, scan_unit_lags_t, score_active_set_t,
    build_rolled_bank
)

import collision_utils
importlib.reload(collision_utils)

from collision_utils import (
    accumulate_unit_stats, accept_units,
    micro_align_units, build_channel_index,
)

from collision_utils_new import resolve_snippet_t

# ---------------- user knobs ----------------
AMP_THR   = 25.0
MASK_THR  = 5.0        # ignored; kept for API parity
MAX_LAG   = 20
BETA      = 0.5
SNIP_LEN  = 121
OVERLAP   = 80
STEP      = SNIP_LEN - OVERLAP
DISCOVERY_THR = 200.0
MAX_SAMPLES   = 200_000
USE_CUDA  = True

# ---------- one-time GPU preparation ----------
unit_ids  = sorted(unit_info.keys())
uid2row   = {u:i for i,u in enumerate(unit_ids)}
row2uid   = {i:u for u,i in uid2row.items()}

raw_chunk = raw_data[:MAX_SAMPLES, :512].astype(np.float32).T
if USE_CUDA:
    raw_chunk_t = torch.from_numpy(raw_chunk).to('cuda', torch.float32)
    EIs_t  = torch.from_numpy(
                np.stack([unit_info[u]['ei'] for u in unit_ids])
             ).to('cuda')
    rolled_t = build_rolled_bank(EIs_t)               # [U,41,C,T]
    peak_ch_t = torch.tensor(
                [np.argmax(unit_info[u]['ei'].ptp(axis=1)) for u in unit_ids],
                device='cuda', dtype=torch.int64)

sel_mask = np.zeros((len(unit_ids), 512), bool)
for i,u in enumerate(unit_ids):
    sel_mask[i, unit_info[u]['selected_channels']] = True
sel_mask_t = torch.from_numpy(sel_mask).to('cuda')

p2p_all = {u: unit_info[u]['ei'].ptp(axis=1) for u in unit_info}

prev_tail_t = None 
prev_tail = None 
lag_history_t = []


In [None]:

# -------- main snippet loop --------
start = 1845
while start + SNIP_LEN <= 1845+121:
    print(f"\n--- snippet @ {start} ---")


    unit_log     = defaultdict(lambda: {'deltas': [], 'lags': []})
    
    raw_snip   = raw_chunk[:, start:start+SNIP_LEN]
    raw_snip_t = raw_chunk_t[:, start:start+SNIP_LEN] if USE_CUDA else None

    if prev_tail is not None:
        raw_snip[:, :tail_len] -= prev_tail              # CPU
        if USE_CUDA:
            raw_snip_t[:, :tail_len] -= prev_tail_t      # CUDA

    # 1) quick filter
    if USE_CUDA:
        good_rows, best_lags = quick_unit_filter_gpu(
                raw_snip_t,          # [C,T] cuda
                EIs_t,               # [U,C,T]
                peak_ch_t,           # [U]
                sel_mask_t)          # build once:  bool[U,C] True on selected channels

        good_units = [row2uid[r] for r in good_rows]
        lag_dict   = {row2uid[r]: [int(l.item()) for l in best_lags[r]]
                    for r in good_rows}


        # print(good_units)
    else:
        raise NotImplementedError("CPU path not shown")

    if not good_units:
        print("no units → skip")
        prev_tail_t = None 
        prev_tail = None 
        start += STEP
        continue

    print(f"Considering {len(good_units)} units")

    # 2) lag scan
    # lags = scan_unit_lags_t(raw_snip_t, rolled_t)          # [U,3]
    # lag_dict = {u: lags[uid2row[u]].tolist() for u in good_units}
    # print(lag_dict)

    # 3) resolve snippet (uses CUDA scorer internally)
    best_combo, per_unit_delta, _ = resolve_snippet_t(
        raw_snippet=raw_snip,
        good_units=good_units,
        channel_to_units=build_channel_index(good_units, unit_info),
        lag_dict=lag_dict,
        unit_info=unit_info,
        p2p_all=p2p_all,
        amp_thr=AMP_THR, beta=BETA,
        use_cuda=USE_CUDA,
        rolled_t=rolled_t,
        raw_snip_t=raw_snip_t,
        uid2row=uid2row
    )



    # 5) update rolling evidence
    for uid, d in per_unit_delta.items():
        unit_log[uid]['deltas'].append(d)
        unit_log[uid]['lags'].append(best_combo['lags'].get(uid, np.nan))

    # 6) global decision + micro-alignment
    stats = accumulate_unit_stats(unit_log)

    accepted, rejected = accept_units(stats)
    
    final_lags, final_deltas = micro_align_units(accepted, stats, unit_info, raw_snip, p2p_all)


    print("final lags:", best_combo["lags"])

    # ----------- build combined_final (still on CPU) -----------------
    combined_final = np.zeros_like(raw_snip, dtype=np.float32)
    for uid, lag in final_lags.items():
        combined_final += np.roll(unit_info[uid]['ei'], lag, axis=1)

    # ----------- store the tail for the NEXT iteration ---------------
    tail_start = STEP             # 41
    tail_len   = OVERLAP          # 80
    prev_tail  = combined_final[:, tail_start:].copy()      # (C,80)
    if USE_CUDA:
        prev_tail_t = torch.from_numpy(prev_tail).to('cuda')

    adjusted_lags = {
        uid: start + (SNIP_LEN // 2) - MAX_LAG + lag
        for uid, lag in final_lags.items()
    }
    lag_history_t.append(adjusted_lags)   
    
    start += STEP
