In [None]:
 !pip install -U torchmetrics



In [None]:
import torch
torch.cuda.is_available()
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import keras
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Subset, DataLoader, ConcatDataset
import random
import numpy as np
batch_size=8
from sklearn.metrics import precision_score, recall_score, f1_score
import torchmetrics

In [None]:
from torchmetrics import Precision
# Define the model architecture
class CNN(nn.Module):
    def __init__(self, num_classes):
        super(CNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Linear(64*8*8, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# Define the iCaRL class
class iCaRL:
    def __init__(self, device, num_classes, batch_size, memory_size):
        self.device = device
        self.num_classes = num_classes
        self.batch_size = batch_size
        self.memory_size = memory_size
        self.model = CNN(num_classes).to(device)
        self.exemplar_sets = []

    def train(self, train_dataset, lr, num_epochs):
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(self.model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[49, 63], gamma=0.2)

        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
            scheduler.step()
            print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')

    def test(self, test_dataset, lr, num_epochs):
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=True)
        criterion = nn.CrossEntropyLoss()
        # precision = Precision(average='weighted')
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            test_loss = 0
            correct = 0
            all_preds = []
            all_labels = []
            batch_size = 1
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs).to(self.device)
                loss = criterion(outputs, labels)
                pred = outputs.max(1)[1].to(self.device)
                correct += pred.eq(labels).sum().item()
                pred = pred.cpu().numpy()
                all_preds.extend(pred)
                all_labels.extend(labels)
                running_loss += loss.item()
                test_loss = running_loss / len(test_dataset)

            # Calculate metrics
            precision = precision_score(all_labels, all_preds, average='weighted')
            recall = recall_score(all_labels, all_preds, average='weighted')
            f1 = f1_score(all_labels, all_preds, average='weighted')

            print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(test_loss, correct, len(test_dataset), 100. * correct / len(test_dataset)))
            print('Precision: {:.4f}, Recall: {:.4f}, F1 Score: {:.4f}\n'.format(precision, recall, f1))

            return 100. * correct / len(labels)


    def construct_exemplar_set(self, dataset, m):
        exemplar_set = []
        class_means = []

        for class_idx in range(self.num_classes):
            class_indices = [idx for idx in dataset.indices if dataset.dataset.targets[idx] == class_idx]
            random.shuffle(class_indices)
            class_indices = class_indices[:m]

            features = []
            for idx in class_indices:
                img, _ = dataset.dataset[idx]  # Access data and label from the original dataset
                img = img.numpy().astype(np.float32) / 255.0
                img = torch.FloatTensor(img).unsqueeze(0).to(self.device)
                # print(img.shape)
                feature = self.model.features(img).squeeze().cpu().detach().numpy()
                features.append(feature)
            features = np.array(features)
            class_mean = np.mean(features, axis=0)

            exemplar_set.append(class_mean / np.linalg.norm(class_mean))
            class_means.append(class_mean / np.linalg.norm(class_mean))

        self.exemplar_sets.append(exemplar_set)
        self.class_means = class_means

    def reduce_exemplar_set(self, m):
        for exemplar_set in self.exemplar_sets:
            exemplar_set = exemplar_set[:m]

In [None]:
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define hyperparameters
num_classes = 10
batch_size = 8
memory_size = 2000
lr = 0.0005
num_epochs = 20
m = memory_size // num_classes

# Define transforms
transform = transforms.Compose([
    transforms.ToTensor()
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Initialize iCaRL
icarl = iCaRL(device, num_classes, batch_size, memory_size)

# Define the number of classes to train on for each task
classes_per_task = 5  # You can set this to any number you prefer

# Partition the dataset into subsets of classes
num_tasks = (num_classes + classes_per_task - 1) // classes_per_task
train_datasets = []
for task_idx in range(num_tasks):
    start_class_idx = task_idx * classes_per_task
    end_class_idx = min(start_class_idx + classes_per_task, num_classes)
    indices = [idx for idx, label in enumerate(train_dataset.targets)
               if start_class_idx <= label < end_class_idx]
    train_datasets.append(Subset(train_dataset, indices))

# Incremental training
for i, train_dataset in enumerate(train_datasets):
    print(f'Training on task {i + 1}')
    icarl.train(train_dataset, lr, num_epochs)
    print(f'Constructing exemplar set for task {i + 1}')
    icarl.construct_exemplar_set(train_dataset, m)
    if i > 0:
        print(f'Reducing exemplar set for task {i}')
        icarl.reduce_exemplar_set(m)

# Evaluation
with torch.no_grad():
  icarl.test(test_dataset, lr, num_epochs)

Files already downloaded and verified
Files already downloaded and verified
Training on task 1
Epoch [1/20], Loss: 1.6443
Epoch [2/20], Loss: 1.3498
Epoch [3/20], Loss: 1.2510
Epoch [4/20], Loss: 1.1953
Epoch [5/20], Loss: 1.1490
Epoch [6/20], Loss: 1.1053
Epoch [7/20], Loss: 1.0663
Epoch [8/20], Loss: 1.0291
Epoch [9/20], Loss: 1.0000
Epoch [10/20], Loss: 0.9747
Epoch [11/20], Loss: 0.9560
Epoch [12/20], Loss: 0.9373
Epoch [13/20], Loss: 0.9208
Epoch [14/20], Loss: 0.9042
Epoch [15/20], Loss: 0.8912
Epoch [16/20], Loss: 0.8788
Epoch [17/20], Loss: 0.8661
Epoch [18/20], Loss: 0.8538
Epoch [19/20], Loss: 0.8466
Epoch [20/20], Loss: 0.8331
Constructing exemplar set for task 1
Training on task 2
Epoch [1/20], Loss: 1.6901
Epoch [2/20], Loss: 0.9360
Epoch [3/20], Loss: 0.8455
Epoch [4/20], Loss: 0.7922
Epoch [5/20], Loss: 0.7526
Epoch [6/20], Loss: 0.7224
Epoch [7/20], Loss: 0.6979
Epoch [8/20], Loss: 0.6749
Epoch [9/20], Loss: 0.6565
Epoch [10/20], Loss: 0.6426
Epoch [11/20], Loss: 0.6277

  _warn_prf(average, modifier, msg_start, len(result))


Test set: Average loss: 0.0603, Accuracy: 3950/10000 (40%)
Precision: 0.2002, Recall: 0.3950, F1 Score: 0.2649

