In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable

In [2]:
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

In [3]:
batch_size = 100
n_iters = 3000
num_epochs = n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

In [4]:
# create a model class 
# input->linear function->non linear function(sigmoid)->linear function->Softmax->CrossEntropy
class FeedFowardNeuralNetwork(nn.Module): 
    def __init__(self,input_dim,hidden_dim,output_dim):
        super(FeedFowardNeuralNetwork,self).__init__()
        # linear function 784->100
        self.fc1=nn.Linear(input_dim,hidden_dim)
        # Non linearity 1
        self.relu1 = nn.ReLU()
        
        # linear function 100->100
        self.fc2=nn.Linear(hidden_dim,hidden_dim)
        # Non linearity 2
        self.relu2 = nn.ReLU()
        
        # linear function (readout) 100->10
        self.fc3=nn.Linear(hidden_dim,output_dim)
        
        
    def forward(self,x): 
        out=self.fc1(x)
        out=self.relu1(out)
        out=self.fc2(out)
        out=self.relu2(out)
        out=self.fc3(out)
        return out 
    

In [5]:
# instantiate the model 
input_dim=28*28 
hidden_dim=100
output_dim=10 
model=FeedFowardNeuralNetwork(input_dim,hidden_dim,output_dim)

In [6]:
# instantiate the loss 
criterion=nn.CrossEntropyLoss()

In [7]:
# instatiate the optimize 
learning_rate=0.1
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)

In [8]:
print(len(list(model.parameters())))
print(list(model.parameters())[0].size())
print(list(model.parameters())[1].size())
print(list(model.parameters())[2].size())
print(list(model.parameters())[3].size())


6
torch.Size([100, 784])
torch.Size([100])
torch.Size([100, 100])
torch.Size([100])


In [9]:
# Training 
# convert input into variable 
iter=0 
for epochs in range(num_epochs):
    for i,(images,labels) in enumerate (train_loader): 
        images = Variable(images.view(-1,28*28))
        labels = Variable(labels)
        
    # clear the gradients
        optimizer.zero_grad()
    
    # foward pass to obtain the outputs
        outputs = model(images)
    
    #loss 
        loss = criterion(outputs,labels)
    
    #backpropagation 
        loss.backward()
        optimizer.step()
    
    
        iter+=1

        print('Iteration: {}. Loss: {}'.format(iter, loss.data[0]))

        
       

Iteration: 1. Loss: 2.295961856842041
Iteration: 2. Loss: 2.3094120025634766
Iteration: 3. Loss: 2.2979609966278076
Iteration: 4. Loss: 2.3090195655822754
Iteration: 5. Loss: 2.299222707748413
Iteration: 6. Loss: 2.2961041927337646
Iteration: 7. Loss: 2.296984910964966
Iteration: 8. Loss: 2.286041021347046
Iteration: 9. Loss: 2.2845804691314697
Iteration: 10. Loss: 2.2931692600250244
Iteration: 11. Loss: 2.2816431522369385
Iteration: 12. Loss: 2.2820863723754883
Iteration: 13. Loss: 2.2977688312530518
Iteration: 14. Loss: 2.2785065174102783
Iteration: 15. Loss: 2.286782741546631
Iteration: 16. Loss: 2.2735514640808105
Iteration: 17. Loss: 2.2741024494171143
Iteration: 18. Loss: 2.288369655609131
Iteration: 19. Loss: 2.261913537979126
Iteration: 20. Loss: 2.256812334060669
Iteration: 21. Loss: 2.2638347148895264
Iteration: 22. Loss: 2.258091688156128
Iteration: 23. Loss: 2.253525495529175
Iteration: 24. Loss: 2.263747453689575
Iteration: 25. Loss: 2.242072343826294
Iteration: 26. Loss: 

Iteration: 212. Loss: 0.52409827709198
Iteration: 213. Loss: 0.5392146706581116
Iteration: 214. Loss: 0.3523859679698944
Iteration: 215. Loss: 0.33588922023773193
Iteration: 216. Loss: 0.47253063321113586
Iteration: 217. Loss: 0.5903775095939636
Iteration: 218. Loss: 0.46703213453292847
Iteration: 219. Loss: 0.44318535923957825
Iteration: 220. Loss: 0.5025017857551575
Iteration: 221. Loss: 0.4242366552352905
Iteration: 222. Loss: 0.3857786953449249
Iteration: 223. Loss: 0.5491111278533936
Iteration: 224. Loss: 0.5383682250976562
Iteration: 225. Loss: 0.46171218156814575
Iteration: 226. Loss: 0.6248679161071777
Iteration: 227. Loss: 0.4019682705402374
Iteration: 228. Loss: 0.43347567319869995
Iteration: 229. Loss: 0.36112672090530396
Iteration: 230. Loss: 0.5524498224258423
Iteration: 231. Loss: 0.47621095180511475
Iteration: 232. Loss: 0.3991951644420624
Iteration: 233. Loss: 0.39684009552001953
Iteration: 234. Loss: 0.34384486079216003
Iteration: 235. Loss: 0.38935500383377075
Iterati

Iteration: 423. Loss: 0.38746702671051025
Iteration: 424. Loss: 0.3171207308769226
Iteration: 425. Loss: 0.3048940896987915
Iteration: 426. Loss: 0.36534592509269714
Iteration: 427. Loss: 0.23228545486927032
Iteration: 428. Loss: 0.3460231423377991
Iteration: 429. Loss: 0.2580827474594116
Iteration: 430. Loss: 0.18930314481258392
Iteration: 431. Loss: 0.36882269382476807
Iteration: 432. Loss: 0.488783061504364
Iteration: 433. Loss: 0.36127251386642456
Iteration: 434. Loss: 0.2713894546031952
Iteration: 435. Loss: 0.3811684846878052
Iteration: 436. Loss: 0.290493905544281
Iteration: 437. Loss: 0.3426871597766876
Iteration: 438. Loss: 0.5136816501617432
Iteration: 439. Loss: 0.3308650553226471
Iteration: 440. Loss: 0.38806259632110596
Iteration: 441. Loss: 0.3339191675186157
Iteration: 442. Loss: 0.38413456082344055
Iteration: 443. Loss: 0.27067622542381287
Iteration: 444. Loss: 0.21475204825401306
Iteration: 445. Loss: 0.58870929479599
Iteration: 446. Loss: 0.28163257241249084
Iteration

Iteration: 625. Loss: 0.24653130769729614
Iteration: 626. Loss: 0.33561062812805176
Iteration: 627. Loss: 0.3062237799167633
Iteration: 628. Loss: 0.40125054121017456
Iteration: 629. Loss: 0.21376147866249084
Iteration: 630. Loss: 0.2898998558521271
Iteration: 631. Loss: 0.3975120186805725
Iteration: 632. Loss: 0.25683584809303284
Iteration: 633. Loss: 0.24584345519542694
Iteration: 634. Loss: 0.26782599091529846
Iteration: 635. Loss: 0.24524062871932983
Iteration: 636. Loss: 0.26072898507118225
Iteration: 637. Loss: 0.30389198660850525
Iteration: 638. Loss: 0.3211441934108734
Iteration: 639. Loss: 0.18591050803661346
Iteration: 640. Loss: 0.45671990513801575
Iteration: 641. Loss: 0.20203159749507904
Iteration: 642. Loss: 0.22813577950000763
Iteration: 643. Loss: 0.3572346568107605
Iteration: 644. Loss: 0.32393860816955566
Iteration: 645. Loss: 0.12171108275651932
Iteration: 646. Loss: 0.21185527741909027
Iteration: 647. Loss: 0.29927897453308105
Iteration: 648. Loss: 0.255486786365509

Iteration: 841. Loss: 0.46300697326660156
Iteration: 842. Loss: 0.2518138587474823
Iteration: 843. Loss: 0.3501167297363281
Iteration: 844. Loss: 0.20732203125953674
Iteration: 845. Loss: 0.26279416680336
Iteration: 846. Loss: 0.28035590052604675
Iteration: 847. Loss: 0.2462201565504074
Iteration: 848. Loss: 0.25746557116508484
Iteration: 849. Loss: 0.1530621349811554
Iteration: 850. Loss: 0.18941211700439453
Iteration: 851. Loss: 0.2710011601448059
Iteration: 852. Loss: 0.18777970969676971
Iteration: 853. Loss: 0.3942946493625641
Iteration: 854. Loss: 0.18894244730472565
Iteration: 855. Loss: 0.22594988346099854
Iteration: 856. Loss: 0.35841119289398193
Iteration: 857. Loss: 0.1709022969007492
Iteration: 858. Loss: 0.19502025842666626
Iteration: 859. Loss: 0.26366207003593445
Iteration: 860. Loss: 0.29390618205070496
Iteration: 861. Loss: 0.3081251084804535
Iteration: 862. Loss: 0.23879389464855194
Iteration: 863. Loss: 0.33630838990211487
Iteration: 864. Loss: 0.29148757457733154
Ite

Iteration: 1050. Loss: 0.46925172209739685
Iteration: 1051. Loss: 0.26398342847824097
Iteration: 1052. Loss: 0.22837619483470917
Iteration: 1053. Loss: 0.21231380105018616
Iteration: 1054. Loss: 0.3605397045612335
Iteration: 1055. Loss: 0.27589523792266846
Iteration: 1056. Loss: 0.17517417669296265
Iteration: 1057. Loss: 0.11717342585325241
Iteration: 1058. Loss: 0.17210422456264496
Iteration: 1059. Loss: 0.33676743507385254
Iteration: 1060. Loss: 0.18140295147895813
Iteration: 1061. Loss: 0.3245254456996918
Iteration: 1062. Loss: 0.2031809240579605
Iteration: 1063. Loss: 0.23226088285446167
Iteration: 1064. Loss: 0.3168601393699646
Iteration: 1065. Loss: 0.4085843563079834
Iteration: 1066. Loss: 0.19557586312294006
Iteration: 1067. Loss: 0.23111620545387268
Iteration: 1068. Loss: 0.1688891053199768
Iteration: 1069. Loss: 0.16908077895641327
Iteration: 1070. Loss: 0.247115358710289
Iteration: 1071. Loss: 0.19905158877372742
Iteration: 1072. Loss: 0.1406765729188919
Iteration: 1073. Los

Iteration: 1243. Loss: 0.2300473153591156
Iteration: 1244. Loss: 0.19182845950126648
Iteration: 1245. Loss: 0.2804933786392212
Iteration: 1246. Loss: 0.22534990310668945
Iteration: 1247. Loss: 0.1606300324201584
Iteration: 1248. Loss: 0.21813608705997467
Iteration: 1249. Loss: 0.2261546552181244
Iteration: 1250. Loss: 0.23019638657569885
Iteration: 1251. Loss: 0.16906596720218658
Iteration: 1252. Loss: 0.21594925224781036
Iteration: 1253. Loss: 0.326426237821579
Iteration: 1254. Loss: 0.2085643708705902
Iteration: 1255. Loss: 0.20136278867721558
Iteration: 1256. Loss: 0.2671079933643341
Iteration: 1257. Loss: 0.08294112235307693
Iteration: 1258. Loss: 0.31445106863975525
Iteration: 1259. Loss: 0.2343633621931076
Iteration: 1260. Loss: 0.22963634133338928
Iteration: 1261. Loss: 0.29970982670783997
Iteration: 1262. Loss: 0.22053951025009155
Iteration: 1263. Loss: 0.13376396894454956
Iteration: 1264. Loss: 0.15164993703365326
Iteration: 1265. Loss: 0.17599809169769287
Iteration: 1266. Los

Iteration: 1455. Loss: 0.20864222943782806
Iteration: 1456. Loss: 0.1407012641429901
Iteration: 1457. Loss: 0.14702217280864716
Iteration: 1458. Loss: 0.12444231659173965
Iteration: 1459. Loss: 0.16282939910888672
Iteration: 1460. Loss: 0.14273768663406372
Iteration: 1461. Loss: 0.18562918901443481
Iteration: 1462. Loss: 0.12535293400287628
Iteration: 1463. Loss: 0.09136098623275757
Iteration: 1464. Loss: 0.237280011177063
Iteration: 1465. Loss: 0.1835484355688095
Iteration: 1466. Loss: 0.13793104887008667
Iteration: 1467. Loss: 0.16291353106498718
Iteration: 1468. Loss: 0.2144044041633606
Iteration: 1469. Loss: 0.22938823699951172
Iteration: 1470. Loss: 0.09495945274829865
Iteration: 1471. Loss: 0.16716113686561584
Iteration: 1472. Loss: 0.21421998739242554
Iteration: 1473. Loss: 0.2570391893386841
Iteration: 1474. Loss: 0.2102416604757309
Iteration: 1475. Loss: 0.21700622141361237
Iteration: 1476. Loss: 0.25264301896095276
Iteration: 1477. Loss: 0.21531341969966888
Iteration: 1478. L

Iteration: 1666. Loss: 0.07227207720279694
Iteration: 1667. Loss: 0.18023069202899933
Iteration: 1668. Loss: 0.16704846918582916
Iteration: 1669. Loss: 0.22658509016036987
Iteration: 1670. Loss: 0.18165569007396698
Iteration: 1671. Loss: 0.13240481913089752
Iteration: 1672. Loss: 0.14248567819595337
Iteration: 1673. Loss: 0.12971214950084686
Iteration: 1674. Loss: 0.15662698447704315
Iteration: 1675. Loss: 0.16601881384849548
Iteration: 1676. Loss: 0.2821389138698578
Iteration: 1677. Loss: 0.1418754607439041
Iteration: 1678. Loss: 0.15808002650737762
Iteration: 1679. Loss: 0.10914373397827148
Iteration: 1680. Loss: 0.18767137825489044
Iteration: 1681. Loss: 0.04923393204808235
Iteration: 1682. Loss: 0.12128537893295288
Iteration: 1683. Loss: 0.1040138527750969
Iteration: 1684. Loss: 0.1733105629682541
Iteration: 1685. Loss: 0.21557939052581787
Iteration: 1686. Loss: 0.3073490560054779
Iteration: 1687. Loss: 0.15158458054065704
Iteration: 1688. Loss: 0.14661437273025513
Iteration: 1689.

Iteration: 1876. Loss: 0.13683147728443146
Iteration: 1877. Loss: 0.2753555476665497
Iteration: 1878. Loss: 0.13018523156642914
Iteration: 1879. Loss: 0.19221292436122894
Iteration: 1880. Loss: 0.1189299002289772
Iteration: 1881. Loss: 0.14850859344005585
Iteration: 1882. Loss: 0.09713983535766602
Iteration: 1883. Loss: 0.10660839825868607
Iteration: 1884. Loss: 0.20686225593090057
Iteration: 1885. Loss: 0.07593448460102081
Iteration: 1886. Loss: 0.1023217961192131
Iteration: 1887. Loss: 0.06173054873943329
Iteration: 1888. Loss: 0.06545636802911758
Iteration: 1889. Loss: 0.15457206964492798
Iteration: 1890. Loss: 0.11386984586715698
Iteration: 1891. Loss: 0.21201211214065552
Iteration: 1892. Loss: 0.17151764035224915
Iteration: 1893. Loss: 0.14718057215213776
Iteration: 1894. Loss: 0.17441509664058685
Iteration: 1895. Loss: 0.1346878856420517
Iteration: 1896. Loss: 0.07990417629480362
Iteration: 1897. Loss: 0.1362656056880951
Iteration: 1898. Loss: 0.133161723613739
Iteration: 1899. L

Iteration: 2087. Loss: 0.1403389722108841
Iteration: 2088. Loss: 0.13842716813087463
Iteration: 2089. Loss: 0.3718460500240326
Iteration: 2090. Loss: 0.12884239852428436
Iteration: 2091. Loss: 0.3041526973247528
Iteration: 2092. Loss: 0.18910104036331177
Iteration: 2093. Loss: 0.12602640688419342
Iteration: 2094. Loss: 0.08410481363534927
Iteration: 2095. Loss: 0.1924135833978653
Iteration: 2096. Loss: 0.1448967009782791
Iteration: 2097. Loss: 0.2656461298465729
Iteration: 2098. Loss: 0.09655580669641495
Iteration: 2099. Loss: 0.12586528062820435
Iteration: 2100. Loss: 0.051081761717796326
Iteration: 2101. Loss: 0.10117034614086151
Iteration: 2102. Loss: 0.07454347610473633
Iteration: 2103. Loss: 0.20581063628196716
Iteration: 2104. Loss: 0.1624574512243271
Iteration: 2105. Loss: 0.17946098744869232
Iteration: 2106. Loss: 0.20175248384475708
Iteration: 2107. Loss: 0.14977926015853882
Iteration: 2108. Loss: 0.09532246738672256
Iteration: 2109. Loss: 0.2872868776321411
Iteration: 2110. L

Iteration: 2292. Loss: 0.06648117303848267
Iteration: 2293. Loss: 0.2550830543041229
Iteration: 2294. Loss: 0.1166798323392868
Iteration: 2295. Loss: 0.08113935589790344
Iteration: 2296. Loss: 0.1477367877960205
Iteration: 2297. Loss: 0.1811293661594391
Iteration: 2298. Loss: 0.09806524217128754
Iteration: 2299. Loss: 0.09435933828353882
Iteration: 2300. Loss: 0.10829871147871017
Iteration: 2301. Loss: 0.07827803492546082
Iteration: 2302. Loss: 0.11387621611356735
Iteration: 2303. Loss: 0.06160937622189522
Iteration: 2304. Loss: 0.11121319979429245
Iteration: 2305. Loss: 0.2334643304347992
Iteration: 2306. Loss: 0.14312410354614258
Iteration: 2307. Loss: 0.24361161887645721
Iteration: 2308. Loss: 0.12084422260522842
Iteration: 2309. Loss: 0.07328357547521591
Iteration: 2310. Loss: 0.17038924992084503
Iteration: 2311. Loss: 0.0622536763548851
Iteration: 2312. Loss: 0.19865824282169342
Iteration: 2313. Loss: 0.12760259211063385
Iteration: 2314. Loss: 0.06896121799945831
Iteration: 2315. 

Iteration: 2498. Loss: 0.17799502611160278
Iteration: 2499. Loss: 0.07634835690259933
Iteration: 2500. Loss: 0.06915847957134247
Iteration: 2501. Loss: 0.15350954234600067
Iteration: 2502. Loss: 0.04725457355380058
Iteration: 2503. Loss: 0.09147051721811295
Iteration: 2504. Loss: 0.12137772887945175
Iteration: 2505. Loss: 0.10118475556373596
Iteration: 2506. Loss: 0.13528773188591003
Iteration: 2507. Loss: 0.16166667640209198
Iteration: 2508. Loss: 0.06342534720897675
Iteration: 2509. Loss: 0.09067617356777191
Iteration: 2510. Loss: 0.13762634992599487
Iteration: 2511. Loss: 0.1289798617362976
Iteration: 2512. Loss: 0.13840621709823608
Iteration: 2513. Loss: 0.17919813096523285
Iteration: 2514. Loss: 0.14740075170993805
Iteration: 2515. Loss: 0.05432223156094551
Iteration: 2516. Loss: 0.13398195803165436
Iteration: 2517. Loss: 0.17174652218818665
Iteration: 2518. Loss: 0.13377691805362701
Iteration: 2519. Loss: 0.07290752977132797
Iteration: 2520. Loss: 0.12873274087905884
Iteration: 2

Iteration: 2694. Loss: 0.10947572439908981
Iteration: 2695. Loss: 0.12452099472284317
Iteration: 2696. Loss: 0.11676749587059021
Iteration: 2697. Loss: 0.06193318963050842
Iteration: 2698. Loss: 0.09527101367712021
Iteration: 2699. Loss: 0.22313323616981506
Iteration: 2700. Loss: 0.10393037647008896
Iteration: 2701. Loss: 0.13034778833389282
Iteration: 2702. Loss: 0.11037073284387589
Iteration: 2703. Loss: 0.09780462086200714
Iteration: 2704. Loss: 0.033624038100242615
Iteration: 2705. Loss: 0.06568307429552078
Iteration: 2706. Loss: 0.06719312071800232
Iteration: 2707. Loss: 0.03523131087422371
Iteration: 2708. Loss: 0.08525115251541138
Iteration: 2709. Loss: 0.08045417815446854
Iteration: 2710. Loss: 0.040569305419921875
Iteration: 2711. Loss: 0.06596865504980087
Iteration: 2712. Loss: 0.06525203585624695
Iteration: 2713. Loss: 0.2270570546388626
Iteration: 2714. Loss: 0.12099488079547882
Iteration: 2715. Loss: 0.05416035279631615
Iteration: 2716. Loss: 0.13030306994915009
Iteration:

Iteration: 2887. Loss: 0.043224919587373734
Iteration: 2888. Loss: 0.10188210755586624
Iteration: 2889. Loss: 0.1715763509273529
Iteration: 2890. Loss: 0.060319457203149796
Iteration: 2891. Loss: 0.07966947555541992
Iteration: 2892. Loss: 0.09946151077747345
Iteration: 2893. Loss: 0.08015517145395279
Iteration: 2894. Loss: 0.11153356730937958
Iteration: 2895. Loss: 0.15279501676559448
Iteration: 2896. Loss: 0.14878256618976593
Iteration: 2897. Loss: 0.06893664598464966
Iteration: 2898. Loss: 0.11635011434555054
Iteration: 2899. Loss: 0.12776978313922882
Iteration: 2900. Loss: 0.032608289271593094
Iteration: 2901. Loss: 0.09097747504711151
Iteration: 2902. Loss: 0.09696827828884125
Iteration: 2903. Loss: 0.17596352100372314
Iteration: 2904. Loss: 0.1967346966266632
Iteration: 2905. Loss: 0.0831632986664772
Iteration: 2906. Loss: 0.2039925456047058
Iteration: 2907. Loss: 0.07018424570560455
Iteration: 2908. Loss: 0.09525808691978455
Iteration: 2909. Loss: 0.17678146064281464
Iteration: 2

In [10]:
if iter %500 == 0:
            # Calculate Accuracy         
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
               
                images = Variable(images.view(-1, 28*28))
                
                # Forward pass only to get logits/output
                outputs = model(images)
                
                # Get predictions from the maximum value
                _, predicted = torch.max(outputs.data, 1)
                
                # Total number of labels
                total += labels.size(0)
                
                correct += (predicted.cpu() == labels.cpu()).sum()
            
            accuracy = 100 * correct / total
        
            # Print Loss
            print('%Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.data[0], accuracy))

%Iteration: 3000. Loss: 0.05800364911556244. Accuracy: 96.65
