# GMM Classification Experiments 


We experiment with 2D logistic models.

In [1]:
%load_ext autoreload
%autoreload 2


import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import torch
from torch import nn
from torch import optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from copy import deepcopy
import matplotlib.pyplot as plt
from layered_unlearning.utils import set_seed
from layered_unlearning.gmm_classification import (
    Gaussian,
    LogisticModel,
    train,
    evaluate,
    construct_dataset,
)
from pathlib import Path

seed = set_seed(0)

## Hyperparameters
Default hyperparameters for our experiments. Of note, we do this in 2 dimensions. 

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n_epochs = 2
lr = 1e-3
batch_size = 32
n_classes = 2
n_samples = 10000
dim = 2
weight_decay = 1e-3
degree = 2
eps = 1e-8
scale = 1.5
large_scale = 100

loss_type = "cross_entropy"

# Null, A, B, C
gaussians = [
    Gaussian(
        mu=torch.tensor([0.0, 0.0]),
        cov=torch.eye(dim) * scale * large_scale,
    ),
    Gaussian(
        mu=torch.tensor([-2.0, 1.5]),
        cov=torch.eye(dim) * scale,
    ),
    Gaussian(
        mu=torch.tensor([2.0, 1.5]),
        cov=torch.eye(dim) * scale,
    ),
    Gaussian(mu=torch.tensor([0.0, -3.0]), cov=torch.eye(dim) * scale),
]

# null, task A, task B, retain

X_full = [g.sample(n_samples) for g in gaussians]

## Training
We train the initial model, the base unlearned model, and the Layered Unlearning (LU) version of the base unlearned model. 

In [None]:
model_checkpoints = {}
evals = {}


def get_model(old_model: nn.Module = None):
    model = LogisticModel(
        dim=dim,
        n_classes=n_classes,
        degree=degree,
    ).to(device)
    if old_model is not None:
        model.load_state_dict(old_model.state_dict())
    return model


def global_train(model: nn.Module, learn_A: bool, learn_B: bool, relearn: bool = False):
    X, y = construct_dataset(
        X_full, learn_A=learn_A, learn_B=learn_B, relearn=relearn, n_samples=n_samples
    )
    model = train(
        model,
        X,
        y,
        eps=eps,
        n_epochs=n_epochs,
        lr=lr,
        batch_size=batch_size,
        weight_decay=weight_decay,
        device=device,
        loss_type=loss_type,
    )
    return model


def global_eval(model: nn.Module):
    accuracies = []
    for i in range(1, 4):
        X = X_full[i]
        y = torch.ones(n_samples)
        acc = evaluate(model, X, y, device=device)
        accuracies.append(acc)
    return accuracies


def run(start: str, end: str, learn_A: bool, learn_B: bool, relearn: bool = False):
    assert start is None or start in model_checkpoints
    model = get_model(model_checkpoints.get(start))
    model = global_train(model, learn_A=learn_A, learn_B=learn_B, relearn=relearn)
    evals[end] = global_eval(model)
    print(
        f"{end}, A: {evals[end][0]:.2f}, B: {evals[end][1]:.2f}, Retain: {evals[end][2]:.2f}"
    )
    model_checkpoints[end] = deepcopy(model)


def run_relearn(name: str):
    run(name, f"{name}-A", learn_A=True, learn_B=False, relearn=True)
    run(name, f"{name}-B", learn_A=False, learn_B=True, relearn=True)


run(None, "init", learn_A=True, learn_B=True)
run("init", "base", learn_A=False, learn_B=False)
run("init", "base-lu-partial", learn_A=False, learn_B=True)
run("base-lu-partial", "base-lu", learn_A=False, learn_B=False)
run_relearn("base")
run_relearn("base-lu")

Epoch 1/2: 100%|██████████| 1250/1250 [00:02<00:00, 565.29it/s, loss=0.349]
Epoch 2/2: 100%|██████████| 1250/1250 [00:02<00:00, 496.03it/s, loss=0.248]


init, A: 1.00, B: 1.00, Retain: 1.00


Epoch 1/2: 100%|██████████| 1250/1250 [00:02<00:00, 553.81it/s, loss=0.265] 
Epoch 2/2: 100%|██████████| 1250/1250 [00:02<00:00, 522.31it/s, loss=0.139] 


base, A: 0.06, B: 0.06, Retain: 0.95


Epoch 1/2: 100%|██████████| 1250/1250 [00:02<00:00, 534.10it/s, loss=0.216]
Epoch 2/2: 100%|██████████| 1250/1250 [00:02<00:00, 530.90it/s, loss=0.195] 


base-lu-partial, A: 0.12, B: 0.95, Retain: 0.97


Epoch 1/2: 100%|██████████| 1250/1250 [00:02<00:00, 527.31it/s, loss=0.212] 
Epoch 2/2: 100%|██████████| 1250/1250 [00:02<00:00, 566.03it/s, loss=0.115] 


base-lu, A: 0.04, B: 0.06, Retain: 0.95


Epoch 1/2: 100%|██████████| 313/313 [00:00<00:00, 510.46it/s, loss=0.449]
Epoch 2/2: 100%|██████████| 313/313 [00:00<00:00, 517.70it/s, loss=0.173]


base-A, A: 0.98, B: 0.20, Retain: 1.00


Epoch 1/2: 100%|██████████| 313/313 [00:00<00:00, 519.00it/s, loss=0.541]
Epoch 2/2: 100%|██████████| 313/313 [00:00<00:00, 531.05it/s, loss=0.188]


base-B, A: 0.20, B: 0.98, Retain: 1.00


Epoch 1/2: 100%|██████████| 313/313 [00:00<00:00, 524.21it/s, loss=1.16]
Epoch 2/2: 100%|██████████| 313/313 [00:00<00:00, 516.36it/s, loss=0.253]


base-lu-A, A: 0.95, B: 0.19, Retain: 1.00


Epoch 1/2: 100%|██████████| 313/313 [00:00<00:00, 520.73it/s, loss=0.647]
Epoch 2/2: 100%|██████████| 313/313 [00:00<00:00, 520.14it/s, loss=0.174]


base-lu-B, A: 0.11, B: 0.94, Retain: 1.00


## Visualizations
We visualize decision boundaries learned and the resulting model.

In [15]:
def visualize(
    name: str,
    X: torch.Tensor,
    y: torch.Tensor,
    n_grid: int = 100,
    n_samples: int = None,
    output_path: Path = None,
):
    model = model_checkpoints[name]
    model.eval()
    if n_samples is not None:
        if n_samples > X.size(0):
            n_samples = X.size(0)
        inds = torch.randperm(X.size(0))[:n_samples]
        X = X[inds]
        y = y[inds]
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = torch.meshgrid(
        torch.linspace(x_min, x_max, n_grid),
        torch.linspace(y_min, y_max, n_grid),
    )
    grid = torch.stack([xx.ravel(), yy.ravel()], dim=1).to(device)
    with torch.no_grad():
        grid_out = model(grid).squeeze().cpu()

    plt.figure(figsize=(8, 6))
    plt.contourf(
        xx.cpu(),
        yy.cpu(),
        grid_out.reshape(xx.shape),
        levels=[0, 0.5, 1],
        alpha=0.2,
        cmap="coolwarm",
    )

    def plot_gaussian_ellipse(gaussian: Gaussian, n_std: float = 2.5, **kwargs):
        import numpy as np
        from matplotlib.patches import Ellipse

        """
        Add an n‑σ ellipse of a 2‑D Gaussian (mean, cov) to *ax*.
        Extra **kwargs are forwarded to matplotlib.patches.Ellipse.
        """
        ax = plt.gca()
        mean = gaussian.mu.cpu().numpy()
        cov = gaussian.cov.cpu().numpy()
        # Eigen‑decomposition of the covariance matrix
        vals, vecs = np.linalg.eigh(cov)
        order = vals.argsort()[::-1]  # largest first
        vals, vecs = vals[order], vecs[:, order]

        # Rotation of the ellipse (deg)
        theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))

        # Full‑width/height of the ellipse (factor 2 because Ellipse wants diameters)
        width, height = 2 * n_std * np.sqrt(vals)

        ellipse = Ellipse(
            xy=mean,
            width=width,
            height=height,
            angle=theta,
            facecolor="none",
            linestyle="--",
            linewidth=2,
            **kwargs,
        )
        ax.add_patch(ellipse)
        return ellipse

    plot_gaussian_ellipse(
        gaussians[0],
        edgecolor="blue",
        label="Null",
    )
    plot_gaussian_ellipse(gaussians[1], edgecolor="orange", label="A")
    plot_gaussian_ellipse(gaussians[2], edgecolor="yellow", label="B")
    plot_gaussian_ellipse(gaussians[3], edgecolor="red", label="C")

    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    plt.xlabel("Feature 1")
    plt.ylabel("Feature 2")
    plt.title("Decision Boundary")
    plt.legend()
    if output_path is not None:
        plt.savefig(output_path)
    else:
        plt.show()


base_dir = Path("./gmm_figures")
base_dir.mkdir(exist_ok=True)

X, y = construct_dataset(X_full, learn_A=True, learn_B=True, n_samples=n_samples)
for name in model_checkpoints:
    visualize(
        name,
        X,
        y,
        n_grid=100,
        n_samples=5000,
        output_path=base_dir / f"{name}.png",
    )
    plt.close()