In [220]:
import torch
import numpy as np
import matplotlib.pyplot as plt

def compute_Ct(x, X, eps=1e-3):
    """
    Computes C_t(x, X) which counts the number of times the distance between each
    element in X and x is less than eps (threshold).
    
    Args:
    - x (torch.Tensor): A tensor representing a batch of reference points (num_x, in_dim).
    - X (torch.Tensor): Matrix of shape (num_timesteps, N_Agents, in_dim).
    - eps (float): Small radius defining the threshold for the ball.
    
    Returns:
    - Ct (torch.Tensor): A tensor with the count for each x in the batch (num_x,).
    """
    distance = torch.linalg.vector_norm(X.unsqueeze(0) - x.unsqueeze(1).unsqueeze(1).unsqueeze(1), dim=4, ord = float('inf'))  # Broadcasting over x
    #Ct_approx = torch.exp(-0.5 * ((distance - eps) / 1e-4)**2).sum(dim=(1, 2))
    Ct_approx = (torch.clamp(eps - distance, min=0) * (3 / eps)).sum(dim=(2, 3))
    Ct_true = (distance < eps).float().sum(dim=(2, 3))
    #Ct =distance.sum(dim=(1, 2))
    #print(Ct.grad_fn)
    return Ct_true, Ct_approx

# Create synthetic data
num_timesteps = 10
N_Agents = 5
in_dim = 2
#num_grid_points = 100

# Random tensor representing agent positions (X)
X = torch.randn(16, num_timesteps, N_Agents, in_dim, requires_grad = True)
# Define a range for x (reference points) over a meshgrid
eps = 1e-3
x_vals = torch.arange(0 ,1, eps * 2)  # Define range for x (1D grid)
num_grid_points = x_vals.shape[0]
X1, X2 = torch.meshgrid(x_vals, x_vals)  # Create a 2D meshgrid for x

# Flatten the meshgrid to make it easier to handle in batches
x_grid = torch.stack([X1.flatten(), X2.flatten()], dim=1)  # Shape (num_x, 2), where num_x is the number of x values
Ct_values = compute_Ct(x_grid, X, eps)

In [218]:
print(Ct_values[0].shape)


torch.Size([10000, 16])


In [175]:
from torch import nn

tensor(0.0291)

In [224]:
coeffs = torch.fft.rfft2(Ct_values[0][:,1].reshape(num_grid_points,num_grid_points), norm = 'ortho').real
#crit = nn.MSELoss()
#loss = crit(coeffs, torch.randn_like(coeffs, requires_grad=False))
#loss.backward()


In [6]:
""" function C_t """
import torch

# Function to compute the indicator function I_B(x_j(t_i), epsilon)
def indicator_function(x, x_j_point, epsilon=1e-2):
    distance = torch.norm(x - x_j_point)
    return (distance < epsilon).float()

# Function to compute C_t(x)
def compute_Ct(x, X):
    """
    X is matrix of computed values X_j shape(num_timessteps, N_Agents, in_dim)
    eps: radius of ball
    """
    eps = 1e-3
    return (torch.linalg.vector_norm(X - x , axis = 2) < eps).float().sum()

n = 10
N = 100
X = torch.randn([N,n,2])
compute_Ct(torch.tensor([0.1,0.5]), X)

tensor(0.)

In [136]:
import numpy as np
from scipy.integrate import nquad

def evaluate_integral(k, L):
    """
    Evaluate the integral for multiple k and L values.
    
    Parameters:
    k: List of k values [k1, k2, ...]
    L: List of L values [L1, L2, ...]

    Returns:
    Tensor of the square root of the evaluated integrals.
    """
    def integrand(*args):
        result = np.array([1.0])
        for i,coeff in enumerate(k.numpy()):
            result *= np.cos(coeff * args[i])**2
        return result

    # Define the integration limits for all dimensions
    limits = [(0, L) for L in L.numpy()]

    # Compute the nested integral
    integral, _ = nquad(integrand, limits)
    return torch.tensor(np.sqrt(integral))

0.5290007198584084


tensor(0.7273, dtype=torch.float64)

In [279]:
def mu_k_n_dim(k, mu_function, grid_step=0.01):
    """
    Computes the Fourier coefficient µ_k for the given Fourier indices k_indices in n dimensions.
    
    Args:
    - k_indices: List of integer indices for the Fourier modes (k1, k2, ..., kn).
    - mu_function: Function that takes (x1, x2, ..., xn) and returns the value of µ(x1, x2, ..., xn).
    - domain_lengths: List of lengths of the domain for each dimension (default is [0,1] in each dimension).
    - grid_step: The step size for discretizing the domain.

    Returns:
    - The Fourier coefficient µ_k.
    """

    L = torch.tensor([1.0, 1.0])
    n = len(k)
    grids = [torch.arange(0, L, grid_step) for L in L]
    meshgrid = torch.meshgrid(*grids)
    X = torch.stack(meshgrid, dim=-1)
    mu_values = mu_function(*X.unbind(dim=-1))

    # Fourier coefficients calculation
    k = torch.tensor(k, dtype=torch.float32)
    fourier_factors = [torch.cos(k[i] * X[..., i] * torch.pi / L[i]) for i in range(n)]
    print(fourier_factors[0].shape)
    # Compute the Fourier basis function f_k(x1, x2, ..., xn)
    fk = torch.prod(torch.stack(fourier_factors, dim=-1), dim=-1)
    print(fk.shape)
    # Compute the normalization constant h_k
    hk = torch.sqrt(torch.sum(fk**2)) * grid_step

    # Normalize the Fourier basis function f_k
    fk = fk / hk

    # Compute the discrete sum to approximate the inner product (integral)
    
    # Compute the Fourier coefficient µ_k
    print(fk.shape, mu_values.shape)
    mu_k_value = torch.sum(mu_values * fk * grid_step**2)

    return mu_k_value.item()

# Example usage:
# Define a function for the PDF (example: Gaussian function)
def mu_function(*x):
    # Assuming Gaussian PDF centered at (0.5, 0.5, ..., 0.5)
    return torch.ones([x[i].shape[0] for i in range(len(x))])
    return (1 / ((torch.pi) *(0.01))) * torch.exp(-sum([(xi - 0.5)**2 for xi in x]) / 0.01)
# Compute the Fourier coefficient for a given k_indices (e.g., [1, 1, 1] for 3 dimensions)
k_indices = [0, 0]
result = mu_k_n_dim(k_indices, mu_function)

print(f"mu_k = {result}")

torch.Size([100, 100])
torch.Size([100, 100])
torch.Size([100, 100]) torch.Size([100, 100])
mu_k = 0.9999995827674866


In [249]:
x = torch.ones([2,2,2,3])
print(x)
x[...,1]

tensor([[[[1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.]]]])


tensor([[[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]]])

In [233]:
from itertools import product
class Ergodicity_Loss(nn.Module):
    def __init__(self, N_Agents, n_timesteps):
        super(Ergodicity_Loss, self).__init__()
        self.N_Agents = N_Agents
        self.n_timesteps = n_timesteps
        self.L = torch.tensor([1.,1.]) ## rectangular area
        self.in_dim = 2 ## Dimension of rect

    def compute_normalization_constant(self,k):
        """ 
        h_k
        """
        return evaluate_integral(k, self.L)

    def fourier_basis(self,x, k):
        """ 
        x: State at time t_k [Num_timesteps ,Batch_size, N_Agents, in_dim]
        k torch.tensor (in_dim)
        """
        k *= torch.pi * (self.L)**(-1)
        h_k = self.compute_normalization_constant(k)
        return (torch.cos(x * k.view(1,1,1,-1)).prod(dim = -1)) / h_k

        
    def compute_fourier_coefficients_agents_at_time_t(self,x, k):
        """
        x: State of Agents [Num_timesteps ,Batch_size, N_Agents, in_dim] 
        """
        # For now i just put as calculaated t 1s
        transform = self.fourier_basis(x,k)
        c_k = transform.sum(dim=-2).sum(dim=-1)    
        return c_k

    def compute_fourier_coefficients_density(self,x, k):
        """
        x: State of Agents [Num_timesteps ,Batch_size, N_Agents, in_dim] 
        """
        # For now i just put as calculaated t 1s
        transform = self.fourier_basis(x.view(-1, self.in_dim),k)
        result = transform.view(self.n_timesteps, -1 , self.N_Agents)
        c_k = result.sum(dim=0).sum(dim=1, keepdim=True)         
        return c_k

    def forward(self,x):
        """
        x: State of Agents [Num_timesteps ,Batch_size, N_Agents, in_dim] 
        """
        loss = torch.zeros(1)
        k = list(range(n+1))
        for sets in product(k, repeat = 4):
            loss += (self.compute_fourier_coefficients(x,k) - self.compute_fourier_coefficients_density(k))

        
        

In [234]:
j = 6
k = 499.
m = 499. 
coeffs = torch.fft.fft2(Ct_values[0][:,j].reshape(num_grid_points,num_grid_points), norm = 'ortho')
print(coeffs[int(k),int(m)])
Loss = Ergodicity_Loss(N_Agents, num_timesteps)
Loss.compute_fourier_coefficients_agents_at_time_t(X, k = torch.tensor([k,m]))[j]

tensor(-0.0033+0.0010j)


tensor(0.6824, grad_fn=<SelectBackward0>)

In [None]:
# Define grid size and frequency components
N1, N2 = 128, 128
K1, K2 = 1, 1  # Example frequencies
x1 = torch.linspace(0, 2 * torch.pi, N1)
x2 = torch.linspace(0, 2 * torch.pi, N2)

# Create a 2D grid of values (x1, x2)
X1, X2 = torch.meshgrid(x1, x2)
Loss.fourier_basis()
f_k_fft = torch.fft.fft2(fk)
print(f_k_fft)

In [229]:
coeffs.shape

torch.Size([500, 500])

In [None]:
x1 = torch.linspace(0, 1)  # Grid for x1
x2 = torch.linspace(0, 1)  # Grid for x2
X1, X2 = torch.meshgrid(x1, x2)
Z = compute_Ct(X1, X2)
fftn = torch.fft.fftn(x)