In [65]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn

from esc import ESC
from esc_t import ESC_T
from utils import evaluate, evaluate_kr
from network import AllCNN, ResNet
from dataset import InputPipeLineBuilder

device = 'cuda' if torch.cuda.is_available() else 'cpu'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### All CNN

In [None]:
input_pipeline_builder = InputPipeLineBuilder(batch_size=64, select_forget_concept=True, dataset='cifar10')

f_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=False, subset='train')
r_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=True, subset='train')
ft_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=False, subset='test')
rt_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=True, subset='test')

In [None]:
# Original Model
model = AllCNN(device=device).to(device)
model.load_state_dict(torch.load('./prepare/all_cnn_pretrained.pth', map_location=device))

metrices = evaluate(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model)
print(metrices)

metrices = evaluate_kr(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model, output_dir='./prepare/all_cnn_origin_probe.pth')
print(metrices)

In [None]:
# Retrain Model
model = AllCNN(device=device).to(device)
model.load_state_dict(torch.load('./prepare/all_cnn_retrained.pth', map_location=device))

metrices = evaluate(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model)
print(metrices)

metrices = evaluate_kr(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model, output_dir='./prepare/all_cnn_retrain_probe.pth')
print(metrices)

In [None]:
# ESC Model
model = ESC(p=0.017, use_pretrain=True, model_type='all_cnn')
model.get_up_matrix(f_dataloader)

metrices = evaluate(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model)
print(metrices)

metrices = evaluate_kr(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model, output_dir='./prepare/all_cnn_esc_probe.pth')
print(metrices)

In [None]:
#ESC-T Model
model = ESC_T(threshold=0.8, use_pretrain=True, model_type='all_cnn')
model.get_up_matrix(f_dataloader)

optimizer = torch.optim.Adam([model.mask], lr=0.01)
_ = model.train_mask(f_dataloader, optimizer, num_epochs=10)

metrices = evaluate(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model)
print(metrices)

metrices = evaluate_kr(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model, output_dir='./prepare/all_cnn_esct_probe.pth')
print(metrices)

### ResNet

In [None]:
input_pipeline_builder = InputPipeLineBuilder(batch_size=64, select_forget_concept=True, dataset='cifar100')

f_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=False, subset='train')
r_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=True, subset='train')
ft_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=False, subset='test')
rt_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=True, subset='test')

In [73]:
# Original Model
model = ResNet(device=device).to(device)
model.load_state_dict(torch.load('./prepare/resnet_18_pretrained.pth', map_location=device))

metrices = evaluate(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model)
print(metrices)

metrices = evaluate_kr(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model, output_dir='./prepare/resnet_origin_probe.pth')
print(metrices)

KeyboardInterrupt: 

In [None]:
# Retrain Model
model = ResNet(device=device).to(device)
model.load_state_dict(torch.load('./prepare/resnet_18_retrained.pth', map_location=device))

metrices = evaluate(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model)
print(metrices)

metrices = evaluate_kr(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model, output_dir='./prepare/resnet_retrain_probe.pth')
print(metrices)

In [None]:
# ESC Model
model = ESC(p=0.017, use_pretrain=True, model_type='resnet')
model.get_up_matrix(f_dataloader)

metrices = evaluate(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model)
print(metrices)

metrices = evaluate_kr(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model, output_dir='./prepare/resnet_esc_probe.pth')
print(metrices)

In [None]:
#ESC-T Model
model = ESC_T(threshold=0.7, use_pretrain=True, model_type='resnet')
model.get_up_matrix(f_dataloader)

optimizer = torch.optim.Adam([model.mask], lr=0.01)
_ = model.train_mask(f_dataloader, optimizer, num_epochs=10)

metrices = evaluate(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model)
print(metrices)

metrices = evaluate_kr(r_dataloader,f_dataloader, rt_dataloader, ft_dataloader, model, output_dir='./prepare/resnet_esct_probe.pth')
print(metrices)