In [None]:
from torch import nn
import torch
import torch.nn.functional as F

In [None]:
def softplus(x):
    lower = 1e-12
    return F.softplus(x) + lower

def lengthscales(var):
    return softplus(var)

def variance(var):
    return softplus(var)

In [None]:
def square_dist_dimwise(X, X2=None):
    """
    Computes squared euclidean distance (scaled) for dimwise kernel setting
    @param X: Input 1 (N,D_in)
    @param X2: Input 2 (M,D_in)
    @return: Tensor (D_out, N,M)
    """
    X = X.unsqueeze(0) / lengthscales(unconstrained_lengthscales).unsqueeze(1)  # (D_out,N,D_in)
    Xs = torch.sum(torch.pow(X, 2), dim=2)  # (D_out,N)
    if X2 is None:
        return -2 * torch.einsum('dnk, dmk -> dnm', X, X) + \
                Xs.unsqueeze(-1) + Xs.unsqueeze(1)  # (D_out,N,N)
    else:
        X2 = X2.unsqueeze(0) / lengthscales(unconstrained_lengthscales).unsqueeze(1)  # (D_out,M,D_in)
        X2s = torch.sum(torch.pow(X2, 2), dim=2)  # (D_out,N)
        return -2 * torch.einsum('dnk, dmk -> dnm', X, X2) + Xs.unsqueeze(-1) + X2s.unsqueeze(1)  # (D_out,N,M)

def square_dist(X, X2=None):
    """
    Computes squared euclidean distance (scaled) for non dimwise kernel setting
    @param X: Input 1 (N,D_in)
    @param X2: Input 2 (M,D_in)
    @return: Tensor (N,M)
    """
    X = X / lengthscales(unconstrained_lengthscales)  # (N,D_in)
    Xs = torch.sum(torch.pow(X, 2), dim=1)  # (N,)
    if X2 is None:
        return -2 * torch.matmul(X, X.t()) + \
                torch.reshape(Xs, (-1, 1)) + torch.reshape(Xs, (1, -1))  # (N,1)
    else:
        X2 = X2 / lengthscales(unconstrained_lengthscales)  # (M,D_in)
        X2s = torch.sum(torch.pow(X2, 2), dim=1)  # (M,)
        return -2 * torch.matmul(X, X2.t()) + torch.reshape(Xs, (-1, 1)) + torch.reshape(X2s, (1, -1))  # (N,M)


In [None]:
def difference_matrix_a(X, X2=None):
    '''
    Computes (X-X2)
    '''
    X = X / lengthscales(unconstrained_lengthscales)  # (N,D_in)
    if X2 is None:
        X2=X
    else:
        X2 = X2 / lengthscales(unconstrained_lengthscales) # (M,D_in)
    return X[:,None,:] - X2[None,:,:] #broadcasting rules (M,N, D_in)

def difference_matrix_dimwise(X, X2=None):
    '''
    Computes (X-X2)
    '''
    X = X.unsqueeze(0) / lengthscales(unconstrained_lengthscales).unsqueeze(1)   # (D_out,N,D_in)
    if X2 is None:
        X2=X
    else:
        X2 = X2.unsqueeze(0) / lengthscales(unconstrained_lengthscales).unsqueeze(1)  # (D_out,M,D_in)
    return X[:,:,None,:] - X2[:,None,:,:] #broadcasting rules (D_out, M, N, D_in)

In [None]:
def identity(X, X2=None):
    if X2 is None:
        return torch.eye(X.shape[0])
    else:
        return torch.eye(X2.shape[0])

dimwise False

In [None]:
unconstrained_lengthscales = nn.Parameter(torch.ones(size=(16,),requires_grad=True))
unconstrained_variance = nn.Parameter(torch.ones(size=(1,)), requires_grad=True)

In [None]:
X = torch.randint(5, (50,16)) #inducing
X2 = torch.randint(5, (25,16)) #data
X2=None

In [None]:
sq_dist = square_dist(X, X2)  # (N,M)
K2 = torch.exp(-0.5 * sq_dist) # (M,N)
K2 = K2.unsqueeze(0) # (1,M,N)
diff = difference_matrix_a(X, X2) #(M,N,D_in)
diff1 = torch.permute(diff, (0,2,1)) # (M, D_in, N)
K1_term = torch.einsum('mnd, mdn -> dmn', diff, diff1) # (D_in,M,N) #TODO not sure if this is correct
K3 = (16 - 1.0) - sq_dist # (M,N)
K3 = K3 @ identity(X,X2) # M,N
K3 = K3.unsqueeze(0) # 1, M, N
K = (K1_term + K3) * K2 # D_in, M, N
K = torch.permute(K,(1,2,0)) # M,N,D_in
l2 = torch.permute((1.0/torch.pow(softplus(unconstrained_lengthscales),2).unsqueeze(0)), (1,0))
K = K @ l2
K = K @ variance(unconstrained_variance).unsqueeze(-1)
# factor =  (variance(unconstrained_variance)/torch.pow(softplus(unconstrained_lengthscales),2).unsqueeze(0)) #1,D_in
# factor = torch.permute(factor, (1,0)) # D_in, 1
# K = (K @ factor).squeeze() #M,N

In [None]:
K.shape

dimwise True

In [None]:
unconstrained_lengthscales = nn.Parameter(torch.ones(size=(8,16),requires_grad=True))
unconstrained_variance = nn.Parameter(torch.ones(size=(8,)), requires_grad=True)

In [None]:
X = torch.randint(5, (50,16)) #inducing
X2 = torch.randint(5, (25,16)) #data
#X2=None

In [None]:
sq_dist = square_dist_dimwise(X, X2) # (D_out, M,N)
K2 = torch.exp(-0.5 * sq_dist) # (D_out, M,N)
K2 = K2.unsqueeze(0) # (1, D_out,M,N)
diff = difference_matrix_dimwise(X, X2) #(D_out, M,N,D_in)
diff1 = torch.permute(diff, (0,1,3,2)) # (D_out,M, D_in, N)
K1_term = torch.einsum('dmni, dmin -> idmn', diff, diff1) # (D_in, D_out,M,N) #TODO not sure if this is correct
K3 = (16 - 1.0) - sq_dist # (D_out,M,N)
K3 = K3 @ identity(X,X2) # D_out,M,N
K3 = K3.unsqueeze(0) # 1,D_out,M, N
K = (K1_term + K3) * K2 # D_in,D_out, M, N
K = torch.permute(K,(1,2,3,0)) # D_out,M,N,D_in
l2 = torch.permute((1.0/torch.pow(softplus(unconstrained_lengthscales),2)), (1,0))
K = K @ l2
K = K @ variance(unconstrained_variance).unsqueeze(-1)
K =  K.squeeze()

In [None]:
K.shape