In [2]:
!pip install diagrams



In [7]:
# transformer_classifier_diagram_tb.py

from graphviz import Digraph

def make_transformer_classifier_diagram_tb(
    filename='transformer_classifier_tb', fmt='png'
):
    dot = Digraph('TransformerClassifier',
                  filename=filename, format=fmt)
    # Top-to-bottom layout:
    dot.attr(rankdir='TB', splines='ortho',
             nodesep='0.6', ranksep='0.8')
    dot.node_attr.update(
        shape='rectangle', style='filled',
        fontname='Helvetica', fontsize='10'
    )

    # Input → Embedding → PosEnc → Pre‑Norm
    dot.node('Input', '[N, 18, 23]', fillcolor='#F8C3C3')
    dot.node('Emb',   'Linear(18 → 256)\nEmbedding',
             fillcolor='#F8C3C3')
    dot.node('PE',    '⊕ Positional\nEncoding',
             shape='circle', fillcolor='white')
    dot.node('PreLN','LayerNorm', fillcolor='#FEF9C7')

    dot.edge('Input', 'Emb')
    dot.edge('Emb',   'PE')
    dot.edge('PE',    'PreLN')

    # Encoder cluster
    with dot.subgraph(name='cluster_enc') as enc:
        enc.attr(label='Transformer Encoder (3×)', color='black')
        enc.node_attr.update(fillcolor='#F0F0F0')
        enc.node('MHA',       'Multi‑Head\nAttention',    fillcolor='#FFD8B1')
        enc.node('AddNorm1', 'Add & Norm',                fillcolor='#FEF9C7')
        enc.node('FFN',       'Feed‑Forward\n(256→512→256)', fillcolor='#B1E2FF')
        enc.node('AddNorm2', 'Add & Norm',                fillcolor='#FEF9C7')

        # internal wiring
        enc.edge('MHA',       'AddNorm1')
        enc.edge('AddNorm1',  'FFN')
        enc.edge('FFN',       'AddNorm2')
        # residual loop
        enc.edge('AddNorm2',  'MHA', style='dashed')

    # wire Pre‑Norm → MHA
    dot.edge('PreLN', 'MHA')

    # Pooling → FC → Sigmoid
    dot.node('Pool',    'Global Mean\nPooling',
             shape='ellipse', fillcolor='#D3D3F9')
    dot.edge('AddNorm2', 'Pool')

    dot.node('FC',      'Linear(256 → 1)',
             fillcolor='#C3E6C3')
    dot.edge('Pool',    'FC')

    dot.node('Sigmoid','Sigmoid',
             shape='oval', fillcolor='#C3E6C3')
    dot.edge('FC',      'Sigmoid')

    # render
    dot.render(cleanup=True)
    print(f"Diagram saved to {filename}.{fmt}")

if __name__ == '__main__':
    make_transformer_classifier_diagram_tb()


Diagram saved to transformer_classifier_tb.png
