<a href="https://colab.research.google.com/github/Latesh99/MUL/blob/main/MU_GRU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import torch.nn as nn

# Configuring the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
learning_rate = 0.001
epochs = 10  # Increased epochs for better learning
batch_size = 64
hidden_size = 128  # Size of GRU hidden layer
num_layers = 2  # Number of GRU layers

# Data augmentation for better generalization
transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomAffine(0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Function to remove a specific digit
def filter_digit(dataset, digit_to_remove):
    indices = [i for i, (image, label) in enumerate(dataset) if label != digit_to_remove]
    filtered_dataset = Subset(dataset, indices)
    return filtered_dataset

# Define GRU model with dropout and regularization
class GRUModel(nn.Module):
    def __init__(self, input_size=28, hidden_size=128, num_layers=2, num_classes=10):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, dropout=0.2)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = x.squeeze(1)  # Remove channel dimension
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, _ = self.gru(x, h0)
        out = self.fc(out[:, -1, :])  # Take last time step output
        return out

# Set up the model
def create_model():
    model = GRUModel().to(device)
    return model

# Train function with learning rate scheduling
def train_model(model, loader, criterion, optimizer, scheduler, history):
    model.train()
    total_loss = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        images = images.view(-1, 28, 28)  # Reshape to (batch, sequence, input_size)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step(loss)  # Adjust learning rate
    history.append(total_loss / len(loader))

# Evaluation function
def evaluate_model(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            images = images.view(-1, 28, 28)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# Model, loss, optimizer, and scheduler
train_loss_history = []
model = create_model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)  # L2 Regularization
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

# Initial accuracy
initial_accuracy = evaluate_model(model, test_loader)
print(f"Initial accuracy: {initial_accuracy:.4f}")

# Train the model
for epoch in range(epochs):
    train_model(model, train_loader, criterion, optimizer, scheduler, train_loss_history)

trained_accuracy = evaluate_model(model, test_loader)
print(f"Trained accuracy: {trained_accuracy:.4f}")

# Remove a specific digit
digit_to_unlearn = int(input("Enter the digit to unlearn (0-9): "))
filtered_train_dataset = filter_digit(train_dataset, digit_to_unlearn)
filtered_train_loader = DataLoader(filtered_train_dataset, batch_size=batch_size, shuffle=True)

# Retrain without the digit (fine-tune with lower learning rate)
unlearn_loss_history = []
model = create_model()
optimizer = optim.Adam(model.parameters(), lr=learning_rate * 0.1, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

for epoch in range(epochs):
    train_model(model, filtered_train_loader, criterion, optimizer, scheduler, unlearn_loss_history)

unlearned_accuracy = evaluate_model(model, test_loader)
print(f"Accuracy after unlearning digit {digit_to_unlearn}: {unlearned_accuracy:.4f}")

# Relearn the full dataset
relearn_loss_history = []
model = create_model()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

for epoch in range(epochs):
    train_model(model, train_loader, criterion, optimizer, scheduler, relearn_loss_history)

relearned_accuracy = evaluate_model(model, test_loader)
print(f"Relearned accuracy: {relearned_accuracy:.4f}")

# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(range(epochs), train_loss_history, label='Training Loss')
plt.plot(range(epochs), unlearn_loss_history, label='Unlearning Loss')
plt.plot(range(epochs), relearn_loss_history, label='Relearning Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Curve During Training, Unlearning, and Relearning')
plt.legend()
plt.show()
