## Dependencies

In [None]:
import json

import random

import torch

import torch.nn as nn

import torch.optim as optim

from torchvision import models, transforms, datasets

from torch.utils.data import DataLoader, random_split, Subset

import numpy as np

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix

import seaborn as sb

import matplotlib.pyplot as plt

from tqdm import tqdm



# Device configuration

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Device used: {DEVICE}")

## Utils

In [None]:
# Set random seed for reproducibility
from utils.utils import set_seed
set_seed()

# Path to save models and metrics
from pathlib import Path

In [None]:
BATCH_SIZE = 32

LEARNING_RATE = 0.001

EPOCHS = 5

## Simple CNN

In [None]:
from models.simple_cnn import init_model_cnn

## Init model

In [None]:
from utils.train_test_metrics import train_model

## SISA structure

In [None]:
from methods.sisa.sisa_utils import create_sisa_structure

from methods.sisa.sisa_utils import recreate_sisa_dataloaders

In [None]:
*_, transform = init_model_cnn()
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

In [None]:
SHARDS = 3
SLICES = 5

In [None]:
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

In [None]:
create_sisa_structure(dataset, shards=SHARDS, slices_per_shard=SLICES)

In [None]:
sisa_structure_file = 'sisa_structure.json'

In [None]:
dataloaders = recreate_sisa_dataloaders(datasets=(dataset, test_dataset), info_file_path = sisa_structure_file, batch_size=32, val_ratio=0.1)

In [None]:
from utils.train_test_metrics import test_model

In [None]:
def sisa_test(model, model_path,  test_loader):
    model.eval()

    predictions = []
    true_labels = []

    # Inference on the test set
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc=f"Evaluating model: {model_path}"):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    return predictions, true_labels

In [None]:
def sisa_train(dataloaders, num_epochs, save_models_metrics_dir):

    save_path = Path(save_models_metrics_dir)
    save_path.mkdir(parents=True, exist_ok=True) # Ensure the directory exists

    # Iterate over each shard
    for shard_id, slices in dataloaders.items():
        if shard_id == "test":  # Skip the test loader
            continue
            
        print(f"Training shard: {shard_id}")
        
        # Initialize a new model for the shard
        model, model_name, criterion, optimizer, _ = init_model_cnn()

        # Iterate over slices in the shard
        for slice_id, loaders in slices.items():

            print(f"  Training slice: {slice_id}")
            
            # Get train and validation loaders for this slice
            train_loader = loaders["train"]
            val_loader = loaders["val"]

            model_name = f"{shard_id}_{slice_id}" + model_name
            
            # Call the slice-level training function
            train_model(
                model=model, 
                model_name=model_name, 
                train_loader=train_loader, 
                val_loader=val_loader, 
                criterion=criterion, 
                optimizer=optimizer, 
                num_epochs=num_epochs
                )

        shard_model_path = f"./{save_path}/{shard_id}_final_model.pth"
        torch.save(model.state_dict(), shard_model_path)
        print(f"Saved final shard model to {shard_model_path}")

In [None]:
import os
def sisa_test(dataloaders, saved_models_metrics_dir, clear_solo_models_preds=True):
    test_loader = dataloaders["test"]

    # Initialize the evaluation results dictionary
    evaluation_results = {}

    for shard_id in [key for key in dataloaders.keys() if key != "test"]:

        # Path to the final model for this shard
        shard_model_path = f"{saved_models_metrics_dir}/{shard_id}_final_model.pth"

        # Load model
        model, model_name, *_ = init_model_cnn(learning_rate=LEARNING_RATE)

        model_name = f"{shard_id}_" + model_name

        # Call the evaluation function
        test_model(model, model_name, shard_model_path, test_loader)

        intermediate_json_path = f"{saved_models_metrics_dir}/{model_name}_predictions.json"

        # Load intermediate predictions JSON
        with open(intermediate_json_path, "r") as f:
            shard_data = json.load(f)

        # Add shard-specific predictions and true labels to evaluation_results
        evaluation_results[shard_id] = {
            "predictions": [int(p) for p in shard_data["predictions"]],
            "true_labels": [int(t) for t in shard_data["true_labels"]]
        }

        # Delete intermediate JSON if clear_solo_models_preds is True
        if clear_solo_models_preds:
            os.remove(intermediate_json_path)

    # Save predictions and true labels to a JSON file
    output_path = "sisa_final_evaluation.json"
    with open(output_path, "w") as f:
        json.dump(evaluation_results, f)

    print(f"Evaluation results saved to {output_path}")

In [None]:
def calculate_shard_metrics(true_labels, shard_predictions):
    """
    Calculate metrics for a single shard's predictions.
    
    Args:
        true_labels (list): True labels for the dataset.
        shard_predictions (list): Predictions from a shard model.
    
    Returns:
        dict: A dictionary with accuracy, precision, recall, and F1 score.
    """
    accuracy = accuracy_score(true_labels, shard_predictions)
    precision = precision_score(true_labels, shard_predictions, average="weighted", zero_division=0)
    recall = recall_score(true_labels, shard_predictions, average="weighted", zero_division=0)
    f1 = f1_score(true_labels, shard_predictions, average="weighted", zero_division=0)
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1_score": f1}

In [None]:
def aggregate_predictions(true_labels, predictions, weights):
    """
    Perform weighted voting aggregation of shard predictions.
    
    Args:
        true_labels (list): True labels for the dataset.
        predictions (dict): Dictionary of shard predictions.
        weights (list): List of weights (accuracies) for each shard.
    
    Returns:
        list: Aggregated predictions.
    """
    num_samples = len(true_labels)
    num_shards = len(predictions)
    
    # Create an array to store shard predictions for each sample
    shard_preds = np.array([predictions[f"shard_{i}"] for i in range(num_shards)])  # Shape: (num_shards, num_samples)
    weights = np.array(weights).reshape(-1, 1)  # Shape: (num_shards, 1)
    
    # Weighted voting: Take the weighted mode of predictions
    weighted_votes = np.zeros((10, num_samples))  # Assuming 10 classes (MNIST)
    for shard_idx in range(num_shards):
        for i, pred in enumerate(shard_preds[shard_idx]):
            pred =pred.item()
            i = int(i)
            weighted_votes[pred, i] += weights[shard_idx]
    
    # Final prediction: Class with the highest weighted vote
    aggregated_predictions = np.argmax(weighted_votes, axis=0)
    return aggregated_predictions

In [None]:
def evaluate_aggregated_model(results):
    """
    Evaluate aggregated predictions and print metrics, confusion matrix.

    Args:
        results (dict): Results dictionary containing shard-specific predictions and true_labels.
    """
    # Initialize variables for aggregated predictions and shard metrics
    shard_metrics = {}
    shard_accuracies = []
    aggregated_predictions = []
    true_labels = None

    # Evaluate each shard
    for shard_id, shard_data in results.items():
        shard_preds = shard_data["predictions"]
        shard_true_labels = shard_data["true_labels"]

        # Ensure true_labels are consistent across shards
        if true_labels is None:
            true_labels = shard_true_labels
        elif true_labels != shard_true_labels:
            raise ValueError(f"True labels in shard {shard_id} do not match other shards!")

        # Calculate metrics for the shard
        metrics = calculate_shard_metrics(true_labels, shard_preds)
        shard_metrics[shard_id] = metrics
        shard_accuracies.append(metrics["accuracy"])
        print(f"Shard {shard_id} Metrics:")
        print(f"  Accuracy: {metrics['accuracy']:.4f}")
        print(f"  Precision: {metrics['precision']:.4f}")
        print(f"  Recall: {metrics['recall']:.4f}")
        print(f"  F1 Score: {metrics['f1_score']:.4f}")


    aggregated_predictions = aggregate_predictions(true_labels, aggregated_predictions)

    # Calculate metrics for the aggregated predictions
    aggregated_metrics = calculate_shard_metrics(true_labels, aggregated_predictions)
    print("\nAggregated Model Metrics:")
    print(f"  Accuracy: {aggregated_metrics['accuracy']:.4f}")
    print(f"  Precision: {aggregated_metrics['precision']:.4f}")
    print(f"  Recall: {aggregated_metrics['recall']:.4f}")
    print(f"  F1 Score: {aggregated_metrics['f1_score']:.4f}")

    # Generate and display confusion matrix
    cm = confusion_matrix(true_labels, aggregated_predictions)
    plt.figure(figsize=(10, 8))
    sb.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=range(10), yticklabels=range(10))
    plt.xlabel("Predicted Labels")
    plt.ylabel("True Labels")
    plt.title("Confusion Matrix of Aggregated Model")
    plt.show()

In [None]:
with open("sisa_final_evaluation.json", "r") as f:
    results = json.load(f)

In [None]:
evaluate_aggregated_model(results)

In [None]:
def update_sisa_structure(unlearn_samples_path, sisa_structure_path, updated_structure_path, deleted_samples_path):
    """
    Handles unlearning by identifying affected shards and slices, updating SISA structure,
    and triggering retraining and testing.

    Args:
        unlearn_samples_path (str): Path to JSON file with samples to unlearn.
        sisa_structure_path (str): Path to SISA structure JSON file.
        updated_structure_path (str): Path to save updated SISA structure JSON file.
        deleted_samples_path (str): Path to save deleted samples JSON file.
    """
    # Load samples to unlearn
    with open(unlearn_samples_path, "r") as f:
        unlearn_samples = json.load(f)

    # Load the SISA structure
    with open(sisa_structure_path, "r") as f:
        sisa_structure = json.load(f)

    # Track affected shards, slices, and samples to delete
    affected_shards = {}
    deleted_samples = []

    # Identify affected shards and slices
    for sample in unlearn_samples:
        index, label = sample["index"], sample["class"]
        for shard, slices in sisa_structure.items():
            for slice_name, slice_data in slices.items():
                if index in slice_data["indices"] and label in slice_data["classes"]:
                    # Track the lowest affected slice
                    if shard not in affected_shards:
                        affected_shards[shard] = []
                    affected_shards[shard].append(slice_name)

                    # Remove the sample from the slice
                    idx_position = slice_data["indices"].index(index)
                    slice_data["indices"].pop(idx_position)
                    slice_data["classes"].pop(idx_position)

                    # Track deleted samples
                    deleted_samples.append(sample)
                    break  # Move to the next sample after finding a match

    # Deduplicate and sort slice flags
    affected_shards = {
    shard: sorted(set(affected_shards[shard]))
    for shard in sorted(affected_shards)
}

    # Save the updated SISA structure
    with open(updated_structure_path, "w") as f:
        json.dump(sisa_structure, f)

    # Save the deleted samples
    with open(deleted_samples_path, "w") as f:
        json.dump(deleted_samples, f)

    # Print retraining plan
    print("Retraining Plan:")
    for shard, slices in affected_shards.items():
        print(f"  Shard: {shard}, Start from Slice: {slices[0]} onward")

    return affected_shards

In [None]:
mnist_samples_to_delete = 'mnist_samples_to_unlearn.json'
mnist_sisa_structure = 'sisa_structure.json'
updated_sisa_structure = 'updated_sisa_strucute.json'
deleted_samples = 'deleted_samples.json'

affected_shards = update_sisa_structure(mnist_samples_to_delete, mnist_sisa_structure, updated_sisa_structure,deleted_samples)
print(affected_shards)

In [None]:
def retrain_sisa_framework(updated_structure_path, affected_shards, dataset, save_path="./sisa_models"):
    """
    Retrain the SISA framework starting from the flagged slices for affected shards.

    Args:
        updated_structure_path (str): Path to the updated SISA structure JSON file.
        affected_shards (dict): Dictionary of affected shards and their flagged slices.
        dataset (Dataset): Original dataset to recreate dataloaders.
        save_path (str): Path to save the updated models.
    """
    # Load the updated SISA structure
    with open(updated_structure_path, "r") as f:
        updated_structure = json.load(f)

    # Recreate dataloaders
    dataloaders = recreate_sisa_dataloaders(updated_structure, dataset)

    # Loop through affected shards
    for shard_id, flagged_slices in affected_shards.items():
        print(f"Retraining shard: {shard_id}, starting from slice: slice_{flagged_slices[0]}")

        # Initialize a new model for this shard
        model, criterion, optimizer = init_model()

        # Retrain starting from the lowest flagged slice
        for slice_id, loaders in dataloaders[shard_id].items():
            # Check if the slice ID is >= the lowest flagged slice
            current_slice_idx = int(slice_id.split("_")[1])
            if current_slice_idx >= flagged_slices[0]:
                print(f"  Retraining slice: {slice_id}")

                # Get train and validation loaders
                train_loader = loaders["train"]
                val_loader = loaders["val"]

                # Train on this slice
                sisa_train(
                    model=model,
                    shard_id=shard_id,
                    slice_id=slice_id,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    criterion=criterion,
                    optimizer=optimizer,
                    num_epochs=EPOCHS,
                    save_path=save_path
                )

        # Save the final model for the shard
        shard_model_path = f"{save_path}/{shard_id}_final_model.pth"
        torch.save(model.state_dict(), shard_model_path)
        print(f"Saved updated model for {shard_id} to {shard_model_path}")