In [8]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable

In [9]:
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

In [10]:
batch_size = 100
n_iters = 3000
num_epochs = n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

In [21]:
# create a model class 
# input->linear function->non linear function(sigmoid)->linear function->Softmax->CrossEntropy
class FeedFowardNeuralNetwork(nn.Module): 
    def __init__(self,input_dim,hidden_dim,output_dim):
        super(FeedFowardNeuralNetwork,self).__init__()
        self.fc1=nn.Linear(input_dim,hidden_dim)
        self.sigmoid = nn.Sigmoid()
        self.fc2=nn.Linear(hidden_dim,output_dim)
        
    def forward(self,x): 
        out=self.fc1(x)
        out=self.sigmoid(out)
        out=self.fc2(out)
        return out 
    

In [22]:
# instantiate the model 
input_dim=28*28 
hidden_dim=100
output_dim=10 
model=FeedFowardNeuralNetwork(input_dim,hidden_dim,output_dim)

In [23]:
# instantiate the loss 
criterion=nn.CrossEntropyLoss()

In [24]:
# instatiate the optimize 
learning_rate=0.1
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)

In [31]:
print(len(list(model.parameters())))
print(list(model.parameters())[0].size())
print(list(model.parameters())[1].size())
print(list(model.parameters())[2].size())
print(list(model.parameters())[3].size())


4
torch.Size([100, 784])
torch.Size([100])
torch.Size([10, 100])
torch.Size([10])


In [39]:
# Training 
# convert input into variable 
iter=0 
for epochs in range(num_epochs):
    for i,(images,labels) in enumerate (train_loader): 
        images = Variable(images.view(-1,28*28))
        labels = Variable(labels)
        
    # clear the gradients
        optimizer.zero_grad()
    
    # foward pass to obtain the outputs
        outputs = model(images)
    
    #loss 
        loss = criterion(outputs,labels)
    
    #backpropagation 
        loss.backward()
        optimizer.step()
    
    
        iter+=1

        print('Iteration: {}. Loss: {}'.format(iter, loss.data[0]))

        
        images 

Iteration: 1. Loss: 2.3307626247406006
Iteration: 2. Loss: 2.335205078125
Iteration: 3. Loss: 2.3012731075286865
Iteration: 4. Loss: 2.3016693592071533
Iteration: 5. Loss: 2.3026785850524902
Iteration: 6. Loss: 2.2862207889556885
Iteration: 7. Loss: 2.2832369804382324
Iteration: 8. Loss: 2.26493239402771
Iteration: 9. Loss: 2.3040342330932617
Iteration: 10. Loss: 2.2666122913360596
Iteration: 11. Loss: 2.2747910022735596
Iteration: 12. Loss: 2.2995901107788086
Iteration: 13. Loss: 2.2954704761505127
Iteration: 14. Loss: 2.272814989089966
Iteration: 15. Loss: 2.294840097427368
Iteration: 16. Loss: 2.270069122314453
Iteration: 17. Loss: 2.271491527557373
Iteration: 18. Loss: 2.244065284729004
Iteration: 19. Loss: 2.274722099304199
Iteration: 20. Loss: 2.2521162033081055
Iteration: 21. Loss: 2.2776525020599365
Iteration: 22. Loss: 2.237713098526001
Iteration: 23. Loss: 2.234771966934204
Iteration: 24. Loss: 2.2693986892700195
Iteration: 25. Loss: 2.2691965103149414
Iteration: 26. Loss: 2.

Iteration: 210. Loss: 1.20193612575531
Iteration: 211. Loss: 1.203674554824829
Iteration: 212. Loss: 1.2569990158081055
Iteration: 213. Loss: 1.099810242652893
Iteration: 214. Loss: 1.2523466348648071
Iteration: 215. Loss: 1.221348524093628
Iteration: 216. Loss: 1.1950064897537231
Iteration: 217. Loss: 1.1988598108291626
Iteration: 218. Loss: 1.2570981979370117
Iteration: 219. Loss: 1.1025618314743042
Iteration: 220. Loss: 1.2380335330963135
Iteration: 221. Loss: 1.184187412261963
Iteration: 222. Loss: 1.2080261707305908
Iteration: 223. Loss: 1.1744797229766846
Iteration: 224. Loss: 1.1390950679779053
Iteration: 225. Loss: 1.1416269540786743
Iteration: 226. Loss: 1.1592235565185547
Iteration: 227. Loss: 1.1263138055801392
Iteration: 228. Loss: 1.0963387489318848
Iteration: 229. Loss: 1.095086932182312
Iteration: 230. Loss: 1.0711536407470703
Iteration: 231. Loss: 1.1585756540298462
Iteration: 232. Loss: 1.1927951574325562
Iteration: 233. Loss: 1.170780062675476
Iteration: 234. Loss: 0.

Iteration: 428. Loss: 0.7049258947372437
Iteration: 429. Loss: 0.7024204134941101
Iteration: 430. Loss: 0.6261444091796875
Iteration: 431. Loss: 0.6208955645561218
Iteration: 432. Loss: 0.6084201335906982
Iteration: 433. Loss: 0.5814751982688904
Iteration: 434. Loss: 0.5955579280853271
Iteration: 435. Loss: 0.6804344654083252
Iteration: 436. Loss: 0.6457852721214294
Iteration: 437. Loss: 0.6622045636177063
Iteration: 438. Loss: 0.6767109632492065
Iteration: 439. Loss: 0.7834210991859436
Iteration: 440. Loss: 0.6800373196601868
Iteration: 441. Loss: 0.6234684586524963
Iteration: 442. Loss: 0.6672210693359375
Iteration: 443. Loss: 0.7156423330307007
Iteration: 444. Loss: 0.7207456231117249
Iteration: 445. Loss: 0.6166000962257385
Iteration: 446. Loss: 0.6938987970352173
Iteration: 447. Loss: 0.6930767893791199
Iteration: 448. Loss: 0.6167694330215454
Iteration: 449. Loss: 0.5799959897994995
Iteration: 450. Loss: 0.6748138666152954
Iteration: 451. Loss: 0.45767584443092346
Iteration: 452.

Iteration: 641. Loss: 0.5520417094230652
Iteration: 642. Loss: 0.46483856439590454
Iteration: 643. Loss: 0.4979674220085144
Iteration: 644. Loss: 0.6539208889007568
Iteration: 645. Loss: 0.6079562902450562
Iteration: 646. Loss: 0.43398529291152954
Iteration: 647. Loss: 0.43895405530929565
Iteration: 648. Loss: 0.4656504690647125
Iteration: 649. Loss: 0.5718525052070618
Iteration: 650. Loss: 0.5692428946495056
Iteration: 651. Loss: 0.4884550869464874
Iteration: 652. Loss: 0.4498145282268524
Iteration: 653. Loss: 0.5453805327415466
Iteration: 654. Loss: 0.505911648273468
Iteration: 655. Loss: 0.44069114327430725
Iteration: 656. Loss: 0.42022305727005005
Iteration: 657. Loss: 0.4484058618545532
Iteration: 658. Loss: 0.5837541818618774
Iteration: 659. Loss: 0.5594027638435364
Iteration: 660. Loss: 0.4887617826461792
Iteration: 661. Loss: 0.4579312205314636
Iteration: 662. Loss: 0.5253421068191528
Iteration: 663. Loss: 0.6078764200210571
Iteration: 664. Loss: 0.48078370094299316
Iteration: 

Iteration: 840. Loss: 0.4135713577270508
Iteration: 841. Loss: 0.44181305170059204
Iteration: 842. Loss: 0.42025822401046753
Iteration: 843. Loss: 0.44848841428756714
Iteration: 844. Loss: 0.3819803297519684
Iteration: 845. Loss: 0.43844783306121826
Iteration: 846. Loss: 0.36966028809547424
Iteration: 847. Loss: 0.46626320481300354
Iteration: 848. Loss: 0.43146035075187683
Iteration: 849. Loss: 0.5598979592323303
Iteration: 850. Loss: 0.38457393646240234
Iteration: 851. Loss: 0.4286067485809326
Iteration: 852. Loss: 0.43966734409332275
Iteration: 853. Loss: 0.40419575572013855
Iteration: 854. Loss: 0.3636784851551056
Iteration: 855. Loss: 0.481658399105072
Iteration: 856. Loss: 0.5988155603408813
Iteration: 857. Loss: 0.5557394027709961
Iteration: 858. Loss: 0.6153659224510193
Iteration: 859. Loss: 0.3953007459640503
Iteration: 860. Loss: 0.4458882808685303
Iteration: 861. Loss: 0.405337929725647
Iteration: 862. Loss: 0.5751714706420898
Iteration: 863. Loss: 0.44716858863830566
Iterati

Iteration: 1057. Loss: 0.5394534468650818
Iteration: 1058. Loss: 0.44552457332611084
Iteration: 1059. Loss: 0.4889252483844757
Iteration: 1060. Loss: 0.4104699194431305
Iteration: 1061. Loss: 0.26344752311706543
Iteration: 1062. Loss: 0.46992599964141846
Iteration: 1063. Loss: 0.3196554183959961
Iteration: 1064. Loss: 0.35979124903678894
Iteration: 1065. Loss: 0.3754998743534088
Iteration: 1066. Loss: 0.3315136730670929
Iteration: 1067. Loss: 0.45215901732444763
Iteration: 1068. Loss: 0.39608269929885864
Iteration: 1069. Loss: 0.46283668279647827
Iteration: 1070. Loss: 0.4404822289943695
Iteration: 1071. Loss: 0.35250863432884216
Iteration: 1072. Loss: 0.3457489013671875
Iteration: 1073. Loss: 0.3392787277698517
Iteration: 1074. Loss: 0.42192572355270386
Iteration: 1075. Loss: 0.44728243350982666
Iteration: 1076. Loss: 0.47996535897254944
Iteration: 1077. Loss: 0.4817388951778412
Iteration: 1078. Loss: 0.36604294180870056
Iteration: 1079. Loss: 0.3060312569141388
Iteration: 1080. Loss:

Iteration: 1257. Loss: 0.3854697048664093
Iteration: 1258. Loss: 0.4290136694908142
Iteration: 1259. Loss: 0.38870295882225037
Iteration: 1260. Loss: 0.22633296251296997
Iteration: 1261. Loss: 0.4738813042640686
Iteration: 1262. Loss: 0.4139886796474457
Iteration: 1263. Loss: 0.33417823910713196
Iteration: 1264. Loss: 0.2808234691619873
Iteration: 1265. Loss: 0.3323850929737091
Iteration: 1266. Loss: 0.3153400123119354
Iteration: 1267. Loss: 0.28224828839302063
Iteration: 1268. Loss: 0.4098421037197113
Iteration: 1269. Loss: 0.36751988530158997
Iteration: 1270. Loss: 0.37278133630752563
Iteration: 1271. Loss: 0.3025243878364563
Iteration: 1272. Loss: 0.44045937061309814
Iteration: 1273. Loss: 0.4356826841831207
Iteration: 1274. Loss: 0.36790168285369873
Iteration: 1275. Loss: 0.4016934633255005
Iteration: 1276. Loss: 0.3513405919075012
Iteration: 1277. Loss: 0.4048590064048767
Iteration: 1278. Loss: 0.4168553948402405
Iteration: 1279. Loss: 0.3112664520740509
Iteration: 1280. Loss: 0.3

Iteration: 1455. Loss: 0.2959161698818207
Iteration: 1456. Loss: 0.3704148828983307
Iteration: 1457. Loss: 0.2968137264251709
Iteration: 1458. Loss: 0.38592633605003357
Iteration: 1459. Loss: 0.298880934715271
Iteration: 1460. Loss: 0.2701408565044403
Iteration: 1461. Loss: 0.4562482535839081
Iteration: 1462. Loss: 0.43439698219299316
Iteration: 1463. Loss: 0.3830256760120392
Iteration: 1464. Loss: 0.46645763516426086
Iteration: 1465. Loss: 0.2969393730163574
Iteration: 1466. Loss: 0.3346613347530365
Iteration: 1467. Loss: 0.36408936977386475
Iteration: 1468. Loss: 0.3263883888721466
Iteration: 1469. Loss: 0.39352840185165405
Iteration: 1470. Loss: 0.36592021584510803
Iteration: 1471. Loss: 0.4214824438095093
Iteration: 1472. Loss: 0.36582931876182556
Iteration: 1473. Loss: 0.37773796916007996
Iteration: 1474. Loss: 0.29600998759269714
Iteration: 1475. Loss: 0.2827935516834259
Iteration: 1476. Loss: 0.3019355833530426
Iteration: 1477. Loss: 0.31764665246009827
Iteration: 1478. Loss: 0.

Iteration: 1653. Loss: 0.37343594431877136
Iteration: 1654. Loss: 0.3449788987636566
Iteration: 1655. Loss: 0.3415484130382538
Iteration: 1656. Loss: 0.30292361974716187
Iteration: 1657. Loss: 0.20590052008628845
Iteration: 1658. Loss: 0.3928092122077942
Iteration: 1659. Loss: 0.27405908703804016
Iteration: 1660. Loss: 0.34442755579948425
Iteration: 1661. Loss: 0.5328817963600159
Iteration: 1662. Loss: 0.37186580896377563
Iteration: 1663. Loss: 0.42166006565093994
Iteration: 1664. Loss: 0.2929655611515045
Iteration: 1665. Loss: 0.2951931059360504
Iteration: 1666. Loss: 0.262233704328537
Iteration: 1667. Loss: 0.35954749584198
Iteration: 1668. Loss: 0.32925787568092346
Iteration: 1669. Loss: 0.2707059383392334
Iteration: 1670. Loss: 0.20416267216205597
Iteration: 1671. Loss: 0.3657548427581787
Iteration: 1672. Loss: 0.36378589272499084
Iteration: 1673. Loss: 0.3351646363735199
Iteration: 1674. Loss: 0.36287975311279297
Iteration: 1675. Loss: 0.35118892788887024
Iteration: 1676. Loss: 0.

Iteration: 1867. Loss: 0.2711459696292877
Iteration: 1868. Loss: 0.2654101848602295
Iteration: 1869. Loss: 0.3505314886569977
Iteration: 1870. Loss: 0.2707112431526184
Iteration: 1871. Loss: 0.4699013829231262
Iteration: 1872. Loss: 0.3288533091545105
Iteration: 1873. Loss: 0.3533875346183777
Iteration: 1874. Loss: 0.3931267261505127
Iteration: 1875. Loss: 0.27975162863731384
Iteration: 1876. Loss: 0.39696186780929565
Iteration: 1877. Loss: 0.4325062036514282
Iteration: 1878. Loss: 0.4740581214427948
Iteration: 1879. Loss: 0.37247562408447266
Iteration: 1880. Loss: 0.40440404415130615
Iteration: 1881. Loss: 0.3329533636569977
Iteration: 1882. Loss: 0.3353099822998047
Iteration: 1883. Loss: 0.2291007786989212
Iteration: 1884. Loss: 0.4248652756214142
Iteration: 1885. Loss: 0.2833453118801117
Iteration: 1886. Loss: 0.31345418095588684
Iteration: 1887. Loss: 0.410416841506958
Iteration: 1888. Loss: 0.37446126341819763
Iteration: 1889. Loss: 0.2894538938999176
Iteration: 1890. Loss: 0.3757

Iteration: 2078. Loss: 0.36961764097213745
Iteration: 2079. Loss: 0.31677618622779846
Iteration: 2080. Loss: 0.4368476867675781
Iteration: 2081. Loss: 0.31272923946380615
Iteration: 2082. Loss: 0.3579224646091461
Iteration: 2083. Loss: 0.2237183004617691
Iteration: 2084. Loss: 0.3621537387371063
Iteration: 2085. Loss: 0.36381983757019043
Iteration: 2086. Loss: 0.26515814661979675
Iteration: 2087. Loss: 0.3104361891746521
Iteration: 2088. Loss: 0.48016557097435
Iteration: 2089. Loss: 0.6228853464126587
Iteration: 2090. Loss: 0.2775813043117523
Iteration: 2091. Loss: 0.2431110441684723
Iteration: 2092. Loss: 0.4794829189777374
Iteration: 2093. Loss: 0.27959078550338745
Iteration: 2094. Loss: 0.38796278834342957
Iteration: 2095. Loss: 0.4076635241508484
Iteration: 2096. Loss: 0.348231703042984
Iteration: 2097. Loss: 0.3699076175689697
Iteration: 2098. Loss: 0.29012951254844666
Iteration: 2099. Loss: 0.31395336985588074
Iteration: 2100. Loss: 0.27506959438323975
Iteration: 2101. Loss: 0.32

Iteration: 2277. Loss: 0.1674550175666809
Iteration: 2278. Loss: 0.3037598133087158
Iteration: 2279. Loss: 0.3193477690219879
Iteration: 2280. Loss: 0.38434749841690063
Iteration: 2281. Loss: 0.14721909165382385
Iteration: 2282. Loss: 0.24777071177959442
Iteration: 2283. Loss: 0.24655655026435852
Iteration: 2284. Loss: 0.30291253328323364
Iteration: 2285. Loss: 0.2217530459165573
Iteration: 2286. Loss: 0.25336551666259766
Iteration: 2287. Loss: 0.26935309171676636
Iteration: 2288. Loss: 0.26950713992118835
Iteration: 2289. Loss: 0.29431530833244324
Iteration: 2290. Loss: 0.3108353316783905
Iteration: 2291. Loss: 0.25324809551239014
Iteration: 2292. Loss: 0.36036959290504456
Iteration: 2293. Loss: 0.3768494129180908
Iteration: 2294. Loss: 0.28511136770248413
Iteration: 2295. Loss: 0.26731741428375244
Iteration: 2296. Loss: 0.389940083026886
Iteration: 2297. Loss: 0.2685624361038208
Iteration: 2298. Loss: 0.27409541606903076
Iteration: 2299. Loss: 0.30762845277786255
Iteration: 2300. Los

Iteration: 2475. Loss: 0.3914056718349457
Iteration: 2476. Loss: 0.17490415275096893
Iteration: 2477. Loss: 0.45467203855514526
Iteration: 2478. Loss: 0.23914474248886108
Iteration: 2479. Loss: 0.3664010167121887
Iteration: 2480. Loss: 0.27422112226486206
Iteration: 2481. Loss: 0.23651760816574097
Iteration: 2482. Loss: 0.3156583905220032
Iteration: 2483. Loss: 0.2628858983516693
Iteration: 2484. Loss: 0.4810916781425476
Iteration: 2485. Loss: 0.3492995500564575
Iteration: 2486. Loss: 0.35735023021698
Iteration: 2487. Loss: 0.39133220911026
Iteration: 2488. Loss: 0.4842740595340729
Iteration: 2489. Loss: 0.4706918001174927
Iteration: 2490. Loss: 0.4141179621219635
Iteration: 2491. Loss: 0.2636478543281555
Iteration: 2492. Loss: 0.18650022149085999
Iteration: 2493. Loss: 0.2538665533065796
Iteration: 2494. Loss: 0.2592374086380005
Iteration: 2495. Loss: 0.2263849675655365
Iteration: 2496. Loss: 0.27241846919059753
Iteration: 2497. Loss: 0.29315292835235596
Iteration: 2498. Loss: 0.19134

Iteration: 2672. Loss: 0.35843127965927124
Iteration: 2673. Loss: 0.2904261648654938
Iteration: 2674. Loss: 0.32956787943840027
Iteration: 2675. Loss: 0.3871360123157501
Iteration: 2676. Loss: 0.3662886321544647
Iteration: 2677. Loss: 0.18077851831912994
Iteration: 2678. Loss: 0.3260897696018219
Iteration: 2679. Loss: 0.3989192247390747
Iteration: 2680. Loss: 0.464405357837677
Iteration: 2681. Loss: 0.27213191986083984
Iteration: 2682. Loss: 0.3255651891231537
Iteration: 2683. Loss: 0.3070930242538452
Iteration: 2684. Loss: 0.3555421829223633
Iteration: 2685. Loss: 0.21577636897563934
Iteration: 2686. Loss: 0.3634416460990906
Iteration: 2687. Loss: 0.21460431814193726
Iteration: 2688. Loss: 0.37809914350509644
Iteration: 2689. Loss: 0.36593472957611084
Iteration: 2690. Loss: 0.49985024333000183
Iteration: 2691. Loss: 0.2595117688179016
Iteration: 2692. Loss: 0.24973121285438538
Iteration: 2693. Loss: 0.42682212591171265
Iteration: 2694. Loss: 0.2842774987220764
Iteration: 2695. Loss: 0

Iteration: 2869. Loss: 0.39206817746162415
Iteration: 2870. Loss: 0.2558484971523285
Iteration: 2871. Loss: 0.3825396001338959
Iteration: 2872. Loss: 0.30061739683151245
Iteration: 2873. Loss: 0.28320834040641785
Iteration: 2874. Loss: 0.19884587824344635
Iteration: 2875. Loss: 0.2898329198360443
Iteration: 2876. Loss: 0.31142300367355347
Iteration: 2877. Loss: 0.34352126717567444
Iteration: 2878. Loss: 0.213687464594841
Iteration: 2879. Loss: 0.31334683299064636
Iteration: 2880. Loss: 0.2539511024951935
Iteration: 2881. Loss: 0.19010141491889954
Iteration: 2882. Loss: 0.16403864324092865
Iteration: 2883. Loss: 0.2793671488761902
Iteration: 2884. Loss: 0.2089172601699829
Iteration: 2885. Loss: 0.11987635493278503
Iteration: 2886. Loss: 0.26790520548820496
Iteration: 2887. Loss: 0.2256034016609192
Iteration: 2888. Loss: 0.25386807322502136
Iteration: 2889. Loss: 0.2719871699810028
Iteration: 2890. Loss: 0.34889933466911316
Iteration: 2891. Loss: 0.23438875377178192
Iteration: 2892. Loss

In [58]:
if iter %500 == 0:
            # Calculate Accuracy         
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
               
                images = Variable(images.view(-1, 28*28))
                
                # Forward pass only to get logits/output
                outputs = model(images)
                
                # Get predictions from the maximum value
                _, predicted = torch.max(outputs.data, 1)
                
                # Total number of labels
                total += labels.size(0)
                
                correct += (predicted.cpu() == labels.cpu()).sum()
            
            accuracy = 100 * correct / total
        
            # Print Loss
            print('%Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.data[0], accuracy))

%Iteration: 3000. Loss: 0.2841082513332367. Accuracy: 91.96
