In [1]:
import torch
import torch.nn as nn
from build_dataloader import get_train_loader, get_test_loader
from custom_resnet import CustomRes
from torchsummary import summary
from trainer import train, test
from torch.optim import SGD,Adam
from torch.optim.lr_scheduler import OneCycleLR
from ignite.handlers import FastaiLRFinder



In [2]:
BATCHSIZE=512
train_dataloader = get_train_loader(data_dir='./data', 
                                    train=True,
                                    download=True,
                                    shuffle=True,
                                    batch_size=BATCHSIZE
                                    )

test_dataloader = get_test_loader(data_dir='./data', 
                                    train=False,
                                    download=True,
                                    shuffle=True,
                                    batch_size=BATCHSIZE
                                    )

Files already downloaded and verified
Files already downloaded and verified


In [3]:
model = CustomRes()
summary(model, input_size=(3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4          [-1, 128, 32, 32]          73,728
         MaxPool2d-5          [-1, 128, 16, 16]               0
       BatchNorm2d-6          [-1, 128, 16, 16]             256
              ReLU-7          [-1, 128, 16, 16]               0
            Conv2d-8          [-1, 128, 16, 16]         147,456
       BatchNorm2d-9          [-1, 128, 16, 16]             256
             ReLU-10          [-1, 128, 16, 16]               0
           Conv2d-11          [-1, 128, 16, 16]         147,456
      BatchNorm2d-12          [-1, 128, 16, 16]             256
             ReLU-13          [-1, 128, 16, 16]               0
           Conv2d-14          [-1, 256,

In [4]:
EPOCHS = 25
criterion = nn.CrossEntropyLoss()
optimizer = Adam(params=model.parameters(), lr=0.08)
scheduler = OneCycleLR(optimizer=optimizer, max_lr=.01, epochs=EPOCHS, steps_per_epoch=len(train_dataloader))
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device("mps")
for e in range(EPOCHS):
    train(model, optimizer=optimizer, train_dataloader=train_dataloader,device=device,epoch=e,scheduler=scheduler, criterion=criterion )
    test(model,test_dataloader=test_dataloader, device=device, epoch=e,criterion=criterion)

  0%|          | 0/98 [00:00<?, ?it/s]

100%|██████████| 98/98 [00:35<00:00,  2.76it/s]


at epoch:0
avg_train_loss:1.416864037513733,avg_train_acc:0.5007924437522888
avg_test_loss:1.0424970388412476,avg_test_acc:0.6290326118469238


100%|██████████| 98/98 [00:35<00:00,  2.79it/s]


at epoch:1
avg_train_loss:0.870063066482544,avg_train_acc:0.694782018661499
avg_test_loss:0.7542669773101807,avg_test_acc:0.7387752532958984


100%|██████████| 98/98 [00:35<00:00,  2.80it/s]


at epoch:2
avg_train_loss:0.6841527819633484,avg_train_acc:0.7617149353027344
avg_test_loss:0.8940948247909546,avg_test_acc:0.7186006307601929


100%|██████████| 98/98 [00:35<00:00,  2.74it/s]


at epoch:3
avg_train_loss:0.5842074155807495,avg_train_acc:0.8004148602485657
avg_test_loss:0.6082676649093628,avg_test_acc:0.8000574111938477


100%|██████████| 98/98 [00:35<00:00,  2.74it/s]


at epoch:4
avg_train_loss:0.5061802864074707,avg_train_acc:0.8233513236045837
avg_test_loss:0.7318302392959595,avg_test_acc:0.7713407278060913


100%|██████████| 98/98 [00:35<00:00,  2.75it/s]


at epoch:5
avg_train_loss:0.4498867988586426,avg_train_acc:0.8444418907165527
avg_test_loss:0.9364898800849915,avg_test_acc:0.7404354214668274


100%|██████████| 98/98 [00:35<00:00,  2.78it/s]


at epoch:6
avg_train_loss:0.3728276193141937,avg_train_acc:0.8716669678688049
avg_test_loss:0.8321179151535034,avg_test_acc:0.785776674747467


100%|██████████| 98/98 [00:35<00:00,  2.73it/s]


at epoch:7
avg_train_loss:0.28789547085762024,avg_train_acc:0.8995678424835205
avg_test_loss:0.6643854975700378,avg_test_acc:0.8079445958137512


100%|██████████| 98/98 [00:35<00:00,  2.72it/s]


at epoch:8
avg_train_loss:0.23387938737869263,avg_train_acc:0.918340802192688
avg_test_loss:0.5865095853805542,avg_test_acc:0.8551241159439087


100%|██████████| 98/98 [00:35<00:00,  2.75it/s]


at epoch:9
avg_train_loss:0.1836121827363968,avg_train_acc:0.9348778128623962
avg_test_loss:0.557948112487793,avg_test_acc:0.8626263737678528


100%|██████████| 98/98 [00:36<00:00,  2.66it/s]


at epoch:10
avg_train_loss:0.14377421140670776,avg_train_acc:0.950498104095459
avg_test_loss:0.7864389419555664,avg_test_acc:0.8303078413009644


100%|██████████| 98/98 [00:35<00:00,  2.75it/s]


at epoch:11
avg_train_loss:0.1361963152885437,avg_train_acc:0.9535767436027527
avg_test_loss:0.6290660500526428,avg_test_acc:0.8608857989311218


100%|██████████| 98/98 [00:35<00:00,  2.77it/s]


at epoch:12
avg_train_loss:0.1195443794131279,avg_train_acc:0.958939790725708
avg_test_loss:0.5508044958114624,avg_test_acc:0.8794519305229187


100%|██████████| 98/98 [00:36<00:00,  2.71it/s]


at epoch:13
avg_train_loss:0.11187813431024551,avg_train_acc:0.9618780016899109
avg_test_loss:0.5605612993240356,avg_test_acc:0.8741096258163452


100%|██████████| 98/98 [00:36<00:00,  2.68it/s]


at epoch:14
avg_train_loss:0.09965537488460541,avg_train_acc:0.9660651683807373
avg_test_loss:0.5913864374160767,avg_test_acc:0.8689797520637512


100%|██████████| 98/98 [00:35<00:00,  2.74it/s]


at epoch:15
avg_train_loss:0.09746357053518295,avg_train_acc:0.9672305583953857
avg_test_loss:0.5363028049468994,avg_test_acc:0.8801413774490356


100%|██████████| 98/98 [00:35<00:00,  2.77it/s]


at epoch:16
avg_train_loss:0.0889049768447876,avg_train_acc:0.9695215821266174
avg_test_loss:0.526117742061615,avg_test_acc:0.8824735879898071


100%|██████████| 98/98 [00:35<00:00,  2.79it/s]


at epoch:17
avg_train_loss:0.08385537564754486,avg_train_acc:0.9713456034660339
avg_test_loss:0.5468054413795471,avg_test_acc:0.8820945024490356


100%|██████████| 98/98 [00:35<00:00,  2.75it/s]


at epoch:18
avg_train_loss:0.07942540943622589,avg_train_acc:0.9732076525688171
avg_test_loss:0.5053330659866333,avg_test_acc:0.8854606747627258


100%|██████████| 98/98 [00:35<00:00,  2.79it/s]


at epoch:19
avg_train_loss:0.0700482726097107,avg_train_acc:0.9771062731742859
avg_test_loss:0.5017173290252686,avg_test_acc:0.8888614773750305


100%|██████████| 98/98 [00:36<00:00,  2.69it/s]


at epoch:20
avg_train_loss:0.06610319763422012,avg_train_acc:0.9777145981788635
avg_test_loss:0.5035909414291382,avg_test_acc:0.8894933462142944


100%|██████████| 98/98 [00:35<00:00,  2.76it/s]


at epoch:21
avg_train_loss:0.06657670438289642,avg_train_acc:0.9784406423568726
avg_test_loss:0.5007193088531494,avg_test_acc:0.8900677561759949


100%|██████████| 98/98 [00:35<00:00,  2.79it/s]


at epoch:22
avg_train_loss:0.06144772842526436,avg_train_acc:0.9802665710449219
avg_test_loss:0.49589958786964417,avg_test_acc:0.8899873495101929


100%|██████████| 98/98 [00:35<00:00,  2.79it/s]


at epoch:23
avg_train_loss:0.05870560184121132,avg_train_acc:0.9811035990715027
avg_test_loss:0.4957941472530365,avg_test_acc:0.8901540040969849


100%|██████████| 98/98 [00:35<00:00,  2.79it/s]


at epoch:24
avg_train_loss:0.05747431516647339,avg_train_acc:0.981749951839447
avg_test_loss:0.48817095160484314,avg_test_acc:0.891228199005127
