## **HOMEWORK2: Learning Vector Quantization (LVQ)** <br> <br>

Implement the update logic inside the train method

In [11]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# Set seed for reproducibility
np.random.seed(17)

In [12]:
class LVQ:
    def __init__(self, n_prototypes, n_features, n_classes):
        self.n_prototypes = n_prototypes
        self.n_features = n_features
        self.n_classes = n_classes
        # Initialize prototypes randomly
        self.prototypes = np.random.randn(n_prototypes, n_features)
        self.prototype_labels = np.array([i % n_classes for i in range(n_prototypes)])
        
    def train(self, X, y, epochs, learning_rate):
        for _ in range(epochs):
            for i, x in enumerate(X):
                # 1. Find nearest prototype (Best Matching Unit - BMU)
                distances = np.linalg.norm(self.prototypes - x, axis=1)
                bmu_idx = np.argmin(distances)
                
                # 2. Update prototype                
                if self.prototype_labels[bmu_idx] == y[i]:
                    # Move closer
                    self.prototypes[bmu_idx] += learning_rate * (x - self.prototypes[bmu_idx])
                else:
                    # Move away
                    self.prototypes[bmu_idx] -= learning_rate * (x - self.prototypes[bmu_idx])


    def predict(self, X):
        predictions = []
        for x in X:
            distances = np.linalg.norm(self.prototypes - x, axis=1)
            bmu_idx = np.argmin(distances)
            predictions.append(self.prototype_labels[bmu_idx])
        return np.array(predictions)

In [15]:

# TEST EXERCISE 2
print("Testing LVQ...")
# Synthetic data: two clusters
X_lvq = np.concatenate([np.random.randn(10, 2), np.random.randn(10, 2) + 3])
y_lvq = np.array([0]*10 + [1]*10)

lvq = LVQ(n_prototypes=2, n_features=2, n_classes=2)
lvq.train(X_lvq, y_lvq, epochs=20, learning_rate=0.1)
preds = lvq.predict(X_lvq)
print(f"LVQ Predictions: {preds}")

acc = np.mean(preds == y_lvq)
print(f"Accuracy: {acc:.2f}")

if acc < 0.6:
    print("Warning: Accuracy is low. Check your update rule implementation.")

Testing LVQ...
LVQ Predictions: [0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 1 1 1]
Accuracy: 0.95
