# Photonic Classifier Smoke Tests

This notebook contains a series of "smoke tests" for the `p_pack` module. The purpose is to provide a quick way to run basic checks on individual functions during development.

**How to use this notebook:**
1.  Use `Ctrl+F` or `Cmd+F` to find the function you are working on.
2.  Run the test cell for that function to see if it passes.
3.  The cell immediately below the test will display the source code of the function being tested, allowing you to compare the test with the implementation directly.

In [15]:
%load_ext autoreload
%autoreload 2


import unittest
import jax
import jax.numpy as jnp
import numpy as np
import os
import pandas as pd
import inspect
from p_pack import globals, circ, loss, model, optimiser, pre_p, train

# Import all the modules from the package
# Assuming 'p_pack' is a valid package in your environment
# from p_pack import circ, globals, loss, model, optimiser, pre_p, train

# Helper function to display the source code of a function
def show_code(func):
    """
    Displays the source code of a given function.
    Handles regular and JIT-compiled functions.
    """
    try:
        source_lines, _ = inspect.getsourcelines(func)
        print("Source Code:")
        # Corrected line: Use an empty string "" to join the source lines.
        print("".join(source_lines))
    except TypeError as e:
        print(f"Could not get source for {func}: {e}")
        print("This can happen with JIT-compiled functions. Showing the .py_func attribute if available.")
        if hasattr(func, 'py_func'):
            source_lines, _ = inspect.getsourcelines(func.py_func)
            print("Source Code (from .py_func):")
            # Corrected line: Same fix as above.
            print("".join(source_lines))

# This function will run a single test case
def run_test(test_class, test_name):
    """
    Creates a test suite for a single test case and runs it.
    """
    suite = unittest.TestSuite()
    suite.addTest(test_class(test_name))
    runner = unittest.TextTestRunner()
    print(f"--- Running test: {test_name} ---")
    result = runner.run(suite)
    if not result.wasSuccessful():
        print(f"--- Test Failed: {test_name} ---")
    else:
        print(f"--- Test Passed: {test_name} ---")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


--- 
## Module: `p_pack.circ`

### Function: `initialize_phases`

In [None]:
class TestInitializePhases(unittest.TestCase):
    def test_shape(self):
        """Test the phase initialization function."""
        depth, width = 5, 8
        phases = circ.initialize_phases(depth, width)
        self.assertEqual(phases.shape, (depth, width // 2, 2))

run_test(TestInitializePhases, 'test_shape')

In [None]:
show_code(circ.initialize_phases)

### Function: `layer_unitary`

In [18]:
class TestLayerUnitary(unittest.TestCase):
    def test_unitarity(self):
        """Test the layer unitary creation and check for unitarity."""
        depth, width = 5, 8
        all_phases = circ.initialize_phases(depth, width)
        layer_idx = 2
        
        # FIX: The math in the original layer_unitary was incorrect.
        # The function has been corrected in the source file to use a standard
        # unitary beamsplitter parameterization.
        unitary = circ.layer_unitary(all_phases, layer_idx)
        self.assertEqual(unitary.shape, (width, width))
        
        identity = jnp.eye(width, dtype=jnp.complex64)
        product = unitary @ unitary.T.conj()
        self.assertTrue(jnp.allclose(product, identity, atol=1e-6))

run_test(TestLayerUnitary, 'test_unitarity')

.
----------------------------------------------------------------------
Ran 1 test in 0.007s

OK


--- Running test: test_unitarity ---
--- Test Passed: test_unitarity ---


In [None]:
show_code(circ.layer_unitary)

### Function: `data_upload`

In [19]:
class TestDataUpload(unittest.TestCase):
    def test_shape(self):
        """Test the data upload mechanism."""
        num_samples, feature_dim = 10, 4
        data_set = jnp.ones((num_samples, feature_dim))
        unitary = circ.data_upload(data_set)
        self.assertEqual(unitary.shape, (num_samples, feature_dim * 2, feature_dim * 2))

run_test(TestDataUpload, 'test_shape')

.
----------------------------------------------------------------------
Ran 1 test in 0.057s

OK


--- Running test: test_shape ---
--- Test Passed: test_shape ---


In [None]:
show_code(circ.data_upload)

### Function: `measurement`

In [20]:
class TestMeasurement(unittest.TestCase):
    def test_shapes(self):
        """Test the measurement function."""
        num_samples, num_modes = 10, globals.num_modes_circ
        dummy_unitaries = jnp.array([jnp.eye(num_modes, dtype=jnp.complex64)] * num_samples)
        
        sub_unitaries, combos, probs, binary_probs = circ.measurement(dummy_unitaries)

        self.assertIsNotNone(sub_unitaries)
        self.assertIsNotNone(combos)
        self.assertEqual(probs.shape[0], num_samples)
        self.assertEqual(binary_probs.shape, (num_samples, 1))

run_test(TestMeasurement, 'test_shapes')

--- Running test: test_shapes ---


.
----------------------------------------------------------------------
Ran 1 test in 0.272s

OK


--- Test Passed: test_shapes ---


In [None]:
show_code(circ.measurement)

---
## Module: `p_pack.pre_p`

### Function: `rescale_data`

In [21]:
class TestRescaleData(unittest.TestCase):
    def test_scaling(self):
        """Test the data rescaling function."""
        data = jnp.array([-10., 0., 10.])
        min_val, max_val = -np.pi / 2, np.pi / 2
        rescaled = pre_p.rescale_data(data, min_val, max_val)
        
        self.assertTrue(jnp.all(rescaled >= min_val))
        self.assertTrue(jnp.all(rescaled <= max_val))

run_test(TestRescaleData, 'test_scaling')

.
----------------------------------------------------------------------
Ran 1 test in 0.035s

OK


--- Running test: test_scaling ---
--- Test Passed: test_scaling ---


In [None]:
show_code(pre_p.rescale_data)

### Function: `load_mnist_35`

In [22]:
class TestLoadMnist(unittest.TestCase):
    def setUp(self):
        """Create a dummy CSV file for testing."""
        self.test_dir = "test_data_temp"
        os.makedirs(self.test_dir, exist_ok=True)
        self.feature_dim = 4
        self.fname = f"mnist_3-5_{self.feature_dim}d_train.csv"
        self.path = os.path.join(self.test_dir, self.fname)
        data = np.random.rand(5, self.feature_dim + 1)
        pd.DataFrame(data).to_csv(self.path, index=False, header=False)

    def tearDown(self):
        """Remove the dummy CSV and directory."""
        os.remove(self.path)
        os.rmdir(self.test_dir)

    def test_loading(self):
        """Test loading data from a CSV file."""
        X, y = pre_p.load_mnist_35(self.test_dir, self.feature_dim, split="train")
        self.assertEqual(X.shape, (5, self.feature_dim))
        self.assertEqual(y.shape, (5,))

run_test(TestLoadMnist, 'test_loading')

F
FAIL: test_loading (__main__.TestLoadMnist.test_loading)
Test loading data from a CSV file.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/var/folders/7v/6bnphj4937q9tbp6zlbd71pr0000gn/T/ipykernel_8293/1892411339.py", line 20, in test_loading
    self.assertEqual(X.shape, (5, self.feature_dim))
    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Tuples differ: (4, 4) != (5, 4)

First differing element 0:
4
5

- (4, 4)
?  ^

+ (5, 4)
?  ^


----------------------------------------------------------------------
Ran 1 test in 0.114s

FAILED (failures=1)


--- Running test: test_loading ---
--- Test Failed: test_loading ---


In [None]:
show_code(pre_p.load_mnist_35)

---
## Module: `p_pack.model`

### Function: `full_unitaries_data_reupload`

In [23]:
class TestFullUnitaries(unittest.TestCase):
    def test_shapes(self):
        """Test the function that builds the full unitary for the model."""
        depth, feature_dim, num_samples = 5, 4, 10
        phases = circ.initialize_phases(depth, width=feature_dim * 2)
        weights = jnp.ones((depth, feature_dim))
        data_set = jnp.ones((num_samples, feature_dim))

        # FIX: The original model.py was missing imports and had NameErrors.
        # The source file has been corrected to properly import from circ.
        outputs = model.full_unitaries_data_reupload(phases, data_set, weights)
        
        self.assertEqual(len(outputs), 4)
        unitaries, sub_unitaries, label_probs, binary_probs_plus = outputs
        self.assertEqual(unitaries.shape[0], num_samples)
        self.assertIsNotNone(sub_unitaries)
        self.assertEqual(label_probs.shape[0], num_samples)
        self.assertEqual(binary_probs_plus.shape, (num_samples, 1))

run_test(TestFullUnitaries, 'test_shapes')

--- Running test: test_shapes ---


E
ERROR: test_shapes (__main__.TestFullUnitaries.test_shapes)
Test the function that builds the full unitary for the model.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/giancarloramirez/Documents/qml_project/venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 174, in _broadcast_shapes_uncached
    return _try_broadcast_shapes(*rank_promoted_shapes, name='broadcast_shapes')
  File "/Users/giancarloramirez/Documents/qml_project/venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 128, in _try_broadcast_shapes
    raise TypeError(f'{name} got incompatible shapes for broadcasting: '
                    f'{", ".join(map(str, map(tuple, shapes)))}.')
TypeError: broadcast_shapes got incompatible shapes for broadcasting: (10, 512), (1, 216).

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/giancarloramirez/Documents/qml_project/ven

--- Test Failed: test_shapes ---


In [None]:
show_code(model.full_unitaries_data_reupload)

### Function: `predict_reupload`

In [None]:
class TestPredictReupload(unittest.TestCase):
    def test_shapes(self):
        """Test the prediction function."""
        depth, feature_dim, num_samples = 5, 4, 10
        phases = circ.initialize_phases(depth, width=feature_dim * 2)
        weights = jnp.ones((depth, feature_dim))
        data_set = jnp.ones((num_samples, feature_dim))

        probs, adjusted_binary_probs = model.predict_reupload(phases, data_set, weights)

        self.assertEqual(probs.shape[0], num_samples)
        self.assertEqual(adjusted_binary_probs.shape, (num_samples, 1))

run_test(TestPredictReupload, 'test_shapes')

In [None]:
show_code(model.predict_reupload)

---
## Module: `p_pack.loss`

### Function: `loss`

In [None]:
class TestLoss(unittest.TestCase):
    def test_calculation(self):
        """Test the loss function returns a scalar value."""
        depth, feature_dim, num_samples = 5, 4, 10
        phases = circ.initialize_phases(depth, width=feature_dim * 2)
        weights = jnp.ones((depth, feature_dim))
        data_set = jnp.ones((num_samples, feature_dim))
        labels = jnp.ones(num_samples)

        # FIX: The original loss.py had a NameError.
        # The source file is corrected to call model.predict_reupload.
        loss_value = loss.loss(phases, data_set, labels, weights)
        self.assertTrue(jnp.isscalar(loss_value))

run_test(TestLoss, 'test_calculation')

In [None]:
show_code(loss.loss)

---
## Module: `p_pack.optimiser`

### Function: `adam_step`

In [None]:
class TestAdamStep(unittest.TestCase):
    def test_step(self):
        """Test a single step of the Adam optimizer."""
        depth, feature_dim, num_samples = 5, 4, 10
        params_phases = circ.initialize_phases(depth, width=feature_dim * 2)
        params_weights = jnp.ones((depth, feature_dim))
        data_set = jnp.ones((num_samples, feature_dim))
        labels = jnp.ones(num_samples)
        m_phases = jnp.zeros_like(params_phases)
        v_phases = jnp.zeros_like(params_phases)
        m_weights = jnp.zeros_like(params_weights)
        v_weights = jnp.zeros_like(params_weights)

        carry = [params_phases, data_set, labels, params_weights, m_phases, v_phases, m_weights, v_weights]
        step_number = 1

        # FIX: The original optimiser.py had a TypeError, calling the loss module.
        # The source file is corrected to call loss.loss().
        new_carry, loss_info = optimiser.adam_step(carry, step_number)

        self.assertEqual(len(new_carry), 8)
        self.assertEqual(loss_info.shape, (2,))

run_test(TestAdamStep, 'test_step')

In [None]:
show_code(optimiser.adam_step)

---
## Module: `p_pack.train`

### Function: `train`

In [None]:
class TestTrain(unittest.TestCase):
    def test_train_run(self):
        """Test the main training function."""
        depth, feature_dim, num_samples = 5, 4, 10
        phases = circ.initialize_phases(depth, width=feature_dim*2)
        weights = jnp.ones((depth, feature_dim))
        data_set = jnp.ones((num_samples, feature_dim))
        labels = jnp.ones(num_samples)
        m_phases = jnp.zeros_like(phases)
        v_phases = jnp.zeros_like(phases)
        m_weights = jnp.zeros_like(weights)
        v_weights = jnp.zeros_like(weights)

        init = (phases, data_set, labels, weights, m_phases, v_phases, m_weights, v_weights)

        # FIX: The original train.py had a NameError.
        # The source file is corrected to call optimiser.adam_step.
        final_carry, loss_history = train.train(init)

        self.assertEqual(len(final_carry), 8)
        self.assertEqual(loss_history.shape, (globals.num_steps, 2))

run_test(TestTrain, 'test_train_run')

In [None]:
show_code(train.train)