In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
from tensorflow.keras import layers as kl

In [2]:
distr_layer = tfp.layers.DistributionLambda(lambda t: tfd.NegativeBinomial(
        total_count = tf.math.softplus(t[... , 0]), logits = t[..., 1]    
       ))

In [3]:
model = tf.keras.models.Sequential([
    
    kl.Dense(2),
       
    distr_layer
    
])
negloglik = lambda x, rv_x: tf.reduce_mean(-rv_x.log_prob(x))
model.compile(optimizer='adam', loss=negloglik)

In [4]:
model(tf.constant([[1,2,3],[-1,0,1]]))

<tfp.distributions.NegativeBinomial 'sequential/distribution_lambda/NegativeBinomial/' batch_shape=[2] event_shape=[] dtype=float32>

In [5]:
model.weights

[<tf.Variable 'sequential/dense/kernel:0' shape=(3, 2) dtype=float32, numpy=
 array([[ 0.14176834, -0.6824337 ],
        [-0.5987934 ,  0.39817536],
        [ 0.18142378,  0.7412889 ]], dtype=float32)>,
 <tf.Variable 'sequential/dense/bias:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>]

In [6]:
d = tfd.Normal(0, 1)
d.sample()

<tf.Tensor: id=82, shape=(), dtype=float32, numpy=0.36463216>

In [7]:
N = 10000
coefs = np.array([[1, 2, 3],[3, 2, 1]]).T
bias = np.array([3, 0])
X = tfd.Normal(1, 1).sample(sample_shape = [N, 3])

In [8]:
X

<tf.Tensor: id=104, shape=(10000, 3), dtype=float32, numpy=
array([[ 0.5539832 , -0.8368081 ,  2.6849008 ],
       [ 0.87797767, -0.7828914 ,  0.65795726],
       [ 1.301696  ,  1.3853161 ,  2.053258  ],
       ...,
       [ 3.0525506 , -0.31443965,  1.8236    ],
       [ 0.5653211 ,  0.12393051,  1.8508487 ],
       [ 1.1416686 ,  2.1904838 ,  2.3637786 ]], dtype=float32)>

In [9]:
true_neg_bin_params = X @ coefs + bias

In [10]:
true_neg_bin_params

<tf.Tensor: id=108, shape=(10000, 2), dtype=float32, numpy=
array([[ 9.935069 ,  2.6732342],
       [ 4.2860665,  1.7261076],
       [13.232101 ,  8.728978 ],
       ...,
       [10.894471 , 10.352372 ],
       [ 9.365728 ,  3.794673 ],
       [15.613972 , 10.169752 ]], dtype=float32)>

In [11]:
true_distr = distr_layer(true_neg_bin_params)

In [12]:
Y = true_distr.sample()

In [34]:
# old_weights = model.get_weights()
# model.set_weights([coefs, bias])
# model_distr = model(X)
# print(model.evaluate(X, Y, batch_size=100000, verbose=0))
# model.set_weights(old_weights)

917772.75


In [35]:
model.set_weights([coefs, bias])

In [13]:
np.mean(negloglik(Y, true_distr))

8.614815

In [15]:
np.mean(negloglik(Y, model(X)))

2227965.2

In [13]:
print(model.evaluate(X, Y, batch_size=100000, verbose=0))

5707.3525390625


In [21]:
model_distr.variance()

<tf.Tensor: id=822, shape=(1000,), dtype=float32, numpy=
array([1.21238232e+01, 1.00705922e-01, 7.64667285e+03, 3.18914962e+00,
       1.57489258e+04, 2.03829758e+02, 9.40464878e+00, 1.35461194e-02,
       2.55139014e+03, 5.43421484e-04, 1.48330480e-01, 2.06514938e+02,
       2.50184769e-03, 4.73372488e-08, 8.23575234e+04, 1.32114947e+00,
       6.72494507e+01, 6.96377600e+06, 5.32422920e+07, 5.04576117e-02,
       8.25678329e+01, 1.77305031e+01, 6.16659641e+00, 2.71555638e+00,
       1.32029236e+02, 7.03650434e-03, 3.89921808e+00, 8.87841523e-01,
       1.82209956e-03, 1.01331357e+04, 2.46027417e+03, 1.63398814e+00,
       3.79049349e+00, 6.69985504e+01, 4.67193685e-02, 1.36680186e+00,
       1.48160279e+00, 2.11490967e+02, 1.00911297e-01, 2.24391068e-03,
       2.54644629e+03, 7.50159025e-01, 4.95267175e-02, 3.22795331e-01,
       3.65319276e+00, 6.21352673e+00, 3.28687575e+06, 2.36122292e-02,
       4.89925499e+01, 4.10976142e-01, 3.38280737e-01, 4.58471295e-05,
       2.90678040e+0

In [22]:
true_distr.mean()

<tf.Tensor: id=824, shape=(1000,), dtype=float32, numpy=
array([5.39192915e+00, 9.40134972e-02, 2.61385040e+02, 2.01179886e+00,
       1.83725555e+02, 2.39086666e+01, 1.71441627e+00, 1.30386744e-02,
       8.63368378e+01, 5.41332411e-04, 1.41170457e-01, 3.30959282e+01,
       2.48523825e-03, 4.73338169e-08, 8.45050842e+02, 8.71967971e-01,
       2.12845192e+01, 7.70625244e+03, 1.97486152e+04, 4.95060720e-02,
       1.70164757e+01, 6.84960270e+00, 2.24802613e+00, 1.58621025e+00,
       2.24217052e+01, 7.00184191e-03, 2.31625009e+00, 6.30251825e-01,
       1.82117056e-03, 2.63476562e+02, 1.24819069e+02, 1.34678996e+00,
       1.59668016e+00, 1.40467587e+01, 4.60770465e-02, 7.82155454e-01,
       1.05084479e+00, 2.00056953e+01, 7.83467814e-02, 2.16445350e-03,
       1.03910805e+02, 6.58941627e-01, 4.84120101e-02, 1.90160573e-01,
       2.22222304e+00, 2.88203239e+00, 5.04907129e+03, 2.32535619e-02,
       8.42647743e+00, 3.56983066e-01, 3.10838521e-01, 4.55867739e-05,
       1.67614784e+0

In [23]:
model(X).mean()

<tf.Tensor: id=860, shape=(1000,), dtype=float32, numpy=
array([ 1.60564   ,  1.379874  ,  0.37944475,  0.23928322,  0.35129416,
        0.5350127 ,  0.58887446,  3.7094767 ,  0.03705967,  0.6486339 ,
        0.29919487,  3.3640707 ,  0.7692973 ,  0.153582  ,  0.92201734,
        0.36696893,  8.129183  ,  0.6083264 ,  0.49984547,  5.672316  ,
        0.20700514,  0.22036603,  0.12912177,  0.94990593,  0.82027286,
        0.0970009 ,  3.2522144 ,  0.44362554,  0.9279015 ,  0.23899819,
        6.7617755 ,  9.026428  ,  0.7018325 ,  0.28638062,  0.17627268,
        0.74865276,  0.4431658 ,  0.22856337,  0.21052438,  0.15510939,
        0.91958874,  0.3258133 ,  2.0351088 ,  0.12172444,  0.7549416 ,
        0.3646252 ,  0.10030773,  0.71948314,  0.281849  ,  0.5646567 ,
        0.4052561 ,  0.2630552 ,  0.7462153 ,  1.5540676 ,  1.2013553 ,
        1.7074366 ,  1.826355  ,  0.15734006,  0.86663127,  0.48212838,
        2.0446255 ,  0.11625703,  1.1176066 ,  2.1943846 ,  0.33802193,
       

In [24]:
model.evaluate(X, Y)



2819.313861816406

In [25]:
model.evaluate(X, Y)



2819.313861816406

In [56]:
true_distr

<tfp.distributions.NegativeBinomial 'distribution_lambda_3_NegativeBinomial' batch_shape=[100000] event_shape=[] dtype=float32>

In [57]:
true_distr.probs_parameter()

<tf.Tensor: id=585, shape=(100000,), dtype=float32, numpy=
array([0.36850506, 0.74067277, 0.66926724, ..., 0.9999394 , 0.32746673,
       0.4854044 ], dtype=float32)>

In [58]:
true_distr.total_count

<tf.Tensor: id=559, shape=(100000,), dtype=float32, numpy=
array([0.2757673 , 0.32079494, 1.5467578 , ..., 8.653196  , 1.9845582 ,
       1.8261846 ], dtype=float32)>

In [59]:
true_distr.logits

<tf.Tensor: id=563, shape=(100000,), dtype=float32, numpy=
array([-0.5386355 ,  1.0494682 ,  0.7048727 , ...,  9.711327  ,
       -0.71966517, -0.05839898], dtype=float32)>

In [61]:
Y

<tf.Tensor: id=606, shape=(100000,), dtype=float32, numpy=
array([0.00000e+00, 2.00000e+00, 1.00000e+00, ..., 1.17367e+05,
       0.00000e+00, 0.00000e+00], dtype=float32)>

In [18]:
model.fit(x = X, y = Y, epochs = 1000, batch_size=128)

Train on 10000 samples
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 

Epoch 87/1000
Epoch 88/1000
Epoch 89/1000
Epoch 90/1000
Epoch 91/1000
Epoch 92/1000
Epoch 93/1000
Epoch 94/1000
Epoch 95/1000
Epoch 96/1000
Epoch 97/1000
Epoch 98/1000
Epoch 99/1000
Epoch 100/1000
Epoch 101/1000
Epoch 102/1000
Epoch 103/1000
Epoch 104/1000
Epoch 105/1000
Epoch 106/1000
Epoch 107/1000
Epoch 108/1000
Epoch 109/1000
Epoch 110/1000
Epoch 111/1000
Epoch 112/1000
Epoch 113/1000
Epoch 114/1000
Epoch 115/1000
Epoch 116/1000
Epoch 117/1000
Epoch 118/1000
Epoch 119/1000
Epoch 120/1000
Epoch 121/1000
Epoch 122/1000
Epoch 123/1000
Epoch 124/1000
Epoch 125/1000
Epoch 126/1000
Epoch 127/1000
Epoch 128/1000
Epoch 129/1000
Epoch 130/1000
Epoch 131/1000
Epoch 132/1000
Epoch 133/1000
Epoch 134/1000
Epoch 135/1000
Epoch 136/1000
Epoch 137/1000
Epoch 138/1000
Epoch 139/1000
Epoch 140/1000
Epoch 141/1000
Epoch 142/1000
Epoch 143/1000
Epoch 144/1000
Epoch 145/1000
Epoch 146/1000
Epoch 147/1000
Epoch 148/1000
Epoch 149/1000
Epoch 150/1000
Epoch 151/1000
Epoch 152/1000
Epoch 153/1000
Epoch 15

Epoch 172/1000
Epoch 173/1000
Epoch 174/1000
Epoch 175/1000
Epoch 176/1000
Epoch 177/1000
Epoch 178/1000
Epoch 179/1000
Epoch 180/1000
Epoch 181/1000
Epoch 182/1000
Epoch 183/1000
Epoch 184/1000
Epoch 185/1000
Epoch 186/1000
Epoch 187/1000
Epoch 188/1000
Epoch 189/1000
Epoch 190/1000
Epoch 191/1000
Epoch 192/1000
Epoch 193/1000
Epoch 194/1000
Epoch 195/1000
Epoch 196/1000
Epoch 197/1000
Epoch 198/1000
Epoch 199/1000
Epoch 200/1000
Epoch 201/1000
Epoch 202/1000
Epoch 203/1000
Epoch 204/1000
Epoch 205/1000
Epoch 206/1000
Epoch 207/1000
Epoch 208/1000
Epoch 209/1000
Epoch 210/1000
Epoch 211/1000
Epoch 212/1000
Epoch 213/1000
Epoch 214/1000
Epoch 215/1000
Epoch 216/1000
Epoch 217/1000
Epoch 218/1000
Epoch 219/1000
Epoch 220/1000
Epoch 221/1000
Epoch 222/1000
Epoch 223/1000
Epoch 224/1000
Epoch 225/1000
Epoch 226/1000
Epoch 227/1000
Epoch 228/1000
Epoch 229/1000
Epoch 230/1000
Epoch 231/1000
Epoch 232/1000
Epoch 233/1000
Epoch 234/1000
Epoch 235/1000
Epoch 236/1000
Epoch 237/1000
Epoch 238/

Epoch 258/1000
Epoch 259/1000
Epoch 260/1000
Epoch 261/1000
Epoch 262/1000
Epoch 263/1000
Epoch 264/1000
Epoch 265/1000
Epoch 266/1000
Epoch 267/1000
Epoch 268/1000
Epoch 269/1000
Epoch 270/1000
Epoch 271/1000
Epoch 272/1000
Epoch 273/1000
Epoch 274/1000
Epoch 275/1000
Epoch 276/1000
Epoch 277/1000
Epoch 278/1000
Epoch 279/1000
Epoch 280/1000
Epoch 281/1000
Epoch 282/1000
Epoch 283/1000
Epoch 284/1000
Epoch 285/1000
Epoch 286/1000
Epoch 287/1000
Epoch 288/1000
Epoch 289/1000
Epoch 290/1000
Epoch 291/1000
Epoch 292/1000
Epoch 293/1000
Epoch 294/1000
Epoch 295/1000
Epoch 296/1000
Epoch 297/1000
Epoch 298/1000
Epoch 299/1000
Epoch 300/1000
Epoch 301/1000
Epoch 302/1000
Epoch 303/1000
Epoch 304/1000
Epoch 305/1000
Epoch 306/1000
Epoch 307/1000
Epoch 308/1000
Epoch 309/1000
Epoch 310/1000
Epoch 311/1000
Epoch 312/1000
Epoch 313/1000
Epoch 314/1000
Epoch 315/1000
Epoch 316/1000
Epoch 317/1000
Epoch 318/1000
Epoch 319/1000
Epoch 320/1000
Epoch 321/1000
Epoch 322/1000
Epoch 323/1000
Epoch 324/

Epoch 344/1000
Epoch 345/1000
Epoch 346/1000
Epoch 347/1000
Epoch 348/1000
Epoch 349/1000
Epoch 350/1000
Epoch 351/1000
Epoch 352/1000
Epoch 353/1000
Epoch 354/1000
Epoch 355/1000
Epoch 356/1000
Epoch 357/1000
Epoch 358/1000
Epoch 359/1000
Epoch 360/1000
Epoch 361/1000
Epoch 362/1000
Epoch 363/1000
Epoch 364/1000
Epoch 365/1000
Epoch 366/1000
Epoch 367/1000
Epoch 368/1000
Epoch 369/1000
Epoch 370/1000
Epoch 371/1000
Epoch 372/1000
Epoch 373/1000
Epoch 374/1000
Epoch 375/1000
Epoch 376/1000
Epoch 377/1000
Epoch 378/1000
Epoch 379/1000
Epoch 380/1000
Epoch 381/1000
Epoch 382/1000
Epoch 383/1000
Epoch 384/1000
Epoch 385/1000
Epoch 386/1000
Epoch 387/1000
Epoch 388/1000
Epoch 389/1000
Epoch 390/1000
Epoch 391/1000
Epoch 392/1000
Epoch 393/1000
Epoch 394/1000
Epoch 395/1000
Epoch 396/1000
Epoch 397/1000
Epoch 398/1000
Epoch 399/1000
Epoch 400/1000
Epoch 401/1000
Epoch 402/1000
Epoch 403/1000
Epoch 404/1000
Epoch 405/1000
Epoch 406/1000
Epoch 407/1000
Epoch 408/1000
Epoch 409/1000
Epoch 410/

Epoch 431/1000
Epoch 432/1000
Epoch 433/1000
Epoch 434/1000
Epoch 435/1000
Epoch 436/1000
Epoch 437/1000
Epoch 438/1000
Epoch 439/1000
Epoch 440/1000
Epoch 441/1000
Epoch 442/1000
Epoch 443/1000
Epoch 444/1000
Epoch 445/1000
Epoch 446/1000
Epoch 447/1000
Epoch 448/1000
Epoch 449/1000
Epoch 450/1000
Epoch 451/1000
Epoch 452/1000
Epoch 453/1000
Epoch 454/1000
Epoch 455/1000
Epoch 456/1000
Epoch 457/1000
Epoch 458/1000
Epoch 459/1000
Epoch 460/1000
Epoch 461/1000
Epoch 462/1000
Epoch 463/1000
Epoch 464/1000
Epoch 465/1000
Epoch 466/1000
Epoch 467/1000
Epoch 468/1000
Epoch 469/1000
Epoch 470/1000
Epoch 471/1000
Epoch 472/1000
Epoch 473/1000
Epoch 474/1000
Epoch 475/1000
Epoch 476/1000
Epoch 477/1000
Epoch 478/1000
Epoch 479/1000
Epoch 480/1000
Epoch 481/1000
Epoch 482/1000
Epoch 483/1000
Epoch 484/1000
Epoch 485/1000
Epoch 486/1000
Epoch 487/1000
Epoch 488/1000
Epoch 489/1000
Epoch 490/1000
Epoch 491/1000
Epoch 492/1000
Epoch 493/1000
Epoch 494/1000
Epoch 495/1000
Epoch 496/1000
Epoch 497/

Epoch 518/1000
Epoch 519/1000
Epoch 520/1000
Epoch 521/1000
Epoch 522/1000
Epoch 523/1000
Epoch 524/1000
Epoch 525/1000
Epoch 526/1000
Epoch 527/1000
Epoch 528/1000
Epoch 529/1000
Epoch 530/1000
Epoch 531/1000
Epoch 532/1000
Epoch 533/1000
Epoch 534/1000
Epoch 535/1000
Epoch 536/1000
Epoch 537/1000
Epoch 538/1000
Epoch 539/1000
Epoch 540/1000
Epoch 541/1000
Epoch 542/1000
Epoch 543/1000
Epoch 544/1000
Epoch 545/1000
Epoch 546/1000
Epoch 547/1000
Epoch 548/1000
Epoch 549/1000
Epoch 550/1000
Epoch 551/1000
Epoch 552/1000
Epoch 553/1000
Epoch 554/1000
Epoch 555/1000
Epoch 556/1000
Epoch 557/1000
Epoch 558/1000
Epoch 559/1000
Epoch 560/1000
Epoch 561/1000
Epoch 562/1000
Epoch 563/1000
Epoch 564/1000
Epoch 565/1000
Epoch 566/1000
Epoch 567/1000
Epoch 568/1000
Epoch 569/1000
Epoch 570/1000
Epoch 571/1000
Epoch 572/1000
Epoch 573/1000
Epoch 574/1000
Epoch 575/1000
Epoch 576/1000
Epoch 577/1000
Epoch 578/1000
Epoch 579/1000
Epoch 580/1000
Epoch 581/1000
Epoch 582/1000
Epoch 583/1000
Epoch 584/

Epoch 606/1000
Epoch 607/1000
Epoch 608/1000
Epoch 609/1000
Epoch 610/1000
Epoch 611/1000
Epoch 612/1000
Epoch 613/1000
Epoch 614/1000
Epoch 615/1000
Epoch 616/1000
Epoch 617/1000
Epoch 618/1000
Epoch 619/1000
Epoch 620/1000
Epoch 621/1000
Epoch 622/1000
Epoch 623/1000
Epoch 624/1000
Epoch 625/1000
Epoch 626/1000
Epoch 627/1000
Epoch 628/1000
Epoch 629/1000
Epoch 630/1000
Epoch 631/1000
Epoch 632/1000
Epoch 633/1000
Epoch 634/1000
Epoch 635/1000
Epoch 636/1000
Epoch 637/1000
Epoch 638/1000
Epoch 639/1000
Epoch 640/1000
Epoch 641/1000
Epoch 642/1000
Epoch 643/1000
Epoch 644/1000
Epoch 645/1000
Epoch 646/1000
Epoch 647/1000
Epoch 648/1000
Epoch 649/1000
Epoch 650/1000
Epoch 651/1000
Epoch 652/1000
Epoch 653/1000
Epoch 654/1000
Epoch 655/1000
Epoch 656/1000
Epoch 657/1000
Epoch 658/1000
Epoch 659/1000
Epoch 660/1000
Epoch 661/1000
Epoch 662/1000
Epoch 663/1000
Epoch 664/1000
Epoch 665/1000
Epoch 666/1000
Epoch 667/1000
Epoch 668/1000
Epoch 669/1000
Epoch 670/1000
Epoch 671/1000
Epoch 672/

Epoch 694/1000
Epoch 695/1000
Epoch 696/1000
Epoch 697/1000
Epoch 698/1000
Epoch 699/1000
Epoch 700/1000
Epoch 701/1000
Epoch 702/1000
Epoch 703/1000
Epoch 704/1000
Epoch 705/1000
Epoch 706/1000
Epoch 707/1000
Epoch 708/1000
Epoch 709/1000
Epoch 710/1000
Epoch 711/1000
Epoch 712/1000
Epoch 713/1000
Epoch 714/1000
Epoch 715/1000
Epoch 716/1000
Epoch 717/1000
Epoch 718/1000
Epoch 719/1000
Epoch 720/1000
Epoch 721/1000
Epoch 722/1000
Epoch 723/1000
Epoch 724/1000
Epoch 725/1000
Epoch 726/1000
Epoch 727/1000
Epoch 728/1000
Epoch 729/1000
Epoch 730/1000
Epoch 731/1000
Epoch 732/1000
Epoch 733/1000
Epoch 734/1000
Epoch 735/1000
Epoch 736/1000
Epoch 737/1000
Epoch 738/1000
Epoch 739/1000
Epoch 740/1000
Epoch 741/1000
Epoch 742/1000
Epoch 743/1000
Epoch 744/1000
Epoch 745/1000
Epoch 746/1000
Epoch 747/1000
Epoch 748/1000
Epoch 749/1000
Epoch 750/1000
Epoch 751/1000
Epoch 752/1000
Epoch 753/1000
Epoch 754/1000
Epoch 755/1000
Epoch 756/1000
Epoch 757/1000
Epoch 758/1000
Epoch 759/1000
Epoch 760/

Epoch 782/1000
Epoch 783/1000
Epoch 784/1000
Epoch 785/1000
Epoch 786/1000
Epoch 787/1000
Epoch 788/1000
Epoch 789/1000
Epoch 790/1000
Epoch 791/1000
Epoch 792/1000
Epoch 793/1000
Epoch 794/1000
Epoch 795/1000
Epoch 796/1000
Epoch 797/1000
Epoch 798/1000
Epoch 799/1000
Epoch 800/1000
Epoch 801/1000
Epoch 802/1000
Epoch 803/1000
Epoch 804/1000
Epoch 805/1000
Epoch 806/1000
Epoch 807/1000
Epoch 808/1000
Epoch 809/1000
Epoch 810/1000
Epoch 811/1000
Epoch 812/1000
Epoch 813/1000
Epoch 814/1000
Epoch 815/1000
Epoch 816/1000
Epoch 817/1000
Epoch 818/1000
Epoch 819/1000
Epoch 820/1000
Epoch 821/1000
Epoch 822/1000
Epoch 823/1000
Epoch 824/1000
Epoch 825/1000
Epoch 826/1000
Epoch 827/1000
Epoch 828/1000
Epoch 829/1000
Epoch 830/1000
Epoch 831/1000
Epoch 832/1000
Epoch 833/1000
Epoch 834/1000
Epoch 835/1000
Epoch 836/1000
Epoch 837/1000
Epoch 838/1000
Epoch 839/1000
Epoch 840/1000
Epoch 841/1000
Epoch 842/1000
Epoch 843/1000
Epoch 844/1000
Epoch 845/1000
Epoch 846/1000
Epoch 847/1000
Epoch 848/

Epoch 870/1000
Epoch 871/1000
Epoch 872/1000
Epoch 873/1000
Epoch 874/1000
Epoch 875/1000
Epoch 876/1000
Epoch 877/1000
Epoch 878/1000
Epoch 879/1000
Epoch 880/1000
Epoch 881/1000
Epoch 882/1000
Epoch 883/1000
Epoch 884/1000
Epoch 885/1000
Epoch 886/1000
Epoch 887/1000
Epoch 888/1000
Epoch 889/1000
Epoch 890/1000
Epoch 891/1000
Epoch 892/1000
Epoch 893/1000
Epoch 894/1000
Epoch 895/1000
Epoch 896/1000
Epoch 897/1000
Epoch 898/1000
Epoch 899/1000
Epoch 900/1000
Epoch 901/1000
Epoch 902/1000
Epoch 903/1000
Epoch 904/1000
Epoch 905/1000
Epoch 906/1000
Epoch 907/1000
Epoch 908/1000
Epoch 909/1000
Epoch 910/1000
Epoch 911/1000
Epoch 912/1000
Epoch 913/1000
Epoch 914/1000
Epoch 915/1000
Epoch 916/1000
Epoch 917/1000
Epoch 918/1000
Epoch 919/1000
Epoch 920/1000
Epoch 921/1000
Epoch 922/1000
Epoch 923/1000
Epoch 924/1000
Epoch 925/1000
Epoch 926/1000
Epoch 927/1000
Epoch 928/1000
Epoch 929/1000
Epoch 930/1000
Epoch 931/1000
Epoch 932/1000
Epoch 933/1000
Epoch 934/1000
Epoch 935/1000
Epoch 936/

Epoch 958/1000
Epoch 959/1000
Epoch 960/1000
Epoch 961/1000
Epoch 962/1000
Epoch 963/1000
Epoch 964/1000
Epoch 965/1000
Epoch 966/1000
Epoch 967/1000
Epoch 968/1000
Epoch 969/1000
Epoch 970/1000
Epoch 971/1000
Epoch 972/1000
Epoch 973/1000
Epoch 974/1000
Epoch 975/1000
Epoch 976/1000
Epoch 977/1000
Epoch 978/1000
Epoch 979/1000
Epoch 980/1000
Epoch 981/1000
Epoch 982/1000
Epoch 983/1000
Epoch 984/1000
Epoch 985/1000
Epoch 986/1000
Epoch 987/1000
Epoch 988/1000
Epoch 989/1000
Epoch 990/1000
Epoch 991/1000
Epoch 992/1000
Epoch 993/1000
Epoch 994/1000
Epoch 995/1000
Epoch 996/1000
Epoch 997/1000
Epoch 998/1000
Epoch 999/1000
Epoch 1000/1000


<tensorflow.python.keras.callbacks.History at 0x182c2c2a788>

In [19]:
np.mean(negloglik(Y, model(X)))

11.128087

In [17]:
l = model.loss_functions[0]

In [18]:
l.call(Y, model(X))

<tf.Tensor: id=41183, shape=(), dtype=float32, numpy=6.023795>

In [24]:
dense = model.layers[0]

In [27]:
dense.kernel_regularizer

In [None]:
kl.Dense()

In [26]:
dense.a


In [20]:
model.weights

[<tf.Variable 'sequential/dense/kernel:0' shape=(3, 2) dtype=float32, numpy=
 array([[-0.03124883,  0.28588724],
        [-0.03150563,  0.2923285 ],
        [-0.03221646,  0.28648335]], dtype=float32)>,
 <tf.Variable 'sequential/dense/bias:0' shape=(2,) dtype=float32, numpy=array([-2.0391998, 16.384024 ], dtype=float32)>]

In [68]:
model(X).total_count

<tf.Tensor: id=399687, shape=(100000,), dtype=float32, numpy=
array([0.05909398, 0.05949057, 0.06059545, ..., 0.06368538, 0.05943586,
       0.06014879], dtype=float32)>

In [69]:
model(X).probs_parameter()

<tf.Tensor: id=399747, shape=(100000,), dtype=float32, numpy=
array([0.99999475, 0.9999949 , 0.99999505, ..., 0.9999956 , 0.9999945 ,
       0.9999949 ], dtype=float32)>

In [21]:
def get_loss():
    return tf.reduce_mean(negloglik(Y, model(X)))

In [22]:
get_loss()

<tf.Tensor: id=160078, shape=(), dtype=float32, numpy=11.128087>

In [23]:
optimizer = tf.optimizers.Adam()

In [24]:
@tf.function
def train_step():
    with tf.GradientTape() as tape:
        loss = get_loss()

    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    return loss

In [25]:
for i in range(100000):
    loss = train_step()
    if i % 100 == 0:
        print(loss)

tf.Tensor(11.128087, shape=(), dtype=float32)
tf.Tensor(11.062134, shape=(), dtype=float32)
tf.Tensor(11.052348, shape=(), dtype=float32)
tf.Tensor(11.0327, shape=(), dtype=float32)
tf.Tensor(10.954716, shape=(), dtype=float32)
tf.Tensor(10.967147, shape=(), dtype=float32)
tf.Tensor(10.987995, shape=(), dtype=float32)
tf.Tensor(10.989368, shape=(), dtype=float32)
tf.Tensor(10.980344, shape=(), dtype=float32)
tf.Tensor(10.9544935, shape=(), dtype=float32)
tf.Tensor(10.939382, shape=(), dtype=float32)
tf.Tensor(10.905773, shape=(), dtype=float32)
tf.Tensor(10.9035, shape=(), dtype=float32)
tf.Tensor(10.890143, shape=(), dtype=float32)
tf.Tensor(10.87322, shape=(), dtype=float32)
tf.Tensor(10.866693, shape=(), dtype=float32)
tf.Tensor(10.856206, shape=(), dtype=float32)
tf.Tensor(10.852836, shape=(), dtype=float32)
tf.Tensor(10.847836, shape=(), dtype=float32)
tf.Tensor(10.839115, shape=(), dtype=float32)
tf.Tensor(10.839273, shape=(), dtype=float32)
tf.Tensor(10.837333, shape=(), dtype=f

tf.Tensor(8.560568, shape=(), dtype=float32)
tf.Tensor(8.544261, shape=(), dtype=float32)
tf.Tensor(8.556552, shape=(), dtype=float32)
tf.Tensor(8.558686, shape=(), dtype=float32)
tf.Tensor(8.553261, shape=(), dtype=float32)
tf.Tensor(8.571409, shape=(), dtype=float32)
tf.Tensor(8.556756, shape=(), dtype=float32)
tf.Tensor(8.569316, shape=(), dtype=float32)
tf.Tensor(8.570954, shape=(), dtype=float32)
tf.Tensor(8.575504, shape=(), dtype=float32)
tf.Tensor(8.5664015, shape=(), dtype=float32)
tf.Tensor(8.569334, shape=(), dtype=float32)
tf.Tensor(8.566512, shape=(), dtype=float32)
tf.Tensor(8.562978, shape=(), dtype=float32)
tf.Tensor(8.55217, shape=(), dtype=float32)
tf.Tensor(8.542455, shape=(), dtype=float32)
tf.Tensor(8.541173, shape=(), dtype=float32)
tf.Tensor(8.535672, shape=(), dtype=float32)
tf.Tensor(8.543259, shape=(), dtype=float32)
tf.Tensor(8.563656, shape=(), dtype=float32)
tf.Tensor(8.556761, shape=(), dtype=float32)
tf.Tensor(8.573159, shape=(), dtype=float32)
tf.Tensor(

KeyboardInterrupt: 

In [26]:
model.trainable_weights

[<tf.Variable 'sequential/dense/kernel:0' shape=(3, 2) dtype=float32, numpy=
 array([[0.9516669, 3.0018325],
        [1.9121954, 2.0007813],
        [2.794243 , 1.00997  ]], dtype=float32)>,
 <tf.Variable 'sequential/dense/bias:0' shape=(2,) dtype=float32, numpy=array([2.942339  , 0.03384899], dtype=float32)>]