In [1]:
import numpy as np
from sklearn.linear_model import lars_path
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
from tqdm.notebook import trange

from src.models.ae import ConvAE, Decoder, Encoder
from src.models.anode import ANODENet

In [76]:

class ConvNodeWithBatch(nn.Module):
    def __init__(self, device, size, latent_dim, in_channels,
    ode_hidden_dim, ode_out_dim, augment_dim=0, time_dependent=False, 
    ode_non_linearity='relu', conv_activation=nn.ReLU(),latent_activation=None, stack_size=1):
        super(ConvNodeWithBatch, self).__init__()
        self.device = device
        self.size = size
        self.latent_dim = latent_dim
        self.in_channels = in_channels
        self.conv_activation = conv_activation
        self.latent_activation = latent_activation
        self.ode_hidden_dim = ode_hidden_dim
        self.out_dim = ode_out_dim
        self.augment_dim = augment_dim
        self.time_dependent = time_dependent
        self.ode_non_linearity = ode_non_linearity

        print("-"*50)
        print("Creating ConvAE...")
        self.encoder = TimeDistributed(
            Encoder(device=device, latent_dim=latent_dim, in_channels=in_channels,
            activation=conv_activation, relu=latent_activation).to(device), 
            len_shape_without_batch=4, # input without batch are (times, channels, height, width)
            batch_first=True
        )
        self.decoder = TimeDistributed(
            Decoder(device=device, latent_dim=latent_dim, in_channels=in_channels,
            activation=conv_activation).to(device),
            len_shape_without_batch=2, # input without batch are (times, latent_dim)
            batch_first=True
        )

        print("-"*50)
        print("Creating ANODENet...")
        self.node = ANODENet(device, latent_dim*(stack_size + 1), ode_hidden_dim, ode_out_dim, augment_dim, time_dependent=False,
            non_linearity=ode_non_linearity).to(device)

    def forward(self, images, times, dt):
        # images: [(batch), n_stack, in_channels, height, width]
        # latent_z: [n_stack, latent_dim]
        # print("input_images: ", images.shape)
        latent_z = self.encoder(images)
        # print("latent_z: ", latent_z.shape)
        
        # latent_z_stack: [(batch), n_stack, latent_dim*(n_stack+1)]
        # for the moment n_stack = 1
        if len(latent_z.shape) == 3:
            latent_z_stack = torch.cat([latent_z[:, :-1], (latent_z[:, 1:]-latent_z[:, :-1])/dt], dim=-1).squeeze(1)
        

        elif len(latent_z.shape) == 2:
            latent_z_stack = torch.cat([latent_z[:-1], (latent_z[1:]-latent_z[:-1])/dt], dim=-1)

        # print("latent_z_stack: ", latent_z_stack.shape)

        # sim : [times, (batch),ode_out_dim]
        sim = self.node(latent_z_stack, times)[..., :latent_z.shape[-1]]
        # print("sim: ", sim.shape)
        # sim : [(batch), n_stack, ode_out_dim]
        if len(images.shape) == 5:
            sim = sim.swapdims(0,1)
        else:
            sim = sim.squeeze(1)
        # print("sim: ", sim.shape)

        reconstructed_images = self.decoder(sim)
        # print("reconstructed_images: ", reconstructed_images.shape)

        return reconstructed_images, sim



class TimeDistributed(nn.Module):
    def __init__(self, module, len_shape_without_batch, batch_first=False):
        super(TimeDistributed, self).__init__()
        self.module = module
        self._len_shape_without_batch = len_shape_without_batch
        self.batch_first = batch_first

    def forward(self, x):
        # x: [batch, time, *]
        assert len(x.shape) == self._len_shape_without_batch or self._len_shape_without_batch + 1, f"Input must have shape {self._len_shape_without_batch}D or {self._len_shape_without_batch + 1}D, received {len(x.shape)}D"

        if len(x.size()) == self._len_shape_without_batch:
            return self.module(x)

        batch_flatten_shapes = list(x.shape[1:])
        batch_flatten_shapes[0] = -1
        # Squash samples and timesteps into a single axis
        x_reshape = x.contiguous().reshape(batch_flatten_shapes)  # (samples * timesteps, input_size)
        # print("TimeDistributed: x_reshape: ", x_reshape.shape)
        y = self.module(x_reshape)
        # print("TimeDistributed: y: ", y.shape)

        # We have to reshape Y
        
        if self.batch_first:
            final_shapes = [x.shape[0], -1] + list(y.shape[1:])
            y = y.contiguous().view(final_shapes)  # (samples, timesteps, output_size)
        else:
            final_shapes = [-1, x.shape[1]] + list(y.shape[1:])
            y = y.view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)
        # print("TimeDistributed: y return: ", y.shape)


        # print('TimeDistributed: y return: ', y.shape)    
        return y


In [77]:
model = ConvNodeWithBatch(device='cpu', size=28, latent_dim=16, in_channels=3,
    ode_hidden_dim=128, ode_out_dim=16, augment_dim=0, time_dependent=False,
    ode_non_linearity='relu', conv_activation=nn.ReLU(),latent_activation=None, stack_size=1)

--------------------------------------------------
Creating ConvAE...
Number of parameters in the encoder model: 127568
Number of parameters in the decoder model: 127715
--------------------------------------------------
Creating ANODENet...
Number of parameters in the model: 25392


In [78]:
input_test_2 = torch.randn(10, 2, 3, 28, 28)
times = torch.linspace(0, 1, 20)
print(times)
res1 = model(input_test_2, times, 1./20)
print(len(res1), res1[0].shape, res1[1].shape)


tensor([0.0000, 0.0526, 0.1053, 0.1579, 0.2105, 0.2632, 0.3158, 0.3684, 0.4211,
        0.4737, 0.5263, 0.5789, 0.6316, 0.6842, 0.7368, 0.7895, 0.8421, 0.8947,
        0.9474, 1.0000])
2 torch.Size([10, 20, 3, 28, 28]) torch.Size([10, 20, 16])


In [79]:
input_test = torch.randn(2, 3, 28, 28)
res2 = model(input_test, times, 1./17)
print(len(res2), res2[0].shape, res2[1].shape)

2 torch.Size([20, 3, 28, 28]) torch.Size([20, 16])
