
# Final layer updation of a pre-trained model, if new classes come in the target dataset



In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np

# Define a simple model for MNIST
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Elastic Weight Consolidation class
class EWC:
    def __init__(self, model, dataloader, device='cpu'):
        self.model = model
        self.device = device
        self.dataloader = dataloader
        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        self._precision_matrices = self.compute_fisher_information()
        self.saved_params = {n: p.clone().detach() for n, p in self.params.items()}

    def compute_fisher_information(self):
        fisher_matrices = {n: torch.zeros_like(p) for n, p in self.params.items()}
        self.model.eval()

        for inputs, targets in self.dataloader:
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            self.model.zero_grad()
            outputs = self.model(inputs)
            loss = nn.CrossEntropyLoss()(outputs, targets)
            loss.backward()

            for n, p in self.model.named_parameters():
                fisher_matrices[n] += p.grad**2 / len(self.dataloader)

        return fisher_matrices

    def penalty(self):
        loss = 0
        for n, p in self.model.named_parameters():
            fisher = self._precision_matrices[n]
            loss += (fisher * (p - self.saved_params[n]) ** 2).sum()
        return loss

# Function for training with EWC regularization
def train_ewc(model, train_loader, optimizer, ewc=None, lambda_ewc=0.4, device='cpu'):
    model.train()
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = nn.CrossEntropyLoss()(outputs, targets)

        if ewc:
            loss += lambda_ewc * ewc.penalty()

        loss.backward()
        optimizer.step()

# Evaluation function
def evaluate(model, test_loader, device='cpu'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    return correct / total


In [2]:

# Loading and preprocessing MNIST
transform = transforms.Compose([transforms.ToTensor()])
train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Filter classes 0-7 for initial training
train_idx_0_7 = [i for i, t in enumerate(train_set.targets) if t < 8]
train_loader_0_7 = DataLoader(Subset(train_set, train_idx_0_7), batch_size=64, shuffle=True)

# Filter classes 8-9 for new training
train_idx_8_9 = [i for i, t in enumerate(train_set.targets) if t >= 8]
train_loader_8_9 = DataLoader(Subset(train_set, train_idx_8_9), batch_size=64, shuffle=True)

# Test loaders for different classes
test_idx_0_7 = [i for i, t in enumerate(test_set.targets) if t < 8]
test_loader_0_7 = DataLoader(Subset(test_set, test_idx_0_7), batch_size=64, shuffle=False)

test_idx_8_9 = [i for i, t in enumerate(test_set.targets) if t >= 8]
test_loader_8_9 = DataLoader(Subset(test_set, test_idx_8_9), batch_size=64, shuffle=False)

test_loader_all = DataLoader(test_set, batch_size=64, shuffle=False)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model, optimizer
model = SimpleMLP().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Step 1: Train on classes 0-7
train_ewc(model, train_loader_0_7, optimizer, device=device)
accuracy_0_7 = evaluate(model, test_loader_0_7, device)
print(f'Accuracy on classes 0-7 after initial training: {accuracy_0_7:.4f}')


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:13<00:00, 762kB/s] 


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 135kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:01<00:00, 1.09MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 6.61MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Accuracy on classes 0-7 after initial training: 0.9521


In [4]:
# Set a higher lambda_ewc for stronger regularization
lambda_ewc = 5.0  # Increase to give higher priority to retaining initial knowledge
learning_rate = 0.001  # Reduce learning rate for EWC phase

# Reinitialize optimizer with a lower learning rate for EWC retraining phase
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

# Retrain on classes 8-9 with a higher EWC penalty
train_ewc(model, train_loader_8_9, optimizer, ewc=ewc, lambda_ewc=lambda_ewc, device=device)

# Re-evaluate the model
accuracy_0_7_after = evaluate(model, test_loader_0_7, device)
accuracy_8_9 = evaluate(model, test_loader_8_9, device)
accuracy_all = evaluate(model, test_loader_all, device)

print(f'Adjusted Accuracy on classes 0-7 after retraining on 8-9 with EWC: {accuracy_0_7_after:.4f}')
print(f'Adjusted Accuracy on classes 8-9 after retraining with EWC: {accuracy_8_9:.4f}')
print(f'Adjusted Overall accuracy after retraining with EWC: {accuracy_all:.4f}')


Adjusted Accuracy on classes 0-7 after retraining on 8-9 with EWC: 0.0000
Adjusted Accuracy on classes 8-9 after retraining with EWC: 0.9793
Adjusted Overall accuracy after retraining with EWC: 0.1942


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np

class EWC:
    def __init__(self, model, dataloader, device='cpu'):
        self.model = model
        self.device = device
        self.dataloader = dataloader
        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        self.fisher_info = self.compute_fisher_information()
        self.saved_params = {n: p.clone().detach() for n, p in self.params.items()}

    def compute_fisher_information(self):
        fisher_matrices = {n: torch.zeros_like(p) for n, p in self.params.items()}
        self.model.eval()

        for inputs, targets in self.dataloader:
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            self.model.zero_grad()
            outputs = self.model(inputs)
            loss = nn.CrossEntropyLoss()(outputs, targets)
            loss.backward()

            for n, p in self.model.named_parameters():
                fisher_matrices[n] += p.grad**2 / len(self.dataloader)

        return fisher_matrices

    def penalty(self):
        loss = 0
        for n, p in self.model.named_parameters():
            fisher = self.fisher_info[n]
            # Adaptive EWC: applying stronger penalties to weights with high Fisher values
            importance = fisher * (p - self.saved_params[n]) ** 2
            loss += torch.sum(importance)
        return loss

# Training with EWC regularization
def train_with_ewc(model, train_loader, optimizer, ewc=None, lambda_ewc=5.0, device='cpu'):
    model.train()
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = nn.CrossEntropyLoss()(outputs, targets)

        # Apply EWC penalty
        if ewc:
            loss += lambda_ewc * ewc.penalty()

        loss.backward()
        optimizer.step()



In [6]:

# Evaluation function remains the same as previous code

# Load MNIST and filter for classes 8-9
transform = transforms.Compose([transforms.ToTensor()])
train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create dataloaders for classes 0-7 and 8-9
train_idx_0_7 = [i for i, t in enumerate(train_set.targets) if t < 8]
train_loader_0_7 = DataLoader(Subset(train_set, train_idx_0_7), batch_size=64, shuffle=True)

train_idx_8_9 = [i for i, t in enumerate(train_set.targets) if t >= 8]
train_loader_8_9 = DataLoader(Subset(train_set, train_idx_8_9), batch_size=64, shuffle=True)

test_idx_0_7 = [i for i, t in enumerate(test_set.targets) if t < 8]
test_loader_0_7 = DataLoader(Subset(test_set, test_idx_0_7), batch_size=64, shuffle=False)

test_idx_8_9 = [i for i, t in enumerate(test_set.targets) if t >= 8]
test_loader_8_9 = DataLoader(Subset(test_set, test_idx_8_9), batch_size=64, shuffle=False)

test_loader_all = DataLoader(test_set, batch_size=64, shuffle=False)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model, optimizer, and Fisher Information for classes 0-7
model = SimpleMLP().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Step 1: Train on classes 0-7
train_with_ewc(model, train_loader_0_7, optimizer, device=device)
accuracy_0_7 = evaluate(model, test_loader_0_7, device)
print(f'Accuracy on classes 0-7 after initial training: {accuracy_0_7:.4f}')


Accuracy on classes 0-7 after initial training: 0.8699


In [7]:

# Step 2: Compute Fisher Information matrix for EWC
ewc = EWC(model, train_loader_0_7, device=device)

# Step 3: Retrain on classes 8-9 using adaptive EWC
train_with_ewc(model, train_loader_8_9, optimizer, ewc=ewc, lambda_ewc=5.0, device=device)

# Step 4: Evaluate the model on each test set
accuracy_0_7_after = evaluate(model, test_loader_0_7, device)
accuracy_8_9 = evaluate(model, test_loader_8_9, device)
accuracy_all = evaluate(model, test_loader_all, device)

print(f'Adaptive EWC Accuracy on classes 0-7 after retraining on 8-9: {accuracy_0_7_after:.4f}')
print(f'Adaptive EWC Accuracy on classes 8-9 after retraining: {accuracy_8_9:.4f}')
print(f'Adaptive EWC Overall accuracy after retraining: {accuracy_all:.4f}')

Adaptive EWC Accuracy on classes 0-7 after retraining on 8-9: 0.0000
Adaptive EWC Accuracy on classes 8-9 after retraining: 0.9465
Adaptive EWC Overall accuracy after retraining: 0.1877
