In [0]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import sklearn.metrics

In [0]:
class SlimAlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1,32,kernel_size = 3, stride =1),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size =3, stride = 2),
            nn.Conv2d(32,64,kernel_size = 3),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size =3, stride = 2),
            nn.Conv2d(64,128,kernel_size = 3, padding =1),
            nn.ReLU(inplace = True),
            nn.Conv2d(128,256,kernel_size = 3, padding =1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256,128,kernel_size = 3, padding =1),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size =3, stride = 2),
            )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(128,1024),
            nn.ReLU(inplace= True),
            nn.Dropout(),
            nn.Linear(1024,1024),
            nn.ReLU(inplace= True),
            nn.Linear(1024,num_classes)
            )
    def forward(self,x):
        x = self.features(x)
        x = x.flatten(start_dim=1)
        x = self.classifier(x)
        return x 

In [3]:
mnist = torchvision.datasets.MNIST('./var',download = True)
transform = transforms.Compose([transforms.ToTensor()])
train = torchvision.datasets.MNIST('./var',train =True, transform = transform)
trainloader = torch.utils.data.DataLoader(train,batch_size =32,shuffle=True)
test = torchvision.datasets.MNIST('./var',train =False, transform = transform)
testloader = torch.utils.data.DataLoader(test,batch_size =len(test),shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./var/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./var/MNIST/raw/train-images-idx3-ubyte.gz to ./var/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./var/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./var/MNIST/raw/train-labels-idx1-ubyte.gz to ./var/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./var/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./var/MNIST/raw/t10k-images-idx3-ubyte.gz to ./var/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./var/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./var/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./var/MNIST/raw
Processing...
Done!




In [4]:
net = SlimAlexNet(num_classes=10)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
net.to(device)

for epoch in range(16):
    totoal_loss = 0
    for inputs,outputs in trainloader:
        inputs = inputs.to(device)
        outputs = outputs.to(device)
        optimizer.zero_grad()
        results = net(inputs)
        loss = loss_function(results,outputs)
        totoal_loss += loss.item()
        loss.backward()
        optimizer.step()
    print ('Loss:{0}'.format(totoal_loss/len(trainloader)))

for inputs,actual in testloader:
    inputs = inputs.to(device)
    results = net(inputs).argmax(dim=1).to('cpu').numpy()
    accuracy = sklearn.metrics.accuracy_score(actual,results)
    print(accuracy)
print(sklearn.metrics.classification_report(actual,results))




Loss:0.33659027629097305
Loss:0.09883843184908231
Loss:0.07384511246780555
Loss:0.06449187859147787
Loss:0.057848890162507695
Loss:0.05103320320416242
Loss:0.04823986408112881
Loss:0.045207177080089846
Loss:0.04717345767547376
Loss:0.03493837742140361
Loss:0.037289574959597664
Loss:0.03836622825067801
Loss:0.03665045520040439
Loss:0.03763867653748797
Loss:0.03812796914447099
Loss:0.032225746778650984
0.988
              precision    recall  f1-score   support

           0       0.99      0.99      0.99       980
           1       0.99      0.99      0.99      1135
           2       0.99      0.98      0.99      1032
           3       0.98      0.99      0.99      1010
           4       0.99      0.99      0.99       982
           5       0.98      0.98      0.98       892
           6       0.99      0.99      0.99       958
           7       0.99      0.99      0.99      1028
           8       0.99      0.98      0.98       974
           9       0.99      0.99      0.99   