In [1]:
import numpy as np

class ART:
    def __init__(self, num_input, rho=0.5, alpha=0.1):
        self.num_input = num_input
        self.rho = rho
        self.alpha = alpha
        self.W = np.zeros((num_input,))
        self.V = self.rho * np.linalg.norm(self.W)
        
    def train(self, input_pattern):
        input_pattern = input_pattern / np.linalg.norm(input_pattern)
        similarity = np.dot(self.W, input_pattern)
        if similarity < self.V:
            self.W = (1 - self.alpha) * self.W + self.alpha * input_pattern
            self.V = self.rho * np.linalg.norm(self.W)
            
    def predict(self, input_pattern):
        input_pattern = input_pattern / np.linalg.norm(input_pattern)
        similarity = np.dot(self.W, input_pattern)
        return similarity >= self.V


In [8]:
# Create an ART network with 3 inputs
art = ART(num_input=3)

# Train the network on some input patterns
art.train(np.array([1, 0, 0]))
art.train(np.array([0, 1, 0]))
art.train(np.array([0, 0, 1]))

# Predict whether some input patterns are similar to existing categories
print(art.predict(np.array([0.9, 0.1, 0])))  # Output: True
print(art.predict(np.array([0.1, 0.9, 0])))  # Output: True
print(art.predict(np.array([0, 0, 1])))  # Output: True
print(art.predict(np.array([10, 20, 5])))  # Output: False


True
True
True
True
