In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision.models import resnet18
from torch.utils.data import DataLoader, random_split
from sklearn.cluster import KMeans
from torch.optim.lr_scheduler import StepLR

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def extract_features(model, loader):
    """Extracts features from the penultimate layer of the model for the given loader."""
    model.eval()
    features_list = []
    labels_list = []

    # Remove the final classification layer to get features
    feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])

    with torch.no_grad():
        for images, lbls in loader:
            images = images.to(DEVICE)
            features = feature_extractor(images)
            features = features.view(features.size(0), -1).cpu().numpy()
            features_list.extend(features)
            labels_list.extend(lbls.cpu().numpy())

    return features_list, labels_list

In [None]:
def group_data_old(loader, n_clusters):
    """Groups the data into 'n_clusters' clusters using KMeans clustering."""
    data_list, labels = [], []

    for sample in loader:
        images, lbls = sample
        data_list.extend(images.view(images.size(0), -1).cpu().numpy())
        labels.extend(lbls.cpu().numpy())

    kmeans = KMeans(n_clusters=n_clusters, n_init=10).fit(data_list)
    clusters = kmeans.predict(data_list)

    data_tensor = torch.tensor(data_list).view(-1, 3, 32, 32)
    return data_tensor, labels, clusters


In [None]:
def group_data(model, loader, n_clusters=100):
    """Groups the data into 'n_clusters' clusters using KMeans clustering on the extracted features."""
    # Extract features from the model
    features, labels = extract_features(model, loader)

    # Use KMeans to cluster the extracted features
    kmeans = KMeans(n_clusters=n_clusters, n_init=10).fit(features)
    clusters = kmeans.predict(features)

    # Convert clusters and labels to tensors
    clusters = torch.tensor(clusters, dtype=torch.long)

    # Get the data tensor
    data_list = [sample[0] for sample in loader.dataset]
    data_tensor = torch.stack(data_list).view(-1, 3, 32, 32)

    return data_tensor, labels, clusters

In [None]:
def unlearning(net, retain_loader, forget_loader, n_clusters=100):
    # Not vectorized!!!!
    data, labels, clusters = group_data(net, retain_loader, n_clusters)
    print("data.shape: ", data.shape)
    epochs = 1
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
    net.train()

    for ep in range(epochs):
        for inputs, targets in forget_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(inputs)
            random_targets = (targets + torch.randint(1, 10, targets.size(), device=DEVICE)) % 10
            loss = criterion(outputs, random_targets)
            loss.backward()
            optimizer.step()

            for idx, target in enumerate(targets):
              cluster_label = clusters[labels.index(target.item())]
              same_cluster_indices = [i for i, cluster in enumerate(clusters) if cluster == cluster_label]
              print("same_cluster_indices: ", same_cluster_indices)
              print("len(same_cluster_indices)", len(same_cluster_indices))
              same_cluster_data = data[same_cluster_indices].to(DEVICE)
              print("same_cluster_data.shape: ", same_cluster_data.shape)
              same_cluster_labels = torch.tensor([labels[i] for i in same_cluster_indices], device=DEVICE)

              optimizer.zero_grad()
              output = net(same_cluster_data)
              loss = criterion(output, same_cluster_labels)
              loss.backward()
              optimizer.step()
    net.eval()

In [None]:
def unlearning(net, retain_loader, forget_loader, validation_loader, n_clusters=100):
  # Vectorized.
    data, labels, clusters = group_data(net, retain_loader, n_clusters)
    data, clusters = data.to(DEVICE), clusters.to(DEVICE)

    epochs = 1
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
    # scheduler = StepLR(optimizer, step_size=10, gamma=0.7)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    net.train()

    for ep in range(epochs):
        for inputs, targets in forget_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            # First Optimization Phase
            optimizer.zero_grad()
            outputs = net(inputs)
            random_targets = (targets + torch.randint(1, 10, targets.size(), device=DEVICE)) % 10
            loss = criterion(outputs, random_targets)
            loss.backward()
            # Gradient Clipping
            nn.utils.clip_grad_norm_(net.parameters(), max_norm=1)
            optimizer.step()
            scheduler.step()

            # Map targets to cluster labels
            cluster_labels = torch.tensor([clusters[labels.index(t.item())] for t in targets], device=DEVICE)

            # Get a mask for each cluster label
            masks = [(clusters == label).nonzero(as_tuple=True)[0] for label in cluster_labels.unique()]


            for mask in masks:
                same_cluster_data = torch.index_select(data, 0, mask).to(DEVICE)
                same_cluster_labels = torch.tensor(labels, device=DEVICE)[mask]

                # Second Optimization Phase
                optimizer.zero_grad()
                output = net(same_cluster_data)
                loss = criterion(output, same_cluster_labels)
                loss.backward()
                # Gradient Clipping
                nn.utils.clip_grad_norm_(net.parameters(), max_norm=1)
                optimizer.step()
                scheduler.step()
        # Start validation phase
        net.eval()  # set the model to evaluation mode
        total_val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():  # no need to compute gradients during validation
            for inputs, targets in validation_loader:
                inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

                # Forward pass
                outputs = net(inputs)
                val_loss = criterion(outputs, targets)
                total_val_loss += val_loss.item()

                # Optionally, compute accuracy
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        avg_val_loss = total_val_loss / len(validation_loader)
        val_accuracy = 100. * correct / total

        print(f"Epoch {ep + 1}, Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
        net.train()  # set the model back to training mode

In [None]:
# Loading CIFAR-10 dataset
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

cifar10_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Splitting the CIFAR-10 training set into retain and forget subsets
retain_dataset, forget_dataset = random_split(cifar10_dataset, [40000, 10000])

retain_loader = DataLoader(retain_dataset, batch_size=64, shuffle=True, num_workers=4)
forget_loader = DataLoader(forget_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:10<00:00, 16032963.53it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
data_iter = iter(retain_loader)
images, labels = next(data_iter)
print(images.shape)

torch.Size([64, 3, 32, 32])


In [None]:
net = resnet18(pretrained=True) # Load pre-trained on ImageNet
# Adjust for CIFAR-10
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 10)
net.to(DEVICE)

In [None]:
unlearning(net, retain_loader, forget_loader, val_loader)

Epoch 1, Validation Loss: 4.8504, Validation Accuracy: 10.97%
