In [1]:
from dataclasses import dataclass

import numpy as np
import seaborn as sns
import torch
import wandb
from contextual_gaussian import (ContextDataset, ContextualGaussian, ContextualGMM)
from dvi_process import DIS
from matplotlib import pyplot as plt
from score_function import ScoreFunction
from torch.distributions import Normal
from torch.utils.data import DataLoader
from train import train
from dataclasses import asdict

In [2]:
import torch_directml
device = torch.device("cpu") # = torch_directml.device()

In [3]:
dataset = ContextDataset(size=4096)
dataloader = DataLoader(dataset, batch_size=4096, shuffle=True)

In [4]:
@dataclass
class Config():
    num_steps = 16
    c_dim = 1
    z_dim = 1
    h_dim = 32
    num_layers = 3
    non_linearity = "SiLU"
    learning_rate = 3e-4
    num_epochs = 1000

config = Config()

In [5]:
score_function = ScoreFunction(
    c_dim=config.c_dim, 
    h_dim=config.h_dim, 
    z_dim=config.z_dim, 
    num_layers=config.num_layers, 
    non_linearity=config.non_linearity, 
    num_steps=config.num_steps
)

dvi_process = DIS(
    device=device, 
    z_dim=config.z_dim, 
    num_steps=config.num_steps, 
    score_function=score_function
).to(device)

optimizer = torch.optim.Adam(dvi_process.parameters(), lr=config.learning_rate)

In [6]:
target = ContextualGMM

In [7]:
wandb_logging = False
if wandb_logging:
    wandb.init(project="dvi", config=asdict(config))

In [8]:
losses = train(
    dvi_process, 
    device, 
    dataloader, 
    optimizer, 
    config.num_epochs, 
    target,
    wandb_logging=wandb_logging
)

100%|██████████| 1/1 [00:00<00:00,  1.88it/s, epoch=0, loss=inf]
  0%|          | 0/1 [00:00<?, ?it/s]


ValueError: Expected parameter loc (Tensor of shape (4096, 1)) of distribution Normal(loc: torch.Size([4096, 1]), scale: torch.Size([4096, 1])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan],
        [nan],
        [nan],
        ...,
        [nan],
        [nan],
        [nan]], grad_fn=<AddBackward0>)

In [17]:
num_samples = 256

context = torch.zeros((config.z_dim), device=device).unsqueeze(0).repeat(num_samples, 1)
# context = torch.ones((z_dim), device=device).unsqueeze(0).repeat(num_samples, 1) * 5

p_z_0 = Normal(
    torch.zeros((num_samples, config.z_dim), device=device),
    torch.ones((num_samples, config.z_dim), device=device) # * dvi_process.sigmas[0],
)

p_z_T = target(context)

dvi_process.eval()
with torch.no_grad():
    _, z_samples = dvi_process.run_chain(p_z_0, p_z_T, context.to(device))

    z_0_samples = z_samples[0].squeeze(1).tolist()
    z_T_samples = z_samples[-1].squeeze(1).tolist()

    z_trajectories = [[z[i].cpu().numpy() for z in z_samples] for i in range(num_samples)]

    z_samples = p_z_T.sample().squeeze(1).tolist()

fig, ax = plt.subplots(1, 4, figsize=(18, 3), gridspec_kw={'width_ratios': [1, 3, 1, 1]})

sns.kdeplot(z_0_samples, ax=ax[0])
ax[0].set_title("Prior $q(z_0)$")
ax[0].set_ylabel(None)
ax[0].set_xlim(-8, 8)
ax[0].set_ylim(0, 0.7)

for i in range(num_samples):
    ax[1].plot(z_trajectories[i])
ax[1].set_title("Samples from Forward Process $q(z_{0:T}|c)$")

sns.kdeplot(z_T_samples, ax=ax[2])
ax[2].set_title("Marginal $q(z_T|z_{0:T-1},c)$")
ax[2].set_ylabel(None)
ax[2].set_xlim(-8, 8)
ax[2].set_ylim(0, 0.7)

sns.kdeplot(z_samples, ax=ax[3])
ax[3].set_title("Ground-Truth $p(z|c)$")
ax[3].set_ylabel(None)
ax[3].set_xlim(-8, 8)
ax[3].set_ylim(0, 0.7)

plt.tight_layout()
plt.show()

tensor([-5.7756])


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)