In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np

  check_for_updates()


In [None]:
from  torchvision import datasets

cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True )
data = cifar_trainset.data / 255

cifar_10_mean = data.mean(axis=(0, 1, 2)).tolist()
cifar_10_std = data.std(axis=(0, 1, 2)).tolist()
print(f"Mean : {cifar_10_mean}, Standard deviation: {cifar_10_std}")

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 42.2MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Mean : [0.4913996786152028, 0.4821584083946074, 0.4465309144454644], Standard deviation: [0.24703223246328238, 0.2434851280000556, 0.26158784172796423]


In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
num_epochs = 100
batch_size = 128
learning_rate = 0.001

In [None]:
# Albumentation transforms
transform_train = A.Compose([
    A.HorizontalFlip(),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
    A.CoarseDropout(max_holes=1, max_height=16, max_width=16, min_holes=1, min_height=16, min_width=16,
                    fill_value=cifar_10_mean, mask_fill_value=None, p=0.5),
    A.Normalize(mean=cifar_10_mean, std=cifar_10_std),
    ToTensorV2(),
])


transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar_10_mean, std=cifar_10_std),
])

In [None]:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=lambda img: transform_train(image=np.array(img))['image'])
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
# has the architecture with convolutional blocks - C1 | C2 | C3 | C4 | Output
    # (No MaxPooling, but convolutions, where the last one layer in each block has a stride of 2 instead)
    # (No restriction on using 1x1)
    # (If you can figure out how to use Dilated kernels here instead of MP or strided convolution, then 200pts extra!)
# total Receptive Field must be more than 44
# one of the layers must use Depthwise Separable Convolution
# one of the layers must use Dilated Convolution
# use GAP (compulsory): add FC after GAP to target #of classes (optional)


# Convolutional neural network (CNN)
class CIFAR10_ConvNet(nn.Module):
    def __init__(self):
        super(CIFAR10_ConvNet, self).__init__()
        dropout_value = 0.1
        # Block - 1
        # Input: 3x32x32
        self.convblock1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(3, 3), stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout(dropout_value),
            # ouptput: 32, RF: 3

            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(dropout_value),
            # output: 32, RF: 5

            # Depthwise separable convolution
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), stride=1, padding=1, groups=32),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1, 1), stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # output: 32, RF: 7

            # Layer Dilated Convolution
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=2, padding=2, dilation=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(dropout_value),
            # output: 16, RF: 11

        )

        self.transition1 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=16, kernel_size=(1, 1), padding=0, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU()
            #output : 16, RF: 11
        )

        # Block - 2
        self.convblock2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(dropout_value),
            # output: 16, RF: 15

            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(dropout_value),
            # output: 16, RF: 19

            # Layer Dilated Convolution
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=2, padding=2, dilation=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(dropout_value),
            # output: 8, RF: 27
        )

        self.transition2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=16, kernel_size=(1, 1), padding=0, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU()
            # output: 8, RF: 27
        )

        # Block - 3
        self.convblock3 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(dropout_value),
            # output: 8, RF: 35

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(dropout_value),
            # output: 8, RF: 43

            # Layer Dilated Convolution
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=2, padding=2, dilation=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(dropout_value),
            # output: 4, RF: 59
        )

        self.transition3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=16, kernel_size=(1, 1), padding=0, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU()
            # output: 4, RF: 59
        )

        # Block - 4
        self.convblock4 = nn.Sequential(
            # Layer Dilated Convolution
            nn.Conv2d(in_channels=16, out_channels=64, kernel_size=(3, 3), stride=2, padding=2, dilation=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # output: 2, RF: 91
        )

        # Output Block
        self.outputblock = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels=64, out_channels=10, kernel_size=(1, 1), padding=0, bias=False)
        )


    def forward(self, x):
        x = self.convblock1(x)
        x = self.transition1(x)
        x = self.convblock2(x)
        x = self.transition2(x)
        x = self.convblock3(x)
        x = self.transition3(x)
        x = self.convblock4(x)
        x = self.outputblock(x)
        x = x.view(-1, 10)
        return x

In [None]:
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)
model = CIFAR10_ConvNet().to(device)
summary(model, input_size=(3, 32, 32))

cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0
           Dropout-4           [-1, 16, 32, 32]               0
            Conv2d-5           [-1, 32, 32, 32]           4,608
       BatchNorm2d-6           [-1, 32, 32, 32]              64
              ReLU-7           [-1, 32, 32, 32]               0
           Dropout-8           [-1, 32, 32, 32]               0
            Conv2d-9           [-1, 32, 32, 32]             320
           Conv2d-10           [-1, 64, 32, 32]           2,112
      BatchNorm2d-11           [-1, 64, 32, 32]             128
             ReLU-12           [-1, 64, 32, 32]               0
           Conv2d-13           [-1, 64, 16, 16]          36,864
      BatchNorm2d-14           [-1

In [None]:
from tqdm import tqdm

train_losses = []
test_losses = []
train_acc = []
test_acc = []

def train(model, device, train_loader, criterion, optimizer, epoch):
  model.train()
  pbar = tqdm(train_loader)
  correct = 0
  processed = 0
  for batch_idx, (data, target) in enumerate(pbar):
    # get samples
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()

    # Predict
    y_pred = model(data)

    # Calculate loss
    loss = criterion(y_pred, target)
    train_losses.append(loss)

    # Backpropagation
    loss.backward()
    optimizer.step()

    # Update pbar-tqdm
    pred = y_pred.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    correct += pred.eq(target.view_as(pred)).sum().item()
    processed += len(data)

    pbar.set_description(desc= f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100*correct/processed:0.2f}')
    train_acc.append(100*correct/processed)

def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    test_acc.append(100. * correct / len(test_loader.dataset))

In [None]:
model = CIFAR10_ConvNet().to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# step_lr
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
                            optimizer,
                            max_lr=0.1,
                            steps_per_epoch=len(train_loader),
                            epochs=num_epochs
                        )

# train loop
for epoch in range(1, num_epochs + 1):
    print("Epoch: ", epoch)
    print("Learning rate: ", optimizer.state_dict()['param_groups'][0]['lr'])
    train(model, device, train_loader, criterion, optimizer, epoch)
    test(model, device, test_loader, criterion)
    lr_scheduler.step()

Epoch:  1
Learning rate:  0.0040000000000000036


Loss=1.674565315246582 Batch_id=390 Accuracy=33.39: 100%|██████████| 391/391 [00:49<00:00,  7.85it/s]



Test set: Average loss: 0.0120, Accuracy: 4389/10000 (43.89%)

Epoch:  2
Learning rate:  0.004000001721825006


Loss=1.3594435453414917 Batch_id=390 Accuracy=44.87: 100%|██████████| 391/391 [00:42<00:00,  9.26it/s]



Test set: Average loss: 0.0106, Accuracy: 5081/10000 (50.81%)

Epoch:  3
Learning rate:  0.004000006887299928


Loss=1.2722824811935425 Batch_id=390 Accuracy=51.17: 100%|██████████| 391/391 [00:41<00:00,  9.35it/s]



Test set: Average loss: 0.0088, Accuracy: 5983/10000 (59.83%)

Epoch:  4
Learning rate:  0.004000015496424356


Loss=1.4783557653427124 Batch_id=390 Accuracy=54.33: 100%|██████████| 391/391 [00:41<00:00,  9.43it/s]



Test set: Average loss: 0.0083, Accuracy: 6241/10000 (62.41%)

Epoch:  5
Learning rate:  0.004000027549197704


Loss=1.2672476768493652 Batch_id=390 Accuracy=56.79: 100%|██████████| 391/391 [00:40<00:00,  9.56it/s]



Test set: Average loss: 0.0085, Accuracy: 6238/10000 (62.38%)

Epoch:  6
Learning rate:  0.0040000430456191005


Loss=1.143467664718628 Batch_id=390 Accuracy=58.79: 100%|██████████| 391/391 [00:40<00:00,  9.54it/s]



Test set: Average loss: 0.0073, Accuracy: 6696/10000 (66.96%)

Epoch:  7
Learning rate:  0.00400006198568742


Loss=1.0217262506484985 Batch_id=390 Accuracy=60.56: 100%|██████████| 391/391 [00:41<00:00,  9.40it/s]



Test set: Average loss: 0.0072, Accuracy: 6766/10000 (67.66%)

Epoch:  8
Learning rate:  0.004000084369401316


Loss=1.2247064113616943 Batch_id=390 Accuracy=62.25: 100%|██████████| 391/391 [00:41<00:00,  9.38it/s]



Test set: Average loss: 0.0069, Accuracy: 6902/10000 (69.02%)

Epoch:  9
Learning rate:  0.0040001101967591796


Loss=1.21384596824646 Batch_id=390 Accuracy=63.42: 100%|██████████| 391/391 [00:41<00:00,  9.52it/s]



Test set: Average loss: 0.0076, Accuracy: 6720/10000 (67.20%)

Epoch:  10
Learning rate:  0.0040001394677591645


Loss=1.1351549625396729 Batch_id=390 Accuracy=64.52: 100%|██████████| 391/391 [00:41<00:00,  9.43it/s]



Test set: Average loss: 0.0063, Accuracy: 7198/10000 (71.98%)

Epoch:  11
Learning rate:  0.0040001721823991615


Loss=0.8469350934028625 Batch_id=390 Accuracy=65.26: 100%|██████████| 391/391 [00:40<00:00,  9.62it/s]



Test set: Average loss: 0.0062, Accuracy: 7245/10000 (72.45%)

Epoch:  12
Learning rate:  0.004000208340676825


Loss=0.9223294258117676 Batch_id=390 Accuracy=66.40: 100%|██████████| 391/391 [00:40<00:00,  9.58it/s]



Test set: Average loss: 0.0062, Accuracy: 7275/10000 (72.75%)

Epoch:  13
Learning rate:  0.00400024794258956


Loss=1.1258251667022705 Batch_id=390 Accuracy=67.09: 100%|██████████| 391/391 [00:41<00:00,  9.48it/s]



Test set: Average loss: 0.0061, Accuracy: 7316/10000 (73.16%)

Epoch:  14
Learning rate:  0.004000290988134536


Loss=0.99415522813797 Batch_id=390 Accuracy=67.84: 100%|██████████| 391/391 [00:41<00:00,  9.33it/s]



Test set: Average loss: 0.0060, Accuracy: 7406/10000 (74.06%)

Epoch:  15
Learning rate:  0.0040003374773086575


Loss=1.0922236442565918 Batch_id=390 Accuracy=68.61: 100%|██████████| 391/391 [00:41<00:00,  9.49it/s]



Test set: Average loss: 0.0056, Accuracy: 7470/10000 (74.70%)

Epoch:  16
Learning rate:  0.004000387410108566


Loss=0.7836856842041016 Batch_id=390 Accuracy=69.31: 100%|██████████| 391/391 [00:41<00:00,  9.44it/s]



Test set: Average loss: 0.0056, Accuracy: 7568/10000 (75.68%)

Epoch:  17
Learning rate:  0.0040004407865307234


Loss=0.9401439428329468 Batch_id=390 Accuracy=69.38: 100%|██████████| 391/391 [00:40<00:00,  9.60it/s]



Test set: Average loss: 0.0054, Accuracy: 7614/10000 (76.14%)

Epoch:  18
Learning rate:  0.004000497606571285


Loss=0.873416543006897 Batch_id=390 Accuracy=69.99: 100%|██████████| 391/391 [00:41<00:00,  9.46it/s]



Test set: Average loss: 0.0054, Accuracy: 7603/10000 (76.03%)

Epoch:  19
Learning rate:  0.004000557870226157


Loss=0.9444351196289062 Batch_id=390 Accuracy=70.48: 100%|██████████| 391/391 [00:40<00:00,  9.65it/s]



Test set: Average loss: 0.0050, Accuracy: 7776/10000 (77.76%)

Epoch:  20
Learning rate:  0.004000621577491023


Loss=0.9986352920532227 Batch_id=390 Accuracy=70.79: 100%|██████████| 391/391 [00:42<00:00,  9.30it/s]



Test set: Average loss: 0.0055, Accuracy: 7618/10000 (76.18%)

Epoch:  21
Learning rate:  0.004000688728361318


Loss=0.7034445405006409 Batch_id=390 Accuracy=71.54: 100%|██████████| 391/391 [00:41<00:00,  9.51it/s]



Test set: Average loss: 0.0050, Accuracy: 7784/10000 (77.84%)

Epoch:  22
Learning rate:  0.004000759322832226


Loss=0.8977529406547546 Batch_id=390 Accuracy=71.42: 100%|██████████| 391/391 [00:40<00:00,  9.54it/s]



Test set: Average loss: 0.0049, Accuracy: 7854/10000 (78.54%)

Epoch:  23
Learning rate:  0.004000833360898695


Loss=0.7404463291168213 Batch_id=390 Accuracy=72.14: 100%|██████████| 391/391 [00:40<00:00,  9.61it/s]



Test set: Average loss: 0.0048, Accuracy: 7898/10000 (78.98%)

Epoch:  24
Learning rate:  0.004000910842555383


Loss=0.6886987686157227 Batch_id=390 Accuracy=72.30: 100%|██████████| 391/391 [00:40<00:00,  9.55it/s]



Test set: Average loss: 0.0047, Accuracy: 7916/10000 (79.16%)

Epoch:  25
Learning rate:  0.004000991767796738


Loss=0.6425908803939819 Batch_id=390 Accuracy=72.39: 100%|██████████| 391/391 [00:40<00:00,  9.65it/s]



Test set: Average loss: 0.0047, Accuracy: 7964/10000 (79.64%)

Epoch:  26
Learning rate:  0.004001076136616974


Loss=0.8541432619094849 Batch_id=390 Accuracy=72.56: 100%|██████████| 391/391 [00:40<00:00,  9.61it/s]



Test set: Average loss: 0.0047, Accuracy: 7969/10000 (79.69%)

Epoch:  27
Learning rate:  0.004001163949010025


Loss=0.6786538362503052 Batch_id=390 Accuracy=72.93: 100%|██████████| 391/391 [00:40<00:00,  9.62it/s]



Test set: Average loss: 0.0046, Accuracy: 7994/10000 (79.94%)

Epoch:  28
Learning rate:  0.004001255204969578


Loss=0.6492525935173035 Batch_id=390 Accuracy=73.18: 100%|██████████| 391/391 [00:40<00:00,  9.67it/s]



Test set: Average loss: 0.0048, Accuracy: 7947/10000 (79.47%)

Epoch:  29
Learning rate:  0.004001349904489124


Loss=0.7305186986923218 Batch_id=390 Accuracy=73.71: 100%|██████████| 391/391 [00:40<00:00,  9.66it/s]



Test set: Average loss: 0.0046, Accuracy: 8018/10000 (80.18%)

Epoch:  30
Learning rate:  0.004001448047561834


Loss=0.8429633378982544 Batch_id=390 Accuracy=73.53: 100%|██████████| 391/391 [00:40<00:00,  9.65it/s]



Test set: Average loss: 0.0043, Accuracy: 8088/10000 (80.88%)

Epoch:  31
Learning rate:  0.004001549634180673


Loss=0.8712469935417175 Batch_id=390 Accuracy=73.97: 100%|██████████| 391/391 [00:40<00:00,  9.62it/s]



Test set: Average loss: 0.0048, Accuracy: 7936/10000 (79.36%)

Epoch:  32
Learning rate:  0.0040016546643383555


Loss=0.6549279093742371 Batch_id=390 Accuracy=73.93: 100%|██████████| 391/391 [00:40<00:00,  9.64it/s]



Test set: Average loss: 0.0044, Accuracy: 8100/10000 (81.00%)

Epoch:  33
Learning rate:  0.004001763138027373


Loss=0.731695830821991 Batch_id=390 Accuracy=74.46: 100%|██████████| 391/391 [00:40<00:00,  9.56it/s]



Test set: Average loss: 0.0045, Accuracy: 8021/10000 (80.21%)

Epoch:  34
Learning rate:  0.004001875055239884


Loss=0.8350318074226379 Batch_id=390 Accuracy=74.61: 100%|██████████| 391/391 [00:40<00:00,  9.64it/s]



Test set: Average loss: 0.0044, Accuracy: 8128/10000 (81.28%)

Epoch:  35
Learning rate:  0.004001990415967924


Loss=0.7317013740539551 Batch_id=390 Accuracy=74.87: 100%|██████████| 391/391 [00:40<00:00,  9.60it/s]



Test set: Average loss: 0.0043, Accuracy: 8116/10000 (81.16%)

Epoch:  36
Learning rate:  0.004002109220203179


Loss=0.6987301111221313 Batch_id=390 Accuracy=74.82: 100%|██████████| 391/391 [00:40<00:00,  9.58it/s]



Test set: Average loss: 0.0042, Accuracy: 8194/10000 (81.94%)

Epoch:  37
Learning rate:  0.0040022314679371285


Loss=0.6903976202011108 Batch_id=390 Accuracy=74.76: 100%|██████████| 391/391 [00:40<00:00,  9.55it/s]



Test set: Average loss: 0.0043, Accuracy: 8127/10000 (81.27%)

Epoch:  38
Learning rate:  0.004002357159161016


Loss=0.776331901550293 Batch_id=390 Accuracy=74.99: 100%|██████████| 391/391 [00:40<00:00,  9.66it/s]



Test set: Average loss: 0.0042, Accuracy: 8190/10000 (81.90%)

Epoch:  39
Learning rate:  0.004002486293865806


Loss=0.7064996957778931 Batch_id=390 Accuracy=75.17: 100%|██████████| 391/391 [00:40<00:00,  9.55it/s]



Test set: Average loss: 0.0045, Accuracy: 8054/10000 (80.54%)

Epoch:  40
Learning rate:  0.004002618872042257


Loss=0.7962319850921631 Batch_id=390 Accuracy=75.60: 100%|██████████| 391/391 [00:40<00:00,  9.59it/s]



Test set: Average loss: 0.0045, Accuracy: 8085/10000 (80.85%)

Epoch:  41
Learning rate:  0.004002754893680835


Loss=0.8051894307136536 Batch_id=390 Accuracy=75.62: 100%|██████████| 391/391 [00:41<00:00,  9.50it/s]



Test set: Average loss: 0.0043, Accuracy: 8216/10000 (82.16%)

Epoch:  42
Learning rate:  0.004002894358771797


Loss=0.6317909955978394 Batch_id=390 Accuracy=75.62: 100%|██████████| 391/391 [00:40<00:00,  9.65it/s]



Test set: Average loss: 0.0041, Accuracy: 8211/10000 (82.11%)

Epoch:  43
Learning rate:  0.0040030372673051234


Loss=0.7083423137664795 Batch_id=390 Accuracy=75.97: 100%|██████████| 391/391 [00:41<00:00,  9.53it/s]



Test set: Average loss: 0.0041, Accuracy: 8218/10000 (82.18%)

Epoch:  44
Learning rate:  0.004003183619270573


Loss=1.0098849534988403 Batch_id=390 Accuracy=76.03: 100%|██████████| 391/391 [00:40<00:00,  9.67it/s]



Test set: Average loss: 0.0040, Accuracy: 8280/10000 (82.80%)

Epoch:  45
Learning rate:  0.004003333414657639


Loss=0.44520872831344604 Batch_id=390 Accuracy=76.06: 100%|██████████| 391/391 [00:40<00:00,  9.55it/s]



Test set: Average loss: 0.0041, Accuracy: 8238/10000 (82.38%)

Epoch:  46
Learning rate:  0.004003486653455582


Loss=0.7004948854446411 Batch_id=390 Accuracy=76.04: 100%|██████████| 391/391 [00:40<00:00,  9.54it/s]



Test set: Average loss: 0.0039, Accuracy: 8325/10000 (83.25%)

Epoch:  47
Learning rate:  0.004003643335653395


Loss=0.8314526677131653 Batch_id=390 Accuracy=76.67: 100%|██████████| 391/391 [00:41<00:00,  9.38it/s]



Test set: Average loss: 0.0039, Accuracy: 8311/10000 (83.11%)

Epoch:  48
Learning rate:  0.004003803461239866


Loss=0.8919135928153992 Batch_id=390 Accuracy=76.57: 100%|██████████| 391/391 [00:41<00:00,  9.53it/s]



Test set: Average loss: 0.0041, Accuracy: 8250/10000 (82.50%)

Epoch:  49
Learning rate:  0.004003967030203476


Loss=0.640898585319519 Batch_id=390 Accuracy=76.44: 100%|██████████| 391/391 [00:41<00:00,  9.40it/s]



Test set: Average loss: 0.0040, Accuracy: 8283/10000 (82.83%)

Epoch:  50
Learning rate:  0.004004134042532498


Loss=0.5064693689346313 Batch_id=390 Accuracy=76.45: 100%|██████████| 391/391 [00:40<00:00,  9.57it/s]



Test set: Average loss: 0.0039, Accuracy: 8330/10000 (83.30%)

Epoch:  51
Learning rate:  0.004004304498214956


Loss=0.7892385125160217 Batch_id=390 Accuracy=76.90: 100%|██████████| 391/391 [00:40<00:00,  9.55it/s]



Test set: Average loss: 0.0039, Accuracy: 8344/10000 (83.44%)

Epoch:  52
Learning rate:  0.004004478397238623


Loss=0.7206576466560364 Batch_id=390 Accuracy=76.64: 100%|██████████| 391/391 [00:40<00:00,  9.54it/s]



Test set: Average loss: 0.0040, Accuracy: 8310/10000 (83.10%)

Epoch:  53
Learning rate:  0.004004655739591023


Loss=0.7046281695365906 Batch_id=390 Accuracy=76.84: 100%|██████████| 391/391 [00:40<00:00,  9.66it/s]



Test set: Average loss: 0.0039, Accuracy: 8334/10000 (83.34%)

Epoch:  54
Learning rate:  0.004004836525259417


Loss=0.637482762336731 Batch_id=390 Accuracy=77.07: 100%|██████████| 391/391 [00:41<00:00,  9.52it/s]



Test set: Average loss: 0.0038, Accuracy: 8374/10000 (83.74%)

Epoch:  55
Learning rate:  0.0040050207542308425


Loss=0.5130741596221924 Batch_id=390 Accuracy=77.09: 100%|██████████| 391/391 [00:40<00:00,  9.54it/s]



Test set: Average loss: 0.0038, Accuracy: 8363/10000 (83.63%)

Epoch:  56
Learning rate:  0.004005208426492102


Loss=0.4899267256259918 Batch_id=390 Accuracy=77.04: 100%|██████████| 391/391 [00:41<00:00,  9.51it/s]



Test set: Average loss: 0.0037, Accuracy: 8378/10000 (83.78%)

Epoch:  57
Learning rate:  0.004005399542029706


Loss=0.6366291046142578 Batch_id=390 Accuracy=77.36: 100%|██████████| 391/391 [00:40<00:00,  9.68it/s]



Test set: Average loss: 0.0038, Accuracy: 8353/10000 (83.53%)

Epoch:  58
Learning rate:  0.004005594100829957


Loss=0.648241400718689 Batch_id=390 Accuracy=77.46: 100%|██████████| 391/391 [00:41<00:00,  9.53it/s]



Test set: Average loss: 0.0039, Accuracy: 8359/10000 (83.59%)

Epoch:  59
Learning rate:  0.004005792102878908


Loss=0.7105847597122192 Batch_id=390 Accuracy=77.48: 100%|██████████| 391/391 [00:40<00:00,  9.64it/s]



Test set: Average loss: 0.0038, Accuracy: 8380/10000 (83.80%)

Epoch:  60
Learning rate:  0.004005993548162307


Loss=0.7576695084571838 Batch_id=390 Accuracy=77.55: 100%|██████████| 391/391 [00:41<00:00,  9.48it/s]



Test set: Average loss: 0.0038, Accuracy: 8365/10000 (83.65%)

Epoch:  61
Learning rate:  0.004006198436665762


Loss=0.5695167779922485 Batch_id=390 Accuracy=77.52: 100%|██████████| 391/391 [00:41<00:00,  9.51it/s]



Test set: Average loss: 0.0040, Accuracy: 8298/10000 (82.98%)

Epoch:  62
Learning rate:  0.004006406768374535


Loss=0.5876227021217346 Batch_id=390 Accuracy=77.98: 100%|██████████| 391/391 [00:41<00:00,  9.44it/s]



Test set: Average loss: 0.0037, Accuracy: 8415/10000 (84.15%)

Epoch:  63
Learning rate:  0.004006618543273693


Loss=0.620241641998291 Batch_id=390 Accuracy=77.76: 100%|██████████| 391/391 [00:40<00:00,  9.60it/s]



Test set: Average loss: 0.0037, Accuracy: 8446/10000 (84.46%)

Epoch:  64
Learning rate:  0.0040068337613480415


Loss=0.5650013089179993 Batch_id=390 Accuracy=77.83: 100%|██████████| 391/391 [00:41<00:00,  9.44it/s]



Test set: Average loss: 0.0036, Accuracy: 8429/10000 (84.29%)

Epoch:  65
Learning rate:  0.004007052422582133


Loss=0.6374754905700684 Batch_id=390 Accuracy=77.93: 100%|██████████| 391/391 [00:40<00:00,  9.62it/s]



Test set: Average loss: 0.0037, Accuracy: 8396/10000 (83.96%)

Epoch:  66
Learning rate:  0.004007274526960286


Loss=0.6386277079582214 Batch_id=390 Accuracy=78.25: 100%|██████████| 391/391 [00:41<00:00,  9.45it/s]



Test set: Average loss: 0.0037, Accuracy: 8408/10000 (84.08%)

Epoch:  67
Learning rate:  0.004007500074466569


Loss=0.8678592443466187 Batch_id=390 Accuracy=77.88: 100%|██████████| 391/391 [00:40<00:00,  9.60it/s]



Test set: Average loss: 0.0036, Accuracy: 8450/10000 (84.50%)

Epoch:  68
Learning rate:  0.004007729065084786


Loss=0.7133013010025024 Batch_id=390 Accuracy=78.14: 100%|██████████| 391/391 [00:41<00:00,  9.50it/s]



Test set: Average loss: 0.0035, Accuracy: 8463/10000 (84.63%)

Epoch:  69
Learning rate:  0.004007961498798535


Loss=0.7069622874259949 Batch_id=390 Accuracy=78.23: 100%|██████████| 391/391 [00:41<00:00,  9.48it/s]



Test set: Average loss: 0.0036, Accuracy: 8474/10000 (84.74%)

Epoch:  70
Learning rate:  0.004008197375591133


Loss=0.580814003944397 Batch_id=390 Accuracy=78.19: 100%|██████████| 391/391 [00:41<00:00,  9.53it/s]



Test set: Average loss: 0.0038, Accuracy: 8403/10000 (84.03%)

Epoch:  71
Learning rate:  0.004008436695445622


Loss=0.6482537984848022 Batch_id=390 Accuracy=78.41: 100%|██████████| 391/391 [00:40<00:00,  9.60it/s]



Test set: Average loss: 0.0036, Accuracy: 8474/10000 (84.74%)

Epoch:  72
Learning rate:  0.004008679458344877


Loss=0.5509679913520813 Batch_id=390 Accuracy=78.54: 100%|██████████| 391/391 [00:40<00:00,  9.56it/s]



Test set: Average loss: 0.0035, Accuracy: 8528/10000 (85.28%)

Epoch:  73
Learning rate:  0.0040089256642714816


Loss=0.4759845733642578 Batch_id=390 Accuracy=78.53: 100%|██████████| 391/391 [00:40<00:00,  9.61it/s]



Test set: Average loss: 0.0037, Accuracy: 8406/10000 (84.06%)

Epoch:  74
Learning rate:  0.004009175313207727


Loss=0.5227862000465393 Batch_id=390 Accuracy=78.55: 100%|██████████| 391/391 [00:41<00:00,  9.50it/s]



Test set: Average loss: 0.0035, Accuracy: 8513/10000 (85.13%)

Epoch:  75
Learning rate:  0.004009428405135754


Loss=0.4765198230743408 Batch_id=390 Accuracy=78.77: 100%|██████████| 391/391 [00:41<00:00,  9.52it/s]



Test set: Average loss: 0.0035, Accuracy: 8472/10000 (84.72%)

Epoch:  76
Learning rate:  0.004009684940037367


Loss=0.7958260774612427 Batch_id=390 Accuracy=78.66: 100%|██████████| 391/391 [00:40<00:00,  9.58it/s]



Test set: Average loss: 0.0034, Accuracy: 8541/10000 (85.41%)

Epoch:  77
Learning rate:  0.004009944917894179


Loss=0.499755859375 Batch_id=390 Accuracy=78.67: 100%|██████████| 391/391 [00:41<00:00,  9.53it/s]



Test set: Average loss: 0.0034, Accuracy: 8512/10000 (85.12%)

Epoch:  78
Learning rate:  0.0040102083386875376


Loss=0.4356493055820465 Batch_id=390 Accuracy=78.67: 100%|██████████| 391/391 [00:40<00:00,  9.60it/s]



Test set: Average loss: 0.0036, Accuracy: 8471/10000 (84.71%)

Epoch:  79
Learning rate:  0.004010475202398556


Loss=0.6181755661964417 Batch_id=390 Accuracy=78.92: 100%|██████████| 391/391 [00:41<00:00,  9.47it/s]



Test set: Average loss: 0.0036, Accuracy: 8452/10000 (84.52%)

Epoch:  80
Learning rate:  0.004010745509008068


Loss=0.6478984951972961 Batch_id=390 Accuracy=79.08: 100%|██████████| 391/391 [00:40<00:00,  9.61it/s]



Test set: Average loss: 0.0034, Accuracy: 8562/10000 (85.62%)

Epoch:  81
Learning rate:  0.004011019258496687


Loss=0.6079732179641724 Batch_id=390 Accuracy=78.82: 100%|██████████| 391/391 [00:41<00:00,  9.34it/s]



Test set: Average loss: 0.0035, Accuracy: 8506/10000 (85.06%)

Epoch:  82
Learning rate:  0.00401129645084479


Loss=0.5118319392204285 Batch_id=390 Accuracy=79.17: 100%|██████████| 391/391 [00:40<00:00,  9.56it/s]



Test set: Average loss: 0.0035, Accuracy: 8512/10000 (85.12%)

Epoch:  83
Learning rate:  0.004011577086032461


Loss=0.5366116762161255 Batch_id=390 Accuracy=79.09: 100%|██████████| 391/391 [00:41<00:00,  9.47it/s]



Test set: Average loss: 0.0034, Accuracy: 8559/10000 (85.59%)

Epoch:  84
Learning rate:  0.004011861164039579


Loss=0.42381706833839417 Batch_id=390 Accuracy=79.03: 100%|██████████| 391/391 [00:40<00:00,  9.63it/s]



Test set: Average loss: 0.0034, Accuracy: 8521/10000 (85.21%)

Epoch:  85
Learning rate:  0.0040121486848457705


Loss=0.5351141095161438 Batch_id=390 Accuracy=79.10: 100%|██████████| 391/391 [00:41<00:00,  9.45it/s]



Test set: Average loss: 0.0034, Accuracy: 8533/10000 (85.33%)

Epoch:  86
Learning rate:  0.004012439648430399


Loss=0.6902936697006226 Batch_id=390 Accuracy=79.07: 100%|██████████| 391/391 [00:40<00:00,  9.62it/s]



Test set: Average loss: 0.0035, Accuracy: 8485/10000 (84.85%)

Epoch:  87
Learning rate:  0.004012734054772607


Loss=0.5771927833557129 Batch_id=390 Accuracy=79.33: 100%|██████████| 391/391 [00:41<00:00,  9.50it/s]



Test set: Average loss: 0.0035, Accuracy: 8479/10000 (84.79%)

Epoch:  88
Learning rate:  0.004013031903851258


Loss=0.686380922794342 Batch_id=390 Accuracy=79.45: 100%|██████████| 391/391 [00:40<00:00,  9.60it/s]



Test set: Average loss: 0.0035, Accuracy: 8516/10000 (85.16%)

Epoch:  89
Learning rate:  0.004013333195644994


Loss=0.639906644821167 Batch_id=390 Accuracy=79.32: 100%|██████████| 391/391 [00:40<00:00,  9.56it/s]



Test set: Average loss: 0.0035, Accuracy: 8536/10000 (85.36%)

Epoch:  90
Learning rate:  0.00401363793013218


Loss=0.7167567014694214 Batch_id=390 Accuracy=79.32: 100%|██████████| 391/391 [00:40<00:00,  9.63it/s]



Test set: Average loss: 0.0035, Accuracy: 8516/10000 (85.16%)

Epoch:  91
Learning rate:  0.004013946107290958


Loss=0.6343678832054138 Batch_id=390 Accuracy=79.34: 100%|██████████| 391/391 [00:41<00:00,  9.48it/s]



Test set: Average loss: 0.0034, Accuracy: 8546/10000 (85.46%)

Epoch:  92
Learning rate:  0.004014257727099235


Loss=0.6418282389640808 Batch_id=390 Accuracy=79.44: 100%|██████████| 391/391 [00:40<00:00,  9.67it/s]



Test set: Average loss: 0.0034, Accuracy: 8555/10000 (85.55%)

Epoch:  93
Learning rate:  0.004014572789534654


Loss=0.6481951475143433 Batch_id=390 Accuracy=79.93: 100%|██████████| 391/391 [00:40<00:00,  9.63it/s]



Test set: Average loss: 0.0033, Accuracy: 8593/10000 (85.93%)

Epoch:  94
Learning rate:  0.00401489129457458


Loss=0.655978798866272 Batch_id=390 Accuracy=79.69: 100%|██████████| 391/391 [00:40<00:00,  9.63it/s]



Test set: Average loss: 0.0035, Accuracy: 8527/10000 (85.27%)

Epoch:  95
Learning rate:  0.004015213242196197


Loss=0.6200782060623169 Batch_id=390 Accuracy=79.57: 100%|██████████| 391/391 [00:40<00:00,  9.57it/s]



Test set: Average loss: 0.0034, Accuracy: 8550/10000 (85.50%)

Epoch:  96
Learning rate:  0.004015538632376414


Loss=0.44762319326400757 Batch_id=390 Accuracy=79.41: 100%|██████████| 391/391 [00:40<00:00,  9.70it/s]



Test set: Average loss: 0.0033, Accuracy: 8603/10000 (86.03%)

Epoch:  97
Learning rate:  0.0040158674650918325


Loss=0.5116454362869263 Batch_id=390 Accuracy=79.49: 100%|██████████| 391/391 [00:40<00:00,  9.56it/s]



Test set: Average loss: 0.0034, Accuracy: 8572/10000 (85.72%)

Epoch:  98
Learning rate:  0.004016199740318915


Loss=0.5007434487342834 Batch_id=390 Accuracy=79.43: 100%|██████████| 391/391 [00:41<00:00,  9.50it/s]



Test set: Average loss: 0.0032, Accuracy: 8621/10000 (86.21%)

Epoch:  99
Learning rate:  0.004016535458033821


Loss=0.4407823979854584 Batch_id=390 Accuracy=80.03: 100%|██████████| 391/391 [00:40<00:00,  9.60it/s]



Test set: Average loss: 0.0035, Accuracy: 8550/10000 (85.50%)

Epoch:  100
Learning rate:  0.004016874618212443


Loss=0.4359794557094574 Batch_id=390 Accuracy=79.97: 100%|██████████| 391/391 [00:40<00:00,  9.66it/s]



Test set: Average loss: 0.0033, Accuracy: 8594/10000 (85.94%)



In [None]:

%matplotlib inline
import matplotlib.pyplot as plt

def plot_training_curve(train_losses, train_acc, test_losses, test_acc):
    t = [t_items.item() for t_items in train_losses]
    fig, axs = plt.subplots(2,2,figsize=(15,10))
    axs[0, 0].plot(t)
    axs[0, 0].set_title("Training Loss")
    axs[1, 0].plot(train_acc[4000:])
    axs[1, 0].set_title("Training Accuracy")
    axs[0, 1].plot(test_losses)
    axs[0, 1].set_title("Test Loss")
    axs[1, 1].plot(test_acc)
    axs[1, 1].set_title("Test Accuracy")

In [None]:
plot_training_curve(train_losses, train_acc, test_losses, test_acc)