<a href="https://colab.research.google.com/github/XinGuu/Gradient-Subspace-Distance/blob/main/gsd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Prepare public and private datasets

In [None]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset, Dataset

n_samples = 2000
dataset_slice = list(range(n_samples))

transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

num_classes = 10
private_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
private_dataset = Subset(private_dataset, dataset_slice)
private_loader = DataLoader(private_dataset, batch_size=n_samples, shuffle=False)

public_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
public_dataset = Subset(public_dataset, dataset_slice)
public_loader = DataLoader(public_dataset, batch_size=n_samples, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


## Prepare model

In [None]:
from torchvision import models
import torch.nn as nn
from opacus import GradSampleModule


model = models.resnet152(pretrained=True)
for name, param in model.named_parameters():
    param.requires_grad = False

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)

model = GradSampleModule(model)
model = model.to(device)

## Compute per-sample gradients

In [None]:
def flatten_tensor(tensor_list):
    for i in range(len(tensor_list)):
        tensor_list[i] = tensor_list[i].reshape([tensor_list[i].shape[0], -1])
    flatten_param = torch.cat(tensor_list, dim=1)
    del tensor_list
    return flatten_param


def compute_per_sample_grads(model, dataloader, criterion):
    model.eval()
    model.zero_grad()

    per_sample_grads = None

    for step, batch in enumerate(dataloader):
        cur_batch_grad_list = []
        batch = tuple(t.to(device) for t in batch)
        inputs, labels = batch
        labels = torch.randint(high=num_classes, size=labels.shape, device=labels.device)
        pred = model(inputs)
        loss = criterion(pred, labels)
        loss.backward()

        for p in filter(lambda p: p.requires_grad, model.parameters()):
            cur_batch_grad_list.append(p.grad_sample.reshape(p.grad_sample.shape[0], -1))
            del p.grad_sample, p.grad

        if per_sample_grads is None:
            per_sample_grads = flatten_tensor(cur_batch_grad_list)
        else:
            per_sample_grads = torch.vstack((per_sample_grads, flatten_tensor(cur_batch_grad_list)))

    return per_sample_grads  # n x p


criterion = nn.CrossEntropyLoss()
private_per_sample_grads = compute_per_sample_grads(model, private_loader, criterion)
public_per_sample_grads = compute_per_sample_grads(model, public_loader, criterion)

## Compute GSD

In [None]:
import numpy as np

k = 16

_, _, Vh_private = torch.linalg.svd(private_per_sample_grads, full_matrices=False)
V_private_k = Vh_private.conj().transpose(-2, -1)[:, :k]

_, _, Vh_public = torch.linalg.svd(public_per_sample_grads, full_matrices=False)
V_public_k = Vh_public.conj().transpose(-2, -1)[:, :k]

def principle_angle(v1, v2):
    u, s, vh = torch.linalg.svd(v1.conj().transpose(-2, -1) @ v2)
    return s


def projection_metric(v1, v2):
    angles = principle_angle(v1, v2)
    return torch.sqrt(len(angles) - torch.sum(angles ** 2)) / np.sqrt(len(angles))

gsd = projection_metric(V_private_k, V_public_k).item()
print(f"GSD: {gsd:.4f}")