In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from scipy import linalg as LA
plt.style.use("ggplot")

from os.path import join, exists
import os

import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F

from pixyz.models import Model
from pixyz.losses import ELBO, NLL

from sklearn.utils import shuffle

from models import *
from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

path_gif = "./NP_png/"
if not exists(path_gif):
    os.makedirs(path_gif)
    
path_gif = "./ANP_png/"
if not exists(path_gif):
    os.makedirs(path_gif)

In [None]:
# 訓練データ
sigma_y = 0.2

N = 8
np.random.seed(42)
train_X = np.random.uniform(0, 1, N)
train_y = np.sin(2 * np.pi * train_X) + np.random.normal(0, sigma_y, N) + 1
plt.plot(np.linspace(0,1), np.sin(2 * np.pi * np.linspace(0, 1)) + 1, "g")
plt.scatter(train_X, np.sin(2 * np.pi * train_X)+1, c="b", label=r"$f(X)$") # ノイズがのってないデータ
plt.scatter(train_X, train_y, label=r"$f(X) + \epsilon$") # 訓練データ
plt.legend()
plt.show()

In [None]:
# ガウス過程

sigma_f = 0.5
l = 0.2
sigma_y = 0.2

def k(x, y):
    return sigma_f ** 2 * np.exp(- ((x - y) ** 2) / (2 * l ** 2))

def k_(x):
    return np.vectorize(lambda x, y: k(x, y))(train_X, x)

def m(x):
    return K_y_inv.dot(train_y).dot(k_(x))

def sd(x):
    return np.sqrt(k(x, x) - k_(x).dot(K_y_inv).dot(k_(x)))

X, Y = np.meshgrid(train_X, train_X)
K = np.vectorize(k)(X, Y)  #(N, N) ノイズが乗ってないデータ(f(x))の同時分布の共分散
K_y = K + sigma_y ** 2 * np.eye(len(train_X)) #(N, N)  ノイズが乗ったデータ(y)の同時分布の共分散
K_y_inv = LA.inv(K_y)

Nx = 100
x = np.linspace(-0.2, 1.2, Nx)
y_mean = np.vectorize(m)(x)
y_upper = y_mean + np.vectorize(sd)(x) * 2  # 上2シグマ
y_under = y_mean - np.vectorize(sd)(x) * 2 # 下2シグマ

plt.scatter(train_X, train_y)
plt.plot(x, y_mean, label=r"$m(x_{N+1})$")
plt.plot(x, y_upper, "k--", label=r"$m(x_{N+1}) + 2\sigma(x_{N+1})$")
plt.plot(x, y_under, "k--")
plt.plot(np.linspace(0,1), np.sin(2 * np.pi * np.linspace(0, 1)) + 1, "g", alpha=0.5, label=r"$f(x_{N+1})$") # f(x)
plt.legend()
plt.show()

In [None]:
# パラメータ
z_dim = 3
d_dim = 128
x_dim = 1
y_dim = 1

r_encoder = R_encoder(x_dim, y_dim, d_dim, z_dim).to(device) # (x,y)->r
s_encoder = S_encoder(x_dim, y_dim, d_dim, z_dim).to(device) # (x,y)->z
s_encoder_context = s_encoder.replace_var(x="x_c", y="y_c").to(device)
decoder = Decoder(x_dim, y_dim, d_dim, z_dim).to(device) # (x*, r, z) -> y*
opt = torch.optim.Adam(list(decoder.parameters())+list(r_encoder.parameters())+
                       list(s_encoder.parameters()), 1e-3)

#%debug
x_grid = torch.from_numpy(np.arange(-4, 4, 0.1).reshape(-1,1).astype(np.float32)).to(device)
for i in range(15):
    untrained_zs = torch.from_numpy(np.random.normal(size=(z_dim)).astype(np.float32)).to(device)
    mu = decoder.sample_mean({"x_": x_grid, "r": untrained_zs.repeat(len(x_grid), 1), "z": untrained_zs.repeat(len(x_grid), 1)})
    plt.plot(x_grid.cpu().data.numpy(), mu.cpu().data.numpy(), linewidth=1)
plt.show()

In [None]:
train_X_ = torch.from_numpy(train_X.astype("float32")).to(device)
train_y_ = torch.from_numpy(train_y.astype("float32")).to(device)

# 訓練
for i in range(20000):
    opt.zero_grad()
    
    # context target split
    x_c, x_t, y_c, y_t = context_target_random_split(train_X_[:, None], train_y_[:, None])
    x_ct = torch.cat([x_c, x_t], dim=0)
    y_ct = torch.cat([y_c, y_t], dim=0)

    # deterministic path
    r_mean = r_encoder(x_c, y_c)["r"].mean(0) # aggregate

    # latent path
    z_sample_target = s_encoder.sample({"x": x_t, "y": y_t})
    
    # Loss
    nll = - decoder.log_likelihood({"x_": x_t, "r": r_mean.repeat(len(x_t), 1), "z": z_sample_target["z"].repeat(len(x_t), 1), "y_": y_t})
    kl = s_encoder.log_likelihood(z_sample_target) - s_encoder_context.log_likelihood({"x_c": x_c, "y_c": y_c, "z": z_sample_target["z"]})
    loss = nll.mean() + kl.mean()

    loss.backward()
    opt.step()
    

    # visualize
    if ((i+1)%200)==0:
        Nx = 100
        x = np.linspace(-0.2, 1.2, Nx)
        x_ = torch.from_numpy(x.astype("float32")).to(device)
        r_mean = r_encoder(x_ct, y_ct)["r"].mean(0)
        for j in range(10):
            z = s_encoder.sample({"x": x_ct, "y": y_ct})["z"]
            y_ = decoder.sample_mean({"x_": x_[:, None], "r": r_mean.repeat(len(x_), 1), "z": z.repeat(len(x_), 1)})
            y_ = y_.detach().cpu().numpy()
            if j == 0:
                plt.plot(x, y_, alpha=0.5, c="b",label="NP sample")
            else:
                plt.plot(x, y_, alpha=0.5, c="b")


        plt.scatter(train_X, train_y)
        plt.title("epoch: {}".format(i+1))
        plt.xlim(-0.22, 1.22)
        plt.ylim(-0.55, 2.4)
        plt.plot(x, y_mean, "k", label="GP mean")
        plt.plot(x, y_upper, "k--", label="GP 2sigma")
        plt.plot(x, y_under, "k--")
        plt.legend(loc='upper right')
        plt.savefig("./NP_png/{}".format(i+1))
        plt.show()

# ANP

In [None]:
z_dim = 2
d_dim = 128
x_dim = 1
y_dim = 1

r_encoder = R_encoder(x_dim, y_dim, d_dim, z_dim).to(device) # (x,y)->r
s_encoder = S_encoder(x_dim, y_dim, d_dim, z_dim).to(device) # (x,y)->z
s_encoder_context = s_encoder.replace_var(x="x_c", y="y_c").to(device)
CA = CrossAttention(x_dim, d_dim, z_dim).to(device) # (x_t, x_c, r)->r_
decoder = Decoder(x_dim, y_dim, d_dim, z_dim).to(device) # (x*, z) -> y*
opt = torch.optim.Adam(list(decoder.parameters())+list(r_encoder.parameters())+
                       list(s_encoder.parameters())+list(CA.parameters()), 1e-3)

In [None]:
train_X_ = torch.from_numpy(train_X.astype("float32")).to(device)
train_y_ = torch.from_numpy(train_y.astype("float32")).to(device)

# 訓練
for epoch in range(20000):
    opt.zero_grad()
    
    # context target split
    x_c, x_t, y_c, y_t = context_target_random_split(train_X_[:, None], train_y_[:, None])
    x_ct = torch.cat([x_c, x_t], dim=0)
    y_ct = torch.cat([y_c, y_t], dim=0)

    # deterministic path
    r = r_encoder(x_c, y_c)["r"]
    r = CA(x_t, x_c, r)["r_"]  # ATTENTION
    
    z_sample_target = s_encoder.sample({"x": x_t, "y": y_t})

    # Loss
    nll = - decoder.log_likelihood({"x_": x_t, "r": r, "z": z_sample_target["z"].repeat(len(x_t), 1), "y_": y_t})
    kl = s_encoder.log_likelihood(z_sample_target) - s_encoder_context.log_likelihood({"x_c": x_c, "y_c": y_c, "z": z_sample_target["z"]})
    loss = nll.mean() + kl.mean()
    loss.backward()
    opt.step()

    # visualize
    if ((epoch+1)%200)==0:
        Nx = 100
        x = np.linspace(-0.2, 1.2, Nx)
        x_ = torch.from_numpy(x.astype("float32")).to(device)
        r = r_encoder(x_ct, y_ct)["r"]
        r = CA(x_[:, None], x_ct, r)["r_"]
        for j in range(10):
            z = s_encoder.sample({"x": x_ct, "y": y_ct})["z"]
            x_ = torch.from_numpy(x.astype("float32")).to(device)
            y_ = decoder.sample_mean({"x_": x_[:, None], "r": r, "z": z.repeat(len(x_), 1)})
            y_ = y_.detach().cpu().numpy()
            if j == 0:
                plt.plot(x, y_, alpha=0.5, c="b",label="ANP sample")
            else:
                plt.plot(x, y_, alpha=0.5, c="b")


        plt.scatter(train_X, train_y)
        plt.title("epoch: {}".format(epoch+1))
        plt.xlim(-0.22, 1.22)
        plt.ylim(-0.55, 2.4)
        plt.plot(x, y_mean, "k", label="GP mean")
        plt.plot(x, y_upper, "k--", label="GP 2sigma")
        plt.plot(x, y_under, "k--")
        plt.legend(loc='upper right')
        plt.savefig("./ANP_png/{}".format(epoch+1))
        plt.show()