In [1]:
from torchvision import transforms
from torchvision.datasets import MNIST

train_data = MNIST('mnist_train', train=True, transform=transforms.ToTensor(), download=True)
test_data = MNIST('mnist_test', train=False, transform=transforms.ToTensor(), download=True)

In [4]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from network import CharacterClassifier
from tqdm import tqdm

input_dim = (1, 28, 28)
hidden_layers = [50, 100, 500]
output_dim = 10

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=False)

model = CharacterClassifier(input_dim, hidden_layers, output_dim)
model.train()

epochs = 10
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

for epoch in range(epochs):
        print("Epoch {0}".format(epoch))
        for step, [x_train, y_train] in enumerate(train_loader):
            optimizer.zero_grad()
            train_pred = model(x_train)
            loss = criterion(train_pred, y_train)
            loss.backward()
            optimizer.step()
            if step % 10 == 0:
                print('Loss: {}'.format(loss))
                
model.eval()
accuracies = []
for idx, [x_test, y_test] in enumerate(tqdm(test_loader, desc='Test')):
    test_pred = model(x_test)
    accuracy = 100 * torch.mean((torch.argmax(test_pred, dim=1) == y_test).float())
    accuracies.append(accuracy)
print("Accuracy: {0}".format(np.mean(accuracies)))

Epoch 0
Loss: 2.3020832538604736
Loss: 2.3019120693206787
Loss: 2.2937777042388916
Loss: 2.2979071140289307
Loss: 2.2957372665405273
Loss: 2.2979912757873535
Loss: 2.3177380561828613
Loss: 2.2566473484039307
Loss: 2.1929543018341064
Loss: 2.2238259315490723
Loss: 2.080141067504883
Loss: 1.9250049591064453
Loss: 1.9362080097198486
Loss: 1.8585638999938965
Loss: 1.7859126329421997
Loss: 1.7009669542312622
Loss: 1.6642439365386963
Loss: 1.7924697399139404
Loss: 1.7032510042190552
Loss: 1.6893410682678223
Loss: 1.6681236028671265
Loss: 1.6500428915023804
Loss: 1.6148968935012817
Loss: 1.7951689958572388
Loss: 1.6821885108947754
Loss: 1.6603007316589355
Loss: 1.6868939399719238
Loss: 1.6507160663604736
Loss: 1.6756420135498047
Loss: 1.6677026748657227
Loss: 1.722570776939392
Loss: 1.6529139280319214
Loss: 1.6107826232910156
Loss: 1.5987273454666138
Loss: 1.559265375137329
Loss: 1.729501485824585
Loss: 1.6740093231201172
Loss: 1.579500436782837
Loss: 1.641340732574463
Loss: 1.6529541015625
L

Loss: 1.5397989749908447
Loss: 1.5106890201568604
Loss: 1.5883986949920654
Loss: 1.5255440473556519
Loss: 1.5692812204360962
Loss: 1.474873661994934
Loss: 1.5454282760620117
Loss: 1.6243953704833984
Loss: 1.6007673740386963
Loss: 1.5432674884796143
Loss: 1.533017635345459
Loss: 1.5613582134246826
Loss: 1.5510770082473755
Loss: 1.6008918285369873
Loss: 1.5373400449752808
Loss: 1.519104242324829
Loss: 1.6637276411056519
Loss: 1.5338882207870483
Loss: 1.501157283782959
Loss: 1.7009040117263794
Loss: 1.5154697895050049
Loss: 1.551594853401184
Loss: 1.5905427932739258
Loss: 1.4897843599319458
Loss: 1.5908639430999756
Loss: 1.5471324920654297
Loss: 1.5576399564743042
Loss: 1.6208195686340332
Loss: 1.6027640104293823
Loss: 1.5421971082687378
Loss: 1.6178014278411865
Loss: 1.5465259552001953
Loss: 1.5527890920639038
Loss: 1.5273157358169556
Loss: 1.565955400466919
Loss: 1.5972604751586914
Loss: 1.5001132488250732
Loss: 1.598504900932312
Loss: 1.4876512289047241
Loss: 1.5142168998718262
Loss: 1

Loss: 1.5228993892669678
Loss: 1.671897292137146
Loss: 1.4805912971496582
Loss: 1.5737085342407227
Loss: 1.492423415184021
Loss: 1.4852983951568604
Loss: 1.549099326133728
Loss: 1.4973742961883545
Loss: 1.4810724258422852
Loss: 1.5560152530670166
Loss: 1.4714800119400024
Loss: 1.5085610151290894
Loss: 1.535562515258789
Loss: 1.5343984365463257
Loss: 1.4778043031692505
Loss: 1.474582314491272
Loss: 1.525465488433838
Loss: 1.485992670059204
Loss: 1.517822027206421
Loss: 1.5016677379608154
Loss: 1.4825317859649658
Loss: 1.4789235591888428
Loss: 1.5520298480987549
Loss: 1.4966908693313599
Loss: 1.6618683338165283
Loss: 1.5017783641815186
Loss: 1.4941809177398682
Loss: 1.6130367517471313
Loss: 1.5380390882492065
Loss: 1.5389894247055054
Loss: 1.526018500328064
Loss: 1.5452104806900024
Loss: 1.4967619180679321
Loss: 1.5943130254745483
Loss: 1.4960647821426392
Loss: 1.6017227172851562
Loss: 1.5544065237045288
Loss: 1.4915870428085327
Loss: 1.4680838584899902
Loss: 1.5225491523742676
Loss: 1.5

Loss: 1.5246740579605103
Loss: 1.4631567001342773
Loss: 1.536536455154419
Loss: 1.483473777770996
Loss: 1.4833893775939941
Loss: 1.5331895351409912
Loss: 1.5005958080291748
Loss: 1.470458984375
Loss: 1.4861971139907837
Loss: 1.483064889907837
Loss: 1.5127421617507935
Loss: 1.474637508392334
Loss: 1.5219320058822632
Loss: 1.4910216331481934
Loss: 1.468563437461853
Loss: 1.4979681968688965
Loss: 1.4713351726531982
Loss: 1.5842633247375488
Loss: 1.4759292602539062
Loss: 1.4758894443511963
Loss: 1.5176141262054443
Loss: 1.4719287157058716
Loss: 1.4836053848266602
Loss: 1.494503140449524
Loss: 1.490609884262085
Loss: 1.571102499961853
Loss: 1.4930554628372192
Loss: 1.4908024072647095
Loss: 1.5379730463027954
Loss: 1.503703236579895
Loss: 1.4819200038909912
Loss: 1.508763074874878
Loss: 1.4615312814712524
Loss: 1.5254100561141968
Loss: 1.491302251815796
Loss: 1.484767198562622
Loss: 1.4630420207977295
Loss: 1.4732213020324707
Loss: 1.5784176588058472
Loss: 1.4821449518203735
Loss: 1.47301208

Loss: 1.5257365703582764
Loss: 1.4813412427902222
Loss: 1.468158483505249
Loss: 1.4783879518508911
Loss: 1.4941818714141846
Loss: 1.4687517881393433
Loss: 1.473938226699829
Loss: 1.5372052192687988
Loss: 1.5096293687820435
Loss: 1.4648820161819458
Loss: 1.4737164974212646
Loss: 1.4637317657470703
Loss: 1.5237188339233398
Loss: 1.4639103412628174
Loss: 1.4621548652648926
Loss: 1.4841779470443726
Loss: 1.4632930755615234
Loss: 1.4910099506378174
Loss: 1.4694958925247192
Loss: 1.4958611726760864
Loss: 1.4693403244018555
Loss: 1.4873666763305664
Loss: 1.4817955493927002
Loss: 1.476973533630371
Loss: 1.5190342664718628
Loss: 1.4627645015716553
Loss: 1.4768123626708984
Loss: 1.4678196907043457
Loss: 1.4721300601959229
Loss: 1.4619779586791992
Loss: 1.466651439666748
Loss: 1.47105073928833
Loss: 1.4717917442321777
Loss: 1.4627097845077515
Loss: 1.4638738632202148
Loss: 1.4617116451263428
Loss: 1.4618171453475952
Loss: 1.5061652660369873
Loss: 1.5267795324325562
Loss: 1.4631149768829346
Loss: 

Loss: 1.533699631690979
Loss: 1.463296890258789
Loss: 1.5376274585723877
Loss: 1.4676296710968018
Loss: 1.4633197784423828
Loss: 1.464775562286377
Loss: 1.5367283821105957
Loss: 1.4693512916564941
Loss: 1.4701817035675049
Loss: 1.5595850944519043
Loss: 1.4619581699371338
Loss: 1.481947660446167
Loss: 1.4725464582443237
Loss: 1.4702060222625732
Loss: 1.4615864753723145
Loss: 1.4674060344696045
Loss: 1.4660207033157349
Loss: 1.4724485874176025
Loss: 1.4701340198516846
Loss: 1.4711799621582031
Loss: 1.4697922468185425
Loss: 1.4630835056304932
Loss: 1.4660606384277344
Loss: 1.4612032175064087
Loss: 1.4850032329559326
Loss: 1.4642935991287231
Loss: 1.4685273170471191
Loss: 1.4628551006317139
Loss: 1.4672003984451294
Loss: 1.4670891761779785
Loss: 1.4840449094772339
Loss: 1.4906134605407715
Loss: 1.470210313796997
Loss: 1.4652924537658691
Loss: 1.465898036956787
Loss: 1.5388983488082886
Loss: 1.5366172790527344
Loss: 1.4812610149383545
Loss: 1.479262351989746
Loss: 1.4794445037841797
Loss: 1

Loss: 1.512778401374817
Loss: 1.5186036825180054
Loss: 1.4613523483276367
Loss: 1.4616529941558838
Loss: 1.5008058547973633
Loss: 1.529259443283081
Loss: 1.4805001020431519
Loss: 1.5337109565734863
Loss: 1.5601030588150024
Loss: 1.4613728523254395
Loss: 1.468853235244751
Loss: 1.4702244997024536
Loss: 1.4796786308288574
Loss: 1.5144520998001099
Loss: 1.4689563512802124
Loss: 1.4643583297729492
Loss: 1.4936091899871826
Loss: 1.463054895401001
Loss: 1.472360610961914
Loss: 1.4614609479904175
Loss: 1.4932585954666138
Loss: 1.4758938550949097
Loss: 1.4731991291046143
Loss: 1.5166634321212769
Loss: 1.4636058807373047
Loss: 1.4691425561904907
Loss: 1.4736555814743042
Loss: 1.4742072820663452
Loss: 1.4889051914215088
Loss: 1.490188479423523
Loss: 1.4707622528076172
Loss: 1.5204877853393555
Loss: 1.4716612100601196
Loss: 1.4621477127075195
Loss: 1.481788158416748
Loss: 1.4717098474502563
Loss: 1.4727561473846436
Loss: 1.5258569717407227
Loss: 1.4630625247955322
Loss: 1.462159514427185
Loss: 1.

Loss: 1.4690431356430054
Loss: 1.462486982345581
Loss: 1.4935102462768555
Loss: 1.4779555797576904
Loss: 1.4826785326004028
Loss: 1.4793460369110107
Loss: 1.4909422397613525
Loss: 1.464037537574768
Loss: 1.4995598793029785
Loss: 1.531886339187622
Loss: 1.4654688835144043
Loss: 1.581175684928894
Loss: 1.4694969654083252
Loss: 1.5223007202148438
Loss: 1.5363175868988037
Loss: 1.472825527191162
Loss: 1.4755624532699585
Loss: 1.4744114875793457
Loss: 1.4646313190460205
Loss: 1.4708083868026733
Loss: 1.5236833095550537
Loss: 1.464308738708496
Loss: 1.464200496673584
Loss: 1.4780348539352417
Loss: 1.461462140083313
Loss: 1.4780479669570923
Loss: 1.498913288116455
Loss: 1.4696568250656128
Loss: 1.4710084199905396
Loss: 1.5152463912963867
Loss: 1.5196852684020996
Loss: 1.467644214630127
Loss: 1.465785026550293
Loss: 1.4626353979110718
Loss: 1.527263879776001
Loss: 1.479644536972046
Loss: 1.4624146223068237
Loss: 1.4645287990570068
Loss: 1.4712820053100586
Loss: 1.4792463779449463
Loss: 1.52601

Loss: 1.4785685539245605
Loss: 1.4632130861282349
Loss: 1.4726470708847046
Loss: 1.4697571992874146
Loss: 1.46201491355896
Loss: 1.4652585983276367
Loss: 1.4653351306915283
Loss: 1.4718698263168335
Loss: 1.4687345027923584
Loss: 1.530434012413025
Loss: 1.4738502502441406
Loss: 1.4724440574645996
Loss: 1.4641528129577637
Loss: 1.470166802406311
Loss: 1.5149576663970947
Loss: 1.4881422519683838
Loss: 1.4733966588974
Loss: 1.461639165878296
Loss: 1.516584873199463
Loss: 1.466163158416748
Loss: 1.4835240840911865
Loss: 1.4790705442428589
Loss: 1.4765536785125732
Loss: 1.4834542274475098
Loss: 1.4702181816101074
Loss: 1.4611775875091553
Loss: 1.4661914110183716
Loss: 1.4850351810455322
Loss: 1.4950052499771118
Loss: 1.4743869304656982
Loss: 1.4673147201538086
Loss: 1.4617817401885986
Loss: 1.4731299877166748
Loss: 1.4829676151275635
Loss: 1.4676600694656372
Loss: 1.4612388610839844
Loss: 1.4712634086608887
Loss: 1.4734114408493042
Loss: 1.4621529579162598
Loss: 1.4869216680526733
Loss: 1.46

Loss: 1.4613510370254517
Loss: 1.4912209510803223
Loss: 1.461167335510254
Loss: 1.462139368057251
Loss: 1.4699537754058838
Loss: 1.4757435321807861
Loss: 1.470994234085083
Loss: 1.474104404449463
Loss: 1.4620085954666138
Loss: 1.4715036153793335
Loss: 1.4617741107940674
Loss: 1.4772502183914185
Loss: 1.4612294435501099
Loss: 1.4668292999267578
Loss: 1.4618576765060425
Loss: 1.4698354005813599
Loss: 1.5191123485565186
Loss: 1.4691998958587646
Epoch 8
Loss: 1.4649299383163452
Loss: 1.4625054597854614
Loss: 1.5011367797851562
Loss: 1.535003423690796
Loss: 1.4763715267181396
Loss: 1.4697412252426147
Loss: 1.4623589515686035
Loss: 1.463712215423584
Loss: 1.4618332386016846
Loss: 1.4612514972686768
Loss: 1.4617968797683716
Loss: 1.462458610534668
Loss: 1.4622814655303955
Loss: 1.4676923751831055
Loss: 1.4778130054473877
Loss: 1.5345399379730225
Loss: 1.4706984758377075
Loss: 1.4684193134307861
Loss: 1.477310299873352
Loss: 1.471645712852478
Loss: 1.5138368606567383
Loss: 1.4772028923034668
L

Loss: 1.483151912689209
Loss: 1.461354374885559
Loss: 1.4761171340942383
Loss: 1.4720370769500732
Loss: 1.4710428714752197
Loss: 1.461179494857788
Loss: 1.46187424659729
Loss: 1.470208764076233
Loss: 1.4646214246749878
Loss: 1.4701164960861206
Loss: 1.4612135887145996
Loss: 1.4856187105178833
Loss: 1.4631171226501465
Loss: 1.4992256164550781
Loss: 1.4683706760406494
Loss: 1.523689866065979
Loss: 1.464876651763916
Loss: 1.4697716236114502
Loss: 1.4714348316192627
Loss: 1.4668858051300049
Loss: 1.4727191925048828
Loss: 1.4710569381713867
Loss: 1.499726414680481
Loss: 1.4698165655136108
Loss: 1.4800723791122437
Loss: 1.4709501266479492
Loss: 1.4613897800445557
Loss: 1.470405101776123
Loss: 1.4614553451538086
Loss: 1.502397894859314
Loss: 1.4929423332214355
Loss: 1.4646046161651611
Loss: 1.4669246673583984
Loss: 1.4695829153060913
Loss: 1.4650766849517822
Loss: 1.4710463285446167
Loss: 1.4612936973571777
Loss: 1.461266279220581
Loss: 1.4622070789337158
Loss: 1.461242437362671
Loss: 1.46164

Loss: 1.4650949239730835
Loss: 1.4619619846343994
Loss: 1.4707574844360352
Loss: 1.4703445434570312
Loss: 1.4621896743774414
Loss: 1.4611953496932983
Loss: 1.469441533088684
Loss: 1.5205810070037842
Loss: 1.4783000946044922
Loss: 1.4688022136688232
Loss: 1.477609395980835
Loss: 1.466597080230713
Loss: 1.508502721786499
Loss: 1.470289945602417
Loss: 1.4653754234313965
Loss: 1.4696123600006104
Loss: 1.491332769393921
Loss: 1.4760788679122925
Loss: 1.4926843643188477
Loss: 1.4755779504776
Loss: 1.4778176546096802
Loss: 1.4755581617355347
Loss: 1.478640079498291
Loss: 1.4879474639892578
Loss: 1.5071992874145508
Loss: 1.4698519706726074
Loss: 1.468641757965088
Loss: 1.54342782497406
Loss: 1.4726390838623047
Loss: 1.4701778888702393
Loss: 1.47013521194458
Loss: 1.4619941711425781
Loss: 1.532724380493164
Loss: 1.4752715826034546
Loss: 1.4800246953964233
Loss: 1.4613041877746582
Loss: 1.537416696548462
Loss: 1.470550537109375
Loss: 1.4677577018737793
Loss: 1.5386738777160645
Loss: 1.4708142280

Test: 100%|██████████| 625/625 [00:58<00:00, 10.68it/s]

Accuracy: 98.66999816894531



