In [1]:
import argparse
import logging
import math
import os
import random
from collections import namedtuple
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from torch import distributions, nn, optim

import torchsde

In [2]:
# check if the gpu is available or not, if yes, use gpu
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [277]:
args = {
    "likelihood": "laplace", # specify the likelihood distribution,p(x|z)
    "adjoint": True, # specify whether to use adjoint sensitivty method to backward the sde
    "debug": True, # specify whether to debug or not
    "data": "segmented_cosine",# specify which data to use
    "kl_anneal_iters": 100, # number of the iterations of the annealing kl divergence schedule
    "train_iters": 1000, # number of iterations to train the model
    "batch_size": 100, 
    "adaptive": False, # whether use adaptive solver or not
    "method": "euler", # the method of sde solver
    "dt": 1e-2, # the parameter dt of sde solver
    "rtol": 1e-3, # the parameter rtol of sde solver
    "atol": 1e-3, # the atol of sde solver
    "scale": 0.05, # the scale, of the likelihood distribution
    "dpi": 500, # dpi of images
    "pause_iters": 50 # the interval to evaluate the model
}

In [278]:
# w/ underscore -> numpy; w/o underscore -> torch.
'''
    ts: original time series; can be segmented or irregular
    ts_ext: with extended time outisde the time series, use to generate latent outside to penalize out-of-data region and spread uncertainty
    ts_vis: regular time series used to plot the data
    ys: the observed dynamic, same size as ts
'''

Data = namedtuple('Data', ['ts_', 'ts_ext_', 'ts_vis_', 'ts', 'ts_ext', 'ts_vis', 'ys', 'ys_'])

In [279]:
class LinearScheduler(object):
    '''
        output a value follows linear schedule from maxval/iters to maxval, with 'iters' steps
    '''
    
    '''
        iters = 100,
        maxval = 1,
        1/100 = 0.001, 0.002, .... 1
    '''
    
    def __init__(self, iters, maxval=1.0):
        self._iters = max(1, iters)
        self._val = maxval / self._iters
        self._maxval = maxval

    def step(self):
        self._val = min(self._maxval, self._val + self._maxval / self._iters)

    @property
    def val(self):
        return self._val

In [280]:
class EMAMetric(object):
    '''
        Exponential moving average, used to calculate the average
    '''
    def __init__(self, gamma: Optional[float] = .99):
        super(EMAMetric, self).__init__()
        self._val = 0.
        self._gamma = gamma

    def step(self, x: Union[torch.Tensor, np.ndarray]):
        x = x.detach().cpu().numpy() if torch.is_tensor(x) else x
        self._val = self._gamma * self._val + (1 - self._gamma) * x
        return self._val

    @property
    def val(self):
        return self._val

In [281]:
def manual_seed(seed: int):
    '''
        set the random seed, make sure the result keeps the same for each call
    '''
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [282]:
def _stable_division(a, b, epsilon=1e-7): #a/b
    '''
        change all elements x in b s.t -epsilon < x < epsilon to epsilon, to make sure the division is stable, won't cause a really large number
    '''
    
    b = torch.where(b.abs().detach() > epsilon, b, torch.full_like(b, fill_value=epsilon) * b.sign())
    return a / b

In [283]:
class LatentSDE(torchsde.SDEIto):

    def __init__(self,theta=1.0,mu=[0.0,0.0], sigma=[0.5,0.5]):
        super(LatentSDE, self).__init__(noise_type="diagonal")
        # logvar = math.log(sigma ** 2 / (2. * theta)) # calculate the log variance
        # Prior drift.
        
        var = [t**2/(2.*theta) for t in sigma]
        
        self.register_buffer("theta", torch.tensor([[theta]])) # prior parameters, register 成buffer, 参数不会进行更新
        self.register_buffer("mu", torch.tensor(mu))
        self.register_buffer("sigma", torch.tensor(sigma)) 

        # p(z0).
        self.register_buffer("py0_mean", torch.tensor(mu)) 
        
        self.register_buffer("py0_var",torch.tensor(var))

        # Approximate posterior drift: Takes in 2 positional encodings and the state. f(t,y)
        self.net = nn.Sequential( #h\Phi()
            nn.Linear(4, 200),
            nn.Tanh(),
            nn.Linear(200, 200),
            nn.Tanh(),
            nn.Linear(200, 2)
        )
        # Initialization trick from Glow. 
        self.net[-1].weight.data.fill_(0.)
        self.net[-1].bias.data.fill_(0.)

        # q(y0). the initial value of the approx posterior distribution
        self.qy0_mean = nn.Parameter(torch.tensor(mu), requires_grad=True) 
        self.qy0_var = nn.Parameter(torch.tensor(var), requires_grad=True)

    def f(self, t, y):  # Approximate posterior drift. f(t, y) = t # h\Phi
        if t.dim() == 0:
            #t = torch.full_like(y, fill_value=t)
            t = t.repeat(y.size(0),1).to(device)
        # Positional encoding in transformers for time-inhomogeneous posterior.
        #print(y.size())
        #print(t.size())
        #print(torch.cat((torch.sin(t),torch.cos(t),y),dim=-1).size())
        res = self.net(torch.cat((torch.sin(t), torch.cos(t), y), dim=-1))
        #print(res.size())
        return res
    def g(self, t, y):  # Shared diffusion. g(t,y) = sigma
        return self.sigma.repeat(y.size(0), 1) # should chage to 2, since we have two dimension latent variable

    def h(self, t, y):  # Prior drift. h(t,y) = theta * y
        #return y.view(y.size(0),1)*self.theta # need to figure out
        h = self.theta*(self.mu-y)    
        return h
    def f_aug(self, t, y):  # Drift for augmented dynamics with logqp term.
        '''
             y has two columns, the first column is y0, the one we want to generate the SDE dynamic
             the second column is 0, used to generate the sampling paths from the posterior process, and used to estimate the kl divergence
        '''
        y = y[:, 0:2] # get the first column of y, that is to get y0 # z0
        f, g, h = self.f(t, y), self.g(t, y), self.h(t, y) # calculate f, g, h
        u = _stable_division(f - h, g) # u(z,t) = (f-h)/g
        #print("f_aug",u.size())
        f_logqp = .5 * (u ** 2).sum(dim=1, keepdim=True) # (u^2)/2, the drift of the second sde
        #print("f_aug",y.size())
        #print("f_aug",torch.cat([f,f_logqp],dim=1).size())
        return torch.cat([f, f_logqp], dim=1) # [batch_size, 2]

    def g_aug(self, t, y):  # Diffusion for augmented dynamics with logqp term.
        y = y[:, 0:2]
        g = self.g(t, y)
        g_logqp = torch.zeros_like(y[:,0:1]) # the diffusion of the second sde
        #print("g_aug",y.size())
        #print("g_aug",torch.cat([g,g_logqp],dim=1).size())
        return torch.cat([g, g_logqp], dim=1) # [batch_size, 2]

    def forward(self, ts, batch_size, eps=None):
        # recognition process
        #eps = torch.randn(batch_size, 1).to(self.qy0_std) if eps is None else eps
        #y0 = self.qy0_mean + eps * self.qy0_std # randomly generate z0 from approx posterior distribution q(z|x)
        
        y = distributions.multivariate_normal.MultivariateNormal(self.qy0_mean,torch.diag(self.qy0_var))
        y0 = y.rsample([batch_size])
        #print("a",y0.size()) # [batch_size, 2]
                 
        #qy0 = distributions.Normal(loc=self.qy0_mean, scale=self.qy0_std) # approx posterior distribution
        #py0 = distributions.Normal(loc=self.py0_mean, scale=self.py0_std) # prior distribution
        
        qy0 = distributions.multivariate_normal.MultivariateNormal(loc=self.qy0_mean, covariance_matrix=torch.diag(self.qy0_var))
        py0 = distributions.multivariate_normal.MultivariateNormal(loc=self.py0_mean, covariance_matrix=torch.diag(self.py0_var))
        
        logqp0 = distributions.kl_divergence(qy0, py0)  # KL(t=0). # kl divergence when t = 0
        logqp0 = logqp0.view(1,)
        
        aug_y0 = torch.cat([y0, torch.zeros(batch_size, 1).to(y0)], dim=1)
        #print("b",aug_y0.size())
        aug_ys = sdeint_fn(
            sde=self,
            y0=aug_y0,
            ts=ts, #[0,0.1,0.2, 0.3]
            method=args["method"],
            dt=args["dt"],
            adaptive=args["adaptive"],
            rtol=args["rtol"],
            atol=args["atol"],
            names={'drift': 'f_aug', 'diffusion': 'g_aug'}
        )
        #[len(ts),batch_size,2]
        ys, logqp_path = aug_ys[:, :, 0:2], aug_ys[-1, :, 2] 
        # the first column of the last dimension is the sample dynamic
        # the second column of the last dimension is the kl divergence
        logqp = (logqp0 + logqp_path).mean(dim=0)  # KL(t=0) + KL(path). # calculate the kl divergence
        return ys, logqp

    def sample_p(self, ts, batch_size, eps=None, bm=None):
        '''
            latent variable samples from prior distribution p(z), and their SDE dynamics
        '''
        #eps = torch.randn(batch_size, 1).to(self.py0_mean) if eps is None else eps
        #y0 = self.py0_mean + eps * self.py0_std # [batch_size, 1]: [1024, 1]
                 
        y = distributions.multivariate_normal.MultivariateNormal(self.py0_mean,torch.diag(self.py0_var))
        y0 = y.rsample([batch_size])
        yt = sdeint_fn(self, y0, ts, bm=bm, method='srk', dt=args["dt"], names={'drift': 'h'}) # [len(ts), batch_size, 1]: [300, 1024,1]
        return yt

    def sample_q(self, ts, batch_size, eps=None, bm=None):
        '''
            latent variable samples from approx posterior distribution q(z|x), and their SDE dynamics
        '''
        #eps = torch.randn(batch_size, 1).to(self.qy0_mean) if eps is None else eps
        #y0 = self.qy0_mean + eps * self.qy0_std # [batch_size, 1]: [1024, 1]
                 
        y = distributions.multivariate_normal.MultivariateNormal(self.qy0_mean,torch.diag(self.qy0_var))
        y0 = y.rsample([batch_size])
                 
        return sdeint_fn(self, y0, ts, bm=bm, method='srk', dt=args["dt"]) # [len(ts), batch_size, 1]: [300, 1024, 1]


    



In [284]:
def make_segmented_cosine_data():
    ts_ = np.concatenate((np.linspace(0.3, 0.8, 10), np.linspace(1.2, 1.5, 10)), axis=0) # create segmented time series
    ts_ext_ = np.array([0.] + list(ts_) + [2.0]) # add out-of-data time point
    ts_vis_ = np.linspace(0., 2.0, 300) # regular time series used for visualization
    ys_ = np.cos(ts_ * (2. * math.pi))[:, None] # get the segmented cosine data

    ts = torch.tensor(ts_).float()
    ts_ext = torch.tensor(ts_ext_).float()
    ts_vis = torch.tensor(ts_vis_).float()
    ys = torch.tensor(ys_).float().to(device)
    return Data(ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_)


def make_irregular_sine_data():
    ts_ = np.sort(np.random.uniform(low=0.4, high=1.6, size=16)) # create irregular time series
    ts_ext_ = np.array([0.] + list(ts_) + [2.0]) # add out-of-data time point
    ts_vis_ = np.linspace(0., 2.0, 300) # regular time series used for visualization
    ys_ = np.sin(ts_ * (2. * math.pi))[:, None] * 0.8 # get the irregular sine data

    ts = torch.tensor(ts_).float()
    ts_ext = torch.tensor(ts_ext_).float()
    ts_vis = torch.tensor(ts_vis_).float()
    ys = torch.tensor(ys_).float().to(device)
    return Data(ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_)


def make_data():
    data_constructor = {
        'segmented_cosine': make_segmented_cosine_data,
        'irregular_sine': make_irregular_sine_data
    }[args["data"]]
    return data_constructor()

In [295]:
def main():
    # Dataset.
    ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_ = make_data()

    beta1 = 0.2
    beta2 = 0.8
    
    # Plotting parameters.
    vis_batch_size = 1024 # the batch_size used to visaulize
    ylims = (-1.75, 1.75) # set up the ylim of the figure
    
    alphas = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55]
    percentiles = [0.999, 0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
    
    vis_idx = np.random.permutation(vis_batch_size) # shuffle the numbers from 1-vis_batch_size
    # From https://colorbrewer2.org/.

    sample_colors = ('#fc4e2a', '#e31a1c', '#bd0026') 
    fill_color = '#fd8d3c'
    
    mean1_color = '#800026'
    mean2_color = '#fc4e2a'
    
    num_samples = len(sample_colors)
    
    # eps = torch.randn(vis_batch_size, 1).to(device)  # samples from normal distribution
    
    bm = torchsde.BrownianInterval(
        t0=ts_vis[0],
        t1=ts_vis[-1],
        size=(vis_batch_size, 2),
        device=device,
        levy_area_approximation='space-time'
    )  # We need space-time Levy area to use the SRK solver, fix the brownian motion allows us to generate the sde dynamic fexedly

    # Model.
    model = LatentSDE().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-2)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=.999)
    kl_scheduler = LinearScheduler(iters=args["kl_anneal_iters"])

    logpy_metric = EMAMetric()
    kl_metric = EMAMetric()
    loss_metric = EMAMetric()
    
    with torch.no_grad():
        zs = model.sample_p(ts=ts_vis, batch_size=vis_batch_size, bm=bm).squeeze() # sde dynamic sampled from prior distribution
        zs = 0.5*zs[:,:,0]+0.5*zs[:,:,1]
        ts_vis_, zs_ = ts_vis.cpu().numpy(), zs.cpu().numpy()  # convert them to numpy
        # print(zs.size()) # [len(ts_vis), batch_size]: [300, 1024]
        
        zs_ = np.sort(zs_, axis=1) # sort each row

        img_dir = os.path.join('./img/' 'prior.png')
        plt.subplot(frameon=False)
        for alpha, percentile in zip(alphas, percentiles):
            idx = int((1 - percentile) / 2. * vis_batch_size) # 选择要考虑百分之多少的数据
            zs_bot_ = zs_[:, idx] # 计算底线
            zs_top_ = zs_[:, -idx] # 计算顶线
            plt.fill_between(ts_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color) # 用来填充两条曲线之间的区域
        # `zorder` determines who's on top; the larger the more at the top.
        plt.scatter(ts_, ys_, marker='x', zorder=3, color='k', s=35)  # scatter plot the original observed dynamic
        plt.ylim(ylims) 
        plt.xlabel('$t$')
        plt.ylabel('$Y_t$')
        plt.tight_layout()
        plt.savefig(img_dir, dpi=args["dpi"])
        plt.close()
        logging.info(f'Saved prior figure at: {img_dir}')
    
    for global_step in tqdm.tqdm(range(args["train_iters"])):
        # Plot and save.
        
        if global_step % args["pause_iters"] == 0:
            img_path = os.path.join('./img/', f'global_step_{global_step}.png')
            
            with torch.no_grad():
                zs = model.sample_q(ts=ts_vis, batch_size=vis_batch_size, bm=bm).squeeze()
                
                #samples = zs[:, vis_idx] # samples是按照之前的permutation来排序的
                #ts_vis_, zs_, samples_ = ts_vis.cpu().numpy(), zs.cpu().numpy(), samples.cpu().numpy()
                ts_vis_, zs_ = ts_vis.cpu().numpy(), zs.cpu().numpy()
                zs_ = np.sort(zs_, axis=1)
                plt.subplot(frameon=False)
                

                plt.plot(ts_vis_, zs_[:,:,0:1].mean(axis=1), color=mean1_color)
                plt.plot(ts_vis_, zs_[:,:,1:].mean(axis=1), color=mean2_color)
                
                
                if False: #args.hide_ticks:
                    plt.xticks([], [])
                    plt.yticks([], [])

                plt.scatter(ts_, ys_, marker='x', zorder=3, color='k', s=35)  # scatter plot the Data.
                plt.ylim(ylims)
                plt.xlabel('$t$')
                plt.ylabel('$Y_t$')
                plt.tight_layout()
                plt.savefig(img_path, dpi=args["dpi"])
                plt.close()
                logging.info(f'Saved figure at: {img_path}')

        
        
        # Train.
        optimizer.zero_grad()
        zs, kl = model(ts=ts_ext, batch_size=args["batch_size"]) 
        # print(zs.size()) # [len(ts_ext),batch_size,1]: [22,100,1] 
        zs = zs.squeeze() # [len(ts_ext), batch_size]: [22, 100]
        zs = zs[1:-1]  # Drop first and last which are only used to penalize out-of-data region and spread uncertainty.

        # select the likelihood function p(x|z)
        # generation process
        zs = beta1*zs[:,:,0]+beta2*zs[:,:,1]
        likelihood_constructor = {"laplace": distributions.Laplace, "normal": distributions.Normal}[args["likelihood"]]
        likelihood = likelihood_constructor(loc=zs, scale=args["scale"]) #f(x) = p(x|z)

        
        logpy = likelihood.log_prob(ys).sum(dim=0).mean(dim=0) # calculate the log likelihood p(x|z)
        
        loss = -logpy + kl * kl_scheduler.val # we want to maximize the ELBO
        loss.backward()

        optimizer.step()
        scheduler.step()
        kl_scheduler.step()

        logpy_metric.step(logpy)
        kl_metric.step(kl)
        loss_metric.step(loss)

        logging.info(
            f'global_step: {global_step}, '
            f'logpy: {logpy_metric.val:.3f}, '
            f'kl: {kl_metric.val:.3f}, '
            f'loss: {loss_metric.val:.3f}'
        )
    torch.save(
        {'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'kl_scheduler': kl_scheduler},
        os.path.join('./', f'global_step_{global_step}.ckpt')
    )

In [296]:
manual_seed(0)

if args["debug"]:
    logging.getLogger().setLevel(logging.INFO)

ckpt_dir = os.path.join('./', 'ckpts')
os.makedirs(ckpt_dir, exist_ok=True)

sdeint_fn = torchsde.sdeint_adjoint if args["adjoint"] else torchsde.sdeint

main()

INFO:root:Saved prior figure at: ./img/prior.png
  0%|                                                                                         | 0/1000 [00:00<?, ?it/s]INFO:root:Saved figure at: ./img/global_step_0.png
INFO:root:global_step: 0, logpy: -2.430, kl: 0.032, loss: 2.430
  0%|                                                                                 | 1/1000 [00:03<59:32,  3.58s/it]INFO:root:global_step: 1, logpy: -4.422, kl: 0.079, loss: 4.424
  0%|▏                                                                                | 2/1000 [00:05<40:41,  2.45s/it]INFO:root:global_step: 2, logpy: -7.529, kl: 0.528, loss: 7.543
  0%|▏                                                                                | 3/1000 [00:06<34:03,  2.05s/it]INFO:root:global_step: 3, logpy: -8.681, kl: 0.597, loss: 8.699
  0%|▎                                                                                | 4/1000 [00:08<31:00,  1.87s/it]INFO:root:global_step: 4, logpy: -10.023, kl: 0.6

  4%|███▍                                                                            | 43/1000 [01:10<25:29,  1.60s/it]INFO:root:global_step: 43, logpy: -41.505, kl: 5.122, loss: 42.775
  4%|███▌                                                                            | 44/1000 [01:12<25:24,  1.59s/it]INFO:root:global_step: 44, logpy: -42.118, kl: 5.175, loss: 43.422
  4%|███▌                                                                            | 45/1000 [01:14<25:26,  1.60s/it]INFO:root:global_step: 45, logpy: -42.706, kl: 5.227, loss: 44.045
  5%|███▋                                                                            | 46/1000 [01:15<25:42,  1.62s/it]INFO:root:global_step: 46, logpy: -43.323, kl: 5.272, loss: 44.695
  5%|███▊                                                                            | 47/1000 [01:17<25:49,  1.63s/it]INFO:root:global_step: 47, logpy: -43.921, kl: 5.325, loss: 45.330
  5%|███▊                                                             

  9%|██████▉                                                                         | 87/1000 [02:22<24:16,  1.59s/it]INFO:root:global_step: 87, logpy: -62.476, kl: 6.733, loss: 65.616
  9%|███████                                                                         | 88/1000 [02:24<24:21,  1.60s/it]INFO:root:global_step: 88, logpy: -62.835, kl: 6.759, loss: 66.027
  9%|███████                                                                         | 89/1000 [02:26<24:15,  1.60s/it]INFO:root:global_step: 89, logpy: -63.185, kl: 6.784, loss: 66.428
  9%|███████▏                                                                        | 90/1000 [02:27<24:07,  1.59s/it]INFO:root:global_step: 90, logpy: -63.522, kl: 6.808, loss: 66.817
  9%|███████▎                                                                        | 91/1000 [02:29<23:58,  1.58s/it]INFO:root:global_step: 91, logpy: -63.861, kl: 6.833, loss: 67.208
  9%|███████▎                                                         

 13%|██████████▎                                                                    | 130/1000 [03:33<22:58,  1.58s/it]INFO:root:global_step: 130, logpy: -67.934, kl: 9.564, loss: 75.124
 13%|██████████▎                                                                    | 131/1000 [03:34<22:53,  1.58s/it]INFO:root:global_step: 131, logpy: -67.758, kl: 9.689, loss: 75.096
 13%|██████████▍                                                                    | 132/1000 [03:36<22:59,  1.59s/it]INFO:root:global_step: 132, logpy: -67.618, kl: 9.843, loss: 75.135
 13%|██████████▌                                                                    | 133/1000 [03:38<22:58,  1.59s/it]INFO:root:global_step: 133, logpy: -67.436, kl: 9.970, loss: 75.102
 13%|██████████▌                                                                    | 134/1000 [03:39<22:51,  1.58s/it]INFO:root:global_step: 134, logpy: -67.313, kl: 10.101, loss: 75.133
 14%|██████████▋                                                

 17%|█████████████▋                                                                 | 173/1000 [04:43<21:57,  1.59s/it]INFO:root:global_step: 173, logpy: -56.374, kl: 15.464, loss: 70.297
 17%|█████████████▋                                                                 | 174/1000 [04:45<21:52,  1.59s/it]INFO:root:global_step: 174, logpy: -56.045, kl: 15.607, loss: 70.126
 18%|█████████████▊                                                                 | 175/1000 [04:46<21:49,  1.59s/it]INFO:root:global_step: 175, logpy: -55.692, kl: 15.738, loss: 69.920
 18%|█████████████▉                                                                 | 176/1000 [04:48<21:42,  1.58s/it]INFO:root:global_step: 176, logpy: -55.343, kl: 15.877, loss: 69.724
 18%|█████████████▉                                                                 | 177/1000 [04:49<21:49,  1.59s/it]INFO:root:global_step: 177, logpy: -54.985, kl: 16.019, loss: 69.524
 18%|██████████████                                         

 22%|█████████████████                                                              | 216/1000 [05:53<19:40,  1.51s/it]INFO:root:global_step: 216, logpy: -41.650, kl: 20.820, loss: 61.469
 22%|█████████████████▏                                                             | 217/1000 [05:54<19:42,  1.51s/it]INFO:root:global_step: 217, logpy: -41.318, kl: 20.941, loss: 61.269
 22%|█████████████████▏                                                             | 218/1000 [05:56<19:30,  1.50s/it]INFO:root:global_step: 218, logpy: -40.971, kl: 21.061, loss: 61.051
 22%|█████████████████▎                                                             | 219/1000 [05:57<19:26,  1.49s/it]INFO:root:global_step: 219, logpy: -40.626, kl: 21.181, loss: 60.837
 22%|█████████████████▍                                                             | 220/1000 [05:59<19:23,  1.49s/it]INFO:root:global_step: 220, logpy: -40.298, kl: 21.291, loss: 60.628
 22%|█████████████████▍                                     

 26%|████████████████████▍                                                          | 259/1000 [07:02<19:41,  1.59s/it]INFO:root:global_step: 259, logpy: -28.181, kl: 25.541, loss: 53.073
 26%|████████████████████▌                                                          | 260/1000 [07:03<19:36,  1.59s/it]INFO:root:global_step: 260, logpy: -27.878, kl: 25.661, loss: 52.896
 26%|████████████████████▌                                                          | 261/1000 [07:05<19:43,  1.60s/it]INFO:root:global_step: 261, logpy: -27.573, kl: 25.776, loss: 52.713
 26%|████████████████████▋                                                          | 262/1000 [07:06<19:41,  1.60s/it]INFO:root:global_step: 262, logpy: -27.270, kl: 25.882, loss: 52.522
 26%|████████████████████▊                                                          | 263/1000 [07:08<19:29,  1.59s/it]INFO:root:global_step: 263, logpy: -26.955, kl: 25.987, loss: 52.318
 26%|████████████████████▊                                  

 30%|███████████████████████▊                                                       | 302/1000 [08:12<23:18,  2.00s/it]INFO:root:global_step: 302, logpy: -16.732, kl: 29.742, loss: 46.052
 30%|███████████████████████▉                                                       | 303/1000 [08:13<21:52,  1.88s/it]INFO:root:global_step: 303, logpy: -16.513, kl: 29.822, loss: 45.918
 30%|████████████████████████                                                       | 304/1000 [08:15<20:57,  1.81s/it]INFO:root:global_step: 304, logpy: -16.280, kl: 29.913, loss: 45.780
 30%|████████████████████████                                                       | 305/1000 [08:17<20:07,  1.74s/it]INFO:root:global_step: 305, logpy: -16.054, kl: 30.007, loss: 45.652
 31%|████████████████████████▏                                                      | 306/1000 [08:18<19:34,  1.69s/it]INFO:root:global_step: 306, logpy: -15.823, kl: 30.096, loss: 45.514
 31%|████████████████████████▎                              

 34%|███████████████████████████▎                                                   | 345/1000 [09:20<17:17,  1.58s/it]INFO:root:global_step: 345, logpy: -8.194, kl: 32.981, loss: 40.902
 35%|███████████████████████████▎                                                   | 346/1000 [09:22<17:18,  1.59s/it]INFO:root:global_step: 346, logpy: -8.018, kl: 33.039, loss: 40.787
 35%|███████████████████████████▍                                                   | 347/1000 [09:23<17:17,  1.59s/it]INFO:root:global_step: 347, logpy: -7.862, kl: 33.109, loss: 40.702
 35%|███████████████████████████▍                                                   | 348/1000 [09:25<17:24,  1.60s/it]INFO:root:global_step: 348, logpy: -7.692, kl: 33.167, loss: 40.594
 35%|███████████████████████████▌                                                   | 349/1000 [09:27<17:18,  1.60s/it]INFO:root:global_step: 349, logpy: -7.519, kl: 33.222, loss: 40.479
 35%|███████████████████████████▋                                

 39%|██████████████████████████████▋                                                | 388/1000 [10:31<16:14,  1.59s/it]INFO:root:global_step: 388, logpy: -1.868, kl: 35.546, loss: 37.237
 39%|██████████████████████████████▋                                                | 389/1000 [10:32<16:15,  1.60s/it]INFO:root:global_step: 389, logpy: -1.745, kl: 35.598, loss: 37.168
 39%|██████████████████████████████▊                                                | 390/1000 [10:34<16:11,  1.59s/it]INFO:root:global_step: 390, logpy: -1.620, kl: 35.664, loss: 37.110
 39%|██████████████████████████████▉                                                | 391/1000 [10:36<16:05,  1.59s/it]INFO:root:global_step: 391, logpy: -1.492, kl: 35.716, loss: 37.036
 39%|██████████████████████████████▉                                                | 392/1000 [10:37<16:02,  1.58s/it]INFO:root:global_step: 392, logpy: -1.350, kl: 35.760, loss: 36.939
 39%|███████████████████████████████                             

 43%|██████████████████████████████████▏                                            | 432/1000 [11:43<15:03,  1.59s/it]INFO:root:global_step: 432, logpy: 2.861, kl: 37.438, loss: 34.463
 43%|██████████████████████████████████▏                                            | 433/1000 [11:44<15:03,  1.59s/it]INFO:root:global_step: 433, logpy: 2.962, kl: 37.480, loss: 34.405
 43%|██████████████████████████████████▎                                            | 434/1000 [11:46<14:59,  1.59s/it]INFO:root:global_step: 434, logpy: 3.063, kl: 37.515, loss: 34.340
 44%|██████████████████████████████████▎                                            | 435/1000 [11:48<15:03,  1.60s/it]INFO:root:global_step: 435, logpy: 3.161, kl: 37.539, loss: 34.267
 44%|██████████████████████████████████▍                                            | 436/1000 [11:49<14:59,  1.59s/it]INFO:root:global_step: 436, logpy: 3.243, kl: 37.578, loss: 34.225
 44%|██████████████████████████████████▌                              

 48%|█████████████████████████████████████▌                                         | 476/1000 [12:55<13:52,  1.59s/it]INFO:root:global_step: 476, logpy: 6.413, kl: 38.906, loss: 32.420
 48%|█████████████████████████████████████▋                                         | 477/1000 [12:56<13:48,  1.58s/it]INFO:root:global_step: 477, logpy: 6.475, kl: 38.926, loss: 32.379
 48%|█████████████████████████████████████▊                                         | 478/1000 [12:58<13:47,  1.58s/it]INFO:root:global_step: 478, logpy: 6.542, kl: 38.947, loss: 32.333
 48%|█████████████████████████████████████▊                                         | 479/1000 [13:00<13:51,  1.60s/it]INFO:root:global_step: 479, logpy: 6.600, kl: 38.976, loss: 32.305
 48%|█████████████████████████████████████▉                                         | 480/1000 [13:01<13:45,  1.59s/it]INFO:root:global_step: 480, logpy: 6.661, kl: 39.032, loss: 32.301
 48%|█████████████████████████████████████▉                           

 52%|█████████████████████████████████████████                                      | 520/1000 [14:06<12:39,  1.58s/it]INFO:root:global_step: 520, logpy: 8.943, kl: 40.079, loss: 31.089
 52%|█████████████████████████████████████████▏                                     | 521/1000 [14:08<12:38,  1.58s/it]INFO:root:global_step: 521, logpy: 8.983, kl: 40.106, loss: 31.076
 52%|█████████████████████████████████████████▏                                     | 522/1000 [14:10<12:35,  1.58s/it]INFO:root:global_step: 522, logpy: 9.034, kl: 40.128, loss: 31.048
 52%|█████████████████████████████████████████▎                                     | 523/1000 [14:11<12:33,  1.58s/it]INFO:root:global_step: 523, logpy: 9.092, kl: 40.152, loss: 31.014
 52%|█████████████████████████████████████████▍                                     | 524/1000 [14:13<12:36,  1.59s/it]INFO:root:global_step: 524, logpy: 9.152, kl: 40.184, loss: 30.986
 52%|█████████████████████████████████████████▍                       

 56%|████████████████████████████████████████████▌                                  | 564/1000 [15:18<11:27,  1.58s/it]INFO:root:global_step: 564, logpy: 10.842, kl: 40.908, loss: 30.036
 56%|████████████████████████████████████████████▋                                  | 565/1000 [15:20<11:28,  1.58s/it]INFO:root:global_step: 565, logpy: 10.892, kl: 40.927, loss: 30.005
 57%|████████████████████████████████████████████▋                                  | 566/1000 [15:21<11:32,  1.60s/it]INFO:root:global_step: 566, logpy: 10.946, kl: 40.933, loss: 29.957
 57%|████████████████████████████████████████████▊                                  | 567/1000 [15:23<11:32,  1.60s/it]INFO:root:global_step: 567, logpy: 10.979, kl: 40.939, loss: 29.930
 57%|████████████████████████████████████████████▊                                  | 568/1000 [15:25<11:34,  1.61s/it]INFO:root:global_step: 568, logpy: 11.006, kl: 40.941, loss: 29.906
 57%|████████████████████████████████████████████▉               

 61%|███████████████████████████████████████████████▉                               | 607/1000 [16:28<10:49,  1.65s/it]INFO:root:global_step: 607, logpy: 12.263, kl: 41.492, loss: 29.210
 61%|████████████████████████████████████████████████                               | 608/1000 [16:30<10:38,  1.63s/it]INFO:root:global_step: 608, logpy: 12.304, kl: 41.513, loss: 29.189
 61%|████████████████████████████████████████████████                               | 609/1000 [16:31<10:29,  1.61s/it]INFO:root:global_step: 609, logpy: 12.339, kl: 41.547, loss: 29.188
 61%|████████████████████████████████████████████████▏                              | 610/1000 [16:33<10:22,  1.60s/it]INFO:root:global_step: 610, logpy: 12.377, kl: 41.572, loss: 29.177
 61%|████████████████████████████████████████████████▎                              | 611/1000 [16:35<10:17,  1.59s/it]INFO:root:global_step: 611, logpy: 12.407, kl: 41.588, loss: 29.163
 61%|████████████████████████████████████████████████▎           

INFO:root:global_step: 650, logpy: 13.344, kl: 42.010, loss: 28.654
 65%|███████████████████████████████████████████████████▍                           | 651/1000 [17:40<12:38,  2.17s/it]INFO:root:global_step: 651, logpy: 13.358, kl: 42.009, loss: 28.638
 65%|███████████████████████████████████████████████████▌                           | 652/1000 [17:42<11:32,  1.99s/it]INFO:root:global_step: 652, logpy: 13.372, kl: 42.021, loss: 28.636
 65%|███████████████████████████████████████████████████▌                           | 653/1000 [17:43<10:45,  1.86s/it]INFO:root:global_step: 653, logpy: 13.386, kl: 42.035, loss: 28.636
 65%|███████████████████████████████████████████████████▋                           | 654/1000 [17:45<10:13,  1.77s/it]INFO:root:global_step: 654, logpy: 13.410, kl: 42.038, loss: 28.616
 66%|███████████████████████████████████████████████████▋                           | 655/1000 [17:46<09:51,  1.72s/it]INFO:root:global_step: 655, logpy: 13.435, kl: 42.043, loss: 28.5

 69%|██████████████████████████████████████████████████████▊                        | 694/1000 [18:48<08:00,  1.57s/it]INFO:root:global_step: 694, logpy: 14.149, kl: 42.404, loss: 28.248
 70%|██████████████████████████████████████████████████████▉                        | 695/1000 [18:49<07:58,  1.57s/it]INFO:root:global_step: 695, logpy: 14.170, kl: 42.423, loss: 28.245
 70%|██████████████████████████████████████████████████████▉                        | 696/1000 [18:51<08:01,  1.58s/it]INFO:root:global_step: 696, logpy: 14.192, kl: 42.436, loss: 28.236
 70%|███████████████████████████████████████████████████████                        | 697/1000 [18:53<08:00,  1.59s/it]INFO:root:global_step: 697, logpy: 14.224, kl: 42.429, loss: 28.197
 70%|███████████████████████████████████████████████████████▏                       | 698/1000 [18:54<08:00,  1.59s/it]INFO:root:global_step: 698, logpy: 14.232, kl: 42.436, loss: 28.196
 70%|███████████████████████████████████████████████████████▏    

 74%|██████████████████████████████████████████████████████████▏                    | 737/1000 [19:58<06:57,  1.59s/it]INFO:root:global_step: 737, logpy: 14.608, kl: 42.710, loss: 28.097
 74%|██████████████████████████████████████████████████████████▎                    | 738/1000 [20:00<06:54,  1.58s/it]INFO:root:global_step: 738, logpy: 14.629, kl: 42.726, loss: 28.092
 74%|██████████████████████████████████████████████████████████▍                    | 739/1000 [20:02<06:56,  1.59s/it]INFO:root:global_step: 739, logpy: 14.659, kl: 42.724, loss: 28.060
 74%|██████████████████████████████████████████████████████████▍                    | 740/1000 [20:03<06:52,  1.59s/it]INFO:root:global_step: 740, logpy: 14.680, kl: 42.739, loss: 28.054
 74%|██████████████████████████████████████████████████████████▌                    | 741/1000 [20:05<06:51,  1.59s/it]INFO:root:global_step: 741, logpy: 14.691, kl: 42.753, loss: 28.057
 74%|██████████████████████████████████████████████████████████▌ 

 78%|█████████████████████████████████████████████████████████████▌                 | 780/1000 [21:09<05:51,  1.60s/it]INFO:root:global_step: 780, logpy: 15.107, kl: 43.024, loss: 27.914
 78%|█████████████████████████████████████████████████████████████▋                 | 781/1000 [21:10<05:49,  1.59s/it]INFO:root:global_step: 781, logpy: 15.113, kl: 43.055, loss: 27.939
 78%|█████████████████████████████████████████████████████████████▊                 | 782/1000 [21:12<05:50,  1.61s/it]INFO:root:global_step: 782, logpy: 15.129, kl: 43.060, loss: 27.928
 78%|█████████████████████████████████████████████████████████████▊                 | 783/1000 [21:13<05:47,  1.60s/it]INFO:root:global_step: 783, logpy: 15.150, kl: 43.073, loss: 27.920
 78%|█████████████████████████████████████████████████████████████▉                 | 784/1000 [21:15<05:44,  1.60s/it]INFO:root:global_step: 784, logpy: 15.154, kl: 43.073, loss: 27.916
 78%|████████████████████████████████████████████████████████████

 82%|█████████████████████████████████████████████████████████████████              | 823/1000 [22:19<04:39,  1.58s/it]INFO:root:global_step: 823, logpy: 15.561, kl: 43.273, loss: 27.709
 82%|█████████████████████████████████████████████████████████████████              | 824/1000 [22:21<04:38,  1.58s/it]INFO:root:global_step: 824, logpy: 15.555, kl: 43.280, loss: 27.723
 82%|█████████████████████████████████████████████████████████████████▏             | 825/1000 [22:22<04:39,  1.60s/it]INFO:root:global_step: 825, logpy: 15.561, kl: 43.284, loss: 27.721
 83%|█████████████████████████████████████████████████████████████████▎             | 826/1000 [22:24<04:36,  1.59s/it]INFO:root:global_step: 826, logpy: 15.563, kl: 43.280, loss: 27.714
 83%|█████████████████████████████████████████████████████████████████▎             | 827/1000 [22:25<04:33,  1.58s/it]INFO:root:global_step: 827, logpy: 15.576, kl: 43.301, loss: 27.723
 83%|████████████████████████████████████████████████████████████

 87%|████████████████████████████████████████████████████████████████████▍          | 866/1000 [23:29<03:33,  1.60s/it]INFO:root:global_step: 866, logpy: 15.917, kl: 43.465, loss: 27.547
 87%|████████████████████████████████████████████████████████████████████▍          | 867/1000 [23:31<03:31,  1.59s/it]INFO:root:global_step: 867, logpy: 15.914, kl: 43.472, loss: 27.557
 87%|████████████████████████████████████████████████████████████████████▌          | 868/1000 [23:33<03:29,  1.59s/it]INFO:root:global_step: 868, logpy: 15.913, kl: 43.467, loss: 27.553
 87%|████████████████████████████████████████████████████████████████████▋          | 869/1000 [23:34<03:27,  1.58s/it]INFO:root:global_step: 869, logpy: 15.915, kl: 43.479, loss: 27.562
 87%|████████████████████████████████████████████████████████████████████▋          | 870/1000 [23:36<03:25,  1.58s/it]INFO:root:global_step: 870, logpy: 15.920, kl: 43.464, loss: 27.542
 87%|████████████████████████████████████████████████████████████

 91%|███████████████████████████████████████████████████████████████████████▊       | 909/1000 [24:40<02:27,  1.62s/it]INFO:root:global_step: 909, logpy: 16.235, kl: 43.606, loss: 27.370
 91%|███████████████████████████████████████████████████████████████████████▉       | 910/1000 [24:41<02:21,  1.58s/it]INFO:root:global_step: 910, logpy: 16.243, kl: 43.609, loss: 27.366
 91%|███████████████████████████████████████████████████████████████████████▉       | 911/1000 [24:43<02:21,  1.59s/it]INFO:root:global_step: 911, logpy: 16.243, kl: 43.636, loss: 27.392
 91%|████████████████████████████████████████████████████████████████████████       | 912/1000 [24:44<02:18,  1.58s/it]INFO:root:global_step: 912, logpy: 16.250, kl: 43.637, loss: 27.386
 91%|████████████████████████████████████████████████████████████████████████▏      | 913/1000 [24:46<02:16,  1.57s/it]INFO:root:global_step: 913, logpy: 16.261, kl: 43.637, loss: 27.375
 91%|████████████████████████████████████████████████████████████

 95%|███████████████████████████████████████████████████████████████████████████▏   | 952/1000 [25:49<01:36,  2.01s/it]INFO:root:global_step: 952, logpy: 16.467, kl: 43.802, loss: 27.335
 95%|███████████████████████████████████████████████████████████████████████████▎   | 953/1000 [25:51<01:27,  1.87s/it]INFO:root:global_step: 953, logpy: 16.475, kl: 43.808, loss: 27.333
 95%|███████████████████████████████████████████████████████████████████████████▎   | 954/1000 [25:52<01:21,  1.77s/it]INFO:root:global_step: 954, logpy: 16.478, kl: 43.804, loss: 27.325
 96%|███████████████████████████████████████████████████████████████████████████▍   | 955/1000 [25:54<01:17,  1.73s/it]INFO:root:global_step: 955, logpy: 16.470, kl: 43.788, loss: 27.317
 96%|███████████████████████████████████████████████████████████████████████████▌   | 956/1000 [25:55<01:13,  1.68s/it]INFO:root:global_step: 956, logpy: 16.472, kl: 43.792, loss: 27.319
 96%|████████████████████████████████████████████████████████████

100%|██████████████████████████████████████████████████████████████████████████████▋| 996/1000 [26:58<00:06,  1.59s/it]INFO:root:global_step: 996, logpy: 16.671, kl: 43.890, loss: 27.219
100%|██████████████████████████████████████████████████████████████████████████████▊| 997/1000 [27:00<00:04,  1.59s/it]INFO:root:global_step: 997, logpy: 16.679, kl: 43.896, loss: 27.216
100%|██████████████████████████████████████████████████████████████████████████████▊| 998/1000 [27:01<00:03,  1.60s/it]INFO:root:global_step: 998, logpy: 16.685, kl: 43.896, loss: 27.211
100%|██████████████████████████████████████████████████████████████████████████████▉| 999/1000 [27:03<00:01,  1.59s/it]INFO:root:global_step: 999, logpy: 16.703, kl: 43.898, loss: 27.194
100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [27:05<00:00,  1.63s/it]
