In [18]:
from torchinfo import summary
import pandas as pd

In [None]:
def is_strided_conv(layer):
    # Detect a Conv1d with stride > 1
    ks = getattr(layer, "kernel_size", None)
    st = getattr(layer, "stride", None)
    if ks and st:
        # kernel_size/stride may be tuples
        return any(s > 1 for s in (st if isinstance(st, (list,tuple)) else (st,)))
    return False

def collapse_summary_layers(layers):
    groups = []
    enc_id = 0
    dec_id = 0
    pending = None
    last_was_convT = False

    for L in layers:
        cls = L.layer_class.__name__ if hasattr(L, "layer_class") else type(L).__name__

        # 1) Encoder block start
        if is_strided_conv(L):
            # finish previous if any
            if pending:
                groups.append(pending)
            enc_id += 1
            pending = {
                "Stage": f"Encoder Block {enc_id}",
                "layers": [],
                "input_size": L.input_size,
                "params": 0
            }
            # fall through to append this conv

        # 2) Bottleneck LSTM (once)
        elif "LSTM" in cls and pending is None and enc_id and dec_id==0:
            # finish any pending (shouldn't be)
            if pending:
                groups.append(pending)
            # create LSTM group
            groups.append({
                "Stage": "Bottleneck LSTM",
                "layers": ["LSTM"],
                "input_size": L.input_size,
                "output_size": L.output_size,
                "params": L.num_params
            })
            continue

        # 3) Decoder block start
        elif cls == "ConvTranspose1d":
            # finish previous if any
            if pending:
                groups.append(pending)
            dec_id += 1
            pending = {
                "Stage": f"Decoder Block {dec_id}",
                "layers": [],
                "input_size": L.input_size,
                "params": 0
            }
            last_was_convT = True
            # fall through to append

        # 4) within a block: nothing special
        if pending:
            # add this layer
            pending["params"] += getattr(L, "num_params", 0)

            # build a short descriptor
            if cls == "Conv1d":
                k = L.kernel_size[-1] if hasattr(L, "kernel_size") else "?"
                s = L.stride[-1]    if hasattr(L, "stride")    else ""
                pending["layers"].append(f"Conv1d(k={k},s={s})")
            elif cls == "ConvTranspose1d":
                k = L.kernel_size[-1] if hasattr(L, "kernel_size") else "?"
                s = L.stride[-1]    if hasattr(L, "stride")    else ""
                pending["layers"].append(f"ConvT1d(k={k},s={s})")
            elif "ReLU" in cls:
                pending["layers"].append("ReLU")
                if last_was_convT:
                    # treat this ReLU as end of decoder block
                    pending["output_size"] = L.output_size
                    groups.append(pending)
                    pending = None
                    last_was_convT = False
            elif "GLU" in cls:
                pending["layers"].append("GLU")
                # end of encoder block
                pending["output_size"] = L.output_size
                groups.append(pending)
                pending = None

    # if something left hanging, append it
    if pending:
        pending["output_size"] = pending.get("output_size", pending["input_size"])
        groups.append(pending)

    # build DataFrame
    records = []
    for g in groups:
        records.append({
            "Stage":       g["Stage"],
            "Layers":      " → ".join(g["layers"]),
            "Input Shape": g["input_size"],
            "Output Shape":g["output_size"],
            "Params":      g["params"],
        })
    return pd.DataFrame(records)

In [None]:
from model_def import CausalDemucsSplit

def main():
    # 1) Instantiate and summarize
    model = CausalDemucsSplit().eval()
    summ = summary(
        model,
        input_size=(1, 1, 32085),
        verbose=0,
        col_names=["input_size", "output_size", "num_params"],
        depth=None
    )

    # 2) Grab the summary list
    layers = getattr(summ, "summary_list", None) or getattr(summ, "_summary_list", [])

    # 3) Define the order in which blocks appear, with a prefix to match against each layer's name
    block_prefixes = [
        ("Input Upsample", "sinc_interpolation"),          # your upsampling conv
        ("Encoder Block 1", "encoder.layers.0"),
        ("Encoder Block 2", "encoder.layers.1"),
        ("Encoder Block 3", "encoder.layers.2"),
        ("Encoder Block 4", "encoder.layers.3"),
        ("Encoder Block 5", "encoder.layers.4"),
        ("Bottleneck LSTM",      "lstm"),
        ("Decoder Block 5", "decoder.layers.4"),
        ("Decoder Block 4", "decoder.layers.3"),
        ("Decoder Block 3", "decoder.layers.2"),
        ("Decoder Block 2", "decoder.layers.1"),
        ("Decoder Block 1", "decoder.layers.0"),
        ("Output Conv", "output_conv")                    # your final deconv or conv
    ]

    # 4) Iterate and group
    grouped = []
    current = None
    for layer in layers:
        name = layer.name  # e.g. "encoder.layers.0.conv1"
        # see if this layer starts a new block
        for label, prefix in block_prefixes:
            if name.startswith(prefix):
                # start a new group
                if current:
                    grouped.append(current)
                current = {
                    "Block": label,
                    "layers": [],
                    "Params": 0
                }
                break
        if not current:
            # skip any stray layers before the first block
            continue
        # add this layer into the current block
        current["layers"].append(layer)
        current["Params"] += layer.num_params

    # add the final block
    if current:
        grouped.append(current)

    # 5) Build table rows
    rows = []
    for g in grouped:
        first, last = g["layers"][0], g["layers"][-1]

        # build a short operation summary
        ops = []
        for l in g["layers"]:
            lt, k, s = get_layer_type_and_params(l)
            # only show params if it's a conv or deconv or GLU
            if lt in ("Conv1d", "ConvTranspose1d"):
                ops.append(f"{lt}(k={k},s={s})")
            elif lt in ("GLU", "ReLU", "LSTM"):
                ops.append(lt)
            else:
                # skip BatchNorm, Dropout, etc. or include if you like
                ops.append(lt)
        op_str = "→".join(ops)

        rows.append({
            "Block":         g["Block"],
            "Operation":     op_str,
            "Output Shape":  last.output_size,
            "Params":        g["Params"]
        })

    df = pd.DataFrame(rows)

    # 6a) Print Markdown table
    print(df.to_markdown(index=False))

    # 6b) Save CSV if you want
    df.to_csv("model_summary_grouped.csv", index=False)
    print("\nSaved → model_summary_grouped.csv")