In [1]:
from google.colab import drive
drive.mount('/content/drive')

%cd ./drive/MyDrive/esc/prepare

Mounted at /content/drive
/content/drive/MyDrive/esc/prepare


In [2]:
%load_ext autoreload
%autoreload 2

import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append('..')

from tqdm import tqdm
from network import AllCNN
from dataset import InputPipeLineBuilder

In [3]:
num_epochs = 350
batch_size = 256

device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 0.05
weight_decay = 0.001

model = AllCNN(head_input_dim=10).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)

def lr_lambda(epoch):
    if epoch < 200:
        return 1.0
    elif epoch < 250:
        return 0.1
    elif epoch < 300:
        return 0.01
    else:
        return 0.001

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lr_lambda)

In [5]:
input_pipeline_builder = InputPipeLineBuilder(batch_size=batch_size, select_forget_concept=True, dataset='cifar10')

train_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(subset='train', is_retain=True)
test_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(subset='test', is_retain=True)

In [6]:
for epoch in range(num_epochs):
    losses = []

    model.train()
    for batch in tqdm(train_dataloader):
        batch_x, batch_y = batch
        logits = model(batch_x.to(device))

        loss = loss_fn(logits, batch_y.to(device))
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.cpu().item())

    scheduler.step()
    print(f"avg loss at epoch: {epoch+1}/{num_epochs}: {sum(losses) / len(losses):.4f}")


100%|██████████| 140/140 [00:15<00:00,  9.15it/s]


avg loss at epoch: 1/350: 1.5751


100%|██████████| 140/140 [00:13<00:00, 10.50it/s]


avg loss at epoch: 2/350: 1.1210


100%|██████████| 140/140 [00:13<00:00, 10.44it/s]


avg loss at epoch: 3/350: 0.9687


100%|██████████| 140/140 [00:13<00:00, 10.43it/s]


avg loss at epoch: 4/350: 0.8583


100%|██████████| 140/140 [00:13<00:00, 10.41it/s]


avg loss at epoch: 5/350: 0.7956


100%|██████████| 140/140 [00:13<00:00, 10.38it/s]


avg loss at epoch: 6/350: 0.7336


100%|██████████| 140/140 [00:13<00:00, 10.31it/s]


avg loss at epoch: 7/350: 0.6847


100%|██████████| 140/140 [00:13<00:00, 10.30it/s]


avg loss at epoch: 8/350: 0.6420


100%|██████████| 140/140 [00:13<00:00, 10.38it/s]


avg loss at epoch: 9/350: 0.6038


100%|██████████| 140/140 [00:13<00:00, 10.26it/s]


avg loss at epoch: 10/350: 0.5752


100%|██████████| 140/140 [00:13<00:00, 10.38it/s]


avg loss at epoch: 11/350: 0.5471


100%|██████████| 140/140 [00:13<00:00, 10.26it/s]


avg loss at epoch: 12/350: 0.5236


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 13/350: 0.5033


100%|██████████| 140/140 [00:13<00:00, 10.26it/s]


avg loss at epoch: 14/350: 0.4790


100%|██████████| 140/140 [00:13<00:00, 10.34it/s]


avg loss at epoch: 15/350: 0.4689


100%|██████████| 140/140 [00:13<00:00, 10.35it/s]


avg loss at epoch: 16/350: 0.4467


100%|██████████| 140/140 [00:13<00:00, 10.13it/s]


avg loss at epoch: 17/350: 0.4386


100%|██████████| 140/140 [00:13<00:00, 10.40it/s]


avg loss at epoch: 18/350: 0.4261


100%|██████████| 140/140 [00:13<00:00, 10.43it/s]


avg loss at epoch: 19/350: 0.4143


100%|██████████| 140/140 [00:13<00:00, 10.40it/s]


avg loss at epoch: 20/350: 0.4093


100%|██████████| 140/140 [00:13<00:00, 10.46it/s]


avg loss at epoch: 21/350: 0.3991


100%|██████████| 140/140 [00:13<00:00, 10.44it/s]


avg loss at epoch: 22/350: 0.3937


100%|██████████| 140/140 [00:13<00:00, 10.41it/s]


avg loss at epoch: 23/350: 0.3805


100%|██████████| 140/140 [00:13<00:00, 10.50it/s]


avg loss at epoch: 24/350: 0.3637


100%|██████████| 140/140 [00:13<00:00, 10.20it/s]


avg loss at epoch: 25/350: 0.3683


100%|██████████| 140/140 [00:13<00:00, 10.30it/s]


avg loss at epoch: 26/350: 0.3650


100%|██████████| 140/140 [00:13<00:00, 10.34it/s]


avg loss at epoch: 27/350: 0.3563


100%|██████████| 140/140 [00:13<00:00, 10.25it/s]


avg loss at epoch: 28/350: 0.3619


100%|██████████| 140/140 [00:13<00:00, 10.30it/s]


avg loss at epoch: 29/350: 0.3415


100%|██████████| 140/140 [00:13<00:00, 10.37it/s]


avg loss at epoch: 30/350: 0.3352


100%|██████████| 140/140 [00:13<00:00, 10.40it/s]


avg loss at epoch: 31/350: 0.3350


100%|██████████| 140/140 [00:13<00:00, 10.26it/s]


avg loss at epoch: 32/350: 0.3272


100%|██████████| 140/140 [00:13<00:00, 10.33it/s]


avg loss at epoch: 33/350: 0.3356


100%|██████████| 140/140 [00:14<00:00,  9.91it/s]


avg loss at epoch: 34/350: 0.3289


100%|██████████| 140/140 [00:13<00:00, 10.21it/s]


avg loss at epoch: 35/350: 0.3290


100%|██████████| 140/140 [00:13<00:00, 10.29it/s]


avg loss at epoch: 36/350: 0.3213


100%|██████████| 140/140 [00:13<00:00, 10.16it/s]


avg loss at epoch: 37/350: 0.3172


100%|██████████| 140/140 [00:13<00:00, 10.25it/s]


avg loss at epoch: 38/350: 0.3081


100%|██████████| 140/140 [00:13<00:00, 10.26it/s]


avg loss at epoch: 39/350: 0.3063


100%|██████████| 140/140 [00:13<00:00, 10.23it/s]


avg loss at epoch: 40/350: 0.3210


100%|██████████| 140/140 [00:13<00:00, 10.30it/s]


avg loss at epoch: 41/350: 0.3056


100%|██████████| 140/140 [00:13<00:00, 10.33it/s]


avg loss at epoch: 42/350: 0.3024


100%|██████████| 140/140 [00:13<00:00, 10.23it/s]


avg loss at epoch: 43/350: 0.3003


100%|██████████| 140/140 [00:13<00:00, 10.35it/s]


avg loss at epoch: 44/350: 0.3007


100%|██████████| 140/140 [00:13<00:00, 10.38it/s]


avg loss at epoch: 45/350: 0.2960


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 46/350: 0.2823


100%|██████████| 140/140 [00:13<00:00, 10.34it/s]


avg loss at epoch: 47/350: 0.2854


100%|██████████| 140/140 [00:13<00:00, 10.35it/s]


avg loss at epoch: 48/350: 0.2802


100%|██████████| 140/140 [00:13<00:00, 10.30it/s]


avg loss at epoch: 49/350: 0.2857


100%|██████████| 140/140 [00:13<00:00, 10.36it/s]


avg loss at epoch: 50/350: 0.2892


100%|██████████| 140/140 [00:13<00:00, 10.18it/s]


avg loss at epoch: 51/350: 0.2898


100%|██████████| 140/140 [00:13<00:00, 10.33it/s]


avg loss at epoch: 52/350: 0.2821


100%|██████████| 140/140 [00:13<00:00, 10.31it/s]


avg loss at epoch: 53/350: 0.2811


100%|██████████| 140/140 [00:13<00:00, 10.30it/s]


avg loss at epoch: 54/350: 0.2844


100%|██████████| 140/140 [00:13<00:00, 10.37it/s]


avg loss at epoch: 55/350: 0.2699


100%|██████████| 140/140 [00:13<00:00, 10.27it/s]


avg loss at epoch: 56/350: 0.2781


100%|██████████| 140/140 [00:13<00:00, 10.16it/s]


avg loss at epoch: 57/350: 0.2756


100%|██████████| 140/140 [00:13<00:00, 10.31it/s]


avg loss at epoch: 58/350: 0.2702


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 59/350: 0.2738


100%|██████████| 140/140 [00:13<00:00, 10.27it/s]


avg loss at epoch: 60/350: 0.2709


100%|██████████| 140/140 [00:13<00:00, 10.47it/s]


avg loss at epoch: 61/350: 0.2697


100%|██████████| 140/140 [00:13<00:00, 10.38it/s]


avg loss at epoch: 62/350: 0.2764


100%|██████████| 140/140 [00:13<00:00, 10.32it/s]


avg loss at epoch: 63/350: 0.2583


100%|██████████| 140/140 [00:13<00:00, 10.41it/s]


avg loss at epoch: 64/350: 0.2713


100%|██████████| 140/140 [00:13<00:00, 10.44it/s]


avg loss at epoch: 65/350: 0.2640


100%|██████████| 140/140 [00:13<00:00, 10.40it/s]


avg loss at epoch: 66/350: 0.2641


100%|██████████| 140/140 [00:13<00:00, 10.44it/s]


avg loss at epoch: 67/350: 0.2507


100%|██████████| 140/140 [00:13<00:00, 10.17it/s]


avg loss at epoch: 68/350: 0.2609


100%|██████████| 140/140 [00:13<00:00, 10.41it/s]


avg loss at epoch: 69/350: 0.2591


100%|██████████| 140/140 [00:13<00:00, 10.43it/s]


avg loss at epoch: 70/350: 0.2534


100%|██████████| 140/140 [00:13<00:00, 10.37it/s]


avg loss at epoch: 71/350: 0.2635


100%|██████████| 140/140 [00:13<00:00, 10.41it/s]


avg loss at epoch: 72/350: 0.2534


100%|██████████| 140/140 [00:13<00:00, 10.43it/s]


avg loss at epoch: 73/350: 0.2637


100%|██████████| 140/140 [00:13<00:00, 10.43it/s]


avg loss at epoch: 74/350: 0.2581


100%|██████████| 140/140 [00:13<00:00, 10.37it/s]


avg loss at epoch: 75/350: 0.2675


100%|██████████| 140/140 [00:13<00:00, 10.36it/s]


avg loss at epoch: 76/350: 0.2590


100%|██████████| 140/140 [00:13<00:00, 10.29it/s]


avg loss at epoch: 77/350: 0.2509


100%|██████████| 140/140 [00:13<00:00, 10.31it/s]


avg loss at epoch: 78/350: 0.2509


100%|██████████| 140/140 [00:13<00:00, 10.34it/s]


avg loss at epoch: 79/350: 0.2602


100%|██████████| 140/140 [00:13<00:00, 10.36it/s]


avg loss at epoch: 80/350: 0.2612


100%|██████████| 140/140 [00:13<00:00, 10.21it/s]


avg loss at epoch: 81/350: 0.2527


100%|██████████| 140/140 [00:13<00:00, 10.30it/s]


avg loss at epoch: 82/350: 0.2534


100%|██████████| 140/140 [00:13<00:00, 10.32it/s]


avg loss at epoch: 83/350: 0.2497


100%|██████████| 140/140 [00:14<00:00,  9.95it/s]


avg loss at epoch: 84/350: 0.2615


100%|██████████| 140/140 [00:13<00:00, 10.37it/s]


avg loss at epoch: 85/350: 0.2448


100%|██████████| 140/140 [00:13<00:00, 10.33it/s]


avg loss at epoch: 86/350: 0.2448


100%|██████████| 140/140 [00:13<00:00, 10.38it/s]


avg loss at epoch: 87/350: 0.2496


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 88/350: 0.2447


100%|██████████| 140/140 [00:13<00:00, 10.35it/s]


avg loss at epoch: 89/350: 0.2569


100%|██████████| 140/140 [00:13<00:00, 10.31it/s]


avg loss at epoch: 90/350: 0.2564


100%|██████████| 140/140 [00:13<00:00, 10.30it/s]


avg loss at epoch: 91/350: 0.2453


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 92/350: 0.2516


100%|██████████| 140/140 [00:13<00:00, 10.26it/s]


avg loss at epoch: 93/350: 0.2530


100%|██████████| 140/140 [00:13<00:00, 10.29it/s]


avg loss at epoch: 94/350: 0.2420


100%|██████████| 140/140 [00:13<00:00, 10.21it/s]


avg loss at epoch: 95/350: 0.2368


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 96/350: 0.2412


100%|██████████| 140/140 [00:13<00:00, 10.24it/s]


avg loss at epoch: 97/350: 0.2628


100%|██████████| 140/140 [00:13<00:00, 10.30it/s]


avg loss at epoch: 98/350: 0.2433


100%|██████████| 140/140 [00:13<00:00, 10.29it/s]


avg loss at epoch: 99/350: 0.2442


100%|██████████| 140/140 [00:13<00:00, 10.40it/s]


avg loss at epoch: 100/350: 0.2416


100%|██████████| 140/140 [00:13<00:00, 10.14it/s]


avg loss at epoch: 101/350: 0.2454


100%|██████████| 140/140 [00:13<00:00, 10.34it/s]


avg loss at epoch: 102/350: 0.2416


100%|██████████| 140/140 [00:13<00:00, 10.33it/s]


avg loss at epoch: 103/350: 0.2537


100%|██████████| 140/140 [00:13<00:00, 10.19it/s]


avg loss at epoch: 104/350: 0.2387


100%|██████████| 140/140 [00:13<00:00, 10.18it/s]


avg loss at epoch: 105/350: 0.2434


100%|██████████| 140/140 [00:13<00:00, 10.18it/s]


avg loss at epoch: 106/350: 0.2397


100%|██████████| 140/140 [00:13<00:00, 10.32it/s]


avg loss at epoch: 107/350: 0.2363


100%|██████████| 140/140 [00:13<00:00, 10.32it/s]


avg loss at epoch: 108/350: 0.2298


100%|██████████| 140/140 [00:13<00:00, 10.38it/s]


avg loss at epoch: 109/350: 0.2403


100%|██████████| 140/140 [00:13<00:00, 10.35it/s]


avg loss at epoch: 110/350: 0.2409


100%|██████████| 140/140 [00:13<00:00, 10.39it/s]


avg loss at epoch: 111/350: 0.2428


100%|██████████| 140/140 [00:13<00:00, 10.35it/s]


avg loss at epoch: 112/350: 0.2356


100%|██████████| 140/140 [00:13<00:00, 10.36it/s]


avg loss at epoch: 113/350: 0.2387


100%|██████████| 140/140 [00:13<00:00, 10.35it/s]


avg loss at epoch: 114/350: 0.2360


100%|██████████| 140/140 [00:13<00:00, 10.38it/s]


avg loss at epoch: 115/350: 0.2456


100%|██████████| 140/140 [00:13<00:00, 10.38it/s]


avg loss at epoch: 116/350: 0.2326


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 117/350: 0.2340


100%|██████████| 140/140 [00:13<00:00, 10.08it/s]


avg loss at epoch: 118/350: 0.2365


100%|██████████| 140/140 [00:13<00:00, 10.34it/s]


avg loss at epoch: 119/350: 0.2359


100%|██████████| 140/140 [00:13<00:00, 10.35it/s]


avg loss at epoch: 120/350: 0.2392


100%|██████████| 140/140 [00:13<00:00, 10.32it/s]


avg loss at epoch: 121/350: 0.2328


100%|██████████| 140/140 [00:13<00:00, 10.35it/s]


avg loss at epoch: 122/350: 0.2334


100%|██████████| 140/140 [00:13<00:00, 10.35it/s]


avg loss at epoch: 123/350: 0.2364


100%|██████████| 140/140 [00:13<00:00, 10.37it/s]


avg loss at epoch: 124/350: 0.2294


100%|██████████| 140/140 [00:13<00:00, 10.33it/s]


avg loss at epoch: 125/350: 0.2421


100%|██████████| 140/140 [00:13<00:00, 10.34it/s]


avg loss at epoch: 126/350: 0.2419


100%|██████████| 140/140 [00:13<00:00, 10.29it/s]


avg loss at epoch: 127/350: 0.2374


100%|██████████| 140/140 [00:13<00:00, 10.27it/s]


avg loss at epoch: 128/350: 0.2295


100%|██████████| 140/140 [00:13<00:00, 10.20it/s]


avg loss at epoch: 129/350: 0.2365


100%|██████████| 140/140 [00:13<00:00, 10.26it/s]


avg loss at epoch: 130/350: 0.2245


100%|██████████| 140/140 [00:13<00:00, 10.29it/s]


avg loss at epoch: 131/350: 0.2356


100%|██████████| 140/140 [00:13<00:00, 10.27it/s]


avg loss at epoch: 132/350: 0.2341


100%|██████████| 140/140 [00:13<00:00, 10.35it/s]


avg loss at epoch: 133/350: 0.2272


100%|██████████| 140/140 [00:13<00:00, 10.33it/s]


avg loss at epoch: 134/350: 0.2418


100%|██████████| 140/140 [00:13<00:00, 10.14it/s]


avg loss at epoch: 135/350: 0.2342


100%|██████████| 140/140 [00:13<00:00, 10.18it/s]


avg loss at epoch: 136/350: 0.2310


100%|██████████| 140/140 [00:13<00:00, 10.14it/s]


avg loss at epoch: 137/350: 0.2336


100%|██████████| 140/140 [00:13<00:00, 10.31it/s]


avg loss at epoch: 138/350: 0.2249


100%|██████████| 140/140 [00:13<00:00, 10.11it/s]


avg loss at epoch: 139/350: 0.2278


100%|██████████| 140/140 [00:13<00:00, 10.30it/s]


avg loss at epoch: 140/350: 0.2270


100%|██████████| 140/140 [00:13<00:00, 10.21it/s]


avg loss at epoch: 141/350: 0.2369


100%|██████████| 140/140 [00:13<00:00, 10.24it/s]


avg loss at epoch: 142/350: 0.2283


100%|██████████| 140/140 [00:13<00:00, 10.26it/s]


avg loss at epoch: 143/350: 0.2284


100%|██████████| 140/140 [00:13<00:00, 10.29it/s]


avg loss at epoch: 144/350: 0.2301


100%|██████████| 140/140 [00:13<00:00, 10.34it/s]


avg loss at epoch: 145/350: 0.2298


100%|██████████| 140/140 [00:13<00:00, 10.33it/s]


avg loss at epoch: 146/350: 0.2415


100%|██████████| 140/140 [00:13<00:00, 10.37it/s]


avg loss at epoch: 147/350: 0.2369


100%|██████████| 140/140 [00:13<00:00, 10.36it/s]


avg loss at epoch: 148/350: 0.2192


100%|██████████| 140/140 [00:13<00:00, 10.33it/s]


avg loss at epoch: 149/350: 0.2241


100%|██████████| 140/140 [00:13<00:00, 10.22it/s]


avg loss at epoch: 150/350: 0.2282


100%|██████████| 140/140 [00:13<00:00, 10.05it/s]


avg loss at epoch: 151/350: 0.2194


100%|██████████| 140/140 [00:13<00:00, 10.17it/s]


avg loss at epoch: 152/350: 0.2255


100%|██████████| 140/140 [00:13<00:00, 10.17it/s]


avg loss at epoch: 153/350: 0.2286


100%|██████████| 140/140 [00:13<00:00, 10.24it/s]


avg loss at epoch: 154/350: 0.2215


100%|██████████| 140/140 [00:13<00:00, 10.20it/s]


avg loss at epoch: 155/350: 0.2257


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 156/350: 0.2249


100%|██████████| 140/140 [00:13<00:00, 10.24it/s]


avg loss at epoch: 157/350: 0.2353


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 158/350: 0.2198


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 159/350: 0.2219


100%|██████████| 140/140 [00:13<00:00, 10.33it/s]


avg loss at epoch: 160/350: 0.2170


100%|██████████| 140/140 [00:13<00:00, 10.25it/s]


avg loss at epoch: 161/350: 0.2289


100%|██████████| 140/140 [00:13<00:00, 10.37it/s]


avg loss at epoch: 162/350: 0.2433


100%|██████████| 140/140 [00:13<00:00, 10.26it/s]


avg loss at epoch: 163/350: 0.2310


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 164/350: 0.2285


100%|██████████| 140/140 [00:13<00:00, 10.24it/s]


avg loss at epoch: 165/350: 0.2254


100%|██████████| 140/140 [00:13<00:00, 10.26it/s]


avg loss at epoch: 166/350: 0.2216


100%|██████████| 140/140 [00:13<00:00, 10.22it/s]


avg loss at epoch: 167/350: 0.2152


100%|██████████| 140/140 [00:13<00:00, 10.01it/s]


avg loss at epoch: 168/350: 0.2347


100%|██████████| 140/140 [00:13<00:00, 10.33it/s]


avg loss at epoch: 169/350: 0.2215


100%|██████████| 140/140 [00:13<00:00, 10.12it/s]


avg loss at epoch: 170/350: 0.2232


100%|██████████| 140/140 [00:13<00:00, 10.15it/s]


avg loss at epoch: 171/350: 0.2272


100%|██████████| 140/140 [00:13<00:00, 10.15it/s]


avg loss at epoch: 172/350: 0.2283


100%|██████████| 140/140 [00:13<00:00, 10.22it/s]


avg loss at epoch: 173/350: 0.2293


100%|██████████| 140/140 [00:13<00:00, 10.21it/s]


avg loss at epoch: 174/350: 0.2210


100%|██████████| 140/140 [00:13<00:00, 10.27it/s]


avg loss at epoch: 175/350: 0.2145


100%|██████████| 140/140 [00:13<00:00, 10.23it/s]


avg loss at epoch: 176/350: 0.2212


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 177/350: 0.2288


100%|██████████| 140/140 [00:13<00:00, 10.20it/s]


avg loss at epoch: 178/350: 0.2244


100%|██████████| 140/140 [00:13<00:00, 10.17it/s]


avg loss at epoch: 179/350: 0.2222


100%|██████████| 140/140 [00:13<00:00, 10.15it/s]


avg loss at epoch: 180/350: 0.2185


100%|██████████| 140/140 [00:14<00:00,  9.97it/s]


avg loss at epoch: 181/350: 0.2238


100%|██████████| 140/140 [00:13<00:00, 10.01it/s]


avg loss at epoch: 182/350: 0.2283


100%|██████████| 140/140 [00:13<00:00, 10.15it/s]


avg loss at epoch: 183/350: 0.2261


100%|██████████| 140/140 [00:13<00:00, 10.10it/s]


avg loss at epoch: 184/350: 0.2184


100%|██████████| 140/140 [00:14<00:00,  9.83it/s]


avg loss at epoch: 185/350: 0.2266


100%|██████████| 140/140 [00:13<00:00, 10.16it/s]


avg loss at epoch: 186/350: 0.2072


100%|██████████| 140/140 [00:13<00:00, 10.20it/s]


avg loss at epoch: 187/350: 0.2323


100%|██████████| 140/140 [00:13<00:00, 10.02it/s]


avg loss at epoch: 188/350: 0.2213


100%|██████████| 140/140 [00:13<00:00, 10.14it/s]


avg loss at epoch: 189/350: 0.2091


100%|██████████| 140/140 [00:13<00:00, 10.03it/s]


avg loss at epoch: 190/350: 0.2239


100%|██████████| 140/140 [00:14<00:00,  9.87it/s]


avg loss at epoch: 191/350: 0.2182


100%|██████████| 140/140 [00:14<00:00,  9.93it/s]


avg loss at epoch: 192/350: 0.2207


100%|██████████| 140/140 [00:13<00:00, 10.05it/s]


avg loss at epoch: 193/350: 0.2173


100%|██████████| 140/140 [00:13<00:00, 10.04it/s]


avg loss at epoch: 194/350: 0.2269


100%|██████████| 140/140 [00:13<00:00, 10.13it/s]


avg loss at epoch: 195/350: 0.2221


100%|██████████| 140/140 [00:13<00:00, 10.13it/s]


avg loss at epoch: 196/350: 0.2136


100%|██████████| 140/140 [00:13<00:00, 10.13it/s]


avg loss at epoch: 197/350: 0.2167


100%|██████████| 140/140 [00:13<00:00, 10.11it/s]


avg loss at epoch: 198/350: 0.2306


100%|██████████| 140/140 [00:13<00:00, 10.21it/s]


avg loss at epoch: 199/350: 0.2168


100%|██████████| 140/140 [00:13<00:00, 10.16it/s]


avg loss at epoch: 200/350: 0.2209


100%|██████████| 140/140 [00:13<00:00, 10.10it/s]


avg loss at epoch: 201/350: 0.1183


100%|██████████| 140/140 [00:13<00:00, 10.01it/s]


avg loss at epoch: 202/350: 0.0819


100%|██████████| 140/140 [00:13<00:00, 10.17it/s]


avg loss at epoch: 203/350: 0.0719


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 204/350: 0.0651


100%|██████████| 140/140 [00:13<00:00, 10.18it/s]


avg loss at epoch: 205/350: 0.0602


100%|██████████| 140/140 [00:13<00:00, 10.14it/s]


avg loss at epoch: 206/350: 0.0562


100%|██████████| 140/140 [00:13<00:00, 10.11it/s]


avg loss at epoch: 207/350: 0.0535


100%|██████████| 140/140 [00:13<00:00, 10.12it/s]


avg loss at epoch: 208/350: 0.0511


100%|██████████| 140/140 [00:14<00:00,  9.89it/s]


avg loss at epoch: 209/350: 0.0488


100%|██████████| 140/140 [00:13<00:00, 10.31it/s]


avg loss at epoch: 210/350: 0.0472


100%|██████████| 140/140 [00:13<00:00, 10.21it/s]


avg loss at epoch: 211/350: 0.0449


100%|██████████| 140/140 [00:13<00:00, 10.26it/s]


avg loss at epoch: 212/350: 0.0439


100%|██████████| 140/140 [00:13<00:00, 10.25it/s]


avg loss at epoch: 213/350: 0.0424


100%|██████████| 140/140 [00:13<00:00, 10.32it/s]


avg loss at epoch: 214/350: 0.0412


100%|██████████| 140/140 [00:13<00:00, 10.24it/s]


avg loss at epoch: 215/350: 0.0405


100%|██████████| 140/140 [00:13<00:00, 10.25it/s]


avg loss at epoch: 216/350: 0.0393


100%|██████████| 140/140 [00:13<00:00, 10.21it/s]


avg loss at epoch: 217/350: 0.0377


100%|██████████| 140/140 [00:14<00:00,  9.98it/s]


avg loss at epoch: 218/350: 0.0375


100%|██████████| 140/140 [00:13<00:00, 10.23it/s]


avg loss at epoch: 219/350: 0.0359


100%|██████████| 140/140 [00:13<00:00, 10.20it/s]


avg loss at epoch: 220/350: 0.0354


100%|██████████| 140/140 [00:13<00:00, 10.20it/s]


avg loss at epoch: 221/350: 0.0352


100%|██████████| 140/140 [00:13<00:00, 10.10it/s]


avg loss at epoch: 222/350: 0.0344


100%|██████████| 140/140 [00:13<00:00, 10.11it/s]


avg loss at epoch: 223/350: 0.0338


100%|██████████| 140/140 [00:13<00:00, 10.21it/s]


avg loss at epoch: 224/350: 0.0333


100%|██████████| 140/140 [00:13<00:00, 10.13it/s]


avg loss at epoch: 225/350: 0.0329


100%|██████████| 140/140 [00:13<00:00, 10.07it/s]


avg loss at epoch: 226/350: 0.0322


100%|██████████| 140/140 [00:13<00:00, 10.03it/s]


avg loss at epoch: 227/350: 0.0322


100%|██████████| 140/140 [00:13<00:00, 10.08it/s]


avg loss at epoch: 228/350: 0.0320


100%|██████████| 140/140 [00:13<00:00, 10.01it/s]


avg loss at epoch: 229/350: 0.0312


100%|██████████| 140/140 [00:13<00:00, 10.05it/s]


avg loss at epoch: 230/350: 0.0301


100%|██████████| 140/140 [00:13<00:00, 10.00it/s]


avg loss at epoch: 231/350: 0.0307


100%|██████████| 140/140 [00:13<00:00, 10.02it/s]


avg loss at epoch: 232/350: 0.0299


100%|██████████| 140/140 [00:13<00:00, 10.11it/s]


avg loss at epoch: 233/350: 0.0294


100%|██████████| 140/140 [00:13<00:00, 10.18it/s]


avg loss at epoch: 234/350: 0.0299


100%|██████████| 140/140 [00:14<00:00,  9.90it/s]


avg loss at epoch: 235/350: 0.0292


100%|██████████| 140/140 [00:13<00:00, 10.13it/s]


avg loss at epoch: 236/350: 0.0290


100%|██████████| 140/140 [00:13<00:00, 10.05it/s]


avg loss at epoch: 237/350: 0.0289


100%|██████████| 140/140 [00:13<00:00, 10.18it/s]


avg loss at epoch: 238/350: 0.0287


100%|██████████| 140/140 [00:13<00:00, 10.04it/s]


avg loss at epoch: 239/350: 0.0279


100%|██████████| 140/140 [00:13<00:00, 10.03it/s]


avg loss at epoch: 240/350: 0.0279


100%|██████████| 140/140 [00:13<00:00, 10.07it/s]


avg loss at epoch: 241/350: 0.0279


100%|██████████| 140/140 [00:13<00:00, 10.00it/s]


avg loss at epoch: 242/350: 0.0279


100%|██████████| 140/140 [00:14<00:00,  9.94it/s]


avg loss at epoch: 243/350: 0.0277


100%|██████████| 140/140 [00:14<00:00,  9.87it/s]


avg loss at epoch: 244/350: 0.0274


100%|██████████| 140/140 [00:14<00:00,  9.92it/s]


avg loss at epoch: 245/350: 0.0273


100%|██████████| 140/140 [00:13<00:00, 10.04it/s]


avg loss at epoch: 246/350: 0.0267


100%|██████████| 140/140 [00:13<00:00, 10.07it/s]


avg loss at epoch: 247/350: 0.0273


100%|██████████| 140/140 [00:13<00:00, 10.05it/s]


avg loss at epoch: 248/350: 0.0267


100%|██████████| 140/140 [00:13<00:00, 10.14it/s]


avg loss at epoch: 249/350: 0.0267


100%|██████████| 140/140 [00:13<00:00, 10.15it/s]


avg loss at epoch: 250/350: 0.0265


100%|██████████| 140/140 [00:13<00:00, 10.15it/s]


avg loss at epoch: 251/350: 0.0247


100%|██████████| 140/140 [00:14<00:00,  9.99it/s]


avg loss at epoch: 252/350: 0.0236


100%|██████████| 140/140 [00:13<00:00, 10.21it/s]


avg loss at epoch: 253/350: 0.0239


100%|██████████| 140/140 [00:13<00:00, 10.11it/s]


avg loss at epoch: 254/350: 0.0237


100%|██████████| 140/140 [00:13<00:00, 10.19it/s]


avg loss at epoch: 255/350: 0.0240


100%|██████████| 140/140 [00:13<00:00, 10.22it/s]


avg loss at epoch: 256/350: 0.0235


100%|██████████| 140/140 [00:13<00:00, 10.17it/s]


avg loss at epoch: 257/350: 0.0233


100%|██████████| 140/140 [00:13<00:00, 10.22it/s]


avg loss at epoch: 258/350: 0.0232


100%|██████████| 140/140 [00:13<00:00, 10.13it/s]


avg loss at epoch: 259/350: 0.0231


100%|██████████| 140/140 [00:13<00:00, 10.07it/s]


avg loss at epoch: 260/350: 0.0235


100%|██████████| 140/140 [00:13<00:00, 10.08it/s]


avg loss at epoch: 261/350: 0.0236


100%|██████████| 140/140 [00:13<00:00, 10.06it/s]


avg loss at epoch: 262/350: 0.0231


100%|██████████| 140/140 [00:14<00:00,  9.99it/s]


avg loss at epoch: 263/350: 0.0235


100%|██████████| 140/140 [00:13<00:00, 10.17it/s]


avg loss at epoch: 264/350: 0.0234


100%|██████████| 140/140 [00:13<00:00, 10.17it/s]


avg loss at epoch: 265/350: 0.0236


100%|██████████| 140/140 [00:14<00:00,  9.99it/s]


avg loss at epoch: 266/350: 0.0233


100%|██████████| 140/140 [00:13<00:00, 10.08it/s]


avg loss at epoch: 267/350: 0.0231


100%|██████████| 140/140 [00:13<00:00, 10.07it/s]


avg loss at epoch: 268/350: 0.0232


100%|██████████| 140/140 [00:14<00:00,  9.88it/s]


avg loss at epoch: 269/350: 0.0233


100%|██████████| 140/140 [00:13<00:00, 10.10it/s]


avg loss at epoch: 270/350: 0.0233


100%|██████████| 140/140 [00:13<00:00, 10.14it/s]


avg loss at epoch: 271/350: 0.0229


100%|██████████| 140/140 [00:13<00:00, 10.15it/s]


avg loss at epoch: 272/350: 0.0234


100%|██████████| 140/140 [00:13<00:00, 10.14it/s]


avg loss at epoch: 273/350: 0.0234


100%|██████████| 140/140 [00:13<00:00, 10.12it/s]


avg loss at epoch: 274/350: 0.0234


100%|██████████| 140/140 [00:13<00:00, 10.07it/s]


avg loss at epoch: 275/350: 0.0234


100%|██████████| 140/140 [00:13<00:00, 10.09it/s]


avg loss at epoch: 276/350: 0.0233


100%|██████████| 140/140 [00:13<00:00, 10.17it/s]


avg loss at epoch: 277/350: 0.0233


100%|██████████| 140/140 [00:14<00:00, 10.00it/s]


avg loss at epoch: 278/350: 0.0234


100%|██████████| 140/140 [00:13<00:00, 10.12it/s]


avg loss at epoch: 279/350: 0.0231


100%|██████████| 140/140 [00:13<00:00, 10.02it/s]


avg loss at epoch: 280/350: 0.0232


100%|██████████| 140/140 [00:14<00:00,  9.99it/s]


avg loss at epoch: 281/350: 0.0231


100%|██████████| 140/140 [00:13<00:00, 10.20it/s]


avg loss at epoch: 282/350: 0.0235


100%|██████████| 140/140 [00:13<00:00, 10.18it/s]


avg loss at epoch: 283/350: 0.0233


100%|██████████| 140/140 [00:13<00:00, 10.15it/s]


avg loss at epoch: 284/350: 0.0230


100%|██████████| 140/140 [00:14<00:00,  9.87it/s]


avg loss at epoch: 285/350: 0.0231


100%|██████████| 140/140 [00:13<00:00, 10.15it/s]


avg loss at epoch: 286/350: 0.0232


100%|██████████| 140/140 [00:13<00:00, 10.05it/s]


avg loss at epoch: 287/350: 0.0233


100%|██████████| 140/140 [00:13<00:00, 10.09it/s]


avg loss at epoch: 288/350: 0.0232


100%|██████████| 140/140 [00:13<00:00, 10.19it/s]


avg loss at epoch: 289/350: 0.0231


100%|██████████| 140/140 [00:13<00:00, 10.18it/s]


avg loss at epoch: 290/350: 0.0231


100%|██████████| 140/140 [00:13<00:00, 10.19it/s]


avg loss at epoch: 291/350: 0.0233


100%|██████████| 140/140 [00:13<00:00, 10.23it/s]


avg loss at epoch: 292/350: 0.0230


100%|██████████| 140/140 [00:13<00:00, 10.19it/s]


avg loss at epoch: 293/350: 0.0232


100%|██████████| 140/140 [00:13<00:00, 10.22it/s]


avg loss at epoch: 294/350: 0.0231


100%|██████████| 140/140 [00:13<00:00, 10.19it/s]


avg loss at epoch: 295/350: 0.0231


100%|██████████| 140/140 [00:13<00:00, 10.02it/s]


avg loss at epoch: 296/350: 0.0233


100%|██████████| 140/140 [00:13<00:00, 10.08it/s]


avg loss at epoch: 297/350: 0.0231


100%|██████████| 140/140 [00:14<00:00,  9.93it/s]


avg loss at epoch: 298/350: 0.0230


100%|██████████| 140/140 [00:13<00:00, 10.09it/s]


avg loss at epoch: 299/350: 0.0230


100%|██████████| 140/140 [00:13<00:00, 10.06it/s]


avg loss at epoch: 300/350: 0.0231


100%|██████████| 140/140 [00:13<00:00, 10.26it/s]


avg loss at epoch: 301/350: 0.0229


100%|██████████| 140/140 [00:14<00:00,  9.99it/s]


avg loss at epoch: 302/350: 0.0226


100%|██████████| 140/140 [00:13<00:00, 10.07it/s]


avg loss at epoch: 303/350: 0.0230


100%|██████████| 140/140 [00:13<00:00, 10.02it/s]


avg loss at epoch: 304/350: 0.0225


100%|██████████| 140/140 [00:13<00:00, 10.03it/s]


avg loss at epoch: 305/350: 0.0229


100%|██████████| 140/140 [00:13<00:00, 10.01it/s]


avg loss at epoch: 306/350: 0.0226


100%|██████████| 140/140 [00:13<00:00, 10.06it/s]


avg loss at epoch: 307/350: 0.0228


100%|██████████| 140/140 [00:13<00:00, 10.08it/s]


avg loss at epoch: 308/350: 0.0228


100%|██████████| 140/140 [00:13<00:00, 10.03it/s]


avg loss at epoch: 309/350: 0.0226


100%|██████████| 140/140 [00:14<00:00,  9.89it/s]


avg loss at epoch: 310/350: 0.0226


100%|██████████| 140/140 [00:14<00:00,  9.92it/s]


avg loss at epoch: 311/350: 0.0228


100%|██████████| 140/140 [00:14<00:00,  9.74it/s]


avg loss at epoch: 312/350: 0.0226


100%|██████████| 140/140 [00:14<00:00,  9.83it/s]


avg loss at epoch: 313/350: 0.0225


100%|██████████| 140/140 [00:14<00:00,  9.85it/s]


avg loss at epoch: 314/350: 0.0226


100%|██████████| 140/140 [00:13<00:00, 10.00it/s]


avg loss at epoch: 315/350: 0.0225


100%|██████████| 140/140 [00:13<00:00, 10.04it/s]


avg loss at epoch: 316/350: 0.0227


100%|██████████| 140/140 [00:13<00:00, 10.11it/s]


avg loss at epoch: 317/350: 0.0231


100%|██████████| 140/140 [00:13<00:00, 10.12it/s]


avg loss at epoch: 318/350: 0.0228


100%|██████████| 140/140 [00:17<00:00,  7.98it/s]


avg loss at epoch: 319/350: 0.0225


100%|██████████| 140/140 [00:18<00:00,  7.74it/s]


avg loss at epoch: 320/350: 0.0228


100%|██████████| 140/140 [00:18<00:00,  7.72it/s]


avg loss at epoch: 321/350: 0.0227


100%|██████████| 140/140 [00:16<00:00,  8.49it/s]


avg loss at epoch: 322/350: 0.0227


100%|██████████| 140/140 [00:14<00:00,  9.93it/s]


avg loss at epoch: 323/350: 0.0222


100%|██████████| 140/140 [00:13<00:00, 10.04it/s]


avg loss at epoch: 324/350: 0.0230


100%|██████████| 140/140 [00:13<00:00, 10.07it/s]


avg loss at epoch: 325/350: 0.0227


100%|██████████| 140/140 [00:13<00:00, 10.10it/s]


avg loss at epoch: 326/350: 0.0226


100%|██████████| 140/140 [00:13<00:00, 10.19it/s]


avg loss at epoch: 327/350: 0.0224


100%|██████████| 140/140 [00:13<00:00, 10.07it/s]


avg loss at epoch: 328/350: 0.0226


100%|██████████| 140/140 [00:13<00:00, 10.02it/s]


avg loss at epoch: 329/350: 0.0228


100%|██████████| 140/140 [00:13<00:00, 10.13it/s]


avg loss at epoch: 330/350: 0.0227


100%|██████████| 140/140 [00:13<00:00, 10.10it/s]


avg loss at epoch: 331/350: 0.0230


100%|██████████| 140/140 [00:13<00:00, 10.20it/s]


avg loss at epoch: 332/350: 0.0227


100%|██████████| 140/140 [00:13<00:00, 10.18it/s]


avg loss at epoch: 333/350: 0.0228


100%|██████████| 140/140 [00:13<00:00, 10.25it/s]


avg loss at epoch: 334/350: 0.0227


100%|██████████| 140/140 [00:13<00:00, 10.19it/s]


avg loss at epoch: 335/350: 0.0230


100%|██████████| 140/140 [00:13<00:00, 10.05it/s]


avg loss at epoch: 336/350: 0.0227


100%|██████████| 140/140 [00:13<00:00, 10.25it/s]


avg loss at epoch: 337/350: 0.0229


100%|██████████| 140/140 [00:13<00:00, 10.16it/s]


avg loss at epoch: 338/350: 0.0227


100%|██████████| 140/140 [00:13<00:00, 10.18it/s]


avg loss at epoch: 339/350: 0.0225


100%|██████████| 140/140 [00:13<00:00, 10.25it/s]


avg loss at epoch: 340/350: 0.0226


100%|██████████| 140/140 [00:13<00:00, 10.28it/s]


avg loss at epoch: 341/350: 0.0224


100%|██████████| 140/140 [00:13<00:00, 10.18it/s]


avg loss at epoch: 342/350: 0.0227


100%|██████████| 140/140 [00:13<00:00, 10.24it/s]


avg loss at epoch: 343/350: 0.0226


100%|██████████| 140/140 [00:13<00:00, 10.26it/s]


avg loss at epoch: 344/350: 0.0227


100%|██████████| 140/140 [00:13<00:00, 10.20it/s]


avg loss at epoch: 345/350: 0.0227


100%|██████████| 140/140 [00:13<00:00, 10.20it/s]


avg loss at epoch: 346/350: 0.0226


100%|██████████| 140/140 [00:13<00:00, 10.01it/s]


avg loss at epoch: 347/350: 0.0226


100%|██████████| 140/140 [00:13<00:00, 10.06it/s]


avg loss at epoch: 348/350: 0.0228


100%|██████████| 140/140 [00:13<00:00, 10.16it/s]


avg loss at epoch: 349/350: 0.0228


100%|██████████| 140/140 [00:14<00:00, 10.00it/s]

avg loss at epoch: 350/350: 0.0227





In [8]:
torch.save(model.state_dict(), './all_cnn_retrained.pth')

In [None]:
rcorrect, total = 0, 0

model.eval()
for batch in tqdm(test_dataloader):
    batch_x, batch_y = batch
    logits = model(batch_x.to(device))
    preds = torch.sum((preds == batch_y)).item()

    rcorrect += preds
    total += batch_x.shape[0]

print(f"avg accuracy at epoch: {epoch+1}/{num_epochs}: {rcorrect/total:.4f}")

In [None]:
test_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=False, subset='test')

In [None]:
rcorrect, total = 0, 0

model.eval()
for batch in tqdm(test_dataloader):
    batch_x, batch_y = batch
    logits = model(batch_x.to(device))
    preds = torch.sum((preds == batch_y)).item()

    rcorrect += preds
    total += batch_x.shape[0]

print(f"avg accuracy at epoch: {epoch+1}/{num_epochs}: {rcorrect/total:.4f}")