# Reparameterization Trick

https://medium.com/@llionj/the-reparameterization-trick-4ff30fe92954

In [1]:
import torch

from torch.optim import SGD, Adam

from torch.nn import MSELoss

In [17]:
uniform = torch.distributions.Uniform(-1, 1)  # input

In [20]:
unit_normal = torch.distributions.Normal(0, 1)

In [18]:
criterion = MSELoss()

# With rsample

In [33]:
delta = torch.tensor(1.0, requires_grad=True)
optimizer = Adam([delta], lr=0.05)
for _ in range(100):
    optimizer.zero_grad()
    x = uniform.sample()
    normal = torch.distributions.Normal(x, delta)
    y = normal.rsample()
    loss = criterion(y, x)
    loss.backward()
    optimizer.step()
print(delta)

tensor(-0.0068, requires_grad=True)


# With sample
This is expected to throw an exception.

In [34]:
delta = torch.tensor(1.0, requires_grad=True)
optimizer = Adam([delta], lr=0.05)
for _ in range(100):
    optimizer.zero_grad()
    x = uniform.sample()
    normal = torch.distributions.Normal(x, delta)
    y = normal.sample()  # sample instead of rsample
    loss = criterion(y, x)
    loss.backward()
    optimizer.step()
print(delta)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

# Manual reparametrization

In [35]:
delta = torch.tensor(1.0, requires_grad=True)
optimizer = Adam([delta], lr=0.05)
for _ in range(100):
    optimizer.zero_grad()
    x = uniform.sample()
    y = x + (unit_normal.sample() * delta)
    loss = criterion(y, x)
    loss.backward()
    optimizer.step()
print(delta)

tensor(0.0051, requires_grad=True)
