In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from models.vae import VariationalAutoencoder
from models.lstm import LSTM, mdn_loss_function

from torch.utils.data import DataLoader
from lib.dataset import LSTMDataset

from lib.consts import *

In [2]:
PATH = 'models/trained/'

In [3]:

def train_epoch(lstm, optimizer, example):
    
    optimizer.zero_grad()
    lstm.hidden = lstm.init_hidden(SEQUENCE)
    x = torch.cat((example['encoded'],
            example['actions']), dim=2).to(DEVICE)
    
    last_ex = example['encoded'][:,-OFFSET].view(2,-1,32)
    target = torch.cat((torch.roll(example['encoded'], shifts=-1, dims=1)[:,1:],\
                          last_ex,),dim=1).detach()
    
    pi, sigma, mu = lstm(x)
    
    loss = mdn_loss_function(pi, sigma, mu, target)
    loss.backward()
    optimizer.step()
    
    return float(loss)

In [4]:
def train_lstm():

    
    lstm = LSTM(SEQUENCE, HIDDEN_UNITS, LATENT_VEC,\
                        NUM_LAYERS, GAUSSIANS, HIDDEN_DIM).to(DEVICE)
    
    vae = VariationalAutoencoder()
    vae_checkpoint = torch.load(PATH+ 'vae'+str(8)+'.pt')
    vae.load_state_dict(vae_checkpoint['model_state_dict'])
    
    optimizer = torch.optim.Adam(lstm.parameters(), lr=LR, weight_decay=L2_REG)
    
    dataset = LSTMDataset(data_file='rollouts.data',root_dir='data/')
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE_LSTM)
    
    while True:
        running_loss = []
        batch_loss = []
        
        for observations,actions in dataloader:
            
            encoded = torch.stack([vae(observations[0],encode=True),vae(observations[1],encode=True)],dim=0)
            example = {'encoded': encoded,
                       'actions': actions}
            
            loss = train_epoch(lstm, optimizer, example)
            print('loss: ', loss)
            running_loss.append(loss)
            print('running loss: ',running_loss)

In [None]:
train_lstm()

loss:  1.6942962408065796
running loss:  [1.6942962408065796]
loss:  1.6362873315811157
running loss:  [1.6362873315811157]
loss:  1.6175556182861328
running loss:  [1.6175556182861328]
loss:  1.5978407859802246
running loss:  [1.5978407859802246]
loss:  1.6522492170333862
running loss:  [1.6522492170333862]
