In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import glob

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from PIL import ImageFile

sys.path.append( '../' )
from ConvVAE import ConvVAE
from utils.Dataset import Rollout_Dataset, to_latent

In [2]:
n_channel = 3
z_dim = 128
VAE_model = ConvVAE(n_channel=n_channel, z_dim=z_dim)
VAE_model.load_state_dict(torch.load("./weights/segmodel_rollouts.pt"))
VAE_model = VAE_model.eval()

In [3]:
validation_prop = 0.15

s_paths = glob.glob(os.path.join("../S_Rollouts_11/*.npz"))
d_paths = glob.glob(os.path.join("../D_Rollouts_11/*.npz"))
print(f"S_Rollouts lenght: {len(s_paths)}")
print(f"D_Rollouts lenght: {len(d_paths)}")
n_static_examples = len(s_paths)
n_dinamic_examples = len(d_paths)

s_idxs = np.random.choice(n_static_examples-1, size=round(n_static_examples*(1-validation_prop)),
                          replace=False)
s_cidxs = np.array(list(set(np.arange(n_static_examples)) - set(s_idxs)))

d_idxs = np.random.choice(n_dinamic_examples-1, size=round(n_dinamic_examples*(1-validation_prop)),
                          replace=False)
d_cidxs = np.array(list(set(np.arange(n_dinamic_examples))- set(d_idxs)))

print(f"s_idxs lenght: {s_idxs.shape[0]}")
print(f"s_cidxs lenght: {s_cidxs.shape[0]}")

print(f"d_idxs lenght: {d_idxs.shape[0]}")
print(f"d_cidxs lenght: {d_cidxs.shape[0]}")

del s_paths
del d_paths

S_Rollouts lenght: 97
D_Rollouts lenght: 100
s_idxs lenght: 82
s_cidxs lenght: 15
d_idxs lenght: 85
d_cidxs lenght: 15


In [20]:
s_path = "../S_Rollouts_11"
d_path = "../D_Rollouts_11"
seq_len = 3
buffer_size = 3 # number of files
# complt_states: 
# 0 : steer
# 1: throttle
# 2: speed
# 3: x's orientation
# 4: y's orientation
# 5: z's orientation

_complt_states = [0, 1, 2]
transform = transforms.Compose([transforms.Resize((80, 160)), transforms.ToTensor()])

s_train_dataset = Rollout_Dataset(s_path, s_idxs, seq_len, buffer_size, _complt_states, transform=transform)
s_val_dataset = Rollout_Dataset(s_path, s_cidxs, seq_len, buffer_size, _complt_states, transform=transform)

batch_size = 32
s_train_dataloader = DataLoader(s_train_dataset, batch_size=batch_size, shuffle=False)
s_val_dataloader = DataLoader(s_val_dataset, batch_size=batch_size, shuffle=False)

In [30]:
from LSTM import MDN_RNN
input_size = z_dim + len(_complt_states)
hidden_size = 512
action_size = 2
num_layers = 1
gaussians = 3

lstm_model = MDN_RNN(input_size=input_size,
                     hidden_size=hidden_size,
                     action_size=action_size,
                     num_layers=num_layers,
                     gaussians=gaussians)

In [34]:
epochs = 10
validation_epoch = 4
lr = 1e-3
opt = torch.optim.Adam(lstm_model.parameters(), lr=lr)


for epoch in range(epochs):
    lstm_model = lstm_model.train()
    train_loss = []
    for i, X_tr in enumerate(s_train_dataloader):
        states, actions, rewards, next_states, terminals = X_tr if len(X_tr)==5 else X_tr[:-2]
        complt_states, next_complt_states = X_tr[-2:] if len(X_tr) > 5 else (None, None)
        states, next_states = to_latent(VAE_model, states, next_states, complt_states, next_complt_states)
        # states -> (B, S, Z_dim+complt_states)
        mus, sigmas, logpi, rs, ds = lstm_model(states, actions)
        gmm_loss = lstm_model.gmm_loss(next_states, mus, sigmas, logpi)
        loss = lstm_model.get_loss(gmm_loss, rewards, rs, terminals, ds) # (B, S)
        loss = torch.mean(loss) # [1] global mean???
        opt.zero_grad()
        loss.backward()
        opt.step()
        train_loss.append(loss.item())
    train_loss = sum(train_loss)/len(train_loss)
    print(f"epoch: {epoch+1} || train_loss: {train_loss}")
    if (epoch+1)%validation_epoch == 0:
        lstm = lstm_model.eval()
        val_loss = []
        for j, X_val in enumerate(s_val_dataloader):
            states, actions, rewards, next_states, terminals = X_tr if len(X_tr)==5 else X_tr[:-2]
            complt_states, next_complt_states = X_tr[-2:] if len(X_tr) > 5 else (None, None)
            states, next_states = to_latent(VAE_model, states, next_states, complt_states, next_complt_states)
            # states -> (B, S, Z_dim+complt_states)
            mus, sigmas, log_pi, rs, ds = lstm_model(states, actions)
            loss = lstm_model.get_loss(gmm_loss, rewards, rs, terminals, ds)
            loss = get_loss()
            val_loss.append(loss.item())
        val_loss = sum(val_loss)/len(val_loss)
        print(f"epoch: {epoch+1} || val_loss: {val_loss}")

torch.Size([])
torch.Size([])
torch.Size([])
torch.Size([])


KeyboardInterrupt: 