# Machine Unlearning: Retraining and Scrubbing in a K9db-Integrated System
Final Project for DS 593 Fall 2025

Tracy Cui, Yuki Li, Yang Lu, Xin Wei

## Code Notebook 2: K9db Integration

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
import time
import copy
from tabulate import tabulate

import sqlite3
import random
from contextlib import closing

In [None]:
# Depends on if running this notebook separately, we might need to redefine the parameters
if 'CONFIG' not in globals():
    CONFIG = {
        "total_users": 100,
        "batch_size": 128,
        "seed": 42,
        "zipf_param": 1.5,
        "device": torch.device("cuda" if torch.cuda.is_available() else "cpu")
    }

class K9dbMock:


    def __init__(self):
        print(">>> [K9db] Booting up K9db SQL Backend for MNIST...")

        # 1. Setup SQLite in Memory
        self.db = sqlite3.connect(":memory:")
        self.cursor = self.db.cursor()

        # Enable Foreign Keys for Cascade Logic
        self.cursor.execute("PRAGMA foreign_keys = ON;")

        # Create Schema: Users -> Images
        self.cursor.execute("CREATE TABLE users (user_id INTEGER PRIMARY KEY)")
        self.cursor.execute("""
            CREATE TABLE images (
                global_index INTEGER PRIMARY KEY,
                owner_id INTEGER,
                FOREIGN KEY (owner_id) REFERENCES users(user_id) ON DELETE CASCADE
            )
        """)
        self.cursor.execute("CREATE INDEX idx_owner ON images(owner_id)")
        self.db.commit()

        # 2. Load MNIST Data
        print(">>> [K9db] Ingesting MNIST Dataset...")
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)) # MNIST Mean/Std
        ])
        self.full_mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

        # 3. Distribute data using Zipfian (Power Law) among users
        print(">>> [K9db] Distributing data using Zipfian (Power Law) among users...")
        num_items = len(self.full_mnist_dataset)
        total_users = CONFIG["total_users"]
        zipf_param = CONFIG["zipf_param"]

        # Generate Zipfian probabilities
        ranks = np.arange(1, total_users + 1)
        weights = 1 / np.power(ranks, zipf_param)
        weights /= weights.sum() # Normalize to sum to 1

        # Assign counts per user
        # Ensure sum of counts equals num_items
        counts = np.random.multinomial(num_items, weights)

        # Create mapping of global_index to owner_id
        all_indices = np.arange(num_items)
        np.random.shuffle(all_indices)

        current_idx_in_all_indices = 0
        user_rows = []
        image_rows = []
        self.data_map = {}

        for user_id, count in enumerate(counts):
            user_rows.append((user_id,))
            user_assigned_indices = []
            if count > 0:
                # Assign 'count' images to this user
                # Ensure we don't go out of bounds if multinomial distribution makes last user have more than available indices
                end_idx = min(current_idx_in_all_indices + count, num_items)
                user_indices = all_indices[current_idx_in_all_indices : end_idx]
                for img_idx in user_indices:
                    image_rows.append((int(img_idx), int(user_id)))
                    user_assigned_indices.append(int(img_idx))
                current_idx_in_all_indices += len(user_indices)
            self.data_map[user_id] = user_assigned_indices

        self.cursor.executemany("INSERT INTO users (user_id) VALUES (?)", user_rows)
        self.cursor.executemany("INSERT INTO images (global_index, owner_id) VALUES (?, ?)", image_rows)
        self.db.commit()

        # Final check for data distribution
        total_images_in_db = self.cursor.execute("SELECT COUNT(*) FROM images").fetchone()[0]
        print(f">>> [K9db] Indexing Complete. {total_images_in_db} images mapped to {total_users} owners.")

    def get_full_dataset(self):
        """
        Returns the original full MNIST dataset.
        """
        return self.full_mnist_dataset

    def get_data_map(self):
        """
        Returns the data_map dictionary for user-to-image index mapping.
        """
        return self.data_map

    def get_loaders(self):
        """
        Returns loaders for the CURRENT valid state.
        Queries the DB to see which images are still 'live'.
        """
        # Select all indices that currently exist in the DB
        self.cursor.execute("SELECT global_index FROM images")
        valid_indices = [row[0] for row in self.cursor.fetchall()]

        # Create PyTorch Subset
        train_sub = Subset(self.full_mnist_dataset, valid_indices)

        train_loader = DataLoader(train_sub, batch_size=CONFIG["batch_size"], shuffle=True)

        return train_loader

    def execute_delete_command(self, user_id):
        """
        Simulates: DELETE FROM Users WHERE id = {user_id};
        K9db Logic:
           1. We identify the user.
           2. We delete the user.
           3. The database ENGINE automatically deletes the images.
        """
        print(f"\n>>> [K9db SQL EXEC] DELETE FROM users WHERE user_id = {user_id};")

        # 1. Pre-fetch Forget Set (simulating the identification step)
        self.cursor.execute("SELECT global_index FROM images WHERE owner_id = ?", (user_id,))
        forget_indices = [row[0] for row in self.cursor.fetchall()]

        if not forget_indices:
            print(f">>> [K9db] Warning: User {user_id} not found or has no data currently associated.")
            # Return empty lists if user has no data or doesn't exist
            return [], self.cursor.execute("SELECT global_index FROM images").fetchall()

        # 2. Execute SQL Delete (Triggers Cascade)
        self.cursor.execute("DELETE FROM users WHERE user_id = ?", (user_id,))
        self.db.commit()

        # 3. Fetch 'Retain Set' (Everything remaining)
        self.cursor.execute("SELECT global_index FROM images")
        
        # Fetchall returns list of tuples, convert to list of ints
        retain_indices = [row[0] for row in self.cursor.fetchall()]

        print(f">>> [K9db] Cascade successful. Purged {len(forget_indices)} images related to user {user_id}.")

        # Return purely the indices so the pipeline can do its math
        return forget_indices, retain_indices

In [None]:
# Initialize K9db
db = K9dbMock()

# Get the full dataset and data_map from K9dbMock for consistency
full_dataset = db.get_full_dataset()
data_map = db.get_data_map()

print(f"Total images in full_dataset: {len(full_dataset)}")
print(f"Number of users in data_map: {len(data_map)}")


>>> [K9db] Booting up K9db SQL Backend for MNIST...
>>> [K9db] Ingesting MNIST Dataset...
>>> [K9db] Distributing data using Zipfian (Power Law) among users...
>>> [K9db] Indexing Complete. 60000 images mapped to 100 owners.
Total images in full_dataset: 60000
Number of users in data_map: 100


then we tried MNIST, but assign each user with a specific unique set. Like user 1 is associated with all images with number 0.