In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import torch

sns.set(font_scale=1.2, style='whitegrid')

src_dir = (Path('..') / 'src').resolve()
if str(src_dir) not in sys.path:
    sys.path.insert(0, str(src_dir))

data_dir = os.environ['DATADIR']

## Toy Dataset

In [None]:
import torch
from torch.utils.data import Dataset


class ToyNoisyDataset(Dataset):
    def __init__(self, n=10, obs_scale=1., aug_scale=1.):
        super().__init__()

        self._train = True
        self._clean = True
        self._aug = False
        self.aug_scale = aug_scale

        self.sigma = obs_scale

        x = 1.5 * (2 * torch.rand(n, 1) - 1)
        
        self.targets = x + self.sigma * torch.randn_like(x)
        self.clean_data = x
        self.spur_data = torch.cat([x, torch.randn_like(x)], dim=-1)

        x = torch.linspace(-6, 6, 100).unsqueeze(-1)

        self.test_targets = x + self.sigma * torch.randn_like(x)
        self.clean_test_data = x
        self.spur_test_data = torch.cat([x, torch.randn_like(x)], dim=-1)

    def train(self):
        self._train = True
        return self

    def eval(self):
        self._train = False
        return self

    def clean(self, mode=True):
        self._clean = mode
        return self

    def aug(self, mode=True):
        self._aug = mode
        return self
    
    def __len__(self):
        if self._train:
            return len(self.targets)
        return len(self.test_targets)

    def __getitem__(self, index):
        if self._train:
            x = self.clean_data[index] if self._clean else self.spur_data[index]
            if not self._clean and self._aug:
                x[..., -1] += self.aug_scale * torch.randn_like(x[..., -1])
            return x, self.targets[index]

        x = self.clean_test_data if self._clean else self.spur_test_data[index]
        return x, self.test_targets[index]

dataset = ToyNoisyDataset(n=10, obs_scale=1.).train()

fig, ax = plt.subplots()

x, y = dataset[:]
ax.scatter(x.numpy(), y.numpy())
ax.set(ylim=[-15,15], xlim=[-6,6])
fig.show()

## Inference

Posterior predictive for Bayesian linear regression.

### Exact

In [None]:
def post_pred(X, y, X_test, obs_scale=1., prior_scale=1.):
    '''
    X: n x d
    y: n x 1
    '''
    d = X.size(-1)
    
    post_prec = (X.T @ X).div(obs_scale**2) + torch.eye(d).div(prior_scale**2)
    post_mean = torch.linalg.solve(post_prec, X.T @ y).div(obs_scale**2)

    test_mean = X_test @ post_mean
    # Just get diagonal of covariance.
    test_var = (X_test * torch.linalg.solve(post_prec, X_test.T).T).sum(dim=-1, keepdim=True) \
                + obs_scale**2

    return test_mean, test_var, post_mean, post_prec

### SGLD

In [None]:
from copy import deepcopy
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.distributions import Normal

from data_aug.optim import SGLD
from data_aug.optim.lr_scheduler import CosineLR

class GaussianPriorAugmentedRegressionLoss(nn.Module):
  '''Gaussian likelihood + prior.

  To get the unbiased density estimate, multiply by N.
  '''
  def __init__(self, params, aug_scale=10, obs_scale=1, prior_scale=1):
    super().__init__()

    self.theta = params
    self.omega = aug_scale
    self.sigma = prior_scale
    self.sigma_obs = obs_scale

  def forward(self, obs_mean, Y, N=1):
    p_obs = Normal(obs_mean, self.sigma_obs)
    energy = -p_obs.log_prob(Y).mean()
    
    for p in self.theta:
      prior = Normal(torch.zeros_like(p), self.sigma)
      energy -= prior.log_prob(p).sum().div(N)
    
    return energy


def run_sgld(net, epochs, train_loader, criterion, sgld, sgld_scheduler=None):
  samples = []

  for _ in tqdm(range(epochs)):
    net.train()
    for (X, Y) in train_loader:
      sgld.zero_grad()

      f_hat = net(X)
      loss = criterion(f_hat, Y, N=N)

      loss.backward()

      if sgld_scheduler is None:
        sgld.step()
        samples.append(deepcopy(net.state_dict()))
      else:
        if sgld_scheduler.get_last_beta() < sgld_scheduler.beta:
          sgld.step(noise=False)
        else:
          sgld.step()

          if sgld_scheduler.should_sample():
            samples.append(deepcopy(net.state_dict()))

        sgld_scheduler.step()
  
  if sgld_scheduler is None:
    samples = samples[1000::100]
  return samples

## Results

### With Clean Data

In [None]:
fig, ax = plt.subplots(figsize=(7,5))

_X, _y = dataset.train().clean()[:]
_X_test, _ = dataset.eval().clean()[:]

exact_mean, exact_var, post_mean, post_prec = post_pred(
    torch.cat([_X, torch.ones(len(_X), 1)], dim=-1), _y,
    torch.cat([_X_test, torch.ones(len(_X_test), 1)], dim=-1),
    obs_scale=dataset.sigma)

print(post_mean)
print(torch.linalg.pinv(post_prec))

ax.scatter(_X[:, 0].numpy(), _y[:, 0].numpy(), alpha=.2, c='black')
ax.plot(_X_test[:, 0].numpy(), exact_mean[:, 0].numpy(), label='Exact', c='green')
ax.plot(_X_test[:, 0].numpy(), (exact_mean - 2 * exact_var.sqrt())[:, 0].numpy(),
             linestyle='dashed', c='green')
ax.plot(_X_test[:, 0].numpy(), (exact_mean + 2 * exact_var.sqrt())[:, 0].numpy(),
             linestyle='dashed', c='green')
# ax.fill_between(_X_test[:, 0].numpy(), y1=(exact_mean - 2 * exact_var.sqrt())[:, 0].numpy(),
#                      y2=(exact_mean + 2 * exact_var.sqrt())[:, 0].numpy(), alpha=.2, color='green')

######################

dataset = dataset.train().clean()

N = len(dataset)
epochs = 50000
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

net = nn.Linear(_X.size(-1), 1)
criterion = GaussianPriorAugmentedRegressionLoss(net.parameters(), prior_scale=1,
                                                 obs_scale=dataset.sigma)

sgld = SGLD(net.parameters(), lr=1e-3, momentum=0, temperature=1 / N)
sgld_scheduler = None #CosineLR(sgld, n_cycles=1, n_samples=100, T_max=len(train_loader) * epochs)

samples = run_sgld(net, epochs, train_loader, criterion, sgld, sgld_scheduler)

_y_test = []
with torch.no_grad():
  for s in samples:
    net.load_state_dict(s)
    _y_test.append(net(_X_test))
  _y_test = torch.stack(_y_test)

  samples = torch.Tensor([[s['weight'].item() for s in samples],
                         [s['bias'].item() for s in samples]]).T

pred_mean = _y_test.mean(dim=0)
pred_var = dataset.sigma**2 + _y_test.pow(2).mean(dim=0) \
           - pred_mean.pow(2)

print(samples.mean(dim=0, keepdim=True))
print(samples.T.cov())

ax.plot(_X_test[:, 0].numpy(), pred_mean[:, 0].numpy(), label='SGLD', c='red')
ax.plot(_X_test[:, 0].numpy(), (pred_mean - 2 * pred_var.sqrt())[:, 0].numpy(),
             linestyle='dotted', c='red')
ax.plot(_X_test[:, 0].numpy(), (pred_mean + 2 * pred_var.sqrt())[:, 0].numpy(),
             linestyle='dotted', c='red')
ax.set(title='Clean Data')
ax.legend()

fig.show()
# fig.savefig('clean_data.png', bbox_inches='tight')

### With Spurious Data

In [None]:
fig, ax = plt.subplots(figsize=(7,5))

_X, _y = dataset.train().clean(False)[:]
_X_test, _ = dataset.eval().clean(False)[:]

ax.scatter(_X[:, 0].numpy(), _y[:, 0].numpy(), alpha=.2, c='black')

exact_mean, exact_var, post_mean, post_prec = post_pred(
    torch.cat([_X, torch.ones(len(_X), 1)], dim=-1), _y,
    torch.cat([_X_test, torch.ones(len(_X_test), 1)], dim=-1),
    obs_scale=dataset.sigma)

print(post_mean)
print(torch.linalg.pinv(post_prec))

ax.plot(_X_test[:, 0].numpy(), exact_mean[:, 0].numpy(), label='Exact', c='green')
ax.plot(_X_test[:, 0].numpy(), (exact_mean - 2 * exact_var.sqrt())[:, 0].numpy(),
             linestyle='dashed', c='green')
ax.plot(_X_test[:, 0].numpy(), (exact_mean + 2 * exact_var.sqrt())[:, 0].numpy(),
             linestyle='dashed', c='green')

#########

dataset = dataset.train().clean(False)

N = len(dataset)
epochs = 50000
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

net = nn.Linear(_X.size(-1), 1)
criterion = GaussianPriorAugmentedRegressionLoss(net.parameters(), prior_scale=1,
                                                 obs_scale=dataset.sigma)

sgld = SGLD(net.parameters(), lr=3e-5, momentum=0, temperature=1 / N)
sgld_scheduler = None #CosineLR(sgld, n_cycles=1, n_samples=100, T_max=len(train_loader) * epochs)

samples = run_sgld(net, epochs, train_loader, criterion, sgld, sgld_scheduler)

_y_test = []
with torch.no_grad():
  for s in samples:
    net.load_state_dict(s)
    _y_test.append(net(_X_test))
  _y_test = torch.stack(_y_test)

  samples = torch.cat([
    torch.cat([s['weight'] for s in samples], dim=0),
    torch.cat([s['bias'] for s in samples], dim=0).unsqueeze(-1)], dim=-1)

pred_mean = _y_test.mean(dim=0)
pred_var = dataset.sigma**2 + _y_test.pow(2).mean(dim=0) \
           - pred_mean.pow(2)

print(samples.mean(dim=0, keepdim=True))
print(samples.T.cov())

ax.plot(_X_test[:, 0].numpy(), pred_mean[:, 0].numpy(), label='SGLD', c='red')
ax.plot(_X_test[:, 0].numpy(), (pred_mean - 2 * pred_var.sqrt())[:, 0].numpy(),
             linestyle='dotted', c='red')
ax.plot(_X_test[:, 0].numpy(), (pred_mean + 2 * pred_var.sqrt())[:, 0].numpy(),
             linestyle='dotted', c='red')

ax.set(title='Spurious Data')
ax.legend()

fig.show()
# fig.savefig('spur_data.png', bbox_inches='tight')

### With Spurious Data + Augmentation

In [None]:
fig, ax = plt.subplots(figsize=(7,5))

n_aug = 5
_X, _y = [], []
for _ in range(n_aug):
    _A, _b = dataset.train().clean(False).aug()[:]
    _X.append(_A)
    _y.append(_b)
_X, _y = torch.cat(_X, dim=0), torch.cat(_y, dim=0)
ax.scatter(_X[:len(dataset), 0].numpy(), _y[:len(dataset), 0].numpy(), alpha=.2, c='black')

_X_test, _ = dataset.eval().clean(False).aug()[:]

exact_mean, exact_var, post_mean, post_prec = post_pred(
    torch.cat([_X, torch.ones(len(_X), 1)], dim=-1), _y,
    torch.cat([_X_test, torch.ones(len(_X_test), 1)], dim=-1),
    obs_scale=dataset.sigma)
print(post_mean)
print(torch.linalg.pinv(post_prec))

ax.plot(_X_test[:, 0].numpy(), exact_mean[:, 0].numpy(), label='Exact', c='green')
ax.plot(_X_test[:, 0].numpy(), (exact_mean - 2 * exact_var.sqrt())[:, 0].numpy(),
             linestyle='dashed', c='green')
ax.plot(_X_test[:, 0].numpy(), (exact_mean + 2 * exact_var.sqrt())[:, 0].numpy(),
             linestyle='dashed', c='green')

######################

dataset.train().clean(False).aug()

N = len(dataset)
epochs = 50000
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

net = nn.Linear(_X.size(-1), 1)
criterion = GaussianPriorAugmentedRegressionLoss(net.parameters(), prior_scale=1,
                                                 obs_scale=dataset.sigma * np.sqrt(epochs))
sgld = SGLD(net.parameters(), lr=3e-5, momentum=0, temperature=1 / N)
sgld_scheduler = None #CosineLR(sgld, n_cycles=1, n_samples=100, T_max=len(train_loader) * epochs)

samples = run_sgld(net, epochs, train_loader, criterion, sgld, sgld_scheduler)

_y_test = []
with torch.no_grad():
  for s in samples:
    net.load_state_dict(s)
    _y_test.append(net(_X_test))
  _y_test = torch.stack(_y_test)

  samples = torch.cat([
    torch.cat([s['weight'] for s in samples], dim=0),
    torch.cat([s['bias'] for s in samples], dim=0).unsqueeze(-1)], dim=-1)

pred_mean = _y_test.mean(dim=0)
pred_var = dataset.sigma**2 + _y_test.pow(2).mean(dim=0) \
           - pred_mean.pow(2)

print(samples.mean(dim=0, keepdim=True))
print(samples.T.cov())

ax.plot(_X_test[:, 0].numpy(), pred_mean[:, 0].numpy(), label='SGLD', c='red')
ax.plot(_X_test[:, 0].numpy(), (pred_mean - 2 * pred_var.sqrt())[:, 0].numpy(),
             linestyle='dotted', c='red')
ax.plot(_X_test[:, 0].numpy(), (pred_mean + 2 * pred_var.sqrt())[:, 0].numpy(),
             linestyle='dotted', c='red')

ax.set(title='Spurious Data + Augmentation')
ax.legend()

fig.show()
# fig.savefig('spur_data_aug.png', bbox_inches='tight')