In [1]:
%reset -s -f

In [2]:
import os
from functools import partial
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import pyro
import pyro.distributions as dist

In [11]:
smoke_test = ('CI' in os.environ)
pyro.enable_validation(True)
assert pyro.__version__.startswith('1.3.1')
pyro.set_rng_seed(1)
pyro.enable_validation(True)
%matplotlib inline
plt.style.use('default')

In [19]:
DATA_PATH = r'C:\Users\Lenovo\Desktop\pl\project'
train = os.path.join(DATA_PATH, 'train.csv')
test = os.path.join(DATA_PATH, 'test.csv')
weights = os.path.join(DATA_PATH, 'true_weights.csv')
train_data = pd.read_csv(train, index_col=0).values
test_data = pd.read_csv(test, index_col=0).values
true_weights = pd.read_csv(weights, index_col=0).values

In [24]:
from torch import nn
from pyro.nn import PyroModule

assert issubclass(PyroModule[nn.Linear], nn.Linear)
assert issubclass(PyroModule[nn.Linear], PyroModule)

In [25]:
train_tensor = torch.tensor(train_data, dtype=torch.float)
test_tensor = torch.tensor(test_data, dtype=torch.float)

In [37]:
X_train, y_train = train_tensor[:, :-1], train_tensor[:,-1]
X_test, y_test = test_tensor[:, :-1], test_tensor[:, -1]

In [41]:
# Regression model
linear_reg_model = PyroModule[nn.Linear](10, 1)

In [42]:
linear_reg_model

PyroLinear(in_features=10, out_features=1, bias=True)

In [43]:
# Define loss and optimize
loss_fn = torch.nn.MSELoss(reduction='sum')
optim = torch.optim.Adam(linear_reg_model.parameters(), lr=0.05)
num_iterations = 1500 if not smoke_test else 2

In [44]:
def train():
    # run the model forward on the data
    y_pred = linear_reg_model(X_train).squeeze(-1)
    # calculate the mse loss
    loss = loss_fn(y_pred, y_train)
    # initialize gradients to zero
    optim.zero_grad()
    # backpropagate
    loss.backward()
    # take a gradient step
    optim.step()
    return loss

for j in range(num_iterations):
    loss = train()
    if (j + 1) % 50 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss.item()))


# Inspect learned parameters
print("Learned parameters:")
for name, param in linear_reg_model.named_parameters():
    print(name, param.data.numpy())

[iteration 0050] loss: 1.3841
[iteration 0100] loss: 0.4040
[iteration 0150] loss: 0.3958
[iteration 0200] loss: 0.3958
[iteration 0250] loss: 0.3958
[iteration 0300] loss: 0.3958
[iteration 0350] loss: 0.3958
[iteration 0400] loss: 0.3958
[iteration 0450] loss: 0.3958
[iteration 0500] loss: 0.3958
[iteration 0550] loss: 0.3958
[iteration 0600] loss: 0.3958
[iteration 0650] loss: 0.3958
[iteration 0700] loss: 0.3958
[iteration 0750] loss: 0.3958
[iteration 0800] loss: 0.3958
[iteration 0850] loss: 0.3958
[iteration 0900] loss: 0.3958
[iteration 0950] loss: 0.3958
[iteration 1000] loss: 0.3958
[iteration 1050] loss: 0.3958
[iteration 1100] loss: 0.3958
[iteration 1150] loss: 0.3958
[iteration 1200] loss: 0.3958
[iteration 1250] loss: 0.3958
[iteration 1300] loss: 0.3958
[iteration 1350] loss: 0.3958
[iteration 1400] loss: 0.3958
[iteration 1450] loss: 0.3958
[iteration 1500] loss: 0.3958
Learned parameters:
weight [[-0.28477463 -0.60408837  0.02518461  1.5225981   1.46099    -0.5903652


In [88]:
from pyro.nn import PyroSample


class BayesianRegression(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(dist.Normal(0., 1).expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 0.1).expand([out_features]).to_event(1))

    def forward(self, x, y=None):
        sigma = pyro.sample("sigma", dist.Uniform(0., 10))
        mean = self.linear(x).squeeze(-1)
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean

In [89]:
from pyro.infer.autoguide import AutoDiagonalNormal

model = BayesianRegression(10, 1)
guide = AutoDiagonalNormal(model)

In [90]:
from pyro.infer import SVI, Trace_ELBO


adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

In [91]:
pyro.clear_param_store()
for j in range(num_iterations):
    # calculate the loss and take a gradient step
    loss = svi.step(X_train, y_train)
    if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(train_tensor)))

[iteration 0001] loss: 3.1949
[iteration 0101] loss: 1.0819
[iteration 0201] loss: 0.9119
[iteration 0301] loss: 0.6070
[iteration 0401] loss: 0.4535
[iteration 0501] loss: 0.4700
[iteration 0601] loss: 0.7576
[iteration 0701] loss: 0.5230
[iteration 0801] loss: 0.5850
[iteration 0901] loss: 0.4364
[iteration 1001] loss: 0.7072
[iteration 1101] loss: 0.4160
[iteration 1201] loss: 0.6090
[iteration 1301] loss: 0.8364
[iteration 1401] loss: 0.6838


In [92]:
guide.requires_grad_(False)

for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))

AutoDiagonalNormal.loc Parameter containing:
tensor([-4.1899, -0.2898, -0.5869,  0.0107,  1.5130,  1.4300, -0.5740,  1.0071,
         0.1779, -0.8020, -0.6576,  0.0130])
AutoDiagonalNormal.scale tensor([0.1085, 0.0187, 0.0248, 0.0263, 0.0229, 0.0178, 0.0241, 0.0243, 0.0162,
        0.0222, 0.0197, 0.0201])


In [97]:
guide.quantiles([0.25, 0.5, 0.75])

{'sigma': [tensor(0.1388), tensor(0.1492), tensor(0.1604)],
 'linear.weight': [tensor([[-0.3024, -0.6036, -0.0070,  1.4976,  1.4180, -0.5903,  0.9907,  0.1670,
           -0.8170, -0.6709]]),
  tensor([[-0.2898, -0.5869,  0.0107,  1.5130,  1.4300, -0.5740,  1.0071,  0.1779,
           -0.8020, -0.6576]]),
  tensor([[-0.2772, -0.5702,  0.0285,  1.5285,  1.4420, -0.5577,  1.0235,  0.1888,
           -0.7870, -0.6443]])],
 'linear.bias': [tensor([-0.0006]), tensor([0.0130]), tensor([0.0265])]}

In [95]:
from pyro.infer import Predictive


def summary(samples):
    site_stats = {}
    for k, v in samples.items():
        site_stats[k] = {
            "mean": torch.mean(v, 0),
            "std": torch.std(v, 0),
            "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
            "95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
        }
    return site_stats


predictive = Predictive(model, guide=guide, num_samples=800,
                        return_sites=("linear.weight", "obs", "_RETURN"))
samples = predictive(X_train)
pred_summary = summary(samples)

In [96]:
mu = pred_summary["_RETURN"]
y = pred_summary["obs"]
predictions = pd.DataFrame({
    "mu_mean": mu["mean"],
    "mu_perc_5": mu["5%"],
    "mu_perc_95": mu["95%"],
    "y_mean": y["mean"],
    "y_perc_5": y["5%"],
    "y_perc_95": y["95%"],
    "true_y": y_train,
})

In [98]:
from pyro import poutine
from pyro.poutine.util import prune_subsample_sites
import warnings


class Predict(torch.nn.Module):
    def __init__(self, model, guide):
        super().__init__()
        self.model = model
        self.guide = guide

    def forward(self, *args, **kwargs):
        samples = {}
        guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
        model_trace = poutine.trace(poutine.replay(self.model, guide_trace)).get_trace(*args, **kwargs)
        for site in prune_subsample_sites(model_trace).stochastic_nodes:
            samples[site] = model_trace.nodes[site]['value']
        return tuple(v for _, v in sorted(samples.items()))

predict_fn = Predict(model, guide)
predict_module = torch.jit.trace_module(predict_fn, {"forward": (X_train,)}, check_trace=False)

In [101]:
torch.jit.save(predict_module, '/reg_predict.pt')
pred_loaded = torch.jit.load('/reg_predict.pt')
pred_loaded(X_train)

(tensor([-0.0253]),
 tensor([[-0.2729, -0.6312, -0.0327,  1.5240,  1.4222, -0.5671,  1.0286,  0.1851,
          -0.8037, -0.6570]]),
 tensor([ 2.4579,  1.7805,  0.7715, -3.8812, -3.3339,  1.7881,  3.1425, -1.2856,
          0.1320,  2.2151,  0.2221,  1.3540, -0.2499, -2.7208,  4.2636,  2.5203,
         -4.3985,  0.1932, -1.9239, -1.6937, -1.8983,  2.6671,  3.8161, -0.2910,
          1.3351, -0.0052, -2.7304,  2.5826,  4.2411,  2.7677, -0.5513, -0.5916,
         -2.6256, -2.8155, -1.4596, -3.3038, -1.2420,  1.6937,  3.1560, -0.9714]),
 tensor(0.1311))