In [None]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.autograd.functional import jacobian
from torch.utils.data import Subset
from torch import Tensor
from typing import Tuple, Callable
from itertools import chain
import copy
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch.cuda.amp import GradScaler
from torch.distributions import MultivariateNormal
from typing import Union

In [8]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

scaler = GradScaler(device)

Using cuda device


  scaler = GradScaler(device)


In [None]:
class LangevinSampling:
    def __init__(
            self,
            model: Callable,
            batch_size: int,
            sample_size: list,
            refresh_prob: float,
            n_steps: int,
            step_size: float,
            noise_scale: float,
            device: str,
        ):
        self.model = model
        self.batch_size = batch_size
        self.sample_size = sample_size
        self.refresh_prob = refresh_prob.to(device)
        self.n_steps = n_steps
        self.step_size = step_size
        self.noise_scale = noise_scale
        self.device = device
        self.buffer = 2*torch.rand([batch_size, *sample_size], device=device) - 1


    def _get_refresh_indices(self):
        mask = torch.rand(self.sample_size, device=self.device) < self.refresh_prob
        indexes_to_refresh = torch.nonzero(mask, as_tuple=True)[0]
        return indexes_to_refresh
    

    def _refresh_chains(self):
        indexes_to_refresh = self._get_refresh_indices()
        n_new = len(indexes_to_refresh)
        self.buffer[indexes_to_refresh] = 2*torch.rand([n_new, *self.sample_size], device=device) - 1
    

    def _compute_per_chain_grads(self, x: Tensor):
        x.requires_grad_()
        y = self.model(x)

        grads = []
        for i in range(x.size(0)):
            grad_i = torch.autograd.grad(
                y[i], x, retain_graph=True, create_graph=True
            )[0][i]
            grads.append(grad_i)

        return torch.stack(grads)
    

    def _langevin_step(self, x: Tensor):
        noise = self.noise_scale * (2*torch.rand_like(x) - 1)
        grad = self.batch_size*self._compute_per_chain_grads(x)
        self.buffer += grad + noise


    def persist_sample(self):
        self._refresh_chains()
        for _ in range(self.n_steps):
            self._langevin_step(self.buffer)

        return self.buffer
    

    def sample(self, batch_size: int, n_steps: Union[int, None] = None):
        if n_steps is None:
            n_steps = self.n_steps

        x = 2*torch.rand([batch_size, *self.sample_size], device=device) - 1
        for _ in range(n_steps):
            self._langevin_step(x)

        return x
