In [None]:
import numpy as np
from functools import reduce


class Wigner():
    """Wigner function for a Gaussian distribution in phase space."""
    
    def __init__(self, mean, sigma):
        """
        Initialize the Wigner function with mean and covariance matrix.
        
        Parameters
        -----------
        mean (ndarray[K, N]) : Mean vector of the Gaussian distribution. Can be a 1D or 2D array.
            - If 1D, it is treated as a single point in K-dimensional space.  
            - If 2D, each row corresponds to a point in K-dimensional space.  
            - K is the number of points and N is the dimensionality of the mean vector.  
            - The mean vector should be of shape (N,) or (K, N).  
        
        
        sigma (ndarray[N, N]) : Covariance matrix of the Gaussian distribution. It should be a square matrix of shape (N, N) where N is the dimensionality of the mean vector.
        """
        
        self.state = {
            'mean': mean,
            'cov': sigma,
        }
        self.mean = mean
        self.sigma = sigma
        self.dim = len(mean) if np.ndim(mean) == 1 else mean.shape[1]
        
        # self.dim = len(self.mean)
        # self.__setattr__("mean", mean)
        # self.__setattr__("cov", sigma)
        # self.__setattr__("dim", len(sigma))
        
        self.sigmaDet = np.linalg.det(self.sigma)
        self.sigmaInv = np.linalg.inv(self.sigma)
        self.norm = 1/((2*np.pi)**self.dim * np.sqrt(self.sigmaDet))
        
        
        
    def updateState(self, F, d):
        """
        Update the state of the Wigner function with new mean and covariance.
        
        Parameters
        -----------
        mean (ndarray[K, N]) : New mean vector of the Gaussian distribution.
        cov (ndarray[N, N]) : New covariance matrix of the Gaussian distribution.
        """
        
        # assert len(d) == self.dim, "Displacement vector must match the dimensionality of the mean vector."
        # assert F.shape == (2, 2), "Transformation matrix F must be a 2x2 matrix."
        
        self.mean = np.einsum("ij,j->i", F, self.mean) + d
        self.sigma = np.einsum("ij,jk,lk->ik", F, self.sigma, F) # Everyone and their mom is using einsum now huh
        
        self.state["mean"] = self.mean
        self.state["cov"] = self.sigma
        
        try:
            
            self.sigmaDet = np.linalg.det(self.sigma)
            self.sigmaInv = np.linalg.inv(self.sigma)
            self.norm = 1/((2*np.pi)**self.dim * np.sqrt(self.sigmaDet))
        except np.linalg.LinAlgError as e:
            raise ValueError(f"Covariance matrix is singular or not positive definite.\nDeterminant :\n{self.sigmaDet}\nInverse :\n{self.sigmaInv}") from e
        
    
    def displacement(self, alpha):
        if isinstance(alpha, (np.ndarray, list)):
            assert len(alpha) == self.dim, f"Displacement vector of shape {alpha.shape} must match the dimensionality of the mean vector of shape {self.dim}." 
        elif isinstance(alpha, (int, float, complex)):
            alpha = np.atleast_1d(alpha)
        else:
            raise TypeError("Displacement alpha must be a (complex) number, list, or numpy array.")
        
        F = np.eye(2)
        # flatten the complex vector to the form of [Re(alpha1), Im(alpha1), Re(alpha2), Im(alpha2), ...]
        alpha = np.atleast_1d(alpha)
        alphavec = np.stack((np.real(alpha), np.imag(alpha)), axis=0).T.ravel() 
        d = np.sqrt(2) * alphavec
        self.updateState(F, d)
        
    
    def phaseShift(self, phi):
        
        F = np.array([[np.cos(phi), -np.sin(phi)],
                      [np.sin(phi), np.cos(phi)]])
        d = np.zeros(self.dim)
        # mean = np.einsum("ij,j->i", F, self.mean) + d
        # sigma = np.einsum("ij,jk,ik->ii", F, self.sigma, F)
        self.updateState(F, d)
    
    def BS(self, eta): # I should probably have implemented a BS function too
        """
        Apply a beam splitter transformation to the Wigner function.
        
        Parameters
        -----------
        eta (float) : Transmittance of the beam splitter (0 <= eta <= 1).
        """
        assert 0 <= eta <= 1, "Transmittance must be between 0 and 1."
        F = np.array([[np.sqrt(eta), np.sqrt(1 - eta)],
                      [np.sqrt(1 - eta), -np.sqrt(eta)]])
        d = np.zeros(self.dim)
        self.updateState(F, d)
        
    def __repr__(self):
        
        out_string  = f"Wigner function :\n"
        out_string += f"        Mean vector : {self.mean}\n"
        out_string += f"  Covariance matrix :\n"
        out_string += f"{self.sigma}, shape={self.sigma.shape}\n"
        # out_string += f"shape : {self.sigma.shape}"
        
        return out_string
        
            
    def __call__(self, r, kwargs=None):
        """
        Evaluate the Wigner function at points r.
        
        Parameters
        -----------
        r (ndarray[N, K]) : Points in phase space where the Wigner function is evaluated. Each row corresponds to a point in K-dimensional space.
        kwargs (dict) : Optional dictionary to update the state of the Wigner function. It can contain 'mean' or 'cov' to update the mean vector or covariance matrix respectively.
        Returns
        --------
        ndarray : Values of the Wigner function at the points r.
        """
        if kwargs is not None:
            if np.any(["mean", "cov"]) in kwargs:
                if "mean" in kwargs:
                    self.mean = kwargs["mean"]
                if "cov" in kwargs:
                    self.sigma = kwargs["cov"]
            else:
                raise ValueError("kwargs must contain 'mean' or 'cov' to update the state.")
        
        r = np.atleast_2d(r) # Ensure r is at least 2D
        assert r.shape[-1] == self.mean.shape[0], "Each input point must match mean's shape"
        diff = r - self.mean
        
        exponent = -0.5 * np.einsum("ni,ij,nj->n...", diff , self.sigmaInv , diff)
        # exponent = -0.5 * diff.T @ self.sigmaInv @ diff
        
        return self.norm * np.exp(exponent)



def test_input(size):
    I = np.eye(size)
    sigma = I / np.linalg.norm(I)
    r = np.zeros((size, ))
    return r, sigma

rbar, sigma = test_input(2) # Example input for a 2D Gaussian distribution

wigner = Wigner(rbar, sigma) # Initialize the Wigner function with mean and covariance matrix

# test the Wigner function on a grid of points in phase space
x = np.linspace(-5, 5, 200)
p = np.linspace(-5, 5, 200)
X, P = np.meshgrid(x, p)
grid_points = np.vstack([X.ravel(), P.ravel()]).T

# calculate the Wigner function for the grid points
wigner_values = wigner(grid_points).reshape(X.shape)
print("mean :\n{:}\n\ncov :\n{:}".format(wigner.mean, wigner.sigma))
# wigner_values[10,2] = np.inf  # Set a value to infinity for testing purposes
with np.printoptions(precision=3, linewidth=160, suppress=True, threshold=600, infstr=" oo "):
    print("Wigner function values:\n{:}".format(wigner_values)) # This print is kinda funky and i like it