# Deep Markov Model

In [147]:
from tqdm import tqdm

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from tensorboardX import SummaryWriter
import numpy as np

In [148]:
batch_size = 128#128
epochs = 5
seed = 1
torch.manual_seed(seed)

<torch._C.Generator at 0x2b65be68e490>

In [149]:
# if torch.cuda.is_available():
#     device = "cuda"
# else:
#     device = "cpu"
device="cpu"

In [150]:
# def init_dataset(f_batch_size):
#     kwargs = {'num_workers': 1, 'pin_memory': True}
#     data_dir = '../data'
#     mnist_transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Lambda(lambda data: data[0])
#     ])
#     train_loader = torch.utils.data.DataLoader(
#         datasets.MNIST(data_dir, train=True, download=True,
#                        transform=mnist_transform),
#         batch_size=f_batch_size, shuffle=True, **kwargs)
#     test_loader = torch.utils.data.DataLoader(
#         datasets.MNIST(data_dir, train=False, transform=mnist_transform),
#         batch_size=f_batch_size, shuffle=True, **kwargs)

#     fixed_t_size = 28
#     return train_loader, test_loader, fixed_t_size

# train_loader, test_loader, t_max = init_dataset(batch_size)

In [151]:
landmark_num = 10

In [152]:

transform = transforms.Compose([transforms.ToTensor()])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}
#data loader #とりあえず1時系列分を分身させて食わせる
#[time,s_x,s_y,s_yaw,uv,ur,ot[1],,,,ot[N]]
data = np.loadtxt('vehicle_motion_data.csv', delimiter=',')
data = torch.tensor([data],dtype=torch.float32)
st = data[0,:,1:4]
ut = data[0,:,4:6]
ot = data[0,:,6:26]

print(st.size())
st=st.repeat(1000,1,1)
ut=ut.repeat(1000,1,1)
ot=ot.repeat(1000,1,1)
print(st.size())

train = torch.utils.data.TensorDataset(ot)
train_loader = torch.utils.data.DataLoader(train, shuffle=False,**kwargs)
test = torch.utils.data.TensorDataset(ot)
test_loader = torch.utils.data.DataLoader(test, shuffle=False,**kwargs)

torch.Size([139, 3])
torch.Size([1000, 139, 3])


In [153]:
from pixyz.models import Model
from pixyz.losses import KullbackLeibler, CrossEntropy, IterativeLoss
from pixyz.distributions import Bernoulli, Normal, Deterministic
from pixyz.utils import print_latex

In [154]:
x_dim = landmark_num*2
h_dim = 32 #32
hidden_dim = 32 #32
z_dim = 3
t_max = 139

In [155]:
class RNN(Deterministic):
    def __init__(self):
        super(RNN, self).__init__(cond_var=["x"], var=["h"])
        self.rnn = nn.GRU(x_dim, h_dim, bidirectional=True)
#         self.h0 = torch.zeros(2, batch_size, self.rnn.hidden_size).to(device)
        self.h0 = nn.Parameter(torch.zeros(2, 1, self.rnn.hidden_size))
        self.hidden_size = self.rnn.hidden_size
        
    def forward(self, x):
        h0 = self.h0.expand(2, x.size(1), self.rnn.hidden_size).contiguous()
        h, _ = self.rnn(x, h0)
        return {"h": h}

In [156]:
class Generator(Bernoulli):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"])
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, x_dim)
    
    def forward(self, z):
        h = F.relu(self.fc1(z))
        return {"probs": torch.sigmoid(self.fc2(h))}

In [157]:
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["h", "z_prev"], var=["z"])
        self.fc1 = nn.Linear(z_dim, h_dim*2)
        self.fc21 = nn.Linear(h_dim*2, z_dim)
        self.fc22 = nn.Linear(h_dim*2, z_dim)
        
    def forward(self, h, z_prev):
        h_z = torch.tanh(self.fc1(z_prev))
        h = 0.5 * (h + h_z)
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

In [158]:
class Prior(Normal):
    def __init__(self):
        super(Prior, self).__init__(cond_var=["z_prev"], var=["z"])
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        
    def forward(self, z_prev):
        h = F.relu(self.fc1(z_prev))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

In [159]:
prior = Prior().to(device)
encoder = Inference().to(device)
decoder = Generator().to(device)
rnn = RNN().to(device)

In [160]:
print(prior)
print("*"*80)
print(encoder)
print("*"*80)
print(decoder)
print("*"*80)
print(rnn)

Distribution:
  p(z|z_{prev})
Network architecture:
  Prior(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['z_prev'], input_var=['z_prev'], features_shape=torch.Size([])
    (fc1): Linear(in_features=3, out_features=32, bias=True)
    (fc21): Linear(in_features=32, out_features=3, bias=True)
    (fc22): Linear(in_features=32, out_features=3, bias=True)
  )
********************************************************************************
Distribution:
  p(z|h,z_{prev})
Network architecture:
  Inference(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['h', 'z_prev'], input_var=['h', 'z_prev'], features_shape=torch.Size([])
    (fc1): Linear(in_features=3, out_features=64, bias=True)
    (fc21): Linear(in_features=64, out_features=3, bias=True)
    (fc22): Linear(in_features=64, out_features=3, bias=True)
  )
********************************************************************************
Distribution:
  p(x|z)
Network architecture:
  Generator(
    name=p, 

In [161]:
generate_from_prior = prior * decoder
print(generate_from_prior)
print_latex(generate_from_prior)

Distribution:
  p(x,z|z_{prev}) = p(x|z)p(z|z_{prev})
Network architecture:
  Prior(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['z_prev'], input_var=['z_prev'], features_shape=torch.Size([])
    (fc1): Linear(in_features=3, out_features=32, bias=True)
    (fc21): Linear(in_features=32, out_features=3, bias=True)
    (fc22): Linear(in_features=32, out_features=3, bias=True)
  )
  Generator(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (fc1): Linear(in_features=3, out_features=32, bias=True)
    (fc2): Linear(in_features=32, out_features=20, bias=True)
  )


<IPython.core.display.Math object>

In [162]:
step_loss = CrossEntropy(encoder, decoder) + KullbackLeibler(encoder, prior)
_loss = IterativeLoss(step_loss, max_iter=t_max, 
                      series_var=["x", "h"], update_value={"z": "z_prev"})
loss = _loss.expectation(rnn).mean()

In [163]:
dmm = Model(loss, distributions=[rnn, encoder, decoder, prior], 
            optimizer=optim.RMSprop, optimizer_params={"lr": 5e-4}, clip_grad_value=10)

In [164]:
print(dmm)
print_latex(dmm)

Distributions (for training): 
  p(h|x), p(z|h,z_{prev}), p(x|z), p(z|z_{prev}) 
Loss function: 
  mean \left(\mathbb{E}_{p(h|x)} \left[\sum_{t=1}^{139} \left(D_{KL} \left[p(z|h,z_{prev})||p(z|z_{prev}) \right] - \mathbb{E}_{p(z|h,z_{prev})} \left[\log p(x|z) \right]\right) \right] \right) 
Optimizer: 
  RMSprop (
  Parameter Group 0
      alpha: 0.99
      centered: False
      eps: 1e-08
      lr: 0.0005
      momentum: 0
      weight_decay: 0
  )


<IPython.core.display.Math object>

In [165]:
def data_loop(epoch, loader, model, device, train_mode=False):
    mean_loss = 0
    for idx,[data] in enumerate(tqdm(loader)):#batch_idx, (data, _) in enumerate(tqdm(loader)):
        print(data.size())# [batchsize=128,28,28] #画像サイズは28*28
        print(data.dtype)
        data = data.to(device)
        batch_size = data.size()[0]
        x = data.transpose(0, 1) #多分転置してるだけ
        z_prev = torch.zeros(batch_size, z_dim).to(device)
        print(z_prev.dtype)
        if train_mode:
            mean_loss += model.train({'x': x, 'z_prev': z_prev}).item() * batch_size
        else:
            mean_loss += model.test({'x': x, 'z_prev': z_prev}).item() * batch_size
    mean_loss /= len(loader.dataset)
    if train_mode:
        print('Epoch: {} Train loss: {:.4f}'.format(epoch, mean_loss))
    else:
        print('Test loss: {:.4f}'.format(mean_loss))
    return mean_loss

In [166]:
def plot_image_from_latent(batch_size):
    x = []
    z_prev = torch.zeros(batch_size, z_dim).to(device)
    for step in range(t_max):
        samples = generate_from_prior.sample({'z_prev': z_prev})
        x_t = decoder.sample_mean({"z": samples["z"]})
        z_prev = samples["z"]
        x.append(x_t[None, :])
    x = torch.cat(x, dim=0).transpose(0, 1)
    return x

In [167]:
writer = SummaryWriter()

for epoch in range(1, epochs + 1):
    train_loss = data_loop(epoch, train_loader, dmm, device, train_mode=True)
    test_loss = data_loop(epoch, test_loader, dmm, device)

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('test_loss', test_loss, epoch)

    sample = plot_image_from_latent(batch_size)[:, None][1,:]
    writer.add_image('Image_from_latent', sample, epoch)

  0%|          | 0/8 [00:00<?, ?it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 12%|█▎        | 1/8 [00:00<00:04,  1.40it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 25%|██▌       | 2/8 [00:01<00:04,  1.50it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 38%|███▊      | 3/8 [00:02<00:03,  1.44it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 50%|█████     | 4/8 [00:02<00:02,  1.46it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 62%|██████▎   | 5/8 [00:03<00:02,  1.42it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 75%|███████▌  | 6/8 [00:04<00:01,  1.44it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 88%|████████▊ | 7/8 [00:04<00:00,  1.48it/s]

torch.Size([104, 139, 20])
torch.float32
torch.float32


100%|██████████| 8/8 [00:05<00:00,  1.54it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch: 1 Train loss: -16217.8789
torch.Size([128, 139, 20])
torch.float32
torch.float32


 12%|█▎        | 1/8 [00:00<00:02,  2.53it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 25%|██▌       | 2/8 [00:00<00:02,  2.76it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 38%|███▊      | 3/8 [00:00<00:01,  3.06it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 50%|█████     | 4/8 [00:01<00:01,  3.28it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 62%|██████▎   | 5/8 [00:01<00:00,  3.48it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 88%|████████▊ | 7/8 [00:01<00:00,  4.04it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32
torch.Size([104, 139, 20])
torch.float32
torch.float32


100%|██████████| 8/8 [00:02<00:00,  4.27it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Test loss: -37210.3535
torch.Size([128, 139, 20])
torch.float32
torch.float32


 12%|█▎        | 1/8 [00:00<00:06,  1.13it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 25%|██▌       | 2/8 [00:01<00:04,  1.23it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 38%|███▊      | 3/8 [00:02<00:03,  1.27it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 50%|█████     | 4/8 [00:02<00:03,  1.33it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 62%|██████▎   | 5/8 [00:03<00:02,  1.34it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 75%|███████▌  | 6/8 [00:04<00:01,  1.41it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 88%|████████▊ | 7/8 [00:04<00:00,  1.51it/s]

torch.Size([104, 139, 20])
torch.float32
torch.float32


100%|██████████| 8/8 [00:05<00:00,  1.60it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch: 2 Train loss: -51480.8087
torch.Size([128, 139, 20])
torch.float32
torch.float32


 12%|█▎        | 1/8 [00:00<00:02,  2.47it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 25%|██▌       | 2/8 [00:00<00:02,  2.74it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 38%|███▊      | 3/8 [00:00<00:01,  3.04it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 50%|█████     | 4/8 [00:01<00:01,  3.25it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 62%|██████▎   | 5/8 [00:01<00:00,  3.41it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 88%|████████▊ | 7/8 [00:01<00:00,  3.81it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32
torch.Size([104, 139, 20])
torch.float32
torch.float32


100%|██████████| 8/8 [00:02<00:00,  4.17it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Test loss: -71192.4667
torch.Size([128, 139, 20])
torch.float32
torch.float32


 12%|█▎        | 1/8 [00:00<00:06,  1.10it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 25%|██▌       | 2/8 [00:01<00:05,  1.14it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 38%|███▊      | 3/8 [00:02<00:04,  1.21it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 50%|█████     | 4/8 [00:03<00:03,  1.26it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 62%|██████▎   | 5/8 [00:03<00:02,  1.36it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 75%|███████▌  | 6/8 [00:04<00:01,  1.33it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 88%|████████▊ | 7/8 [00:05<00:00,  1.40it/s]

torch.Size([104, 139, 20])
torch.float32
torch.float32


100%|██████████| 8/8 [00:05<00:00,  1.47it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch: 3 Train loss: -86594.9063
torch.Size([128, 139, 20])
torch.float32
torch.float32


 12%|█▎        | 1/8 [00:00<00:03,  2.01it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 25%|██▌       | 2/8 [00:00<00:02,  2.33it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 38%|███▊      | 3/8 [00:01<00:01,  2.62it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 50%|█████     | 4/8 [00:01<00:01,  2.74it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 62%|██████▎   | 5/8 [00:01<00:00,  3.05it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 88%|████████▊ | 7/8 [00:02<00:00,  3.70it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32
torch.Size([104, 139, 20])
torch.float32
torch.float32


100%|██████████| 8/8 [00:02<00:00,  4.04it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Test loss: -107881.9201
torch.Size([128, 139, 20])
torch.float32
torch.float32


 12%|█▎        | 1/8 [00:00<00:06,  1.02it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 25%|██▌       | 2/8 [00:01<00:05,  1.11it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 38%|███▊      | 3/8 [00:02<00:03,  1.26it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 50%|█████     | 4/8 [00:02<00:02,  1.38it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 62%|██████▎   | 5/8 [00:03<00:02,  1.38it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 75%|███████▌  | 6/8 [00:04<00:01,  1.42it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 88%|████████▊ | 7/8 [00:04<00:00,  1.48it/s]

torch.Size([104, 139, 20])
torch.float32
torch.float32


100%|██████████| 8/8 [00:05<00:00,  1.58it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch: 4 Train loss: -125476.2608
torch.Size([128, 139, 20])
torch.float32
torch.float32


 12%|█▎        | 1/8 [00:00<00:03,  1.85it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 25%|██▌       | 2/8 [00:00<00:02,  2.12it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 38%|███▊      | 3/8 [00:01<00:02,  2.47it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 50%|█████     | 4/8 [00:01<00:01,  2.64it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 62%|██████▎   | 5/8 [00:01<00:01,  2.91it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 88%|████████▊ | 7/8 [00:02<00:00,  3.54it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32
torch.Size([104, 139, 20])
torch.float32
torch.float32


100%|██████████| 8/8 [00:02<00:00,  3.84it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Test loss: -148473.6246
torch.Size([128, 139, 20])
torch.float32
torch.float32


 12%|█▎        | 1/8 [00:00<00:06,  1.15it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 25%|██▌       | 2/8 [00:01<00:04,  1.25it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 38%|███▊      | 3/8 [00:02<00:03,  1.29it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 50%|█████     | 4/8 [00:02<00:02,  1.34it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 62%|██████▎   | 5/8 [00:03<00:02,  1.34it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 75%|███████▌  | 6/8 [00:04<00:01,  1.40it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 88%|████████▊ | 7/8 [00:04<00:00,  1.47it/s]

torch.Size([104, 139, 20])
torch.float32
torch.float32


100%|██████████| 8/8 [00:05<00:00,  1.53it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch: 5 Train loss: -166968.2374
torch.Size([128, 139, 20])
torch.float32
torch.float32


 12%|█▎        | 1/8 [00:00<00:03,  1.80it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 25%|██▌       | 2/8 [00:00<00:02,  2.10it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 38%|███▊      | 3/8 [00:01<00:02,  2.44it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 50%|█████     | 4/8 [00:01<00:01,  2.78it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 62%|██████▎   | 5/8 [00:01<00:00,  3.10it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32


 88%|████████▊ | 7/8 [00:02<00:00,  3.74it/s]

torch.Size([128, 139, 20])
torch.float32
torch.float32
torch.Size([104, 139, 20])
torch.float32
torch.float32


100%|██████████| 8/8 [00:02<00:00,  4.08it/s]

Test loss: -193030.2448



