In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
import math

import numpy as np
import pandas as pd
from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()
from matplotlib import pyplot as plt

from sklearn.preprocessing import StandardScaler, RobustScaler, MinMaxScaler
from keras.models import Sequential
from keras.layers import Dense, LSTM, TimeDistributed, BatchNormalization
from keras.callbacks import EarlyStopping

import datetime as dt

from vrae.vrae import VRAE
from torch.utils.data import TensorDataset
import torch

Using TensorFlow backend.


In [2]:
path = "C:/Users/OPTIMUSPRIME/Desktop/Studia/Magisterka/Faints-Prediction/"
BP_filename = "BP.csv"
HR_filename = "HR.csv"

train_indices =  [str(indx[0]) for indx in pd.read_csv(path + "DATA/training_set.txt").values.tolist()]
test_indices =  [str(indx[0]) for indx in pd.read_csv(path + "DATA/test_set.txt").values.tolist()]
validation_indices =  [str(indx[0]) for indx in pd.read_csv(path + "DATA/validation_set.txt").values.tolist()]
all_indices = train_indices + test_indices + validation_indices

In [3]:
def shift(xs, n):
    if n >= 0:
        return np.concatenate((np.full(n, np.nan), xs[:-n]))
    else:
        return np.concatenate((xs[-n:], np.full(-n, np.nan)))

    
def series_to_supervised(share_prices, timestamps, input_time_steps, dropnan=True):
    share_prices_df = pd.DataFrame(share_prices)
    timestamps_df = pd.DataFrame(timestamps)
    share_prices_timeseries = list()
    timestamps_timeseries = list()
    
    for i in range(input_time_steps-1, -1, -1):
        share_prices_timeseries.append(share_prices_df.shift(i))
        timestamps_timeseries.append(timestamps_df.shift(i))
    
    aggregated_share_prices = pd.concat(share_prices_timeseries, axis=1)
    aggregated_timestamps = pd.concat(timestamps_timeseries, axis=1)
    
    if dropnan:
        aggregated_share_prices.dropna(inplace=True)
        aggregated_timestamps.dropna(inplace=True)

    aggregated_timestamps = aggregated_timestamps.values
    aggregated_share_prices = aggregated_share_prices.values
    
    not_overlapping_indexes = range(0, 
                                    len(aggregated_share_prices), 
                                    input_time_steps)
    
    aggregated_timestamps = aggregated_timestamps[not_overlapping_indexes]
    aggregated_share_prices = aggregated_share_prices[not_overlapping_indexes]
    return aggregated_share_prices, aggregated_timestamps
 

def split(BP_data, HR_data, col, time_steps):
    X = []
    y = []

    BP_supervised, HR_supervised = series_to_supervised(BP_data[col], HR_data[col], time_steps)
    for BP_interval, HR_interval in zip(BP_supervised, HR_supervised):
        BP_HR_interval = []
        for BP_time_step, HR_time_step in zip(BP_interval, HR_interval):
            BP_HR_interval.append([BP_time_step, HR_time_step])
        X.append(BP_HR_interval)
        if labels[col] == 'Synkope': label = [0., 1.] 
        else: label = [1., 0.]
        y.append(label)
    return np.array(X), np.array(y)


def split_df(BP_data, HR_data, time_steps):
    X = np.array([])
    y = np.array([])
    for col in BP_data:
        X_single, y_single = split(BP_data, HR_data, col, time_steps)
        if X.size == 0:
            X = X_single
            y = y_single
        else:
            X = np.concatenate((X, X_single))
            y = np.concatenate((y, y_single))
    return np.array(X), np.array(y)

In [4]:
BP_data = pd.read_csv(path + "DATA/" + BP_filename, low_memory=False)[all_indices]
HR_data = pd.read_csv(path + "DATA/" + HR_filename, low_memory=False)[all_indices]
labels = {col: BP_data[col].iloc[0] for col in BP_data}

BP_max_value = BP_data.iloc[1:].astype(np.float32).max().max()
BP_min_value = BP_data.iloc[1:].astype(np.float32).min().min()
HR_max_value = HR_data.iloc[1:].astype(np.float32).max().max()
HR_min_value = HR_data.iloc[1:].astype(np.float32).min().min()

BP_scaler = MinMaxScaler().fit(np.array([BP_min_value, BP_max_value]).reshape(-1,1))
HR_scaler = MinMaxScaler().fit(np.array([HR_min_value, HR_max_value]).reshape(-1,1))

BP_data_scaled = BP_data.iloc[1:].astype(np.float32).copy()
BP_data_scaled[all_indices] = BP_scaler.transform(BP_data_scaled[all_indices])

HR_data_scaled = HR_data.iloc[1:].astype(np.float32).copy()
HR_data_scaled[all_indices] = HR_scaler.transform(HR_data_scaled[all_indices])

In [5]:
part_to_drop = 0.4
BP_data_scaled_trimmed = BP_data_scaled.copy()
HR_data_scaled_trimmed = HR_data_scaled.copy()
for col in BP_data_scaled_trimmed:
    n_rows_to_drop = int(part_to_drop * np.count_nonzero(~np.isnan(BP_data_scaled_trimmed[col])))

    BP_data_scaled_trimmed[col] = shift(BP_data_scaled_trimmed[col], -n_rows_to_drop)
    HR_data_scaled_trimmed[col] = shift(HR_data_scaled_trimmed[col], -n_rows_to_drop)

In [19]:
#32 -> 8 -> 2
#64 -> 16 -> 2
#128 -> 32 -> 2
#wybor parametrow

# time_steps = [32, 64, 128]
time_steps = [128, 64, 32] 
hidden_sizes = [32, 16, 8]
# r_layers = [0, 2, 3, 4]
# r_layers = [2]
# LSTM_cells = [1, 16, 64, 256]
# LSTM_cells = [64]
# neurons = [5, 20, 50]
# neurons = [5]
# epochs = [1, 2, 3, 4, 5]
epochs = [10, 40, 60]
# batch_size = [32, 64, 128]
batch_sizes = [32, 64, 128]
# dropout_rates = [0.2, 0.4, 0.0]
# clips = [True, False]

hidden_size = 60 # 30? 60?
hidden_layer_depth = 1 #2?
latent_length = 2 #10? 20?
batch_size = 32
learning_rate = 0.0005
n_epochs = 40 #20? 80?
dropout_rate = 0.2
optimizer = 'Adam'  # options: ADAM, SGD
cuda = False  # options: True, False
print_every = 30
clip = True  # options: True, False
max_grad_norm = 5
loss = 'MSELoss'  # options: SmoothL1Loss, MSELoss
block = 'LSTM'  # options: LSTM, GRU

In [20]:
def choose_color(label):
    if label[0] == 1.0:
        return 'r'
    elif label[0] == 0.0:
        return 'b'
    else:
        return 'm'
    
def plot_latent_space(X, Y, typ, directory):
    latent_space = vrae.transform(TensorDataset(torch.from_numpy(X)))
    
    lsl = latent_space.shape[0]

    X_BP_means = scale([np.mean(x[:,0]) for x in X], (0,100))
    X_HR_means = scale([np.mean(x[:,1]) for x in X], (0,100))
    label_colors = [choose_color(y) for y in Y]

    plt.figure(figsize=(12,8))
    plt.scatter(latent_space[:, 0], latent_space[:, 1], c=X_HR_means[:lsl], s=5, cmap='gist_rainbow')
    plt.savefig(f"{directory}/{typ}_HR_MEANS_LS.png")

    plt.figure(figsize=(12,8))
    plt.scatter(latent_space[:, 0], latent_space[:, 1], c=X_BP_means[:lsl], s=5, cmap='gist_rainbow')
    plt.savefig(f"{directory}/{typ}_BP_MEANS_LS.png")

    plt.figure(figsize=(12,8))
    plt.scatter(latent_space[:, 0], latent_space[:, 1], c=label_colors[:lsl], s=5)
    plt.savefig(f"{directory}/{typ}_LABELS_LS.png")
    plt.close('all')
    
def scale(x, out_range=(-1, 1), axis=None):
    domain = np.min(x, axis), np.max(x, axis)
    y = (x - (domain[1] + domain[0]) / 2) / (domain[1] - domain[0])
    return y * (out_range[1] - out_range[0]) + (out_range[1] + out_range[0]) / 2
            

In [21]:
plt.ioff()
for ts, h in zip(time_steps, hidden_sizes):
    X_train, y_train = split_df(BP_data_scaled_trimmed[train_indices], HR_data_scaled_trimmed[train_indices], ts)
    X_test, y_test = split_df(BP_data_scaled_trimmed[test_indices], HR_data_scaled_trimmed[test_indices], ts)
    train_dataset = TensorDataset(torch.from_numpy(X_train))
    test_dataset = TensorDataset(torch.from_numpy(X_test))
    sequence_length = X_train.shape[1]
    number_of_features = X_train.shape[2]
    for e in epochs:
        for bs in batch_sizes:            
            vrae = VRAE(sequence_length=sequence_length,
            number_of_features=number_of_features,
            hidden_size=h,
            hidden_layer_depth=hidden_layer_depth,
            latent_length=latent_length,
            batch_size=bs,
            learning_rate=learning_rate,
            n_epochs=e,
            dropout_rate=dropout_rate,
            optimizer=optimizer,
            cuda=cuda,
            print_every=print_every,
            clip=clip,
            max_grad_norm=max_grad_norm,
            loss=loss,
            block=block,
            dload='a')
            
            epoch_loss, recon_loss, kl_loss = vrae.fit(train_dataset)
            epoch_loss = round(epoch_loss, 2)
            recon_loss = round(recon_loss, 2)
            kl_loss = round(kl_loss, 2)
            directory = "VRAE_results/" + f"LOSS{epoch_loss}_RECONLOSS{recon_loss}_KLLOSS{kl_loss}_LATENT{latent_length}_TS{ts}_BS{bs}_E{e}"
            os.mkdir(directory)
            
            plot_latent_space(X_train, y_train, 'TRAIN', directory)
            plot_latent_space(X_test, y_test, 'TEST', directory)
            del vrae

Epoch: 0
Batch 30, loss = 214.4090, recon_loss = 214.3918, kl_loss = 0.0172
Batch 60, loss = 107.5715, recon_loss = 107.4995, kl_loss = 0.0720
Average loss: 229.6017
Epoch: 1
Batch 30, loss = 100.3390, recon_loss = 100.1908, kl_loss = 0.1482
Batch 60, loss = 113.6138, recon_loss = 113.5305, kl_loss = 0.0833
Average loss: 109.9032
Epoch: 2
Batch 30, loss = 97.2541, recon_loss = 97.2277, kl_loss = 0.0264
Batch 60, loss = 79.1855, recon_loss = 79.1721, kl_loss = 0.0135
Average loss: 110.1601
Epoch: 3
Batch 30, loss = 112.8470, recon_loss = 112.8438, kl_loss = 0.0033
Batch 60, loss = 112.7895, recon_loss = 112.7850, kl_loss = 0.0044
Average loss: 110.6442
Epoch: 4
Batch 30, loss = 119.1137, recon_loss = 118.8837, kl_loss = 0.2300
Batch 60, loss = 89.0413, recon_loss = 88.5265, kl_loss = 0.5148
Average loss: 109.5991
Epoch: 5
Batch 30, loss = 102.1503, recon_loss = 101.5637, kl_loss = 0.5866
Batch 60, loss = 136.1575, recon_loss = 135.5035, kl_loss = 0.6540
Average loss: 108.5565
Epoch: 6
B



Epoch: 0
Batch 30, loss = 677.9473, recon_loss = 677.9083, kl_loss = 0.0389
Average loss: 1376.3770
Epoch: 1
Batch 30, loss = 239.9485, recon_loss = 239.1195, kl_loss = 0.8290
Average loss: 235.5407
Epoch: 2
Batch 30, loss = 176.5161, recon_loss = 175.7767, kl_loss = 0.7394
Average loss: 223.5679
Epoch: 3
Batch 30, loss = 240.7787, recon_loss = 239.8329, kl_loss = 0.9459
Average loss: 223.0174
Epoch: 4
Batch 30, loss = 259.6044, recon_loss = 258.3551, kl_loss = 1.2494
Average loss: 220.2708
Epoch: 5
Batch 30, loss = 269.8448, recon_loss = 268.5281, kl_loss = 1.3168
Average loss: 219.1929
Epoch: 6
Batch 30, loss = 204.7441, recon_loss = 203.7989, kl_loss = 0.9452
Average loss: 219.7244
Epoch: 7
Batch 30, loss = 195.8260, recon_loss = 194.9362, kl_loss = 0.8898
Average loss: 219.2205
Epoch: 8
Batch 30, loss = 196.1486, recon_loss = 195.0938, kl_loss = 1.0548
Average loss: 218.8713
Epoch: 9
Batch 30, loss = 228.8480, recon_loss = 227.7856, kl_loss = 1.0624
Average loss: 218.4111
Epoch: 0


Batch 60, loss = 42.5479, recon_loss = 41.1813, kl_loss = 1.3666
Average loss: 55.1053
Epoch: 20
Batch 30, loss = 59.2828, recon_loss = 57.8072, kl_loss = 1.4756
Batch 60, loss = 51.5489, recon_loss = 50.2347, kl_loss = 1.3142
Average loss: 53.5997
Epoch: 21
Batch 30, loss = 60.0715, recon_loss = 58.5848, kl_loss = 1.4867
Batch 60, loss = 40.9058, recon_loss = 39.3114, kl_loss = 1.5944
Average loss: 54.2964
Epoch: 22
Batch 30, loss = 58.2937, recon_loss = 56.8172, kl_loss = 1.4765
Batch 60, loss = 38.8024, recon_loss = 37.2865, kl_loss = 1.5159
Average loss: 53.2058
Epoch: 23
Batch 30, loss = 69.8085, recon_loss = 68.3973, kl_loss = 1.4112
Batch 60, loss = 44.9069, recon_loss = 43.3694, kl_loss = 1.5375
Average loss: 53.7090
Epoch: 24
Batch 30, loss = 58.3184, recon_loss = 56.7875, kl_loss = 1.5309
Batch 60, loss = 73.4997, recon_loss = 72.1889, kl_loss = 1.3109
Average loss: 54.5615
Epoch: 25
Batch 30, loss = 50.2949, recon_loss = 48.9187, kl_loss = 1.3762
Batch 60, loss = 34.7016, re

Batch 30, loss = 51.6509, recon_loss = 50.9141, kl_loss = 0.7369
Batch 60, loss = 43.5761, recon_loss = 42.8855, kl_loss = 0.6906
Batch 90, loss = 65.6313, recon_loss = 64.4928, kl_loss = 1.1384
Batch 120, loss = 50.8194, recon_loss = 50.0692, kl_loss = 0.7502
Batch 150, loss = 46.0662, recon_loss = 45.1322, kl_loss = 0.9340
Average loss: 50.8171
Epoch: 6
Batch 30, loss = 31.6368, recon_loss = 30.4728, kl_loss = 1.1640
Batch 60, loss = 36.2871, recon_loss = 35.1128, kl_loss = 1.1743
Batch 90, loss = 33.6430, recon_loss = 32.2957, kl_loss = 1.3473
Batch 120, loss = 37.8136, recon_loss = 36.4016, kl_loss = 1.4121
Batch 150, loss = 40.2557, recon_loss = 38.7884, kl_loss = 1.4673
Average loss: 30.1896
Epoch: 7
Batch 30, loss = 20.4259, recon_loss = 19.2710, kl_loss = 1.1549
Batch 60, loss = 22.9977, recon_loss = 21.8019, kl_loss = 1.1958
Batch 90, loss = 25.2520, recon_loss = 24.0311, kl_loss = 1.2209
Batch 120, loss = 22.7230, recon_loss = 21.2217, kl_loss = 1.5013
Batch 150, loss = 32.81

Batch 60, loss = 40.0576, recon_loss = 38.8130, kl_loss = 1.2446
Batch 90, loss = 24.0298, recon_loss = 22.8154, kl_loss = 1.2144
Batch 120, loss = 30.0162, recon_loss = 28.8015, kl_loss = 1.2147
Batch 150, loss = 31.0843, recon_loss = 29.8716, kl_loss = 1.2127
Average loss: 24.3259
Epoch: 34
Batch 30, loss = 18.9129, recon_loss = 17.6926, kl_loss = 1.2204
Batch 60, loss = 19.0581, recon_loss = 17.8191, kl_loss = 1.2391
Batch 90, loss = 30.0811, recon_loss = 28.8729, kl_loss = 1.2082
Batch 120, loss = 29.1538, recon_loss = 27.9130, kl_loss = 1.2408
Batch 150, loss = 32.0768, recon_loss = 30.8255, kl_loss = 1.2513
Average loss: 24.2466
Epoch: 35
Batch 30, loss = 21.5993, recon_loss = 20.3795, kl_loss = 1.2197
Batch 60, loss = 25.4898, recon_loss = 24.2399, kl_loss = 1.2499
Batch 90, loss = 23.5817, recon_loss = 22.3815, kl_loss = 1.2002
Batch 120, loss = 24.2128, recon_loss = 23.0025, kl_loss = 1.2103
Batch 150, loss = 18.8757, recon_loss = 17.6620, kl_loss = 1.2136
Average loss: 24.245

Batch 120, loss = 19.2325, recon_loss = 17.2992, kl_loss = 1.9333
Batch 150, loss = 19.8630, recon_loss = 18.0162, kl_loss = 1.8468
Average loss: 26.0652
Epoch: 10
Batch 30, loss = 20.9361, recon_loss = 19.1967, kl_loss = 1.7394
Batch 60, loss = 25.4505, recon_loss = 23.6690, kl_loss = 1.7814
Batch 90, loss = 37.3539, recon_loss = 35.7615, kl_loss = 1.5924
Batch 120, loss = 29.2022, recon_loss = 27.3285, kl_loss = 1.8738
Batch 150, loss = 23.9569, recon_loss = 22.4291, kl_loss = 1.5278
Average loss: 26.0836
Epoch: 11
Batch 30, loss = 27.1838, recon_loss = 25.5548, kl_loss = 1.6290
Batch 60, loss = 27.2986, recon_loss = 25.7233, kl_loss = 1.5752
Batch 90, loss = 24.9430, recon_loss = 23.3548, kl_loss = 1.5882
Batch 120, loss = 25.3486, recon_loss = 23.7756, kl_loss = 1.5730
Batch 150, loss = 22.0057, recon_loss = 20.5682, kl_loss = 1.4375
Average loss: 25.8637
Epoch: 12
Batch 30, loss = 23.5115, recon_loss = 21.7681, kl_loss = 1.7434
Batch 60, loss = 42.0704, recon_loss = 40.3525, kl_lo

Batch 60, loss = 22.2120, recon_loss = 20.9481, kl_loss = 1.2639
Batch 90, loss = 16.3709, recon_loss = 15.1304, kl_loss = 1.2405
Batch 120, loss = 34.5969, recon_loss = 33.3844, kl_loss = 1.2125
Batch 150, loss = 26.0315, recon_loss = 24.8341, kl_loss = 1.1974
Average loss: 24.1338
Epoch: 56
Batch 30, loss = 14.5444, recon_loss = 13.3432, kl_loss = 1.2012
Batch 60, loss = 33.0712, recon_loss = 31.8627, kl_loss = 1.2085
Batch 90, loss = 29.1042, recon_loss = 27.8483, kl_loss = 1.2559
Batch 120, loss = 18.9909, recon_loss = 17.8186, kl_loss = 1.1723
Batch 150, loss = 29.3140, recon_loss = 28.0455, kl_loss = 1.2685
Average loss: 24.0991
Epoch: 57
Batch 30, loss = 44.4012, recon_loss = 43.1284, kl_loss = 1.2727
Batch 60, loss = 21.4750, recon_loss = 20.2303, kl_loss = 1.2447
Batch 90, loss = 26.2104, recon_loss = 24.9560, kl_loss = 1.2545
Batch 120, loss = 28.9094, recon_loss = 27.6497, kl_loss = 1.2597
Batch 150, loss = 21.8112, recon_loss = 20.5776, kl_loss = 1.2336
Average loss: 24.370

Average loss: 98.5332
Epoch: 50
Batch 30, loss = 97.4126, recon_loss = 94.6065, kl_loss = 2.8061
Average loss: 98.5438
Epoch: 51
Batch 30, loss = 104.1590, recon_loss = 101.3620, kl_loss = 2.7970
Average loss: 96.0898
Epoch: 52
Batch 30, loss = 84.6778, recon_loss = 81.9724, kl_loss = 2.7054
Average loss: 96.7503
Epoch: 53
Batch 30, loss = 85.2447, recon_loss = 82.5342, kl_loss = 2.7105
Average loss: 97.8946
Epoch: 54
Batch 30, loss = 96.2460, recon_loss = 93.3837, kl_loss = 2.8623
Average loss: 97.1165
Epoch: 55
Batch 30, loss = 99.8493, recon_loss = 97.0653, kl_loss = 2.7839
Average loss: 96.1617
Epoch: 56
Batch 30, loss = 99.5336, recon_loss = 96.7732, kl_loss = 2.7605
Average loss: 97.2823
Epoch: 57
Batch 30, loss = 90.4625, recon_loss = 87.7745, kl_loss = 2.6880
Average loss: 97.3176
Epoch: 58
Batch 30, loss = 101.5892, recon_loss = 98.8531, kl_loss = 2.7361
Average loss: 96.5382
Epoch: 59
Batch 30, loss = 89.7662, recon_loss = 86.9727, kl_loss = 2.7935
Average loss: 97.0251
Epoch

Batch 240, loss = 14.4851, recon_loss = 13.5816, kl_loss = 0.9034
Batch 270, loss = 11.0248, recon_loss = 10.1354, kl_loss = 0.8894
Batch 300, loss = 12.3978, recon_loss = 11.3884, kl_loss = 1.0093
Average loss: 13.0244
Epoch: 5
Batch 30, loss = 10.7601, recon_loss = 9.7591, kl_loss = 1.0009
Batch 60, loss = 12.8833, recon_loss = 11.8776, kl_loss = 1.0057
Batch 90, loss = 14.6014, recon_loss = 13.6641, kl_loss = 0.9374
Batch 120, loss = 13.8061, recon_loss = 12.6762, kl_loss = 1.1298
Batch 150, loss = 9.8830, recon_loss = 8.8490, kl_loss = 1.0339
Batch 180, loss = 9.7379, recon_loss = 8.6763, kl_loss = 1.0615
Batch 210, loss = 22.1623, recon_loss = 21.2155, kl_loss = 0.9468
Batch 240, loss = 15.8692, recon_loss = 14.8898, kl_loss = 0.9794
Batch 270, loss = 6.8650, recon_loss = 5.9270, kl_loss = 0.9380
Batch 300, loss = 15.4513, recon_loss = 14.4415, kl_loss = 1.0098
Average loss: 12.7685
Epoch: 6
Batch 30, loss = 12.9907, recon_loss = 11.9508, kl_loss = 1.0400
Batch 60, loss = 13.4136,

Batch 180, loss = 13.5279, recon_loss = 12.4871, kl_loss = 1.0408
Batch 210, loss = 8.1544, recon_loss = 7.1695, kl_loss = 0.9849
Batch 240, loss = 9.7118, recon_loss = 8.5520, kl_loss = 1.1598
Batch 270, loss = 13.8547, recon_loss = 12.8880, kl_loss = 0.9667
Batch 300, loss = 10.6374, recon_loss = 9.6624, kl_loss = 0.9750
Average loss: 12.1761
Epoch: 29
Batch 30, loss = 9.3854, recon_loss = 8.4022, kl_loss = 0.9832
Batch 60, loss = 13.1241, recon_loss = 12.0150, kl_loss = 1.1090
Batch 90, loss = 9.4586, recon_loss = 8.4335, kl_loss = 1.0250
Batch 120, loss = 9.6563, recon_loss = 8.5692, kl_loss = 1.0870
Batch 150, loss = 11.5263, recon_loss = 10.5957, kl_loss = 0.9306
Batch 180, loss = 9.4886, recon_loss = 8.4739, kl_loss = 1.0147
Batch 210, loss = 14.3569, recon_loss = 13.1424, kl_loss = 1.2145
Batch 240, loss = 12.5893, recon_loss = 11.5185, kl_loss = 1.0708
Batch 270, loss = 16.4504, recon_loss = 15.3411, kl_loss = 1.1093
Batch 300, loss = 11.7453, recon_loss = 10.7643, kl_loss = 0

Batch 150, loss = 18.1924, recon_loss = 16.8042, kl_loss = 1.3881
Average loss: 20.8894
Epoch: 24
Batch 30, loss = 22.1749, recon_loss = 20.8500, kl_loss = 1.3249
Batch 60, loss = 19.1090, recon_loss = 17.7837, kl_loss = 1.3253
Batch 90, loss = 18.0506, recon_loss = 16.7893, kl_loss = 1.2613
Batch 120, loss = 18.0262, recon_loss = 16.6680, kl_loss = 1.3582
Batch 150, loss = 26.5773, recon_loss = 25.2344, kl_loss = 1.3429
Average loss: 20.5634
Epoch: 25
Batch 30, loss = 17.5614, recon_loss = 16.0560, kl_loss = 1.5053
Batch 60, loss = 15.9794, recon_loss = 14.6584, kl_loss = 1.3210
Batch 90, loss = 20.2852, recon_loss = 18.8016, kl_loss = 1.4836
Batch 120, loss = 16.1356, recon_loss = 14.6253, kl_loss = 1.5103
Batch 150, loss = 20.0834, recon_loss = 18.6004, kl_loss = 1.4829
Average loss: 20.2335
Epoch: 26
Batch 30, loss = 28.8858, recon_loss = 27.3947, kl_loss = 1.4911
Batch 60, loss = 21.1335, recon_loss = 19.6162, kl_loss = 1.5173
Batch 90, loss = 19.4052, recon_loss = 17.8202, kl_los

Batch 270, loss = 10.2101, recon_loss = 9.3135, kl_loss = 0.8966
Batch 300, loss = 13.5193, recon_loss = 12.6259, kl_loss = 0.8934
Average loss: 13.0152
Epoch: 6
Batch 30, loss = 12.3186, recon_loss = 11.3399, kl_loss = 0.9787
Batch 60, loss = 13.1801, recon_loss = 12.2238, kl_loss = 0.9563
Batch 90, loss = 10.1424, recon_loss = 9.1916, kl_loss = 0.9508
Batch 120, loss = 32.4843, recon_loss = 31.5614, kl_loss = 0.9229
Batch 150, loss = 8.5541, recon_loss = 7.6532, kl_loss = 0.9009
Batch 180, loss = 12.4528, recon_loss = 11.4707, kl_loss = 0.9821
Batch 210, loss = 12.1342, recon_loss = 11.1214, kl_loss = 1.0128
Batch 240, loss = 9.2030, recon_loss = 8.3125, kl_loss = 0.8905
Batch 270, loss = 6.6032, recon_loss = 5.7164, kl_loss = 0.8868
Batch 300, loss = 12.8708, recon_loss = 11.8408, kl_loss = 1.0301
Average loss: 12.7864
Epoch: 7
Batch 30, loss = 16.1767, recon_loss = 15.1014, kl_loss = 1.0753
Batch 60, loss = 15.3569, recon_loss = 14.4977, kl_loss = 0.8592
Batch 90, loss = 10.5154, r

Batch 270, loss = 8.3878, recon_loss = 7.4077, kl_loss = 0.9800
Batch 300, loss = 10.4813, recon_loss = 9.3669, kl_loss = 1.1145
Average loss: 11.5938
Epoch: 30
Batch 30, loss = 13.5147, recon_loss = 12.4033, kl_loss = 1.1115
Batch 60, loss = 10.8971, recon_loss = 9.7377, kl_loss = 1.1594
Batch 90, loss = 15.8684, recon_loss = 14.8123, kl_loss = 1.0561
Batch 120, loss = 11.1426, recon_loss = 10.0464, kl_loss = 1.0962
Batch 150, loss = 18.2664, recon_loss = 17.2502, kl_loss = 1.0161
Batch 180, loss = 9.9159, recon_loss = 8.8351, kl_loss = 1.0809
Batch 210, loss = 8.8602, recon_loss = 7.7740, kl_loss = 1.0862
Batch 240, loss = 10.6947, recon_loss = 9.5713, kl_loss = 1.1233
Batch 270, loss = 8.9802, recon_loss = 7.7888, kl_loss = 1.1914
Batch 300, loss = 10.9737, recon_loss = 9.8695, kl_loss = 1.1042
Average loss: 11.4704
Epoch: 31
Batch 30, loss = 8.5714, recon_loss = 7.3983, kl_loss = 1.1731
Batch 60, loss = 11.9498, recon_loss = 10.7353, kl_loss = 1.2145
Batch 90, loss = 10.8705, recon

Batch 270, loss = 15.1083, recon_loss = 13.7894, kl_loss = 1.3189
Batch 300, loss = 12.3909, recon_loss = 11.0897, kl_loss = 1.3012
Average loss: 11.4919
Epoch: 54
Batch 30, loss = 13.3396, recon_loss = 12.0795, kl_loss = 1.2601
Batch 60, loss = 12.0661, recon_loss = 10.8458, kl_loss = 1.2203
Batch 90, loss = 13.5102, recon_loss = 12.2361, kl_loss = 1.2742
Batch 120, loss = 7.5837, recon_loss = 6.2341, kl_loss = 1.3496
Batch 150, loss = 14.9669, recon_loss = 13.7394, kl_loss = 1.2276
Batch 180, loss = 9.1582, recon_loss = 7.9190, kl_loss = 1.2392
Batch 210, loss = 12.3773, recon_loss = 11.0557, kl_loss = 1.3216
Batch 240, loss = 9.4407, recon_loss = 8.1976, kl_loss = 1.2432
Batch 270, loss = 14.3776, recon_loss = 13.2345, kl_loss = 1.1431
Batch 300, loss = 11.5637, recon_loss = 10.2906, kl_loss = 1.2731
Average loss: 11.5657
Epoch: 55
Batch 30, loss = 14.5708, recon_loss = 13.3370, kl_loss = 1.2338
Batch 60, loss = 13.3207, recon_loss = 12.0402, kl_loss = 1.2805
Batch 90, loss = 12.413

Batch 150, loss = 24.8618, recon_loss = 23.6753, kl_loss = 1.1865
Average loss: 23.7216
Epoch: 34
Batch 30, loss = 21.3936, recon_loss = 20.1770, kl_loss = 1.2166
Batch 60, loss = 14.8320, recon_loss = 13.5982, kl_loss = 1.2338
Batch 90, loss = 27.3507, recon_loss = 26.0954, kl_loss = 1.2553
Batch 120, loss = 19.8577, recon_loss = 18.6360, kl_loss = 1.2217
Batch 150, loss = 26.3741, recon_loss = 25.1163, kl_loss = 1.2579
Average loss: 23.6879
Epoch: 35
Batch 30, loss = 29.0153, recon_loss = 27.7831, kl_loss = 1.2322
Batch 60, loss = 16.0220, recon_loss = 14.7584, kl_loss = 1.2636
Batch 90, loss = 28.1749, recon_loss = 26.9610, kl_loss = 1.2139
Batch 120, loss = 20.0535, recon_loss = 18.7831, kl_loss = 1.2703
Batch 150, loss = 18.8426, recon_loss = 17.6350, kl_loss = 1.2076
Average loss: 23.7264
Epoch: 36
Batch 30, loss = 23.8114, recon_loss = 22.6139, kl_loss = 1.1975
Batch 60, loss = 19.7572, recon_loss = 18.5293, kl_loss = 1.2280
Batch 90, loss = 22.3951, recon_loss = 21.1723, kl_los

Batch 60, loss = 50.0178, recon_loss = 48.3629, kl_loss = 1.6549
Average loss: 43.3139
Epoch: 43
Batch 30, loss = 35.3502, recon_loss = 33.6984, kl_loss = 1.6518
Batch 60, loss = 31.0722, recon_loss = 29.4636, kl_loss = 1.6085
Average loss: 43.0217
Epoch: 44
Batch 30, loss = 48.5475, recon_loss = 47.0412, kl_loss = 1.5063
Batch 60, loss = 40.4431, recon_loss = 38.6737, kl_loss = 1.7694
Average loss: 42.6970
Epoch: 45
Batch 30, loss = 38.3453, recon_loss = 36.8000, kl_loss = 1.5453
Batch 60, loss = 34.4911, recon_loss = 33.0312, kl_loss = 1.4600
Average loss: 42.2513
Epoch: 46
Batch 30, loss = 36.4299, recon_loss = 34.9236, kl_loss = 1.5063
Batch 60, loss = 49.8803, recon_loss = 48.2958, kl_loss = 1.5845
Average loss: 41.8760
Epoch: 47
Batch 30, loss = 41.5272, recon_loss = 39.8577, kl_loss = 1.6695
Batch 60, loss = 46.3983, recon_loss = 44.7998, kl_loss = 1.5985
Average loss: 41.4885
Epoch: 48
Batch 30, loss = 37.3165, recon_loss = 35.7087, kl_loss = 1.6077
Batch 60, loss = 37.7243, re

In [None]:
plt.plot()