# Photonic Classifier Smoke Tests

This notebook contains a series of "smoke tests" for the `p_pack` module. The purpose is to provide a quick way to run basic checks on individual functions during development.

**How to use this notebook:**
1.  Use `Ctrl+F` or `Cmd+F` to find the function you are working on.
2.  Run the test cell for that function to see if it passes.
3.  The cell immediately below the test will display the source code of the function being tested, allowing you to compare the test with the implementation directly.

In [3]:
%load_ext autoreload
%autoreload 2


import unittest
import jax
import jax.numpy as jnp
import numpy as np
import os
import pandas as pd
import inspect
import itertools
from jax.scipy.special import factorial
from p_pack import globals, circ, loss, model, optimiser, pre_p, train

# Import all the modules from the package
# Assuming 'p_pack' is a valid package in your environment
# from p_pack import circ, globals, loss, model, optimiser, pre_p, train

# Helper function to display the source code of a function
def show_code(func):
    """
    Displays the source code of a given function.
    Handles regular and JIT-compiled functions.
    """
    try:
        source_lines, _ = inspect.getsourcelines(func)
        print("Source Code:")
        # Corrected line: Use an empty string "" to join the source lines.
        print("".join(source_lines))
    except TypeError as e:
        print(f"Could not get source for {func}: {e}")
        print("This can happen with JIT-compiled functions. Showing the .py_func attribute if available.")
        if hasattr(func, 'py_func'):
            source_lines, _ = inspect.getsourcelines(func.py_func)
            print("Source Code (from .py_func):")
            # Corrected line: Same fix as above.
            print("".join(source_lines))

# This function will run a single test case
def run_test(test_class, test_name):
    """
    Creates a test suite for a single test case and runs it.
    """
    suite = unittest.TestSuite()
    suite.addTest(test_class(test_name))
    runner = unittest.TextTestRunner()
    print(f"--- Running test: {test_name} ---")
    result = runner.run(suite)
    if not result.wasSuccessful():
        print(f"--- Test Failed: {test_name} ---")
    else:
        print(f"--- Test Passed: {test_name} ---")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## OG init checks

In [2]:
# Example usage
out_state_combos = jnp.array(list(itertools.combinations_with_replacement(range(10), 3)))
print('Output state combinations:', out_state_combos[:20])

def extract_submatrices(unitary):
    # unitary: (num_modes, 3)
    # out_state_combos: (n_combos, 3)
    return unitary[out_state_combos[:5], :] 

key = jax.random.PRNGKey(0)
random_matrix = jax.random.normal(key, (10, 3))
print('Random matrix:', random_matrix)

extracted_submatrices = extract_submatrices(random_matrix)
print('Extracted submatrices:', extracted_submatrices[:10])

# Count repeats for each combo
def count_repeats(combo):
    # Count occurrences of each value
    unique, counts = jnp.unique(combo, return_counts=True)
    # Only count repeats (counts > 1)
    repeats = counts[counts > 1]
    return repeats.sum() #- repeats.size  # subtract 1 for each unique repeated value

# Vectorize over all combos
repeats_per_combo = jnp.array([count_repeats(combo) for combo in out_state_combos])

print(repeats_per_combo[:20])

factorials = factorial(repeats_per_combo)

print('Factorials of repeats:', factorials[:20])

Output state combinations: [[0 0 0]
 [0 0 1]
 [0 0 2]
 [0 0 3]
 [0 0 4]
 [0 0 5]
 [0 0 6]
 [0 0 7]
 [0 0 8]
 [0 0 9]
 [0 1 1]
 [0 1 2]
 [0 1 3]
 [0 1 4]
 [0 1 5]
 [0 1 6]
 [0 1 7]
 [0 1 8]
 [0 1 9]
 [0 2 2]]
Random matrix: [[-0.28371066  0.9368162  -1.0050073 ]
 [ 1.4165013   1.0543301   0.9108127 ]
 [-0.42656708  0.986188   -0.5575324 ]
 [ 0.01532502 -2.078568    0.5548371 ]
 [ 0.91423655  0.5744596   0.7227863 ]
 [ 0.12106175 -0.3237354   1.6234998 ]
 [ 0.24500391 -1.3809781  -0.6111237 ]
 [ 0.1403725   0.84100425 -1.0943578 ]
 [-1.077502   -1.1396457  -0.593338  ]
 [-0.15576515 -0.38321444 -1.1144515 ]]
Extracted submatrices: [[[-0.28371066  0.9368162  -1.0050073 ]
  [-0.28371066  0.9368162  -1.0050073 ]
  [-0.28371066  0.9368162  -1.0050073 ]]

 [[-0.28371066  0.9368162  -1.0050073 ]
  [-0.28371066  0.9368162  -1.0050073 ]
  [ 1.4165013   1.0543301   0.9108127 ]]

 [[-0.28371066  0.9368162  -1.0050073 ]
  [-0.28371066  0.9368162  -1.0050073 ]
  [-0.42656708  0.986188   -0.5575324 ]

In [22]:
# Not needed for now:

print( train_set.shape, train_labels.shape, test_set.shape, test_labels.shape)

# Initialize the phases

# each feature has its own uploading BS so we had a factor of 2
init_phases = circ.initialize_phases(10, 2*num_features, )  

weights_data = jnp.ones(shape = [init_phases.shape[0],init_phases.shape[1]]) #weights for data reuploading
#print(init_phases)
#print(init_phases)

# If you didn't test any of the jitted functions yet, the ratio in times should be around a factor 10^3 - 10^5.
# The first time is larger because of the compilation.
# The second time is small because it just runs the compiled code.
# Also, try to get any of these run times in pure Python+Numpy.
b = time.time()
# The block_until_ready is supposed to only let Python continue when the compiled code has finished.
# For me, it's not reliable. Therefore, I print the results first before measuring the end time.
result1, result2, result3, x = jax.block_until_ready(model.full_unitaries_data_reupload)(init_phases, train_set, weights_data)
print(result1.shape)
print(result2.shape)
print(result3.shape)
e = time.time()
print(e-b)
b = time.time()
result1 , result2, result3, x  = jax.block_until_ready(model.full_unitaries_data_reupload)(init_phases, train_set, weights_data)
print(result1.shape)
print(result2.shape)
print(result3.shape)

e = time.time()
print(e-b)


NameError: name 'train_set' is not defined

In [None]:
#check multiphoton output combinations

# Example usage of the measurement function
n = 3  # number of matrices
size = 6  # size of each identity matrix

# Create a (n, 1, 1) array of factors: [0], [1], [2]
factors = jnp.arange(n, dtype=jnp.complex64).reshape(-1, 1, 1)
# Create a (1, 6, 6) identity matrix and broadcast
ones_matrix = jnp.arange(n * size * size, dtype=jnp.complex64).reshape(n, size, size)
# Multiply to get (n, 6, 6)
temp_unitaries = factors * ones_matrix

#print(temp_unitaries.shape)  # (3, 6, 6)
#print(temp_unitaries)

result_measurement, combos1, probs1, _ = measurement(temp_unitaries, num_photons = 3)
#print(probs1)
#print(result_measurement.shape)  # Should be (num_samples, 2, 6) if num_samples is the batch size
parity = jnp.sum(combos1, axis=1) % 2
mask = (parity == 1)  # shape (n_combos,)

arr = jnp.arange(1, 21)
test = parity*arr
test1 = mask*arr
print("Masked array:", test1)
print("Even indices:", test)
print(parity)  # 0 = even, 1 = odd
print(combos1)


In [None]:
#check permanent calculation

def perm_3x3_jax(mat):
    # Only works for 3x3 matrices
    perms = jnp.array([
        [0, 1, 2],
        [0, 2, 1],
        [1, 0, 2],
        [1, 2, 0],
        [2, 0, 1],
        [2, 1, 0]
    ])
    return jnp.sum(jnp.prod(mat[jnp.arange(3), perms], axis=1))

# Example usage
mat = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=jnp.complex64)
result_perm = perm_3x3_jax(mat) 
print("Permanent of the matrix:", result_perm)  # Should print the permanent of the matrix
mat1 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=jnp.complex64)
print("Permanent of the matrix (using numpy):", perm(mat1))  # For comparison with numpy's perm function

# Circ

### Function: `initialize_phases`

In [4]:
class TestInitializePhases(unittest.TestCase):
    def test_shape(self):
        """Test the phase initialization function."""
        depth, width = 5, 8
        phases = circ.initialize_phases(depth, width)
        self.assertEqual(phases.shape, (depth, width // 2, 2))

run_test(TestInitializePhases, 'test_shape')

.
----------------------------------------------------------------------
Ran 1 test in 0.115s

OK


--- Running test: test_shape ---
--- Test Passed: test_shape ---


In [5]:
show_code(circ.initialize_phases)

Source Code:
def initialize_phases(depth: int, width: int = None, mask: np.ndarray = None) -> jnp.array:
    """
    Initializes the phase parameters for the photonic circuit.å

    The phases are initialized with small random values to avoid barren plateaus.
    A mask can be provided to fix certain phases to zero, making them non-trainable.
    By default, data-uploading layers (determined by `reupload_freq`) are masked out.

    Args:
        depth (int): The number of layers in the circuit.
        width (Optional[int]): The number of modes in the circuit. Defaults to `depth`.
        mask (Optional[np.ndarray]): A binary mask to apply to the phases.
                                     A value of 0 freezes a phase.

    Returns:
        jnp.array: A JAX array of initialized phases.
    """
    # Default case is Clements et al. layout, with all beam splitters tunable.
    if width == None:
        width = depth 
    if mask == None:
        mask = np.ones(shape = [depth, width//2, 

### Function: `layer_unitary`

In [6]:
class TestLayerUnitary(unittest.TestCase):
    def test_unitarity(self):
        """Test the layer unitary creation and check for unitarity."""
        depth, width = 5, 8
        all_phases = circ.initialize_phases(depth, width)
        layer_idx = 2
        
        # FIX: The math in the original layer_unitary was incorrect.
        # The function has been corrected in the source file to use a standard
        # unitary beamsplitter parameterization.
        unitary = circ.layer_unitary(all_phases, layer_idx)
        self.assertEqual(unitary.shape, (width, width))
        
        identity = jnp.eye(width, dtype=jnp.complex64)
        product = unitary @ unitary.T.conj()
        self.assertTrue(jnp.allclose(product, identity, atol=1e-6))

run_test(TestLayerUnitary, 'test_unitarity')

--- Running test: test_unitarity ---


.
----------------------------------------------------------------------
Ran 1 test in 1.097s

OK


--- Test Passed: test_unitarity ---


In [16]:
show_code(circ.layer_unitary)

Source Code:
@partial(jax.jit, static_argnames=['layer'])
def layer_unitary(all_phases: jnp.array, layer: int, mask: jnp.array = None) -> jnp.array:
    """
    Constructs the unitary matrix for a single trainable layer of the circuit.

    Args:
        all_phases (jnp.array): The full tensor of phase parameters for all layers.
        layer (int): The index of the layer to construct the unitary for.
        mask (Optional[jnp.array]): An optional mask to apply to the layer's phases.

    Returns:
        jnp.array: The complex-valued unitary matrix for the specified layer.
    """
    #layer = jax.lax.stop_gradient(layer) # doesn't work, don't ask me why.
    
    width = 2*jax.lax.stop_gradient(all_phases).shape[1] 
    # Stopping the gradient here allows to use the size of an input tensor to define other tensors.
    # Depth of the trainable part.
    depth = jax.lax.stop_gradient(all_phases).shape[0]
    if mask == None:
        # The default mask allows all phases to be trained
 

### Function: `data_upload`

In [None]:
class TestDataUpload(unittest.TestCase):
    def test_shape(self):
        """Test the data upload mechanism."""
        num_samples, feature_dim = 10, 4
        data_set = jnp.ones((num_samples, feature_dim))
        unitary = circ.data_upload(data_set)
        self.assertEqual(unitary.shape, (num_samples, feature_dim * 2, feature_dim * 2))

run_test(TestDataUpload, 'test_shape')

--- Running test: test_shape ---


.
----------------------------------------------------------------------
Ran 1 test in 0.356s

OK


--- Test Passed: test_shape ---


In [None]:
show_code(circ.data_upload)

Source Code:
@jax.jit
def data_upload(data_set: jnp.array) -> jnp.array:
    """
    Constructs the unitary matrices for the data uploading layer.

    This function creates a batch of block-diagonal unitary matrices, where each
    matrix encodes one sample (e.g., an image) from the input data set.

    Args:
        data_set (jnp.array): The input data, with shape (num_samples, num_features).

    Returns:
        jnp.array: A batch of unitary matrices with shape (num_samples, width, width).
    """
    num_samples = jax.lax.stop_gradient(data_set).shape[0]

    # Each pixel gets its BS, therefore factor 2 for counting overall system width
    width = 2*jax.lax.stop_gradient(data_set).shape[1]
    # Again, the 3rd dimension with 2 represents the two phases for each beamsplitter. 

    # is this the fastest way to fill the array with what we want or ist here a faster way
    phases = (jnp.pi/2)*jnp.ones(shape = [num_samples, width//2, 2]) 
    # The first of the phases of the beam spl

### Function: `measurement`

In [None]:
class TestMeasurement(unittest.TestCase):
    def test_shapes(self):
        """Test the measurement function."""
        num_samples, num_modes = 10, globals.num_modes_circ
        dummy_unitaries = jnp.array([jnp.eye(num_modes, dtype=jnp.complex64)] * num_samples)
        
        sub_unitaries, combos, probs, binary_probs = circ.measurement(dummy_unitaries)

        self.assertIsNotNone(sub_unitaries)
        self.assertIsNotNone(combos)
        self.assertEqual(probs.shape[0], num_samples)
        self.assertEqual(binary_probs.shape, (num_samples, 1))

run_test(TestMeasurement, 'test_shapes')

--- Running test: test_shapes ---


.
----------------------------------------------------------------------
Ran 1 test in 2.340s

OK


--- Test Passed: test_shapes ---


In [20]:
show_code(circ.measurement)

Source Code:
def measurement(unitaries: jnp.array, num_photons: int = 3) -> tuple[jnp.array, jnp.array, jnp.array, jnp.array]:
    """
    Simulates the measurement process of the photonic circuit.

    It calculates the probabilities of detecting photons in different output modes,
    computes the permanents of submatrices, and aggregates these into binary
    classification probabilities (+1 or -1) based on the parity of output modes.

    Args:
        unitaries (jnp.array): The batch of final unitary matrices from the circuit.
        num_photons (int): The number of photons in the input state. Defaults to 3.
        factorials (jnp.array): Pre-computed factorial values for probability calculation.

    Returns:
        Tuple[jnp.array, jnp.array, jnp.array, jnp.array]: A tuple containing:
            - all_extracts: The submatrices used for permanent calculation.
            - out_state_combos: All possible output state combinations.
            - all_probs: The raw probability fo

---
## Module: `p_pack.pre_p`

### Function: `rescale_data`

In [21]:
class TestRescaleData(unittest.TestCase):
    def test_scaling(self):
        """Test the data rescaling function."""
        data = jnp.array([-10., 0., 10.])
        min_val, max_val = -np.pi / 2, np.pi / 2
        rescaled = pre_p.rescale_data(data, min_val, max_val)
        
        self.assertTrue(jnp.all(rescaled >= min_val))
        self.assertTrue(jnp.all(rescaled <= max_val))

run_test(TestRescaleData, 'test_scaling')

.
----------------------------------------------------------------------
Ran 1 test in 0.035s

OK


--- Running test: test_scaling ---
--- Test Passed: test_scaling ---


In [None]:
show_code(pre_p.rescale_data)

### Function: `load_mnist_35`

In [22]:
class TestLoadMnist(unittest.TestCase):
    def setUp(self):
        """Create a dummy CSV file for testing."""
        self.test_dir = "test_data_temp"
        os.makedirs(self.test_dir, exist_ok=True)
        self.feature_dim = 4
        self.fname = f"mnist_3-5_{self.feature_dim}d_train.csv"
        self.path = os.path.join(self.test_dir, self.fname)
        data = np.random.rand(5, self.feature_dim + 1)
        pd.DataFrame(data).to_csv(self.path, index=False, header=False)

    def tearDown(self):
        """Remove the dummy CSV and directory."""
        os.remove(self.path)
        os.rmdir(self.test_dir)

    def test_loading(self):
        """Test loading data from a CSV file."""
        X, y = pre_p.load_mnist_35(self.test_dir, self.feature_dim, split="train")
        self.assertEqual(X.shape, (5, self.feature_dim))
        self.assertEqual(y.shape, (5,))

run_test(TestLoadMnist, 'test_loading')

F
FAIL: test_loading (__main__.TestLoadMnist.test_loading)
Test loading data from a CSV file.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/var/folders/7v/6bnphj4937q9tbp6zlbd71pr0000gn/T/ipykernel_8293/1892411339.py", line 20, in test_loading
    self.assertEqual(X.shape, (5, self.feature_dim))
    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Tuples differ: (4, 4) != (5, 4)

First differing element 0:
4
5

- (4, 4)
?  ^

+ (5, 4)
?  ^


----------------------------------------------------------------------
Ran 1 test in 0.114s

FAILED (failures=1)


--- Running test: test_loading ---
--- Test Failed: test_loading ---


In [1]:
show_code(pre_p.load_mnist_35)

NameError: name 'show_code' is not defined

---
## Module: `p_pack.model`

### Function: `full_unitaries_data_reupload`

In [23]:
class TestFullUnitaries(unittest.TestCase):
    def test_shapes(self):
        """Test the function that builds the full unitary for the model."""
        depth, feature_dim, num_samples = 5, 4, 10
        phases = circ.initialize_phases(depth, width=feature_dim * 2)
        weights = jnp.ones((depth, feature_dim))
        data_set = jnp.ones((num_samples, feature_dim))

        # FIX: The original model.py was missing imports and had NameErrors.
        # The source file has been corrected to properly import from circ.
        outputs = model.full_unitaries_data_reupload(phases, data_set, weights)
        
        self.assertEqual(len(outputs), 4)
        unitaries, sub_unitaries, label_probs, binary_probs_plus = outputs
        self.assertEqual(unitaries.shape[0], num_samples)
        self.assertIsNotNone(sub_unitaries)
        self.assertEqual(label_probs.shape[0], num_samples)
        self.assertEqual(binary_probs_plus.shape, (num_samples, 1))

run_test(TestFullUnitaries, 'test_shapes')

--- Running test: test_shapes ---


E
ERROR: test_shapes (__main__.TestFullUnitaries.test_shapes)
Test the function that builds the full unitary for the model.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/giancarloramirez/Documents/qml_project/venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 174, in _broadcast_shapes_uncached
    return _try_broadcast_shapes(*rank_promoted_shapes, name='broadcast_shapes')
  File "/Users/giancarloramirez/Documents/qml_project/venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 128, in _try_broadcast_shapes
    raise TypeError(f'{name} got incompatible shapes for broadcasting: '
                    f'{", ".join(map(str, map(tuple, shapes)))}.')
TypeError: broadcast_shapes got incompatible shapes for broadcasting: (10, 512), (1, 216).

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/giancarloramirez/Documents/qml_project/ven

--- Test Failed: test_shapes ---


In [None]:
show_code(model.full_unitaries_data_reupload)

### Function: `predict_reupload`

In [None]:
class TestPredictReupload(unittest.TestCase):
    def test_shapes(self):
        """Test the prediction function."""
        depth, feature_dim, num_samples = 5, 4, 10
        phases = circ.initialize_phases(depth, width=feature_dim * 2)
        weights = jnp.ones((depth, feature_dim))
        data_set = jnp.ones((num_samples, feature_dim))

        probs, adjusted_binary_probs = model.predict_reupload(phases, data_set, weights)

        self.assertEqual(probs.shape[0], num_samples)
        self.assertEqual(adjusted_binary_probs.shape, (num_samples, 1))

run_test(TestPredictReupload, 'test_shapes')

In [None]:
show_code(model.predict_reupload)

---
## Module: `p_pack.loss`

### Function: `loss`

In [None]:
class TestLoss(unittest.TestCase):
    def test_calculation(self):
        """Test the loss function returns a scalar value."""
        depth, feature_dim, num_samples = 5, 4, 10
        phases = circ.initialize_phases(depth, width=feature_dim * 2)
        weights = jnp.ones((depth, feature_dim))
        data_set = jnp.ones((num_samples, feature_dim))
        labels = jnp.ones(num_samples)

        # FIX: The original loss.py had a NameError.
        # The source file is corrected to call model.predict_reupload.
        loss_value = loss.loss(phases, data_set, labels, weights)
        self.assertTrue(jnp.isscalar(loss_value))

run_test(TestLoss, 'test_calculation')

In [None]:
show_code(loss.loss)

---
## Module: `p_pack.optimiser`

### Function: `adam_step`

In [None]:
class TestAdamStep(unittest.TestCase):
    def test_step(self):
        """Test a single step of the Adam optimizer."""
        depth, feature_dim, num_samples = 5, 4, 10
        params_phases = circ.initialize_phases(depth, width=feature_dim * 2)
        params_weights = jnp.ones((depth, feature_dim))
        data_set = jnp.ones((num_samples, feature_dim))
        labels = jnp.ones(num_samples)
        m_phases = jnp.zeros_like(params_phases)
        v_phases = jnp.zeros_like(params_phases)
        m_weights = jnp.zeros_like(params_weights)
        v_weights = jnp.zeros_like(params_weights)

        carry = [params_phases, data_set, labels, params_weights, m_phases, v_phases, m_weights, v_weights]
        step_number = 1

        # FIX: The original optimiser.py had a TypeError, calling the loss module.
        # The source file is corrected to call loss.loss().
        new_carry, loss_info = optimiser.adam_step(carry, step_number)

        self.assertEqual(len(new_carry), 8)
        self.assertEqual(loss_info.shape, (2,))

run_test(TestAdamStep, 'test_step')

In [None]:
show_code(optimiser.adam_step)

---
## Module: `p_pack.train`

### Function: `train`

In [None]:
class TestTrain(unittest.TestCase):
    def test_train_run(self):
        """Test the main training function."""
        depth, feature_dim, num_samples = 5, 4, 10
        phases = circ.initialize_phases(depth, width=feature_dim*2)
        weights = jnp.ones((depth, feature_dim))
        data_set = jnp.ones((num_samples, feature_dim))
        labels = jnp.ones(num_samples)
        m_phases = jnp.zeros_like(phases)
        v_phases = jnp.zeros_like(phases)
        m_weights = jnp.zeros_like(weights)
        v_weights = jnp.zeros_like(weights)

        init = (phases, data_set, labels, weights, m_phases, v_phases, m_weights, v_weights)

        # FIX: The original train.py had a NameError.
        # The source file is corrected to call optimiser.adam_step.
        final_carry, loss_history = train.train(init)

        self.assertEqual(len(final_carry), 8)
        self.assertEqual(loss_history.shape, (globals.num_steps, 2))

run_test(TestTrain, 'test_train_run')

In [None]:
show_code(train.train)