In [12]:
%load_ext autoreload
%autoreload 2

import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append('..')

from tqdm import tqdm
from network import ResNet
from dataset import InputPipeLineBuilder

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
num_epochs = 100
batch_size = 256

device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 1e-3
weight_decay = 0.05

model = ResNet(head_input_dim=512).to(device)

for layer in model.modules():
  layer.requires_grad_ = True

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

sch1 = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=10)
sch2 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers=[sch1, sch2],
    milestones=[10]
)

In [14]:
input_pipeline_builder = InputPipeLineBuilder(batch_size=batch_size, select_forget_concept=True, dataset='cifar100')

train_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=True, subset='train')
valid_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=True, subset='valid')
test_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=True, subset='test')

In [None]:
for epoch in range(num_epochs):
    losses = []
    model.train()
    for batch in tqdm(train_dataloader):
        train_x, train_y = batch
        logits = model(train_x.to(device))
        
        loss = loss_fn(logits, train_y.to(device))
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        losses.append(loss.cpu().item())

    print(f"\tavg loss at epoch: {epoch+1}/{num_epochs}: {sum(losses) / len(losses):.4f}")
      
    model.eval()
    losses = []
    rcorrect, total = 0, 0
    for batch in valid_dataloader:
      valid_x, valid_y = batch
      logits = model(valid_x.to(device))
      
      loss = loss_fn(logits, valid_y.to(device))
      losses.append(loss.cpu().item())
      pred_labels = torch.argmax(logits, dim=1)
      num_preds = torch.sum((pred_labels == valid_y.to(device))).item()
      
      rcorrect += num_preds
      total += valid_x.shape[0]
    print(f"\tvalid acc at epoch: {epoch+1}/{num_epochs} : {rcorrect / total:.4f}")
    print(f"\tavg loss at epoch: {epoch+1}/{num_epochs}: {sum(losses) / len(losses):.4f}")

    scheduler.step()

In [None]:
torch.save(model.head.state_dict(), './resnet_18_retrained.pth')