<a href="https://colab.research.google.com/github/Latesh99/gru_unlearn/blob/main/GRUModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import torch.nn as nn
import numpy as np

# Configuring the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
learning_rate = 0.001
epochs = 10
batch_size = 64
hidden_size = 128
num_layers = 2

# Data augmentation
transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomAffine(0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Function to remove a specific digit
def filter_digit(dataset, digit_to_remove):
    indices = [i for i, (image, label) in enumerate(dataset) if label != digit_to_remove]
    return Subset(dataset, indices)

# Define GRU model
class GRUModel(nn.Module):
    def __init__(self, input_size=28, hidden_size=128, num_layers=2, num_classes=10):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, dropout=0.2)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = x.squeeze(1)
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, _ = self.gru(x, h0)
        return self.fc(out[:, -1, :])

# Create model
def create_model():
    return GRUModel().to(device)

# Training function
def train_model(model, loader, criterion, optimizer, scheduler, history):
    model.train()
    total_loss = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        images = images.view(-1, 28, 28)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step(loss)
    history.append(total_loss / len(loader))

# Evaluation function
def evaluate_model(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            images = images.view(-1, 28, 28)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# Function to compute Forgetting Score
def forgetting_score(model, loader, digit):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            mask = labels == digit
            if mask.sum() == 0:
                continue
            images, labels = images[mask].to(device), labels[mask].to(device)
            images = images.view(-1, 28, 28)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    return 1 - (correct / total) if total > 0 else 1

# Function to compute Layer-wise Distance
def layerwise_distance(model_before, model_after):
    return np.mean([torch.norm(p1 - p2).item() for p1, p2 in zip(model_before.parameters(), model_after.parameters())])

# Function to compute Activation Drift
def activation_drift(model_before, model_after, loader):
    model_before.eval()
    model_after.eval()
    total_diff = 0
    count = 0
    with torch.no_grad():
        for images, _ in loader:
            images = images.to(device).view(-1, 28, 28)
            act_before = model_before(images)
            act_after = model_after(images)
            total_diff += torch.norm(act_before - act_after).item()
            count += 1
    return total_diff / count if count > 0 else 0

# Train the model
model = create_model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

for epoch in range(epochs):
    train_model(model, train_loader, criterion, optimizer, scheduler, [])

trained_accuracy = evaluate_model(model, test_loader)
print(f"Trained accuracy: {trained_accuracy:.4f}")

# Remove a specific digit
digit_to_unlearn = int(input("Enter the digit to unlearn (0-9): "))
filtered_train_loader = DataLoader(filter_digit(train_dataset, digit_to_unlearn), batch_size=batch_size, shuffle=True)

# Unlearning phase
unlearned_model = create_model()
optimizer = optim.Adam(unlearned_model.parameters(), lr=learning_rate * 0.1, weight_decay=1e-4)
for epoch in range(epochs):
    train_model(unlearned_model, filtered_train_loader, criterion, optimizer, scheduler, [])

unlearned_accuracy = evaluate_model(unlearned_model, test_loader)
print(f"Accuracy after unlearning digit {digit_to_unlearn}: {unlearned_accuracy:.4f}")

# Compute unlearning metrics
FS = forgetting_score(unlearned_model, test_loader, digit_to_unlearn)
lw_distance = layerwise_distance(model, unlearned_model)
ad = activation_drift(model, unlearned_model, test_loader)
print(f"Forgetting Score: {FS:.4f}")
print(f"Layer-wise Distance: {lw_distance:.4f}")
print(f"Activation Drift: {ad:.4f}")

# Relearning phase
relearned_model = create_model()
optimizer = optim.Adam(relearned_model.parameters(), lr=learning_rate, weight_decay=1e-4)
for epoch in range(epochs):
    train_model(relearned_model, train_loader, criterion, optimizer, scheduler, [])

relearned_accuracy = evaluate_model(relearned_model, test_loader)
relearning_efficiency = (relearned_accuracy - unlearned_accuracy) / (trained_accuracy - unlearned_accuracy)
print(f"Relearned accuracy: {relearned_accuracy:.4f}")
print(f"Relearning Efficiency: {relearning_efficiency:.4f}")


100%|██████████| 9.91M/9.91M [00:00<00:00, 128MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 35.1MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 91.6MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.11MB/s]


Trained accuracy: 0.9855
Enter the digit to unlearn (0-9): 5
Accuracy after unlearning digit 5: 0.8826
Forgetting Score: 1.0000
Layer-wise Distance: 8.8251
Activation Drift: 63.8981
Relearned accuracy: 0.9860
Relearning Efficiency: 1.0049
