In [None]:
import matplotlib.pyplot as plt


def draw_neural_network(layer_sizes, layer_labels):
    fig, ax = plt.subplots(figsize=(12, 7))
    ax.axis('off')

    x_space = 2
    y_space = 1

    layer_pos = []
    for i, layer_size in enumerate(layer_sizes):
        x = i * y_space
        y_start = -(layer_size - 1) * x_space / 2
        positions = [(x, y_start + j * x_space) for j in range(layer_size)]
        layer_pos.append(positions)

    for i, layer in enumerate(layer_pos):
        for j, (x, y) in enumerate(layer):
            circle = plt.Circle((x, y), 0.2, fill=True)
            ax.add_patch(circle)

            if i == 0:
                ax.text(x - 0.6, y, layer_labels[j], fontsize=9, ha='right')
            elif i == len(layer_pos) - 1:
                ax.text(x + 0.4, y, "Fraud\n(0 / 1)", fontsize=9, ha='left')

    for i in range(len(layer_pos) - 1):
        for (x1, y1) in layer_pos[i]:
            for (x2, y2) in layer_pos[i + 1]:
                ax.plot([x1, x2], [y1, y2], 'gray')

    titles = ["Input Layer", "Hidden Layer 1", "Hidden Layer 2", "Output Layer"]
    for i, title in enumerate(titles):
        ax.text(i * y_space, max(3, layer_sizes[0] / 2),
                title, ha='center', fontsize=11, fontweight='bold')

    plt.show()

layer_sizes = [6, 8, 4, 1]

input_labels = [
    "TransactionAmount",
    "TransactionTime",
    "MerchantCategory",
    "CustomerAge",
    "AccountBalance",
    "TransactionsToday"
]

draw_neural_network(layer_sizes, input_labels)


: 