In [1]:
##%%
# ---------------------- SETUP AND IMPORTS ----------------------------------------------------------------------------
import torch
from sympy.codegen.ast import continue_
from torch import nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from ipywidgets import FileUpload, Dropdown, SelectMultiple, Button, VBox, HBox, Output, IntSlider, Checkbox
from IPython.display import display, clear_output

# import the dataset
from model.vesc_dataset import VESCTimeSeriesDataset, VESCDatasetConfig, CONFIDENCE_COLS, FEATURE_COLS
from preprocessing.prod_preprocessing import prod_load_log, prod_sample_rate_normalization
from preprocessing.training_preprocessing import infer_log_date_from_filename

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model definition
class CNN(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(c_in, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
        )
        self.head = nn.Linear(64, c_out)

    # x: (B, T, C)
    def forward(self, x):
        # translate to (B, C, T)
        x = x.permute(0, 2, 1)
        h = self.net(x).squeeze(-1)
        # (B, C_out) logits
        return self.head(h)

def load_normalization_and_model():
    # load normalization stats from npz file from model training
    NORM_STATS_FILE = np.load("norm_stats.npz", allow_pickle=True)
    NORM_MEAN = torch.from_numpy(NORM_STATS_FILE["mean"]).to(DEVICE)
    NORM_STD_DEV = torch.from_numpy(NORM_STATS_FILE["std"]).to(DEVICE)
    FEATURE_COLS = list(NORM_STATS_FILE["feature_cols"])

    # load trained model
    c_in = len(FEATURE_COLS)
    C_out = len(CONFIDENCE_COLS)
    model = CNN(c_in, C_out).to(DEVICE)
    state = torch.load("best_model.pt", map_location=DEVICE)
    model.load_state_dict(state)
    model.eval()
    return model, NORM_MEAN, NORM_STD_DEV, FEATURE_COLS

MODEL, NORM_MEAN, NORM_STD_DEV, FEATURE_COLS = load_normalization_and_model()

# normalize the batches to comparable value scale
def normalize_batch(xb: torch.Tensor) -> torch.Tensor:
    return (xb - NORM_MEAN) / NORM_STD_DEV

  state = torch.load("best_model.pt", map_location=DEVICE)


In [2]:
##%%
# ---------------------- LOG FILE PREPROCESSING -----------------------------------------------------------------------

from preprocessing import prod_preprocessing

# Input: path to a raw CSV uploaded by the user
# Output: path to a single processed CSV (with your columns)
def preprocess_user_log(raw_csv_path: str) -> str:
    """
    Convert a raw VESC Tool ride log into a CSV formatted for the machine learning model.
    Return the path to the processed CSV.
    """
    raw_path = Path(raw_csv_path)
    # infer log date
    ride_date = infer_log_date_from_filename(raw_csv_path)

    # load the log
    df = prod_load_log(raw_path, ride_date)
    df_resampled =  prod_sample_rate_normalization(df)

    out_dir = Path("tmp_processed")
    out_dir.mkdir(exist_ok=True)
    out_path = out_dir / f"{raw_path.stem}_processed.csv"
    df_resampled.to_csv(out_path, index=False)
    return str(out_path)

In [None]:
##%%
# ---------------------- MODEL INFERENCE  --------------------------------------------------------
from matplotlib.ticker import FuncFormatter, MultipleLocator

def build_dataset_from_csv(csv_path: str) -> VESCTimeSeriesDataset:
    cfg = VESCDatasetConfig(
        files=[csv_path],
        feature_cols=None,
        conf_cols=None,
        sampling_hz=10.0,
        window_ms=3000,
        stride_ms=500,
        min_valid_ratio=0.7,
    )
    ds = VESCTimeSeriesDataset(cfg)
    # store the source path
    ds._dfs[0].attrs["_source_path"] = str(csv_path)
    return ds

def run_inference_on_dataset(ds: VESCTimeSeriesDataset):
    # collect windows for file 0, sorted by start index
    idxs = [(k, s, e) for k,(fi,s,e) in enumerate(ds._index) if fi == 0]
    idxs.sort(key=lambda t: t[1])

    df = ds._dfs[0]
    tcol = ds.cfg.time_col if ds.cfg.time_col in df.columns else None

    preds = []
    times = []
    with torch.no_grad():
        for k, s, e in idxs:
            X, _ = ds[k]                              # (T,Cin), (Cout,)
            xb = X.unsqueeze(0).to(DEVICE)            # (1,T,Cin)
            pb = torch.sigmoid(MODEL(normalize_batch(xb))).cpu().numpy()[0]
            preds.append(pb)
            t_mid = float(df.loc[s:e-1, tcol].median()) if tcol else float(s)
            times.append(t_mid)

    preds = np.vstack(preds)     # (N, C_out)
    times = np.asarray(times)    # (N,)
    # normalize time to seconds starting at 0
    t0 = times.min()
    tsec = (times - t0) / 1000.0
    # bar width (seconds)
    win_sec = ds.window_steps / ds.cfg.sampling_hz
    return tsec, win_sec, preds

In [4]:
##%%
#----------------------- TIMELINE/GRAPH PLOTTING ------------------------------------------------------------------------
def _fmt_mmss(x, pos=None):
    m = int(x // 60)
    s = int(x % 60)
    return f"{m}:{s:02d}"

def apply_behavior_conflict_suppression(probs_at_time: np.ndarray, behavior_class_names, conflict_groups):
    """
    suppress mutually exclusive behaviors at each time step by keeping only the behavior with the highest
    confidence within a conflict group at each time step. Zero out the other conflicting behaviors with lower
    confidence.

    :param probs_at_time: Y behavior confidence at time step
    :param behavior_class_names: CONFIDENCE_COLS, the behavior classification names
    :param conflict_groups: explicitly defined groups of conflicting behaviors, i.e., left and right turn, brake
    and accelerate, that cannot occur simultaneously.

    :return: a copy of the df with non-winning exclusive behaviors suppressed at each time step.
    """
    # don't apply behavior suppression to behaviors that don't have exclusivity
    if not conflict_groups:
        return probs_at_time

    name_to_idx = {behav_class:i for i, behav_class in enumerate(behavior_class_names)}
    suppressed_behaviors = probs_at_time.copy()
    for group in conflict_groups:
        # convert behavior class names to column indices
        group_col_idx = [name_to_idx[behav_class] for behav_class in group if behav_class in name_to_idx]
        # if only one class in this group, then behavior suppression is not required
        if len(group_col_idx) <= 1:
            continue

        # extract just the behaviors that belong to an exclusivity group
        group_scores = suppressed_behaviors[group_col_idx]

        # at each time step, find the index of the highest behavior score within the conflicting behaviors group
        # winners_per_row is length num_steps with values in [0, group_size-1]
        winners_per_row = np.argmax(group_scores, axis=1)

        # build a boolean mask that is the same shape as the group_scores
        # identifies highest confidence behavior within an exclusivity group and the behaviors to suppress
        # True == winner at this row/col, False means a behavior to be suppressed
        winner_mask = np.zeros_like(group_scores, dtype=bool)
        winner_mask[np.arange(group_scores.shape[0]), winners_per_row] = True

        # zero out the non-winner (False) scores
        group_scores[~winner_mask] = 0.0

        # write back into the behavior matrix
        suppressed_behaviors[:, group_col_idx] = group_scores

    return suppressed_behaviors


def plot_timeline_bars(
    tsec, win_sec, preds, targets, CONFIDENCE_COLS, selected,
    alpha=0.35, stack=False,
    xlim=None, x_tick=5, ylim_max=1.0, y_tick=0.1, decimate=2.0, stride_sec=None
):
    """
    - x_tick: major x-axis label spacing in seconds
    - xlim: limit on x-axis maximum (log duration)
    - y_tick: major y-axis label spacing in confidence magnitude (0 - 1)
    - ylim_max: top of y-axis
    - decimate: keep every Nth bar (reduces sample rate visualization to make graph decipherable)
    - stride_sec: determines visual bar width on time-series graph
    - selected: the behaviors selected by user for plotting
    - alpha: transparency of bars
    """
    # map selected class names to indices
    name_to_idx = {c:i for i,c in enumerate(CONFIDENCE_COLS)}
    sel_idx = [name_to_idx[c] for c in selected if c in name_to_idx]
    if not sel_idx:
        print("No classes selected.")
        return

    # decimation to reduce sample density in plot
    if decimate and decimate > 2.0:
        mask = np.zeros_like(tsec, dtype=bool)
        mask[::decimate] = True
        t = tsec[mask]
        P = preds[mask]
        T = targets[mask] if targets is not None else None
    else:
        t, P, T = tsec, preds, targets

    # visualization bar width calculated by stride_sec
    if stride_sec is None:
        # median gap between centers
        diffs = np.diff(t)
        stride_sec = float(np.median(diffs)) if len(diffs) else win_sec
    bar_w = max(1e-3, stride_sec * 0.9)

    plt.figure(figsize=(12, 4 + 0.4*len(sel_idx)))
    bases = np.zeros_like(t)

    for ci in sel_idx:
        y = P[:, ci]
        if stack:
            bottom = bases.copy()
            bases += y
            plt.bar(t, y, width=bar_w, bottom=bottom, alpha=alpha,
                    label=CONFIDENCE_COLS[ci], align='center')
        else:
            plt.bar(t, y, width=bar_w, alpha=alpha,
                    label=CONFIDENCE_COLS[ci], align='center')

    ax = plt.gca()
    ax.xaxis.set_major_locator(MultipleLocator(x_tick))      # ← 5s ticks
    ax.xaxis.set_major_formatter(FuncFormatter(_fmt_mmss))    # ← mm:ss
    ax.yaxis.set_major_locator(MultipleLocator(y_tick))
    if xlim: plt.xlim(*xlim)
    plt.ylim(0, float(ylim_max))                              # ← cap height
    plt.xlabel("time (mm:ss)")
    plt.ylabel("confidence")
    plt.title("Predicted behavior confidence over time")
    plt.legend(ncol=2, fontsize=9)
    plt.tight_layout()
    plt.show()

NameError: name 'np' is not defined

In [4]:
##%%
# ---------------------- USER LOG UPLOAD/UI -----------------------------------------------------------------------------
uploader = FileUpload(accept='.csv', multiple=False)
run_btn = Button(description="Run Inference", button_style='success')
classes_picker = SelectMultiple(
    options=CONFIDENCE_COLS,
    value=("cf_accel","cf_brake","cf_turn_left","cf_turn_right"),
    description='Plot classes',
    rows=8
)
alpha_slider = IntSlider(description='Alpha (%)', min=10, max=90, step=5, value=40)
stack_cb = Checkbox(description='Stack bars', value=False)
out = Output()

CONFLICT_GROUPS = [
    ["cf_turn_left", "cf_turn_right"],
    ["cf_turn_left", "cf_carve_left"],
    ["cf_turn_right", "cf_carve_right"],
    ["cf_carve_left", "cf_carve_right"],
    ["cf_accel", "cf_brake"],
    ["cf_ascent", "cf_descent"],
    ["cf_forward", "cf_reverse"],
]

def handle_run(_):
    with out:
        clear_output()
        try:
            val = uploader.value
            if not val:
                print("Upload a CSV first.")
                return

            # compatibility for ipywidgets versions 7 and 8, dict vs tuple
            if isinstance(val, dict):
                item = next(iter(val.values()))
                raw_bytes = item["content"]
                raw_name = item.get("metadata", {}).get("name", item.get("name", "uploaded.csv"))
            else:  # tuple/list in 8.x
                item = val[0]
                raw_bytes = item["content"]
                raw_name = item.get("name", "uploaded.csv")

            # Save uploaded file
            up_dir = Path("uploads")
            up_dir.mkdir(exist_ok=True)
            raw_file = up_dir / raw_name
            raw_file.write_bytes(raw_bytes)
            print("Uploaded:", raw_file)

            # Preprocess -> processed CSV
            proc_csv = preprocess_user_log(str(raw_file))
            print("Processed CSV:", proc_csv)

            # Build dataset and run inference
            ds = build_dataset_from_csv(proc_csv)

            tsec, win_sec, preds = run_inference_on_dataset(ds)
            print("Windows:", len(ds), "| window_ms:", ds.cfg.window_ms, "| stride_ms:", ds.cfg.stride_ms)

            # Plot
            alpha = alpha_slider.value / 100.0
            selected = list(classes_picker.value)
            conf_suppressed_preds = apply_behavior_conflict_suppression(preds, CONFIDENCE_COLS, CONFLICT_GROUPS)
            plot_timeline_bars(tsec, win_sec, conf_suppressed_preds, None, CONFIDENCE_COLS, selected, alpha=alpha, stack=stack_cb.value,
                           x_tick=5, ylim_max=1.0, decimate=2, stride_sec=1)

        except Exception as e:
            import traceback
            print("ERROR:", e)
            traceback.print_exc()

run_btn.on_click(handle_run)
display(VBox([uploader, HBox([classes_picker, VBox([alpha_slider, stack_cb, run_btn])]), out]))
print("Ready: upload a CSV, pick classes, click Run Inference.")

VBox(children=(FileUpload(value=(), accept='.csv', description='Upload'), HBox(children=(SelectMultiple(descri…

Ready: upload a CSV, pick classes, click Run Inference.
