In [16]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchsummary import summary

# Import custom modules
from model import Net
from dataset import get_cifar10_loaders
from train import train, test, reset_metrics, train_losses, test_losses, train_acc, test_acc


cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
print(f"CUDA Available? {cuda}")
print(device)


train_loader, test_loader = get_cifar10_loaders(batch_size=128, num_workers=4)

model = Net(dropout_value=0.05).to(device)

# Print model summary
summary(model, input_size=(3, 32, 32))

CUDA Available? True
cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 14, 32, 32]             378
       BatchNorm2d-2           [-1, 14, 32, 32]              28
              ReLU-3           [-1, 14, 32, 32]               0
           Dropout-4           [-1, 14, 32, 32]               0
            Conv2d-5           [-1, 22, 32, 32]           2,772
       BatchNorm2d-6           [-1, 22, 32, 32]              44
              ReLU-7           [-1, 22, 32, 32]               0
           Dropout-8           [-1, 22, 32, 32]               0
            Conv2d-9           [-1, 22, 32, 32]             198
           Conv2d-10           [-1, 32, 32, 32]             704
DepthwiseSeparableConv-11           [-1, 32, 32, 32]               0
      BatchNorm2d-12           [-1, 32, 32, 32]              64
             ReLU-13           [-1, 32, 32, 32]               0
        

In [17]:
# ==================== TRAINING LOOP ====================

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = StepLR(optimizer, step_size=4, gamma=0.6)

EPOCHS = 50
for epoch in range(EPOCHS):z
    print("EPOCH:", epoch)
    train(model, device, train_loader, optimizer, epoch)
    scheduler.step()
    test(model, device, test_loader)

    # Check if target reached
    if test_acc[-1] >= 85.0:
        print(f"\n{'='*80}")
        print(f"🎉 TARGET REACHED! Test Accuracy: {test_acc[-1]:.2f}% >= 85.0%")
        print(f"{'='*80}\n")
        break

EPOCH: 0


Loss=1.3296 Batch_id=390 Accuracy=37.41: 100%|██████████| 391/391 [00:24<00:00, 15.74it/s]



Test set: Average loss: 1.4308, Accuracy: 4952/10000 (49.52%)

EPOCH: 1


Loss=1.2881 Batch_id=390 Accuracy=52.97: 100%|██████████| 391/391 [00:24<00:00, 15.99it/s]



Test set: Average loss: 1.2258, Accuracy: 5810/10000 (58.10%)

EPOCH: 2


Loss=1.0592 Batch_id=390 Accuracy=59.45: 100%|██████████| 391/391 [00:23<00:00, 16.32it/s]



Test set: Average loss: 1.1153, Accuracy: 6247/10000 (62.47%)

EPOCH: 3


Loss=1.3670 Batch_id=390 Accuracy=63.66: 100%|██████████| 391/391 [00:24<00:00, 16.16it/s]



Test set: Average loss: 0.9671, Accuracy: 6680/10000 (66.80%)

EPOCH: 4


Loss=0.8854 Batch_id=390 Accuracy=67.89: 100%|██████████| 391/391 [00:25<00:00, 15.59it/s]



Test set: Average loss: 0.7961, Accuracy: 7279/10000 (72.79%)

EPOCH: 5


Loss=0.9386 Batch_id=390 Accuracy=69.87: 100%|██████████| 391/391 [00:24<00:00, 16.23it/s]



Test set: Average loss: 0.7413, Accuracy: 7485/10000 (74.85%)

EPOCH: 6


Loss=1.0945 Batch_id=390 Accuracy=70.93: 100%|██████████| 391/391 [00:24<00:00, 16.23it/s]



Test set: Average loss: 0.6386, Accuracy: 7820/10000 (78.20%)

EPOCH: 7


Loss=0.9583 Batch_id=390 Accuracy=71.63: 100%|██████████| 391/391 [00:23<00:00, 16.62it/s]



Test set: Average loss: 0.6069, Accuracy: 7909/10000 (79.09%)

EPOCH: 8


Loss=0.8761 Batch_id=390 Accuracy=73.76: 100%|██████████| 391/391 [00:23<00:00, 16.30it/s]



Test set: Average loss: 0.5760, Accuracy: 7996/10000 (79.96%)

EPOCH: 9


Loss=0.6714 Batch_id=390 Accuracy=74.60: 100%|██████████| 391/391 [00:23<00:00, 16.40it/s]



Test set: Average loss: 0.5288, Accuracy: 8172/10000 (81.72%)

EPOCH: 10


Loss=0.7729 Batch_id=390 Accuracy=75.22: 100%|██████████| 391/391 [00:23<00:00, 16.38it/s]



Test set: Average loss: 0.5577, Accuracy: 8074/10000 (80.74%)

EPOCH: 11


Loss=0.6795 Batch_id=390 Accuracy=75.51: 100%|██████████| 391/391 [00:24<00:00, 16.29it/s]



Test set: Average loss: 0.5381, Accuracy: 8175/10000 (81.75%)

EPOCH: 12


Loss=0.5145 Batch_id=390 Accuracy=76.53: 100%|██████████| 391/391 [00:28<00:00, 13.72it/s]



Test set: Average loss: 0.5048, Accuracy: 8274/10000 (82.74%)

EPOCH: 13


Loss=0.6413 Batch_id=390 Accuracy=76.94: 100%|██████████| 391/391 [00:32<00:00, 11.92it/s]



Test set: Average loss: 0.4947, Accuracy: 8317/10000 (83.17%)

EPOCH: 14


Loss=0.7395 Batch_id=390 Accuracy=77.14: 100%|██████████| 391/391 [00:32<00:00, 11.94it/s]



Test set: Average loss: 0.5070, Accuracy: 8292/10000 (82.92%)

EPOCH: 15


Loss=0.5546 Batch_id=390 Accuracy=77.17: 100%|██████████| 391/391 [00:31<00:00, 12.59it/s]



Test set: Average loss: 0.5098, Accuracy: 8262/10000 (82.62%)

EPOCH: 16


Loss=0.4193 Batch_id=390 Accuracy=77.95: 100%|██████████| 391/391 [00:22<00:00, 17.01it/s]



Test set: Average loss: 0.4657, Accuracy: 8407/10000 (84.07%)

EPOCH: 17


Loss=0.7036 Batch_id=390 Accuracy=78.32: 100%|██████████| 391/391 [00:22<00:00, 17.58it/s]



Test set: Average loss: 0.4656, Accuracy: 8429/10000 (84.29%)

EPOCH: 18


Loss=0.7815 Batch_id=390 Accuracy=78.51: 100%|██████████| 391/391 [00:22<00:00, 17.72it/s]



Test set: Average loss: 0.4539, Accuracy: 8462/10000 (84.62%)

EPOCH: 19


Loss=0.5875 Batch_id=390 Accuracy=78.50: 100%|██████████| 391/391 [00:23<00:00, 16.74it/s]



Test set: Average loss: 0.4534, Accuracy: 8456/10000 (84.56%)

EPOCH: 20


Loss=0.4830 Batch_id=390 Accuracy=78.87: 100%|██████████| 391/391 [00:21<00:00, 17.90it/s]



Test set: Average loss: 0.4419, Accuracy: 8499/10000 (84.99%)

EPOCH: 21


Loss=0.6374 Batch_id=390 Accuracy=79.12: 100%|██████████| 391/391 [00:22<00:00, 17.37it/s]



Test set: Average loss: 0.4380, Accuracy: 8530/10000 (85.30%)


🎉 TARGET REACHED! Test Accuracy: 85.30% >= 85.0%

