In [1]:
from graphviz import Digraph
from colors import colors_light, colors_light_transparent

In [2]:
"""CNN Simple"""

'CNN Simple'

In [3]:
"""CNN Complex"""

dot = Digraph(comment="CNN Complex")

dot.attr(label="CNN Complex", labelloc="t", style="bold", fontsize="20", rankdir="TB", splines="false")

dot.node("Input", "Input (Image)", shape="ellipse", color=colors_light['gray'], style="filled", fillcolor=colors_light_transparent['gray'])
dot.node("Output", "Output (Classification)", shape="ellipse", color=colors_light['gray'], style="filled", fillcolor=colors_light_transparent['gray'])

dot.node("CNNLayers", "CNN Layers", shape="box", color=colors_light['green'], style="filled", fillcolor=colors_light_transparent['green'])
with dot.subgraph(name="cluster_cnn_details") as c:
    c.attr(label="CNN Layers", color=colors_light['gray'], style="dashed", fontsize="15")
    c.node("Conv1", "Conv2d (3, 16, 3x3)\n+ ReLU + MaxPool2d (2x2)", shape="box", color=colors_light['green'], style="filled", fillcolor=colors_light_transparent['green'])
    c.node("Conv2", "Conv2d (16, 32, 3x3)\n + ReLU + MaxPool2d (2x2)", shape="box", color=colors_light['green'], style="filled", fillcolor=colors_light_transparent['green'])
    c.node("Conv3", "Conv2d (32, 64, 3x3)\n+ ReLU + MaxPool2d (2x2)", shape="box", color=colors_light['green'], style="filled", fillcolor=colors_light_transparent['green'])

dot.node("Flatten", "Flatten Layer", shape="box", color=colors_light['peach'], style="filled", fillcolor=colors_light_transparent['peach'])
dot.node("FC1", "Fully Connected\n(512 units)", shape="box", color=colors_light['peach'], style="filled", fillcolor=colors_light_transparent['peach'])
dot.node("Dropout", "Dropout\n(0.5)", shape="box", color=colors_light['peach'], style="filled", fillcolor=colors_light_transparent['peach'])
dot.node("FC2", "Fully Connected\n(3 classes)", shape="box", color=colors_light['peach'], style="filled", fillcolor=colors_light_transparent['peach'])

dot.edge("Input", "CNNLayers")
dot.edge("CNNLayers", "Flatten")
dot.edge("CNNLayers", "Conv1", style="dashed", constraint="false")
dot.edge("Flatten", "FC1")
dot.edge("FC1", "Dropout")
dot.edge("Dropout", "FC2")
dot.edge("FC2", "Output")

dot.edge("Conv1", "Conv2", style="dashed")
dot.edge("Conv2", "Conv3", style="dashed")

dot.render("../images/model_graph/cnn_model_graph", format="png", view=True)

'..\\images\\model_graph\\cnn_model_graph.png'

In [4]:
"""ResNet"""

'ResNet'

In [5]:
"""ViT"""

dot = Digraph(comment="ViTForImageClassification")

dot.attr(label="ViT", labelloc="t", style="bold", fontsize="20", rankdir="TB", splines="false")

dot.node("Input", "Input (Image)", shape="ellipse", color=colors_light['gray'], style="filled", fillcolor=colors_light_transparent['gray'])
dot.node("Output", "Output (Classification)", shape="ellipse", color=colors_light['gray'], style="filled", fillcolor=colors_light_transparent['gray'])

dot.node("LinearClassifier", "Linear Classifier", shape="box", color=colors_light['peach'], style="filled", fillcolor=colors_light_transparent['peach'])

with dot.subgraph(name="cluster_vit_model") as c:
    c.attr(label="ViTModel", color=colors_light['gray'], style="solid",  labeljust="l", fontsize="15")
    c.node("ViTEmbeddings", "ViTEmbeddings", shape="box", width="2", style="filled", color=colors_light['green'], fillcolor=colors_light_transparent['green'])
    c.node("ViTEncoder", "ViTEncoder", shape="box", width="2", style="filled", color=colors_light['green'], fillcolor=colors_light_transparent['green'])
    c.node("LayerNorm", "LayerNorm", shape="box", width="2", style="filled", color=colors_light['green'], fillcolor=colors_light_transparent['green'])

    dot.edge("ViTEmbeddings", "ViTEncoder")
    dot.edge("ViTEncoder", "LayerNorm")

with dot.subgraph(name="cluster_vit_encoder") as c:
    c.attr(label="ViTEncoder", color=colors_light['gray'], style="dashed", fontsize="15")
    c.node("FrozenLayers", "ViTLayer 1-2 \n(Frozen)", shape="box", height="0.33", color=colors_light['green'], style="filled", fillcolor=colors_light_transparent['green'])
    c.node("TrainableLayers", "ViTLayer 3-12 \n(Trainable)", shape="box", height="2", color=colors_light['green'], style="filled", fillcolor=colors_light_transparent['green'])

    dot.edge("ViTEncoder", "FrozenLayers", style="dashed")
    dot.edge("FrozenLayers", "TrainableLayers", style="dashed")

dot.edge("Input", "ViTEmbeddings")
dot.edge("LayerNorm", "LinearClassifier")
dot.edge("LinearClassifier", "Output")

dot.render("../images/model_graph/vit_model_graph", format="png", view=True)

'..\\images\\model_graph\\vit_model_graph.png'

In [9]:
from PIL import Image

vit_image = Image.open("../images/model_graph/vit_model_graph.png")
cnn_image = Image.open("../images/model_graph/cnn_model_graph.png")

vit_width, vit_height = vit_image.size
cnn_width, cnn_height = cnn_image.size

combined_width = vit_width + cnn_width
combined_height = max(vit_height, cnn_height)

combined_image = Image.new("RGB", (combined_width, combined_height), (255, 255, 255))

combined_image.paste(vit_image, (0, 0))
combined_image.paste(cnn_image, (vit_width, 0))

combined_image.save("../images/model_graph/combined_model_graph.png")
combined_image.show()