# Bayesian Transfer Learning for Deep Networks

In this project we are concerned with **Bayesian Deep Learning**. Specifically, we want to know whether having a deep Bayesian model will improve the transfer of learning. Our hypothesis is that that knowledge gained from training a model on tasks **A** and then using the learned weights as a basis for learning on tasks $B$ will perform better than training **B** from scratch - assuming the domains are similar.

![Transfer Learning](https://image.slidesharecdn.com/13aibigdata-160606103446/95/aibigdata-lab-2016-transfer-learning-7-638.jpg?cb=1465209397)

We use Bayes By Backprop introduced by [Blundell, 2015](https://arxiv.org/abs/1505.05424). to learn a probability distribution over each of the weights in the network. These weight distributions are fitted using variational inference given some prior.

By inferring the posterior weight distribution in task **A** $p(w|D_A)$, a model is trained which is able to solve the second task **B** when exposed to new data $D_B$, while remembering task **A**. Variational Bayasian approximations of $p(w|D_A)$ are considered for this operation.

> The model constructed in this notebook tries to dynamically adapt its weights when confronted with new tasks. A method named **elastic weight consolidation (EWC)** ([Kirkpatrick, 2016](http://www.pnas.org/content/114/13/3521.full.pdf)) is implemented that considers data from two different tasks as independent.

### Import packages

In [None]:
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline  


import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter
import torch.optim as optim
import torch.nn.functional as F

from collections import Counter, defaultdict, OrderedDict
from tqdm import tqdm
import math
import os, pickle, gc
import seaborn as sns
from scipy.stats import norm

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

### Defining $\beta$ experiments

### Defining Normalizing Flows experiments

In [2]:

class NormalizingFlows(nn.Module):
    def __init__(self, n, features):
        super(NormalizingFlows, self).__init__()
        flows = [PlanarNormalizingFlow(features) for _ in range(n)]
        self.flows = nn.ModuleList(flows)

    def forward(self, z):
        log_dets = []

        for flow in self.flows:
            z, log_det_jacobian = flow(z)
            log_dets.append(log_det_jacobian)

        return z, log_dets


class PlanarNormalizingFlow(nn.Module):
    """
    Based on Normalizing Flow implementation from Parmesan
    https://github.com/casperkaae/parmesan
    """
    def __init__(self, features):
        super(PlanarNormalizingFlow, self).__init__()
        self.u = Parameter(torch.randn(features))
        self.w = Parameter(torch.randn(features))
        self.b = Parameter(torch.ones(1))

    def forward(self, z):
        # Create uhat such that it is parallel to w
        uw = torch.dot(self.u, self.w)
        muw = -1 + F.softplus(uw)
        uhat = self.u + (muw - uw) * torch.transpose(self.w, 0, -1) / torch.sum(self.w ** 2)

        # Equation 21 - Transform z
        zwb = torch.mv(z, self.w) + self.b

        f_z = z + (uhat.view(1, -1) * F.tanh(zwb).view(-1, 1))

        # Compute the Jacobian using the fact that
        # tanh(x) dx = 1 - tanh(x)**2
        psi = (1 - F.tanh(zwb)**2).view(-1, 1) * self.w.view(1, -1)
        psi_u = torch.mv(psi, uhat)

        # Return the transformed output along
        # with log determninant of J
        logdet_jacobian = torch.log(torch.abs(1 + psi_u) + 1e-8)

        return f_z, logdet_jacobian

### Defining transfer learning experiments

### (header missing)

In [3]:
def log_sum_exp(A, dim=-1, keepdim=False, sum_op=torch.sum):
    """Computes `log(exp(A).sum(axis=axis))` avoiding numerical issues using the log-sum-exp trick.
    Direct calculation of :math:`\log \sum_i \exp A_i` can result in underflow or overflow numerical
    issues. Big positive values can cause overflow :math:`\exp A_i = \inf`, and big negative values
    can cause underflow :math:`\exp A_i = 0`. The latter can eventually cause the sum to go to zero
    and finally resulting in :math:`\log 0 = -\inf`.
    The log-sum-exp trick avoids these issues by using the identity,
    .. math::
        \log \sum_i \exp A_i = \log \sum_i \exp(A_i - c) + c, \text{using},  \\
        c = \max A.7
    This avoids overflow, and while underflow can still happen for individual elements it avoids
    the sum being zero.

    Parameters
    ----------
    A : tensor
        Tensor of which we wish to compute the log-sum-exp.
    axis : int, tuple, list, None
        Axis or axes to sum over; None (default) sums over all axes.
    sum_op : function
        Summing function to apply; default is torch.sum, but can also be torch.mean for log-mean-exp.

    Returns
    -------
    tensor
        The log-sum-exp of `A`, dimensions over which is summed will be dropped.
    """
    A_max = torch.max(A,dim=dim,keepdim=True)[0]
    B = torch.log(sum_op(torch.exp(A - A_max), dim=dim, keepdim=True)) + A_max
    
    if not keepdim:
        B = B.squeeze(dim)
    
    return B



class Distribution():
    def pdf(self, x):
        raise NotImplementedError

    def logpdf(self, x):
        raise NotImplementedError

    def cdf(self, x):
        raise NotImplementedError

    def logcdf(self, x):
        raise NotImplementedError

    def sample(self):
        raise NotImplementedError

    
class Normal(Distribution):
    # scalar version
    def __init__(self, loc, logvar):
        self.loc = loc
        self.logvar = logvar
        self.shp = loc.size()

        super(Normal,self).__init__()

    def logpdf(self, x, eps=0.0):
        c = - float(0.5 * math.log(2 * math.pi))
        return c - 0.5*self.logvar - (x - self.loc).pow(2) / (2 * torch.exp((self.logvar)) + eps)

    def pdf(self, x):
        return torch.exp(self.logpdf(x))

    def sample(self):
        if self.loc.is_cuda:
            eps = torch.cuda.FloatTensor(self.shp).normal_()
        else:
            eps = torch.FloatTensor(self.shp).normal_()
        return self.loc + torch.exp(0.5*self.logvar) * Variable(eps)
    
    def entropy(self):
        return 0.5 * math.log(2. * math.pi * math.e) + 0.5*self.logvar



def kl_normal2_normal2(mean1, log_var1, mean2, log_var2, eps=0.0):
    """
    Compute closed-form solution to the KL-divergence between two Gaussians parameterized
    with diagonal log variance.
    .. math::
       D_{KL}[q||p] &= -\int p(x) \log q(x) dx + \int p(x) \log p(x) dx     \\
                    &= -\int \mathcal{N}(x; \mu_2, \sigma^2_2) \log \mathcal{N}(x; \mu_1, \sigma^2_1) dx
                        + \int \mathcal{N}(x; \mu_2, \sigma^2_2) \log \mathcal{N}(x; \mu_2, \sigma^2_2) dx     \\
                    &= \frac{1}{2} \log(2\pi\sigma^2_2) + \frac{\sigma^2_1 + (\mu_1 - \mu_2)^2}{2\sigma^2_2}
                        - \frac{1}{2}( 1 + \log(2\pi\sigma^2_1) )      \\
                    &= \log \frac{\sigma_2}{\sigma_1} + \frac{\sigma^2_1 + (\mu_1 - \mu_2)^2}{2\sigma^2_2} - \frac{1}{2}
    """
    return 0.5*log_var2 - 0.5*log_var1 + (torch.exp(log_var1) + (mean1 - mean2)**2) / (2*torch.exp(log_var2) + eps) - 0.5

    
class FixedMixtureNormal(nn.Module):   #needs to be a nn.Moudle to register the parameters correcly
    # takes loc, logvar and pi as list of float values and assumes they are shared across all dimenisions
    def __init__(self, loc, logvar, pi):
        super(FixedMixtureNormal,self).__init__()
        assert sum(pi) -1 < 0.0001
        self.loc = Parameter(torch.from_numpy(np.array(loc)).float(),requires_grad=False)
        self.logvar = Parameter(torch.from_numpy(np.array(logvar)).float(),requires_grad=False)
        self.pi = Parameter(torch.from_numpy(np.array(pi)).float(),requires_grad=False)

    def _component_logpdf(self, x, eps=0.0):
        ndim = len(x.size())
        shpexpand = ndim*(None,)
        x = x.unsqueeze(-1)
        
        c = - float(0.5 * math.log(2 * math.pi))
        loc = self.loc[shpexpand]
        logvar = self.logvar[shpexpand]
        pi = self.pi[shpexpand]
        
        return c - 0.5*logvar - (x - loc).pow(2) / (2 * torch.exp(logvar) + eps)

    def logpdf(self,x):
        ndim = len(x.size())
        shpexpand = ndim*(None,)
        pi = self.pi[shpexpand]
        px = torch.exp(self._component_logpdf(x))   #... x num_components
        return torch.log(torch.sum(pi*px,-1))          

    
    
class FixedNormal(Distribution):
    # takes loc and logvar as float values and assumes they are shared across all dimenisions
    def __init__(self, loc, logvar):
        self.loc = loc
        self.logvar = logvar
        super(FixedNormal,self).__init__()

    def logpdf(self, x, eps=0.0):
        c = - float(0.5 * math.log(2 * math.pi))
        return c - 0.5*self.logvar - (x - self.loc).pow(2) / (2 * math.exp((self.logvar)) + eps)

    

### Defining prior weight distributions, activation functions, and loss function

correct? is that all? Better to make multiple cells out of it?

In [4]:
def distribution_selector(loc, logvar, pi):    
    if isinstance(logvar,(list,tuple)) and isinstance(pi,(list,tuple)):
        assert len(logvar) == len(pi)
        num_components = len(logvar)
        print('Mixture of Normal prior, nc:',num_components)
        if not isinstance(loc,(list,tuple)):
            loc = (loc,)*num_components
        return FixedMixtureNormal(loc,logvar,pi)
    else:
        print('Normal prior')
        return FixedNormal(loc,logvar)


class BBBLinearFactorial(nn.Module):
    def __init__(self,in_features, out_features, p_logvar_init=-3, p_pi=1.0, q_logvar_init=-5):
        # p_logvar_init, p_pi can be either
        #    (list/tuples): prior model is a mixture of gaussians components=len(p_pi)=len(p_logvar_init)
        #    float: Gussian distribution
        # q_logvar_init: float, the approximate posterior is currently always a factorized gaussian
        super(BBBLinearFactorial,self).__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.p_logvar_init = p_logvar_init
        self.q_logvar_init = q_logvar_init
                
        #Approximate Posterior model
        self.qw_mean = Parameter(torch.Tensor(out_features, in_features))
        self.qw_logvar = Parameter(torch.Tensor(out_features, in_features))
        self.qb_mean = Parameter(torch.Tensor(out_features))
        self.qb_logvar = Parameter(torch.Tensor(out_features))   
        
        self.normalizing_flow_w = NormalizingFlows(n=5, features=in_features*out_features)
        self.normalizing_flow_b = NormalizingFlows(n=5, features=out_features)

        self.qw = Normal(loc=self.qw_mean, logvar=self.qw_logvar)
        self.qb = Normal(loc=self.qb_mean, logvar=self.qb_logvar)
            
        #Prior Model (the prior model does not have any trainable parameters so we use special versions of the normal distributions)
        self.pw = distribution_selector(loc=0.0, logvar=p_logvar_init, pi=p_pi)
        self.pb = distribution_selector(loc=0.0, logvar=p_logvar_init, pi=p_pi)

        #initialize all paramaters
        self.reset_parameters()

    def reset_parameters(self):
        #initialize (learnable) approximate posterior parameters        
        stdv = 10. / math.sqrt(self.in_features)
        self.qw_mean.data.uniform_(-stdv, stdv)
        self.qw_logvar.data.uniform_(-stdv, stdv).add_(self.q_logvar_init)
        self.qb_mean.data.uniform_(-stdv, stdv)
        self.qb_logvar.data.uniform_(-stdv, stdv).add_(self.q_logvar_init)

        
    def forward(self, input):
        raise NotImplementedError()
        
    def probforward(self, input, MAP=False):
        #input: BS, in_features
        #W: BS, in_features
        #MAP: maximum a posterior (return the mean instead of sampling from the distributions)
        if MAP:
            w_sample = self.qw.loc
            b_sample = self.qb.loc
        else:
            w_sample = self.qw.sample()
            b_sample = self.qb.sample()

        f_w_sample, log_det_w = self.normalizing_flow_w(w_sample.view(1, -1))
        f_b_sample, log_det_b = self.normalizing_flow_b(b_sample.view(1, -1))

        f_w_sample = f_w_sample.view(w_sample.size())
        f_b_sample = f_b_sample.view(b_sample.size())

        # Subtracting log det J is the same as multiplying by 1/(det J)
        qw_logpdf = self.qw.logpdf(w_sample) - sum(log_det_w)
        qb_logpdf = self.qb.logpdf(b_sample) - sum(log_det_b)

        kl_w = torch.sum(qw_logpdf - self.pw.logpdf(f_w_sample))
        kl_b = torch.sum(qb_logpdf - self.pb.logpdf(f_b_sample))
        kl = kl_w + kl_b
        
        diagnostics = {'kl_w':kl_w.data.mean(), 'kl_b':kl_b.data.mean(), 
                       'Hq_w': self.qw.entropy().data.mean(), 'Hq_b': self.qb.entropy().data.mean()}   #Hq_w and Hq_b are the differential entropy
        output = F.linear(input,f_w_sample, f_b_sample)

        return output, kl, diagnostics


    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ')'
    
    
    
class BBBMLP(nn.Module):
    def __init__(self,in_features, num_class, num_hidden, num_layers, p_logvar_init=-3, p_pi=1.0, q_logvar_init=-5):
        #create a simple MLP model with probabilistic weights
        super(BBBMLP, self).__init__()        
        layers = [BBBLinearFactorial(in_features=in_features,out_features=num_hidden, p_logvar_init=p_logvar_init, p_pi=p_pi, q_logvar_init=q_logvar_init), nn.ReLU()]
        for i in range(num_layers-1):
            layers += [BBBLinearFactorial(in_features=num_hidden,out_features=num_hidden, p_logvar_init=p_logvar_init, p_pi=p_pi, q_logvar_init=q_logvar_init), nn.ReLU()]
        layers += [BBBLinearFactorial(in_features=num_hidden,out_features=num_class, p_logvar_init=p_logvar_init, p_pi=p_pi, q_logvar_init=q_logvar_init)]
        
        self.layers = nn.ModuleList(layers)
        self.loss = nn.CrossEntropyLoss()
        
    def probforward(self,x, MAP=False):
        diagnostics = defaultdict(list)
        kl = 0
        for l in self.layers:
            if hasattr(l, 'probforward' ) and callable( l.probforward ): 
                x, _kl, _diagnostics = l.probforward(x,MAP=MAP) 
                kl += _kl
                for k,v in _diagnostics.items():
                    diagnostics[k].append(v)
            else:
                x = l(x)        
        logits = x
        return logits, kl, diagnostics

    def load_prior(self, state_dict):
        d_q = {k: v for k, v in state_dict.items() if "q" in k}
        for i, layer in enumerate(self.layers):
            if type(layer) is BBBLinearFactorial:
                layer.pw = Normal(loc=Variable(d_q["layers.{}.qw_mean".format(i)]).cuda(),
                                  logvar=Variable(d_q["layers.{}.qw_logvar".format(i)]).cuda())

                layer.pb = Normal(loc=Variable(d_q["layers.{}.qb_mean".format(i)]).cuda(),
                                  logvar=Variable(d_q["layers.{}.qb_logvar".format(i)]).cuda())


    def getloss(self,x,y, beta, MAP=False):
        logits, kl, _diagnostics = self.probforward(x, MAP=MAP)
        #_diagnostics is here a dictinary of list of floats
        # We need the dataset_size in order to 'spread' the KL divergence across all samples - this is dscribed in EQ (8) in Blundel et. al. 2015

        logpy = -self.loss(logits,y) #sample average

        ll = logpy - beta*kl   #ELBO
        loss = -ll

        acc = (logits.max(1)[1].data == th_y).float().mean() #accuracy
        
        #the xxx.data.mean() is just an easy way to transfer to cpu and convert from torch to normal floats
        diagnostics = {'loss': [loss.data.mean()],
                       'll': [ll.data.mean()],
                       'kl': [kl.data.mean()],
                       'logpy': [logpy.data.mean()],
                       'acc': [acc],
                       'kl_w': _diagnostics['kl_w'],
                       'kl_b': _diagnostics['kl_b'],
                       'Hq_w': _diagnostics['Hq_w'],
                       'Hq_b': _diagnostics['Hq_b'],}
        return logits, loss, diagnostics
    
    

def plothist(model,filename):
    N = norm(loc=0,scale=np.exp(p_logvar_init))
    x =np.linspace(-0.5,0.5,100)
    W = torch.cat([model.layers[0].qw_mean.view(-1), model.layers[2].qw_mean.view(-1), model.layers[4].qw_mean.view(-1)]).data.cpu().numpy()
    b = torch.cat([model.layers[0].qb_mean.view(-1), model.layers[2].qb_mean.view(-1), model.layers[4].qb_mean.view(-1)]).data.cpu().numpy()

    plt.figure(figsize=(10,5))
    plt.subplot(121)
    _ = plt.hist(W,np.linspace(-0.5,0.5,100),normed=True,label='q samples')
    plt.plot(x, N.pdf(x),label='prior pdf')
    plt.xlim([-0.5,0.5])
    plt.ylim([0,10])
    plt.legend()
    plt.title('Weights')

    plt.subplot(122)
    _ = plt.hist(b,np.linspace(-0.5,0.5,100),normed=True,label='q samples')
    plt.plot(x, N.pdf(x),label='prior pdf')
    plt.xlim([-0.5,0.5])
    plt.ylim([0,10])
    plt.legend()
    plt.title('bias')
    
    plt.savefig(filename)
    plt.close('all')

def addres(old,new):
    for k in new.keys():
        if k in old:
            old[k] = [vo+vn for vo,vn in zip(old[k],new[k])]
        else:
            old[k] = new[k]
    return old
    
def averres(res,num_batch):
    for k in res.keys():
        res[k] = [v/num_batch for v in res[k]]
    return res
    
def listdict2dictlist(LD):
    DL = dict(zip(sorted(LD[0].keys()),zip(*[[v for k,v in sorted(d.items())]for d in LD])))
    return DL
    

### Loading data

In [5]:
from operator import __or__
from functools import reduce
data = np.load('mnist.npz')  #Download from https://www.dropbox.com/s/k92825vinroxh6i/mnist.npz?dl=0
xtrain, ytrain = data['x_train'],data['y_train']
xvalid, yvalid = data['x_valid'],data['y_valid']
xtest, ytest = data['x_test'],data['y_test']
    
    


#MU, SIGMA = xtrain.mean(keepdims=True), xtrain.std(keepdims=True)
#xtrain = (xtrain-0.5)/0.5
#xvalid = (xvalid-0.5)/0.5
#xtest = (xtest-0.5)/0.5


### Hyperparameters and experiment settings

In [None]:
batch_size = 128
nsamples = 10
dataset_size, in_features = xtrain.shape
num_hidden = 400
num_layers = 2
p_logvar_init = 0#(-1,-2,-3,-4,-5)
q_logvar_init = -4
p_pi = 1.0 #(1./10.,2./10.,4./10.,2./10.,1./10.)

num_class = 10
learningrate = 1e-4
num_epochs = 50
cuda = torch.cuda.is_available()
save_every_epoch = 10


expname = 'norm_flow'
weights_dir = os.path.join(expname,'weights')
figure_dir = os.path.join(expname,'figures')

logfile = os.path.join(expname,'logfile.txt')
diagnosticsfile = os.path.join(expname,'diagnostics.pkl')
weightsfile = os.path.join(weights_dir,'model_epoch%i.pkl')
histfigurefile = os.path.join(figure_dir,'weighthistogram_epoch%i.png')


if not os.path.exists(expname): os.makedirs(expname)
if not os.path.exists(weights_dir): os.makedirs(weights_dir)  
if not os.path.exists(figure_dir): os.makedirs(figure_dir)  
with open(logfile,'w') as fh: fh.write('')


### Training the model

In [6]:
model = BBBMLP(in_features=in_features, num_class=num_class, num_hidden=num_hidden, num_layers=num_layers, p_logvar_init=p_logvar_init, p_pi=p_pi, q_logvar_init=q_logvar_init)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learningrate)
th_x = torch.FloatTensor(batch_size*nsamples,in_features).fill_(0)
th_y = torch.LongTensor(batch_size*nsamples).fill_(0)

#d = pickle.load(open("domain_a_{}/weights/model_epoch99.pkl".format(f), "rb"))
#model.load_prior(d)

if cuda:
    model.cuda()
    th_x = th_x.cuda()
    th_y = th_y.cuda()


def run_epoch(np_x, np_y, batch_size, epoch, MAP=False, is_training=False):
    N = np_x.shape[0]
    idx = np.arange(N)
    np.random.shuffle(idx)
    diagnostics = {}
    nbatch_per_epoch = N//batch_size

    for i in tqdm(range(nbatch_per_epoch)):
        batch_idx = idx[i*batch_size : (i+1)*batch_size]

        th_x.copy_(torch.from_numpy(np.tile(np_x[batch_idx],[nsamples,1])))
        th_y.copy_(torch.from_numpy(np.tile(np_y[batch_idx],[nsamples])))

        beta = 2 ** (nbatch_per_epoch - (epoch + 1)) / (2 ** nbatch_per_epoch - 1)
        logits, loss, _diagnostics = model.getloss(Variable(th_x),Variable(th_y), beta=beta, MAP=MAP)
        diagnostics = addres(diagnostics, _diagnostics)
        if is_training:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    diagnostics = averres(diagnostics,nbatch_per_epoch)
    return diagnostics



diagnostics_batch_train, diagnostics_batch_valid, diagnostics_batch_valid_MAP = [],[],[]
for e in range(num_epochs):

    if e % save_every_epoch == 0:
        #get weights before training
        weights = model.state_dict()
        for k in weights: 
            weights[k] = weights[k].cpu()
        plothist(model,histfigurefile%(e))   

    diagnostics_batch_train += [run_epoch(xtrain, ytrain, batch_size, e, MAP=False, is_training=True)]
    diagnostics_batch_valid += [run_epoch(xvalid, yvalid, batch_size, e, MAP=False, is_training=False)]
    diagnostics_batch_valid_MAP += [run_epoch(xvalid, yvalid, batch_size, e, MAP=True, is_training=False)]

    ltr = "Train %i | "%(e) + ", ".join(["%s: %s"%(k,"|".join(["%0.3f"%(_v) for _v in v])) for k,v in sorted(diagnostics_batch_train[-1].items())])
    lte = "Test %i | "%(e)     + ", ".join(["%s: %s"%(k,"|".join(["%0.3f"%(_v) for _v in v])) for k,v in sorted(diagnostics_batch_valid[-1].items())])
    ltemap = "MAP-Test %i | "%(e)     + ", ".join(["%s: %s"%(k,"|".join(["%0.3f"%(_v) for _v in v])) for k,v in sorted(diagnostics_batch_valid_MAP[-1].items())])

    print('\n'.join([ltr,lte,ltemap]))
    with open(logfile,'a') as fh: fh.write('\n'.join([ltr,lte,ltemap]) + '\n')

    if e % save_every_epoch == 0:
        with open(diagnosticsfile,'wb') as fh: 
            pickle.dump({'train':listdict2dictlist(diagnostics_batch_train),
                         'valid':listdict2dictlist(diagnostics_batch_valid),
                         'validMAP':listdict2dictlist(diagnostics_batch_valid_MAP)},fh)
        with open(weightsfile%(e),'wb') as fh: pickle.dump(weights,fh)
    gc.collect()
          
with open(weightsfile%(e),'wb') as fh: pickle.dump(weights,fh)

Normal prior
Normal prior
Normal prior
Normal prior
Normal prior
Normal prior


100%|██████████| 390/390 [02:27<00:00,  2.65it/s]
100%|██████████| 78/78 [00:10<00:00,  7.67it/s]
100%|██████████| 78/78 [00:06<00:00, 11.86it/s]


Train 0 | Hq_b: -0.589|-0.582|-0.562, Hq_w: -0.581|-0.576|-0.582, acc: 0.134, kl: 1911262.741, kl_b: 1657.242|1409.687|50.421, kl_w: 1249785.139|642001.854|16358.391, ll: -975944.013, logpy: -20312.643, loss: 975944.013
Test 0 | Hq_b: -0.588|-0.582|-0.559, Hq_w: -0.581|-0.574|-0.582, acc: 0.123, kl: 1908744.503, kl_b: 1672.778|1573.845|47.135, kl_w: 1250718.713|638420.562|16311.473, ll: -971397.826, logpy: -17025.578, loss: 971397.826
MAP-Test 0 | Hq_b: -0.588|-0.582|-0.559, Hq_w: -0.581|-0.574|-0.582, acc: 0.123, kl: 2142486.000, kl_b: 1763.986|1078.362|49.512, kl_w: 1404631.000|716824.625|18138.572, ll: -1085891.787, logpy: -14648.787, loss: 1085891.787


100%|██████████| 390/390 [02:30<00:00,  2.60it/s]
100%|██████████| 78/78 [00:09<00:00,  7.88it/s]
100%|██████████| 78/78 [00:06<00:00, 11.74it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 1 | Hq_b: -0.588|-0.582|-0.557, Hq_w: -0.581|-0.574|-0.582, acc: 0.325, kl: 1900153.797, kl_b: 1605.559|1371.665|43.583, kl_w: 1247041.895|633751.246|16339.850, ll: -483825.665, logpy: -8787.215, loss: 483825.665
Test 1 | Hq_b: -0.588|-0.582|-0.555, Hq_w: -0.581|-0.574|-0.582, acc: 0.507, kl: 1891170.077, kl_b: 1574.637|1585.707|41.120, kl_w: 1242958.665|628845.148|16164.805, ll: -477513.093, logpy: -4720.575, loss: 477513.093
MAP-Test 1 | Hq_b: -0.588|-0.582|-0.555, Hq_w: -0.581|-0.574|-0.582, acc: 0.515, kl: 2125617.500, kl_b: 1678.274|1596.261|44.478, kl_w: 1396863.875|707341.312|18093.328, ll: -535921.405, logpy: -4517.032, loss: 535921.405


100%|██████████| 390/390 [02:25<00:00,  2.68it/s]
100%|██████████| 78/78 [00:09<00:00,  7.89it/s]
100%|██████████| 78/78 [00:07<00:00,  9.76it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 2 | Hq_b: -0.588|-0.582|-0.554, Hq_w: -0.581|-0.574|-0.582, acc: 0.571, kl: 1885662.630, kl_b: 1586.985|1386.110|39.150, kl_w: 1240545.760|625980.381|16124.245, ll: -239677.742, logpy: -3969.913, loss: 239677.742
Test 2 | Hq_b: -0.588|-0.582|-0.553, Hq_w: -0.581|-0.574|-0.582, acc: 0.658, kl: 1879632.093, kl_b: 1529.023|1085.337|38.055, kl_w: 1237955.753|622917.696|16106.222, ll: -237803.842, logpy: -2849.830, loss: 237803.842
MAP-Test 2 | Hq_b: -0.588|-0.582|-0.553, Hq_w: -0.581|-0.574|-0.582, acc: 0.677, kl: 2115307.500, kl_b: 3648.140|430.611|41.076, kl_w: 1391805.375|701317.188|18065.275, ll: -267086.650, logpy: -2673.213, loss: 267086.650


100%|██████████| 390/390 [02:29<00:00,  2.60it/s]
100%|██████████| 78/78 [00:10<00:00,  7.15it/s]
100%|██████████| 78/78 [00:07<00:00, 11.13it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 3 | Hq_b: -0.588|-0.582|-0.552, Hq_w: -0.581|-0.574|-0.582, acc: 0.674, kl: 1875871.668, kl_b: 1567.519|1251.549|36.771, kl_w: 1235885.777|621037.551|16092.502, ll: -120047.936, logpy: -2805.957, loss: 120047.936
Test 3 | Hq_b: -0.588|-0.582|-0.551, Hq_w: -0.581|-0.574|-0.582, acc: 0.733, kl: 1872840.396, kl_b: 1513.335|1234.687|36.201, kl_w: 1234831.696|619124.861|16099.626, ll: -119101.534, logpy: -2049.009, loss: 119101.534
MAP-Test 3 | Hq_b: -0.588|-0.582|-0.551, Hq_w: -0.581|-0.574|-0.582, acc: 0.747, kl: 2107109.750, kl_b: 1736.901|1000.206|39.075, kl_w: 1388698.750|697590.500|18044.418, ll: -133646.842, logpy: -1952.483, loss: 133646.842


100%|██████████| 390/390 [02:28<00:00,  2.62it/s]
100%|██████████| 78/78 [00:09<00:00,  8.09it/s]
100%|██████████| 78/78 [00:06<00:00, 12.22it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 4 | Hq_b: -0.588|-0.582|-0.551, Hq_w: -0.581|-0.574|-0.582, acc: 0.734, kl: 1870934.792, kl_b: 1630.402|1252.027|35.456, kl_w: 1233918.179|618021.888|16076.843, ll: -60566.340, logpy: -2099.628, loss: 60566.340
Test 4 | Hq_b: -0.588|-0.582|-0.550, Hq_w: -0.581|-0.574|-0.582, acc: 0.775, kl: 1868763.038, kl_b: 1604.779|1208.249|35.432, kl_w: 1232938.776|616910.542|16065.269, ll: -60071.434, logpy: -1672.589, loss: 60071.434
MAP-Test 4 | Hq_b: -0.588|-0.582|-0.550, Hq_w: -0.581|-0.574|-0.582, acc: 0.789, kl: 2102930.000, kl_b: 1721.760|1000.428|37.926, kl_w: 1386808.000|695334.375|18027.633, ll: -67271.504, logpy: -1554.942, loss: 67271.504


100%|██████████| 390/390 [02:19<00:00,  2.81it/s]
100%|██████████| 78/78 [00:09<00:00,  7.89it/s]
100%|██████████| 78/78 [00:06<00:00, 12.14it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 5 | Hq_b: -0.588|-0.582|-0.550, Hq_w: -0.581|-0.574|-0.582, acc: 0.766, kl: 1867556.599, kl_b: 1587.553|1310.262|34.300, kl_w: 1232377.788|616182.732|16063.969, ll: -31048.671, logpy: -1868.099, loss: 31048.671
Test 5 | Hq_b: -0.588|-0.582|-0.550, Hq_w: -0.581|-0.574|-0.582, acc: 0.803, kl: 1866423.197, kl_b: 1617.358|1368.157|34.154, kl_w: 1231758.676|615585.090|16059.760, ll: -30575.677, logpy: -1412.814, loss: 30575.677
MAP-Test 5 | Hq_b: -0.588|-0.582|-0.550, Hq_w: -0.581|-0.574|-0.582, acc: 0.815, kl: 2100172.750, kl_b: 1744.076|758.294|37.262, kl_w: 1385646.625|693972.688|18013.703, ll: -34133.703, logpy: -1318.504, loss: 34133.703


100%|██████████| 390/390 [02:19<00:00,  2.79it/s]
100%|██████████| 78/78 [00:09<00:00,  7.84it/s]
100%|██████████| 78/78 [00:06<00:00, 12.20it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 6 | Hq_b: -0.588|-0.582|-0.550, Hq_w: -0.581|-0.574|-0.582, acc: 0.795, kl: 1865477.801, kl_b: 1577.474|1255.930|34.018, kl_w: 1231435.520|615132.153|16042.703, ll: -16057.406, logpy: -1483.361, loss: 16057.406
Test 6 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.821, kl: 1864672.684, kl_b: 1582.086|1304.881|34.252, kl_w: 1231046.500|614671.381|16033.589, ll: -15824.951, logpy: -1257.195, loss: 15824.951
MAP-Test 6 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.833, kl: 2098602.500, kl_b: 1722.888|746.721|36.862, kl_w: 1384939.500|693156.000|18000.539, ll: -17551.660, logpy: -1156.328, loss: 17551.660


100%|██████████| 390/390 [02:23<00:00,  2.72it/s]
100%|██████████| 78/78 [00:09<00:00,  8.31it/s]
100%|██████████| 78/78 [00:06<00:00, 12.33it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 7 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.815, kl: 1864272.433, kl_b: 1602.691|1281.914|33.246, kl_w: 1230887.168|614432.753|16034.659, ll: -8569.869, logpy: -1287.555, loss: 8569.869
Test 7 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.837, kl: 1863716.117, kl_b: 1565.584|1250.949|33.026, kl_w: 1230658.981|614183.707|16023.862, ll: -8384.442, logpy: -1104.301, loss: 8384.442
MAP-Test 7 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.849, kl: 2097677.500, kl_b: 1712.957|772.831|36.619, kl_w: 1384509.000|692657.875|17988.031, ll: -9215.943, logpy: -1021.890, loss: 9215.943


100%|██████████| 390/390 [02:23<00:00,  2.71it/s]
100%|██████████| 78/78 [00:09<00:00,  7.80it/s]
100%|██████████| 78/78 [00:06<00:00, 12.25it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 8 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.832, kl: 1863429.444, kl_b: 1619.055|1237.468|33.656, kl_w: 1230494.866|614022.013|16022.386, ll: -4784.727, logpy: -1145.216, loss: 4784.727
Test 8 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.852, kl: 1863252.005, kl_b: 1589.858|1325.645|32.701, kl_w: 1230411.638|613885.897|16006.260, ll: -4637.174, logpy: -998.010, loss: 4637.174
MAP-Test 8 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.862, kl: 2097053.500, kl_b: 1673.990|775.282|36.474, kl_w: 1384244.125|692347.875|17975.764, ll: -5015.133, logpy: -919.325, loss: 5015.133


100%|██████████| 390/390 [02:25<00:00,  2.68it/s]
100%|██████████| 78/78 [00:11<00:00,  6.97it/s]
100%|██████████| 78/78 [00:06<00:00, 11.76it/s]


Train 9 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.842, kl: 1863126.114, kl_b: 1638.549|1339.984|33.189, kl_w: 1230317.457|613777.975|16018.965, ll: -2902.080, logpy: -1082.620, loss: 2902.080
Test 9 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.857, kl: 1862935.569, kl_b: 1657.823|1267.066|34.189, kl_w: 1230240.979|613731.572|16003.934, ll: -2739.267, logpy: -919.994, loss: 2739.267
MAP-Test 9 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.870, kl: 2096669.000, kl_b: 1668.749|780.816|36.397, kl_w: 1384076.250|692142.562|17964.287, ll: -2893.921, logpy: -846.393, loss: 2893.921


100%|██████████| 390/390 [02:26<00:00,  2.66it/s]
100%|██████████| 78/78 [00:10<00:00,  7.52it/s]
100%|██████████| 78/78 [00:07<00:00, 10.72it/s]


Train 10 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.854, kl: 1862634.492, kl_b: 1574.796|1239.429|33.258, kl_w: 1230186.686|613605.952|15994.376, ll: -1896.383, logpy: -986.893, loss: 1896.383
Test 10 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.863, kl: 1862453.141, kl_b: 1555.997|1300.289|32.809, kl_w: 1230068.351|613501.040|15994.649, ll: -1779.534, logpy: -870.133, loss: 1779.534
MAP-Test 10 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.878, kl: 2096364.500, kl_b: 1629.632|777.557|36.345, kl_w: 1383966.625|692001.625|17952.602, ll: -1804.179, logpy: -780.564, loss: 1804.179


100%|██████████| 390/390 [02:22<00:00,  2.74it/s]
100%|██████████| 78/78 [00:09<00:00,  8.18it/s]
100%|██████████| 78/78 [00:06<00:00, 11.51it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 11 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.865, kl: 1862362.883, kl_b: 1554.125|1206.085|32.793, kl_w: 1230076.355|613508.090|15985.437, ll: -1289.666, logpy: -834.987, loss: 1289.666
Test 11 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.871, kl: 1862436.808, kl_b: 1519.610|1360.425|34.050, kl_w: 1230072.567|613466.483|15983.668, ll: -1240.527, logpy: -785.831, loss: 1240.527
MAP-Test 11 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.886, kl: 2096162.750, kl_b: 1618.704|777.857|36.326, kl_w: 1383891.750|691897.500|17940.598, ll: -1228.590, logpy: -716.831, loss: 1228.590


100%|██████████| 390/390 [02:39<00:00,  2.45it/s]
100%|██████████| 78/78 [00:14<00:00,  5.28it/s]
100%|██████████| 78/78 [00:06<00:00, 11.83it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 12 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.872, kl: 1862397.441, kl_b: 1613.909|1360.278|34.121, kl_w: 1229998.685|613416.517|15973.934, ll: -1004.125, logpy: -776.782, loss: 1004.125
Test 12 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.881, kl: 1861964.143, kl_b: 1572.203|1099.047|34.339, kl_w: 1229974.383|613321.535|15962.637, ll: -975.751, logpy: -748.460, loss: 975.751
MAP-Test 12 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.890, kl: 2095985.000, kl_b: 1591.418|777.934|36.313, kl_w: 1383836.875|691814.062|17928.457, ll: -935.067, logpy: -679.209, loss: 935.067


100%|██████████| 390/390 [02:34<00:00,  2.52it/s]
100%|██████████| 78/78 [00:10<00:00,  7.71it/s]
100%|██████████| 78/78 [00:07<00:00, 11.07it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 13 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.877, kl: 1862114.854, kl_b: 1583.006|1272.174|33.296, kl_w: 1229936.843|613319.494|15970.047, ll: -874.546, logpy: -760.891, loss: 874.546
Test 13 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.885, kl: 1862171.511, kl_b: 1619.522|1284.146|33.365, kl_w: 1229992.596|613285.731|15956.164, ll: -821.249, logpy: -707.591, loss: 821.249
MAP-Test 13 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.897, kl: 2095828.875, kl_b: 1571.308|776.466|36.295, kl_w: 1383791.125|691736.875|17916.793, ll: -757.132, logpy: -629.213, loss: 757.132


100%|██████████| 390/390 [02:37<00:00,  2.48it/s]
100%|██████████| 78/78 [00:10<00:00,  7.15it/s]
100%|██████████| 78/78 [00:07<00:00, 10.20it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 14 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.884, kl: 1862034.206, kl_b: 1616.578|1309.746|33.093, kl_w: 1229895.865|613243.372|15935.547, ll: -710.355, logpy: -653.530, loss: 710.355
Test 14 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.894, kl: 1861709.910, kl_b: 1531.248|1070.443|33.665, kl_w: 1229913.540|613212.765|15948.252, ll: -706.405, logpy: -649.590, loss: 706.405
MAP-Test 14 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.900, kl: 2095732.750, kl_b: 1591.371|775.684|36.298, kl_w: 1383755.375|691669.812|17904.256, ll: -655.350, logpy: -591.393, loss: 655.350


100%|██████████| 390/390 [02:36<00:00,  2.49it/s]
100%|██████████| 78/78 [00:10<00:00,  7.66it/s]
100%|██████████| 78/78 [00:07<00:00, 11.14it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 15 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.890, kl: 1861973.067, kl_b: 1568.588|1368.728|32.996, kl_w: 1229852.465|613211.538|15938.759, ll: -636.727, logpy: -608.316, loss: 636.727
Test 15 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.896, kl: 1861855.615, kl_b: 1647.504|1233.627|32.605, kl_w: 1229824.830|613197.497|15919.556, ll: -640.556, logpy: -612.147, loss: 640.556
MAP-Test 15 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.903, kl: 2095580.125, kl_b: 1557.645|759.387|36.322, kl_w: 1383723.750|691611.500|17891.549, ll: -595.376, logpy: -563.400, loss: 595.376


100%|██████████| 390/390 [02:35<00:00,  2.51it/s]
100%|██████████| 78/78 [00:09<00:00,  8.02it/s]
100%|██████████| 78/78 [00:07<00:00, 10.38it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 16 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.897, kl: 1861721.986, kl_b: 1522.473|1269.618|33.090, kl_w: 1229859.480|613111.683|15925.645, ll: -567.661, logpy: -553.457, loss: 567.661
Test 16 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.901, kl: 1861758.255, kl_b: 1686.772|1269.555|32.439, kl_w: 1229720.290|613132.992|15916.204, ll: -600.278, logpy: -586.074, loss: 600.278
MAP-Test 16 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.908, kl: 2095461.625, kl_b: 1536.819|757.334|36.309, kl_w: 1383695.250|691556.750|17879.045, ll: -547.140, logpy: -531.153, loss: 547.140


100%|██████████| 390/390 [02:30<00:00,  2.59it/s]
100%|██████████| 78/78 [00:10<00:00,  7.48it/s]
100%|██████████| 78/78 [00:06<00:00, 11.41it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 17 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.582, acc: 0.902, kl: 1861633.667, kl_b: 1561.302|1260.458|32.966, kl_w: 1229792.262|613074.045|15912.637, ll: -523.560, logpy: -516.458, loss: 523.560
Test 17 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.902, kl: 1861684.006, kl_b: 1602.811|1279.558|32.793, kl_w: 1229800.112|613071.543|15897.176, ll: -565.583, logpy: -558.481, loss: 565.583
MAP-Test 17 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.912, kl: 2095343.625, kl_b: 1514.593|759.196|36.305, kl_w: 1383665.500|691501.438|17866.576, ll: -516.501, logpy: -508.508, loss: 516.501


100%|██████████| 390/390 [02:26<00:00,  2.67it/s]
100%|██████████| 78/78 [00:09<00:00,  7.98it/s]
100%|██████████| 78/78 [00:06<00:00, 12.16it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 18 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.906, kl: 1861593.839, kl_b: 1600.818|1253.345|33.228, kl_w: 1229791.817|613014.083|15900.549, ll: -490.708, logpy: -487.157, loss: 490.708
Test 18 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.903, kl: 1861416.556, kl_b: 1529.719|1233.155|32.681, kl_w: 1229726.952|612993.944|15900.101, ll: -530.232, logpy: -526.682, loss: 530.232
MAP-Test 18 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.914, kl: 2095254.375, kl_b: 1513.400|762.178|36.316, kl_w: 1383637.125|691451.500|17853.775, ll: -484.293, logpy: -480.297, loss: 484.293


100%|██████████| 390/390 [02:23<00:00,  2.72it/s]
100%|██████████| 78/78 [00:09<00:00,  7.87it/s]
100%|██████████| 78/78 [00:06<00:00, 11.64it/s]


Train 19 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.909, kl: 1861470.547, kl_b: 1592.979|1215.008|33.470, kl_w: 1229776.992|612967.767|15884.332, ll: -464.143, logpy: -462.368, loss: 464.143
Test 19 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.908, kl: 1861237.593, kl_b: 1537.050|1178.749|32.045, kl_w: 1229723.304|612936.439|15830.003, ll: -510.975, logpy: -509.200, loss: 510.975
MAP-Test 19 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.918, kl: 2095107.750, kl_b: 1470.851|751.387|36.307, kl_w: 1383609.375|691398.938|17841.002, ll: -463.003, logpy: -461.005, loss: 463.003


100%|██████████| 390/390 [02:22<00:00,  2.74it/s]
100%|██████████| 78/78 [00:10<00:00,  7.57it/s]
100%|██████████| 78/78 [00:07<00:00, 10.80it/s]


Train 20 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.914, kl: 1861392.195, kl_b: 1574.187|1239.742|32.657, kl_w: 1229749.996|612924.517|15871.099, ll: -421.083, logpy: -420.195, loss: 421.083
Test 20 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.910, kl: 1861493.785, kl_b: 1555.779|1472.337|33.556, kl_w: 1229718.790|612847.357|15865.969, ll: -501.846, logpy: -500.958, loss: 501.846
MAP-Test 20 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.917, kl: 2095197.750, kl_b: 1644.936|754.106|36.302, kl_w: 1383581.375|691353.000|17828.041, ll: -446.077, logpy: -445.077, loss: 446.077


100%|██████████| 390/390 [02:25<00:00,  2.68it/s]
100%|██████████| 78/78 [00:09<00:00,  8.03it/s]
100%|██████████| 78/78 [00:07<00:00, 11.10it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 21 | Hq_b: -0.588|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.916, kl: 1861375.881, kl_b: 1598.015|1323.408|33.025, kl_w: 1229718.049|612841.882|15861.508, ll: -394.803, logpy: -394.359, loss: 394.803
Test 21 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.913, kl: 1861469.845, kl_b: 1618.142|1379.518|32.835, kl_w: 1229676.909|612908.011|15854.435, ll: -460.493, logpy: -460.049, loss: 460.493
MAP-Test 21 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.921, kl: 2095063.750, kl_b: 1604.283|742.986|36.343, kl_w: 1383552.375|691312.438|17815.414, ll: -427.122, logpy: -426.622, loss: 427.122


100%|██████████| 390/390 [02:28<00:00,  2.63it/s]
100%|██████████| 78/78 [00:10<00:00,  7.67it/s]
100%|██████████| 78/78 [00:06<00:00, 11.53it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 22 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.920, kl: 1861290.779, kl_b: 1575.145|1333.662|33.185, kl_w: 1229650.161|612852.539|15846.088, ll: -369.603, logpy: -369.381, loss: 369.603
Test 22 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.916, kl: 1861104.062, kl_b: 1553.182|1202.884|32.746, kl_w: 1229696.933|612771.652|15846.646, ll: -456.153, logpy: -455.931, loss: 456.153
MAP-Test 22 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.924, kl: 2094986.875, kl_b: 1612.976|741.045|36.341, kl_w: 1383523.625|691269.688|17803.125, ll: -406.010, logpy: -405.761, loss: 406.010


100%|██████████| 390/390 [02:28<00:00,  2.62it/s]
100%|██████████| 78/78 [00:09<00:00,  8.11it/s]
100%|██████████| 78/78 [00:06<00:00, 12.08it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 23 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.923, kl: 1861078.250, kl_b: 1530.987|1237.262|33.379, kl_w: 1229653.283|612787.264|15836.073, ll: -344.827, logpy: -344.716, loss: 344.827
Test 23 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.916, kl: 1861012.178, kl_b: 1535.759|1282.348|32.988, kl_w: 1229585.837|612734.764|15840.487, ll: -436.908, logpy: -436.797, loss: 436.908
MAP-Test 23 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.924, kl: 2094932.500, kl_b: 1621.148|755.061|36.319, kl_w: 1383495.375|691233.562|17791.033, ll: -399.939, logpy: -399.814, loss: 399.939


100%|██████████| 390/390 [02:22<00:00,  2.73it/s]
100%|██████████| 78/78 [00:09<00:00,  8.23it/s]
100%|██████████| 78/78 [00:06<00:00, 12.55it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 24 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.928, kl: 1861171.423, kl_b: 1557.171|1350.774|33.547, kl_w: 1229641.851|612770.921|15817.160, ll: -321.479, logpy: -321.424, loss: 321.479
Test 24 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.919, kl: 1860791.324, kl_b: 1486.671|1171.796|34.175, kl_w: 1229634.542|612654.889|15809.259, ll: -419.765, logpy: -419.709, loss: 419.765
MAP-Test 24 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.926, kl: 2094820.000, kl_b: 1604.374|734.763|36.320, kl_w: 1383465.250|691199.938|17779.484, ll: -383.574, logpy: -383.511, loss: 383.574


100%|██████████| 390/390 [02:17<00:00,  2.84it/s]
100%|██████████| 78/78 [00:09<00:00,  8.26it/s]
100%|██████████| 78/78 [00:06<00:00, 12.58it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 25 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.931, kl: 1860917.806, kl_b: 1528.176|1198.542|33.021, kl_w: 1229601.396|612745.826|15810.846, ll: -296.063, logpy: -296.035, loss: 296.063
Test 25 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.922, kl: 1861046.011, kl_b: 1524.597|1396.998|31.623, kl_w: 1229580.986|612708.227|15803.582, ll: -406.684, logpy: -406.656, loss: 406.684
MAP-Test 25 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.928, kl: 2094773.250, kl_b: 1603.365|763.293|36.307, kl_w: 1383436.625|691165.812|17767.816, ll: -369.347, logpy: -369.316, loss: 369.347


100%|██████████| 390/390 [02:18<00:00,  2.81it/s]
100%|██████████| 78/78 [00:09<00:00,  8.15it/s]
100%|██████████| 78/78 [00:06<00:00, 12.54it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 26 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.933, kl: 1860870.762, kl_b: 1486.197|1302.617|33.149, kl_w: 1229556.313|612690.962|15801.526, ll: -284.832, logpy: -284.818, loss: 284.832
Test 26 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.926, kl: 1860930.944, kl_b: 1526.305|1372.283|32.638, kl_w: 1229533.231|612671.863|15794.628, ll: -384.933, logpy: -384.920, loss: 384.933
MAP-Test 26 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.932, kl: 2094408.375, kl_b: 1310.735|764.480|36.297, kl_w: 1383406.750|691132.875|17757.182, ll: -351.753, logpy: -351.737, loss: 351.753


100%|██████████| 390/390 [02:23<00:00,  2.72it/s]
100%|██████████| 78/78 [00:09<00:00,  8.36it/s]
100%|██████████| 78/78 [00:06<00:00, 12.87it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 27 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.934, kl: 1860782.082, kl_b: 1487.365|1325.351|32.951, kl_w: 1229502.409|612671.929|15762.077, ll: -322.864, logpy: -322.857, loss: 322.864
Test 27 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.924, kl: 1860987.361, kl_b: 1569.762|1363.636|33.731, kl_w: 1229541.899|612691.196|15787.136, ll: -387.360, logpy: -387.353, loss: 387.360
MAP-Test 27 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.932, kl: 2094343.125, kl_b: 1338.519|737.155|36.201, kl_w: 1383370.250|691109.188|17751.945, ll: -345.224, logpy: -345.216, loss: 345.224


100%|██████████| 390/390 [02:19<00:00,  2.80it/s]
100%|██████████| 78/78 [00:09<00:00,  8.13it/s]
100%|██████████| 78/78 [00:06<00:00, 12.80it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 28 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.937, kl: 1860755.387, kl_b: 1537.449|1259.916|33.322, kl_w: 1229478.406|612657.464|15788.830, ll: -256.688, logpy: -256.685, loss: 256.688
Test 28 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.928, kl: 1860808.952, kl_b: 1413.036|1434.856|34.058, kl_w: 1229517.712|612626.095|15783.195, ll: -378.207, logpy: -378.203, loss: 378.207
MAP-Test 28 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.934, kl: 2094270.125, kl_b: 1295.954|762.064|36.251, kl_w: 1383343.375|691085.188|17747.301, ll: -339.491, logpy: -339.487, loss: 339.491


100%|██████████| 390/390 [02:17<00:00,  2.83it/s]
100%|██████████| 78/78 [00:09<00:00,  8.23it/s]
100%|██████████| 78/78 [00:06<00:00, 12.53it/s]


Train 29 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.940, kl: 1860631.081, kl_b: 1492.573|1262.589|33.051, kl_w: 1229454.541|612607.527|15780.796, ll: -237.869, logpy: -237.868, loss: 237.869
Test 29 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.929, kl: 1860698.038, kl_b: 1464.200|1345.479|33.652, kl_w: 1229470.346|612605.214|15779.159, ll: -351.747, logpy: -351.745, loss: 351.747
MAP-Test 29 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.934, kl: 2094108.250, kl_b: 1283.333|664.211|36.261, kl_w: 1383315.125|691066.875|17742.490, ll: -325.017, logpy: -325.015, loss: 325.017


100%|██████████| 390/390 [02:17<00:00,  2.84it/s]
100%|██████████| 78/78 [00:09<00:00,  8.38it/s]
100%|██████████| 78/78 [00:06<00:00, 12.85it/s]


Train 30 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.942, kl: 1860727.973, kl_b: 1517.744|1340.290|33.051, kl_w: 1229444.272|612611.467|15781.147, ll: -224.612, logpy: -224.611, loss: 224.612
Test 30 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.927, kl: 1860829.337, kl_b: 1513.347|1364.273|32.720, kl_w: 1229520.064|612616.580|15782.349, ll: -369.477, logpy: -369.476, loss: 369.477
MAP-Test 30 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.936, kl: 2093930.875, kl_b: 1177.190|649.960|36.255, kl_w: 1383284.500|691044.812|17738.074, ll: -314.465, logpy: -314.464, loss: 314.465


100%|██████████| 390/390 [02:16<00:00,  2.85it/s]
100%|██████████| 78/78 [00:09<00:00,  8.40it/s]
100%|██████████| 78/78 [00:06<00:00, 12.77it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 31 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.944, kl: 1860686.581, kl_b: 1504.115|1355.679|33.049, kl_w: 1229412.339|612605.188|15776.214, ll: -207.951, logpy: -207.951, loss: 207.951
Test 31 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.932, kl: 1860741.825, kl_b: 1521.961|1397.792|32.174, kl_w: 1229374.188|612642.845|15772.867, ll: -343.303, logpy: -343.302, loss: 343.303
MAP-Test 31 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.937, kl: 2093908.125, kl_b: 1189.320|661.372|36.279, kl_w: 1383257.750|691029.562|17733.848, ll: -310.544, logpy: -310.543, loss: 310.544


100%|██████████| 390/390 [02:18<00:00,  2.82it/s]
100%|██████████| 78/78 [00:09<00:00,  7.84it/s]
100%|██████████| 78/78 [00:06<00:00, 12.88it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 32 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.947, kl: 1860578.871, kl_b: 1480.620|1310.275|32.828, kl_w: 1229403.659|612582.125|15769.362, ll: -197.068, logpy: -197.068, loss: 197.068
Test 32 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.931, kl: 1860716.188, kl_b: 1469.046|1484.913|32.994, kl_w: 1229393.115|612564.684|15771.431, ll: -344.186, logpy: -344.186, loss: 344.186
MAP-Test 32 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.939, kl: 2093806.000, kl_b: 1163.643|627.587|36.245, kl_w: 1383235.625|691013.375|17729.535, ll: -301.681, logpy: -301.681, loss: 301.681


100%|██████████| 390/390 [02:17<00:00,  2.83it/s]
100%|██████████| 78/78 [00:09<00:00,  8.40it/s]
100%|██████████| 78/78 [00:06<00:00, 12.83it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 33 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.583, acc: 0.950, kl: 1860510.818, kl_b: 1438.724|1354.011|33.020, kl_w: 1229373.949|612549.508|15761.605, ll: -178.105, logpy: -178.105, loss: 178.105
Test 33 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.583, acc: 0.931, kl: 1860449.260, kl_b: 1497.019|1264.531|34.128, kl_w: 1229302.139|612581.459|15769.983, ll: -338.721, logpy: -338.720, loss: 338.721
MAP-Test 33 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.583, acc: 0.940, kl: 2093880.875, kl_b: 1176.334|734.694|36.228, kl_w: 1383210.250|690997.812|17725.514, ll: -297.688, logpy: -297.688, loss: 297.688


100%|██████████| 390/390 [02:17<00:00,  2.85it/s]
100%|██████████| 78/78 [00:10<00:00,  7.59it/s]
100%|██████████| 78/78 [00:06<00:00, 12.90it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 34 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.583, acc: 0.951, kl: 1860448.458, kl_b: 1441.192|1302.311|33.056, kl_w: 1229343.154|612565.960|15762.786, ll: -169.280, logpy: -169.280, loss: 169.280
Test 34 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.583, acc: 0.935, kl: 1860399.101, kl_b: 1502.678|1197.935|33.400, kl_w: 1229307.548|612594.553|15762.995, ll: -312.072, logpy: -312.072, loss: 312.072
MAP-Test 34 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.583, acc: 0.942, kl: 2093821.875, kl_b: 1201.134|683.750|36.206, kl_w: 1383184.625|690994.125|17722.047, ll: -289.676, logpy: -289.676, loss: 289.676


100%|██████████| 390/390 [02:19<00:00,  2.80it/s]
100%|██████████| 78/78 [00:10<00:00,  7.22it/s]
100%|██████████| 78/78 [00:06<00:00, 11.88it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 35 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.583, acc: 0.953, kl: 1860408.104, kl_b: 1461.810|1276.072|33.282, kl_w: 1229332.204|612544.531|15760.208, ll: -159.837, logpy: -159.837, loss: 159.837
Test 35 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.583, acc: 0.936, kl: 1860505.367, kl_b: 1404.628|1361.711|33.573, kl_w: 1229341.801|612610.575|15753.084, ll: -312.339, logpy: -312.339, loss: 312.339
MAP-Test 35 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.583, acc: 0.943, kl: 2093931.500, kl_b: 1209.070|817.450|36.090, kl_w: 1383161.250|690989.500|17718.188, ll: -283.824, logpy: -283.824, loss: 283.824


100%|██████████| 390/390 [02:18<00:00,  2.82it/s]
100%|██████████| 78/78 [00:10<00:00,  7.47it/s]
100%|██████████| 78/78 [00:06<00:00, 11.51it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 36 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.583, acc: 0.954, kl: 1860343.859, kl_b: 1425.964|1315.584|32.560, kl_w: 1229291.710|612527.266|15750.775, ll: -147.443, logpy: -147.443, loss: 147.443
Test 36 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.936, kl: 1860420.867, kl_b: 1464.308|1410.101|33.283, kl_w: 1229260.708|612499.500|15752.965, ll: -317.141, logpy: -317.141, loss: 317.141
MAP-Test 36 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.944, kl: 2093619.125, kl_b: 1200.885|547.356|36.003, kl_w: 1383137.125|690982.938|17714.928, ll: -278.007, logpy: -278.007, loss: 278.007


100%|██████████| 390/390 [02:32<00:00,  2.56it/s]
100%|██████████| 78/78 [00:11<00:00,  6.80it/s]
100%|██████████| 78/78 [00:07<00:00, 10.16it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 37 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.957, kl: 1860274.678, kl_b: 1423.905|1281.906|33.047, kl_w: 1229253.131|612533.047|15749.635, ll: -135.582, logpy: -135.582, loss: 135.582
Test 37 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.939, kl: 1860096.361, kl_b: 1459.565|1078.321|33.582, kl_w: 1229276.155|612498.587|15750.152, ll: -293.115, logpy: -293.115, loss: 293.115
MAP-Test 37 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.945, kl: 2093825.125, kl_b: 1164.610|808.036|35.970, kl_w: 1383121.250|690983.312|17711.963, ll: -268.886, logpy: -268.886, loss: 268.886


100%|██████████| 390/390 [02:33<00:00,  2.55it/s]
100%|██████████| 78/78 [00:10<00:00,  7.27it/s]
100%|██████████| 78/78 [00:07<00:00, 10.31it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 38 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.583, acc: 0.958, kl: 1860200.104, kl_b: 1431.964|1234.303|32.201, kl_w: 1229206.412|612548.530|15746.696, ll: -132.464, logpy: -132.464, loss: 132.464
Test 38 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.584, acc: 0.939, kl: 1860235.769, kl_b: 1437.305|1288.991|33.157, kl_w: 1229222.356|612489.688|15764.278, ll: -294.795, logpy: -294.795, loss: 294.795
MAP-Test 38 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.584, acc: 0.946, kl: 2093963.625, kl_b: 1176.767|952.830|35.883, kl_w: 1383103.625|690985.562|17708.936, ll: -263.490, logpy: -263.490, loss: 263.490


100%|██████████| 390/390 [02:28<00:00,  2.63it/s]
100%|██████████| 78/78 [00:11<00:00,  6.65it/s]
100%|██████████| 78/78 [00:07<00:00,  9.94it/s]


Train 39 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.584, acc: 0.960, kl: 1860287.919, kl_b: 1456.513|1287.628|32.168, kl_w: 1229233.870|612533.589|15744.149, ll: -119.330, logpy: -119.330, loss: 119.330
Test 39 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.584, acc: 0.940, kl: 1860329.210, kl_b: 1348.769|1293.856|33.286, kl_w: 1229307.566|612605.648|15740.089, ll: -280.542, logpy: -280.542, loss: 280.542
MAP-Test 39 | Hq_b: -0.589|-0.582|-0.549, Hq_w: -0.581|-0.574|-0.584, acc: 0.947, kl: 2093924.000, kl_b: 1164.600|939.979|35.934, kl_w: 1383091.250|690986.438|17705.844, ll: -258.576, logpy: -258.576, loss: 258.576


100%|██████████| 390/390 [02:26<00:00,  2.66it/s]
100%|██████████| 78/78 [00:10<00:00,  7.33it/s]
100%|██████████| 78/78 [00:06<00:00, 11.45it/s]


Train 40 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.962, kl: 1860150.177, kl_b: 1429.720|1154.413|33.100, kl_w: 1229262.676|612528.212|15742.053, ll: -111.165, logpy: -111.165, loss: 111.165
Test 40 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.939, kl: 1859896.386, kl_b: 1404.742|993.415|32.540, kl_w: 1229242.284|612480.768|15742.629, ll: -297.723, logpy: -297.723, loss: 297.723
MAP-Test 40 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.946, kl: 2094308.250, kl_b: 1159.932|1342.487|35.921, kl_w: 1383072.125|690994.875|17702.941, ll: -258.473, logpy: -258.473, loss: 258.473


100%|██████████| 390/390 [02:29<00:00,  2.61it/s]
100%|██████████| 78/78 [00:09<00:00,  8.03it/s]
100%|██████████| 78/78 [00:06<00:00, 12.51it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 41 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.964, kl: 1860147.473, kl_b: 1429.631|1217.157|32.751, kl_w: 1229193.049|612536.268|15738.616, ll: -102.215, logpy: -102.215, loss: 102.215
Test 41 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.941, kl: 1860152.801, kl_b: 1398.977|1244.291|32.873, kl_w: 1229227.601|612504.353|15744.729, ll: -271.125, logpy: -271.125, loss: 271.125
MAP-Test 41 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.949, kl: 2094292.250, kl_b: 1160.148|1330.045|35.860, kl_w: 1383064.875|691000.250|17701.104, ll: -248.482, logpy: -248.482, loss: 248.482


100%|██████████| 390/390 [02:21<00:00,  2.76it/s]
100%|██████████| 78/78 [00:10<00:00,  7.20it/s]
100%|██████████| 78/78 [00:06<00:00, 12.38it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 42 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.963, kl: 1860092.011, kl_b: 1433.782|1191.397|32.770, kl_w: 1229165.098|612541.570|15727.388, ll: -147.825, logpy: -147.825, loss: 147.825
Test 42 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.944, kl: 1860226.471, kl_b: 1467.230|1202.090|32.170, kl_w: 1229185.226|612606.829|15732.913, ll: -275.531, logpy: -275.531, loss: 275.531
MAP-Test 42 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.948, kl: 2094579.125, kl_b: 1277.676|1519.473|35.787, kl_w: 1383046.500|690999.250|17700.502, ll: -248.558, logpy: -248.558, loss: 248.558


100%|██████████| 390/390 [02:21<00:00,  2.76it/s]
100%|██████████| 78/78 [00:09<00:00,  7.91it/s]
100%|██████████| 78/78 [00:06<00:00, 11.75it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 43 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.967, kl: 1860175.274, kl_b: 1420.090|1243.847|32.415, kl_w: 1229181.354|612560.593|15736.980, ll: -89.882, logpy: -89.882, loss: 89.882
Test 43 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.942, kl: 1860183.561, kl_b: 1440.910|1190.782|32.524, kl_w: 1229159.939|612618.690|15740.703, ll: -277.785, logpy: -277.785, loss: 277.785
MAP-Test 43 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.948, kl: 2094813.375, kl_b: 1345.664|1687.937|35.795, kl_w: 1383038.625|691006.938|17698.443, ll: -248.592, logpy: -248.592, loss: 248.592


100%|██████████| 390/390 [02:26<00:00,  2.66it/s]
100%|██████████| 78/78 [00:10<00:00,  7.14it/s]
100%|██████████| 78/78 [00:06<00:00, 12.03it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 44 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.968, kl: 1860230.149, kl_b: 1422.547|1311.292|33.323, kl_w: 1229140.458|612585.554|15736.965, ll: -82.938, logpy: -82.938, loss: 82.938
Test 44 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.943, kl: 1860315.152, kl_b: 1461.573|1231.637|32.126, kl_w: 1229277.865|612583.703|15728.249, ll: -264.196, logpy: -264.196, loss: 264.196
MAP-Test 44 | Hq_b: -0.589|-0.582|-0.548, Hq_w: -0.581|-0.574|-0.584, acc: 0.950, kl: 2094891.375, kl_b: 1508.628|1592.539|35.954, kl_w: 1383035.875|691021.938|17696.443, ll: -243.614, logpy: -243.614, loss: 243.614


100%|██████████| 390/390 [02:34<00:00,  2.53it/s]
100%|██████████| 78/78 [00:10<00:00,  7.63it/s]
100%|██████████| 78/78 [00:06<00:00, 12.16it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 45 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.584, acc: 0.971, kl: 1860226.492, kl_b: 1423.377|1292.405|33.036, kl_w: 1229156.080|612587.576|15734.019, ll: -72.251, logpy: -72.251, loss: 72.251
Test 45 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.584, acc: 0.943, kl: 1860011.750, kl_b: 1440.798|1122.997|31.945, kl_w: 1229152.069|612523.925|15740.023, ll: -265.577, logpy: -265.577, loss: 265.577
MAP-Test 45 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.584, acc: 0.950, kl: 2094918.875, kl_b: 1528.992|1585.409|35.921, kl_w: 1383034.625|691039.125|17694.705, ll: -239.984, logpy: -239.984, loss: 239.984


100%|██████████| 390/390 [02:28<00:00,  2.63it/s]
100%|██████████| 78/78 [00:10<00:00,  7.69it/s]
100%|██████████| 78/78 [00:06<00:00, 11.46it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 46 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.584, acc: 0.971, kl: 1860175.524, kl_b: 1438.122|1193.958|32.329, kl_w: 1229174.112|612603.006|15733.997, ll: -70.273, logpy: -70.273, loss: 70.273
Test 46 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.584, acc: 0.946, kl: 1860015.381, kl_b: 1434.649|1054.559|32.436, kl_w: 1229144.165|612610.393|15739.172, ll: -248.705, logpy: -248.705, loss: 248.705
MAP-Test 46 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.584, acc: 0.951, kl: 2094543.375, kl_b: 1564.269|1162.814|35.929, kl_w: 1383036.375|691051.125|17692.990, ll: -238.977, logpy: -238.977, loss: 238.977


100%|██████████| 390/390 [02:22<00:00,  2.74it/s]
100%|██████████| 78/78 [00:09<00:00,  8.25it/s]
100%|██████████| 78/78 [00:06<00:00, 12.59it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 47 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.584, acc: 0.973, kl: 1860152.108, kl_b: 1424.690|1186.656|32.835, kl_w: 1229199.056|612577.528|15731.344, ll: -65.518, logpy: -65.518, loss: 65.518
Test 47 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.584, acc: 0.941, kl: 1860205.798, kl_b: 1420.300|1256.430|32.010, kl_w: 1229146.976|612615.562|15734.511, ll: -273.447, logpy: -273.447, loss: 273.447
MAP-Test 47 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.584, acc: 0.950, kl: 2095077.250, kl_b: 1548.050|1694.633|35.937, kl_w: 1383039.500|691067.562|17691.568, ll: -232.559, logpy: -232.559, loss: 232.559


100%|██████████| 390/390 [02:18<00:00,  2.82it/s]
100%|██████████| 78/78 [00:09<00:00,  8.23it/s]
100%|██████████| 78/78 [00:06<00:00, 12.56it/s]
  0%|          | 0/390 [00:00<?, ?it/s]

Train 48 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.584, acc: 0.974, kl: 1860347.402, kl_b: 1430.255|1350.596|32.671, kl_w: 1229189.573|612613.567|15730.742, ll: -59.826, logpy: -59.826, loss: 59.826
Test 48 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.584, acc: 0.943, kl: 1860317.790, kl_b: 1453.179|1372.864|33.704, kl_w: 1229139.981|612587.651|15730.405, ll: -255.677, logpy: -255.677, loss: 255.677
MAP-Test 48 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.584, acc: 0.951, kl: 2095241.875, kl_b: 1620.884|1750.633|36.077, kl_w: 1383051.500|691092.938|17689.854, ll: -231.562, logpy: -231.562, loss: 231.562


100%|██████████| 390/390 [02:19<00:00,  2.80it/s]
100%|██████████| 78/78 [00:09<00:00,  8.21it/s]
100%|██████████| 78/78 [00:06<00:00, 12.41it/s]


Train 49 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.584, acc: 0.975, kl: 1860423.620, kl_b: 1423.501|1403.653|32.927, kl_w: 1229211.198|612623.000|15729.341, ll: -51.688, logpy: -51.688, loss: 51.688
Test 49 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.585, acc: 0.943, kl: 1860509.122, kl_b: 1440.766|1460.202|32.988, kl_w: 1229178.018|612669.321|15727.828, ll: -259.081, logpy: -259.081, loss: 259.081
MAP-Test 49 | Hq_b: -0.589|-0.582|-0.547, Hq_w: -0.581|-0.574|-0.585, acc: 0.952, kl: 2095312.375, kl_b: 1684.968|1725.372|36.120, kl_w: 1383063.375|691113.188|17689.283, ll: -230.659, logpy: -230.659, loss: 230.659


### Plotting results of $\beta$ experiments

In [None]:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import rc
plt.style.use("seaborn")
import seaborn as sns
import re
import numpy as np
plt.rc('text', usetex=True)
plt.rc('font', family='serif', size=22)
plt.rcParams.update({'legend.fontsize': 16, 'xtick.labelsize': 16,
'ytick.labelsize': 16, 'axes.labelsize': 16})
def load_data(basename):
    files_no_transfer = [open("results/{0}_beta_{1}/diagnostics.txt".format(basename, beta_type)).read() for beta_type in ["Blundell", "Standard", "None"] ]
    
    accuracies = [list(map(lambda x: x.split(" ")[-1], re.findall(r"(\'acc\': \d.\d+)", file))) for file in files_no_transfer]
    valid = [acc[1::2] for acc in accuracies]
    return np.array(valid).astype(np.float32)

f = plt.figure(figsize=(10, 8))

current_palette = sns.color_palette()
vt = load_data("[0,1,2,3,4,5,6,7,8,9]")
x_ticks = range(vt.shape[1])
colors  = sns.color_palette(n_colors = 4)
legends  = [r"$\frac{2^{M-i}}{2^M-1}$", r"$\frac{1}{M}$", "0"]
for values, legend_name, color in zip(vt,legends, colors):
    plt.plot(x_ticks, values, label=r"Accuracy with transfer, with $\beta$= {0}".format(legend_name), color=color)

plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.xticks(x_ticks[::5], map(lambda x: x+1, x_ticks[::5]))
f.suptitle("Accuracy after training for 50 epochs")
plt.legend(loc = 7)

plt.savefig("figs/comparing_beta.pdf")

### Plotting results of Normalizing Flows experiments

In [1]:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import rc
plt.style.use("seaborn")
import seaborn as sns
import re
import numpy as np

plt.rc('font', family='serif', size=22)
plt.rcParams.update({'legend.fontsize': 16, 'xtick.labelsize': 16,
'ytick.labelsize': 16, 'axes.labelsize': 16})
def load_data(basename):
    normflow = open("norm_flow/{}.txt".format(basename)).read()
    
    accuracies = list(map(lambda x: x.split(" ")[-1], re.findall(r"(acc: \d.\d+)", normflow)))
    valid = accuracies[1::3]
    if basename=="beta_blundell":
        valid=valid[:50]
    return np.array(valid).astype(np.float32)

f = plt.figure(figsize=(10, 8))

current_palette = sns.color_palette()

colors = sns.color_palette(n_colors = 2)
basenames = ["beta_blundell", "logfile"]
legends = ["Without Normalizing Flows", "With Normalizing Flows"]
for basename, color, legend in zip(basenames, colors, legends):
    vt = load_data(basename)
    x_ticks = range(vt.shape[0])
    plt.plot(x_ticks, vt, label=legend, color=color)

plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.xticks(x_ticks[5::5], map(lambda x: x, x_ticks[5::5]))
f.suptitle("With and without Normalizing Flows")
plt.legend(loc = 7)

plt.savefig("figs/normflow.pdf")

### Plotting results of transfer learning experiments