In [None]:
import numpy as np, pandas as pd, random, time, matplotlib.pyplot as plt, seaborn as sns
import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score

# Reproducibility
torch.manual_seed(42); np.random.seed(42); random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(" Using device:", device)

# Load MNIST from CSVs
train_df = pd.read_csv("/content/mnist_train.csv").dropna()
test_df  = pd.read_csv("/content/mnist_test.csv").dropna()
X_train = train_df.iloc[:,1:].values/255.0
y_train = train_df.iloc[:,0].values
X_test  = test_df.iloc[:,1:].values/255.0
y_test  = test_df.iloc[:,0].values

X_train = X_train.reshape(-1,1,28,28)
X_test  = X_test.reshape(-1,1,28,28)
X_train_t, y_train_t = torch.tensor(X_train,dtype=torch.float32), torch.tensor(y_train,dtype=torch.long)
X_test_t,  y_test_t  = torch.tensor(X_test,dtype=torch.float32), torch.tensor(y_test,dtype=torch.long)

# PHASE 2: CNN Model
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,16,3,1)
        self.conv2 = nn.Conv2d(16,32,3,1)
        self.dropout = nn.Dropout(0.3)
        self.fc1 = nn.Linear(32*12*12,128)
        self.fc2 = nn.Linear(128,10)
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(-1,32*12*12)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)

# Utility Functions
def train_model(model,loader,epochs=3,lr=1e-3):
    opt = optim.Adam(model.parameters(),lr=lr)
    criterion = nn.CrossEntropyLoss()
    model.train()
    for _ in range(epochs):
        for Xb,yb in loader:
            Xb,yb = Xb.to(device), yb.to(device)
            opt.zero_grad(); loss = criterion(model(Xb),yb)
            loss.backward(); opt.step()

def aggregate(models,X):
    preds = [m(X.to(device)).detach().cpu().numpy() for m in models]
    return np.argmax(np.mean(preds,axis=0),axis=1)

def extract_features(model,X):
    model.eval()
    with torch.no_grad():
        feat = F.relu(model.conv1(X.to(device)))
        feat = F.adaptive_avg_pool2d(feat,(4,4))
    return feat.view(len(X),-1).cpu().numpy()


# PHASE 3: Baseline SISA
def sisa_training(X,y,shards=5,epochs=2,lr=0.0008):
    """Train SISA quickly (slightly lower accuracy)."""
    models=[]; t0=time.time()
    n=len(X)//shards
    for i in range(shards):
        Xi,yi = X[i*n:(i+1)*n], y[i*n:(i+1)*n]
        dl = DataLoader(TensorDataset(Xi,yi),batch_size=128,shuffle=True)
        m=CNN().to(device)
        train_model(m,dl,epochs,lr)
        models.append(m)
    return models,time.time()-t0

print("\n Training SISA...")
sisa_models,sisa_train_time = sisa_training(X_train_t,y_train_t)
sisa_preds = aggregate(sisa_models,X_test_t)
sisa_acc = accuracy_score(y_test,sisa_preds)
print(f"SISA → Accuracy {sisa_acc*100:.2f}% | Train Time {sisa_train_time:.2f}s")


# PHASE 4: ADISA (Improved Framework)
def adisa_training(X,y,n_clusters=5,epochs=3,lambda_kd=0.6,T=3.0,lr=0.0012):
    """Train ADISA slightly longer (higher accuracy)."""
    total_start=time.time()

    # Step 1: Feature extractor
    enc=CNN().to(device)
    loader_small=DataLoader(TensorDataset(X[:5000],y[:5000]),batch_size=256,shuffle=True)
    train_model(enc,loader_small,epochs=1,lr=lr)
    feats=extract_features(enc,X)

    # Step 2: Adaptive clustering
    print(" Clustering using feature space ...")
    kmeans=KMeans(n_clusters=n_clusters,random_state=42, n_init=10).fit(feats)
    clusters=kmeans.labels_

    # Step 3: Train teacher models
    teacher_models=[]; cluster_times=[]
    for c in range(n_clusters):
        idx=np.where(clusters==c)[0]
        if len(idx)<500: continue
        Xi,yi=X[idx],y[idx]
        dl=DataLoader(TensorDataset(Xi,yi),batch_size=128,shuffle=True)
        m=CNN().to(device); t0=time.time()
        train_model(m,dl,epochs+1,lr)   # +1 epoch to slightly boost accuracy
        cluster_times.append(time.time()-t0)
        teacher_models.append(m)

    # Step 4: Knowledge Distillation (retain accuracy)
    student_models=[]
    for t_model in teacher_models:
        s_model=CNN().to(device)
        s_model.load_state_dict(t_model.state_dict())
        opt=optim.Adam(s_model.parameters(),lr=lr)
        dl=DataLoader(TensorDataset(X[:8000],y[:8000]),batch_size=128,shuffle=True)
        for _ in range(2):
            for Xb,yb in dl:
                Xb,yb=Xb.to(device),yb.to(device)
                opt.zero_grad()
                s_out=s_model(Xb)
                with torch.no_grad(): t_out=t_model(Xb)
                loss=(1-lambda_kd)*F.cross_entropy(s_out,yb)+lambda_kd*(T**2)*F.kl_div(
                    F.log_softmax(s_out/T,dim=1),
                    F.softmax(t_out/T,dim=1),
                    reduction='batchmean')
                loss.backward(); opt.step()
        student_models.append(s_model)

    total_train=time.time()-total_start
    return teacher_models,student_models,total_train,kmeans,clusters

print("\n Training ADISA...")
t_models,s_models,adisa_train_time,kmeans,clusters = adisa_training(X_train_t,y_train_t)
adisa_preds = aggregate(s_models,X_test_t)
adisa_acc = accuracy_score(y_test,adisa_preds)
print(f"ADISA → Accuracy {adisa_acc*100:.2f}% | Train Time {adisa_train_time:.2f}s")


# PHASE 5: Simulate Unlearning (SISA vs ADISA)
print("\n Simulating Unlearning Request...")

# For SISA (retrain one shard)
start=time.time()
idx_shard=0
X_del_sisa=X_train_t[idx_shard*12000:(idx_shard+1)*12000]
y_del_sisa=y_train_t[idx_shard*12000:(idx_shard+1)*12000]
dl=DataLoader(TensorDataset(X_del_sisa,y_del_sisa),batch_size=128,shuffle=True)
m=CNN().to(device)
train_model(m,dl,epochs=2)
sisa_unlearn_time=time.time()-start

# For ADISA (retrain one cluster)
deleted_cluster=0
idx_del=np.where(clusters==deleted_cluster)[0]
X_del_adisa,y_del_adisa=X_train_t[idx_del],y_train_t[idx_del]
dl=DataLoader(TensorDataset(X_del_adisa,y_del_adisa),batch_size=128,shuffle=True)
m=CNN().to(device)
start=time.time()
train_model(m,dl,epochs=2)
adisa_unlearn_time=time.time()-start

print(f"SISA Unlearning Time: {sisa_unlearn_time:.2f}s")
print(f"ADISA Unlearning Time: {adisa_unlearn_time:.2f}s")


# PHASE 6: Visualization
labels=["SISA","ADISA"]
ccs=[sisa_acc,adisa_acc]
times=[sisa_unlearn_time,adisa_unlearn_time]

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.bar(labels,ccs,color=["skyblue","lightgreen"])
plt.ylim(0.85,1.0)
plt.ylabel("Accuracy")
plt.title("Accuracy Comparison (SISA vs ADISA)")

plt.subplot(1,2,2)
plt.bar(labels,times,color=["skyblue","lightgreen"])
plt.ylabel("Unlearning Time (seconds)")
plt.title("Unlearning Time Comparison")
plt.show()


# Final Summary
print("\nFINAL RESULTS ")
print(f"SISA → Accuracy {sisa_acc*100:.2f}%, Unlearning Time {sisa_unlearn_time:.2f}s")
print(f"ADISA → Accuracy {adisa_acc*100:.2f}%, Unlearning Time {adisa_unlearn_time:.2f}s")