In [0]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

In [2]:
transform=transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),
                              transforms.RandomVerticalFlip(p=0.5),
                              transforms.ToTensor(),
                               transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset=torchvision.datasets.CIFAR10(root='./data', train=True,
                                    download=True, transform=transform)
trainloader=torch.utils.data.DataLoader(trainset, batch_size=5,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
                                         shuffle=False, num_workers=0)

classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [0]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 20, (3,3))
        # self.batch1=nn.BatchNorm2d(20)
        self.conv2 = nn.Conv2d(20, 50, (3,3))
        # self.batch2=nn.BatchNorm2d(50)
        self.fc1 = nn.Linear(50 * (6*6), 125)
        self.fc2 = nn.Linear(125, 75)
        self.fc3 = nn.Linear(75, 10)
        self.pool = nn.MaxPool2d((2,2), 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 50 * (6*6))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

In [4]:
# Classification Cross-Entropy loss and SGD with momentum
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(net.parameters())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net.to(device)

cuda:0


Net(
  (conv1): Conv2d(3, 20, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=1800, out_features=125, bias=True)
  (fc2): Linear(in_features=125, out_features=75, bias=True)
  (fc3): Linear(in_features=75, out_features=10, bias=True)
  (pool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
)

Training my network here.

In [5]:
for epoch in range(10): # no of epochs
    running_loss=0.0
    for i,data in enumerate(trainloader,0):
          inputs, labels = data[0].to(device), data[1].to(device)
          optimizer.zero_grad()
          outputs=net(inputs)
          loss=criterion(outputs,labels)
          loss.backward()
          optimizer.step()
          running_loss+=loss.item()
          if i%2000==1999:
                if i%12000==11999: running_loss=0
                #   print('Epoch:',epoch+1,', loss:',running_loss/2000)
                
    correct=0
    total=0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs=net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print("Epoch=",epoch+1,", Accuracy=",correct*100/total)
print('Model trained :)')

Epoch= 1 , Accuracy= 50.06
Epoch= 2 , Accuracy= 57.52
Epoch= 3 , Accuracy= 59.48
Epoch= 4 , Accuracy= 62.01
Epoch= 5 , Accuracy= 62.93
Epoch= 6 , Accuracy= 63.91
Epoch= 7 , Accuracy= 62.73
Epoch= 8 , Accuracy= 65.75
Epoch= 9 , Accuracy= 65.42
Epoch= 10 , Accuracy= 64.28
Model trained :)


In [0]:
# torch.save(net.state_dict(), './cifar2.pth')