Skip to content

Commit

Permalink
Merge ebd523a into f4021f8
Browse files Browse the repository at this point in the history
  • Loading branch information
trappitsch committed Jan 20, 2022
2 parents f4021f8 + ebd523a commit a11b142
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 103 deletions.
115 changes: 86 additions & 29 deletions instruments/abstract_instruments/comm/usb_communicator.py
Expand Up @@ -10,7 +10,12 @@

import io

import usb.core
import usb.util

from instruments.abstract_instruments.comm import AbstractCommunicator
from instruments.units import ureg as u
from instruments.util_fns import assume_units

# CLASSES #####################################################################

Expand All @@ -28,12 +33,47 @@ class USBCommunicator(io.IOBase, AbstractCommunicator):
and it is suggested that it is not relied on.
"""

def __init__(self, conn):
def __init__(self, dev):
super(USBCommunicator, self).__init__(self)
# TODO: Check to make sure this is a USB connection
self._conn = conn
if not isinstance(dev, usb.core.Device):
raise TypeError("USBCommunicator must wrap a usb.core.Device object.")

# follow (mostly) pyusb tutorial

# set the active configuration. With no arguments, the first
# configuration will be the active one
dev.set_configuration()

# get an endpoint instance
cfg = dev.get_active_configuration()
intf = cfg[(0, 0)]

# initialize in and out endpoints
ep_out = usb.util.find_descriptor(
intf,
# match the first OUT endpoint
custom_match= \
lambda e: \
usb.util.endpoint_direction(e.bEndpointAddress) == \
usb.util.ENDPOINT_OUT
)

ep_in = usb.util.find_descriptor(
intf,
# match the first OUT endpoint
custom_match= \
lambda e: \
usb.util.endpoint_direction(e.bEndpointAddress) == \
usb.util.ENDPOINT_IN
)

if (ep_in or ep_out) is None:
raise IOError("USB endpoint not found.")

self._dev = dev
self._ep_in = ep_in
self._ep_out = ep_out
self._terminator = "\n"

# PROPERTIES #

@property
Expand All @@ -57,60 +97,77 @@ def terminator(self):
def terminator(self, newval):
if not isinstance(newval, str):
raise TypeError("Terminator for USBCommunicator must be specified "
"as a single character string.")
if len(newval) > 1:
raise ValueError("Terminator for USBCommunicator must only be 1 "
"character long.")
"as a character string.")
self._terminator = newval

@property
def timeout(self):
raise NotImplementedError
"""
Gets/sets the communication timeout of the USB communicator.
:type: `~pint.Quantity`
:units: As specified or assumed to be of units ``seconds``
"""
return assume_units(self._dev.default_timeout, u.ms).to(u.second)

@timeout.setter
def timeout(self, newval):
raise NotImplementedError
newval = assume_units(newval, u.second).to(u.ms).magnitude
self._dev.default_timeout = newval

# FILE-LIKE METHODS #

def close(self):
"""
Shutdown and close the USB connection
"""
try:
self._conn.shutdown()
finally:
self._conn.close()
self._dev.reset()
usb.util.dispose_resources(self._dev)

def read_raw(self, size=-1):
raise NotImplementedError
def read_raw(self, size=1000):
"""Read raw string back from device and return.
def read(self, size=-1, encoding="utf-8"):
raise NotImplementedError
String returned is most likely shorter than the size requested. Will
terminate by itself.
Read size of -1 will be transformed into 1000 bytes.
def write_raw(self, msg):
:param size: Size to read in bytes
:type size: int
"""
Write bytes to the raw usb connection object.
if size == -1:
size = 1000
term = self._terminator.encode("utf-8")
read_val = bytes(self._ep_in.read(size))
if term not in read_val:
raise IOError(f"Did not find the terminator in the returned string. "
f"Total size of {size} might not be enough.")
return read_val.rstrip(term)

def write_raw(self, msg):
"""Write bytes to the raw usb connection object.
:param bytes msg: Bytes to be sent to the instrument over the usb
connection.
"""
self._conn.write(msg)
self._ep_out.write(msg)

def seek(self, offset): # pylint: disable=unused-argument,no-self-use
return NotImplemented
raise NotImplementedError

def tell(self): # pylint: disable=no-self-use
return NotImplemented
raise NotImplementedError

def flush_input(self):
"""
Instruct the communicator to flush the input buffer, discarding the
entirety of its contents.
Not implemented for usb communicator
entirety of its contents. Read 1000 bytes at a time and be done
once a timeout error comes back (which means the buffer is empty).
"""
raise NotImplementedError
while True:
try:
self._ep_in.read(1000, 10) # read until any exception
except: # pylint: disable=bare-except
break

# METHODS #

Expand All @@ -124,9 +181,9 @@ def _sendcmd(self, msg):
:param str msg: The command message to send to the instrument
"""
msg += self._terminator
self._conn.sendall(msg)
self.write(msg)

def _query(self, msg, size=-1):
def _query(self, msg, size=1000):
"""
This is the implementation of ``query`` for communicating with
raw usb connections. This function is in turn wrapped by the concrete
Expand Down
29 changes: 2 additions & 27 deletions instruments/abstract_instruments/instrument.py
Expand Up @@ -16,9 +16,7 @@
from serial import SerialException
from serial.tools.list_ports import comports
import pyvisa
import usb
import usb.core
import usb.util

from instruments.abstract_instruments.comm import (
SocketCommunicator, USBCommunicator, VisaCommunicator, FileCommunicator,
Expand Down Expand Up @@ -674,39 +672,16 @@ def open_usb(cls, vid, pid):
method.
:param str vid: Vendor ID of the USB device to open.
:param int pid: Product ID of the USB device to open.
:param str pid: Product ID of the USB device to open.
:rtype: `Instrument`
:return: Object representing the connected instrument.
"""
# pylint: disable=no-member
dev = usb.core.find(idVendor=vid, idProduct=pid)
if dev is None:
raise IOError("No such device found.")

# Use the default configuration offered by the device.
dev.set_configuration()

# Copied from the tutorial at:
# http://pyusb.sourceforge.net/docs/1.0/tutorial.html
cfg = dev.get_active_configuration()
interface_number = cfg[(0, 0)].bInterfaceNumber
alternate_setting = usb.control.get_interface(dev, interface_number)
intf = usb.util.find_descriptor(
cfg, bInterfaceNumber=interface_number,
bAlternateSetting=alternate_setting
)

ep = usb.util.find_descriptor(
intf,
custom_match=lambda e:
usb.util.endpoint_direction(e.bEndpointAddress) ==
usb.util.ENDPOINT_OUT
)
if ep is None:
raise IOError("USB descriptor not found.")

return cls(USBCommunicator(ep))
return cls(USBCommunicator(dev))

@classmethod
def open_file(cls, filename):
Expand Down
54 changes: 10 additions & 44 deletions instruments/tests/test_base_instrument.py
Expand Up @@ -346,46 +346,23 @@ def test_instrument_open_vxi11(mock_vxi11_comm):
mock_vxi11_comm.assert_called_with("string", 1, key1="value")


@mock.patch("instruments.abstract_instruments.instrument.USBCommunicator")
@mock.patch("instruments.abstract_instruments.instrument.usb")
def test_instrument_open_usb(mock_usb):
def test_instrument_open_usb(mock_usb, mock_usb_comm):
"""Open USB device."""
# mock some behavior
mock_usb.core.find.return_value.__class__ = usb.core.Device # dev
mock_usb.core.find().get_active_configuration.return_value.__class__ = (
usb.core.Configuration
)
mock_usb.core.find.return_value.__class__ = usb.core.Device
mock_usb_comm.return_value.__class__ = USBCommunicator

# shortcuts for asserting calls
dev = mock_usb.core.find()
cfg = dev.get_active_configuration()
interface_number = cfg[(0, 0)].bInterfaceNumber
alternate_setting = mock_usb.control.get_interface(
dev, cfg[(0, 0)].bInterfaceNumber
)
# fake instrument
vid = "0x1000"
pid = "0x1000"
dev = mock_usb.core.find(idVendor=vid, idProduct=pid)

# call instrument
inst = ik.Instrument.open_usb("0x1000", 0x1000)

# assert calls according to manual
dev.set_configuration.assert_called() # check default configuration
dev.get_active_configuration.assert_called() # get active configuration
mock_usb.control.get_interface.assert_called_with(dev, interface_number)
mock_usb.util.find_descriptor.assert_any_call(
cfg,
bInterfaceNumber=interface_number,
bAlternateSetting=alternate_setting
)
# check the first argument of the `ep =` call
assert mock_usb.util.find_descriptor.call_args_list[1][0][0] == (
mock_usb.util.find_descriptor(
cfg,
bInterfaceNumber=interface_number,
bAlternateSetting=alternate_setting
)
)
inst = ik.Instrument.open_usb(vid, pid)

# assert instrument of correct class
assert isinstance(inst._file, USBCommunicator)
mock_usb_comm.assert_called_with(dev)


@mock.patch("instruments.abstract_instruments.instrument.usb")
Expand All @@ -398,17 +375,6 @@ def test_instrument_open_usb_no_device(mock_usb):
assert err_msg == "No such device found."


@mock.patch("instruments.abstract_instruments.instrument.usb")
def test_instrument_open_usb_ep_none(mock_usb):
"""Raise IOError if endpoint matching returns None."""
mock_usb.util.find_descriptor.return_value = None

with pytest.raises(IOError) as err:
_ = ik.Instrument.open_usb(0x1000, 0x1000)
err_msg = err.value.args[0]
assert err_msg == "USB descriptor not found."


@mock.patch("instruments.abstract_instruments.instrument.USBTMCCommunicator")
def test_instrument_open_usbtmc(mock_usbtmc_comm):
mock_usbtmc_comm.return_value.__class__ = USBTMCCommunicator
Expand Down

0 comments on commit a11b142

Please sign in to comment.