In [1]:
import argparse
import logging
import math
import os
import random
from collections import namedtuple
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from torch import distributions, nn, optim

import torchsde

In [2]:
# check if the gpu is available or not, if yes, use gpu
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
args = {
    "likelihood": "laplace", # specify the likelihood distribution,p(x|z)
    "adjoint": True, # specify whether to use adjoint sensitivty method to backward the sde
    "debug": True, # specify whether to debug or not
    "data": "segmented_cosine",# specify which data to use
    "kl_anneal_iters": 100, # number of the iterations of the annealing kl divergence schedule
    "train_iters": 1000, # number of iterations to train the model
    "batch_size": 100, 
    "adaptive": False, # whether use adaptive solver or not
    "method": "euler", # the method of sde solver
    "dt": 1e-2, # the parameter dt of sde solver
    "rtol": 1e-3, # the parameter rtol of sde solver
    "atol": 1e-3, # the atol of sde solver
    "scale": 0.05, # the scale, of the likelihood distribution
    "dpi": 500, # dpi of images
    "pause_iters": 50 # the interval to evaluate the model
}

In [4]:
# w/ underscore -> numpy; w/o underscore -> torch.
'''
    ts: original time series; can be segmented or irregular
    ts_ext: with extended time outisde the time series, use to generate latent outside to penalize out-of-data region and spread uncertainty
    ts_vis: regular time series used to plot the data
    ys: the observed dynamic, same size as ts
'''

Data = namedtuple('Data', ['ts_', 'ts_ext_', 'ts_vis_', 'ts', 'ts_ext', 'ts_vis', 'ys', 'ys_'])

In [5]:
class LinearScheduler(object):
    '''
        output a value follows linear schedule from maxval/iters to maxval, with 'iters' steps
    '''
    def __init__(self, iters, maxval=1.0):
        self._iters = max(1, iters)
        self._val = maxval / self._iters
        self._maxval = maxval

    def step(self):
        self._val = min(self._maxval, self._val + self._maxval / self._iters)

    @property
    def val(self):
        return self._val

In [6]:
class EMAMetric(object):
    '''
        Exponential moving average, used to calculate the average
    '''
    def __init__(self, gamma: Optional[float] = .99):
        super(EMAMetric, self).__init__()
        self._val = 0.
        self._gamma = gamma

    def step(self, x: Union[torch.Tensor, np.ndarray]):
        x = x.detach().cpu().numpy() if torch.is_tensor(x) else x
        self._val = self._gamma * self._val + (1 - self._gamma) * x
        return self._val

    @property
    def val(self):
        return self._val

In [7]:
def manual_seed(seed: int):
    '''
        set the random seed, make sure the result keeps the same for each call
    '''
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [8]:
def _stable_division(a, b, epsilon=1e-7):
    '''
        change all elements x in b s.t -epsilon < x < epsilon to epsilon, to make sure the division is stable, won't cause a really large number
    '''
    
    b = torch.where(b.abs().detach() > epsilon, b, torch.full_like(b, fill_value=epsilon) * b.sign())
    return a / b

In [9]:
class LatentSDE(torchsde.SDEIto):

    def __init__(self, theta=1.0, mu=0.0, sigma=0.5):
        super(LatentSDE, self).__init__(noise_type="diagonal")
        logvar = math.log(sigma ** 2 / (2. * theta)) # calculate the log variance

        # Prior drift.
        self.register_buffer("theta", torch.tensor([[theta]])) # prior parameters, register 成buffer, 参数不会进行更新
        self.register_buffer("mu", torch.tensor([[mu]]))
        self.register_buffer("sigma", torch.tensor([[sigma]]))

        # p(y0).
        self.register_buffer("py0_mean", torch.tensor([[mu]])) # setup the prior distribution
        self.register_buffer("py0_logvar", torch.tensor([[logvar]]))

        # Approximate posterior drift: Takes in 2 positional encodings and the state. f(t,y)
        self.net = nn.Sequential(
            nn.Linear(3, 200),
            nn.Tanh(),
            nn.Linear(200, 200),
            nn.Tanh(),
            nn.Linear(200, 1)
        )
        # Initialization trick from Glow.
        self.net[-1].weight.data.fill_(0.)
        self.net[-1].bias.data.fill_(0.)

        # q(y0). the initial value of the approx posterior distribution
        self.qy0_mean = nn.Parameter(torch.tensor([[mu]]), requires_grad=True) 
        self.qy0_logvar = nn.Parameter(torch.tensor([[logvar]]), requires_grad=True)

    def f(self, t, y):  # Approximate posterior drift. f(t, y) = t
        if t.dim() == 0:
            t = torch.full_like(y, fill_value=t)
        # Positional encoding in transformers for time-inhomogeneous posterior.
        return self.net(torch.cat((torch.sin(t), torch.cos(t), y), dim=-1))

    def g(self, t, y):  # Shared diffusion. g(t,y) = sigma
        return self.sigma.repeat(y.size(0), 1)

    def h(self, t, y):  # Prior drift. h(t,y) = theta * (mu-y)
        return self.theta * (self.mu - y)

    def f_aug(self, t, y):  # Drift for augmented dynamics with logqp term.
        '''
             y has two columns, the first column is y0, the one we want to generate the SDE dynamic
             the second column is 0, used to generate the sampling paths from the posterior process, and used to estimate the kl divergence
        '''
        y = y[:, 0:1] # get the first column of y, that is to get y0
        f, g, h = self.f(t, y), self.g(t, y), self.h(t, y) # calculate f, g, h
        u = _stable_division(f - h, g) # u(z,t) = (f-h)/g
        f_logqp = .5 * (u ** 2).sum(dim=1, keepdim=True) # (u^2)/2, the drift of the second sde
        return torch.cat([f, f_logqp], dim=1) # [batch_size, 2]

    def g_aug(self, t, y):  # Diffusion for augmented dynamics with logqp term.
        y = y[:, 0:1]
        g = self.g(t, y)
        g_logqp = torch.zeros_like(y) # the diffusion of the second sde
        return torch.cat([g, g_logqp], dim=1) # [batch_size, 2]

    def forward(self, ts, batch_size, eps=None):
        # recognition process
        eps = torch.randn(batch_size, 1).to(self.qy0_std) if eps is None else eps
        y0 = self.qy0_mean + eps * self.qy0_std # randomly generate y0 from approx posterior distribution q(z|x)
        qy0 = distributions.Normal(loc=self.qy0_mean, scale=self.qy0_std) # approx posterior distribution
        py0 = distributions.Normal(loc=self.py0_mean, scale=self.py0_std) # prior distribution
        logqp0 = distributions.kl_divergence(qy0, py0).sum(dim=1)  # KL(t=0). # kl divergence when t = 0

        aug_y0 = torch.cat([y0, torch.zeros(batch_size, 1).to(y0)], dim=1)
        aug_ys = sdeint_fn(
            sde=self,
            y0=aug_y0,
            ts=ts,
            method=args["method"],
            dt=args["dt"],
            adaptive=args["adaptive"],
            rtol=args["rtol"],
            atol=args["atol"],
            names={'drift': 'f_aug', 'diffusion': 'g_aug'}
        )
        ys, logqp_path = aug_ys[:, :, 0:1], aug_ys[-1, :, 1] 
        # the first column of the last dimension is the sample dynamic
        # the second column of the last dimension is the kl divergence
        logqp = (logqp0 + logqp_path).mean(dim=0)  # KL(t=0) + KL(path). # calculate the kl divergence
        return ys, logqp

    def sample_p(self, ts, batch_size, eps=None, bm=None):
        '''
            latent variable samples from prior distribution p(z), and their SDE dynamics
        '''
        eps = torch.randn(batch_size, 1).to(self.py0_mean) if eps is None else eps
        y0 = self.py0_mean + eps * self.py0_std # [batch_size, 1]: [1024, 1]
        yt = sdeint_fn(self, y0, ts, bm=bm, method='srk', dt=args["dt"], names={'drift': 'h'}) # [len(ts), batch_size, 1]: [300, 1024,1]
        return yt

    def sample_q(self, ts, batch_size, eps=None, bm=None):
        '''
            latent variable samples from approx posterior distribution q(z|x), and their SDE dynamics
        '''
        eps = torch.randn(batch_size, 1).to(self.qy0_mean) if eps is None else eps
        y0 = self.qy0_mean + eps * self.qy0_std # [batch_size, 1]: [1024, 1]
        return sdeint_fn(self, y0, ts, bm=bm, method='srk', dt=args["dt"]) # [len(ts), batch_size, 1]: [300, 1024, 1]

    @property # declare it as a property, then we can access it through self.py0_std and specify setter and getter
    def py0_std(self): # the standard deviation of the prior distribution p(z)
        return torch.exp(.5 * self.py0_logvar)

    @property
    def qy0_std(self): # the standard deviation of the approx posterior distribution q(z|x)
        return torch.exp(.5 * self.qy0_logvar)



In [10]:
def make_segmented_cosine_data():
    ts_ = np.concatenate((np.linspace(0.3, 0.8, 10), np.linspace(1.2, 1.5, 10)), axis=0) # create segmented time series
    ts_ext_ = np.array([0.] + list(ts_) + [2.0]) # add out-of-data time point
    ts_vis_ = np.linspace(0., 2.0, 300) # regular time series used for visualization
    ys_ = np.cos(ts_ * (2. * math.pi))[:, None] # get the segmented cosine data

    ts = torch.tensor(ts_).float()
    ts_ext = torch.tensor(ts_ext_).float()
    ts_vis = torch.tensor(ts_vis_).float()
    ys = torch.tensor(ys_).float().to(device)
    return Data(ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_)


def make_irregular_sine_data():
    ts_ = np.sort(np.random.uniform(low=0.4, high=1.6, size=16)) # create irregular time series
    ts_ext_ = np.array([0.] + list(ts_) + [2.0]) # add out-of-data time point
    ts_vis_ = np.linspace(0., 2.0, 300) # regular time series used for visualization
    ys_ = np.sin(ts_ * (2. * math.pi))[:, None] * 0.8 # get the irregular sine data

    ts = torch.tensor(ts_).float()
    ts_ext = torch.tensor(ts_ext_).float()
    ts_vis = torch.tensor(ts_vis_).float()
    ys = torch.tensor(ys_).float().to(device)
    return Data(ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_)


def make_data():
    data_constructor = {
        'segmented_cosine': make_segmented_cosine_data,
        'irregular_sine': make_irregular_sine_data
    }[args["data"]]
    return data_constructor()

In [11]:
def main():
    # Dataset.
    ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_ = make_data()

    # Plotting parameters.
    vis_batch_size = 1024 # the batch_size used to visaulize
    ylims = (-1.75, 1.75) # set up the ylim of the figure
    alphas = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55]
    percentiles = [0.999, 0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
    vis_idx = np.random.permutation(vis_batch_size) # shuffle the numbers from 1-vis_batch_size
    # From https://colorbrewer2.org/.

    sample_colors = ('#fc4e2a', '#e31a1c', '#bd0026') 
    fill_color = '#fd8d3c'
    mean_color = '#800026'
    num_samples = len(sample_colors)
    eps = torch.randn(vis_batch_size, 1).to(device)  # samples from normal distribution
    bm = torchsde.BrownianInterval(
        t0=ts_vis[0],
        t1=ts_vis[-1],
        size=(vis_batch_size, 1),
        device=device,
        levy_area_approximation='space-time'
    )  # We need space-time Levy area to use the SRK solver, fix the brownian motion allows us to generate the sde dynamic fexedly

    # Model.
    model = LatentSDE().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-2)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=.999)
    kl_scheduler = LinearScheduler(iters=args["kl_anneal_iters"])

    logpy_metric = EMAMetric()
    kl_metric = EMAMetric()
    loss_metric = EMAMetric()
    
    with torch.no_grad():
        zs = model.sample_p(ts=ts_vis, batch_size=vis_batch_size, eps=eps, bm=bm).squeeze() # sde dynamic sampled from prior distribution
        ts_vis_, zs_ = ts_vis.cpu().numpy(), zs.cpu().numpy()  # convert them to numpy
        # print(zs.size()) # [len(ts_vis), batch_size]: [300, 1024]
        zs_ = np.sort(zs_, axis=1) # sort each row

        img_dir = os.path.join('./img/' 'prior.png')
        plt.subplot(frameon=False)
        for alpha, percentile in zip(alphas, percentiles):
            idx = int((1 - percentile) / 2. * vis_batch_size) # 选择要考虑百分之多少的数据
            zs_bot_ = zs_[:, idx] # 计算底线
            zs_top_ = zs_[:, -idx] # 计算顶线
            plt.fill_between(ts_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color) # 用来填充两条曲线之间的区域

        # `zorder` determines who's on top; the larger the more at the top.
        plt.scatter(ts_, ys_, marker='x', zorder=3, color='k', s=35)  # scatter plot the original observed dynamic
        plt.ylim(ylims) 
        plt.xlabel('$t$')
        plt.ylabel('$Y_t$')
        plt.tight_layout()
        plt.savefig(img_dir, dpi=args["dpi"])
        plt.close()
        logging.info(f'Saved prior figure at: {img_dir}')
    
    for global_step in tqdm.tqdm(range(args["train_iters"])):
        # Plot and save.
        
        if global_step % args["pause_iters"] == 0:
            img_path = os.path.join('./img/', f'global_step_{global_step}.png')

            with torch.no_grad():
                zs = model.sample_q(ts=ts_vis, batch_size=vis_batch_size, eps=eps, bm=bm).squeeze()
                samples = zs[:, vis_idx] # samples是按照之前的permutation来排序的
                ts_vis_, zs_, samples_ = ts_vis.cpu().numpy(), zs.cpu().numpy(), samples.cpu().numpy()
                zs_ = np.sort(zs_, axis=1)
                plt.subplot(frameon=False)
                
                # same as above, plot the percentiles
                if True: #args.show_percentiles:
                    for alpha, percentile in zip(alphas, percentiles):
                        idx = int((1 - percentile) / 2. * vis_batch_size)
                        zs_bot_, zs_top_ = zs_[:, idx], zs_[:, -idx]
                        plt.fill_between(ts_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color)
                
                # plot the mean of all the SDE dynamics
                if True: #args.show_mean:
                    plt.plot(ts_vis_, zs_.mean(axis=1), color=mean_color)
                # plot the first three SDE dynamics, since we shuffle the samples already, so the first SDE dynamics are random
                if True: #args.show_samples:
                    for j in range(num_samples):
                        plt.plot(ts_vis_, samples_[:, j], color=sample_colors[j], linewidth=1.0)
                
                if True: #args.show_arrows:
                    num, dt = 12, 0.12
                    t, y = torch.meshgrid(
                        [torch.linspace(0.2, 1.8, num).to(device), torch.linspace(-1.5, 1.5, num).to(device)]
                    )
                    #print(t.size()) # [12,12]
                    #print(y.size()) # [12, 12]
                    t, y = t.reshape(-1, 1), y.reshape(-1, 1)
                    '''
                        ex:
                        t = [[1],[1],[2],[2],[3],[3]]
                        y = [[1],[2],[3],[1],[2],[3]]
                    '''
                    fty = model.f(t=t, y=y).reshape(num, num) # call f(t,y)
                    dt = torch.zeros(num, num).fill_(dt).to(device)
                    dy = fty * dt # calculate the gradients
                    dt_, dy_, t_, y_ = dt.cpu().numpy(), dy.cpu().numpy(), t.cpu().numpy(), y.cpu().numpy()
                    plt.quiver(t_, y_, dt_, dy_, alpha=0.3, edgecolors='k', width=0.0035, scale=50) #画箭头，画风场

                if False: #args.hide_ticks:
                    plt.xticks([], [])
                    plt.yticks([], [])

                plt.scatter(ts_, ys_, marker='x', zorder=3, color='k', s=35)  # scatter plot the Data.
                plt.ylim(ylims)
                plt.xlabel('$t$')
                plt.ylabel('$Y_t$')
                plt.tight_layout()
                plt.savefig(img_path, dpi=args["dpi"])
                plt.close()
                logging.info(f'Saved figure at: {img_path}')

    
        
        # Train.
        optimizer.zero_grad()
        zs, kl = model(ts=ts_ext, batch_size=args["batch_size"]) 
        # print(zs.size()) # [len(ts_ext),batch_size,1]: [22,100,1] 
        zs = zs.squeeze() # [len(ts_ext), batch_size]: [22, 100]
        zs = zs[1:-1]  # Drop first and last which are only used to penalize out-of-data region and spread uncertainty.

        # select the likelihood function p(x|z)
        # generation process
        likelihood_constructor = {"laplace": distributions.Laplace, "normal": distributions.Normal}[args["likelihood"]]
        likelihood = likelihood_constructor(loc=zs, scale=args["scale"])
        logpy = likelihood.log_prob(ys).sum(dim=0).mean(dim=0) # calculate the log likelihood p(x|z)

        loss = -logpy + kl * kl_scheduler.val # we want to maximize the ELBO
        loss.backward()

        optimizer.step()
        scheduler.step()
        kl_scheduler.step()

        logpy_metric.step(logpy)
        kl_metric.step(kl)
        loss_metric.step(loss)

        logging.info(
            f'global_step: {global_step}, '
            f'logpy: {logpy_metric.val:.3f}, '
            f'kl: {kl_metric.val:.3f}, '
            f'loss: {loss_metric.val:.3f}'
        )
    torch.save(
        {'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'kl_scheduler': kl_scheduler},
        os.path.join('./', f'global_step_{global_step}.ckpt')
    )

In [12]:
manual_seed(0)

if args["debug"]:
    logging.getLogger().setLevel(logging.INFO)

ckpt_dir = os.path.join('./', 'ckpts')
os.makedirs(ckpt_dir, exist_ok=True)

sdeint_fn = torchsde.sdeint_adjoint if args["adjoint"] else torchsde.sdeint

main()

INFO:root:Saved prior figure at: ./img/prior.png
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
INFO:root:Saved figure at: ./img/global_step_0.png
INFO:root:global_step: 0, logpy: -2.653, kl: 0.016, loss: 2.653
  0%|                                                                                 | 1/1000 [00:03<52:39,  3.16s/it]INFO:root:global_step: 1, logpy: -4.592, kl: 0.042, loss: 4.593
  0%|▏                                                                                | 2/1000 [00:04<34:42,  2.09s/it]INFO:root:global_step: 2, logpy: -7.416, kl: 0.205, loss: 7.422
  0%|▏                                                                                | 3/1000 [00:05<28:59,  1.75s/it]INFO:root:global_step: 3, logpy: -8.583, kl: 0.231, loss: 8.589
  0%|▎                                                                                | 4/1000 [00:07<26:24,  1.59s/it]INFO:root:global_step: 4, logpy: -9.974, kl: 0.240, loss: 9.981
  0%|▍                           

  4%|███▎                                                                            | 42/1000 [00:58<21:28,  1.34s/it]INFO:root:global_step: 42, logpy: -40.489, kl: 4.120, loss: 41.568
  4%|███▍                                                                            | 43/1000 [00:59<21:35,  1.35s/it]INFO:root:global_step: 43, logpy: -41.111, kl: 4.168, loss: 42.219
  4%|███▌                                                                            | 44/1000 [01:01<21:32,  1.35s/it]INFO:root:global_step: 44, logpy: -41.730, kl: 4.210, loss: 42.864
  4%|███▌                                                                            | 45/1000 [01:02<21:32,  1.35s/it]INFO:root:global_step: 45, logpy: -42.330, kl: 4.252, loss: 43.491
  5%|███▋                                                                            | 46/1000 [01:03<21:27,  1.35s/it]INFO:root:global_step: 46, logpy: -42.961, kl: 4.289, loss: 44.148
  5%|███▊                                                             

  9%|██████▉                                                                         | 86/1000 [01:59<20:35,  1.35s/it]INFO:root:global_step: 86, logpy: -61.477, kl: 5.904, loss: 64.344
  9%|██████▉                                                                         | 87/1000 [02:00<20:38,  1.36s/it]INFO:root:global_step: 87, logpy: -61.786, kl: 5.939, loss: 64.706
  9%|███████                                                                         | 88/1000 [02:02<20:34,  1.35s/it]INFO:root:global_step: 88, logpy: -62.096, kl: 5.972, loss: 65.069
  9%|███████                                                                         | 89/1000 [02:03<20:31,  1.35s/it]INFO:root:global_step: 89, logpy: -62.399, kl: 6.007, loss: 65.428
  9%|███████▏                                                                        | 90/1000 [02:04<20:26,  1.35s/it]INFO:root:global_step: 90, logpy: -62.689, kl: 6.039, loss: 65.771
  9%|███████▎                                                         

 17%|█████████████▌                                                                 | 172/1000 [04:01<19:23,  1.41s/it]INFO:root:global_step: 172, logpy: -47.589, kl: 15.249, loss: 61.525
 17%|█████████████▋                                                                 | 173/1000 [04:03<19:25,  1.41s/it]INFO:root:global_step: 173, logpy: -47.238, kl: 15.367, loss: 61.305
 17%|█████████████▋                                                                 | 174/1000 [04:04<19:13,  1.40s/it]INFO:root:global_step: 174, logpy: -46.874, kl: 15.484, loss: 61.071
 18%|█████████████▊                                                                 | 175/1000 [04:05<19:50,  1.44s/it]INFO:root:global_step: 175, logpy: -46.499, kl: 15.593, loss: 60.818
 18%|█████████████▉                                                                 | 176/1000 [04:07<20:25,  1.49s/it]INFO:root:global_step: 176, logpy: -46.107, kl: 15.718, loss: 60.564
 18%|█████████████▉                                         

 26%|████████████████████▍                                                          | 258/1000 [06:01<17:06,  1.38s/it]INFO:root:global_step: 258, logpy: -20.094, kl: 23.641, loss: 43.182
 26%|████████████████████▍                                                          | 259/1000 [06:03<17:06,  1.38s/it]INFO:root:global_step: 259, logpy: -19.831, kl: 23.729, loss: 43.012
 26%|████████████████████▌                                                          | 260/1000 [06:04<16:52,  1.37s/it]INFO:root:global_step: 260, logpy: -19.553, kl: 23.816, loss: 42.827
 26%|████████████████████▌                                                          | 261/1000 [06:05<16:47,  1.36s/it]INFO:root:global_step: 261, logpy: -19.267, kl: 23.893, loss: 42.623
 26%|████████████████████▋                                                          | 262/1000 [06:07<16:42,  1.36s/it]INFO:root:global_step: 262, logpy: -19.006, kl: 23.977, loss: 42.451
 26%|████████████████████▊                                  

 34%|███████████████████████████▎                                                   | 345/1000 [08:00<14:39,  1.34s/it]INFO:root:global_step: 345, logpy: -2.485, kl: 29.423, loss: 31.677
 35%|███████████████████████████▎                                                   | 346/1000 [08:02<14:52,  1.36s/it]INFO:root:global_step: 346, logpy: -2.324, kl: 29.458, loss: 31.554
 35%|███████████████████████████▍                                                   | 347/1000 [08:03<14:45,  1.36s/it]INFO:root:global_step: 347, logpy: -2.180, kl: 29.500, loss: 31.454
 35%|███████████████████████████▍                                                   | 348/1000 [08:04<14:39,  1.35s/it]INFO:root:global_step: 348, logpy: -2.041, kl: 29.542, loss: 31.359
 35%|███████████████████████████▌                                                   | 349/1000 [08:06<14:36,  1.35s/it]INFO:root:global_step: 349, logpy: -1.902, kl: 29.597, loss: 31.277
 35%|███████████████████████████▋                                

 39%|██████████████████████████████▋                                                | 389/1000 [09:01<13:48,  1.36s/it]INFO:root:global_step: 389, logpy: 3.045, kl: 31.359, loss: 28.166
 39%|██████████████████████████████▊                                                | 390/1000 [09:03<13:45,  1.35s/it]INFO:root:global_step: 390, logpy: 3.154, kl: 31.403, loss: 28.102
 39%|██████████████████████████████▉                                                | 391/1000 [09:04<13:43,  1.35s/it]INFO:root:global_step: 391, logpy: 3.268, kl: 31.428, loss: 28.015
 39%|██████████████████████████████▉                                                | 392/1000 [09:05<13:45,  1.36s/it]INFO:root:global_step: 392, logpy: 3.364, kl: 31.449, loss: 27.941
 39%|███████████████████████████████                                                | 393/1000 [09:07<13:40,  1.35s/it]INFO:root:global_step: 393, logpy: 3.468, kl: 31.476, loss: 27.866
 39%|███████████████████████████████▏                                 

 43%|██████████████████████████████████▏                                            | 433/1000 [10:02<12:43,  1.35s/it]INFO:root:global_step: 433, logpy: 7.023, kl: 32.688, loss: 25.570
 43%|██████████████████████████████████▎                                            | 434/1000 [10:04<12:44,  1.35s/it]INFO:root:global_step: 434, logpy: 7.106, kl: 32.706, loss: 25.505
 44%|██████████████████████████████████▎                                            | 435/1000 [10:05<12:41,  1.35s/it]INFO:root:global_step: 435, logpy: 7.191, kl: 32.733, loss: 25.449
 44%|██████████████████████████████████▍                                            | 436/1000 [10:06<12:48,  1.36s/it]INFO:root:global_step: 436, logpy: 7.258, kl: 32.766, loss: 25.415
 44%|██████████████████████████████████▌                                            | 437/1000 [10:08<12:44,  1.36s/it]INFO:root:global_step: 437, logpy: 7.330, kl: 32.793, loss: 25.372
 44%|██████████████████████████████████▌                              

 48%|█████████████████████████████████████▋                                         | 477/1000 [11:03<11:42,  1.34s/it]INFO:root:global_step: 477, logpy: 9.993, kl: 33.643, loss: 23.590
 48%|█████████████████████████████████████▊                                         | 478/1000 [11:05<11:40,  1.34s/it]INFO:root:global_step: 478, logpy: 10.038, kl: 33.660, loss: 23.562
 48%|█████████████████████████████████████▊                                         | 479/1000 [11:06<11:48,  1.36s/it]INFO:root:global_step: 479, logpy: 10.097, kl: 33.682, loss: 23.525
 48%|█████████████████████████████████████▉                                         | 480/1000 [11:07<11:44,  1.36s/it]INFO:root:global_step: 480, logpy: 10.160, kl: 33.709, loss: 23.489
 48%|█████████████████████████████████████▉                                         | 481/1000 [11:09<11:41,  1.35s/it]INFO:root:global_step: 481, logpy: 10.211, kl: 33.727, loss: 23.456
 48%|██████████████████████████████████████                       

 56%|████████████████████████████████████████████▍                                  | 563/1000 [13:02<09:49,  1.35s/it]INFO:root:global_step: 563, logpy: 13.787, kl: 34.991, loss: 21.178
 56%|████████████████████████████████████████████▌                                  | 564/1000 [13:04<09:52,  1.36s/it]INFO:root:global_step: 564, logpy: 13.818, kl: 35.013, loss: 21.170
 56%|████████████████████████████████████████████▋                                  | 565/1000 [13:05<09:49,  1.35s/it]INFO:root:global_step: 565, logpy: 13.852, kl: 35.047, loss: 21.169
 57%|████████████████████████████████████████████▋                                  | 566/1000 [13:07<09:51,  1.36s/it]INFO:root:global_step: 566, logpy: 13.901, kl: 35.061, loss: 21.134
 57%|████████████████████████████████████████████▊                                  | 567/1000 [13:08<09:47,  1.36s/it]INFO:root:global_step: 567, logpy: 13.942, kl: 35.073, loss: 21.106
 57%|████████████████████████████████████████████▊               

 65%|███████████████████████████████████████████████████▎                           | 650/1000 [15:02<07:53,  1.35s/it]INFO:root:Saved figure at: ./img/global_step_650.png
INFO:root:global_step: 650, logpy: 16.130, kl: 36.050, loss: 19.909
 65%|███████████████████████████████████████████████████▍                           | 651/1000 [15:04<10:26,  1.80s/it]INFO:root:global_step: 651, logpy: 16.155, kl: 36.057, loss: 19.891
 65%|███████████████████████████████████████████████████▌                           | 652/1000 [15:06<09:37,  1.66s/it]INFO:root:global_step: 652, logpy: 16.167, kl: 36.062, loss: 19.885
 65%|███████████████████████████████████████████████████▌                           | 653/1000 [15:07<09:02,  1.56s/it]INFO:root:global_step: 653, logpy: 16.184, kl: 36.065, loss: 19.871
 65%|███████████████████████████████████████████████████▋                           | 654/1000 [15:08<08:38,  1.50s/it]INFO:root:global_step: 654, logpy: 16.196, kl: 36.075, loss: 19.869
 66%|███████

 74%|██████████████████████████████████████████████████████████▏                    | 736/1000 [17:01<05:56,  1.35s/it]INFO:root:global_step: 736, logpy: 17.555, kl: 36.807, loss: 19.248
 74%|██████████████████████████████████████████████████████████▏                    | 737/1000 [17:02<05:54,  1.35s/it]INFO:root:global_step: 737, logpy: 17.561, kl: 36.808, loss: 19.243
 74%|██████████████████████████████████████████████████████████▎                    | 738/1000 [17:03<05:52,  1.35s/it]INFO:root:global_step: 738, logpy: 17.573, kl: 36.804, loss: 19.227
 74%|██████████████████████████████████████████████████████████▍                    | 739/1000 [17:05<05:52,  1.35s/it]INFO:root:global_step: 739, logpy: 17.597, kl: 36.802, loss: 19.200
 74%|██████████████████████████████████████████████████████████▍                    | 740/1000 [17:06<05:49,  1.35s/it]INFO:root:global_step: 740, logpy: 17.621, kl: 36.811, loss: 19.185
 74%|██████████████████████████████████████████████████████████▌ 

 82%|████████████████████████████████████████████████████████████████▉              | 822/1000 [19:00<03:59,  1.35s/it]INFO:root:global_step: 822, logpy: 18.520, kl: 37.374, loss: 18.852
 82%|█████████████████████████████████████████████████████████████████              | 823/1000 [19:01<03:59,  1.35s/it]INFO:root:global_step: 823, logpy: 18.527, kl: 37.388, loss: 18.859
 82%|█████████████████████████████████████████████████████████████████              | 824/1000 [19:03<03:57,  1.35s/it]INFO:root:global_step: 824, logpy: 18.533, kl: 37.393, loss: 18.858
 82%|█████████████████████████████████████████████████████████████████▏             | 825/1000 [19:04<03:56,  1.35s/it]INFO:root:global_step: 825, logpy: 18.539, kl: 37.408, loss: 18.867
 83%|█████████████████████████████████████████████████████████████████▎             | 826/1000 [19:05<03:53,  1.34s/it]INFO:root:global_step: 826, logpy: 18.530, kl: 37.410, loss: 18.878
 83%|████████████████████████████████████████████████████████████

 91%|███████████████████████████████████████████████████████████████████████▋       | 908/1000 [20:59<02:08,  1.40s/it]INFO:root:global_step: 908, logpy: 19.162, kl: 37.751, loss: 18.588
 91%|███████████████████████████████████████████████████████████████████████▊       | 909/1000 [21:00<02:05,  1.38s/it]INFO:root:global_step: 909, logpy: 19.166, kl: 37.770, loss: 18.603
 91%|███████████████████████████████████████████████████████████████████████▉       | 910/1000 [21:02<02:03,  1.37s/it]INFO:root:global_step: 910, logpy: 19.169, kl: 37.767, loss: 18.597
 91%|███████████████████████████████████████████████████████████████████████▉       | 911/1000 [21:03<02:01,  1.36s/it]INFO:root:global_step: 911, logpy: 19.168, kl: 37.782, loss: 18.613
 91%|████████████████████████████████████████████████████████████████████████       | 912/1000 [21:04<01:59,  1.36s/it]INFO:root:global_step: 912, logpy: 19.172, kl: 37.781, loss: 18.608
 91%|████████████████████████████████████████████████████████████

100%|██████████████████████████████████████████████████████████████████████████████▌| 995/1000 [22:58<00:06,  1.35s/it]INFO:root:global_step: 995, logpy: 19.614, kl: 38.050, loss: 18.436
100%|██████████████████████████████████████████████████████████████████████████████▋| 996/1000 [22:59<00:05,  1.35s/it]INFO:root:global_step: 996, logpy: 19.617, kl: 38.053, loss: 18.436
100%|██████████████████████████████████████████████████████████████████████████████▊| 997/1000 [23:01<00:04,  1.35s/it]INFO:root:global_step: 997, logpy: 19.624, kl: 38.068, loss: 18.444
100%|██████████████████████████████████████████████████████████████████████████████▊| 998/1000 [23:02<00:02,  1.35s/it]INFO:root:global_step: 998, logpy: 19.638, kl: 38.077, loss: 18.438
100%|██████████████████████████████████████████████████████████████████████████████▉| 999/1000 [23:03<00:01,  1.35s/it]INFO:root:global_step: 999, logpy: 19.645, kl: 38.082, loss: 18.437
100%|████████████████████████████████████████████████████████████