
# Final layer updation of a pre-trained model, if new classes come in the target dataset



In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import random

# Define a simple model for MNIST
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# MAML Meta-training step function
def meta_train_step(model, meta_optimizer, task_loader, inner_lr=0.01, num_inner_steps=1, device='cpu'):
    meta_optimizer.zero_grad()
    original_params = {name: param.clone() for name, param in model.named_parameters()}

    for inputs, targets in task_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        # Copy the model for inner-loop updates
        temp_model = SimpleMLP().to(device)
        temp_model.load_state_dict(model.state_dict())
        inner_optimizer = optim.SGD(temp_model.parameters(), lr=inner_lr)

        # Inner-loop training on task-specific data
        for _ in range(num_inner_steps):
            inner_optimizer.zero_grad()
            outputs = temp_model(inputs)
            loss = nn.CrossEntropyLoss()(outputs, targets)
            loss.backward()
            inner_optimizer.step()

        # Compute meta-gradient
        for name, param in model.named_parameters():
            meta_grad = param - temp_model.state_dict()[name]
            if param.grad is None:
                param.grad = meta_grad / len(task_loader)
            else:
                param.grad += meta_grad / len(task_loader)

    # Meta-optimization step
    meta_optimizer.step()

# Function to fine-tune on new classes
def fine_tune(model, train_loader, optimizer, device='cpu'):
    model.train()
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer.step()

# Evaluation function
def evaluate(model, test_loader, device='cpu'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    return correct / total



In [4]:

# Data Preparation
transform = transforms.Compose([transforms.ToTensor()])
train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create meta-task datasets by sampling from classes 0-7
meta_tasks = []
for class_pair in [(0,1), (2,3), (4,5), (6,7)]:
    indices = [i for i, t in enumerate(train_set.targets) if t in class_pair]
    meta_tasks.append(DataLoader(Subset(train_set, indices), batch_size=64, shuffle=True))

# Test loaders for classes 0-7 and classes 8-9
test_idx_0_7 = [i for i, t in enumerate(test_set.targets) if t < 8]
test_loader_0_7 = DataLoader(Subset(test_set, test_idx_0_7), batch_size=64, shuffle=False)

test_idx_8_9 = [i for i, t in enumerate(test_set.targets) if t >= 8]
test_loader_8_9 = DataLoader(Subset(test_set, test_idx_8_9), batch_size=64, shuffle=False)

test_loader_all = DataLoader(test_set, batch_size=64, shuffle=False)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model and meta-optimizer
model = SimpleMLP().to(device)
meta_optimizer = optim.Adam(model.parameters(), lr=0.001)

# Meta-training phase on classes 0-7
for epoch in range(10):
    for task_loader in meta_tasks:
        meta_train_step(model, meta_optimizer, task_loader, inner_lr=0.01, num_inner_steps=1, device=device)

accuracy_0_7 = evaluate(model, test_loader_0_7, device)
print(f'Accuracy on classes 0-7 after meta-training: {accuracy_0_7:.4f}')


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 54.7MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.84MB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 13.6MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 9.75MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






Accuracy on classes 0-7 after meta-training: 0.8514


In [5]:

# Fine-tune on new classes 8-9
train_idx_8_9 = [i for i, t in enumerate(train_set.targets) if t >= 8]
train_loader_8_9 = DataLoader(Subset(train_set, train_idx_8_9), batch_size=64, shuffle=True)
fine_tune(model, train_loader_8_9, meta_optimizer, device=device)

# Evaluate the model on different test sets
accuracy_0_7_after = evaluate(model, test_loader_0_7, device)
accuracy_8_9 = evaluate(model, test_loader_8_9, device)
accuracy_all = evaluate(model, test_loader_all, device)

print(f'MAML Accuracy on classes 0-7 after fine-tuning on 8-9: {accuracy_0_7_after:.4f}')
print(f'MAML Accuracy on classes 8-9 after fine-tuning: {accuracy_8_9:.4f}')
print(f'MAML Overall accuracy after fine-tuning: {accuracy_all:.4f}')

MAML Accuracy on classes 0-7 after fine-tuning on 8-9: 0.0000
MAML Accuracy on classes 8-9 after fine-tuning: 0.9803
MAML Overall accuracy after fine-tuning: 0.1944
