Skip to content

Commit

Permalink
Fix Pulse channel index validation (#10476)
Browse files Browse the repository at this point in the history
* Correct channel index validation

* Add tests and release note.

(cherry picked from commit 802a735)
  • Loading branch information
TsafrirA authored and mergify[bot] committed Jul 24, 2023
1 parent 7d964cb commit ccfd8cd
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
2 changes: 1 addition & 1 deletion qiskit/pulse/channels.py
Expand Up @@ -121,7 +121,7 @@ def _validate_index(self, index: Any) -> None:
if index.is_integer():
index = int(index)

if not isinstance(index, (int, np.integer)) and index < 0:
if not isinstance(index, (int, np.integer)) or index < 0:
raise PulseError("Channel index must be a nonnegative integer")

@property
Expand Down
@@ -0,0 +1,5 @@
---
fixes:
- |
Fixed a bug in :class:`.pulse.Channel` where index validation was done incorrectly and only
raised an error when the index was both non-integer and negative, instead of either.
36 changes: 36 additions & 0 deletions test/python/pulse/test_channels.py
Expand Up @@ -25,6 +25,7 @@
PulseChannel,
RegisterSlot,
SnapshotChannel,
PulseError,
)
from qiskit.test import QiskitTestCase

Expand Down Expand Up @@ -88,6 +89,13 @@ def test_default(self):
self.assertEqual(memory_slot.name, "m123")
self.assertTrue(isinstance(memory_slot, ClassicalIOChannel))

def test_validation(self):
"""Test channel validation"""
with self.assertRaises(PulseError):
MemorySlot(0.5)
with self.assertRaises(PulseError):
MemorySlot(-1)


class TestRegisterSlot(QiskitTestCase):
"""RegisterSlot tests."""
Expand All @@ -100,6 +108,13 @@ def test_default(self):
self.assertEqual(register_slot.name, "c123")
self.assertTrue(isinstance(register_slot, ClassicalIOChannel))

def test_validation(self):
"""Test channel validation"""
with self.assertRaises(PulseError):
RegisterSlot(0.5)
with self.assertRaises(PulseError):
RegisterSlot(-1)


class TestSnapshotChannel(QiskitTestCase):
"""SnapshotChannel tests."""
Expand All @@ -123,6 +138,13 @@ def test_default(self):
self.assertEqual(drive_channel.index, 123)
self.assertEqual(drive_channel.name, "d123")

def test_validation(self):
"""Test channel validation"""
with self.assertRaises(PulseError):
DriveChannel(0.5)
with self.assertRaises(PulseError):
DriveChannel(-1)


class TestControlChannel(QiskitTestCase):
"""ControlChannel tests."""
Expand All @@ -134,6 +156,13 @@ def test_default(self):
self.assertEqual(control_channel.index, 123)
self.assertEqual(control_channel.name, "u123")

def test_validation(self):
"""Test channel validation"""
with self.assertRaises(PulseError):
ControlChannel(0.5)
with self.assertRaises(PulseError):
ControlChannel(-1)


class TestMeasureChannel(QiskitTestCase):
"""MeasureChannel tests."""
Expand All @@ -145,6 +174,13 @@ def test_default(self):
self.assertEqual(measure_channel.index, 123)
self.assertEqual(measure_channel.name, "m123")

def test_validation(self):
"""Test channel validation"""
with self.assertRaises(PulseError):
MeasureChannel(0.5)
with self.assertRaises(PulseError):
MeasureChannel(-1)


if __name__ == "__main__":
unittest.main()

0 comments on commit ccfd8cd

Please sign in to comment.