In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import numpy as np
from torchvision import transforms


In [2]:
if torch.backends.mps.is_available():
    device = 'mps'
elif torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64*8*8, 512)
        self.fc2 = nn.Linear(512, 2)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64*8*8)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return nn.functional.log_softmax(x, dim=1)


In [41]:
def train(model, train_loader, optimizer, ewc_lambda, precision_matrices,prior_means):
    print(len(train_loader))
    model.train()
    loss_fn = nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(train_loader):
        if target[0]==3 or target[0]==5:
            target = torch.where(target == 3, torch.tensor(1), torch.tensor(0))

        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        for name, param in model.named_parameters():
            if name in precision_matrices:
                fisher = precision_matrices[name]
                ewc_loss = (fisher * (param - prior_means[name])**2).sum()
                loss += ewc_lambda * ewc_loss
        loss.backward()
        optimizer.step()
        
def train_continuous(model, train_loader, optimizer, ewc_lambda, n_epochs, n_tasks):
    model.prev_means = {}
    model.prec_matrices = {}
    for task in range(n_tasks):
        print("Task:", task)
        if task == 0:
            train(model, train_loader[task], optimizer, ewc_lambda, {},model.prev_means)
        else:
            print("here task number", task)
            train_loader_new = train_loader[task]
            train_loader_prev = []
            for i in range(task):
                train_loader_prev.append(train_loader[i])
            precision_matrices = compute_fisher(model, train_loader_prev)
            model.prev_means = {}
            for name, param in model.named_parameters():
                model.prev_means[name] = param.data.clone()
            train(model, train_loader_new, optimizer, ewc_lambda, precision_matrices,model.prev_means)
            model.prec_matrices = {}
            for name, param in model.named_parameters():
                if name in precision_matrices:
                    model.prec_matrices[name] = precision_matrices[name].clone()
                    
def compute_fisher(model, data_loader):
    data_loader = data_loader[0]
    precision_matrices = {}
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(data_loader):
        output = model(data)
        loss = loss_fn(output, target)
        model.zero_grad()
        loss.backward()
        for name, param in model.named_parameters():
            if param.requires_grad:
                fisher = (param.grad.detach()**2)
                if name in precision_matrices:
                    precision_matrices[name] += fisher
                else:
                    precision_matrices[name] = fisher
    for name, param in model.named_parameters():
        if name in precision_matrices:
            precision_matrices[name] /= len(data_loader)
    return precision_matrices


In [5]:
import torch
import torchvision
from torch.utils.data import DataLoader, Subset

# Define the transform to normalize the data
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the entire CIFAR10 dataset
full_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create a subset of the dataset containing only classes 0 and 1
indices = torch.where((torch.tensor(full_dataset.targets) == 0) | (torch.tensor(full_dataset.targets) == 1))[0]
subset_dataset = Subset(full_dataset, indices)

# Create a dataloader for the subset dataset
train_loader_task1 = DataLoader(subset_dataset, batch_size=128, shuffle=True)

indices = torch.where((torch.tensor(test_dataset.targets) == 0) | (torch.tensor(test_dataset.targets) == 1))[0]
subset_dataset = Subset(test_dataset, indices)

# Create a dataloader for the subset dataset
test_loader_task1 = DataLoader(subset_dataset, batch_size=128, shuffle=True)

indices = torch.where((torch.tensor(full_dataset.targets) == 3) | (torch.tensor(full_dataset.targets) == 5))[0]
subset_dataset = Subset(full_dataset, indices)

# Create a dataloader for the subset dataset
train_loader_task2 = DataLoader(subset_dataset, batch_size=128, shuffle=True)


indices = torch.where((torch.tensor(test_dataset.targets) == 3) | (torch.tensor(test_dataset.targets) == 5))[0]
subset_dataset = Subset(test_dataset, indices)

# Create a dataloader for the subset dataset
test_loader_task2 = DataLoader(subset_dataset, batch_size=128, shuffle=True)





Files already downloaded and verified
Files already downloaded and verified


In [34]:
b

tensor([1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0,
        0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1,
        0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1,
        1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0,
        1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0,
        0, 0, 1, 1, 1, 1, 0, 0])

In [25]:
subset_dataset.targets

AttributeError: ignored

In [43]:
def test(model, test_loaders, n_tasks):
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    with torch.no_grad():
        for task in range(n_tasks):
            total_loss = 0
            total_correct = 0
            total_samples = 0
            for data, target in test_loaders[task]:
                if target[0]==3 or target[0]==5:
                  target = torch.where(target == 3, torch.tensor(1), torch.tensor(0))

                output = model(data)
                loss = loss_fn(output, target)
                total_loss += loss.item() * data.size(0)
                _, predicted = torch.max(output.data, 1)
                total_correct += (predicted == target).sum().item()
                total_samples += data.size(0)
            print("Task:", task, "Loss:", total_loss / total_samples, "Accuracy:", total_correct / total_samples)


In [None]:
n_epochs = 10
n_tasks = 2
lr = 0.001
ewc_lambda = 0.5

model = CNN()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Train on Task 1
train_continuous(model, [train_loader_task1, train_loader_task2], optimizer, ewc_lambda, n_epochs, n_tasks)
#test(model, [test_loader_task1, test_loader_task2], n_tasks)

In [44]:
test(model, [test_loader_task1, test_loader_task2], n_tasks)

Task: 0 Loss: 0.6429721341133118 Accuracy: 0.6125
Task: 1 Loss: 0.6171479921340942 Accuracy: 0.6595
