In [1]:
from noise_lib import LogLinearNoise
from graph_lib import Uniform
import numpy as np
import torch
from scipy.linalg import expm

In [2]:
def get_analytic_score_fn(p, graph):
    def score_fn(x, sigma):
        Q_exp = expm(graph.Q*sigma)
        p_t = Q_exp@p
        den = torch.gather(p_t, -1, x[...,None])
        num = p_t * torch.ones(x.shape[0],1,graph.dim, device=x.device)
        return num/den
    return score_fn

In [3]:
dim = 2
graph = Uniform(dim)
noise = LogLinearNoise()
Q = graph.Q
print(Q)

tensor([[-1.,  1.],
        [ 1., -1.]])


In [4]:
bs = 512
tokens = 2
x = torch.randint(0,dim,(bs,tokens))
t = torch.rand_like(x.float())
# t = torch.tensor([[0.99]])
p = torch.rand(tokens,dim,1)
p = p/p.sum(dim=1,keepdim=True)
sigma, dsigma = noise(t)
"""print("x is \n", x)
print("p is \n", p)
print("i-th column is \n", graph.transition(x,sigma)) # p_t(x_t|x_0)
print("sampled transition is \n", graph.sample_transition(x,sigma)) 
print("sigma is \n", sigma)
print("t is \n", t)
print("p is \n", p)"""
p_t = graph.get_pt(p, sigma)
print("x shape is \n", x.shape)
print("x is \n", x)
score = graph.get_analytic_score(x, p, sigma)
print("p_t is \n", p_t)
print("p_t shape is \n", p_t.shape)
print("score shape is \n", score.shape)
print("score is \n", score)

x shape is 
 torch.Size([512, 2])
x is 
 tensor([[0, 0],
        [0, 1],
        [0, 0],
        ...,
        [0, 1],
        [0, 0],
        [0, 0]])
score matrix shape is  torch.Size([512, 2, 2, 2])
p_t is 
 tensor([[[[0.6024],
          [0.3976]],

         [[0.5322],
          [0.4678]]],


        [[[0.5350],
          [0.4650]],

         [[0.5335],
          [0.4665]]],


        [[[0.6496],
          [0.3504]],

         [[0.5011],
          [0.4989]]],


        ...,


        [[[0.7325],
          [0.2675]],

         [[0.5021],
          [0.4979]]],


        [[[0.6012],
          [0.3988]],

         [[0.5170],
          [0.4830]]],


        [[[0.8294],
          [0.1706]],

         [[0.5264],
          [0.4736]]]])
p_t shape is 
 torch.Size([512, 2, 2, 1])
score shape is 
 torch.Size([512, 2, 2])
score is 
 tensor([[[1.0000, 0.6601],
         [1.0000, 0.8789]],

        [[1.0000, 0.8691],
         [1.1437, 1.0000]],

        [[1.0000, 0.5393],
         [1.0000, 0.9956]],

In [5]:
tokens = 2
indeces = torch.arange(graph.dim).unsqueeze(0).unsqueeze(0)
indeces = indeces.expand(bs,tokens,-1)
# indeces = torch.arange(graph.dim).reshape(-1,1)
exp_qt = graph.transition(indeces,sigma[...,None])
# print(indeces)
print(exp_qt.shape) # Shape is (bs, token, dim, dim)
print(p.shape)
p = p.expand(bs,tokens,-1,-1)
# print(exp_qt)
print(p.shape)
p_t = exp_qt@p
print(p_t.shape)
print(p_t)

torch.Size([512, 2, 2, 2])
torch.Size([2, 2, 1])
torch.Size([512, 2, 2, 1])
torch.Size([512, 2, 2, 1])
tensor([[[[0.6024],
          [0.3976]],

         [[0.5322],
          [0.4678]]],


        [[[0.5350],
          [0.4650]],

         [[0.5335],
          [0.4665]]],


        [[[0.6496],
          [0.3504]],

         [[0.5011],
          [0.4989]]],


        ...,


        [[[0.7325],
          [0.2675]],

         [[0.5021],
          [0.4979]]],


        [[[0.6012],
          [0.3988]],

         [[0.5170],
          [0.4830]]],


        [[[0.8294],
          [0.1706]],

         [[0.5264],
          [0.4736]]]])
