In [14]:
import torch
from torch.autograd import grad
import numpy as np

In [3]:
def potential(x):
    return torch.sum(x**2)

In [None]:
x = torch.tensor(np.array([1.,2.]))

In [28]:
for _ in range(1000):
    torch.set_grad_enabled(True)
    grad_vec = grad(potential(x),x.requires_grad_())[0]
    torch.set_grad_enabled(False)
    x.add_(x)
    grad_vec.dot(grad_vec)

In [26]:
grad_vec[0]

tensor([2.1430e+301, 4.2860e+301], dtype=torch.float64)

In [23]:
grad_vec

(tensor([1.0715e+301, 2.1430e+301], dtype=torch.float64),)

In [32]:
import torch
import numpy as np
from torch.autograd import grad

In [64]:
def compute_grad(fn, x):
    torch.set_grad_enabled(True)
    computed_grad = grad(fn(x),x)[0]
    torch.set_grad_enabled(False)
    return computed_grad

In [80]:
class BPS(object):

    def __init__(self, x, v, potential_fn, bounce_fn):
        super(BPS, self).__init__()
        self.dim = len(x)
        self.x = torch.tensor(x, requires_grad=True, dtype=torch.float)
        self.v = torch.tensor(v, dtype=torch.float)

        self.potential_fn = potential_fn
        self.bounce_fn = bounce_fn

    def propagate_state(self, t):
        self.x = self.x.detach() + (self.v * t)

    def simulate_new_event(self):
        return self.bounce_fn(self.x, self.v)

    def event_transition(self):
        xnew = self.x
        potential_grad = compute_grad(self.potential_fn, self.x.requires_grad_())
        
        self.v = self.v - 2*potential_grad.dot(self.v) * potential_grad / potential_grad.dot(potential_grad)

    def get_state(self):
        return self.x.detach().numpy(), self.v.detach().numpy()

    def run(self, num_iterations):
        results = []
        for i in range(num_iterations):
            new_t = self.simulate_new_event()
            self.propagate_state(new_t)
            self.event_transition()
            results.append(self.get_state())
        return results


def gaussian_bounce(mu, Sig):
    inv_Sig = torch.inverse(Sig)
    unif_dist = torch.distributions.Uniform(0,1)
    def func(x, v):

        epsilon = unif_dist.sample()
        vv = 0.5 * torch.matmul(v.unsqueeze(-1).T, torch.matmul(inv_Sig, v.unsqueeze(-1)))
        xv = 0.5 * torch.matmul(v.unsqueeze(-1).T, torch.matmul(inv_Sig, (x - mu).unsqueeze(-1)))
        if xv < 0:
            return ((-xv + torch.sqrt(-vv * torch.log(epsilon))) / vv).squeeze()
        else:
            return ((-xv + torch.sqrt(xv ** 2 - vv * torch.log(epsilon))) / vv).squeeze()

    return func



dim = 1000
Sig = torch.tensor(np.diag(np.ones(dim)), dtype=torch.float)
mu = torch.tensor(np.zeros(dim), dtype=torch.float)

x = np.random.random(dim)
v = np.random.random(dim)

potential_fn = lambda x: torch.sum(x ** 2)
bounce_fn = gaussian_bounce(mu, Sig)

bps = BPS(x, v, potential_fn, bounce_fn)

In [81]:
%%timeit 
bps.run(1)

1.35 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [63]:
bps.x

tensor([-0.2042, -1.3123])