In [1]:
import numpy as np
from scipy import optimize, special
import itertools
from tqdm import trange
from scipy.special import expit


## Initialize weights for theta values

In [2]:
def rand_initialize_weights(L_in, L_out):
    '''Returns a vector of dimension: (L_out, L_in + 1) of random values'''
    t = np.zeros((L_out, L_in + 1))
    e = 0.12 # sigma
    t[:] = np.random.randn(*t.shape) * 2.0 * e - e
    return t

In [3]:
# Load data
data = np.genfromtxt('./data/tt.csv', delimiter=',')
y = data[:,0].reshape(-1,1)  # Shape: (688,1)
X = data[:, 1:]              # Shape: (m, input_layer_size)
m = len(y)

input_layer_size = 784
hidden_layer_size = 120      # hidden layer has 120 nodes excluding bias 
num_labels = 26              # 26 total output values: a=0, b=1, ... , z=25

# Initialize theta values
# theta1 size = (input_layer_size + 1, hidden_layer_size)
# theta2 size = (hidden_layer_size + 1, num_labels)
np.random.seed(1999)
theta1 = rand_initialize_weights(input_layer_size, hidden_layer_size)
theta2 = rand_initialize_weights(hidden_layer_size, num_labels)

print(f'theta1 shape: {np.shape(theta1)}, theta2 shape: {np.shape(theta2)}')

lam = 0.0 # Lambda used for regularization in cost function

theta1 shape: (120, 785), theta2 shape: (26, 121)


## Sigmoid function and derivative
* Used for activation function in forward propogation
* Derivative used for back propogation

In [4]:
def sigmoid(z):
   # return (1 / (1 + np.exp(-z)))
    return expit(z)

def dx_sigmoid(z):
    res = sigmoid(z)
    return (res * (1 - res))

print(sigmoid(0))
print(dx_sigmoid(0))

0.5
0.25


In [5]:
def return_labeled(y, num_labels):
    '''returns (m,num_labels) matrix where each column corresoponds to an example
    all indexes are 0, each column contains a single 1, the row corresponding to the answer'''
    out = np.zeros((m, num_labels))
    for row in range(0, m):
        label = int(y[row])
        out[row, label] = 1
    return out

In [6]:
def forward_backward(X,y):
    y_k = return_labeled(y, num_labels) # output matrix used for cost evaluation
    # Forward propogation
    a1 = np.c_[np.ones((m,1)), X]   # layer 1 : adding a bias column of 1's to X. (m, input_size + 1)
    z2 = a1.dot(theta1.T)           # layer 2 matrix calculation
    a2 = sigmoid(z2)                # activation of layer 2: shape(m, hidden_layer_size)
    a2 = np.c_[np.ones((m, 1)), a2] # adding a bias column of 1's : shape(m, hidden_layer_size + 1)
    z3 = a2.dot(theta2.T)           # layer3 || output layer calculation
    a3 = sigmoid(z3)                # activation of output layer: size = (m, num_labels)
    
    # Cost function
    # Check this after training
    inner_term0 = (-y_k * np.log(a3))
    inner_term1 = (1 - y_k) * np.log(1 - a3)
    left_side = np.sum(inner_term0 + inner_term1) / m
    right_side = np.sum(theta1[:, 1:] ** 2) + np.sum(theta2[:,1:] ** 2) # sum of all theta vals squred excluding theta index0 
    right_side = (lam / 2 / m) * right_side
    cost = left_side + right_side
    print(f'cost: {cost}')
    
    # Back propogation
    delta3 = a3 - y_k
    delta2 = delta3.dot(theta2)[:,1:] * dx_sigmoid(z2) # excluding bias
    Gradient1 = (delta2.T).dot(a1)
    Gradient2 = (delta3.T).dot(a2)
    

    return cost, a3, Gradient1, Gradient2


In [9]:

learning_rate = 0.001
batch_size = 128

# reinitialize thetas
theta1 = rand_initialize_weights(input_layer_size, hidden_layer_size)
theta2 = rand_initialize_weights(hidden_layer_size, num_labels)

losses, accuracies = [], []

for i in (t := trange(m)):
    J, h_x, grad1, grad2 = forward_backward(X,y)
    prediction = np.argmax(h_x, axis=1)
    accuracy = (prediction == y[i]).mean()
    
    # Gradient descent:
    theta1 = theta1 - (learning_rate * grad1)
    theta2 = theta2 - (learning_rate * grad2)
    
    loss = J.mean()
    losses.append(loss)
    
    accuracies.append(accuracy)
    t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))

    

loss 0.02 accuracy 0.00:   1%|▏         | 10/688 [00:00<00:07, 94.17it/s] 

cost: -16.002664998057345
cost: -12.360678857005533
cost: -9.75562748363745
cost: -7.851756219928978
cost: -6.422448127420966
cost: -5.320138192476938
cost: -4.448776895102651
cost: -3.7448211527121726
cost: -3.165260711356478
cost: -2.6802605042385563
cost: -2.2686238083357506
cost: -1.9149505645770368
cost: -1.6078204532259386
cost: -1.3386052963523525
cost: -1.1006756722873448
cost: -0.8888597114472543
cost: -0.6990666182095243
cost: -0.5280199736675979
cost: -0.37306560200146677
cost: -0.23203098802959587
cost: -0.1031209303652523
cost: 0.015160940471411838
cost: 0.12407191590874236


loss 1.54 accuracy 0.00:   6%|▌         | 38/688 [00:00<00:05, 112.65it/s]

cost: 0.22467622160820647
cost: 0.3178808217784276
cost: 0.40446355287026853
cost: 0.48509543127077953
cost: 0.5603584915828704
cost: 0.630760164011506
cost: 0.6967449482550201
cost: 0.7587039581457762
cost: 0.8169827763322659
cost: 0.8718879578987018
cost: 0.9236924464498333
cost: 0.9726401091224927
cost: 1.0189495534259252
cost: 1.062817355307733
cost: 1.1044208018828796
cost: 1.1439202320112585
cost: 1.1814610420062863
cost: 1.2171754111910844
cost: 1.2511837920307807
cost: 1.2835962015855669
cost: 1.3145133446128034
cost: 1.344027593463329
cost: 1.3722238457096003
cost: 1.399180277011584
cost: 1.424969003914864
cost: 1.4496566689620722
cost: 1.473304958587424
cost: 1.4959710626786977
cost: 1.517708083370779
cost: 1.5385653995316582
cost: 1.5585889924763725


loss 1.91 accuracy 0.00:  10%|█         | 69/688 [00:00<00:04, 128.64it/s]

cost: 1.5778217376657613
cost: 1.596303666489454
cost: 1.614072201675699
cost: 1.6311623693976476
cost: 1.647606990742675
cost: 1.663436854866977
cost: 1.6786808758626635
cost: 1.6933662351111138
cost: 1.7075185106781134
cost: 1.721161795117821
cost: 1.7343188028895449
cost: 1.7470109684498267
cost: 1.759258535959339
cost: 1.7710806414369298
cost: 1.7824953880995908
cost: 1.7935199155452417
cost: 1.8041704633634557
cost: 1.8144624296962113
cost: 1.8244104252152726
cost: 1.8340283229338672
cost: 1.8433293042271652
cost: 1.852325901397805
cost: 1.861030037088857
cost: 1.8694530608165338
cost: 1.877605782868236
cost: 1.8854985057876978
cost: 1.893141053647781
cost: 1.9005427992924844
cost: 1.9077126897127956
cost: 1.9146592697057938
cost: 1.9213907039528098


loss 2.05 accuracy 0.00:  17%|█▋        | 115/688 [00:00<00:04, 140.65it/s]

cost: 1.9279147976402056
cost: 1.9342390157353062
cost: 1.9403705010201422
cost: 1.9463160909767117
cost: 1.9520823336093842
cost: 1.9576755022828292
cost: 1.963101609647197
cost: 1.9683664207163725
cost: 1.9734754651596327
cost: 1.9784340488621939
cost: 1.9832472648056105
cost: 1.987920003314947
cost: 1.9924569617159484
cost: 1.996862653442033
cost: 2.00114141662788
cost: 2.005297422223561
cost: 2.009334681660589
cost: 2.013257054098905
cost: 2.0170682532816926
cost: 2.0207718540228905
cost: 2.0243712983505002
cost: 2.027869901327117
cost: 2.031270856567551
cost: 2.0345772414720398
cost: 2.0377920221922445
cost: 2.040918058346009
cost: 2.0439581074957904
cost: 2.046914829404648
cost: 2.0497907900827212
cost: 2.0525884656362856


loss 2.11 accuracy 0.00:  19%|█▉        | 130/688 [00:01<00:03, 140.68it/s]

cost: 2.05531024593067
cost: 2.057958438077562
cost: 2.0605352697565613
cost: 2.063042892380199
cost: 2.0654833841110443
cost: 2.0678587527389856
cost: 2.070170938426254
cost: 2.0724218163272914
cost: 2.0746131990901255
cost: 2.0767468392455086
cost: 2.0788244314896906
cost: 2.080847614866342
cost: 2.082817974852837
cost: 2.0847370453557534
cost: 2.0866063106202053
cost: 2.0884272070573306
cost: 2.090201124994012
cost: 2.0919294103486723
cost: 2.0936133662367737
cost: 2.0952542545094546
cost: 2.096853297228495
cost: 2.0984116780807116
cost: 2.099930543734627
cost: 2.101411005142152
cost: 2.1028541387878543
cost: 2.104260987888259
cost: 2.1056325635434736
cost: 2.1069698458433357
cost: 2.1082737849301485
cost: 2.1095453020199666


loss 2.14 accuracy 0.00:  24%|██▎       | 162/688 [00:01<00:03, 146.60it/s]

cost: 2.110785290384288
cost: 2.111994616293924
cost: 2.1131741199267156
cost: 2.114324616240685
cost: 2.115446895814136
cost: 2.1165417256541317
cost: 2.117609849974714
cost: 2.118651990946156
cost: 2.1196688494164913
cost: 2.12066110560647
cost: 2.1216294197790795
cost: 2.1225744328846723
cost: 2.1234967671827296
cost: 2.1243970268412067
cost: 2.1252757985143846
cost: 2.1261336519001097
cost: 2.126971140277246
cost: 2.1277888010241353
cost: 2.1285871561188343
cost: 2.1293667126218407
cost: 2.130127963142012
cost: 2.1308713862863313
cost: 2.1315974470941446
cost: 2.13230659745649
cost: 2.1329992765210823
cost: 2.1336759110835053
cost: 2.1343369159651457
cost: 2.1349826943783627
cost: 2.1356136382793895
cost: 2.1362301287094145
cost: 2.1368325361242966


loss 2.15 accuracy 0.00:  28%|██▊       | 192/688 [00:01<00:03, 145.92it/s]

cost: 2.1374212207133283
cost: 2.137996532707462
cost: 2.1385588126773794
cost: 2.1391083918217793
cost: 2.1396455922462496
cost: 2.1401707272330524
cost: 2.1406841015021567
cost: 2.1411860114638404
cost: 2.1416767454631525
cost: 2.1421565840165293
cost: 2.1426258000408565
cost: 2.1430846590752197
cost: 2.143533419495623
cost: 2.143972332722905
cost: 2.1444016434241107
cost: 2.1448215897075213
cost: 2.1452324033115815
cost: 2.1456343097879222
cost: 2.1460275286786943
cost: 2.146412273688391
cost: 2.146788752850359
cost: 2.147157168688178
cost: 2.1475177183720735
cost: 2.1478705938705414
cost: 2.1482159820973323
cost: 2.148554065053962
cost: 2.1488850199678926
cost: 2.149209019426525
cost: 2.149526231507139


loss 2.16 accuracy 0.00:  32%|███▏      | 222/688 [00:01<00:03, 146.28it/s]

cost: 2.149836819902931
cost: 2.1501409440452557
cost: 2.1504387592222045
cost: 2.150730416693646
cost: 2.1510160638028397
cost: 2.1512958440847187
cost: 2.1515698973709823
cost: 2.151838359892063
cost: 2.152101364376096
cost: 2.1523590401449755
cost: 2.152611513207587
cost: 2.152858906350317
cost: 2.1531013392249134
cost: 2.1533389284337887
cost: 2.153571787612844
cost: 2.1538000275118883
cost: 2.154023756072734
cost: 2.154243078505038
cost: 2.1544580973599547
cost: 2.154668912601678
cost: 2.1548756216769243
cost: 2.15507831958244
cost: 2.155277098930565
cost: 2.155472050012943
cost: 2.1556632608624144
cost: 2.155850817313151
cost: 2.1560348030590943
cost: 2.1562152997107398
cost: 2.1563923868503188
cost: 2.156566142085431


loss 2.16 accuracy 0.00:  37%|███▋      | 252/688 [00:01<00:02, 147.70it/s]

cost: 2.1567366411011686
cost: 2.156903957710778
cost: 2.157068163904914
cost: 2.157229329899502
cost: 2.1573875241822877
cost: 2.1575428135580723
cost: 2.1576952631927027
cost: 2.1578449366558488
cost: 2.1579918959625783
cost: 2.158136201613806
cost: 2.158277912635615
cost: 2.1584170866175
cost: 2.1585537797495595
cost: 2.158688046858674
cost: 2.158819941443687
cost: 2.15894951570963
cost: 2.1590768206010145
cost: 2.1592019058342147
cost: 2.1593248199289747
cost: 2.159445610239055
cost: 2.159564322982059
cost: 2.1596810032684406
cost: 2.1597956951297426
cost: 2.15990844154606
cost: 2.1600192844727784
cost: 2.1601282648665814
cost: 2.160235422710774
cost: 2.160340797039915
cost: 2.1604444259638007
cost: 2.1605463466908015


loss 2.16 accuracy 0.00:  41%|████      | 282/688 [00:02<00:02, 147.61it/s]

cost: 2.1606465955505825
cost: 2.160745208016219
cost: 2.160842218725716
cost: 2.160937661502973
cost: 2.161031569378176
cost: 2.1611239746076625
cost: 2.1612149086932573
cost: 2.1613044024011003
cost: 2.161392485779976
cost: 2.1614791881791633
cost: 2.161564538265817
cost: 2.161648564041892
cost: 2.1617312928606256
cost: 2.161812751442591
cost: 2.161892965891329
cost: 2.1619719617085704
cost: 2.1620497638090685
cost: 2.1621263965350432
cost: 2.162201883670245
cost: 2.1622762484536677
cost: 2.1623495135928876
cost: 2.162421701277079
cost: 2.1624928331896744
cost: 2.162562930520712
cost: 2.16263201397886
cost: 2.16270010380313
cost: 2.162767219774297
cost: 2.162833381226015
cost: 2.162898607055664
cost: 2.162962915734902


loss 2.16 accuracy 0.00:  45%|████▌     | 313/688 [00:02<00:02, 149.05it/s]

cost: 2.1630263253199598
cost: 2.163088853461673
cost: 2.1631505174152497
cost: 2.163211334049798
cost: 2.16327131985761
cost: 2.1633304909632103
cost: 2.1633888631321696
cost: 2.163446451779705
cost: 2.1635032719790535
cost: 2.1635593384696463
cost: 2.1636146656650608
cost: 2.1636692676607865
cost: 2.1637231582417913
cost: 2.1637763508898917
cost: 2.163828858790951
cost: 2.163880694841882
cost: 2.1639318716574905
cost: 2.1639824015771327
cost: 2.1640322966712175
cost: 2.164081568747541
cost: 2.16413022935747
cost: 2.1641782898019595
cost: 2.164225761137433
cost: 2.164272654181514
cost: 2.1643189795186086
cost: 2.1643647475053625
cost: 2.16440996827597
cost: 2.164454651747362
cost: 2.16449880762426
cost: 2.1645424454041104


loss 2.17 accuracy 1.00:  50%|████▉     | 343/688 [00:02<00:02, 147.20it/s]

cost: 2.1645855743818943
cost: 2.1646282036548192
cost: 2.1646703421268993
cost: 2.1647119985134164
cost: 2.164753181345283
cost: 2.1647938989732856
cost: 2.164834159572234
cost: 2.164873971145007
cost: 2.1649133415264985
cost: 2.1649522783874664
cost: 2.164990789238293
cost: 2.1650288814326495
cost: 2.1650665621710714
cost: 2.1651038385044536
cost: 2.1651407173374553
cost: 2.165177205431819
cost: 2.1652133094096238
cost: 2.165249035756444
cost: 2.165284390824441
cost: 2.165319380835382
cost: 2.165354011883578
cost: 2.1653882899387593
cost: 2.1654222208488814
cost: 2.1654558103428574
cost: 2.165489064033233
cost: 2.1655219874187943
cost: 2.1655545858871106
cost: 2.1655868647170236
cost: 2.1656188290810703
cost: 2.165650484047853


loss 2.17 accuracy 0.00:  54%|█████▍    | 374/688 [00:02<00:02, 148.67it/s]

cost: 2.1656818345843534
cost: 2.1657128855581846
cost: 2.165743641739802
cost: 2.1657741078046535
cost: 2.165804288335276
cost: 2.165834187823357
cost: 2.165863810671728
cost: 2.1658931611963257
cost: 2.1659222436281054
cost: 2.1659510621149036
cost: 2.1659796207232573
cost: 2.1660079234401888
cost: 2.166035974174941
cost: 2.1660637767606747
cost: 2.1660913349561275
cost: 2.1661186524472287
cost: 2.166145732848689
cost: 2.166172579705535
cost: 2.1661991964946306
cost: 2.166225586626136
cost: 2.16625175344496
cost: 2.1662777002321625
cost: 2.1663034302063244
cost: 2.166328946524897
cost: 2.1663542522855104
cost: 2.1663793505272513
cost: 2.166404244231922
cost: 2.1664289363252576
cost: 2.1664534296781235
cost: 2.1664777271076843


loss 2.17 accuracy 0.00:  59%|█████▊    | 404/688 [00:02<00:01, 147.37it/s]

cost: 2.1665018313785427
cost: 2.1665257452038555
cost: 2.166549471246425
cost: 2.1665730121197573
cost: 2.166596370389114
cost: 2.1666195485725166
cost: 2.1666425491417507
cost: 2.1666653745233315
cost: 2.1666880270994544
cost: 2.166710509208925
cost: 2.16673282314806
cost: 2.166754971171583
cost: 2.166776955493483
cost: 2.166798778287864
cost: 2.1668204416897745
cost: 2.166841947796016
cost: 2.166863298665933
cost: 2.1668844963221874
cost: 2.166905542751516
cost: 2.1669264399054673
cost: 2.1669471897011263
cost: 2.166967794021822
cost: 2.1669882547178174
cost: 2.167008573606983
cost: 2.1670287524754603
cost: 2.167048793078309
cost: 2.1670686971401376
cost: 2.1670884663557164
cost: 2.1671081023905914
cost: 2.167127606881667


loss 2.17 accuracy 0.00:  63%|██████▎   | 434/688 [00:03<00:01, 146.03it/s]

cost: 2.1671469814377864
cost: 2.167166227640298
cost: 2.167185347043609
cost: 2.1672043411757205
cost: 2.167223211538764
cost: 2.1672419596095134
cost: 2.167260586839892
cost: 2.167279094657468
cost: 2.1672974844659394
cost: 2.1673157576456044
cost: 2.1673339155538303
cost: 2.1673519595255017
cost: 2.167369890873465
cost: 2.1673877108889656
cost: 2.167405420842067
cost: 2.16742302198207
cost: 2.1674405155379186
cost: 2.1674579027185983
cost: 2.16747518471352
cost: 2.1674923626929092
cost: 2.1675094378081705
cost: 2.1675264111922576
cost: 2.16754328396003
cost: 2.167560057208596
cost: 2.167576732017663
cost: 2.167593309449866
cost: 2.167609790551099
cost: 2.1676261763508307
cost: 2.1676424678624233
cost: 2.167658666083436


loss 2.17 accuracy 0.00:  67%|██████▋   | 464/688 [00:03<00:01, 144.66it/s]

cost: 2.1676747719959253
cost: 2.167690786566743
cost: 2.16770671074782
cost: 2.1677225454764506
cost: 2.167738291675568
cost: 2.167753950254013
cost: 2.1677695221068016
cost: 2.1677850081153807
cost: 2.1678004091478837
cost: 2.1678157260593793
cost: 2.167830959692113
cost: 2.1678461108757454
cost: 2.167861180427588
cost: 2.167876169152825
cost: 2.167891077844747
cost: 2.1679059072849585
cost: 2.167920658243603
cost: 2.1679353314795655
cost: 2.16794992774068
cost: 2.167964447763934
cost: 2.167978892275662
cost: 2.1679932619917412
cost: 2.16800755761778
cost: 2.168021779849303
cost: 2.1680359293719325
cost: 2.1680500068615687
cost: 2.1680640129845634
cost: 2.1680779483978863
cost: 2.1680918137493
cost: 2.168105609677518


loss 2.17 accuracy 0.00:  72%|███████▏  | 496/688 [00:03<00:01, 148.17it/s]

cost: 2.1681193368123677
cost: 2.1681329957749482
cost: 2.168146587177783
cost: 2.1681601116249727
cost: 2.1681735697123443
cost: 2.1681869620275904
cost: 2.16820028915042
cost: 2.1682135516526913
cost: 2.16822675009855
cost: 2.1682398850445663
cost: 2.168252957039863
cost: 2.168265966626244
cost: 2.168278914338319
cost: 2.1682918007036354
cost: 2.1683046262427883
cost: 2.168317391469546
cost: 2.1683300968909673
cost: 2.168342743007508
cost: 2.1683553303131418
cost: 2.1683678592954663
cost: 2.16838033043581
cost: 2.168392744209337
cost: 2.168405101085154
cost: 2.1684174015264084
cost: 2.1684296459903885
cost: 2.1684418349286223
cost: 2.1684539687869715
cost: 2.168466048005726
cost: 2.1684780730196964
cost: 2.168490044258304
cost: 2.1685019621456676


loss 2.17 accuracy 1.00:  77%|███████▋  | 528/688 [00:03<00:01, 151.15it/s]

cost: 2.1685138271006936
cost: 2.1685256395371577
cost: 2.16853739986379
cost: 2.1685491084843576
cost: 2.1685607657977437
cost: 2.1685723721980277
cost: 2.168583928074561
cost: 2.1685954338120426
cost: 2.168606889790597
cost: 2.1686182963858425
cost: 2.1686296539689667
cost: 2.1686409629067946
cost: 2.1686522235618573
cost: 2.168663436292462
cost: 2.1686746014527585
cost: 2.1686857193927978
cost: 2.168696790458607
cost: 2.1687078149922434
cost: 2.1687187933318586
cost: 2.1687297258117626
cost: 2.1687406127624764
cost: 2.1687514545107955
cost: 2.1687622513798472
cost: 2.1687730036891413
cost: 2.1687837117546325
cost: 2.168794375888771
cost: 2.168804996400553
cost: 2.1688155735955776
cost: 2.1688261077760957
cost: 2.168836599241062
cost: 2.1688470482861795


loss 2.17 accuracy 0.00:  81%|████████▏ | 560/688 [00:03<00:00, 152.65it/s]

cost: 2.1688574552039555
cost: 2.1688678202837424
cost: 2.1688781438117886
cost: 2.1688884260712826
cost: 2.168898667342399
cost: 2.168908867902343
cost: 2.168919028025391
cost: 2.1689291479829382
cost: 2.1689392280435373
cost: 2.16894926847294
cost: 2.1689592695341386
cost: 2.1689692314874045
cost: 2.1689791545903283
cost: 2.1689890390978577
cost: 2.168998885262335
cost: 2.169008693333537
cost: 2.1690184635587055
cost: 2.1690281961825884
cost: 2.1690378914474726
cost: 2.1690475495932193
cost: 2.169057170857297
cost: 2.1690667554748178
cost: 2.1690763036785636
cost: 2.1690858156990265
cost: 2.1690952917644353
cost: 2.169104732100788
cost: 2.169114136931881
cost: 2.1691235064793424
cost: 2.1691328409626602
cost: 2.16914214059921
cost: 2.1691514056042838


loss 2.17 accuracy 0.00:  86%|████████▌ | 592/688 [00:04<00:00, 152.41it/s]

cost: 2.1691606361911187
cost: 2.1691698325709274
cost: 2.1691789949529188
cost: 2.1691881235443313
cost: 2.169197218550454
cost: 2.1692062801746586
cost: 2.169215308618416
cost: 2.1692243040813306
cost: 2.1692332667611574
cost: 2.1692421968538316
cost: 2.169251094553488
cost: 2.1692599600524898
cost: 2.169268793541444
cost: 2.1692775952092305
cost: 2.1692863652430225
cost: 2.1692951038283064
cost: 2.169303811148906
cost: 2.1693124873870024
cost: 2.169321132723155
cost: 2.169329747336322
cost: 2.1693383314038797
cost: 2.169346885101645
cost: 2.169355408603891
cost: 2.16936390208337
cost: 2.1693723657113297
cost: 2.1693807996575343
cost: 2.1693892040902796
cost: 2.1693975791764135
cost: 2.1694059250813553
cost: 2.1694142419691067
cost: 2.1694225300022763


loss 2.17 accuracy 0.00:  91%|█████████ | 624/688 [00:04<00:00, 150.37it/s]

cost: 2.1694307893420923
cost: 2.1694390201484213
cost: 2.169447222579784
cost: 2.1694553967933667
cost: 2.1694635429450484
cost: 2.1694716611894056
cost: 2.1694797516797313
cost: 2.1694878145680527
cost: 2.169495850005143
cost: 2.1695038581405366
cost: 2.169511839122546
cost: 2.169519793098272
cost: 2.169527720213622
cost: 2.16953562061332
cost: 2.169543494440924
cost: 2.169551341838836
cost: 2.1695591629483184
cost: 2.169566957909502
cost: 2.169574726861407
cost: 2.1695824699419464
cost: 2.1695901872879473
cost: 2.1695978790351536
cost: 2.1696055453182486
cost: 2.1696131862708583
cost: 2.169620802025567
cost: 2.16962839271393
cost: 2.1696359584664804
cost: 2.169643499412746
cost: 2.1696510156812576
cost: 2.169658507399558


loss 2.17 accuracy 0.00:  95%|█████████▌| 656/688 [00:04<00:00, 150.95it/s]

cost: 2.169665974694216
cost: 2.1696734176908365
cost: 2.169680836514068
cost: 2.169688231287618
cost: 2.1696956021342566
cost: 2.169702949175835
cost: 2.1697102725332864
cost: 2.1697175723266424
cost: 2.1697248486750382
cost: 2.1697321016967264
cost: 2.169739331509081
cost: 2.1697465382286123
cost: 2.16975372197097
cost: 2.16976088285096
cost: 2.169768020982543
cost: 2.1697751364788544
cost: 2.1697822294522013
cost: 2.1697893000140813
cost: 2.169796348275185
cost: 2.1698033743454075
cost: 2.16981037833385
cost: 2.1698173603488375
cost: 2.169824320497918
cost: 2.1698312588878776
cost: 2.169838175624742
cost: 2.1698450708137873
cost: 2.1698519445595466
cost: 2.1698587969658196
cost: 2.1698656281356747
cost: 2.1698724381714634
cost: 2.16987922717482


loss 2.17 accuracy 0.00: 100%|██████████| 688/688 [00:04<00:00, 147.08it/s]

cost: 2.1698859952466734
cost: 2.1698927424872534
cost: 2.1698994689960958
cost: 2.1699061748720516
cost: 2.1699128602132896
cost: 2.1699195251173076
cost: 2.1699261696809384
cost: 2.16993279400035
cost: 2.1699393981710613
cost: 2.1699459822879397
cost: 2.1699525464452147
cost: 2.169959090736478
cost: 2.169965615254693
cost: 2.1699721200921998
cost: 2.1699786053407193
cost: 2.169985071091362
cost: 2.1699915174346334
cost: 2.1699979444604356
cost: 2.1700043522580774
cost: 2.1700107409162785
cost: 2.1700171105231743
cost: 2.1700234611663203
cost: 2.170029792932701
cost: 2.170036105908731
cost: 2.1700424001802627
cost: 2.170048675832589
cost: 2.1700549329504524
cost: 2.1700611716180465





In [8]:
# Model training
