In [1]:
import numpy as np
from sklearn.linear_model import lars_path
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
from tqdm.notebook import trange

from src.models.ae import ConvAE, Decoder, Encoder
from src.models.anode import ANODENet

from src.data.box import GravityHoleBall
from src.data.generate import generate_gravity_hole_ball_images

from src.utils.utils import add_spatial_encoding
from src.utils.node import  BatchGetterMultiImages, train_convnode, train_convnode_with_batch
from src.utils.viz import  display_convnode_trajectory

In [12]:

class ConvNodeWithBatch(nn.Module):
    def __init__(self, device, size, latent_dim, in_channels,
    ode_hidden_dim, ode_out_dim, augment_dim=0, time_dependent=False, 
    ode_non_linearity='relu', conv_activation=nn.ReLU(),latent_activation=None, stack_size=1):
        super(ConvNodeWithBatch, self).__init__()
        self.device = device
        self.size = size
        self.latent_dim = latent_dim
        self.in_channels = in_channels
        self.conv_activation = conv_activation
        self.latent_activation = latent_activation
        self.ode_hidden_dim = ode_hidden_dim
        self.out_dim = ode_out_dim
        self.augment_dim = augment_dim
        self.time_dependent = time_dependent
        self.ode_non_linearity = ode_non_linearity

        print("-"*50)
        print("Creating ConvAE...")
        self.encoder = TimeDistributed(
            Encoder(device=device, latent_dim=latent_dim, in_channels=in_channels,
            activation=conv_activation, relu=latent_activation).to(device), 
            len_shape_without_batch=4, # input without batch are (times, channels, height, width)
            batch_first=True
        )
        self.decoder = TimeDistributed(
            Decoder(device=device, latent_dim=latent_dim, in_channels=in_channels,
            activation=conv_activation).to(device),
            len_shape_without_batch=2, # input without batch are (times, latent_dim)
            batch_first=True
        )

        print("-"*50)
        print("Creating ANODENet...")
        self.node = ANODENet(device, latent_dim*(stack_size + 1), ode_hidden_dim, ode_out_dim, augment_dim, time_dependent=False,
            non_linearity=ode_non_linearity).to(device)

    def forward(self, images, times, dt):
        # images: [(batch), n_stack, in_channels, height, width]
        # latent_z: [n_stack, latent_dim]
        # print("input_images: ", images.shape)
        latent_z = self.encoder(images)
        # print("latent_z: ", latent_z.shape)
        
        # latent_z_stack: [(batch), n_stack, latent_dim*(n_stack+1)]
        # for the moment n_stack = 1
        if len(latent_z.shape) == 3:
            latent_z_stack = torch.cat([latent_z[:, :-1], (latent_z[:, 1:]-latent_z[:, :-1])/dt], dim=-1).squeeze(1)
        

        elif len(latent_z.shape) == 2:
            latent_z_stack = torch.cat([latent_z[:-1], (latent_z[1:]-latent_z[:-1])/dt], dim=-1)

        # print("latent_z_stack: ", latent_z_stack.shape)

        # sim : [times, (batch),ode_out_dim]
        sim = self.node(latent_z_stack, times)[..., :latent_z.shape[-1]]
        # print("sim: ", sim.shape)
        # sim : [(batch), n_stack, ode_out_dim]
        if len(images.shape) == 5:
            sim = sim.swapdims(0,1)
        else:
            sim = sim.squeeze(1)
        # print("sim: ", sim.shape)

        reconstructed_images = self.decoder(sim)
        # print("reconstructed_images: ", reconstructed_images.shape)

        return reconstructed_images, sim

    def encode(self, images):

        return self.encoder(images)

    def decode(self, latent_z):
        return self.decoder(latent_z)




class TimeDistributed(nn.Module):
    def __init__(self, module, len_shape_without_batch, batch_first=False):
        super(TimeDistributed, self).__init__()
        self.module = module
        self._len_shape_without_batch = len_shape_without_batch
        self.batch_first = batch_first

    def forward(self, x):
        # x: [batch, time, *]
        assert len(x.shape) == self._len_shape_without_batch or self._len_shape_without_batch + 1, f"Input must have shape {self._len_shape_without_batch}D or {self._len_shape_without_batch + 1}D, received {len(x.shape)}D"

        if len(x.size()) == self._len_shape_without_batch:
            return self.module(x)

        batch_flatten_shapes = list(x.shape[1:])
        batch_flatten_shapes[0] = -1
        # Squash samples and timesteps into a single axis
        x_reshape = x.contiguous().reshape(batch_flatten_shapes)  # (samples * timesteps, input_size)
        # print("TimeDistributed: x_reshape: ", x_reshape.shape)
        y = self.module(x_reshape)
        # print("TimeDistributed: y: ", y.shape)

        # We have to reshape Y
        
        if self.batch_first:
            final_shapes = [x.shape[0], -1] + list(y.shape[1:])
            y = y.contiguous().view(final_shapes)  # (samples, timesteps, output_size)
        else:
            final_shapes = [-1, x.shape[1]] + list(y.shape[1:])
            y = y.view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)
        # print("TimeDistributed: y return: ", y.shape)


        # print('TimeDistributed: y return: ', y.shape)    
        return y


In [32]:
conv_ode = ConvNodeWithBatch(device='cpu', size=28, latent_dim=25, in_channels=3,
    ode_hidden_dim=128, ode_out_dim=25, augment_dim=0, time_dependent=False,
    ode_non_linearity='relu', conv_activation=nn.ReLU(),latent_activation=None, stack_size=1)

--------------------------------------------------
Creating ConvAE...
Number of parameters in the encoder model: 137945
Number of parameters in the decoder model: 138083
--------------------------------------------------
Creating ANODENet...
Number of parameters in the model: 30765


In [33]:
input_test_2 = torch.randn(10, 2, 3, 28, 28)
times = torch.linspace(0, 1, 20)
print(times)
res1 = conv_ode(input_test_2, times, 1./20)
print(len(res1), res1[0].shape, res1[1].shape)


tensor([0.0000, 0.0526, 0.1053, 0.1579, 0.2105, 0.2632, 0.3158, 0.3684, 0.4211,
        0.4737, 0.5263, 0.5789, 0.6316, 0.6842, 0.7368, 0.7895, 0.8421, 0.8947,
        0.9474, 1.0000])
2 torch.Size([10, 20, 3, 28, 28]) torch.Size([10, 20, 25])


In [34]:
input_test = torch.randn(2, 3, 28, 28)
res2 = conv_ode(input_test, times, 1./17)
print(len(res2), res2[0].shape, res2[1].shape)

2 torch.Size([20, 3, 28, 28]) torch.Size([20, 25])


In [39]:
MARGIN_MIN = 5
MIN_INIT_VELOCITY = 200.
WIDTH, HEIGHT = 28, 28
RADIUS = 3

infos = {
    "MARGIN_MIN":MARGIN_MIN,
    "MIN_INIT_VELOCITY":MIN_INIT_VELOCITY,
    "WIDTH":WIDTH,
    "HEIGHT":HEIGHT,
    "RADIUS":RADIUS
}

x = WIDTH/4.
y = HEIGHT/4.
vx = 0.
vy = 0.

box = GravityHoleBall(x, y, vx, vy, (WIDTH, HEIGHT),RADIUS)


Num_pos_velocity = 1
N = 100
N_frames = 300 + Num_pos_velocity
dt = 1./N_frames

times = np.arange(0, N_frames*dt, dt)

# encoded_trajectory = generate_gravity_hole_ball_positions(box, N=N, N_frames=N_frames, dt=dt)[:,:,:]
# print(encoded_trajectory.shape)
print("-"*50)
print("Generating images...")
images = generate_gravity_hole_ball_images(box, N=N, N_frames=N_frames, dt=dt, infos=infos).reshape(-1, 1, HEIGHT, WIDTH)
print(images.shape)
# dataset = [(image, 0) for image in dataset]
# dataset = add_spatial_encoding(dataset)
# print(len(dataset), len(dataset[0]), dataset[0][0].shape)
images = torch.from_numpy(add_spatial_encoding(images)).float().reshape(N, -1, 3, HEIGHT, WIDTH)
print(images.shape)


--------------------------------------------------
Generating images...


100%|██████████| 100/100 [00:03<00:00, 32.05it/s]


(30100, 1, 28, 28)
torch.Size([100, 301, 3, 28, 28])


In [42]:

batch_size = 16
batch_time = 200
n_stack = 1
total_length = N_frames - Num_pos_velocity
# batch_size = 64

class BatchGetterMultiImages:
    def __init__(self, batch_time, batch_size, n_stack, total_length, dt, images, frac_train):
        # N: number of trajectories
        # M: number of time steps
        # D: dimension of the state space
        # positions: (N, T, D)
        self.times = torch.linspace(0., total_length*dt, total_length, dtype=torch.float64).float()
        if isinstance(images, torch.Tensor):
            self.true_images = images.float()

        elif isinstance(images, np.ndarray):
            self.true_images = torch.from_numpy(images).float()

        else:
            assert False, "positions must be either a torch.Tensor or a np.ndarray"

        self.N_train = int(images.shape[0]*frac_train)

        self.train_times = self.times #[:self.N_train]
        self.test_times = self.times #[self.N_train:]
        self.train_images = self.true_images[:self.N_train]
        self.test_images = self.true_images[self.N_train:]
        self.batch_size = batch_size
        self.n_stack = n_stack
        self.batch_time = batch_time
        self.dt = dt
        self.total_length = total_length

    def get_batch(self):
        index = np.random.randint(0, self.N_train, self.batch_size)
        s = torch.from_numpy(np.random.choice(np.arange(self.train_times.shape[0] - self.batch_time, dtype=np.int64), 1, replace=False))
        batch_y0 = self.train_images[index, s:s+self.n_stack+1].squeeze(0) # (M, D)
        batch_t = self.train_times[:self.batch_time]  # (T)
        batch_y = torch.stack([self.train_images[index, s + i] for i in range(self.batch_time)], dim=1).squeeze(1)  # (T, M, D)
        return batch_y0, batch_t, batch_y

getter = BatchGetterMultiImages(batch_time, batch_size, n_stack, total_length, dt, images, frac_train=1.)

optimizer = torch.optim.Adam(conv_ode.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.9)
loss_fn = nn.MSELoss()

In [43]:
sample = getter.get_batch()
print(sample[0].shape, sample[1].shape, sample[2].shape)
res1 = conv_ode(sample[0], sample[1], 1./total_length)
print(len(res1), res1[0].shape, res1[1].shape)

torch.Size([16, 2, 3, 28, 28]) torch.Size([200]) torch.Size([16, 200, 3, 28, 28])
2 torch.Size([16, 200, 3, 28, 28]) torch.Size([16, 200, 25])


In [44]:
epochs = 3000
root = None #"images/AE_ODE/Gravity/MultiTrajectories/Together/"
name = None # "conv_ode_1_ball_latent_{}_hidden_ode_{}_stack_{}_conv_activation_{}".format(latent_dim, ode_hidden_dim, stack_size, conv_activation)
display_fn = lambda i, model, out_display, getter, final_time, dt: display_convnode_trajectory(i, model, out_display, getter, final_time, dt, root=root, name=name)
# train_convnode(conv_ode, optimizer, scheduler, epochs, 1, getter, display=100, loss_fn=None, display_results_fn=display_fn)
train_convnode_with_batch(conv_ode, optimizer, scheduler, epochs, getter, display=100, loss_fn=None, display_results_fn=display_fn)

  0%|          | 10/3000 [00:07<38:06,  1.31it/s, Loss: 0.01467174]


KeyboardInterrupt: 