<a href="https://colab.research.google.com/github/XinGuu/Gradient-Subspace-Distance/blob/main/gsd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Prepare public and private datasets

In [None]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset, Dataset

n_samples = 2000
batch_size = 200

transform = transforms.Compose([
    transforms.ToTensor(),
])

num_classes = 10
private_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
private_dataset = Subset(private_dataset, torch.randperm(len(private_dataset))[:n_samples])
private_loader = DataLoader(private_dataset, batch_size=batch_size, shuffle=False)

public_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
public_dataset = Subset(public_dataset, torch.randperm(len(public_dataset))[:n_samples])
public_loader = DataLoader(public_dataset, batch_size=batch_size, shuffle=False)

## Prepare model

In [None]:
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F


class TinyCNN(nn.Module):
    def __init__(self, num_classes: int = 10, dense_size: int = 32):
        super().__init__()

        self.backbone = nn.Sequential(
            nn.Sequential(
                nn.Conv2d(3, 16, 3, padding=1),  # (3,H,W) -> (16,H,W)
                nn.GroupNorm(8, 16),
                nn.ReLU()
            ),
            nn.Sequential(
                nn.Conv2d(16, 32, 3, padding=1, stride=2),  # -> (32,H/2,W/2)
                nn.GroupNorm(8, 32),
                nn.ReLU()
            ),
            nn.Sequential(
                nn.Conv2d(32, 64, 3, padding=1, stride=2),  # -> (64,H/4,W/4)
                nn.GroupNorm(16, 64),
                nn.ReLU()
            ),
            nn.AdaptiveAvgPool2d(2),  # (64,2,2)
            nn.Flatten(),
        )

        self.head = nn.Sequential(
            nn.Linear(64 * 2 * 2, dense_size),
            nn.ReLU(),
            nn.Linear(dense_size, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.head(features)

model = TinyCNN(num_classes)

model = model.to(device)

## Compute per-sample gradients

In [None]:
from torch.func import grad, vmap


def compute_per_sample_grads(model, dataloader, criterion):
    model.eval()
    model.zero_grad()

    per_sample_grads = None

    # Gather all parameters that require gradients
    named_parameters = [(name, p) for name, p in model.named_parameters() if p.requires_grad]

    flat_params = torch.cat([p.flatten() for _, p in named_parameters])

    def reconstruct_named_params(flat_params, named_parameters):
        offset = 0
        ndict = {}
        for name, p in named_parameters:
            numel = p.numel()
            ndict[name] = flat_params[offset:offset + numel].view_as(p)
            offset += numel
        return ndict

    def loss_fn(params, X, Y):
        named_params_map = reconstruct_named_params(params, named_parameters)
        preds = torch.func.functional_call(model, named_params_map, (X,))
        loss = criterion(preds, Y)
        return loss

    # Precompute the gradient function once
    grad_fn = vmap(grad(loss_fn), in_dims=(None, 0, 0))

    for step, batch in enumerate(dataloader):
        batch = tuple(t.to(device) for t in batch)
        inputs, labels = batch
        labels = torch.randint(high=num_classes, size=labels.shape, device=labels.device)
        inputs = inputs.unsqueeze(1)
        labels = labels.unsqueeze(1)

        if per_sample_grads is None:
            per_sample_grads = grad_fn(flat_params, inputs, labels)
        else:
            per_sample_grads = torch.cat([per_sample_grads, grad_fn(flat_params, inputs, labels)], dim=0)

    return per_sample_grads  # n x p


criterion = nn.CrossEntropyLoss(reduction='sum')
private_per_sample_grads = compute_per_sample_grads(model, private_loader, criterion)
public_per_sample_grads = compute_per_sample_grads(model, public_loader, criterion)

## Compute GSD

In [None]:
import numpy as np

k = 16

_, _, Vh_private = torch.linalg.svd(private_per_sample_grads, full_matrices=False)
V_private_k = Vh_private.conj().transpose(-2, -1)[:, :k]

_, _, Vh_public = torch.linalg.svd(public_per_sample_grads, full_matrices=False)
V_public_k = Vh_public.conj().transpose(-2, -1)[:, :k]

def principle_angle(v1, v2):
    u, s, vh = torch.linalg.svd(v1.conj().transpose(-2, -1) @ v2)
    return s


def projection_metric(v1, v2):
    angles = principle_angle(v1, v2)
    return torch.sqrt(len(angles) - torch.sum(angles ** 2)) / np.sqrt(len(angles))

gsd = projection_metric(V_private_k, V_public_k).item()
print(f"GSD: {gsd:.4f}")