In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

import joblib
import random
import sys

from mip import Model as ModelMip, xsum, minimize, maximize, \
INTEGER, BINARY, CONTINUOUS, CutType, OptimizationStatus

import data_generator
from model import StrongStandardNet, StrongVariationalNet
from train import TrainDecoupled, TrainCombined

from sklearn.preprocessing import StandardScaler

import torch.multiprocessing as mp

In [2]:
import cvar_minimization_utils as cmu

In [3]:
is_cuda = False
dev = torch.device('cpu')  
if torch.cuda.is_available():
    is_cuda = True
    dev = torch.device('cuda') 

In [4]:
method_name = 'bnn'
method_learning = 'combined'
aleat_bool = 1

N_SAMPLES = 32
BATCH_SIZE_LOADER = 64

EPOCHS = 100

cpu_count = mp.cpu_count()

seed_number = 0

lr = 0.0005

In [5]:
np.random.seed(seed_number)
torch.manual_seed(seed_number)
random.seed(seed_number)

assert (method_name in ['ann','bnn','gp'])

if method_name in ['ann','bnn']:
    assert (method_learning in ['decoupled','combined'])
    assert (aleat_bool in [True, False])
    assert (N_SAMPLES>=1 and N_SAMPLES<9999)
    #assert (M_SAMPLES>=1 and M_SAMPLES<9999)

bnn = False 
if method_name == 'bnn':
    bnn = True   
    K = 1 # Hyperparameter for the training in ELBO loss
    PLV = 1 # Prior in ELBO loss   

model_name = method_name + '_constrained_'
for i in range(2, len(sys.argv)):
    model_name += '_'+sys.argv[i]
model_name += '_'+ str(seed_number)

In [6]:
def gen_intermidiate(num, n_assets, x1, x2, x3):      
        factor = num * 2/(n_assets)
        return x1**factor + x2**factor + x3**factor
    
def gen_data(N, n_assets, nl, seed_number=42):
    np.random.seed(seed_number)
    x1 = np.random.normal(1, 1, size = N).clip(0)
    x2 = np.random.normal(1, 1, size = N).clip(0)
    x3 = np.random.normal(1, 1, size = N).clip(0)
    X = np.vstack((x1, x2, x3)).T

    Y = np.zeros((N, n_assets))
    for i in range(1, n_assets + 1):
        interm = gen_intermidiate(i, n_assets, x1, x2, x3)
        Y[:,i-1] = (np.sin(interm) - np.sin(interm).mean()) \
        + np.sin(interm).std()*(-0.45 + np.random.random())
        
        ns = Y[:,0].shape[0]//2
        
        Y[:,i-1] = Y[:,i-1] + nl*Y[:,i-1]*x1*(np.random.beta(5, 2, size = Y[:,0].shape) - 0.2)

    return X, Y

def gen_cond_dist(N, n_assets, n_samples, nl, seed_number=420):
    np.random.seed(seed_number)
    Y_dist = np.zeros((n_samples, N, n_assets))
    for i in range(0, n_samples):
        Y_dist[i, :, :] = gen_data(N, n_assets, nl, seed_number=np.random.randint(0,999999))[1]
    return Y_dist

In [7]:
class CVaROP():
    def __init__(self, beta, n_assets, min_return, Y_train):
        self.beta = beta
        self.n_assets = n_assets
        self.R = min_return 
        self.uy = Y_train.mean(axis=0)
        
        self.zstar = np.zeros_like(Y_train)
        self.alphastar = np.zeros_like(Y_train[:,0])
        self.cvarstar = np.zeros_like(Y_train[:,0])
                        
        #for i in range(0, Y_train.shape[0]):
        #    self.zstar[i,:], _ , self.alphastar[i], self.cvarstar[i] = \
        #    self.minimize_cvar(np.expand_dims(Y_train, 0)[:,i,:])
      
    def minimize_cvar(self, y):
    
        n_samples = y.shape[0]
        n_assets = y.shape[1]
        
        assert self.n_assets == n_assets
        
        
        m = ModelMip("cvar")
        m.verbose = 0
        z = ([m.add_var(var_type=CONTINUOUS, name=f'z_{i}') for i in range(0, n_assets)])
        u = ([m.add_var(var_type=CONTINUOUS, name=f'u_{i}') for i in range(0, n_samples)])
        alpha = m.add_var(var_type=CONTINUOUS, name='alpha')

        m.objective = minimize(alpha + (1/(1 - self.beta))*(1/n_samples) \
                               *xsum(u[i] for i in range(0, n_samples)))

        for i in range(0, n_assets):
            m += z[i] >= 0

        for i in range(0, n_samples):
            m += u[i] >= 0

        for i in range(0, n_samples):
            m += xsum(z[j]*y[i][j] for j in range(0, n_assets)) + alpha + u[i] >= 0

        for i in range(0, n_assets):
            m += xsum(-z[i]*self.uy[i] for i in range(0, n_assets)) <= -self.R

        m.optimize()
        f_opt = m.objective_value
        argmins = []
        for v in m.vars:
            argmins.append(v.x)
           
        zstar = argmins[:n_assets]
        ustar = argmins[n_assets:n_assets+n_samples]
        alphastar = argmins[n_assets+n_samples]         

        return zstar, ustar, alphastar, f_opt
    
    
    def minimize_cvar_dataset(self, Y_dist):
        zstar = np.zeros_like(Y_dist[0])
        for i in range(0, Y_dist.shape[1]):
            zstar[i,:] , _, _, _ = self.minimize_cvar(Y_dist[:,i,:])       
        return zstar
    
    def minimize_cvar_given_z(self, y, z):
    
        n_samples = y.shape[0]
        n_assets = y.shape[1]
        
        assert self.n_assets == n_assets
        
        
        m = ModelMip("cvar")
        m.verbose = 0

        u = ([m.add_var(var_type=CONTINUOUS, name=f'u_{i}') for i in range(0, n_samples)])
        alpha = m.add_var(var_type=CONTINUOUS, name='alpha')

        m.objective = minimize(alpha + (1/(1 - self.beta))*(1/n_samples) \
                               *xsum(u[i] for i in range(0, n_samples)))

        for i in range(0, n_samples):
            m += u[i] >= 0

        for i in range(0, n_samples):
            m += xsum(z[j]*y[i][j] for j in range(0, n_assets)) + alpha + u[i] >= 0

        m.optimize()
        f_opt = m.objective_value
        argmins = []
        for v in m.vars:
            argmins.append(v.x)
           
        ustar = argmins[:n_samples]
        alphastar = argmins[n_samples]    

        return ustar, alphastar, f_opt
    
    
    def minimize_cvar_given_z_dataset(self, Y_dist, z):
        alphastar = np.zeros_like(Y_dist[0,:,0])
        for i in range(0, Y_dist.shape[1]):
            _, alphastar[i] , _ = self.minimize_cvar_given_z(Y_dist[:,i,:], z[i,:])
        return alphastar
        
    def calc_cvar_dataset(self, Y_dist, zstar_pred, alpha):
        cvar_pred = alpha + (1/(1-beta))*(np.maximum(-((Y_dist*zstar_pred).sum(2)) \
                                                              - alpha, 0)).mean(0)
        return cvar_pred
        
    
    def calc_f_dataset(self, Y_dist_pred, Y_dist):
        zstar_pred = self.minimize_cvar_dataset(self, Y_dist_pred)
        alphastar_pred = minimize_cvar_given_z_dataset(self, Y_dist, zstar_pred)
        cvar_pred = self.calc_cvar_dataset(Y_dist, zstar_pred)
        return cvar_pred
    
    
    def cost_fn(self, Y_dist_pred, Y_dist):
        f = self.calc_f_dataset(Y_dist_pred, Y_dist)
        f_total = torch.mean(f)
        return f_total
        
    #def end_loss(self, Y_dist_pred, Y_dist):
    #    y_pred = y_pred.unsqueeze(0)
    #    f_total = self.cost_fn(y_pred, y)
    #    return f_total
    
    def end_loss_dist(self, Y_dist_pred, Y_dist):
        f_total = self.cost_fn(Y_dist_pred, Y_dist)
        return f_total

In [8]:
class SolverDiff(torch.autograd.Function):
    @staticmethod
    def forward(ctx, Y_dist_pred):
        
        Y_dist_pred = Y_dist_pred.detach().numpy()
        Y_dist = Ys[:,:,:,1].detach().numpy()
                
        zstar_pred = op_solver_dist.minimize_cvar_dataset(Y_dist_pred)
        alphastar_pred = op_solver_dist.minimize_cvar_given_z_dataset(
            np.expand_dims(Y_original, 0), zstar_pred)
        
            
        ctx.Y_true = Y_true
        ctx.Y = Y
        ctx.argmin = zstar_pred
        ctx.argmin_true = zstar_true
        ctx.sign_f = np.sign(f_pred)
        
        

        return torch.tensor(zstar_pred, dtype=torch.float32)#, torch.tensor(reg, dtype=torch.float32)

    # SPO
    @staticmethod
    def backward(ctx, grad_output):  
        
        zstar_true = np.zeros_like(Y_dist[0])
        for i in range(0, Y_dist.shape[1]):
            zstar_true[i,:] , _, _, _ = self.minimize_cvar(Y_dist[:,i,:])
        
        grad_output = grad_output.detach().numpy()
        Y_aux = 2*ctx.Y - ctx.Y_true
        argmin_aux = op.solve_dataset(Y_to_solve=ctx.ps*Y_aux, activations=ctx.activations, is_torch=False)
        grad_input = (-argmin_aux + ctx.argmin_true)*grad_output/ctx.Y_true
        grad_input = torch.tensor(grad_input, dtype=torch.float32)
                        
        return torch.dstack((grad_input, torch.zeros_like(grad_input))), None

In [9]:
N_train = 1000
N_val = 600
N_test = 600

n_assets = 15

n_samples_orig = 200

nl = 0.1

X, Y_original = gen_data(N_train, n_assets, nl, seed_number)
X_val, Y_val_original = gen_data(N_val, n_assets, nl, seed_number + 80)
X_test, Y_test_original = gen_data(N_test, n_assets, nl, seed_number + 160)

Y_dist = gen_cond_dist(N_train, n_assets, n_samples_orig, nl=nl, seed_number=seed_number)
Y_val_dist = gen_cond_dist(N_val, n_assets, n_samples_orig, nl=nl, seed_number=seed_number + 100)
Y_test_dist = gen_cond_dist(N_test, n_assets, n_samples_orig, nl=nl, seed_number=seed_number + 200)

In [10]:
# Output normalization
scaler = StandardScaler()
scaler.fit(Y_original)
tmean = torch.tensor(scaler.mean_).to(dev)
tstd = torch.tensor(scaler.scale_).to(dev)
joblib.dump(scaler, 'scaler_cvar.gz')

# Function to denormalize the data
def inverse_transform(yy):
    return yy*tstd + tmean

Y = scaler.transform(Y_original).copy()
X = torch.tensor(X, dtype=torch.float32)
Y = torch.tensor(Y, dtype=torch.float32)
data_train = data_generator.ArtificialDataset(X, Y)
training_loader = torch.utils.data.DataLoader(
    data_train, batch_size=BATCH_SIZE_LOADER,
    shuffle=True, num_workers=cpu_count)
Y_dist = torch.tensor(Y_dist, dtype=torch.float32)

Y_val = scaler.transform(Y_val_original).copy()
X_val = torch.tensor(X_val, dtype=torch.float32)
Y_val_original = torch.tensor(Y_val_original, dtype=torch.float32)
Y_val = torch.tensor(Y_val, dtype=torch.float32)
data_valid = data_generator.ArtificialDataset(X_val, Y_val)
validation_loader = torch.utils.data.DataLoader(
    data_valid, batch_size=BATCH_SIZE_LOADER,
    shuffle=False, num_workers=cpu_count)

X_test = torch.tensor(X_test, dtype=torch.float32)
Y_test_original = torch.tensor(
    Y_test_original, dtype=torch.float32)
data_test = data_generator.ArtificialDataset(
    X_test, Y_test_original)
test_loader = torch.utils.data.DataLoader(
data_test, batch_size=16,
shuffle=False, num_workers=cpu_count)

In [11]:
n_samples = Y.shape[0]
n_assets = Y.shape[1]
beta = 0.80
min_return = 100

In [40]:
method_name = 'ann'
method_learning = 'decoupled'

bnn = False
aleat_bool = False
if method_name == 'bnn':
    bnn = True
    aleat_bool = True
N_SAMPLES = 16
K = 1

In [41]:
if method_name == 'bnn':
    h = StrongVariationalNet(
    n_samples=N_SAMPLES,
    input_size=X.shape[1], 
    output_size=Y.shape[1], 
    plv=PLV, 
    dev=dev).to(dev)

    
#ANN Baseline model
elif method_name == 'ann':
    h = StrongStandardNet(X.shape[1], Y.shape[1]).to(dev)
    K = 0 # There is no K in ANN
    N_SAMPLES = 1
    
opt_h = torch.optim.Adam(h.parameters(), lr=lr)
mse_loss = nn.MSELoss(reduction='none')


opt_cvar_gZ = cmu.CVaROPgZ(1, n_assets, beta, min_return, torch.tensor(Y_original), dev)
opt_cvar = cmu.CVaRQP(N_SAMPLES, n_assets, beta, min_return, torch.tensor(Y_original), opt_cvar_gZ, dev)

# Decoupled learning approach
if method_learning == 'decoupled':
    train_NN = TrainDecoupled(
                    bnn = bnn,
                    model=h,
                    opt=opt_h,
                    loss_data=mse_loss,
                    K=K,
                    aleat_bool=aleat_bool,
                    training_loader=training_loader,
                    validation_loader=validation_loader,
                    dev=dev
                )

# Combined learning approach (end-to-end loss)
elif method_learning == 'combined':
    train_NN = TrainCombined(
                    bnn = bnn,
                    model=h,
                    opt=opt_h,
                    K=K,
                    aleat_bool=aleat_bool,
                    training_loader=training_loader,
                    scaler=scaler,
                    validation_loader=validation_loader,
                    OP=opt_cvar,
                    dev=dev
                )

In [42]:
model_used = train_NN.train(EPOCHS=200)

  0%|▎                                                                   | 1/200 [00:01<05:11,  1.57s/it]

------------------EPOCH 1------------------
DATA LOSS 	 train 0.908 valid 0.89
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.91 valid 0.89


  1%|▋                                                                   | 2/200 [00:03<05:06,  1.55s/it]

------------------EPOCH 2------------------
DATA LOSS 	 train 0.656 valid 0.677
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.66 valid 0.68


  2%|█                                                                   | 3/200 [00:04<05:03,  1.54s/it]

------------------EPOCH 3------------------
DATA LOSS 	 train 0.484 valid 0.592
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.48 valid 0.59


  2%|█▎                                                                  | 4/200 [00:06<05:05,  1.56s/it]

------------------EPOCH 4------------------
DATA LOSS 	 train 0.42 valid 0.543
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.42 valid 0.54


  2%|█▋                                                                  | 5/200 [00:07<05:06,  1.57s/it]

------------------EPOCH 5------------------
DATA LOSS 	 train 0.38 valid 0.516
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.38 valid 0.52


  3%|██                                                                  | 6/200 [00:09<05:01,  1.56s/it]

------------------EPOCH 6------------------
DATA LOSS 	 train 0.344 valid 0.458
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.34 valid 0.46


  4%|██▍                                                                 | 7/200 [00:10<04:58,  1.55s/it]

------------------EPOCH 7------------------
DATA LOSS 	 train 0.315 valid 0.444
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.31 valid 0.44


  4%|██▋                                                                 | 8/200 [00:12<04:54,  1.53s/it]

------------------EPOCH 8------------------
DATA LOSS 	 train 0.281 valid 0.428
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.28 valid 0.43


  4%|███                                                                 | 9/200 [00:13<04:57,  1.56s/it]

------------------EPOCH 9------------------
DATA LOSS 	 train 0.252 valid 0.393
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.25 valid 0.39


  5%|███▎                                                               | 10/200 [00:15<04:55,  1.56s/it]

------------------EPOCH 10------------------
DATA LOSS 	 train 0.228 valid 0.358
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.23 valid 0.36


  6%|███▋                                                               | 11/200 [00:17<04:53,  1.55s/it]

------------------EPOCH 11------------------
DATA LOSS 	 train 0.208 valid 0.351
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.21 valid 0.35


  6%|████                                                               | 12/200 [00:18<04:53,  1.56s/it]

------------------EPOCH 12------------------
DATA LOSS 	 train 0.196 valid 0.341
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.2 valid 0.34


  6%|████▎                                                              | 13/200 [00:20<04:50,  1.55s/it]

------------------EPOCH 13------------------
DATA LOSS 	 train 0.183 valid 0.34
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.18 valid 0.34


  7%|████▋                                                              | 14/200 [00:21<04:51,  1.57s/it]

------------------EPOCH 14------------------
DATA LOSS 	 train 0.174 valid 0.324
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.17 valid 0.32


  8%|█████                                                              | 15/200 [00:23<04:48,  1.56s/it]

------------------EPOCH 15------------------
DATA LOSS 	 train 0.165 valid 0.311
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.16 valid 0.31


  8%|█████▎                                                             | 16/200 [00:24<04:49,  1.57s/it]

------------------EPOCH 16------------------
DATA LOSS 	 train 0.158 valid 0.292
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.16 valid 0.29


  8%|█████▋                                                             | 17/200 [00:26<04:48,  1.57s/it]

------------------EPOCH 17------------------
DATA LOSS 	 train 0.15 valid 0.293
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.15 valid 0.29


  9%|██████                                                             | 18/200 [00:28<04:46,  1.58s/it]

------------------EPOCH 18------------------
DATA LOSS 	 train 0.144 valid 0.277
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.14 valid 0.28


 10%|██████▎                                                            | 19/200 [00:29<04:45,  1.58s/it]

------------------EPOCH 19------------------
DATA LOSS 	 train 0.139 valid 0.284
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.14 valid 0.28


 10%|██████▋                                                            | 20/200 [00:31<04:45,  1.59s/it]

------------------EPOCH 20------------------
DATA LOSS 	 train 0.132 valid 0.271
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.13 valid 0.27


 10%|███████                                                            | 21/200 [00:32<04:45,  1.59s/it]

------------------EPOCH 21------------------
DATA LOSS 	 train 0.127 valid 0.265
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.13 valid 0.26


 11%|███████▎                                                           | 22/200 [00:34<04:42,  1.59s/it]

------------------EPOCH 22------------------
DATA LOSS 	 train 0.122 valid 0.26
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.12 valid 0.26


 12%|███████▋                                                           | 23/200 [00:36<04:44,  1.61s/it]

------------------EPOCH 23------------------
DATA LOSS 	 train 0.117 valid 0.243
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.12 valid 0.24


 12%|████████                                                           | 24/200 [00:37<04:42,  1.60s/it]

------------------EPOCH 24------------------
DATA LOSS 	 train 0.113 valid 0.259
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.11 valid 0.26


 12%|████████▍                                                          | 25/200 [00:39<04:37,  1.59s/it]

------------------EPOCH 25------------------
DATA LOSS 	 train 0.109 valid 0.255
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.11 valid 0.26


 13%|████████▋                                                          | 26/200 [00:40<04:33,  1.57s/it]

------------------EPOCH 26------------------
DATA LOSS 	 train 0.105 valid 0.251
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.1 valid 0.25


 14%|█████████                                                          | 27/200 [00:42<04:31,  1.57s/it]

------------------EPOCH 27------------------
DATA LOSS 	 train 0.101 valid 0.231
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.1 valid 0.23


 14%|█████████▍                                                         | 28/200 [00:43<04:32,  1.59s/it]

------------------EPOCH 28------------------
DATA LOSS 	 train 0.098 valid 0.24
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.1 valid 0.24


 14%|█████████▋                                                         | 29/200 [00:45<04:34,  1.60s/it]

------------------EPOCH 29------------------
DATA LOSS 	 train 0.094 valid 0.221
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.09 valid 0.22


 15%|██████████                                                         | 30/200 [00:47<04:32,  1.60s/it]

------------------EPOCH 30------------------
DATA LOSS 	 train 0.092 valid 0.234
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.09 valid 0.23


 16%|██████████▍                                                        | 31/200 [00:48<04:29,  1.60s/it]

------------------EPOCH 31------------------
DATA LOSS 	 train 0.09 valid 0.233
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.09 valid 0.23


 16%|██████████▋                                                        | 32/200 [00:50<04:28,  1.60s/it]

------------------EPOCH 32------------------
DATA LOSS 	 train 0.089 valid 0.228
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.09 valid 0.23


 16%|███████████                                                        | 33/200 [00:52<04:27,  1.60s/it]

------------------EPOCH 33------------------
DATA LOSS 	 train 0.085 valid 0.227
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.09 valid 0.23


 17%|███████████▍                                                       | 34/200 [00:53<04:26,  1.61s/it]

------------------EPOCH 34------------------
DATA LOSS 	 train 0.084 valid 0.227
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.08 valid 0.23


 18%|███████████▋                                                       | 35/200 [00:55<04:23,  1.60s/it]

------------------EPOCH 35------------------
DATA LOSS 	 train 0.083 valid 0.223
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.08 valid 0.22


 18%|████████████                                                       | 36/200 [00:56<04:20,  1.59s/it]

------------------EPOCH 36------------------
DATA LOSS 	 train 0.08 valid 0.235
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.08 valid 0.23


 18%|████████████▍                                                      | 37/200 [00:58<04:19,  1.59s/it]

------------------EPOCH 37------------------
DATA LOSS 	 train 0.078 valid 0.23
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.08 valid 0.23


 19%|████████████▋                                                      | 38/200 [00:59<04:18,  1.59s/it]

------------------EPOCH 38------------------
DATA LOSS 	 train 0.077 valid 0.208
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.08 valid 0.21


 20%|█████████████                                                      | 39/200 [01:01<04:14,  1.58s/it]

------------------EPOCH 39------------------
DATA LOSS 	 train 0.075 valid 0.213
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.08 valid 0.21


 20%|█████████████▍                                                     | 40/200 [01:03<04:13,  1.59s/it]

------------------EPOCH 40------------------
DATA LOSS 	 train 0.073 valid 0.226
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.07 valid 0.23


 20%|█████████████▋                                                     | 41/200 [01:04<04:14,  1.60s/it]

------------------EPOCH 41------------------
DATA LOSS 	 train 0.073 valid 0.227
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.07 valid 0.23


 21%|██████████████                                                     | 42/200 [01:06<04:16,  1.63s/it]

------------------EPOCH 42------------------
DATA LOSS 	 train 0.072 valid 0.211
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.07 valid 0.21


 22%|██████████████▍                                                    | 43/200 [01:08<04:15,  1.63s/it]

------------------EPOCH 43------------------
DATA LOSS 	 train 0.07 valid 0.214
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.07 valid 0.21


 22%|██████████████▋                                                    | 44/200 [01:09<04:12,  1.62s/it]

------------------EPOCH 44------------------
DATA LOSS 	 train 0.07 valid 0.232
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.07 valid 0.23


 22%|███████████████                                                    | 45/200 [01:11<04:12,  1.63s/it]

------------------EPOCH 45------------------
DATA LOSS 	 train 0.068 valid 0.209
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.07 valid 0.21


 23%|███████████████▍                                                   | 46/200 [01:12<04:10,  1.63s/it]

------------------EPOCH 46------------------
DATA LOSS 	 train 0.066 valid 0.213
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.07 valid 0.21


 24%|███████████████▋                                                   | 47/200 [01:14<04:08,  1.63s/it]

------------------EPOCH 47------------------
DATA LOSS 	 train 0.066 valid 0.215
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.07 valid 0.22


 24%|████████████████                                                   | 48/200 [01:16<04:06,  1.62s/it]

------------------EPOCH 48------------------
DATA LOSS 	 train 0.065 valid 0.209
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.07 valid 0.21


 24%|████████████████▍                                                  | 49/200 [01:17<04:05,  1.63s/it]

------------------EPOCH 49------------------
DATA LOSS 	 train 0.063 valid 0.199
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.06 valid 0.2


 25%|████████████████▊                                                  | 50/200 [01:19<04:03,  1.62s/it]

------------------EPOCH 50------------------
DATA LOSS 	 train 0.063 valid 0.212
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.06 valid 0.21


 26%|█████████████████                                                  | 51/200 [01:21<04:03,  1.64s/it]

------------------EPOCH 51------------------
DATA LOSS 	 train 0.061 valid 0.207
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.06 valid 0.21


 26%|█████████████████▍                                                 | 52/200 [01:22<03:59,  1.62s/it]

------------------EPOCH 52------------------
DATA LOSS 	 train 0.06 valid 0.207
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.06 valid 0.21


 26%|█████████████████▊                                                 | 53/200 [01:24<03:59,  1.63s/it]

------------------EPOCH 53------------------
DATA LOSS 	 train 0.059 valid 0.187
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.06 valid 0.19


 27%|██████████████████                                                 | 54/200 [01:25<03:56,  1.62s/it]

------------------EPOCH 54------------------
DATA LOSS 	 train 0.06 valid 0.209
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.06 valid 0.21


 28%|██████████████████▍                                                | 55/200 [01:27<03:53,  1.61s/it]

------------------EPOCH 55------------------
DATA LOSS 	 train 0.058 valid 0.19
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.06 valid 0.19


 28%|██████████████████▊                                                | 56/200 [01:29<03:53,  1.62s/it]

------------------EPOCH 56------------------
DATA LOSS 	 train 0.057 valid 0.203
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.06 valid 0.2


 28%|███████████████████                                                | 57/200 [01:30<03:52,  1.63s/it]

------------------EPOCH 57------------------
DATA LOSS 	 train 0.057 valid 0.208
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.06 valid 0.21


 29%|███████████████████▍                                               | 58/200 [01:32<03:50,  1.63s/it]

------------------EPOCH 58------------------
DATA LOSS 	 train 0.055 valid 0.2
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.06 valid 0.2


 30%|███████████████████▊                                               | 59/200 [01:34<03:46,  1.61s/it]

------------------EPOCH 59------------------
DATA LOSS 	 train 0.056 valid 0.194
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.06 valid 0.19


 30%|████████████████████                                               | 60/200 [01:35<03:46,  1.62s/it]

------------------EPOCH 60------------------
DATA LOSS 	 train 0.054 valid 0.191
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.19


 30%|████████████████████▍                                              | 61/200 [01:37<03:44,  1.61s/it]

------------------EPOCH 61------------------
DATA LOSS 	 train 0.053 valid 0.194
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.19


 31%|████████████████████▊                                              | 62/200 [01:38<03:43,  1.62s/it]

------------------EPOCH 62------------------
DATA LOSS 	 train 0.053 valid 0.185
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.19


 32%|█████████████████████                                              | 63/200 [01:40<03:40,  1.61s/it]

------------------EPOCH 63------------------
DATA LOSS 	 train 0.052 valid 0.184
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.18


 32%|█████████████████████▍                                             | 64/200 [01:42<03:39,  1.61s/it]

------------------EPOCH 64------------------
DATA LOSS 	 train 0.051 valid 0.196
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.2


 32%|█████████████████████▊                                             | 65/200 [01:43<03:35,  1.60s/it]

------------------EPOCH 65------------------
DATA LOSS 	 train 0.051 valid 0.196
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.2


 33%|██████████████████████                                             | 66/200 [01:45<03:34,  1.60s/it]

------------------EPOCH 66------------------
DATA LOSS 	 train 0.05 valid 0.189
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.19


 34%|██████████████████████▍                                            | 67/200 [01:46<03:32,  1.60s/it]

------------------EPOCH 67------------------
DATA LOSS 	 train 0.05 valid 0.192
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.19


 34%|██████████████████████▊                                            | 68/200 [01:48<03:31,  1.61s/it]

------------------EPOCH 68------------------
DATA LOSS 	 train 0.049 valid 0.189
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.19


 34%|███████████████████████                                            | 69/200 [01:50<03:30,  1.60s/it]

------------------EPOCH 69------------------
DATA LOSS 	 train 0.048 valid 0.205
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.21


 35%|███████████████████████▍                                           | 70/200 [01:51<03:33,  1.64s/it]

------------------EPOCH 70------------------
DATA LOSS 	 train 0.049 valid 0.178
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.18


 36%|███████████████████████▊                                           | 71/200 [01:53<03:33,  1.66s/it]

------------------EPOCH 71------------------
DATA LOSS 	 train 0.048 valid 0.188
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.19


 36%|████████████████████████                                           | 72/200 [01:55<03:31,  1.65s/it]

------------------EPOCH 72------------------
DATA LOSS 	 train 0.047 valid 0.18
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.18


 36%|████████████████████████▍                                          | 73/200 [01:56<03:29,  1.65s/it]

------------------EPOCH 73------------------
DATA LOSS 	 train 0.047 valid 0.179
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.18


 37%|████████████████████████▊                                          | 74/200 [01:58<03:26,  1.64s/it]

------------------EPOCH 74------------------
DATA LOSS 	 train 0.047 valid 0.192
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.19


 38%|█████████████████████████▏                                         | 75/200 [02:00<03:25,  1.64s/it]

------------------EPOCH 75------------------
DATA LOSS 	 train 0.046 valid 0.19
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.19


 38%|█████████████████████████▍                                         | 76/200 [02:01<03:22,  1.63s/it]

------------------EPOCH 76------------------
DATA LOSS 	 train 0.046 valid 0.189
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.05 valid 0.19


 38%|█████████████████████████▊                                         | 77/200 [02:03<03:18,  1.62s/it]

------------------EPOCH 77------------------
DATA LOSS 	 train 0.045 valid 0.189
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.19


 39%|██████████████████████████▏                                        | 78/200 [02:04<03:18,  1.63s/it]

------------------EPOCH 78------------------
DATA LOSS 	 train 0.045 valid 0.181
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 40%|██████████████████████████▍                                        | 79/200 [02:06<03:16,  1.63s/it]

------------------EPOCH 79------------------
DATA LOSS 	 train 0.045 valid 0.187
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.19


 40%|██████████████████████████▊                                        | 80/200 [02:08<03:14,  1.62s/it]

------------------EPOCH 80------------------
DATA LOSS 	 train 0.044 valid 0.171
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.17


 40%|███████████████████████████▏                                       | 81/200 [02:09<03:14,  1.63s/it]

------------------EPOCH 81------------------
DATA LOSS 	 train 0.044 valid 0.187
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.19


 41%|███████████████████████████▍                                       | 82/200 [02:11<03:11,  1.62s/it]

------------------EPOCH 82------------------
DATA LOSS 	 train 0.042 valid 0.183
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 42%|███████████████████████████▊                                       | 83/200 [02:13<03:12,  1.65s/it]

------------------EPOCH 83------------------
DATA LOSS 	 train 0.042 valid 0.178
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 42%|████████████████████████████▏                                      | 84/200 [02:14<03:08,  1.63s/it]

------------------EPOCH 84------------------
DATA LOSS 	 train 0.044 valid 0.198
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.2


 42%|████████████████████████████▍                                      | 85/200 [02:16<03:07,  1.63s/it]

------------------EPOCH 85------------------
DATA LOSS 	 train 0.043 valid 0.194
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.19


 43%|████████████████████████████▊                                      | 86/200 [02:17<03:05,  1.63s/it]

------------------EPOCH 86------------------
DATA LOSS 	 train 0.043 valid 0.184
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 44%|█████████████████████████████▏                                     | 87/200 [02:19<03:04,  1.63s/it]

------------------EPOCH 87------------------
DATA LOSS 	 train 0.041 valid 0.178
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 44%|█████████████████████████████▍                                     | 88/200 [02:21<03:02,  1.63s/it]

------------------EPOCH 88------------------
DATA LOSS 	 train 0.042 valid 0.181
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 44%|█████████████████████████████▊                                     | 89/200 [02:22<03:02,  1.64s/it]

------------------EPOCH 89------------------
DATA LOSS 	 train 0.041 valid 0.188
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.19


 45%|██████████████████████████████▏                                    | 90/200 [02:24<02:59,  1.64s/it]

------------------EPOCH 90------------------
DATA LOSS 	 train 0.041 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.17


 46%|██████████████████████████████▍                                    | 91/200 [02:26<02:57,  1.63s/it]

------------------EPOCH 91------------------
DATA LOSS 	 train 0.04 valid 0.177
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 46%|██████████████████████████████▊                                    | 92/200 [02:27<02:55,  1.63s/it]

------------------EPOCH 92------------------
DATA LOSS 	 train 0.04 valid 0.186
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.19


 46%|███████████████████████████████▏                                   | 93/200 [02:29<02:58,  1.67s/it]

------------------EPOCH 93------------------
DATA LOSS 	 train 0.039 valid 0.181
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 47%|███████████████████████████████▍                                   | 94/200 [02:31<03:05,  1.75s/it]

------------------EPOCH 94------------------
DATA LOSS 	 train 0.04 valid 0.177
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 48%|███████████████████████████████▊                                   | 95/200 [02:33<03:05,  1.77s/it]

------------------EPOCH 95------------------
DATA LOSS 	 train 0.039 valid 0.181
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 48%|████████████████████████████████▏                                  | 96/200 [02:35<03:05,  1.79s/it]

------------------EPOCH 96------------------
DATA LOSS 	 train 0.039 valid 0.169
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.17


 48%|████████████████████████████████▍                                  | 97/200 [02:36<03:04,  1.79s/it]

------------------EPOCH 97------------------
DATA LOSS 	 train 0.038 valid 0.182
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 49%|████████████████████████████████▊                                  | 98/200 [02:38<03:03,  1.80s/it]

------------------EPOCH 98------------------
DATA LOSS 	 train 0.038 valid 0.179
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 50%|█████████████████████████████████▏                                 | 99/200 [02:40<03:01,  1.79s/it]

------------------EPOCH 99------------------
DATA LOSS 	 train 0.038 valid 0.172
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.17


 50%|█████████████████████████████████                                 | 100/200 [02:42<03:00,  1.80s/it]

------------------EPOCH 100------------------
DATA LOSS 	 train 0.037 valid 0.182
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 50%|█████████████████████████████████▎                                | 101/200 [02:44<02:58,  1.81s/it]

------------------EPOCH 101------------------
DATA LOSS 	 train 0.037 valid 0.186
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.19


 51%|█████████████████████████████████▋                                | 102/200 [02:45<02:56,  1.80s/it]

------------------EPOCH 102------------------
DATA LOSS 	 train 0.038 valid 0.169
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.17


 52%|█████████████████████████████████▉                                | 103/200 [02:47<02:53,  1.79s/it]

------------------EPOCH 103------------------
DATA LOSS 	 train 0.037 valid 0.181
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 52%|██████████████████████████████████▎                               | 104/200 [02:49<02:52,  1.80s/it]

------------------EPOCH 104------------------
DATA LOSS 	 train 0.036 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.17


 52%|██████████████████████████████████▋                               | 105/200 [02:51<02:51,  1.81s/it]

------------------EPOCH 105------------------
DATA LOSS 	 train 0.037 valid 0.183
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 53%|██████████████████████████████████▉                               | 106/200 [02:53<02:49,  1.80s/it]

------------------EPOCH 106------------------
DATA LOSS 	 train 0.036 valid 0.171
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.17


 54%|███████████████████████████████████▎                              | 107/200 [02:54<02:47,  1.80s/it]

------------------EPOCH 107------------------
DATA LOSS 	 train 0.036 valid 0.175
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.17


 54%|███████████████████████████████████▋                              | 108/200 [02:56<02:42,  1.77s/it]

------------------EPOCH 108------------------
DATA LOSS 	 train 0.036 valid 0.18
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.18


 55%|███████████████████████████████████▉                              | 109/200 [02:58<02:38,  1.74s/it]

------------------EPOCH 109------------------
DATA LOSS 	 train 0.035 valid 0.169
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.17


 55%|████████████████████████████████████▎                             | 110/200 [02:59<02:35,  1.72s/it]

------------------EPOCH 110------------------
DATA LOSS 	 train 0.035 valid 0.166
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.04 valid 0.17


 56%|████████████████████████████████████▋                             | 111/200 [03:01<02:33,  1.72s/it]

------------------EPOCH 111------------------
DATA LOSS 	 train 0.035 valid 0.18
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.18


 56%|████████████████████████████████████▉                             | 112/200 [03:03<02:30,  1.70s/it]

------------------EPOCH 112------------------
DATA LOSS 	 train 0.035 valid 0.173
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 56%|█████████████████████████████████████▎                            | 113/200 [03:04<02:26,  1.69s/it]

------------------EPOCH 113------------------
DATA LOSS 	 train 0.035 valid 0.174
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 57%|█████████████████████████████████████▌                            | 114/200 [03:06<02:24,  1.68s/it]

------------------EPOCH 114------------------
DATA LOSS 	 train 0.034 valid 0.181
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.18


 57%|█████████████████████████████████████▉                            | 115/200 [03:08<02:22,  1.68s/it]

------------------EPOCH 115------------------
DATA LOSS 	 train 0.034 valid 0.171
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 58%|██████████████████████████████████████▎                           | 116/200 [03:09<02:20,  1.68s/it]

------------------EPOCH 116------------------
DATA LOSS 	 train 0.035 valid 0.169
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 58%|██████████████████████████████████████▌                           | 117/200 [03:11<02:18,  1.67s/it]

------------------EPOCH 117------------------
DATA LOSS 	 train 0.034 valid 0.174
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 59%|██████████████████████████████████████▉                           | 118/200 [03:13<02:15,  1.65s/it]

------------------EPOCH 118------------------
DATA LOSS 	 train 0.033 valid 0.174
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 60%|███████████████████████████████████████▎                          | 119/200 [03:14<02:13,  1.65s/it]

------------------EPOCH 119------------------
DATA LOSS 	 train 0.033 valid 0.176
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.18


 60%|███████████████████████████████████████▌                          | 120/200 [03:16<02:12,  1.66s/it]

------------------EPOCH 120------------------
DATA LOSS 	 train 0.033 valid 0.17
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 60%|███████████████████████████████████████▉                          | 121/200 [03:18<02:12,  1.68s/it]

------------------EPOCH 121------------------
DATA LOSS 	 train 0.033 valid 0.17
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 61%|████████████████████████████████████████▎                         | 122/200 [03:19<02:09,  1.66s/it]

------------------EPOCH 122------------------
DATA LOSS 	 train 0.033 valid 0.173
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 62%|████████████████████████████████████████▌                         | 123/200 [03:21<02:07,  1.66s/it]

------------------EPOCH 123------------------
DATA LOSS 	 train 0.033 valid 0.177
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.18


 62%|████████████████████████████████████████▉                         | 124/200 [03:23<02:06,  1.66s/it]

------------------EPOCH 124------------------
DATA LOSS 	 train 0.032 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 62%|█████████████████████████████████████████▎                        | 125/200 [03:24<02:05,  1.68s/it]

------------------EPOCH 125------------------
DATA LOSS 	 train 0.033 valid 0.176
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.18


 63%|█████████████████████████████████████████▌                        | 126/200 [03:26<02:03,  1.67s/it]

------------------EPOCH 126------------------
DATA LOSS 	 train 0.032 valid 0.17
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 64%|█████████████████████████████████████████▉                        | 127/200 [03:28<02:02,  1.68s/it]

------------------EPOCH 127------------------
DATA LOSS 	 train 0.032 valid 0.166
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 64%|██████████████████████████████████████████▏                       | 128/200 [03:29<02:01,  1.69s/it]

------------------EPOCH 128------------------
DATA LOSS 	 train 0.031 valid 0.172
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 64%|██████████████████████████████████████████▌                       | 129/200 [03:31<01:59,  1.68s/it]

------------------EPOCH 129------------------
DATA LOSS 	 train 0.032 valid 0.169
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 65%|██████████████████████████████████████████▉                       | 130/200 [03:33<01:57,  1.68s/it]

------------------EPOCH 130------------------
DATA LOSS 	 train 0.031 valid 0.174
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 66%|███████████████████████████████████████████▏                      | 131/200 [03:34<01:55,  1.67s/it]

------------------EPOCH 131------------------
DATA LOSS 	 train 0.031 valid 0.162
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 66%|███████████████████████████████████████████▌                      | 132/200 [03:36<01:54,  1.68s/it]

------------------EPOCH 132------------------
DATA LOSS 	 train 0.032 valid 0.181
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.18


 66%|███████████████████████████████████████████▉                      | 133/200 [03:38<01:52,  1.68s/it]

------------------EPOCH 133------------------
DATA LOSS 	 train 0.031 valid 0.167
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 67%|████████████████████████████████████████████▏                     | 134/200 [03:40<01:50,  1.68s/it]

------------------EPOCH 134------------------
DATA LOSS 	 train 0.031 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 68%|████████████████████████████████████████████▌                     | 135/200 [03:41<01:49,  1.68s/it]

------------------EPOCH 135------------------
DATA LOSS 	 train 0.031 valid 0.176
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.18


 68%|████████████████████████████████████████████▉                     | 136/200 [03:43<01:47,  1.68s/it]

------------------EPOCH 136------------------
DATA LOSS 	 train 0.031 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 68%|█████████████████████████████████████████████▏                    | 137/200 [03:45<01:45,  1.68s/it]

------------------EPOCH 137------------------
DATA LOSS 	 train 0.031 valid 0.171
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 69%|█████████████████████████████████████████████▌                    | 138/200 [03:46<01:44,  1.69s/it]

------------------EPOCH 138------------------
DATA LOSS 	 train 0.031 valid 0.167
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 70%|█████████████████████████████████████████████▊                    | 139/200 [03:48<01:42,  1.68s/it]

------------------EPOCH 139------------------
DATA LOSS 	 train 0.03 valid 0.178
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.18


 70%|██████████████████████████████████████████████▏                   | 140/200 [03:50<01:40,  1.67s/it]

------------------EPOCH 140------------------
DATA LOSS 	 train 0.03 valid 0.171
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 70%|██████████████████████████████████████████████▌                   | 141/200 [03:51<01:38,  1.67s/it]

------------------EPOCH 141------------------
DATA LOSS 	 train 0.031 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 71%|██████████████████████████████████████████████▊                   | 142/200 [03:53<01:37,  1.67s/it]

------------------EPOCH 142------------------
DATA LOSS 	 train 0.03 valid 0.172
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 72%|███████████████████████████████████████████████▏                  | 143/200 [03:55<01:34,  1.66s/it]

------------------EPOCH 143------------------
DATA LOSS 	 train 0.03 valid 0.167
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 72%|███████████████████████████████████████████████▌                  | 144/200 [03:56<01:32,  1.66s/it]

------------------EPOCH 144------------------
DATA LOSS 	 train 0.03 valid 0.17
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 72%|███████████████████████████████████████████████▊                  | 145/200 [03:58<01:31,  1.66s/it]

------------------EPOCH 145------------------
DATA LOSS 	 train 0.029 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 73%|████████████████████████████████████████████████▏                 | 146/200 [04:00<01:29,  1.66s/it]

------------------EPOCH 146------------------
DATA LOSS 	 train 0.03 valid 0.171
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 74%|████████████████████████████████████████████████▌                 | 147/200 [04:01<01:28,  1.66s/it]

------------------EPOCH 147------------------
DATA LOSS 	 train 0.029 valid 0.163
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 74%|████████████████████████████████████████████████▊                 | 148/200 [04:03<01:26,  1.66s/it]

------------------EPOCH 148------------------
DATA LOSS 	 train 0.029 valid 0.17
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 74%|█████████████████████████████████████████████████▏                | 149/200 [04:05<01:24,  1.67s/it]

------------------EPOCH 149------------------
DATA LOSS 	 train 0.029 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 75%|█████████████████████████████████████████████████▌                | 150/200 [04:06<01:23,  1.67s/it]

------------------EPOCH 150------------------
DATA LOSS 	 train 0.029 valid 0.165
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 76%|█████████████████████████████████████████████████▊                | 151/200 [04:08<01:22,  1.68s/it]

------------------EPOCH 151------------------
DATA LOSS 	 train 0.029 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 76%|██████████████████████████████████████████████████▏               | 152/200 [04:10<01:20,  1.67s/it]

------------------EPOCH 152------------------
DATA LOSS 	 train 0.029 valid 0.166
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 76%|██████████████████████████████████████████████████▍               | 153/200 [04:11<01:18,  1.67s/it]

------------------EPOCH 153------------------
DATA LOSS 	 train 0.029 valid 0.165
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 77%|██████████████████████████████████████████████████▊               | 154/200 [04:13<01:16,  1.67s/it]

------------------EPOCH 154------------------
DATA LOSS 	 train 0.029 valid 0.162
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 78%|███████████████████████████████████████████████████▏              | 155/200 [04:15<01:15,  1.68s/it]

------------------EPOCH 155------------------
DATA LOSS 	 train 0.029 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 78%|███████████████████████████████████████████████████▍              | 156/200 [04:16<01:13,  1.68s/it]

------------------EPOCH 156------------------
DATA LOSS 	 train 0.028 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 78%|███████████████████████████████████████████████████▊              | 157/200 [04:18<01:11,  1.67s/it]

------------------EPOCH 157------------------
DATA LOSS 	 train 0.029 valid 0.171
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 79%|████████████████████████████████████████████████████▏             | 158/200 [04:20<01:09,  1.66s/it]

------------------EPOCH 158------------------
DATA LOSS 	 train 0.029 valid 0.167
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 80%|████████████████████████████████████████████████████▍             | 159/200 [04:21<01:08,  1.67s/it]

------------------EPOCH 159------------------
DATA LOSS 	 train 0.028 valid 0.165
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 80%|████████████████████████████████████████████████████▊             | 160/200 [04:23<01:06,  1.66s/it]

------------------EPOCH 160------------------
DATA LOSS 	 train 0.028 valid 0.169
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 80%|█████████████████████████████████████████████████████▏            | 161/200 [04:25<01:04,  1.66s/it]

------------------EPOCH 161------------------
DATA LOSS 	 train 0.028 valid 0.166
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 81%|█████████████████████████████████████████████████████▍            | 162/200 [04:26<01:02,  1.66s/it]

------------------EPOCH 162------------------
DATA LOSS 	 train 0.028 valid 0.17
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 82%|█████████████████████████████████████████████████████▊            | 163/200 [04:28<01:01,  1.67s/it]

------------------EPOCH 163------------------
DATA LOSS 	 train 0.028 valid 0.169
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 82%|██████████████████████████████████████████████████████            | 164/200 [04:30<00:59,  1.66s/it]

------------------EPOCH 164------------------
DATA LOSS 	 train 0.027 valid 0.165
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 82%|██████████████████████████████████████████████████████▍           | 165/200 [04:31<00:58,  1.66s/it]

------------------EPOCH 165------------------
DATA LOSS 	 train 0.027 valid 0.161
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 83%|██████████████████████████████████████████████████████▊           | 166/200 [04:33<00:56,  1.67s/it]

------------------EPOCH 166------------------
DATA LOSS 	 train 0.028 valid 0.163
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 84%|███████████████████████████████████████████████████████           | 167/200 [04:35<00:54,  1.66s/it]

------------------EPOCH 167------------------
DATA LOSS 	 train 0.028 valid 0.164
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 84%|███████████████████████████████████████████████████████▍          | 168/200 [04:36<00:53,  1.66s/it]

------------------EPOCH 168------------------
DATA LOSS 	 train 0.028 valid 0.165
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 84%|███████████████████████████████████████████████████████▊          | 169/200 [04:38<00:51,  1.66s/it]

------------------EPOCH 169------------------
DATA LOSS 	 train 0.027 valid 0.169
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 85%|████████████████████████████████████████████████████████          | 170/200 [04:40<00:50,  1.67s/it]

------------------EPOCH 170------------------
DATA LOSS 	 train 0.027 valid 0.167
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 86%|████████████████████████████████████████████████████████▍         | 171/200 [04:41<00:48,  1.67s/it]

------------------EPOCH 171------------------
DATA LOSS 	 train 0.027 valid 0.169
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 86%|████████████████████████████████████████████████████████▊         | 172/200 [04:43<00:46,  1.66s/it]

------------------EPOCH 172------------------
DATA LOSS 	 train 0.027 valid 0.166
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 86%|█████████████████████████████████████████████████████████         | 173/200 [04:45<00:45,  1.67s/it]

------------------EPOCH 173------------------
DATA LOSS 	 train 0.027 valid 0.165
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 87%|█████████████████████████████████████████████████████████▍        | 174/200 [04:46<00:43,  1.67s/it]

------------------EPOCH 174------------------
DATA LOSS 	 train 0.027 valid 0.164
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 88%|█████████████████████████████████████████████████████████▊        | 175/200 [04:48<00:41,  1.68s/it]

------------------EPOCH 175------------------
DATA LOSS 	 train 0.027 valid 0.162
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 88%|██████████████████████████████████████████████████████████        | 176/200 [04:50<00:40,  1.68s/it]

------------------EPOCH 176------------------
DATA LOSS 	 train 0.027 valid 0.16
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 88%|██████████████████████████████████████████████████████████▍       | 177/200 [04:51<00:38,  1.66s/it]

------------------EPOCH 177------------------
DATA LOSS 	 train 0.027 valid 0.169
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 89%|██████████████████████████████████████████████████████████▋       | 178/200 [04:53<00:36,  1.67s/it]

------------------EPOCH 178------------------
DATA LOSS 	 train 0.027 valid 0.169
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 90%|███████████████████████████████████████████████████████████       | 179/200 [04:55<00:35,  1.67s/it]

------------------EPOCH 179------------------
DATA LOSS 	 train 0.027 valid 0.167
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 90%|███████████████████████████████████████████████████████████▍      | 180/200 [04:56<00:33,  1.68s/it]

------------------EPOCH 180------------------
DATA LOSS 	 train 0.027 valid 0.165
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 90%|███████████████████████████████████████████████████████████▋      | 181/200 [04:58<00:31,  1.68s/it]

------------------EPOCH 181------------------
DATA LOSS 	 train 0.026 valid 0.167
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 91%|████████████████████████████████████████████████████████████      | 182/200 [05:00<00:30,  1.69s/it]

------------------EPOCH 182------------------
DATA LOSS 	 train 0.026 valid 0.166
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 92%|████████████████████████████████████████████████████████████▍     | 183/200 [05:01<00:28,  1.68s/it]

------------------EPOCH 183------------------
DATA LOSS 	 train 0.026 valid 0.164
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 92%|████████████████████████████████████████████████████████████▋     | 184/200 [05:03<00:26,  1.67s/it]

------------------EPOCH 184------------------
DATA LOSS 	 train 0.026 valid 0.162
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 92%|█████████████████████████████████████████████████████████████     | 185/200 [05:05<00:25,  1.67s/it]

------------------EPOCH 185------------------
DATA LOSS 	 train 0.026 valid 0.164
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 93%|█████████████████████████████████████████████████████████████▍    | 186/200 [05:06<00:23,  1.67s/it]

------------------EPOCH 186------------------
DATA LOSS 	 train 0.026 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 94%|█████████████████████████████████████████████████████████████▋    | 187/200 [05:08<00:21,  1.67s/it]

------------------EPOCH 187------------------
DATA LOSS 	 train 0.026 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 94%|██████████████████████████████████████████████████████████████    | 188/200 [05:10<00:19,  1.67s/it]

------------------EPOCH 188------------------
DATA LOSS 	 train 0.026 valid 0.166
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 94%|██████████████████████████████████████████████████████████████▎   | 189/200 [05:11<00:18,  1.66s/it]

------------------EPOCH 189------------------
DATA LOSS 	 train 0.026 valid 0.164
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 95%|██████████████████████████████████████████████████████████████▋   | 190/200 [05:13<00:16,  1.66s/it]

------------------EPOCH 190------------------
DATA LOSS 	 train 0.026 valid 0.165
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 96%|███████████████████████████████████████████████████████████████   | 191/200 [05:15<00:15,  1.67s/it]

------------------EPOCH 191------------------
DATA LOSS 	 train 0.026 valid 0.163
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 96%|███████████████████████████████████████████████████████████████▎  | 192/200 [05:16<00:13,  1.67s/it]

------------------EPOCH 192------------------
DATA LOSS 	 train 0.025 valid 0.162
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 96%|███████████████████████████████████████████████████████████████▋  | 193/200 [05:18<00:11,  1.67s/it]

------------------EPOCH 193------------------
DATA LOSS 	 train 0.025 valid 0.168
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 97%|████████████████████████████████████████████████████████████████  | 194/200 [05:20<00:10,  1.68s/it]

------------------EPOCH 194------------------
DATA LOSS 	 train 0.026 valid 0.166
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.17


 98%|████████████████████████████████████████████████████████████████▎ | 195/200 [05:21<00:08,  1.70s/it]

------------------EPOCH 195------------------
DATA LOSS 	 train 0.025 valid 0.164
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 98%|████████████████████████████████████████████████████████████████▋ | 196/200 [05:23<00:06,  1.72s/it]

------------------EPOCH 196------------------
DATA LOSS 	 train 0.026 valid 0.161
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 98%|█████████████████████████████████████████████████████████████████ | 197/200 [05:25<00:05,  1.70s/it]

------------------EPOCH 197------------------
DATA LOSS 	 train 0.026 valid 0.164
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


 99%|█████████████████████████████████████████████████████████████████▎| 198/200 [05:27<00:03,  1.69s/it]

------------------EPOCH 198------------------
DATA LOSS 	 train 0.025 valid 0.163
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.03 valid 0.16


100%|█████████████████████████████████████████████████████████████████▋| 199/200 [05:28<00:01,  1.68s/it]

------------------EPOCH 199------------------
DATA LOSS 	 train 0.025 valid 0.165
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.02 valid 0.16


100%|██████████████████████████████████████████████████████████████████| 200/200 [05:30<00:00,  1.65s/it]

------------------EPOCH 200------------------
DATA LOSS 	 train 0.025 valid 0.161
KL LOSS 	 train 0.0 valid 0.0
ELBO LOSS 	 train 0.02 valid 0.16





In [43]:
M_opt = 1
model_used.update_n_samples(n_samples=M_opt)
Y_pred = model_used.forward_dist(X_test, aleat_bool)
Y_pred_original = inverse_transform(Y_pred)

In [44]:
opt_cvar = cmu.CVaRQP(M_opt, n_assets, beta, min_return, torch.tensor(Y_original), opt_cvar_gZ, dev)
opt_cvar.end_loss_dist(Y_pred_original, Y_test_original)

tensor(27.0925, dtype=torch.float64, grad_fn=<MeanBackward0>)

In [39]:
M_opt = 1
model_used.update_n_samples(n_samples=M_opt)
Y_pred_val = model_used.forward_dist(X_val, aleat_bool)
Y_pred_val_original = inverse_transform(Y_pred_val)

opt_cvar = cmu.CVaRQP(M_opt, n_assets, beta, min_return, torch.tensor(Y_original), opt_cvar_gZ, dev)
opt_cvar.end_loss_dist(Y_pred_val_original, Y_val_original)

tensor(8.4273, dtype=torch.float64, grad_fn=<MeanBackward0>)

In [26]:
opt_cvar.end_loss_dist(Y_pred_original, Y_test_original)

tensor(42.9886, dtype=torch.float64, grad_fn=<MeanBackward0>)

In [49]:
opt_cvar.end_loss_dist(Y_pred_original, Y_test_original)

tensor(55.0741, dtype=torch.float64, grad_fn=<MeanBackward0>)

In [18]:
opt_cvar = cmu.CVaRQP(M_opt, n_assets, beta, min_return, torch.tensor(Y_original), opt_cvar_gZ, dev)
opt_cvar.end_loss_dist(Y_pred_original, Y_test_original)

tensor(63.7008, dtype=torch.float64, grad_fn=<MeanBackward0>)

tensor(41.8511, dtype=torch.float64, grad_fn=<MeanBackward0>)

In [27]:
Y_pred_original.shape

torch.Size([1, 1000, 20])

In [50]:
opt_cvar = cmu.CVaRQP(M_opt, n_assets, beta, min_return, torch.tensor(Y_original), opt_cvar_gZ, dev)
zz = opt_cvar.forward(torch.permute(Y_pred_original, (1, 0, 2)))

In [56]:
Y_test_original.shape

torch.Size([600, 10])

In [66]:
(Y_test_original*zz).sum(1).clip(None, 0).mean()

tensor(-67.2309, dtype=torch.float64, grad_fn=<MeanBackward0>)

In [53]:
zz[0]

tensor([1.6260e+01, 4.9736e+01, 2.2901e-14, 8.6343e+01, 2.8042e+01, 6.7583e+01,
        5.7734e+01, 5.1143e+00, 1.2606e+01, 6.7443e-14], dtype=torch.float64,
       grad_fn=<SelectBackward0>)

In [40]:
Y_pred_original.shape

torch.Size([32, 600, 10])

In [42]:
opt_cvar = cmu.CVaRQP(1, n_assets, beta, min_return, torch.tensor(Y_original), opt_cvar_gZ, dev)
opt_cvar.end_loss_dist(Y_test_original.unsqueeze(0), Y_test_original)

tensor(10.6740, dtype=torch.float64)

In [33]:
Y_original.mean(), Y_original.std(), Y_original.max(), Y_original.min()

(0.13693512093163707,
 0.8841071259180603,
 2.6387125058604113,
 -2.7960391110496476)

In [25]:
Y_pred_original.shape

torch.Size([32, 600, 10])

In [22]:
opt_cvar.end_loss_dist(Y_pred_original, torch.permute(torch.tensor(Y_test_dist), (1, 0, 2)))

RuntimeError: The size of tensor a (3) must match the size of tensor b (202) at non-singleton dimension 1

In [22]:
zstar_pred_new = opt_cvar.forward(torch.permute(Y_pred_original, (1,0,2)))

In [23]:
zstar_pred_new

tensor([[ 7.3330e+01,  8.4389e+01,  1.1276e+02,  2.1225e-14,  1.1709e+02],
        [ 6.7217e+01,  9.0927e+01,  1.1521e+02,  2.0232e-14,  1.1385e+02],
        [ 1.1074e+01,  1.0179e+02, -5.4336e-14, -2.2984e-14,  2.5212e+02],
        [ 1.5634e+01,  1.3213e+02,  8.3355e+01,  4.7789e+01,  1.6134e+02]],
       dtype=torch.float64, grad_fn=<SliceBackward0>)

In [24]:
alphastar_pred_new = opt_cvar.OP_Z.forward(torch.permute(Y_test_original.unsqueeze(0), (1,0,2)), zstar_pred_new)

In [25]:
alphastar_pred_new

tensor([1.6425e+02, 9.3361e+01, 1.0529e-14, 4.5766e-15], dtype=torch.float64,
       grad_fn=<SelectBackward0>)

In [36]:
zstar_pred = op_old.minimize_cvar_dataset(Y_pred_original.detach().numpy())
alphastar_pred = op_old.minimize_cvar_given_z_dataset(np.expand_dims(Y_test_original, 0), zstar_pred)
cvar_pred = op_old.calc_cvar_dataset(np.expand_dims(Y_test_original, 0), zstar_pred, alphastar_pred)

> [0;32m/tmp/ipykernel_256377/2293865252.py[0m(59)[0;36mminimize_cvar[0;34m()[0m
[0;32m     57 [0;31m[0;34m[0m[0m
[0m[0;32m     58 [0;31m[0;34m[0m[0m
[0m[0;32m---> 59 [0;31m        [0;32mreturn[0m [0mzstar[0m[0;34m,[0m [0mustar[0m[0;34m,[0m [0malphastar[0m[0;34m,[0m [0mf_opt[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m[0;34m[0m[0m
[0m[0;32m     61 [0;31m[0;34m[0m[0m
[0m
ipdb> l
[1;32m     54 [0m[0;34m[0m[0m
[1;32m     55 [0m        [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[1;32m     56 [0m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m     57 [0m[0;34m[0m[0m
[1;32m     58 [0m[0;34m[0m[0m
[0;32m---> 59 [0;31m        [0;32mreturn[0m [0mzstar[0m[0;34m,[0m [0mustar[0m[0;34m,[0m [0malphastar[0m[0;34m,[0m [0mf_opt[0m[0;34m[0m[0;34m[0m[0m
[0m[1;32m     60 [0m[0;34m[0m[0m
[1;32m     61 [0m[0;34m[0m[0m
[1;32m     6

BdbQuit: 

In [32]:
zstar_pred

array([[  0.        , 461.23839069,   0.        ,   0.        ,
          0.        ],
       [  0.        ,   0.        , 345.19590618,   0.        ,
          0.        ],
       [ 11.07380712, 101.79467917,   0.        ,   0.        ,
        252.11989681],
       [ 15.63351324, 132.13076345,  83.35474558,  47.78917676,
        161.33640877]])

In [28]:
alphastar_pred

array([346.91013, 280.73404,   0.     ,   0.     ], dtype=float32)

In [30]:
cvar_pred.mean()

156.91104125976562

In [37]:
op_old.uy

array([ 0.18839605,  0.21680762,  0.28969057, -0.08941982,  0.30082463])

In [24]:
(zstar_pred*op_old.uy).sum(axis=1)

array([100., 100., 100., 100.])

In [25]:
(zstar_pred_new*opt_cvar.uy).sum(axis=1)

tensor([101.0000, 101.0000, 100.0000, 100.0000], dtype=torch.float64,
       grad_fn=<SumBackward1>)

In [35]:
((Y_pred_original.mean(0)).detach().numpy()*zstar_pred).sum(1)

array([384.86705589, 214.75938277,  69.39420483,  -6.3605234 ])

In [26]:
zstar_pred

array([[  0.        , 461.23839069,   0.        ,   0.        ,
          0.        ],
       [  0.        ,   0.        , 345.19590618,   0.        ,
          0.        ],
       [ 11.07380712, 101.79467917,   0.        ,   0.        ,
        252.11989681],
       [ 15.63351324, 132.13076345,  83.35474558,  47.78917676,
        161.33640877]])

In [None]:
zstar_pred = opt_cvar.forward(Y_pred_original)
alphastar_pred = op_old.minimize_cvar_given_z_dataset(np.expand_dims(Y_test_original, 0), zstar_pred)
cvar_pred = op_old.calc_cvar_dataset(np.expand_dims(Y_test_original, 0), zstar_pred, alphastar_pred)

In [47]:
(zstar_pred_new[1].detach().numpy() - zstar_pred[1])

array([6.16415948e-15, 3.94899537e+03, 3.62644610e+03, 4.84417479e+03,
       4.83752559e+03, 4.89291048e+03, 8.26435561e+03, 5.53094399e+03,
       5.34771620e+03, 3.18915729e+03, 4.30792665e+03, 5.98957556e+03,
       4.85176412e+03, 2.29808733e-12, 2.52292449e+03, 4.08726661e+03,
       1.79131885e+03, 2.77712095e-12, 3.26970287e-13, 6.33605259e-13])

In [58]:
zstar_pred_new[5]

tensor([1.0777e+03, 4.3437e+03, 2.3827e+03, 4.3540e+03, 5.5204e+03, 4.4254e+03,
        5.3152e+03, 7.9586e+03, 5.0483e+03, 5.0448e+03, 4.2246e+03, 4.5593e+03,
        4.4394e+02, 1.9223e+03, 2.8281e-12, 3.8541e+02, 1.2015e+03, 1.3990e-12,
        8.0155e+02, 4.0557e+03], dtype=torch.float64,
       grad_fn=<SelectBackward0>)

In [59]:
zstar_pred[5]

array([  0.        , 499.72849577,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ])

In [30]:
cvar_pred

array([199.69296265,   0.        ,   0.        , ...,   0.        ,
         0.        ,   0.        ])

In [None]:
cvar_pred.mean()

In [37]:
zstar_test = op_solver_dist.minimize_cvar_dataset(np.expand_dims(Y_test_original, 0))
alphastar_test = op_solver_dist.minimize_cvar_given_z_dataset(np.expand_dims(Y_test_original, 0), zstar_test)
cvar_test = op_solver_dist.calc_cvar_dataset(np.expand_dims(Y_test_original, 0), zstar_test, alphastar_test)

In [38]:
cvar_test.mean()

2.130949

In [81]:
cvar_pred = alphastar_pred + (1/(1-beta))*(np.maximum(-((Y_dist*zstar_pred).sum(2)) - alphastar_pred, 0)).mean(0)

In [82]:
f_pred_dataset = cvar_pred.mean()

In [83]:
fair_regret = f_pred_dataset - op_solver.cvarstar.mean()

In [84]:
fair_regret

4.802389564712255

In [15]:
op_solver.alphastar + (1/(1-beta))*(np.maximum(-((Y_dist*op_solver.zstar).sum(2)) - op_solver.alphastar, 0)).mean(0)

array([14.53573779, 15.40195475, 16.40051164, 15.38380813, 14.62208475,
       17.06472651, 16.73093918, 15.48154203, 15.9300598 , 16.2041831 ,
       15.73938929, 16.89708193, 15.03783049, 15.81537838, 16.31737609,
       15.29510803, 16.09471929, 14.65658093, 15.2551622 , 15.96507641,
       17.64164752, 17.32653127, 14.25618879, 13.61590662, 16.92963039,
       15.39177444, 15.94804965, 15.49853349, 17.32618356, 15.23503582,
       15.87627014, 15.36836605, 16.97589351, 14.78739717, 16.26122729,
       18.07661054, 15.84807938, 14.6334357 , 16.64567533, 15.43718012,
       15.94837282, 14.39151073, 16.18164058, 14.63956316, 16.72973545,
       17.41537503, 15.88181541, 16.850747  , 15.30816212, 13.10728099,
       13.95678749, 15.29048891, 15.62895434, 16.26142701, 15.8916077 ,
       15.7978594 , 16.34415953, 15.10308489, 15.14827456, 15.24617177,
       16.37071953, 15.95829432, 14.10051356, 14.39860641, 14.34294722,
       15.74917112, 14.66477814, 15.33982245, 16.17635938, 14.59

In [None]:
def compute_cvar(alpha, z, y):
    

In [154]:
1/(Y_dist.shape[0]*(1-beta)) * Y_dist

(1000, 1000, 20)

In [None]:
Y_dist

In [159]:
Y_dist*zstar

array([[[-0.00000000e+00, -0.00000000e+00, -0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [-0.00000000e+00, -8.01762101e-02,  0.00000000e+00, ...,
         -0.00000000e+00, -0.00000000e+00,  7.26688064e-01],
        [-2.77114335e-01, -3.74680090e-01, -0.00000000e+00, ...,
         -0.00000000e+00, -0.00000000e+00, -3.38799630e+00],
        ...,
        [-1.75644780e-02, -2.58917253e+00, -0.00000000e+00, ...,
          5.05400752e+00, -3.07404112e-01, -3.07341864e+00],
        [-5.71981859e-01, -0.00000000e+00, -0.00000000e+00, ...,
         -0.00000000e+00,  1.60096702e-01,  0.00000000e+00],
        [-3.02333994e-02, -0.00000000e+00, -0.00000000e+00, ...,
          0.00000000e+00,  2.50841331e+00,  7.44609135e-02]],

       [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  1.62525526e-01,  0.00000000e+00, ...,
          0.00000000e+00,  0.00000000e

In [155]:
Y_dist.shape[0]

1000

In [151]:
Y_pred.shape

(50, 1000, 20)

In [137]:
Y_pred

array([[[-0.17638682, -0.09089706, -0.3830332 , ...,  0.06720436,
          0.42834576,  0.26608252],
        [ 0.32939309,  0.55933907,  0.0138017 , ..., -0.38775832,
          0.38721256,  0.48453956],
        [-0.08600242, -0.06805484,  0.51201416, ..., -0.30176222,
         -0.30011163,  0.48041784],
        ...,
        [ 0.5547552 ,  0.53806358,  0.41646602, ..., -0.1646881 ,
         -0.38714016,  0.0585669 ],
        [-0.259434  , -0.03745498, -0.23963089, ...,  0.41290765,
          0.06538298, -0.25733343],
        [ 0.1752564 ,  0.55979906, -0.1785693 , ...,  0.22204114,
          0.43828789, -0.28569533]],

       [[ 0.25746872, -0.39684173, -0.3284593 , ...,  0.3906845 ,
         -0.39183591,  0.16925576],
        [-0.01981441, -0.05351203,  0.32120704, ...,  0.0888863 ,
          0.13828446,  0.25011441],
        [-0.23420105, -0.14614686,  0.52928029, ...,  0.4170639 ,
          0.55037046,  0.51355313],
        ...,
        [ 0.54040521,  0.23890133, -0.31902953, ..., -

In [8]:
Y_train[:50].mean(0)

array([ 0.01103508,  0.18582536,  0.02757531,  0.176867  ,  0.28287247,
        0.24011691,  0.10373269,  0.22923871,  0.30873738, -0.07671305,
        0.39660958, -0.21470825,  0.23993704, -0.19205835,  0.06976262,
       -0.07803554,  0.40508948, -0.03724071,  0.02091073,  0.40063581])

In [9]:
Y_train[:50].std(0)

array([0.41032437, 0.46137744, 0.52203211, 0.58222118, 0.63199747,
       0.66359498, 0.67581864, 0.67831638, 0.68665954, 0.70163281,
       0.70021433, 0.67717812, 0.68244412, 0.71339072, 0.69029142,
       0.6841247 , 0.67276091, 0.62925573, 0.6156747 , 0.7166073 ])

In [10]:
op_solver.minimize_cvar(Y_train[:50])

([0.0,
  0.0,
  0.0,
  0.0,
  10.93680879473024,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  5.475523075674423,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  4.652755334066428,
  0.0,
  0.0,
  3.8449485359989164],
 [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.27095111058335786,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 3.958868157503432,
 4.067248601736775)

In [34]:
y = Y_train

In [133]:
y.mean(axis=0)

array([ 0.09154903,  0.12018547, -0.00612775,  0.08615439,  0.07785487])

In [125]:
y.std(axis=0)

array([0.28204021, 0.29966942, 0.2348671 , 0.25421325, 0.31059858])

In [126]:
m.status.name

'OPTIMAL'

In [127]:
alpha_opt = -1

In [128]:
(1/n_samples)*(1/(1-beta))*np.sum(np.clip(((-y*np.array(argmins[:5])).sum(axis=1) - alpha_opt), 0, None))

71.47976078082966

In [129]:
(y.mean(axis=0)*np.array(argmins[:5])).sum()

9.999999999999998

In [134]:
argmins

[53.23569452514242,
 18.33912369004899,
 0.0,
 30.08292387425699,
 4.372902220899407,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.4596077782392385,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 17.568190336620138]

In [135]:
f_opt

18.152033447915834

In [3]:
(1/(1 - B))*(1/n)

1.9999999999999982

In [None]:


class LPKnapsackProblem():
    def __init__(self, Y, probs_mean):
        self.Y = Y
        self.probs_mean = probs_mean
        self.n_items = self.Y.shape[1]
        self.weights, self.capacity = self.define_knapsack()
        
    def define_knapsack(self):
        weights = self.Y.mean(axis=0)*self.probs_mean + 80*np.random.random(size=self.n_items)
        capacity = np.random.uniform(0.4, 0.6) * sum(weights)
        return weights, int(capacity)
    
    def solve_knapsack_mip(self, values, activations):
        m = ModelMip("knapsack")
        m.verbose = 0
        z = ([m.add_var(var_type=CONTINUOUS) for i in range(0, self.n_items)])
        m.objective = minimize(
            -xsum(values[i] * z[i] for i in range(0, self.n_items)))
        
        for i in range(0, self.n_items):
            m += z[i] <= 1
            m += -z[i] <= 0
            
        m += xsum(
            self.weights[i] * z[i] for i in range(0, self.n_items)) <= self.capacity
        m += xsum(z[i] for i in range(0, self.n_items) if activations[i]==0) <= 0
        
        m.optimize()
        f_opt = m.objective_value
        argmins = []
        for v in m.vars:
            argmins.append(v.x)
        return argmins, f_opt