In [0]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch import Tensor
import os

In [0]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Normalize the test set same as training set without augmentation
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [0]:
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [0]:
import torch.nn as nn


class CNN(nn.Module):
    """CNN."""

    def __init__(self):
        """CNN Builder."""
        super(CNN, self).__init__()

        self.conv_layer = nn.Sequential(

            # Conv Layer block 1
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Conv Layer block 2
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout2d(p=0.05),

            # Conv Layer block 3
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )


        self.fc_layer = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(4096, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Linear(512, 10)
        )


    def forward(self, x):
        """Perform forward."""
        
        # conv layers
        x = self.conv_layer(x)
        
        # flatten
        x = x.view(x.size(0), -1)
        
        # fc layer
        x = self.fc_layer(x)

        return x
net=CNN()

In [0]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001, weight_decay=0)

In [0]:

net = net.cuda()
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
torch.backends.cudnn.benchmark = True

In [0]:
def calculate_accuracy(dataset):
  correct = 0
  total = 0
  with torch.no_grad():
    for data in dataset:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
  return 100 * correct / total

In [0]:
for epoch in range(0, 50):

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data

        inputs = inputs.cuda()
        labels = labels.cuda()

        # wrap them in Variable
        #inputs, labels = Variable(inputs), Variable(labels)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        if epoch > 16:
            for group in optimizer.param_groups:
                for p in group['params']:
                    state = optimizer.state[p]
                    if state['step'] >= 1024:
                        state['step'] = 1000
        optimizer.step()


        # print statistics
        running_loss += loss.data

    # Normalizing the loss by the total number of train batches
    running_loss /= len(trainloader)

    # Calculate training/test set accuracy of the existing model
    #train_accuracy = calculate_accuracy(trainloader)
    #test_accuracy = calculate_accuracy(testloader)

    print("Iteration: {0} | Loss: {1} |".format(epoch+1, running_loss))

    # save model  Training accuracy: {2}% | Test accuracy: {3}% , train_accuracy, test_accuracy
    #if epoch % 50 == 0:
    #    print('==> Saving model ...')
     #   state = {
      #      'net': net.module,
       #     'epoch': epoch,
        #}
        #if not os.path.isdir('checkpoint'):
         #   os.mkdir('checkpoint')
        #torch.save(state, '../checkpoint/ckpt.t7')

print('==> Finished Training ...')

Iteration: 1 | Loss: 0.7786043286323547 |
Iteration: 2 | Loss: 0.7056595683097839 |
Iteration: 3 | Loss: 0.6485595107078552 |
Iteration: 4 | Loss: 0.6013839840888977 |
Iteration: 5 | Loss: 0.5600818991661072 |
Iteration: 6 | Loss: 0.5246593952178955 |
Iteration: 7 | Loss: 0.4968425929546356 |
Iteration: 8 | Loss: 0.4686832129955292 |
Iteration: 9 | Loss: 0.44329696893692017 |
Iteration: 10 | Loss: 0.4183397889137268 |
Iteration: 11 | Loss: 0.3966088593006134 |
Iteration: 12 | Loss: 0.3834694027900696 |
Iteration: 13 | Loss: 0.3653469979763031 |
Iteration: 14 | Loss: 0.343368262052536 |
Iteration: 15 | Loss: 0.3315315842628479 |
Iteration: 16 | Loss: 0.31881919503211975 |
Iteration: 17 | Loss: 0.3027816116809845 |
Iteration: 18 | Loss: 0.27673643827438354 |
Iteration: 19 | Loss: 0.26327747106552124 |
Iteration: 20 | Loss: 0.2570335566997528 |
Iteration: 21 | Loss: 0.24076591432094574 |
Iteration: 22 | Loss: 0.2320958822965622 |
Iteration: 23 | Loss: 0.2246130406856537 |
Iteration: 24 | 