In [5]:
import dynamiqs as dq
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from scipy.integrate import simpson
from scipy.ndimage import gaussian_filter

# Task A

In [7]:
#utility functions for affine distortions
#affinely distort wigner function
def affine_distort(wigner, alpha, b):
    return alpha*wigner+b

#add gaussian noise to wigner function
def add_gaussian_noise(wigner, sigma):
    noise = sigma*jax.random.normal(jax.random.key(0), wigner.shape)
    return jnp.add(wigner, noise)

#calculate the norm of a wigner function
def integrate_wigner(wigner, x, y):
    return simpson(simpson(wigner, y), x)

#calculate the b value of the affine distortion
def recover_affine_offset(wigner):
    x_len, y_len = wigner.shape
    #assuming the corners don't make too much of a difference
    return 0.25*(jnp.mean(wigner[0,:]) + jnp.mean(wigner[x_len-1,:]) + jnp.mean(wigner[:,0]) + jnp.mean(wigner[:,y_len-1]))

#calculate the alpha value of the affine distortion
def recover_affine_scaling(zeroed_wigner, x, y):
    return integrate_wigner(zeroed_wigner, x, y)

#remove the affine distortion of a wigner function
def remove_affine_distortion(affine_wigner, x, y):
    b = recover_affine_offset(affine_wigner)
    zeroed_affined_wigner = affine_wigner - b
    alpha = recover_affine_scaling(wigner, x, y)
    return zeroed_affine_wigner/alpha

#run the simple affine denoising + gaussian filter pipeline
def remove_affine_gaussian_noise(noisy_wigner, x, y, sigma):
    #remove affine noise
    affine_less_wigner = remove_affine_distrotion(noisy_wigner, x, y)
    #apply gaussian filter
    filtered_wigner = gaussian_filter(affine_less_wigner, sigma)
    return filtered_wigner