In [1]:
# In this cell I imported the pyTorch library and torchvision lib
import torch
import torchvision

print(torch.__version__)
print(torchvision.__version__)


2.9.0+cu126
0.24.0+cu126


In [2]:
#this cell loaded MNIST dataset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Transform: convert images to tensors
transform = transforms.ToTensor()

# Load MNIST training dataset
train_dataset = datasets.MNIST(
    root="./data",
    train=True,
#torchvision knows the offical download for MNIST and downloadtrue willexecuteit
    download=True,
    transform=transform
)

# DataLoader(grouping an shuffling the data)
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True
)

# Check one batch
images, labels = next(iter(train_loader))
print(images.shape)
print(labels[:10])


100%|██████████| 9.91M/9.91M [00:01<00:00, 5.03MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 131kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.24MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.29MB/s]

torch.Size([64, 1, 28, 28])
tensor([0, 6, 0, 7, 3, 8, 4, 2, 2, 9])





In [3]:
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 24 * 24, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # flatten
        x = F.relu(self.fc1(x))
        return self.fc2(x)


In [4]:
model = SimpleCNN()

test_images, _ = next(iter(train_loader))
outputs = model(test_images)

print(outputs.shape)


torch.Size([64, 10])


In [5]:
import torch.optim as optim

model = SimpleCNN()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [6]:
model.train()

for epoch in range(2):  # just 2 epochs for now
    total_loss = 0

    for images, labels in train_loader:
        optimizer.zero_grad()      # clear old gradients
        outputs = model(images)    # forward pass
        loss = criterion(outputs, labels)
        loss.backward()            # compute gradients
        optimizer.step()           # update weights

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")


Epoch 1, Loss: 126.6528
Epoch 2, Loss: 37.3111


In [7]:
# Load test dataset
test_dataset = datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=transform
)

test_loader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False
)

# Evaluation
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")


Test Accuracy: 98.72%


In [8]:
# Create a fixed permutation (a random shuffle of 784 pixels)
pixels = 28 * 28
perm = torch.randperm(pixels)

# Function to apply the permutation
def permute_transform(tensor):
    # Flatten image, shuffle pixels, then reshape back to 28x28
    return tensor.view(-1, pixels)[:, perm].view(1, 28, 28)

# Load the Task 2 (Scrambled) Dataset
transform_task2 = transforms.Compose([transforms.ToTensor(), permute_transform])

train_dataset2 = datasets.MNIST(root="./data", train=True, download=True, transform=transform_task2)
train_loader2 = DataLoader(train_dataset2, batch_size=64, shuffle=True)

test_dataset2 = datasets.MNIST(root="./data", train=False, download=True, transform=transform_task2)
test_loader2 = DataLoader(test_dataset2, batch_size=64, shuffle=False)

print("Task 2 (Permuted MNIST) is ready!")

Task 2 (Permuted MNIST) is ready!


In [9]:
# Train the existing model on Task 2 for 2 epochs
model.train()
for epoch in range(2):
    for images, labels in train_loader2:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f"Finished Task 2 Epoch {epoch+1}")

Finished Task 2 Epoch 1
Finished Task 2 Epoch 2


In [10]:
def get_accuracy(loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

acc_task1 = get_accuracy(test_loader)  # Normal MNIST
acc_task2 = get_accuracy(test_loader2) # Permuted MNIST

print(f"Accuracy on Task 1 (Original): {acc_task1:.2f}%")
print(f"Accuracy on Task 2 (Scrambled): {acc_task2:.2f}%")

Accuracy on Task 1 (Original): 95.43%
Accuracy on Task 2 (Scrambled): 97.77%


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define which hardware to use (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        # 64 channels * 24 * 24 is the size after two convolutions
        self.fc1 = nn.Linear(64 * 24 * 24, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # flatten
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# Define the Loss function
criterion = nn.CrossEntropyLoss()

Using device: cuda


In [13]:
# --- OPTIMIZED FAST REHEARSAL ---
rehearsal_model = SimpleCNN().to(device)
rehearsal_optimizer = torch.optim.Adam(rehearsal_model.parameters(), lr=0.001)

# Get a "Memory Buffer" (reusable reminders of Task 1)
# Make sure your train_dataset and train_loader2 are already loaded!
mem_loader = DataLoader(train_dataset, batch_size=200, shuffle=True)
mem_images, mem_labels = next(iter(mem_loader))
mem_images, mem_labels = mem_images.to(device), mem_labels.to(device)

rehearsal_model.train()
# Training on Task 2 while replaying Task 1
for i, (images, labels) in enumerate(train_loader2):
    images, labels = images.to(device), labels.to(device)
    rehearsal_optimizer.zero_grad()

    loss_new = criterion(rehearsal_model(images), labels)
    loss_old = criterion(rehearsal_model(mem_images), mem_labels)

    (loss_new + loss_old).backward()
    rehearsal_optimizer.step()

print("Rehearsal Training Complete!")

Rehearsal Training Complete!


In [14]:
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            outputs = model(imgs)
            _, predicted = torch.max(outputs, 1)
            total += lbls.size(0)
            correct += (predicted == lbls).sum().item()
    return 100 * correct / total

acc1 = evaluate(rehearsal_model, test_loader)
acc2 = evaluate(rehearsal_model, test_loader2)

print(f"Task 1 (Original) Accuracy: {acc1:.2f}%")
print(f"Task 2 (Permuted) Accuracy: {acc2:.2f}%")
print(f"Average Accuracy: {(acc1 + acc2)/2:.2f}%")
print(f"Forgetting Measure: {98.6 - acc1:.2f}%")

Task 1 (Original) Accuracy: 80.62%
Task 2 (Permuted) Accuracy: 96.57%
Average Accuracy: 88.59%
Forgetting Measure: 17.98%


In [15]:
# 1. Initialize a fresh model for this strategy
reg_model = SimpleCNN().to(device)
reg_optimizer = torch.optim.Adam(reg_model.parameters(), lr=0.001)

# 2. Train on Task 1 (Standard MNIST)
reg_model.train()
for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    reg_optimizer.zero_grad()
    loss = criterion(reg_model(images), labels)
    loss.backward()
    reg_optimizer.step()

# 3. SAVE THE WEIGHTS (This is our "anchor" for the penalty)
import copy
task1_weights = {}
for name, param in reg_model.named_parameters():
    task1_weights[name] = param.data.clone()

print("Task 1 finished and weights anchored!")

Task 1 finished and weights anchored!


In [16]:
# How strong should the 'lock' be?
importance_lambda = 0.5

reg_model.train()
for i, (images, labels) in enumerate(train_loader2):
    images, labels = images.to(device), labels.to(device)
    reg_optimizer.zero_grad()

    # Standard Loss for Task 2
    loss_new = criterion(reg_model(images), labels)

    # Regularization Penalty:
    # Compare current weights to Task 1 weights and penalize the difference
    reg_loss = 0
    for name, param in reg_model.named_parameters():
        reg_loss += torch.sum((param - task1_weights[name]) ** 2)

    # Total Loss = Learning New Task + Staying close to Old Weights
    total_loss = loss_new + (importance_lambda * reg_loss)

    total_loss.backward()
    reg_optimizer.step()

print("Regularization Training Complete!")

Regularization Training Complete!


In [17]:
reg_acc1 = evaluate(reg_model, test_loader)
reg_acc2 = evaluate(reg_model, test_loader2)

print(f"--- Regularization Strategy Results ---")
print(f"Task 1 (Original) Accuracy: {reg_acc1:.2f}%")
print(f"Task 2 (Permuted) Accuracy: {reg_acc2:.2f}%")
print(f"Average Accuracy: {(reg_acc1 + reg_acc2)/2:.2f}%")
print(f"Forgetting Measure: {98.6 - reg_acc1:.2f}%")

--- Regularization Strategy Results ---
Task 1 (Original) Accuracy: 91.43%
Task 2 (Permuted) Accuracy: 46.93%
Average Accuracy: 69.18%
Forgetting Measure: 7.17%


In [18]:
# 1. Save the Rehearsal Model (which had the best balance)
torch.save(rehearsal_model.state_dict(), 'rehearsal_model.pth')

# 2. Save the Regularization Model (which had the best forgetting score)
torch.save(reg_model.state_dict(), 'regularization_model.pth')

print("Success! Both models are saved in the Colab file folder.")

Success! Both models are saved in the Colab file folder.
