In [1]:
import numpy as np
import torch
import torch.nn as nn
import sys

In [2]:
def wSum(X,W):
    h = torch.from_numpy(X)
    z = torch.matmul(h,W)
    return z

In [3]:
def activate(x):
    return 1/(1+torch.exp(-x))

In [4]:
def forwardStep(X,W_list):
    h = torch.from_numpy(X)
    for w in W_list:
        z = torch.matmul(w,h)
        h = activate(z)
    return h

In [5]:
def UpdateParams(W_list,dW_list,lr):
    with torch.no_grad():
        for i in range(len(W_list)):
            W_list[i] -= lr*dW_list[i]
    return W_list

In [6]:
def trainNN_sgd(X,y,W_list,loss_fn,lr=0.0001,n_epochs=100):
    for epoch in range(n_epochs):
        avgLoss = []
        for i in range(len(y)):
            Xin = X[i,:]
            y_true = y[i]
            y_pred = forwardStep(Xin,W_list)
            loss = loss_fn(y_pred,torch.tensor(y_true,dtype=torch.double))    
            loss.backward()
            avgLoss.append(loss.item())
            sys.stdout.flush()
            dW_list = []
            for j in range(len(W_list)):
                dW_list.append(W_list[j].grad.data)
            W_list = UpdateParams(W_list,dW_list,lr)
            for j in range(len(W_list)):
                W_list[j].grad.data.zero_()
        print("Loss after epoch=%d: %f" %(epoch,np.mean(np.array(avgLoss))))
    return W_list

In [7]:
def trainNN_batch(X,y,W_list,loss_fn,lr=0.0001,n_epochs=100):
    n = len(y)
    for epoch in range(n_epochs):
        loss = 0
        for i in range(n):
            Xin = X[i,:]
            y_true = y[i]
            y_pred = forwardStep(Xin,W_list)
            loss += loss_fn(y_pred,torch.tensor(y_true,dtype=torch.double)) 
        loss = loss/n
        loss.backward()
        sys.stdout.flush()
        dW_list = []
        for j in range(len(W_list)):
            dW_list.append(W_list[j].grad.data)
        W_list = UpdateParams(W_list,dW_list,lr)
        for j in range(len(W_list)):
            W_list[j].grad.data.zero_()
        print("Loss after epoch=%d: %f" %(epoch,loss))
    return W_list

In [8]:
def trainNN_minibatch(X,y,W_list,loss_fn,lr=0.0001,n_epochs=100,batchSize=16):
    n = len(y)
    numbatches = n//batchSize
    for epoch in range(n_epochs):
        for batch in range(numbatches):
            X_batch = X[batch*batchSize:(batch+1)*batchSize,:]
            y_batch = y[batch*batchSize:(batch+1)*batchSize]
            loss = 0
            for i in range(batchSize):
                Xin = X_batch[i,:]
                y_true = y_batch[i]
                y_pred = forwardStep(Xin,W_list)
                loss += loss_fn(y_pred,torch.tensor(y_true,dtype=torch.double)) 
            loss = loss/batchSize
            loss.backward()
            sys.stdout.flush()
            dW_list = []
            for j in range(len(W_list)):
                dW_list.append(W_list[j].grad.data)
            W_list = UpdateParams(W_list,dW_list,lr)
            for j in range(len(W_list)):
                W_list[j].grad.data.zero_()
        print("Loss after epoch=%d: %f" %(epoch,loss/numbatches))
    return W_list

In [9]:
inDim = 10
n = 1000
X = np.random.rand(n,inDim)
y = np.random.randint(0,2,n)

W1 = torch.tensor(np.random.uniform(0,1,(2,inDim)),requires_grad=True)
W2 = torch.tensor(np.random.uniform(0,1,(3,2)),requires_grad=True)
W3 = torch.tensor(np.random.uniform(0,1,3),requires_grad=True)

W_list = []
W_list.append(W1)
W_list.append(W2)
W_list.append(W3)

loss_fn = nn.BCELoss()
# W_list = trainNN_sgd(X,y,W_list,loss_fn,lr=0.0001,n_epochs=100)
# W_list = trainNN_batch(X,y,W_list,loss_fn,lr=0.0001,n_epochs=100)
W_list = trainNN_minibatch(X,y,W_list,loss_fn,lr=0.0001,n_epochs=100)

Loss after epoch=0: 0.014794
Loss after epoch=1: 0.014781
Loss after epoch=2: 0.014768
Loss after epoch=3: 0.014755
Loss after epoch=4: 0.014742
Loss after epoch=5: 0.014730
Loss after epoch=6: 0.014717
Loss after epoch=7: 0.014704
Loss after epoch=8: 0.014692
Loss after epoch=9: 0.014679
Loss after epoch=10: 0.014666
Loss after epoch=11: 0.014654
Loss after epoch=12: 0.014641
Loss after epoch=13: 0.014629
Loss after epoch=14: 0.014616
Loss after epoch=15: 0.014604
Loss after epoch=16: 0.014592
Loss after epoch=17: 0.014580
Loss after epoch=18: 0.014567
Loss after epoch=19: 0.014555
Loss after epoch=20: 0.014543
Loss after epoch=21: 0.014531
Loss after epoch=22: 0.014519
Loss after epoch=23: 0.014507
Loss after epoch=24: 0.014495
Loss after epoch=25: 0.014483
Loss after epoch=26: 0.014471
Loss after epoch=27: 0.014459
Loss after epoch=28: 0.014447
Loss after epoch=29: 0.014436
Loss after epoch=30: 0.014424
Loss after epoch=31: 0.014412
Loss after epoch=32: 0.014401
Loss after epoch=33:

In [10]:
inDim = 10
n = 1000
X = np.random.rand(n,inDim)
y = np.random.randint(0,2,n)

In [11]:
X.shape

(1000, 10)

In [12]:
y.shape

(1000,)

In [13]:
np.unique(y)

array([0, 1])

In [14]:
W = torch.tensor(np.random.uniform(0,1,inDim),requires_grad=True)

In [15]:
print(W.shape)
print(X[0,:].shape)
z = wSum(X[0,:],W)

torch.Size([10])
(10,)


In [16]:
z

tensor(3.1883, dtype=torch.float64, grad_fn=<DotBackward0>)

In [17]:
inDim = 10
n = 1000
X = np.random.rand(n,inDim)
y = np.random.randint(0,2,n)

W1 = torch.tensor(np.random.uniform(0,1,(2,inDim)),requires_grad=True)
W2 = torch.tensor(np.random.uniform(0,1,(3,2)),requires_grad=True)
W3 = torch.tensor(np.random.uniform(0,1,3),requires_grad=True)

W_list = []
W_list.append(W1)
W_list.append(W2)
W_list.append(W3)
z = forwardStep(X[0,:],W_list)
print(z)

tensor(0.6014, dtype=torch.float64, grad_fn=<MulBackward0>)


In [18]:
# activation_func = nn.Sigmoid()
# # activation_func = nn.ReLU()
# x = 100*torch.rand(1)
# y = torch.randint(0,2,(1,),dtype=torch.float)
# y_hat = activation_func(x)
# loss_fn = nn.BCELoss()
# loss = loss_fn(y_hat,y)
# print(loss.item())

In [19]:
m = nn.Sigmoid()
lr = 0.001
loss = nn.BCELoss()
x = torch.rand(1)
y = torch.randint(0,2,(1,),dtype=torch.float)
w = torch.rand(1,requires_grad=True)

In [20]:
n_iter = 100
for i in range(n_iter):
    y_hat = m(x*w)
    l = loss(y_hat,y)
    l.backward()
    dw = w.grad.data
    with torch.no_grad():
        w -= lr*dw
    w.grad.zero_()
    print(l.item())

0.7764442563056946
0.7764291763305664
0.776414155960083
0.7763991355895996
0.7763841152191162
0.7763689756393433
0.7763538956642151
0.7763388752937317
0.7763238549232483
0.7763088345527649
0.7762938141822815
0.7762787938117981
0.7762637734413147
0.776248574256897
0.7762337327003479
0.7762185335159302
0.7762035131454468
0.7761884927749634
0.77617347240448
0.7761584520339966
0.7761434316635132
0.7761284112930298
0.7761133909225464
0.7760981917381287
0.7760833501815796
0.7760681509971619
0.7760531306266785
0.7760381102561951
0.7760230898857117
0.7760080695152283
0.7759930491447449
0.7759780287742615
0.7759630084037781
0.7759478688240051
0.7759329676628113
0.7759179472923279
0.7759029269218445
0.7758879065513611
0.7758727669715881
0.7758577466011047
0.7758427262306213
0.7758277058601379
0.7758126854896545
0.7757976651191711
0.7757826447486877
0.7757676243782043
0.775752604007721
0.775737464427948
0.7757225632667542
0.7757074236869812
0.7756924033164978
0.775677502155304
0.775662362575531
0

In [22]:
import torch.nn as nn
from torch.utils.data import TensorDataset,DataLoader

In [78]:
inDim = 10
n = 1000
X = np.random.rand(n,inDim)
y = np.random.randint(0,2,(n,))

tensor_x = torch.tensor(X,dtype=torch.float)
tensor_y = torch.tensor(y,dtype=torch.float)
Xy = TensorDataset(tensor_x,tensor_y)
Xy_loader = DataLoader(Xy,batch_size=16,shuffle=True,drop_last=True)

In [79]:
model = nn.Sequential(
    nn.Linear(inDim,200),
    nn.ReLU(),
    nn.Linear(200,100),
    nn.Tanh(),
    nn.Linear(100,1),
    nn.Sigmoid()
)

In [80]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

In [81]:
loss_fn = nn.BCELoss()

In [84]:
n_epochs = 100
for epochs in range(n_epochs):
    for X,y in Xy_loader:
        batch_size = X.shape[0]
        y_hat = model(X.view(batch_size,-1))
        loss = loss_fn(y_hat.squeeze(),y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(float(loss))

0.04301806539297104
0.230081245303154
0.01376318372786045
0.03853791579604149
0.08929727226495743
0.025465210899710655
0.07472198456525803
0.04494418203830719
0.013303088024258614
0.04164240509271622
0.0358283594250679
0.06571105867624283
0.022317517548799515
0.06925579160451889
0.15819379687309265
0.019611796364188194
0.009875250048935413
0.036558520048856735
0.016004817560315132
0.18601268529891968
0.09103182703256607
0.08455430716276169
0.0769190862774849
0.057886045426130295
0.02949903905391693
0.021850544959306717
0.038273267447948456
0.1318669319152832
0.050080493092536926
0.04610443487763405
0.12209693342447281
0.0394294336438179
0.031318679451942444
0.1132856160402298
0.09140631556510925
0.052667949348688126
0.07115688174962997
0.03213868290185928
0.020482974126935005
0.02855503372848034
0.06897559762001038
0.09310086816549301
0.05120164901018143
0.1113477498292923
0.09390976279973984
0.1304098516702652
0.08256924152374268
0.04143328219652176
0.07089288532733917
0.0594827085733

0.08110445737838745
0.2473539561033249
0.04532715305685997
0.07955584675073624
0.5434329509735107
0.06591150909662247
0.15562038123607635
0.008314847946166992
0.18031489849090576
0.2026062160730362
0.0967698022723198
0.09346407651901245
0.05031617730855942
0.12332776188850403
0.2559177577495575
0.1438104212284088
0.12443505972623825
0.07602789998054504
0.14071646332740784
0.09500947594642639
0.23000580072402954
0.11120210587978363
0.04618033766746521
0.028568796813488007
0.057377889752388
0.021338744089007378
0.20814022421836853
0.05567730963230133
0.19488966464996338
0.015223054215312004
0.06352362036705017
0.05392686650156975
0.04899977520108223
0.2628999948501587
0.31630176305770874
0.04904500022530556
0.06118996813893318
0.061395470052957535
0.34843209385871887
0.034994207322597504
0.1991044282913208
0.05761462450027466
0.12798288464546204
0.1591247022151947
0.06912294775247574
0.11644548177719116
0.04827526956796646
0.18224534392356873
0.02104192227125168
0.1191769540309906
0.2875

0.009788940660655499
0.0432545505464077
0.03812635689973831
0.10548704117536545
0.027140730991959572
0.03418883681297302
0.024614471942186356
0.028800049796700478
0.056656260043382645
0.0646175667643547
0.0758170485496521
0.02912786416709423
0.015260472893714905
0.07054035365581512
0.02970660850405693
0.04774384945631027
0.14227595925331116
0.06858223676681519
0.012190739624202251
0.044130563735961914
0.07429017126560211
0.040784671902656555
0.025613192468881607
0.019008444622159004
0.06863139569759369
0.031800903379917145
0.013584866188466549
0.04297073185443878
0.036628320813179016
0.04910515993833542
0.008620216511189938
0.025040224194526672
0.02419380098581314
0.021794050931930542
0.02353646606206894
0.041325654834508896
0.034459542483091354
0.05284476652741432
0.02764230966567993
0.032185714691877365
0.03503619506955147
0.03248632326722145
0.026403330266475677
0.03458767756819725
0.011618305929005146
0.016644397750496864
0.06138349324464798
0.049227360635995865
0.01466866116970777

0.04572807624936104
0.03336331993341446
0.07045260071754456
0.04358422011137009
0.015677116811275482
0.11291588097810745
0.05419747531414032
0.038948070257902145
0.021327955648303032
0.14493408799171448
0.05256371572613716
0.018355682492256165
0.03258741274476051
0.04492711275815964
0.015509548597037792
0.0025575042236596346
0.05481717362999916
0.025542855262756348
0.09037081152200699
0.04414550960063934
0.013688706792891026
0.015824537724256516
0.14229543507099152
0.08901653438806534
0.07535585761070251
0.03566713258624077
0.10593265295028687
0.16245415806770325
0.12683750689029694
0.014289132319390774
0.0444507971405983
0.01201335433870554
0.0395781509578228
0.02042790688574314
0.018478738144040108
0.04363425448536873
0.03335106372833252
0.02336570993065834
0.03957267850637436
0.003475502133369446
0.030382737517356873
0.01603510044515133
0.032574135810136795
0.060627102851867676
0.043508727103471756
0.03773055598139763
0.19823074340820312
0.005746329668909311
0.040644433349370956
0.0

0.1418684422969818
0.011197595857083797
0.02852412313222885
0.02476757951080799
0.05490921437740326
0.05609353259205818
0.07680611312389374
0.0410567969083786
0.3595616817474365
0.05692201852798462
0.004690234083682299
0.16725324094295502
0.0831575021147728
0.12730304896831512
0.020809629932045937
0.016108477488160133
0.009831385686993599
0.03777872025966644
0.032446134835481644
0.05806044489145279
0.05911034345626831
0.03570132330060005
0.013358866795897484
0.06536701321601868
0.018074607476592064
0.2816857397556305
0.03381934389472008
0.26391828060150146
0.030369479209184647
0.08594423532485962
0.11080092191696167
0.023068927228450775
0.04804646596312523
0.07854638993740082
0.07421735674142838
0.049607761204242706
0.09415251016616821
0.018799372017383575
0.009512979537248611
0.10704533010721207
0.17928823828697205
0.06438496708869934
0.08821867406368256
0.016006994992494583
0.04471495747566223
0.03431836888194084
0.029735852032899857
0.009942417033016682
0.07901253551244736
0.0669404

0.031621210277080536
0.4489831030368805
0.2738288640975952
0.2657713294029236
0.23014555871486664
0.10418841242790222
0.1866510510444641
0.018662532791495323
0.31644803285598755
0.17071981728076935
0.23836737871170044
0.21579387784004211
0.24674315750598907
0.08602806180715561
0.463201105594635
0.13995584845542908
0.1708804965019226
0.09910248965024948
0.2659694254398346
0.37000158429145813
0.1383741945028305
0.23560330271720886
0.053976599127054214
0.07989663630723953
0.0681186318397522
0.04167228192090988
0.6501432061195374
0.04444209858775139
0.7780599594116211
0.21048785746097565
0.006559005938470364
0.5681865811347961
0.45923325419425964
0.2734726369380951
0.3215053677558899
0.03682873025536537
0.034420110285282135
0.12297236919403076
0.1282421350479126
0.07432147860527039
0.09372133016586304
0.12445665150880814
0.41359594464302063
0.04007003456354141
0.13462449610233307
0.1471211463212967
0.18839342892169952
0.19460438191890717
0.09114271402359009
0.3076449930667877
0.06054165586

0.008524483069777489
0.0060265096835792065
0.019490867853164673
0.00926119927316904
0.011075212620198727
0.00602660933509469
0.014982569962739944
0.001839182572439313
0.00434720516204834
0.011929601430892944
0.02864961326122284
0.02892191708087921
0.009643606841564178
0.025458337739109993
0.006993026938289404
0.009226109832525253
0.042218729853630066
0.005450740456581116
0.007474598474800587
0.02637634240090847
0.00646638497710228
0.00850814487785101
0.005929224193096161
0.010393024422228336
0.029844144359230995
0.023951200768351555
0.006033667828887701
0.02827455848455429
0.014034762047231197
0.009907391853630543
0.009822647087275982
0.019551541656255722
0.008195936679840088
0.002332197967916727
0.003386708442121744
0.003961073700338602
0.007814859040081501
0.01554364338517189
0.010939390398561954
0.009754423052072525
0.007791173178702593
0.014908190816640854
0.023313503712415695
0.0015647344989702106
0.022021586075425148
0.0038033376913517714
0.016731755807995796
0.014090687036514282

0.016511833295226097
0.025331707671284676
0.06680856645107269
0.01897304132580757
0.008223350159823895
0.0038310254458338022
0.005613564979285002
0.010023600421845913
0.014081308618187904
0.0061675324104726315
0.021196186542510986
0.008584672585129738
0.012202140875160694
0.01651420071721077
0.007499828934669495
0.020380180329084396
0.011038467288017273
0.005475912243127823
0.015818800777196884
0.00431902427226305
0.0037221666425466537
0.01586531102657318
0.005793592892587185
0.014033181592822075
0.010315469466149807
0.0025696761440485716
0.015321834944188595
0.04293813928961754
0.003489760449156165
0.005191804841160774
0.0046266126446425915
0.025791913270950317
0.011412052437663078
0.0033606106881052256
0.002844952279701829
0.012709097005426884
0.010381859727203846
0.005423072259873152
0.015298000536859035
0.006850134581327438
0.014178628101944923
0.004452697467058897
0.00808478519320488
0.016539962962269783
0.01670709066092968
0.009044595994055271
0.0054281409829854965
0.008846147917

0.0035127298906445503
0.007406302727758884
0.0026416215114295483
0.0035738658625632524
0.004657617770135403
0.003733802353963256
0.0021069166250526905
0.0017370638670399785
0.01761709526181221
0.010879882611334324
0.011710298247635365
0.0035353777930140495
0.01082172803580761
0.005433331243693829
0.010520368814468384
0.001796526717953384
0.005971373524516821
0.002784517128020525
0.007618854288011789
0.011031553149223328
0.004391290247440338
0.005325491540133953
0.009914989583194256
0.004070695489645004
0.008624518290162086
0.011745908297598362
0.0007638962124474347
0.0074148052372038364
0.007941646501421928
0.007621360942721367
0.04915997385978699
0.004558525513857603
0.008709577843546867
0.01353747770190239
0.01443228218704462
0.01893238164484501
0.0087432861328125
0.014917821623384953
0.012750356458127499
0.004661726299673319
0.004913561977446079
0.012617279775440693
0.0020278748124837875
0.009993262588977814
0.007930261082947254
0.005086745135486126
0.008407777175307274
0.0137792788

0.0035858580376952887
0.009646500460803509
0.0041984496638178825
0.00455086026340723
0.009631025604903698
0.005118814297020435
0.007081034127622843
0.0030988557264208794
0.005152276251465082
0.0014091301709413528
0.012987017631530762
0.007406999357044697
0.004824000410735607
0.0027683181688189507
0.007071591913700104
0.0039051929488778114
0.0022414454724639654
0.005491372663527727
0.010069184936583042
0.0016910206759348512
0.0029978197999298573
0.024240199476480484
0.0046516903676092625
0.012691578827798367
0.002973714377731085
0.01146464329212904
0.00221817079000175
0.009493466466665268
0.0022522900253534317
0.000913429306820035
0.00792866200208664
0.003167755901813507
0.0069471849128603935
0.0019773882813751698
0.0024937614798545837
0.00471853744238615
0.00258023664355278
0.007847490720450878
0.00438353605568409
0.003777184057980776
0.0008893245831131935
0.004907272756099701
0.004732300993055105
0.012848308309912682
0.007336975075304508
0.0047832876443862915
0.00460984418168664
0.005

0.006191398482769728
0.0025100759230554104
0.003783815074712038
0.00509241595864296
0.003907621838152409
0.0023043497931212187
0.0010109972208738327
0.005014681722968817
0.006981280632317066
0.0017610379727557302
0.005971503909677267
0.0076512801460921764
0.0021468589548021555
0.02218821458518505
0.004132159985601902
0.0014276237925514579
0.0033078298438340425
0.008615161292254925
0.0032271072268486023
0.004613752476871014
0.004508946090936661
0.00888961274176836
0.002041904255747795
0.0031442558392882347
0.0019966347608715296
0.006178815849125385
0.00418558344244957
0.005264692939817905
0.0007159247761592269
0.004979354329407215
0.0028755527455359697
0.002900373423472047
0.0036319217178970575
0.0007692275103181601
0.004695891868323088
0.0025188149884343147
0.0038598808459937572
0.008860514499247074
0.0029069199226796627
0.0064393142238259315
0.0009433201048523188
0.004351502750068903
0.005493715405464172
0.0011059256503358483
0.008060324005782604
0.005616515409201384
0.004513342399150

0.04242749512195587
0.033382657915353775
0.07118983566761017
0.0827215164899826
0.04037998989224434
0.12629356980323792
0.14047157764434814
0.07021673023700714
0.052445173263549805
0.08658028393983841
0.010412351228296757
0.029853777959942818
0.14732711017131805
0.07885433733463287
0.07217275351285934
0.1337275207042694
0.023652605712413788
0.04616214334964752
0.028801796957850456
0.13373258709907532
0.1541370153427124
0.27862289547920227
0.11116516590118408
0.006460594013333321
0.054501745849847794
0.0991128608584404
0.05119606480002403
0.030759314075112343
0.007546392269432545
0.00873345322906971
0.02648678608238697
0.01715695671737194
0.015202734619379044
0.041848085820674896
0.003874892368912697
0.05857585743069649
0.06400841474533081
0.1040920689702034
0.08369893580675125
0.00859948992729187
0.027966752648353577
0.021848170086741447
0.03355612978339195
0.10505180060863495
0.00929433573037386
0.03648892790079117
0.012226101942360401
0.0172540582716465
0.007047575432807207
0.0651077

0.011048927903175354
0.02049926295876503
0.0192766934633255
0.004246181342750788
0.02249154821038246
0.005623652134090662
0.017010614275932312
0.025126826018095016
0.023616893216967583
0.02014896087348461
0.03785646706819534
0.023407189175486565
0.025725871324539185
0.026625510305166245
0.002988518215715885
0.010047724470496178
0.046360187232494354
0.0159921832382679
0.014795425347983837
0.016498861834406853
0.012592119164764881
0.008565571159124374
0.020253527909517288
0.023858724161982536
0.030686447396874428
0.00524491909891367
0.03171631321310997
0.010371473617851734
0.02182735688984394
0.01450536958873272
0.027243301272392273
0.009818167425692081
0.008266683667898178
0.005600667558610439
0.010403357446193695
0.007805254310369492
0.021775376051664352
0.00542974378913641
0.018463874235749245
0.0064019085839390755
0.029064089059829712
0.01226307637989521
0.012999877333641052
0.03722064569592476
0.006519939750432968
0.011712496168911457
0.02237655594944954
0.08926118910312653
0.017428

0.012700241059064865
0.0072769708931446075
0.02208840288221836
0.0061354548670351505
0.009890353307127953
0.006111487280577421
0.007310853805392981
0.005336598493158817
0.004400551784783602
0.00938701257109642
0.008240102790296078
0.010109173133969307
0.0035625859163701534
0.0030979644507169724
0.012520362623035908
0.004840687848627567
0.010102109983563423
0.011690709739923477
0.006551822647452354
0.005981164518743753
0.009256815537810326
0.009134234860539436
0.014301485382020473
0.0042741261422634125
0.0048127626068890095
0.01682751253247261
0.010757355019450188
0.023121381178498268
0.010878346860408783
0.007341640070080757
0.00747310183942318
0.007758519612252712
0.0059570493176579475
0.010887231677770615
0.013348022475838661
0.009049959480762482
0.03406097739934921
0.03230304270982742
0.008060231804847717
0.014674875885248184
0.0034822376910597086
0.012024002149701118
0.008914327248930931
0.003321663476526737
0.006000513210892677
0.007027013227343559
0.008802108466625214
0.010929023

0.007616711780428886
0.0196426622569561
0.006727935746312141
0.013407467864453793
0.00660160044208169
0.008890656754374504
0.019753815606236458
0.001795138232409954
0.007660809904336929
0.0223041120916605
0.014287808910012245
0.0065799374133348465
0.013180097565054893
0.012139931321144104
0.0058303941041231155
0.013481659814715385
0.0036215169820934534
0.00459866039454937
0.004587160889059305
0.005906959529966116
0.003933819010853767
0.014456981793045998
0.014607510529458523
0.0016121792141348124
0.01069085393100977
0.009802588261663914
0.003832664806395769
0.014279499650001526
0.005777076352387667
0.010868273675441742
0.005818080622702837
0.006597977597266436
0.007451415993273258
0.010123195126652718
0.006555203348398209
0.002836038824170828
0.01171757373958826
0.014128648675978184
0.010275036096572876
0.0028510219417512417
0.009197695180773735
0.0059517561458051205
0.013461878523230553
0.009843495674431324
0.011285119690001011
0.0036744496319442987
0.006039517931640148
0.013300928287

In [77]:
y_hat.squeeze().shape

torch.Size([16])

In [67]:
y.shape

torch.Size([16])