In [None]:
import vcdvcd
import matplotlib.pyplot as plt
import numpy as np
from dataclasses import dataclass

def condition_signal_tv(tv):
    x, y = list(zip(*tv))
    if not isinstance(y[0], int):
        make_bin_to_int = lambda x : int(x, base=2)
        #x = map(make_bin_to_int, x)
        y = list(map(make_bin_to_int, y))
    tv = list(zip(x, y))
    return tv

def expand_signal(xin, yin, length=None):
    import numpy as np
    import pandas as pd
    xin = list(xin)
    if not length:
        length = max(xin)+1

    # Select all points less than length, then add last point, extrapolting that the last point has no transitions
    # This will always result in an accurate frame of the desired length
    df = pd.DataFrame({"x": xin, "y": yin})
    df = df[df["x"] < length]
    df.loc[len(df)] = {'x': length-1, "y": list(df["y"])[-1]}

    # At the transition add a double x point at the initial value before the change index
    dout = {"x": [], "y": []}
    for thisr, nextr in zip(df.iterrows(), list(df.iterrows())[1:]):
        _, next = nextr
        _, this = thisr
        dout["x"].append(this["x"])
        dout["y"].append(this["y"])

        dout["x"].append(next["x"])
        dout["y"].append(this["y"])
    return dout['x'], dout['y']

def condition_expand(signal, endtime: float):
    """Condition and expand VCD signal to endtime."""
    signal.tv = condition_signal_tv(signal.tv)
    x, y = list(zip(*signal.tv))
    x, y = expand_signal(x, y, endtime)
    return np.array(x, dtype=float), np.array(y, dtype=float)
    
@dataclass
class Signal:
    title: str
    x: np.array
    y: np.array
    xscale: float
    endtime: float

    @classmethod
    def load(cls, vcd: vcdvcd.VCDVCD, signal_id: str, title: str, endtime: float):
        xscale = float(vcd.timescale["timescale"])
        x, y = condition_expand(vcd[signal_id], endtime)
        return cls(title=title, x=np.asarray(x), y=np.asarray(y), xscale=xscale, endtime=endtime)
    
def get_signals(fname="../hdl/q1/logs/vlt_dump.vcd"):
    vcd = vcdvcd.VCDVCD(fname)
    signal_tags = (
        ("TOP.clk", "clk"),
        ("TOP.rstn", "rstn"),
        ("TOP.seq", "seq"),
        ("TOP.det_struct", "det_struct"),
        ("TOP.det_behav", "det_behav")
    )
    
    endtime = max(vcd[name].endtime for name, _ in signal_tags)

    signals: List[Signal] = [
        Signal.load(vcd=vcd, signal_id=signal_id, title=title, endtime=endtime) 
        for signal_id, title in signal_tags]
    return signals

def plot_timing_diagram(signals):
    '''
    Normalize signals and space them out on one axis.
    '''
    fig, ax = plt.subplots(figsize=(12, 6))
    spacing = 2.2

    # Digital signals
    for i, signal in enumerate(signals):
        y_offset = -i * spacing
        endtime = signal.endtime
        xscale = signal.xscale
        x = signal.x*xscale*1e9
        
        y_range = (max(signal.y)-min(signal.y))
        y = np.zeros_like(signal.y) if y_range == 0 else signal.y/y_range
    
        label = signal.title

        # Normalize: each digital waveform ranges from baseline (y_offset) to y_offset+1
        ax.step(x, y_offset + y, where="post", color="black", lw=1.5)
        ax.hlines(y_offset, 0, endtime * xscale, color="0.8",
                  lw=0.5, linestyle="--")  # baseline
        ax.text(-1e-3, y_offset + 0.5, label, ha="right", va="center",
                fontsize=12, weight="bold")
        y_offset -= spacing

    # Formatting
    ax.set_ylim(-spacing*(len(signals)-1) - spacing/2, spacing)
    ax.set_xlabel("Time (ns)", fontsize=12)
    ax.set_yticks([])
    ax.set_title("Structural and Behavioral Sequence Detector", fontsize=14, weight="bold")

    # Gridlines
    # ax.xaxis.set_major_locator(plt.MultipleLocator(5))
    # ax.xaxis.set_minor_locator(plt.MultipleLocator(1))
    ax.grid(axis="x", which="major", linestyle=":", color="0.7")

    return fig, [ax]
    
signals = get_signals()
fig, axes = plot_timing_diagram(list(signals))
axes[0].set_xlim(0, 0.11)
fig.savefig("../report/assets/q1_wave.pdf", bbox_inches="tight")
plt.show()
