In [1]:
# Import packages
import gzip
import torch
import torchvision
import numpy as np 

import idx2numpy
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
def load_one_dataset(path):
    
    f = gzip.open(path, 'rb')
    data = torch.from_numpy(idx2numpy.convert_from_file(f))
    f.close()
    
    return(data)


def load_all_datasets(train_imgs, train_labs, test_imgs, test_labs, batch_size):
    
    
    train_images = load_one_dataset(train_imgs).type(torch.float32)
    train_labels = load_one_dataset(train_labs).type(torch.long)
    train = list(zip(train_images, train_labels))
    
    test_images = load_one_dataset(test_imgs).type(torch.float32)
    test_labels = load_one_dataset(test_labs).type(torch.long)
    test = list(zip(test_images, test_labels))
    
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return(train_loader, test_loader)

In [3]:
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

In [4]:
def train(epochs=2, lr=0.01, momentum=0.1, batch_size=1024):
    
    
    # Set paths to datasets
    paths = {
        
        'train_imgs': 'train-images-idx3-ubyte.gz',
        'train_labs': 'train-labels-idx1-ubyte.gz',
        'test_imgs': 't10k-images-idx3-ubyte.gz',
        'test_labs': 't10k-labels-idx1-ubyte.gz'
    }
    
    # Load datasets
    train_loader, test_loader = load_all_datasets(**paths, batch_size = 256)
    
    # Set parameters
    net = Net()
    
    # We use cross entropy loss
    criterion = nn.CrossEntropyLoss()
    
    # We use a momentum optimizer
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    
    # Loop over the dataset multiple times
    for epoch in range(epochs):  
        
        # Initialize running loss
        running_loss = 0.0

        # Iterate through data now
        for i, data in enumerate(train_loader):
            
            # Get the inputs: data is a list of [inputs, labels]
            inputs, labels = data

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward
            outputs = net(inputs)
            
            # Calculate loss
            loss = criterion(outputs, labels)
            
            # Backward
            loss.backward()
            
            # Optimize
            optimizer.step()

            # Print statistics
            running_loss += loss.item()
        
        # Perform validation loss calculation
        with torch.no_grad(): 
            
            # Initialize running validation loss
            running_test_loss = 0.0
            
            # Iterate through the training set
            for j, test_data in enumerate(test_loader):
                
                # Unpack data
                test_inputs, test_labels = test_data
            
                # Get validation outputs
                test_outputs = net(test_inputs)
        
                # Validation loss
                test_loss = criterion(test_outputs, test_labels)
        
                # Compute running validation loss
                running_test_loss += test_loss.item()
            
        # Print loss on training at the end of the epoch
        print("The training loss on epoch {} is {}...".format(epoch, running_loss))
        
        # Print loss on validation set at the end of the epoch
        print("The validation loss on epoch {} is {}...".format(epoch, running_test_loss))
    
    # Print message
    print('Done training...')

In [5]:
train(epochs=100)

  data = torch.from_numpy(idx2numpy.convert_from_file(f))


The training loss on epoch 0 is 38294.85690498352...
The validation loss on epoch 0 is 92.077960729599...
The training loss on epoch 1 is 516.4564434289932...
The validation loss on epoch 1 is 85.66543364524841...
The training loss on epoch 2 is 507.0219702720642...
The validation loss on epoch 2 is 84.80469799041748...
The training loss on epoch 3 is 520.9553354978561...
The validation loss on epoch 3 is 86.97969174385071...
The training loss on epoch 4 is 520.2812601327896...
The validation loss on epoch 4 is 84.92408263683319...
The training loss on epoch 5 is 476.0395314693451...
The validation loss on epoch 5 is 81.8424369096756...
The training loss on epoch 6 is 460.359920501709...
The validation loss on epoch 6 is 76.86101973056793...
The training loss on epoch 7 is 457.15140330791473...
The validation loss on epoch 7 is 76.73743331432343...
The training loss on epoch 8 is 497.0734432935715...
The validation loss on epoch 8 is 85.54135537147522...
The training loss on epoch 9 is

KeyboardInterrupt: 