In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
import torch
from torch import Tensor, distributions

import matplotlib.pyplot as plt
import seaborn as sns

In this exercise, we'll look at how we can use distributions to form differentiable density-estimation methods. 

## Sampling
Let's first get a differentiable sample of points. Below, create two parameters, `concentration` and `rate`, both initialised at 2, and use these to instantiate a `Gamma` distribution.

In [None]:
# your code here

Now sample 100 points from the `Gamma` into a tensor called `samples`.

In [None]:
# you code here

In [None]:
assert torch.autograd.grad(samples.mean(), concentration, retain_graph=True)[0].abs() > 0

## KDE
Now we'll plot their KDE (kernel density estimate). This is done using a standard method, which doesn't account for the gradient function on the samples.

In [None]:
sns.kdeplot(samples.detach())

Now we'll create our own, differentiable KDE for our sample. KDEs are formed by seeding Gaussians of a set scale (bandwidth) at each point in the sample, and taking their mean PDF at a range of points.

Below compute the KDE values, `y`, using a `Normal` distribution, for a range of points, `x`, for the given bandwidth.

In [None]:
bandwidth = 0.25
x = torch.linspace(0,4, 50)

In [None]:
# Your code here

In [None]:
assert y.shape==torch.Size([50])
assert torch.autograd.grad(y.mean(), concentration, retain_graph=True)[0].abs() > 0

And now we'll plot them below. The result should be similar to the plot above.

In [None]:
plt.plot(x,y.detach())

## Histogram
Quite often in HEP, and other fields, we like to work with histograms, but in the standard implementation, the bin fills are not differentiable: samples either lie in a given bin and contribute 1 to the bin population, or lie outside and contribute 0.

Using our `Normal` distribution, it becomes possible to compute a differentiable histogram. Have a think about how to do this (hint: use the `.cdf` method).

In [None]:
# your code here

In [None]:
assert torch.autograd.grad(bin_fills.mean(), concentration, retain_graph=True)[0].abs() > 0

In [None]:
plt.bar(x[:-1], bin_fills.detach())