In [1]:
import argparse
import logging
import math
import os
import random
from collections import namedtuple
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from torch import distributions, nn, optim

import torchsde

In [2]:
# check if the gpu is available or not, if yes, use gpu
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
args = {
    "likelihood": "laplace", # specify the likelihood distribution,p(x|z)
    "adjoint": True, # specify whether to use adjoint sensitivty method to backward the sde
    "debug": True, # specify whether to debug or not
    "data": "segmented_cosine",# specify which data to use
    "kl_anneal_iters": 100, # number of the iterations of the annealing kl divergence schedule
    "train_iters": 1000, # number of iterations to train the model
    "batch_size": 100, 
    "adaptive": False, # whether use adaptive solver or not
    "method": "euler", # the method of sde solver
    "dt": 1e-2, # the parameter dt of sde solver
    "rtol": 1e-3, # the parameter rtol of sde solver
    "atol": 1e-3, # the atol of sde solver
    "scale": 1, # the scale, of the likelihood distribution
    "dpi": 500, # dpi of images
    "pause_iters": 50 # the interval to evaluate the model
}

In [4]:
# w/ underscore -> numpy; w/o underscore -> torch.
'''
    ts: original time series; can be segmented or irregular
    ts_ext: with extended time outisde the time series, use to generate latent outside to penalize out-of-data region and spread uncertainty
    ts_vis: regular time series used to plot the data
    ys: the observed dynamic, same size as ts
'''

Data = namedtuple('Data', ['ts_', 'ts_ext_', 'ts_vis_', 'ts', 'ts_ext', 'ts_vis', 'ys', 'ys_'])

In [5]:
class LinearScheduler(object):
    '''
        output a value follows linear schedule from maxval/iters to maxval, with 'iters' steps
    '''
    
    '''
        iters = 100,
        maxval = 1,
        1/100 = 0.001, 0.002, .... 1
    '''
    
    def __init__(self, iters, maxval=1.0):
        self._iters = max(1, iters)
        self._val = maxval / self._iters
        self._maxval = maxval

    def step(self):
        self._val = min(self._maxval, self._val + self._maxval / self._iters)

    @property
    def val(self):
        return self._val

In [6]:
class EMAMetric(object):
    '''
        Exponential moving average, used to calculate the average
    '''
    def __init__(self, gamma: Optional[float] = .99):
        super(EMAMetric, self).__init__()
        self._val = 0.
        self._gamma = gamma

    def step(self, x: Union[torch.Tensor, np.ndarray]):
        x = x.detach().cpu().numpy() if torch.is_tensor(x) else x
        self._val = self._gamma * self._val + (1 - self._gamma) * x
        return self._val

    @property
    def val(self):
        return self._val

In [7]:
def manual_seed(seed: int):
    '''
        set the random seed, make sure the result keeps the same for each call
    '''
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [8]:
def _stable_division(a, b, epsilon=1e-7): #a/b
    '''
        change all elements x in b s.t -epsilon < x < epsilon to epsilon, to make sure the division is stable, won't cause a really large number
    '''
    
    b = torch.where(b.abs().detach() > epsilon, b, torch.full_like(b, fill_value=epsilon) * b.sign())
    return a / b

In [13]:
class LatentSDE(torchsde.SDEIto):

    def __init__(self, theta=1.0, mu=0.0, sigma=0.5):
        super(LatentSDE, self).__init__(noise_type="diagonal")
        logvar = math.log(sigma ** 2 / (2. * theta)) # calculate the log variance

        # Prior drift.
        self.register_buffer("theta", torch.tensor([[theta]])) # prior parameters, register 成buffer, 参数不会进行更新
        self.register_buffer("mu", torch.tensor([[mu]]))
        self.register_buffer("sigma", torch.tensor([[sigma]]))

        # p(z0).
        self.register_buffer("py0_mean", torch.tensor([[mu]])) # setup the prior distribution
        self.register_buffer("py0_logvar", torch.tensor([[logvar]]))

        # Approximate posterior drift: Takes in 2 positional encodings and the state. f(t,y)
        self.net = nn.Sequential( #h\Phi()
            nn.Linear(3, 200),
            nn.Tanh(),
            nn.Linear(200, 200),
            nn.Tanh(),
            nn.Linear(200, 1)
        )
        # Initialization trick from Glow. 
        self.net[-1].weight.data.fill_(0.)
        self.net[-1].bias.data.fill_(0.)

        # q(y0). the initial value of the approx posterior distribution
        self.qy0_mean = nn.Parameter(torch.tensor([[mu]]), requires_grad=True) 
        self.qy0_logvar = nn.Parameter(torch.tensor([[logvar]]), requires_grad=True)

    def f(self, t, y):  # Approximate posterior drift. f(t, y) = t # h\Phi
        if t.dim() == 0:
            t = torch.full_like(y, fill_value=t)
        # Positional encoding in transformers for time-inhomogeneous posterior.
        return self.net(torch.cat((torch.sin(t), torch.cos(t), y), dim=-1))

    def g(self, t, y):  # Shared diffusion. g(t,y) = sigma
        return self.sigma.repeat(y.size(0), 1)

    def h(self, t, y):  # Prior drift. h(t,y) = theta * (mu-y) # h\Theta
        return self.theta * (self.mu - y) # need to figure out

    def f_aug(self, t, y):  # Drift for augmented dynamics with logqp term.
        '''
             y has two columns, the first column is y0, the one we want to generate the SDE dynamic
             the second column is 0, used to generate the sampling paths from the posterior process, and used to estimate the kl divergence
        '''
        y = y[:, 0:1] # get the first column of y, that is to get y0 # z0
        f, g, h = self.f(t, y), self.g(t, y), self.h(t, y) # calculate f, g, h
        u = _stable_division(f - h, g) # u(z,t) = (f-h)/g
        f_logqp = .5 * (u ** 2).sum(dim=1, keepdim=True) # (u^2)/2, the drift of the second sde
        return torch.cat([f, f_logqp], dim=1) # [batch_size, 2]

    def g_aug(self, t, y):  # Diffusion for augmented dynamics with logqp term.
        y = y[:, 0:1]
        g = self.g(t, y)
        g_logqp = torch.zeros_like(y) # the diffusion of the second sde
        return torch.cat([g, g_logqp], dim=1) # [batch_size, 2]

    def forward(self, ts, batch_size, eps=None):
        # recognition process
        eps = torch.randn(batch_size, 1).to(self.qy0_std) if eps is None else eps
        y0 = self.qy0_mean + eps * self.qy0_std # randomly generate z0 from approx posterior distribution q(z|x)
        
        
        qy0 = distributions.Normal(loc=self.qy0_mean, scale=self.qy0_std) # approx posterior distribution
        py0 = distributions.Normal(loc=self.py0_mean, scale=self.py0_std) # prior distribution
        
        logqp0 = distributions.kl_divergence(qy0, py0).sum(dim=1)  # KL(t=0). # kl divergence when t = 0

        aug_y0 = torch.cat([y0, torch.zeros(batch_size, 1).to(y0)], dim=1)
        aug_ys = sdeint_fn(
            sde=self,
            y0=aug_y0,
            ts=ts, #[0,0.1,0.2, 0.3]
            method=args["method"],
            dt=args["dt"],
            adaptive=args["adaptive"],
            rtol=args["rtol"],
            atol=args["atol"],
            names={'drift': 'f_aug', 'diffusion': 'g_aug'}
        )
        #[len(ts),batch_size,2]
        ys, logqp_path = aug_ys[:, :, 0:1], aug_ys[-1, :, 1] 
        # the first column of the last dimension is the sample dynamic
        # the second column of the last dimension is the kl divergence
        logqp = (logqp0 + logqp_path).mean(dim=0)  # KL(t=0) + KL(path). # calculate the kl divergence
        return ys, logqp

    def sample_p(self, ts, batch_size, eps=None, bm=None):
        '''
            latent variable samples from prior distribution p(z), and their SDE dynamics
        '''
        eps = torch.randn(batch_size, 1).to(self.py0_mean) if eps is None else eps
        y0 = self.py0_mean + eps * self.py0_std # [batch_size, 1]: [1024, 1]
        
        yt = sdeint_fn(self, y0, ts, bm=bm, method='srk', dt=args["dt"], names={'drift': 'h'}) # [len(ts), batch_size, 1]: [300, 1024,1]
        return yt

    def sample_q(self, ts, batch_size, eps=None, bm=None):
        '''
            latent variable samples from approx posterior distribution q(z|x), and their SDE dynamics
        '''
        eps = torch.randn(batch_size, 1).to(self.qy0_mean) if eps is None else eps
        y0 = self.qy0_mean + eps * self.qy0_std # [batch_size, 1]: [1024, 1]
        return sdeint_fn(self, y0, ts, bm=bm, method='srk', dt=args["dt"]) # [len(ts), batch_size, 1]: [300, 1024, 1]

    # self.py_std = 
    @property # declare it as a property, then we can access it through self.py0_std and specify setter and getter
    def py0_std(self): # the standard deviation of the prior distribution p(z)
        return torch.exp(.5 * self.py0_logvar)

    @property
    def qy0_std(self): # the standard deviation of the approx posterior distribution q(z|x)
        return torch.exp(.5 * self.qy0_logvar)



In [14]:
def make_segmented_cosine_data():
    ts_ = np.concatenate((np.linspace(0.3, 0.8, 10), np.linspace(1.2, 1.5, 10)), axis=0) # create segmented time series
    ts_ext_ = np.array([0.] + list(ts_) + [2.0]) # add out-of-data time point
    ts_vis_ = np.linspace(0., 2.0, 300) # regular time series used for visualization
    ys_ = np.cos(ts_ * (2. * math.pi))[:, None] # get the segmented cosine data

    ts = torch.tensor(ts_).float()
    ts_ext = torch.tensor(ts_ext_).float()
    ts_vis = torch.tensor(ts_vis_).float()
    ys = torch.tensor(ys_).float().to(device)
    return Data(ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_)


def make_irregular_sine_data():
    ts_ = np.sort(np.random.uniform(low=0.4, high=1.6, size=16)) # create irregular time series
    ts_ext_ = np.array([0.] + list(ts_) + [2.0]) # add out-of-data time point
    ts_vis_ = np.linspace(0., 2.0, 300) # regular time series used for visualization
    ys_ = np.sin(ts_ * (2. * math.pi))[:, None] * 0.8 # get the irregular sine data

    ts = torch.tensor(ts_).float()
    ts_ext = torch.tensor(ts_ext_).float()
    ts_vis = torch.tensor(ts_vis_).float()
    ys = torch.tensor(ys_).float().to(device)
    return Data(ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_)


def make_data():
    data_constructor = {
        'segmented_cosine': make_segmented_cosine_data,
        'irregular_sine': make_irregular_sine_data
    }[args["data"]]
    return data_constructor()

In [15]:
def main():
    # Dataset.
    ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_ = make_data()

    # Plotting parameters.
    vis_batch_size = 1024 # the batch_size used to visaulize
    ylims = (-1.75, 1.75) # set up the ylim of the figure
    
    alphas = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55]
    percentiles = [0.999, 0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
    
    vis_idx = np.random.permutation(vis_batch_size) # shuffle the numbers from 1-vis_batch_size
    # From https://colorbrewer2.org/.

    sample_colors = ('#fc4e2a', '#e31a1c', '#bd0026') 
    fill_color = '#fd8d3c'
    
    mean_color = '#800026'
    
    num_samples = len(sample_colors)
    
    eps = torch.randn(vis_batch_size, 1).to(device)  # samples from normal distribution
    
    bm = torchsde.BrownianInterval(
        t0=ts_vis[0],
        t1=ts_vis[-1],
        size=(vis_batch_size, 1),
        device=device,
        levy_area_approximation='space-time'
    )  # We need space-time Levy area to use the SRK solver, fix the brownian motion allows us to generate the sde dynamic fexedly

    # Model.
    model = LatentSDE().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-2)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=.999)
    kl_scheduler = LinearScheduler(iters=args["kl_anneal_iters"])

    logpy_metric = EMAMetric()
    kl_metric = EMAMetric()
    loss_metric = EMAMetric()
    
    with torch.no_grad():
        zs = model.sample_p(ts=ts_vis, batch_size=vis_batch_size, eps=eps, bm=bm).squeeze() # sde dynamic sampled from prior distribution
        ts_vis_, zs_ = ts_vis.cpu().numpy(), zs.cpu().numpy()  # convert them to numpy
        # print(zs.size()) # [len(ts_vis), batch_size]: [300, 1024]
        zs_ = np.sort(zs_, axis=1) # sort each row

        img_dir = os.path.join('./img/' 'prior.png')
        plt.subplot(frameon=False)
        for alpha, percentile in zip(alphas, percentiles):
            idx = int((1 - percentile) / 2. * vis_batch_size) # 选择要考虑百分之多少的数据
            zs_bot_ = zs_[:, idx] # 计算底线
            zs_top_ = zs_[:, -idx] # 计算顶线
            plt.fill_between(ts_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color) # 用来填充两条曲线之间的区域

        # `zorder` determines who's on top; the larger the more at the top.
        plt.scatter(ts_, ys_, marker='x', zorder=3, color='k', s=35)  # scatter plot the original observed dynamic
        plt.ylim(ylims) 
        plt.xlabel('$t$')
        plt.ylabel('$Y_t$')
        plt.tight_layout()
        plt.savefig(img_dir, dpi=args["dpi"])
        plt.close()
        logging.info(f'Saved prior figure at: {img_dir}')
    
    for global_step in tqdm.tqdm(range(args["train_iters"])):
        # Plot and save.
        
        if global_step % args["pause_iters"] == 0:
            img_path = os.path.join('./img/', f'global_step_{global_step}.png')

            with torch.no_grad():
                zs = model.sample_q(ts=ts_vis, batch_size=vis_batch_size, eps=eps, bm=bm).squeeze()
                samples = zs[:, vis_idx] # samples是按照之前的permutation来排序的
                ts_vis_, zs_, samples_ = ts_vis.cpu().numpy(), zs.cpu().numpy(), samples.cpu().numpy()
                zs_ = np.sort(zs_, axis=1)
                plt.subplot(frameon=False)
                
                # same as above, plot the percentiles
                if True: #args.show_percentiles:
                    for alpha, percentile in zip(alphas, percentiles):
                        idx = int((1 - percentile) / 2. * vis_batch_size)
                        zs_bot_, zs_top_ = zs_[:, idx], zs_[:, -idx]
                        plt.fill_between(ts_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color)
                
                # plot the mean of all the SDE dynamics
                if True: #args.show_mean:
                    plt.plot(ts_vis_, zs_.mean(axis=1), color=mean_color)
                # plot the first three SDE dynamics, since we shuffle the samples already, so the first SDE dynamics are random
                if True: #args.show_samples:
                    for j in range(num_samples):
                        plt.plot(ts_vis_, samples_[:, j], color=sample_colors[j], linewidth=1.0)
                
                if True: #args.show_arrows:
                    num, dt = 12, 0.12
                    t, y = torch.meshgrid(
                        [torch.linspace(0.2, 1.8, num).to(device), torch.linspace(-1.5, 1.5, num).to(device)]
                    )
                    #print(t.size()) # [12,12]
                    #print(y.size()) # [12, 12]
                    t, y = t.reshape(-1, 1), y.reshape(-1, 1)
                    '''
                        ex:
                        t = [[1],[1],[2],[2],[3],[3]]
                        y = [[1],[2],[3],[1],[2],[3]]
                    '''
                    fty = model.f(t=t, y=y).reshape(num, num) # call f(t,y)
                    dt = torch.zeros(num, num).fill_(dt).to(device)
                    dy = fty * dt # calculate the gradients
                    dt_, dy_, t_, y_ = dt.cpu().numpy(), dy.cpu().numpy(), t.cpu().numpy(), y.cpu().numpy()
                    plt.quiver(t_, y_, dt_, dy_, alpha=0.3, edgecolors='k', width=0.0035, scale=50) #画箭头，画风场

                if False: #args.hide_ticks:
                    plt.xticks([], [])
                    plt.yticks([], [])

                plt.scatter(ts_, ys_, marker='x', zorder=3, color='k', s=35)  # scatter plot the Data.
                plt.ylim(ylims)
                plt.xlabel('$t$')
                plt.ylabel('$Y_t$')
                plt.tight_layout()
                plt.savefig(img_path, dpi=args["dpi"])
                plt.close()
                logging.info(f'Saved figure at: {img_path}')

    
        
        # Train.
        optimizer.zero_grad()
        zs, kl = model(ts=ts_ext, batch_size=args["batch_size"]) 
        # print(zs.size()) # [len(ts_ext),batch_size,1]: [22,100,1] 
        zs = zs.squeeze() # [len(ts_ext), batch_size]: [22, 100]
        zs = zs[1:-1]  # Drop first and last which are only used to penalize out-of-data region and spread uncertainty.

        # select the likelihood function p(x|z)
        # generation process
        likelihood_constructor = {"laplace": distributions.Laplace, "normal": distributions.Normal}[args["likelihood"]]
        likelihood = likelihood_constructor(loc=zs, scale=args["scale"]) #f(x) = p(x|z)
        
        
        logpy = likelihood.log_prob(ys).sum(dim=0).mean(dim=0) # calculate the log likelihood p(x|z)

        loss = -logpy + kl * kl_scheduler.val # we want to maximize the ELBO
        loss.backward()

        optimizer.step()
        scheduler.step()
        kl_scheduler.step()

        logpy_metric.step(logpy)
        kl_metric.step(kl)
        loss_metric.step(loss)

        logging.info(
            f'global_step: {global_step}, '
            f'logpy: {logpy_metric.val:.3f}, '
            f'kl: {kl_metric.val:.3f}, '
            f'loss: {loss_metric.val:.3f}'
        )
    torch.save(
        {'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'kl_scheduler': kl_scheduler},
        os.path.join('./', f'global_step_{global_step}.ckpt')
    )

In [18]:
manual_seed(0)

if args["debug"]:
    logging.getLogger().setLevel(logging.INFO)

ckpt_dir = os.path.join('./', 'ckpts')
os.makedirs(ckpt_dir, exist_ok=True)

sdeint_fn = torchsde.sdeint_adjoint if args["adjoint"] else torchsde.sdeint

main()

INFO:root:Saved prior figure at: ./img/prior.png
  0%|          | 0/1000 [00:00<?, ?it/s]INFO:root:Saved figure at: ./img/global_step_0.png
INFO:root:global_step: 0, logpy: -0.288, kl: 0.016, loss: 0.289
  0%|          | 1/1000 [00:05<1:34:58,  5.70s/it]INFO:root:global_step: 1, logpy: -0.540, kl: 0.041, loss: 0.541
  0%|          | 2/1000 [00:06<1:11:21,  4.29s/it]INFO:root:global_step: 2, logpy: -0.800, kl: 0.149, loss: 0.804
  0%|          | 3/1000 [00:07<54:48,  3.30s/it]  INFO:root:global_step: 3, logpy: -1.015, kl: 0.162, loss: 1.020
  0%|          | 4/1000 [00:08<43:15,  2.61s/it]INFO:root:global_step: 4, logpy: -1.245, kl: 0.171, loss: 1.250
  0%|          | 5/1000 [00:09<35:57,  2.17s/it]INFO:root:global_step: 5, logpy: -1.467, kl: 0.188, loss: 1.473
  1%|          | 6/1000 [00:10<30:02,  1.81s/it]INFO:root:global_step: 6, logpy: -1.678, kl: 0.218, loss: 1.686
  1%|          | 7/1000 [00:11<25:54,  1.57s/it]INFO:root:global_step: 7, logpy: -1.879, kl: 0.265, loss: 1.891
  1%| 

  7%|▋         | 71/1000 [01:19<15:20,  1.01it/s]INFO:root:global_step: 71, logpy: -11.390, kl: 1.713, loss: 11.928
  7%|▋         | 72/1000 [01:20<15:11,  1.02it/s]INFO:root:global_step: 72, logpy: -11.497, kl: 1.713, loss: 12.042
  7%|▋         | 73/1000 [01:21<15:05,  1.02it/s]INFO:root:global_step: 73, logpy: -11.602, kl: 1.712, loss: 12.154
  7%|▋         | 74/1000 [01:22<14:57,  1.03it/s]INFO:root:global_step: 74, logpy: -11.708, kl: 1.713, loss: 12.267
  8%|▊         | 75/1000 [01:23<14:54,  1.03it/s]INFO:root:global_step: 75, logpy: -11.811, kl: 1.713, loss: 12.378
  8%|▊         | 76/1000 [01:24<14:51,  1.04it/s]INFO:root:global_step: 76, logpy: -11.911, kl: 1.712, loss: 12.485
  8%|▊         | 77/1000 [01:25<14:46,  1.04it/s]INFO:root:global_step: 77, logpy: -12.010, kl: 1.712, loss: 12.592
  8%|▊         | 78/1000 [01:26<14:47,  1.04it/s]INFO:root:global_step: 78, logpy: -12.110, kl: 1.713, loss: 12.700
  8%|▊         | 79/1000 [01:27<15:08,  1.01it/s]INFO:root:global_step: 

 21%|██        | 208/1000 [11:01<14:31,  1.10s/it]INFO:root:global_step: 208, logpy: -19.169, kl: 1.640, loss: 20.495
 21%|██        | 209/1000 [11:02<13:58,  1.06s/it]INFO:root:global_step: 209, logpy: -19.197, kl: 1.640, loss: 20.526
 21%|██        | 210/1000 [11:03<13:35,  1.03s/it]INFO:root:global_step: 210, logpy: -19.221, kl: 1.640, loss: 20.552
 21%|██        | 211/1000 [11:04<13:17,  1.01s/it]INFO:root:global_step: 211, logpy: -19.243, kl: 1.641, loss: 20.579
 21%|██        | 212/1000 [11:05<13:16,  1.01s/it]INFO:root:global_step: 212, logpy: -19.265, kl: 1.642, loss: 20.606
 21%|██▏       | 213/1000 [11:06<13:41,  1.04s/it]INFO:root:global_step: 213, logpy: -19.288, kl: 1.644, loss: 20.633
 21%|██▏       | 214/1000 [11:07<13:25,  1.02s/it]INFO:root:global_step: 214, logpy: -19.312, kl: 1.644, loss: 20.661
 22%|██▏       | 215/1000 [11:08<13:13,  1.01s/it]INFO:root:global_step: 215, logpy: -19.338, kl: 1.643, loss: 20.689
 22%|██▏       | 216/1000 [11:09<13:25,  1.03s/it]INFO:r

 28%|██▊       | 277/1000 [12:14<11:51,  1.02it/s]INFO:root:global_step: 277, logpy: -20.435, kl: 1.656, loss: 21.934
 28%|██▊       | 278/1000 [12:15<11:47,  1.02it/s]INFO:root:global_step: 278, logpy: -20.448, kl: 1.656, loss: 21.948
 28%|██▊       | 279/1000 [12:16<11:46,  1.02it/s]INFO:root:global_step: 279, logpy: -20.460, kl: 1.656, loss: 21.962
 28%|██▊       | 280/1000 [12:17<11:56,  1.00it/s]INFO:root:global_step: 280, logpy: -20.471, kl: 1.657, loss: 21.975
 28%|██▊       | 281/1000 [12:18<11:55,  1.01it/s]INFO:root:global_step: 281, logpy: -20.480, kl: 1.657, loss: 21.986
 28%|██▊       | 282/1000 [12:19<11:50,  1.01it/s]INFO:root:global_step: 282, logpy: -20.492, kl: 1.657, loss: 21.999
 28%|██▊       | 283/1000 [12:20<11:42,  1.02it/s]INFO:root:global_step: 283, logpy: -20.504, kl: 1.656, loss: 22.012
 28%|██▊       | 284/1000 [12:21<11:42,  1.02it/s]INFO:root:global_step: 284, logpy: -20.515, kl: 1.656, loss: 22.025
 28%|██▊       | 285/1000 [12:22<11:37,  1.02it/s]INFO:r

 35%|███▍      | 346/1000 [13:26<10:39,  1.02it/s]INFO:root:global_step: 346, logpy: -21.066, kl: 1.668, loss: 22.656
 35%|███▍      | 347/1000 [13:27<10:35,  1.03it/s]INFO:root:global_step: 347, logpy: -21.070, kl: 1.670, loss: 22.662
 35%|███▍      | 348/1000 [13:28<10:30,  1.03it/s]INFO:root:global_step: 348, logpy: -21.075, kl: 1.671, loss: 22.669
 35%|███▍      | 349/1000 [13:29<10:25,  1.04it/s]INFO:root:global_step: 349, logpy: -21.081, kl: 1.669, loss: 22.674
 35%|███▌      | 350/1000 [13:30<10:23,  1.04it/s]INFO:root:Saved figure at: ./img/global_step_350.png
INFO:root:global_step: 350, logpy: -21.085, kl: 1.667, loss: 22.676
 35%|███▌      | 351/1000 [13:35<24:40,  2.28s/it]INFO:root:global_step: 351, logpy: -21.093, kl: 1.666, loss: 22.684
 35%|███▌      | 352/1000 [13:36<20:34,  1.91s/it]INFO:root:global_step: 352, logpy: -21.098, kl: 1.666, loss: 22.690
 35%|███▌      | 353/1000 [13:37<17:45,  1.65s/it]INFO:root:global_step: 353, logpy: -21.101, kl: 1.668, loss: 22.695
 35

 42%|████▏     | 415/1000 [15:27<09:41,  1.01it/s]INFO:root:global_step: 415, logpy: -21.371, kl: 1.684, loss: 23.017
 42%|████▏     | 416/1000 [15:28<09:34,  1.02it/s]INFO:root:global_step: 416, logpy: -21.377, kl: 1.684, loss: 23.023
 42%|████▏     | 417/1000 [15:29<09:32,  1.02it/s]INFO:root:global_step: 417, logpy: -21.377, kl: 1.684, loss: 23.022
 42%|████▏     | 418/1000 [15:30<09:33,  1.01it/s]INFO:root:global_step: 418, logpy: -21.381, kl: 1.684, loss: 23.027
 42%|████▏     | 419/1000 [15:31<09:49,  1.01s/it]INFO:root:global_step: 419, logpy: -21.384, kl: 1.683, loss: 23.030
 42%|████▏     | 420/1000 [15:32<09:39,  1.00it/s]INFO:root:global_step: 420, logpy: -21.385, kl: 1.682, loss: 23.030
 42%|████▏     | 421/1000 [15:33<09:34,  1.01it/s]INFO:root:global_step: 421, logpy: -21.387, kl: 1.681, loss: 23.031
 42%|████▏     | 422/1000 [15:34<09:42,  1.01s/it]INFO:root:global_step: 422, logpy: -21.393, kl: 1.680, loss: 23.037
 42%|████▏     | 423/1000 [15:35<09:32,  1.01it/s]INFO:r

 48%|████▊     | 484/1000 [16:39<08:28,  1.02it/s]INFO:root:global_step: 484, logpy: -21.525, kl: 1.687, loss: 23.192
 48%|████▊     | 485/1000 [16:40<08:23,  1.02it/s]INFO:root:global_step: 485, logpy: -21.525, kl: 1.688, loss: 23.194
 49%|████▊     | 486/1000 [16:41<08:25,  1.02it/s]INFO:root:global_step: 486, logpy: -21.528, kl: 1.687, loss: 23.196
 49%|████▊     | 487/1000 [16:42<08:26,  1.01it/s]INFO:root:global_step: 487, logpy: -21.530, kl: 1.686, loss: 23.197
 49%|████▉     | 488/1000 [16:43<08:24,  1.02it/s]INFO:root:global_step: 488, logpy: -21.531, kl: 1.686, loss: 23.198
 49%|████▉     | 489/1000 [16:44<08:23,  1.01it/s]INFO:root:global_step: 489, logpy: -21.532, kl: 1.687, loss: 23.199
 49%|████▉     | 490/1000 [16:45<08:20,  1.02it/s]INFO:root:global_step: 490, logpy: -21.532, kl: 1.687, loss: 23.201
 49%|████▉     | 491/1000 [16:46<08:29,  1.00s/it]INFO:root:global_step: 491, logpy: -21.535, kl: 1.687, loss: 23.204
 49%|████▉     | 492/1000 [16:47<08:31,  1.01s/it]INFO:r

 62%|██████▏   | 621/1000 [50:32<06:23,  1.01s/it]INFO:root:global_step: 621, logpy: -21.611, kl: 1.691, loss: 23.297
 62%|██████▏   | 622/1000 [50:33<06:15,  1.01it/s]INFO:root:global_step: 622, logpy: -21.614, kl: 1.691, loss: 23.300
 62%|██████▏   | 623/1000 [50:34<06:09,  1.02it/s]INFO:root:global_step: 623, logpy: -21.617, kl: 1.691, loss: 23.304
 62%|██████▏   | 624/1000 [50:35<06:05,  1.03it/s]INFO:root:global_step: 624, logpy: -21.615, kl: 1.691, loss: 23.302
 62%|██████▎   | 625/1000 [50:36<06:03,  1.03it/s]INFO:root:global_step: 625, logpy: -21.615, kl: 1.692, loss: 23.303
 63%|██████▎   | 626/1000 [50:37<06:03,  1.03it/s]INFO:root:global_step: 626, logpy: -21.616, kl: 1.694, loss: 23.305
 63%|██████▎   | 627/1000 [50:37<06:00,  1.03it/s]INFO:root:global_step: 627, logpy: -21.619, kl: 1.694, loss: 23.309
 63%|██████▎   | 628/1000 [50:38<05:59,  1.03it/s]INFO:root:global_step: 628, logpy: -21.619, kl: 1.694, loss: 23.309
 63%|██████▎   | 629/1000 [50:39<05:57,  1.04it/s]INFO:r

 69%|██████▉   | 690/1000 [51:45<05:16,  1.02s/it]INFO:root:global_step: 690, logpy: -21.643, kl: 1.709, loss: 23.349
 69%|██████▉   | 691/1000 [51:46<05:15,  1.02s/it]INFO:root:global_step: 691, logpy: -21.640, kl: 1.707, loss: 23.345
 69%|██████▉   | 692/1000 [51:47<05:10,  1.01s/it]INFO:root:global_step: 692, logpy: -21.641, kl: 1.706, loss: 23.345
 69%|██████▉   | 693/1000 [51:48<05:09,  1.01s/it]INFO:root:global_step: 693, logpy: -21.641, kl: 1.706, loss: 23.345
 69%|██████▉   | 694/1000 [51:49<05:05,  1.00it/s]INFO:root:global_step: 694, logpy: -21.639, kl: 1.706, loss: 23.342
 70%|██████▉   | 695/1000 [51:50<05:04,  1.00it/s]INFO:root:global_step: 695, logpy: -21.640, kl: 1.707, loss: 23.345
 70%|██████▉   | 696/1000 [51:51<05:02,  1.01it/s]INFO:root:global_step: 696, logpy: -21.639, kl: 1.707, loss: 23.344
 70%|██████▉   | 697/1000 [51:52<05:04,  1.00s/it]INFO:root:global_step: 697, logpy: -21.638, kl: 1.707, loss: 23.343
 70%|██████▉   | 698/1000 [51:53<05:00,  1.01it/s]INFO:r

 76%|███████▌  | 759/1000 [53:06<04:46,  1.19s/it]INFO:root:global_step: 759, logpy: -21.645, kl: 1.693, loss: 23.336
 76%|███████▌  | 760/1000 [53:07<04:28,  1.12s/it]INFO:root:global_step: 760, logpy: -21.646, kl: 1.691, loss: 23.336
 76%|███████▌  | 761/1000 [53:08<04:16,  1.07s/it]INFO:root:global_step: 761, logpy: -21.649, kl: 1.691, loss: 23.339
 76%|███████▌  | 762/1000 [53:09<04:06,  1.04s/it]INFO:root:global_step: 762, logpy: -21.650, kl: 1.690, loss: 23.339
 76%|███████▋  | 763/1000 [53:10<04:04,  1.03s/it]INFO:root:global_step: 763, logpy: -21.650, kl: 1.691, loss: 23.340
 76%|███████▋  | 764/1000 [53:11<04:00,  1.02s/it]INFO:root:global_step: 764, logpy: -21.650, kl: 1.692, loss: 23.340
 76%|███████▋  | 765/1000 [53:12<04:00,  1.03s/it]INFO:root:global_step: 765, logpy: -21.650, kl: 1.693, loss: 23.341
 77%|███████▋  | 766/1000 [53:13<03:57,  1.01s/it]INFO:root:global_step: 766, logpy: -21.649, kl: 1.692, loss: 23.340
 77%|███████▋  | 767/1000 [53:14<03:52,  1.00it/s]INFO:r

 83%|████████▎ | 828/1000 [54:24<03:04,  1.07s/it]INFO:root:global_step: 828, logpy: -21.655, kl: 1.697, loss: 23.351
 83%|████████▎ | 829/1000 [54:25<02:57,  1.04s/it]INFO:root:global_step: 829, logpy: -21.657, kl: 1.696, loss: 23.352
 83%|████████▎ | 830/1000 [54:26<02:58,  1.05s/it]INFO:root:global_step: 830, logpy: -21.655, kl: 1.694, loss: 23.348
 83%|████████▎ | 831/1000 [54:27<02:53,  1.03s/it]INFO:root:global_step: 831, logpy: -21.656, kl: 1.694, loss: 23.349
 83%|████████▎ | 832/1000 [54:28<02:57,  1.05s/it]INFO:root:global_step: 832, logpy: -21.656, kl: 1.694, loss: 23.349
 83%|████████▎ | 833/1000 [54:29<02:55,  1.05s/it]INFO:root:global_step: 833, logpy: -21.655, kl: 1.696, loss: 23.351
 83%|████████▎ | 834/1000 [54:30<02:55,  1.06s/it]INFO:root:global_step: 834, logpy: -21.654, kl: 1.697, loss: 23.351
 84%|████████▎ | 835/1000 [54:31<02:54,  1.06s/it]INFO:root:global_step: 835, logpy: -21.654, kl: 1.698, loss: 23.352
 84%|████████▎ | 836/1000 [54:32<02:52,  1.05s/it]INFO:r

 90%|████████▉ | 897/1000 [55:58<01:43,  1.01s/it]INFO:root:global_step: 897, logpy: -21.665, kl: 1.700, loss: 23.365
 90%|████████▉ | 898/1000 [55:59<01:57,  1.15s/it]INFO:root:global_step: 898, logpy: -21.664, kl: 1.700, loss: 23.364
 90%|████████▉ | 899/1000 [56:00<01:54,  1.13s/it]INFO:root:global_step: 899, logpy: -21.666, kl: 1.701, loss: 23.366
 90%|█████████ | 900/1000 [56:02<01:51,  1.11s/it]INFO:root:Saved figure at: ./img/global_step_900.png
INFO:root:global_step: 900, logpy: -21.665, kl: 1.700, loss: 23.365
 90%|█████████ | 901/1000 [56:07<04:04,  2.47s/it]INFO:root:global_step: 901, logpy: -21.666, kl: 1.701, loss: 23.366
 90%|█████████ | 902/1000 [56:08<03:18,  2.03s/it]INFO:root:global_step: 902, logpy: -21.666, kl: 1.701, loss: 23.367
 90%|█████████ | 903/1000 [56:09<02:46,  1.71s/it]INFO:root:global_step: 903, logpy: -21.668, kl: 1.702, loss: 23.370
 90%|█████████ | 904/1000 [56:10<02:25,  1.52s/it]INFO:root:global_step: 904, logpy: -21.669, kl: 1.702, loss: 23.371
 90

 97%|█████████▋| 966/1000 [57:18<00:32,  1.03it/s]INFO:root:global_step: 966, logpy: -21.646, kl: 1.700, loss: 23.346
 97%|█████████▋| 967/1000 [57:19<00:31,  1.04it/s]INFO:root:global_step: 967, logpy: -21.644, kl: 1.700, loss: 23.344
 97%|█████████▋| 968/1000 [57:19<00:30,  1.04it/s]INFO:root:global_step: 968, logpy: -21.644, kl: 1.701, loss: 23.344
 97%|█████████▋| 969/1000 [57:20<00:29,  1.04it/s]INFO:root:global_step: 969, logpy: -21.644, kl: 1.701, loss: 23.344
 97%|█████████▋| 970/1000 [57:21<00:28,  1.04it/s]INFO:root:global_step: 970, logpy: -21.646, kl: 1.702, loss: 23.348
 97%|█████████▋| 971/1000 [57:22<00:28,  1.02it/s]INFO:root:global_step: 971, logpy: -21.649, kl: 1.702, loss: 23.351
 97%|█████████▋| 972/1000 [57:23<00:27,  1.03it/s]INFO:root:global_step: 972, logpy: -21.647, kl: 1.703, loss: 23.349
 97%|█████████▋| 973/1000 [57:24<00:26,  1.03it/s]INFO:root:global_step: 973, logpy: -21.647, kl: 1.704, loss: 23.351
 97%|█████████▋| 974/1000 [57:25<00:25,  1.04it/s]INFO:r