In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets
import numpy as np

#Buffer

In [2]:
class ReplayBuffer:
    def __init__(self,size):
        self.size=size
        self.buffer=[]
    def addSamples(self,X,y):
        if len(self.buffer)>=self.size:
            self.buffer=self.buffer[len(X):]
        self.buffer.extend(list(zip(X,y)))
    def getSamples(self):
        if len(self.buffer)==0:
            return None,None
        X,y=zip(*self.buffer)
        return torch.stack(X),torch.stack(y)

#HAT CNN

In [3]:
class HATCNN(nn.Module):
    def __init__(self,numTasks):
        super(HATCNN,self).__init__()
        self.num_tasks=numTasks
        self.conv1=nn.Conv2d(3,32,kernel_size=3,padding=1)
        self.conv2=nn.Conv2d(32,64,kernel_size=3,padding=1)
        self.pool=nn.MaxPool2d(2,2)
        self.fc1=nn.Linear(64*8*8,128)
        self.fc2=nn.Linear(128,2)
        self.masks = nn.Parameter(torch.ones(numTasks,128)) #Masking per Tasks
    def forward(self,x,task_id):
        x=self.pool(torch.relu(self.conv1(x)))#(T,32,16,16)
        x=self.pool(torch.relu(self.conv2(x))) #(T,64,8,8)
        x=x.view(x.size(0),-1) #(T,4096)
        mask=self.masks[task_id]
        x=torch.relu(self.fc1(x)*mask)
        return x

#WP CNN

In [4]:
class MultiTaskCNN(nn.Module):
    def __init__(self):
        super(MultiTaskCNN, self).__init__()
        self.task_heads = nn.ModuleList([
            nn.Linear(128,2),  #Task 0
            nn.Linear(128,2), #Task 1
            nn.Linear(128,2),  #Task 2
            nn.Linear(128,2),  #Task 3
            nn.Linear(128,2)])   #Task 4
    def forward(self, x, task_id):
        return self.task_heads[task_id](x)

#OOD CNN

In [5]:
class OODCNN(nn.Module):
    def __init__(self):
        super(OODCNN, self).__init__()
        self.task_heads = nn.ModuleList([
            nn.Linear(128,2),  #Task 0
            nn.Linear(128,2), #Task 1
            nn.Linear(128,2),  #Task 2
            nn.Linear(128,2),  #Task 3
            nn.Linear(128,2)])   #Task 4
    def forward(self, x, task_id):
        return self.task_heads[task_id](x)

#TRAINING OODs

In [7]:
import random
from torchvision import datasets, transforms
from torch.utils.data import Subset

def loadTaskDataset(trainData,taskClasses):
    taskIndices=[i for i,(_,label) in enumerate(trainData) if label in taskClasses]
    taskDataset=torch.utils.data.Subset(trainData,taskIndices)
    def remapLabels(data):
        images,labels=data
        labelMap={taskClasses[0]:0,taskClasses[1]:1}
        return images,labelMap[labels]
    remappedDataset=[(remapLabels(item)) for item in taskDataset]
    return DataLoader(remappedDataset,batch_size=64,shuffle=True)

def traingOODheadWithFeatures(trainData,taskClasses,oodModel,featureExtractor,epochs=10):
    criterion=nn.CrossEntropyLoss()
    optimizer=optim.Adam(list(featureExtractor.parameters())+list(oodModel.parameters()),lr=0.001)
    for taskNum,taskClasses in enumerate(classGroups):
        print(f"\nTraining on task {taskNum+1} with classes: {taskClasses}")
        taskLoader=loadTaskDataset(trainDataset,taskClasses)
        XTaskList,yTaskList=[],[]
        for images,labels in taskLoader:
            XTaskList.append(images)
            yTaskList.append(labels)
        XTask=torch.cat(XTaskList)
        yTask=torch.cat(yTaskList)
        for epoch in range(epochs):
            featureExtractor.train()
            oodModel.train()
            optimizer.zero_grad()
            features=featureExtractor(XTask,taskNum)
            outputs=oodModel(features,taskNum)
            loss=criterion(outputs,yTask)
            loss.backward()
            optimizer.step()
            print(f"Task {taskNum+1}, Epoch {epoch+1}, Loss: {loss.item()}")

transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
trainDataset=datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
testDataset=datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
classGroups=[[0,1],[2,3],[4,5],[6,7],[8,9]]
featureExtractor=HATCNN(numTasks=len(classGroups))
oodModel=OODCNN()
traingOODheadWithFeatures(trainDataset,classGroups,oodModel,featureExtractor,epochs=10)

Files already downloaded and verified
Files already downloaded and verified

Training on task 1 with classes: [0, 1]
Task 1, Epoch 1, Loss: 0.6945529580116272
Task 1, Epoch 2, Loss: 0.651195764541626
Task 1, Epoch 3, Loss: 0.652748167514801
Task 1, Epoch 4, Loss: 0.5800525546073914
Task 1, Epoch 5, Loss: 0.5939843654632568
Task 1, Epoch 6, Loss: 0.5417751669883728
Task 1, Epoch 7, Loss: 0.5321953296661377
Task 1, Epoch 8, Loss: 0.5126423239707947
Task 1, Epoch 9, Loss: 0.47330132126808167
Task 1, Epoch 10, Loss: 0.4722510278224945

Training on task 2 with classes: [2, 3]
Task 2, Epoch 1, Loss: 0.7303708791732788
Task 2, Epoch 2, Loss: 0.69655442237854
Task 2, Epoch 3, Loss: 0.6973191499710083
Task 2, Epoch 4, Loss: 0.6879715323448181
Task 2, Epoch 5, Loss: 0.6641501784324646
Task 2, Epoch 6, Loss: 0.6574621796607971
Task 2, Epoch 7, Loss: 0.6581306457519531
Task 2, Epoch 8, Loss: 0.6403805017471313
Task 2, Epoch 9, Loss: 0.6229217052459717
Task 2, Epoch 10, Loss: 0.6215378642082214

Tr

#TRAINING

In [9]:
def loadTaskDataset(trainData,taskClasses):
    taskIndices=[i for i,(_,label) in enumerate(trainData) if label in taskClasses]
    taskDataset=torch.utils.data.Subset(trainData,taskIndices)
    def remapLabels(data):
        images,labels=data
        labelMap={taskClasses[0]:0,taskClasses[1]:1}
        return images,labelMap[labels]
    remappedDataset=[(remapLabels(item)) for item in taskDataset]
    return DataLoader(remappedDataset,batch_size=64,shuffle=True)

def evaluateTaskAccuracy(model,featureExtractor,taskLoader,taskNum):
    model.eval()
    correct,total=0,0
    with torch.no_grad():
        for images,labels in taskLoader:
            features=featureExtractor(images,taskNum)
            outputs=model(features,taskNum).argmax(dim=1)
            correct+=(outputs==labels).sum().item()
            total+=labels.size(0)
    return correct/total if total>0 else 0

def trainIncrementalClassTasks(featureExtractor,wpModel,trainDataset,classGroups,epochs=10):
    criterion=nn.CrossEntropyLoss()
    optimizer=optim.Adam(list(wpModel.parameters()),lr=0.001)
    buffer=ReplayBuffer(size=1000)
    task_accuracies,avg_forgetting_rate=[],[]
    for taskNum,taskClasses in enumerate(classGroups):
        print(f"\nTraining on task {taskNum+1} with classes: {taskClasses}")
        taskLoader=loadTaskDataset(trainDataset,taskClasses)
        XTaskList,yTaskList=[],[]
        for images,labels in taskLoader:
            XTaskList.append(images)
            yTaskList.append(labels)
        XTask=torch.cat(XTaskList)
        yTask=torch.cat(yTaskList)
        XBuffer,yBuffer=buffer.getSamples()
        if XBuffer is not None and len(XBuffer)>0:
            XReplay,yReplay=XBuffer,yBuffer
        else:
            XReplay,yReplay=None,None
        for epoch in range(epochs):
            featureExtractor.train()
            wpModel.train()
            optimizer.zero_grad()
            features=featureExtractor(XTask,taskNum)
            outputs=wpModel(features,taskNum)
            loss=criterion(outputs,yTask)
            if XReplay is not None:
                features=featureExtractor(XReplay,taskNum)
                replayOutputs=wpModel(features,taskNum)
                replayLoss=criterion(replayOutputs,yReplay)
                loss+=replayLoss
            loss.backward()
            optimizer.step()
            print(f"Task {taskNum+1}, Epoch {epoch+1}, Loss: {loss.item()}")
        task_accuracy=evaluateTaskAccuracy(wpModel,featureExtractor,taskLoader,taskNum)
        print(f"Accuracy on task {taskNum+1}: {task_accuracy*100:.2f}%")
        task_accuracies.append([task_accuracy])
        for i in range(taskNum):
            previous_task_loader=loadTaskDataset(trainDataset,classGroups[i])
            accuracy_after_task=evaluateTaskAccuracy(wpModel,featureExtractor,previous_task_loader,taskNum)
            task_accuracies[i].append(accuracy_after_task)
            print(f"Accuracy on task {i+1} after learning task {taskNum+1}: {accuracy_after_task*100:.2f}%")
        if taskNum>0:
            forgetting_rates=[]
            for i in range(taskNum):
                initial_accuracy=task_accuracies[i][0]
                current_accuracy=task_accuracies[i][-1]
                forgetting=initial_accuracy-current_accuracy
                forgetting_rates.append(forgetting)
            avg_forgetting_rate.append(sum(forgetting_rates)/len(forgetting_rates))
            print(f"Average Forgetting Rate after task {taskNum+1}: {avg_forgetting_rate[-1]:.4f}")
        buffer.addSamples(XTask,yTask)
    aca=sum([task_accuracies[i][-1] for i in range(len(classGroups))])/len(classGroups)
    print(f"\nAverage Classification Accuracy (ACA) after the last task: {aca:.4f}")
    return aca,avg_forgetting_rate

transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
trainDataset=datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
testDataset=datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
classGroups=[[0,1],[2,3],[4,5],[6,7],[8,9]]
wpModel=MultiTaskCNN()
trainIncrementalClassTasks(featureExtractor,wpModel,trainDataset,classGroups)


Files already downloaded and verified
Files already downloaded and verified

Training on task 1 with classes: [0, 1]
Task 1, Epoch 1, Loss: 1.0068581104278564
Task 1, Epoch 2, Loss: 0.9093834161758423
Task 1, Epoch 3, Loss: 0.8244733214378357
Task 1, Epoch 4, Loss: 0.7548810243606567
Task 1, Epoch 5, Loss: 0.7029379606246948
Task 1, Epoch 6, Loss: 0.6696040034294128
Task 1, Epoch 7, Loss: 0.6534898281097412
Task 1, Epoch 8, Loss: 0.65055912733078
Task 1, Epoch 9, Loss: 0.655038058757782
Task 1, Epoch 10, Loss: 0.6612259745597839
Accuracy on task 1: 57.98%

Training on task 2 with classes: [2, 3]
Task 2, Epoch 1, Loss: 2.2999584674835205
Task 2, Epoch 2, Loss: 2.1008553504943848
Task 2, Epoch 3, Loss: 1.9156476259231567
Task 2, Epoch 4, Loss: 1.747509241104126
Task 2, Epoch 5, Loss: 1.5999573469161987
Task 2, Epoch 6, Loss: 1.4765948057174683
Task 2, Epoch 7, Loss: 1.380618929862976
Task 2, Epoch 8, Loss: 1.3140560388565063
Task 2, Epoch 9, Loss: 1.2767927646636963
Task 2, Epoch 10, Los

(0.58026, [-0.05630000000000002, 0.1829, 0.05316666666666666, 0.038925])

#Testing

In [None]:
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

In [10]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets

def loadTaskTestDataset(testData,taskClasses):
    taskIndices=[i for i,(_,label) in enumerate(testData) if label in taskClasses]
    taskDataset=torch.utils.data.Subset(testData,taskIndices)
    def remapLabels(data):
        images,labels=data
        labelMap={taskClasses[0]:0,taskClasses[1]:1}
        return images,labelMap[labels]
    remappedDataset=[(remapLabels(item)) for item in taskDataset]
    return DataLoader(remappedDataset,batch_size=64,shuffle=False)

def testModel(featureExtractor,wpModel,testDataset,classGroups):
    wpModel.eval()
    totalCorrect=0
    totalSamples=0
    criterion=nn.CrossEntropyLoss()
    with torch.no_grad():
        for taskNum,taskClasses in enumerate(classGroups):
            print(f"\nTesting on task {taskNum+1} with classes: {taskClasses}")
            taskTestLoader=loadTaskTestDataset(testDataset,taskClasses)
            taskCorrect=0
            taskTotal=0
            taskLoss=0
            for images,labels in taskTestLoader:
                features=featureExtractor(images,taskNum)
                outputs=wpModel(features,taskNum)
                loss=criterion(outputs,labels)
                taskLoss+=loss.item()
                _,predicted=torch.max(outputs,1)
                taskTotal+=labels.size(0)
                taskCorrect+=(predicted==labels).sum().item()
            taskAccuracy=100*taskCorrect/taskTotal
            avgLoss=taskLoss/len(taskTestLoader)
            print(f"Task {taskNum+1} Accuracy: {taskAccuracy:.2f}%")
            print(f"Task {taskNum+1} Loss: {avgLoss:.4f}")
            totalCorrect+=taskCorrect
            totalSamples+=taskTotal
    overallAccuracy=100*totalCorrect/totalSamples
    print(f"\nOverall Accuracy on all tasks: {overallAccuracy:.2f}%")

transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
testDataset=datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
classGroups=[[0,1],[2,3],[4,5],[6,7],[8,9]]
testModel(featureExtractor,wpModel,testDataset,classGroups)

Files already downloaded and verified

Testing on task 1 with classes: [0, 1]
Task 1 Accuracy: 57.45%
Task 1 Loss: 0.6649

Testing on task 2 with classes: [2, 3]
Task 2 Accuracy: 63.00%
Task 2 Loss: 0.6406

Testing on task 3 with classes: [4, 5]
Task 3 Accuracy: 54.25%
Task 3 Loss: 0.6818

Testing on task 4 with classes: [6, 7]
Task 4 Accuracy: 81.20%
Task 4 Loss: 0.4649

Testing on task 5 with classes: [8, 9]
Task 5 Accuracy: 50.90%
Task 5 Loss: 0.7129

Overall Accuracy on all tasks: 61.36%


#FULL CODE

In [None]:
class ReplayBuffer:
    def __init__(self,size):
        self.size=size
        self.buffer=[]
    def addSamples(self,X,y):
        if len(self.buffer)>=self.size:
            self.buffer=self.buffer[len(X):]
        self.buffer.extend(list(zip(X,y)))
    def getSamples(self):
        if len(self.buffer)==0:
            return None,None
        X,y=zip(*self.buffer)
        return torch.stack(X),torch.stack(y)

class HATCNN(nn.Module):
    def __init__(self,numTasks):
        super(HATCNN,self).__init__()
        self.num_tasks=numTasks
        self.conv1=nn.Conv2d(3,32,kernel_size=3,padding=1)
        self.conv2=nn.Conv2d(32,64,kernel_size=3,padding=1)
        self.pool=nn.MaxPool2d(2,2)
        self.fc1=nn.Linear(64*8*8,128)
        self.fc2=nn.Linear(128,2)
        self.masks = nn.Parameter(torch.ones(numTasks,128)) #Masking per Tasks
    def forward(self,x,task_id):
        x=self.pool(torch.relu(self.conv1(x)))#(T,32,16,16)
        x=self.pool(torch.relu(self.conv2(x))) #(T,64,8,8)
        x=x.view(x.size(0),-1) #(T,4096)
        mask=self.masks[task_id]
        x=torch.relu(self.fc1(x)*mask)
        return x

class MultiTaskCNN(nn.Module):
    def __init__(self):
        super(MultiTaskCNN, self).__init__()
        self.task_heads = nn.ModuleList([
            nn.Linear(128,2),  #Task 0
            nn.Linear(128,2), #Task 1
            nn.Linear(128,2),  #Task 2
            nn.Linear(128,2),  #Task 3
            nn.Linear(128,2)])   #Task 4
    def forward(self, x, task_id):
        return self.task_heads[task_id](x)

class OODCNN(nn.Module):
    def __init__(self):
        super(OODCNN, self).__init__()
        self.task_heads = nn.ModuleList([
            nn.Linear(128,2),  #Task 0
            nn.Linear(128,2), #Task 1
            nn.Linear(128,2),  #Task 2
            nn.Linear(128,2),  #Task 3
            nn.Linear(128,2)])   #Task 4
    def forward(self, x, task_id):
        return self.task_heads[task_id](x)


import random
from torchvision import datasets, transforms
from torch.utils.data import Subset

def loadTaskDataset(trainData,taskClasses):
    taskIndices=[i for i,(_,label) in enumerate(trainData) if label in taskClasses]
    taskDataset=torch.utils.data.Subset(trainData,taskIndices)
    def remapLabels(data):
        images,labels=data
        labelMap={taskClasses[0]:0,taskClasses[1]:1}
        return images,labelMap[labels]
    remappedDataset=[(remapLabels(item)) for item in taskDataset]
    return DataLoader(remappedDataset,batch_size=64,shuffle=True)

def traingOODheadWithFeatures(trainData,taskClasses,oodModel,featureExtractor,epochs=10):
    criterion=nn.CrossEntropyLoss()
    optimizer=optim.Adam(list(featureExtractor.parameters())+list(oodModel.parameters()),lr=0.001)
    for taskNum,taskClasses in enumerate(classGroups):
        print(f"\nTraining on task {taskNum+1} with classes: {taskClasses}")
        taskLoader=loadTaskDataset(trainDataset,taskClasses)
        XTaskList,yTaskList=[],[]
        for images,labels in taskLoader:
            XTaskList.append(images)
            yTaskList.append(labels)
        XTask=torch.cat(XTaskList)
        yTask=torch.cat(yTaskList)
        for epoch in range(epochs):
            featureExtractor.train()
            oodModel.train()
            optimizer.zero_grad()
            features=featureExtractor(XTask,taskNum)
            outputs=oodModel(features,taskNum)
            loss=criterion(outputs,yTask)
            loss.backward()
            optimizer.step()
            print(f"Task {taskNum+1}, Epoch {epoch+1}, Loss: {loss.item()}")

transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
trainDataset=datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
testDataset=datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
classGroups=[[0,1],[2,3],[4,5],[6,7],[8,9]]
featureExtractor=HATCNN(numTasks=len(classGroups))
oodModel=OODCNN()
traingOODheadWithFeatures(trainDataset,classGroups,oodModel,featureExtractor,epochs=1)


def loadTaskDataset(trainData,taskClasses):
    taskIndices=[i for i,(_,label) in enumerate(trainData) if label in taskClasses]
    taskDataset=torch.utils.data.Subset(trainData,taskIndices)
    def remapLabels(data):
        images,labels=data
        labelMap={taskClasses[0]:0,taskClasses[1]:1}
        return images,labelMap[labels]
    remappedDataset=[(remapLabels(item)) for item in taskDataset]
    return DataLoader(remappedDataset,batch_size=64,shuffle=True)

def evaluateTaskAccuracy(model,featureExtractor,taskLoader,taskNum):
    model.eval()
    correct,total=0,0
    with torch.no_grad():
        for images,labels in taskLoader:
            features=featureExtractor(images,taskNum)
            outputs=model(features,taskNum).argmax(dim=1)
            correct+=(outputs==labels).sum().item()
            total+=labels.size(0)
    return correct/total if total>0 else 0

def trainIncrementalClassTasks(featureExtractor,wpModel,trainDataset,classGroups,epochs=5):
    criterion=nn.CrossEntropyLoss()
    optimizer=optim.Adam(list(wpModel.parameters()),lr=0.001)
    buffer=ReplayBuffer(size=1000)
    task_accuracies,avg_forgetting_rate=[],[]
    for taskNum,taskClasses in enumerate(classGroups):
        print(f"\nTraining on task {taskNum+1} with classes: {taskClasses}")
        taskLoader=loadTaskDataset(trainDataset,taskClasses)
        XTaskList,yTaskList=[],[]
        for images,labels in taskLoader:
            XTaskList.append(images)
            yTaskList.append(labels)
        XTask=torch.cat(XTaskList)
        yTask=torch.cat(yTaskList)
        XBuffer,yBuffer=buffer.getSamples()
        if XBuffer is not None and len(XBuffer)>0:
            XReplay,yReplay=XBuffer,yBuffer
        else:
            XReplay,yReplay=None,None
        for epoch in range(epochs):
            featureExtractor.train()
            wpModel.train()
            optimizer.zero_grad()
            features=featureExtractor(XTask,taskNum)
            outputs=wpModel(features,taskNum)
            loss=criterion(outputs,yTask)
            if XReplay is not None:
                features=featureExtractor(XReplay,taskNum)
                replayOutputs=wpModel(features,taskNum)
                replayLoss=criterion(replayOutputs,yReplay)
                loss+=replayLoss
            loss.backward()
            optimizer.step()
            print(f"Task {taskNum+1}, Epoch {epoch+1}, Loss: {loss.item()}")
        task_accuracy=evaluateTaskAccuracy(wpModel,featureExtractor,taskLoader,taskNum)
        print(f"Accuracy on task {taskNum+1}: {task_accuracy*100:.2f}%")
        task_accuracies.append([task_accuracy])
        for i in range(taskNum):
            previous_task_loader=loadTaskDataset(trainDataset,classGroups[i])
            accuracy_after_task=evaluateTaskAccuracy(wpModel,featureExtractor,previous_task_loader,taskNum)
            task_accuracies[i].append(accuracy_after_task)
            print(f"Accuracy on task {i+1} after learning task {taskNum+1}: {accuracy_after_task*100:.2f}%")
        if taskNum>0:
            forgetting_rates=[]
            for i in range(taskNum):
                initial_accuracy=task_accuracies[i][0]
                current_accuracy=task_accuracies[i][-1]
                forgetting=initial_accuracy-current_accuracy
                forgetting_rates.append(forgetting)
            avg_forgetting_rate.append(sum(forgetting_rates)/len(forgetting_rates))
            print(f"Average Forgetting Rate after task {taskNum+1}: {avg_forgetting_rate[-1]:.4f}")
        buffer.addSamples(XTask,yTask)
    aca=sum([task_accuracies[i][-1] for i in range(len(classGroups))])/len(classGroups)
    print(f"\nAverage Classification Accuracy (ACA) after the last task: {aca:.4f}")
    return aca,avg_forgetting_rate

transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
trainDataset=datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
testDataset=datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
classGroups=[[0,1],[2,3],[4,5],[6,7],[8,9]]
wpModel=MultiTaskCNN()
trainIncrementalClassTasks(featureExtractor,wpModel,trainDataset,classGroups)



import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets

def loadTaskTestDataset(testData,taskClasses):
    taskIndices=[i for i,(_,label) in enumerate(testData) if label in taskClasses]
    taskDataset=torch.utils.data.Subset(testData,taskIndices)
    def remapLabels(data):
        images,labels=data
        labelMap={taskClasses[0]:0,taskClasses[1]:1}
        return images,labelMap[labels]
    remappedDataset=[(remapLabels(item)) for item in taskDataset]
    return DataLoader(remappedDataset,batch_size=64,shuffle=False)

def testModel(featureExtractor,wpModel,testDataset,classGroups):
    wpModel.eval()
    totalCorrect=0
    totalSamples=0
    criterion=nn.CrossEntropyLoss()
    with torch.no_grad():
        for taskNum,taskClasses in enumerate(classGroups):
            print(f"\nTesting on task {taskNum+1} with classes: {taskClasses}")
            taskTestLoader=loadTaskTestDataset(testDataset,taskClasses)
            taskCorrect=0
            taskTotal=0
            taskLoss=0
            for images,labels in taskTestLoader:
                features=featureExtractor(images,taskNum)
                outputs=wpModel(features,taskNum)
                loss=criterion(outputs,labels)
                taskLoss+=loss.item()
                _,predicted=torch.max(outputs,1)
                taskTotal+=labels.size(0)
                taskCorrect+=(predicted==labels).sum().item()
            taskAccuracy=100*taskCorrect/taskTotal
            avgLoss=taskLoss/len(taskTestLoader)
            print(f"Task {taskNum+1} Accuracy: {taskAccuracy:.2f}%")
            print(f"Task {taskNum+1} Loss: {avgLoss:.4f}")
            totalCorrect+=taskCorrect
            totalSamples+=taskTotal
    overallAccuracy=100*totalCorrect/totalSamples
    print(f"\nOverall Accuracy on all tasks: {overallAccuracy:.2f}%")

transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
testDataset=datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
classGroups=[[0,1],[2,3],[4,5],[6,7],[8,9]]
testModel(featureExtractor,wpModel,testDataset,classGroups)