# Classification Toy problems

### Imports and utils

#### Imports

In [None]:
# %matplotlib inline
# %config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
import pandas as pd

import numpy as np
import seaborn as sns
import warnings
from icecream import ic

warnings.filterwarnings('ignore')
pd.options.display.float_format = '{:,.2f}'.format
pd.set_option('display.max_rows', 100)
pd.set_option('display.max_columns', 200)

from datetime import datetime
from matplotlib.colors import ListedColormap
from sklearn.datasets import make_classification, make_moons, make_circles
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.linear_model import LogisticRegression
import torch
from torch import tensor, Tensor
from torch.nn import Sequential, Linear, BCELoss, Sigmoid, Tanh, Module, CrossEntropyLoss, ReLU, Softmax, NLLLoss, KLDivLoss, L1Loss, MSELoss, BCEWithLogitsLoss
import torch.nn as nn
from torch.optim import Adam, SGD, Adagrad, Adadelta, AdamW, RMSprop, Optimizer
from torch.nn.functional import softmax, sigmoid
from copy import deepcopy


In [None]:
class History:
    def __init__(self) -> None:
        self.history = {}
        self.epoch = []

def model_forward_func(model: Module):
    def forward(x: Tensor):
        return sigmoid(model(x))
    return lambda x: forward(torch.tensor(x, dtype=torch.float32)).detach().squeeze().numpy()
        
def to_categorical(x, num_classes=None):
    """Converts a class vector (integers) to binary class matrix.

    E.g. for use with `categorical_crossentropy`.

    Args:
        x: Array-like with class values to be converted into a matrix
            (integers from 0 to `num_classes - 1`).
        num_classes: Total number of classes. If `None`, this would be inferred
            as `max(x) + 1`. Defaults to `None`.

    Returns:
        A binary matrix representation of the input as a NumPy array. The class
        axis is placed last.

    Example:

    >>> a = keras.utils.to_categorical([0, 1, 2, 3], num_classes=4)
    >>> print(a)
    [[1. 0. 0. 0.]
     [0. 1. 0. 0.]
     [0. 0. 1. 0.]
     [0. 0. 0. 1.]]

    >>> b = np.array([.9, .04, .03, .03,
    ...               .3, .45, .15, .13,
    ...               .04, .01, .94, .05,
    ...               .12, .21, .5, .17],
    ...               shape=[4, 4])
    >>> loss = keras.ops.categorical_crossentropy(a, b)
    >>> print(np.around(loss, 5))
    [0.10536 0.82807 0.1011  1.77196]

    >>> loss = keras.ops.categorical_crossentropy(a, a)
    >>> print(np.around(loss, 5))
    [0. 0. 0. 0.]
    """
    x = np.array(x, dtype="int64")
    input_shape = x.shape

    # Shrink the last dimension if the shape is (..., 1).
    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
        input_shape = tuple(input_shape[:-1])

    x = x.reshape(-1)
    if not num_classes:
        num_classes = np.max(x) + 1
    batch_size = x.shape[0]
    categorical = np.zeros((batch_size, num_classes))
    categorical[np.arange(batch_size), x] = 1
    output_shape = input_shape + (num_classes,)
    categorical = np.reshape(categorical, output_shape)
    return categorical

#### Plotting Utils

In [None]:
from matplotlib.axes import Axes

def plot_decision_boundary(func, X, y, figsize=(7, 5), ax: Axes = None):
    amin, bmin = X.min(axis=0) - 0.1
    amax, bmax = X.max(axis=0) + 0.1
    hticks = np.linspace(amin, amax, 101)
    vticks = np.linspace(bmin, bmax, 101)
    
    aa, bb = np.meshgrid(hticks, vticks)
    ab = np.c_[aa.ravel(), bb.ravel()]
    c = func(ab)
    cc = c.reshape(aa.shape)

    cm = plt.colormaps['RdBu']
    cm_bright = ListedColormap(['#FF0000', '#0000FF'])
    
    if ax is None:
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)

    contour = ax.contourf(aa, bb, cc, cmap=cm, alpha=0.8)
    
    ax_c = plt.colorbar(contour, ax=ax)
    ax_c.set_label("$P(y = 1)$")
    ax_c.set_ticks([0, 0.25, 0.5, 0.75, 1])
    
    ax.scatter(X[:, 0], X[:, 1], c=y, cmap=cm_bright)
    ax.set_xlim(amin, amax)
    ax.set_ylim(bmin, bmax)
    ax.set_title("Decision Boundary")
    return ax

def plot_multiclass_decision_boundary(func, X, y, ax: Axes = None):
    x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
    y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 101), np.linspace(y_min, y_max, 101))
    cmap = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])
    x_map = np.c_[xx.ravel(), yy.ravel()]
    Z = func(x_map)
    Z = Z.reshape(xx.shape)
    if ax is None:
        fig = plt.figure(figsize=(6, 6))
        ax = fig.add_subplot(1,1,1)
    ax.contourf(xx, yy, Z, cmap=plt.colormaps['Spectral'], alpha=0.8)
    ax.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.colormaps['RdYlBu'])
    ax.set_xlim(xx.min(), xx.max())
    ax.set_ylim(yy.min(), yy.max())
    ax.set_title("Decision Boundaries")
    return ax
    
def plot_data(X, y, figsize=(6, 4), ax: Axes = None):
    if ax is None:
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(1,1,1)
    ax.plot(X[y==0, 0], X[y==0, 1], 'or', alpha=0.5, label=0)
    ax.plot(X[y==1, 0], X[y==1, 1], 'ob', alpha=0.5, label=1)
    ax.set_xlim((min(X[:, 0])-0.1, max(X[:, 0])+0.1))
    ax.set_ylim((min(X[:, 1])-0.1, max(X[:, 1])+0.1))
    ax.legend()
    return ax



def plot_confusion_matrix(y_pred, y, ax: Axes = None):
    if ax is None:
        fig = plt.figure(figsize=(6, 4))
        ax = fig.add_subplot(1,1,1)

    sns.heatmap(pd.DataFrame(confusion_matrix(y, y_pred)), annot=True, fmt='d', cmap='YlGnBu', alpha=0.8, vmin=0, ax=ax)
    ax.set_title("Confusion Matrix")
    return ax

In [None]:
def plot_metrics(values: dict, y_keys: list, x_keys: list, ax: Axes = None):
    assert len(y_keys) == len(x_keys)

    assert all(xk in values.keys() for xk in x_keys), f"x_keys: {x_keys} not in values.keys(): {values.keys()}"
    assert  all(yk in values.keys() for yk in y_keys), f"y_keys: {y_keys} not in values.keys(): {values.keys()}"
    
    for y_k, x_k in zip(y_keys, x_keys):
        assert len(values[y_k]) == len(values[x_k])

    if ax is None:
        fig = plt.figure(figsize=(6, 4))
        ax = fig.add_subplot(1,1,1)
    # ax.set = plt.get_cmap('tab10')
    for y_k, x_k in zip(y_keys, x_keys):
        # ic(x_k, y_k)
        ax.plot(values[x_k], values[y_k], linewidth=2, markersize=12, alpha=0.7, label=y_k)
    
    ax.set_xlabel(str(x_keys))
    metric_str = str([yk+': '+str(round(values[yk][-1], 3)) for yk in y_keys])
    ax.set_title(f"Metrics: {metric_str}")
    return ax

def plot_loss(history: dict, ax: Axes = None):
    assert 'loss' in history, "Loss is not in history"
    assert 'epoch' in history, "Epoch is not in history"

    # historydf = pd.DataFrame(history['loss'], index=history['epoch'])
    if ax is None:
        fig = plt.figure(figsize=(6, 4))
        ax = fig.add_subplot(111)
    ax.plot(history['epoch'], history['loss'], label='loss')
    ax.set_xlabel('Epochs')
    ax.set_ylabel('Loss')
    ax.set_ylim(0, max(1, max(history['loss'])))
    ax.legend()
    # historydf.plot(ylim=(0, historydf.values.max()), ax=ax)
    ax.set_title('Loss: %.3f' % history['loss'][-1])
    return ax

def plot_loss_accuracy(history: dict, ax: Axes = None):
    assert 'loss' in history, "Loss is not in history"
    assert 'acc' in history, "Acc is not in history"
    assert 'epoch' in history, "Epoch is not in history"

    # historydf = pd.DataFrame(history, index=history['epoch'])

    if ax is None:
        fig = plt.figure(figsize=(6, 4))
        ax = fig.add_subplot(111)
    # hdf = historydf.loc[:, ['loss', 'acc']]
    ax.plot(history['epoch'], history['loss'], label='loss')
    ax.plot(history['epoch'], history['acc'], label='acc')
    ax.set_xlabel('Epochs')
    ax.set_ylim(0, max(1, max(history['loss'])))
    ax.legend()

    # hdf.plot(ylim=(0, max(1, hdf.values.max())), ax=ax)
    loss = history['loss'][-1]
    acc = history['acc'][-1]
    ax.set_title('Loss: %.3f, Accuracy: %.3f' % (loss, acc))
    return  ax

def plot_compare_histories(history_list: list[History], name_list, plot_accuracy=True):
    #FIXME: remove history item
    dflist = []
    for history in history_list:
        h = {key: val for key, val in history.history.items() if not key.startswith('val_')}
        dflist.append(pd.DataFrame(h, index=history.epoch))

    historydf = pd.concat(dflist, axis=1)

    metrics = dflist[0].columns
    idx = pd.MultiIndex.from_product([name_list, metrics], names=['model', 'metric'])
    historydf.columns = idx
    
    fig = plt.figure(figsize=(6, 8))

    ax = fig.add_subplot(211)
    historydf.xs('loss', axis=1, level='metric').plot(ylim=(0,1), ax=ax)
    ax.set_title("Loss")
    
    if plot_accuracy:
        ax = fig.add_subplot(212)
        historydf.xs('acc', axis=1, level='metric').plot(ylim=(0,1), ax=ax)
        ax.set_title("Accuracy")
        ax.set_xlabel("Epochs")

    plt.tight_layout()


#### Synthetic data

In [None]:
def make_sine_wave():
    c = 3
    num = 2400
    step = num/(c*4)
    np.random.seed(0)
    x0 = np.linspace(-c*np.pi, c*np.pi, num)
    x1 = np.sin(x0)
    noise = np.random.normal(0, 0.1, num) + 0.1
    noise = np.sign(x1) * np.abs(noise)
    x1  = x1 + noise
    x0 = x0 + (np.asarray(range(num)) / step) * 0.3
    X = np.column_stack((x0, x1))
    y = np.asarray([int((i/step)%2==1) for i in range(len(x0))])
    return X, y

def make_multiclass(N=500, D=2, K=3, noise=0.2, shuffle= False, shuffle_seed=42, return_idx =False,plot= False):
    """
    N: number of points per class
    D: dimensionality
    K: number of classes
    """
    np.random.seed(0)
    X = np.zeros((N*K, D))
    y = np.zeros(N*K)
    for j in range(K):
        ix = range(N*j, N*(j+1))
        # radius
        r = np.linspace(0.0,1,N)
        # theta
        t = np.linspace(j*4,(j+1)*4,N) + np.random.randn(N)*noise
        X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
        y[ix] = j
    
    if shuffle:
        np.random.seed(shuffle_seed)
        idx = np.random.permutation(N*K)
        X, y = X[idx], y[idx]
    if plot:
        fig = plt.figure(figsize=(6, 6))
        ax = fig.add_subplot(111)
        ax.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.colormaps['RdYlBu'], alpha=0.8)
        ax.set_xlim(-1,1)
        ax.set_ylim(-1,1)
    if shuffle and return_idx:
        return X, y, idx
    else:
        return X, y

### Simple Classification dataset

In [None]:
def get_classification_data(plot=False, n_samples=1000, n_features=2, n_redundant=0, n_informative=2, random_state=7, n_clusters_per_class=1):
    X, y = make_classification(n_samples=n_samples, n_features=n_features, n_redundant=n_redundant, n_informative=n_informative, random_state=random_state, n_clusters_per_class=n_clusters_per_class)
    X_tensor = torch.tensor(X, dtype=torch.float32)
    y_tensor = torch.tensor(y, dtype=torch.float32).view(-1,1)
    if plot:
        plot_data(X, y)
    return X, y, X_tensor, y_tensor

### Simple Sklearn Logistic Regression

In [None]:
def simple_sklearn_logistic_regression(X, y):
    lr = LogisticRegression()
    lr.fit(X, y)
    print('LR coefficients:', lr.coef_)
    print('LR intercept:', lr.intercept_)

    plot_data(X, y)

    limits = np.array([-2, 2])
    boundary = -(lr.coef_[0][0] * limits + lr.intercept_[0]) / lr.coef_[0][1]
    plt.plot(limits, boundary, "g-", linewidth=2)

### Trainers

In [None]:
def compute_binary_accuracy(y_true: Tensor, y_pred: Tensor):
    assert y_true.shape == y_pred.shape, f"y_true: {y_true.shape} != y_pred: {y_pred.shape}"
    assert y_true.shape[1] == 1, f"y_true: {y_true.shape} != (n, 1)"
    assert y_pred.shape[1] == 1, f"y_pred: {y_pred.shape} != (n, 1)"
    # y_true = y_true.view(-1)
    # y_pred = y_pred.view(-1)
    return (y_true == (y_pred > 0)).sum().item() / len(y_true)

def compute_multiclass_accuracy(y_true: Tensor, y_pred: Tensor):
    assert y_true.shape == y_pred.shape, f"y_true: {y_true.shape} != y_pred: {y_pred.shape}"
    assert y_true.shape[1] > 1, f"y_true: {y_true.shape} != (n, >1)"
    assert y_pred.shape[1] > 1, f"y_pred: {y_pred.shape} != (n, >1)"

    y_pred_d = torch.argmax(y_pred, dim=1)
    y_og = torch.argmax(y_true, dim=1)

    return(y_pred_d == y_og).sum().item() / len(y_og)


In [None]:
def batch_trainer(model: Module, X_tensor: Tensor, y_tensor: Tensor, optimizer, criterion, epochs: int, accuracy_func):
    history = {}
    history['loss'] = []
    history['acc'] = []
    history['epoch'] = []

    for epoch in range(epochs):
        optimizer.zero_grad()
        y_pred = model(X_tensor)
        loss = criterion(y_pred, y_tensor)

        loss.backward()
        optimizer.step()
        # y_pred_d = y_pred > 0.5
        # acc = (y_pred_d == y_tensor).sum().item() / len(y_tensor)
        acc = accuracy_func(y_tensor, y_pred)
        # y_pred_d = torch.argmax(y_pred, dim=1)
        # y_og = torch.argmax(y_tensor, dim=1)
        # acc = (y_pred_d == y_og).sum().item() / len(y_tensor)
        # ic(y_pred_d.shape, y_og.shape)
        
        # ic(acc)
        history['loss'].append(loss.item())
        history['acc'].append(acc)
        history['epoch'].append(epoch)
        if epoch % 10 == 0:
            print('Epoch:', epoch, 'Loss:', loss.item(), end='\r')
    return model, history


In [None]:
def minibatch_trainer(model: Module, X_tensor: Tensor, y_tensor: Tensor, optimizer, criterion, epochs: int, accuracy_func, batch_size:int = 256, shuffle=True):
    history = {}
    history['loss'] = []
    history['acc'] = []
    history['epoch'] = []
    history['step'] = []
    history['loss_int'] = []
    
    step_len = batch_size/(X_tensor.shape[0])
    # ic(step_len)
    ic(batch_size)
    if shuffle:
        idx = torch.randperm(X_tensor.shape[0])
        X_tensor = X_tensor[idx]
        y_tensor = y_tensor[idx]

    for epoch in range(epochs):
        for step, i in enumerate(range(0, X_tensor.shape[0], batch_size)):
            optimizer.zero_grad()
            X_batch = X_tensor[i:i+batch_size]
            y_batch = y_tensor[i:i+batch_size]
            y_pred_batch = model(X_batch)
            loss = criterion(y_pred_batch, y_batch)
            loss.backward()
            optimizer.step()
            history['loss_int'].append(loss.item())
            history['step'].append(epoch + step_len*step)

        y_pred = model(X_tensor)
        loss = criterion(y_pred, y_tensor)
        acc = accuracy_func(y_tensor, y_pred)

        history['loss'].append(loss.item())
        history['acc'].append(acc)
        history['epoch'].append(epoch)
        if epoch % 10 == 0:
            print('Epoch:', epoch, 'Loss:', loss.item(), end='\r')
    return model, history

### Torch Logistic Regression

In [None]:
def get_logistic_regression_model():
    model = Sequential(Linear(2, 1))
    criterion = BCEWithLogitsLoss()
    optimizer = SGD(model.parameters(), lr=0.1)
    return model, criterion, optimizer

#### Log regr for simple classification data

In [None]:
X, y, X_tensor, y_tensor = get_classification_data(n_samples=1000)
model, criterion, optimizer = get_logistic_regression_model()
# model, history = batch_trainer(model, X_tensor, y_tensor, optimizer, criterion, 50, compute_binary_accuracy)
model, history = minibatch_trainer(model, X_tensor, y_tensor, optimizer, criterion, 50, compute_binary_accuracy)
y_pred = (model(X_tensor) > 0.5).detach().numpy()

In [None]:
master_fig, axs = plt.subplots(1,3 ,figsize=(16, 4), width_ratios=[6,8,6])
ax1 = plot_loss_accuracy(history, ax=axs[0])
ax2 = plot_decision_boundary(model_forward_func(model), X, y, ax=axs[1])
ax3 = plot_confusion_matrix(y_pred, y, ax=axs[2])
# print(classification_report(y, y_pred))

### Moons

In [None]:
def get_moons_data(plot=False, n_samples=1024, noise=0.05, random_state=0):
    X, y = make_moons(n_samples=n_samples, noise=noise, random_state=random_state)
    X_tensor = torch.tensor(X, dtype=torch.float32)
    y_tensor = torch.tensor(y, dtype=torch.float32).view(-1,1)
    if plot:
        plot_data(X, y)
    return X, y, X_tensor, y_tensor

In [None]:
X, y, x_tensor, y_tensor = get_moons_data()
X.shape, y.shape, x_tensor.shape, y_tensor.shape

#### LR with moons

In [None]:
X, y, X_tensor, y_tensor = get_moons_data()
model, criterion, optimizer = get_logistic_regression_model()
model, history = batch_trainer(model, X_tensor, y_tensor, optimizer, criterion, 100, compute_binary_accuracy)
y_pred = (model(X_tensor) > 0.5).detach().numpy()

In [None]:
master_fig2, axs = plt.subplots(1,3 ,figsize=(16, 4), width_ratios=[6,8,6])
ax1 = plot_loss_accuracy(history, ax=axs[0])
ax2 = plot_decision_boundary(model_forward_func(model), X, y, ax=axs[1])
ax3 = plot_confusion_matrix(y_pred, y, ax=axs[2])

## ANN

### Moons

In [None]:
def get_simple_nn_model():
    model = Sequential(
        Linear(2,4),
        Tanh(),
        Linear(4,2),
        Tanh(),
        Linear(2,1)
    )
    return model

In [None]:
X, y, X_tensor, y_tensor = get_moons_data()
model = get_simple_nn_model()
criterion = BCEWithLogitsLoss()

optimizer = Adam(model.parameters(), lr=0.1)
model, history = batch_trainer(model, X_tensor, y_tensor, optimizer, criterion, 100, compute_binary_accuracy)
y_pred = (model(X_tensor) > 0).detach().numpy()

In [None]:
master_fig3, axs = plt.subplots(1,3 ,figsize=(16, 4), width_ratios=[6,8,6])
ax1 = plot_loss_accuracy(history, ax=axs[0])
ax2 = plot_decision_boundary(model_forward_func(model), X, y, ax=axs[1])
ax3 = plot_confusion_matrix(y_pred, y, ax=axs[2])

### Circles

In [None]:
def get_circles_data(plot = False):
    X, y = make_circles(n_samples=1000, noise=0.05, factor=0.3, random_state=0)
    X_tensor = torch.tensor(X, dtype=torch.float32)
    y_tensor = torch.tensor(y, dtype=torch.float32).view(-1,1)
    if plot:
        plot_data(X, y)
    return X, y, X_tensor, y_tensor

In [None]:
X, y, X_tensor, y_tensor = get_circles_data()
model = get_simple_nn_model()
criterion = BCEWithLogitsLoss()
optimizer = Adam(model.parameters(), lr=0.1)
model, history = batch_trainer(model, X_tensor, y_tensor, optimizer, criterion, 100, compute_binary_accuracy)
y_pred = (model(X_tensor) > 0).detach().numpy()

In [None]:
master_fig3, axs = plt.subplots(1,3 ,figsize=(16, 4), width_ratios=[6,8,6])
ax1 = plot_loss_accuracy(history, ax=axs[0])
ax2 = plot_decision_boundary(model_forward_func(model), X, y, ax=axs[1])
ax3 = plot_confusion_matrix(y_pred, y, ax=axs[2])

## Multiclass

In [None]:
data_size = 2048
X, y = make_multiclass(K=3, noise=0.2, N = data_size)
y_cat = to_categorical(y)
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y_cat, dtype=torch.float32)

In [None]:
y_t = tensor(y, dtype=torch.long)
y_oh = nn.functional.one_hot(y_t)
X_tensor.shape, y_tensor.shape

In [None]:
def mutliclass_forward_func_with_softmax(model: Module):
    def forward(x):
        x = model(x)
        x = nn.functional.softmax(x, dim=1)
        return torch.argmax(x, dim=1)
    return lambda x: forward(tensor(x, dtype=torch.float32)).detach().squeeze().numpy()

### Single layer

In [None]:
model = Sequential(Linear(2, 3))
criterion = CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=0.1)
model, history = batch_trainer(model, X_tensor, y_tensor, optimizer, criterion, 100, compute_multiclass_accuracy)

In [None]:
fn = mutliclass_forward_func_with_softmax(model)
y_pred = fn(X_tensor)
master_fig35, axs = plt.subplots(1,3 ,figsize=(16, 4), width_ratios=[6,6,6])
ax1 = plot_loss_accuracy(history, ax=axs[0])
ax2 = plot_multiclass_decision_boundary(fn , X, y, ax=axs[1])
ax3 = plot_confusion_matrix(y_pred, y, ax=axs[2])

### Multi layer

In [None]:
def get_multiclass_model():
    model = Sequential(
        Linear(2, 128),
        Tanh(),
        Linear(128, 64),
        Tanh(),
        Linear(64, 32),
        Tanh(),
        Linear(32, 16),
        Tanh(),
        Linear(16, 3)
    )
    return model

In [None]:
def get_multiclass_model_small():
    model = Sequential(
        Linear(2, 16),
        Tanh(),
        Linear(16, 8),
        Tanh(),
        Linear(8, 3),
    )
    return model

In [None]:
def get_multiclass_model_small2(embed_dim=8):
    model = Sequential(
        Linear(2, embed_dim),
        Tanh(),
        Linear(embed_dim, 3)
    )
    return model

#### Miniatch GD

In [None]:
# model = get_multiclass_model()
# model = get_multiclass_model_small()
model = get_multiclass_model_small2(embed_dim=16)
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.1)
# optimizer = SGD(model.parameters(), lr=0.1)
# model, history = trainer(model, X_tensor, y_tensor, optimizer, criterion, 200)
model, history = minibatch_trainer(model, X_tensor, y_tensor, optimizer, criterion, 200, batch_size=data_size//16, accuracy_func=compute_multiclass_accuracy)
fn = mutliclass_forward_func_with_softmax(model)
y_pred = fn(X_tensor)
master_fig4, axs = plt.subplots(1,4 ,figsize=(16, 4), width_ratios=[4,4,6,6])
ax1 = plot_metrics(history, y_keys=[ 'loss_int', 'acc'],x_keys=[ 'step','epoch'] , ax=axs[0])
ax11 = plot_loss_accuracy(history, ax=axs[1])
ax2 = plot_multiclass_decision_boundary(fn , X, y, ax=axs[2])
ax3 = plot_confusion_matrix(y_pred, y, ax=axs[3])

#### Batch GD

In [None]:
model = get_multiclass_model_small2(embed_dim=16)
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.1)
# optimizer = SGD(model.parameters(), lr=0.1)
model, history = batch_trainer(model, X_tensor, y_tensor, optimizer, criterion, 200, accuracy_func=compute_multiclass_accuracy)
fn = mutliclass_forward_func_with_softmax(model)
y_pred = fn(X_tensor)
master_fig4, axs = plt.subplots(1,3 ,figsize=(16, 4), width_ratios=[6,6,6])
# ax1 = plot_metrics(history, y_keys=[ 'loss_int', 'acc'],x_keys=[ 'step','epoch'] , ax=axs[0])
ax1 = plot_loss_accuracy(history, ax=axs[0])
ax2 = plot_multiclass_decision_boundary(fn , X, y, ax=axs[1])
ax3 = plot_confusion_matrix(y_pred, y, ax=axs[2])

## Visualizing weights

In [None]:
state_dict = dict(model.named_parameters())
n_cols = 4
n_rows = (len(state_dict)-1)//n_cols +1

In [None]:
def sparsify_xlabels(ax: Axes, limit=20):
    ylabs = ax.get_xticklabels()
    len_ylab = len(ylabs)
    if len_ylab> limit:
        ylab_new = ['' for i in range(len_ylab)]
        reduction = len_ylab//limit
        ylab_new[::reduction+1] = ylabs[::reduction+1]
        ax.set_xticklabels(ylab_new)
    return ax

#### Histogram

In [None]:
# fig, axs = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
fig = plt.figure(figsize=(5*n_cols, 4*n_rows))
for i, (key, val) in enumerate(state_dict.items()):
    ax = fig.add_subplot(n_rows, n_cols, i+1)
    sns.histplot(val.detach().flatten().numpy(),kde=True, palette='viridis', ax=ax)
    ax.set_title(key)

In [None]:
## TODO:
def plot_weights_barplot(weights: dict[str, Tensor], ax: Axes = None):
    if ax is None:
        fig = plt.figure(figsize=(6, 4))
        ax = fig.add_subplot(1,1,1)
    keys = list(weights.keys())
    vals = [weights[key].detach().numpy().flatten() for key in keys]
    sns.barplot(keys, vals)
    ax.set_title('Weights')
    return ax

#### Barplot

In [None]:
# fig, axs = plt.subplots(n_rows, n_cols, figsize=(16, 4*n_rows))
fig = plt.figure(figsize=(5*n_cols, 4*n_rows))
for i, (key, val) in enumerate(state_dict.items()):
    # ax = axs[i//3, i%3]
    ax = fig.add_subplot(n_rows, n_cols, i+1)
    sns.barplot(val.detach().flatten().numpy(), ax=ax, palette='viridis')
    ax = sparsify_xlabels(ax)
    ax.set_title(key)
    ax.set_xlabel('Neuron index')
    ax.set_ylabel('Value')

## Ensembling

In [None]:
data_size = 2048
seeds = [0, 1, 2, 3, 4]
data_map = {}
ids = {}
for seed in seeds:
    X, y, idx = make_multiclass(K=3, noise=0.2, N = data_size, shuffle=True, shuffle_seed=seed, return_idx=True)
    ids[seed] = idx
    y_cat = to_categorical(y)
    X_tensor = torch.tensor(X, dtype=torch.float32)
    y_tensor = torch.tensor(y_cat, dtype=torch.float32)
    data_map[seed] = (X, y, X_tensor, y_tensor)
    

In [None]:
def ensemble_trainer(data_map: dict, model_getter, accuracy_func):
    seeds = data_map.keys()
    model_map = {s: model_getter() for s in seeds}
    histories = {}
    for seed in seeds:
        X, y, X_tensor, y_tensor = data_map[seed]
        model = model_map[seed]
        criterion = CrossEntropyLoss()
        optimizer = Adam(model.parameters(), lr=0.1)
        model, history = minibatch_trainer(model, X_tensor, y_tensor, optimizer, criterion, 200, accuracy_func=accuracy_func)
        model_map[seed] = model
        histories[seed] =  history
    return model_map, histories

In [None]:
from functools import partial
model_getter = partial(get_multiclass_model_small2, embed_dim=8)
model_map, histories = ensemble_trainer(data_map, model_getter, compute_multiclass_accuracy)

In [None]:
model_map = {s: get_multiclass_model_small2(embed_dim=8) for s in seeds}
histories = {}
for s in seeds:
    model = model_map[s]
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.1)
    # optimizer = SGD(model.parameters(), lr=0.1)
    _, _, X_tensor, y_tensor = data_map[s]
    model, history = minibatch_trainer(model, X_tensor, y_tensor, optimizer, criterion, 200, batch_size=data_size//16, accuracy_func=compute_multiclass_accuracy)
    model_map[s] = model 
    histories[s] = history


In [None]:
master_fig, axs = plt.subplots(1,5 ,figsize=(20, 4))
for i, s in enumerate(seeds):
    ax = axs[i]
    ax1 = plot_loss_accuracy(histories[s], ax=ax)
    ax1.set_title(f"Seed: {s}")

In [None]:
master_fig, axs = plt.subplots(1,5 ,figsize=(20, 4))

for i, s in enumerate(seeds):
    model = model_map[s]
    X, y, _, _ = data_map[s]
    fn = mutliclass_forward_func_with_softmax(model)
    ax = axs[i]
    ax2 = plot_multiclass_decision_boundary(fn , X, y, ax=ax)
    ax2.set_title(f"Seed: {s}")


In [None]:
## Barplots for weights

# n_rows, n_cols = (len(state_dict)-1)//4 +1, 4
master_fig, axs = plt.subplots(5,4 ,figsize=(16, 20))
for i, s in enumerate(seeds):
    weight_dict = dict(model_map[s].named_parameters())
    for j, (key, val) in enumerate(weight_dict.items()):
        ax = axs[i, j]
        sns.barplot(val.detach().flatten().numpy(), ax=ax, palette='viridis')
        ax = sparsify_xlabels(ax)
        ax.set_title(f"Seed: {s}, Layer: {key}")
    # ax.set_title(key)
    axs[i,0].set_ylabel('Value')

for k in range(4):
    axs[-1, k].set_xlabel('Neuron index')
# axs[-1, k].set_xlabel('Neuron index')


In [None]:
## Histogram of weights
# fig = plt.figure(figsize=(5*n_cols, 4*n_rows))
master_fig, axs = plt.subplots(5,4 ,figsize=(16, 20))
for i, s in enumerate(seeds):
    weight_dict = dict(model_map[s].named_parameters())
    for j, (key, val) in enumerate(weight_dict.items()):
        # ax = fig.add_subplot(n_rows, n_cols, i+1)
        ax = axs[i, j]
        sns.histplot(val.detach().flatten().numpy(),kde=True, palette='viridis', ax=ax)
        ax.set_title(f"Seed: {s}, Layer: {key}")

        ax.set_title(key)

In [None]:
master_fig, axs = plt.subplots(1,4 ,figsize=(16, 4))
for i, s in enumerate(seeds):
    weight_dict = dict(model_map[s].named_parameters())
    for j, (key, val) in enumerate(weight_dict.items()):
        # ax = fig.add_subplot(n_rows, n_cols, i+1)
        ax = axs[j]
        sns.histplot(val.detach().flatten().numpy(),kde=True, palette='viridis', ax=ax, element='step', fill=True)
        ax.set_title(f"Seed: {s}, Layer: {key}")

        ax.set_title(key)

### Merging models

In [None]:
def simple_merge_models(models: list[Module]):
    merged_model = deepcopy(models[0])
    # merged_param_itr = merged_model.named_parameters()
    for key, param in merged_model.named_parameters():
        stacked = torch.stack([m.get_parameter(key) for m in models], dim=0)
        param.data = torch.mean(stacked, dim=0)
    ic(param.data[-1])

    return merged_model

In [None]:
merged = simple_merge_models([model_map[s] for s in seeds])

In [None]:
fn = mutliclass_forward_func_with_softmax(merged)
ax2 = plot_multiclass_decision_boundary(fn , X, y)
ax2.set_title(f"Merged model decision boundary")

### Merged weights visualization

In [None]:
master_fig, axs = plt.subplots(1,4 ,figsize=(16, 4))
# for i, s in enumerate(seeds):
merged_weight_dict = dict(merged.named_parameters())
for j, (key, val) in enumerate(merged_weight_dict.items()):
    ax = axs[j]
    sns.barplot(val.detach().flatten().numpy(), ax=ax, palette='viridis')
    ax = sparsify_xlabels(ax)
    ax.set_title(f"Merged Layer: {key}")

In [None]:
y_pred_batch = merged(X_tensor)
# y_pred_batch = model_map[0](X_tensor)
criterion(y_pred_batch, y_tensor)

In [None]:
def ensemble_predict(models, X_tensor):
    y_preds = [model(X_tensor) for model in models]
    y_pred = torch.stack(y_preds, dim=0).mean(dim=0)
    return y_pred

In [None]:
master_fig, axs = plt.subplots(1,5 ,figsize=(20, 4))

for i, s in enumerate(seeds):
    model = model_map[s]
    X, y, X_tensor, _ = data_map[s]
    y_pred = fn(X_tensor)
    ax3 = plot_confusion_matrix(y_pred, y, ax=axs[i])
    ax3.set_title(f"Seed: {s}")

## Federation

In [None]:
def simple_federated_learning(data_map, model, accuracy_func, epochs=1, rounds=200):
    num_clients = len(data_map)
    model_map = {i: deepcopy(model) for i in range(num_clients)}
    histories = {}
    criterion = CrossEntropyLoss()

    for r in range(rounds):
        for k in range(num_clients):
            X, y, X_tensor, y_tensor = data_map[k]
            optimizer = Adam(model_map[k].parameters(), lr=0.1)
            model, history = minibatch_trainer(model_map[k], X_tensor, y_tensor, optimizer, criterion, epochs, accuracy_func)
            model_map[k] = model
            histories[k] = history
        merged = simple_merge_models([model_map[s] for s in range(num_clients)])
        model_map = {i: deepcopy(merged) for i in range(num_clients)}

    return model_map, histories

In [None]:
def iid_data_split(data_map, num_clients):
    data_size = len(data_map[0][0])
    split_size = data_size//num_clients
    split_data_map = {}
    for i in range(num_clients):
        split_data_map[i] = (data_map[0][i*split_size:(i+1)*split_size], data_map[1][i*split_size:(i+1)*split_size])
    return split_data_map

## Old Code

### IRIS dataset

In [None]:
X, y = make_sine_wave()
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)
plot_data(X, y)

In [None]:
from sklearn import datasets

iris = datasets.load_iris()
import matplotlib.pyplot as plt

_, ax = plt.subplots()
scatter = ax.scatter(iris.data[:, 0], iris.data[:, 1], c=iris.target)
ax.set(xlabel=iris.feature_names[0], ylabel=iris.feature_names[1])
_ = ax.legend(
    scatter.legend_elements()[0], iris.target_names, loc="lower right", title="Classes"
)

In [None]:
# unused but required import for doing 3d projections with matplotlib < 3.2


from sklearn.decomposition import PCA

fig = plt.figure(1, figsize=(8, 6))
ax = fig.add_subplot(111, projection="3d", elev=-150, azim=110)

X_reduced = PCA(n_components=3).fit_transform(iris.data)
ax.scatter(
    X_reduced[:, 0],
    X_reduced[:, 1],
    X_reduced[:, 2],
    c=iris.target,
    s=40,
)

ax.set_title("First three PCA dimensions")
ax.set_xlabel("1st Eigenvector")
ax.xaxis.set_ticklabels([])
ax.set_ylabel("2nd Eigenvector")
ax.yaxis.set_ticklabels([])
ax.set_zlabel("3rd Eigenvector")
ax.zaxis.set_ticklabels([])

plt.show()

In [None]:
history = History()
history.history['loss'] = []
history.history['acc'] = []
for e in range(50):
    
    y_pred_tensor = torch.squeeze(model(X_tensor))
    loss = criterion(y_pred_tensor, y_tensor)
    y_pred = y_pred_tensor.round().detach().numpy()
    
    if e % 10 == 0:
        print('Epoch:', e, 'Loss:', loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    total = 0
    correct = 0
    total += y_tensor.size(0)
    correct += np.sum(y_pred == y)
    accuracy = correct/total

    history.history['loss'].append(loss.item())
    history.history['acc'].append(accuracy)

    history.epoch.append(e)