In [6]:
from google.colab import drive
drive.mount('/content/drive')

%cd ./drive/MyDrive/esc/prepare

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[Errno 2] No such file or directory: './drive/MyDrive/esc/prepare'
/content/drive/MyDrive/esc/prepare


In [7]:
%load_ext autoreload
%autoreload 2

import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append('..')

from tqdm import tqdm
from network import AllCNN
from dataset import InputPipeLineBuilder

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
num_epochs = 350
batch_size = 256

device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 0.05
weight_decay = 0.001

model = AllCNN(head_input_dim=10).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)

def lr_lambda(epoch):
    if epoch < 200:
        return 1.0
    elif epoch < 250:
        return 0.1
    elif epoch < 300:
        return 0.01
    else:
        return 0.001

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lr_lambda)

In [15]:
input_pipeline_builder = InputPipeLineBuilder(batch_size=batch_size, dataset='cifar10')

train_dataloader = input_pipeline_builder.get_dataloader(subset='train')
test_dataloader = input_pipeline_builder.get_dataloader(subset='test')

In [16]:
for epoch in range(num_epochs):
    losses = []

    model.train()
    for batch in tqdm(train_dataloader):
        batch_x, batch_y = batch
        logits = model(batch_x.to(device))

        loss = loss_fn(logits, batch_y.to(device))
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.cpu().item())

    scheduler.step()
    print(f"avg loss at epoch: {epoch+1}/{num_epochs}: {sum(losses) / len(losses):.4f}")


100%|██████████| 156/156 [00:15<00:00,  9.97it/s]


avg loss at epoch: 1/350: 1.6966


100%|██████████| 156/156 [00:16<00:00,  9.51it/s]


avg loss at epoch: 2/350: 1.2144


100%|██████████| 156/156 [00:16<00:00,  9.61it/s]


avg loss at epoch: 3/350: 1.0276


100%|██████████| 156/156 [00:15<00:00, 10.11it/s]


avg loss at epoch: 4/350: 0.9129


100%|██████████| 156/156 [00:15<00:00, 10.11it/s]


avg loss at epoch: 5/350: 0.8395


100%|██████████| 156/156 [00:15<00:00, 10.03it/s]


avg loss at epoch: 6/350: 0.7711


100%|██████████| 156/156 [00:15<00:00,  9.93it/s]


avg loss at epoch: 7/350: 0.7209


100%|██████████| 156/156 [00:15<00:00,  9.91it/s]


avg loss at epoch: 8/350: 0.6765


100%|██████████| 156/156 [00:15<00:00, 10.17it/s]


avg loss at epoch: 9/350: 0.6337


100%|██████████| 156/156 [00:15<00:00, 10.17it/s]


avg loss at epoch: 10/350: 0.5961


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 11/350: 0.5641


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 12/350: 0.5458


100%|██████████| 156/156 [00:15<00:00, 10.12it/s]


avg loss at epoch: 13/350: 0.5257


100%|██████████| 156/156 [00:15<00:00,  9.99it/s]


avg loss at epoch: 14/350: 0.5022


100%|██████████| 156/156 [00:15<00:00, 10.21it/s]


avg loss at epoch: 15/350: 0.4904


100%|██████████| 156/156 [00:15<00:00, 10.19it/s]


avg loss at epoch: 16/350: 0.4733


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 17/350: 0.4662


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 18/350: 0.4522


100%|██████████| 156/156 [00:15<00:00,  9.77it/s]


avg loss at epoch: 19/350: 0.4398


100%|██████████| 156/156 [00:15<00:00,  9.84it/s]


avg loss at epoch: 20/350: 0.4401


100%|██████████| 156/156 [00:15<00:00, 10.16it/s]


avg loss at epoch: 21/350: 0.4247


100%|██████████| 156/156 [00:15<00:00,  9.93it/s]


avg loss at epoch: 22/350: 0.4098


100%|██████████| 156/156 [00:15<00:00, 10.08it/s]


avg loss at epoch: 23/350: 0.4108


100%|██████████| 156/156 [00:15<00:00, 10.19it/s]


avg loss at epoch: 24/350: 0.4095


100%|██████████| 156/156 [00:15<00:00,  9.83it/s]


avg loss at epoch: 25/350: 0.4023


100%|██████████| 156/156 [00:15<00:00, 10.01it/s]


avg loss at epoch: 26/350: 0.3937


100%|██████████| 156/156 [00:15<00:00, 10.04it/s]


avg loss at epoch: 27/350: 0.3880


100%|██████████| 156/156 [00:15<00:00, 10.13it/s]


avg loss at epoch: 28/350: 0.3794


100%|██████████| 156/156 [00:15<00:00, 10.04it/s]


avg loss at epoch: 29/350: 0.3735


100%|██████████| 156/156 [00:15<00:00, 10.08it/s]


avg loss at epoch: 30/350: 0.3805


100%|██████████| 156/156 [00:15<00:00,  9.93it/s]


avg loss at epoch: 31/350: 0.3709


100%|██████████| 156/156 [00:15<00:00, 10.10it/s]


avg loss at epoch: 32/350: 0.3696


100%|██████████| 156/156 [00:15<00:00, 10.06it/s]


avg loss at epoch: 33/350: 0.3692


100%|██████████| 156/156 [00:15<00:00, 10.06it/s]


avg loss at epoch: 34/350: 0.3550


100%|██████████| 156/156 [00:15<00:00, 10.07it/s]


avg loss at epoch: 35/350: 0.3579


100%|██████████| 156/156 [00:15<00:00,  9.91it/s]


avg loss at epoch: 36/350: 0.3572


100%|██████████| 156/156 [00:15<00:00,  9.78it/s]


avg loss at epoch: 37/350: 0.3460


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 38/350: 0.3470


100%|██████████| 156/156 [00:15<00:00, 10.22it/s]


avg loss at epoch: 39/350: 0.3520


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 40/350: 0.3438


100%|██████████| 156/156 [00:15<00:00, 10.18it/s]


avg loss at epoch: 41/350: 0.3450


100%|██████████| 156/156 [00:15<00:00,  9.94it/s]


avg loss at epoch: 42/350: 0.3412


100%|██████████| 156/156 [00:15<00:00, 10.12it/s]


avg loss at epoch: 43/350: 0.3345


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 44/350: 0.3375


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 45/350: 0.3396


100%|██████████| 156/156 [00:15<00:00, 10.29it/s]


avg loss at epoch: 46/350: 0.3238


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 47/350: 0.3205


100%|██████████| 156/156 [00:15<00:00,  9.99it/s]


avg loss at epoch: 48/350: 0.3236


100%|██████████| 156/156 [00:15<00:00, 10.07it/s]


avg loss at epoch: 49/350: 0.3158


100%|██████████| 156/156 [00:15<00:00, 10.22it/s]


avg loss at epoch: 50/350: 0.3245


100%|██████████| 156/156 [00:15<00:00, 10.19it/s]


avg loss at epoch: 51/350: 0.3174


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 52/350: 0.3226


100%|██████████| 156/156 [00:15<00:00, 10.26it/s]


avg loss at epoch: 53/350: 0.3212


100%|██████████| 156/156 [00:16<00:00,  9.73it/s]


avg loss at epoch: 54/350: 0.3173


100%|██████████| 156/156 [00:15<00:00, 10.13it/s]


avg loss at epoch: 55/350: 0.3263


100%|██████████| 156/156 [00:15<00:00, 10.14it/s]


avg loss at epoch: 56/350: 0.3113


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 57/350: 0.3189


100%|██████████| 156/156 [00:15<00:00, 10.28it/s]


avg loss at epoch: 58/350: 0.3140


100%|██████████| 156/156 [00:15<00:00, 10.19it/s]


avg loss at epoch: 59/350: 0.3100


100%|██████████| 156/156 [00:15<00:00,  9.92it/s]


avg loss at epoch: 60/350: 0.2989


100%|██████████| 156/156 [00:15<00:00, 10.17it/s]


avg loss at epoch: 61/350: 0.3164


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 62/350: 0.3128


100%|██████████| 156/156 [00:15<00:00, 10.10it/s]


avg loss at epoch: 63/350: 0.3165


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 64/350: 0.3019


100%|██████████| 156/156 [00:15<00:00, 10.16it/s]


avg loss at epoch: 65/350: 0.2939


100%|██████████| 156/156 [00:15<00:00,  9.88it/s]


avg loss at epoch: 66/350: 0.2983


100%|██████████| 156/156 [00:15<00:00, 10.22it/s]


avg loss at epoch: 67/350: 0.3023


100%|██████████| 156/156 [00:15<00:00, 10.22it/s]


avg loss at epoch: 68/350: 0.3123


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 69/350: 0.3022


100%|██████████| 156/156 [00:15<00:00, 10.21it/s]


avg loss at epoch: 70/350: 0.2993


100%|██████████| 156/156 [00:15<00:00,  9.83it/s]


avg loss at epoch: 71/350: 0.2913


100%|██████████| 156/156 [00:15<00:00,  9.92it/s]


avg loss at epoch: 72/350: 0.3118


100%|██████████| 156/156 [00:15<00:00, 10.27it/s]


avg loss at epoch: 73/350: 0.2923


100%|██████████| 156/156 [00:15<00:00, 10.19it/s]


avg loss at epoch: 74/350: 0.2970


100%|██████████| 156/156 [00:15<00:00, 10.27it/s]


avg loss at epoch: 75/350: 0.2943


100%|██████████| 156/156 [00:15<00:00, 10.18it/s]


avg loss at epoch: 76/350: 0.3000


100%|██████████| 156/156 [00:15<00:00, 10.16it/s]


avg loss at epoch: 77/350: 0.2987


100%|██████████| 156/156 [00:15<00:00,  9.94it/s]


avg loss at epoch: 78/350: 0.2976


100%|██████████| 156/156 [00:15<00:00, 10.19it/s]


avg loss at epoch: 79/350: 0.2906


100%|██████████| 156/156 [00:15<00:00, 10.27it/s]


avg loss at epoch: 80/350: 0.2876


100%|██████████| 156/156 [00:15<00:00, 10.25it/s]


avg loss at epoch: 81/350: 0.3075


100%|██████████| 156/156 [00:15<00:00, 10.26it/s]


avg loss at epoch: 82/350: 0.2900


100%|██████████| 156/156 [00:15<00:00, 10.12it/s]


avg loss at epoch: 83/350: 0.2910


100%|██████████| 156/156 [00:15<00:00,  9.94it/s]


avg loss at epoch: 84/350: 0.2967


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 85/350: 0.2931


100%|██████████| 156/156 [00:15<00:00, 10.25it/s]


avg loss at epoch: 86/350: 0.2873


100%|██████████| 156/156 [00:15<00:00, 10.18it/s]


avg loss at epoch: 87/350: 0.2864


100%|██████████| 156/156 [00:15<00:00,  9.96it/s]


avg loss at epoch: 88/350: 0.2990


100%|██████████| 156/156 [00:15<00:00,  9.89it/s]


avg loss at epoch: 89/350: 0.2827


100%|██████████| 156/156 [00:15<00:00, 10.12it/s]


avg loss at epoch: 90/350: 0.2900


100%|██████████| 156/156 [00:15<00:00, 10.35it/s]


avg loss at epoch: 91/350: 0.2854


100%|██████████| 156/156 [00:15<00:00, 10.34it/s]


avg loss at epoch: 92/350: 0.2827


100%|██████████| 156/156 [00:15<00:00, 10.39it/s]


avg loss at epoch: 93/350: 0.2891


100%|██████████| 156/156 [00:15<00:00, 10.34it/s]


avg loss at epoch: 94/350: 0.2915


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 95/350: 0.2811


100%|██████████| 156/156 [00:15<00:00, 10.11it/s]


avg loss at epoch: 96/350: 0.2859


100%|██████████| 156/156 [00:15<00:00, 10.33it/s]


avg loss at epoch: 97/350: 0.2763


100%|██████████| 156/156 [00:15<00:00, 10.34it/s]


avg loss at epoch: 98/350: 0.2880


100%|██████████| 156/156 [00:15<00:00, 10.36it/s]


avg loss at epoch: 99/350: 0.2835


100%|██████████| 156/156 [00:15<00:00, 10.39it/s]


avg loss at epoch: 100/350: 0.2811


100%|██████████| 156/156 [00:14<00:00, 10.41it/s]


avg loss at epoch: 101/350: 0.2880


100%|██████████| 156/156 [00:15<00:00, 10.03it/s]


avg loss at epoch: 102/350: 0.2897


100%|██████████| 156/156 [00:15<00:00, 10.21it/s]


avg loss at epoch: 103/350: 0.2733


100%|██████████| 156/156 [00:15<00:00, 10.35it/s]


avg loss at epoch: 104/350: 0.2972


100%|██████████| 156/156 [00:15<00:00, 10.35it/s]


avg loss at epoch: 105/350: 0.2782


100%|██████████| 156/156 [00:15<00:00, 10.17it/s]


avg loss at epoch: 106/350: 0.2817


100%|██████████| 156/156 [00:14<00:00, 10.41it/s]


avg loss at epoch: 107/350: 0.2818


100%|██████████| 156/156 [00:15<00:00, 10.09it/s]


avg loss at epoch: 108/350: 0.2868


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 109/350: 0.2816


100%|██████████| 156/156 [00:15<00:00, 10.38it/s]


avg loss at epoch: 110/350: 0.2817


100%|██████████| 156/156 [00:15<00:00, 10.31it/s]


avg loss at epoch: 111/350: 0.2772


100%|██████████| 156/156 [00:15<00:00, 10.30it/s]


avg loss at epoch: 112/350: 0.2748


100%|██████████| 156/156 [00:15<00:00, 10.29it/s]


avg loss at epoch: 113/350: 0.2856


100%|██████████| 156/156 [00:15<00:00, 10.13it/s]


avg loss at epoch: 114/350: 0.2784


100%|██████████| 156/156 [00:15<00:00, 10.13it/s]


avg loss at epoch: 115/350: 0.2736


100%|██████████| 156/156 [00:15<00:00, 10.27it/s]


avg loss at epoch: 116/350: 0.2798


100%|██████████| 156/156 [00:15<00:00, 10.27it/s]


avg loss at epoch: 117/350: 0.2803


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 118/350: 0.2776


100%|██████████| 156/156 [00:15<00:00, 10.35it/s]


avg loss at epoch: 119/350: 0.2810


100%|██████████| 156/156 [00:15<00:00, 10.13it/s]


avg loss at epoch: 120/350: 0.2745


100%|██████████| 156/156 [00:15<00:00, 10.04it/s]


avg loss at epoch: 121/350: 0.2777


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 122/350: 0.2860


100%|██████████| 156/156 [00:15<00:00, 10.11it/s]


avg loss at epoch: 123/350: 0.2754


100%|██████████| 156/156 [00:15<00:00, 10.28it/s]


avg loss at epoch: 124/350: 0.2869


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 125/350: 0.2798


100%|██████████| 156/156 [00:15<00:00, 10.14it/s]


avg loss at epoch: 126/350: 0.2718


100%|██████████| 156/156 [00:15<00:00,  9.99it/s]


avg loss at epoch: 127/350: 0.2690


100%|██████████| 156/156 [00:15<00:00, 10.31it/s]


avg loss at epoch: 128/350: 0.2768


100%|██████████| 156/156 [00:15<00:00, 10.12it/s]


avg loss at epoch: 129/350: 0.2791


100%|██████████| 156/156 [00:15<00:00, 10.34it/s]


avg loss at epoch: 130/350: 0.2712


100%|██████████| 156/156 [00:15<00:00, 10.39it/s]


avg loss at epoch: 131/350: 0.2756


100%|██████████| 156/156 [00:15<00:00, 10.40it/s]


avg loss at epoch: 132/350: 0.2711


100%|██████████| 156/156 [00:15<00:00, 10.09it/s]


avg loss at epoch: 133/350: 0.2674


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 134/350: 0.2704


100%|██████████| 156/156 [00:15<00:00, 10.29it/s]


avg loss at epoch: 135/350: 0.2733


100%|██████████| 156/156 [00:15<00:00, 10.26it/s]


avg loss at epoch: 136/350: 0.2758


100%|██████████| 156/156 [00:14<00:00, 10.48it/s]


avg loss at epoch: 137/350: 0.2683


100%|██████████| 156/156 [00:15<00:00, 10.16it/s]


avg loss at epoch: 138/350: 0.2773


100%|██████████| 156/156 [00:15<00:00, 10.17it/s]


avg loss at epoch: 139/350: 0.2784


100%|██████████| 156/156 [00:15<00:00, 10.07it/s]


avg loss at epoch: 140/350: 0.2708


100%|██████████| 156/156 [00:14<00:00, 10.40it/s]


avg loss at epoch: 141/350: 0.2746


100%|██████████| 156/156 [00:15<00:00, 10.28it/s]


avg loss at epoch: 142/350: 0.2737


100%|██████████| 156/156 [00:14<00:00, 10.45it/s]


avg loss at epoch: 143/350: 0.2735


100%|██████████| 156/156 [00:15<00:00, 10.32it/s]


avg loss at epoch: 144/350: 0.2696


100%|██████████| 156/156 [00:15<00:00, 10.12it/s]


avg loss at epoch: 145/350: 0.2649


100%|██████████| 156/156 [00:15<00:00, 10.09it/s]


avg loss at epoch: 146/350: 0.2785


100%|██████████| 156/156 [00:15<00:00, 10.38it/s]


avg loss at epoch: 147/350: 0.2812


100%|██████████| 156/156 [00:15<00:00, 10.35it/s]


avg loss at epoch: 148/350: 0.2734


100%|██████████| 156/156 [00:15<00:00, 10.35it/s]


avg loss at epoch: 149/350: 0.2728


100%|██████████| 156/156 [00:15<00:00, 10.36it/s]


avg loss at epoch: 150/350: 0.2716


100%|██████████| 156/156 [00:15<00:00, 10.22it/s]


avg loss at epoch: 151/350: 0.2689


100%|██████████| 156/156 [00:15<00:00, 10.11it/s]


avg loss at epoch: 152/350: 0.2830


100%|██████████| 156/156 [00:15<00:00, 10.38it/s]


avg loss at epoch: 153/350: 0.2688


100%|██████████| 156/156 [00:15<00:00, 10.36it/s]


avg loss at epoch: 154/350: 0.2612


100%|██████████| 156/156 [00:15<00:00, 10.33it/s]


avg loss at epoch: 155/350: 0.2672


100%|██████████| 156/156 [00:14<00:00, 10.48it/s]


avg loss at epoch: 156/350: 0.2797


100%|██████████| 156/156 [00:15<00:00, 10.02it/s]


avg loss at epoch: 157/350: 0.2681


100%|██████████| 156/156 [00:15<00:00, 10.07it/s]


avg loss at epoch: 158/350: 0.2713


100%|██████████| 156/156 [00:15<00:00, 10.37it/s]


avg loss at epoch: 159/350: 0.2770


100%|██████████| 156/156 [00:15<00:00, 10.38it/s]


avg loss at epoch: 160/350: 0.2684


100%|██████████| 156/156 [00:15<00:00, 10.38it/s]


avg loss at epoch: 161/350: 0.2733


100%|██████████| 156/156 [00:14<00:00, 10.42it/s]


avg loss at epoch: 162/350: 0.2656


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 163/350: 0.2712


100%|██████████| 156/156 [00:15<00:00, 10.06it/s]


avg loss at epoch: 164/350: 0.2630


100%|██████████| 156/156 [00:15<00:00, 10.28it/s]


avg loss at epoch: 165/350: 0.2733


100%|██████████| 156/156 [00:15<00:00, 10.38it/s]


avg loss at epoch: 166/350: 0.2638


100%|██████████| 156/156 [00:15<00:00, 10.38it/s]


avg loss at epoch: 167/350: 0.2700


100%|██████████| 156/156 [00:15<00:00, 10.38it/s]


avg loss at epoch: 168/350: 0.2583


100%|██████████| 156/156 [00:14<00:00, 10.45it/s]


avg loss at epoch: 169/350: 0.2735


100%|██████████| 156/156 [00:15<00:00, 10.26it/s]


avg loss at epoch: 170/350: 0.2676


100%|██████████| 156/156 [00:15<00:00, 10.14it/s]


avg loss at epoch: 171/350: 0.2595


100%|██████████| 156/156 [00:14<00:00, 10.41it/s]


avg loss at epoch: 172/350: 0.2666


100%|██████████| 156/156 [00:15<00:00, 10.35it/s]


avg loss at epoch: 173/350: 0.2724


100%|██████████| 156/156 [00:15<00:00, 10.30it/s]


avg loss at epoch: 174/350: 0.2662


100%|██████████| 156/156 [00:15<00:00, 10.15it/s]


avg loss at epoch: 175/350: 0.2589


100%|██████████| 156/156 [00:15<00:00, 10.36it/s]


avg loss at epoch: 176/350: 0.2778


100%|██████████| 156/156 [00:15<00:00, 10.02it/s]


avg loss at epoch: 177/350: 0.2692


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 178/350: 0.2630


100%|██████████| 156/156 [00:15<00:00, 10.28it/s]


avg loss at epoch: 179/350: 0.2660


100%|██████████| 156/156 [00:15<00:00, 10.31it/s]


avg loss at epoch: 180/350: 0.2656


100%|██████████| 156/156 [00:15<00:00, 10.29it/s]


avg loss at epoch: 181/350: 0.2620


100%|██████████| 156/156 [00:15<00:00, 10.22it/s]


avg loss at epoch: 182/350: 0.2697


100%|██████████| 156/156 [00:15<00:00, 10.08it/s]


avg loss at epoch: 183/350: 0.2580


100%|██████████| 156/156 [00:15<00:00, 10.27it/s]


avg loss at epoch: 184/350: 0.2652


100%|██████████| 156/156 [00:15<00:00, 10.34it/s]


avg loss at epoch: 185/350: 0.2671


100%|██████████| 156/156 [00:15<00:00, 10.34it/s]


avg loss at epoch: 186/350: 0.2655


100%|██████████| 156/156 [00:15<00:00, 10.28it/s]


avg loss at epoch: 187/350: 0.2676


100%|██████████| 156/156 [00:15<00:00, 10.28it/s]


avg loss at epoch: 188/350: 0.2612


100%|██████████| 156/156 [00:15<00:00, 10.03it/s]


avg loss at epoch: 189/350: 0.2653


100%|██████████| 156/156 [00:15<00:00,  9.96it/s]


avg loss at epoch: 190/350: 0.2687


100%|██████████| 156/156 [00:15<00:00, 10.07it/s]


avg loss at epoch: 191/350: 0.2649


100%|██████████| 156/156 [00:15<00:00,  9.98it/s]


avg loss at epoch: 192/350: 0.2656


100%|██████████| 156/156 [00:15<00:00, 10.12it/s]


avg loss at epoch: 193/350: 0.2660


100%|██████████| 156/156 [00:15<00:00, 10.04it/s]


avg loss at epoch: 194/350: 0.2614


100%|██████████| 156/156 [00:15<00:00,  9.91it/s]


avg loss at epoch: 195/350: 0.2692


100%|██████████| 156/156 [00:15<00:00, 10.22it/s]


avg loss at epoch: 196/350: 0.2592


100%|██████████| 156/156 [00:15<00:00, 10.13it/s]


avg loss at epoch: 197/350: 0.2651


100%|██████████| 156/156 [00:15<00:00,  9.94it/s]


avg loss at epoch: 198/350: 0.2596


100%|██████████| 156/156 [00:15<00:00, 10.13it/s]


avg loss at epoch: 199/350: 0.2679


100%|██████████| 156/156 [00:15<00:00, 10.02it/s]


avg loss at epoch: 200/350: 0.2667


100%|██████████| 156/156 [00:15<00:00, 10.05it/s]


avg loss at epoch: 201/350: 0.1501


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 202/350: 0.1103


100%|██████████| 156/156 [00:15<00:00, 10.16it/s]


avg loss at epoch: 203/350: 0.0964


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 204/350: 0.0890


100%|██████████| 156/156 [00:15<00:00, 10.26it/s]


avg loss at epoch: 205/350: 0.0815


100%|██████████| 156/156 [00:15<00:00, 10.06it/s]


avg loss at epoch: 206/350: 0.0773


100%|██████████| 156/156 [00:15<00:00, 10.11it/s]


avg loss at epoch: 207/350: 0.0735


100%|██████████| 156/156 [00:15<00:00, 10.31it/s]


avg loss at epoch: 208/350: 0.0701


100%|██████████| 156/156 [00:15<00:00, 10.09it/s]


avg loss at epoch: 209/350: 0.0665


100%|██████████| 156/156 [00:15<00:00, 10.28it/s]


avg loss at epoch: 210/350: 0.0636


100%|██████████| 156/156 [00:15<00:00, 10.35it/s]


avg loss at epoch: 211/350: 0.0619


100%|██████████| 156/156 [00:15<00:00, 10.07it/s]


avg loss at epoch: 212/350: 0.0594


100%|██████████| 156/156 [00:15<00:00, 10.10it/s]


avg loss at epoch: 213/350: 0.0575


100%|██████████| 156/156 [00:15<00:00, 10.29it/s]


avg loss at epoch: 214/350: 0.0560


100%|██████████| 156/156 [00:15<00:00, 10.31it/s]


avg loss at epoch: 215/350: 0.0552


100%|██████████| 156/156 [00:15<00:00, 10.27it/s]


avg loss at epoch: 216/350: 0.0532


100%|██████████| 156/156 [00:15<00:00, 10.28it/s]


avg loss at epoch: 217/350: 0.0514


100%|██████████| 156/156 [00:15<00:00, 10.08it/s]


avg loss at epoch: 218/350: 0.0498


100%|██████████| 156/156 [00:15<00:00, 10.05it/s]


avg loss at epoch: 219/350: 0.0487


100%|██████████| 156/156 [00:15<00:00, 10.26it/s]


avg loss at epoch: 220/350: 0.0483


100%|██████████| 156/156 [00:15<00:00, 10.28it/s]


avg loss at epoch: 221/350: 0.0469


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 222/350: 0.0458


100%|██████████| 156/156 [00:15<00:00, 10.33it/s]


avg loss at epoch: 223/350: 0.0452


100%|██████████| 156/156 [00:15<00:00, 10.14it/s]


avg loss at epoch: 224/350: 0.0444


100%|██████████| 156/156 [00:15<00:00, 10.08it/s]


avg loss at epoch: 225/350: 0.0432


100%|██████████| 156/156 [00:15<00:00, 10.03it/s]


avg loss at epoch: 226/350: 0.0437


100%|██████████| 156/156 [00:15<00:00, 10.35it/s]


avg loss at epoch: 227/350: 0.0418


100%|██████████| 156/156 [00:15<00:00, 10.26it/s]


avg loss at epoch: 228/350: 0.0413


100%|██████████| 156/156 [00:15<00:00, 10.30it/s]


avg loss at epoch: 229/350: 0.0403


100%|██████████| 156/156 [00:15<00:00, 10.12it/s]


avg loss at epoch: 230/350: 0.0405


100%|██████████| 156/156 [00:15<00:00, 10.08it/s]


avg loss at epoch: 231/350: 0.0396


100%|██████████| 156/156 [00:15<00:00, 10.25it/s]


avg loss at epoch: 232/350: 0.0390


100%|██████████| 156/156 [00:15<00:00, 10.37it/s]


avg loss at epoch: 233/350: 0.0387


100%|██████████| 156/156 [00:15<00:00, 10.31it/s]


avg loss at epoch: 234/350: 0.0377


100%|██████████| 156/156 [00:15<00:00, 10.30it/s]


avg loss at epoch: 235/350: 0.0380


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 236/350: 0.0367


100%|██████████| 156/156 [00:15<00:00,  9.99it/s]


avg loss at epoch: 237/350: 0.0373


100%|██████████| 156/156 [00:15<00:00, 10.25it/s]


avg loss at epoch: 238/350: 0.0362


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 239/350: 0.0364


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 240/350: 0.0359


100%|██████████| 156/156 [00:15<00:00, 10.25it/s]


avg loss at epoch: 241/350: 0.0361


100%|██████████| 156/156 [00:15<00:00, 10.14it/s]


avg loss at epoch: 242/350: 0.0358


100%|██████████| 156/156 [00:16<00:00,  9.71it/s]


avg loss at epoch: 243/350: 0.0352


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 244/350: 0.0346


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 245/350: 0.0348


100%|██████████| 156/156 [00:15<00:00, 10.21it/s]


avg loss at epoch: 246/350: 0.0345


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 247/350: 0.0341


100%|██████████| 156/156 [00:15<00:00, 10.03it/s]


avg loss at epoch: 248/350: 0.0341


100%|██████████| 156/156 [00:15<00:00,  9.91it/s]


avg loss at epoch: 249/350: 0.0330


100%|██████████| 156/156 [00:15<00:00, 10.12it/s]


avg loss at epoch: 250/350: 0.0334


100%|██████████| 156/156 [00:15<00:00, 10.32it/s]


avg loss at epoch: 251/350: 0.0301


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 252/350: 0.0287


100%|██████████| 156/156 [00:15<00:00, 10.21it/s]


avg loss at epoch: 253/350: 0.0285


100%|██████████| 156/156 [00:15<00:00, 10.09it/s]


avg loss at epoch: 254/350: 0.0284


100%|██████████| 156/156 [00:15<00:00, 10.01it/s]


avg loss at epoch: 255/350: 0.0286


100%|██████████| 156/156 [00:15<00:00, 10.21it/s]


avg loss at epoch: 256/350: 0.0281


100%|██████████| 156/156 [00:15<00:00, 10.16it/s]


avg loss at epoch: 257/350: 0.0280


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 258/350: 0.0284


100%|██████████| 156/156 [00:15<00:00, 10.29it/s]


avg loss at epoch: 259/350: 0.0282


100%|██████████| 156/156 [00:15<00:00, 10.02it/s]


avg loss at epoch: 260/350: 0.0284


100%|██████████| 156/156 [00:15<00:00,  9.82it/s]


avg loss at epoch: 261/350: 0.0283


100%|██████████| 156/156 [00:15<00:00, 10.19it/s]


avg loss at epoch: 262/350: 0.0277


100%|██████████| 156/156 [00:15<00:00, 10.18it/s]


avg loss at epoch: 263/350: 0.0279


100%|██████████| 156/156 [00:15<00:00, 10.18it/s]


avg loss at epoch: 264/350: 0.0282


100%|██████████| 156/156 [00:15<00:00, 10.26it/s]


avg loss at epoch: 265/350: 0.0279


100%|██████████| 156/156 [00:15<00:00,  9.97it/s]


avg loss at epoch: 266/350: 0.0280


100%|██████████| 156/156 [00:15<00:00,  9.99it/s]


avg loss at epoch: 267/350: 0.0280


100%|██████████| 156/156 [00:15<00:00, 10.18it/s]


avg loss at epoch: 268/350: 0.0279


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 269/350: 0.0275


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 270/350: 0.0277


100%|██████████| 156/156 [00:15<00:00, 10.17it/s]


avg loss at epoch: 271/350: 0.0274


100%|██████████| 156/156 [00:15<00:00, 10.06it/s]


avg loss at epoch: 272/350: 0.0275


100%|██████████| 156/156 [00:15<00:00, 10.17it/s]


avg loss at epoch: 273/350: 0.0279


100%|██████████| 156/156 [00:15<00:00, 10.30it/s]


avg loss at epoch: 274/350: 0.0279


100%|██████████| 156/156 [00:15<00:00, 10.32it/s]


avg loss at epoch: 275/350: 0.0278


100%|██████████| 156/156 [00:15<00:00, 10.25it/s]


avg loss at epoch: 276/350: 0.0276


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 277/350: 0.0275


100%|██████████| 156/156 [00:16<00:00,  9.63it/s]


avg loss at epoch: 278/350: 0.0273


100%|██████████| 156/156 [00:15<00:00, 10.19it/s]


avg loss at epoch: 279/350: 0.0275


100%|██████████| 156/156 [00:15<00:00, 10.26it/s]


avg loss at epoch: 280/350: 0.0274


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 281/350: 0.0273


100%|██████████| 156/156 [00:15<00:00, 10.26it/s]


avg loss at epoch: 282/350: 0.0276


100%|██████████| 156/156 [00:15<00:00, 10.19it/s]


avg loss at epoch: 283/350: 0.0274


100%|██████████| 156/156 [00:15<00:00, 10.05it/s]


avg loss at epoch: 284/350: 0.0277


100%|██████████| 156/156 [00:15<00:00, 10.26it/s]


avg loss at epoch: 285/350: 0.0277


100%|██████████| 156/156 [00:15<00:00, 10.27it/s]


avg loss at epoch: 286/350: 0.0277


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 287/350: 0.0275


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 288/350: 0.0278


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 289/350: 0.0272


100%|██████████| 156/156 [00:15<00:00,  9.93it/s]


avg loss at epoch: 290/350: 0.0276


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 291/350: 0.0273


100%|██████████| 156/156 [00:15<00:00, 10.27it/s]


avg loss at epoch: 292/350: 0.0275


100%|██████████| 156/156 [00:15<00:00, 10.27it/s]


avg loss at epoch: 293/350: 0.0273


100%|██████████| 156/156 [00:15<00:00, 10.22it/s]


avg loss at epoch: 294/350: 0.0271


100%|██████████| 156/156 [00:15<00:00,  9.85it/s]


avg loss at epoch: 295/350: 0.0273


100%|██████████| 156/156 [00:15<00:00,  9.93it/s]


avg loss at epoch: 296/350: 0.0276


100%|██████████| 156/156 [00:15<00:00, 10.21it/s]


avg loss at epoch: 297/350: 0.0276


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 298/350: 0.0276


100%|██████████| 156/156 [00:15<00:00, 10.28it/s]


avg loss at epoch: 299/350: 0.0274


100%|██████████| 156/156 [00:15<00:00, 10.22it/s]


avg loss at epoch: 300/350: 0.0272


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 301/350: 0.0269


100%|██████████| 156/156 [00:17<00:00,  8.84it/s]


avg loss at epoch: 302/350: 0.0270


100%|██████████| 156/156 [00:15<00:00, 10.13it/s]


avg loss at epoch: 303/350: 0.0270


100%|██████████| 156/156 [00:15<00:00, 10.27it/s]


avg loss at epoch: 304/350: 0.0273


100%|██████████| 156/156 [00:15<00:00, 10.19it/s]


avg loss at epoch: 305/350: 0.0270


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 306/350: 0.0270


100%|██████████| 156/156 [00:15<00:00, 10.06it/s]


avg loss at epoch: 307/350: 0.0275


100%|██████████| 156/156 [00:15<00:00, 10.05it/s]


avg loss at epoch: 308/350: 0.0267


100%|██████████| 156/156 [00:15<00:00, 10.25it/s]


avg loss at epoch: 309/350: 0.0268


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 310/350: 0.0271


100%|██████████| 156/156 [00:15<00:00, 10.31it/s]


avg loss at epoch: 311/350: 0.0270


100%|██████████| 156/156 [00:15<00:00,  9.99it/s]


avg loss at epoch: 312/350: 0.0268


100%|██████████| 156/156 [00:15<00:00,  9.91it/s]


avg loss at epoch: 313/350: 0.0273


100%|██████████| 156/156 [00:15<00:00, 10.10it/s]


avg loss at epoch: 314/350: 0.0268


100%|██████████| 156/156 [00:15<00:00, 10.22it/s]


avg loss at epoch: 315/350: 0.0269


100%|██████████| 156/156 [00:15<00:00, 10.20it/s]


avg loss at epoch: 316/350: 0.0271


100%|██████████| 156/156 [00:15<00:00, 10.29it/s]


avg loss at epoch: 317/350: 0.0264


100%|██████████| 156/156 [00:15<00:00, 10.16it/s]


avg loss at epoch: 318/350: 0.0269


100%|██████████| 156/156 [00:15<00:00, 10.00it/s]


avg loss at epoch: 319/350: 0.0269


100%|██████████| 156/156 [00:15<00:00, 10.04it/s]


avg loss at epoch: 320/350: 0.0269


100%|██████████| 156/156 [00:15<00:00, 10.24it/s]


avg loss at epoch: 321/350: 0.0270


100%|██████████| 156/156 [00:15<00:00, 10.25it/s]


avg loss at epoch: 322/350: 0.0273


100%|██████████| 156/156 [00:15<00:00, 10.25it/s]


avg loss at epoch: 323/350: 0.0267


100%|██████████| 156/156 [00:15<00:00, 10.17it/s]


avg loss at epoch: 324/350: 0.0268


100%|██████████| 156/156 [00:15<00:00,  9.86it/s]


avg loss at epoch: 325/350: 0.0269


100%|██████████| 156/156 [00:15<00:00, 10.15it/s]


avg loss at epoch: 326/350: 0.0271


100%|██████████| 156/156 [00:15<00:00, 10.16it/s]


avg loss at epoch: 327/350: 0.0265


100%|██████████| 156/156 [00:15<00:00, 10.25it/s]


avg loss at epoch: 328/350: 0.0266


100%|██████████| 156/156 [00:15<00:00, 10.27it/s]


avg loss at epoch: 329/350: 0.0271


100%|██████████| 156/156 [00:15<00:00,  9.84it/s]


avg loss at epoch: 330/350: 0.0271


100%|██████████| 156/156 [00:15<00:00,  9.94it/s]


avg loss at epoch: 331/350: 0.0268


100%|██████████| 156/156 [00:15<00:00, 10.18it/s]


avg loss at epoch: 332/350: 0.0269


100%|██████████| 156/156 [00:15<00:00, 10.25it/s]


avg loss at epoch: 333/350: 0.0269


100%|██████████| 156/156 [00:15<00:00, 10.17it/s]


avg loss at epoch: 334/350: 0.0270


100%|██████████| 156/156 [00:15<00:00, 10.26it/s]


avg loss at epoch: 335/350: 0.0267


100%|██████████| 156/156 [00:15<00:00, 10.16it/s]


avg loss at epoch: 336/350: 0.0268


100%|██████████| 156/156 [00:15<00:00, 10.01it/s]


avg loss at epoch: 337/350: 0.0265


100%|██████████| 156/156 [00:15<00:00, 10.26it/s]


avg loss at epoch: 338/350: 0.0266


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 339/350: 0.0269


100%|██████████| 156/156 [00:15<00:00, 10.13it/s]


avg loss at epoch: 340/350: 0.0267


100%|██████████| 156/156 [00:15<00:00, 10.23it/s]


avg loss at epoch: 341/350: 0.0271


100%|██████████| 156/156 [00:15<00:00, 10.14it/s]


avg loss at epoch: 342/350: 0.0267


100%|██████████| 156/156 [00:15<00:00,  9.81it/s]


avg loss at epoch: 343/350: 0.0269


100%|██████████| 156/156 [00:15<00:00, 10.13it/s]


avg loss at epoch: 344/350: 0.0268


100%|██████████| 156/156 [00:15<00:00, 10.21it/s]


avg loss at epoch: 345/350: 0.0269


100%|██████████| 156/156 [00:15<00:00, 10.12it/s]


avg loss at epoch: 346/350: 0.0269


100%|██████████| 156/156 [00:15<00:00,  9.90it/s]


avg loss at epoch: 347/350: 0.0272


100%|██████████| 156/156 [00:15<00:00,  9.79it/s]


avg loss at epoch: 348/350: 0.0263


100%|██████████| 156/156 [00:15<00:00, 10.00it/s]


avg loss at epoch: 349/350: 0.0269


100%|██████████| 156/156 [00:15<00:00, 10.18it/s]

avg loss at epoch: 350/350: 0.0270





In [17]:
torch.save(model.state_dict(), './all_cnn.pth')

In [18]:
rcorrect, total = 0, 0

model.eval()
for batch in tqdm(test_dataloader):
    batch_x, batch_y = batch
    with torch.no_grad():
      logits = model(batch_x.to(device))
    preds = torch.argmax(logits, dim=1)
    num_preds = torch.sum((preds == batch_y.to(device))).item()

    rcorrect += num_preds
    total += batch_x.shape[0]

print(f"\n\tavg accuracy at epoch: {epoch+1}/{num_epochs}: {rcorrect/total:.4f}")

100%|██████████| 40/40 [00:02<00:00, 17.57it/s]


	avg accuracy at epoch: 350/350: 0.8923



