# GMM Classification Experiments 


We experiment with 2D logistic models.

In [40]:
%load_ext autoreload
%autoreload 2


import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from torch import nn
from torch import optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from copy import deepcopy
import matplotlib.pyplot as plt
from layered_unlearning.utils import set_seed
from layered_unlearning.gmm_classification import (
    Gaussian,
    GaussianMixture,
    LogisticModel,
    Uniform,
    train,
    evaluate,
    construct_dataset,
)
import math
from typing import Dict, List
from pathlib import Path

seed = set_seed(0)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Hyperparameters
Default hyperparameters for our experiments. Of note, we do this in 2 dimensions. 

In [73]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n_epochs = 3
lr = 1e-3
batch_size = 32
n_classes = 2
n_samples = 5000
dim = 2
weight_decay = 0.0
weight_delta_penalty = 0.0

rbf = True
degree = 1

eps = 1e-8
n_layers = 0
batch_norm = True
hidden_dim = 128

loss_type = "cross_entropy"

def ellipse(rotate: float = 0.0, x_scale: float = 1.0, y_scale: float = 1.0):
    rotate = rotate * (torch.pi / 180)
    cov = torch.Tensor([
        [x_scale, 0],
        [0, y_scale],   
    ])
    rotate = torch.Tensor([
        [math.cos(rotate), -math.sin(rotate)],
        [math.sin(rotate), math.cos(rotate)],
    ])
    return rotate @ cov @ rotate.T

def mu_gen():
    width = 50
    return torch.rand((dim,)) * 2 * width - width 


def cov_gen():
    base = torch.eye(dim) * 4
    U = torch.randn((dim, dim))
    perturb = U.T @ U * 0.1
    return base + perturb


def get_gaussian_mixture(
    n_classes: int,
) -> GaussianMixture:
    classes = []
    for i in range(n_classes):
        classes.append(
            Gaussian(
                mu=mu_gen(),
                cov=cov_gen(),
            )
        )

    mixture = GaussianMixture(
        classes=classes,
        weights=torch.ones(n_classes) / n_classes,
    )
    return mixture


gaussians = [
    Uniform(
        low=torch.tensor([-1.0, -1.0]) * 200,
        high=torch.tensor([1.0, 1.0]) * 200
    ),
    get_gaussian_mixture(3),
    get_gaussian_mixture(3),
    get_gaussian_mixture(3),
]

# null, task A, task B, retain

X_full = [g.sample(n_samples) for g in gaussians]

  torch.tensor(self.weights), n_samples, replacement=True


## Training
We train the initial model, the base unlearned model, and the Layered Unlearning (LU) version of the base unlearned model. 

In [77]:
model_checkpoints = {}
evals = {}


def get_model(old_model: nn.Module = None):
    model = LogisticModel(
        dim=dim,
        n_classes=n_classes,
        degree=degree,
        rbf=rbf,
        n_layers=n_layers,
        batch_norm=batch_norm,
        hidden_dim=hidden_dim,
    ).to(device)
    if old_model is not None:
        model.load_state_dict(old_model.state_dict())
    return model


def global_train(model: nn.Module, learn_A: bool, learn_B: bool, relearn: bool = False, kwargs: Dict = {}):
    X, y = construct_dataset(
        X_full, learn_A=learn_A, learn_B=learn_B, relearn=relearn, n_samples=n_samples
    )
    init_kwargs = {
        "eps": eps,
        "n_epochs": n_epochs,
        "lr": lr,
        "batch_size": batch_size,
        "weight_decay": weight_decay,
        "device": device,
        "loss_type": loss_type,
    }
    init_kwargs.update(kwargs)
    model = train(
        model,
        X,
        y,
        **init_kwargs,
    )
    return model


def global_eval(model: nn.Module, kwargs: Dict = {}):
    accuracies = []
    for i in range(1, 4):
        X = X_full[i]
        y = torch.ones(n_samples)
        acc = evaluate(model, X, y, device=device, **kwargs)
        accuracies.append(acc)
    return accuracies


def run(start: str, end: str, learn_A: bool, learn_B: bool, relearn: bool = False, train_kwargs: Dict = {}, eval_kwargs: Dict = {}):
    assert start is None or start in model_checkpoints
    model = get_model(model_checkpoints.get(start))
    model = global_train(model, learn_A=learn_A, learn_B=learn_B, relearn=relearn, kwargs=train_kwargs)
    evals[end] = global_eval(model, kwargs=eval_kwargs)
    print(
        f"{end}, A: {evals[end][0]:.2f}, B: {evals[end][1]:.2f}, Retain: {evals[end][2]:.2f}"
    )
    model_checkpoints[end] = deepcopy(model)


def run_relearn(name: str):
    run(name, f"{name}-A", learn_A=True, learn_B=False, relearn=True)
    run(name, f"{name}-B", learn_A=False, learn_B=True, relearn=True)


run(None, "init", learn_A=True, learn_B=True)
run("init", "base", learn_A=False, learn_B=False)
run("init", "base-lu-partial", learn_A=False, learn_B=True)
run("base-lu-partial", "base-lu", learn_A=False, learn_B=False)
run_relearn("base")
run_relearn("base-lu")

Epoch 1/3: 100%|██████████| 625/625 [00:01<00:00, 376.68it/s, loss=0.439]
Epoch 2/3: 100%|██████████| 625/625 [00:01<00:00, 378.52it/s, loss=0.478]
Epoch 3/3: 100%|██████████| 625/625 [00:01<00:00, 369.35it/s, loss=0.46] 


init, A: 1.00, B: 1.00, Retain: 1.00


Epoch 1/3: 100%|██████████| 625/625 [00:01<00:00, 376.72it/s, loss=0.687]
Epoch 2/3: 100%|██████████| 625/625 [00:01<00:00, 368.95it/s, loss=0.516]
Epoch 3/3: 100%|██████████| 625/625 [00:01<00:00, 366.13it/s, loss=0.375]


base, A: 0.00, B: 0.00, Retain: 0.82


Epoch 1/3: 100%|██████████| 625/625 [00:01<00:00, 382.96it/s, loss=0.465]
Epoch 2/3: 100%|██████████| 625/625 [00:01<00:00, 401.20it/s, loss=0.537]
Epoch 3/3: 100%|██████████| 625/625 [00:01<00:00, 393.50it/s, loss=0.431]


base-lu-partial, A: 0.01, B: 0.91, Retain: 0.91


Epoch 1/3: 100%|██████████| 625/625 [00:01<00:00, 425.55it/s, loss=0.37] 
Epoch 2/3: 100%|██████████| 625/625 [00:01<00:00, 411.88it/s, loss=0.312]
Epoch 3/3: 100%|██████████| 625/625 [00:01<00:00, 421.18it/s, loss=0.262]


base-lu, A: 0.00, B: 0.00, Retain: 0.78


Epoch 1/3: 100%|██████████| 157/157 [00:00<00:00, 421.89it/s, loss=0.488]
Epoch 2/3: 100%|██████████| 157/157 [00:00<00:00, 428.27it/s, loss=0.198]
Epoch 3/3: 100%|██████████| 157/157 [00:00<00:00, 428.36it/s, loss=0.193] 


base-A, A: 0.99, B: 0.01, Retain: 0.34


Epoch 1/3: 100%|██████████| 157/157 [00:00<00:00, 403.11it/s, loss=0.117]
Epoch 2/3: 100%|██████████| 157/157 [00:00<00:00, 402.09it/s, loss=0.0281]
Epoch 3/3: 100%|██████████| 157/157 [00:00<00:00, 426.44it/s, loss=0.0141]


base-B, A: 0.33, B: 1.00, Retain: 0.66


Epoch 1/3: 100%|██████████| 157/157 [00:00<00:00, 387.89it/s, loss=0.948]
Epoch 2/3: 100%|██████████| 157/157 [00:00<00:00, 421.71it/s, loss=0.799]
Epoch 3/3: 100%|██████████| 157/157 [00:00<00:00, 414.40it/s, loss=0.429]


base-lu-A, A: 0.85, B: 0.00, Retain: 0.34


Epoch 1/3: 100%|██████████| 157/157 [00:00<00:00, 392.16it/s, loss=0.132]
Epoch 2/3: 100%|██████████| 157/157 [00:00<00:00, 388.27it/s, loss=0.0671]
Epoch 3/3: 100%|██████████| 157/157 [00:00<00:00, 383.05it/s, loss=0.0175]


base-lu-B, A: 0.33, B: 1.00, Retain: 0.66


## Visualizations
We visualize decision boundaries learned and the resulting model.

In [76]:
def visualize(
    name: str,
    X: torch.Tensor,
    y: torch.Tensor,
    n_grid: int = 100,
    n_samples: int = None,
    output_path: Path = None,
):
    model = model_checkpoints[name]
    model.eval()
    if n_samples is not None:
        if n_samples > X.size(0):
            n_samples = X.size(0)
        inds = torch.randperm(X.size(0))[:n_samples]
        X = X[inds]
        y = y[inds]
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    x_min, x_max = -1000, 1000
    y_min, y_max = -1000, 1000
    xx, yy = torch.meshgrid(
        torch.linspace(x_min, x_max, n_grid),
        torch.linspace(y_min, y_max, n_grid),
    )
    grid = torch.stack([xx.ravel(), yy.ravel()], dim=1).to(device)
    with torch.no_grad():
        grid_out = model(grid).squeeze().cpu()

    plt.figure(figsize=(8, 6))
    plt.contourf(
        xx.cpu(),
        yy.cpu(),
        grid_out.reshape(xx.shape),
        levels=[0, 0.5, 1],
        alpha=0.2,
        cmap="coolwarm",
    )

    def plot_gaussian_ellipse(gaussian: Gaussian, n_std: float = 2.5, **kwargs):
        import numpy as np
        from matplotlib.patches import Ellipse

        """
        Add an n‑σ ellipse of a 2‑D Gaussian (mean, cov) to *ax*.
        Extra **kwargs are forwarded to matplotlib.patches.Ellipse.
        """
        ax = plt.gca()
        mean = gaussian.mu.cpu().numpy()
        cov = gaussian.cov.cpu().numpy()
        # Eigen‑decomposition of the covariance matrix
        vals, vecs = np.linalg.eigh(cov)
        order = vals.argsort()[::-1]  # largest first
        vals, vecs = vals[order], vecs[:, order]

        # Rotation of the ellipse (deg)
        theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))

        # Full‑width/height of the ellipse (factor 2 because Ellipse wants diameters)
        width, height = 2 * n_std * np.sqrt(vals)

        ellipse = Ellipse(
            xy=mean,
            width=width,
            height=height,
            angle=theta,
            facecolor="none",
            linestyle="--",
            linewidth=2,
            **kwargs,
        )
        ax.add_patch(ellipse)
        return ellipse
    
    def scatter(x: torch.Tensor, y: torch.Tensor, **kwargs):
        plt.scatter(
            x.cpu(),
            y.cpu(),
            s=1,
            **kwargs,
        )

    scatter(X_full[0][:, 0], X_full[0][:, 1], color="blue", label="Null")
    # scatter(X_full[1][:, 0], X_full[1][:, 1], color="orange", label="A")
    # scatter(X_full[2][:, 0], X_full[2][:, 1], color="yellow", label="B")
    # scatter(X_full[3][:, 0], X_full[3][:, 1], color="red", label="Retain")


    # plot_gaussian_ellipse(
    #     gaussians[0],
    #     edgecolor="blue",
    #     label="Null",
    # )
    # plot_gaussian_ellipse(gaussians[1], edgecolor="orange", label="A")
    # plot_gaussian_ellipse(gaussians[2], edgecolor="yellow", label="B")
    # plot_gaussian_ellipse(gaussians[3], edgecolor="red", label="C")


    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    plt.xlabel("Feature 1")
    plt.ylabel("Feature 2")
    plt.title("Decision Boundary")
    plt.legend()
    if output_path is not None:
        plt.savefig(output_path)
    else:
        plt.show()


base_dir = Path("./gmm_figures")
base_dir.mkdir(exist_ok=True)

X, y = construct_dataset(X_full, learn_A=True, learn_B=True, n_samples=n_samples)
for name in model_checkpoints:
    visualize(
        name,
        X,
        y,
        n_grid=100,
        n_samples=5000,
        output_path=base_dir / f"{name}.png",
    )
    plt.close()