In [2]:


import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn
import numpy as np
import time
import copy
from torch.utils.data import TensorDataset
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
from scipy.ndimage import rotate as scipyrotate
from torchvision import datasets, transforms
import random
import matplotlib.pyplot as plt
import time
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
from sklearn import linear_model, model_selection
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
from sklearn import linear_model, model_selection
import torchvision.models as models
from sklearn.cluster import KMeans

# Define your network architecture (MLP)
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size*4)
        self.fc2 = nn.Linear(hidden_size*4, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)


    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten input
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        y = self.fc3(x)
        return y


    def feature(self,x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x




def simple_mia(sample_loss, members, n_splits=10, random_state=0):
    unique_members = np.unique(members)
    if not np.all(unique_members == np.array([0, 1])):
        raise ValueError("members should only have 0 and 1s")

    attack_model = linear_model.LogisticRegression()
    cv = model_selection.StratifiedShuffleSplit(
        n_splits=n_splits, random_state=random_state
    )
    return model_selection.cross_val_score(
        attack_model, sample_loss, members, cv=cv, scoring="accuracy"
    )


def testing_losses(model, distill_loader, device):
    model.to(device)
    model.eval()

    criterion = nn.CrossEntropyLoss(reduction='none')

    losses = []

    with torch.no_grad():
        for inputs, labels in distill_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            losses.append(loss.detach().cpu().numpy())

    losses = np.concatenate(losses)
    return losses



def measure_mia(model, forget_loader, test_loader):
    forget_losses=testing_losses(model, forget_loader, device)
    test_losses=testing_losses(model, test_loader, device)

    np.random.shuffle(forget_losses)
    stack_size=min([len(forget_losses), len(test_losses)])
    forget_losses = forget_losses[: stack_size]
    test_losses = test_losses[: stack_size]

    samples_mia = np.concatenate((test_losses, forget_losses)).reshape((-1, 1))
    labels_mia = [0] * len(test_losses) + [1] * len(forget_losses)

    mia_cands=[]
    for i in range(20):
        mia_scores = simple_mia(samples_mia, labels_mia)
        mia_cands.append(mia_scores.mean())

    mia_score=np.min(mia_cands)

    return mia_score

def test(model, data_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

def get_time():
    return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))

batch_size = 256
num_classes = 10
batch_real = 256
ipc = 10   # according to authors, recommended outer_loop, inner_loop = 10, 50
channel = 3
im_size = (32, 32)
hidden_size=128

mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

dst_train = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dst_train, batch_size=128*4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=128*4,
                                         shuffle=False, num_workers=2)



device = "cuda:0" if torch.cuda.is_available() else "cpu"
im_size = (32, 32)
channel = 3

#------------------Train the Net--------------------------------
net= MLP(input_size=channel * im_size[0] * im_size[1], hidden_size=hidden_size, output_size=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(), lr=1e-3)
print(get_time(), 'Start training Original network')
# Teacher training before starting the dataset distillation process
for epochy in range(20):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        output = net(data) 
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


print(get_time(), 'Finished training Original network')
del loss
del output
net.eval()

torch.save(net.state_dict(), 'pretrained_net.pth')


Files already downloaded and verified
Files already downloaded and verified
[2023-10-25 16:32:57] Start training Original network
[2023-10-25 16:34:16] Finished training Original network


In [3]:
def extract_features(model, dataloader, device):
    features = []
    labels = []
    
    model.eval()
    with torch.no_grad():
        for data, label in dataloader:
            data = data.to(device)
            # feature = model(data)
            feature = model.feature(data)
            feature = feature.view(feature.size(0), -1)  # Flatten spatial dimensions
            features.append(feature.cpu())
            labels.append(label)
    
    return torch.cat(features, 0), torch.cat(labels, 0)


def create_sub_classes(tensor, labels, model, num_classes=10, sub_divisions=10):
    new_labels = torch.zeros_like(labels)
    original_labels_dict = {}
    
    # Load the pretrained model for feature extraction
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Create a DataLoader to facilitate feature extraction
    dataset = TensorDataset(tensor, labels)
    loader = DataLoader(dataset, batch_size=256, shuffle=False)
    
    # Extract features
    features, _ = extract_features(model, loader, device)
    
    for i in range(num_classes):
        mask = labels == i
        class_features = features[mask]
        
        # Apply k-means clustering
        kmeans = KMeans(n_clusters=sub_divisions).fit(class_features)
        class_new_labels = torch.tensor(kmeans.labels_, dtype=torch.long)
        
        # Assign new labels
        new_subclass_labels = i * sub_divisions + class_new_labels
        new_labels[mask] = new_subclass_labels

        # Store original label reference
        for j in range(sub_divisions):
            original_labels_dict[int(i * sub_divisions + j)] = i
    
    return new_labels, original_labels_dict

In [7]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision.utils import make_grid



net= MLP(input_size=channel * im_size[0] * im_size[1], hidden_size=128, output_size=num_classes)
net.load_state_dict(torch.load('pretrained_net.pth'))
net.eval()
net.to(device)


#--------Hyperparameters-----------------------------------------------
condense_iterations = 10
num_classes = 10
batch_real = 5000
ipc = 10
net.to(device)
lambdy_disentang = 0.0
final_model_epochs = 20
databank_model_epochs = 20
lr_final=1e-4
lr_databank=1e-5
split_ratio = 0.1   # forget-retain split ratio
n_classes=10
n_subclasses=100
#----------------------------------------------------------------------


class Beginning(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Beginning, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size*4)

    def forward(self,x):
        x=x.view(x.size(0), -1)
        x=F.relu(self.fc1(x))
        return x

class Intermediate(nn.Module):
    def __init__(self,hidden_size):
        super(Intermediate, self).__init__()
        self.fc2 = nn.Linear(hidden_size*4, hidden_size)

    def forward(self,x):
        x=F.relu(self.fc2(x))
        return x


class Final(nn.Module):
    def __init__(self,hidden_size, num_classes):
        super(Final, self).__init__()
        self.fc3 = nn.Linear(hidden_size, num_classes)

    def forward(self,x):
        x=F.relu(self.fc3(x))
        return x


class OmegaFinal(nn.Module):
    def __init__(self,hidden_size, num_classes):
        super(OmegaFinal, self).__init__()
        self.fc31=nn.Sequential(nn.Linear(hidden_size, hidden_size*2), nn.ReLU())
        self.fc32=nn.Sequential(nn.Linear(hidden_size*2, hidden_size), nn.ReLU())
        self.fc33=nn.Sequential(nn.Linear(hidden_size, num_classes), nn.ReLU())

    def forward(self,x):
        x=self.fc31(x)
        x=self.fc32(x)
        x=self.fc33(x)
        return x


class Databank(nn.Module):
    def __init__(self,beggining,intermediate):
        super(Databank, self).__init__()
        self.beggining=beggining
        self.intermediate=intermediate

    def forward(self,x):
        x=self.beggining(x)
        x=self.intermediate(x)
        return x

    def hidden(self,x):
        x=self.beggining(x)
        return x


class CombinedModel(nn.Module):
    def __init__(self, databank, final):
        super(CombinedModel, self).__init__()
        self.databank=databank
        self.final=final

    def forward(self, x):
        x=self.databank(x)
        x=self.final(x)
        return x


def rho_loss(rho, data_rho, size_average=True):
    dkl = - rho * torch.log(data_rho) - (1-rho) * torch.log(1-data_rho)
    if size_average:
        return dkl.mean()
    else:
        return dkl.sum()


# ref: https://openreview.net/pdf?id=ByzvHagA-

# Covariance Regularization Loss
def covariance_regularizer(H):
    # H: [N, p] tensor containing activations for N examples and p features
    N, p = H.size()
    H_mean = torch.mean(H, dim=0)  # Mean activation [p]
    H_centered = H - H_mean  # Subtract the mean [N, p]
    cov_matrix = 1 / (N - 1) * torch.matmul(H_centered.t(), H_centered)  # [p, p]
    L1_norm_cov = torch.sum(torch.abs(cov_matrix))  # Sum of absolute values of covariance matrix
    L_sigma = L1_norm_cov / (p ** 2)  # Normalize by p^2
    return L_sigma




class TensorDataset(Dataset):
    def __init__(self, images, labels): # images: n x c x h x w tensor
        self.images = images.detach().float()
        self.labels = labels.detach()

    def __getitem__(self, index):
        return self.images[index], self.labels[index]

    def __len__(self):
        return self.images.shape[0]
    


''' organize the real dataset '''
images_all = []
labels_all = []
indices_class = [[] for c in range(num_classes)]

images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
labels_all = [dst_train[i][1] for i in range(len(dst_train))]
for i, lab in enumerate(labels_all):
    indices_class[lab].append(i)
images_all = torch.cat(images_all, dim=0).to(device)
labels_all = torch.tensor(labels_all, dtype=torch.long, device=device)

def get_images(c, n): # get random n images from class c
    idx_shuffle = np.random.permutation(indices_class[c])[:n]
    return images_all[idx_shuffle]

IMG_real=[]
LAB_real=[]
for c in range(num_classes):
    IMG_real.append(get_images(c, batch_real))
    LAB_real.append(torch.ones(batch_real, dtype=torch.long, device=device)*c)

IMG_real=torch.cat(IMG_real, dim=0)
LAB_real=torch.cat(LAB_real, dim=0)

print('Total real images: %d'%len(IMG_real))
print('Real Images per class: %d'%batch_real)

''' initialize the synthetic data from random noise '''
image_syn = torch.randn(size=(num_classes*ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=device)
label_syn = torch.tensor([np.ones(ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

''' copy the real data to synthetic data for initialization'''
for c in range(num_classes):
    image_syn.data[c*ipc:(c+1)*ipc] = get_images(c, ipc).detach().data

print('Total synthetic images: %d'%len(image_syn))
print('Synthetic Images per class: %d'%ipc)

print("Compression Rate: %.2f"%(len(IMG_real)/len(image_syn)))

''' training '''
print("\n")
print('%s Condensation protocol begins'%get_time())
print("\n")


beggining=Beginning(input_size=channel * im_size[0] * im_size[1], hidden_size=128).to(device)
beggining.fc1.weight.data = net.fc1.weight.data
beggining.fc1.bias.data = net.fc1.bias.data

intermediate=Intermediate(hidden_size=128).to(device)
intermediate.fc2.weight.data = net.fc2.weight.data
intermediate.fc2.bias.data = net.fc2.bias.data

databank=Databank(beggining=beggining, intermediate=intermediate).to(device)
img_real_list=[]
lab_real_list=[]
img_real_sampled_list=[]
lab_real_sampled_list=[]


for c in range(num_classes):
    img_real = IMG_real[c * batch_real: (c + 1) * batch_real].clone().detach()
    lab_real = LAB_real[c * batch_real: (c + 1) * batch_real].clone().detach()
    img_real_list.append(img_real)
    lab_real_list.append(lab_real)
    
    rand_idices=torch.randperm(img_real.shape[0])[:ipc]
    sampled_img_real = img_real[rand_idices]
    img_real_sampled_list.append(sampled_img_real)
    lab_real_sampled_list.append(torch.ones(sampled_img_real.shape[0], dtype=torch.long, device=device)*c)


img_real_data=torch.cat(img_real_list, dim=0)
lab_real_data=torch.cat(lab_real_list, dim=0)
img_real_sampled_data=torch.cat(img_real_sampled_list, dim=0)
lab_real_sampled_data=torch.cat(lab_real_sampled_list, dim=0)


img_real_data_dataset=TensorDataset(img_real_data.clone().detach().cpu(), lab_real_data.clone().detach().cpu())


img_syn_dataset=TensorDataset(img_real_sampled_data.clone().detach().cpu(), lab_real_sampled_data.clone().detach().cpu())
img_syn_loader=torch.utils.data.DataLoader(img_syn_dataset, batch_size=32, shuffle=True)


# Assuming img_real_data_dataset is predefined
dataset_size = len(img_real_data_dataset)

train_images=img_real_data_dataset.images
train_labels=img_real_data_dataset.labels


new_lab_train, original_labels_dict = create_sub_classes(train_images, train_labels, model=net, num_classes=n_classes, sub_divisions=n_subclasses)


indices = list(range(dataset_size))


#--------------If I want to forget a class then better comment this===================
# shuffling the indices
torch.manual_seed(42)  # for reproducibility
indices = torch.randperm(dataset_size)
#--------------If I want to forget a class then better comment this===================


new_lab_train= new_lab_train[indices]
train_images = train_images[indices]
train_labels = train_labels[indices]



bucket_dataset_train=TensorDataset(train_images, new_lab_train)
bucket_train_loader=DataLoader(bucket_dataset_train, batch_size=128, shuffle=True)

img_real_data_dataset=TensorDataset(train_images, train_labels)
img_real_data_loader=torch.utils.data.DataLoader(img_real_data_dataset, batch_size=64*4, shuffle=True)

# Define split ratio and sizes
split = int(split_ratio * dataset_size)

# Split indices into two parts
forget_indices = indices[:split]
retain_indices = indices[split:]


forget_images=train_images[forget_indices]
forget_labels=train_labels[forget_indices]

retain_images=train_images[retain_indices]
retain_labels=train_labels[retain_indices]


forget_set_real=TensorDataset(forget_images, forget_labels)
retain_set_real=TensorDataset(retain_images, retain_labels)

# Now you can create your dataloaders as before
forget_loader = torch.utils.data.DataLoader(forget_set_real, batch_size=64*2, shuffle=True)
retain_loader = torch.utils.data.DataLoader(retain_set_real, batch_size=128, shuffle=True)




for it in range(condense_iterations):

    for param in list(databank.parameters()):
        param.requires_grad = False

    final=Final(hidden_size=hidden_size, num_classes=num_classes).to(device)
    final.fc3.weight.data = net.fc3.weight.data
    final.fc3.bias.data = net.fc3.bias.data

    # final=OmegaFinal(hidden_size=hidden_size, num_classes=num_classes).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer_final=torch.optim.Adam(final.parameters(), lr=lr_final)
    

    # training of final part
    for ep in range(final_model_epochs):
        run_loss=0.0
        for batch in img_syn_loader:
            img_syn_buffer, label_img_syn_buffer= batch
            img_syn_buffer=img_syn_buffer.to(device)
            label_img_syn_buffer=label_img_syn_buffer.to(device)
            decoded_img_syn_buffer=databank(img_syn_buffer)
            output=final(decoded_img_syn_buffer)
            loss=criterion(output, label_img_syn_buffer)
            optimizer_final.zero_grad()
            loss.backward()
            optimizer_final.step()
            run_loss+=loss.item()



    del optimizer_final
    del loss

    # make final non-trainable
    for param in list(final.parameters()):
        param.requires_grad = False

    # make databank's beggining non-trainable
    for param in list(databank.beggining.parameters()):
        param.requires_grad = False

    # make databank's intermediate trainable
    for param in list(databank.intermediate.parameters()):
        param.requires_grad = True


    optimizer_databank=torch.optim.Adam(databank.parameters(), lr=lr_databank)

    # training the databank's intermediate part
    for ep in range(databank_model_epochs):
        run_loss=0.0
        for batch in img_real_data_loader:
            img_real_buffer, label_img_real_buffer= batch
            img_real_buffer=img_real_buffer.to(device)
            label_img_real_buffer=label_img_real_buffer.to(device)
            approx_img_real_buffer=databank(img_real_buffer)
            hidden = databank.hidden(img_real_buffer)
            output=final(approx_img_real_buffer)
            loss=criterion(output, label_img_real_buffer)
            L_sigma = covariance_regularizer(hidden)   # reduce entanglement
            total_loss = loss + lambdy_disentang*L_sigma
            optimizer_databank.zero_grad()
            total_loss.backward()
            optimizer_databank.step()
            run_loss+=loss.item()

        if ep==0:
            print("Loss associated with databank (it's intermediate): %.3f"%(run_loss/len(img_real_data_loader)))

        if ep==databank_model_epochs-1:
            combined_model=CombinedModel(databank=databank, final=final).to(device)
            with torch.no_grad():
                combined_retrain_acc=test(combined_model, retain_loader, device)
                combined_forget_acc=test(combined_model, forget_loader, device)

                print("Combined model's accuracy on retain set: %.2f"%combined_retrain_acc)
                print("Combined model's accuracy on forget set: %.2f"%combined_forget_acc)
                


    # finally make databank's intermediate non-trainable
    for param in list(databank.intermediate.parameters()):
        param.requires_grad = False



print("\n")
print('%s Condensation Protocol ends'%get_time())
print("\n")


print("\nSaving the databank")
torch.save(databank.state_dict(), 'databank.pth')


print("\n Saving the final model")
torch.save(final.state_dict(), 'final.pth')

Total real images: 50000
Real Images per class: 5000
Total synthetic images: 100
Synthetic Images per class: 10
Compression Rate: 500.00


[2023-10-25 16:49:33] Condensation protocol begins


Loss associated with databank (it's intermediate): 0.308
Combined model's accuracy on retain set: 95.33
Combined model's accuracy on forget set: 95.32
Loss associated with databank (it's intermediate): 0.195
Combined model's accuracy on retain set: 96.09
Combined model's accuracy on forget set: 96.20
Loss associated with databank (it's intermediate): 0.177
Combined model's accuracy on retain set: 96.51
Combined model's accuracy on forget set: 96.78
Loss associated with databank (it's intermediate): 0.167
Combined model's accuracy on retain set: 96.72
Combined model's accuracy on forget set: 97.06
Loss associated with databank (it's intermediate): 0.159
Combined model's accuracy on retain set: 96.92
Combined model's accuracy on forget set: 97.32
Loss associated with databank (it's intermediate): 0.

In [8]:
# Defining a new module for weighted average
class WeightedAverage(nn.Module):
    def __init__(self, num_batches):
        super(WeightedAverage, self).__init__()
        self.weights = nn.Parameter(1/num_batches*torch.ones(num_batches, device=device))
        # self.fc1 = nn.Linear(128, 128)

    def forward(self, imgs):
        imgs = imgs.view(imgs.shape[0], -1)
        weighted_imgs = imgs * self.weights.view(-1, 1)
        weighted_imgs = torch.sum(weighted_imgs, dim=0, keepdim=True)
        # pro_weighted_avg = F.relu(self.fc1(weighted_avg))
        weighted_imgs = weighted_imgs.reshape(1, 3, 32, 32)
        return weighted_imgs




def Average(ref_imgs_all, pretrained=net, num_epochs=100):

    ref_imgs_all=ref_imgs_all.to(device)
    
    weighted_avg_module = WeightedAverage(num_batches=ref_imgs_all.shape[0]).to(device)
    optim_weighted_avg = torch.optim.Adam(weighted_avg_module.parameters(), lr=1e-3)

    ref_features= pretrained.feature(ref_imgs_all).detach()

    for ep in range(num_epochs):
        fused_img= weighted_avg_module(ref_imgs_all)
        fused_img_features= pretrained.feature(fused_img)
        loss=torch.sum((torch.mean(ref_features, dim=0) - torch.mean(fused_img_features, dim=0))**2)
        optim_weighted_avg.zero_grad()
        loss.backward()
        optim_weighted_avg.step()



    averaged_img=weighted_avg_module(ref_imgs_all).detach()

    return averaged_img



img_shape=(1,3,32,32)  # Shape of the image
inverted_IMG=[]
inverted_LABEL=[]
indices_train_wrt_bucket=[]

bucket_labbies=torch.unique(new_lab_train).tolist()

for idx in bucket_labbies:

    indices_idx = torch.where(new_lab_train.to(device)==idx)[0]

    indices_train_wrt_bucket.append(indices_idx.cpu())


    ref_imgs_all = train_images[indices_idx.cpu()]
    ref_labs_all= train_labels[indices_idx.cpu()]

    inverted_image = Average(ref_imgs_all, pretrained=net, num_epochs=100)

    inverted_IMG.append(inverted_image)
    inverted_LABEL.append(ref_labs_all[0])
 

    # print percentage of idx covered on same line
    print('\r','Condensation Progress: ', (idx+1)*100/(n_classes*n_subclasses), '%', end='')

inverted_IMG=torch.cat(inverted_IMG, dim=0).cpu()
inverted_LABEL=torch.tensor(inverted_LABEL).cpu()
 


condensed_loader=torch.utils.data.DataLoader(TensorDataset(inverted_IMG, inverted_LABEL), batch_size=128, shuffle=True)


indices_collector=[]
not_safe_zones=[]    # foret has been found here so corresponding condensed is of no use (and i have to residue it)


# in bucketting system, find the retain indices where forget is found
for i in range(len(forget_indices)):
    for j in range(len(indices_train_wrt_bucket)):

        if forget_indices[i] in indices_train_wrt_bucket[j]:

            not_safe_zones.append(j)

            #find the indexes of the indices_train_wrt_bucket[j] that are not equal to forget_labels[i]
            false_indices = [idx for idx, val in enumerate(indices_train_wrt_bucket[j]) if val != forget_indices[i]]

            # only select those false_indices that lie in retain_indices
            false_indices_subset = [x for x in false_indices if indices_train_wrt_bucket[j][x] in retain_indices]


            if len(false_indices_subset)!=0:
                indices_collector.append(indices_train_wrt_bucket[j][false_indices_subset].tolist())

            break

    print('\r','Residual Collection Progress: ', i*100/(len(forget_images)), '%', end='')


print("\n\nSize of each bucket: ", int(len(img_real_data_dataset)/len(bucket_labbies)))

not_safe_zones=torch.tensor(not_safe_zones)
not_safe_zones=torch.unique(not_safe_zones).tolist()
print("\n\nFaulty Buckets: ", len(not_safe_zones), '/', len(bucket_labbies))

# convert indices_collector to a flat list 
possible_retain_sols = [item for sublist in indices_collector for item in sublist]

retain_sols=torch.unique(torch.tensor(possible_retain_sols)).tolist()


residual_retain_imgs=train_images[retain_sols]
residual_retain_labels=train_labels[retain_sols]


print("\n\nResidual Retain Images (in bucketting system where retain was found alongside with forget): ", len(residual_retain_imgs))



total_retain_imgs=[]
total_retain_labs=[]

total_retain_imgs.append(residual_retain_imgs)
total_retain_labs.append(residual_retain_labels)


# safe zone is indices in indices_train_wrt_bucket that are not in not_safe_zones
safe_zone=[]
for i in range(len(indices_train_wrt_bucket)):
    if i not in not_safe_zones:
        safe_zone.append(i)


if len(safe_zone)!=0:
    safe_zone=torch.tensor(safe_zone)
    safe_zone=torch.unique(safe_zone)

    condensed_retain_imgs=inverted_IMG[safe_zone]
    condensed_retain_labels=inverted_LABEL[safe_zone]

    total_retain_imgs.append(condensed_retain_imgs)
    total_retain_labs.append(condensed_retain_labels)

    print("Size of usable condensed images: ", len(condensed_retain_imgs), '/', len(inverted_IMG), 'buckets')


total_retain_imgs=torch.cat(total_retain_imgs, dim=0)
total_retain_labs=torch.cat(total_retain_labs, dim=0)

print("---------------------------------------------------")
print(">> Total size of Reduced Retain Set: ", len(total_retain_imgs))
print(">> Reference size of naive retain loader:", len(retain_loader.dataset))
print(">> Retain Compression Ratio (>=1): %.2f"%(len(retain_loader.dataset)/len(total_retain_imgs)))
print("---------------------------------------------------")

reduced_retain_loader=torch.utils.data.DataLoader(TensorDataset(total_retain_imgs, total_retain_labs), batch_size=128, shuffle=True)


 Residual Collection Progress:  99.98 %

Size of each bucket:  50


Faulty Buckets:  946 / 1000


Residual Retain Images (in bucketting system where retain was found alongside with forget):  44181
Size of usable condensed images:  54 / 1000 buckets
---------------------------------------------------
>> Total size of Reduced Retain Set:  44235
>> Reference size of naive retain loader: 45000
>> Retain Compression Ratio (>=1): 1.02
---------------------------------------------------


In [9]:
rtrryt

NameError: name 'rtrryt' is not defined

In [None]:

class CombinedModel(nn.Module):
    def __init__(self, databank, final):
        super(CombinedModel, self).__init__()
        self.databank = databank
        self.final = final

    def forward(self, x):
        x = self.databank(x)
        x = self.final(x)
        return x


overture_epochs=10
beggining_epochs=2
final_epochs=50
intermediate_epochs= 5
final_thr=2   # intended for blocking the final training in overture, from the end of overture epochs--> improves retain acc while preserving forget accuracy



main_ep_thr=2
second_ep_thr=2

beggining=Beginning(input_size=channel * im_size[0] * im_size[1], hidden_size=128).to(device)
intermediate=Intermediate(hidden_size=128).to(device)
final=Final(hidden_size=hidden_size, num_classes=num_classes).to(device)

data_bank=Databank(beggining=beggining, intermediate=intermediate).to(device)
data_bank.load_state_dict(torch.load('databank.pth'))

final.load_state_dict(torch.load('final.pth'))
combined_model=CombinedModel(databank=data_bank, final=final).to(device)

retain_acc=test(combined_model, retain_loader, device)
forget_acc=test(combined_model, forget_loader, device)
mia_score=measure_mia(combined_model, forget_loader, test_loader)
print("Pre Retain Accuracy: %.2f %%"%(retain_acc))
print("Pre Forget Accuracy: %.2f %%"%(forget_acc))
print("Pre MIA Score: %.2f %%"%(mia_score))





optim_model=torch.optim.Adam(combined_model.parameters(), lr=1e-3)
#--------------------------------------------------------




#=============================Adding Noise =============================
lambd=0.1


# Calculate the Fisher Information Matrix for each parameter in base_model
fisher_information = {}
for name, param in combined_model.named_parameters():
    fisher_information[name] = torch.zeros_like(param).to(device)

# Assume we use a single datapoint to calculate the Fisher Information
# Usually, you would use a dataset or a subset
for i, (inputs, labels) in enumerate(forget_loader):
    inputs, labels = inputs.to(device), labels.to(device)
    combined_model.zero_grad()
    outputs = combined_model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()

    for name, param in combined_model.named_parameters():
        fisher_information[name] += param.grad ** 2 / len(forget_loader)


# Save the optimal parameters for distill_loader
optimal_params = {}
for name, param in combined_model.named_parameters():
    optimal_params[name] = param.clone()


# Normalize and binarize the Fisher Information
threshold = 0.5  # Choose an appropriate threshold

for name, param in fisher_information.items():
    # Normalizing by dividing each entry by the maximum value
    param /= torch.max(param)

    # Binarizing by applying a threshold
    param[param < threshold] = 0
    param[param >= threshold] = 1


for name, param in combined_model.named_parameters():
    noise = torch.randn_like(param)
    noise *= lambd*fisher_information[name]
    param.data += noise

#=========================================================================


print(get_time(), 'Start training the combined model')



for main_ep in range(overture_epochs):

    for param in list(combined_model.databank.beggining.parameters()):
        param.requires_grad = True

    for param in list(combined_model.databank.intermediate.parameters()):
        param.requires_grad = False

    for param in list(combined_model.final.parameters()):
        param.requires_grad = False


    for _ in range(beggining_epochs):
        for batch in reduced_retain_loader:
            img,lab=batch
            img,lab=img.to(device), lab.to(device)
            output=combined_model(img)
            loss=criterion(output, lab)

            optim_model.zero_grad()
            loss.backward()

            optim_model.step()



    for param in list(combined_model.parameters()):
        param.requires_grad = False

    for param in list(combined_model.final.parameters()):
        param.requires_grad = True


    if main_ep<overture_epochs-final_thr:

        for epi in range(final_epochs):
            distill_loss=0.0
            for batch in img_syn_loader:
                img,lab=batch
                img,lab=img.to(device), lab.to(device)
                output=combined_model(img)
                loss=criterion(output, lab)
                distill_loss+=loss
            distill_loss/=len(img_syn_loader)

            lhs_loss=distill_loss

            optim_model.zero_grad()
            lhs_loss.backward()
            optim_model.step()

    for param in list(combined_model.parameters()):
        param.requires_grad = False


#---------------------just training the intermediate--------------------------


for param in list(combined_model.databank.intermediate.parameters()):
    param.requires_grad = True


for second_ep in range(intermediate_epochs):
    for batch in reduced_retain_loader:
        img,lab=batch
        img,lab=img.to(device), lab.to(device)
        output=combined_model(img)
        loss=criterion(output, lab)


        optim_model.zero_grad()
        loss.backward()
        optim_model.step()



retain_acc=test(combined_model, retain_loader, device)
forget_acc=test(combined_model, forget_loader, device)
mia_score=measure_mia(combined_model, forget_loader, test_loader)
print("Projected Retain Accuracy: %.2f %%"%(retain_acc))
print("Projected Forget Accuracy: %.2f %%"%(forget_acc))
print("Projected MIA Score: %.2f %%"%(mia_score))
print(get_time(), 'Ending training the combined model')

: 

In [None]:
pretrained_net=MLP(input_size=channel * im_size[0] * im_size[1], hidden_size=128, output_size=num_classes)
pretrained_net.load_state_dict(torch.load('pretrained_net.pth'))
pretrained_net.to(device)

optim_pretrained=torch.optim.Adam(pretrained_net.parameters(), lr=1e-3)
print(get_time(), 'Start training the pretrained model')
retraining_epochs=30
for _ in range(retraining_epochs):
    for batch in retain_loader:
        img,lab=batch
        img,lab=img.to(device), lab.to(device)
        output=pretrained_net(img)
        loss=criterion(output, lab)
        optim_pretrained.zero_grad()
        loss.backward()
        optim_pretrained.step()

retain_acc=test(pretrained_net, retain_loader, device)
forget_acc=test(pretrained_net, forget_loader, device)
mia_score=measure_mia(pretrained_net, forget_loader, test_loader)
print("Pretrain Retraining Retain Accuracy: %.2f %%"%(retain_acc))
print("Pretrained Retraining Forget Accuracy: %.2f %%"%(forget_acc))
print("MIA Score: %.2f %%"%(mia_score))

print(get_time(), 'Ending training the pretrained model')

: 