In [None]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score
import pandas as pd
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message=".*torch.tensor.*")
warnings.filterwarnings('ignore', category=FutureWarning, message=".*torch.load.*")
# Define the feature extractor and classifier
class MLPFeatureExtractor(nn.Module):
    def __init__(self, input_size, hidden_size, feature_dim):
        super(MLPFeatureExtractor, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, feature_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)

class LwPClassifier:
    def __init__(self, n_classes=10, feature_dim=128):
        self.n_classes = n_classes
        self.prototypes = np.zeros((n_classes, feature_dim))
        self.class_counts = np.zeros(n_classes)

    def partial_fit(self, X, y):
        for i in range(len(X)):
            class_label = y[i]
            self.class_counts[class_label] += 1
            if self.class_counts[class_label] == 1:
                self.prototypes[class_label] = X[i]
            else:
                self.prototypes[class_label] = (self.prototypes[class_label] * (self.class_counts[class_label] - 1) + X[i]) / self.class_counts[class_label]

    def predict(self, X):
        distances = np.linalg.norm(X[:, np.newaxis] - self.prototypes, axis=2)
        return np.argmin(distances, axis=1)

def load_dataset(path, is_train=True):
    data = torch.load(path)
    inputs = data['data']
    labels = data.get('targets', None)

    if not isinstance(inputs, torch.Tensor):
        inputs = torch.tensor(inputs, dtype=torch.float32)
    if labels is not None and not isinstance(labels, torch.Tensor):
        labels = torch.tensor(labels, dtype=torch.long)

    if len(inputs.shape) > 2:
        inputs = inputs.view(inputs.shape[0], -1)

    if is_train:
        return inputs, labels
    else:
        return inputs, labels

def train_initial_model(feature_extractor, classifier, dataset, labels):
    feature_extractor.eval()
    with torch.no_grad():
        features = feature_extractor(torch.tensor(dataset, dtype=torch.float32)).numpy()
    classifier.partial_fit(features, labels)

def pseudo_label_and_update(feature_extractor, classifier, dataset):
    feature_extractor.eval()
    with torch.no_grad():
        features = feature_extractor(torch.tensor(dataset, dtype=torch.float32)).numpy()
    pseudo_labels = classifier.predict(features)
    classifier.partial_fit(features, pseudo_labels)

def evaluate_model(feature_extractor, classifier, dataset, true_labels):
    feature_extractor.eval()
    with torch.no_grad():
        features = feature_extractor(torch.tensor(dataset, dtype=torch.float32)).numpy()
    predictions = classifier.predict(features)
    accuracy = accuracy_score(true_labels, predictions)
    return accuracy

def continual_learning(base_path_train_1_to_10, base_path_eval_1_to_10, base_path_train_11_to_20, base_path_eval_11_to_20, input_size, hidden_size, feature_dim, n_classes):
    models = []
    accuracies = []
    previous_data = []
    previous_labels = []

    feature_extractor = MLPFeatureExtractor(input_size, hidden_size, feature_dim)
    classifier = LwPClassifier(n_classes=n_classes, feature_dim=feature_dim)
    # Initialize empty tensors for previous data and labels
    # Initialize empty tensors for previous data and labels
    previous_data = torch.empty((0, input_size), dtype=torch.float32)
    previous_labels = torch.empty(0, dtype=torch.long)

    # For datasets 1 to 10
    for i in range(1, 11):
        train_dataset_path = f"{base_path_train_1_to_10}/{i}_train_data.tar.pth"
        train_data, train_labels = load_dataset(train_dataset_path, is_train=True)

        eval_dataset_path = f"{base_path_eval_1_to_10}/{i}_eval_data.tar.pth"
        eval_data, eval_labels = load_dataset(eval_dataset_path, is_train=False)

        if train_labels is None:
            # Generate pseudo-labels if labels are missing
            train_labels = classifier.predict(feature_extractor(torch.tensor(train_data, dtype=torch.float32)).detach().numpy())

        if i == 1:
            train_initial_model(feature_extractor, classifier, train_data, train_labels)
        else:
            # Combine previous data with new dataset
            combined_data = torch.cat([previous_data, train_data])
            combined_labels = torch.cat([previous_labels, torch.tensor(train_labels, dtype=torch.long)])
            pseudo_label_and_update(feature_extractor, classifier, combined_data)

        # Store current dataset for future use
        previous_data = torch.cat([previous_data, train_data])
        previous_labels = torch.cat([previous_labels, torch.tensor(train_labels, dtype=torch.long)])

        models.append((feature_extractor, classifier))
        model_accuracies = []

        for j in range(1, 11):
            eval_dataset_path_j = f"{base_path_eval_1_to_10}/{j}_eval_data.tar.pth"
            eval_data_j, eval_labels_j = load_dataset(eval_dataset_path_j, is_train=False)

            eval_acc = evaluate_model(feature_extractor, classifier, eval_data_j, eval_labels_j)
            model_accuracies.append(eval_acc)

        while len(model_accuracies) < 20:
            model_accuracies.append(None)

        accuracies.append(model_accuracies)

    # Similar logic applies for datasets 11 to 20
    for i in range(1, 11):
        train_dataset_path = f"{base_path_train_11_to_20}/{i}_train_data.tar.pth"
        train_data, train_labels = load_dataset(train_dataset_path, is_train=True)

        eval_dataset_path = f"{base_path_eval_11_to_20}/{i}_eval_data.tar.pth"
        eval_data, eval_labels = load_dataset(eval_dataset_path, is_train=False)

        if train_labels is None:
            train_labels = classifier.predict(feature_extractor(torch.tensor(train_data, dtype=torch.float32)).detach().numpy())

        combined_data = torch.cat([previous_data, train_data])
        combined_labels = torch.cat([previous_labels, torch.tensor(train_labels, dtype=torch.long)])
        pseudo_label_and_update(feature_extractor, classifier, combined_data)

        previous_data = torch.cat([previous_data, train_data])
        previous_labels = torch.cat([previous_labels, torch.tensor(train_labels, dtype=torch.long)])

        models.append((feature_extractor, classifier))
        model_accuracies = []

        for j in range(1, 11):
            eval_dataset_path_j = f"{base_path_eval_1_to_10}/{j}_eval_data.tar.pth"
            eval_data_j, eval_labels_j = load_dataset(eval_dataset_path_j, is_train=False)

            eval_acc = evaluate_model(feature_extractor, classifier, eval_data_j, eval_labels_j)
            model_accuracies.append(eval_acc)

        for j in range(1, 11):
            eval_dataset_path_j = f"{base_path_eval_11_to_20}/{j}_eval_data.tar.pth"
            eval_data_j, eval_labels_j = load_dataset(eval_dataset_path_j, is_train=False)

            eval_acc = evaluate_model(feature_extractor, classifier, eval_data_j, eval_labels_j)
            model_accuracies.append(eval_acc)

        accuracies.append(model_accuracies)


    return accuracies

# Paths for the datasets
base_path_train_1_to_10 = "/content/drive/MyDrive/mini-project-2-dataset/dataset/part_one_dataset/train_data"
base_path_eval_1_to_10 = "/content/drive/MyDrive/mini-project-2-dataset/dataset/part_one_dataset/eval_data"

base_path_train_11_to_20 = "/content/drive/MyDrive/mini-project-2-dataset/dataset/part_two_dataset/train_data"
base_path_eval_11_to_20 = "/content/drive/MyDrive/mini-project-2-dataset/dataset/part_two_dataset/eval_data"

input_size = 3072
hidden_size = 256
feature_dim = 128
n_classes = 10

# Perform continual learning for models f1 to f20
accuracy_matrix_11_20 = continual_learning(base_path_train_1_to_10, base_path_eval_1_to_10, base_path_train_11_to_20, base_path_eval_11_to_20, input_size, hidden_size, feature_dim, n_classes)

# Convert the accuracy matrix to a pandas DataFrame for better visualization
df_accuracy_matrix = pd.DataFrame(accuracy_matrix_11_20)

# Set the row and column names
df_accuracy_matrix.index = [f"f{i+1}" for i in range(20)]
df_accuracy_matrix.columns = [f"D{i+1}" for i in range(20)]

# Print the results
# Extract and print accuracies for f11 to f20, corresponding to D1 to D(n+10)
for model_index in range(10, 20):  # f11 to f20 correspond to indices 10 to 19
    dataset_limit = model_index + 1  # D1 to D(n+10)
    model_accuracies = accuracy_matrix_11_20[model_index][:dataset_limit]  # Slice accuracies for D1 to D(n+10)
    print(f"Accuracies for f{model_index + 1} corresponding to D1 to D{dataset_limit}: {model_accuracies}")


  data = torch.load(path)
  features = feature_extractor(torch.tensor(dataset, dtype=torch.float32)).numpy()
  previous_labels = torch.cat([previous_labels, torch.tensor(train_labels, dtype=torch.long)])
  features = feature_extractor(torch.tensor(dataset, dtype=torch.float32)).numpy()
  data = torch.load(path)
  features = feature_extractor(torch.tensor(dataset, dtype=torch.float32)).numpy()
  data = torch.load(path)
  features = feature_extractor(torch.tensor(dataset, dtype=torch.float32)).numpy()
  data = torch.load(path)
  features = feature_extractor(torch.tensor(dataset, dtype=torch.float32)).numpy()
  data = torch.load(path)
  features = feature_extractor(torch.tensor(dataset, dtype=torch.float32)).numpy()
  data = torch.load(path)
  features = feature_extractor(torch.tensor(dataset, dtype=torch.float32)).numpy()
  data = torch.load(path)
  features = feature_extractor(torch.tensor(dataset, dtype=torch.float32)).numpy()
  data = torch.load(path)
  features = feature_extractor(to