In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from random import randint
import time

In [2]:
#Use GPU
device= torch.device("cuda")
print(device)

cuda


In [3]:
#Read in files
trainData=torch.load('mnist/train_data.pt')
trainLabel=torch.load('mnist/train_label.pt')
testData=torch.load('mnist/test_data.pt')
testLabel=torch.load('mnist/test_label.pt')

print(trainData.size())
print(testData.size())

torch.Size([60000, 28, 28])
torch.Size([10000, 28, 28])


In [4]:
#mean and std to normalize
mean= trainData.mean()
std= trainData.std()

print(mean)
print(std)

tensor(0.1307)
tensor(0.3081)


In [5]:
#CNN fucntion
class CNN(nn.Module):

    def __init__(self):

        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(1,   50,  kernel_size=3,  padding=1 )
        self.pool1  = nn.MaxPool2d(2,2)

        self.conv2 = nn.Conv2d(50,  100,  kernel_size=3,  padding=1 )
        self.pool2 = nn.MaxPool2d(2,2)
        
        self.conv3 = nn.Conv2d(100,  200,  kernel_size=3,  padding=1 )
 
        self.linear1 = nn.Linear(9800, 100)

        self.linear2 = nn.Linear(100,10)


    def forward(self, x):

        x = self.conv1(x)
        x = F.relu(x)        
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = F.relu(x)
  
        x = x.view(-1, 9800)
        x = self.linear1(x)
        x = F.relu(x)
        
        x = self.linear2(x)
    
        return x

In [6]:
net=CNN()

#Send to GPU
net = net.to(device)
mean=mean.to(device)
std=std.to(device)

In [7]:
#Hyperparameters

epochs = 30

lossFunction = nn.CrossEntropyLoss()

learningRate = 0.25 

batchSize = 128

In [8]:
#Calculate Error
def calcError( scores , labels ):

    bs=scores.size(0)
    predictedLabels = scores.argmax(dim=1)
    indicator = (predictedLabels == labels)
    numMatches=indicator.sum()
    
    return 1-numMatches.float()/bs 

In [9]:
#Evaluation Function

def evalTestData():

    runningError=0
    numBatches=0

    for i in range(0,10000,batchSize):

        miniBatchData =  testData[i:i+batchSize].unsqueeze(dim=1)
        miniBatchLabel= testLabel[i:i+batchSize]

        miniBatchData=miniBatchData.to(device)
        miniBatchLabel=miniBatchLabel.to(device)
        
        inputs = (miniBatchData - mean)/std

        scores=net( inputs ) 

        error = calcError( scores , miniBatchLabel)

        runningError += error.item()

        numBatches+=1


    totalError = runningError/numBatches
    print( 'error rate on test set =', totalError*100 ,'percent')

In [10]:
#Training Loop

start=time.time()

for epoch in range(1,epochs):
    
    if not epoch%5:
        learningRate = learningRate / 2
    
    #Optimizer Hyperparameter
    optimizer=torch.optim.SGD( net.parameters() , lr=learningRate )
        
    runningLoss=0
    runningError=0
    numBatches=0
    
    shuffledIndices=torch.randperm(60000)
 
    for count in range(0,60000,batchSize):
    
        optimizer.zero_grad()
             
        indices=shuffledIndices[count:count+batchSize]
        miniBatchData =  trainData[indices].unsqueeze(dim=1)
        miniBatchLabel=  trainLabel[indices]
        
        miniBatchData=miniBatchData.to(device)
        miniBatchLabel=miniBatchLabel.to(device)
        
        inputs = (miniBatchData - mean)/std  
        
        inputs.requires_grad_()

        scores=net( inputs ) 

        loss =  lossFunction( scores , miniBatchLabel) 
          
        loss.backward()        
        optimizer.step()
        

        # Calculate Error
        
        runningLoss += loss.detach().item()
        
        error = calcError( scores.detach() , miniBatchLabel)
        runningError += error.item()
        
        numBatches+=1        
    
    
    # Display average error
    totalLoss = runningLoss/numBatches
    totalError = runningError/numBatches
    elapsed = (time.time()-start)/60
    
    print('epoch=',epoch, '\t time=', elapsed,'min', '\t lr=', learningRate  ,'\t loss=', totalLoss , '\t error=', 
          totalError*100 ,'percent')
    evalTestData() 
    print(' ')
    
    

epoch= 1 	 time= 0.11693048477172852 min 	 lr= 0.25 	 loss= 0.26530257218888703 	 error= 8.419953798180197 percent
error rate on test set = 1.3053797468354431 percent
 
epoch= 2 	 time= 0.22628641923268636 min 	 lr= 0.25 	 loss= 0.045241959190476674 	 error= 1.3748223085139097 percent
error rate on test set = 0.9889240506329114 percent
 
epoch= 3 	 time= 0.33521777391433716 min 	 lr= 0.25 	 loss= 0.02887677943814538 	 error= 0.8767546112857648 percent
error rate on test set = 0.9295886075949367 percent
 
epoch= 4 	 time= 0.44418073892593385 min 	 lr= 0.25 	 loss= 0.020686477886017428 	 error= 0.6663113006396588 percent
error rate on test set = 1.0383702531645569 percent
 
epoch= 5 	 time= 0.5538660287857056 min 	 lr= 0.125 	 loss= 0.010076494273870612 	 error= 0.2804059972132701 percent
error rate on test set = 0.7120253164556962 percent
 
epoch= 6 	 time= 0.6637563506762186 min 	 lr= 0.125 	 loss= 0.006982160362798268 	 error= 0.17823827292110875 percent
error rate on test set = 0.712

KeyboardInterrupt: 