In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from data import *
from model import AIS_LSTM

In [2]:
# Load best model based on `Visualize.ipynb`
model = AIS_LSTM.load_from_checkpoint(checkpoint_path='./logs/AIS_LSTM_model/version_15/checkpoints/epoch=203-step=21623.ckpt', map_location=None)
model.eval()

AIS_LSTM(
  (gen): Sequential(
    (0): Linear(in_features=63, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): LSTM(128, 64, batch_first=True)
    (5): ExtractRNNOutput()
  )
  (pred): Sequential(
    (0): Linear(in_features=89, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=33, bias=True)
  )
)

In [3]:
train_data = torch.load(train_data_file)
val_data = torch.load(val_data_file)
test_data = torch.load(test_data_file)

In [4]:
# Check consistency of sample size before proceeding
for data in [train_data, val_data, test_data]:
    num_episodes = len(data['index'])
    for k, v in data.items():
        assert len(v) == num_episodes

In [5]:
train_dataset = TensorDataset(
    train_data['demographics'], 
    train_data['observations'], 
    train_data['actionvecs'], 
    train_data['lengths'], 
)
val_dataset = TensorDataset(
    val_data['demographics'], 
    val_data['observations'], 
    val_data['actionvecs'], 
    val_data['lengths'], 
)
test_dataset = TensorDataset(
    test_data['demographics'], 
    test_data['observations'], 
    test_data['actionvecs'], 
    test_data['lengths'], 
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [6]:
# Encode all states
all_encoded_states = []
for dataloader in [train_loader, val_loader, test_loader]:
    with torch.no_grad():
        encoded_states = []
        for dem, obs, act, lng in dataloader:
            x = torch.cat([obs, dem, act], dim=-1)
            z = model(x)
            mask = (obs == 0).all(dim=2)
            z[mask] = 0
            encoded_states.append(z)

    encoded_states = torch.cat(encoded_states, axis=0)
    all_encoded_states.append(encoded_states)

In [7]:
train_states, val_states, test_states = all_encoded_states

In [8]:
train_data['statevecs'] = train_states
val_data['statevecs'] = val_states
test_data['statevecs'] = test_states

In [9]:
# Check consistency of sample size of outputs
for data in [train_data, val_data, test_data]:
    num_episodes = len(data['index'])
    for k, v in data.items():
        assert len(v) == num_episodes

In [10]:
torch.save(train_data, '../data/episodes+encoded_state/train_data.pt')
torch.save(val_data, '../data/episodes+encoded_state/val_data.pt')
torch.save(test_data, '../data/episodes+encoded_state/test_data.pt')

In [11]:
# Check consistency of episode length and feature vector
for i in range(100):
    lng = train_data['lengths'][i]
    assert (train_data['observations'][i][lng:] == 0).all()
    assert (train_data['statevecs'][i][lng:] == 0).all()