In [0]:
import numpy as np
import sklearn.metrics
import torchvision
from torchvision import models, datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import os 
import torch 

In [0]:
vgg = models.vgg11_bn(pretrained = True)
vgg.classifier[-1] = torch.nn.Linear(4096,10)

In [0]:
transform = transforms.Compose([
    transforms.Grayscale(3), 
    transforms.CenterCrop(224), 
    transforms.ToTensor()
    ])
mnist = torchvision.datasets.MNIST('./var',download = True)
workers = int(os.cpu_count()/2)

train = torchvision.datasets.MNIST('./var',train =True, transform = transform)
trainloader = torch.utils.data.DataLoader(train,batch_size =32,shuffle=True,num_workers = workers)
test = torchvision.datasets.MNIST('./var',train =False, transform = transform)
testloader = torch.utils.data.DataLoader(test,batch_size =32,shuffle=True,num_workers = workers)

In [0]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(vgg.parameters())

vgg.to(device)
vgg.train()
for epoch in range(20):
    for inputs,outputs in tqdm(trainloader):
        inputs = inputs.to(device, non_blocking = True)
        outputs = outputs.to(device, non_blocking = True)
        optimizer.zero_grad()
        results = vgg(inputs)
        loss = loss_function(results,outputs)
        loss.backward()
        optimizer.step()
    print ('Loss:{0}'.format(loss))

In [0]:
results_buffer = []
actual_buffer = []
with torch.no_grad():
    vgg.eval()
    for inputs, actual in testloader:
        inputs = inputs.to(device, non_blocking = True)
        results = vgg(inputs).argmax(dim=1).to('cpu').numpy()
        results_buffer.append(results)
        actual_buffer.append(actual)
print (sklearn.metrics.classification_report(
    np.concatenate(actual_buffer),
    np.concatenate(results_buffer)
    )
      )

              precision    recall  f1-score   support

           0       0.68      0.96      0.80       980
           1       0.96      0.96      0.96      1135
           2       0.88      0.78      0.83      1032
           3       0.48      0.99      0.65      1010
           4       1.00      0.07      0.14       982
           5       0.18      0.12      0.14       892
           6       1.00      0.30      0.46       958
           7       0.46      0.99      0.63      1028
           8       0.97      0.43      0.60       974
           9       0.93      0.82      0.87      1009

    accuracy                           0.66     10000
   macro avg       0.75      0.64      0.61     10000
weighted avg       0.76      0.66      0.62     10000

