In [1]:
import torch
import numpy as np
from kvae_model import KVAEModel
from kvae import KVAE
from auxiliary import load_bouncing_ball, export_vid, export_latent_space_vis, save_frames
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
batch_size = 64
train_dataset, val_dataset =  load_bouncing_ball("nonlinear_ball_data", "circle", singular=False)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle = True, num_workers = 0)
val_dataloader   = torch.utils.data.DataLoader(val_dataset,   batch_size=batch_size, shuffle = True, num_workers = 0)

## Load pretrained model

In [3]:
model = KVAEModel(x_dim = 32 * 32, 
            a_dim = 4,
            z_dim = 4, 
            x_2d=True,
            init_kf_mat = 0.05,
            noise_transition = 0.08,
            noise_emission = 0.03,
            init_cov = 20,
            K = 1,
            dim_RNN_alpha = 50,
            num_RNN_alpha = 2,
            dropout_p = 0,
            scale_reconstruction = 0.3,
            device='cpu').to('cpu')
model.build()

kvae = KVAE(model=model,
            lr = 3e-3,#3e-6,
            lr_tot = 1e-3,#1e-6,
            epochs = 20,
            batch_size = batch_size,
            early_stop_patience = 20,
            save_frequency = 1,
            only_vae_epochs = 5,
            kf_update_epochs = 5,
            save_dir = "results")

In [4]:
kvae.model.load_state_dict(torch.load("results/t25_circle_IS_G_a4/model_at_85.pt", map_location='cpu'))

<All keys matched successfully>

## Extract results from a training sample

In [5]:
export_vid(kvae, train_dataloader, j=0, both=True, epoch=555)

torch.Size([64, 32, 32, 40]) (64, 32, 32, 40)


In [6]:
t = torch.Tensor(train_dataloader.dataset[0:0+64])
kvae.model.eval()
with torch.no_grad():
    x, y_forward, a, a_gen, mu_smooth = kvae.model.forward_debug(t, compute_loss=False)

In [6]:
z_temp = np.array([[[0.0,0.0,0.0,0.0]]])

In [7]:
# mu_smooth in seq_len, batch_size, z_dim
z_temp = mu_smooth[:, 0:3, :]

In [8]:
z_temp[0, 0, :]

array([ -3.7295022,  -2.42     , -21.364344 , -11.816754 ], dtype=float32)

In [8]:
z_temp = torch.tensor(z_temp).detach()
seq_len, batch_size, _ = z_temp.shape
z_temp = z_temp.permute(1,-1,0)

In [9]:
N = 200

In [10]:
z_samples = torch.zeros((batch_size, kvae.model.z_dim, seq_len+N))
y_samples = torch.zeros((batch_size, kvae.model.x_dim, seq_len+N))
a_samples = torch.zeros((batch_size, kvae.model.a_dim, seq_len+N))

In [11]:
z = z_temp[:, :, -2:-1] #batch_size, z_dim, 1

In [12]:
K = 1

In [13]:
A_flatten = kvae.model.A.view(K, kvae.model.z_dim*kvae.model.z_dim) # (K, z_dim*z_dim) 
B_flatten = kvae.model.B.view(K, kvae.model.z_dim*kvae.model.u_dim) # (K, z_dim*u_dim) 
C_flatten = kvae.model.C.view(K, kvae.model.a_dim*kvae.model.z_dim) # (K, a_dim*z_dim) 

In [14]:
A_mix = A_flatten.view(1, kvae.model.z_dim, kvae.model.z_dim)
B_mix = B_flatten.view(1, kvae.model.z_dim, kvae.model.u_dim)
C_mix = C_flatten.view(1, kvae.model.a_dim, kvae.model.z_dim)

A_mix = A_mix.repeat((batch_size, 1, 1))
B_mix = B_mix.repeat((batch_size, 1, 1))
C_mix = C_mix.repeat((batch_size, 1, 1))

In [15]:
def z_to_y(model, z, C): #z in (batch_size, z_dim, seq_len)
    model.eval()
    _, _, seq_len = z.shape
    #print(C.shape, z.shape)

    a_gen = C.bmm(z)
    a_gen = a_gen.permute(-1, 0, 1)
    #print(a_gen.shape)
    with torch.no_grad():
        y = model.decode(a_gen).permute(1,-1,0) #(batch_size, dim, seq_len)
    return a_gen.permute(1,-1,0), y

In [16]:
z_samples[:, :, :seq_len] = z_temp
a_gen, y = z_to_y(kvae.model, z_temp, C_mix)
a_samples[:, :, :seq_len] = a_gen
y_samples[:, :, :seq_len] = y

In [17]:
z_samples[:, :, :seq_len].shape

torch.Size([3, 4, 40])

In [18]:
for i in range(seq_len, seq_len+N):
    a_gen, y = z_to_y(kvae.model, z, C_mix)
    z = A_mix.bmm(z)
    
    z_samples[:, :, i:i+1] = z
    a_samples[:, :, i:i+1] = a_gen
    y_samples[:, :, i:i+1] = y

In [19]:
y_samples = y_samples.view(batch_size, 32, 32, seq_len+N)

In [20]:
save_frames(y_samples)