In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dgr_bert_model import *
from dgr_vae_model import *
from dgr_utils import *

In [3]:
# Use GPU
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print("Available device: {}".format(device))
if use_cuda:
    print(torch.cuda.get_device_name(0))

Available device: cuda:0
NVIDIA GeForce GTX 1070


In [ ]:
"""
Builds and trains a VAE model on the extracted features from the FeatureExtractor

VAE model can be used to reconstruct or generate new features
"""

In [4]:
bert_lstm = torch.load('data/model_bert_854.pth')

feature_extractor = torch.load('data/feature_extractor.pth')
feature_classifier = torch.load('data/feature_classifier.pth')

In [5]:
train_features_sample = torch.load('data/train_features_sample.pt')
train_features_label = torch.load('data/train_features_label.pt')

val_features_sample = torch.load('data/val_features_sample.pt')
val_features_label = torch.load('data/val_features_label.pt')

In [6]:
batch_size = 32

train_features_dataset = TensorDataset(train_features_sample, train_features_label)
train_features_sampler = RandomSampler(train_features_dataset)
train_features_dataloader = DataLoader(train_features_dataset, 
                                  sampler=train_features_sampler, 
                                  batch_size=batch_size)


val_features_dataset = TensorDataset(val_features_sample, val_features_label)
val_features_sampler = RandomSampler(val_features_dataset)
val_features_dataloader = DataLoader(val_features_dataset, 
                                  sampler=val_features_sampler, 
                                  batch_size=batch_size)

In [7]:
input_dim = 1024
latent_dim = 128

vae = VAE(input_dim, latent_dim).to(device)
print(vae)

optimizer = torch.optim.Adam(vae.parameters(), lr=0.0001)

VAE(
  (encoder): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=256, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=1024, bias=True)
    (5): Tanh()
  )
)


In [8]:
num_epoch = 100

best_loss = float('inf')

for epoch in range(1, num_epoch + 1):
    vae.train()
    epoch_loss = train_epoch_vae(vae, train_features_dataloader, optimizer, device)
    
    vae.eval()
    val_loss = eval_vae(vae, val_features_dataloader, device)
    
    print("Epoch: {} train_oss: {:.5f} val_loss: {:.5f}".format(epoch, epoch_loss, val_loss))
    
    if val_loss < best_loss:
        model_checkpoint_path = 'data/vae.pth'
        torch.save(vae, model_checkpoint_path)
        print("saving model checkpoint to {}".format(model_checkpoint_path))
        
        best_loss = val_loss

Epoch: 1 train_oss: 451.47124 val_loss: 455.90908
saving model checkpoint to data/vae.pth
Epoch: 2 train_oss: 302.78167 val_loss: 288.45377
saving model checkpoint to data/vae.pth
Epoch: 3 train_oss: 258.52272 val_loss: 278.08301
saving model checkpoint to data/vae.pth
Epoch: 4 train_oss: 256.02034 val_loss: 278.65415
Epoch: 5 train_oss: 246.41565 val_loss: 235.90972
saving model checkpoint to data/vae.pth
Epoch: 6 train_oss: 199.55964 val_loss: 212.65284
saving model checkpoint to data/vae.pth
Epoch: 7 train_oss: 191.90834 val_loss: 207.18986
saving model checkpoint to data/vae.pth
Epoch: 8 train_oss: 189.27628 val_loss: 202.58018
saving model checkpoint to data/vae.pth
Epoch: 9 train_oss: 187.86692 val_loss: 201.20388
saving model checkpoint to data/vae.pth
Epoch: 10 train_oss: 190.23083 val_loss: 196.86234
saving model checkpoint to data/vae.pth
Epoch: 11 train_oss: 186.32210 val_loss: 198.95958
Epoch: 12 train_oss: 184.18966 val_loss: 199.47740
Epoch: 13 train_oss: 183.22485 val_lo

In [10]:
latent_sample = torch.randn(1, latent_dim).to(device)
with torch.no_grad():
    generated_item = vae.decoder(latent_sample)
    
print(generated_item)

tensor([[ 0.0250,  0.0772, -0.0244,  ...,  0.2842, -0.1136,  0.0600]],
       device='cuda:0')
