In [1]:
import numpy as np
import time
import bkw2d_implicit_score_matching as ism
import bkw2d_sampling as sp
import bkw2d_particle as par

In [4]:
T = 5 # terminal time
t = 0 # initial time
dt = 0.01 # time step
Nt = int((T-t)/dt) # steps

energy=[]
entropy_decay_rate=[]
fisher_divergence=[]
particle_traj=[]

'''initial sampling'''
nex=100**2
sampling_model = sp.Sampling(nex=nex, time=t, seed=23)
v = np.array(sampling_model.rejection_sampling())
particle_traj.append(v)
#f_net = par.f_exact(v, t)

'''evolution'''
for nt in range(Nt):
    
    print('current time: ', t)
    model = ism.ScoreMatching(samples=v, time=t)
    if nt==0:
        start=time.perf_counter()
        fisher_divergence.append(model.train_L2(tol=5e-4))
        end = time.perf_counter()
        print(end-start)
    else:
        start=time.perf_counter()
        fisher_divergence.append(model.train_ism(max_iter=25))
        end = time.perf_counter()
        print(end-start)
    model.save_model()

    # compute score at t_{n}
    score = model.compute_score(v)
    # compute velocity field
    velocity_field = par.compute_velocity_field(v, score)
    # compute entropy, energy
    entropy_decay_rate.append(par.compute_entropy_decay_rate(velocity_field, score))
    energy.append(par.compute_energy(v))
    # compute density at t_{n+1}
    # score_jac = model.compute_score_jac(v)
    # logdet_field = par.compute_logdet_field(v, score, score_jac)
    # f_net = par.compute_f(f_net, logdet_field, dt)
    # compute v at t_{n+1}
    v = par.compute_v(v, velocity_field, dt)
    particle_traj.append(v)
    t += dt

current time:  0
Iter 1000, L2 Loss: 3.70767e-02, ISM Loss: -3.82933e+00
Iter 2000, L2 Loss: 7.67073e-03, ISM Loss: -3.96921e+00
Iter 3000, L2 Loss: 3.15272e-03, ISM Loss: -4.01268e+00
Iter 4000, L2 Loss: 1.73147e-03, ISM Loss: -4.05760e+00
Iter 5000, L2 Loss: 1.12379e-03, ISM Loss: -4.09866e+00
Iter 6000, L2 Loss: 7.89858e-04, ISM Loss: -4.10384e+00
Iter 7000, L2 Loss: 5.72771e-04, ISM Loss: -4.09454e+00
Iter 7420, Loss: 4.99596e-04, ISM Loss: -4.09454e+00
33.346088699999996
current time:  0.01
Iter 5, ISM Loss: -4.06090e+00, L2 Loss: 1.66420e-03
Iter 10, ISM Loss: -4.06640e+00, L2 Loss: 1.84519e-03
Iter 15, ISM Loss: -4.07050e+00, L2 Loss: 2.30012e-03
Iter 20, ISM Loss: -4.07429e+00, L2 Loss: 2.70620e-03
Iter 25, ISM Loss: -4.07782e+00, L2 Loss: 2.78732e-03
0.4244922000000031
current time:  0.02
Iter 5, ISM Loss: -4.04264e+00, L2 Loss: 3.71154e-03
Iter 10, ISM Loss: -4.04662e+00, L2 Loss: 3.73052e-03
Iter 15, ISM Loss: -4.05014e+00, L2 Loss: 3.88195e-03
Iter 20, ISM Loss: -4.05344e+0

KeyboardInterrupt: 