# GMM Classification Experiments 


We experiment with 2D logistic models.

In [3]:
%load_ext autoreload
%autoreload 2


import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from torch import nn
from torch import optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from copy import deepcopy
import matplotlib.pyplot as plt
from layered_unlearning.utils import set_seed
from layered_unlearning.gmm_classification import (
    Gaussian,
    GaussianMixture,
    LogisticModel,
    Uniform,
    train,
    evaluate,
    construct_dataset,
)
import math
from typing import Dict, List
from pathlib import Path
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
import numpy as np


seed = set_seed(3)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Hyperparameters
Default hyperparameters for our experiments. Of note, we do this in 2 dimensions. 

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n_epochs = 3
lr = 1e-2
batch_size = 32
n_classes = 2
n_samples = 5000
dim = 2
weight_decay = 0.0
weight_delta_penalty = 0.0

rbf = True
degree = 0

eps = 1e-8
n_layers = 0
batch_norm = True
hidden_dim = 128

loss_type = "cross_entropy"


def ellipse(rotate: float = 0.0, x_scale: float = 1.0, y_scale: float = 1.0):
    rotate = rotate * (torch.pi / 180)
    cov = torch.Tensor(
        [
            [x_scale, 0],
            [0, y_scale],
        ]
    )
    rotate = torch.Tensor(
        [
            [math.cos(rotate), -math.sin(rotate)],
            [math.sin(rotate), math.cos(rotate)],
        ]
    )
    return rotate @ cov @ rotate.T


def mu_gen():
    width = 50
    return torch.rand((dim,)) * 2 * width - width


def cov_gen():
    base = torch.eye(dim) * 4
    U = torch.randn((dim, dim))
    perturb = U.T @ U * 0.1
    return base + perturb


def get_gaussian_mixture(
    n_classes: int,
    mu_list: List[torch.Tensor] = None,
    cov_list: List[torch.Tensor] = None,
) -> GaussianMixture:
    classes = []
    for i in range(n_classes):
        classes.append(
            Gaussian(
                mu=mu_gen() if mu_list is None else mu_list[i],
                cov=cov_gen() if cov_list is None else cov_list[i],
            )
        )

    mixture = GaussianMixture(
        classes=classes,
        weights=torch.ones(n_classes) / n_classes,
    )
    return mixture


def get_even_clusters(X: np.ndarray, cluster_size: int):
    n_clusters = int(np.ceil(len(X) / cluster_size))
    kmeans = KMeans(n_clusters)
    kmeans.fit(X)
    centers = kmeans.cluster_centers_
    centers = (
        centers.reshape(-1, 1, X.shape[-1])
        .repeat(cluster_size, 1)
        .reshape(-1, X.shape[-1])
    )
    distance_matrix = cdist(X, centers)
    clusters = linear_sum_assignment(distance_matrix)[1] // cluster_size
    return clusters


uniform_half_width = 60
n_classes = 3
clustering = "adversarial"  # random, k-means, adversarial

mean_lists = []
for i in range(3):
    for j in range(n_classes):
        mean_lists.append(mu_gen())
if clustering == "random":
    # reshape to (3, n_classes, dim)
    mean_lists = torch.stack(mean_lists).reshape(3, n_classes, dim)
elif clustering == "k-means":
    all_means = torch.cat(mean_lists, dim=0)
    all_means = all_means.reshape(-1, dim)
    # get clusters
    labels = get_even_clusters(all_means.numpy(), n_classes)
    # reshape to (3, n_classes, dim)
    # group by labels
    mean_lists = []
    for i in range(3):
        filtered = all_means[labels == i]
        # convert to tensor
        filtered = torch.tensor(filtered)
        mean_lists.append(filtered)
    mean_lists = torch.stack(mean_lists).reshape(3, n_classes, dim)
elif clustering == "adversarial":
    all_means = torch.cat(mean_lists, dim=0)
    all_means = all_means.reshape(-1, dim)
    # get clusters
    labels = get_even_clusters(all_means.numpy(), n_classes)
    # reshape to (3, n_classes, dim)
    # group by labels
    better_means = [[] for _ in range(3)]
    mean_lists = []
    for i in range(3):
        filtered = all_means[labels == i]
        # convert to tensor
        filtered = torch.tensor(filtered)
        mean_lists.append(filtered)
    mean_lists = torch.stack(mean_lists).reshape(3 * n_classes, dim)
    for i in range(3 * n_classes):
        better_means[i % 3].append(mean_lists[i])
    for i in range(3):
        better_means[i] = torch.stack(better_means[i])
    mean_lists = better_means
    mean_lists = torch.stack(mean_lists).reshape(3, n_classes, dim)
else:
    raise ValueError(f"Unknown clustering method: {clustering}")

gaussians = [
    Uniform(
        low=torch.tensor([-1.0, -1.0]) * uniform_half_width,
        high=torch.tensor([1.0, 1.0]) * uniform_half_width,
    ),
]

for i in range(3):
    gaussians.append(
        get_gaussian_mixture(
            n_classes=n_classes,
            mu_list=mean_lists[i],
        )
    )

# null, task A, task B, retain

X_full = [g.sample(n_samples) for g in gaussians]

  filtered = torch.tensor(filtered)
  torch.tensor(self.weights), n_samples, replacement=True
  torch.tensor(self.weights), n_samples, replacement=True


## Training
We train the initial model, the base unlearned model, and the Layered Unlearning (LU) version of the base unlearned model. 

In [22]:
model_checkpoints = {}
evals = {}


def get_model(old_model: nn.Module = None):
    model = LogisticModel(
        dim=dim,
        n_classes=n_classes,
        degree=degree,
        rbf=rbf,
        n_layers=n_layers,
        batch_norm=batch_norm,
        hidden_dim=hidden_dim,
    ).to(device)
    if old_model is not None:
        model.load_state_dict(old_model.state_dict())
    return model


def global_train(
    model: nn.Module,
    learn_A: bool,
    learn_B: bool,
    relearn: bool = False,
    kwargs: Dict = {},
):
    X, y = construct_dataset(
        X_full, learn_A=learn_A, learn_B=learn_B, relearn=relearn, n_samples=n_samples
    )
    init_kwargs = {
        "eps": eps,
        "n_epochs": n_epochs,
        "lr": lr,
        "batch_size": batch_size,
        "weight_decay": weight_decay,
        "device": device,
        "loss_type": loss_type,
    }
    init_kwargs.update(kwargs)
    model = train(
        model,
        X,
        y,
        **init_kwargs,
    )
    return model


def global_eval(model: nn.Module, kwargs: Dict = {}):
    accuracies = []
    for i in range(1, 4):
        X = X_full[i]
        y = torch.ones(n_samples)
        if i == 3:
            X = torch.cat([X_full[0], X])
            y = torch.cat([torch.zeros(n_samples), y])
        acc = evaluate(model, X, y, device=device, batch_size=batch_size, **kwargs)
        accuracies.append(acc)
    return accuracies


def run(
    start: str,
    end: str,
    learn_A: bool,
    learn_B: bool,
    relearn: bool = False,
    train_kwargs: Dict = {},
    eval_kwargs: Dict = {},
):
    assert start is None or start in model_checkpoints
    model = get_model(model_checkpoints.get(start))
    model = global_train(
        model, learn_A=learn_A, learn_B=learn_B, relearn=relearn, kwargs=train_kwargs
    )
    evals[end] = global_eval(model, kwargs=eval_kwargs)
    print(
        f"{end}, A: {evals[end][0]:.2f}, B: {evals[end][1]:.2f}, Retain: {evals[end][2]:.2f}"
    )
    model_checkpoints[end] = deepcopy(model)


def run_relearn(name: str, train_kwargs: Dict = {}, eval_kwargs: Dict = {}):
    run(
        name,
        f"{name}-A",
        learn_A=True,
        learn_B=False,
        relearn=True,
        train_kwargs=train_kwargs,
        eval_kwargs=eval_kwargs,
    )
    run(
        name,
        f"{name}-B",
        learn_A=False,
        learn_B=True,
        relearn=True,
        train_kwargs=train_kwargs,
        eval_kwargs=eval_kwargs,
    )


run(None, "init", learn_A=True, learn_B=True)
run("init", "base", learn_A=False, learn_B=False)
run("init", "base-lu-partial", learn_A=False, learn_B=True)
run("base-lu-partial", "base-lu", learn_A=False, learn_B=False)
run_relearn("base")
run_relearn("base-lu")

Epoch 1/3:   0%|          | 0/625 [00:00<?, ?it/s, loss=0.591]

Epoch 1/3: 100%|██████████| 625/625 [00:01<00:00, 518.93it/s, loss=0.16]  
Epoch 2/3: 100%|██████████| 625/625 [00:01<00:00, 604.12it/s, loss=0.251] 
Epoch 3/3: 100%|██████████| 625/625 [00:01<00:00, 590.10it/s, loss=0.211] 


init, A: 1.00, B: 1.00, Retain: 0.92


Epoch 1/3: 100%|██████████| 625/625 [00:01<00:00, 598.93it/s, loss=0.0294]
Epoch 2/3: 100%|██████████| 625/625 [00:01<00:00, 593.54it/s, loss=0.017]  
Epoch 3/3: 100%|██████████| 625/625 [00:01<00:00, 593.10it/s, loss=0.0763] 


base, A: 0.00, B: 0.00, Retain: 0.98


Epoch 1/3: 100%|██████████| 625/625 [00:01<00:00, 580.89it/s, loss=0.129] 
Epoch 2/3: 100%|██████████| 625/625 [00:01<00:00, 574.39it/s, loss=0.0277]
Epoch 3/3: 100%|██████████| 625/625 [00:01<00:00, 576.65it/s, loss=0.0275]


base-lu-partial, A: 0.00, B: 1.00, Retain: 0.96


Epoch 1/3: 100%|██████████| 625/625 [00:01<00:00, 579.89it/s, loss=0.0719] 
Epoch 2/3: 100%|██████████| 625/625 [00:01<00:00, 578.06it/s, loss=0.0103] 
Epoch 3/3: 100%|██████████| 625/625 [00:01<00:00, 594.26it/s, loss=0.026]  


base-lu, A: 0.00, B: 0.00, Retain: 0.98


Epoch 1/3: 100%|██████████| 157/157 [00:00<00:00, 587.17it/s, loss=0.576]
Epoch 2/3: 100%|██████████| 157/157 [00:00<00:00, 590.84it/s, loss=0.0833]
Epoch 3/3: 100%|██████████| 157/157 [00:00<00:00, 587.37it/s, loss=0.0457]


base-A, A: 1.00, B: 0.68, Retain: 0.90


Epoch 1/3: 100%|██████████| 157/157 [00:00<00:00, 589.55it/s, loss=0.518]
Epoch 2/3: 100%|██████████| 157/157 [00:00<00:00, 587.38it/s, loss=0.0795]
Epoch 3/3: 100%|██████████| 157/157 [00:00<00:00, 588.60it/s, loss=0.0544]


base-B, A: 0.60, B: 1.00, Retain: 0.91


Epoch 1/3: 100%|██████████| 157/157 [00:00<00:00, 582.98it/s, loss=4.98]
Epoch 2/3: 100%|██████████| 157/157 [00:00<00:00, 586.87it/s, loss=0.352]
Epoch 3/3: 100%|██████████| 157/157 [00:00<00:00, 581.71it/s, loss=0.0943]


base-lu-A, A: 1.00, B: 0.68, Retain: 0.87


Epoch 1/3: 100%|██████████| 157/157 [00:00<00:00, 583.99it/s, loss=0.471]
Epoch 2/3: 100%|██████████| 157/157 [00:00<00:00, 587.06it/s, loss=0.0996]
Epoch 3/3: 100%|██████████| 157/157 [00:00<00:00, 588.27it/s, loss=0.058] 


base-lu-B, A: 0.00, B: 1.00, Retain: 0.94


## Visualizations
We visualize decision boundaries learned and the resulting model.

In [4]:
df_dict = [
    {
        "name": name,
        "A": result[0],
        "B": result[1],
        "retain": result[2],
    }
    for name, result in evals.items()
]


df = pd.DataFrame(df_dict)

In [23]:
def visualize(
    name: str,
    X: torch.Tensor,
    y: torch.Tensor,
    n_grid: int = 100,
    n_samples: int = None,
    output_path: Path = None,
    include_scatter: bool = True,
):
    model = model_checkpoints[name]
    model.eval()
    if n_samples is not None:
        if n_samples > X.size(0):
            n_samples = X.size(0)
        inds = torch.randperm(X.size(0))[:n_samples]
        X = X[inds]
        y = y[inds]
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    width = 80
    x_min, x_max = -width, width
    y_min, y_max = -width, width
    xx, yy = torch.meshgrid(
        torch.linspace(x_min, x_max, n_grid),
        torch.linspace(y_min, y_max, n_grid),
    )
    grid = torch.stack([xx.ravel(), yy.ravel()], dim=1).to(device)
    with torch.no_grad():
        grid_out = model(grid).squeeze().cpu()

    plt.figure(figsize=(8, 6))
    plt.contourf(
        xx.cpu(),
        yy.cpu(),
        grid_out.reshape(xx.shape),
        levels=[0, 0.5, 1],
        alpha=0.2,
        cmap="coolwarm",
    )

    def plot_gaussian_ellipse(gaussian: Gaussian, n_std: float = 2.5, **kwargs):
        import numpy as np
        from matplotlib.patches import Ellipse

        """
        Add an n‑σ ellipse of a 2‑D Gaussian (mean, cov) to *ax*.
        Extra **kwargs are forwarded to matplotlib.patches.Ellipse.
        """
        ax = plt.gca()
        mean = gaussian.mu.cpu().numpy()
        cov = gaussian.cov.cpu().numpy()
        # Eigen‑decomposition of the covariance matrix
        vals, vecs = np.linalg.eigh(cov)
        order = vals.argsort()[::-1]  # largest first
        vals, vecs = vals[order], vecs[:, order]

        # Rotation of the ellipse (deg)
        theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))

        # Full‑width/height of the ellipse (factor 2 because Ellipse wants diameters)
        width, height = 2 * n_std * np.sqrt(vals)

        ellipse = Ellipse(
            xy=mean,
            width=width,
            height=height,
            angle=theta,
            facecolor="none",
            linestyle="--",
            linewidth=2,
            **kwargs,
        )
        ax.add_patch(ellipse)
        return ellipse

    def scatter(x: torch.Tensor, y: torch.Tensor, **kwargs):
        plt.scatter(
            x.cpu(),
            y.cpu(),
            s=1,
            **kwargs,
        )

    if include_scatter:
        scatter(X_full[0][:, 0], X_full[0][:, 1], color="blue", label="Null")
        scatter(X_full[1][:, 0], X_full[1][:, 1], color="orange", label="A")
        scatter(X_full[2][:, 0], X_full[2][:, 1], color="yellow", label="B")
        scatter(X_full[3][:, 0], X_full[3][:, 1], color="red", label="Retain")

    # plot_gaussian_ellipse(
    #     gaussians[0],
    #     edgecolor="blue",
    #     label="Null",
    # )
    # plot_gaussian_ellipse(gaussians[1], edgecolor="orange", label="A")
    # plot_gaussian_ellipse(gaussians[2], edgecolor="yellow", label="B")
    # plot_gaussian_ellipse(gaussians[3], edgecolor="red", label="C")

    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    plt.xlabel("Feature 1")
    plt.ylabel("Feature 2")
    plt.title("Decision Boundary")
    plt.legend()
    if output_path is not None:
        plt.savefig(output_path)
    else:
        plt.show()


base_dir = Path("./gmm_figures")
base_dir.mkdir(exist_ok=True)

X, y = construct_dataset(X_full, learn_A=True, learn_B=True, n_samples=n_samples)
for name in model_checkpoints:
    visualize(
        name,
        X,
        y,
        n_grid=100,
        n_samples=5000,
        output_path=base_dir / f"{name}.png",
        include_scatter=name == "init",
    )
    plt.close()

  plt.legend()
