#### Importing packages

In [None]:
%run Packages/Libraries.ipynb
%run Packages/Functions.ipynb
%run Packages/Networks.ipynb
%run Packages/Optimize.ipynb

#### Creating Datasets

In [None]:
data_root_dir = 'Dataset'
train_dataset, test_dataset = load_datasets(data_root_dir)

#### Creating Variational AutoEncoders

In [None]:
encoded_space_dim = 5
lambda_ = 0.75
net_vae = VAE(encoded_space_dim=encoded_space_dim, lambda_=lambda_)

loss_fn = torch.nn.MSELoss()
optim_vae = torch.optim.Adam(net_vae.parameters(), lr=0.001, weight_decay=1e-5)

device = torch.device("cuda")
net_vae.to(device)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=512, shuffle=True)

#### Training

In [None]:
num_epochs = 50
load_weights = False
load_best = True

val_loss_log_vae = []
for epoch in range(num_epochs):
    if load_weights or (load_best and epoch>0.5):
        print('Loaded!')
        net_vae.load_state_dict(torch.load('ckpt/net_vae_params.pth'))
    print('EPOCH %d/%d' % (epoch + 1, num_epochs))
    train_epoch(net_vae, dataloader=train_dataloader, loss_fn=loss_fn, optimizer=optim_vae, show_steps=20,
                use_noise = True)
    val_loss = test_epoch(net_vae, dataloader=test_dataloader, loss_fn=loss_fn)
    val_loss_log_vae.append(val_loss.item())
    print('\n\n\t VALIDATION - EPOCH %d/%d - loss: %f\n\n' % (epoch + 1, num_epochs, val_loss))
    
    if (epoch<0.5 or val_loss.item()<min(val_loss_log_vae[:-1])):
        print('Saved!')
        torch.save(net_vae.state_dict(), 'ckpt/net_vae_params.pth')