In [None]:
import config
import torch
import numpy as np
import matplotlib

from torch import tensor
from variational.nn import IsotropicGaussian, Sequential

In [None]:
from torch import nn
import variational.nn as vnn 

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.encode = vnn.Sequential(
            nn.Linear(1, 12),
            nn.ReLU(),
            nn.Linear(12, 24),
            nn.ReLU(),
            nn.Linear(24, 12),
            nn.ReLU(),
            IsotropicGaussian(12, 1)
        )

        self.decode= nn.Linear(1, 1)

    def forward(self, x, n_samples):
        y = self.encode(x, n_samples)
        return self.decode(y).mean(0)
    


toyModel = ToyModel()

x = torch.randn(3, 1, dtype=torch.float32)
y = toyModel(x, 2)

mu, sigma = toyModel.encode[6].mu, toyModel.encode[6].sigma
print(y.shape, mu, sigma)

In [None]:
from dataset import SyntheticDataset1

dataset = SyntheticDataset1(1000)
dataset.X.shape, dataset.y.shape

from matplotlib import pyplot as plt
fig, ax = plt.figure(figsize=(5,3)), plt.axes()
ax.scatter(dataset.X, dataset.y, c=dataset.y, s=12)
plt.show()

In [None]:
from torch.nn.functional import mse_loss
from torch.optim import Adam
from torch.utils.data import DataLoader
from variational.loss import SGVBL

model = ToyModel()
optimizer = Adam(model.parameters(), lr=1e-3)
loss = SGVBL(model, 1, mle=mse_loss)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

for epoch in range(1000):
    for x, y in dataloader:
        optimizer.zero_grad()
        y_hat = model(x, 10)
        l = loss(y_hat, y, 1e-3)
        l.backward()
        optimizer.step()
    print(f'Epoch {epoch}: {l.item()}')



In [None]:

encode = toyModel.encode
_ = encode(dataset.X)
mu, std = encode[6].mu.detach().numpy(), encode[6].sigma.detach().numpy()

fig, ax = plt.subplots(1, 2, figsize=(5,3))
ax[0].scatter(dataset.X, dataset.y, c=dataset.y, s=12)
ax[1].scatter(dataset.X.ravel(), mu.ravel())
plt.show()