In [1]:
import numpy as np 
import torch
from data.wrappers import load_data
from model.odegpvae import ODEGPVAE

device = 'cuda'
data_root = 'data/'
task = 'mnist'


In [2]:
########### data ############ 
trainset, testset = load_data(data_root, task, plot=True)


########### model ###########
odegpvae = ODEGPVAE(q=8,n_filt=16, device=device).to(device)

In [3]:
len(trainset)

20

In [52]:
next(odegpvae.ode_model.parameters()).is_cuda

True

In [53]:
next(odegpvae.parameters()).is_cuda

True

In [54]:
next(odegpvae.enc_s.parameters()).is_cuda

True

In [55]:
next(odegpvae.decoder.parameters()).is_cuda

True

In [56]:
next(odegpvae.enc_v.parameters()).is_cuda

True

In [57]:
odegpvae.ode_model.svgp.device

device(type='cuda', index=0)

In [58]:
# ########### train ###########
Nepoch = 500
optimizer = torch.optim.Adam(odegpvae.parameters(),lr=1e-3)

In [59]:
minibatch = next(iter(trainset)).to(device)

In [60]:
X = minibatch

In [61]:
[N,T,nc,d,d] = X.shape

In [62]:
X.shape

torch.Size([25, 16, 1, 28, 28])

In [63]:
X[:,0].shape

torch.Size([25, 1, 28, 28])

In [64]:
#h = odegpvae.encoder(X[:,0])
s0_mu, s0_logv = odegpvae.enc_s(X[:,0])

In [65]:
v0_mu, v0_logv = odegpvae.enc_v(torch.squeeze(X[:,0:odegpvae.v_steps]))

In [66]:
print(s0_mu.shape, s0_logv.shape)
print(v0_mu.shape, v0_logv.shape)

torch.Size([25, 8]) torch.Size([25, 8])
torch.Size([25, 8]) torch.Size([25, 8])


In [67]:
eps_s0   = torch.randn_like(s0_mu)
eps_v0   = torch.randn_like(v0_mu)

In [68]:
eps_v0.is_cuda

True

In [69]:
s0 = s0_mu + eps_s0*torch.exp(s0_logv) #N,q
v0 = v0_mu + eps_v0*torch.exp(v0_logv) #N,q

In [70]:
odegpvae._zero_mean.is_cuda

True

In [71]:
logp0 = odegpvae.mvn.log_prob(eps_s0) + odegpvae.mvn.log_prob(eps_v0) # N 

In [15]:
logp0.shape

torch.Size([25])

ode

In [72]:
z0 = torch.concat([v0,s0],dim=1)

In [73]:
z0.shape

torch.Size([25, 16])

In [77]:
odegpvae.ode_model.svgp.Lzz.evaluate().is_cuda

True

In [82]:
odegpvae.ode_model.svgp.variational_strategy.variational_distribution.scale_tril.is_cuda

False

In [74]:
gp_draw = odegpvae.ode_model.svgp.draw_posterior_function()
oderhs = lambda t, x: odegpvae.ode_model.svgp.ode_rhs(t,x,gp_draw)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_bmm)

In [29]:
from torchdiffeq import odeint
ts  =  0.1 * torch.arange(T,dtype=torch.float).to(z0.device)
zt, logp = odeint(oderhs, (z0, logp0), ts, method="euler")

In [30]:
ztL   = []
logpL = []
ztL.append(zt.permute([1,0,2]).unsqueeze(0)) # 1,N,T,2q
logpL.append(logp.permute([1,0]).unsqueeze(0)) # 1,N,T
ztL   = torch.cat(ztL,0)
logpL = torch.cat(logpL)

decode

In [31]:
st_muL = ztL[:,:,:,8:] # L,N,T,q Only the position is decoded
s = odegpvae.fc3(st_muL.contiguous().view([1*N*T,8]) ) # L*N*T,h_dim
Xrec = odegpvae.decoder(s) # L*N*T,nc,d,d
Xrec = Xrec.view([1,N,T,nc,d,d]) # L,N,T,nc,d,d

In [32]:
Xrec.shape

torch.Size([1, 25, 16, 1, 28, 28])

In [2]:
for i in range(9):
    print(i)

0
1
2
3
4
5
6
7
8


In [77]:
ztL.shape

torch.Size([1, 25, 16, 16])

In [78]:
st_muL = ztL[:,:,:,q:] 

In [79]:
s = odegpvae.fc3(st_muL.contiguous().view([1*N*T,q]) ) # L*N*T,h_dim

In [80]:
s.shape

torch.Size([400, 1024])

In [81]:
Xrec = odegpvae.decoder(s)

In [82]:
Xrec.shape

torch.Size([400, 1, 28, 28])

In [83]:
Xrec = Xrec.view([1,N,T,nc,d,d])

In [84]:
Xrec.shape

torch.Size([1, 25, 16, 1, 28, 28])

In [85]:
XL = X.repeat([1,1,1,1,1,1])

In [92]:
a = torch.rand(4,2,4,2)

In [99]:
a.shape

torch.Size([4, 2, 4, 2])

In [102]:
a = a.reshape(4,4,2,2)

In [103]:
a.shape

torch.Size([4, 4, 2, 2])

In [105]:
a.view(-1,64)

tensor([[0.0846, 0.5949, 0.6314, 0.9313, 0.8341, 0.6749, 0.7853, 0.6984, 0.5148,
         0.3916, 0.0730, 0.5372, 0.9051, 0.1531, 0.0383, 0.2877, 0.1880, 0.2894,
         0.8003, 0.5723, 0.2640, 0.7646, 0.3810, 0.0985, 0.7947, 0.5044, 0.7920,
         0.4210, 0.7991, 0.3221, 0.0861, 0.0154, 0.6771, 0.5737, 0.2921, 0.5412,
         0.6293, 0.1227, 0.8588, 0.6355, 0.5088, 0.4454, 0.0304, 0.0453, 0.2842,
         0.8785, 0.2644, 0.2020, 0.4442, 0.0312, 0.3873, 0.1059, 0.9714, 0.0748,
         0.2523, 0.1219, 0.1804, 0.2995, 0.2924, 0.4937, 0.3388, 0.7356, 0.7520,
         0.1684]])