In [None]:
import torch
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.utils.data import DataLoader, SubsetRandomSampler, TensorDataset
from torchmetrics import ConfusionMatrix
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss
from torch.autograd import Variable
from scipy import ndimage
import copy
import random
import time
import os
from collections import Counter
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
torch.set_printoptions(precision=3)
if torch.cuda.is_available() == True:
    device = 'cuda:1'    
else:
    device = 'cpu'

np.random.seed(44)
torch.manual_seed(44)

Models

In [None]:

class CustomResNet18(nn.Module):
    def __init__(self, in_channels, num_classes=10):
        super(CustomResNet18, self).__init__()
        resnet = models.resnet18()
        # Change the first convolutional layer for MNIST
        resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        # Change the last fully connected layer for MNIST
        resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)
        self.resnet = resnet

    def forward(self, x):
        return self.resnet(x)
    

class MLPNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLPNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)  
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

helper Functions

In [None]:
def get_hard_labels(inputs, expected_distribution):
  # number of samples, number of classes
  assert inputs.ndim == 2
  assert expected_distribution.ndim == 1
  assert np.sum(expected_distribution) == 1.0
  n = inputs.shape[0]
  d = inputs.shape[1]
  # A large constant
  CONSTANT = 10000.0
  # some hacks to get algorithm working without custom datasets.
  augmented_inputs = np.zeros((n, d + 1), dtype=np.float32)
  augmented_inputs[:, :-1] = inputs
  augmented_inputs[:, -1] = -CONSTANT #augmented inputs with the last column denoting the lowest value for all samples shape=(n,d+1)
  standard_inputs = np.zeros_like(augmented_inputs[0,:])
  standard_inputs[-1] = CONSTANT #standard input with shape=(d+1,) with the last entry being the highest
  expected_counts = n * expected_distribution #expected no of samples in each class
  expected_counts = expected_counts.astype(np.int32)
#   print(expected_counts)

  ## loop to make the expected count sum=n
  index = 0
  while np.sum(expected_counts) < n:
    expected_counts[index] += 1
    index += 1
    if index >= n:
      index = 0
#   print(expected_counts)

  hard_samples_by_class = [0 for _ in range(d)]
  sampled_ids = []
  sample_to_hard_id = {}
  while sum(hard_samples_by_class) < n:
    nonsampled_ids = []
    max_inputs = np.max(augmented_inputs, axis=-1) #max values across columns for each datapoint shape(n,)
    max_inputs_class = np.argmax(augmented_inputs, axis=-1) #index of the max values across columns or class for each datapoint shape(n,) 
    sorted_sampels = np.flip(np.argsort(max_inputs)) #index of the highest to lowest among max values across datapoint/rows shape(n,)
  
    for i in range(n):
      current_sample_id = sorted_sampels[i] #index of the highest to lowest sample based on max values
      max_class = max_inputs_class[current_sample_id] #class of the current sample index  
      if max_class == d: #did not understand this logic
        continue
      # print(max_class)
      # print('done')
      if hard_samples_by_class[max_class] < expected_counts[max_class]:
        hard_samples_by_class[max_class] += 1
        sample_to_hard_id[current_sample_id] = max_class
        sampled_ids.append(current_sample_id)
      else:
        nonsampled_ids.append(current_sample_id)
      # print(sampled_ids)
      # print(nonsampled_ids)
      # print('hi')
      for k in sampled_ids:
        augmented_inputs[k, :] = standard_inputs
      for k in nonsampled_ids:
        augmented_inputs[k, max_inputs_class[k]] = -CONSTANT
  samples = np.ones(n, dtype=np.int32)
  for sample in sample_to_hard_id.keys():
    samples[sample] = sample_to_hard_id[sample]
  return samples



def max_pred_distribution(model, test_loader, display=True):
    """Predictive distribution using deterministic max"""
    model.eval()
    predictions= []
    targets = []
    pred_dist = []

    with torch.no_grad():  
      for data, target in test_loader:
        data = data.to(device)
        # target = target.to(device)
        output = model(data)
        output_prob = F.softmax(output, dim=1).to(device='cpu')
        pred_dist.append(output_prob.numpy())
        pred = output.data.max(1, keepdim=True)[1].to('cpu').numpy()
        # print(pred)
        predictions.append(pred)
        targets.append(target)
        data.detach().cpu()
        target.detach().cpu()

    predictions = np.concatenate(predictions, axis=0)
    # print(predictions.shape)
    predictions = predictions.reshape((predictions.shape[0],))
    # print(predictions)
    targets = torch.cat(targets, dim=0).to(device='cpu')
    # print(targets)
    
    if display:
        cm = confusion_matrix(targets, predictions, labels=[0,1,2,3,4,5,6,7,8,9], normalize='true')
        fig = plt.figure(figsize=(8,8))
        ax = fig.add_subplot(1,1,1)
        ConfusionMatrixDisplay(cm).plot(ax=ax)
    
    return cm,predictions



## Test and training functions
def test(model, loader, criterion, dname="Test set", printable=True):
  model.eval()
  test_loss = 0
  total = 0
  correct = 0
  with torch.no_grad():
    for data, target in loader:
      data = data.to(device)
      target = target.to(device)
      output = model(data)
      total += target.size()[0]
      test_loss += criterion(output, target).item()
      pred = output.data.max(1, keepdim=True)[1]
      # target = target.data.max(1, keepdim=True)[1]
      # print(pred.shape)
      # print(target.shape)
      correct += pred.eq(target.data.view_as(pred)).sum()
      data.detach().cpu()
      target.detach().cpu()
  # test_loss /= len(loader.dataset)
  if printable:
    print('{}: total test loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        dname, test_loss, correct, total, 
        100. * correct / total
        ))
  return correct / total, test_loss


def unlearn_train_hard_labels(model, epoch, thr_train_loader, nthr_train_loader, thr_test_loader, nthr_test_loader,
                  criterion, hard_labels, expected_distribution, optimizer, alpha=1, gamma=1, ref_class=3, num_classes=10):
  
  #change the labels of class-3 to other hard labels
  model.eval()
  ref_output_probs = []
  ref_hard_targets = []
  # print(len(thr_train_loader.dataset))
  # print(len(thr_train_loader)*(thr_train_loader.batch_size))
  for ref_data,_ in thr_train_loader:
    ref_data = ref_data.to(device)
    ref_output = model(ref_data)
    ref_output_prob = F.softmax(ref_output, dim=1).detach().to('cpu').numpy()
    # print('a',ref_output_prob.shape)
    # ref_target_hard = hard_labels(ref_output_prob, expected_distribution)
    # ref_hard_targets.append(ref_target_hard)
    ref_output_probs.append(ref_output_prob)
  
  ref_output_probs = np.concatenate(ref_output_probs, axis=0)
  ref_hard_targets = hard_labels(ref_output_probs, expected_distribution)
  # ref_hard_targets = np.concatenate(ref_hard_targets, axis=0)
  # print('b',ref_hard_targets.shape)
  print(Counter(ref_hard_targets))
  ref_hard_targets_loader = DataLoader(torch.tensor(ref_hard_targets), batch_size=64)
  
  #unlearning loop
  model.train()
  unlearn_train_loss = 0
  N = 0
  for ((nonref_data, nonref_target),(ref_data, _ ), hard_targets_batch) in zip(nthr_train_loader, thr_train_loader, ref_hard_targets_loader):
    
    #non-ref data and targets
    nonref_data = nonref_data.to(device)
    nonref_target = nonref_target.to(device)
    
    #ref data and soft targets
    ref_data = ref_data.to(device)
    hard_targets_batch = hard_targets_batch.type(torch.LongTensor)
    hard_targets_batch = hard_targets_batch.to('cuda:1')
    optimizer.zero_grad()
    nonref_output = model(nonref_data)
    ref_output = model(ref_data)
    nonref_loss = criterion(nonref_output, nonref_target)
    ref_loss = criterion(ref_output,hard_targets_batch)
    loss = alpha * nonref_loss + gamma * ref_loss
    loss.backward()
    optimizer.step()
    unlearn_train_loss += loss.item()
    N += len(nonref_data)
    nonref_data.detach().cpu()
    nonref_target.detach().cpu()
    ref_data.detach().cpu()
    # ref_hard_targets.detach().cpu()

  # avg_train_loss = train_loss/N
  print("Epoch: {} \ total train Loss: {:.6f}".format(
        epoch, unlearn_train_loss
    ))
  
  thr_test_acc, thr_test_loss = test(model, thr_test_loader, criterion=criterion, dname="Threes Test data", printable=True)
  nthr_test_acc, nthr_test_loss = test(model, nthr_test_loader, criterion=criterion, dname="Nonthree Test data", printable=True)

  return thr_test_acc, thr_test_loss, nthr_test_acc, nthr_test_loss, unlearn_train_loss



DataLoader and Pre-Processing

In [None]:
def pos_neg_datasplit(class_number, dataset):
    """Test data splitting"""
    # Test dataloader with 3's only
    test_threes_index = []
    test_nonthrees_index = []
    for i in range(0, len(dataset)):
      if dataset.targets[i] == class_number:
        test_threes_index.append(i)
      else:
        test_nonthrees_index.append(i)
    

    three_test_loader = DataLoader(dataset, batch_size=64,
                  sampler = SubsetRandomSampler(test_threes_index),drop_last=True)
    nonthree_test_loader = DataLoader(dataset, batch_size=64,
                  sampler = SubsetRandomSampler(test_nonthrees_index),drop_last=True)
    
    return three_test_loader, nonthree_test_loader

In [None]:
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,)),
                                ])

# Using MNIST
traindata = datasets.MNIST('/home/ece/Subhodip/data', download=False, train=True, transform=transform)
testdata = datasets.MNIST('/home/ece/Subhodip/data', download=False, train=False, transform=transform)


# Loaders that give 64 example batches
all_data_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64, shuffle=True)
all_data_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64, shuffle=True)


#pos and neg data splitting for train and test dataset
neg_test_loader, pos_test_loader = pos_neg_datasplit(class_number=3, dataset=testdata)
neg_train_loader, pos_train_loader = pos_neg_datasplit(class_number=3, dataset=traindata)

Forgetting Code

In [None]:
learning_rates=[0.0001,0.0005,0.001,0.005,0.01,0.04,0.16,0.64]
gammas=[0.25,0.5,1,2,4]
expected_dist_list = [1,1,1,0,1,1,1,1,1,1]
expected_pred_dist = np.array(expected_dist_list) * 1/9
neg_ind=3
gamma_list = []
lr_list = []
non_three_accuracies_list = []
l1_distance_list = []
for gamma in gammas:
    for lr in learning_rates:
        print(f"learning_rate:{lr} and gamma:{gamma}")
        forget_resnet = CustomResNet18(in_channels=1, num_classes=10).to(device)
        forget_optimizer = optim.Adam(forget_resnet.parameters(), lr=lr)
        path = F"../checkpoints/resnet/pretrained/trained.pt"
        checkpoint = torch.load(path)
        forget_resnet.load_state_dict(checkpoint['model_state_dict'])
        # forget_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        criterion = CrossEntropyLoss(reduction='mean')
        forgetfulepochs = 1


        unlearning_losses = {'total_unlearning_losses':[], 'thr_test_losses':[], 'nthr_test_losses':[]}
        unlearning_acces = {'thr_test_accs': [], 'nthr_test_accs': []}
        # Train model for forgetting
        for epoch in range(0,forgetfulepochs+1):
            if epoch == 0:
                thr_test_acc, thr_test_loss = test(forget_resnet, neg_test_loader,criterion=criterion, dname="Threes Test Data", printable=False)
                nthr_test_acc, nthr_test_loss = test(forget_resnet, pos_test_loader, criterion=criterion, dname="Nonthrees Test Data", printable=False)
                unlearning_loss = 1000
            else:
                
                thr_test_acc, thr_test_loss, nthr_test_acc, nthr_test_loss, unlearning_loss = unlearn_train_hard_labels(model=forget_resnet, epoch=epoch, 
                                                                            thr_train_loader=neg_train_loader,
                                                                            nthr_train_loader=pos_train_loader, 
                                                                            thr_test_loader=neg_test_loader,
                                                                            nthr_test_loader=pos_test_loader,criterion=criterion, 
                                                                            hard_labels = get_hard_labels, expected_distribution=expected_pred_dist,
                                                                            optimizer=forget_optimizer, gamma=gamma, ref_class=3, num_classes=10)

            unlearning_losses['total_unlearning_losses'].append(unlearning_loss)
            unlearning_losses['thr_test_losses'].append(thr_test_loss)
            unlearning_losses['nthr_test_losses'].append(nthr_test_loss)
            
            unlearning_acces['thr_test_accs'].append(thr_test_acc.to('cpu').item())
            unlearning_acces['nthr_test_accs'].append(nthr_test_acc.to('cpu').item())
            #for max_pred

        #accuracy on pos classes
        cm,_= max_pred_distribution(forget_resnet,test_loader=pos_test_loader)
        accuracies = [cm[i, i] for i in range(min(cm.shape))]
        print("Hard Labels")
        del accuracies[neg_ind]

        print("Non Three Accuracies:",np.mean(accuracies))
        target_label = neg_ind  # Replace with the label you are interested in

    # Get the column corresponding to the specified label
        # print(cm)
        cm2,_= max_pred_distribution(forget_resnet,test_loader=neg_test_loader)

        column_for_label = cm2[target_label]
        # print(column_for_label)
        l1_distance = np.sum(np.abs(column_for_label - expected_pred_dist))

        print("l1_distance",l1_distance)
        gamma_list.append(gamma)
        lr_list.append(lr)
        non_three_accuracies_list.append(np.mean(accuracies))
        l1_distance_list.append(l1_distance)
        


        
        # if epoch % 10 == 0:
        #     print('saving at epoch=',epoch)
        #     path = f"../checkpoints/resnet/label-smooth/gamma_{gamma}/labelsmooth-{epoch}-{lr}.pt"
        #     torch.save({ 
        #             'model_state_dict': forget_resnet.state_dict(),
        #             'optimizer_state_dict': forget_optimizer.state_dict(),
        #             }, path)
            
        # loss_history = pd.DataFrame(unlearning_losses)
        # loss_history.plot(figsize=(8,5))
        # acc_history = pd.DataFrame(unlearning_acces)
        # acc_history.plot(figsize=(8,5))


results_df = pd.DataFrame({
    'Gamma': gamma_list,
    'Learning Rate': lr_list,
    'Non Three Accuracies': non_three_accuracies_list,
    'L1 Distance': l1_distance_list
})
csv_filename = "./results/hard_labels_forgetting_epoch1.csv"
results_df.to_csv(csv_filename, index=False)    