In [1]:
import torch as t
from torch.autograd import Variable as Var
import torch.nn.functional as F

In [2]:
x_data = Var(t.Tensor([
    [2.1, 0.1],
    [4.2, 0.8],
    [3.1, 0.9],
    [3.3, 0.2]
]))
y_data = Var(t.Tensor([
    [0.0],
    [1.0],
    [0.0],
    [1.0]
]))

$ XW = \hat{y} $

$\begin{bmatrix}
  a_1 & b_1\\
  a_2 & b_2 \\
  ... \\
  a_n & b_n
\end{bmatrix} 
\begin{bmatrix}
w_1 \\  w_2
\end{bmatrix}
=
\begin{bmatrix}
 y_1 \\
 ... \\
 y_n
\end{bmatrix}
$

Remember, it's a matrix of features dot product a matrix of weights to make a set of predictions

In the example above, $ a_1 \&  b_1 $ are the row vector and weights are the column vector

Multiple features is considered **wide** input

In [3]:
class Model(t.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = t.nn.Linear(2, 1) # isn't it a 2 x 1?
    def forward(self, x_data):
        y_pred = F.sigmoid(self.linear(x_data))
        return y_pred

In [4]:
model = Model()
criterion = t.nn.BCELoss(size_average=True)
optimizer = t.optim.SGD(model.parameters(), lr = 0.01) 

In [5]:
for epoch in range(1000):
    y_pred = model.forward(x_data)
    loss = criterion(y_pred, y_data)
    print(f'epoch = {epoch}, loss = {loss.item()}') # this deviates from the lecture
    optimizer.zero_grad()
    loss.backward() # calc for optimizer
    optimizer.step() # apply to model

epoch = 0, loss = 1.5943940877914429
epoch = 1, loss = 1.5640308856964111
epoch = 2, loss = 1.5340584516525269
epoch = 3, loss = 1.5044918060302734
epoch = 4, loss = 1.4753459692001343
epoch = 5, loss = 1.4466360807418823
epoch = 6, loss = 1.418377161026001
epoch = 7, loss = 1.390583872795105
epoch = 8, loss = 1.3632711172103882
epoch = 9, loss = 1.3364529609680176
epoch = 10, loss = 1.3101433515548706
epoch = 11, loss = 1.2843561172485352
epoch = 12, loss = 1.2591036558151245
epoch = 13, loss = 1.2343984842300415
epoch = 14, loss = 1.210252046585083
epoch = 15, loss = 1.186674952507019
epoch = 16, loss = 1.1636772155761719
epoch = 17, loss = 1.1412672996520996
epoch = 18, loss = 1.119452953338623
epoch = 19, loss = 1.098240852355957
epoch = 20, loss = 1.0776362419128418
epoch = 21, loss = 1.057643175125122
epoch = 22, loss = 1.0382643938064575
epoch = 23, loss = 1.0195013284683228
epoch = 24, loss = 1.0013542175292969
epoch = 25, loss = 0.9838215708732605
epoch = 26, loss = 0.96690100

epoch = 399, loss = 0.6025914549827576
epoch = 400, loss = 0.6025183796882629
epoch = 401, loss = 0.6024453043937683
epoch = 402, loss = 0.6023722887039185
epoch = 403, loss = 0.6022992730140686
epoch = 404, loss = 0.6022262573242188
epoch = 405, loss = 0.6021533012390137
epoch = 406, loss = 0.6020803451538086
epoch = 407, loss = 0.6020073294639587
epoch = 408, loss = 0.6019343733787537
epoch = 409, loss = 0.6018615365028381
epoch = 410, loss = 0.6017885208129883
epoch = 411, loss = 0.601715624332428
epoch = 412, loss = 0.6016427874565125
epoch = 413, loss = 0.6015698909759521
epoch = 414, loss = 0.6014969348907471
epoch = 415, loss = 0.6014241576194763
epoch = 416, loss = 0.6013513207435608
epoch = 417, loss = 0.6012784838676453
epoch = 418, loss = 0.6012056469917297
epoch = 419, loss = 0.6011328101158142
epoch = 420, loss = 0.6010600328445435
epoch = 421, loss = 0.6009872555732727
epoch = 422, loss = 0.600914478302002
epoch = 423, loss = 0.6008417010307312
epoch = 424, loss = 0.60076

epoch = 684, loss = 0.5822793245315552
epoch = 685, loss = 0.5822097659111023
epoch = 686, loss = 0.5821402668952942
epoch = 687, loss = 0.5820708274841309
epoch = 688, loss = 0.5820013880729675
epoch = 689, loss = 0.5819318890571594
epoch = 690, loss = 0.5818624496459961
epoch = 691, loss = 0.5817930102348328
epoch = 692, loss = 0.5817235708236694
epoch = 693, loss = 0.5816541910171509
epoch = 694, loss = 0.5815847516059875
epoch = 695, loss = 0.581515371799469
epoch = 696, loss = 0.5814460515975952
epoch = 697, loss = 0.5813766717910767
epoch = 698, loss = 0.5813072919845581
epoch = 699, loss = 0.5812379121780396
epoch = 700, loss = 0.5811685919761658
epoch = 701, loss = 0.581099271774292
epoch = 702, loss = 0.5810299515724182
epoch = 703, loss = 0.5809606313705444
epoch = 704, loss = 0.5808913707733154
epoch = 705, loss = 0.5808221101760864
epoch = 706, loss = 0.5807528495788574
epoch = 707, loss = 0.5806835889816284
epoch = 708, loss = 0.5806143283843994
epoch = 709, loss = 0.58054

In [6]:
hour_pred = Var(t.Tensor([[1.0, 0]]))
print(f"Prediction for 1 hour {model.forward(hour_pred)}, closer to 0 is better")

Prediction for 1 hour tensor([[ 0.3641]]), closer to 0 is better


In [7]:
hour_pred = Var(t.Tensor([[3.3, 0.2]]))
print(f"Prediction for 3 hour {model.forward(hour_pred)}, closer to 1 is better")

Prediction for 3 hour tensor([[ 0.6321]]), closer to 1 is better


### Deep

Like the idea of a sigmoid feeding into another sigmoid via interim linear models

Remember, since sigmoids squash to [0,1], the variances or numbers (gradients) can dissappear randomly or appear to since they get really smapp (Vanishing Gradient Problem)

In [8]:
class ModelDeep(t.nn.Module):
    def __init__(self):
        super(ModelDeep, self).__init__()
        self.linear_1 = t.nn.Linear(2, 4) ## nn input sizes
        self.linear_2 = t.nn.Linear(4, 3)
        self.linear_3 = t.nn.Linear(3, 1)
    def forward(self, x_data):
        out_1 = F.sigmoid(self.linear_1(x_data))
        out_2 = F.sigmoid(self.linear_2(out_1))
        return F.sigmoid(self.linear_3(out_2))

In [9]:
model_deep = ModelDeep()
criterion_deep = t.nn.BCELoss(size_average=True)
optimizer_deep = t.optim.SGD(model_deep.parameters(), lr = 0.5)

In [10]:
for epoch in range(1000):
    y_pred = model_deep.forward(x_data)
    loss = criterion_deep(y_pred, y_data)
    print(f'epoch = {epoch}, loss = {loss.item()}') # this deviates from the lecture
    optimizer_deep.zero_grad()
    loss.backward()
    optimizer_deep.step()

epoch = 0, loss = 0.7187903523445129
epoch = 1, loss = 0.7101096510887146
epoch = 2, loss = 0.704725980758667
epoch = 3, loss = 0.7013985514640808
epoch = 4, loss = 0.6993435621261597
epoch = 5, loss = 0.6980718970298767
epoch = 6, loss = 0.6972812414169312
epoch = 7, loss = 0.6967856884002686
epoch = 8, loss = 0.6964707970619202
epoch = 9, loss = 0.6962667107582092
epoch = 10, loss = 0.6961303949356079
epoch = 11, loss = 0.6960357427597046
epoch = 12, loss = 0.6959667801856995
epoch = 13, loss = 0.695913553237915
epoch = 14, loss = 0.695870041847229
epoch = 15, loss = 0.6958326101303101
epoch = 16, loss = 0.6957989931106567
epoch = 17, loss = 0.6957678198814392
epoch = 18, loss = 0.6957381963729858
epoch = 19, loss = 0.695709764957428
epoch = 20, loss = 0.6956818699836731
epoch = 21, loss = 0.6956546902656555
epoch = 22, loss = 0.6956279277801514
epoch = 23, loss = 0.6956015825271606
epoch = 24, loss = 0.695575475692749
epoch = 25, loss = 0.6955496072769165
epoch = 26, loss = 0.695524

epoch = 226, loss = 0.6926935315132141
epoch = 227, loss = 0.692681610584259
epoch = 228, loss = 0.692669689655304
epoch = 229, loss = 0.6926577091217041
epoch = 230, loss = 0.6926457285881042
epoch = 231, loss = 0.6926335692405701
epoch = 232, loss = 0.6926214098930359
epoch = 233, loss = 0.6926091909408569
epoch = 234, loss = 0.692596971988678
epoch = 235, loss = 0.6925845146179199
epoch = 236, loss = 0.6925721168518066
epoch = 237, loss = 0.6925595998764038
epoch = 238, loss = 0.692547082901001
epoch = 239, loss = 0.6925343871116638
epoch = 240, loss = 0.6925216913223267
epoch = 241, loss = 0.6925089955329895
epoch = 242, loss = 0.692496120929718
epoch = 243, loss = 0.6924832463264465
epoch = 244, loss = 0.6924701929092407
epoch = 245, loss = 0.6924571394920349
epoch = 246, loss = 0.6924439072608948
epoch = 247, loss = 0.6924306750297546
epoch = 248, loss = 0.6924173831939697
epoch = 249, loss = 0.6924039125442505
epoch = 250, loss = 0.6923903822898865
epoch = 251, loss = 0.69237691

epoch = 475, loss = 0.6800057291984558
epoch = 476, loss = 0.6798415780067444
epoch = 477, loss = 0.6796753406524658
epoch = 478, loss = 0.6795070171356201
epoch = 479, loss = 0.6793365478515625
epoch = 480, loss = 0.6791638135910034
epoch = 481, loss = 0.6789889335632324
epoch = 482, loss = 0.6788119077682495
epoch = 483, loss = 0.6786324977874756
epoch = 484, loss = 0.6784508228302002
epoch = 485, loss = 0.6782668828964233
epoch = 486, loss = 0.6780804991722107
epoch = 487, loss = 0.6778917908668518
epoch = 488, loss = 0.6777007579803467
epoch = 489, loss = 0.6775071024894714
epoch = 490, loss = 0.6773110032081604
epoch = 491, loss = 0.6771124601364136
epoch = 492, loss = 0.6769114136695862
epoch = 493, loss = 0.6767076849937439
epoch = 494, loss = 0.6765013337135315
epoch = 495, loss = 0.6762923002243042
epoch = 496, loss = 0.6760806441307068
epoch = 497, loss = 0.6758662462234497
epoch = 498, loss = 0.6756490468978882
epoch = 499, loss = 0.6754290461540222
epoch = 500, loss = 0.675

epoch = 758, loss = 0.16230924427509308
epoch = 759, loss = 0.15968969464302063
epoch = 760, loss = 0.15712185204029083
epoch = 761, loss = 0.15460506081581116
epoch = 762, loss = 0.15213854610919952
epoch = 763, loss = 0.14972156286239624
epoch = 764, loss = 0.14735335111618042
epoch = 765, loss = 0.1450331062078476
epoch = 766, loss = 0.1427600383758545
epoch = 767, loss = 0.14053335785865784
epoch = 768, loss = 0.13835221529006958
epoch = 769, loss = 0.13621582090854645
epoch = 770, loss = 0.134123295545578
epoch = 771, loss = 0.13207387924194336
epoch = 772, loss = 0.13006669282913208
epoch = 773, loss = 0.12810087203979492
epoch = 774, loss = 0.1261756867170334
epoch = 775, loss = 0.12429023534059525
epoch = 776, loss = 0.1224437728524208
epoch = 777, loss = 0.12063539028167725
epoch = 778, loss = 0.11886443197727203
epoch = 779, loss = 0.11712991446256638
epoch = 780, loss = 0.11543124169111252
epoch = 781, loss = 0.11376748979091644
epoch = 782, loss = 0.11213796585798264
epoch 

In [11]:
hour_pred = Var(t.Tensor([[1.0, 0]]))
print(f"Prediction for 1 hour {model_deep.forward(hour_pred)}, closer to 0 is better")

Prediction for 1 hour tensor(1.00000e-03 *
       [[ 4.0783]]), closer to 0 is better


In [12]:
hour_pred = Var(t.Tensor([[3.3, 0.2]]))
print(f"Prediction for 3 hour {model_deep.forward(hour_pred)}, closer to 1 is better")

Prediction for 3 hour tensor([[ 0.9765]]), closer to 1 is better


### Activation Functions
https://dashee87.github.io/data%20science/deep%20learning/visualising-activation-functions-in-neural-networks/

### Diabetes Example

In [13]:
import numpy as np

In [88]:
all_data = np.loadtxt('./data/diabetes.csv', delimiter=',', dtype=np.float32)
x_data = Var(t.from_numpy(all_data[:, 0:-1])) # this goes to float32 by default
y_data = Var(t.from_numpy(all_data[:, -1]))

In [89]:
class ModelDiabetes(t.nn.Module):
    def __init__(self):
        super(ModelDiabetes, self).__init__()
        self.l1 = t.nn.Linear(8, 6)
        self.l2 = t.nn.Linear(6, 4)
        self.l3 = t.nn.Linear(4, 1)
        self.sigmoid = t.nn.Sigmoid()
    def forward(self, x):
        out_1 = self.sigmoid(self.l1(x))
        out_2 = self.sigmoid(self.l2(out_1))
        return self.sigmoid(self.l3(out_2))

In [90]:
model_diabetes = ModelDiabetes()
cri_diabetes = t.nn.BCELoss(size_average=True)
opt_diabetes = t.optim.SGD(model_diabetes.parameters(), lr=0.1)

for epoch in range(10000):
    y_pred = model_diabetes.forward(x_data)
    loss = cri_diabetes(y_pred, y_data)
    print(f'epoch = {epoch}, loss = {loss.item()}') # this deviates from the lecture
    opt_diabetes.zero_grad()
    loss.backward()
    opt_diabetes.step()

  "Please ensure they have the same size.".format(target.size(), input.size()))


epoch = 0, loss = 0.6953448057174683
epoch = 1, loss = 0.6909804344177246
epoch = 2, loss = 0.686992883682251
epoch = 3, loss = 0.6833484172821045
epoch = 4, loss = 0.6800180673599243
epoch = 5, loss = 0.6769755482673645
epoch = 6, loss = 0.6741954684257507
epoch = 7, loss = 0.671654999256134
epoch = 8, loss = 0.6693329811096191
epoch = 9, loss = 0.6672117710113525
epoch = 10, loss = 0.66527259349823
epoch = 11, loss = 0.6635007858276367
epoch = 12, loss = 0.6618807315826416
epoch = 13, loss = 0.6604002714157104
epoch = 14, loss = 0.65904700756073
epoch = 15, loss = 0.6578095555305481
epoch = 16, loss = 0.656677782535553
epoch = 17, loss = 0.6556432843208313
epoch = 18, loss = 0.6546969413757324
epoch = 19, loss = 0.6538311243057251
epoch = 20, loss = 0.653039813041687
epoch = 21, loss = 0.6523155570030212
epoch = 22, loss = 0.6516529321670532
epoch = 23, loss = 0.6510458588600159
epoch = 24, loss = 0.6504917740821838
epoch = 25, loss = 0.6499834060668945
epoch = 26, loss = 0.649518430

epoch = 250, loss = 0.6440698504447937
epoch = 251, loss = 0.6440677642822266
epoch = 252, loss = 0.6440660357475281
epoch = 253, loss = 0.6440641283988953
epoch = 254, loss = 0.644061803817749
epoch = 255, loss = 0.6440593600273132
epoch = 256, loss = 0.6440576314926147
epoch = 257, loss = 0.6440560221672058
epoch = 258, loss = 0.6440541744232178
epoch = 259, loss = 0.6440519094467163
epoch = 260, loss = 0.644049346446991
epoch = 261, loss = 0.6440470814704895
epoch = 262, loss = 0.6440452337265015
epoch = 263, loss = 0.6440433859825134
epoch = 264, loss = 0.6440420150756836
epoch = 265, loss = 0.644040048122406
epoch = 266, loss = 0.6440379023551941
epoch = 267, loss = 0.6440356373786926
epoch = 268, loss = 0.6440333724021912
epoch = 269, loss = 0.644031822681427
epoch = 270, loss = 0.6440297961235046
epoch = 271, loss = 0.6440277695655823
epoch = 272, loss = 0.6440257430076599
epoch = 273, loss = 0.6440236568450928
epoch = 274, loss = 0.6440216898918152
epoch = 275, loss = 0.6440193

epoch = 728, loss = 0.642824649810791
epoch = 729, loss = 0.6428210139274597
epoch = 730, loss = 0.6428170204162598
epoch = 731, loss = 0.6428135633468628
epoch = 732, loss = 0.6428101658821106
epoch = 733, loss = 0.6428067684173584
epoch = 734, loss = 0.6428033113479614
epoch = 735, loss = 0.6428002119064331
epoch = 736, loss = 0.6427968740463257
epoch = 737, loss = 0.6427931189537048
epoch = 738, loss = 0.6427889466285706
epoch = 739, loss = 0.6427860260009766
epoch = 740, loss = 0.6427826285362244
epoch = 741, loss = 0.6427789926528931
epoch = 742, loss = 0.6427758932113647
epoch = 743, loss = 0.6427717804908752
epoch = 744, loss = 0.6427682042121887
epoch = 745, loss = 0.6427649855613708
epoch = 746, loss = 0.6427612900733948
epoch = 747, loss = 0.642757773399353
epoch = 748, loss = 0.6427539587020874
epoch = 749, loss = 0.6427509188652039
epoch = 750, loss = 0.64274662733078
epoch = 751, loss = 0.6427433490753174
epoch = 752, loss = 0.6427401900291443
epoch = 753, loss = 0.6427361

epoch = 1073, loss = 0.6412588357925415
epoch = 1074, loss = 0.6412528157234192
epoch = 1075, loss = 0.6412472724914551
epoch = 1076, loss = 0.6412404775619507
epoch = 1077, loss = 0.6412351131439209
epoch = 1078, loss = 0.6412289142608643
epoch = 1079, loss = 0.6412227749824524
epoch = 1080, loss = 0.6412173509597778
epoch = 1081, loss = 0.6412110924720764
epoch = 1082, loss = 0.6412050127983093
epoch = 1083, loss = 0.6411992907524109
epoch = 1084, loss = 0.641193151473999
epoch = 1085, loss = 0.6411871910095215
epoch = 1086, loss = 0.6411814093589783
epoch = 1087, loss = 0.6411749720573425
epoch = 1088, loss = 0.6411691904067993
epoch = 1089, loss = 0.6411629915237427
epoch = 1090, loss = 0.6411568522453308
epoch = 1091, loss = 0.641150951385498
epoch = 1092, loss = 0.6411446332931519
epoch = 1093, loss = 0.6411386132240295
epoch = 1094, loss = 0.6411322355270386
epoch = 1095, loss = 0.6411263346672058
epoch = 1096, loss = 0.6411200761795044
epoch = 1097, loss = 0.6411139369010925
ep

epoch = 1301, loss = 0.6396090984344482
epoch = 1302, loss = 0.6396000385284424
epoch = 1303, loss = 0.6395912766456604
epoch = 1304, loss = 0.6395822167396545
epoch = 1305, loss = 0.6395733952522278
epoch = 1306, loss = 0.6395646929740906
epoch = 1307, loss = 0.6395552158355713
epoch = 1308, loss = 0.6395471692085266
epoch = 1309, loss = 0.6395375728607178
epoch = 1310, loss = 0.639528751373291
epoch = 1311, loss = 0.6395193934440613
epoch = 1312, loss = 0.6395107507705688
epoch = 1313, loss = 0.639501690864563
epoch = 1314, loss = 0.6394929885864258
epoch = 1315, loss = 0.6394837498664856
epoch = 1316, loss = 0.6394748687744141
epoch = 1317, loss = 0.6394657492637634
epoch = 1318, loss = 0.6394559741020203
epoch = 1319, loss = 0.6394472122192383
epoch = 1320, loss = 0.6394380331039429
epoch = 1321, loss = 0.6394286155700684
epoch = 1322, loss = 0.6394198536872864
epoch = 1323, loss = 0.6394105553627014
epoch = 1324, loss = 0.6394014358520508
epoch = 1325, loss = 0.6393921375274658
ep

epoch = 1550, loss = 0.6367819309234619
epoch = 1551, loss = 0.6367677450180054
epoch = 1552, loss = 0.6367529630661011
epoch = 1553, loss = 0.6367389559745789
epoch = 1554, loss = 0.6367238759994507
epoch = 1555, loss = 0.6367095708847046
epoch = 1556, loss = 0.6366952061653137
epoch = 1557, loss = 0.6366803050041199
epoch = 1558, loss = 0.6366656422615051
epoch = 1559, loss = 0.6366509199142456
epoch = 1560, loss = 0.6366361379623413
epoch = 1561, loss = 0.636621356010437
epoch = 1562, loss = 0.6366068720817566
epoch = 1563, loss = 0.6365919709205627
epoch = 1564, loss = 0.6365769505500793
epoch = 1565, loss = 0.636562168598175
epoch = 1566, loss = 0.6365470886230469
epoch = 1567, loss = 0.6365323662757874
epoch = 1568, loss = 0.6365170478820801
epoch = 1569, loss = 0.636502206325531
epoch = 1570, loss = 0.6364868879318237
epoch = 1571, loss = 0.6364719867706299
epoch = 1572, loss = 0.6364570260047913
epoch = 1573, loss = 0.6364419460296631
epoch = 1574, loss = 0.6364266276359558
epo

epoch = 1983, loss = 0.62624591588974
epoch = 1984, loss = 0.6262066960334778
epoch = 1985, loss = 0.626167893409729
epoch = 1986, loss = 0.6261286735534668
epoch = 1987, loss = 0.6260894536972046
epoch = 1988, loss = 0.6260504126548767
epoch = 1989, loss = 0.6260107159614563
epoch = 1990, loss = 0.6259715557098389
epoch = 1991, loss = 0.6259319186210632
epoch = 1992, loss = 0.625892698764801
epoch = 1993, loss = 0.6258525848388672
epoch = 1994, loss = 0.6258127689361572
epoch = 1995, loss = 0.6257733702659607
epoch = 1996, loss = 0.6257331371307373
epoch = 1997, loss = 0.6256927251815796
epoch = 1998, loss = 0.6256526708602905
epoch = 1999, loss = 0.6256126165390015
epoch = 2000, loss = 0.6255720257759094
epoch = 2001, loss = 0.6255317330360413
epoch = 2002, loss = 0.6254908442497253
epoch = 2003, loss = 0.6254494786262512
epoch = 2004, loss = 0.6254093647003174
epoch = 2005, loss = 0.6253687143325806
epoch = 2006, loss = 0.625327467918396
epoch = 2007, loss = 0.6252857446670532
epoch

epoch = 2276, loss = 0.6095902919769287
epoch = 2277, loss = 0.6095113158226013
epoch = 2278, loss = 0.6094306707382202
epoch = 2279, loss = 0.6093516945838928
epoch = 2280, loss = 0.6092715263366699
epoch = 2281, loss = 0.6091910600662231
epoch = 2282, loss = 0.6091107130050659
epoch = 2283, loss = 0.6090295910835266
epoch = 2284, loss = 0.6089494228363037
epoch = 2285, loss = 0.6088681221008301
epoch = 2286, loss = 0.6087870597839355
epoch = 2287, loss = 0.6087055206298828
epoch = 2288, loss = 0.6086238622665405
epoch = 2289, loss = 0.6085425019264221
epoch = 2290, loss = 0.608460545539856
epoch = 2291, loss = 0.6083782315254211
epoch = 2292, loss = 0.6082960963249207
epoch = 2293, loss = 0.6082136631011963
epoch = 2294, loss = 0.6081307530403137
epoch = 2295, loss = 0.6080474257469177
epoch = 2296, loss = 0.6079649329185486
epoch = 2297, loss = 0.607881486415863
epoch = 2298, loss = 0.6077978014945984
epoch = 2299, loss = 0.6077145338058472
epoch = 2300, loss = 0.6076302528381348
ep

epoch = 2562, loss = 0.5785487294197083
epoch = 2563, loss = 0.578410804271698
epoch = 2564, loss = 0.5782734155654907
epoch = 2565, loss = 0.5781355500221252
epoch = 2566, loss = 0.5779969096183777
epoch = 2567, loss = 0.5778588652610779
epoch = 2568, loss = 0.5777204036712646
epoch = 2569, loss = 0.577582061290741
epoch = 2570, loss = 0.5774428248405457
epoch = 2571, loss = 0.5773041844367981
epoch = 2572, loss = 0.5771650075912476
epoch = 2573, loss = 0.5770263075828552
epoch = 2574, loss = 0.5768863558769226
epoch = 2575, loss = 0.5767462849617004
epoch = 2576, loss = 0.5766074657440186
epoch = 2577, loss = 0.5764670968055725
epoch = 2578, loss = 0.5763269662857056
epoch = 2579, loss = 0.5761865973472595
epoch = 2580, loss = 0.5760460495948792
epoch = 2581, loss = 0.5759055614471436
epoch = 2582, loss = 0.5757648348808289
epoch = 2583, loss = 0.5756239891052246
epoch = 2584, loss = 0.5754827857017517
epoch = 2585, loss = 0.5753414034843445
epoch = 2586, loss = 0.5751996636390686
ep

epoch = 2836, loss = 0.5373849868774414
epoch = 2837, loss = 0.5372353196144104
epoch = 2838, loss = 0.5370859503746033
epoch = 2839, loss = 0.5369371175765991
epoch = 2840, loss = 0.5367876887321472
epoch = 2841, loss = 0.5366390943527222
epoch = 2842, loss = 0.5364903211593628
epoch = 2843, loss = 0.536341667175293
epoch = 2844, loss = 0.5361931920051575
epoch = 2845, loss = 0.5360447764396667
epoch = 2846, loss = 0.5358958840370178
epoch = 2847, loss = 0.5357478260993958
epoch = 2848, loss = 0.5355998873710632
epoch = 2849, loss = 0.5354517698287964
epoch = 2850, loss = 0.5353043079376221
epoch = 2851, loss = 0.5351564288139343
epoch = 2852, loss = 0.5350090861320496
epoch = 2853, loss = 0.5348612070083618
epoch = 2854, loss = 0.534713864326477
epoch = 2855, loss = 0.5345665216445923
epoch = 2856, loss = 0.5344195365905762
epoch = 2857, loss = 0.5342726707458496
epoch = 2858, loss = 0.5341259837150574
epoch = 2859, loss = 0.5339796543121338
epoch = 2860, loss = 0.533832848072052
epo

epoch = 3080, loss = 0.5060200691223145
epoch = 3081, loss = 0.50591641664505
epoch = 3082, loss = 0.5058139562606812
epoch = 3083, loss = 0.5057114362716675
epoch = 3084, loss = 0.5056092143058777
epoch = 3085, loss = 0.5055068731307983
epoch = 3086, loss = 0.5054052472114563
epoch = 3087, loss = 0.5053030848503113
epoch = 3088, loss = 0.505202054977417
epoch = 3089, loss = 0.5051005482673645
epoch = 3090, loss = 0.5049999356269836
epoch = 3091, loss = 0.5048990249633789
epoch = 3092, loss = 0.5047981142997742
epoch = 3093, loss = 0.5046980381011963
epoch = 3094, loss = 0.5045979619026184
epoch = 3095, loss = 0.5044981241226196
epoch = 3096, loss = 0.5043983459472656
epoch = 3097, loss = 0.5042992234230042
epoch = 3098, loss = 0.5042001605033875
epoch = 3099, loss = 0.5041009187698364
epoch = 3100, loss = 0.504001796245575
epoch = 3101, loss = 0.5039036870002747
epoch = 3102, loss = 0.5038049221038818
epoch = 3103, loss = 0.503707230091095
epoch = 3104, loss = 0.503609299659729
epoch 

epoch = 3360, loss = 0.4851088225841522
epoch = 3361, loss = 0.48505890369415283
epoch = 3362, loss = 0.485008180141449
epoch = 3363, loss = 0.48495855927467346
epoch = 3364, loss = 0.48490896821022034
epoch = 3365, loss = 0.4848594069480896
epoch = 3366, loss = 0.4848093092441559
epoch = 3367, loss = 0.4847605526447296
epoch = 3368, loss = 0.48471128940582275
epoch = 3369, loss = 0.4846620559692383
epoch = 3370, loss = 0.48461323976516724
epoch = 3371, loss = 0.4845640957355499
epoch = 3372, loss = 0.4845157265663147
epoch = 3373, loss = 0.4844670295715332
epoch = 3374, loss = 0.48441851139068604
epoch = 3375, loss = 0.48437026143074036
epoch = 3376, loss = 0.4843221604824066
epoch = 3377, loss = 0.4842745065689087
epoch = 3378, loss = 0.48422712087631226
epoch = 3379, loss = 0.48417899012565613
epoch = 3380, loss = 0.48413151502609253
epoch = 3381, loss = 0.48408395051956177
epoch = 3382, loss = 0.48403650522232056
epoch = 3383, loss = 0.4839894473552704
epoch = 3384, loss = 0.483942

epoch = 3642, loss = 0.47540390491485596
epoch = 3643, loss = 0.4753822088241577
epoch = 3644, loss = 0.47536009550094604
epoch = 3645, loss = 0.4753381311893463
epoch = 3646, loss = 0.4753156304359436
epoch = 3647, loss = 0.47529399394989014
epoch = 3648, loss = 0.4752717912197113
epoch = 3649, loss = 0.4752500057220459
epoch = 3650, loss = 0.47522827982902527
epoch = 3651, loss = 0.4752061665058136
epoch = 3652, loss = 0.47518467903137207
epoch = 3653, loss = 0.47516319155693054
epoch = 3654, loss = 0.47514232993125916
epoch = 3655, loss = 0.4751204252243042
epoch = 3656, loss = 0.47509923577308655
epoch = 3657, loss = 0.4750778079032898
epoch = 3658, loss = 0.4750566780567169
epoch = 3659, loss = 0.4750348627567291
epoch = 3660, loss = 0.475013792514801
epoch = 3661, loss = 0.47499266266822815
epoch = 3662, loss = 0.474972128868103
epoch = 3663, loss = 0.47495102882385254
epoch = 3664, loss = 0.47493013739585876
epoch = 3665, loss = 0.4749090373516083
epoch = 3666, loss = 0.47488832

epoch = 3932, loss = 0.47096431255340576
epoch = 3933, loss = 0.47095373272895813
epoch = 3934, loss = 0.47094401717185974
epoch = 3935, loss = 0.47093382477760315
epoch = 3936, loss = 0.470923513174057
epoch = 3937, loss = 0.4709135591983795
epoch = 3938, loss = 0.4709032475948334
epoch = 3939, loss = 0.47089269757270813
epoch = 3940, loss = 0.4708830416202545
epoch = 3941, loss = 0.47087255120277405
epoch = 3942, loss = 0.4708627164363861
epoch = 3943, loss = 0.4708525240421295
epoch = 3944, loss = 0.4708424508571625
epoch = 3945, loss = 0.4708329439163208
epoch = 3946, loss = 0.47082310914993286
epoch = 3947, loss = 0.47081321477890015
epoch = 3948, loss = 0.4708031713962555
epoch = 3949, loss = 0.4707934260368347
epoch = 3950, loss = 0.47078341245651245
epoch = 3951, loss = 0.47077345848083496
epoch = 3952, loss = 0.47076350450515747
epoch = 3953, loss = 0.4707536995410919
epoch = 3954, loss = 0.4707443118095398
epoch = 3955, loss = 0.47073492407798767
epoch = 3956, loss = 0.470724

epoch = 4171, loss = 0.46905025839805603
epoch = 4172, loss = 0.4690438508987427
epoch = 4173, loss = 0.46903756260871887
epoch = 4174, loss = 0.46903157234191895
epoch = 4175, loss = 0.46902498602867126
epoch = 4176, loss = 0.4690185785293579
epoch = 4177, loss = 0.4690125286579132
epoch = 4178, loss = 0.4690062403678894
epoch = 4179, loss = 0.46900027990341187
epoch = 4180, loss = 0.4689939320087433
epoch = 4181, loss = 0.4689878225326538
epoch = 4182, loss = 0.4689820408821106
epoch = 4183, loss = 0.4689760208129883
epoch = 4184, loss = 0.46896934509277344
epoch = 4185, loss = 0.46896323561668396
epoch = 4186, loss = 0.4689573049545288
epoch = 4187, loss = 0.4689508378505707
epoch = 4188, loss = 0.4689449369907379
epoch = 4189, loss = 0.46893855929374695
epoch = 4190, loss = 0.4689326286315918
epoch = 4191, loss = 0.46892669796943665
epoch = 4192, loss = 0.4689209759235382
epoch = 4193, loss = 0.46891453862190247
epoch = 4194, loss = 0.4689083397388458
epoch = 4195, loss = 0.4689026

epoch = 4393, loss = 0.4678715765476227
epoch = 4394, loss = 0.46786683797836304
epoch = 4395, loss = 0.46786269545555115
epoch = 4396, loss = 0.467858225107193
epoch = 4397, loss = 0.4678536057472229
epoch = 4398, loss = 0.4678495228290558
epoch = 4399, loss = 0.46784472465515137
epoch = 4400, loss = 0.46784043312072754
epoch = 4401, loss = 0.4678356647491455
epoch = 4402, loss = 0.4678310453891754
epoch = 4403, loss = 0.4678265452384949
epoch = 4404, loss = 0.46782246232032776
epoch = 4405, loss = 0.4678175151348114
epoch = 4406, loss = 0.4678133726119995
epoch = 4407, loss = 0.4678087830543518
epoch = 4408, loss = 0.4678044021129608
epoch = 4409, loss = 0.46779966354370117
epoch = 4410, loss = 0.4677952229976654
epoch = 4411, loss = 0.46779105067253113
epoch = 4412, loss = 0.4677870571613312
epoch = 4413, loss = 0.4677821099758148
epoch = 4414, loss = 0.46777769923210144
epoch = 4415, loss = 0.46777310967445374
epoch = 4416, loss = 0.4677690267562866
epoch = 4417, loss = 0.467764496

epoch = 4640, loss = 0.46689626574516296
epoch = 4641, loss = 0.46689310669898987
epoch = 4642, loss = 0.4668892025947571
epoch = 4643, loss = 0.4668855667114258
epoch = 4644, loss = 0.4668824374675751
epoch = 4645, loss = 0.4668787121772766
epoch = 4646, loss = 0.4668753147125244
epoch = 4647, loss = 0.4668723940849304
epoch = 4648, loss = 0.4668682813644409
epoch = 4649, loss = 0.4668653607368469
epoch = 4650, loss = 0.46686193346977234
epoch = 4651, loss = 0.4668585956096649
epoch = 4652, loss = 0.46685460209846497
epoch = 4653, loss = 0.4668515622615814
epoch = 4654, loss = 0.4668479859828949
epoch = 4655, loss = 0.4668443500995636
epoch = 4656, loss = 0.466841459274292
epoch = 4657, loss = 0.46683767437934875
epoch = 4658, loss = 0.4668343961238861
epoch = 4659, loss = 0.46683061122894287
epoch = 4660, loss = 0.4668278396129608
epoch = 4661, loss = 0.46682414412498474
epoch = 4662, loss = 0.466820627450943
epoch = 4663, loss = 0.466817706823349
epoch = 4664, loss = 0.4668136835098

epoch = 4945, loss = 0.46596866846084595
epoch = 4946, loss = 0.4659656584262848
epoch = 4947, loss = 0.4659634232521057
epoch = 4948, loss = 0.4659608006477356
epoch = 4949, loss = 0.4659578204154968
epoch = 4950, loss = 0.46595504879951477
epoch = 4951, loss = 0.4659525454044342
epoch = 4952, loss = 0.4659501016139984
epoch = 4953, loss = 0.46594730019569397
epoch = 4954, loss = 0.4659445583820343
epoch = 4955, loss = 0.4659423530101776
epoch = 4956, loss = 0.46593937277793884
epoch = 4957, loss = 0.4659365117549896
epoch = 4958, loss = 0.4659341871738434
epoch = 4959, loss = 0.46593159437179565
epoch = 4960, loss = 0.46592867374420166
epoch = 4961, loss = 0.4659265875816345
epoch = 4962, loss = 0.46592339873313904
epoch = 4963, loss = 0.4659208655357361
epoch = 4964, loss = 0.46591809391975403
epoch = 4965, loss = 0.46591535210609436
epoch = 4966, loss = 0.46591314673423767
epoch = 4967, loss = 0.4659103751182556
epoch = 4968, loss = 0.46590733528137207
epoch = 4969, loss = 0.465905

epoch = 5227, loss = 0.4652903378009796
epoch = 5228, loss = 0.4652881622314453
epoch = 5229, loss = 0.46528592705726624
epoch = 5230, loss = 0.4652836322784424
epoch = 5231, loss = 0.4652819335460663
epoch = 5232, loss = 0.46527984738349915
epoch = 5233, loss = 0.4652778208255768
epoch = 5234, loss = 0.46527522802352905
epoch = 5235, loss = 0.46527305245399475
epoch = 5236, loss = 0.4652707576751709
epoch = 5237, loss = 0.4652688205242157
epoch = 5238, loss = 0.4652665853500366
epoch = 5239, loss = 0.46526458859443665
epoch = 5240, loss = 0.4652624726295471
epoch = 5241, loss = 0.4652602970600128
epoch = 5242, loss = 0.4652579426765442
epoch = 5243, loss = 0.4652562439441681
epoch = 5244, loss = 0.46525415778160095
epoch = 5245, loss = 0.46525174379348755
epoch = 5246, loss = 0.46524953842163086
epoch = 5247, loss = 0.46524766087532043
epoch = 5248, loss = 0.46524566411972046
epoch = 5249, loss = 0.4652428925037384
epoch = 5250, loss = 0.4652409255504608
epoch = 5251, loss = 0.4652392

epoch = 5496, loss = 0.46476131677627563
epoch = 5497, loss = 0.46475937962532043
epoch = 5498, loss = 0.464758038520813
epoch = 5499, loss = 0.46475574374198914
epoch = 5500, loss = 0.4647538959980011
epoch = 5501, loss = 0.4647522270679474
epoch = 5502, loss = 0.4647507071495056
epoch = 5503, loss = 0.4647482931613922
epoch = 5504, loss = 0.46474698185920715
epoch = 5505, loss = 0.46474480628967285
epoch = 5506, loss = 0.464743435382843
epoch = 5507, loss = 0.46474161744117737
epoch = 5508, loss = 0.4647395610809326
epoch = 5509, loss = 0.46473783254623413
epoch = 5510, loss = 0.4647365212440491
epoch = 5511, loss = 0.46473443508148193
epoch = 5512, loss = 0.46473273634910583
epoch = 5513, loss = 0.46473103761672974
epoch = 5514, loss = 0.46472927927970886
epoch = 5515, loss = 0.4647274613380432
epoch = 5516, loss = 0.464725136756897
epoch = 5517, loss = 0.4647234380245209
epoch = 5518, loss = 0.4647219181060791
epoch = 5519, loss = 0.4647206664085388
epoch = 5520, loss = 0.464718610

epoch = 5781, loss = 0.46429648995399475
epoch = 5782, loss = 0.4642953872680664
epoch = 5783, loss = 0.4642939865589142
epoch = 5784, loss = 0.4642925560474396
epoch = 5785, loss = 0.4642902612686157
epoch = 5786, loss = 0.46428945660591125
epoch = 5787, loss = 0.46428796648979187
epoch = 5788, loss = 0.4642866253852844
epoch = 5789, loss = 0.4642845690250397
epoch = 5790, loss = 0.4642830491065979
epoch = 5791, loss = 0.46428200602531433
epoch = 5792, loss = 0.4642804265022278
epoch = 5793, loss = 0.4642789661884308
epoch = 5794, loss = 0.4642776548862457
epoch = 5795, loss = 0.46427565813064575
epoch = 5796, loss = 0.46427440643310547
epoch = 5797, loss = 0.46427303552627563
epoch = 5798, loss = 0.4642718732357025
epoch = 5799, loss = 0.4642699062824249
epoch = 5800, loss = 0.46426889300346375
epoch = 5801, loss = 0.46426740288734436
epoch = 5802, loss = 0.4642655551433563
epoch = 5803, loss = 0.464264452457428
epoch = 5804, loss = 0.46426302194595337
epoch = 5805, loss = 0.46426126

epoch = 6071, loss = 0.46390268206596375
epoch = 6072, loss = 0.4639012813568115
epoch = 6073, loss = 0.46390023827552795
epoch = 6074, loss = 0.46389907598495483
epoch = 6075, loss = 0.46389755606651306
epoch = 6076, loss = 0.4638965129852295
epoch = 6077, loss = 0.46389490365982056
epoch = 6078, loss = 0.46389392018318176
epoch = 6079, loss = 0.46389278769493103
epoch = 6080, loss = 0.4638914465904236
epoch = 6081, loss = 0.4638902246952057
epoch = 6082, loss = 0.463888943195343
epoch = 6083, loss = 0.46388769149780273
epoch = 6084, loss = 0.46388620138168335
epoch = 6085, loss = 0.46388480067253113
epoch = 6086, loss = 0.4638842046260834
epoch = 6087, loss = 0.4638829827308655
epoch = 6088, loss = 0.4638814330101013
epoch = 6089, loss = 0.4638800621032715
epoch = 6090, loss = 0.4638792872428894
epoch = 6091, loss = 0.4638775885105133
epoch = 6092, loss = 0.46387648582458496
epoch = 6093, loss = 0.4638751745223999
epoch = 6094, loss = 0.46387410163879395
epoch = 6095, loss = 0.463872

epoch = 6353, loss = 0.46357718110084534
epoch = 6354, loss = 0.4635760188102722
epoch = 6355, loss = 0.46357446908950806
epoch = 6356, loss = 0.46357396245002747
epoch = 6357, loss = 0.46357282996177673
epoch = 6358, loss = 0.463571697473526
epoch = 6359, loss = 0.46357032656669617
epoch = 6360, loss = 0.4635693430900574
epoch = 6361, loss = 0.4635685980319977
epoch = 6362, loss = 0.46356743574142456
epoch = 6363, loss = 0.46356648206710815
epoch = 6364, loss = 0.46356555819511414
epoch = 6365, loss = 0.46356433629989624
epoch = 6366, loss = 0.463563472032547
epoch = 6367, loss = 0.4635622799396515
epoch = 6368, loss = 0.4635610580444336
epoch = 6369, loss = 0.4635603725910187
epoch = 6370, loss = 0.46355926990509033
epoch = 6371, loss = 0.4635579586029053
epoch = 6372, loss = 0.4635563790798187
epoch = 6373, loss = 0.46355557441711426
epoch = 6374, loss = 0.46355488896369934
epoch = 6375, loss = 0.4635537266731262
epoch = 6376, loss = 0.46355241537094116
epoch = 6377, loss = 0.463551

epoch = 6641, loss = 0.46328893303871155
epoch = 6642, loss = 0.4632880687713623
epoch = 6643, loss = 0.463287353515625
epoch = 6644, loss = 0.46328628063201904
epoch = 6645, loss = 0.4632854759693146
epoch = 6646, loss = 0.463283896446228
epoch = 6647, loss = 0.46328309178352356
epoch = 6648, loss = 0.4632827639579773
epoch = 6649, loss = 0.46328145265579224
epoch = 6650, loss = 0.46328040957450867
epoch = 6651, loss = 0.4632793664932251
epoch = 6652, loss = 0.4632784128189087
epoch = 6653, loss = 0.4632776975631714
epoch = 6654, loss = 0.46327725052833557
epoch = 6655, loss = 0.46327629685401917
epoch = 6656, loss = 0.4632752239704132
epoch = 6657, loss = 0.4632739722728729
epoch = 6658, loss = 0.46327292919158936
epoch = 6659, loss = 0.46327176690101624
epoch = 6660, loss = 0.46327143907546997
epoch = 6661, loss = 0.4632700979709625
epoch = 6662, loss = 0.46326929330825806
epoch = 6663, loss = 0.4632686674594879
epoch = 6664, loss = 0.46326732635498047
epoch = 6665, loss = 0.4632667

epoch = 6931, loss = 0.46303173899650574
epoch = 6932, loss = 0.463031142950058
epoch = 6933, loss = 0.4630303978919983
epoch = 6934, loss = 0.4630295932292938
epoch = 6935, loss = 0.46302878856658936
epoch = 6936, loss = 0.46302786469459534
epoch = 6937, loss = 0.46302729845046997
epoch = 6938, loss = 0.46302616596221924
epoch = 6939, loss = 0.46302539110183716
epoch = 6940, loss = 0.4630245864391327
epoch = 6941, loss = 0.4630233943462372
epoch = 6942, loss = 0.4630226194858551
epoch = 6943, loss = 0.4630218744277954
epoch = 6944, loss = 0.4630208909511566
epoch = 6945, loss = 0.4630206823348999
epoch = 6946, loss = 0.46301963925361633
epoch = 6947, loss = 0.46301841735839844
epoch = 6948, loss = 0.46301764249801636
epoch = 6949, loss = 0.4630168676376343
epoch = 6950, loss = 0.46301591396331787
epoch = 6951, loss = 0.46301549673080444
epoch = 6952, loss = 0.4630140960216522
epoch = 6953, loss = 0.4630136489868164
epoch = 6954, loss = 0.46301236748695374
epoch = 6955, loss = 0.463011

epoch = 7217, loss = 0.46280187368392944
epoch = 7218, loss = 0.46280089020729065
epoch = 7219, loss = 0.4628003239631653
epoch = 7220, loss = 0.4627997577190399
epoch = 7221, loss = 0.46279898285865784
epoch = 7222, loss = 0.4627978205680847
epoch = 7223, loss = 0.4627971947193146
epoch = 7224, loss = 0.4627964198589325
epoch = 7225, loss = 0.4627952575683594
epoch = 7226, loss = 0.4627948999404907
epoch = 7227, loss = 0.46279391646385193
epoch = 7228, loss = 0.4627934396266937
epoch = 7229, loss = 0.4627927243709564
epoch = 7230, loss = 0.4627917408943176
epoch = 7231, loss = 0.46279123425483704
epoch = 7232, loss = 0.4627903401851654
epoch = 7233, loss = 0.46278977394104004
epoch = 7234, loss = 0.46278879046440125
epoch = 7235, loss = 0.4627876877784729
epoch = 7236, loss = 0.46278688311576843
epoch = 7237, loss = 0.46278631687164307
epoch = 7238, loss = 0.462785929441452
epoch = 7239, loss = 0.46278488636016846
epoch = 7240, loss = 0.46278443932533264
epoch = 7241, loss = 0.4627831

epoch = 7504, loss = 0.46258705854415894
epoch = 7505, loss = 0.46258604526519775
epoch = 7506, loss = 0.46258533000946045
epoch = 7507, loss = 0.4625844955444336
epoch = 7508, loss = 0.4625839293003082
epoch = 7509, loss = 0.462583065032959
epoch = 7510, loss = 0.46258223056793213
epoch = 7511, loss = 0.4625815451145172
epoch = 7512, loss = 0.46258077025413513
epoch = 7513, loss = 0.4625799357891083
epoch = 7514, loss = 0.4625793695449829
epoch = 7515, loss = 0.46257880330085754
epoch = 7516, loss = 0.46257805824279785
epoch = 7517, loss = 0.46257707476615906
epoch = 7518, loss = 0.4625764787197113
epoch = 7519, loss = 0.4625760018825531
epoch = 7520, loss = 0.4625752866268158
epoch = 7521, loss = 0.46257448196411133
epoch = 7522, loss = 0.462573379278183
epoch = 7523, loss = 0.4625728726387024
epoch = 7524, loss = 0.4625723361968994
epoch = 7525, loss = 0.46257129311561584
epoch = 7526, loss = 0.4625707268714905
epoch = 7527, loss = 0.4625701308250427
epoch = 7528, loss = 0.462568968

epoch = 7778, loss = 0.4623905122280121
epoch = 7779, loss = 0.4623894691467285
epoch = 7780, loss = 0.4623892903327942
epoch = 7781, loss = 0.4623883366584778
epoch = 7782, loss = 0.46238741278648376
epoch = 7783, loss = 0.46238717436790466
epoch = 7784, loss = 0.4623860716819763
epoch = 7785, loss = 0.46238577365875244
epoch = 7786, loss = 0.4623848795890808
epoch = 7787, loss = 0.4623843729496002
epoch = 7788, loss = 0.46238359808921814
epoch = 7789, loss = 0.4623827636241913
epoch = 7790, loss = 0.4623818099498749
epoch = 7791, loss = 0.46238142251968384
epoch = 7792, loss = 0.4623804986476898
epoch = 7793, loss = 0.4623800814151764
epoch = 7794, loss = 0.462379515171051
epoch = 7795, loss = 0.4623788297176361
epoch = 7796, loss = 0.4623778164386749
epoch = 7797, loss = 0.4623771011829376
epoch = 7798, loss = 0.46237650513648987
epoch = 7799, loss = 0.46237605810165405
epoch = 7800, loss = 0.462375283241272
epoch = 7801, loss = 0.46237438917160034
epoch = 7802, loss = 0.46237331628

epoch = 8095, loss = 0.46216917037963867
epoch = 8096, loss = 0.4621690511703491
epoch = 8097, loss = 0.4621686041355133
epoch = 8098, loss = 0.46216773986816406
epoch = 8099, loss = 0.46216684579849243
epoch = 8100, loss = 0.46216607093811035
epoch = 8101, loss = 0.46216511726379395
epoch = 8102, loss = 0.46216461062431335
epoch = 8103, loss = 0.4621639847755432
epoch = 8104, loss = 0.4621632993221283
epoch = 8105, loss = 0.4621623456478119
epoch = 8106, loss = 0.4621616303920746
epoch = 8107, loss = 0.4621610641479492
epoch = 8108, loss = 0.4621601402759552
epoch = 8109, loss = 0.4621596038341522
epoch = 8110, loss = 0.46215948462486267
epoch = 8111, loss = 0.46215856075286865
epoch = 8112, loss = 0.46215736865997314
epoch = 8113, loss = 0.4621571898460388
epoch = 8114, loss = 0.4621562659740448
epoch = 8115, loss = 0.4621557593345642
epoch = 8116, loss = 0.4621550142765045
epoch = 8117, loss = 0.4621542990207672
epoch = 8118, loss = 0.46215346455574036
epoch = 8119, loss = 0.4621529

epoch = 8430, loss = 0.4619366526603699
epoch = 8431, loss = 0.4619363844394684
epoch = 8432, loss = 0.4619356393814087
epoch = 8433, loss = 0.4619351029396057
epoch = 8434, loss = 0.4619341790676117
epoch = 8435, loss = 0.4619336724281311
epoch = 8436, loss = 0.46193283796310425
epoch = 8437, loss = 0.46193206310272217
epoch = 8438, loss = 0.4619317054748535
epoch = 8439, loss = 0.4619307518005371
epoch = 8440, loss = 0.46193063259124756
epoch = 8441, loss = 0.4619297981262207
epoch = 8442, loss = 0.46192893385887146
epoch = 8443, loss = 0.46192800998687744
epoch = 8444, loss = 0.461927205324173
epoch = 8445, loss = 0.4619268774986267
epoch = 8446, loss = 0.461926132440567
epoch = 8447, loss = 0.46192556619644165
epoch = 8448, loss = 0.46192502975463867
epoch = 8449, loss = 0.4619242250919342
epoch = 8450, loss = 0.4619233012199402
epoch = 8451, loss = 0.46192285418510437
epoch = 8452, loss = 0.46192190051078796
epoch = 8453, loss = 0.46192121505737305
epoch = 8454, loss = 0.461920410

epoch = 8709, loss = 0.4617416262626648
epoch = 8710, loss = 0.4617406725883484
epoch = 8711, loss = 0.4617401659488678
epoch = 8712, loss = 0.46173983812332153
epoch = 8713, loss = 0.4617390036582947
epoch = 8714, loss = 0.4617382884025574
epoch = 8715, loss = 0.46173760294914246
epoch = 8716, loss = 0.46173685789108276
epoch = 8717, loss = 0.46173593401908875
epoch = 8718, loss = 0.46173498034477234
epoch = 8719, loss = 0.46173474192619324
epoch = 8720, loss = 0.4617340862751007
epoch = 8721, loss = 0.46173325181007385
epoch = 8722, loss = 0.4617319703102112
epoch = 8723, loss = 0.46173134446144104
epoch = 8724, loss = 0.46173083782196045
epoch = 8725, loss = 0.46173036098480225
epoch = 8726, loss = 0.46172961592674255
epoch = 8727, loss = 0.4617289900779724
epoch = 8728, loss = 0.4617283046245575
epoch = 8729, loss = 0.46172741055488586
epoch = 8730, loss = 0.4617269039154053
epoch = 8731, loss = 0.46172598004341125
epoch = 8732, loss = 0.4617252051830292
epoch = 8733, loss = 0.4617

epoch = 8993, loss = 0.4615379571914673
epoch = 8994, loss = 0.4615374803543091
epoch = 8995, loss = 0.46153631806373596
epoch = 8996, loss = 0.4615359604358673
epoch = 8997, loss = 0.46153515577316284
epoch = 8998, loss = 0.46153438091278076
epoch = 8999, loss = 0.46153393387794495
epoch = 9000, loss = 0.46153274178504944
epoch = 9001, loss = 0.46153196692466736
epoch = 9002, loss = 0.46153151988983154
epoch = 9003, loss = 0.4615309238433838
epoch = 9004, loss = 0.46153026819229126
epoch = 9005, loss = 0.46152928471565247
epoch = 9006, loss = 0.46152883768081665
epoch = 9007, loss = 0.4615280330181122
epoch = 9008, loss = 0.4615272283554077
epoch = 9009, loss = 0.46152639389038086
epoch = 9010, loss = 0.46152573823928833
epoch = 9011, loss = 0.4615249037742615
epoch = 9012, loss = 0.46152421832084656
epoch = 9013, loss = 0.46152356266975403
epoch = 9014, loss = 0.4615229368209839
epoch = 9015, loss = 0.46152228116989136
epoch = 9016, loss = 0.46152159571647644
epoch = 9017, loss = 0.4

epoch = 9290, loss = 0.46131834387779236
epoch = 9291, loss = 0.461317777633667
epoch = 9292, loss = 0.461316853761673
epoch = 9293, loss = 0.4613165855407715
epoch = 9294, loss = 0.46131566166877747
epoch = 9295, loss = 0.46131473779678345
epoch = 9296, loss = 0.4613136351108551
epoch = 9297, loss = 0.46131306886672974
epoch = 9298, loss = 0.4613122344017029
epoch = 9299, loss = 0.4613114297389984
epoch = 9300, loss = 0.4613107442855835
epoch = 9301, loss = 0.46131011843681335
epoch = 9302, loss = 0.46130967140197754
epoch = 9303, loss = 0.4613085389137268
epoch = 9304, loss = 0.46130794286727905
epoch = 9305, loss = 0.4613072872161865
epoch = 9306, loss = 0.46130627393722534
epoch = 9307, loss = 0.4613056778907776
epoch = 9308, loss = 0.4613049626350403
epoch = 9309, loss = 0.4613043963909149
epoch = 9310, loss = 0.4613035023212433
epoch = 9311, loss = 0.4613030254840851
epoch = 9312, loss = 0.4613020122051239
epoch = 9313, loss = 0.46130120754241943
epoch = 9314, loss = 0.4613003730

epoch = 9601, loss = 0.46107926964759827
epoch = 9602, loss = 0.4610786736011505
epoch = 9603, loss = 0.4610776901245117
epoch = 9604, loss = 0.4610764980316162
epoch = 9605, loss = 0.46107593178749084
epoch = 9606, loss = 0.461075097322464
epoch = 9607, loss = 0.4610743224620819
epoch = 9608, loss = 0.4610734283924103
epoch = 9609, loss = 0.4610731303691864
epoch = 9610, loss = 0.4610722064971924
epoch = 9611, loss = 0.4610709846019745
epoch = 9612, loss = 0.46107029914855957
epoch = 9613, loss = 0.4610697031021118
epoch = 9614, loss = 0.46106892824172974
epoch = 9615, loss = 0.46106818318367004
epoch = 9616, loss = 0.4610676169395447
epoch = 9617, loss = 0.46106669306755066
epoch = 9618, loss = 0.46106597781181335
epoch = 9619, loss = 0.4610649049282074
epoch = 9620, loss = 0.46106383204460144
epoch = 9621, loss = 0.46106335520744324
epoch = 9622, loss = 0.46106234192848206
epoch = 9623, loss = 0.46106165647506714
epoch = 9624, loss = 0.4610605835914612
epoch = 9625, loss = 0.4610603

epoch = 9907, loss = 0.4608325660228729
epoch = 9908, loss = 0.46083176136016846
epoch = 9909, loss = 0.46083131432533264
epoch = 9910, loss = 0.4608302712440491
epoch = 9911, loss = 0.46082961559295654
epoch = 9912, loss = 0.4608284533023834
epoch = 9913, loss = 0.46082794666290283
epoch = 9914, loss = 0.46082690358161926
epoch = 9915, loss = 0.46082648634910583
epoch = 9916, loss = 0.46082553267478943
epoch = 9917, loss = 0.46082478761672974
epoch = 9918, loss = 0.4608238637447357
epoch = 9919, loss = 0.46082282066345215
epoch = 9920, loss = 0.4608220160007477
epoch = 9921, loss = 0.460821270942688
epoch = 9922, loss = 0.46082040667533875
epoch = 9923, loss = 0.4608193337917328
epoch = 9924, loss = 0.4608185589313507
epoch = 9925, loss = 0.46081793308258057
epoch = 9926, loss = 0.4608171284198761
epoch = 9927, loss = 0.4608159065246582
epoch = 9928, loss = 0.4608154594898224
epoch = 9929, loss = 0.46081438660621643
epoch = 9930, loss = 0.4608134925365448
epoch = 9931, loss = 0.460812

In [67]:
print(f"record {x_data[0]} should be {y_data[0]}, pred {model_diabetes(x_data[0])}")

record tensor([-0.2941,  0.4874,  0.1803, -0.2929,  0.0000,  0.0015, -0.5312,
        -0.0333]) should be 0.0, pred tensor([ 0.3169])


### TODO Exercise 7-1