Skip to content

Commit

Permalink
fix a type hint and add tests for iterable (#12309)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-imamichi committed Apr 30, 2024
1 parent 958cc9b commit 95476b7
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 1 deletion.
2 changes: 1 addition & 1 deletion qiskit/primitives/backend_sampler_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _validate_pubs(self, pubs: list[SamplerPub]):
UserWarning,
)

def _run(self, pubs: Iterable[SamplerPub]) -> PrimitiveResult[PubResult]:
def _run(self, pubs: list[SamplerPub]) -> PrimitiveResult[PubResult]:
pub_dict = defaultdict(list)
# consolidate pubs with the same number of shots
for i, pub in enumerate(pubs):
Expand Down
12 changes: 12 additions & 0 deletions test/python/primitives/test_backend_estimator_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,18 @@ def test_job_size_limit_backend_v1(self):
estimator.run([(qc, op, param_list)] * k).result()
self.assertEqual(run_mock.call_count, 10)

def test_iter_pub(self):
"""test for an iterable of pubs"""
backend = BasicSimulator()
circuit = self.ansatz.assign_parameters([0, 1, 1, 2, 3, 5])
pm = generate_preset_pass_manager(optimization_level=0, backend=backend)
circuit = pm.run(circuit)
estimator = BackendEstimatorV2(backend=backend, options=self._options)
observable = self.observable.apply_layout(circuit.layout)
result = estimator.run(iter([(circuit, observable), (circuit, observable)])).result()
np.testing.assert_allclose(result[0].data.evs, [-1.284366511861733], rtol=self._rtol)
np.testing.assert_allclose(result[1].data.evs, [-1.284366511861733], rtol=self._rtol)


if __name__ == "__main__":
unittest.main()
17 changes: 17 additions & 0 deletions test/python/primitives/test_backend_sampler_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,23 @@ def test_job_size_limit_backend_v1(self):
self._assert_allclose(result[0].data.meas, np.array({0: self._shots}))
self._assert_allclose(result[1].data.meas, np.array({1: self._shots}))

def test_iter_pub(self):
"""Test of an iterable of pubs"""
backend = BasicSimulator()
qc = QuantumCircuit(1)
qc.measure_all()
qc2 = QuantumCircuit(1)
qc2.x(0)
qc2.measure_all()
sampler = BackendSamplerV2(backend=backend)
result = sampler.run(iter([qc, qc2]), shots=self._shots).result()
self.assertIsInstance(result, PrimitiveResult)
self.assertEqual(len(result), 2)
self.assertIsInstance(result[0], PubResult)
self.assertIsInstance(result[1], PubResult)
self._assert_allclose(result[0].data.meas, np.array({0: self._shots}))
self._assert_allclose(result[1].data.meas, np.array({1: self._shots}))


if __name__ == "__main__":
unittest.main()
9 changes: 9 additions & 0 deletions test/python/primitives/test_statevector_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,15 @@ def test_precision_seed(self):
result = job.result()
np.testing.assert_allclose(result[0].data.evs, [1.5555572817900956])

def test_iter_pub(self):
"""test for an iterable of pubs"""
estimator = StatevectorEstimator()
circuit = self.ansatz.assign_parameters([0, 1, 1, 2, 3, 5])
observable = self.observable.apply_layout(circuit.layout)
result = estimator.run(iter([(circuit, observable), (circuit, observable)])).result()
np.testing.assert_allclose(result[0].data.evs, [-1.284366511861733])
np.testing.assert_allclose(result[1].data.evs, [-1.284366511861733])


if __name__ == "__main__":
unittest.main()
16 changes: 16 additions & 0 deletions test/python/primitives/test_statevector_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,22 @@ def test_no_cregs(self):
self.assertEqual(len(result), 1)
self.assertEqual(len(result[0].data), 0)

def test_iter_pub(self):
"""Test of an iterable of pubs"""
qc = QuantumCircuit(1)
qc.measure_all()
qc2 = QuantumCircuit(1)
qc2.x(0)
qc2.measure_all()
sampler = StatevectorSampler()
result = sampler.run(iter([qc, qc2]), shots=self._shots).result()
self.assertIsInstance(result, PrimitiveResult)
self.assertEqual(len(result), 2)
self.assertIsInstance(result[0], PubResult)
self.assertIsInstance(result[1], PubResult)
self._assert_allclose(result[0].data.meas, np.array({0: self._shots}))
self._assert_allclose(result[1].data.meas, np.array({1: self._shots}))


if __name__ == "__main__":
unittest.main()

0 comments on commit 95476b7

Please sign in to comment.