In [1]:
from google.colab import drive
drive.mount('/content/drive')

%cd ./drive/MyDrive/esc/prepare

Mounted at /content/drive
/content/drive/MyDrive/esc/prepare


In [2]:
%load_ext autoreload
%autoreload 2

import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append('..')

from tqdm import tqdm
from network import ResNet
from dataset import InputPipeLineBuilder

In [3]:
num_epochs = 100
batch_size = 256

device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 1e-3
weight_decay = 0.05

model = ResNet(head_input_dim=512).to(device)

for layer in model.modules():
  layer.requires_grad_ = True

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

sch1 = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=10)
sch2 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers=[sch1, sch2],
    milestones=[10]
)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 208MB/s]


In [4]:
input_pipeline_builder = InputPipeLineBuilder(batch_size=batch_size, select_forget_concept=True, dataset='cifar100')

train_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=True, subset='train')
valid_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=True, subset='valid')
test_dataloader = input_pipeline_builder.get_dataloader_for_unlearn(is_retain=True, subset='test')

In [5]:
for epoch in range(num_epochs):
    losses = []
    model.train()
    for batch in tqdm(train_dataloader):
        train_x, train_y = batch
        logits = model(train_x.to(device))

        loss = loss_fn(logits, train_y.to(device))
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.cpu().item())

    print(f"\tavg loss at epoch: {epoch+1}/{num_epochs}: {sum(losses) / len(losses):.4f}")

    model.eval()
    losses = []
    rcorrect, total = 0, 0
    for batch in valid_dataloader:
      valid_x, valid_y = batch
      logits = model(valid_x.to(device))

      loss = loss_fn(logits, valid_y.to(device))
      losses.append(loss.cpu().item())
      pred_labels = torch.argmax(logits, dim=1)
      num_preds = torch.sum((pred_labels == valid_y.to(device))).item()

      rcorrect += num_preds
      total += valid_x.shape[0]
    print(f"\tvalid acc at epoch: {epoch+1}/{num_epochs} : {rcorrect / total:.4f}")
    print(f"\tavg loss at epoch: {epoch+1}/{num_epochs}: {sum(losses) / len(losses):.4f}")

    scheduler.step()

100%|██████████| 140/140 [00:11<00:00, 11.70it/s]


	avg loss at epoch: 1/100: 2.6964
	valid acc at epoch: 1/100 : 0.3812
	avg loss at epoch: 1/100: 2.4053


100%|██████████| 140/140 [00:10<00:00, 13.24it/s]


	avg loss at epoch: 2/100: 1.8565
	valid acc at epoch: 2/100 : 0.4389
	avg loss at epoch: 2/100: 2.1426


100%|██████████| 140/140 [00:10<00:00, 13.37it/s]


	avg loss at epoch: 3/100: 1.5530
	valid acc at epoch: 3/100 : 0.4713
	avg loss at epoch: 3/100: 1.9960


100%|██████████| 140/140 [00:10<00:00, 13.27it/s]


	avg loss at epoch: 4/100: 1.3316
	valid acc at epoch: 4/100 : 0.5007
	avg loss at epoch: 4/100: 1.9106


100%|██████████| 140/140 [00:10<00:00, 13.60it/s]


	avg loss at epoch: 5/100: 1.1650
	valid acc at epoch: 5/100 : 0.4952
	avg loss at epoch: 5/100: 1.9587


100%|██████████| 140/140 [00:09<00:00, 14.17it/s]


	avg loss at epoch: 6/100: 1.0129
	valid acc at epoch: 6/100 : 0.5052
	avg loss at epoch: 6/100: 1.9594


100%|██████████| 140/140 [00:10<00:00, 13.61it/s]


	avg loss at epoch: 7/100: 0.9037
	valid acc at epoch: 7/100 : 0.5197
	avg loss at epoch: 7/100: 1.9655


100%|██████████| 140/140 [00:10<00:00, 13.30it/s]


	avg loss at epoch: 8/100: 0.7820
	valid acc at epoch: 8/100 : 0.5063
	avg loss at epoch: 8/100: 2.0697


100%|██████████| 140/140 [00:10<00:00, 13.13it/s]


	avg loss at epoch: 9/100: 0.6940
	valid acc at epoch: 9/100 : 0.5351
	avg loss at epoch: 9/100: 2.0051


100%|██████████| 140/140 [00:10<00:00, 13.19it/s]


	avg loss at epoch: 10/100: 0.6224
	valid acc at epoch: 10/100 : 0.5380
	avg loss at epoch: 10/100: 2.0184


100%|██████████| 140/140 [00:10<00:00, 13.47it/s]


	avg loss at epoch: 11/100: 0.5535
	valid acc at epoch: 11/100 : 0.5387
	avg loss at epoch: 11/100: 2.0979


100%|██████████| 140/140 [00:10<00:00, 13.44it/s]


	avg loss at epoch: 12/100: 0.4877
	valid acc at epoch: 12/100 : 0.5180
	avg loss at epoch: 12/100: 2.2578


100%|██████████| 140/140 [00:10<00:00, 13.38it/s]


	avg loss at epoch: 13/100: 0.4337
	valid acc at epoch: 13/100 : 0.5514
	avg loss at epoch: 13/100: 2.0513


100%|██████████| 140/140 [00:10<00:00, 13.44it/s]


	avg loss at epoch: 14/100: 0.4119
	valid acc at epoch: 14/100 : 0.5492
	avg loss at epoch: 14/100: 2.1378


100%|██████████| 140/140 [00:09<00:00, 14.06it/s]


	avg loss at epoch: 15/100: 0.3575
	valid acc at epoch: 15/100 : 0.5376
	avg loss at epoch: 15/100: 2.2278


100%|██████████| 140/140 [00:10<00:00, 13.73it/s]


	avg loss at epoch: 16/100: 0.3360
	valid acc at epoch: 16/100 : 0.5333
	avg loss at epoch: 16/100: 2.3858


100%|██████████| 140/140 [00:10<00:00, 13.39it/s]


	avg loss at epoch: 17/100: 0.3111
	valid acc at epoch: 17/100 : 0.5380
	avg loss at epoch: 17/100: 2.3189


100%|██████████| 140/140 [00:10<00:00, 13.35it/s]


	avg loss at epoch: 18/100: 0.2949
	valid acc at epoch: 18/100 : 0.5546
	avg loss at epoch: 18/100: 2.3044


100%|██████████| 140/140 [00:10<00:00, 13.43it/s]


	avg loss at epoch: 19/100: 0.2637
	valid acc at epoch: 19/100 : 0.5354
	avg loss at epoch: 19/100: 2.4912


100%|██████████| 140/140 [00:10<00:00, 13.49it/s]


	avg loss at epoch: 20/100: 0.2410
	valid acc at epoch: 20/100 : 0.5484
	avg loss at epoch: 20/100: 2.3383


100%|██████████| 140/140 [00:10<00:00, 13.49it/s]


	avg loss at epoch: 21/100: 0.2280
	valid acc at epoch: 21/100 : 0.5460
	avg loss at epoch: 21/100: 2.4482


100%|██████████| 140/140 [00:10<00:00, 13.56it/s]


	avg loss at epoch: 22/100: 0.2303
	valid acc at epoch: 22/100 : 0.5433
	avg loss at epoch: 22/100: 2.4794


100%|██████████| 140/140 [00:09<00:00, 14.02it/s]


	avg loss at epoch: 23/100: 0.2331
	valid acc at epoch: 23/100 : 0.5288
	avg loss at epoch: 23/100: 2.5763


100%|██████████| 140/140 [00:10<00:00, 13.87it/s]


	avg loss at epoch: 24/100: 0.2084
	valid acc at epoch: 24/100 : 0.5374
	avg loss at epoch: 24/100: 2.5736


100%|██████████| 140/140 [00:10<00:00, 13.56it/s]


	avg loss at epoch: 25/100: 0.2028
	valid acc at epoch: 25/100 : 0.5394
	avg loss at epoch: 25/100: 2.5760


100%|██████████| 140/140 [00:10<00:00, 13.39it/s]


	avg loss at epoch: 26/100: 0.1950
	valid acc at epoch: 26/100 : 0.5477
	avg loss at epoch: 26/100: 2.4866


100%|██████████| 140/140 [00:10<00:00, 12.98it/s]


	avg loss at epoch: 27/100: 0.1833
	valid acc at epoch: 27/100 : 0.5253
	avg loss at epoch: 27/100: 2.7556


100%|██████████| 140/140 [00:10<00:00, 13.52it/s]


	avg loss at epoch: 28/100: 0.1946
	valid acc at epoch: 28/100 : 0.5048
	avg loss at epoch: 28/100: 2.8193


100%|██████████| 140/140 [00:10<00:00, 13.54it/s]


	avg loss at epoch: 29/100: 0.1734
	valid acc at epoch: 29/100 : 0.5433
	avg loss at epoch: 29/100: 2.5670


100%|██████████| 140/140 [00:10<00:00, 13.57it/s]


	avg loss at epoch: 30/100: 0.1616
	valid acc at epoch: 30/100 : 0.5471
	avg loss at epoch: 30/100: 2.5752


100%|██████████| 140/140 [00:10<00:00, 13.98it/s]


	avg loss at epoch: 31/100: 0.1558
	valid acc at epoch: 31/100 : 0.5454
	avg loss at epoch: 31/100: 2.6092


100%|██████████| 140/140 [00:09<00:00, 14.35it/s]


	avg loss at epoch: 32/100: 0.1618
	valid acc at epoch: 32/100 : 0.5209
	avg loss at epoch: 32/100: 2.8710


100%|██████████| 140/140 [00:10<00:00, 13.72it/s]


	avg loss at epoch: 33/100: 0.1608
	valid acc at epoch: 33/100 : 0.5269
	avg loss at epoch: 33/100: 2.7368


100%|██████████| 140/140 [00:10<00:00, 13.35it/s]


	avg loss at epoch: 34/100: 0.1605
	valid acc at epoch: 34/100 : 0.5314
	avg loss at epoch: 34/100: 2.7466


100%|██████████| 140/140 [00:10<00:00, 13.22it/s]


	avg loss at epoch: 35/100: 0.1563
	valid acc at epoch: 35/100 : 0.5404
	avg loss at epoch: 35/100: 2.7312


100%|██████████| 140/140 [00:10<00:00, 13.21it/s]


	avg loss at epoch: 36/100: 0.1518
	valid acc at epoch: 36/100 : 0.5462
	avg loss at epoch: 36/100: 2.6588


100%|██████████| 140/140 [00:10<00:00, 13.40it/s]


	avg loss at epoch: 37/100: 0.1454
	valid acc at epoch: 37/100 : 0.5507
	avg loss at epoch: 37/100: 2.6355


100%|██████████| 140/140 [00:10<00:00, 13.42it/s]


	avg loss at epoch: 38/100: 0.1551
	valid acc at epoch: 38/100 : 0.5364
	avg loss at epoch: 38/100: 2.7782


100%|██████████| 140/140 [00:10<00:00, 13.62it/s]


	avg loss at epoch: 39/100: 0.1560
	valid acc at epoch: 39/100 : 0.5189
	avg loss at epoch: 39/100: 2.8949


100%|██████████| 140/140 [00:10<00:00, 13.58it/s]


	avg loss at epoch: 40/100: 0.1506
	valid acc at epoch: 40/100 : 0.5574
	avg loss at epoch: 40/100: 2.5878


100%|██████████| 140/140 [00:10<00:00, 13.67it/s]


	avg loss at epoch: 41/100: 0.0669
	valid acc at epoch: 41/100 : 0.5887
	avg loss at epoch: 41/100: 2.3239


100%|██████████| 140/140 [00:10<00:00, 13.36it/s]


	avg loss at epoch: 42/100: 0.0285
	valid acc at epoch: 42/100 : 0.5949
	avg loss at epoch: 42/100: 2.2936


100%|██████████| 140/140 [00:10<00:00, 13.21it/s]


	avg loss at epoch: 43/100: 0.0197
	valid acc at epoch: 43/100 : 0.5980
	avg loss at epoch: 43/100: 2.2791


100%|██████████| 140/140 [00:10<00:00, 13.42it/s]


	avg loss at epoch: 44/100: 0.0135
	valid acc at epoch: 44/100 : 0.6006
	avg loss at epoch: 44/100: 2.2894


100%|██████████| 140/140 [00:10<00:00, 13.55it/s]


	avg loss at epoch: 45/100: 0.0114
	valid acc at epoch: 45/100 : 0.6000
	avg loss at epoch: 45/100: 2.2849


100%|██████████| 140/140 [00:10<00:00, 13.44it/s]


	avg loss at epoch: 46/100: 0.0092
	valid acc at epoch: 46/100 : 0.6028
	avg loss at epoch: 46/100: 2.2847


100%|██████████| 140/140 [00:10<00:00, 13.21it/s]


	avg loss at epoch: 47/100: 0.0077
	valid acc at epoch: 47/100 : 0.6027
	avg loss at epoch: 47/100: 2.2858


100%|██████████| 140/140 [00:10<00:00, 13.67it/s]


	avg loss at epoch: 48/100: 0.0067
	valid acc at epoch: 48/100 : 0.5991
	avg loss at epoch: 48/100: 2.2869


100%|██████████| 140/140 [00:09<00:00, 14.03it/s]


	avg loss at epoch: 49/100: 0.0058
	valid acc at epoch: 49/100 : 0.6024
	avg loss at epoch: 49/100: 2.2882


100%|██████████| 140/140 [00:10<00:00, 13.91it/s]


	avg loss at epoch: 50/100: 0.0052
	valid acc at epoch: 50/100 : 0.6006
	avg loss at epoch: 50/100: 2.2883


100%|██████████| 140/140 [00:10<00:00, 13.45it/s]


	avg loss at epoch: 51/100: 0.0047
	valid acc at epoch: 51/100 : 0.6029
	avg loss at epoch: 51/100: 2.3039


100%|██████████| 140/140 [00:10<00:00, 13.39it/s]


	avg loss at epoch: 52/100: 0.0042
	valid acc at epoch: 52/100 : 0.6044
	avg loss at epoch: 52/100: 2.3050


100%|██████████| 140/140 [00:10<00:00, 13.15it/s]


	avg loss at epoch: 53/100: 0.0040
	valid acc at epoch: 53/100 : 0.6034
	avg loss at epoch: 53/100: 2.3134


100%|██████████| 140/140 [00:10<00:00, 13.24it/s]


	avg loss at epoch: 54/100: 0.0040
	valid acc at epoch: 54/100 : 0.6026
	avg loss at epoch: 54/100: 2.3187


100%|██████████| 140/140 [00:10<00:00, 13.47it/s]


	avg loss at epoch: 55/100: 0.0036
	valid acc at epoch: 55/100 : 0.6053
	avg loss at epoch: 55/100: 2.3275


100%|██████████| 140/140 [00:10<00:00, 13.38it/s]


	avg loss at epoch: 56/100: 0.0033
	valid acc at epoch: 56/100 : 0.6043
	avg loss at epoch: 56/100: 2.3279


100%|██████████| 140/140 [00:10<00:00, 13.88it/s]


	avg loss at epoch: 57/100: 0.0033
	valid acc at epoch: 57/100 : 0.6043
	avg loss at epoch: 57/100: 2.3355


100%|██████████| 140/140 [00:09<00:00, 14.10it/s]


	avg loss at epoch: 58/100: 0.0029
	valid acc at epoch: 58/100 : 0.6043
	avg loss at epoch: 58/100: 2.3359


100%|██████████| 140/140 [00:10<00:00, 13.63it/s]


	avg loss at epoch: 59/100: 0.0029
	valid acc at epoch: 59/100 : 0.6048
	avg loss at epoch: 59/100: 2.3418


100%|██████████| 140/140 [00:10<00:00, 13.26it/s]


	avg loss at epoch: 60/100: 0.0026
	valid acc at epoch: 60/100 : 0.6060
	avg loss at epoch: 60/100: 2.3517


100%|██████████| 140/140 [00:10<00:00, 13.44it/s]


	avg loss at epoch: 61/100: 0.0027
	valid acc at epoch: 61/100 : 0.6056
	avg loss at epoch: 61/100: 2.3547


100%|██████████| 140/140 [00:10<00:00, 13.28it/s]


	avg loss at epoch: 62/100: 0.0025
	valid acc at epoch: 62/100 : 0.6072
	avg loss at epoch: 62/100: 2.3617


100%|██████████| 140/140 [00:10<00:00, 13.27it/s]


	avg loss at epoch: 63/100: 0.0025
	valid acc at epoch: 63/100 : 0.6051
	avg loss at epoch: 63/100: 2.3668


100%|██████████| 140/140 [00:10<00:00, 13.39it/s]


	avg loss at epoch: 64/100: 0.0024
	valid acc at epoch: 64/100 : 0.6056
	avg loss at epoch: 64/100: 2.3776


100%|██████████| 140/140 [00:10<00:00, 13.79it/s]


	avg loss at epoch: 65/100: 0.0023
	valid acc at epoch: 65/100 : 0.6069
	avg loss at epoch: 65/100: 2.3759


100%|██████████| 140/140 [00:10<00:00, 13.64it/s]


	avg loss at epoch: 66/100: 0.0021
	valid acc at epoch: 66/100 : 0.6036
	avg loss at epoch: 66/100: 2.3857


100%|██████████| 140/140 [00:10<00:00, 13.91it/s]


	avg loss at epoch: 67/100: 0.0022
	valid acc at epoch: 67/100 : 0.6069
	avg loss at epoch: 67/100: 2.3894


100%|██████████| 140/140 [00:10<00:00, 13.60it/s]


	avg loss at epoch: 68/100: 0.0020
	valid acc at epoch: 68/100 : 0.6068
	avg loss at epoch: 68/100: 2.3901


100%|██████████| 140/140 [00:10<00:00, 13.28it/s]


	avg loss at epoch: 69/100: 0.0020
	valid acc at epoch: 69/100 : 0.6047
	avg loss at epoch: 69/100: 2.4111


100%|██████████| 140/140 [00:10<00:00, 13.32it/s]


	avg loss at epoch: 70/100: 0.0017
	valid acc at epoch: 70/100 : 0.6062
	avg loss at epoch: 70/100: 2.4092


100%|██████████| 140/140 [00:10<00:00, 13.38it/s]


	avg loss at epoch: 71/100: 0.0017
	valid acc at epoch: 71/100 : 0.6071
	avg loss at epoch: 71/100: 2.4139


100%|██████████| 140/140 [00:10<00:00, 13.29it/s]


	avg loss at epoch: 72/100: 0.0015
	valid acc at epoch: 72/100 : 0.6070
	avg loss at epoch: 72/100: 2.4025


100%|██████████| 140/140 [00:10<00:00, 13.25it/s]


	avg loss at epoch: 73/100: 0.0015
	valid acc at epoch: 73/100 : 0.6070
	avg loss at epoch: 73/100: 2.4116


100%|██████████| 140/140 [00:10<00:00, 13.91it/s]


	avg loss at epoch: 74/100: 0.0014
	valid acc at epoch: 74/100 : 0.6074
	avg loss at epoch: 74/100: 2.4196


100%|██████████| 140/140 [00:09<00:00, 14.18it/s]


	avg loss at epoch: 75/100: 0.0013
	valid acc at epoch: 75/100 : 0.6058
	avg loss at epoch: 75/100: 2.4113


100%|██████████| 140/140 [00:10<00:00, 13.63it/s]


	avg loss at epoch: 76/100: 0.0013
	valid acc at epoch: 76/100 : 0.6072
	avg loss at epoch: 76/100: 2.4135


100%|██████████| 140/140 [00:10<00:00, 13.53it/s]


	avg loss at epoch: 77/100: 0.0013
	valid acc at epoch: 77/100 : 0.6082
	avg loss at epoch: 77/100: 2.4100


100%|██████████| 140/140 [00:10<00:00, 13.24it/s]


	avg loss at epoch: 78/100: 0.0012
	valid acc at epoch: 78/100 : 0.6061
	avg loss at epoch: 78/100: 2.4180


100%|██████████| 140/140 [00:10<00:00, 13.38it/s]


	avg loss at epoch: 79/100: 0.0012
	valid acc at epoch: 79/100 : 0.6060
	avg loss at epoch: 79/100: 2.4209


100%|██████████| 140/140 [00:10<00:00, 13.21it/s]


	avg loss at epoch: 80/100: 0.0012
	valid acc at epoch: 80/100 : 0.6064
	avg loss at epoch: 80/100: 2.4177


100%|██████████| 140/140 [00:10<00:00, 13.12it/s]


	avg loss at epoch: 81/100: 0.0013
	valid acc at epoch: 81/100 : 0.6056
	avg loss at epoch: 81/100: 2.4211


100%|██████████| 140/140 [00:10<00:00, 13.51it/s]


	avg loss at epoch: 82/100: 0.0012
	valid acc at epoch: 82/100 : 0.6061
	avg loss at epoch: 82/100: 2.4216


100%|██████████| 140/140 [00:10<00:00, 13.99it/s]


	avg loss at epoch: 83/100: 0.0012
	valid acc at epoch: 83/100 : 0.6072
	avg loss at epoch: 83/100: 2.4197


100%|██████████| 140/140 [00:09<00:00, 14.20it/s]


	avg loss at epoch: 84/100: 0.0012
	valid acc at epoch: 84/100 : 0.6079
	avg loss at epoch: 84/100: 2.4170


100%|██████████| 140/140 [00:10<00:00, 13.75it/s]


	avg loss at epoch: 85/100: 0.0011
	valid acc at epoch: 85/100 : 0.6084
	avg loss at epoch: 85/100: 2.4239


100%|██████████| 140/140 [00:10<00:00, 13.27it/s]


	avg loss at epoch: 86/100: 0.0012
	valid acc at epoch: 86/100 : 0.6051
	avg loss at epoch: 86/100: 2.4242


100%|██████████| 140/140 [00:10<00:00, 13.25it/s]


	avg loss at epoch: 87/100: 0.0011
	valid acc at epoch: 87/100 : 0.6070
	avg loss at epoch: 87/100: 2.4254


100%|██████████| 140/140 [00:10<00:00, 13.28it/s]


	avg loss at epoch: 88/100: 0.0011
	valid acc at epoch: 88/100 : 0.6082
	avg loss at epoch: 88/100: 2.4235


100%|██████████| 140/140 [00:10<00:00, 13.16it/s]


	avg loss at epoch: 89/100: 0.0011
	valid acc at epoch: 89/100 : 0.6092
	avg loss at epoch: 89/100: 2.4327


100%|██████████| 140/140 [00:10<00:00, 13.24it/s]


	avg loss at epoch: 90/100: 0.0011
	valid acc at epoch: 90/100 : 0.6070
	avg loss at epoch: 90/100: 2.4312


100%|██████████| 140/140 [00:10<00:00, 13.63it/s]


	avg loss at epoch: 91/100: 0.0010
	valid acc at epoch: 91/100 : 0.6073
	avg loss at epoch: 91/100: 2.4319


100%|██████████| 140/140 [00:10<00:00, 13.98it/s]


	avg loss at epoch: 92/100: 0.0011
	valid acc at epoch: 92/100 : 0.6077
	avg loss at epoch: 92/100: 2.4367


100%|██████████| 140/140 [00:10<00:00, 13.33it/s]


	avg loss at epoch: 93/100: 0.0011
	valid acc at epoch: 93/100 : 0.6078
	avg loss at epoch: 93/100: 2.4353


100%|██████████| 140/140 [00:10<00:00, 13.59it/s]


	avg loss at epoch: 94/100: 0.0011
	valid acc at epoch: 94/100 : 0.6061
	avg loss at epoch: 94/100: 2.4419


100%|██████████| 140/140 [00:10<00:00, 13.38it/s]


	avg loss at epoch: 95/100: 0.0011
	valid acc at epoch: 95/100 : 0.6073
	avg loss at epoch: 95/100: 2.4456


100%|██████████| 140/140 [00:10<00:00, 13.43it/s]


	avg loss at epoch: 96/100: 0.0010
	valid acc at epoch: 96/100 : 0.6051
	avg loss at epoch: 96/100: 2.4409


100%|██████████| 140/140 [00:10<00:00, 13.35it/s]


	avg loss at epoch: 97/100: 0.0010
	valid acc at epoch: 97/100 : 0.6079
	avg loss at epoch: 97/100: 2.4571


100%|██████████| 140/140 [00:10<00:00, 13.39it/s]


	avg loss at epoch: 98/100: 0.0009
	valid acc at epoch: 98/100 : 0.6050
	avg loss at epoch: 98/100: 2.4497


100%|██████████| 140/140 [00:10<00:00, 13.35it/s]


	avg loss at epoch: 99/100: 0.0009
	valid acc at epoch: 99/100 : 0.6067
	avg loss at epoch: 99/100: 2.4636


100%|██████████| 140/140 [00:09<00:00, 14.11it/s]


	avg loss at epoch: 100/100: 0.0010
	valid acc at epoch: 100/100 : 0.6078
	avg loss at epoch: 100/100: 2.4553


In [None]:
torch.save(model.state_dict(), './resnet_18_retrained.pth')