# Classification with Least Squares

This notebook demonstrates how to implement a binary classifier and solve for the parameters using a least-squares approach.

In [80]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from mlxtend.plotting.decision_regions import plot_decision_regions
from sklearn.datasets import make_classification

%matplotlib widget

In [104]:
def get_one_hot(targets, nb_classes):
    res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
    return res.reshape(list(targets.shape)+[nb_classes])


class LinearDiscriminant:        
    def fit(self, data, targets):
        num_classes = np.max(targets, axis=0) + 1
        data = np.concatenate((np.ones((data.shape[0], 1)), data), axis=-1)
        targets = get_one_hot(targets, num_classes)
        self.weights_ = np.linalg.inv(data.T @ data) @ data.T @ targets
        
    def predict(self, x):
        """Classify input sample(s)
        
        Parameters
        ----------
        x : array-like, [n_samples, n_features]
            Samples
        
        Returns
        -------
        result : array-like, int, [n_samples]
            Corresponding prediction(s)
        """
        # Add constant for bias parameter
        x = np.concatenate((np.ones((x.shape[0], 1)), x), axis=-1)
            
        return np.argmax(self.weights_.T @ x.T, axis=0)

In [106]:
# Generate some data
n_classes = 3
X, Y = make_classification(200, 2, n_redundant=0, n_classes=n_classes, n_clusters_per_class=1)

fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(X[:, 0], X[:, 1], c=Y)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.collections.PathCollection at 0x7f429ce7b750>

In [107]:
classifier = LinearDiscriminant()
classifier.fit(X, Y)

# Measure number of misclassifications
error = np.sum(np.abs(classifier.predict(X) - Y))
print(f"Error = {(error / 200) * 100:1.2f}%")

fig = plt.figure()
ax = plot_decision_regions(X, Y, classifier)
fig.add_subplot(ax)

Error = 29.00%


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<AxesSubplot:>