# Visualizing Decision Boundaries in 3D: An Extension of Discriminative-vs-QDA-One-Gaussian.ipynb

## Introduction

This notebook serves as an extension of the original `Discriminative-vs-QDA-One-Gaussian.ipynb` project, where we compared the performance of a discriminative model with a manually implemented Quadratic Discriminant Analysis (QDA) on data generated from a single Gaussian distribution. 

In this extension, we focus on visualizing the decision boundaries in a 3-dimensional space (dim=3). By leveraging the `mayavi` library, we can create interactive 3D plots that provide deeper insights into how these models separate the data when the dimensionality is increased. This visualization is crucial for understanding the geometric properties of decision boundaries in higher dimensions and how well each model adapts to the data's underlying structure.

The notebook will cover the following steps:
1. Data generation from a 3-dimensional Gaussian distribution.
2. Training both the discriminative model and the QDA model on this 3D data.
3. Visualizing the decision boundaries using `mayavi` to explore the differences between the models in a 3D space.
4. Comparing the models' performances through misclassification error rates across different sample sizes.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.optim.lr_scheduler import StepLR
from mayavi import mlab

def generate_cov_matrix():
    L = torch.tril(torch.randn(dim, dim))
    L.diagonal().copy_(torch.rand(dim) + 1)  # ensure positive diagonal elements
    cov_matrix = torch.mm(L, L.t())  # L * L^T guarantees a positive definite matrix
    return cov_matrix

def generate_data(n):
    # Create M with only the first three rows and columns non-zero
    M = torch.zeros(dim, dim)
    M[:2, :2] = torch.randn(2, 2)
    
    # Ensure positive diagonal entries in the top-left 3x3 block for positive definiteness
    M.diagonal()[:2] = torch.rand(2) + 1.0
    
    # Compute A_True as M * M^T and adjust it to have only the first three non-zero rows and columns
    A_True = torch.mm(M, M.t())
    A_True[2:, :] = 0
    A_True[:, 2:] = 0
    
    # Adjust B_True similarly
    B_True = torch.zeros(dim, 1)
    B_True[:2] = torch.randn(2, 1)

    Mu_true = torch.full((dim,), 0.1, dtype=torch.float32)
    cov_true = generate_cov_matrix()
    
    dist = MultivariateNormal(Mu_true, cov_true)
    x = dist.sample((n,))

    # Calculate the quadratic form for the classifier
    term1 = torch.sum((x @ A_True) * x, dim=1)
    term2 = (x @ B_True).squeeze()
    c_True = -torch.mean(term1 + term2)
    y = torch.sign(term1 + term2 + c_True)

    return x, y, A_True, B_True, c_True

# Manual QDA Implementation
class ManualQDA:
    def __init__(self):
        self.A = None
        self.B = None
        self.c = None

    def fit(self, x, y):
        class_1 = x[y == 1]
        class_2 = x[y == -1]
        mu_1 = torch.mean(class_1, axis=0)
        mu_2 = torch.mean(class_2, axis=0)
        Sigma_1 = torch.cov(class_1.T)
        Sigma_2 = torch.cov(class_2.T)
        Sigma_1_inv = torch.inverse(Sigma_1)
        Sigma_2_inv = torch.inverse(Sigma_2)

        self.A = 0.5*(Sigma_1_inv - Sigma_2_inv)
        self.B = (Sigma_2_inv @ mu_2 - Sigma_1_inv @ mu_1).T
        self.c = 0.5*(mu_1.T @ Sigma_1_inv @ mu_1 - mu_2.T @ Sigma_2_inv @ mu_2 + torch.logdet(Sigma_1) - torch.logdet(Sigma_2))

    def decision_function(self, x):
        # Compute x^T A x + B^T x + c
        quad_form = torch.sum((x @ self.A) * x, dim=1)
        linear_form = torch.matmul(x, self.B)
        const_term = self.c
        decision_scores = quad_form + linear_form + const_term
        return decision_scores

    def predict(self, x):
        scores = self.decision_function(x)
        return -torch.sign(scores)

### Discriminative Model Training
def train_discriminative_model(x, y, epochs=300, batch_size=64):
    # Convert to PyTorch tensors if not already
    if not isinstance(x, torch.Tensor):
        x = torch.tensor(x, dtype=torch.float32)
    if not isinstance(y, torch.Tensor):
        y = torch.tensor(y, dtype=torch.float32)

    A = torch.zeros((dim, dim), requires_grad=True)
    B = torch.zeros((dim, 1), requires_grad=True)
    C = torch.zeros(1, requires_grad=True)
    optimizer = torch.optim.Adam([A, B, C], lr=0.01)
    scheduler = StepLR(optimizer, step_size=25, gamma=0.1)  # Reduce LR by a factor of 0.1 every 25 epochs
    error_history = []

    # Prepare DataLoader for mini-batch processing
    dataset = TensorDataset(x, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        for batch_x, batch_y in dataloader:
            optimizer.zero_grad()
            yhat = torch.sum((batch_x @ A) * batch_x, dim=1) + (batch_x @ B + C).squeeze()
            loss = torch.mean(torch.log(1 + torch.exp(-batch_y * yhat)))
            loss.backward()
            optimizer.step()
        
        scheduler.step()  # Adjust the learning rate according to the scheduler

        # Evaluate misclassification error at each epoch using all data
        with torch.no_grad():
            yhat = torch.sum((x @ A) * x, dim=1) + (x @ B + C).squeeze()
            predictions = torch.sign(yhat)
            misclassification_error = (predictions != y).float().mean().item()
            error_history.append(misclassification_error)

    return A, B, C, error_history

def evaluate_model(x, y, A, B, C):
    with torch.no_grad():
        yhat = torch.sum((x @ A) * x, dim=1) + (x @ B + C).squeeze()
        predictions = torch.sign(yhat)
        misclassification_error = (predictions != y).float().mean().item()
    return misclassification_error

# Manual QDA Training and Evaluation
def train_and_evaluate_manual_qda(x_train, y_train, x_test, y_test):
    qda = ManualQDA()
    qda.fit(x_train, y_train)

    # Evaluate on training data
    y_pred_train = qda.predict(x_train)
    error_train = (y_pred_train != y_train).float().mean().item()

    # Evaluate on test data
    y_pred_test = qda.predict(x_test)
    error_test = (y_pred_test != y_test).float().mean().item()

    return error_train, error_test

def plot_decision_boundaries(qda_model, disc_model, x, y, A_True, B_True, title="Comparison of Decision Boundaries"):
    x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
    y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
    z_min, z_max = x[:, 2].min() - 1, x[:, 2].max() + 1
    
    if dim == 2:
        xx, yy = np.meshgrid(np.linspace(x_min, x_max, num=400), np.linspace(y_min, y_max, num=400))
        mesh_points = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)
        # Create a mesh to plot the decision boundaries
        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
        xx, yy = np.meshgrid(np.linspace(x_min, x_max, num=400), np.linspace(y_min, y_max, num=400))
        mesh_points = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)
        # Evaluate the QDA classifier on all the grid points
        Z_qda = qda_model.decision_function(mesh_points)
        Z_qda = Z_qda.reshape(xx.shape)
    
        # Evaluate the Discriminative model on all the grid points
        model_outputs = disc_model(mesh_points).detach().numpy()
        Z_disc = model_outputs.reshape(xx.shape)
    
        # Calculate the true decision boundary using A_True and B_True
        # True constant term calculation
        Z_true = torch.sum((mesh_points @ A_True) * mesh_points, dim=1) + torch.matmul(mesh_points, B_True).squeeze() + c_True
        Z_true = Z_true.detach().numpy().reshape(xx.shape)
    
        plt.figure(figsize=(18, 6))
        
        plt.subplot(1, 3, 1)
        plt.contourf(xx, yy, Z_qda, levels=[-1, 0, 1], cmap=plt.cm.coolwarm, alpha=0.8)
        plt.contour(xx, yy, Z_qda, levels=[0], colors='k', linestyles='--')
        plt.scatter(x[:, 0], x[:, 1], c=-y, cmap=plt.cm.coolwarm, s=20, edgecolors='k')
        plt.title(f"QDA - {title}")
        plt.xlabel('Feature 1')
        plt.ylabel('Feature 2')
    
        plt.subplot(1, 3, 2)
        plt.contourf(xx, yy, Z_disc, levels=[-1, 0, 1], cmap=plt.cm.viridis, alpha=0.5)
        plt.contour(xx, yy, Z_disc, levels=[0], cmap="RdBu_r")
        plt.scatter(x[:, 0], x[:, 1], c=y, cmap=plt.cm.viridis, s=20, edgecolors='k')
        plt.title(f"Discriminative - {title}")
        plt.xlabel('Feature 1')
        plt.ylabel('Feature 2')
    
        plt.subplot(1, 3, 3)
        plt.contourf(xx, yy, Z_true, levels=[-1, 0, 1], cmap=plt.cm.Pastel1, alpha=0.8)
        plt.contour(xx, yy, Z_true, levels=[0], colors='k', linestyles='--')
        plt.scatter(x[:, 0], x[:, 1], c=y, cmap=plt.cm.Pastel1, s=20, edgecolors='k')
        plt.title(f"True - {title}")
        plt.xlabel('Feature 1')
        plt.ylabel('Feature 2')
    
        plt.tight_layout()
        plt.show()
        
    elif dim == 3:
        # Generate a 3D grid
        xx, yy, zz = np.meshgrid(np.linspace(x_min, x_max, num=40), 
                                 np.linspace(y_min, y_max, num=40),
                                 np.linspace(z_min, z_max, num=40), indexing='ij')
        mesh_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float32)
    
        # Evaluate the models on all the grid points
        Z_qda = qda_model.decision_function(mesh_points).reshape(xx.shape)
        model_outputs = disc_model(mesh_points).detach().numpy()
        Z_disc = model_outputs.reshape(xx.shape)
        Z_true = (torch.sum((mesh_points @ A_True) * mesh_points, dim=1) + 
                  torch.matmul(mesh_points, B_True).squeeze() + c_True).detach().numpy().reshape(xx.shape)
    
        # Scatter plot for actual data points based on labels
        points_label_1 = x[y == 1]
        points_label_neg1 = x[y == -1]
    
        # Plotting the QDA decision boundary
        mlab.figure(bgcolor=(1, 1, 1), size=(800, 600), fgcolor=(0, 0, 0))  # Black foreground color
        mlab.clf()  # Clear the current figure
        mlab.contour3d(xx, yy, zz, Z_qda, contours=[0], opacity=0.5, colormap='cool')
        mlab.points3d(points_label_1[:, 0], points_label_1[:, 1], points_label_1[:, 2], color=(1, 0, 0), scale_factor=0.1)  # Red for label 1
        mlab.points3d(points_label_neg1[:, 0], points_label_neg1[:, 1], points_label_neg1[:, 2], color=(0, 0, 1), scale_factor=0.1)  # Blue for label -1
        mlab.title(f"QDA - {title}, Sample Size: {sample_sizes}", size=0.5)
        mlab.xlabel("Feature 1")
        mlab.ylabel("Feature 2")
        mlab.zlabel("Feature 3")
        mlab.show()
    
        # Plotting the Discriminative model decision boundary
        mlab.figure(bgcolor=(1, 1, 1), size=(800, 600), fgcolor=(0, 0, 0))
        mlab.clf()
        mlab.contour3d(xx, yy, zz, Z_disc, contours=[0], opacity=0.5, colormap='viridis')
        mlab.points3d(points_label_1[:, 0], points_label_1[:, 1], points_label_1[:, 2], color=(1, 0, 0), scale_factor=0.1)  # Red for label 1
        mlab.points3d(points_label_neg1[:, 0], points_label_neg1[:, 1], points_label_neg1[:, 2], color=(0, 0, 1), scale_factor=0.1)  # Blue for label -1
        mlab.title(f"Discriminative - {title}, Sample Size: {sample_sizes}", size=0.5)
        mlab.xlabel("Feature 1")
        mlab.ylabel("Feature 2")
        mlab.zlabel("Feature 3")
        mlab.show()
    
        # Plotting the True decision boundary
        mlab.figure(bgcolor=(1, 1, 1), size=(800, 600), fgcolor=(0, 0, 0))
        mlab.clf()
        mlab.contour3d(xx, yy, zz, Z_true, contours=[0], opacity=0.5, colormap='autumn')
        mlab.points3d(points_label_1[:, 0], points_label_1[:, 1], points_label_1[:, 2], color=(1, 0, 0), scale_factor=0.1)  # Red for label 1
        mlab.points3d(points_label_neg1[:, 0], points_label_neg1[:, 1], points_label_neg1[:, 2], color=(0, 0, 1), scale_factor=0.1)  # Blue for label -1
        mlab.title(f"True - {title}, Sample Size: {sample_sizes}", size=0.5)
        mlab.xlabel("Feature 1")
        mlab.ylabel("Feature 2")
        mlab.zlabel("Feature 3")
        mlab.show()

################# Simulation #################
    
# Simulation with multiple seeds
dim=3
sample_sizes = [100, 300, 1000, 3000]
results = []
seeds = [1223, 13242, 13252, 1718, 132152]

for n in sample_sizes:
    avg_errors_disc_train, avg_errors_disc_test, avg_errors_qda_train, avg_errors_qda_test = [], [], [], []

    
    # Initialize a flag to control the plotting
    plotted = False
    
    for seed in seeds:
        torch.manual_seed(seed)
        np.random.seed(seed)
        errors_disc_train, errors_disc_test, errors_qda_train, errors_qda_test = [], [], [], []

        for i in range(5):  # Train 5 times with different splits but same initial seed setup
            x, y, A_True, B_True, c_True = generate_data(n)
            x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)
            
            # Discriminative model training and evaluation
            A, B, C, error_history = train_discriminative_model(x_train, y_train)
            error_disc_train = evaluate_model(x_train, y_train, A, B, C)
            error_disc_test = evaluate_model(x_test, y_test, A, B, C)
            errors_disc_train.append(error_disc_train)
            errors_disc_test.append(error_disc_test)

            # QDA training and evaluation
            qda_model = ManualQDA()
            qda_model.fit(x_train, y_train)
            error_qda_train, error_qda_test = train_and_evaluate_manual_qda(x_train, y_train, x_test, y_test)
            errors_qda_train.append(error_qda_train)
            errors_qda_test.append(error_qda_test)
            
            if (dim == 2 or dim == 3) and not plotted:
                if seed == seeds[0] and i == 0:
                    def disc_model_predict(mesh_points):
                        # This function uses the latest values of A, B, and C to predict labels
                        return torch.sum((mesh_points @ A) * mesh_points, dim=1) + (mesh_points @ B + C).squeeze()
                
                    # Call the plot function with all required parameters
                    plot_decision_boundaries(qda_model, disc_model_predict, x_train, y_train, A_True, B_True, title=f"Sample Size {n}")
                    plotted = True  # Set the flag to true after plotting
            
            avg_errors_disc_train.append(np.mean(errors_disc_train))
            avg_errors_disc_test.append(np.mean(errors_disc_test))
            avg_errors_qda_train.append(np.mean(errors_qda_train))
            avg_errors_qda_test.append(np.mean(errors_qda_test))

    results.append((n, np.mean(avg_errors_disc_train), np.mean(avg_errors_disc_test), 
                    np.mean(avg_errors_qda_train), np.mean(avg_errors_qda_test)))


# Plot for Average Misclassification Error
plt.figure(figsize=(8, 5))
plt.title('Average Misclassification Error Comparison')
plt.plot(sample_sizes, [r[1] for r in results], label='Discriminative Model Train', marker='o', color='blue')
plt.plot(sample_sizes, [r[2] for r in results], label='Discriminative Model Test', linestyle='--', marker='x', color='blue')
plt.plot(sample_sizes, [r[3] for r in results], label='QDA Model Train', marker='o', color='red')
plt.plot(sample_sizes, [r[4] for r in results], label='QDA Model Test', linestyle='--', marker='x', color='red')
plt.xlabel('Sample Size')
plt.ylabel('Average Misclassification Error')
plt.legend()
plt.grid(True)
plt.show()

********************************************************************************
         to build the TVTK classes (9.2). This may cause problems.
         Please rebuild TVTK.
********************************************************************************



  self.B = (Sigma_2_inv @ mu_2 - Sigma_1_inv @ mu_1).T
