# MNIST unlearning: переобучение с нуля  

In [1]:
import datetime
import yaml
from copy import deepcopy

import numpy as np
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
from sklearn import linear_model, model_selection

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from torchvision import datasets, transforms
import torchvision.models as cv_models

import sys
sys.path.append("..")
from src import *

2023-11-29 07:52:55.642634: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-11-29 07:52:55.666576: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Конфиг 

In [2]:
class CFG:
    dataset_name = "mnist"
    model_name = "resnet18"
    process_name = "unlearning_retrain"
    full_model_name = f"{dataset_name}_{model_name}_{process_name}"
    
    gpu_num = 0
    device = f"cuda:{gpu_num}" if torch.cuda.is_available() else "cpu"
    num_workers = 16
    
    batch_size = 32
    
    train_size = 50000
    val_size = 10000
    test_size = 10000
    
    retain_size = 40000
    forget_size = 10000
    
    max_epoch_num = 1000
    early_stopping_patience = 10
    lr = 1e-4
    
    save_best_model = True
    seed = 42
    
seed_everything(CFG.seed)

## Подготовка данных 

In [3]:
train_data, test_data = get_mnist_data()
train_dataloader, retain_dataloader, forget_dataloader, val_dataloader, test_dataloader = (
    get_dataloaders(
        train_data=train_data, test_data=test_data, 
        train_size=CFG.train_size, 
        retain_size=CFG.retain_size, forget_size=CFG.forget_size, 
        val_size=CFG.val_size, test_size=CFG.test_size, 
        seed=CFG.seed, 
        batch_size=CFG.batch_size, 
        num_workers=CFG.num_workers
    )
)

## Модель 

In [4]:
model = MnistModel(base_model=cv_models.resnet18())
model.to(CFG.device);

## Обучение 

In [5]:
loss_f = nn.CrossEntropyLoss()
accuracy_f = lambda pred, targets: float((pred.argmax(axis=1) == targets.to(CFG.device)).float().mean())

optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr)

In [6]:
curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = f"../logs/{CFG.dataset_name}/{CFG.model_name}/{CFG.process_name}/{curr_time}"
writer = SummaryWriter(log_dir=log_dir)

In [7]:
best_val_acc = 0
train_time = 0
for epoch_i in range(1, CFG.max_epoch_num + 1):
    ### TRAIN 
    model.train()
        
    mean_train_acc = 0
    train_batches_n = 0
    for batch_i, (inputs, targets) in enumerate(tqdm(retain_dataloader)):
        inputs, targets = inputs.to(CFG.device), targets.to(CFG.device)
        
        batch_start_time = datetime.datetime.now()
        
        pred = model(inputs)
        loss = loss_f(pred, targets)

        model.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_time += (datetime.datetime.now() - batch_start_time).total_seconds()
        
        mean_train_acc += accuracy_f(pred, targets)
        train_batches_n += 1

    mean_train_acc /= train_batches_n
    
    ### VAL
    
    mean_val_acc = validate_model(model, val_dataloader, accuracy_f, device=CFG.device)
    
    ### TENSORBOARD
    
    writer.add_scalars(
        "Accuracy",
        {"train": mean_train_acc, "val": mean_val_acc}, 
        global_step=epoch_i
    )
    writer.flush()

    ### SAVE BEST
    
    mean_train_acc_perc = round(float(mean_train_acc) * 100, 2)
    mean_val_acc_perc = round(float(mean_val_acc) * 100, 2)
    print(f"Epoch {epoch_i}: train_acc = {mean_train_acc_perc}%, val_acc = {mean_val_acc_perc}%", end="; ")

    if mean_val_acc > best_val_acc:
        best_epoch_i = epoch_i
        best_val_acc = mean_val_acc
        best_train_acc = mean_train_acc
        best_model = deepcopy(model)
        best_model_train_time = train_time
        
        print(f'new best model')
    elif epoch_i - best_epoch_i > CFG.early_stopping_patience:
        print(f'early stopping')
        break
    else:
        print("continue")

  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 1: train_acc = 94.39%, val_acc = 97.98%; new best model


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 2: train_acc = 97.86%, val_acc = 97.21%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 3: train_acc = 98.37%, val_acc = 98.19%; new best model


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 4: train_acc = 98.67%, val_acc = 98.43%; new best model


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 5: train_acc = 98.95%, val_acc = 97.21%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 6: train_acc = 99.04%, val_acc = 98.54%; new best model


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 7: train_acc = 99.22%, val_acc = 98.31%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 8: train_acc = 99.31%, val_acc = 98.96%; new best model


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 9: train_acc = 99.38%, val_acc = 98.75%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 10: train_acc = 99.41%, val_acc = 98.66%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 11: train_acc = 99.45%, val_acc = 98.88%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 12: train_acc = 99.59%, val_acc = 99.02%; new best model


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 13: train_acc = 99.62%, val_acc = 98.95%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 14: train_acc = 99.61%, val_acc = 99.05%; new best model


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 15: train_acc = 99.62%, val_acc = 98.55%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 16: train_acc = 99.73%, val_acc = 98.8%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 17: train_acc = 99.67%, val_acc = 99.05%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 18: train_acc = 99.67%, val_acc = 99.02%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 19: train_acc = 99.8%, val_acc = 99.0%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 20: train_acc = 99.68%, val_acc = 98.97%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 21: train_acc = 99.72%, val_acc = 98.79%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 22: train_acc = 99.78%, val_acc = 99.19%; new best model


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 23: train_acc = 99.76%, val_acc = 99.08%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 24: train_acc = 99.76%, val_acc = 98.98%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 25: train_acc = 99.8%, val_acc = 99.1%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 26: train_acc = 99.81%, val_acc = 98.74%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 27: train_acc = 99.75%, val_acc = 99.17%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 28: train_acc = 99.89%, val_acc = 99.19%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 29: train_acc = 99.85%, val_acc = 99.0%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 30: train_acc = 99.84%, val_acc = 98.95%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 31: train_acc = 99.8%, val_acc = 99.13%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 32: train_acc = 99.8%, val_acc = 98.88%; continue


  0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 33: train_acc = 99.88%, val_acc = 99.13%; early stopping


In [8]:
mean_forget_acc = validate_model(best_model, forget_dataloader, accuracy_f, device=CFG.device)
mean_test_acc = validate_model(best_model, test_dataloader, accuracy_f, device=CFG.device)

print(f"Best model retain acc = {round(best_train_acc * 100, 2)}%")
print(f"Best model forget acc = {round(mean_forget_acc * 100, 2)}%")
print(f"Best model val acc = {round(best_val_acc * 100, 2)}%")
print(f"Best model test acc = {round(mean_test_acc * 100, 2)}%")

test_acc_str = str(round(mean_test_acc * 100, 2)).replace(".", "_")
if CFG.save_best_model:
    best_model_path = f"../models/{CFG.full_model_name}_acc_{test_acc_str}.torch"
    torch.save(best_model.state_dict(), best_model_path)
    print(f"Best model saved: {best_model_path}")

Best model retain acc = 99.78%
Best model forget acc = 99.26%
Best model val acc = 99.19%
Best model test acc = 99.38%
Best model saved: ../models/mnist_resnet18_unlearning_retrain_acc_99_38.torch


## Анализ забывания

In [9]:
retain_losses = get_losses(best_model, retain_dataloader, device=CFG.device)
forget_losses = get_losses(best_model, forget_dataloader, device=CFG.device)
val_losses = get_losses(best_model, val_dataloader, device=CFG.device)
test_losses = get_losses(best_model, test_dataloader, device=CFG.device)

In [10]:
forget_vs_retain_mia = mia_score(forget_losses, retain_losses, seed=CFG.seed)
retain_vs_val_mia = mia_score(retain_losses, val_losses, seed=CFG.seed)
retain_vs_test_mia = mia_score(retain_losses, test_losses, seed=CFG.seed)

forget_vs_val_mia = mia_score(forget_losses, val_losses, seed=CFG.seed)
forget_vs_test_mia = mia_score(forget_losses, test_losses, seed=CFG.seed)

val_vs_test_mia = mia_score(val_losses, test_losses, seed=CFG.seed)

print(f"forget_vs_retain_mia = {round(forget_vs_retain_mia * 100, 2)}%")
print(f"retain_vs_val_mia = {round(retain_vs_val_mia * 100, 2)}%")
print(f"retain_vs_test_mia = {round(retain_vs_test_mia * 100, 2)}%")
print(f"forget_vs_val_mia = {round(forget_vs_val_mia * 100, 2)}%")
print(f"forget_vs_test_mia = {round(forget_vs_test_mia * 100, 2)}%")
print(f"val_vs_test_mia = {round(val_vs_test_mia * 100, 2)}%")

forget_vs_retain_mia = 50.65%
retain_vs_val_mia = 51.1%
retain_vs_test_mia = 50.95%
forget_vs_val_mia = 50.51%
forget_vs_test_mia = 49.76%
val_vs_test_mia = 50.16%


## Сохранение результатов 

In [100]:
research_info = {
    "dataset_name": CFG.dataset_name,
    "model_name": CFG.model_name,
    "process_name": CFG.process_name,
    "batch_size": CFG.batch_size,
    "retain_size": CFG.retain_size,
    "forget_size": CFG.forget_size,
    "val_size": CFG.val_size,
    "test_size": CFG.test_size,
    "unlearn_model": {
        "best_epoch": best_epoch_i,
        "accuracy": {
            "retrain_acc": best_train_acc,
            "forget_acc": mean_forget_acc,
            "val_acc": best_val_acc,
            "test_acc": mean_test_acc,
        },
        "path": best_model_path,
        "log_dir": log_dir,
        "train_time_sec": best_model_train_time,
        "mia_scores": {
            "forget_vs_retain_mia": forget_vs_retain_mia,
            "retain_vs_val_mia": retain_vs_val_mia,
            "retain_vs_test_mia": retain_vs_test_mia,
            "forget_vs_val_mia": forget_vs_val_mia,
            "forget_vs_test_mia": forget_vs_test_mia,
            "val_vs_test_mia": val_vs_test_mia,
        }
    },

}
research_info

{'dataset_name': 'mnist',
 'model_name': 'resnet18',
 'process_name': 'unlearning_retrain',
 'batch_size': 32,
 'retain_size': 40000,
 'forget_size': 10000,
 'val_size': 10000,
 'test_size': 10000,
 'unlearn_model': {'best_epoch': 22,
  'accuracy': {'retrain_acc': 0.997775,
   'forget_acc': 0.992611821086262,
   'val_acc': 0.9919129392971247,
   'test_acc': 0.9938099041533547},
  'path': '../models/mnist_resnet18_unlearning_retrain_acc_99_38.torch',
  'log_dir': '../logs/mnist/resnet18/unlearning_retrain/20231129_023016',
  'train_time_sec': 84.82473300000059,
  'mia_scores': {'forget_vs_retain_mia': 0.5064,
   'retain_vs_val_mia': 0.5122,
   'retain_vs_test_mia': 0.5104,
   'forget_vs_val_mia': 0.5050999999999999,
   'forget_vs_test_mia': 0.4976,
   'val_vs_test_mia': 0.5016}}}

In [101]:
with open(f"../research_info/{CFG.full_model_name}_acc_{test_acc_str}.yaml", "w") as f:
    yaml.dump(research_info, f)