In [None]:
%load_ext autoreload
%autoreload 2


import seaborn as sns
sns.set_style('whitegrid')
import matplotlib.pyplot as plt
import random
import torch
from torch.nn.utils.rnn import pad_packed_sequence, PackedSequence, pack_sequence
from src.models.rae_words import RAEWords as RAE
from src.models.vrae_words import VRAEWords as VRAE
from src.models.iaf_words import IAFWords as IAF
from src.models.iaf_words import VariationalInference as VI_AIF

from src.models.vrae_chars import VariationalInference


from src.data.toy import ToyData, Continuous
from src.data.common import get_loader
from src.models.common import CriterionTrainer, OneHotPacked, VITrainer

from torch.optim import Adam
from torch.nn import CrossEntropyLoss, MSELoss

import numpy as np

seed = 42
torch.set_deterministic(True)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# Defining data

In [None]:
data_parameters = {
    "max_length" : 16,
    "min_length" : 2,
    "error_scale" : 0.05,
}

batch_size = 100
max_epochs = 40

train_data = Continuous(num_observations=10_000, **data_parameters)
validation_data = Continuous(num_observations=1000, **data_parameters)
test_data = Continuous(num_observations=1000, **data_parameters)

num_tests = 1000
x_test = next(iter(get_loader(test_data, batch_size=num_tests)))

Data consists of sequences of variable length (chosen uniformly at random in (0, 1)), each with the same 1-dimensional value repeated with some gaussian noise. Each sequence ends with -1.

In [None]:
for data in train_data.values[:10]:
    plt.plot(data)

# Recurrent Autoencoder

In [None]:
rae = RAE(
    input_dim=1,
    latent_features=2,
    encoder_hidden_size=32,
    decoder_hidden_size=32,
)

optimizer_parameters = {
    "lr": 0.00,
}

criterion = MSELoss(reduction="sum")
optimizer = Adam(rae.parameters(), **optimizer_parameters)

class CriterionTrainerNoCache(CriterionTrainer):

    def save_checkpoint(self):
        pass

    def restore_checkpoint(self):
        pass


In [None]:
mt = CriterionTrainerNoCache(
    criterion=criterion,
    model=rae,
    optimizer=optimizer,
    batch_size=batch_size,
    max_epochs=max_epochs,
    training_data=train_data,
    validation_data=test_data,
)

mt.train(progress_bar='epoch')

In [None]:
plt.plot(mt.training_loss)
plt.plot(mt.validation_loss)

Some encoded/decoded examples:

In [None]:
output_rae = rae(x_test)

# Sample from observation model and pack, then pad
sample_rae_padded, sequence_lengths = pad_packed_sequence(output_rae)
sample_rae_padded = sample_rae_padded

target_padded, _ = pad_packed_sequence(x_test)

idx = random.sample(range(len(sequence_lengths)), k=3)

for t, i in enumerate(idx):

    decoded_rae = sample_rae_padded[:, i][:sequence_lengths[i]+1]
    target = target_padded[:, i][:sequence_lengths[i]+1]
    
    plt.plot(decoded_rae.detach().numpy()[:-1], color = f'C{t}', linestyle='dashed')
    plt.plot(target.detach().numpy()[:-1], color = f'C{t}')


In [None]:
rae

In [None]:
import seaborn as sns

sns.set_style("whitegrid")

x, y = rae.encoder(x_test).detach().squeeze().numpy().T

lengths = sequence_lengths.squeeze().numpy()
target_averages = []

for i, length in enumerate(lengths):
    arr = target_padded[:,i][:length-1].numpy()
    target_averages.append(arr.mean())

data = {
    'x' : x,
    'y' : y,
    'length' : lengths,
    'average' : target_averages,
}

sns.scatterplot(
    x=data['x'], 
    y=data['y'], 
    size=data['length'], 
    hue=data['average']
    )

# Variational recurrent autoencoder

In [None]:
vrae = VRAE(
    input_dim=1,
    latent_features=2,
    encoder_hidden_size=64,
    decoder_hidden_size=64,
)

optimizer_parameters = {
    "lr": 0.002,
}

vi = VariationalInference(1)
optimizer = Adam(vrae.parameters(), **optimizer_parameters)

class VITrainerNoCache(VITrainer):

    def save_checkpoint(self):
        pass

    def restore_checkpoint(self):
        pass

In [None]:
mt = VITrainerNoCache(
    vi=vi,
    model=vrae,
    optimizer=optimizer,
    batch_size=batch_size,
    max_epochs=max_epochs,
    training_data=train_data,
    validation_data=test_data,
)

mt.train(progress_bar='epoch')

In [None]:
plt.plot(mt.training_loss)
plt.plot(mt.validation_loss)

In [None]:
output_vrae = vrae(x_test)

# Sample from observation model and pack, then pad
sample_vrae_packed = PackedSequence(
    output_vrae['px'].sample(),
    x_test.batch_sizes
)
sample_vrae_padded, sequence_lengths = pad_packed_sequence(sample_vrae_packed)

target_padded, _ = pad_packed_sequence(x_test)

idx = random.sample(range(len(sequence_lengths)), k=3)

for t, i in enumerate(idx):

    decoded_vrae = sample_vrae_padded[:, i][:sequence_lengths[i]+1]
    target = target_padded[:, i][:sequence_lengths[i]+1]
    
    plt.plot(decoded_vrae.detach().numpy()[:-1], color = f'C{t}', linestyle='dashed')
    plt.plot(target.detach().numpy()[:-1], color = f'C{t}')


In [None]:
mu, log_sigma = vrae.encoder(x_test)
mu = mu.detach().numpy()

In [None]:
import seaborn as sns

x, y = mu.T

lengths = sequence_lengths.squeeze().numpy()
target_averages = []

for i, length in enumerate(lengths):
    arr = target_padded[:,i][:length-1].numpy()
    target_averages.append(arr.mean())

data = {
    'x' : x,
    'y' : y,
    'length' : lengths,
    'value' : target_averages,
}

sns.scatterplot(
    x=data['x'], 
    y=data['y'], 
    size=data['length'], 
    hue=data['value']
    )

In [None]:
values = np.linspace(0, 1, 5) 
length = 8

sequences = []

for value in values:
    
    tmp = []
    for i in range(length):
        tmp.append(value + random.gauss(0, data_parameters['error_scale']))
    tmp.append(-1)
    sequences.append(torch.tensor(tmp).view(-1, 1).float())

packed_input = pack_sequence(sequences)
output = vrae(packed_input)

n_samples = 1_000

samples = {}
for value in values:
    samples[value] = []

for i in range(n_samples):

    sample = output['qz'].sample()
    
    for j, value in enumerate(values):
        samples[value].append(sample[j, :].numpy())
    
plotting_data = {
    'x' : [],
    'y' : [],
    'value' : [],
}
for value, sample in samples.items():
    
    x, y = zip(*sample)
    plotting_data['x'].extend(x)
    plotting_data['y'].extend(y)
    plotting_data['value'].extend([value]*len(x))
    

In [None]:
sns.displot(plotting_data, x='x', y='y', hue='value', kind='kde')

In [None]:
value = 0.2
lengths = np.arange(2, 10)[::-1]

sequences = []

for length in lengths:
    
    tmp = []
    for i in range(length):
        tmp.append(value + random.gauss(0, data_parameters['error_scale']))
    tmp.append(-1)
    sequences.append(torch.tensor(tmp).view(-1, 1).float())

packed_input = pack_sequence(sequences)
output = vrae(packed_input)

n_samples = 1_000

samples = {}
for length in lengths:
    samples[length] = []

for i in range(n_samples):
    sample = output['qz'].sample()
    for j, length in enumerate(lengths):
        samples[length].append(sample[j, :].numpy())
    
plotting_data = {
    'x' : [],
    'y' : [],
    'length' : [],
}
for length, sample in samples.items():
    
    x, y = zip(*sample)
    plotting_data['x'].extend(x)
    plotting_data['y'].extend(y)
    plotting_data['length'].extend([length]*len(x))
    

In [None]:
sns.displot(plotting_data, x='x', y='y', hue='length', kind='kde')

# IAF Variational Recurrent Autoencoder

In [None]:
iaf = IAF(
    input_dim=1,
    latent_features=2,
    encoder_hidden_size=32,
    decoder_hidden_size=32,
    flow_depth=4,
    flow_hidden_features=24,
    flow_context_features=2,
)

optimizer_parameters = {
    "lr": 0.001,
}

class VITrainerNoCache(VITrainer):

    def save_checkpoint(self):
        pass

    def restore_checkpoint(self):
        pass

vi = VI_AIF()
optimizer = Adam(iaf.parameters(), **optimizer_parameters)

In [None]:
mt = VITrainerNoCache(
    vi=vi,
    model=iaf,
    optimizer=optimizer,
    batch_size=batch_size,
    max_epochs=max_epochs,
    training_data=train_data,
    validation_data=test_data,
)

mt.train(progress_bar='epoch')

In [None]:
plt.plot(mt.training_loss)
plt.plot(mt.validation_loss)

In [None]:
output_iaf = iaf(x_test)

# Sample from observation model and pack, then pad
sample_iaf_packed = PackedSequence(
    output_iaf['px'].sample(),
    x_test.batch_sizes
)
sample_vrae_padded, sequence_lengths = pad_packed_sequence(sample_iaf_packed)

target_padded, _ = pad_packed_sequence(x_test)

idx = random.sample(range(len(sequence_lengths)), k=3)

for t, i in enumerate(idx):

    decoded_vrae = sample_vrae_padded[:, i][:sequence_lengths[i]+1]
    target = target_padded[:, i][:sequence_lengths[i]+1]
    
    plt.plot(decoded_vrae.detach().numpy()[:-1], color = f'C{t}', linestyle='dashed')
    plt.plot(target.detach().numpy()[:-1], color = f'C{t}')

In [None]:
import seaborn as sns

z = output_iaf['z'].detach().numpy()

x, y = z.T

lengths = sequence_lengths.squeeze().numpy()
target_averages = []

for i, length in enumerate(lengths):
    arr = target_padded[:,i][:length-1].numpy()
    target_averages.append(arr.mean())

data = {
    'x' : x,
    'y' : y,
    'length' : lengths,
    'value' : target_averages,
}

sns.scatterplot(
    x=data['x'], 
    y=data['y'], 
    size=data['length'], 
    hue=data['value']
    )
plt.gca().set_aspect('equal', 'box')