In [1]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.models import alexnet

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
use_cuda = torch.cuda.is_available()
assert use_cuda == True
device

device(type='cuda', index=0)

In [2]:
from pylab import *
from IPython.core.debugger import set_trace

%matplotlib inline

In [3]:
"""
Loads Alex NN in gpu
Sets optimizer and Loss function
"""
#
a_model = alexnet()
#a_model = torch.load('AlexMnist')
a_model = a_model.to(device)
#a_model = resnet18().cuda()
#n_ftrs=a_model.fc.in_features
#a_model.fc=torch.nn.Linear(n_ftrs,10)
optimizer = optim.SGD(a_model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

In [None]:
"""
Weight Initialization
"""
for m in a_model.modules():
    if isinstance(m, nn.Conv2d):
        

In [4]:
"""
Imports DataSets from pytorch db
"""
root = './data'
if not os.path.exists(root):
    os.mkdir(root)

data_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: torch.cat([x, x, x], 0))
])
#trans = transforms.ToTensor()    
mnist_trainset = dset.MNIST(root=root, train=True,transform = data_transform, download=True)
mnist_testset = dset.MNIST(root=root, train=False,transform = data_transform, download=True)

batch_size = 64

mnist_train_loader = torch.utils.data.DataLoader(
                 dataset=mnist_trainset,
                 batch_size=batch_size,
                 shuffle=True)
mnist_test_loader = torch.utils.data.DataLoader(
                dataset=mnist_testset,
                batch_size=batch_size,
                shuffle=False)

print('===>>> MNIST total training batch number: {}'.format(len(mnist_train_loader)))
print('===>>> MNIST total testing batch number: {}'.format(len(mnist_test_loader)))

===>>> MNIST total training batch number: 938
===>>> MNIST total testing batch number: 157


In [5]:
"""
Parameters
"""
epoches = 10

In [6]:
"""
Training
"""
a_model.train() #sets training mode
for epoch in range(epoches):
    ave_loss = 0
    for batch_idx, (x,target) in enumerate(mnist_train_loader):
        x = x.cuda()
        target = target.cuda()
        #set_trace()
        #print(x.shape)
        #x = torch.cat((x,x,x),0)
        
        out = a_model(x)
        loss = criterion(out,target)
        
        ave_loss = ave_loss * 0.9 + loss.item() * 0.1
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if(batch_idx+1) % batch_size == 0 or (batch_idx+1) == len(mnist_train_loader):
            print('==>>> epoch: {}, batch index: {}, train loss {:.6f}'.format(epoch,batch_idx+1,ave_loss))

==>>> epoch: 0, batch index: 64, train loss 5.986807
==>>> epoch: 0, batch index: 128, train loss 2.465410
==>>> epoch: 0, batch index: 192, train loss 2.345332
==>>> epoch: 0, batch index: 256, train loss 2.353976
==>>> epoch: 0, batch index: 320, train loss 2.341865
==>>> epoch: 0, batch index: 384, train loss 2.344184
==>>> epoch: 0, batch index: 448, train loss 2.334264
==>>> epoch: 0, batch index: 512, train loss 2.330644
==>>> epoch: 0, batch index: 576, train loss 2.345826
==>>> epoch: 0, batch index: 640, train loss 2.323877
==>>> epoch: 0, batch index: 704, train loss 2.330319
==>>> epoch: 0, batch index: 768, train loss 2.331876
==>>> epoch: 0, batch index: 832, train loss 2.335341
==>>> epoch: 0, batch index: 896, train loss 2.324032
==>>> epoch: 0, batch index: 938, train loss 2.318449
==>>> epoch: 1, batch index: 64, train loss 2.319501
==>>> epoch: 1, batch index: 128, train loss 2.328300
==>>> epoch: 1, batch index: 192, train loss 2.329762
==>>> epoch: 1, batch index: 2

In [7]:
#torch.save(a_model, './AlexMnist')

In [8]:
"""
Testing
"""
a_model = a_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in mnist_test_loader:
        images, labels = data
        images = images.cuda()
        labels = labels.cuda()
        outputs = a_model(images)
        _, predicted = torch.max(outputs.data, 1)
        print(predicted)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()



tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 8, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 2, 3, 5, 1, 2,
        4, 4, 6, 3, 5, 5, 6, 0, 4, 1, 9, 5, 7, 8, 9, 3], device='cuda:0')
tensor([7, 4, 6, 4, 3, 0, 7, 0, 2, 8, 1, 7, 3, 2, 8, 7, 7, 6, 2, 7, 8, 4, 7, 3,
        6, 1, 3, 6, 9, 3, 1, 4, 1, 7, 6, 9, 6, 0, 5, 4, 9, 9, 2, 1, 9, 4, 8, 7,
        3, 9, 7, 4, 4, 4, 9, 2, 5, 4, 7, 6, 7, 9, 0, 5], device='cuda:0')
tensor([8, 5, 6, 6, 5, 7, 8, 1, 0, 1, 6, 4, 6, 7, 3, 1, 7, 1, 8, 2, 0, 2, 9, 9,
        5, 5, 1, 5, 6, 0, 3, 4, 4, 6, 5, 4, 6, 5, 4, 5, 1, 4, 4, 7, 2, 3, 2, 7,
        1, 8, 1, 8, 1, 8, 5, 0, 8, 9, 2, 5, 0, 1, 1, 1], device='cuda:0')
tensor([0, 9, 0, 3, 1, 6, 4, 2, 3, 6, 1, 1, 1, 3, 8, 5, 2, 9, 4, 5, 9, 3, 9, 0,
        3, 6, 5, 5, 7, 2, 2, 7, 1, 2, 8, 4, 1, 7, 3, 3, 8, 8, 7, 9, 2, 2, 4, 1,
        5, 8, 8, 7, 2, 3, 0, 2, 4, 2, 4, 1, 9, 5, 7, 7], device='cuda:0')
tensor([2, 8, 2, 6, 8, 5, 7, 7, 9, 1, 8, 1, 8, 0, 3, 0, 1, 9, 9,

tensor([0, 3, 4, 4, 3, 8, 9, 2, 3, 9, 7, 1, 1, 7, 0, 4, 9, 6, 5, 9, 1, 2, 0, 2,
        0, 0, 4, 6, 7, 0, 7, 1, 4, 6, 4, 5, 4, 9, 9, 1, 7, 9, 5, 3, 3, 8, 2, 3,
        6, 2, 2, 1, 1, 1, 1, 1, 6, 9, 8, 4, 3, 7, 1, 6], device='cuda:0')
tensor([4, 5, 0, 4, 7, 4, 2, 4, 0, 7, 0, 1, 9, 8, 8, 6, 0, 0, 4, 1, 6, 8, 2, 2,
        3, 8, 4, 8, 2, 2, 1, 7, 5, 4, 4, 0, 4, 3, 9, 7, 3, 1, 0, 1, 2, 5, 9, 2,
        1, 0, 1, 8, 9, 1, 6, 8, 3, 8, 9, 3, 6, 2, 8, 3], device='cuda:0')
tensor([2, 2, 1, 0, 4, 2, 9, 2, 4, 3, 7, 9, 1, 5, 2, 9, 9, 0, 3, 8, 5, 3, 8, 0,
        9, 4, 6, 2, 5, 0, 2, 7, 4, 6, 6, 8, 6, 6, 8, 6, 9, 1, 7, 2, 5, 9, 9, 0,
        7, 2, 7, 6, 7, 0, 6, 5, 2, 4, 7, 2, 0, 9, 9, 2], device='cuda:0')
tensor([2, 9, 4, 4, 2, 3, 3, 2, 1, 7, 0, 7, 6, 4, 1, 3, 8, 7, 4, 5, 9, 2, 5, 1,
        8, 7, 3, 7, 1, 5, 5, 0, 9, 1, 4, 0, 6, 3, 3, 6, 0, 4, 9, 7, 5, 1, 6, 8,
        9, 5, 5, 7, 9, 3, 8, 3, 8, 1, 5, 3, 5, 0, 5, 5], device='cuda:0')
tensor([3, 8, 6, 7, 7, 7, 3, 7, 0, 5, 9, 0, 2, 5, 5, 3, 1, 7, 7,

tensor([8, 4, 9, 0, 7, 3, 0, 2, 9, 0, 6, 6, 6, 3, 6, 7, 7, 2, 8, 6, 0, 8, 3, 0,
        2, 9, 8, 3, 2, 5, 3, 9, 8, 0, 0, 1, 9, 5, 1, 3, 9, 6, 0, 1, 4, 1, 7, 1,
        2, 3, 7, 9, 7, 4, 9, 9, 3, 9, 2, 8, 2, 7, 1, 8], device='cuda:0')
tensor([0, 9, 1, 0, 1, 7, 7, 9, 6, 9, 9, 9, 2, 1, 6, 1, 3, 5, 2, 1, 9, 7, 6, 4,
        5, 7, 6, 6, 9, 9, 6, 3, 6, 2, 9, 8, 1, 2, 2, 5, 5, 2, 3, 7, 2, 1, 0, 1,
        0, 4, 5, 2, 8, 2, 8, 3, 5, 1, 7, 8, 1, 1, 2, 9], device='cuda:0')
tensor([7, 8, 4, 0, 3, 0, 7, 8, 8, 4, 7, 7, 8, 5, 8, 4, 9, 8, 1, 3, 8, 0, 3, 4,
        7, 8, 5, 6, 1, 6, 5, 7, 4, 9, 3, 5, 4, 7, 1, 2, 0, 8, 1, 6, 0, 7, 3, 4,
        7, 3, 9, 6, 0, 8, 6, 4, 8, 7, 7, 9, 3, 8, 6, 9], device='cuda:0')
tensor([7, 2, 3, 4, 0, 2, 1, 0, 5, 5, 5, 7, 2, 4, 0, 7, 2, 8, 3, 0, 8, 7, 8, 9,
        0, 8, 4, 4, 5, 8, 5, 6, 6, 3, 0, 9, 3, 7, 6, 8, 9, 3, 4, 9, 5, 8, 9, 1,
        2, 8, 8, 6, 8, 1, 3, 7, 9, 0, 1, 1, 9, 7, 0, 8], device='cuda:0')
tensor([1, 7, 4, 5, 7, 1, 2, 1, 1, 3, 0, 6, 2, 1, 2, 8, 0, 7, 6,

tensor([2, 1, 9, 4, 9, 1, 3, 9, 2, 0, 6, 0, 4, 0, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8,
        9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 8, 0,
        7, 1, 0, 7, 5, 5, 6, 9, 0, 1, 0, 2, 8, 3, 4, 3], device='cuda:0')
tensor([1, 5, 0, 0, 9, 5, 3, 4, 9, 3, 7, 6, 9, 2, 4, 5, 7, 2, 6, 4, 9, 4, 9, 4,
        1, 2, 2, 5, 8, 1, 3, 2, 9, 4, 3, 8, 2, 2, 1, 2, 8, 6, 5, 1, 6, 7, 2, 1,
        3, 9, 3, 8, 7, 5, 7, 2, 7, 4, 8, 8, 5, 0, 6, 6], device='cuda:0')
tensor([3, 7, 6, 9, 9, 4, 8, 4, 1, 0, 6, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
        2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 7, 4, 0, 4, 0, 1,
        7, 9, 5, 1, 4, 2, 8, 9, 4, 3, 7, 8, 2, 4, 4, 3], device='cuda:0')
tensor([3, 6, 9, 9, 5, 8, 6, 7, 0, 6, 8, 2, 6, 3, 9, 3, 2, 8, 6, 1, 7, 4, 8, 8,
        9, 0, 3, 3, 9, 0, 5, 2, 9, 4, 1, 0, 3, 7, 5, 8, 7, 7, 8, 2, 9, 7, 1, 2,
        6, 4, 2, 5, 2, 3, 6, 6, 5, 0, 0, 2, 8, 1, 6, 1], device='cuda:0')
tensor([0, 4, 3, 1, 6, 1, 9, 0, 1, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4,

tensor([9, 3, 9, 3, 0, 0, 1, 0, 4, 2, 6, 3, 5, 3, 0, 3, 4, 1, 5, 3, 0, 8, 3, 0,
        6, 1, 7, 8, 0, 9, 2, 6, 7, 1, 9, 6, 9, 4, 9, 9, 6, 7, 1, 2, 5, 3, 7, 8,
        0, 1, 2, 4, 5, 6, 7, 8, 9, 0, 1, 3, 4, 5, 6, 7], device='cuda:0')
tensor([8, 0, 1, 3, 4, 7, 8, 9, 7, 5, 5, 1, 9, 9, 7, 1, 0, 0, 5, 9, 7, 1, 7, 2,
        2, 3, 6, 8, 3, 2, 0, 0, 6, 1, 7, 5, 8, 6, 2, 9, 4, 8, 8, 7, 1, 0, 8, 7,
        7, 5, 8, 5, 3, 4, 6, 1, 1, 5, 5, 0, 7, 2, 3, 6], device='cuda:0')
tensor([4, 1, 2, 4, 1, 5, 4, 2, 0, 4, 8, 6, 1, 9, 0, 2, 5, 6, 9, 3, 6, 3, 6, 0,
        1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 5,
        6, 7, 8, 1, 0, 9, 5, 7, 5, 1, 8, 6, 9, 0, 4, 1], device='cuda:0')
tensor([9, 3, 8, 4, 4, 7, 0, 1, 9, 2, 8, 7, 8, 2, 5, 9, 6, 0, 6, 5, 5, 3, 3, 3,
        9, 8, 1, 1, 0, 6, 1, 0, 0, 6, 2, 1, 1, 3, 2, 7, 7, 8, 8, 7, 8, 4, 6, 0,
        2, 0, 7, 0, 3, 6, 8, 7, 1, 5, 9, 9, 3, 7, 2, 4], device='cuda:0')
tensor([9, 4, 3, 6, 2, 2, 5, 3, 2, 5, 5, 9, 4, 1, 7, 2, 0, 1, 2,

In [9]:
print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 97 %


In [11]:
torch.save(a_model, './AlexMnist')

In [37]:
ws = a_model.state_dict()

In [39]:
ws

OrderedDict([('features.0.weight',
              tensor([[[[-0.0994, -0.0463, -0.0902,  ..., -0.0894, -0.0582, -0.0299],
                        [-0.0765, -0.0849, -0.0419,  ..., -0.0688, -0.1233, -0.0210],
                        [-0.0649, -0.0308, -0.0452,  ..., -0.0519, -0.0538, -0.0393],
                        ...,
                        [-0.1094, -0.0960, -0.1147,  ..., -0.0823, -0.0353, -0.0518],
                        [-0.0707, -0.1038, -0.1248,  ..., -0.0984, -0.0403, -0.0762],
                        [-0.0226, -0.0289, -0.0948,  ..., -0.0588, -0.0563, -0.1044]],
              
                       [[-0.0207, -0.1163, -0.0720,  ..., -0.1108, -0.0399, -0.0581],
                        [-0.1204, -0.0486, -0.0674,  ..., -0.0311, -0.0887, -0.1019],
                        [-0.0502, -0.0379, -0.0616,  ..., -0.0579, -0.0529, -0.1176],
                        ...,
                        [-0.0462, -0.0901, -0.0580,  ..., -0.1175, -0.1243, -0.0395],
                        [-0.021