In [9]:
# Import PyTorch

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

# Import other packages

import matplotlib.pyplot as plt
import numpy as np

In [10]:
# Define the hyper-parameters of the model

input_size = 784
hidden_size = 512
num_classes = 10
num_epochs = 20
batch_size = 100
learning_rate = 0.05

In [11]:
# Import the MNIST dataset as two separate datasets

train_dataset = torchvision.datasets.MNIST(root='data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

validation_dataset = torchvision.datasets.MNIST(root='data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

# Create the data loader for training
train_dataset_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

# Create the data loader for validating

validation_dataset_loader = torch.utils.data.DataLoader(dataset=validation_dataset, 
                                          batch_size=1, 
                                          shuffle=False) 

In [12]:
# Defining a FC neural network

class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.fc2 = nn.Linear(hidden_size, hidden_size) 
        self.fc3 = nn.Linear(hidden_size, num_classes)  
    
    def forward(self, x):
        x = self.fc1(x)
        x = x.relu()
        x = self.fc2(x)
        x = x.relu()
        x = self.fc3(x)
        return x

model = NeuralNet(input_size, hidden_size, num_classes)

In [13]:
# Use CrossEntropyLoss as its bounded between 0-1

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) 

In [14]:
# Training Epoch

def train_epoch():
    for i, (images, labels) in enumerate(train_dataset_loader):  
        # Move tensors to the configured device
        images = images.reshape(-1, 28*28)
        labels = labels
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backprpagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        

    return loss.item()

# Obtain empirical risk of model from the validation set

def empirical_validation_risk(data_loader):
    loss_list=[]
    with torch.no_grad():
        for image, label in data_loader:
            image = image.reshape(-1, 28*28)
            outputs = model(image)
            loss_list.append(criterion(outputs, label).item())
    return sum(loss_list)/(len(loss_list))
    
def empirical_validation_risk_with_list(data_loader):
    loss_list=[]
    with torch.no_grad():
        for image, label in data_loader:
            image = image.reshape(-1, 28*28)
            outputs = model(image)
            loss_list.append(criterion(outputs, label).item())
    return sum(loss_list)/(len(loss_list)), loss_list
    

In [15]:
# Train the model

max_validation_loss=1

training_epochs=0
validation_epochs=0
while training_epochs<10 or validation_epochs<10:
    loss=train_epoch()
    if loss<0.05:
        validation_loss=empirical_validation_risk(validation_dataset_loader)
        if validation_loss<max_validation_loss:
            torch.save(model, './models/validated_model')
            max_validation_loss=validation_loss
        validation_epochs+=1
        print ('Epoch [{}], Loss: {}, Validation Loss: {}'.format(training_epochs+validation_epochs, loss, validation_loss))
    else:
        training_epochs+=1
        print ('Epoch [{}], Loss: {}'.format(training_epochs+validation_epochs, loss))

Epoch [1], Loss: 0.24942602217197418
Epoch [2], Loss: 0.2768819332122803
Epoch [3], Loss: 0.19234256446361542
Epoch [4], Loss: 0.24565021693706512
Epoch [5], Loss: 0.2121107131242752
Epoch [6], Loss: 0.23824289441108704
Epoch [7], Loss: 0.059267621487379074
Epoch [8], Loss: 0.10241420567035675
Epoch [9], Loss: 0.0981324315071106
Epoch [10], Loss: 0.1258523315191269
Epoch [11], Loss: 0.10823557525873184
Epoch [12], Loss: 0.0665387362241745
Epoch [13], Loss: 0.07587321102619171
Epoch [14], Loss: 0.0832892507314682
Epoch [15], Loss: 0.014344729483127594, Validation Loss: 0.07480454442187828
Epoch [16], Loss: 0.04845869913697243, Validation Loss: 0.07162741835739166
Epoch [17], Loss: 0.02243698574602604, Validation Loss: 0.06816520196682499
Epoch [18], Loss: 0.031495530158281326, Validation Loss: 0.06951017985482802
Epoch [19], Loss: 0.03017069771885872, Validation Loss: 0.06715069839426477
Epoch [20], Loss: 0.017194082960486412, Validation Loss: 0.06692903898719132
Epoch [21], Loss: 0.018

In [16]:
model = torch.load('./models/validated_model')

In [17]:
# delta and Gamma_squared chosen reasonably, C chosen for worst-case, 

delta=0.1
gamma_squared=0.05**2
C=1

validation_risk_value,validation_risk_list=empirical_validation_risk_with_list(validation_dataset_loader)
m_val=len(validation_risk_list)

# Calculate bound on the on the expected risk

generalization_bound=validation_risk_value+(2*C*np.log(validation_epochs/delta))/(3*m_val)+np.sqrt((2*gamma_squared)*np.log(validation_epochs/delta)/m_val)

In [18]:
print(generalization_bound)

0.0645282108138849
