In [4]:
import numpy as np
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt

class PerceptronMulticlass:
    def __init__(self, learning_rate=0.01, n_iterations=1000, n_classes=3):
        self.learning_rate = learning_rate
        self.n_iterations = n_iterations
        self.n_classes = n_classes
        self.W = None
        self.b = None

    def step_function(self, x):
        return np.where(x >= 0, 1, 0)

    def fit(self, X, y):
        n_samples, n_features = X.shape
        # Initialize weight matrix W and bias vector b to zeros
        self.W = np.zeros((self.n_classes, n_features))
        self.b = np.zeros(self.n_classes)

        # Training process (One vs Rest approach)
        for _ in range(self.n_iterations):
            for idx, x_i in enumerate(X):
                true_class = y[idx]
                for j in range(self.n_classes):
                    z_j = np.dot(self.W[j], x_i) + self.b[j]
                    y_hat_j = self.step_function(z_j)
                    y_j = 1 if true_class == j else 0
                    if y_hat_j != y_j:
                        update = self.learning_rate * (y_j - y_hat_j)
                        self.W[j] += update * x_i
                        self.b[j] += update
    
    def predict(self, X):
        # (a) Calculate Linear Output: z_j = W_j · x_i + b_j
        linear_output = np.dot(X, self.W.T) + self.b
        y_predicted = self.step_function(linear_output)
        
        # (b) Class Selection: Use argmax to select the class with the highest output
        return np.argmax(y_predicted, axis=1)

# Example usage:
if __name__ == "__main__":
    X, y = make_blobs(n_samples=100, centers=3, n_features=2, random_state=42)

    # Initialize and train the multiclass perceptron
    perceptron = PerceptronMulticlass(learning_rate=0.01, n_iterations=1000, n_classes=3)
    perceptron.fit(X, y)

    # Make predictions on new samples
    predictions = perceptron.predict(X)
    print("Predictions:", predictions)


Predictions: [2 1 0 1 2 1 0 1 1 0 0 2 2 0 0 2 2 0 2 2 0 2 2 0 0 0 1 2 2 2 2 1 1 2 0 0 0
 0 1 1 2 0 1 0 0 1 2 2 2 1 1 1 0 2 2 2 0 0 1 0 2 1 2 1 2 2 1 2 1 1 1 2 2 0
 1 2 1 2 1 1 0 1 0 2 0 0 0 1 0 1 1 1 0 1 0 0 0 1 2 0]
