In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import wandb
import tempfile
import os
from models.normalizing_flow import HierarchicalNormalizingFlowSB
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch import nn


In [3]:
from train_flow_sb import unflatten, object_from_config

In [4]:
api = wandb.Api()

In [5]:
run = api.run('druhe/gw-src/16pgb80c')

In [6]:
config = unflatten(run.config)

In [7]:
dataset = object_from_config(config, key='dataset')(**config['dataset'])

In [8]:
sb_weight = config['model']['sb_weight']

In [9]:
flows = object_from_config(config, key='flow')(**config['flow'])

In [10]:
checkpoints = [f for f in run.files() if f.name.endswith('.ckpt')]

In [11]:
tempdir = tempfile.TemporaryDirectory()

In [12]:
ckpt = checkpoints[0]

In [13]:
ckpt_path = ckpt.download(root=tempdir.name, replace=True)

In [14]:
model = HierarchicalNormalizingFlowSB.load_from_checkpoint(ckpt_path.name, dataset=dataset, sb_weight=sb_weight, flows=flows)

INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /var/folders/d3/x840qlg17x1f92cnsrkq62fw0000gn/T/tmper60dl8g
INFO:torch.distributed.nn.jit.instantiator:Writing /var/folders/d3/x840qlg17x1f92cnsrkq62fw0000gn/T/tmper60dl8g/_remote_module_non_sriptable.py


In [15]:
# with torch.no_grad():
#     axes_names = []
#     axes = []

#     for n, ax in dataset.grid.items():
#         axes_names.append(n)
#         axes.append(ax)

#     m1, m2, z = np.stack(np.meshgrid(*axes, indexing="xy")).reshape(3, -1)

#     if dataset.has_normalization:
#         raise NotImplementedError
#         x, y = dataset.normalize_forward(x, y)

#     resolutions = [len(ax) for ax in axes]

#     input = np.stack([m1, m2, z], axis=-1)
#     input = torch.from_numpy(input).float()
#     prob = model.log_prob(input).exp().view(*resolutions)


#     pm1m2 = prob.sum(-1)

#     fig = plt.figure(figsize=(16, 16), facecolor="white")
#     plt.imshow(
#         pm1m2,
#         cmap="jet",
#         origin="lower",
#         extent=(
#             axes[0][0],
#             axes[0][-1],
#             axes[1][0],
#             axes[1][-1],
#         ),  # origin='lower' changes the order
#         aspect="auto",
#     )

#     plt.xlabel(axes_names[0])  # origin='lower' changes the order
#     plt.ylabel(axes_names[1])
#     plt.tight_layout()



In [16]:
axes_names = []
axes = []

for n, ax in dataset.grid.items():
    axes_names.append(n)
    axes.append(ax)

m1, m2, z = np.stack(np.meshgrid(*axes, indexing="xy")).reshape(3, -1)

if dataset.has_normalization:
    raise NotImplementedError
    x, y = dataset.normalize_forward(x, y)

resolutions = [len(ax) for ax in axes]

input = np.stack([m1, m2, z], axis=-1)


In [17]:
input=torch.from_numpy(input).double()

In [18]:
import hamiltorch

In [19]:
from torch.nn.utils import _stateless


In [29]:
def flatten_params(parameters):
    """
    flattens all parameters into a single column vector. Returns the dictionary to recover them
    :param: parameters: a generator or list of all the parameters
    :return: a dictionary: {"params": [#params, 1],
    "indices": [(start index, end index) for each param] **Note end index in uninclusive**

    """
    l = [torch.flatten(p) for p in parameters]
    indices = []
    s = 0
    for p in l:
        size = p.shape[0]
        indices.append((s, s+size))
        s += size
    flat = torch.cat(l).view(-1, 1)
    return flat, indices


def unflatten_params(flat_params, indices, model):
    """
    Gives a list of recovered parameters from their flattened form
    :param flat_params: [#params, 1]
    :param indices: a list detaling the start and end index of each param [(start, end) for param]
    :param model: the model that gives the params with correct shapes
    :return: the params, reshaped to the ones in the model, with the same order as those in the model
    """
    l = [flat_params[s:e] for (s, e) in indices]
    for i, p in enumerate(model.parameters()):
        l[i] = l[i].view(*p.shape)
    return tuple(nn.Parameter(p) for p in l)


In [30]:
params_flat, indices = flatten_params(model.parameters())
names = list(n for n, _ in model.named_parameters())


params_flat = params_flat.squeeze()

In [31]:
model=model.double()

params_init = hamiltorch.util.flatten(model).clone().double()


In [32]:
fmodel = hamiltorch.util.make_functional(model)

In [36]:
def log_prob(params, y=None):
    prior = torch.distributions.Normal(params_init, 1).log_prob(params)

    params = unflatten_params(params, indices, model)
            
    out: torch.Tensor = _stateless.functional_call(model, {n: p for n, p in zip(names, params)}, input)    
    ll = prior.sum() + out.sum()
    print(out.mean().item())
    return ll




In [37]:
N = 300
burn = 200
N_nuts = N + burn
step_size = 1e-2
L = 20

hamiltorch.sample(
    log_prob, 
    params_init=params_init, 
    num_samples=N_nuts, num_steps_per_sample
    =L, step_size=step_size, burn=burn, )







Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent  | Time remain.| Progress             | Samples | Samples/sec
-258.8183631382832
-258.8183631382832
-791.3876175427846
-563.0427895238288
-714.6202286248185
-937.2559735191967
-1149.1823304826748
-1332.1297743550433
-1485.48737085598
-1613.493071029463
-1720.93595486402
-1811.9521560357257
-1889.8497029088917
-1957.2202475136783
-2016.0960715990802
-2068.0808311168594
-2114.440930564363
-2156.1706664319313
-2194.038458974127
-2228.6411205655004
-2260.456291642062
-2289.8808815953835
-2289.8808815953835
-258.81836313828320:08:40 | -------------------- |   1/500 | 0.96       
-258.8183631382832
-468.0544499209573
-853.4726915896724
-1168.6509631114454
-1408.3042942586594
-1585.1137446569776
-1715.206168247502
-1810.6708600215493
-1878.9564447580783
-1923.535226566257
-1943.6650130090645
-1931.9247107470314
-1867.819345695304
-1745.1162584390318
-2331.9326613940248
-2583.0830627491214
-2656.8631416342237
-2744.539606301988
-2824.06091

-1630.1615369910946
-5959.394201837327
-618.9254166889682
-943.6850270727163
-1362.0169184254423
-1591.635161273359
-1759.3277780847775
-1916.7387184282834
-2096.2227352584064
-2309.3434330514447
-2599.591368081178
-2873.5982615001067
-3079.8389222513933
-3242.211978627185
-3355.246178509622
-3435.9627746194665
-3485.974404093552
-3518.893449795649
-3549.0786430730655
-3549.0786430730655
-258.81836313828320:17:52 | #------------------- |  17/500 | 0.45       
-258.8183631382832
-1886.1741439928949
-1529.2442435518085
-1824.8678254448482
-2104.922093692722
-2322.6385332311334
-2486.0968447059668
-2609.872480454541
-2705.6524493757624
-2781.7408757569583
-2844.34415308605
-2899.167202177381
-2949.52511028711
-2995.5119186154398
-3037.43023011239
-3076.224055110789
-3112.9422642630498
-3148.2809109427258
-3182.4706886746135
-3215.521154052868
-3247.4459660723232
-3247.4459660723232
-258.81836313828320:17:50 | #------------------- |  18/500 | 0.45       
-258.8183631382832
-742.94728427132

-2537.536231043605
-2614.5016823252145
-2657.9570745854494
-2684.2234581958173
-2701.0835388894466
-2712.5628651308107
-2720.9627945612506
-2720.9627945612506
-258.81836313828320:16:48 | #------------------- |  33/500 | 0.46       
-258.8183631382832
-390.99867268643345
-649.6518339072813
-1028.619580502932
-1286.404923783341
-1446.6377413704722
-1566.7777392874195
-1663.769542020375
-1742.0204506893092
-1807.4038773353427
-1866.1966260892698
-1924.0620481714955
-1986.0950503288307
-2057.014497942445
-2141.193783971529
-2237.827940747456
-2327.825324567794
-2403.148718044859
-2466.6137217959204
-2520.873884504903
-2567.910848462153
-2567.910848462153
-258.81836313828320:16:44 | #------------------- |  34/500 | 0.46       
-258.8183631382832
-478.1385870899536
-1060.7506112337194
-1440.1144508128814
-1695.5917492877988
-1883.4926052834892
-2027.6401621382329
-2140.239339975458
-2228.9141792167657
-2299.0736021505704
-2354.8224548569888
-2399.380133523017
-2435.3107442039927
-2464.640533

-258.81836313828320:15:47 | ##------------------ |  49/500 | 0.48       
-258.8183631382832
-254.4659940764309
-467.9707100071053
-700.3741720419758
-852.7211471275289
-965.6567131120667
-1058.9542925833402
-1139.4432546524472
-1210.3189933699346
-1273.6560312080323
-1331.033481867879
-1383.7131426399228
-1432.711999399111
-1478.849301304483
-1522.7860992856906
-1565.0585589164111
-1606.1053383849735
-1646.288797439191
-1685.9112099945842
-1725.2271125900813
-1764.4529844243261
-1764.4529844243261
-258.81836313828320:15:45 | ##------------------ |  50/500 | 0.48       
-258.8183631382832
-1326.9555172260887
-1433.9367297173949
-1631.7653548105977
-1851.0061960048208
-2036.7373652369547
-2190.409583640634
-2320.056175530861
-2433.5834328985766
-2537.7766651725574
-2638.193036052164
-2739.3571542791296
-2844.86458820857
-2956.1836022098905
-3068.577932373992
-3175.4385160242255
-3275.1259090532794
-3367.853303721124
-3454.0561172333387
-3534.178539668048
-3608.6521453730375
-3608.6521453

-837.1728195879961
-1072.7254531513142
-1272.9922571948769
-1439.3149300568696
-1574.215292872874
-1680.7003079648666
-1763.8768871757015
-1828.282072141773
-1877.198953446425
-1913.0671426710167
-1937.8906420841508
-1953.512990884844
-1961.7787672320896
-1964.5963409183378
-1963.9166912020614
-1961.653066821031
-1959.5752349723625
-1959.2125454358936
-1959.2125454358936
-258.81836313828320:15:10 | ###----------------- |  66/500 | 0.48       
-258.8183631382832
-402.412071520481
-730.6199682355292
-1034.359626707023
-1281.8533694497871
-1478.0230423903777
-1631.859860604479
-1753.0893574398551
-1850.1311234434384
-1929.0067189466315
-1993.9741390545464
-2048.1189805825375
-2093.728750527373
-2132.53076772918
-2165.849269225868
-2194.712638832848
-2219.928994664638
-2242.140797979827
-2261.8628851398553
-2279.4999199584245
-2295.378087274357
-2295.378087274357
-258.81836313828320:15:10 | ###----------------- |  67/500 | 0.48       
-258.8183631382832
-361.88481412663117
-1201.0841174623

-3578.988757127221
-3640.141983659236
-3702.0105867447187
-3761.836275892871
-3818.240065598779
-3870.6608805751002
-3918.9984622353713
-3963.398043819234
-3963.398043819234
-258.81836313828320:15:00 | ###----------------- |  82/500 | 0.46       
-258.8183631382832
-463.6039175458778
-1227.1762350267381
-1598.1722994730533
-1521.1298718049043
-1553.6915193444372
-1648.380459050094
-1762.0369409523437
-1877.7128312336445
-1993.8870508438338
-2112.9328240288137
-2227.5589887633796
-2329.8704529560246
-2424.502781071771
-2517.2477698382127
-2611.0404498963
-2702.8243718813983
-2785.9675340918834
-2862.3712640346607
-2935.5564637850207
-3004.1682622228154
-3004.1682622228154
-258.81836313828320:14:58 | ###----------------- |  83/500 | 0.46       
-258.8183631382832
-439.2060684770972
-801.467558019341
-1128.3726554492118
-1392.7820544782066
-1595.8641870648385
-1751.9887175581553
-1873.5958928124357
-1969.4994175989946
-2045.926024150664
-2107.4634171889893
-2158.239162057484
-2204.7584946

-258.81836313828320:14:19 | ####---------------- |  98/500 | 0.47       
-258.8183631382832
-497.6886995117563
-1426.9573947729498
-1427.4755192297778
-1681.20951551371
-1886.170776620293
-2027.7393628169646
-2125.565383283894
-2196.9110809280533
-2252.4408306359446
-2298.219384551649
-2337.665950487701
-2372.7491963646107
-2404.6421469752313
-2434.071614852633
-2461.505764451914
-2487.257181193261
-2511.5415332613484
-2534.5124184317788
-2556.2829987383157
-2576.939899205579
-2576.939899205579
-258.81836313828320:14:16 | ####---------------- |  99/500 | 0.47       
-258.8183631382832
-737.8466486868842
-1412.6374685195483
-1536.5522851486362
-1695.1412898595838
-1855.6131025837074
-1999.7395159616551
-2129.6089503368876
-2246.4849599728605
-2349.0310435526335
-2439.5062246625903
-2520.496142495983
-2593.5865772730244
-2659.488761121501
-2718.46794570783
-2770.626751799243
-2816.0003926536515
-2854.6028871191957
-2886.5025113550714
-2911.9254348098707
-2931.36031566389
-2931.3603156638

-2699.338544478049
-2752.529740667157
-2784.043807592299
-2819.1570931887154
-2861.291454656067
-2907.844587875595
-2955.8477389773207
-3003.219531565428
-3048.7397638613315
-3091.7913890549285
-3132.1345156563893
-3169.7498883912513
-3204.7389194192765
-3237.262610904897
-3267.505629959639
-3267.505629959639
-258.81836313828320:13:36 | #####--------------- | 115/500 | 0.47       
-258.8183631382832
-254.21870783140668
-438.47542263731117
-701.5634644009184
-953.1436842634005
-1157.9026329552282
-1317.2860014893017
-1440.1738781894464
-1534.633305995686
-1606.7932764113261
-1661.1978098980326
-1701.298334402106
-1729.8353862107951
-1749.0946001628288
-1761.0661236239507
-1767.5450248394657
-1770.2187456641136
-1770.8134545056605
-1771.4637462027465
-1775.7932755202032
-1791.3577843996795
-1791.3577843996795
-258.81836313828320:13:33 | #####--------------- | 116/500 | 0.47       
-258.8183631382832
-984.2687195249275
-2374.0491801751596
-3253.5311806056134
-3074.026621205841
-2649.83843

-3666.26364288176
-3784.927467791043
-3891.7980609256133
-3988.4722744251694
-4076.354084634855
-4155.865774216444
-4155.865774216444
-258.81836313828320:12:56 | #####--------------- | 131/500 | 0.48       
-258.8183631382832
-615.3618856884334
-1274.9218383403797
-1167.4755242294018
-1351.262414264422
-1571.009855787756
-1753.0216733749658
-1897.1782033897416
-2012.8184981076115
-2109.1189721976043
-2195.0150844200425
-2279.9153826319152
-2373.029196247343
-2478.8580734729308
-2595.854526896003
-2739.2544207481787
-2898.644014708766
-3050.367312160921
-3194.1726378609887
-3331.2294640138184
-3462.1615334611292
-3462.1615334611292
-258.81836313828320:12:53 | #####--------------- | 132/500 | 0.48       
-258.8183631382832
-304.6900410923687
-864.3193638908224
-895.3744118470695
-1160.1402459269716
-1396.726311037146
-1583.4904529577188
-1729.8913851218072
-1847.6810034008638
-1945.3238172426722
-2028.3293856888704
-2100.2720068324074
-2163.5385352167336
-2219.7858088085895
-2270.2099240

-258.81836313828320:12:21 | ######-------------- | 147/500 | 0.48       
-258.8183631382832
-578.5678144337132
-676.7983437711792
-973.0363027838539
-1343.1404391455412
-1693.5001209769514
-1986.4867580969374
-2218.740839163145
-2405.9467856611705
-2560.786789400875
-2691.44401095601
-2803.5311761896724
-2901.037480473222
-2986.880005137126
-3063.2415778744517
-3131.786097040608
-3193.802821328324
-3250.3057273627774
-3302.1032850353986
-3349.848228741302
-3394.0735959982103
-3394.0735959982103
-258.81836313828320:12:19 | ######-------------- | 148/500 | 0.48       
-258.8183631382832
-675.1296385636055
-1105.9167609378183
-1606.2186110524315
-1951.0005286383503
-2188.7518410836074
-2361.816613818554
-2495.7398919314696
-2606.359866022395
-2703.8914044067583
-2793.862134260553
-2878.280172060922
-2958.7855247509856
-3037.78571591787
-3117.726416659141
-3200.629124959074
-3288.0267219592324
-3380.6610907247805
-3476.7769809791325
-3571.5103068121393
-3661.920255922987
-3661.920255922987

-2137.635155417192
-2270.2023646716707
-2373.5757620985805
-2451.0852625236703
-2506.466089619335
-2543.9217119562027
-2567.707831915569
-2581.758320288456
-2589.439452126669
-2593.438251125904
-2595.763226855113
-2597.821550117098
-2600.5353127435483
-2604.4679405723728
-2609.944347897158
-2609.944347897158
-258.81836313828320:11:51 | #######------------- | 164/500 | 0.47       
-258.8183631382832
-775.4567046713938
-950.5857687442426
-1127.8608499228953
-1300.100688298634
-1505.1574938180265
-1754.5243577775643
-2007.0236654724877
-2196.591542073672
-2304.6890061150034
-2356.9609299260337
-2379.87635500109
-2380.5452204291023
-2377.3868649520928
-2383.9923127389206
-2403.233490046342
-2433.448822411372
-2471.833297393638
-2515.6877383470232
-2562.78856360346
-2611.4329407497635
-2611.4329407497635
-258.81836313828320:11:49 | #######------------- | 165/500 | 0.47       
-258.8183631382832
-485.52066249734025
-940.6835450941735
-1466.8821194311631
-1965.5777494168326
-2448.102315328314

-2166.312055740238
-2189.1667274061365
-2227.3400058237967
-2273.0320215846878
-2319.5565581982964
-2365.124480033247
-2409.744940423923
-2409.744940423923
-258.81836313828320:11:20 | #######------------- | 180/500 | 0.47       
-258.8183631382832
-308.90441655453014
-526.3807765071047
-731.188807323185
-988.3465360994956
-1336.132137119347
-1680.0820241942165
-1955.7076068609454
-2163.9503825311203
-2324.7590498699356
-2454.8284084833886
-2564.7436567196346
-2660.6730946656066
-2746.1318036379635
-2823.1894210361397
-2893.1431023938594
-2956.934748201672
-3015.3057005512037
-3068.88801665186
-3118.2430320734024
-3163.857328086168
-3163.857328086168
-258.81836313828320:11:17 | #######------------- | 181/500 | 0.47       
-258.8183631382832
-880.9393478321476
-1839.8639711153137
-2115.002914427654
-2039.7463186646146
-2105.7121861219457
-2220.703122109805
-2352.8515046565653
-2503.109020612189
-2664.118369895183
-2815.3266072604547
-2952.0178856517405
-3074.54425779114
-3183.88851099273

-258.81836313828320:10:46 | ########------------ | 196/500 | 0.47       
-258.8183631382832
-443.742310048469
-809.7983724709541
-1135.4799609186634
-1398.5517971153408
-1604.7736732844885
-1767.5322221449298
-1897.8901604114633
-2003.956449227982
-2091.545459296888
-2164.8507111668955
-2226.9417382868887
-2280.1071874040463
-2326.0889196053713
-2366.2409274242054
-2401.634313328429
-2433.1219033190964
-2461.37815182014
-2486.9342401477115
-2510.2154815122917
-2531.57781577283
-2531.57781577283
-258.81836313828320:10:44 | ########------------ | 197/500 | 0.47       
-258.8183631382832
-610.0878855684273
-639.9613824763088
-2827.1628827540426
-2487.702555864449
-2967.5311526124347
-3231.818394773944
-2745.9477456786735
-2264.2497148776647
-2093.9618386587827
-2078.8495472155337
-2119.0849213322426
-2174.431572858415
-2229.6009548235347
-2279.2294852029254
-2322.006524993378
-2358.259287546531
-2388.902176339071
-2414.972128923635
-2437.4338115227074
-2457.1108111459844
-2457.11081114598

-2010.8103922229811
-2113.8859803182177
-2189.7736139030676
-2245.804894040214
-2286.643054320506
-2315.0242936613577
-2332.3042773707884
-2338.729745405786
-2333.443727109402
-2314.4070282281323
-2278.5232049730157
-2223.8394974286966
-2161.3726289138476
-2154.854825244396
-2323.3453299014523
-2323.3453299014523
-258.81836313828320:10:12 | #########----------- | 213/500 | 0.47       
-258.8183631382832
-528.1277436926827
-1118.6736875182328
-1651.7893825149276
-1899.9536832944152
-2033.1939965691456
-2138.628603077874
-2231.9809759119144
-2314.2193429024164
-2385.714992725945
-2447.627018001937
-2501.5035737064363
-2548.92584423235
-2591.350940941055
-2630.066623646313
-2666.1925233812835
-2700.7044123998435
-2734.4465507425994
-2768.156985552332
-2802.4707839567473
-2837.948807479961
-2837.948807479961
-258.81836313828320:10:10 | #########----------- | 214/500 | 0.47       
-258.8183631382832
-358.7206446832273
-673.7193391402386
-1041.6819556753621
-1311.216172316767
-1505.960189483

-1828.1153751627976
-1882.3530614169495
-1934.367611693433
-1985.0798061422306
-2035.2294051738402
-2085.4033667426966
-2136.074828736667
-2136.074828736667
-258.81836313828320:09:40 | #########----------- | 229/500 | 0.47       
-258.8183631382832
-487.4219427925278
-532.9879579947719
-735.3463599053073
-952.9190265031771
-1136.6281218891565
-1288.6579341771994
-1415.4829815526318
-1522.6250985172667
-1613.6790769053468
-1690.2494884771459
-1753.0206303115156
-1802.8243377363442
-1840.6582718725012
-1867.5162284637263
-1884.4477702463369
-1892.6978006301513
-1893.8006484336672
-1889.5753725737443
-1882.0134483186323
-1873.0921634121944
-1873.0921634121944
-258.81836313828320:09:40 | #########----------- | 230/500 | 0.47       
-258.8183631382832
-465.58781545775116
-857.5311784192535
-1168.2469450926505
-1419.6582402842746
-1620.285657485326
-1778.2607139453694
-1902.4178389123258
-2000.2490546533145
-2077.384660807656
-2137.9040027978176
-2184.779057080722
-2220.2783432262427
-2246.3

-258.81836313828320:09:11 | ##########---------- | 245/500 | 0.46       
-258.8183631382832
-410.3830532696741
-587.4858444081267
-687.8141950039537
-946.2976448653687
-1398.0164550968439
-1930.0049768246352
-2415.1767093866856
-2768.619146145447
-2973.2007700383483
-3100.487297599875
-3193.938453291253
-3268.6101026349124
-3332.1315061624528
-3388.5353986606
-3439.9250305780665
-3487.3943434944986
-3531.5306952853616
-3572.681334770691
-3611.088607599955
-3646.9539191135746
-3646.9539191135746
-258.81836313828320:09:09 | ##########---------- | 246/500 | 0.46       
-258.8183631382832
-652.1792910765159
-1057.2177892435918
-1306.0945075592488
-1480.7042971749052
-1611.3547599690587
-1713.017861837268
-1794.222754304109
-1860.0140247513743
-1913.2683281567329
-1955.3870350558254
-1986.707869425141
-2006.7466406261294
-2014.3470525081248
-2007.921828886962
-1986.082073569096
-1948.9785738091098
-1900.372194970826
-1849.551578283451
-1811.5799174604479
-1806.3570693008949
-1806.3570693008

-1966.952698732282
-2071.8186338333326
-2154.6170253413957
-2221.44709415097
-2276.472577814583
-2322.5754732745936
-2361.7955686767173
-2395.6110641226524
-2425.1161452382594
-2451.135551988095
-2474.30085078104
-2495.102621639507
-2513.9268704942015
-2531.081268943105
-2546.8135789837816
-2561.3245844704297
-2561.3245844704297
-258.81836313828320:08:36 | ##########---------- | 262/500 | 0.46       
-258.8183631382832
-415.8754631781038
-1035.2599801645476
-1470.384092298754
-1768.2147667484544
-1963.527560070367
-2111.2342889609936
-2235.6786063848476
-2345.8879266320578
-2444.9993686336
-2534.3232723607375
-2614.816721377236
-2687.4139440092104
-2753.022919590786
-2812.4661656357175
-2866.4598033270827
-2915.629243667284
-2960.52969664677
-3001.6562431394104
-3039.446838820025
-3074.2966839624532
-3074.2966839624532
-258.81836313828320:08:34 | ###########--------- | 263/500 | 0.46       
-258.8183631382832
-752.4535140493564
-647.1159731437467
-669.8809557864631
-903.3383681519936
-

-3602.912925279665
-3664.2928450617937
-3722.7497381833327
-3781.376983113441
-3840.089144745908
-3897.872137484801
-3897.872137484801
-258.81836313828320:08:03 | ###########--------- | 278/500 | 0.46       
-258.8183631382832
-365.7230202090706
-987.3698304632333
-1340.4276394529115
-1545.1179566659753
-1694.8825512105932
-1816.2176247836298
-1917.6950039820972
-2003.6146277213995
-2077.003596596496
-2140.2446003213354
-2195.254055503501
-2243.6115689935314
-2286.7479818560387
-2326.0620828093997
-2362.442734987142
-2396.15864838229
-2427.36348540962
-2456.2546475827858
-2483.0480895088585
-2507.954137135797
-2507.954137135797
-258.81836313828320:08:00 | ###########--------- | 279/500 | 0.46       
-258.8183631382832
-785.2729771272358
-779.8829549330678
-1095.9220801609172
-1454.498275145045
-1723.8146246411284
-1891.225454397741
-1986.190512455549
-2049.175117557302
-2111.0876331583786
-2192.8267537098636
-2334.9264166748876
-2527.8051954219745
-2714.820285344847
-2888.64529532486
-

-258.81836313828320:07:31 | ############-------- | 294/500 | 0.46       
-258.8183631382832
-773.3660154271968
-1437.9044833912235
-1800.776693369645
-2062.423639831115
-2243.5974514218697
-2373.6675367258713
-2471.9247025099558
-2548.469792350319
-2608.1720027802166
-2652.8423885294123
-2682.2512501043802
-2694.5816609251046
-2686.4779389745677
-2653.2964762877255
-2590.9787995178744
-2501.551548699016
-2401.997803644836
-2327.667283276643
-2320.0903593258063
-2408.6166090877464
-2408.6166090877464
-258.81836313828320:07:28 | ############-------- | 295/500 | 0.46       
-258.8183631382832
-230.40286519692856
-390.92198465915305
-688.7722651769516
-957.4811146738352
-1188.1615658114051
-1438.5556940198344
-1767.2022500894277
-2076.729342972514
-2373.1906637383977
-2505.385886622207
-2549.340662262771
-2584.902723271736
-2618.765138686327
-2651.0979061447274
-2681.562010681042
-2709.8705683752096
-2735.878780470508
-2759.5691231635924
-2781.0147238584377
-2800.3436133737328
-2800.343613

-1477.8313604815853
-1591.788468977159
-1689.3215367296411
-1762.8368198969129
-1818.0606019558927
-1882.709864855899
-2081.913575233916
-2484.515037877439
-1965.0847710188602
-2172.4371737771276
-2420.0341117349635
-2591.9433744880525
-2718.612324649508
-2818.2098906674382
-2900.000208447368
-2969.045198236926
-2969.045198236926
-258.81836313828320:06:52 | ############-------- | 311/500 | 0.46       
-258.8183631382832
-837.267905521406
-1164.6235881658572
-1504.213471270546
-1789.2179010721875
-2004.7478304003262
-2172.416329641022
-2307.1042925368
-2418.181297336561
-2511.791239683606
-2592.128628081429
-2662.14937722918
-2723.9911369664237
-2779.23444347409
-2829.071177845736
-2874.416060219045
-2915.9820559787217
-2954.332686402553
-2989.91912148644
-3023.1069358038762
-3054.195614094352
-3054.195614094352
-258.81836313828320:06:50 | ############-------- | 312/500 | 0.46       
-258.8183631382832
-449.17897397029844
-760.3689462783082
-1058.7097851718652
-1287.0094251861656
-1457.

-3119.6014594518497
-3197.198210175905
-3246.059042813041
-3286.6125357094543
-3323.855543118093
-3359.2613261905735
-3359.2613261905735
-258.81836313828320:06:17 | #############------- | 327/500 | 0.46       
-258.8183631382832
-681.6475043582475
-946.0697854476318
-813.1114352963702
-1106.709917723058
-1419.853821982253
-1791.9812637899286
-2162.0723429728605
-2346.716169221617
-2307.0673441622575
-2168.815772381562
-2039.9379634977886
-1952.6718694585797
-1904.8687013790839
-1887.8259815002687
-1893.7467798268663
-1914.992463902698
-1945.790792998544
-1984.1299815428947
-2028.807011758629
-2078.568186804308
-2078.568186804308
-258.81836313828320:06:14 | #############------- | 328/500 | 0.46       
-258.8183631382832
-634.5164677358994
-1204.0529966049016
-1726.3948026989692
-2051.658248895364
-2252.4255248893523
-2392.2563242072238
-2498.655452389792
-2583.2533978014603
-2652.0348352419696
-2708.8089157167124
-2756.4343017929823
-2797.3439031362655
-2833.446543051841
-2865.744252758

-258.81836313828320:05:40 | ##############------ | 343/500 | 0.46       
-258.8183631382832
-704.7461664735304
-1232.2845980885338
-1508.476721266866
-1677.902915617477
-1801.6081574383002
-1898.343210376155
-1976.652912010341
-2041.4756701391875
-2095.9997078775395
-2142.427361741632
-2182.354381107074
-2216.9789383956513
-2247.226348049877
-2273.828392824701
-2297.376330911855
-2318.357698533311
-2337.1820879477004
-2354.198220143742
-2369.7038116839144
-2383.949997464988
-2383.949997464988
-258.81836313828320:05:38 | ##############------ | 344/500 | 0.46       
-258.8183631382832
-400.3242208592585
-607.2939815691843
-818.9947057541837
-959.1657856795149
-1062.2598343313728
-1147.4073576523551
-1221.3942552854219
-1286.5659258930718
-1344.409267954074
-1396.2377653249443
-1443.207071040555
-1486.2823437852285
-1526.2340010172816
-1563.6644238439794
-1599.0406811498492
-1632.7238380618392
-1664.9920712591143
-1696.0635874045815
-1726.1055411720893
-1755.2487105547302
-1755.2487105547

-1501.1310996947477
-1597.4721938979844
-1674.0123805380385
-1736.9286446779283
-1790.6401472477787
-1838.392324841055
-1882.7016916838315
-1925.6918024413058
-1969.2993558582734
-2015.061473813591
-2062.9673916203515
-2110.783772782062
-2156.5856795197114
-2200.320235656167
-2242.129128709721
-2242.129128709721
-258.81836313828320:05:03 | ##############------ | 360/500 | 0.46       
-258.8183631382832
-400.2816670823664
-689.5061047796189
-802.0803542783219
-1052.2590674295066
-1278.9050736699048
-1462.1282415441403
-1613.7567283208468
-1749.5182051394684
-1886.5268048057176
-2048.5220433606296
-2238.060198993616
-2397.0453199166013
-2522.668642972021
-2649.127981814118
-2882.477820453338
-3343.583644389712
-3833.8420194581113
-4125.920073097737
-4205.382348450504
-4218.504172878926
-4218.504172878926
-258.81836313828320:05:01 | ##############------ | 361/500 | 0.46       
-258.8183631382832
-870.9734740849092
-1579.02680958847
-1936.3214936750642
-2151.4182726753866
-2366.81571139360

-2708.551003936427
-2854.9057408699146
-2977.387617392711
-3081.4702522404596
-3171.3519446640194
-3250.195762472492
-3250.195762472492
-258.81836313828320:04:28 | ###############----- | 376/500 | 0.46       
-258.8183631382832
-262.3396583808436
-506.49811019378285
-664.3099701449213
-750.8292219989306
-789.6613048044352
-782.9405211960384
-714.5714295070544
-549.3585699164977
-384.67577060468767
-1735.196006834496
-2148.578501565382
-1900.0381044453031
-1745.372755841652
-1667.221935589406
-1639.6990531942533
-1637.178712159265
-1649.9812962922051
-1673.037064869001
-1703.5694971441953
-1741.4945995445846
-1741.4945995445846
-258.81836313828320:04:26 | ###############----- | 377/500 | 0.46       
-258.8183631382832
-722.6763498257509
-1306.2521585466789
-1569.9093432092175
-1789.5896575622758
-1970.6132216511253
-2115.5005295615138
-2230.8186158227472
-2322.7880421419713
-2396.291194459789
-2454.9890208957627
-2501.6185706022707
-2538.2490963211885
-2566.469835617939
-2587.5236812681

-258.81836313828320:03:53 | ################---- | 392/500 | 0.46       
-258.8183631382832
-750.1883511497797
-999.329335264664
-806.344842472012
-712.9269633226672
-695.1136895319066
-714.880840518435
-759.8564969028635
-824.508019449248
-904.8564051526212
-997.3806314375652
-1098.6810643585886
-1205.4299875113475
-1314.5613640422418
-1423.5340435312671
-1530.587803399117
-1635.1608127518093
-1739.0497778286658
-1847.7955454132561
-1957.414435585386
-2043.0018155964685
-2043.0018155964685
-258.81836313828320:03:51 | ################---- | 393/500 | 0.46       
-258.8183631382832
-371.3787966713876
-1030.2095061127543
-2106.1715264111735
-3962.8387925785555
-2745.3997824072376
-989.6634972381491
-1193.373421694795
-1405.858664886794
-1567.3686549378867
-1693.8855085001433
-1799.5381829905714
-1893.9151742515548
-1982.8927611869142
-2068.909078950233
-2151.420152147057
-2228.916301741475
-2301.0731845710675
-2369.131052249138
-2434.774489931464
-2499.8567248721165
-2499.8567248721165
-

-1838.808983093742
-2006.5339211752912
-2284.9951772600207
-2577.5737254716446
-2894.018058800612
-3229.4355348024487
-3386.2764138843127
-3504.90382760824
-3656.9745921286876
-3860.3667384953565
-4024.155453373196
-4094.208573120371
-4124.166002194435
-4141.823232204977
-4156.884364571235
-4156.884364571235
-258.81836313828320:03:17 | ################---- | 409/500 | 0.46       
-258.8183631382832
-519.0104413168772
-1180.0982306131718
-1657.935504731577
-1973.7835889214207
-2227.6929006175224
-2475.3740805973166
-2727.6958404537454
-2951.0729833436853
-3141.41029948448
-3304.7069461774845
-3446.4103613173475
-3571.148602701748
-3683.213466233771
-3791.482566484172
-3931.8078084498175
-4057.50849047181
-4157.901110364031
-4248.648446059887
-4332.23180778071
-4409.220259374658
-4409.220259374658
-258.81836313828320:03:15 | ################---- | 410/500 | 0.46       
-258.8183631382832
-217.22864532232978
-430.9725352280516
-647.3966161645716
-818.7447088557788
-954.7113400837663
-1068

-3744.085865955065
-3814.3120994563706
-3847.612364302227
-3854.680103365378
-3853.2145046694686
-3849.006390123841
-3849.006390123841
-258.81836313828320:02:42 | #################--- | 425/500 | 0.46       
-258.8183631382832
-389.821965128475
-1057.5827490694373
-1758.559838142853
-2106.1257313074657
-2141.6717479092354
-2109.4060673840995
-2100.386066682148
-2114.821476494244
-2142.5662979704002
-2175.4806015666027
-2209.1593872414373
-2242.1179892768378
-2274.0733159622796
-2305.200550499367
-2335.7872014414193
-2366.158092140394
-2396.5756012729053
-2427.2750315935978
-2458.430604996978
-2490.172115604003
-2490.172115604003
-258.81836313828320:02:40 | #################--- | 426/500 | 0.46       
-258.8183631382832
-467.42223672045145
-705.0652243506582
-976.4655323043924
-1232.4161253045745
-1481.99577510472
-1808.3721512479378
-2137.6392413675253
-2388.155998331081
-2570.1476469630743
-2699.3085614733013
-2751.0894987458564
-2775.9829900303266
-2804.6567502025227
-2840.0051796295

-258.81836313828320:02:08 | ##################-- | 441/500 | 0.46       
-258.8183631382832
-412.2254921824292
-619.1496076407795
-861.2082056527852
-1087.2459590726949
-1278.3655142187285
-1436.333237847823
-1567.1866258155183
-1676.692855359494
-1769.477944870728
-1849.0783639124536
-1918.1687274597498
-1978.7797302052666
-2032.4660689828397
-2080.433236371239
-2123.6257924075435
-2162.7938540532577
-2198.5386935661904
-2231.347448544198
-2261.6199797504505
-2289.6858682679162
-2289.6858682679162
-258.81836313828320:02:05 | ##################-- | 442/500 | 0.46       
-258.8183631382832
-446.8334229797895
-1051.4680527884673
-1583.814405384352
-1892.4277514949902
-2052.415731302269
-2146.004445423071
-2213.654641800592
-2270.588262063342
-2323.17238664851
-2376.411729618322
-2436.223932303137
-2507.4820502231764
-2594.703882129189
-2699.1695433832365
-2813.178262056643
-2925.545130654894
-3032.212950506706
-3133.1351445513137
-3228.782472917259
-3320.307217492282
-3320.307217492282
-

-2606.669879353877
-2786.976470138864
-2914.603226317164
-3007.6499192669617
-3076.8360431723067
-3128.734918313956
-3167.4850289321193
-3195.736714011614
-3215.1865986306643
-3226.882663365521
-3231.3939435045413
-3228.892362931002
-3219.1685797593327
-3201.587619561253
-3174.9790515519444
-3137.4546906436226
-3137.4546906436226
-258.81836313828320:01:31 | ##################-- | 458/500 | 0.46       
-258.8183631382832
-806.3310564597335
-1083.365347325352
-1381.7830469287417
-1608.4223929334644
-1781.8199865421243
-1920.3988185610785
-2032.9334437209104
-2127.280510853636
-2209.531841744472
-2283.785992255289
-2352.97403706563
-2419.8688519454054
-2488.3768952591754
-2561.94897819327
-2634.9803390345674
-2703.2745901315084
-2768.767533808425
-2832.9096670049776
-2896.291772676077
-2959.1798098648187
-2959.1798098648187
-258.81836313828320:01:28 | ##################-- | 459/500 | 0.46       
-258.8183631382832
-349.70741215556563
-432.21519074443216
-1302.860889760031
-1579.0477898236

-2120.8501085037033
-2203.8219390832687
-2284.480094711314
-2362.046255777316
-2438.9510179359877
-2525.8909212732287
-2638.5755813900323
-2638.5755813900323
-258.81836313828320:00:56 | ###################- | 474/500 | 0.46       
-258.8183631382832
-414.0445794848773
-1063.7261044249285
-1301.8534645838868
-1611.923584176243
-1888.7482684164409
-2083.4566152564134
-2217.397899759273
-2313.111871429848
-2385.3909989702724
-2442.8621104778776
-2490.447982995833
-2531.0098355223363
-2566.2653296398553
-2597.282882526934
-2624.7545555672973
-2649.1504492481236
-2670.8078012338974
-2689.9839907102682
-2706.8883230184642
-2721.7008177927505
-2721.7008177927505
-258.81836313828320:00:54 | ###################- | 475/500 | 0.46       
-258.8183631382832
-254.98751708278485
-590.4071410677236
-1004.5397860984656
-1228.3592005918858
-1342.0918286446683
-1425.9751105423784
-1502.3700469863045
-1574.9048530951336
-1642.793679651103
-1704.8552426963583
-1760.466429388102
-1809.6054212026363
-1852.6

-258.81836313828320:00:21 | #################### | 490/500 | 0.46       
-258.8183631382832
-635.2423192184299
-707.0917982492713
-833.2468003870474
-1065.8681278616627
-1255.8425589174594
-1405.8937980838882
-1525.8458195423987
-1623.6337286204973
-1705.0945914294261
-1774.5211100320626
-1835.1438028206594
-1889.4743647479545
-1939.5430007821542
-1987.0673890964745
-2033.584452999287
-2080.5724133297645
-2129.586179847229
-2182.388212166864
-2240.846896878711
-2305.892010194058
-2305.892010194058
-258.81836313828320:00:19 | #################### | 491/500 | 0.46       
-258.8183631382832
-409.3755096889922
-789.568367825078
-1202.60945288711
-1609.5871455208799
-1892.2253333784024
-2074.690617138085
-2207.932427324469
-2322.0094001557645
-2432.9899595256834
-2552.0402509136225
-2689.432958138655
-2848.012086381635
-3013.084653912765
-3167.706319685869
-3308.849944537894
-3436.880955068891
-3551.663986036273
-3653.4475546986787
-3743.26993244321
-3822.671610571763
-3822.671610571763
-25

[tensor([-2.3550e+00, -1.9355e+00, -3.6661e+00,  2.3216e-02,  2.0492e-02,
         -4.7562e-02,  2.8246e+00,  1.7371e+00, -4.1685e-03, -1.1484e-03,
          4.0934e-04, -1.8839e+00, -2.3366e+00,  2.7113e-03,  4.9928e-03,
          3.6793e-02, -3.1709e+00, -2.5666e+00,  4.9370e-03,  3.9433e-03,
          1.1853e-01, -5.5081e+00, -1.8446e+00,  6.2483e-03,  4.8495e-03,
         -1.5682e-01,  5.0755e+00,  3.2241e+00, -4.9285e-03, -1.3481e-02,
         -4.0563e-02,  3.8434e+00,  9.2220e+00, -8.4619e-02, -2.3256e-01,
          2.1243e-01,  1.6690e+01, -3.0715e+00, -1.8416e-01,  1.7021e-02],
        dtype=torch.float64),
 tensor([-2.3550e+00, -1.9355e+00, -3.6661e+00,  2.3216e-02,  2.0492e-02,
         -4.7562e-02,  2.8246e+00,  1.7371e+00, -4.1685e-03, -1.1484e-03,
          4.0934e-04, -1.8839e+00, -2.3366e+00,  2.7113e-03,  4.9928e-03,
          3.6793e-02, -3.1709e+00, -2.5666e+00,  4.9370e-03,  3.9433e-03,
          1.1853e-01, -5.5081e+00, -1.8446e+00,  6.2483e-03,  4.8495e-03,
       

In [None]:
hamiltorch.sample(log_prob_func, params_init, step_size=step_size)


In [None]:
import hamiltorch
L = 5
step_size = .3
N = 2048

hamiltorch.set_random_seed(123)
params_init = params_flat
burn=500
N_nuts = burn + N
params_hmc_nuts = hamiltorch.sample(log_prob_func=log_prob, params_init=params_init,
                                                  num_samples=N_nuts,step_size=step_size,num_steps_per_sample=L,
                                                  sampler=hamiltorch.Sampler.HMC_NUTS, burn=burn,
                                                  desired_accept_rate=0.8)


In [None]:
def log_prob(x):
    x = x[None]
    return model.log_prob(x).sum()

In [None]:
v, _ = input.median(dim=0)

In [None]:
log_prob(v)

In [None]:
L = 5
step_size = .3
N = 2048

hamiltorch.set_random_seed(123)
params_init = v
burn=500
N_nuts = burn + N
params_hmc_nuts = hamiltorch.sample(log_prob_func=log_prob, params_init=params_init,
                                                  num_samples=N_nuts,step_size=step_size,num_steps_per_sample=L,
                                                  sampler=hamiltorch.Sampler.HMC_NUTS, burn=burn,
                                                  desired_accept_rate=0.8)


In [None]:
import seaborn as sns

In [None]:
samples = torch.stack(params_hmc_nuts)