MNIST ResNet18 (ResNet50) in PyTorch

In [1]:
from torchvision import datasets
from torchvision.transforms import ToTensor
train_data = datasets.MNIST(
    root = 'data',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
)
test_data = datasets.MNIST(
    root = 'data', 
    train = False, 
    transform = ToTensor()
)

Data Loading

In [2]:
from torch.utils.data import DataLoader
loaders = {
    'train' : DataLoader(train_data, 
                                          batch_size=100, 
                                          shuffle=True, 
                                          num_workers=1),
    
    'test'  : DataLoader(test_data, 
                                          batch_size=100, 
                                          shuffle=True, 
                                          num_workers=1),
}
loaders

{'train': <torch.utils.data.dataloader.DataLoader at 0x2d0aac5dd90>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x2d0aac5de20>}

Define model RESNET18

In [3]:
import torch
import torch.nn as nn
import torchvision

class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()

        self.model = torchvision.models.resnet50(pretrained=True)
		
        self.model.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, 10)
        
    def forward(self, x):		
        return self.model(x)
	
 

In [4]:
resnet = ResNet18()


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to C:\Users\aniru/.cache\torch\hub\checkpoints\resnet50-0676ba61.pth
100.0%


Loss Function 

In [5]:
loss_func = nn.CrossEntropyLoss()   
loss_func

CrossEntropyLoss()

Define a Optimization Function

In [6]:
from torch import optim
optimizer = optim.Adam(resnet.parameters(), lr = 0.01)   
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.01
    maximize: False
    weight_decay: 0
)

Training the Model 

In [7]:
from torch.autograd import Variable
num_epochs = 5
def train(num_epochs, resnet, loaders):
    
    resnet.train()
        
    total_step = len(loaders['train'])
        
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(loaders['train']):
            
            
            #b_x = Variable(images)   # batch x
            #b_y = Variable(labels)   # batch y
            output = resnet(images)               
            loss = loss_func(output, labels)
            
             
            optimizer.zero_grad()           
            
            
            loss.backward()    
                         
            optimizer.step()                
            
            if (i+1) % 100 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                       .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
                pass
        
        pass
    
    
    pass
if __name__ == '__main__':
    train(num_epochs, resnet, loaders)

Epoch [1/5], Step [100/600], Loss: 0.2293


KeyboardInterrupt: 

Testing the model 

In [8]:
def test():
    resnet.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in loaders['test']:
            test_output = resnet(images)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            accuracy = (pred_y == labels).sum().item() / float(labels.size(0))
            pass
    print('Test Accuracy of the model on the 10000 test images: %.2f' % accuracy)
pass


test()

Test Accuracy of the model on the 10000 test images: 0.92


print 10 prediction from test data

In [9]:
sample = next(iter(loaders['test']))
imgs, lbls = sample

In [10]:
actual_number = lbls[:10].numpy()
actual_number

array([1, 5, 8, 7, 9, 8, 5, 1, 4, 8], dtype=int64)

In [11]:
test_output= resnet(imgs[:10])
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(f'Prediction number: {pred_y}')
print(f'Actual number: {actual_number}')

Prediction number: [1 5 8 7 9 8 5 1 4 8]
Actual number: [1 5 8 7 9 8 5 1 4 8]
