In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable

In [2]:
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

In [3]:
batch_size = 100
n_iters = 3000
num_epochs = n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

In [4]:
# create a model class 
# input->linear function->non linear function(sigmoid)->linear function->Softmax->CrossEntropy
class FeedFowardNeuralNetwork(nn.Module): 
    def __init__(self,input_dim,hidden_dim,output_dim):
        super(FeedFowardNeuralNetwork,self).__init__()
        self.fc1=nn.Linear(input_dim,hidden_dim)
        self.relu = nn.ReLU()
        self.fc2=nn.Linear(hidden_dim,output_dim)
        
    def forward(self,x): 
        out=self.fc1(x)
        out=self.relu(out)
        out=self.fc2(out)
        return out 
    

In [5]:
# instantiate the model 
input_dim=28*28 
hidden_dim=100
output_dim=10 
model=FeedFowardNeuralNetwork(input_dim,hidden_dim,output_dim)

In [6]:
# instantiate the loss 
criterion=nn.CrossEntropyLoss()

In [7]:
# instatiate the optimize 
learning_rate=0.1
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)

In [8]:
print(len(list(model.parameters())))
print(list(model.parameters())[0].size())
print(list(model.parameters())[1].size())
print(list(model.parameters())[2].size())
print(list(model.parameters())[3].size())


4
torch.Size([100, 784])
torch.Size([100])
torch.Size([10, 100])
torch.Size([10])


In [9]:
# Training 
# convert input into variable 
iter=0 
for epochs in range(num_epochs):
    for i,(images,labels) in enumerate (train_loader): 
        images = Variable(images.view(-1,28*28))
        labels = Variable(labels)
        
    # clear the gradients
        optimizer.zero_grad()
    
    # foward pass to obtain the outputs
        outputs = model(images)
    
    #loss 
        loss = criterion(outputs,labels)
    
    #backpropagation 
        loss.backward()
        optimizer.step()
    
    
        iter+=1

        print('Iteration: {}. Loss: {}'.format(iter, loss.data[0]))

        
       

Iteration: 1. Loss: 2.304568290710449
Iteration: 2. Loss: 2.3007898330688477
Iteration: 3. Loss: 2.2684624195098877
Iteration: 4. Loss: 2.2446839809417725
Iteration: 5. Loss: 2.231877088546753
Iteration: 6. Loss: 2.2281017303466797
Iteration: 7. Loss: 2.208669662475586
Iteration: 8. Loss: 2.182025909423828
Iteration: 9. Loss: 2.18385648727417
Iteration: 10. Loss: 2.1456480026245117
Iteration: 11. Loss: 2.1451170444488525
Iteration: 12. Loss: 2.0888419151306152
Iteration: 13. Loss: 2.1040329933166504
Iteration: 14. Loss: 2.053187608718872
Iteration: 15. Loss: 2.0388927459716797
Iteration: 16. Loss: 2.0337793827056885
Iteration: 17. Loss: 1.968064546585083
Iteration: 18. Loss: 1.9469878673553467
Iteration: 19. Loss: 1.9215055704116821
Iteration: 20. Loss: 1.8813903331756592
Iteration: 21. Loss: 1.8320906162261963
Iteration: 22. Loss: 1.8658249378204346
Iteration: 23. Loss: 1.8529778718948364
Iteration: 24. Loss: 1.8202382326126099
Iteration: 25. Loss: 1.8139169216156006
Iteration: 26. Lo

Iteration: 212. Loss: 0.47948750853538513
Iteration: 213. Loss: 0.335465669631958
Iteration: 214. Loss: 0.3196067810058594
Iteration: 215. Loss: 0.3712795674800873
Iteration: 216. Loss: 0.46826252341270447
Iteration: 217. Loss: 0.5317886471748352
Iteration: 218. Loss: 0.4105917811393738
Iteration: 219. Loss: 0.4404526948928833
Iteration: 220. Loss: 0.35654938220977783
Iteration: 221. Loss: 0.3082595467567444
Iteration: 222. Loss: 0.6001537442207336
Iteration: 223. Loss: 0.38756850361824036
Iteration: 224. Loss: 0.44036856293678284
Iteration: 225. Loss: 0.30504733324050903
Iteration: 226. Loss: 0.5231377482414246
Iteration: 227. Loss: 0.3859824240207672
Iteration: 228. Loss: 0.3989296853542328
Iteration: 229. Loss: 0.3993067145347595
Iteration: 230. Loss: 0.31584668159484863
Iteration: 231. Loss: 0.4281332492828369
Iteration: 232. Loss: 0.26937952637672424
Iteration: 233. Loss: 0.5000604391098022
Iteration: 234. Loss: 0.39248421788215637
Iteration: 235. Loss: 0.505332887172699
Iteration

Iteration: 415. Loss: 0.24716119468212128
Iteration: 416. Loss: 0.2716721296310425
Iteration: 417. Loss: 0.346912145614624
Iteration: 418. Loss: 0.21301555633544922
Iteration: 419. Loss: 0.5428553223609924
Iteration: 420. Loss: 0.42069634795188904
Iteration: 421. Loss: 0.41925084590911865
Iteration: 422. Loss: 0.3702988028526306
Iteration: 423. Loss: 0.23930507898330688
Iteration: 424. Loss: 0.4433562159538269
Iteration: 425. Loss: 0.31293830275535583
Iteration: 426. Loss: 0.3285143971443176
Iteration: 427. Loss: 0.2440948486328125
Iteration: 428. Loss: 0.41964712738990784
Iteration: 429. Loss: 0.3217262625694275
Iteration: 430. Loss: 0.2937425374984741
Iteration: 431. Loss: 0.3183949887752533
Iteration: 432. Loss: 0.35360658168792725
Iteration: 433. Loss: 0.24391412734985352
Iteration: 434. Loss: 0.4350854456424713
Iteration: 435. Loss: 0.2707177996635437
Iteration: 436. Loss: 0.42165032029151917
Iteration: 437. Loss: 0.2570733428001404
Iteration: 438. Loss: 0.20202063024044037
Iterat

Iteration: 612. Loss: 0.16440154612064362
Iteration: 613. Loss: 0.39137348532676697
Iteration: 614. Loss: 0.3546876609325409
Iteration: 615. Loss: 0.24832050502300262
Iteration: 616. Loss: 0.3685187101364136
Iteration: 617. Loss: 0.2993236780166626
Iteration: 618. Loss: 0.2928968369960785
Iteration: 619. Loss: 0.30072033405303955
Iteration: 620. Loss: 0.2824627459049225
Iteration: 621. Loss: 0.34662628173828125
Iteration: 622. Loss: 0.3142697811126709
Iteration: 623. Loss: 0.22728778421878815
Iteration: 624. Loss: 0.4158063530921936
Iteration: 625. Loss: 0.2628464996814728
Iteration: 626. Loss: 0.3587667942047119
Iteration: 627. Loss: 0.19639061391353607
Iteration: 628. Loss: 0.3282468020915985
Iteration: 629. Loss: 0.370172917842865
Iteration: 630. Loss: 0.3250950872898102
Iteration: 631. Loss: 0.23230987787246704
Iteration: 632. Loss: 0.4244445860385895
Iteration: 633. Loss: 0.24572056531906128
Iteration: 634. Loss: 0.2186688929796219
Iteration: 635. Loss: 0.2472248524427414
Iteratio

Iteration: 824. Loss: 0.24659128487110138
Iteration: 825. Loss: 0.4101550281047821
Iteration: 826. Loss: 0.31281226873397827
Iteration: 827. Loss: 0.27498817443847656
Iteration: 828. Loss: 0.30572569370269775
Iteration: 829. Loss: 0.14504849910736084
Iteration: 830. Loss: 0.25713661313056946
Iteration: 831. Loss: 0.26928916573524475
Iteration: 832. Loss: 0.258341521024704
Iteration: 833. Loss: 0.3603942096233368
Iteration: 834. Loss: 0.3328808844089508
Iteration: 835. Loss: 0.3000672459602356
Iteration: 836. Loss: 0.25485312938690186
Iteration: 837. Loss: 0.21895742416381836
Iteration: 838. Loss: 0.20827454328536987
Iteration: 839. Loss: 0.4009447395801544
Iteration: 840. Loss: 0.17800608277320862
Iteration: 841. Loss: 0.3095327913761139
Iteration: 842. Loss: 0.19958843290805817
Iteration: 843. Loss: 0.5074097514152527
Iteration: 844. Loss: 0.40717679262161255
Iteration: 845. Loss: 0.4204385280609131
Iteration: 846. Loss: 0.2891896963119507
Iteration: 847. Loss: 0.3133215606212616
Iter

Iteration: 1028. Loss: 0.37624451518058777
Iteration: 1029. Loss: 0.29608094692230225
Iteration: 1030. Loss: 0.17548829317092896
Iteration: 1031. Loss: 0.21877221763134003
Iteration: 1032. Loss: 0.2381265014410019
Iteration: 1033. Loss: 0.35741522908210754
Iteration: 1034. Loss: 0.15997014939785004
Iteration: 1035. Loss: 0.28210222721099854
Iteration: 1036. Loss: 0.3225540220737457
Iteration: 1037. Loss: 0.45175105333328247
Iteration: 1038. Loss: 0.33664214611053467
Iteration: 1039. Loss: 0.15884919464588165
Iteration: 1040. Loss: 0.3041130006313324
Iteration: 1041. Loss: 0.22215613722801208
Iteration: 1042. Loss: 0.39013031125068665
Iteration: 1043. Loss: 0.22709445655345917
Iteration: 1044. Loss: 0.18047964572906494
Iteration: 1045. Loss: 0.26782429218292236
Iteration: 1046. Loss: 0.2682103216648102
Iteration: 1047. Loss: 0.26157423853874207
Iteration: 1048. Loss: 0.23561958968639374
Iteration: 1049. Loss: 0.26627910137176514
Iteration: 1050. Loss: 0.29873332381248474
Iteration: 1051

Iteration: 1225. Loss: 0.3377651274204254
Iteration: 1226. Loss: 0.1510055661201477
Iteration: 1227. Loss: 0.13750514388084412
Iteration: 1228. Loss: 0.24227018654346466
Iteration: 1229. Loss: 0.20090751349925995
Iteration: 1230. Loss: 0.3429960608482361
Iteration: 1231. Loss: 0.22574740648269653
Iteration: 1232. Loss: 0.18681031465530396
Iteration: 1233. Loss: 0.14235299825668335
Iteration: 1234. Loss: 0.17668873071670532
Iteration: 1235. Loss: 0.37574100494384766
Iteration: 1236. Loss: 0.3932279348373413
Iteration: 1237. Loss: 0.2977563142776489
Iteration: 1238. Loss: 0.3750622868537903
Iteration: 1239. Loss: 0.15460847318172455
Iteration: 1240. Loss: 0.18708282709121704
Iteration: 1241. Loss: 0.1341070979833603
Iteration: 1242. Loss: 0.2085917890071869
Iteration: 1243. Loss: 0.37515684962272644
Iteration: 1244. Loss: 0.24393826723098755
Iteration: 1245. Loss: 0.4149033725261688
Iteration: 1246. Loss: 0.1494196057319641
Iteration: 1247. Loss: 0.23583900928497314
Iteration: 1248. Loss

Iteration: 1419. Loss: 0.20645734667778015
Iteration: 1420. Loss: 0.12919531762599945
Iteration: 1421. Loss: 0.3289169669151306
Iteration: 1422. Loss: 0.18347862362861633
Iteration: 1423. Loss: 0.2419375479221344
Iteration: 1424. Loss: 0.20907045900821686
Iteration: 1425. Loss: 0.28089961409568787
Iteration: 1426. Loss: 0.28641271591186523
Iteration: 1427. Loss: 0.3515685200691223
Iteration: 1428. Loss: 0.2067272961139679
Iteration: 1429. Loss: 0.1501569300889969
Iteration: 1430. Loss: 0.19468531012535095
Iteration: 1431. Loss: 0.30890700221061707
Iteration: 1432. Loss: 0.21185246109962463
Iteration: 1433. Loss: 0.18643276393413544
Iteration: 1434. Loss: 0.3772789239883423
Iteration: 1435. Loss: 0.13562460243701935
Iteration: 1436. Loss: 0.2890397608280182
Iteration: 1437. Loss: 0.2752493619918823
Iteration: 1438. Loss: 0.1537800133228302
Iteration: 1439. Loss: 0.22531431913375854
Iteration: 1440. Loss: 0.20611946284770966
Iteration: 1441. Loss: 0.1690521091222763
Iteration: 1442. Loss

Iteration: 1612. Loss: 0.1413833051919937
Iteration: 1613. Loss: 0.26445379853248596
Iteration: 1614. Loss: 0.1399165391921997
Iteration: 1615. Loss: 0.17188526690006256
Iteration: 1616. Loss: 0.07198075950145721
Iteration: 1617. Loss: 0.16511230170726776
Iteration: 1618. Loss: 0.19421285390853882
Iteration: 1619. Loss: 0.08826715499162674
Iteration: 1620. Loss: 0.11019890755414963
Iteration: 1621. Loss: 0.1990506798028946
Iteration: 1622. Loss: 0.3915640711784363
Iteration: 1623. Loss: 0.22777879238128662
Iteration: 1624. Loss: 0.20724527537822723
Iteration: 1625. Loss: 0.23680754005908966
Iteration: 1626. Loss: 0.29040470719337463
Iteration: 1627. Loss: 0.07336445152759552
Iteration: 1628. Loss: 0.09537165611982346
Iteration: 1629. Loss: 0.0965089201927185
Iteration: 1630. Loss: 0.2868637144565582
Iteration: 1631. Loss: 0.3140818178653717
Iteration: 1632. Loss: 0.2618210017681122
Iteration: 1633. Loss: 0.20459431409835815
Iteration: 1634. Loss: 0.1821574568748474
Iteration: 1635. Los

Iteration: 1810. Loss: 0.2607179880142212
Iteration: 1811. Loss: 0.15278324484825134
Iteration: 1812. Loss: 0.1364121288061142
Iteration: 1813. Loss: 0.2074587196111679
Iteration: 1814. Loss: 0.2813507318496704
Iteration: 1815. Loss: 0.2465590536594391
Iteration: 1816. Loss: 0.169688880443573
Iteration: 1817. Loss: 0.11161284148693085
Iteration: 1818. Loss: 0.11925289779901505
Iteration: 1819. Loss: 0.18854796886444092
Iteration: 1820. Loss: 0.14992386102676392
Iteration: 1821. Loss: 0.2212287336587906
Iteration: 1822. Loss: 0.2436596155166626
Iteration: 1823. Loss: 0.28311625123023987
Iteration: 1824. Loss: 0.30323129892349243
Iteration: 1825. Loss: 0.08380066603422165
Iteration: 1826. Loss: 0.19997285306453705
Iteration: 1827. Loss: 0.16798517107963562
Iteration: 1828. Loss: 0.16840948164463043
Iteration: 1829. Loss: 0.1575298309326172
Iteration: 1830. Loss: 0.11741120368242264
Iteration: 1831. Loss: 0.2104654461145401
Iteration: 1832. Loss: 0.1755039244890213
Iteration: 1833. Loss: 

Iteration: 2008. Loss: 0.12249980866909027
Iteration: 2009. Loss: 0.19577422738075256
Iteration: 2010. Loss: 0.09928983449935913
Iteration: 2011. Loss: 0.14621582627296448
Iteration: 2012. Loss: 0.16118916869163513
Iteration: 2013. Loss: 0.31068405508995056
Iteration: 2014. Loss: 0.17626141011714935
Iteration: 2015. Loss: 0.14702481031417847
Iteration: 2016. Loss: 0.050419557839632034
Iteration: 2017. Loss: 0.3290940821170807
Iteration: 2018. Loss: 0.16454319655895233
Iteration: 2019. Loss: 0.19786600768566132
Iteration: 2020. Loss: 0.19661098718643188
Iteration: 2021. Loss: 0.2417369782924652
Iteration: 2022. Loss: 0.1504179835319519
Iteration: 2023. Loss: 0.14495201408863068
Iteration: 2024. Loss: 0.14689326286315918
Iteration: 2025. Loss: 0.2575259804725647
Iteration: 2026. Loss: 0.19287458062171936
Iteration: 2027. Loss: 0.16580243408679962
Iteration: 2028. Loss: 0.09949758648872375
Iteration: 2029. Loss: 0.0668022483587265
Iteration: 2030. Loss: 0.15115393698215485
Iteration: 2031

Iteration: 2208. Loss: 0.22179128229618073
Iteration: 2209. Loss: 0.20823746919631958
Iteration: 2210. Loss: 0.14986364543437958
Iteration: 2211. Loss: 0.204135000705719
Iteration: 2212. Loss: 0.17678675055503845
Iteration: 2213. Loss: 0.18532438576221466
Iteration: 2214. Loss: 0.15293288230895996
Iteration: 2215. Loss: 0.21967476606369019
Iteration: 2216. Loss: 0.20126764476299286
Iteration: 2217. Loss: 0.17531976103782654
Iteration: 2218. Loss: 0.17303365468978882
Iteration: 2219. Loss: 0.2018737941980362
Iteration: 2220. Loss: 0.21257320046424866
Iteration: 2221. Loss: 0.1797393560409546
Iteration: 2222. Loss: 0.3134385645389557
Iteration: 2223. Loss: 0.0722198486328125
Iteration: 2224. Loss: 0.1861933171749115
Iteration: 2225. Loss: 0.18223509192466736
Iteration: 2226. Loss: 0.24545682966709137
Iteration: 2227. Loss: 0.14781539142131805
Iteration: 2228. Loss: 0.13433457911014557
Iteration: 2229. Loss: 0.2510499954223633
Iteration: 2230. Loss: 0.14322423934936523
Iteration: 2231. Lo

Iteration: 2402. Loss: 0.13262048363685608
Iteration: 2403. Loss: 0.10716726630926132
Iteration: 2404. Loss: 0.13295511901378632
Iteration: 2405. Loss: 0.19110575318336487
Iteration: 2406. Loss: 0.14108160138130188
Iteration: 2407. Loss: 0.22115600109100342
Iteration: 2408. Loss: 0.15860013663768768
Iteration: 2409. Loss: 0.16200672090053558
Iteration: 2410. Loss: 0.31049492955207825
Iteration: 2411. Loss: 0.2641516923904419
Iteration: 2412. Loss: 0.16929058730602264
Iteration: 2413. Loss: 0.11028463393449783
Iteration: 2414. Loss: 0.13789358735084534
Iteration: 2415. Loss: 0.18767976760864258
Iteration: 2416. Loss: 0.15068426728248596
Iteration: 2417. Loss: 0.07807954400777817
Iteration: 2418. Loss: 0.19171804189682007
Iteration: 2419. Loss: 0.27916041016578674
Iteration: 2420. Loss: 0.15740135312080383
Iteration: 2421. Loss: 0.46648460626602173
Iteration: 2422. Loss: 0.102653369307518
Iteration: 2423. Loss: 0.10008301585912704
Iteration: 2424. Loss: 0.09032284468412399
Iteration: 242

Iteration: 2598. Loss: 0.19112282991409302
Iteration: 2599. Loss: 0.13659274578094482
Iteration: 2600. Loss: 0.07510467618703842
Iteration: 2601. Loss: 0.18690699338912964
Iteration: 2602. Loss: 0.1721367985010147
Iteration: 2603. Loss: 0.12056023627519608
Iteration: 2604. Loss: 0.15273594856262207
Iteration: 2605. Loss: 0.1707308143377304
Iteration: 2606. Loss: 0.2931100130081177
Iteration: 2607. Loss: 0.13247457146644592
Iteration: 2608. Loss: 0.06396307796239853
Iteration: 2609. Loss: 0.1742374449968338
Iteration: 2610. Loss: 0.07580248266458511
Iteration: 2611. Loss: 0.26223573088645935
Iteration: 2612. Loss: 0.30909866094589233
Iteration: 2613. Loss: 0.16351108253002167
Iteration: 2614. Loss: 0.070012666285038
Iteration: 2615. Loss: 0.18936091661453247
Iteration: 2616. Loss: 0.14219698309898376
Iteration: 2617. Loss: 0.14461639523506165
Iteration: 2618. Loss: 0.10766172409057617
Iteration: 2619. Loss: 0.07996979355812073
Iteration: 2620. Loss: 0.18781456351280212
Iteration: 2621. 

Iteration: 2794. Loss: 0.14922058582305908
Iteration: 2795. Loss: 0.1683138608932495
Iteration: 2796. Loss: 0.08855738490819931
Iteration: 2797. Loss: 0.09587714076042175
Iteration: 2798. Loss: 0.08234620839357376
Iteration: 2799. Loss: 0.18090450763702393
Iteration: 2800. Loss: 0.3131481409072876
Iteration: 2801. Loss: 0.15103326737880707
Iteration: 2802. Loss: 0.16470466554164886
Iteration: 2803. Loss: 0.18993481993675232
Iteration: 2804. Loss: 0.1765476018190384
Iteration: 2805. Loss: 0.15996013581752777
Iteration: 2806. Loss: 0.3103102743625641
Iteration: 2807. Loss: 0.2510014772415161
Iteration: 2808. Loss: 0.1377321183681488
Iteration: 2809. Loss: 0.09223433583974838
Iteration: 2810. Loss: 0.12324206531047821
Iteration: 2811. Loss: 0.21631424129009247
Iteration: 2812. Loss: 0.08788102120161057
Iteration: 2813. Loss: 0.1983422487974167
Iteration: 2814. Loss: 0.2051394283771515
Iteration: 2815. Loss: 0.12290869653224945
Iteration: 2816. Loss: 0.19296427071094513
Iteration: 2817. Lo

Iteration: 2987. Loss: 0.1872013807296753
Iteration: 2988. Loss: 0.1936473846435547
Iteration: 2989. Loss: 0.1593424528837204
Iteration: 2990. Loss: 0.1257135570049286
Iteration: 2991. Loss: 0.4123833477497101
Iteration: 2992. Loss: 0.09793195128440857
Iteration: 2993. Loss: 0.10620089620351791
Iteration: 2994. Loss: 0.13440407812595367
Iteration: 2995. Loss: 0.16766025125980377
Iteration: 2996. Loss: 0.23291154205799103
Iteration: 2997. Loss: 0.09181895107030869
Iteration: 2998. Loss: 0.2166014462709427
Iteration: 2999. Loss: 0.15496572852134705
Iteration: 3000. Loss: 0.09926022589206696


In [10]:
if iter %500 == 0:
            # Calculate Accuracy         
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
               
                images = Variable(images.view(-1, 28*28))
                
                # Forward pass only to get logits/output
                outputs = model(images)
                
                # Get predictions from the maximum value
                _, predicted = torch.max(outputs.data, 1)
                
                # Total number of labels
                total += labels.size(0)
                
                correct += (predicted.cpu() == labels.cpu()).sum()
            
            accuracy = 100 * correct / total
        
            # Print Loss
            print('%Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.data[0], accuracy))

%Iteration: 3000. Loss: 0.09926022589206696. Accuracy: 95.91
