# Playing around with mixture of experts

TODO for 8/9: Train mixture of experts using SGD or EM. Read more about mixture of experts, boosting, and bayesian model averaging

## 1. Hyperparams, Imports, Utils

In [112]:
import numpy as np

from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

import plotly.graph_objs as go
from plotly.subplots import make_subplots

import torch
from torch import nn
from torch import optim

In [113]:
N_POINTS = 1000
TEST_SIZE = 0.2

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

PLOT_COLORS = [
    "#1f77b4", # muted blue
    "#ff7f0e", # safety orange
    "#2ca02c", # cooked asparagus green
    "#d62728", # brick red
    "#9467bd", # muted purple
    "#8c564b", # chestnut brown
    "#e377c2", # raspberry yogurt pink
    "#7f7f7f", # middle gray
    "#bcbd22", # curry yellow-green
    "#17becf" # blue-teal
]

### 1.1 Plotting Utils

In [114]:
def scatter_plot(x, y, title="Scatter plot", xaxis_title="x", yaxis_title="y"):
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x, y=y, mode="markers", marker_color="purple", name=title))
    fig.update_layout(title=title, xaxis_title=xaxis_title, yaxis_title=yaxis_title)
    fig.update_layout(width=600)

    fig.show()

    return fig

def scatter_plot_with_line(x_scatter, y_scatter, x_lines, y_lines, title="Scatter plot", xaxis_title="x", yaxis_title="y"):
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x_scatter, y=y_scatter, mode="markers", name="Train Data", marker_color="purple"))

    for i, (x_line, y_line) in enumerate(zip(x_lines, y_lines)):
        fig.add_trace(go.Scatter(x=x_line, y=y_line, mode="lines+markers", name=f"Fit {i}", marker_color=PLOT_COLORS[i]))
    
    fig.update_layout(title=title, xaxis_title=xaxis_title, yaxis_title=yaxis_title)
    fig.update_layout(width=600)
    fig.show()

def line_plot_with_markers(x_lines, y_lines, title="Line plot", xaxis_title="x", yaxis_title="y"):
    fig = go.Figure()

    for i, (x_line, y_line) in enumerate(zip(x_lines, y_lines)):
        fig.add_trace(go.Scatter(x=x_line, y=y_line, mode="lines+markers", name=f"Points {i}", marker_color=PLOT_COLORS[i]))
    
    fig.update_layout(title=title, xaxis_title=xaxis_title, yaxis_title=yaxis_title)
    # Set the figure width to 6
    fig.update_layout(width=600)
    fig.show()

def combined_plot(plots):
    if len(plots) == 2 or len(plots) == 3:
        fig = make_subplots(rows=1, cols=2)
        for i, plot in enumerate(plots):
            fig.add_trace(plot, row=1, col=i)

    elif len(plots) == 4:
        fig = make_subplots(rows=2, cols=2)
        for i, plot in enumerate(plots):
            fig.add_trace(plot, row=i//2+1, col=i%2+1)

    fig.show()

## 2. Generate Datasets

Data:
-  https://scikit-learn.org/stable/datasets/sample_generators.html
- https://scikit-learn.org/stable/datasets/toy_dataset.html

### 2.1 Roof Dataset

In [115]:
def create_roof_data(noise=0.05):
    x_left = np.random.uniform(-2, -0.5, size=(N_POINTS//3))
    x_center = np.random.uniform(-0.75, 0.75, size=(N_POINTS//3))
    x_right = np.random.uniform(0.5, 2, size=(N_POINTS//3))

    y_left = x_left + 1.5 + np.random.normal(0, noise, size=x_left.shape)
    y_center = np.random.normal(0, noise, size=x_center.shape)
    y_right = -x_right + 1.5 + np.random.normal(0, noise, size=x_right.shape)

    x = np.concatenate([x_left, x_center, x_right])
    y = np.concatenate([y_left, y_center, y_right])

    idx = np.arange(x.shape[0])
    np.random.shuffle(idx)
    x = x[idx]
    y = y[idx]

    return x, y

x_roof, y_roof = create_roof_data()
scatter_plot(x_roof, y_roof, "Roof Data")

### 2.2 Zag Dataset

In [116]:
def create_zig_data(overlap=0, noise=0.05):
    x_left = np.random.uniform(-2 + overlap/2, overlap/2, size=(N_POINTS//2))
    x_right = np.random.uniform(overlap/2, 2 - overlap/2, size=(N_POINTS//2))

    y_left = x_left + 0.5 + np.random.normal(0, noise, size=x_left.shape)
    y_right = x_right - 0.5 + np.random.normal(0, noise, size=x_right.shape)

    x = np.concatenate([x_left, x_right])
    y = np.concatenate([y_left, y_right])

    idx = np.arange(x.shape[0])
    np.random.shuffle(idx)
    x = x[idx]
    y = y[idx]

    return x, y

x_zig_over, y_zig_over = create_zig_data(overlap=0.5)
scatter_plot(x_zig_over, y_zig_over, "Zig Overlap Data")

x_zig_no_over, y_zig_no_over = create_zig_data(overlap=-0.5)
scatter_plot(x_zig_no_over, y_zig_no_over, "Zig No Overlap Data")

### 2.3 Linear Dataset

In [117]:
def create_linear_data(noise=0.05):
    x = np.random.uniform(-2, 2, size=(N_POINTS))
    y = x + np.random.normal(0, noise, size=x.shape)
    
    idx = np.arange(x.shape[0])
    np.random.shuffle(idx)
    x = x[idx]
    y = y[idx]

    return x, y

x_linear, y_linear = create_linear_data()
figure = scatter_plot(x_linear, y_linear, "Linear Data")

In [118]:
combined_plot([
    go.Scatter(x=x_roof, y=y_roof, mode="markers", marker_color="purple"),
    go.Scatter(x=x_zig_over, y=y_zig_over, mode="markers", marker_color="purple"),
    go.Scatter(x=x_zig_no_over, y=y_zig_no_over, mode="markers", marker_color="purple"),
    # go.Scatter(x=x_linear, y=y_linear, mode="markers", marker_color="purple"),
    figure.data[0]
])

### 2.4 Test Train Splits

In [119]:
x_roof_train, x_roof_test, y_roof_train, y_roof_test = train_test_split(x_roof, y_roof, test_size=TEST_SIZE)
x_zig_over_train, x_zig_over_test, y_zig_over_train, y_zig_over_test = train_test_split(x_zig_over, y_zig_over, test_size=TEST_SIZE)
x_zig_no_over_train, x_zig_no_over_test, y_zig_no_over_train, y_zig_no_over_test = train_test_split(x_zig_no_over, y_zig_no_over, test_size=TEST_SIZE)
x_linear_train, x_linear_test, y_linear_train, y_linear_test = train_test_split(x_linear, y_linear, test_size=TEST_SIZE)

## 3. Training

In [120]:
def train_and_plot_all_data(model):
    train_and_plot(model, x_roof_train, x_roof_test, y_roof_train, y_roof_test)
    train_and_plot(model, x_zig_over_train, x_zig_over_test, y_zig_over_train, y_zig_over_test)
    train_and_plot(model, x_zig_no_over_train, x_zig_no_over_test, y_zig_no_over_train, y_zig_no_over_test)
    train_and_plot(model, x_linear_train, x_linear_test, y_linear_train, y_linear_test)

def train_and_plot(model, x_train, x_test, y_train, y_test):
    model.fit(x_train, y_train)
    predictions = model.predict(x_test)

    mse = mean_squared_error(predictions, y_test)
    rounded_mse = round(mse, 6)

    x_plot, y_plot = model.plot()

    scatter_plot_with_line(x_train, y_train, x_plot, y_plot, title=f"{model.name} with Test MSE: {rounded_mse}")

    return predictions

## 4. Models

In [121]:
class LinearFunction():
    def __init__(self, m, b):
        self.m = m
        self.b = b
    
    def predict(self, x):
        return self.m * x + self.b

In [122]:
class LinearRegressor:
    def __init__(self):
        self.model = LinearRegression()
        self.name = "Linear Regression"
        
    def fit(self, x, y):
        self.model.fit(np.expand_dims(x, axis=1), y)
        
    def predict(self, x):
        return self.model.predict(np.expand_dims(x, axis=1))

    def plot(self):
        x = np.array([-2, 2])
        y = self.model.predict(np.expand_dims(x, axis=1))
        return [x], [y] 
    
    def __str__(self):
        return self.name

train_and_plot_all_data(LinearRegressor())

Bootstrapped Linear Regression turns out to not be so interesting. Part of the reason is possibly that the average of a bunch of lines is just a line. Perhaps it might be more interesting to try a quadratic where this could actually matter.

In [123]:
class BootstrapLinearRegressor:
    def __init__(self, n_regressors=3):
        self.models = [LinearRegression() for _ in range(n_regressors)]
        self.name = f"Bootstrapped Linear Regression with {n_regressors} Regressors"
        self.n_regressors = n_regressors
        
    def fit(self, x, y):
        for i in range(self.n_regressors):
            idxs = np.random.randint(0, x.shape[0], x.shape[0])
            self.models[i].fit(np.expand_dims(x[idxs], axis=1), y[idxs])
        
    def predict(self, x):
        predictions_array = np.concatenate([np.expand_dims(self.models[i].predict(np.expand_dims(x, axis=1)), axis=1) for i in range(self.n_regressors)], axis=1)

        return np.mean(predictions_array, axis=1)

    def plot(self):
        xs = []
        ys = []
        for model in self.models:
            x = np.array([-2, 2])
            xs.append(x)
            ys.append(model.predict(np.expand_dims(x, axis=1))) 
        return xs, ys 
    
    def __str__(self):
        return self.name

train_and_plot_all_data(BootstrapLinearRegressor(n_regressors=3))

In [124]:
class MixtureOfLinearExpertsConditionalEM:
    def __init__(self, n_experts=3, em_iterations=50, maximization_epochs=500, temperature=1.0):
        self.V = torch.rand(n_experts, 2, requires_grad=True)
        self.W = torch.rand(n_experts, 2, requires_grad=True)
        self.name = f"Mixture of Linear Experts Conditional EM with {n_experts} Experts"
        self.n_experts = n_experts
        self.em_iterations = em_iterations
        self.maximization_epochs = maximization_epochs
        self.temperature = temperature
        
    def fit(self, x, y):
        self.V = torch.rand(self.n_experts, 2, requires_grad=True)
        self.W = torch.rand(self.n_experts, 2, requires_grad=True)

        x = torch.tensor(x).float()
        y = torch.tensor(y).float()
        x_lift = torch.cat([torch.unsqueeze(x, dim=1), torch.ones((x.shape[0], 1))], dim=1)
        
        losses = []

        for iteration in range(self.em_iterations):
            # ===Expectation===
            gamma = torch.exp(-0.5 * (torch.unsqueeze(y, dim=1) - x_lift @ self.W.T)**2)
            gamma = nn.Softmax(dim=1)(x_lift @ self.V.T / self.temperature) * gamma
            gamma = gamma / torch.sum(gamma, dim=1).unsqueeze(1)
            gamma = gamma.detach()

            # ===Maximizaton===
            optimizer = optim.Adam([self.W, self.V], lr=0.001)

            for epoch in range(self.maximization_epochs):
                optimizer.zero_grad()
                loss = torch.sum(gamma * ((torch.unsqueeze(y, dim=1) - x_lift @ self.W.T)**2 + torch.log(nn.Softmax(dim=1)(x_lift @ self.V.T / self.temperature))))
                loss.backward(retain_graph=True)
                optimizer.step()

            losses.append(loss.item())

        line_plot_with_markers([[i for i in range(len(losses))]], [losses], title="Loss Curve")

    def forward_weighting(self, x):
        x = torch.tensor(x).float()
        x_lift = torch.cat([torch.unsqueeze(x, dim=1), torch.ones((x.shape[0], 1))], dim=1)

        gamma = nn.Softmax(dim=1)(x_lift @ self.V.T / self.temperature)
        return gamma.detach().numpy()

    def predict(self, x):
        x = torch.tensor(x).float()
        x_lift = torch.cat([torch.unsqueeze(x, dim=1), torch.ones((x.shape[0], 1))], dim=1)

        gamma = nn.Softmax(dim=1)(x_lift @ self.V.T / self.temperature)
        predictions = x_lift @ self.W.T
        y = torch.sum(gamma * predictions, dim=1)
        return y.detach().numpy()
    
    def plot(self):
        linear_experts = []
        weights = self.W.detach().numpy()
        for i in range(self.n_experts):
            linear_experts.append(LinearFunction(weights[i, 0], weights[i, 1]))
        
        x_weight = np.linspace(-2, 2, 100)
        y_weight = self.forward_weighting(x_weight)

        line_plot_with_markers([x_weight for _ in range(self.n_experts)], [y_weight[:, i] for i in range(self.n_experts)], title="Weighting Function")

        xs = []
        ys = []
        for linear_expert in linear_experts:
            x = np.array([-2, 2])
            xs.append(x)
            ys.append(linear_expert.predict(x))

        xs.append(np.linspace(-2, 2, 100))
        ys.append(self.predict(np.linspace(-2, 2, 100)))
        return xs, ys
    
    def __str__(self):
        return self.name

train_and_plot_all_data(MixtureOfLinearExpertsConditionalEM())

In [125]:
class MixtureOfLinearExpertsConditionalModel(nn.Module):
    def __init__(self, n_experts=3, temperature=1.0):
        super().__init__()
        self.linear1 = nn.Linear(1, n_experts)
        self.linear2 = nn.Linear(1, n_experts)
        self.n_experts = n_experts
        self.temperature = temperature

    def forward(self, x):
        u = self.linear1(x)
        z = self.linear2(x) / self.temperature
        v = nn.Softmax(dim=1)(z)
        y = torch.sum(torch.mul(u, v), dim=1)
        return y
    
    def forward_weighting(self, x):
        z = self.linear2(x)
        v = nn.Softmax(dim=1)(z)
        return v
    

class MixtureOfLinearExpertsConditionalSGD():
    def __init__(self, model=MixtureOfLinearExpertsConditionalModel(), epochs=1000, lr=0.001):
        self.model = model
        self.name = f"Mixture of Linear Experts Conditional SGD with {self.model.n_experts} Experts and {'Gumbel' if isinstance(model, MixtureOfLinearExpertsConditionalModelGumbel) else 'Regular'} Softmax Temperature {self.model.temperature}"
        self.n_experts = self.model.n_experts
        self.epochs = epochs
        self.loss_fn = nn.MSELoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
    def fit(self, x, y):
        losses = []
        for _ in range(self.epochs):
            self.optimizer.zero_grad()
            y_pred = self.model(torch.from_numpy(np.expand_dims(x, axis=1)).float())
            loss = self.loss_fn(y_pred, torch.tensor(y).float())
            losses.append(loss.item())
            loss.backward()
            self.optimizer.step()
        
        line_plot_with_markers([[i for i in range(len(losses))]], [losses], title="Loss Curve")
        
    def predict(self, x):
        return self.model(torch.from_numpy(np.expand_dims(x, axis=1)).float()).detach().numpy()

    def plot(self):
        params = {}
        for name, param in self.model.named_parameters():
            params[name] = np.ravel(param.data.numpy())

        linear_experts = []
        for i in range(self.n_experts):
            linear_experts.append(LinearFunction(params[f"linear1.weight"][i], params[f"linear1.bias"][i]))
        
        x_weight = np.linspace(-2, 2, 100)
        y_weight = self.model.forward_weighting(torch.from_numpy(np.expand_dims(x_weight, axis=1)).float()).detach().numpy()

        line_plot_with_markers([x_weight for _ in range(self.n_experts)], [y_weight[:, i] for i in range(self.n_experts)], title="Weighting Function")

        xs = []
        ys = []
        for linear_expert in linear_experts:
            x = np.array([-2, 2])
            xs.append(x)
            ys.append(linear_expert.predict(x))

        xs.append(np.linspace(-2, 2, 100))
        ys.append(self.model.forward(torch.from_numpy(np.expand_dims(np.linspace(-2, 2, 100), axis=1)).float()).detach().numpy())
        return xs, ys
    
    def __str__(self):
        return self.name

train_and_plot_all_data(MixtureOfLinearExpertsConditionalSGD(model=MixtureOfLinearExpertsConditionalModel(n_experts=3, temperature=0.1), epochs=10000, lr=0.01))

In [129]:
class MixtureOfLinearExpertsConditionalModelGumbel(nn.Module):
    def __init__(self, n_experts=3, temperature=1.0):
        super().__init__()
        self.linear1 = nn.Linear(1, n_experts)
        self.linear2 = nn.Linear(1, n_experts)
        self.n_experts = n_experts
        self.temperature = temperature

    def forward(self, x):
        u = self.linear1(x)
        y = torch.sum(torch.mul(u, self.forward_weighting(x)), dim=1)
        return y

    def forward_weighting(self, x):
        z = self.linear2(x)
        v = nn.functional.gumbel_softmax(z, tau=self.temperature, hard=True)
        return v

train_and_plot_all_data(MixtureOfLinearExpertsConditionalSGD(model=MixtureOfLinearExpertsConditionalModelGumbel(n_experts=3, temperature=1.0), epochs=10000, lr=0.01))