In [2]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import RichProgressBar
import torch
import torch.nn.functional as F
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

Read data, encode the categorical column `EJ`, impute missing values with medians. Then convert into PyTorch datasets and dataloaders.

In [3]:
data_path = ""


def read_numpy(df_path):
    df = pd.read_csv(df_path, index_col="Id")
    df["EJ"] = df["EJ"] == "B"
    if "Class" in df.columns:
        X = df.drop(columns="Class").to_numpy(dtype=np.float32)
        y = df["Class"].to_numpy(dtype=np.float32)
        return X, y
    else:
        return df.to_numpy(dtype=np.float32)


imputer = SimpleImputer(strategy="median")
X_avail, y_avail = read_numpy(data_path + "train.csv")
X_avail = imputer.fit_transform(X_avail)
X_train, X_val, y_train, y_val = train_test_split(
    X_avail, y_avail, test_size=0.2, random_state=42
)
X_test = imputer.transform(read_numpy(data_path + "test.csv"))
sample_df = pd.read_csv(data_path + "sample_submission.csv", index_col="Id")

X_train, y_train, X_val, y_val, X_test = map(
    torch.tensor, [X_train, y_train, X_val, y_val, X_test]
)
train_set = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_set, batch_size=len(train_set), num_workers=4)
val_set = TensorDataset(X_val, y_val)
val_loader = DataLoader(val_set, batch_size=len(val_set), num_workers=4)

Let's use a simple linear model for a baseline.

In [22]:
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from torch.distributions import constraints
from pyro.nn import PyroModule, PyroParam, PyroSample
from pyro.nn.module import to_pyro_module_
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.optim import Adam


class BayesianModel(PyroModule):
    def __init__(self):
        super().__init__()
        self.line = PyroModule[nn.Linear](56, 1)
        self.line.weight = PyroSample(
            dist.Normal(0, 1)
            .expand(self.line.weight.shape)
            .to_event(self.line.weight.dim())
        )
        self.line.bias = PyroSample(
            dist.Normal(0, 1)
            .expand(self.line.bias.shape)
            .to_event(self.line.bias.dim())
        )

    def forward(self, x, y=None):
        logits = self.line(x)
        with pyro.plate("instances", len(x)):
            return pyro.sample(
                "obs",
                dist.RelaxedBernoulliStraightThrough(
                    temperature=torch.tensor(1000.0), logits=logits
                ).to_event(1),
                obs=0.1 + 0.8 * y,
            )


In [24]:
pyro.clear_param_store()
pyro.set_rng_seed(1)

model = BayesianModel()
x, y = next(iter(train_loader))
print(x.shape, y.shape)

guide = AutoNormal(model)
svi = SVI(model, guide, Adam({"lr": 0.001}), Trace_ELBO())
for step in range(2):
    loss = svi.step(x, y) / y.numel()
    if step % 100 == 0:
        print("step {} loss = {:0.4g}".format(step, loss))

tensor([[2.1792e+00, 5.6189e+03, 8.5200e+01, 5.8469e+02, 5.0253e+01, 7.2112e+00,
         2.5578e-02, 1.1835e+01, 1.2299e+00, 5.0224e+03],
        [4.0166e-01, 3.8636e+03, 1.3555e+02, 9.3406e+00, 8.1387e+00, 5.3774e+00,
         2.5578e-02, 1.0808e+01, 3.7459e+00, 3.3809e+03],
        [4.5294e-01, 2.3797e+03, 9.2472e+01, 4.1992e+01, 8.1387e+00, 7.2201e+00,
         2.5578e-02, 6.3020e+00, 1.2299e+00, 2.2115e+03],
        [1.9229e-01, 2.2356e+03, 8.5200e+01, 1.8356e+01, 1.3592e+01, 3.9245e+00,
         1.4129e-01, 9.7177e+00, 1.2299e+00, 4.6595e+03],
        [1.7092e-01, 1.0729e+03, 9.0687e+01, 8.4300e+00, 8.1387e+00, 4.5890e+00,
         2.5578e-02, 8.6652e+00, 1.2299e+00, 3.5968e+03],
        [1.8801e-01, 1.8761e+03, 8.5200e+01, 1.1273e+01, 8.1387e+00, 2.8172e+00,
         2.5578e-02, 8.6400e+00, 1.2299e+00, 3.7982e+03],
        [3.2475e-01, 3.8251e+03, 8.5200e+01, 3.0642e+01, 1.4531e+01, 8.4603e+00,
         2.5578e-02, 1.1035e+01, 2.7409e+00, 7.9881e+03],
        [6.1958e-01, 4.7115