<a href="https://colab.research.google.com/github/Jeevesh8/PyTorch-Learning-NBs/blob/master/PyTorch_Distributions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**IMPORTANT LINKS** :-

    https://bochang.me/blog/posts/pytorch-distributions/

    https://pytorch.org/docs/stable/distributions.html#score-function

    https://pytorch.org/docs/stable/distributions.html#pathwise-derivative

In [57]:
import torch

mu = torch.tensor([[1.,2, 3],[4,5,6]])
sigma = torch.tensor([[2,3,4],[1,2.,1]])
dis = torch.distributions.normal.Normal(mu, sigma)
samples = dis.rsample((2,)) #[2 samples] Reparametrized Sample(Gradient passes back through mu and sigma if mu and sigma can pass gradient any backward)
print(samples)    
print(dis.sample((2,)))     #Simple sample

tensor([[[ 1.9505,  3.2115, -1.4568],
         [ 4.3786,  6.8014,  5.4601]],

        [[ 3.4084,  7.0948, 11.5182],
         [ 2.4940,  8.0826,  5.5930]]])
tensor([[[-0.3580, -0.1331, -3.6138],
         [ 4.9757,  2.4498,  7.5045]],

        [[ 0.1587,  5.2897,  4.4107],
         [ 4.0395,  1.8586,  7.0182]]])


In [60]:
x = torch.tensor([0,1.,2.])
print(x,x.exp())
print(dis.log_prob(samples), dis.log_prob(samples).exp())

tensor([0., 1., 2.]) tensor([1.0000, 2.7183, 7.3891])
tensor([[[-1.7250, -2.0991, -2.9260],
         [-0.9906, -2.0177, -1.0647]],

        [[-2.3372, -3.4596, -4.5727],
         [-2.0530, -2.7999, -1.0018]]]) tensor([[[0.1782, 0.1226, 0.0536],
         [0.3714, 0.1330, 0.3448]],

        [[0.0966, 0.0314, 0.0103],
         [0.1283, 0.0608, 0.3672]]])


In [0]:
relu = torch.nn.ReLU()
x = torch.tensor([1.,2,3], requires_grad=True)
y = torch.tensor([4.,5,6], requires_grad=True)
optimizer = torch.optim.Adam([x,y], lr=0.01)

    
iters = 1000

for i in range(iters) :
    optimizer.zero_grad()
    
    m = relu(x)                                                                 #Don't write x=relu(x) as, old x would be deleted then and x.grad==0 after optimizer.step()
    n = relu(y)
    
    tar_dis = torch.distributions.multinomial.Multinomial(1, m)
    dis = torch.distributions.multinomial.Multinomial(1, n)
    
    if i==0 or i==iters-1 :
        print(dis.probs, tar_dis.probs)
    
    a = dis.log_prob(dis.sample())
    b = tar_dis.log_prob(tar_dis.sample())
    
    #print(a,b)
    loss = (a-b)*(a-b)
    
    #print(loss)
    loss.backward()
    
    optimizer.step()
    #print(x.grad, y.grad)
    
    x.detach_()
    y.detach_()
    x.requires_grad = True                                                      #After detach, requires_grad must be set true
    y.requires_grad = True
    
    #print(x, dis.probs)
    #print(y, tar_dis.probs)

tensor([0.2667, 0.3333, 0.4000], grad_fn=<DivBackward0>) tensor([0.1667, 0.3333, 0.5000], grad_fn=<DivBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<DivBackward0>) tensor([0.3333, 0.3333, 0.3333], grad_fn=<DivBackward0>)


In [0]:
from torch.distributions.kl import kl_divergence as kld 

mean_1 = torch.tensor([1., 2, 3], requires_grad=True)
mean_2 = torch.tensor([2., 2, 3], requires_grad=True)
var_1 = torch.tensor([[1., 0, 0], [0., 1, 0], [0, 0, 0.7]], requires_grad=True)
var_2 = torch.tensor([[1., 0, 0], [0., 1, 0], [0, 0, 0.8]], requires_grad=True)
params = [mean_1, mean_2, var_1, var_2]
optimizer = torch.optim.Adam(params, lr=0.01)

iters=100
for i in range(iters) :
    optimizer.zero_grad()

    
    dis_1 = torch.distributions.multivariate_normal.MultivariateNormal(mean_1, var_1)
    dis_2 = torch.distributions.multivariate_normal.MultivariateNormal(mean_2, var_2)

    loss = kld(dis_1, dis_2)
    
    if i==0 or i==iters-1 :
        print(loss)

    loss.backward()
    optimizer.step()
    
    for param in params :
        param.detach_()
        param.requires_grad = True


tensor(0.5043, grad_fn=<AddBackward0>)
tensor(8.4341e-06, grad_fn=<AddBackward0>)


In [77]:
x = torch.randn((1,2,3))
print(x)
print(x.repeat_interleave(4, -2).reshape(1,2,4,3))

tensor([[[-0.9830,  1.3225,  0.9491],
         [-0.6345, -1.1244,  1.5391]]])
tensor([[[[-0.9830,  1.3225,  0.9491],
          [-0.9830,  1.3225,  0.9491],
          [-0.9830,  1.3225,  0.9491],
          [-0.9830,  1.3225,  0.9491]],

         [[-0.6345, -1.1244,  1.5391],
          [-0.6345, -1.1244,  1.5391],
          [-0.6345, -1.1244,  1.5391],
          [-0.6345, -1.1244,  1.5391]]]])
