In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

import joblib
import random
import sys

from mip import Model as ModelMip, xsum, minimize, maximize, \
INTEGER, BINARY, CONTINUOUS, CutType, OptimizationStatus

import data_generator
from model import StrongStandardNet, StrongVariationalNet
from train import TrainDecoupled, TrainCombined

from sklearn.preprocessing import StandardScaler

import torch.multiprocessing as mp

In [2]:
import cvar_minimization_utils as cmu

In [3]:
is_cuda = False
dev = torch.device('cpu')  
if torch.cuda.is_available():
    is_cuda = True
    dev = torch.device('cuda') 

In [4]:
method_name = 'bnn'
method_learning = 'combined'
aleat_bool = 1

N_SAMPLES = 32
BATCH_SIZE_LOADER = 64

EPOCHS = 100

cpu_count = mp.cpu_count()

seed_number = 0

lr = 0.0005

In [5]:
np.random.seed(seed_number)
torch.manual_seed(seed_number)
random.seed(seed_number)

assert (method_name in ['ann','bnn','gp'])

if method_name in ['ann','bnn']:
    assert (method_learning in ['decoupled','combined'])
    assert (aleat_bool in [True, False])
    assert (N_SAMPLES>=1 and N_SAMPLES<9999)
    #assert (M_SAMPLES>=1 and M_SAMPLES<9999)

bnn = False 
if method_name == 'bnn':
    bnn = True   
    K = 1 # Hyperparameter for the training in ELBO loss
    PLV = 1 # Prior in ELBO loss   

model_name = method_name + '_constrained_'
for i in range(2, len(sys.argv)):
    model_name += '_'+sys.argv[i]
model_name += '_'+ str(seed_number)

In [6]:
def gen_intermidiate(num, n_assets, x1, x2, x3):      
        factor = num * 2/(n_assets)
        return x1**factor + x2**factor + x3**factor
    
def gen_data(N, n_assets, nl, seed_number=42):
    np.random.seed(seed_number)
    x1 = np.random.normal(1, 1, size = N).clip(0)
    x2 = np.random.normal(1, 1, size = N).clip(0)
    x3 = np.random.normal(1, 1, size = N).clip(0)
    X = np.vstack((x1, x2, x3)).T

    Y = np.zeros((N, n_assets))
    for i in range(1, n_assets + 1):
        interm = gen_intermidiate(i, n_assets, x1, x2, x3)
        Y[:,i-1] = (np.sin(interm) - np.sin(interm).mean()) \
        + np.sin(interm).std()*(-0.2 + np.random.random())
        
        ns = Y[:,0].shape[0]//2
        
        Y[:,i-1] = Y[:,i-1] + nl*Y[:,i-1]*x1*(np.random.beta(5, 2, size = Y[:,0].shape) - 0.2)

    return X, Y

def gen_cond_dist(N, n_assets, n_samples, nl, seed_number=420):
    np.random.seed(seed_number)
    Y_dist = np.zeros((n_samples, N, n_assets))
    for i in range(0, n_samples):
        Y_dist[i, :, :] = gen_data(N, n_assets, nl, seed_number=np.random.randint(0,999999))[1]
    return Y_dist

In [7]:
class CVaROP():
    def __init__(self, beta, n_assets, min_return, Y_train):
        self.beta = beta
        self.n_assets = n_assets
        self.R = min_return 
        self.uy = Y_train.mean(axis=0)
        
        self.zstar = np.zeros_like(Y_train)
        self.alphastar = np.zeros_like(Y_train[:,0])
        self.cvarstar = np.zeros_like(Y_train[:,0])
                        
        #for i in range(0, Y_train.shape[0]):
        #    self.zstar[i,:], _ , self.alphastar[i], self.cvarstar[i] = \
        #    self.minimize_cvar(np.expand_dims(Y_train, 0)[:,i,:])
      
    def minimize_cvar(self, y):
    
        n_samples = y.shape[0]
        n_assets = y.shape[1]
        
        assert self.n_assets == n_assets
        
        
        m = ModelMip("cvar")
        m.verbose = 0
        z = ([m.add_var(var_type=CONTINUOUS, name=f'z_{i}') for i in range(0, n_assets)])
        u = ([m.add_var(var_type=CONTINUOUS, name=f'u_{i}') for i in range(0, n_samples)])
        alpha = m.add_var(var_type=CONTINUOUS, name='alpha')

        m.objective = minimize(alpha + (1/(1 - self.beta))*(1/n_samples) \
                               *xsum(u[i] for i in range(0, n_samples)))

        for i in range(0, n_assets):
            m += z[i] >= 0

        for i in range(0, n_samples):
            m += u[i] >= 0

        for i in range(0, n_samples):
            m += xsum(z[j]*y[i][j] for j in range(0, n_assets)) + alpha + u[i] >= 0

        for i in range(0, n_assets):
            m += xsum(-z[i]*self.uy[i] for i in range(0, n_assets)) <= -self.R

        m.optimize()
        f_opt = m.objective_value
        argmins = []
        for v in m.vars:
            argmins.append(v.x)
           
        zstar = argmins[:n_assets]
        ustar = argmins[n_assets:n_assets+n_samples]
        alphastar = argmins[n_assets+n_samples]         

        return zstar, ustar, alphastar, f_opt
    
    
    def minimize_cvar_dataset(self, Y_dist):
        zstar = np.zeros_like(Y_dist[0])
        for i in range(0, Y_dist.shape[1]):
            zstar[i,:] , _, _, _ = self.minimize_cvar(Y_dist[:,i,:])       
        return zstar
    
    def minimize_cvar_given_z(self, y, z):
    
        n_samples = y.shape[0]
        n_assets = y.shape[1]
        
        assert self.n_assets == n_assets
        
        
        m = ModelMip("cvar")
        m.verbose = 0

        u = ([m.add_var(var_type=CONTINUOUS, name=f'u_{i}') for i in range(0, n_samples)])
        alpha = m.add_var(var_type=CONTINUOUS, name='alpha')

        m.objective = minimize(alpha + (1/(1 - self.beta))*(1/n_samples) \
                               *xsum(u[i] for i in range(0, n_samples)))

        for i in range(0, n_samples):
            m += u[i] >= 0

        for i in range(0, n_samples):
            m += xsum(z[j]*y[i][j] for j in range(0, n_assets)) + alpha + u[i] >= 0

        m.optimize()
        f_opt = m.objective_value
        argmins = []
        for v in m.vars:
            argmins.append(v.x)
           
        ustar = argmins[:n_samples]
        alphastar = argmins[n_samples]    

        return ustar, alphastar, f_opt
    
    
    def minimize_cvar_given_z_dataset(self, Y_dist, z):
        alphastar = np.zeros_like(Y_dist[0,:,0])
        for i in range(0, Y_dist.shape[1]):
            _, alphastar[i] , _ = self.minimize_cvar_given_z(Y_dist[:,i,:], z[i,:])
        return alphastar
        
    def calc_cvar_dataset(self, Y_dist, zstar_pred, alpha):
        cvar_pred = alpha + (1/(1-beta))*(np.maximum(-((Y_dist*zstar_pred).sum(2)) \
                                                              - alpha, 0)).mean(0)
        return cvar_pred
        
    
    def calc_f_dataset(self, Y_dist_pred, Y_dist):
        zstar_pred = self.minimize_cvar_dataset(self, Y_dist_pred)
        alphastar_pred = minimize_cvar_given_z_dataset(self, Y_dist, zstar_pred)
        cvar_pred = self.calc_cvar_dataset(Y_dist, zstar_pred)
        return cvar_pred
    
    
    def cost_fn(self, Y_dist_pred, Y_dist):
        f = self.calc_f_dataset(Y_dist_pred, Y_dist)
        f_total = torch.mean(f)
        return f_total
        
    #def end_loss(self, Y_dist_pred, Y_dist):
    #    y_pred = y_pred.unsqueeze(0)
    #    f_total = self.cost_fn(y_pred, y)
    #    return f_total
    
    def end_loss_dist(self, Y_dist_pred, Y_dist):
        f_total = self.cost_fn(Y_dist_pred, Y_dist)
        return f_total

In [8]:
N_train = 1000
N_val = 600
N_test = 600

n_assets = 15

n_samples_orig = 200

nl = 0.1

X, Y_original = gen_data(N_train, n_assets, nl, seed_number)
X_val, Y_val_original = gen_data(N_val, n_assets, nl, seed_number + 80)
X_test, Y_test_original = gen_data(N_test, n_assets, nl, seed_number + 160)

Y_dist = gen_cond_dist(N_train, n_assets, n_samples_orig, nl=nl, seed_number=seed_number)
Y_val_dist = gen_cond_dist(N_val, n_assets, n_samples_orig, nl=nl, seed_number=seed_number + 100)
Y_test_dist = gen_cond_dist(N_test, n_assets, n_samples_orig, nl=nl, seed_number=seed_number + 200)

In [9]:
# Output normalization
scaler = StandardScaler()
scaler.fit(Y_original)
tmean = torch.tensor(scaler.mean_).to(dev)
tstd = torch.tensor(scaler.scale_).to(dev)
joblib.dump(scaler, 'scaler_cvar.gz')

# Function to denormalize the data
def inverse_transform(yy):
    return yy*tstd + tmean

Y = scaler.transform(Y_original).copy()
X = torch.tensor(X, dtype=torch.float32)
Y = torch.tensor(Y, dtype=torch.float32)
data_train = data_generator.ArtificialDataset(X, Y)
training_loader = torch.utils.data.DataLoader(
    data_train, batch_size=BATCH_SIZE_LOADER,
    shuffle=True, num_workers=cpu_count)
Y_dist = torch.tensor(Y_dist, dtype=torch.float32)

Y_val = scaler.transform(Y_val_original).copy()
X_val = torch.tensor(X_val, dtype=torch.float32)
Y_val_original = torch.tensor(Y_val_original, dtype=torch.float32)
Y_val = torch.tensor(Y_val, dtype=torch.float32)
data_valid = data_generator.ArtificialDataset(X_val, Y_val)
validation_loader = torch.utils.data.DataLoader(
    data_valid, batch_size=BATCH_SIZE_LOADER,
    shuffle=False, num_workers=cpu_count)

X_test = torch.tensor(X_test, dtype=torch.float32)
Y_test_original = torch.tensor(
    Y_test_original, dtype=torch.float32)
data_test = data_generator.ArtificialDataset(
    X_test, Y_test_original)
test_loader = torch.utils.data.DataLoader(
data_test, batch_size=16,
shuffle=False, num_workers=cpu_count)

In [10]:
n_samples = Y.shape[0]
n_assets = Y.shape[1]
beta = 0.80
min_return = 100

In [23]:
method_name = 'bnn'
method_learning = 'combined'

bnn = False
aleat_bool = False
if method_name == 'bnn':
    bnn = True
    aleat_bool = True
N_SAMPLES = 16
K = 1

In [24]:
lr = 0.0005

In [25]:
if method_name == 'bnn':
    h = StrongVariationalNet(
    n_samples=N_SAMPLES,
    input_size=X.shape[1], 
    output_size=Y.shape[1], 
    plv=PLV, 
    dev=dev
    ).to(dev)

    
#ANN Baseline model
elif method_name == 'ann':
    h = StrongStandardNet(X.shape[1], Y.shape[1]).to(dev)
    K = 0 # There is no K in ANN
    N_SAMPLES = 1
    
opt_h = torch.optim.Adam(h.parameters(), lr=lr)
mse_loss = nn.MSELoss(reduction='none')


#opt_cvar_gZ = cmu.CVaROPgZ(1, n_assets, beta, min_return, torch.tensor(Y_original), dev)
op = cmu.RiskPortOP(N_SAMPLES, n_assets, min_return, torch.tensor(Y_original), dev)

# Decoupled learning approach
if method_learning == 'decoupled':
    train_NN = TrainDecoupled(
                    bnn = bnn,
                    model=h,
                    opt=opt_h,
                    loss_data=mse_loss,
                    K=K,
                    aleat_bool=aleat_bool,
                    training_loader=training_loader,
                    validation_loader=validation_loader,
                    dev=dev
                )

# Combined learning approach (end-to-end loss)
elif method_learning == 'combined':
    train_NN = TrainCombined(
                    bnn = bnn,
                    model=h,
                    opt=opt_h,
                    K=K,
                    aleat_bool=aleat_bool,
                    training_loader=training_loader,
                    scaler=scaler,
                    validation_loader=validation_loader,
                    OP=op,
                    dev=dev
                )

In [26]:
model_used = train_NN.train(EPOCHS=200, pre_train=-1)

  0%|▎                                                                   | 1/200 [00:03<11:13,  3.39s/it]

------------------EPOCH 1------------------
END LOSS 	 train 27.249 valid 26.922 
 KL LOSS 	 train 4.321 valid 4.319 
 TOTAL LOSS 	 train 31.569 valid 31.241 



  1%|▋                                                                   | 2/200 [00:06<10:54,  3.30s/it]

------------------EPOCH 2------------------
END LOSS 	 train 26.711 valid 26.43 
 KL LOSS 	 train 4.318 valid 4.316 
 TOTAL LOSS 	 train 31.029 valid 30.746 



  2%|█                                                                   | 3/200 [00:09<10:50,  3.30s/it]

------------------EPOCH 3------------------
END LOSS 	 train 25.401 valid 25.244 
 KL LOSS 	 train 4.315 valid 4.314 
 TOTAL LOSS 	 train 29.717 valid 29.558 



  2%|█▎                                                                  | 4/200 [00:13<10:57,  3.35s/it]

------------------EPOCH 4------------------
END LOSS 	 train 21.817 valid 23.541 
 KL LOSS 	 train 4.312 valid 4.311 
 TOTAL LOSS 	 train 26.13 valid 27.853 



  2%|█▋                                                                  | 5/200 [00:16<11:02,  3.40s/it]

------------------EPOCH 5------------------
END LOSS 	 train 18.92 valid 20.868 
 KL LOSS 	 train 4.31 valid 4.308 
 TOTAL LOSS 	 train 23.231 valid 25.177 



  3%|██                                                                  | 6/200 [00:20<11:06,  3.44s/it]

------------------EPOCH 6------------------
END LOSS 	 train 16.467 valid 18.79 
 KL LOSS 	 train 4.308 valid 4.306 
 TOTAL LOSS 	 train 20.775 valid 23.096 



  4%|██▍                                                                 | 7/200 [00:23<11:12,  3.48s/it]

------------------EPOCH 7------------------
END LOSS 	 train 14.91 valid 16.172 
 KL LOSS 	 train 4.305 valid 4.305 
 TOTAL LOSS 	 train 19.216 valid 20.477 



  4%|██▋                                                                 | 8/200 [00:27<11:11,  3.50s/it]

------------------EPOCH 8------------------
END LOSS 	 train 14.888 valid 20.987 
 KL LOSS 	 train 4.303 valid 4.302 
 TOTAL LOSS 	 train 19.192 valid 25.289 



  4%|███                                                                 | 9/200 [00:31<11:10,  3.51s/it]

------------------EPOCH 9------------------
END LOSS 	 train 14.619 valid 17.672 
 KL LOSS 	 train 4.301 valid 4.3 
 TOTAL LOSS 	 train 18.92 valid 21.972 



  5%|███▎                                                               | 10/200 [00:34<11:21,  3.59s/it]

------------------EPOCH 10------------------
END LOSS 	 train 13.959 valid 15.251 
 KL LOSS 	 train 4.299 valid 4.298 
 TOTAL LOSS 	 train 18.258 valid 19.549 



  6%|███▋                                                               | 11/200 [00:38<11:26,  3.63s/it]

------------------EPOCH 11------------------
END LOSS 	 train 13.714 valid 16.041 
 KL LOSS 	 train 4.297 valid 4.295 
 TOTAL LOSS 	 train 18.011 valid 20.336 



  6%|████                                                               | 12/200 [00:42<11:29,  3.67s/it]

------------------EPOCH 12------------------
END LOSS 	 train 13.737 valid 16.454 
 KL LOSS 	 train 4.295 valid 4.293 
 TOTAL LOSS 	 train 18.032 valid 20.747 



  6%|████▎                                                              | 13/200 [00:45<11:27,  3.67s/it]

------------------EPOCH 13------------------
END LOSS 	 train 13.523 valid 15.03 
 KL LOSS 	 train 4.292 valid 4.291 
 TOTAL LOSS 	 train 17.816 valid 19.322 



  7%|████▋                                                              | 14/200 [00:49<11:24,  3.68s/it]

------------------EPOCH 14------------------
END LOSS 	 train 13.685 valid 14.985 
 KL LOSS 	 train 4.29 valid 4.289 
 TOTAL LOSS 	 train 17.976 valid 19.274 



  8%|█████                                                              | 15/200 [00:53<11:27,  3.72s/it]

------------------EPOCH 15------------------
END LOSS 	 train 13.189 valid 14.441 
 KL LOSS 	 train 4.289 valid 4.287 
 TOTAL LOSS 	 train 17.478 valid 18.729 



  8%|█████▎                                                             | 16/200 [00:57<11:26,  3.73s/it]

------------------EPOCH 16------------------
END LOSS 	 train 12.915 valid 14.379 
 KL LOSS 	 train 4.286 valid 4.285 
 TOTAL LOSS 	 train 17.201 valid 18.664 



  8%|█████▋                                                             | 17/200 [01:00<11:25,  3.75s/it]

------------------EPOCH 17------------------
END LOSS 	 train 13.213 valid 14.645 
 KL LOSS 	 train 4.284 valid 4.284 
 TOTAL LOSS 	 train 17.498 valid 18.93 



  9%|██████                                                             | 18/200 [01:04<11:23,  3.76s/it]

------------------EPOCH 18------------------
END LOSS 	 train 12.786 valid 14.041 
 KL LOSS 	 train 4.282 valid 4.281 
 TOTAL LOSS 	 train 17.069 valid 18.323 



 10%|██████▎                                                            | 19/200 [01:08<11:28,  3.80s/it]

------------------EPOCH 19------------------
END LOSS 	 train 12.555 valid 13.116 
 KL LOSS 	 train 4.28 valid 4.28 
 TOTAL LOSS 	 train 16.836 valid 17.396 



 10%|██████▋                                                            | 20/200 [01:12<11:14,  3.75s/it]

------------------EPOCH 20------------------
END LOSS 	 train 12.371 valid 13.629 
 KL LOSS 	 train 4.279 valid 4.278 
 TOTAL LOSS 	 train 16.651 valid 17.907 



 10%|███████                                                            | 21/200 [01:15<11:03,  3.71s/it]

------------------EPOCH 21------------------
END LOSS 	 train 12.127 valid 12.268 
 KL LOSS 	 train 4.277 valid 4.275 
 TOTAL LOSS 	 train 16.404 valid 16.544 



 11%|███████▎                                                           | 22/200 [01:19<11:17,  3.80s/it]

------------------EPOCH 22------------------
END LOSS 	 train 12.368 valid 12.777 
 KL LOSS 	 train 4.275 valid 4.274 
 TOTAL LOSS 	 train 16.643 valid 17.052 



 12%|███████▋                                                           | 23/200 [01:23<11:25,  3.87s/it]

------------------EPOCH 23------------------
END LOSS 	 train 12.002 valid 12.973 
 KL LOSS 	 train 4.273 valid 4.272 
 TOTAL LOSS 	 train 16.275 valid 17.245 



 12%|████████                                                           | 24/200 [01:28<11:36,  3.96s/it]

------------------EPOCH 24------------------
END LOSS 	 train 11.707 valid 11.116 
 KL LOSS 	 train 4.271 valid 4.27 
 TOTAL LOSS 	 train 15.978 valid 15.386 



 12%|████████▍                                                          | 25/200 [01:32<11:39,  4.00s/it]

------------------EPOCH 25------------------
END LOSS 	 train 11.263 valid 10.811 
 KL LOSS 	 train 4.27 valid 4.268 
 TOTAL LOSS 	 train 15.533 valid 15.079 



 13%|████████▋                                                          | 26/200 [01:36<11:34,  3.99s/it]

------------------EPOCH 26------------------
END LOSS 	 train 10.891 valid 9.885 
 KL LOSS 	 train 4.268 valid 4.266 
 TOTAL LOSS 	 train 15.159 valid 14.152 



 14%|█████████                                                          | 27/200 [01:40<11:46,  4.08s/it]

------------------EPOCH 27------------------
END LOSS 	 train 10.971 valid 11.871 
 KL LOSS 	 train 4.266 valid 4.265 
 TOTAL LOSS 	 train 15.237 valid 16.137 



 14%|█████████▍                                                         | 28/200 [01:44<11:49,  4.12s/it]

------------------EPOCH 28------------------
END LOSS 	 train 10.002 valid 7.432 
 KL LOSS 	 train 4.264 valid 4.264 
 TOTAL LOSS 	 train 14.266 valid 11.697 



 14%|█████████▋                                                         | 29/200 [01:48<11:33,  4.06s/it]

------------------EPOCH 29------------------
END LOSS 	 train 10.002 valid 6.923 
 KL LOSS 	 train 4.263 valid 4.262 
 TOTAL LOSS 	 train 14.265 valid 11.185 



 15%|██████████                                                         | 30/200 [01:52<11:24,  4.02s/it]

------------------EPOCH 30------------------
END LOSS 	 train 9.309 valid 7.044 
 KL LOSS 	 train 4.261 valid 4.26 
 TOTAL LOSS 	 train 13.57 valid 11.305 



 16%|██████████▍                                                        | 31/200 [01:56<11:21,  4.03s/it]

------------------EPOCH 31------------------
END LOSS 	 train 11.214 valid 14.043 
 KL LOSS 	 train 4.259 valid 4.259 
 TOTAL LOSS 	 train 15.473 valid 18.302 



 16%|██████████▋                                                        | 32/200 [02:00<11:13,  4.01s/it]

------------------EPOCH 32------------------
END LOSS 	 train 10.509 valid 11.067 
 KL LOSS 	 train 4.257 valid 4.257 
 TOTAL LOSS 	 train 14.766 valid 15.324 



 16%|███████████                                                        | 33/200 [02:04<11:10,  4.02s/it]

------------------EPOCH 33------------------
END LOSS 	 train 9.302 valid 8.748 
 KL LOSS 	 train 4.256 valid 4.256 
 TOTAL LOSS 	 train 13.559 valid 13.004 



 17%|███████████▍                                                       | 34/200 [02:08<11:04,  4.00s/it]

------------------EPOCH 34------------------
END LOSS 	 train 7.778 valid 4.789 
 KL LOSS 	 train 4.254 valid 4.253 
 TOTAL LOSS 	 train 12.033 valid 9.043 



 18%|███████████▋                                                       | 35/200 [02:12<11:01,  4.01s/it]

------------------EPOCH 35------------------
END LOSS 	 train 8.573 valid 5.984 
 KL LOSS 	 train 4.253 valid 4.252 
 TOTAL LOSS 	 train 12.826 valid 10.237 



 18%|████████████                                                       | 36/200 [02:16<10:54,  3.99s/it]

------------------EPOCH 36------------------
END LOSS 	 train 7.822 valid 8.382 
 KL LOSS 	 train 4.251 valid 4.251 
 TOTAL LOSS 	 train 12.074 valid 12.634 



 18%|████████████▍                                                      | 37/200 [02:20<10:55,  4.02s/it]

------------------EPOCH 37------------------
END LOSS 	 train 8.045 valid 3.776 
 KL LOSS 	 train 4.25 valid 4.25 
 TOTAL LOSS 	 train 12.295 valid 8.026 



 19%|████████████▋                                                      | 38/200 [02:24<10:50,  4.01s/it]

------------------EPOCH 38------------------
END LOSS 	 train 6.917 valid 6.615 
 KL LOSS 	 train 4.248 valid 4.247 
 TOTAL LOSS 	 train 11.166 valid 10.862 



 20%|█████████████                                                      | 39/200 [02:28<10:48,  4.02s/it]

------------------EPOCH 39------------------
END LOSS 	 train 8.036 valid 6.042 
 KL LOSS 	 train 4.247 valid 4.246 
 TOTAL LOSS 	 train 12.283 valid 10.288 



 20%|█████████████▍                                                     | 40/200 [02:32<10:39,  4.00s/it]

------------------EPOCH 40------------------
END LOSS 	 train 8.204 valid 6.695 
 KL LOSS 	 train 4.245 valid 4.245 
 TOTAL LOSS 	 train 12.45 valid 10.94 



 20%|█████████████▋                                                     | 41/200 [02:36<10:32,  3.98s/it]

------------------EPOCH 41------------------
END LOSS 	 train 9.265 valid 4.679 
 KL LOSS 	 train 4.244 valid 4.243 
 TOTAL LOSS 	 train 13.509 valid 8.922 



 21%|██████████████                                                     | 42/200 [02:40<10:26,  3.96s/it]

------------------EPOCH 42------------------
END LOSS 	 train 6.979 valid 6.054 
 KL LOSS 	 train 4.243 valid 4.242 
 TOTAL LOSS 	 train 11.222 valid 10.297 



 22%|██████████████▍                                                    | 43/200 [02:44<10:23,  3.97s/it]

------------------EPOCH 43------------------
END LOSS 	 train 6.478 valid 3.118 
 KL LOSS 	 train 4.242 valid 4.24 
 TOTAL LOSS 	 train 10.72 valid 7.359 



 22%|██████████████▋                                                    | 44/200 [02:48<10:23,  4.00s/it]

------------------EPOCH 44------------------
END LOSS 	 train 6.022 valid 10.272 
 KL LOSS 	 train 4.24 valid 4.24 
 TOTAL LOSS 	 train 10.262 valid 14.512 



 22%|███████████████                                                    | 45/200 [02:52<10:15,  3.97s/it]

------------------EPOCH 45------------------
END LOSS 	 train 6.655 valid 5.352 
 KL LOSS 	 train 4.239 valid 4.238 
 TOTAL LOSS 	 train 10.895 valid 9.59 



 23%|███████████████▍                                                   | 46/200 [02:56<10:11,  3.97s/it]

------------------EPOCH 46------------------
END LOSS 	 train 5.991 valid 2.761 
 KL LOSS 	 train 4.238 valid 4.236 
 TOTAL LOSS 	 train 10.23 valid 6.998 



 24%|███████████████▋                                                   | 47/200 [03:00<10:14,  4.01s/it]

------------------EPOCH 47------------------
END LOSS 	 train 6.613 valid 1.852 
 KL LOSS 	 train 4.236 valid 4.235 
 TOTAL LOSS 	 train 10.85 valid 6.088 



 24%|████████████████                                                   | 48/200 [03:04<10:09,  4.01s/it]

------------------EPOCH 48------------------
END LOSS 	 train 5.335 valid 3.918 
 KL LOSS 	 train 4.235 valid 4.234 
 TOTAL LOSS 	 train 9.571 valid 8.153 



 24%|████████████████▍                                                  | 49/200 [03:08<10:07,  4.03s/it]

------------------EPOCH 49------------------
END LOSS 	 train 5.693 valid 2.459 
 KL LOSS 	 train 4.234 valid 4.234 
 TOTAL LOSS 	 train 9.928 valid 6.693 



 25%|████████████████▊                                                  | 50/200 [03:12<10:05,  4.03s/it]

------------------EPOCH 50------------------
END LOSS 	 train 5.211 valid 4.877 
 KL LOSS 	 train 4.233 valid 4.231 
 TOTAL LOSS 	 train 9.444 valid 9.109 



 26%|█████████████████                                                  | 51/200 [03:16<10:00,  4.03s/it]

------------------EPOCH 51------------------
END LOSS 	 train 5.542 valid 2.274 
 KL LOSS 	 train 4.231 valid 4.231 
 TOTAL LOSS 	 train 9.774 valid 6.505 



 26%|█████████████████▍                                                 | 52/200 [03:20<09:54,  4.02s/it]

------------------EPOCH 52------------------
END LOSS 	 train 4.983 valid 3.179 
 KL LOSS 	 train 4.23 valid 4.23 
 TOTAL LOSS 	 train 9.213 valid 7.409 



 26%|█████████████████▊                                                 | 53/200 [03:24<09:49,  4.01s/it]

------------------EPOCH 53------------------
END LOSS 	 train 4.502 valid 2.459 
 KL LOSS 	 train 4.229 valid 4.228 
 TOTAL LOSS 	 train 8.731 valid 6.688 



 27%|██████████████████                                                 | 54/200 [03:28<09:42,  3.99s/it]

------------------EPOCH 54------------------
END LOSS 	 train 5.194 valid 11.235 
 KL LOSS 	 train 4.228 valid 4.228 
 TOTAL LOSS 	 train 9.422 valid 15.463 



 28%|██████████████████▍                                                | 55/200 [03:32<09:37,  3.98s/it]

------------------EPOCH 55------------------
END LOSS 	 train 6.466 valid 2.123 
 KL LOSS 	 train 4.227 valid 4.227 
 TOTAL LOSS 	 train 10.693 valid 6.35 



 28%|██████████████████▊                                                | 56/200 [03:36<09:38,  4.02s/it]

------------------EPOCH 56------------------
END LOSS 	 train 6.011 valid 3.161 
 KL LOSS 	 train 4.226 valid 4.224 
 TOTAL LOSS 	 train 10.237 valid 7.385 



 28%|███████████████████                                                | 57/200 [03:40<09:38,  4.04s/it]

------------------EPOCH 57------------------
END LOSS 	 train 4.81 valid 1.93 
 KL LOSS 	 train 4.224 valid 4.223 
 TOTAL LOSS 	 train 9.034 valid 6.153 



 29%|███████████████████▍                                               | 58/200 [03:44<09:39,  4.08s/it]

------------------EPOCH 58------------------
END LOSS 	 train 4.161 valid 2.928 
 KL LOSS 	 train 4.223 valid 4.223 
 TOTAL LOSS 	 train 8.385 valid 7.151 



 30%|███████████████████▊                                               | 59/200 [03:48<09:31,  4.06s/it]

------------------EPOCH 59------------------
END LOSS 	 train 4.352 valid 1.813 
 KL LOSS 	 train 4.222 valid 4.221 
 TOTAL LOSS 	 train 8.574 valid 6.034 



 30%|████████████████████                                               | 60/200 [03:52<09:21,  4.01s/it]

------------------EPOCH 60------------------
END LOSS 	 train 4.146 valid 1.871 
 KL LOSS 	 train 4.221 valid 4.22 
 TOTAL LOSS 	 train 8.368 valid 6.092 



 30%|████████████████████▍                                              | 61/200 [03:56<09:16,  4.00s/it]

------------------EPOCH 61------------------
END LOSS 	 train 3.971 valid 2.13 
 KL LOSS 	 train 4.22 valid 4.219 
 TOTAL LOSS 	 train 8.191 valid 6.349 



 31%|████████████████████▊                                              | 62/200 [04:00<09:15,  4.02s/it]

------------------EPOCH 62------------------
END LOSS 	 train 4.096 valid 2.012 
 KL LOSS 	 train 4.219 valid 4.219 
 TOTAL LOSS 	 train 8.315 valid 6.231 



 32%|█████████████████████                                              | 63/200 [04:04<09:06,  3.99s/it]

------------------EPOCH 63------------------
END LOSS 	 train 4.39 valid 1.958 
 KL LOSS 	 train 4.218 valid 4.217 
 TOTAL LOSS 	 train 8.608 valid 6.175 



 32%|█████████████████████▍                                             | 64/200 [04:08<09:11,  4.05s/it]

------------------EPOCH 64------------------
END LOSS 	 train 3.758 valid 1.613 
 KL LOSS 	 train 4.217 valid 4.216 
 TOTAL LOSS 	 train 7.975 valid 5.829 



 32%|█████████████████████▊                                             | 65/200 [04:12<09:01,  4.01s/it]

------------------EPOCH 65------------------
END LOSS 	 train 3.892 valid 1.469 
 KL LOSS 	 train 4.216 valid 4.214 
 TOTAL LOSS 	 train 8.108 valid 5.684 



 33%|██████████████████████                                             | 66/200 [04:16<08:51,  3.97s/it]

------------------EPOCH 66------------------
END LOSS 	 train 4.142 valid 2.691 
 KL LOSS 	 train 4.214 valid 4.214 
 TOTAL LOSS 	 train 8.357 valid 6.905 



 34%|██████████████████████▍                                            | 67/200 [04:20<08:47,  3.97s/it]

------------------EPOCH 67------------------
END LOSS 	 train 4.194 valid 1.896 
 KL LOSS 	 train 4.213 valid 4.214 
 TOTAL LOSS 	 train 8.408 valid 6.11 



 34%|██████████████████████▊                                            | 68/200 [04:24<08:43,  3.97s/it]

------------------EPOCH 68------------------
END LOSS 	 train 4.328 valid 2.658 
 KL LOSS 	 train 4.212 valid 4.213 
 TOTAL LOSS 	 train 8.54 valid 6.871 



 34%|███████████████████████                                            | 69/200 [04:28<08:40,  3.97s/it]

------------------EPOCH 69------------------
END LOSS 	 train 4.627 valid 1.271 
 KL LOSS 	 train 4.211 valid 4.211 
 TOTAL LOSS 	 train 8.838 valid 5.483 



 35%|███████████████████████▍                                           | 70/200 [04:32<08:37,  3.98s/it]

------------------EPOCH 70------------------
END LOSS 	 train 4.35 valid 2.423 
 KL LOSS 	 train 4.21 valid 4.211 
 TOTAL LOSS 	 train 8.561 valid 6.634 



 36%|███████████████████████▊                                           | 71/200 [04:36<08:36,  4.00s/it]

------------------EPOCH 71------------------
END LOSS 	 train 3.576 valid 2.446 
 KL LOSS 	 train 4.209 valid 4.209 
 TOTAL LOSS 	 train 7.786 valid 6.655 



 36%|████████████████████████                                           | 72/200 [04:40<08:30,  3.99s/it]

------------------EPOCH 72------------------
END LOSS 	 train 3.708 valid 2.266 
 KL LOSS 	 train 4.208 valid 4.209 
 TOTAL LOSS 	 train 7.917 valid 6.476 



 36%|████████████████████████▍                                          | 73/200 [04:44<08:28,  4.01s/it]

------------------EPOCH 73------------------
END LOSS 	 train 3.301 valid 1.203 
 KL LOSS 	 train 4.207 valid 4.206 
 TOTAL LOSS 	 train 7.509 valid 5.409 



 37%|████████████████████████▊                                          | 74/200 [04:48<08:21,  3.98s/it]

------------------EPOCH 74------------------
END LOSS 	 train 3.583 valid 3.216 
 KL LOSS 	 train 4.206 valid 4.206 
 TOTAL LOSS 	 train 7.79 valid 7.423 



 38%|█████████████████████████▏                                         | 75/200 [04:52<08:14,  3.95s/it]

------------------EPOCH 75------------------
END LOSS 	 train 3.256 valid 1.875 
 KL LOSS 	 train 4.205 valid 4.204 
 TOTAL LOSS 	 train 7.462 valid 6.079 



 38%|█████████████████████████▍                                         | 76/200 [04:56<08:08,  3.94s/it]

------------------EPOCH 76------------------
END LOSS 	 train 3.607 valid 1.625 
 KL LOSS 	 train 4.204 valid 4.205 
 TOTAL LOSS 	 train 7.812 valid 5.83 



 38%|█████████████████████████▊                                         | 77/200 [05:00<08:09,  3.98s/it]

------------------EPOCH 77------------------
END LOSS 	 train 3.323 valid 1.17 
 KL LOSS 	 train 4.203 valid 4.203 
 TOTAL LOSS 	 train 7.527 valid 5.373 



 39%|██████████████████████████▏                                        | 78/200 [05:04<08:04,  3.97s/it]

------------------EPOCH 78------------------
END LOSS 	 train 3.465 valid 4.127 
 KL LOSS 	 train 4.202 valid 4.202 
 TOTAL LOSS 	 train 7.668 valid 8.329 



 40%|██████████████████████████▍                                        | 79/200 [05:08<08:03,  4.00s/it]

------------------EPOCH 79------------------
END LOSS 	 train 3.799 valid 1.52 
 KL LOSS 	 train 4.201 valid 4.201 
 TOTAL LOSS 	 train 8.0 valid 5.721 



 40%|██████████████████████████▊                                        | 80/200 [05:12<07:59,  3.99s/it]

------------------EPOCH 80------------------
END LOSS 	 train 3.275 valid 1.665 
 KL LOSS 	 train 4.201 valid 4.199 
 TOTAL LOSS 	 train 7.476 valid 5.865 



 40%|███████████████████████████▏                                       | 81/200 [05:16<07:54,  3.99s/it]

------------------EPOCH 81------------------
END LOSS 	 train 3.47 valid 2.689 
 KL LOSS 	 train 4.199 valid 4.2 
 TOTAL LOSS 	 train 7.67 valid 6.89 



 41%|███████████████████████████▍                                       | 82/200 [05:20<07:48,  3.97s/it]

------------------EPOCH 82------------------
END LOSS 	 train 3.124 valid 1.386 
 KL LOSS 	 train 4.199 valid 4.198 
 TOTAL LOSS 	 train 7.323 valid 5.584 



 42%|███████████████████████████▊                                       | 83/200 [05:24<07:49,  4.01s/it]

------------------EPOCH 83------------------
END LOSS 	 train 3.555 valid 1.252 
 KL LOSS 	 train 4.198 valid 4.198 
 TOTAL LOSS 	 train 7.753 valid 5.45 



 42%|████████████████████████████▏                                      | 84/200 [05:28<07:40,  3.97s/it]

------------------EPOCH 84------------------
END LOSS 	 train 4.098 valid 1.37 
 KL LOSS 	 train 4.197 valid 4.196 
 TOTAL LOSS 	 train 8.295 valid 5.566 



 42%|████████████████████████████▍                                      | 85/200 [05:32<07:34,  3.95s/it]

------------------EPOCH 85------------------
END LOSS 	 train 3.938 valid 1.521 
 KL LOSS 	 train 4.196 valid 4.195 
 TOTAL LOSS 	 train 8.135 valid 5.716 



 43%|████████████████████████████▊                                      | 86/200 [05:36<07:26,  3.91s/it]

------------------EPOCH 86------------------
END LOSS 	 train 3.162 valid 2.348 
 KL LOSS 	 train 4.195 valid 4.195 
 TOTAL LOSS 	 train 7.357 valid 6.544 



 44%|█████████████████████████████▏                                     | 87/200 [05:40<07:24,  3.93s/it]

------------------EPOCH 87------------------
END LOSS 	 train 3.185 valid 1.962 
 KL LOSS 	 train 4.194 valid 4.195 
 TOTAL LOSS 	 train 7.38 valid 6.157 



 44%|█████████████████████████████▍                                     | 88/200 [05:44<07:22,  3.95s/it]

------------------EPOCH 88------------------
END LOSS 	 train 3.138 valid 1.294 
 KL LOSS 	 train 4.193 valid 4.193 
 TOTAL LOSS 	 train 7.331 valid 5.487 



 44%|█████████████████████████████▊                                     | 89/200 [05:48<07:18,  3.95s/it]

------------------EPOCH 89------------------
END LOSS 	 train 3.264 valid 1.461 
 KL LOSS 	 train 4.193 valid 4.191 
 TOTAL LOSS 	 train 7.457 valid 5.653 



 45%|██████████████████████████████▏                                    | 90/200 [05:52<07:14,  3.95s/it]

------------------EPOCH 90------------------
END LOSS 	 train 3.049 valid 2.702 
 KL LOSS 	 train 4.192 valid 4.191 
 TOTAL LOSS 	 train 7.241 valid 6.893 



 46%|██████████████████████████████▍                                    | 91/200 [05:55<07:11,  3.96s/it]

------------------EPOCH 91------------------
END LOSS 	 train 3.137 valid 1.823 
 KL LOSS 	 train 4.191 valid 4.19 
 TOTAL LOSS 	 train 7.329 valid 6.013 



 46%|██████████████████████████████▊                                    | 92/200 [05:59<07:04,  3.93s/it]

------------------EPOCH 92------------------
END LOSS 	 train 3.753 valid 5.015 
 KL LOSS 	 train 4.19 valid 4.19 
 TOTAL LOSS 	 train 7.944 valid 9.205 



 46%|███████████████████████████████▏                                   | 93/200 [06:03<07:03,  3.96s/it]

------------------EPOCH 93------------------
END LOSS 	 train 3.043 valid 1.318 
 KL LOSS 	 train 4.189 valid 4.19 
 TOTAL LOSS 	 train 7.232 valid 5.508 



 47%|███████████████████████████████▍                                   | 94/200 [06:07<07:01,  3.98s/it]

------------------EPOCH 94------------------
END LOSS 	 train 3.041 valid 0.956 
 KL LOSS 	 train 4.189 valid 4.187 
 TOTAL LOSS 	 train 7.23 valid 5.144 



 48%|███████████████████████████████▊                                   | 95/200 [06:11<06:57,  3.98s/it]

------------------EPOCH 95------------------
END LOSS 	 train 2.908 valid 1.546 
 KL LOSS 	 train 4.187 valid 4.186 
 TOTAL LOSS 	 train 7.096 valid 5.733 



 48%|████████████████████████████████▏                                  | 96/200 [06:15<06:52,  3.97s/it]

------------------EPOCH 96------------------
END LOSS 	 train 2.785 valid 1.422 
 KL LOSS 	 train 4.187 valid 4.187 
 TOTAL LOSS 	 train 6.972 valid 5.609 



 48%|████████████████████████████████▍                                  | 97/200 [06:19<06:52,  4.00s/it]

------------------EPOCH 97------------------
END LOSS 	 train 2.883 valid 1.372 
 KL LOSS 	 train 4.186 valid 4.185 
 TOTAL LOSS 	 train 7.069 valid 5.558 



 49%|████████████████████████████████▊                                  | 98/200 [06:23<06:47,  3.99s/it]

------------------EPOCH 98------------------
END LOSS 	 train 3.185 valid 1.843 
 KL LOSS 	 train 4.185 valid 4.187 
 TOTAL LOSS 	 train 7.37 valid 6.03 



 50%|█████████████████████████████████▏                                 | 99/200 [06:27<06:40,  3.97s/it]

------------------EPOCH 99------------------
END LOSS 	 train 3.03 valid 1.385 
 KL LOSS 	 train 4.184 valid 4.184 
 TOTAL LOSS 	 train 7.215 valid 5.569 



 50%|█████████████████████████████████                                 | 100/200 [06:31<06:34,  3.95s/it]

------------------EPOCH 100------------------
END LOSS 	 train 3.02 valid 1.974 
 KL LOSS 	 train 4.184 valid 4.184 
 TOTAL LOSS 	 train 7.204 valid 6.159 



 50%|█████████████████████████████████▎                                | 101/200 [06:35<06:34,  3.98s/it]

------------------EPOCH 101------------------
END LOSS 	 train 2.976 valid 2.319 
 KL LOSS 	 train 4.183 valid 4.183 
 TOTAL LOSS 	 train 7.159 valid 6.503 



 51%|█████████████████████████████████▋                                | 102/200 [06:39<06:31,  4.00s/it]

------------------EPOCH 102------------------
END LOSS 	 train 3.233 valid 1.323 
 KL LOSS 	 train 4.182 valid 4.18 
 TOTAL LOSS 	 train 7.416 valid 5.504 



 52%|█████████████████████████████████▉                                | 103/200 [06:43<06:27,  3.99s/it]

------------------EPOCH 103------------------
END LOSS 	 train 3.296 valid 1.678 
 KL LOSS 	 train 4.182 valid 4.181 
 TOTAL LOSS 	 train 7.479 valid 5.86 



 52%|██████████████████████████████████▎                               | 104/200 [06:47<06:23,  4.00s/it]

------------------EPOCH 104------------------
END LOSS 	 train 3.079 valid 1.699 
 KL LOSS 	 train 4.18 valid 4.181 
 TOTAL LOSS 	 train 7.26 valid 5.881 



 52%|██████████████████████████████████▋                               | 105/200 [06:51<06:21,  4.01s/it]

------------------EPOCH 105------------------
END LOSS 	 train 3.018 valid 1.83 
 KL LOSS 	 train 4.18 valid 4.179 
 TOTAL LOSS 	 train 7.198 valid 6.009 



 53%|██████████████████████████████████▉                               | 106/200 [06:55<06:16,  4.01s/it]

------------------EPOCH 106------------------
END LOSS 	 train 2.86 valid 1.047 
 KL LOSS 	 train 4.179 valid 4.178 
 TOTAL LOSS 	 train 7.039 valid 5.226 



 54%|███████████████████████████████████▎                              | 107/200 [06:59<06:16,  4.04s/it]

------------------EPOCH 107------------------
END LOSS 	 train 2.922 valid 1.821 
 KL LOSS 	 train 4.178 valid 4.178 
 TOTAL LOSS 	 train 7.101 valid 5.999 



 54%|███████████████████████████████████▋                              | 108/200 [07:04<06:18,  4.12s/it]

------------------EPOCH 108------------------
END LOSS 	 train 3.12 valid 1.128 
 KL LOSS 	 train 4.177 valid 4.178 
 TOTAL LOSS 	 train 7.298 valid 5.306 



 55%|███████████████████████████████████▉                              | 109/200 [07:08<06:11,  4.08s/it]

------------------EPOCH 109------------------
END LOSS 	 train 3.042 valid 2.126 
 KL LOSS 	 train 4.177 valid 4.177 
 TOTAL LOSS 	 train 7.219 valid 6.304 



 55%|████████████████████████████████████▎                             | 110/200 [07:12<06:05,  4.06s/it]

------------------EPOCH 110------------------
END LOSS 	 train 3.35 valid 1.519 
 KL LOSS 	 train 4.176 valid 4.176 
 TOTAL LOSS 	 train 7.527 valid 5.695 



 56%|████████████████████████████████████▋                             | 111/200 [07:16<05:59,  4.04s/it]

------------------EPOCH 111------------------
END LOSS 	 train 3.179 valid 1.602 
 KL LOSS 	 train 4.176 valid 4.175 
 TOTAL LOSS 	 train 7.355 valid 5.778 



 56%|████████████████████████████████████▉                             | 112/200 [07:20<05:51,  4.00s/it]

------------------EPOCH 112------------------
END LOSS 	 train 3.178 valid 1.182 
 KL LOSS 	 train 4.175 valid 4.175 
 TOTAL LOSS 	 train 7.353 valid 5.358 



 56%|█████████████████████████████████████▎                            | 113/200 [07:24<05:47,  4.00s/it]

------------------EPOCH 113------------------
END LOSS 	 train 3.003 valid 1.772 
 KL LOSS 	 train 4.174 valid 4.173 
 TOTAL LOSS 	 train 7.178 valid 5.946 



 57%|█████████████████████████████████████▌                            | 114/200 [07:27<05:40,  3.96s/it]

------------------EPOCH 114------------------
END LOSS 	 train 2.832 valid 1.242 
 KL LOSS 	 train 4.174 valid 4.174 
 TOTAL LOSS 	 train 7.006 valid 5.417 



 57%|█████████████████████████████████████▉                            | 115/200 [07:31<05:36,  3.95s/it]

------------------EPOCH 115------------------
END LOSS 	 train 2.857 valid 2.158 
 KL LOSS 	 train 4.173 valid 4.173 
 TOTAL LOSS 	 train 7.03 valid 6.331 



 58%|██████████████████████████████████████▎                           | 116/200 [07:35<05:31,  3.95s/it]

------------------EPOCH 116------------------
END LOSS 	 train 3.413 valid 1.74 
 KL LOSS 	 train 4.172 valid 4.172 
 TOTAL LOSS 	 train 7.586 valid 5.913 



 58%|██████████████████████████████████████▌                           | 117/200 [07:39<05:30,  3.98s/it]

------------------EPOCH 117------------------
END LOSS 	 train 3.041 valid 1.556 
 KL LOSS 	 train 4.172 valid 4.172 
 TOTAL LOSS 	 train 7.213 valid 5.728 



 59%|██████████████████████████████████████▉                           | 118/200 [07:43<05:25,  3.97s/it]

------------------EPOCH 118------------------
END LOSS 	 train 2.916 valid 1.369 
 KL LOSS 	 train 4.171 valid 4.17 
 TOTAL LOSS 	 train 7.087 valid 5.539 



 60%|███████████████████████████████████████▎                          | 119/200 [07:47<05:23,  3.99s/it]

------------------EPOCH 119------------------
END LOSS 	 train 3.104 valid 2.262 
 KL LOSS 	 train 4.17 valid 4.169 
 TOTAL LOSS 	 train 7.275 valid 6.432 



 60%|███████████████████████████████████████▌                          | 120/200 [07:52<05:25,  4.07s/it]

------------------EPOCH 120------------------
END LOSS 	 train 2.925 valid 1.036 
 KL LOSS 	 train 4.169 valid 4.169 
 TOTAL LOSS 	 train 7.095 valid 5.206 



 60%|███████████████████████████████████████▉                          | 121/200 [07:56<05:18,  4.04s/it]

------------------EPOCH 121------------------
END LOSS 	 train 2.692 valid 1.238 
 KL LOSS 	 train 4.169 valid 4.169 
 TOTAL LOSS 	 train 6.862 valid 5.407 



 61%|████████████████████████████████████████▎                         | 122/200 [08:00<05:14,  4.03s/it]

------------------EPOCH 122------------------
END LOSS 	 train 2.657 valid 1.571 
 KL LOSS 	 train 4.168 valid 4.169 
 TOTAL LOSS 	 train 6.825 valid 5.74 



 62%|████████████████████████████████████████▌                         | 123/200 [08:04<05:09,  4.01s/it]

------------------EPOCH 123------------------
END LOSS 	 train 3.032 valid 1.567 
 KL LOSS 	 train 4.167 valid 4.168 
 TOTAL LOSS 	 train 7.2 valid 5.736 



 62%|████████████████████████████████████████▉                         | 124/200 [08:08<05:03,  3.99s/it]

------------------EPOCH 124------------------
END LOSS 	 train 2.849 valid 1.234 
 KL LOSS 	 train 4.167 valid 4.167 
 TOTAL LOSS 	 train 7.017 valid 5.402 



 62%|█████████████████████████████████████████▎                        | 125/200 [08:12<05:01,  4.01s/it]

------------------EPOCH 125------------------
END LOSS 	 train 2.66 valid 1.346 
 KL LOSS 	 train 4.166 valid 4.166 
 TOTAL LOSS 	 train 6.826 valid 5.513 



 63%|█████████████████████████████████████████▌                        | 126/200 [08:16<05:00,  4.07s/it]

------------------EPOCH 126------------------
END LOSS 	 train 2.651 valid 1.072 
 KL LOSS 	 train 4.166 valid 4.165 
 TOTAL LOSS 	 train 6.817 valid 5.237 



 64%|█████████████████████████████████████████▉                        | 127/200 [08:20<04:57,  4.07s/it]

------------------EPOCH 127------------------
END LOSS 	 train 2.867 valid 1.308 
 KL LOSS 	 train 4.165 valid 4.164 
 TOTAL LOSS 	 train 7.032 valid 5.473 



 64%|██████████████████████████████████████████▏                       | 128/200 [08:24<04:51,  4.05s/it]

------------------EPOCH 128------------------
END LOSS 	 train 2.729 valid 2.475 
 KL LOSS 	 train 4.164 valid 4.165 
 TOTAL LOSS 	 train 6.894 valid 6.64 



 64%|██████████████████████████████████████████▌                       | 129/200 [08:28<04:47,  4.04s/it]

------------------EPOCH 129------------------
END LOSS 	 train 2.794 valid 1.914 
 KL LOSS 	 train 4.164 valid 4.165 
 TOTAL LOSS 	 train 6.958 valid 6.079 



 65%|██████████████████████████████████████████▉                       | 130/200 [08:32<04:41,  4.02s/it]

------------------EPOCH 130------------------
END LOSS 	 train 2.708 valid 2.395 
 KL LOSS 	 train 4.163 valid 4.163 
 TOTAL LOSS 	 train 6.871 valid 6.558 



 66%|███████████████████████████████████████████▏                      | 131/200 [08:36<04:35,  4.00s/it]

------------------EPOCH 131------------------
END LOSS 	 train 2.887 valid 2.099 
 KL LOSS 	 train 4.163 valid 4.163 
 TOTAL LOSS 	 train 7.05 valid 6.262 



 66%|███████████████████████████████████████████▌                      | 132/200 [08:40<04:33,  4.02s/it]

------------------EPOCH 132------------------
END LOSS 	 train 2.687 valid 1.575 
 KL LOSS 	 train 4.162 valid 4.162 
 TOTAL LOSS 	 train 6.849 valid 5.737 



 66%|███████████████████████████████████████████▉                      | 133/200 [08:44<04:27,  4.00s/it]

------------------EPOCH 133------------------
END LOSS 	 train 3.153 valid 3.869 
 KL LOSS 	 train 4.162 valid 4.161 
 TOTAL LOSS 	 train 7.315 valid 8.03 



 67%|████████████████████████████████████████████▏                     | 134/200 [08:48<04:29,  4.08s/it]

------------------EPOCH 134------------------
END LOSS 	 train 2.996 valid 1.972 
 KL LOSS 	 train 4.161 valid 4.161 
 TOTAL LOSS 	 train 7.157 valid 6.134 



 68%|████████████████████████████████████████████▌                     | 135/200 [08:52<04:25,  4.08s/it]

------------------EPOCH 135------------------
END LOSS 	 train 2.702 valid 1.713 
 KL LOSS 	 train 4.161 valid 4.161 
 TOTAL LOSS 	 train 6.863 valid 5.874 



 68%|████████████████████████████████████████████▉                     | 136/200 [08:56<04:22,  4.10s/it]

------------------EPOCH 136------------------
END LOSS 	 train 2.779 valid 1.889 
 KL LOSS 	 train 4.16 valid 4.159 
 TOTAL LOSS 	 train 6.939 valid 6.048 



 68%|█████████████████████████████████████████████▏                    | 137/200 [09:00<04:17,  4.09s/it]

------------------EPOCH 137------------------
END LOSS 	 train 2.936 valid 2.727 
 KL LOSS 	 train 4.159 valid 4.159 
 TOTAL LOSS 	 train 7.096 valid 6.886 



 69%|█████████████████████████████████████████████▌                    | 138/200 [09:04<04:12,  4.07s/it]

------------------EPOCH 138------------------
END LOSS 	 train 3.056 valid 1.518 
 KL LOSS 	 train 4.159 valid 4.158 
 TOTAL LOSS 	 train 7.215 valid 5.676 



 70%|█████████████████████████████████████████████▊                    | 139/200 [09:08<04:06,  4.05s/it]

------------------EPOCH 139------------------
END LOSS 	 train 3.328 valid 1.341 
 KL LOSS 	 train 4.158 valid 4.158 
 TOTAL LOSS 	 train 7.486 valid 5.5 



 70%|██████████████████████████████████████████████▏                   | 140/200 [09:12<04:01,  4.02s/it]

------------------EPOCH 140------------------
END LOSS 	 train 3.027 valid 1.128 
 KL LOSS 	 train 4.158 valid 4.158 
 TOTAL LOSS 	 train 7.186 valid 5.286 



 70%|██████████████████████████████████████████████▌                   | 141/200 [09:16<03:58,  4.04s/it]

------------------EPOCH 141------------------
END LOSS 	 train 2.632 valid 2.166 
 KL LOSS 	 train 4.157 valid 4.157 
 TOTAL LOSS 	 train 6.79 valid 6.324 



 71%|██████████████████████████████████████████████▊                   | 142/200 [09:20<03:51,  4.00s/it]

------------------EPOCH 142------------------
END LOSS 	 train 2.788 valid 1.999 
 KL LOSS 	 train 4.157 valid 4.157 
 TOTAL LOSS 	 train 6.945 valid 6.156 



 72%|███████████████████████████████████████████████▏                  | 143/200 [09:24<03:47,  3.99s/it]

------------------EPOCH 143------------------
END LOSS 	 train 2.963 valid 3.027 
 KL LOSS 	 train 4.156 valid 4.155 
 TOTAL LOSS 	 train 7.119 valid 7.183 



 72%|███████████████████████████████████████████████▌                  | 144/200 [09:28<03:43,  3.98s/it]

------------------EPOCH 144------------------
END LOSS 	 train 3.067 valid 2.297 
 KL LOSS 	 train 4.155 valid 4.155 
 TOTAL LOSS 	 train 7.223 valid 6.452 



 72%|███████████████████████████████████████████████▊                  | 145/200 [09:32<03:39,  3.99s/it]

------------------EPOCH 145------------------
END LOSS 	 train 2.983 valid 1.89 
 KL LOSS 	 train 4.155 valid 4.155 
 TOTAL LOSS 	 train 7.139 valid 6.045 



 73%|████████████████████████████████████████████████▏                 | 146/200 [09:36<03:36,  4.02s/it]

------------------EPOCH 146------------------
END LOSS 	 train 2.798 valid 2.083 
 KL LOSS 	 train 4.155 valid 4.154 
 TOTAL LOSS 	 train 6.953 valid 6.237 



 74%|████████████████████████████████████████████████▌                 | 147/200 [09:40<03:32,  4.01s/it]

------------------EPOCH 147------------------
END LOSS 	 train 2.834 valid 1.095 
 KL LOSS 	 train 4.154 valid 4.154 
 TOTAL LOSS 	 train 6.989 valid 5.249 



 74%|████████████████████████████████████████████████▊                 | 148/200 [09:44<03:28,  4.02s/it]

------------------EPOCH 148------------------
END LOSS 	 train 2.724 valid 1.092 
 KL LOSS 	 train 4.153 valid 4.153 
 TOTAL LOSS 	 train 6.878 valid 5.246 



 74%|█████████████████████████████████████████████████▏                | 149/200 [09:48<03:23,  3.99s/it]

------------------EPOCH 149------------------
END LOSS 	 train 2.64 valid 1.262 
 KL LOSS 	 train 4.153 valid 4.153 
 TOTAL LOSS 	 train 6.794 valid 5.415 



 75%|█████████████████████████████████████████████████▌                | 150/200 [09:52<03:19,  3.99s/it]

------------------EPOCH 150------------------
END LOSS 	 train 2.644 valid 1.68 
 KL LOSS 	 train 4.153 valid 4.152 
 TOTAL LOSS 	 train 6.797 valid 5.832 



 76%|█████████████████████████████████████████████████▊                | 151/200 [09:57<03:18,  4.06s/it]

------------------EPOCH 151------------------
END LOSS 	 train 2.882 valid 0.948 
 KL LOSS 	 train 4.152 valid 4.151 
 TOTAL LOSS 	 train 7.035 valid 5.099 



 76%|██████████████████████████████████████████████████▏               | 152/200 [10:01<03:13,  4.02s/it]

------------------EPOCH 152------------------
END LOSS 	 train 2.899 valid 1.174 
 KL LOSS 	 train 4.152 valid 4.151 
 TOTAL LOSS 	 train 7.051 valid 5.326 



 76%|██████████████████████████████████████████████████▍               | 153/200 [10:04<03:07,  4.00s/it]

------------------EPOCH 153------------------
END LOSS 	 train 2.639 valid 1.08 
 KL LOSS 	 train 4.151 valid 4.152 
 TOTAL LOSS 	 train 6.791 valid 5.232 



 77%|██████████████████████████████████████████████████▊               | 154/200 [10:08<03:04,  4.00s/it]

------------------EPOCH 154------------------
END LOSS 	 train 2.66 valid 1.652 
 KL LOSS 	 train 4.151 valid 4.15 
 TOTAL LOSS 	 train 6.811 valid 5.803 



 78%|███████████████████████████████████████████████████▏              | 155/200 [10:12<03:00,  4.01s/it]

------------------EPOCH 155------------------
END LOSS 	 train 2.762 valid 1.032 
 KL LOSS 	 train 4.15 valid 4.15 
 TOTAL LOSS 	 train 6.912 valid 5.182 



 78%|███████████████████████████████████████████████████▍              | 156/200 [10:17<02:56,  4.01s/it]

------------------EPOCH 156------------------
END LOSS 	 train 2.726 valid 1.298 
 KL LOSS 	 train 4.15 valid 4.15 
 TOTAL LOSS 	 train 6.876 valid 5.448 



 78%|███████████████████████████████████████████████████▊              | 157/200 [10:21<02:52,  4.01s/it]

------------------EPOCH 157------------------
END LOSS 	 train 3.206 valid 1.768 
 KL LOSS 	 train 4.149 valid 4.149 
 TOTAL LOSS 	 train 7.355 valid 5.917 



 79%|████████████████████████████████████████████████████▏             | 158/200 [10:25<02:50,  4.06s/it]

------------------EPOCH 158------------------
END LOSS 	 train 3.606 valid 7.479 
 KL LOSS 	 train 4.149 valid 4.148 
 TOTAL LOSS 	 train 7.755 valid 11.628 



 80%|████████████████████████████████████████████████████▍             | 159/200 [10:29<02:45,  4.04s/it]

------------------EPOCH 159------------------
END LOSS 	 train 3.253 valid 1.391 
 KL LOSS 	 train 4.148 valid 4.147 
 TOTAL LOSS 	 train 7.402 valid 5.539 



 80%|████████████████████████████████████████████████████▊             | 160/200 [10:33<02:40,  4.02s/it]

------------------EPOCH 160------------------
END LOSS 	 train 2.959 valid 2.214 
 KL LOSS 	 train 4.148 valid 4.147 
 TOTAL LOSS 	 train 7.107 valid 6.361 



 80%|█████████████████████████████████████████████████████▏            | 161/200 [10:37<02:39,  4.09s/it]

------------------EPOCH 161------------------
END LOSS 	 train 2.577 valid 1.505 
 KL LOSS 	 train 4.147 valid 4.146 
 TOTAL LOSS 	 train 6.725 valid 5.652 



 81%|█████████████████████████████████████████████████████▍            | 162/200 [10:41<02:36,  4.12s/it]

------------------EPOCH 162------------------
END LOSS 	 train 2.646 valid 1.123 
 KL LOSS 	 train 4.147 valid 4.147 
 TOTAL LOSS 	 train 6.793 valid 5.27 



 82%|█████████████████████████████████████████████████████▊            | 163/200 [10:45<02:31,  4.10s/it]

------------------EPOCH 163------------------
END LOSS 	 train 2.931 valid 1.899 
 KL LOSS 	 train 4.146 valid 4.146 
 TOTAL LOSS 	 train 7.078 valid 6.045 



 82%|██████████████████████████████████████████████████████            | 164/200 [10:49<02:27,  4.09s/it]

------------------EPOCH 164------------------
END LOSS 	 train 2.835 valid 1.791 
 KL LOSS 	 train 4.146 valid 4.146 
 TOTAL LOSS 	 train 6.981 valid 5.937 



 82%|██████████████████████████████████████████████████████▍           | 165/200 [10:53<02:22,  4.07s/it]

------------------EPOCH 165------------------
END LOSS 	 train 2.625 valid 2.038 
 KL LOSS 	 train 4.146 valid 4.145 
 TOTAL LOSS 	 train 6.771 valid 6.183 



 83%|██████████████████████████████████████████████████████▊           | 166/200 [10:57<02:19,  4.10s/it]

------------------EPOCH 166------------------
END LOSS 	 train 2.579 valid 1.801 
 KL LOSS 	 train 4.145 valid 4.144 
 TOTAL LOSS 	 train 6.724 valid 5.945 



 84%|███████████████████████████████████████████████████████           | 167/200 [11:01<02:15,  4.09s/it]

------------------EPOCH 167------------------
END LOSS 	 train 2.756 valid 1.39 
 KL LOSS 	 train 4.145 valid 4.144 
 TOTAL LOSS 	 train 6.901 valid 5.535 



 84%|███████████████████████████████████████████████████████▍          | 168/200 [11:06<02:11,  4.10s/it]

------------------EPOCH 168------------------
END LOSS 	 train 2.622 valid 1.416 
 KL LOSS 	 train 4.144 valid 4.145 
 TOTAL LOSS 	 train 6.767 valid 5.562 



 84%|███████████████████████████████████████████████████████▊          | 169/200 [11:10<02:07,  4.13s/it]

------------------EPOCH 169------------------
END LOSS 	 train 2.606 valid 1.535 
 KL LOSS 	 train 4.144 valid 4.144 
 TOTAL LOSS 	 train 6.75 valid 5.68 



 85%|████████████████████████████████████████████████████████          | 170/200 [11:14<02:02,  4.09s/it]

------------------EPOCH 170------------------
END LOSS 	 train 2.901 valid 2.028 
 KL LOSS 	 train 4.144 valid 4.143 
 TOTAL LOSS 	 train 7.045 valid 6.171 



 86%|████████████████████████████████████████████████████████▍         | 171/200 [11:18<01:58,  4.08s/it]

------------------EPOCH 171------------------
END LOSS 	 train 2.767 valid 1.347 
 KL LOSS 	 train 4.143 valid 4.143 
 TOTAL LOSS 	 train 6.911 valid 5.491 



 86%|████████████████████████████████████████████████████████▊         | 172/200 [11:22<01:53,  4.06s/it]

------------------EPOCH 172------------------
END LOSS 	 train 2.68 valid 1.589 
 KL LOSS 	 train 4.142 valid 4.142 
 TOTAL LOSS 	 train 6.823 valid 5.732 



 86%|█████████████████████████████████████████████████████████         | 173/200 [11:26<01:50,  4.07s/it]

------------------EPOCH 173------------------
END LOSS 	 train 2.673 valid 0.756 
 KL LOSS 	 train 4.142 valid 4.141 
 TOTAL LOSS 	 train 6.815 valid 4.898 



 87%|█████████████████████████████████████████████████████████▍        | 174/200 [11:30<01:46,  4.09s/it]

------------------EPOCH 174------------------
END LOSS 	 train 2.485 valid 0.814 
 KL LOSS 	 train 4.142 valid 4.141 
 TOTAL LOSS 	 train 6.627 valid 4.956 



 88%|█████████████████████████████████████████████████████████▊        | 175/200 [11:34<01:41,  4.07s/it]

------------------EPOCH 175------------------
END LOSS 	 train 2.828 valid 0.666 
 KL LOSS 	 train 4.141 valid 4.141 
 TOTAL LOSS 	 train 6.97 valid 4.808 



 88%|██████████████████████████████████████████████████████████        | 176/200 [11:38<01:37,  4.05s/it]

------------------EPOCH 176------------------
END LOSS 	 train 2.622 valid 1.295 
 KL LOSS 	 train 4.141 valid 4.139 
 TOTAL LOSS 	 train 6.763 valid 5.435 



 88%|██████████████████████████████████████████████████████████▍       | 177/200 [11:42<01:32,  4.02s/it]

------------------EPOCH 177------------------
END LOSS 	 train 2.708 valid 1.13 
 KL LOSS 	 train 4.141 valid 4.141 
 TOTAL LOSS 	 train 6.849 valid 5.272 



 89%|██████████████████████████████████████████████████████████▋       | 178/200 [11:46<01:28,  4.03s/it]

------------------EPOCH 178------------------
END LOSS 	 train 2.685 valid 1.063 
 KL LOSS 	 train 4.14 valid 4.14 
 TOTAL LOSS 	 train 6.826 valid 5.203 



 90%|███████████████████████████████████████████████████████████       | 179/200 [11:50<01:24,  4.03s/it]

------------------EPOCH 179------------------
END LOSS 	 train 2.704 valid 1.207 
 KL LOSS 	 train 4.14 valid 4.139 
 TOTAL LOSS 	 train 6.845 valid 5.347 



 90%|███████████████████████████████████████████████████████████▍      | 180/200 [11:54<01:21,  4.07s/it]

------------------EPOCH 180------------------
END LOSS 	 train 2.506 valid 1.424 
 KL LOSS 	 train 4.14 valid 4.139 
 TOTAL LOSS 	 train 6.646 valid 5.563 



 90%|███████████████████████████████████████████████████████████▋      | 181/200 [11:58<01:16,  4.03s/it]

------------------EPOCH 181------------------
END LOSS 	 train 2.516 valid 1.095 
 KL LOSS 	 train 4.139 valid 4.139 
 TOTAL LOSS 	 train 6.656 valid 5.235 



 91%|████████████████████████████████████████████████████████████      | 182/200 [12:02<01:12,  4.06s/it]

------------------EPOCH 182------------------
END LOSS 	 train 2.485 valid 1.162 
 KL LOSS 	 train 4.139 valid 4.137 
 TOTAL LOSS 	 train 6.624 valid 5.299 



 92%|████████████████████████████████████████████████████████████▍     | 183/200 [12:06<01:08,  4.04s/it]

------------------EPOCH 183------------------
END LOSS 	 train 2.788 valid 2.151 
 KL LOSS 	 train 4.139 valid 4.138 
 TOTAL LOSS 	 train 6.927 valid 6.29 



 92%|████████████████████████████████████████████████████████████▋     | 184/200 [12:10<01:04,  4.04s/it]

------------------EPOCH 184------------------
END LOSS 	 train 2.631 valid 1.651 
 KL LOSS 	 train 4.138 valid 4.137 
 TOTAL LOSS 	 train 6.77 valid 5.788 



 92%|█████████████████████████████████████████████████████████████     | 185/200 [12:14<01:00,  4.04s/it]

------------------EPOCH 185------------------
END LOSS 	 train 2.597 valid 2.045 
 KL LOSS 	 train 4.138 valid 4.137 
 TOTAL LOSS 	 train 6.736 valid 6.183 



 93%|█████████████████████████████████████████████████████████████▍    | 186/200 [12:18<00:56,  4.04s/it]

------------------EPOCH 186------------------
END LOSS 	 train 2.541 valid 2.305 
 KL LOSS 	 train 4.138 valid 4.137 
 TOTAL LOSS 	 train 6.679 valid 6.442 



 94%|█████████████████████████████████████████████████████████████▋    | 187/200 [12:23<00:52,  4.05s/it]

------------------EPOCH 187------------------
END LOSS 	 train 2.522 valid 1.419 
 KL LOSS 	 train 4.137 valid 4.137 
 TOTAL LOSS 	 train 6.659 valid 5.556 



 94%|██████████████████████████████████████████████████████████████    | 188/200 [12:27<00:48,  4.02s/it]

------------------EPOCH 188------------------
END LOSS 	 train 2.401 valid 1.157 
 KL LOSS 	 train 4.137 valid 4.137 
 TOTAL LOSS 	 train 6.538 valid 5.294 



 94%|██████████████████████████████████████████████████████████████▎   | 189/200 [12:31<00:44,  4.03s/it]

------------------EPOCH 189------------------
END LOSS 	 train 2.653 valid 1.359 
 KL LOSS 	 train 4.137 valid 4.136 
 TOTAL LOSS 	 train 6.79 valid 5.495 



 95%|██████████████████████████████████████████████████████████████▋   | 190/200 [12:35<00:41,  4.11s/it]

------------------EPOCH 190------------------
END LOSS 	 train 2.919 valid 1.42 
 KL LOSS 	 train 4.136 valid 4.135 
 TOTAL LOSS 	 train 7.056 valid 5.556 



 96%|███████████████████████████████████████████████████████████████   | 191/200 [12:39<00:37,  4.17s/it]

------------------EPOCH 191------------------
END LOSS 	 train 2.431 valid 2.1 
 KL LOSS 	 train 4.136 valid 4.135 
 TOTAL LOSS 	 train 6.567 valid 6.235 



 96%|███████████████████████████████████████████████████████████████▎  | 192/200 [12:43<00:33,  4.13s/it]

------------------EPOCH 192------------------
END LOSS 	 train 2.504 valid 2.017 
 KL LOSS 	 train 4.135 valid 4.135 
 TOTAL LOSS 	 train 6.64 valid 6.153 



 96%|███████████████████████████████████████████████████████████████▋  | 193/200 [12:47<00:29,  4.17s/it]

------------------EPOCH 193------------------
END LOSS 	 train 2.57 valid 1.129 
 KL LOSS 	 train 4.135 valid 4.135 
 TOTAL LOSS 	 train 6.705 valid 5.265 



 97%|████████████████████████████████████████████████████████████████  | 194/200 [12:52<00:24,  4.13s/it]

------------------EPOCH 194------------------
END LOSS 	 train 2.647 valid 1.178 
 KL LOSS 	 train 4.134 valid 4.135 
 TOTAL LOSS 	 train 6.782 valid 5.313 



 98%|████████████████████████████████████████████████████████████████▎ | 195/200 [12:56<00:20,  4.10s/it]

------------------EPOCH 195------------------
END LOSS 	 train 2.696 valid 3.627 
 KL LOSS 	 train 4.134 valid 4.134 
 TOTAL LOSS 	 train 6.83 valid 7.762 



 98%|████████████████████████████████████████████████████████████████▋ | 196/200 [13:00<00:16,  4.10s/it]

------------------EPOCH 196------------------
END LOSS 	 train 2.579 valid 1.726 
 KL LOSS 	 train 4.134 valid 4.133 
 TOTAL LOSS 	 train 6.714 valid 5.86 



 98%|█████████████████████████████████████████████████████████████████ | 197/200 [13:04<00:12,  4.09s/it]

------------------EPOCH 197------------------
END LOSS 	 train 2.605 valid 0.96 
 KL LOSS 	 train 4.134 valid 4.133 
 TOTAL LOSS 	 train 6.739 valid 5.094 



 99%|█████████████████████████████████████████████████████████████████▎| 198/200 [13:08<00:08,  4.07s/it]

------------------EPOCH 198------------------
END LOSS 	 train 2.551 valid 1.522 
 KL LOSS 	 train 4.134 valid 4.133 
 TOTAL LOSS 	 train 6.685 valid 5.655 



100%|█████████████████████████████████████████████████████████████████▋| 199/200 [13:12<00:04,  4.06s/it]

------------------EPOCH 199------------------
END LOSS 	 train 2.506 valid 1.173 
 KL LOSS 	 train 4.133 valid 4.133 
 TOTAL LOSS 	 train 6.639 valid 5.307 



100%|██████████████████████████████████████████████████████████████████| 200/200 [13:16<00:00,  3.98s/it]

------------------EPOCH 200------------------
END LOSS 	 train 2.575 valid 1.021 
 KL LOSS 	 train 4.133 valid 4.133 
 TOTAL LOSS 	 train 6.708 valid 5.154 






In [27]:
M_opt = 128
model_used.update_n_samples(n_samples=M_opt)
Y_pred = model_used.forward_dist(X_test, aleat_bool)
Y_pred_original = inverse_transform(Y_pred)

In [28]:
op = cmu.RiskPortOP(M_opt, n_assets, min_return, torch.tensor(Y_original), dev)
op.end_loss_dist(Y_pred_original, Y_test_original)

tensor(3.2776, dtype=torch.float64, grad_fn=<MeanBackward0>)