#RESULT

In [None]:
! pip install -q torchview
! pip install -q -U graphviz

In [None]:
import json, random, string, torch, torchview           as tv
import torchview.computation_graph as _cg
import torchview.torchview         as _tv
from torch import nn
from IPython.display import HTML, display
from torchview.computation_node import TensorNode, ModuleNode, FunctionNode

In [None]:
#@title Model
class AttentionGate(nn.Module):
    def __init__(self, in_channels: int, gating_channels: int, inter_channels: int):
        super().__init__()
        self.Wx = nn.Sequential(
            nn.Conv1d(in_channels, inter_channels, kernel_size=1, bias=False),
            nn.BatchNorm1d(inter_channels),
        )
        self.Wg = nn.Sequential(
            nn.Conv1d(gating_channels, inter_channels, kernel_size=1, bias=False),
            nn.BatchNorm1d(inter_channels),
        )
        self.psi = nn.Sequential(
            nn.Conv1d(inter_channels, 1, kernel_size=1, bias=False),
            nn.BatchNorm1d(1),
            nn.Sigmoid(),
        )
        self.relu = nn.ReLU(inplace=True)
        self.attention_map = None

    def forward(self, x, g):
        # x: (batch_size, in_channels, seq_len)
        # g: (batch_size, gating_channels, seq_len_g)

        x1 = self.Wx(x)  # (batch_size, inter_channels, seq_len)
        g1 = self.Wg(g)  # (batch_size, inter_channels, seq_len_g)

        if g1.size(2) != x1.size(2):
            g1 = F.interpolate(
                g1, size=x1.size(2), mode="nearest"
            )  # (batch_size, inter_channels, seq_len)

        g_x = self.relu(g1 + x1)  # (batch_size, inter_channels, seq_len)
        psi = self.psi(g_x)  # (batch_size, 1, seq_len)
        self.attention_map = psi
        return x * psi  # (batch_size, in_channels, seq_len)


class ResidualConvBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        dilation: int = 1,
        dropout: float = 0.2,
    ):
        super().__init__()
        padding = 2 * dilation
        self.conv1 = nn.Conv1d(
            in_channels, out_channels, kernel_size=5, padding=padding, dilation=dilation
        )
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(
            out_channels,
            out_channels,
            kernel_size=5,
            padding=padding,
            dilation=dilation,
        )
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.dropout = nn.Dropout(dropout)

        if in_channels != out_channels:
            self.residual = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        else:
            self.residual = nn.Identity()

    def forward(self, x):
        # x: (batch_size, in_channels, seq_len)

        residual = self.residual(x)  # (batch_size, out_channels, seq_len)
        out = self.conv1(x)  # (batch_size, out_channels, seq_len)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)  # (batch_size, out_channels, seq_len)
        out = self.bn2(out)
        out = self.dropout(out)
        out = out + residual  # (batch_size, out_channels, seq_len)
        out = self.relu(out)
        return out  # (batch_size, out_channels, seq_len)


class MultiScaleConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.2):
        super().__init__()
        self.conv3 = nn.Conv1d(in_channels, out_channels // 3, kernel_size=3, padding=1)
        self.conv5 = nn.Conv1d(in_channels, out_channels // 3, kernel_size=5, padding=2)
        self.conv7 = nn.Conv1d(
            in_channels,
            out_channels - 2 * (out_channels // 3),
            kernel_size=7,
            padding=3,
        )
        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(dropout)

        if in_channels != out_channels:
            self.residual = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        else:
            self.residual = nn.Identity()

    def forward(self, x):
        # x: (batch_size, in_channels, seq_len)

        residual = self.residual(x)  # (batch_size, out_channels, seq_len)
        out3 = self.conv3(x)  # (batch_size, out_channels // 3, seq_len)
        out5 = self.conv5(x)  # (batch_size, out_channels // 3, seq_len)
        out7 = self.conv7(
            x
        )  # (batch_size, out_channels - 2 * (out_channels // 3), seq_len)
        out = torch.cat(
            [out3, out5, out7], dim=1
        )  # (batch_size, out_channels, seq_len)
        out = self.bn(out)
        out = self.dropout(out)
        out = out + residual  # (batch_size, out_channels, seq_len)
        out = self.relu(out)
        return out  # (batch_size, out_channels, seq_len)


class UNet1D(nn.Module):
    def __init__(
        self,
        in_channels: int = 1,
        out_channels: int = 1,
        base_filters: int = 64,
        dropout: float = 0.2,
    ):
        super().__init__()

        # Encoder
        self.enc1 = ResidualConvBlock(
            in_channels, base_filters, dilation=1, dropout=dropout
        )
        self.pool1 = nn.Conv1d(
            base_filters, base_filters, kernel_size=3, stride=2, padding=1
        )

        self.enc2 = ResidualConvBlock(
            base_filters, base_filters * 2, dilation=2, dropout=dropout
        )
        self.pool2 = nn.Conv1d(
            base_filters * 2, base_filters * 2, kernel_size=3, stride=2, padding=1
        )

        self.enc3 = ResidualConvBlock(
            base_filters * 2, base_filters * 4, dilation=4, dropout=dropout
        )
        self.pool3 = nn.Conv1d(
            base_filters * 4, base_filters * 4, kernel_size=3, stride=2, padding=1
        )

        self.bottleneck = ResidualConvBlock(
            base_filters * 4, base_filters * 8, dilation=1, dropout=dropout
        )

        # Decoder
        self.up3 = nn.ConvTranspose1d(
            base_filters * 8, base_filters * 4, kernel_size=4, stride=2, padding=1
        )
        self.att3 = AttentionGate(base_filters * 4, base_filters * 4, base_filters * 4)
        self.dec3 = MultiScaleConvBlock(
            base_filters * 8, base_filters * 4, dropout=dropout
        )

        self.up2 = nn.ConvTranspose1d(
            base_filters * 4, base_filters * 2, kernel_size=4, stride=2, padding=1
        )
        self.att2 = AttentionGate(base_filters * 2, base_filters * 2, base_filters * 2)
        self.dec2 = MultiScaleConvBlock(
            base_filters * 4, base_filters * 2, dropout=dropout
        )

        self.up1 = nn.ConvTranspose1d(
            base_filters * 2, base_filters, kernel_size=4, stride=2, padding=1
        )
        self.att1 = AttentionGate(base_filters, base_filters, base_filters)
        self.dec1 = MultiScaleConvBlock(base_filters * 2, base_filters, dropout=dropout)

        self.out = nn.Conv1d(base_filters, out_channels, kernel_size=1)

    def _match_length(self, source, target):
        if source.size(2) > target.size(2):
            return source[:, :, : target.size(2)]
        elif source.size(2) < target.size(2):
            pad_size = target.size(2) - source.size(2)
            return F.pad(source, (0, pad_size))
        return source

    def forward(self, x):
        x = x.unsqueeze(1)  # (batch_size, 1, seq_len)

        # Encoder
        e1 = self.enc1(x)  # (batch_size, base_filters, seq_len)
        p1 = self.pool1(e1)  # (batch_size, base_filters, seq_len // 2)
        e2 = self.enc2(p1)  # (batch_size, base_filters * 2, seq_len // 2)
        p2 = self.pool2(e2)  # (batch_size, base_filters * 2, seq_len // 4)
        e3 = self.enc3(p2)  # (batch_size, base_filters * 4, seq_len // 4)
        p3 = self.pool3(e3)  # (batch_size, base_filters * 4, seq_len // 8)
        b = self.bottleneck(p3)  # (batch_size, base_filters * 8, seq_len // 8)

        # Decoder
        u3 = self.up3(b)  # (batch_size, base_filters * 4, seq_len // 4)
        u3 = self._match_length(u3, e3)
        e3_att = self.att3(e3, u3)  # (batch_size, base_filters * 4, seq_len // 4)
        u3 = torch.cat(
            [u3, e3_att], dim=1
        )  # (batch_size, base_filters * 8, seq_len // 4)
        d3 = self.dec3(u3)  # (batch_size, base_filters * 4, seq_len // 4)

        u2 = self.up2(d3)  # (batch_size, base_filters * 2, seq_len // 2)
        u2 = self._match_length(u2, e2)
        e2_att = self.att2(e2, u2)  # (batch_size, base_filters * 2, seq_len // 2)
        u2 = torch.cat(
            [u2, e2_att], dim=1
        )  # (batch_size, base_filters * 4, seq_len // 2)
        d2 = self.dec2(u2)  # (batch_size, base_filters * 2, seq_len // 2)

        u1 = self.up1(d2)  # (batch_size, base_filters, seq_len)
        u1 = self._match_length(u1, e1)
        e1_att = self.att1(e1, u1)  # (batch_size, base_filters, seq_len)
        u1 = torch.cat([u1, e1_att], dim=1)  # (batch_size, base_filters * 2, seq_len)
        d1 = self.dec1(u1)  # (batch_size, base_filters, seq_len)

        logits = self.out(d1)  # (batch_size, 1, seq_len)
        return logits.squeeze(1)  # (batch_size, seq_len)

In [None]:
#@title With Shapes
class ExportCG(_cg.ComputationGraph):
    def __init__(self,*a,**k):
        super().__init__(*a,**k);  self.export_nodes=[]
    def add_node(self,node,subgraph=None):               # type: ignore[override]
        super().add_node(node);  self.export_nodes.append(node)

# ▪ monkey-patch обеих ссылок
_cg.ComputationGraph = ExportCG
_tv.ComputationGraph = ExportCG
STYLE = {
    "ResidualConvBlock"  : ("resid", "#f8cecc", "#b24040"),
    "ConvTranspose1d"    : ("conv_t", "#d5ead2", "#2b7a3d"),
    "Conv1d"             : ("conv",   "#dae8ff", "#3d6fb0"),
    "MultiScaleConvBlock": ("multi",  "#e4d6f6", "#8266b3"),
    "AttentionGate"      : ("attn",   "#fff3c8", "#bfa636"),
}
# NODE_W, NODE_H, ATTN_SIZE = 60, 200, 80
# ARROW_CLR = "#6c8ebf"

NODE_W, NODE_H   = 70, 220
ATTN_SIZE        = 90
TRIANGLE_CLR     = "#6c8ebf"

cg = tv.draw_graph(
    UNet1D(), input_size=(1, 512),
    graph_name='UNet1Dv2',
    expand_nested=True,
    show_shapes=True,
    depth=1,# ← важно
    roll=True,
    hide_inner_tensors=False,  # оставляем TensorNode-ы, чтобы взять shape
    hide_module_functions=True,
    save_graph=False,
)

# ▸▸▸ 2.  helper: форма → "CxL" ▸▸▸
def fmt_shape(shape_tuple: tuple[int, ...]) -> tuple[int, int]:
    """(N, C, L) -> (C, L) ; паддинг если нет длины."""
    if shape_tuple is None or len(shape_tuple) < 2:
        return None, None
    C = shape_tuple[-2]
    L = shape_tuple[-1] if len(shape_tuple) >= 3 else None
    return C, L


# # helper: вернуть (C , L) или (None, None)
# def extract_CL(node) -> tuple[int | None, int | None]:
#     """Пытается достать (Channels, Length) из разных полей node."""
#     for attr in ("tensor_shape", "shape", "output_shape"):
#         shp = getattr(node, attr, None)
#         if shp is None:
#             continue
#         if isinstance(shp, (tuple, list)) and len(shp) >= 3:
#             return int(shp[-2]), int(shp[-1])
#         if isinstance(shp, str) and "x" in shp:
#             parts = [int(p) for p in shp.split("x")]
#             if len(parts) >= 3:
#                 return parts[-2], parts[-1]
#     return None, None



# ▸▸▸ 3.  build_cytoscape — ищем TensorNode-потомка для формы  ▸▸▸
def build_cytoscape(graph: ExportCG):
    """
    Возвращает dict для Cytoscape:
        {"elements": {"nodes": [...], "edges": [...]}}
    * TensorNode-ы, имя которых содержит 'hidden-tensor', исключаются.
    * Их входы соединяются с их выходами напрямую.
    * Каждый видимый модуль получает подпись Fᵢ×D1×Lⱼ.
    """
    # ───────── helper: безопасно достаём (C, L) ─────────
    def extract_CL(node) -> tuple[int | None, int | None]:
        """Пробуем tensor_shape, shape, output_shape → (C, L)"""
        for attr in ("tensor_shape", "shape", "output_shape"):
            shp = getattr(node, attr, None)
            if shp is None:
                continue
            # tuple/list
            if isinstance(shp, (tuple, list)) and len(shp) >= 3:
                return int(shp[-2]), int(shp[-1])
            # строка '1x32x64'
            if isinstance(shp, str) and "x" in shp:
                parts = [int(p) for p in shp.split("x") if p.isdigit()]
                if len(parts) >= 3:
                    return parts[-2], parts[-1]
        return None, None

    # ───────── разбор узлов на Tensor / Module ─────────
    tensor_nodes  = {n.node_id: n for n in graph.export_nodes
                     if isinstance(n, TensorNode)}
    module_nodes  = {n.node_id: n for n in graph.export_nodes
                     if not isinstance(n, TensorNode)}

    hidden_ids = {tid for tid, t in tensor_nodes.items()
                  if "hidden-tensor" in t.name}

    # ───────── карты входов/выходов для переподключения ─────────
    parents, children = {}, {}
    for tail, head in graph.edge_list:
        parents.setdefault(head.node_id,  []).append(tail.node_id)
        children.setdefault(tail.node_id, []).append(head.node_id)

    # ───────── строим новый список рёбер без hidden ─────────
    edge_pairs = []
    for tail, head in graph.edge_list:
        if tail.node_id in hidden_ids or head.node_id in hidden_ids:
            continue
        edge_pairs.append((tail.node_id, head.node_id))

    for hid in hidden_ids:
        for src in parents.get(hid, []):
            for dst in children.get(hid, []):
                if src not in hidden_ids and dst not in hidden_ids:
                    edge_pairs.append((src, dst))

    # ───────── формируем JSON-узлы ─────────
    nodes_json, edges_json = [], []
    ch2idx, len2idx = {}, {}
    next_f = next_l = 1

    # STYLE, NODE_W, NODE_H, ATTN_SIZE должны быть заранее объявлены
    for nid, n in module_nodes.items():
        # цвета/рамка
        cls, fill, brd = "other", "#ddd", "#666"
        for key, (c, f, b) in STYLE.items():
            if key in n.name:
                cls, fill, brd = c, f, b
                break

        # --- (C, L) сначала из самого узла, затем из любого ребёнка-тензора
        C, L = extract_CL(n)
        if C is None:
            for child_id in children.get(nid, []):
                t = tensor_nodes.get(child_id)
                if t is None:
                    continue
                C, L = extract_CL(t)
                if C is not None:
                    break

        if C is not None:
            if C not in ch2idx:
                ch2idx[C] = next_f; next_f += 1
            if L not in len2idx:
                len2idx[L] = next_l; next_l += 1
            label = f"F{ch2idx[C]}×D1×L{len2idx[L]}"
        else:
            label = "A" if cls == "attn" else n.name

        w, h = (ATTN_SIZE, ATTN_SIZE) if cls == "attn" else (NODE_W, NODE_H)
        nodes_json.append({
            "data": {"id": str(nid), "name": label, "cls": cls,
                     "fill": fill, "brd": brd, "w": w, "h": h},
        })

    # ───────── рёбра (только между видимыми узлами) ─────────
    visible_ids = {nd["data"]["id"] for nd in nodes_json}
    for tail_id, head_id in edge_pairs:
        if str(tail_id) not in visible_ids or str(head_id) not in visible_ids:
            continue
        skip = abs(module_nodes[tail_id].depth -
                   module_nodes[head_id].depth) > 1
        edges_json.append({
            "data": {"source": str(tail_id),
                     "target": str(head_id),
                     "kind": "skip" if skip else "norm"}
        })

    return {"elements": {"nodes": nodes_json, "edges": edges_json}}





# ▸▸▸ 4.  show_cy — снова используем dagre  ▸▸▸
def show_cy(data: dict, *, rotate_labels=False, height=650):
    div = "cy_" + ''.join(random.choices(string.ascii_lowercase, k=6))
    rotation = "-90deg" if rotate_labels else "none"

    html = f"""
    <div id="{div}" style="width:100%;height:{height}px;"></div>
    <script src="https://unpkg.com/cytoscape@3.27.0/dist/cytoscape.min.js"></script>
    <script src="https://unpkg.com/dagre@0.8.5/dist/dagre.min.js"></script>
    <script src="https://unpkg.com/cytoscape-dagre@2.5.0/cytoscape-dagre.js"></script>
    <script>
      cytoscape({{
        container : document.getElementById('{div}'),
        elements  : {json.dumps(data['elements'])},
        layout    : {{ name:'dagre', rankDir:'LR', nodeSep:60, rankSep:150 }},
        style     : [
          {{ selector:'node',
             style:{{ 'shape':'round-rectangle','width':'data(w)','height':'data(h)',
                      'background-color':'data(fill)','border-color':'data(brd)',
                      'border-width':3,'label':'data(name)',
                      'text-rotation':'{rotation}',
                      'text-valign':'center','text-halign':'center',
                      'font-size':14,'font-weight':'bold' }} }},
          {{ selector:'node[cls = "attn"]',
             style:{{ 'shape':'diamond','width':'data(w)','height':'data(h)',
                      'text-rotation':'{rotation}' }} }},
          {{ selector:'edge',
             style:{{ 'curve-style':'bezier','width':2,
                      'line-color':'{TRIANGLE_CLR}',
                      'target-arrow-color':'{TRIANGLE_CLR}',
                      'target-arrow-shape':'triangle','arrow-scale':1.4 }} }},
          {{ selector:'edge[kind = "skip"]', style:{{ 'line-style':'dashed' }} }}
        ]
      }}).ready(function(){{ this.fit(); }});
    </script>
    """
    display(HTML(html))

# ───── build + show ─────
cy_json = build_cytoscape(cg)
show_cy(cy_json, rotate_labels=True)