# GMM Classification Experiments 3


We experiment with 2D logistic models.

In [1]:
%load_ext autoreload
%autoreload 2


import os

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

import torch
from torch import nn
from torch import optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from copy import deepcopy
import matplotlib.pyplot as plt
from layered_unlearning.utils import set_seed
from layered_unlearning.gmm_classification import (
    Gaussian,
    LogisticModel,
)
from pathlib import Path

seed = set_seed(0)

## Scripts
Below are scripts for training and evaluating our models. Relearning is included in the training script.

In [2]:
def evaluate(
    model: nn.Module,
    X: torch.Tensor,
    y: torch.Tensor,
    device: str = "cuda",
):
    # Convert data to PyTorch tensors
    X = X.to(device)
    y = y.to(device)

    model.eval()
    with torch.no_grad():
        outputs = model(X).squeeze()
        y_pred = (outputs > 0.5).float()
        accuracy = (y_pred == y).float().mean().item()
    return accuracy


def train(
    model: nn.Module,
    X: torch.Tensor,
    y: torch.Tensor,
    n_epochs: int = 1,
    lr: float = 0.01,
    batch_size: int = 32,
    weight_decay: float = 0.01,
    device: str = "cuda",
    eps: float = 1e-8,
):
    """
    Train the model using the given data and parameters.
    log_1_minus_p: if True, we optimize log(1 - p), otherwise we do gradient ascent.
    flip_mask: mask for the data points we want to flip in terms of leanr/unlearn.
    mask: mask for the data points we want to use for training, used for relearning.
    """
    # Convert data to PyTorch tensors
    X = X.to(device)
    y = y.to(device)

    X_train = X
    y_train = y

    # Define loss function and optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    dataloader = DataLoader(
        list(zip(X_train, y_train)),
        batch_size=batch_size,
        shuffle=True,
    )

    for epoch in range(n_epochs):
        model.train()
        for batch_X, batch_y in (
            pbar := tqdm(dataloader, desc=f"Epoch {epoch + 1}/{n_epochs}")
        ):
            optimizer.zero_grad()
            outputs = model(batch_X).squeeze()
            batch_y = batch_y.float()

            loss = -(
                batch_y * torch.log(outputs + eps)
                + (1 - batch_y) * torch.log(1 - outputs + eps)
            )

            loss = loss.mean()
            loss.backward()
            optimizer.step()
            pbar.set_postfix(
                {
                    "loss": loss.item(),
                }
            )

    return model

## Hyperparameters
Default hyperparameters for our experiments. Of note, we do this in 2 dimensions. 

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n_epochs = 2
lr = 1e-2
batch_size = 32
n_classes = 2
n_samples = 10000
dim = 2
weight_decay = 1e-3
quadratic_features = True
eps = 1e-8
scale = 0.15
gaussians = [
    Gaussian(
        mu=torch.tensor([3.0, 0.0]),
        cov=torch.eye(dim) * 10,
    ),
    Gaussian(
        mu=torch.tensor([2.0, 1.0]),
        cov=torch.eye(dim) * scale,
    ),
    Gaussian(
        mu=torch.tensor([2.0, -1.0]),
        cov=torch.eye(dim) * scale,
    ),
    Gaussian(mu=torch.tensor([4.0, 0.0]), cov=torch.eye(dim) * scale),
]

# null, task A, task B, retain

X_full = [g.sample(n_samples) for g in gaussians]

## Training
We train the initial model, the base unlearned model, and the Layered Unlearning (LU) version of the base unlearned model. 

In [4]:
model_checkpoints = {}
evals = {}


def get_model(old_model: nn.Module = None):
    model = LogisticModel(
        dim=dim,
        n_classes=n_classes,
        quadratic_features=quadratic_features,
    ).to(device)
    if old_model is not None:
        model.load_state_dict(old_model.state_dict())
    return model


def construct_dataset(learn_A: bool, learn_B: bool, relearn: bool = False):
    X = []
    y = []

    if not relearn:
        X.append(X_full[0])
        y.append(torch.zeros(n_samples))

        X.append(X_full[3])
        y.append(torch.ones(n_samples))

    if learn_A:
        X.append(X_full[1])
        y.append(torch.ones(n_samples))
    elif not relearn:
        X.append(X_full[1])
        y.append(torch.zeros(n_samples))

    if learn_B:
        X.append(X_full[2])
        y.append(torch.ones(n_samples))
    elif not relearn:
        X.append(X_full[2])
        y.append(torch.zeros(n_samples))
    X = torch.cat(X)
    y = torch.cat(y)
    return X, y


def global_train(model: nn.Module, learn_A: bool, learn_B: bool, relearn: bool = False):
    X, y = construct_dataset(learn_A=learn_A, learn_B=learn_B, relearn=relearn)
    model = train(
        model,
        X,
        y,
        eps=eps,
        n_epochs=n_epochs,
        lr=lr,
        batch_size=batch_size,
        weight_decay=weight_decay,
        device=device,
    )
    return model


def global_eval(model: nn.Module):
    accuracies = []
    for i in range(1, 4):
        X = X_full[i]
        y = torch.ones(n_samples)
        acc = evaluate(model, X, y, device=device)
        accuracies.append(acc)
    return accuracies


def visualize(
    name: str,
    X: torch.Tensor,
    y: torch.Tensor,
    n_grid: int = 100,
    n_samples: int = None,
    output_path: Path = None,
):
    model = model_checkpoints[name]
    model.eval()
    if n_samples is not None:
        if n_samples > X.size(0):
            n_samples = X.size(0)
        inds = torch.randperm(X.size(0))[:n_samples]
        X = X[inds]
        y = y[inds]
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = torch.meshgrid(
        torch.linspace(x_min, x_max, n_grid),
        torch.linspace(y_min, y_max, n_grid),
    )
    grid = torch.stack([xx.ravel(), yy.ravel()], dim=1).to(device)
    with torch.no_grad():
        grid_out = model(grid).squeeze().cpu()

    plt.figure(figsize=(8, 6))
    plt.contourf(
        xx.cpu(),
        yy.cpu(),
        grid_out.reshape(xx.shape),
        levels=[0, 0.5, 1],
        alpha=0.2,
        cmap="coolwarm",
    )

    def scatter(index: int, label: str, color: str):
        plt.scatter(
            X_full[index][:, 0].cpu(),
            X_full[index][:, 1].cpu(),
            label=label,
            alpha=0.6,
            edgecolors="k",
            color=color,
        )

    scatter(0, "Null", "blue")
    scatter(1, "Task A", "orange")
    scatter(2, "Task B", "yellow")
    scatter(3, "Retain", "red")

    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    plt.xlabel("Feature 1")
    plt.ylabel("Feature 2")
    plt.title("Decision Boundary")
    plt.legend()
    if output_path is not None:
        plt.savefig(output_path)
    else:
        plt.show()


def run(start: str, end: str, learn_A: bool, learn_B: bool, relearn: bool = False):
    assert start is None or start in model_checkpoints
    model = get_model(model_checkpoints.get(start))
    model = global_train(model, learn_A=learn_A, learn_B=learn_B, relearn=relearn)
    evals[end] = global_eval(model)
    print(end, evals[end])
    model_checkpoints[end] = deepcopy(model)


def run_relearn(name: str):
    run(name, f"{name}-A", learn_A=True, learn_B=False, relearn=True)
    run(name, f"{name}-B", learn_A=False, learn_B=True, relearn=True)


run(None, "init", learn_A=True, learn_B=True)
run("init", "base", learn_A=False, learn_B=False)
run("init", "base-lu-partial", learn_A=False, learn_B=True)
run("base-lu-partial", "base-lu", learn_A=False, learn_B=False)
run_relearn("base")
run_relearn("base-lu")

Epoch 1/2: 100%|██████████| 1250/1250 [00:02<00:00, 568.39it/s, loss=0.0749]
Epoch 2/2: 100%|██████████| 1250/1250 [00:02<00:00, 615.81it/s, loss=0.26]  


init [0.9968999624252319, 0.9960999488830566, 0.9991999864578247]


Epoch 1/2: 100%|██████████| 1250/1250 [00:02<00:00, 554.40it/s, loss=0.21]  
Epoch 2/2: 100%|██████████| 1250/1250 [00:02<00:00, 584.34it/s, loss=0.163] 


base [0.023799998685717583, 0.026599999517202377, 0.9402999877929688]


Epoch 1/2: 100%|██████████| 1250/1250 [00:02<00:00, 556.21it/s, loss=0.153] 
Epoch 2/2: 100%|██████████| 1250/1250 [00:02<00:00, 590.66it/s, loss=0.117] 


base-lu-partial [0.04429999738931656, 0.9939999580383301, 0.9803999662399292]


Epoch 1/2: 100%|██████████| 1250/1250 [00:02<00:00, 541.11it/s, loss=0.194] 
Epoch 2/2: 100%|██████████| 1250/1250 [00:02<00:00, 558.63it/s, loss=0.229] 


base-lu [0.004999999888241291, 0.011299999430775642, 0.9434999823570251]


Epoch 1/2: 100%|██████████| 313/313 [00:00<00:00, 528.22it/s, loss=0.0155] 
Epoch 2/2: 100%|██████████| 313/313 [00:00<00:00, 595.45it/s, loss=0.00447] 


base-A [0.9967999458312988, 0.6028000116348267, 1.0]


Epoch 1/2: 100%|██████████| 313/313 [00:00<00:00, 564.99it/s, loss=0.0043] 
Epoch 2/2: 100%|██████████| 313/313 [00:00<00:00, 545.30it/s, loss=0.00293] 


base-B [0.5651999711990356, 0.9976999759674072, 1.0]


Epoch 1/2: 100%|██████████| 313/313 [00:00<00:00, 535.64it/s, loss=0.0224]
Epoch 2/2: 100%|██████████| 313/313 [00:00<00:00, 541.09it/s, loss=0.00482]


base-lu-A [0.9928999543190002, 0.6631999611854553, 1.0]


Epoch 1/2: 100%|██████████| 313/313 [00:00<00:00, 556.90it/s, loss=0.018]  
Epoch 2/2: 100%|██████████| 313/313 [00:00<00:00, 617.81it/s, loss=0.00591] 


base-lu-B [0.446399986743927, 0.9975000023841858, 1.0]


## Visualizations
We visualize decision boundaries learned and the resulting model.

In [5]:
base_dir = Path("./gmm3_figures")
base_dir.mkdir(exist_ok=True)

X, y = construct_dataset(learn_A=True, learn_B=True)
for name in model_checkpoints:
    visualize(
        name,
        X,
        y,
        n_grid=100,
        n_samples=5000,
        output_path=base_dir / f"{name}.png",
    )
    plt.close()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
