In [83]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [84]:
input_size = 784
hidden_size = 512
num_classes = 10
num_epochs = 20
batch_size = 100
learning_rate = 0.05

In [85]:
train_dataset = torchvision.datasets.MNIST(root='data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False) 

In [86]:
train_full_dataset = torch.utils.data.DataLoader(dataset=train_dataset, 
                                                 batch_size=1, 
                                                 shuffle=True)
test_full_dataset = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=1, 
                                          shuffle=False) 

In [87]:
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.fc2 = nn.Linear(hidden_size, hidden_size) 
        self.fc3 = nn.Linear(hidden_size, num_classes)  
    
    def forward(self, x):
        x = self.fc1(x)
        x = x.relu()
        x = self.fc2(x)
        x = x.relu()
        x = self.fc3(x)
        return x

model = NeuralNet(input_size, hidden_size, num_classes)

In [88]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) 

In [89]:
total_step = len(train_loader)

def train_epoch():
    for i, (images, labels) in enumerate(train_loader):  
        # Move tensors to the configured device
        images = images.reshape(-1, 28*28)
        labels = labels
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backprpagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        

    return loss.item()

def val_loss(data_loader):
    loss=0
    with torch.no_grad():
        correct = 0
        total = 0
        for image, label in data_loader:
            image = image.reshape(-1, 28*28)
            outputs = model(image)
            loss+=criterion(outputs, label).item()
    return loss/(len(data_loader))

max_validation_loss=1
counter=0
for epoch in range(num_epochs):
    loss=train_epoch()
    if loss<0.05:
        validation_loss=val_loss(test_full_dataset)
        if validation_loss<max_validation_loss:
            torch.save(model, './models/validated_model')
            max_validation_loss=validation_loss
            counter+=1
        print ('Epoch [{}], Loss: {}, Validation Loss: {}'.format(epoch, loss, validation_loss))
    else:
        print ('Epoch [{}], Loss: {}' .format(epoch, loss))
    if counter>=5:
        break

    

Epoch [0], Loss: 0.4184012711048126
Epoch [1], Loss: 0.4668915569782257
Epoch [2], Loss: 0.2248646318912506
Epoch [3], Loss: 0.07911556214094162
Epoch [4], Loss: 0.14249642193317413
Epoch [5], Loss: 0.261027067899704
Epoch [6], Loss: 0.08662015199661255
Epoch [7], Loss: 0.1273268461227417
Epoch [8], Loss: 0.08360706269741058
Epoch [9], Loss: 0.08107130974531174
Epoch [10], Loss: 0.1023254469037056
Epoch [11], Loss: 0.08338363468647003
Epoch [12], Loss: 0.06746060401201248
Epoch [13], Loss: 0.09143529087305069
Epoch [14], Loss: 0.03087480552494526, Validation Loss0.07195893618873657 
Epoch [15], Loss: 0.0074791694059967995, Validation Loss0.07202874999851865 
Epoch [16], Loss: 0.06196217238903046
Epoch [17], Loss: 0.02550596557557583, Validation Loss0.07203861293332907 
Epoch [18], Loss: 0.022354183718562126, Validation Loss0.06608316387006687 
Epoch [19], Loss: 0.01441866159439087, Validation Loss0.06516921123234337 


In [90]:
model = torch.load('./models/validated_model')

In [95]:
train_loss=0
with torch.no_grad():
    correct = 0
    total = 0
    for image, label in train_full_dataset:
        image = image.reshape(-1, 28*28)
        outputs = model(image)
        loss+=criterion(outputs, label).item()
    print(loss/len(train_dataset))

validation_loss_list=[]
with torch.no_grad():
    correct = 0
    total = 0
    for image, label in test_full_dataset:
        image = image.reshape(-1, 28*28)
        outputs = model(image)
        validation_loss_list.append(criterion(outputs, label).item())
    print(sum(validation_loss_list)/len(validation_loss_list))

0.03857052959564235
0.06516921123234337


In [99]:
gamma_squared=0.05**2
C=1

In [100]:
generalization_bound=sum(validation_loss_list)/len(validation_loss_list)+(2*C*np.log(counter/0.1))/(3*len(validation_loss_list))+np.sqrt((2*gamma_squared)*np.log(counter/0.1)/len(validation_loss_list))
print(generalization_bound)

0.06670002777273805
