In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

from sklearn.metrics import confusion_matrix
import numpy as np

from tqdm import tqdm

  Referenced from: <CFED5F8E-EC3F-36FD-AAA3-2C6C7F8D3DD9> /opt/anaconda3/envs/rai24-iitm-project/lib/python3.11/site-packages/torchvision/image.so
  warn(


In [2]:
# DEVICE = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
DEVICE = torch.device('cpu')
DEVICE

device(type='cpu')

In [3]:
# CNN architecture
class SimpleCNN(nn.Module):
    def __init__(self, mask=None):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)
        self.mask = mask

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        if self.mask is not None:
            x = x * self.mask
        x = self.fc2(x)
        return x
    
# Function to evaluate the model
def evaluate_model(model, test_loader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
    accuracy = correct / len(test_loader.dataset)
    average_loss = total_loss / len(test_loader)
    return accuracy, average_loss

# Function to print confusion matrix
def print_cf(model, test_loader):
    model.eval()
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            all_labels.extend(labels.numpy())
            all_predictions.extend(predicted.numpy())

    # Generate confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    print(f"Confusion Matrix:\n{cm}")

In [4]:
# Transformations for the training and testing sets
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Loading MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Creating data loaders
train_loader = DataLoader(train_dataset, batch_size=200, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=200, shuffle=False)

# Sending data to DEVICE
for images, labels in train_loader:
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)
for images, labels in test_loader:
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)

# Model, loss function and optimizer initialization
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

### Model Training

In [5]:
# Model training
def train_model(model, train_loader, criterion, optimizer, epochs=5, save_model=False):
    model.train()
    for epoch in range(epochs):
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}')
    if save_model:
        torch.save(model.state_dict(), './models/mnist_cnn.pth')

train_model(model, train_loader, criterion, optimizer, epochs=1, save_model=True)

Epoch 1/1, Loss: 0.0463


In [6]:
# Evaluating trained model
learned_accuracy, learned_avg_loss = evaluate_model(model, test_loader, criterion)
print(f'Learned Model Accuracy: {learned_accuracy:.4f}\nLearned Model Average Loss: {learned_avg_loss:.4f}')
print_cf(model, test_loader)

Learned Model Accuracy: 0.9818
Learned Model Average Loss: 0.0534
Confusion Matrix:
[[ 974    0    0    0    0    3    2    1    0    0]
 [   0 1126    5    1    0    0    0    2    1    0]
 [   1    0 1013    0    0    0    0   15    3    0]
 [   0    0    2  988    0   12    0    4    3    1]
 [   0    1    5    0  967    0    0    2    0    7]
 [   0    0    2    5    0  883    2    0    0    0]
 [   8    4    5    0    2    5  933    0    1    0]
 [   0    1    4    1    0    1    0 1018    2    1]
 [   8    1    4    1    2    4    0    6  936   12]
 [   1    0    0    0    9    8    0    9    2  980]]


### Machine Unlearning

In [7]:
# Loading trained model
trained_model = SimpleCNN()
trained_model.load_state_dict(torch.load('./models/mnist_cnn.pth'))
trained_model.to(DEVICE)

# Evaluating trained model
learned_accuracy, learned_avg_loss = evaluate_model(trained_model, test_loader, criterion)
print(f'Learned Model Accuracy: {learned_accuracy:.4f}\nLearned Model Average Loss: {learned_avg_loss:.4f}')
print_cf(model, test_loader)

Learned Model Accuracy: 0.9818
Learned Model Average Loss: 0.0534
Confusion Matrix:
[[ 974    0    0    0    0    3    2    1    0    0]
 [   0 1126    5    1    0    0    0    2    1    0]
 [   1    0 1013    0    0    0    0   15    3    0]
 [   0    0    2  988    0   12    0    4    3    1]
 [   0    1    5    0  967    0    0    2    0    7]
 [   0    0    2    5    0  883    2    0    0    0]
 [   8    4    5    0    2    5  933    0    1    0]
 [   0    1    4    1    0    1    0 1018    2    1]
 [   8    1    4    1    2    4    0    6  936   12]
 [   1    0    0    0    9    8    0    9    2  980]]


In [8]:
# Choosing target class ('5') and obtaning its indices
target_class = 5
target_indices = [i for i, label in enumerate(train_dataset.targets) if label == target_class]
target_loader = DataLoader(Subset(train_dataset, target_indices), batch_size=1, shuffle=True)

for images, labels in target_loader:
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)

In [9]:
# Function to identify neurons to mask
def identify_neurons_to_mask(model, data_loader, layer_name, threshold=0.5):
    layer_activations = []
    
    def hook_fn(module, input, output):
        layer_activations.append(output.detach().cpu())

    handle = getattr(model, layer_name).register_forward_hook(hook_fn)

    # Forward pass on filtered data
    model.eval()
    with torch.no_grad():
        for images, _ in tqdm(data_loader):
            images = images.to(DEVICE)
            model(images)

    handle.remove()
    
    # Aggregate activations
    activations = torch.cat(layer_activations, dim=0)
    avg_activation = torch.mean(activations, dim=0)
    
    # Identifying neurons to mask
    mask = avg_activation < threshold
    return mask

# Identifying neurons to mask in the first fully connected layer
mask = identify_neurons_to_mask(trained_model, target_loader, 'fc1')
unl_model = SimpleCNN(mask=mask).to(DEVICE)
unl_model.load_state_dict(trained_model.state_dict())

100%|██████████| 5421/5421 [00:27<00:00, 199.58it/s]


<All keys matched successfully>

In [10]:
# # Function to unlearn target class using its data points
# def unlearn_data_points(model, data_loader, data_indices, criterion, optimizer, unlearning_rate=0.001):
#     unl_model = SimpleCNN()
#     unl_model.load_state_dict(model.state_dict())
#     unl_model.train()
#     for idx in tqdm(data_indices):
#         images, labels = data_loader.dataset[idx]
#         images = images.unsqueeze(0)  # Add batch dimension
#         labels = torch.tensor([labels])
        
#         optimizer.zero_grad()
#         outputs = unl_model(images)
#         loss = criterion(outputs, labels)
#         loss.backward()
        
#         # Negative gradient to "unlearn"
#         with torch.no_grad():
#             for param in unl_model.parameters():
#                 param -= unlearning_rate * param.grad
#     return unl_model

# unl_model = unlearn_data_points(trained_model, train_loader, target_indices, criterion, optimizer)

In [11]:
# Evaluating unlearned model
accuracy, avg_loss = evaluate_model(unl_model, test_loader, criterion)
print(f'Accuracy after unlearning: {accuracy:.4f}\nAverage Loss: {avg_loss:.4f}')
print_cf(unl_model, test_loader)

Accuracy after unlearning: 0.8612
Average Loss: 0.4189
Confusion Matrix:
[[ 976    0    1    0    0    0    0    3    0    0]
 [   0 1126    5    0    0    0    0    3    1    0]
 [   1    1 1018    0    0    0    0   12    0    0]
 [   0    0   85  891    0    0    0   29    4    1]
 [   0    3    5    0  972    0    0    2    0    0]
 [  77  154   79  139  162    0  104   73   93   11]
 [  14    5   11    0   10    0  918    0    0    0]
 [   0    1    4    0    0    0    0 1023    0    0]
 [  57    0   49    0    8    0    0   13  843    4]
 [   2    2   15    0   56    0    0   85    4  845]]
