# Deep Markov Model with motion model

In [139]:
from tqdm import tqdm

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from tensorboardX import SummaryWriter
import numpy as np

In [140]:
batch_size = 128#128
epochs = 5
seed = 1
torch.manual_seed(seed)

<torch._C.Generator at 0x2b0703ecb4f0>

In [141]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [142]:
#計測モデルとか
def get_ot(st,lmap,max_range):
    dis = torch.sqrt((st[:,0]-lmap[0])**2+(st[:,1]-lmap[1])**2)
    angle = torch.atan2((lmap[1]-st[:,1]),(lmap[0]-st[:,0]))-st[:,2]
    return torch.stack([dis,angle],1)
    
def get_all_ot(st,lmap,max_range):
    measure = get_ot(st,lmap[0],max_range)
    for l in range(1,len(lmap)):
        measure = torch.cat([measure, get_ot(st,lmap[l],max_range)],1)
    return torch.tensor(measure)

In [143]:
landmark_num = 10
start_pos = [2.0,4.0,0.0]#x0,y0,yaw0

In [144]:
landmark_dim = 2
x_dim = landmark_num*2
h_dim = 32 #32
hidden_dim = 32 #32
z_dim = 3
u_dim = 2
t_max = 139

In [145]:
#データの読み込み
transform = transforms.Compose([transforms.ToTensor()])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

#data loader #とりあえず1時系列分を分身させて食わせてる
#[time,s_x,s_y,s_yaw,uv,ur,ot[1],,,,ot[N]]
data = np.loadtxt('vehicle_motion_data.csv', delimiter=',')
data = torch.tensor([data],dtype=torch.float32)
st = data[0,:,1:(1+z_dim)]
ut = data[0,:,(1+z_dim):(1+z_dim+u_dim)]
ot = data[0,:,(1+z_dim+u_dim):(1+z_dim+u_dim+x_dim)]

print(st.size())
st=st.repeat(1000,1,1)
ut=ut.repeat(1000,1,1)
ot=ot.repeat(1000,1,1)
print(st.size())


landmark = np.loadtxt('landmark_data.csv',delimiter=',')

train = torch.utils.data.TensorDataset(ot,ut)
train_loader = torch.utils.data.DataLoader(train, shuffle=False,**kwargs)
test = torch.utils.data.TensorDataset(ot,ut)
test_loader = torch.utils.data.DataLoader(test, shuffle=False,**kwargs)

torch.Size([139, 3])
torch.Size([1000, 139, 3])


In [146]:
from pixyz.models import Model
from pixyz.losses import KullbackLeibler, CrossEntropy, IterativeLoss
from pixyz.distributions import Bernoulli, Normal, Deterministic
from pixyz.utils import print_latex

In [147]:
class RNN(Deterministic):
    def __init__(self):
        super(RNN, self).__init__(cond_var=["x"], var=["h"])
        self.rnn = nn.GRU(x_dim, h_dim, bidirectional=True)
#         self.h0 = torch.zeros(2, batch_size, self.rnn.hidden_size).to(device)
        self.h0 = nn.Parameter(torch.zeros(2, 1, self.rnn.hidden_size))
        self.hidden_size = self.rnn.hidden_size
        
    def forward(self, x):
        h0 = self.h0.expand(2, x.size(1), self.rnn.hidden_size).contiguous()
        h, _ = self.rnn(x, h0)
        return {"h": h}

In [148]:
# class Generator(Bernoulli):
#     def __init__(self):
#         super(Generator, self).__init__(cond_var=["z"], var=["x"])
#         self.fc1 = nn.Linear(z_dim, hidden_dim)
#         self.fc2 = nn.Linear(hidden_dim, x_dim)
    
#     def forward(self, z):
#         print(z.size()) #[128,3]
#         h = F.relu(self.fc1(z))
#         return {"probs": torch.sigmoid(self.fc2(h))}
class Generator(Normal):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"])
    
    def forward(self, z):#計測モデルそのまま
        ot=get_all_ot(z,landmark,[1000,1000])
        return {"loc": ot,"scale":torch.tensor(0.3).to(device)}

In [149]:
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["h", "z_prev"], var=["z"])
        self.fc1 = nn.Linear(z_dim, h_dim*2)
        self.fc21 = nn.Linear(h_dim*2, z_dim)
        self.fc22 = nn.Linear(h_dim*2, z_dim)
        
    def forward(self, h, z_prev):
        h_z = torch.tanh(self.fc1(z_prev))
        h = 0.5 * (h + h_z)
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

In [150]:
class Prior(Normal):
    def __init__(self):
        super(Prior, self).__init__(cond_var=["z_prev","u"], var=["z"])
        
    def forward(self, z_prev, u):
        # motion model for two-wheel robot x,y,orient,v,steering
        z = torch.zeros(len(z_prev),z_dim).to(device)
        z[:,2] = z_prev[:,2] + u[:,1]
        z[:,0] = z_prev[:,0] + u[:,0] * torch.cos(z_prev[:,2] + u[:,1])
        z[:,1] = z_prev[:,1] + u[:,0] * torch.sin(z_prev[:,2] + u[:,1])

        return {"loc": z, "scale": torch.tensor(0.3).to(device)}

In [151]:
prior = Prior().to(device)
encoder = Inference().to(device)
decoder = Generator().to(device)
rnn = RNN().to(device)

In [152]:
print(prior)
print("*"*80)
print(encoder)
print("*"*80)
print(decoder)
print("*"*80)
print(rnn)

Distribution:
  p(z|z_{prev},u)
Network architecture:
  Prior(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['z_prev', 'u'], input_var=['z_prev', 'u'], features_shape=torch.Size([])
  )
********************************************************************************
Distribution:
  p(z|h,z_{prev})
Network architecture:
  Inference(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['h', 'z_prev'], input_var=['h', 'z_prev'], features_shape=torch.Size([])
    (fc1): Linear(in_features=3, out_features=64, bias=True)
    (fc21): Linear(in_features=64, out_features=3, bias=True)
    (fc22): Linear(in_features=64, out_features=3, bias=True)
  )
********************************************************************************
Distribution:
  p(x|z)
Network architecture:
  Generator(
    name=p, distribution_name=Normal,
    var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
  )
*****************************************************************

In [153]:
generate_from_prior = prior * decoder
print(generate_from_prior)
print_latex(generate_from_prior)

Distribution:
  p(x,z|z_{prev},u) = p(x|z)p(z|z_{prev},u)
Network architecture:
  Prior(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['z_prev', 'u'], input_var=['z_prev', 'u'], features_shape=torch.Size([])
  )
  Generator(
    name=p, distribution_name=Normal,
    var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
  )


<IPython.core.display.Math object>

In [154]:
step_loss = CrossEntropy(encoder, decoder) + KullbackLeibler(encoder, prior)
_loss = IterativeLoss(step_loss, max_iter=t_max, 
                      series_var=["x", "h", "u"], update_value={"z": "z_prev"})
loss = _loss.expectation(rnn).mean()

In [155]:
dmm = Model(loss, distributions=[rnn, encoder, decoder, prior], 
            optimizer=optim.RMSprop, optimizer_params={"lr": 5e-4}, clip_grad_value=10)

In [156]:
print(dmm)
print_latex(dmm)

Distributions (for training): 
  p(h|x), p(z|h,z_{prev}), p(x|z), p(z|z_{prev},u) 
Loss function: 
  mean \left(\mathbb{E}_{p(h|x)} \left[\sum_{t=1}^{139} \left(D_{KL} \left[p(z|h,z_{prev})||p(z|z_{prev},u) \right] - \mathbb{E}_{p(z|h,z_{prev})} \left[\log p(x|z) \right]\right) \right] \right) 
Optimizer: 
  RMSprop (
  Parameter Group 0
      alpha: 0.99
      centered: False
      eps: 1e-08
      lr: 0.0005
      momentum: 0
      weight_decay: 0
  )


<IPython.core.display.Math object>

In [157]:
def data_loop(epoch, loader, model, device, train_mode=False):
    mean_loss = 0
    for idx,[o,u] in enumerate(tqdm(loader)):#batch_idx, (data, _) in enumerate(tqdm(loader)):
        o = o.to(device)
        u = u.to(device)
        batch_size = o.size()[0]
        x = o.transpose(0, 1) #多分転置してるだけ
        u = u.transpose(0, 1)
        z_prev = torch.tensor(start_pos)#初期姿勢
        z_prev = z_prev.repeat(batch_size, 1).to(device)
        if train_mode:
            mean_loss += model.train({'x': x, 'u':u, 'z_prev': z_prev}).item() * batch_size
        else:
            mean_loss += model.test({'x': x, 'u':u, 'z_prev': z_prev}).item() * batch_size
    mean_loss /= len(loader.dataset)
    if train_mode:
        print('Epoch: {} Train loss: {:.4f}'.format(epoch, mean_loss))
    else:
        print('Test loss: {:.4f}'.format(mean_loss))
    return mean_loss

In [161]:
def plot_image_from_latent(batch_size):
    x = []
    z_prev = torch.zeros(batch_size, z_dim).to(device)
    u0 = torch.zeros(batch_size, u_dim).to(device)
    for step in range(t_max):
        samples = generate_from_prior.sample({'z_prev': z_prev,'u':u0})
        x_t = decoder.sample_mean({"z": samples["z"]})
        z_prev = samples["z"]
        x.append(x_t[None, :])
    x = torch.cat(x, dim=0).transpose(0, 1)
    return x

In [162]:
writer = SummaryWriter()

for epoch in range(1, epochs + 1):
    train_loss = data_loop(epoch, train_loader, dmm, device, train_mode=True)
    test_loss = data_loop(epoch, test_loader, dmm, device)

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('test_loss', test_loss, epoch)

    sample = plot_image_from_latent(batch_size)[:, None][1,:]
    writer.add_image('Image_from_latent', sample, epoch)

  # This is added back by InteractiveShellApp.init_path()


torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 12%|█▎        | 1/8 [00:00<00:06,  1.01it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 25%|██▌       | 2/8 [00:01<00:05,  1.03it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 38%|███▊      | 3/8 [00:02<00:04,  1.05it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 50%|█████     | 4/8 [00:03<00:03,  1.07it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 62%|██████▎   | 5/8 [00:04<00:02,  1.07it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 75%|███████▌  | 6/8 [00:05<00:01,  1.08it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 88%|████████▊ | 7/8 [00:06<00:00,  1.10it/s]

torch.Size([139, 104, 20])
torch.Size([139, 104, 2])


100%|██████████| 8/8 [00:07<00:00,  1.12it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch: 1 Train loss: 22884450.9920
torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 12%|█▎        | 1/8 [00:00<00:05,  1.38it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 25%|██▌       | 2/8 [00:01<00:04,  1.42it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 38%|███▊      | 3/8 [00:02<00:03,  1.45it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 50%|█████     | 4/8 [00:02<00:02,  1.47it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 62%|██████▎   | 5/8 [00:03<00:02,  1.49it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 75%|███████▌  | 6/8 [00:04<00:01,  1.50it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 88%|████████▊ | 7/8 [00:04<00:00,  1.53it/s]

torch.Size([139, 104, 20])
torch.Size([139, 104, 2])


100%|██████████| 8/8 [00:05<00:00,  1.56it/s]


Test loss: 22876623.6640


  0%|          | 0/8 [00:00<?, ?it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 12%|█▎        | 1/8 [00:00<00:06,  1.02it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 25%|██▌       | 2/8 [00:01<00:05,  1.02it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 38%|███▊      | 3/8 [00:02<00:04,  1.04it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 50%|█████     | 4/8 [00:03<00:03,  1.06it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 62%|██████▎   | 5/8 [00:04<00:02,  1.07it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 75%|███████▌  | 6/8 [00:05<00:01,  1.07it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 88%|████████▊ | 7/8 [00:06<00:00,  1.09it/s]

torch.Size([139, 104, 20])
torch.Size([139, 104, 2])


100%|██████████| 8/8 [00:07<00:00,  1.10it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch: 2 Train loss: 22869250.5600
torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 12%|█▎        | 1/8 [00:00<00:05,  1.39it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 25%|██▌       | 2/8 [00:01<00:04,  1.42it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 38%|███▊      | 3/8 [00:02<00:03,  1.45it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 50%|█████     | 4/8 [00:02<00:02,  1.46it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 62%|██████▎   | 5/8 [00:03<00:02,  1.48it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 75%|███████▌  | 6/8 [00:04<00:01,  1.49it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 88%|████████▊ | 7/8 [00:04<00:00,  1.53it/s]

torch.Size([139, 104, 20])
torch.Size([139, 104, 2])


100%|██████████| 8/8 [00:05<00:00,  1.55it/s]


Test loss: 22862714.6560


  0%|          | 0/8 [00:00<?, ?it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 12%|█▎        | 1/8 [00:00<00:06,  1.02it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 25%|██▌       | 2/8 [00:01<00:05,  1.04it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 38%|███▊      | 3/8 [00:02<00:04,  1.05it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 50%|█████     | 4/8 [00:03<00:03,  1.06it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 62%|██████▎   | 5/8 [00:04<00:02,  1.07it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 75%|███████▌  | 6/8 [00:05<00:01,  1.08it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 88%|████████▊ | 7/8 [00:06<00:00,  1.09it/s]

torch.Size([139, 104, 20])
torch.Size([139, 104, 2])


100%|██████████| 8/8 [00:07<00:00,  1.11it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch: 3 Train loss: 22859683.8880
torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 12%|█▎        | 1/8 [00:00<00:05,  1.39it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 25%|██▌       | 2/8 [00:01<00:04,  1.43it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 38%|███▊      | 3/8 [00:02<00:03,  1.45it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 50%|█████     | 4/8 [00:02<00:02,  1.45it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 62%|██████▎   | 5/8 [00:03<00:02,  1.48it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 75%|███████▌  | 6/8 [00:04<00:01,  1.47it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 88%|████████▊ | 7/8 [00:04<00:00,  1.51it/s]

torch.Size([139, 104, 20])
torch.Size([139, 104, 2])


100%|██████████| 8/8 [00:05<00:00,  1.54it/s]


Test loss: 22843510.1440


  0%|          | 0/8 [00:00<?, ?it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 12%|█▎        | 1/8 [00:00<00:06,  1.02it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 25%|██▌       | 2/8 [00:01<00:05,  1.04it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 38%|███▊      | 3/8 [00:02<00:04,  1.06it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 50%|█████     | 4/8 [00:03<00:03,  1.07it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 62%|██████▎   | 5/8 [00:04<00:02,  1.07it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 75%|███████▌  | 6/8 [00:05<00:01,  1.08it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 88%|████████▊ | 7/8 [00:06<00:00,  1.10it/s]

torch.Size([139, 104, 20])
torch.Size([139, 104, 2])


100%|██████████| 8/8 [00:07<00:00,  1.11it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch: 4 Train loss: 22835421.7280
torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 12%|█▎        | 1/8 [00:00<00:05,  1.37it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 25%|██▌       | 2/8 [00:01<00:04,  1.41it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 38%|███▊      | 3/8 [00:02<00:03,  1.43it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 50%|█████     | 4/8 [00:02<00:02,  1.45it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 62%|██████▎   | 5/8 [00:03<00:02,  1.47it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 75%|███████▌  | 6/8 [00:04<00:01,  1.48it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 88%|████████▊ | 7/8 [00:04<00:00,  1.51it/s]

torch.Size([139, 104, 20])
torch.Size([139, 104, 2])


100%|██████████| 8/8 [00:05<00:00,  1.54it/s]


Test loss: 22815634.3040


  0%|          | 0/8 [00:00<?, ?it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 12%|█▎        | 1/8 [00:01<00:07,  1.03s/it]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 25%|██▌       | 2/8 [00:01<00:05,  1.00it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 38%|███▊      | 3/8 [00:02<00:04,  1.03it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 50%|█████     | 4/8 [00:03<00:03,  1.05it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 62%|██████▎   | 5/8 [00:04<00:02,  1.06it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 75%|███████▌  | 6/8 [00:05<00:01,  1.07it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 88%|████████▊ | 7/8 [00:06<00:00,  1.09it/s]

torch.Size([139, 104, 20])
torch.Size([139, 104, 2])


100%|██████████| 8/8 [00:07<00:00,  1.11it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch: 5 Train loss: 22800271.2960
torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 12%|█▎        | 1/8 [00:00<00:05,  1.37it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 25%|██▌       | 2/8 [00:01<00:04,  1.41it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 38%|███▊      | 3/8 [00:02<00:03,  1.44it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 50%|█████     | 4/8 [00:02<00:02,  1.46it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 62%|██████▎   | 5/8 [00:03<00:02,  1.49it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 75%|███████▌  | 6/8 [00:03<00:01,  1.52it/s]

torch.Size([139, 128, 20])
torch.Size([139, 128, 2])


 88%|████████▊ | 7/8 [00:04<00:00,  1.55it/s]

torch.Size([139, 104, 20])
torch.Size([139, 104, 2])


100%|██████████| 8/8 [00:05<00:00,  1.57it/s]


Test loss: 22780466.4320
