In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable

In [2]:
# logistic and linear function does not handle non linear function well
# so artificial neural networks are introduced

<h3> Introduction to Non linear function </h3>

<h3> Non linear function </h3>
<ul>
    <li> Function takes a number and performs mathmathical operation</li>
    <li>Common types of non-linear function</li>
   <ul>
    <li>ReLUs(Rectified linear units)</li>
    <li>Sigmoid</li>
    <li>Tanh</li>
   </ul>
</ul>
    

 tanh(x) = 2sig(2x) -1

<b> Loading the mnist dataset </b>

In [3]:
train_dataset = dsets.MNIST(root = './data', train = True, transform= transforms.ToTensor(), download = True)
test_dataset = dsets.MNIST(root = './data', train = False, transform= transforms.ToTensor(), download=True)

<h2> Make Datasets iterable </h2>

In [4]:
batch_size = 100
n_iters = 3000
n_epochs = int(n_iters/(len(train_dataset)/batch_size))

In [5]:
train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size = batch_size, shuffle = True)

<h3> Create a model class </h3>

In [6]:
class FeedForwardNeuralNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(FeedForwardNeuralNetwork, self).__init__()
        
        # linear function
        self.fc1 = nn.Linear(input_dim, hidden_size)
        
        # non linearilty
        self.sigmoid = nn.Sigmoid()
        
        # Linear function 
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        # linear function
        out = self.fc1(x)
        
        # Non linearity
        out = self.sigmoid(out)
        
        # Linear function readout
        out = self.fc2(out)
        
        return out
        
        


In [7]:
input_dim = 28*28
hidden_dim = 100
output_dim = 10
model = FeedForwardNeuralNetwork(input_dim, hidden_dim, output_dim)

In [8]:
# loss function
criterion = nn.CrossEntropyLoss()

In [9]:
learning_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

<h3> Parameters in depth </h3>

In [10]:
print(model.parameters())

print(len(list(model.parameters())))

# hidden layer parameters
print(list(model.parameters())[0].size())

# layer 1 bias parameters
print(list(model.parameters())[1].size())

# FC layer 2 parameters
print(list(model.parameters())[2].size())

# FC bias 2 parameters
print(list(model.parameters())[3].size())


<generator object Module.parameters at 0x1225afb88>
4
torch.Size([100, 784])
torch.Size([100])
torch.Size([10, 100])
torch.Size([10])


In [11]:
a =torch.tensor([[1,2,3,4,5], [3,4,5,1,2]])
torch.max(a, 0)

torch.return_types.max(
values=tensor([3, 4, 5, 4, 5]),
indices=tensor([1, 1, 1, 0, 0]))

In [None]:
iter = 0
for epoch in range(n_iters):
    for i,(images, labels) in enumerate(train_loader):
        # load image as variable
        images = Variable(images.view(-1, 28*28))
        labels = Variable(labels)
        
        # clear the gradient
        optimizer.zero_grad()
        
        output = model(images)
        
        # calculate the loss
        loss = criterion(output, labels)
        
        # back propagate the loss
        loss.backward()
        
        # update gradient
        optimizer.step()
        
        iter += 1
        
        if iter%500 == 0:
            # calculate the accuracy
            correct = 0
            total = 0
            
            # Iterate through test dataset
            for images, labels in (test_loader):
                images = Variable(images.view(-1, 28*28))
                
                # Forward pass only to get logits
                outputs = model(images)
                
                # get predictions from the maximum value
                _, predictions = torch.max(outputs.data, 1)
                
                total += labels.size(0)
                
                correct += (predictions == labels).sum()
            
            accuracy = 100 * correct / total
            
            print("{} : iteractions loss : {} accuracy : {}".format(iter, loss, accuracy))
            
                
        

500 : iteractions loss : 0.45630669593811035 accuracy : 89
1000 : iteractions loss : 0.4580514430999756 accuracy : 90
1500 : iteractions loss : 0.25246304273605347 accuracy : 91
2000 : iteractions loss : 0.3666258156299591 accuracy : 91
2500 : iteractions loss : 0.1772182136774063 accuracy : 92
3000 : iteractions loss : 0.4024144411087036 accuracy : 92
3500 : iteractions loss : 0.22104434669017792 accuracy : 92
4000 : iteractions loss : 0.18438759446144104 accuracy : 92
4500 : iteractions loss : 0.11806178838014603 accuracy : 93
5000 : iteractions loss : 0.2891945540904999 accuracy : 93
5500 : iteractions loss : 0.22974300384521484 accuracy : 93
6000 : iteractions loss : 0.1519843339920044 accuracy : 93
6500 : iteractions loss : 0.1559676080942154 accuracy : 93
7000 : iteractions loss : 0.20107242465019226 accuracy : 94
7500 : iteractions loss : 0.23821835219860077 accuracy : 94
8000 : iteractions loss : 0.1865805685520172 accuracy : 94
8500 : iteractions loss : 0.18559791147708893 acc

68000 : iteractions loss : 0.030006852000951767 accuracy : 97
68500 : iteractions loss : 0.03552922606468201 accuracy : 97
69000 : iteractions loss : 0.04476220905780792 accuracy : 97
69500 : iteractions loss : 0.03285985812544823 accuracy : 97
70000 : iteractions loss : 0.04099128395318985 accuracy : 97
70500 : iteractions loss : 0.01690119318664074 accuracy : 97
71000 : iteractions loss : 0.010495185852050781 accuracy : 97
71500 : iteractions loss : 0.016708746552467346 accuracy : 97
72000 : iteractions loss : 0.02790401503443718 accuracy : 97
72500 : iteractions loss : 0.03359727934002876 accuracy : 97
73000 : iteractions loss : 0.021600885316729546 accuracy : 97
73500 : iteractions loss : 0.04425700008869171 accuracy : 97
74000 : iteractions loss : 0.029869861900806427 accuracy : 97
74500 : iteractions loss : 0.032776251435279846 accuracy : 97
75000 : iteractions loss : 0.028734512627124786 accuracy : 97
75500 : iteractions loss : 0.026692576706409454 accuracy : 97
76000 : iteracti

134000 : iteractions loss : 0.009069671854376793 accuracy : 97
134500 : iteractions loss : 0.024379540234804153 accuracy : 97
135000 : iteractions loss : 0.01609048806130886 accuracy : 97
135500 : iteractions loss : 0.012657937593758106 accuracy : 97
136000 : iteractions loss : 0.02653002180159092 accuracy : 97
136500 : iteractions loss : 0.023204375058412552 accuracy : 97
137000 : iteractions loss : 0.013562760315835476 accuracy : 97
137500 : iteractions loss : 0.005784902721643448 accuracy : 97
138000 : iteractions loss : 0.00947279017418623 accuracy : 97
138500 : iteractions loss : 0.010309887118637562 accuracy : 97
139000 : iteractions loss : 0.01418598648160696 accuracy : 97
139500 : iteractions loss : 0.009457430802285671 accuracy : 97
140000 : iteractions loss : 0.033252790570259094 accuracy : 97
140500 : iteractions loss : 0.03359850496053696 accuracy : 97
141000 : iteractions loss : 0.011823692359030247 accuracy : 97
141500 : iteractions loss : 0.004582958295941353 accuracy : 

199500 : iteractions loss : 0.005736451130360365 accuracy : 97
200000 : iteractions loss : 0.00697840703651309 accuracy : 97
200500 : iteractions loss : 0.05383594334125519 accuracy : 98
201000 : iteractions loss : 0.004776735324412584 accuracy : 98
201500 : iteractions loss : 0.004580287728458643 accuracy : 97
202000 : iteractions loss : 0.005619478411972523 accuracy : 98
202500 : iteractions loss : 0.0032481146045029163 accuracy : 97
203000 : iteractions loss : 0.0030958557035773993 accuracy : 98
203500 : iteractions loss : 0.008182420395314693 accuracy : 98
204000 : iteractions loss : 0.005643878132104874 accuracy : 98
204500 : iteractions loss : 0.010061120614409447 accuracy : 98
205000 : iteractions loss : 0.014548173174262047 accuracy : 98
205500 : iteractions loss : 0.007481088861823082 accuracy : 97
206000 : iteractions loss : 0.007104468531906605 accuracy : 97
206500 : iteractions loss : 0.00450787553563714 accuracy : 97
207000 : iteractions loss : 0.013545837253332138 accurac

264500 : iteractions loss : 0.004617924802005291 accuracy : 98
265000 : iteractions loss : 0.009287566877901554 accuracy : 98
265500 : iteractions loss : 0.0050895120948553085 accuracy : 98
266000 : iteractions loss : 0.005351042840629816 accuracy : 98
266500 : iteractions loss : 0.005031585693359375 accuracy : 98
267000 : iteractions loss : 0.00903852004557848 accuracy : 98
267500 : iteractions loss : 0.0032728195656090975 accuracy : 98
268000 : iteractions loss : 0.007026767823845148 accuracy : 98
268500 : iteractions loss : 0.01123734936118126 accuracy : 98
269000 : iteractions loss : 0.0024274445604532957 accuracy : 98
269500 : iteractions loss : 0.002072887495160103 accuracy : 98
270000 : iteractions loss : 0.006603493820875883 accuracy : 98
270500 : iteractions loss : 0.0037044240161776543 accuracy : 97
271000 : iteractions loss : 0.005933699663728476 accuracy : 98
271500 : iteractions loss : 0.002902393229305744 accuracy : 98
272000 : iteractions loss : 0.0017797374166548252 acc

329500 : iteractions loss : 0.003108501434326172 accuracy : 98
330000 : iteractions loss : 0.0051877498626708984 accuracy : 98
330500 : iteractions loss : 0.0025632763281464577 accuracy : 98
331000 : iteractions loss : 0.004686283878982067 accuracy : 98
331500 : iteractions loss : 0.004593224730342627 accuracy : 98
332000 : iteractions loss : 0.0020028972066938877 accuracy : 98
332500 : iteractions loss : 0.005038785748183727 accuracy : 97
333000 : iteractions loss : 0.0032105399295687675 accuracy : 98
333500 : iteractions loss : 0.0031856726855039597 accuracy : 97
334000 : iteractions loss : 0.004327511880546808 accuracy : 98
334500 : iteractions loss : 0.002520823385566473 accuracy : 98
335000 : iteractions loss : 0.0020316028967499733 accuracy : 97
335500 : iteractions loss : 0.0016888618702068925 accuracy : 98
336000 : iteractions loss : 0.004080362152308226 accuracy : 97
336500 : iteractions loss : 0.003279981669038534 accuracy : 98
337000 : iteractions loss : 0.003428506897762418

394500 : iteractions loss : 0.004315996076911688 accuracy : 98
395000 : iteractions loss : 0.0019147682469338179 accuracy : 98
395500 : iteractions loss : 0.0038886594120413065 accuracy : 98
396000 : iteractions loss : 0.0016119336942210793 accuracy : 98
396500 : iteractions loss : 0.003918905276805162 accuracy : 98
397000 : iteractions loss : 0.0008409118745476007 accuracy : 97
397500 : iteractions loss : 0.004039979074150324 accuracy : 98
398000 : iteractions loss : 0.004389381501823664 accuracy : 98
398500 : iteractions loss : 0.004571742843836546 accuracy : 98
399000 : iteractions loss : 0.0032983541022986174 accuracy : 97
399500 : iteractions loss : 0.001578369177877903 accuracy : 98
400000 : iteractions loss : 0.0023169517517089844 accuracy : 98
400500 : iteractions loss : 0.0018194913864135742 accuracy : 98
401000 : iteractions loss : 0.00224723806604743 accuracy : 97
401500 : iteractions loss : 0.002335252705961466 accuracy : 98
402000 : iteractions loss : 0.0010917282197624445