In [1]:
import math
import torch
import torch.nn as nn
from torch.distributions import Normal, Categorical
import argparse
import os
from tqdm.notebook import tqdm
import torch.nn.functional as F
import numpy as np
from torch.utils import data
import pandas as pd
from scipy.stats import sem

### Hyperparameters

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="CDDP", choices=["CDDP", "VCL-BSSM"])
parser.add_argument("--task", default="sinus",
                    choices=["sinus", "lv", "lorenz"])
parser.add_argument("--data_dir", default="./data/")
parser.add_argument("--base", default="runs")
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=9)
parser.add_argument("--start_replication", type=int, default=1)
parser.add_argument("--max_replication", type=int, default=5)
args = parser.parse_args("")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

task = args.task
if task == "sinus":
    args.epochs = 300
    args.latent_size = 6
    args.eval_period = 60
    args.lrate = 5e-3
    args.context_size = 5
    args.hidden_size = 40
    args.n_samples = 10
    args.dt = 0.1
    args.memory_size = 20

elif task == "lv":
    args.epochs = 750
    args.latent_size = 6
    args.eval_period = 150
    args.lrate = 1e-3
    args.context_size = 8
    args.hidden_size = 40
    args.n_samples = 10
    args.dt = 0.4
    args.memory_size = 10

elif task == "lorenz":
    args.epochs = 500
    args.latent_size = 12
    args.eval_period = 100
    args.lrate = 5e-4
    args.context_size = 16
    args.hidden_size = 90
    args.n_samples = 10
    args.dt = 0.01
    args.memory_size = 15


num_tasks_dict = {
    "sinus": 5,
    "lv": 4,
    "lorenz": 4
}

sequences = {
    'sinus': {
        0: [4, 2, 3, 1, 0], 1: [0, 2, 3, 4, 1], 2: [2, 3, 0, 4, 1], 3: [0, 2, 1, 4, 3], 
        4: [3, 1, 2, 4, 0], 5: [3, 0, 2, 1, 4], 6: [0, 3, 2, 4, 1], 7: [2, 4, 3, 0, 1], 
        8: [1, 0, 3, 4, 2], 9: [4, 0, 2, 3, 1], 10: [1, 0, 4, 2, 3], 11: [3, 0, 2, 4, 1], 
        12: [1, 3, 2, 0, 4], 13: [2, 4, 3, 1, 0], 14: [1, 4, 3, 0, 2], 15: [2, 4, 1, 3, 0], 
        16: [0, 4, 2, 3, 1], 17: [1, 2, 4, 0, 3], 18: [1, 3, 2, 4, 0], 19:  [0, 3, 1, 2, 4]
        },
    'lv': {
        0: [0, 1, 2, 3], 1: [0, 1, 3, 2], 2: [0, 2, 1, 3], 3: [0, 2, 3, 1], 
        4: [0, 3, 1, 2], 5: [0, 3, 2, 1], 6: [1, 0, 2, 3], 7: [1, 0, 3, 2], 
        8: [1, 2, 0, 3], 9: [1, 2, 3, 0], 10: [1, 3, 0, 2], 11: [1, 3, 2, 0], 
        12: [2, 0, 1, 3], 13: [2, 0, 3, 1], 14: [2, 1, 0, 3], 15: [2, 1, 3, 0], 
        16: [2, 3, 0, 1], 17: [2, 3, 1, 0], 18: [3, 0, 1, 2], 19: [3, 0, 2, 1]},
    'lorenz': {
        0: [0, 1, 2, 3], 1: [0, 1, 3, 2], 2: [0, 2, 1, 3], 3: [0, 2, 3, 1], 
        4: [0, 3, 1, 2], 5: [0, 3, 2, 1], 6: [1, 0, 2, 3], 7: [1, 0, 3, 2], 
        8: [1, 2, 0, 3], 9: [1, 2, 3, 0]},
    }

os.makedirs(f"{args.base}", exist_ok=True)
os.makedirs(f"{args.base}/baselines", exist_ok=True)
os.makedirs(f"{args.base}/baselines/{args.task}", exist_ok=True)
os.makedirs(f"{args.base}/baselines/{args.task}/{args.model}", exist_ok=True)

## Data

In [3]:
class MyBatch:
    def __init__(self, x, y=None, mode=None, t=None):
        self.x = x  # high-dim data
        self.y = y  # tasks
        self.t = t  # times
        self.mode = mode  # modes
        self.N = x.shape[0]

class MyDataset:
    def __init__(self, xtr, ytr, modetr, ttr, xval=None, yval=None, xtest=None, ytest=None, modetest=None, ttest=None):
        self.train = MyBatch(xtr, ytr, modetr, ttr)
        if xtest is not None:
            self.test = MyBatch(xtest, ytest, modetest, ttest)

class Dataset(data.Dataset):
    def __init__(self, Xtr, y=None, mode=None, t=None):
        self.Xtr = Xtr 
        self.y = y
        self.mode = mode
        self.t = t

    def __len__(self):
        return len(self.Xtr)
    def __getitem__(self, idx):
        return self.Xtr[idx], self.y[idx], self.mode[idx], self.t[idx]

def load_dataset(task_name, data_dir, dt=0.001):
    if task_name == "sinus":
        name = "sinus waves"
    elif task_name == "lv":
        name = "Lotka Volterra"
    elif task_name == "lorenz":
        name = "Lorenz attractor"
    else:
        raise NotImplementedError(f"The {task} is not available")
        
    Xtr = np.load(os.path.join(data_dir, name, "training.npy"))
    Ytr = np.load(os.path.join(data_dir, name, "training_tasks.npy"))
    Modetr = np.load(os.path.join(data_dir, name, "training_modes.npy"))
    Ttr = dt * np.arange(0, Xtr.shape[1], dtype=np.float32)
    Ttr = np.tile(Ttr, [Xtr.shape[0], 1])

    Xtest = np.load(os.path.join(data_dir, name, "test.npy"))
    Ytest = np.load(os.path.join(data_dir, name, "test_tasks.npy"))
    Modetest = np.load(os.path.join(data_dir, name, "test_modes.npy"))
    Ttest = dt * np.arange(0, Xtest.shape[1], dtype=np.float32)
    Ttest = np.tile(Ttest, [Xtest.shape[0], 1])

    dataset = MyDataset(Xtr,Ytr,Modetr, Ttr,xtest=Xtest,ytest=Ytest,modetest=Modetest, ttest=Ttest)
    return dataset

def load_data(task, data_dir, dt=0.1):
    if task in ["sinus", "lorenz", "lv"]:
        dataset = load_dataset(task, data_dir, dt)
    else:
        raise NotImplementedError(f"The {task} is not available")

    [N, T, D] = dataset.train.x.shape
    return dataset, N, T, D


class Generator(object):
    def __init__(self, dataset, tasks, batch_size=8):
        self.tasks = tasks
        ids = dataset.y == tasks if isinstance(tasks, int) else np.array(list(map(lambda x: x in tasks, dataset.y[0])))
        self.x = dataset.x[ids[0]] if isinstance(tasks, int) else dataset.x[ids]
        self.y = dataset.y[ids] if isinstance(tasks, int) else dataset.y[:,ids][0]
        self.modes = dataset.mode[ids] if isinstance(tasks, int) else dataset.mode[:,ids][0]
        self.ts = dataset.t[ids[0]] if isinstance(tasks, int) else dataset.t[ids]
        self.batch_size = batch_size

    def get(self):
        ds = Dataset(self.x, self.y, self.modes, self.ts)
        params = {'batch_size': self.batch_size, 'shuffle': True}

        return data.DataLoader(ds, **params)



def get_generators(task, dataset, data_dir, dt, batch_size, train_task, test_tasks):
    if task in ["sinus", "lorenz", "lv"]:
        gen_train = Generator(dataset=dataset.train, batch_size=batch_size, tasks=train_task).get()
        gen_test = Generator(dataset=dataset.test, batch_size=batch_size, tasks=test_tasks).get()
    
    else:
        raise NotImplementedError(f"The {task} is not available")
    
    gen_dict = {
        "train": gen_train,
        "test": gen_test
    }

    return gen_dict

## Architectures

In [4]:
class Encoder(nn.Module):
    def __init__(self, in_size, out_size, context_size, task, hidden_size=30, device='cpu'):
        super().__init__()
        self.out_size = out_size
        self.hidden_size = hidden_size
        self.in_size = in_size
        self.context_size = context_size
        self.task = task
        
        self.activation = nn.Tanh

        if self.task in ["lorenz"]:
            self.encoder = nn.Sequential(
                nn.Linear(self.context_size * self.in_size, self.hidden_size),
                self.activation(),
                nn.LayerNorm(self.hidden_size),
                nn.Linear(self.hidden_size, self.hidden_size),
                self.activation(),
                nn.LayerNorm(self.hidden_size),
                nn.Linear(self.hidden_size, self.out_size * 2)
            )
        elif self.task in ["sinus", "lv"]:
            self.encoder = nn.Sequential(
                nn.Linear(context_size * in_size, out_size * 2)
            )
        else:
            raise NotImplementedError(f"{self.__name__} is not implemented for {task}!")
        
        self.to(device)

    def forward(self, x):
        out = self.encoder(x.flatten(1)) # embedding
        return out[:, :self.out_size], out[:, self.out_size:].clamp(-8, 8).exp()

class Decoder(nn.Module):
    def __init__(self, in_size, out_size, task, hidden_size=30, device='cpu'):
        super().__init__()
        self.in_size = in_size
        self.hidden_size = hidden_size
        self.out_size = out_size
        self.task = task

        activation = nn.Tanh
        
        if self.task in ["lorenz"]:
            self.decoder = nn.Sequential(
                nn.Linear(in_size, hidden_size),
                activation(),
                nn.LayerNorm(hidden_size),
                nn.Linear(hidden_size, hidden_size),
                activation(),
                nn.LayerNorm(hidden_size),
                nn.Linear(hidden_size, out_size)
            )
        elif self.task in ["sinus", "lv"]:
            self.decoder = nn.Sequential(
                nn.Linear(in_size, out_size)
            )
        else:
            raise NotImplementedError(f"{self.__name__} is not implemented for {task}!")

        self.to(device)

    def forward(self, x):
        return self.decoder(x)                         

class VBLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(VBLinear, self).__init__()
        self.n_in = in_features
        self.n_out = out_features

        self.bias = nn.Parameter(torch.Tensor(out_features))
        self.mu_w = nn.Parameter(torch.Tensor(out_features, in_features))
        self.prior_mu_w = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False)
        self.prior_mu_w.data.zero_()

        self.logsig2_w = nn.Parameter(torch.Tensor(out_features, in_features))
        self.prior_logsig2_w = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False)
        self.prior_logsig2_w.data.zero_()
        
        
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.mu_w.size(1))
        self.mu_w.data.normal_(0, stdv)
        self.logsig2_w.data.zero_().normal_(-9, 0.001)  # var init via Louizos
        self.bias.data.zero_()

    def KL(self):
        return torch.distributions.kl.kl_divergence(Normal(self.mu_w, self.logsig2_w.clamp(-8, 8).exp().sqrt()), Normal(self.prior_mu_w, self.prior_logsig2_w.clamp(-8, 8).exp().sqrt()))
        

    def forward(self, input):
        # Sampling free forward pass only if MAP prediction and no training rounds
        if not self.training:
            return torch.nn.functional.linear(input, self.mu_w, self.bias)
        else:
            mu_out = torch.nn.functional.linear(input, self.mu_w, self.bias)
            logsig2_w = self.logsig2_w.clamp(-8, 8)
            s2_w = logsig2_w.exp()
            var_out = torch.nn.functional.linear(input.pow(2), s2_w) + 1e-8
            return mu_out + var_out.sqrt() * torch.randn_like(mu_out)

    def __repr__(self):
        return self.__class__.__name__  + " (" + str(self.n_in) + " -> " + str(self.n_out)  + ")"

class F_theta(nn.Module):
    def __init__(self, in_features, out_features, hidden_size, task, is_deterministic, reset):
        super().__init__()
        self.reset = reset
        self.is_deterministic = is_deterministic

        linear = nn.Linear if is_deterministic else VBLinear
        activation = nn.Tanh
        
        if task in ["sinus", "lorenz", "lv"]:
            self.F_theta = nn.Sequential(
                linear(in_features=in_features, out_features=hidden_size),
                activation(),
                nn.LayerNorm(hidden_size),
                linear(in_features=hidden_size, out_features=out_features)
            )
        else:
            raise NotImplementedError(f"{self.__name__} is not implemented for {task}!")

        if reset:
            self.reset_parameters()

    def forward(self, x):
        return self.F_theta(x)

    def KL(self):
        kl = 0
        for layer in self.F_theta.children():
            if isinstance(layer, VBLinear):
                kl += layer.KL().sum()

        return kl

    def reset_parameters(self):
        for layer in self.F_theta.children():
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
            elif isinstance(layer, VBLinear):
                layer.reset_parameters()

    def update_priors(self):
        for layer in self.F_theta.children():
            if isinstance(layer, VBLinear):
                layer.prior_mu_w.data = layer.mu_w.data.clone()
                layer.prior_logsig2_w.data = layer.logsig2_w.data.clone()
                if self.reset:
                    layer.reset_parameters()


# VCL-BSSM

In [5]:
class VCL_BSSM(nn.Module):

    def __init__(self, in_size, out_size, latent_size, task, context_size=1, hidden_size=30,
                 device='cpu', n_samples=3, dt=0.1):
        super(VCL_BSSM, self).__init__()

        self.in_size = in_size
        self.latent_size = latent_size
        self.hidden_size = hidden_size
        self.context_size = context_size
        self.out_size = out_size
        self.device = device
        self.dt = dt
        self.n_samples = n_samples
        self.task = task

        
        self.mu0 = torch.nn.Parameter(torch.zeros(latent_size), requires_grad=True)
        self.logvar0 = torch.nn.Parameter(torch.zeros(latent_size), requires_grad=True)

        self.mu0_prior = torch.nn.Parameter(torch.zeros(latent_size), requires_grad=False)
        self.logvar0_prior = torch.nn.Parameter(torch.zeros(latent_size), requires_grad=False)
        
        self.Sigma = torch.nn.Parameter(torch.ones(latent_size) * 0.1, requires_grad=False)

        self.C = context_size
        self.encoder = Encoder(in_size=in_size, context_size=context_size, out_size=latent_size, hidden_size=hidden_size, device=device, task=task)
        self.decoder = Decoder(in_size=latent_size, out_size=out_size, task=task, hidden_size=hidden_size, device=device)

        self.F_theta = F_theta(in_features=latent_size, hidden_size=hidden_size, out_features=latent_size, task=task, is_deterministic=False, reset=True)
        
        self.to(device)


    def prior_dynamics(self, x_prev):
        mu = x_prev + self.F_theta(x_prev) * self.dt 
        var = (self.Sigma * self.dt).repeat(x_prev.shape[0], 1)
        return mu, var

    def forward(self, y):
        batch_size, T, D = y.shape
        lik = 0
        x0_mu, x0_logvar = self.encoder(y[:, :self.context_size])

        kl = torch.distributions.kl.kl_divergence(Normal(x0_mu, x0_logvar.sqrt()), Normal(self.mu0_prior, self.logvar0_prior.exp().sqrt()))

        x_prev = x0_mu + torch.sqrt(x0_logvar)*Normal(0,1).sample()
        
        y_prob = torch.zeros_like(y, device=self.device)
        for t in range(T):
            mu, logvar = self.prior_dynamics(x_prev)
            x_t = mu + torch.sqrt(logvar)*Normal(0,1).sample()
            y_rcnst = self.decoder(x_t)
            lik += - nn.MSELoss(reduction='sum')(y_rcnst, y[:, t,:]) 

            x_prev = x_t
            y_prob[:, t] = y_rcnst

        return lik.sum() - kl.sum() - self.F_theta.KL(), y_prob.cpu()
        
    def predict(self, y):
        batch_size, T, D = y.shape
        C = self.C
        x0_mu, x0_logvar = self.encoder(y[:, :self.context_size])

        x_prev = x0_mu + torch.sqrt(x0_logvar)*Normal(0,1).sample()
        
        y_prob = torch.zeros((batch_size, T - C, D), device=self.device)
        for t in range(T):
            mu, logvar = self.prior_dynamics(x_prev)
            x_t = mu + torch.sqrt(logvar)*Normal(0,1).sample()
            y_rcnst = self.decoder(x_t)
            x_prev = x_t
            if t - C >= 0:
                y_prob[:, t-C] = y_rcnst

        return y_prob.cpu()

    def update(self):
        self.F_theta.update_priors()


# CDDP

In [6]:
class CDDP(nn.Module):
    def __init__(self, in_size, out_size, latent_size, task, context_size=1, hidden_size=30,
                 device='cpu', n_samples=3, dt=0.1, memory_size=20):
        super(CDDP, self).__init__()

        self.in_size = in_size
        self.latent_size = latent_size
        self.hidden_size = hidden_size
        self.context_size = context_size
        self.out_size = out_size
        self.device = device
        self.dt = dt
        self.n_samples = n_samples
        self.task = task

        
        self.mu0 = torch.nn.Parameter(torch.zeros(latent_size), requires_grad=True)
        self.logvar0 = torch.nn.Parameter(torch.zeros(latent_size), requires_grad=True)

        self.mu0_prior = torch.nn.Parameter(torch.zeros(latent_size), requires_grad=False)
        self.logvar0_prior = torch.nn.Parameter(torch.zeros(latent_size), requires_grad=False)
        
        self.Sigma = torch.nn.Parameter(torch.ones(latent_size) * 0.1, requires_grad=False)

        self.C = context_size
        self.encoder = Encoder(in_size=in_size, context_size=context_size, out_size=latent_size, hidden_size=hidden_size, device=device, task=task)
        self.decoder = Decoder(in_size=latent_size, out_size=out_size, task=task, hidden_size=hidden_size, device=device)

        self.task = task
        self.memory_size = memory_size
        # memory
        self.memory = torch.nn.Parameter(torch.Tensor(self.memory_size, self.latent_size), requires_grad=False)
        self.memory.data.normal_(0, 0.01)
        self.memory.data.pow_(2)
        self.similarity = self.similarity_function
        
        self.alpha0 = torch.nn.Parameter(torch.ones(1, device=self.device), requires_grad=False)

        self.v1 = nn.Linear(self.latent_size * 2, self.latent_size)
        self.v2 = nn.Linear(self.latent_size * 2, self.latent_size)
        
        # transition model
        self.F_theta = F_theta(in_features=latent_size*2, hidden_size=hidden_size, out_features=latent_size, task=task, is_deterministic=False, reset=False)
        
        self.to(device)
        
    
    def update(self):
        return 

    def similarity_function(self, memory_sample, b):
        # cosine sim
        eps = 1e-8
        a_n, b_n = memory_sample.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
        a_norm = memory_sample / torch.max(a_n, eps * torch.ones_like(a_n))
        b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
        sim_mt = torch.mm(b_norm, a_norm.transpose(0, 1))
        return sim_mt

    def sample_from_memory(self):
        return self.memory
    
    def update_memory(self, encoded_context):
        # equation 12
        memory_sample = self.sample_from_memory()
        weights = self.get_memory_weights(encoded_context, memory_sample)
        
        mem_offset = torch.tanh((((1-weights).unsqueeze(2) * memory_sample) + (weights.unsqueeze(2) * encoded_context.unsqueeze(1))).mean(0))
        self.memory.data.zero_().add_(mem_offset)

        return weights

    def get_memory_weights(self, encoded_context, memory_sample):
        # equation 13
        similarity_measure = self.similarity(memory_sample, encoded_context.squeeze())
        weights = F.softmax(similarity_measure, dim=1)
        return weights    

    def prior_dynamics(self, x_prev, pi):
        mu = x_prev + self.F_theta(torch.cat((self.memory[pi], x_prev), 1)) * self.dt 
        var = (self.Sigma * self.dt).repeat(x_prev.shape[0], 1)
        return mu, var
        
    def get_memory_priors(self, batch_size):
        K = self.memory_size
        pi_primes = torch.distributions.Beta(torch.ones((K, batch_size), device=self.device), self.alpha0).sample()
        prior_pis = torch.zeros((K, batch_size), device=self.device)
        accumulator = torch.ones(batch_size, device=self.device)
        for k in range(K):
            prior_pi_k = pi_primes[k]
            j = k -1
            if j >= 0:
                accumulator = accumulator * (1 - pi_primes[j])
            prior_pi_k = prior_pi_k * accumulator
            prior_pis[k] = prior_pi_k
        
        return Categorical(prior_pis.transpose(1,0))

    def forward(self, y):
        batch_size, T, D = y.shape
        lik = 0
        encoder_mu, encoder_logvar = self.encoder(y[:, :self.context_size])
        weights = self.update_memory(encoded_context=encoder_mu)
        
        # equation 29 
        posterior_pi = Categorical(weights)
        prior_pi = self.get_memory_priors(batch_size=y.shape[0])
        kl_pi = torch.distributions.kl.kl_divergence(posterior_pi, prior_pi)
        pi = posterior_pi.sample()

        # equation 30
        x0_mu = self.v1(torch.cat((self.memory[pi], encoder_mu.squeeze()), 1))
        x0_logvar = self.v2(torch.cat((self.memory[pi], encoder_mu.squeeze()), 1)).clamp(-8, 8).exp()
        

        kl = torch.distributions.kl.kl_divergence(Normal(x0_mu, x0_logvar.sqrt()), Normal(self.mu0_prior, self.logvar0_prior.exp().sqrt()))

        x_prev = x0_mu + torch.sqrt(x0_logvar)*Normal(0,1).sample()
        
        y_prob = torch.zeros_like(y, device=self.device)
        for t in range(T):
            mu, logvar = self.prior_dynamics(x_prev, pi)
            x_t = mu + torch.sqrt(logvar)*Normal(0,1).sample()
            y_rcnst = self.decoder(x_t)
            lik += - nn.MSELoss(reduction='sum')(y_rcnst, y[:, t,:]) 
            
            x_prev = x_t
            y_prob[:, t] = y_rcnst

        return lik.sum() - (kl.sum() + self.F_theta.KL() + kl_pi.sum()), y_prob.cpu()

    def predict(self, y):      
        batch_size, T, D = y.shape
        C = self.C
        encoder_mu, encoder_logvar = self.encoder(y[:, :self.context_size])
        memory_sample = self.sample_from_memory()
        weights = self.get_memory_weights(encoder_mu, memory_sample)
        pi = Categorical(weights).sample()
        x0_mu = self.v1(torch.cat((self.memory[pi], encoder_mu), 1))
        x0_logvar = self.v2(torch.cat((self.memory[pi], encoder_mu), 1)).clamp(-8, 8).exp()

        x_prev = x0_mu + torch.sqrt(x0_logvar)*Normal(0,1).sample()
        
        y_prob = torch.zeros((batch_size, T - C, D), device=self.device)
        for t in range(T):
            mu, logvar = self.prior_dynamics(x_prev, pi)
            x_t = mu + torch.sqrt(logvar)*Normal(0,1).sample()
            y_rcnst = self.decoder(x_t)
            x_prev = x_t
            if t - C >= 0:
                y_prob[:, t-C] = y_rcnst

        return y_prob.cpu()


### Pipeline

In [7]:
def NMSE_score_train(preds, targets):
    return (torch.square(targets - preds).mean() / torch.square(targets).mean()).item()

def NMSE_score(preds, targets):
    return (torch.square(targets - preds.mean(0)).mean() / torch.square(targets).mean()).item()

def normal(target, mean):
    sigma = 0.1
    return 1/(sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5*np.square((target - mean)/sigma))

def NLL_score(preds, targets):
    N_sample, batch_size, T, D = preds.shape
    sum_of_normals = np.zeros((batch_size, D))
    for t in range(T):
        sum_of_normals_sample = np.zeros((batch_size, D))
        for sample in range(N_sample):
            pred_point = preds[sample, :, t, :]
            target_point = targets[:, t, :]
            tmp = normal(target_point, pred_point)
            sum_of_normals_sample = np.add(sum_of_normals_sample, tmp)
        sum_of_normals_sample = sum_of_normals_sample / N_sample
        sum_of_normals = np.add(sum_of_normals, sum_of_normals_sample)
    mean_of_normals = sum_of_normals / T
    nll = - np.log(mean_of_normals + 1e-3)
    return nll.mean().item()


def calculate_scores(probs, targets):
    return (
        NMSE_score(probs, targets), 
        NLL_score(probs, targets)
    )

FloatTensor = torch.FloatTensor
LongTensor = torch.LongTensor
ByteTensor = torch.ByteTensor

def train(model, optim, dataloader, task, device='cpu'):
    optim.zero_grad()
    model.train()
    train_elbo = 0
    train_num = 0
    y_probs = []
    targets = []
    for local_batch in dataloader:
        y, task_, mode, ts = local_batch
        y = FloatTensor(y.float()).to(device)
        elbo, y_prob = model(y)
        loss = -elbo
        loss.backward()
        optim.step()
        train_elbo += elbo.item()
        train_num += y.shape[0]

        y_probs.append(y_prob)
        targets.append(y.cpu())

    elbo = train_elbo / train_num
    nmse = NMSE_score_train(torch.vstack(y_probs), torch.vstack(targets))
    return elbo, nmse

def eval(task_name, model, dataloader, train_task, test_tasks, n_samples, name, device="cpu"):

    model.eval()
    with torch.no_grad():
        
        targets = []
        y_probs = [[] for _ in range(n_samples)]
        tasks = []
        modes = []
        for local_batch in dataloader:
            y, task, mode, ts = local_batch
            y = FloatTensor(y.float()).to(device)
            for sample in range(n_samples):
                y_prob = model.predict(y)
                y_probs[sample].extend(y_prob)
            
            targets.append(y[:, model.C:, :])
            tasks.append(task)
            modes.append(mode)
        
        y_probs = [torch.stack(item) for item in y_probs]
        y_probs, targets = torch.stack(y_probs).cpu(), torch.vstack(targets).cpu()

        tasks = torch.cat(tasks)
        modes = torch.cat(modes)
        nmse_mean, nll_mean = calculate_scores(y_probs, targets)
        

    return nmse_mean, nll_mean

In [8]:
def replicate(args, device, experiment_num):
    name = f"{args.base}/baselines/{args.task}/{args.model}/experiment_{str(experiment_num).zfill(2)}"
    os.makedirs(name, exist_ok=True)
    
    num_tasks = num_tasks_dict[args.task]
    tasks = sequences[args.task].copy()[experiment_num - 1].copy()
    train_task = tasks.pop()
    test_tasks = []
    test_tasks.append(train_task)
    print(f"train task: {train_task}, test tasks: {test_tasks}")

    dataset, N, T, D = load_data(args.task, args.data_dir, dt=args.dt)
    gen_dict = get_generators(args.task, dataset, args.data_dir, args.dt, args.batch_size, train_task, test_tasks)
    n_epochs = args.epochs * num_tasks

    context_size = args.context_size
    in_size = D
    out_size = D
    latent_size = args.latent_size
    hidden_size = args.hidden_size

    if args.model == "VCL-BSSM":
        model = VCL_BSSM(in_size=in_size, out_size=out_size, latent_size=latent_size, task=args.task,  context_size=context_size,
                                hidden_size=hidden_size, device=device, n_samples=args.n_samples, dt=args.dt)
    elif args.model == "CDDP":
        model = CDDP(in_size=in_size, out_size=out_size, latent_size=latent_size, task=args.task,  context_size=context_size,
                    hidden_size=hidden_size, device=device, n_samples=args.n_samples, dt=args.dt, memory_size=args.memory_size)
    else:
        raise NotImplementedError(f"The {args.model} is not available")
        
    model.to(device)

    lrate = args.lrate

    optim = torch.optim.Adam(model.parameters(), lr=lrate)

    history = {"elbo": [], "nmse_train": [], "nmse_test": [], "nll_test": [], "idx_train": [], "idx_test": []}
    
    for epoch in tqdm(range(n_epochs)):
        elbo, score = train(model, optim, gen_dict["train"], args.task, device)

        history["elbo"].append(elbo)
        history["nmse_train"].append(score)
        history["idx_train"].append(epoch)

        if epoch % args.eval_period == 0 or epoch == n_epochs - 1:
            mse, nll = eval(args.task, model, gen_dict["test"], train_task, test_tasks, args.n_samples, f"{name}/images/{epoch}.png",device)
            history["nmse_test"].append(mse)
            history["nll_test"].append(nll)
            history["idx_test"].append(epoch)
            print(f"Epoch {epoch}/{n_epochs} || Train: Elbo: {history['elbo'][-1]:.4f} NMSE: {history['nmse_train'][-1]:.5f}, Test NMSE: {history['nmse_test'][-1]:.5f}, NLL: {history['nll_test'][-1]:.5f}")
            
        
        if ((epoch % (n_epochs // num_tasks)) == 0 and epoch != 0) or epoch == n_epochs -1:
            state = {
                "model_state_dict": model.state_dict()
            }
            torch.save(state, f"{name}/checkpoint_trained_on_task_{train_task}.pt")

            if epoch == n_epochs -1:
                continue
                
            model.update()
            
            train_task = tasks.pop()
            test_tasks.append(train_task)
            print(f"train task: {train_task}, test tasks: {test_tasks}")

            gen_dict = get_generators(args.task, dataset, args.data_dir, args.dt, args.batch_size, train_task, test_tasks)
    
    return history
    

In [9]:
experiments = []
for exp_num in range(args.start_replication, args.max_replication + 1):
    print("#"*10, exp_num, "#"*10)
    results = replicate(args, device, exp_num)
    experiments.append(results)

########## 1 ##########
train task: 0, test tasks: [0]


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))

Epoch 0/1500 || Train: Elbo: -359.5756 NMSE: 1.03453, Test NMSE: 0.99228, NLL: 2.96502
Epoch 60/1500 || Train: Elbo: -354.5408 NMSE: 0.96687, Test NMSE: 1.02593, NLL: 3.83257
Epoch 120/1500 || Train: Elbo: -320.3667 NMSE: 0.42563, Test NMSE: 0.55307, NLL: 2.67718
Epoch 180/1500 || Train: Elbo: -330.2021 NMSE: 0.56372, Test NMSE: 0.43563, NLL: 2.36066
Epoch 240/1500 || Train: Elbo: -315.5135 NMSE: 0.33073, Test NMSE: 0.41808, NLL: 1.77968
Epoch 300/1500 || Train: Elbo: -312.1187 NMSE: 0.28306, Test NMSE: 0.40504, NLL: 1.70206
train task: 1, test tasks: [0, 1]
Epoch 360/1500 || Train: Elbo: -349.1324 NMSE: 0.17376, Test NMSE: 0.63221, NLL: 2.62394
Epoch 420/1500 || Train: Elbo: -335.2649 NMSE: 0.12468, Test NMSE: 0.66100, NLL: 3.38040
Epoch 480/1500 || Train: Elbo: -322.5072 NMSE: 0.07650, Test NMSE: 0.68257, NLL: 2.59725
Epoch 540/1500 || Train: Elbo: -310.5539 NMSE: 0.04030, Test NMSE: 0.61888, NLL: 2.31011
Epoch 600/1500 || Train: Elbo: -336.6002 NMSE: 0.14326, Test NMSE: 0.68414, NLL

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))

Epoch 0/1500 || Train: Elbo: -583.3878 NMSE: 1.07419, Test NMSE: 1.04017, NLL: 3.86532
Epoch 60/1500 || Train: Elbo: -434.6259 NMSE: 0.50251, Test NMSE: 0.55240, NLL: 2.53398
Epoch 120/1500 || Train: Elbo: -346.6899 NMSE: 0.17513, Test NMSE: 0.19101, NLL: 1.99713
Epoch 180/1500 || Train: Elbo: -333.8068 NMSE: 0.12292, Test NMSE: 0.15559, NLL: 1.73036
Epoch 240/1500 || Train: Elbo: -360.7942 NMSE: 0.22048, Test NMSE: 0.09997, NLL: 1.82115
Epoch 300/1500 || Train: Elbo: -313.1408 NMSE: 0.04918, Test NMSE: 0.02386, NLL: 1.00922
train task: 4, test tasks: [1, 4]
Epoch 360/1500 || Train: Elbo: -1506.8925 NMSE: 0.67849, Test NMSE: 0.98562, NLL: 4.25735
Epoch 420/1500 || Train: Elbo: -748.8832 NMSE: 0.23281, Test NMSE: 0.89755, NLL: 4.15965
Epoch 480/1500 || Train: Elbo: -503.4119 NMSE: 0.08925, Test NMSE: 0.84270, NLL: 3.58325
Epoch 540/1500 || Train: Elbo: -385.6931 NMSE: 0.02678, Test NMSE: 0.87674, NLL: 3.36827
Epoch 600/1500 || Train: Elbo: -575.5299 NMSE: 0.13959, Test NMSE: 0.85528, NL

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))

Epoch 0/1500 || Train: Elbo: -563.7564 NMSE: 1.00751, Test NMSE: 1.03583, NLL: 3.59322
Epoch 60/1500 || Train: Elbo: -431.2909 NMSE: 0.48840, Test NMSE: 0.59570, NLL: 4.38491
Epoch 120/1500 || Train: Elbo: -527.1037 NMSE: 0.84136, Test NMSE: 0.88539, NLL: 3.96103
Epoch 180/1500 || Train: Elbo: -354.8710 NMSE: 0.19348, Test NMSE: 0.16094, NLL: 1.81227
Epoch 240/1500 || Train: Elbo: -329.4986 NMSE: 0.10022, Test NMSE: 0.08824, NLL: 1.04684
Epoch 300/1500 || Train: Elbo: -341.8176 NMSE: 0.14906, Test NMSE: 0.13901, NLL: 1.60595
train task: 4, test tasks: [1, 4]
Epoch 360/1500 || Train: Elbo: -1206.0362 NMSE: 0.51239, Test NMSE: 1.01689, NLL: 4.71923
Epoch 420/1500 || Train: Elbo: -1149.4402 NMSE: 0.47286, Test NMSE: 0.97068, NLL: 4.44561
Epoch 480/1500 || Train: Elbo: -634.8441 NMSE: 0.16499, Test NMSE: 0.68170, NLL: 4.12532
Epoch 540/1500 || Train: Elbo: -494.8028 NMSE: 0.08709, Test NMSE: 0.54544, NLL: 2.98307
Epoch 600/1500 || Train: Elbo: -446.8491 NMSE: 0.06345, Test NMSE: 0.61345, N

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))

Epoch 0/1500 || Train: Elbo: -1441.4174 NMSE: 1.01490, Test NMSE: 1.04229, NLL: 5.55970
Epoch 60/1500 || Train: Elbo: -923.1426 NMSE: 0.55950, Test NMSE: 0.64145, NLL: 4.52191
Epoch 120/1500 || Train: Elbo: -687.8027 NMSE: 0.33496, Test NMSE: 0.51804, NLL: 4.83831
Epoch 180/1500 || Train: Elbo: -587.3781 NMSE: 0.23700, Test NMSE: 0.12673, NLL: 2.14255
Epoch 240/1500 || Train: Elbo: -380.1192 NMSE: 0.04968, Test NMSE: 0.02449, NLL: 1.45138
Epoch 300/1500 || Train: Elbo: -391.7322 NMSE: 0.05907, Test NMSE: 0.03878, NLL: 1.58881
train task: 4, test tasks: [3, 4]
Epoch 360/1500 || Train: Elbo: -1043.3315 NMSE: 0.42020, Test NMSE: 0.89361, NLL: 4.34718
Epoch 420/1500 || Train: Elbo: -417.2541 NMSE: 0.04646, Test NMSE: 0.12362, NLL: 2.07887
Epoch 480/1500 || Train: Elbo: -355.0490 NMSE: 0.01454, Test NMSE: 0.25280, NLL: 2.50133
Epoch 540/1500 || Train: Elbo: -638.0749 NMSE: 0.17522, Test NMSE: 0.43842, NLL: 3.70818
Epoch 600/1500 || Train: Elbo: -355.5464 NMSE: 0.01274, Test NMSE: 0.28571, N

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))

Epoch 0/1500 || Train: Elbo: -360.5812 NMSE: 1.03732, Test NMSE: 1.03257, NLL: 2.04755
Epoch 60/1500 || Train: Elbo: -338.3364 NMSE: 0.69428, Test NMSE: 0.81687, NLL: 2.67858
Epoch 120/1500 || Train: Elbo: -310.1020 NMSE: 0.26286, Test NMSE: 0.48400, NLL: 2.04931
Epoch 180/1500 || Train: Elbo: -316.6738 NMSE: 0.36596, Test NMSE: 0.72389, NLL: 1.47882
Epoch 240/1500 || Train: Elbo: -317.8038 NMSE: 0.36577, Test NMSE: 0.31938, NLL: 1.48534
Epoch 300/1500 || Train: Elbo: -301.9668 NMSE: 0.14067, Test NMSE: 0.26547, NLL: 1.48111
train task: 4, test tasks: [0, 4]
Epoch 360/1500 || Train: Elbo: -1182.0134 NMSE: 0.48814, Test NMSE: 0.53265, NLL: 3.15340
Epoch 420/1500 || Train: Elbo: -705.2769 NMSE: 0.21360, Test NMSE: 0.75847, NLL: 3.96147
Epoch 480/1500 || Train: Elbo: -597.4488 NMSE: 0.15062, Test NMSE: 0.84077, NLL: 3.97342
Epoch 540/1500 || Train: Elbo: -562.1586 NMSE: 0.12447, Test NMSE: 0.78607, NLL: 3.61554
Epoch 600/1500 || Train: Elbo: -434.5087 NMSE: 0.05331, Test NMSE: 0.89735, NL

In [10]:
scores = ["nmse", "nll"]
n_tasks = num_tasks_dict[args.task]
epochs = args.epochs
ls = []

indexes = [i*epochs for i in range(1, n_tasks)] + [n_tasks*epochs -1]
for exp_num in range(len(experiments)):
    for score in scores:
        for i, idx in enumerate(indexes):
            ls.append([f"{i + 1} tasks", score, experiments[exp_num][f"{score}_test"][experiments[exp_num][f"idx_test"].index(idx)]])


df = pd.DataFrame(ls, columns=['# Tasks', 'score', 'value'])
results = []
for nof_task in range(1, n_tasks + 1):
    mse = df[df.score == "nmse"]
    nll = df[df.score == "nll"]
    mse = mse[mse["# Tasks"] == f"{nof_task} tasks"]
    nll = nll[nll["# Tasks"] == f"{nof_task} tasks"]
    results.append([f"{nof_task} Tasks", f"{mse.value.mean():.3f} +- {sem(mse.value):.3f}", f"{nll.value.mean():.3f} +- {sem(nll.value):.3f}"])

results = pd.DataFrame(results, columns=["# Tasks", 'NMSE', 'NLL'])
results

Unnamed: 0,# Tasks,NMSE,NLL
0,1 Tasks,0.174 +- 0.072,1.477 +- 0.122
1,2 Tasks,0.667 +- 0.109,3.216 +- 0.263
2,3 Tasks,1.272 +- 0.186,4.226 +- 0.118
3,4 Tasks,1.022 +- 0.249,3.992 +- 0.314
4,5 Tasks,1.278 +- 0.172,4.480 +- 0.183


In [11]:
print(f'NMSE: {df[df.score == "nmse"].value.mean():.3f} +- {sem(df[df.score == "nmse"].value):.3f} ')
print(f'NLL: {df[df.score == "nll"].value.mean():.3f} +- {sem(df[df.score == "nll"].value):.3f} ')

NMSE: 0.883 +- 0.110 
NLL: 3.478 +- 0.238 
