In [11]:
%load_ext autoreload
%autoreload 2

import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append('..')

from tqdm import tqdm
from network import AllCNN
from dataset import InputPipeLineBuilder

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
num_epochs = 350
batch_size = 256

device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 0.05
weight_decay = 0.001

model = AllCNN(head_input_dim=10).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)

def lr_lambda(epoch):
    if epoch < 200:
        return 1.0
    elif epoch < 250:
        return 0.1
    elif epoch < 300:
        return 0.01
    else:
        return 0.001
    
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lr_lambda)

In [13]:
input_pipeline_builder = InputPipeLineBuilder(batch_size=batch_size, select_forget_concept=True, dataset='cifar10')

train_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(subset='train', is_retain=True)
test_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(subset='test', is_retain=True)

In [None]:
for epoch in range(num_epochs):
    losses = []
    
    model.train()
    for batch in tqdm(train_dataloader):
        batch_x, batch_y = batch
        logits = model(batch_x.to(device))
        
        loss = loss_fn(logits, batch_y.to(device))
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        losses.append(loss.cpu().item())
    
    scheduler.step()    
    print(f"avg loss at epoch: {epoch+1}/{num_epochs}: {sum(losses) / len(losses):.4f}")
    

In [None]:
torch.save(model.state_dict(), './all_cnn_retrained.pth')

In [None]:
rcorrect, total = 0, 0

model.eval()
for batch in tqdm(test_dataloader):
    batch_x, batch_y = batch
    logits = model(batch_x.to(device))
    preds = torch.sum((preds == batch_y)).item()
    
    rcorrect += preds
    total += batch_x.shape[0]

print(f"avg accuracy at epoch: {epoch+1}/{num_epochs}: {rcorrect/total:.4f}")

In [None]:
test_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=False, subset='train')

In [None]:
rcorrect, total = 0, 0

model.eval()
for batch in tqdm(test_dataloader):
    batch_x, batch_y = batch
    logits = model(batch_x.to(device))
    preds = torch.sum((preds == batch_y)).item()
    
    rcorrect += preds
    total += batch_x.shape[0]

print(f"avg accuracy at epoch: {epoch+1}/{num_epochs}: {rcorrect/total:.4f}")