In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [25]:
# Parameters
rho = 0.4
input_size = 4
num_categories = 3

In [4]:
# Initialize weights
W_bu = np.array([
    [1.0, 0.0, 0.2],
    [0.0, 0.0, 0.2],
    [0.0, 0.0, 0.2],
    [0.0, 1.0, 0.2]
])
W_td = np.array([
    [1.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 1.0],
    [1.0, 1.0, 1.0, 1.0]
])

In [5]:
assignments = {i: [] for i in range(num_categories)}

In [6]:
def norm(x):
    return np.sum(x)

In [14]:
def vigilance_test(x, td_weights):
    match_val = np.sum(np.minimum(x, td_weights.astype(int))) / norm(x)
    return match_val >= rho

In [9]:
def update_bu(td_vector):
    return (2 * td_vector) / (1 + norm(td_vector))

In [10]:
def visualize(assignments, W_bu, W_td):
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    fig.suptitle("ART1 Network State")

    # Show assignments
    axs[0].set_title("Node Assignments")
    for node, vecs in assignments.items():
        for idx, vec in enumerate(vecs):
            axs[0].text(node, -idx, str(vec), fontsize=12)
    axs[0].set_xlim(-1, num_categories)
    axs[0].set_ylim(-5, 1)
    axs[0].set_xticks(range(num_categories))
    axs[0].set_xlabel("Node")
    axs[0].set_ylabel("Assigned Inputs")

    # Show Weights
    axs[1].set_title("Top-Down Weights")
    im = axs[1].imshow(W_td, cmap="Blues", aspect="auto")
    axs[1].set_xlabel("Input Dimension")
    axs[1].set_ylabel("Node")
    plt.colorbar(im, ax=axs[1], label='Weight Value')

    plt.tight_layout()
    plt.show()

In [None]:
def main():
    global W_bu, W_td
    print("ART1 Network - Enter binary input vectors (e.g. 1 0 0 1). Type 'q' to quit.\n")

    while True:
        user_input = input("Enter binary input vector (length 4): ").strip()
        if user_input.lower() == 'q':
            break

        try:
            x = np.array([int(i) for i in user_input.split()])
            if len(x) != input_size or any(v not in [0,1] for v in x):
                raise ValueError
        except ValueError:
            print("Invalid input. Please enter 4 binary digits separated by space.")
            continue

        activated = [(j, np.dot(x, W_bu[:, j])) for j in range(num_categories)]
        activated.sort(key=lambda tup: -tup[1])
        accepted = False

        for j, _ in activated:
            if vigilance_test(x, W_td[j]):
                print(f"✅ Input {x.tolist()} classified by Node {j+1}")
                W_td[j] = np.minimum(W_td[j], x)
                W_bu[:, j] = update_bu(W_td[j])
                assignments[j].append(x.tolist())
                accepted = True
                break
            else:
                print(f"❌ Node {j+1} rejected due to vigilance.")

        if not accepted:
            print("⚠️ No node accepted the input. Try changing ρ or add more categories.")

        print("\nBottom-Up Weights (W_bu):\n", np.round(W_bu, 2))
        print("Top-Down Weights (W_td):\n", W_td.astype(int))
        visualize(assignments, W_bu, W_td)

if __name__ == "__main__":
    main()

ART1 Network - Enter binary input vectors (e.g. 1 0 0 1). Type 'q' to quit.

