116 changes: 70 additions & 46 deletions qubes/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@

_NO_DEFAULT = object()


class Features(dict):
'''Manager of the features.
"""Manager of the features.
Features can have three distinct values: no value (not present in mapping,
which is closest thing to :py:obj:`None`), empty string (which is
Expand All @@ -38,7 +39,7 @@ class Features(dict):
This class inherits from dict, but has most of the methods that manipulate
the item disarmed (they raise NotImplementedError). The ones that are left
fire appropriate events on the qube that owns an instance of this class.
'''
"""

#
# Those are the methods that affect contents. Either disarm them or make
Expand All @@ -51,13 +52,13 @@ def __init__(self, subject, other=None, **kwargs):
self.update(other, **kwargs)

def __delitem__(self, key):
self.subject.fire_event('domain-feature-pre-delete:' + key, feature=key)
self.subject.fire_event("domain-feature-pre-delete:" + key, feature=key)
super().__delitem__(key)
self.subject.fire_event('domain-feature-delete:' + key, feature=key)
self.subject.fire_event("domain-feature-delete:" + key, feature=key)

def __setitem__(self, key, value):
if value is None or isinstance(value, bool):
value = '1' if value else ''
value = "1" if value else ""
else:
value = str(value)
try:
Expand All @@ -66,46 +67,58 @@ def __setitem__(self, key, value):
except KeyError:
has_oldvalue = False
if has_oldvalue:
self.subject.fire_event('domain-feature-pre-set:' + key,
self.subject.fire_event(
"domain-feature-pre-set:" + key,
pre_event=True,
feature=key, value=value, oldvalue=oldvalue)
feature=key,
value=value,
oldvalue=oldvalue,
)
else:
self.subject.fire_event('domain-feature-pre-set:' + key,
self.subject.fire_event(
"domain-feature-pre-set:" + key,
pre_event=True,
feature=key, value=value)
feature=key,
value=value,
)
super().__setitem__(key, value)
if has_oldvalue:
self.subject.fire_event('domain-feature-set:' + key, feature=key,
value=value, oldvalue=oldvalue)
self.subject.fire_event(
"domain-feature-set:" + key,
feature=key,
value=value,
oldvalue=oldvalue,
)
else:
self.subject.fire_event('domain-feature-set:' + key, feature=key,
value=value)
self.subject.fire_event(
"domain-feature-set:" + key, feature=key, value=value
)

def clear(self):
for key in tuple(self):
del self[key]

def pop(self, _key, _default=None):
'''Not implemented
"""Not implemented
:raises: NotImplementedError
'''
"""
raise NotImplementedError()

def popitem(self):
'''Not implemented
"""Not implemented
:raises: NotImplementedError
'''
"""
raise NotImplementedError()

def setdefault(self, _key, _default=None):
'''Not implemented
"""Not implemented
:raises: NotImplementedError
'''
"""
raise NotImplementedError()

def update(self, other=None, **kwargs):
if other is not None:
if hasattr(other, 'keys'):
if hasattr(other, "keys"):
for key in other:
self[key] = other[key]
else:
Expand All @@ -119,9 +132,16 @@ def update(self, other=None, **kwargs):
# end of overriding
#

def _recursive_check(self, attr=None, *, feature, default,
check_adminvm=False, check_app=False):
'''Recursive search for a feature.
def _recursive_check(
self,
attr=None,
*,
feature,
default,
check_adminvm=False,
check_app=False
):
"""Recursive search for a feature.
Traverse domains along one attribute, like
:py:attr:`qubes.vm.qubesvm.QubesVM.netvm` or
Expand All @@ -136,13 +156,15 @@ def _recursive_check(self, attr=None, *, feature, default,
If `check_app` is true, also the app feature is checked. This is not
implemented, as app does not have features yet.
'''
"""
if check_app:
raise NotImplementedError('app does not have features yet')
raise NotImplementedError("app does not have features yet")

assert isinstance(self.subject, _vm.BaseVM), (
'recursive checks do not work for {}'.format(
type(self.subject).__name__))
assert isinstance(
self.subject, _vm.BaseVM
), "recursive checks do not work for {}".format(
type(self.subject).__name__
)

subject = self.subject
while subject is not None:
Expand All @@ -154,7 +176,7 @@ def _recursive_check(self, attr=None, *, feature, default,
subject = getattr(subject, attr, None)

if check_adminvm:
adminvm = self.subject.app.domains['dom0']
adminvm = self.subject.app.domains["dom0"]
if adminvm not in (None, self.subject):
try:
return adminvm.features[feature]
Expand All @@ -169,26 +191,28 @@ def _recursive_check(self, attr=None, *, feature, default,
raise KeyError(feature)

def check_with_template(self, feature, default=_NO_DEFAULT):
'''Check for the specified feature; if this VM does not have it,
it checks with its template.'''
return self._recursive_check('template',
feature=feature, default=default)
"""Check for the specified feature; if this VM does not have it,
it checks with its template."""
return self._recursive_check(
"template", feature=feature, default=default
)

def check_with_netvm(self, feature, default=_NO_DEFAULT):
'''Check for the specified feature; if this VM does not have it,
it checks with its netvm.'''
return self._recursive_check('netvm',
feature=feature, default=default)
"""Check for the specified feature; if this VM does not have it,
it checks with its netvm."""
return self._recursive_check("netvm", feature=feature, default=default)

def check_with_adminvm(self, feature, default=_NO_DEFAULT):
'''Check for the specified feature; if this VM does not have it,
it checks with the AdminVM.'''
return self._recursive_check(check_adminvm=True,
feature=feature, default=default)
"""Check for the specified feature; if this VM does not have it,
it checks with the AdminVM."""
return self._recursive_check(
check_adminvm=True, feature=feature, default=default
)

def check_with_template_and_adminvm(self, feature, default=_NO_DEFAULT):
'''Check for the specified feature; if this VM does not have it,
it checks with its template. If the template does not have it, it
checks with the AdminVM.'''
return self._recursive_check('template', check_adminvm=True,
feature=feature, default=default)
"""Check for the specified feature; if this VM does not have it,
it checks with its template. If the template does not have it, it
checks with the AdminVM."""
return self._recursive_check(
"template", check_adminvm=True, feature=feature, default=default
)
395 changes: 213 additions & 182 deletions qubes/firewall.py

Large diffs are not rendered by default.

30 changes: 16 additions & 14 deletions qubes/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
# License along with this library; if not, see <https://www.gnu.org/licenses/>.
#

'''Qubes logging routines
"""Qubes logging routines
See also: :py:attr:`qubes.vm.qubesvm.QubesVM.log`
'''
"""

import logging
import sys
Expand All @@ -33,26 +33,26 @@ def __init__(self, *args, debug=False, **kwargs):
self.debug = debug

def formatMessage(self, record):
fmt = ''
fmt = ""
if self.debug:
fmt += '[%(processName)s %(module)s.%(funcName)s:%(lineno)d] '
if self.debug or record.name.startswith('vm.'):
fmt += '%(name)s: '
fmt += '%(message)s'
fmt += "[%(processName)s %(module)s.%(funcName)s:%(lineno)d] "
if self.debug or record.name.startswith("vm."):
fmt += "%(name)s: "
fmt += "%(message)s"

return fmt % record.__dict__


def enable():
'''Enable global logging
"""Enable global logging
Use :py:mod:`logging` module from standard library to log messages.
>>> import qubes.log
>>> qubes.log.enable() # doctest: +SKIP
>>> import logging
>>> logging.warning('Foobar') # doctest: +SKIP
'''
"""

if logging.root.handlers:
return
Expand All @@ -63,11 +63,12 @@ def enable():

logging.root.setLevel(logging.INFO)


def enable_debug():
'''Enable debug logging
"""Enable debug logging
Enable more messages and additional info to message format.
'''
"""

enable()

Expand All @@ -76,11 +77,12 @@ def enable_debug():

logging.root.setLevel(logging.DEBUG)


def get_vm_logger(vmname):
'''Initialise logging for particular VM name
"""Initialise logging for particular VM name
:param str vmname: VM's name
:rtype: :py:class:`logging.Logger`
'''
"""

return logging.getLogger('vm.' + vmname)
return logging.getLogger("vm." + vmname)
128 changes: 89 additions & 39 deletions qubes/qmemman/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
MIN_PREFMEM = 200 * 1024 * 1024
DOM0_MEM_BOOST = 350 * 1024 * 1024

log = logging.getLogger('qmemman.daemon.algo')
log = logging.getLogger("qmemman.daemon.algo")


# untrusted meminfo size is taken from xenstore key, thus its size is limited
Expand All @@ -53,11 +53,19 @@ def refresh_meminfo_for_domain(domain, untrusted_xenstore_key):
def prefmem(domain):
# dom0 is special, as it must have large cache, for vbds. Thus, give it
# a special boost
if domain.id == '0':
return int(min(domain.mem_used * CACHE_FACTOR + DOM0_MEM_BOOST,
domain.memory_maximum))
return int(max(min(domain.mem_used * CACHE_FACTOR, domain.memory_maximum),
MIN_PREFMEM))
if domain.id == "0":
return int(
min(
domain.mem_used * CACHE_FACTOR + DOM0_MEM_BOOST,
domain.memory_maximum,
)
)
return int(
max(
min(domain.mem_used * CACHE_FACTOR, domain.memory_maximum),
MIN_PREFMEM,
)
)


def memory_needed(domain):
Expand All @@ -72,8 +80,11 @@ def memory_needed(domain):
# to "xm memset" equivalent in order to obtain "memsize" of memory
# return empty list when the request cannot be satisfied
def balloon(memsize, domain_dictionary):
log.debug('balloon(memsize={!r}, domain_dictionary={!r})'.format(
memsize, domain_dictionary))
log.debug(
"balloon(memsize={!r}, domain_dictionary={!r})".format(
memsize, domain_dictionary
)
)
REQ_SAFETY_NET_FACTOR = 1.05
donors = list()
request = list()
Expand All @@ -85,20 +96,23 @@ def balloon(memsize, domain_dictionary):
continue
need = memory_needed(domain_dictionary[i])
if need < 0:
log.info('balloon: dom {} has actual memory {}'.format(i,
domain_dictionary[i].memory_actual))
log.info(
"balloon: dom {} has actual memory {}".format(
i, domain_dictionary[i].memory_actual
)
)
donors.append((i, -need))
available -= need

log.info('req={} avail={} donors={!r}'.format(memsize, available, donors))
log.info("req={} avail={} donors={!r}".format(memsize, available, donors))

if available < memsize:
return []
scale = 1.0 * memsize / available
for donors_iter in donors:
dom_id, mem = donors_iter
memborrowed = mem * scale * REQ_SAFETY_NET_FACTOR
log.info('borrow {} from {}'.format(memborrowed, dom_id))
log.info("borrow {} from {}".format(memborrowed, dom_id))
memtarget = int(domain_dictionary[dom_id].memory_actual - memborrowed)
request.append((dom_id, memtarget))
return request
Expand All @@ -111,11 +125,15 @@ def balloon(memsize, domain_dictionary):

# redistribute positive "total_available_memory" of memory between domains,
# proportionally to prefmem
def balance_when_enough_memory(domain_dictionary,
xen_free_memory, total_mem_pref, total_available_memory):
log.info('balance_when_enough_memory(xen_free_memory={!r}, '
'total_mem_pref={!r}, total_available_memory={!r})'.format(
xen_free_memory, total_mem_pref, total_available_memory))
def balance_when_enough_memory(
domain_dictionary, xen_free_memory, total_mem_pref, total_available_memory
):
log.info(
"balance_when_enough_memory(xen_free_memory={!r}, "
"total_mem_pref={!r}, total_available_memory={!r})".format(
xen_free_memory, total_mem_pref, total_available_memory
)
)

target_memory = {}
# memory not assigned because of static max
Expand All @@ -128,8 +146,9 @@ def balance_when_enough_memory(domain_dictionary,
continue
# distribute total_available_memory proportionally to mempref
scale = 1.0 * prefmem(domain_dictionary[i]) / total_mem_pref
target_nonint = prefmem(
domain_dictionary[i]) + scale * total_available_memory
target_nonint = (
prefmem(domain_dictionary[i]) + scale * total_available_memory
)
# prevent rounding errors
target = int(0.999 * target_nonint)
# do not try to give more memory than static max
Expand All @@ -142,8 +161,11 @@ def balance_when_enough_memory(domain_dictionary,
target_memory[i] = target
# distribute left memory across all acceptors
while left_memory > 0 and acceptors_count > 0:
log.info('left_memory={} acceptors_count={}'.format(
left_memory, acceptors_count))
log.info(
"left_memory={} acceptors_count={}".format(
left_memory, acceptors_count
)
)

new_left_memory = 0
new_acceptors_count = acceptors_count
Expand All @@ -152,8 +174,11 @@ def balance_when_enough_memory(domain_dictionary,
if target < domain_dictionary[i].memory_maximum:
memory_bonus = int(0.999 * (left_memory / acceptors_count))
if target + memory_bonus >= domain_dictionary[i].memory_maximum:
new_left_memory += target + memory_bonus - \
domain_dictionary[i].memory_maximum
new_left_memory += (
target
+ memory_bonus
- domain_dictionary[i].memory_maximum
)
target = domain_dictionary[i].memory_maximum
new_acceptors_count -= 1
else:
Expand All @@ -180,11 +205,19 @@ def balance_when_enough_memory(domain_dictionary,

# when not enough mem to make everyone be above prefmem, make donors be at
# prefmem, and redistribute anything left between acceptors
def balance_when_low_on_memory(domain_dictionary,
xen_free_memory, total_mem_pref_acceptors, donors, acceptors):
log.info('balance_when_low_on_memory(xen_free_memory={!r}, '
'total_mem_pref_acceptors={!r}, donors={!r}, acceptors={!r})'.format(
xen_free_memory, total_mem_pref_acceptors, donors, acceptors))
def balance_when_low_on_memory(
domain_dictionary,
xen_free_memory,
total_mem_pref_acceptors,
donors,
acceptors,
):
log.info(
"balance_when_low_on_memory(xen_free_memory={!r}, "
"total_mem_pref_acceptors={!r}, donors={!r}, acceptors={!r})".format(
xen_free_memory, total_mem_pref_acceptors, donors, acceptors
)
)
donors_rq = list()
acceptors_rq = list()
squeezed_mem = xen_free_memory
Expand All @@ -201,11 +234,13 @@ def balance_when_low_on_memory(domain_dictionary,
return donors_rq
for i in acceptors:
scale = 1.0 * prefmem(domain_dictionary[i]) / total_mem_pref_acceptors
target_nonint = \
target_nonint = (
domain_dictionary[i].memory_actual + scale * squeezed_mem
)
# do not try to give more memory than static max
target = \
min(int(0.999 * target_nonint), domain_dictionary[i].memory_maximum)
target = min(
int(0.999 * target_nonint), domain_dictionary[i].memory_maximum
)
acceptors_rq.append((i, target))
# print 'balance(low): xen_free_memory=', xen_free_memory, 'requests:',
# donors_rq + acceptors_rq
Expand All @@ -217,8 +252,11 @@ def balance_when_low_on_memory(domain_dictionary,
# return the list of (domain, memory_target) pairs to be passed to
# "xm memset" equivalent
def balance(xen_free_memory, domain_dictionary):
log.debug('balance(xen_free_memory={!r}, domain_dictionary={!r})'.format(
xen_free_memory, domain_dictionary))
log.debug(
"balance(xen_free_memory={!r}, domain_dictionary={!r})".format(
xen_free_memory, domain_dictionary
)
)

# sum of all memory requirements - in other words, the difference between
# memory required to be added to domains (acceptors) to make them be
Expand All @@ -245,8 +283,11 @@ def balance(xen_free_memory, domain_dictionary):
# print 'domain' , i, 'act/pref', \
# domain_dictionary[i].memory_actual, prefmem(domain_dictionary[i]), \
# 'need=', need
if need < 0 or domain_dictionary[i].memory_actual >= \
domain_dictionary[i].memory_maximum:
if (
need < 0
or domain_dictionary[i].memory_actual
>= domain_dictionary[i].memory_maximum
):
donors.append(i)
else:
acceptors.append(i)
Expand All @@ -256,8 +297,17 @@ def balance(xen_free_memory, domain_dictionary):

total_available_memory = xen_free_memory - total_memory_needed
if total_available_memory > 0:
return balance_when_enough_memory(domain_dictionary, xen_free_memory,
total_mem_pref, total_available_memory)
return balance_when_enough_memory(
domain_dictionary,
xen_free_memory,
total_mem_pref,
total_available_memory,
)
else:
return balance_when_low_on_memory(domain_dictionary, xen_free_memory,
total_mem_pref_acceptors, donors, acceptors)
return balance_when_low_on_memory(
domain_dictionary,
xen_free_memory,
total_mem_pref_acceptors,
donors,
acceptors,
)
5 changes: 3 additions & 2 deletions qubes/qmemman/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import socket
import fcntl


class QMemmanClient:
def request_memory(self, amount):
self.sock = socket.socket(socket.AF_UNIX)
Expand All @@ -31,9 +32,9 @@ def request_memory(self, amount):
fcntl.fcntl(self.sock.fileno(), fcntl.F_SETFD, flags)

self.sock.connect("/var/run/qubes/qmemman.sock")
self.sock.send(str(int(amount)).encode('ascii')+b"\n")
self.sock.send(str(int(amount)).encode("ascii") + b"\n")
received = self.sock.recv(1024).strip()
if received == b'OK':
if received == b"OK":
return True
else:
return False
Expand Down
19 changes: 10 additions & 9 deletions qubes/qmemman/domainstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@
# You should have received a copy of the GNU General Public
# License along with this library; if not, see <https://www.gnu.org/licenses/>.


class DomainState:
def __init__(self, id):
self.memory_current = 0 # the current memory size
self.memory_actual = None # the current memory allocation (what VM
# is using or can use at any time)
self.memory_current = 0 # the current memory size
self.memory_actual = None # the current memory allocation (what VM
# is using or can use at any time)
self.memory_maximum = None # the maximum memory size
self.mem_used = None # used memory, computed based on meminfo
self.id = id # domain id
self.last_target = 0 # the last memset target
self.use_hotplug = False # use memory hotplug for mem-set
self.no_progress = False # no react to memset
self.mem_used = None # used memory, computed based on meminfo
self.id = id # domain id
self.last_target = 0 # the last memset target
self.use_hotplug = False # use memory hotplug for mem-set
self.no_progress = False # no react to memset
self.slow_memset_react = False # slow react to memset (after few
# tries still above target)
# tries still above target)

def __repr__(self):
return self.__dict__.__repr__()
260 changes: 170 additions & 90 deletions qubes/qmemman/systemstate.py

Large diffs are not rendered by default.

179 changes: 101 additions & 78 deletions qubes/rngdoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,194 +24,217 @@

import lxml.etree


class Element:
def __init__(self, schema, xml):
self.schema = schema
self.xml = xml


@property
def nsmap(self):
return self.schema.nsmap


@property
def name(self):
return self.xml.get('name')

return self.xml.get("name")

def get_description(self, xml=None, wrap=True):
if xml is None:
xml = self.xml

xml = xml.xpath('./doc:description', namespaces=self.nsmap)
xml = xml.xpath("./doc:description", namespaces=self.nsmap)
if not xml:
return ''
return ""
xml = xml[0]

if wrap:
return ''.join(self.schema.wrapper.fill(p) + '\n\n'
for p in textwrap.dedent(xml.text.strip('\n')).split('\n\n'))

return ' '.join(xml.text.strip().split())
return "".join(
self.schema.wrapper.fill(p) + "\n\n"
for p in textwrap.dedent(xml.text.strip("\n")).split("\n\n")
)

return " ".join(xml.text.strip().split())

def get_data_type(self, xml=None):
if xml is None:
xml = self.xml

value = xml.xpath('./rng:value', namespaces=self.nsmap)
value = xml.xpath("./rng:value", namespaces=self.nsmap)
if value:
value = '``{}``'.format(value[0].text.strip())
value = "``{}``".format(value[0].text.strip())
else:
metavar = xml.xpath('./doc:metavar', namespaces=self.nsmap)
metavar = xml.xpath("./doc:metavar", namespaces=self.nsmap)
if metavar:
value = '``{}``'.format(metavar[0].text.strip())
value = "``{}``".format(metavar[0].text.strip())
else:
value = ''
value = ""

xml = xml.xpath('./rng:data', namespaces=self.nsmap)
xml = xml.xpath("./rng:data", namespaces=self.nsmap)
if not xml:
return ('', value)
return ("", value)

xml = xml[0]
type_ = xml.get('type', '')
type_ = xml.get("type", "")

if not value:
pattern = xml.xpath('./rng:param[@name="pattern"]',
namespaces=self.nsmap)
pattern = xml.xpath(
'./rng:param[@name="pattern"]', namespaces=self.nsmap
)
if pattern:
value = '``{}``'.format(pattern[0].text.strip())
value = "``{}``".format(pattern[0].text.strip())

return type_, value


def get_attributes(self):
for xml in self.xml.xpath('''./rng:attribute |
for xml in self.xml.xpath(
"""./rng:attribute |
./rng:optional/rng:attribute |
./rng:choice/rng:attribute''', namespaces=self.nsmap):
required = 'yes' if xml.getparent() == self.xml else 'no'
./rng:choice/rng:attribute""",
namespaces=self.nsmap,
):
required = "yes" if xml.getparent() == self.xml else "no"
yield (xml, required)


def resolve_ref(self, ref):
refs = self.xml.xpath(
'//rng:define[name="{}"]/rng:element'.format(ref['name']))
'//rng:define[name="{}"]/rng:element'.format(ref["name"])
)
return refs[0] if refs else None


def get_child_elements(self):
for xml in self.xml.xpath('''./rng:element | ./rng:ref |
for xml in self.xml.xpath(
"""./rng:element | ./rng:ref |
./rng:optional/rng:element | ./rng:optional/rng:ref |
./rng:zeroOrMore/rng:element | ./rng:zeroOrMore/rng:ref |
./rng:oneOrMore/rng:element | ./rng:oneOrMore/rng:ref''',
namespaces=self.nsmap):
./rng:oneOrMore/rng:element | ./rng:oneOrMore/rng:ref""",
namespaces=self.nsmap,
):
parent = xml.getparent()
qname = lxml.etree.QName(parent)
if parent == self.xml:
number = '1'
elif qname.localname == 'optional':
number = '?'
elif qname.localname == 'zeroOrMore':
number = '\\*'
elif qname.localname == 'oneOrMore':
number = '\\+'
number = "1"
elif qname.localname == "optional":
number = "?"
elif qname.localname == "zeroOrMore":
number = "\\*"
elif qname.localname == "oneOrMore":
number = "\\+"
else:
print(parent.tag)
raise Exception(
f"Cannot choose number format for tag {parent.tag}")
f"Cannot choose number format for tag {parent.tag}"
)

if xml.tag == 'ref':
if xml.tag == "ref":
xml = self.resolve_ref(xml)
if xml is None:
continue

yield (self.schema.elements[xml.get('name')], number)

yield (self.schema.elements[xml.get("name")], number)

def write_rst(self, stream):
stream.write('.. _qubesxml-element-{}:\n\n'.format(self.name))
stream.write(make_rst_section('Element: **{}**'.format(self.name), '-'))
stream.write(".. _qubesxml-element-{}:\n\n".format(self.name))
stream.write(make_rst_section("Element: **{}**".format(self.name), "-"))
stream.write(self.get_description())

attrtable = []
for attr, required in self.get_attributes():
type_, value = self.get_data_type(attr)
attrtable.append((
attr.get('name'),
required,
type_,
value,
self.get_description(attr, wrap=False)))
attrtable.append(
(
attr.get("name"),
required,
type_,
value,
self.get_description(attr, wrap=False),
)
)

if attrtable:
stream.write(make_rst_section('Attributes', '^'))
write_rst_table(stream, attrtable,
('attribute', 'req.', 'type', 'value', 'description'))

childtable = [(':ref:`{0} <qubesxml-element-{0}>`'.format(
child.xml.get('name')), n)
for child, n in self.get_child_elements()]
stream.write(make_rst_section("Attributes", "^"))
write_rst_table(
stream,
attrtable,
("attribute", "req.", "type", "value", "description"),
)

childtable = [
(
":ref:`{0} <qubesxml-element-{0}>`".format(
child.xml.get("name")
),
n,
)
for child, n in self.get_child_elements()
]
if childtable:
stream.write(make_rst_section('Child elements', '^'))
write_rst_table(stream, childtable, ('element', 'number'))
stream.write(make_rst_section("Child elements", "^"))
write_rst_table(stream, childtable, ("element", "number"))


class Schema:
# pylint: disable=too-few-public-methods
nsmap = {
'rng': 'http://relaxng.org/ns/structure/1.0',
'q': 'http://qubes-os.org/qubes/3',
'doc': 'http://qubes-os.org/qubes-doc/1'}
"rng": "http://relaxng.org/ns/structure/1.0",
"q": "http://qubes-os.org/qubes/3",
"doc": "http://qubes-os.org/qubes-doc/1",
}

def __init__(self, xml):
self.xml = xml

self.wrapper = textwrap.TextWrapper(width=80,
break_long_words=False, break_on_hyphens=False)
self.wrapper = textwrap.TextWrapper(
width=80, break_long_words=False, break_on_hyphens=False
)

self.elements = {}
for node in self.xml.xpath('//rng:element', namespaces=self.nsmap):
for node in self.xml.xpath("//rng:element", namespaces=self.nsmap):
element = Element(self, node)
self.elements[element.name] = element


def make_rst_section(heading, char):
return '{}\n{}\n\n'.format(heading, char[0] * len(heading))
return "{}\n{}\n\n".format(heading, char[0] * len(heading))


def write_rst_table(stream, itr, heads):
stream.write('.. csv-table::\n')
stream.write(' :header: {}\n'.format(', '.join('"{}"'.format(c)
for c in heads)))
stream.write(' :widths: {}\n\n'.format(', '.join('1'
for c in heads)))
stream.write(".. csv-table::\n")
stream.write(
" :header: {}\n".format(", ".join('"{}"'.format(c) for c in heads))
)
stream.write(" :widths: {}\n\n".format(", ".join("1" for c in heads)))

for row in itr:
stream.write(' {}\n'.format(', '.join('"{}"'.format(i) for i in row)))
stream.write(" {}\n".format(", ".join('"{}"'.format(i) for i in row)))

stream.write('\n')
stream.write("\n")


def main(filename, example):
with open(filename, 'rb') as schema_f:
with open(filename, "rb") as schema_f:
schema = Schema(lxml.etree.parse(schema_f))

sys.stdout.write(make_rst_section('Qubes XML specification', '='))
sys.stdout.write('''
sys.stdout.write(make_rst_section("Qubes XML specification", "="))
sys.stdout.write(
"""
This is the documentation of qubes.xml autogenerated from RelaxNG source.
Quick example, worth thousands lines of specification:
.. literalinclude:: {}
:language: xml
'''[1:].format(example))
"""[
1:
].format(
example
)
)

for name in sorted(schema.elements):
schema.elements[name].write_rst(sys.stdout)


if __name__ == '__main__':
if __name__ == "__main__":
# pylint: disable=no-value-for-parameter
main(*sys.argv[1:])

Expand Down
762 changes: 424 additions & 338 deletions qubes/storage/__init__.py

Large diffs are not rendered by default.

299 changes: 192 additions & 107 deletions qubes/storage/callback.py

Large diffs are not rendered by default.

384 changes: 217 additions & 167 deletions qubes/storage/file.py

Large diffs are not rendered by default.

87 changes: 49 additions & 38 deletions qubes/storage/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# License along with this library; if not, see <https://www.gnu.org/licenses/>.
#

''' This module contains pool implementations for different OS kernels. '''
""" This module contains pool implementations for different OS kernels. """

import os

Expand All @@ -29,10 +29,10 @@


class LinuxModules(Volume):
''' A volume representing a ro linux kernel '''
"""A volume representing a ro linux kernel"""

def __init__(self, target_dir, kernel_version, **kwargs):
kwargs['vid'] = ''
kwargs["vid"] = ""
super().__init__(**kwargs)
self._kernel_version = kernel_version
self.target_dir = target_dir
Expand Down Expand Up @@ -62,21 +62,21 @@ def path(self):
kernels_dir = self.kernels_dir
if not kernels_dir:
return None
return os.path.join(kernels_dir, 'modules.img')
return os.path.join(kernels_dir, "modules.img")

@property
def vmlinuz(self):
kernels_dir = self.kernels_dir
if not kernels_dir:
return None
return os.path.join(kernels_dir, 'vmlinuz')
return os.path.join(kernels_dir, "vmlinuz")

@property
def initramfs(self):
kernels_dir = self.kernels_dir
if not kernels_dir:
return None
return os.path.join(kernels_dir, 'initramfs')
return os.path.join(kernels_dir, "initramfs")

@property
def revisions(self):
Expand All @@ -90,8 +90,8 @@ async def import_volume(self, src_volume):
# do nothing
return self
raise StoragePoolException(
'clone of LinuxModules volume from different'
' volume type is not supported'
"clone of LinuxModules volume from different"
" volume type is not supported"
)

async def create(self):
Expand All @@ -104,7 +104,8 @@ def ephemeral(self):
@ephemeral.setter
def ephemeral(self, value):
raise qubes.exc.QubesValueError(
'LinuxModules does not support setting ephemeral value')
"LinuxModules does not support setting ephemeral value"
)

async def remove(self):
pass
Expand All @@ -123,7 +124,8 @@ def revisions_to_keep(self):
def revisions_to_keep(self, value):
if value:
raise qubes.exc.QubesValueError(
'LinuxModules supports only revisions_to_keep=0')
"LinuxModules supports only revisions_to_keep=0"
)

@property
def rw(self):
Expand All @@ -133,7 +135,8 @@ def rw(self):
def rw(self, value):
if value:
raise qubes.exc.QubesValueError(
'LinuxModules supports only read-only volumes')
"LinuxModules supports only read-only volumes"
)

async def start(self):
return self
Expand All @@ -156,35 +159,38 @@ def block_device(self):


class LinuxKernel(Pool):
''' Provides linux kernels '''
driver = 'linux-kernel'
"""Provides linux kernels"""

driver = "linux-kernel"

def __init__(self, *, name, dir_path):
super().__init__(name=name, revisions_to_keep=0)
self.dir_path = dir_path

def init_volume(self, vm, volume_config):
assert not volume_config['rw']
assert not volume_config["rw"]

# migrate old config
if volume_config.get('snap_on_start', False) and not \
volume_config.get('source', None):
volume_config['snap_on_start'] = False
if volume_config.get("snap_on_start", False) and not volume_config.get(
"source", None
):
volume_config["snap_on_start"] = False

if volume_config.get('save_on_stop', False):
if volume_config.get("save_on_stop", False):
raise NotImplementedError(
'LinuxKernel pool does not support save_on_stop=True')
volume_config['pool'] = self
"LinuxKernel pool does not support save_on_stop=True"
)
volume_config["pool"] = self
volume = LinuxModules(self.dir_path, lambda: vm.kernel, **volume_config)

return volume

@property
def config(self):
return {
'name': self.name,
'dir_path': self.dir_path,
'driver': LinuxKernel.driver,
"name": self.name,
"dir_path": self.dir_path,
"driver": LinuxKernel.driver,
}

async def destroy(self):
Expand All @@ -204,28 +210,33 @@ def revisions_to_keep(self):
def revisions_to_keep(self, value):
if value:
raise qubes.exc.QubesValueError(
'LinuxKernel supports only revisions_to_keep=0')
"LinuxKernel supports only revisions_to_keep=0"
)

def included_in(self, app):
''' Check if there is pool containing /var/lib/qubes/vm-kernels '''
"""Check if there is pool containing /var/lib/qubes/vm-kernels"""
return qubes.storage.search_pool_containing_dir(
[pool for pool in app.pools.values() if pool is not self],
self.dir_path)
self.dir_path,
)

def list_volumes(self):
''' Return all known kernel volumes '''
return [LinuxModules(self.dir_path,
kernel_version,
pool=self,
name=kernel_version,
rw=False
)
for kernel_version in os.listdir(self.dir_path)]
"""Return all known kernel volumes"""
return [
LinuxModules(
self.dir_path,
kernel_version,
pool=self,
name=kernel_version,
rw=False,
)
for kernel_version in os.listdir(self.dir_path)
]


def _check_path(path):
''' Raise an :py:class:`qubes.storage.StoragePoolException` if ``path`` does
not exist.
'''
"""Raise an :py:class:`qubes.storage.StoragePoolException` if ``path`` does
not exist.
"""
if not os.path.exists(path):
raise StoragePoolException('Missing file: %s' % path)
raise StoragePoolException("Missing file: %s" % path)
619 changes: 369 additions & 250 deletions qubes/storage/lvm.py

Large diffs are not rendered by default.

280 changes: 169 additions & 111 deletions qubes/storage/reflink.py

Large diffs are not rendered by default.

120 changes: 75 additions & 45 deletions qubes/tarwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

BUF_SIZE = 409600


class TarSparseInfo(tarfile.TarInfo):
def __init__(self, name="", sparsemap=None):
super().__init__(name)
Expand All @@ -33,17 +34,19 @@ def __init__(self, name="", sparsemap=None):
self.sparsemap = sparsemap
self.sparsemap_buf = self.format_sparse_map()
# compact size
self.size = functools.reduce(lambda x, y: x+y[1], sparsemap,
0) + len(self.sparsemap_buf)
self.pax_headers['GNU.sparse.major'] = '1'
self.pax_headers['GNU.sparse.minor'] = '0'
self.pax_headers['GNU.sparse.name'] = name
self.pax_headers['GNU.sparse.realsize'] = str(self.realsize)
self.name = '{}/GNUSparseFile.{}/{}'.format(
os.path.dirname(name), os.getpid(), os.path.basename(name))
self.size = functools.reduce(
lambda x, y: x + y[1], sparsemap, 0
) + len(self.sparsemap_buf)
self.pax_headers["GNU.sparse.major"] = "1"
self.pax_headers["GNU.sparse.minor"] = "0"
self.pax_headers["GNU.sparse.name"] = name
self.pax_headers["GNU.sparse.realsize"] = str(self.realsize)
self.name = "{}/GNUSparseFile.{}/{}".format(
os.path.dirname(name), os.getpid(), os.path.basename(name)
)
else:
self.sparsemap = []
self.sparsemap_buf = b''
self.sparsemap_buf = b""

@property
def realsize(self):
Expand All @@ -52,24 +55,33 @@ def realsize(self):
return self.size

def format_sparse_map(self):
sparsemap_txt = (str(len(self.sparsemap)) + '\n' +
''.join('{}\n{}\n'.format(*entry) for entry in self.sparsemap))
sparsemap_txt = (
str(len(self.sparsemap))
+ "\n"
+ "".join("{}\n{}\n".format(*entry) for entry in self.sparsemap)
)
sparsemap_txt_len = len(sparsemap_txt)
if sparsemap_txt_len % tarfile.BLOCKSIZE:
padding = '\0' * (tarfile.BLOCKSIZE -
sparsemap_txt_len % tarfile.BLOCKSIZE)
padding = "\0" * (
tarfile.BLOCKSIZE - sparsemap_txt_len % tarfile.BLOCKSIZE
)
else:
padding = ''
padding = ""
return (sparsemap_txt + padding).encode()

def tobuf(self, format=tarfile.PAX_FORMAT, encoding=tarfile.ENCODING,
errors="strict"):
def tobuf(
self,
format=tarfile.PAX_FORMAT,
encoding=tarfile.ENCODING,
errors="strict",
):
# pylint: disable=redefined-builtin
header_buf = super().tobuf(format, encoding, errors)
return header_buf + self.sparsemap_buf


def get_sparse_map(input_file):
'''
"""
Return map of the file where actual data is present, ignoring zero-ed
blocks. Last entry of the map spans to the end of file, even if that part is
zero-size (when file ends with zeros).
Expand All @@ -78,7 +90,7 @@ def get_sparse_map(input_file):
:param input_file: io.File object
:return: iterable of (offset, size)
'''
"""
zero_block = bytearray(tarfile.BLOCKSIZE)
buf = bytearray(BUF_SIZE)
in_data_block = False
Expand All @@ -89,31 +101,33 @@ def get_sparse_map(input_file):
if not buf_len:
break
for offset in range(0, buf_len, tarfile.BLOCKSIZE):
if buf[offset:offset+tarfile.BLOCKSIZE] == zero_block:
if buf[offset : offset + tarfile.BLOCKSIZE] == zero_block:
if in_data_block:
in_data_block = False
yield (data_block_start,
buf_start_offset+offset-data_block_start)
yield (
data_block_start,
buf_start_offset + offset - data_block_start,
)
else:
if not in_data_block:
in_data_block = True
data_block_start = buf_start_offset+offset
data_block_start = buf_start_offset + offset
buf_start_offset += buf_len
if in_data_block:
yield (data_block_start, buf_start_offset-data_block_start)
yield (data_block_start, buf_start_offset - data_block_start)
else:
# always emit last slice to the input end - otherwise extracted file
# will be truncated
yield (buf_start_offset, 0)


def copy_sparse_data(input_stream, output_stream, sparse_map):
'''Copy data blocks from input to output according to sparse_map
"""Copy data blocks from input to output according to sparse_map
:param input_stream: io.IOBase input instance
:param output_stream: io.IOBase output instance
:param sparse_map: iterable of (offset, size)
'''
"""

buf = bytearray(BUF_SIZE)

Expand All @@ -130,38 +144,53 @@ def copy_sparse_data(input_stream, output_stream, sparse_map):
output_stream.write(buf_trailer)
left -= read
if not read:
raise EOFError('premature EOF')
raise EOFError("premature EOF")


def finalize(output):
'''Write EOF blocks'''
output.write(b'\0' * 512)
output.write(b'\0' * 512)
"""Write EOF blocks"""
output.write(b"\0" * 512)
output.write(b"\0" * 512)


def main(args=None):
parser = argparse.ArgumentParser()
parser.add_argument('--override-name', action='store', dest='override_name',
help='use this name in tar header')
parser.add_argument('--use-compress-program', default=None,
metavar='COMMAND', action='store', dest='use_compress_program',
help='Filter data through COMMAND.')
parser.add_argument('input_file',
help='input file name')
parser.add_argument('output_file', default='-', nargs='?',
help='output file name')
parser.add_argument(
"--override-name",
action="store",
dest="override_name",
help="use this name in tar header",
)
parser.add_argument(
"--use-compress-program",
default=None,
metavar="COMMAND",
action="store",
dest="use_compress_program",
help="Filter data through COMMAND.",
)
parser.add_argument("input_file", help="input file name")
parser.add_argument(
"output_file", default="-", nargs="?", help="output file name"
)
args = parser.parse_args(args)
with io.open(args.input_file, 'rb') as input_file:
with io.open(args.input_file, "rb") as input_file:
sparse_map = list(get_sparse_map(input_file))
header_name = args.input_file
if args.override_name:
header_name = args.override_name
tar_info = TarSparseInfo(header_name, sparse_map)
with io.open(('/dev/stdout' if args.output_file == '-'
else args.output_file),
'wb') as output:
with io.open(
("/dev/stdout" if args.output_file == "-" else args.output_file),
"wb",
) as output:
if args.use_compress_program:
# pylint: disable=consider-using-with
compress = subprocess.Popen([args.use_compress_program],
stdin=subprocess.PIPE, stdout=output)
compress = subprocess.Popen(
[args.use_compress_program],
stdin=subprocess.PIPE,
stdout=output,
)
output = compress.stdin
else:
compress = None
Expand All @@ -174,5 +203,6 @@ def main(args=None):
return compress.returncode
return 0

if __name__ == '__main__':

if __name__ == "__main__":
main()
966 changes: 607 additions & 359 deletions qubes/tests/__init__.py

Large diffs are not rendered by default.

105 changes: 62 additions & 43 deletions qubes/tests/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,72 +37,77 @@ def __init__(self, app, src, method, dest, arg, send_event=None):
self.send_event = send_event
try:
self.function = {
'mgmt.success': self.success,
'mgmt.success_none': self.success_none,
'mgmt.qubesexception': self.qubesexception,
'mgmt.exception': self.exception,
'mgmt.event': self.event,
"mgmt.success": self.success,
"mgmt.success_none": self.success_none,
"mgmt.qubesexception": self.qubesexception,
"mgmt.exception": self.exception,
"mgmt.event": self.event,
}[self.method.decode()]
except KeyError:
raise qubes.exc.ProtocolError('Invalid method')
raise qubes.exc.ProtocolError("Invalid method")

def execute(self, untrusted_payload):
self.task = asyncio.Task(self.function(
untrusted_payload=untrusted_payload))
self.task = asyncio.Task(
self.function(untrusted_payload=untrusted_payload)
)
return self.task

def cancel(self):
self.task.cancel()

async def success(self, untrusted_payload):
return 'src: {!r}, dest: {!r}, arg: {!r}, payload: {!r}'.format(
return "src: {!r}, dest: {!r}, arg: {!r}, payload: {!r}".format(
self.src, self.dest, self.arg, untrusted_payload
)

async def success_none(self, untrusted_payload):
pass

async def qubesexception(self, untrusted_payload):
raise qubes.exc.QubesException('qubes-exception')
raise qubes.exc.QubesException("qubes-exception")

async def exception(self, untrusted_payload):
raise Exception('exception')
raise Exception("exception")

async def event(self, untrusted_payload):
future = asyncio.get_event_loop().create_future()

class Subject:
name = 'subject'
name = "subject"

def __str__(self):
return 'subject'
return "subject"

self.send_event(Subject(), 'event', payload=untrusted_payload.decode())
self.send_event(Subject(), "event", payload=untrusted_payload.decode())
try:
# give some time to close the other end
await asyncio.sleep(0.1)
# should be canceled
self.send_event(Subject, 'event2',
payload=untrusted_payload.decode())
self.send_event(
Subject, "event2", payload=untrusted_payload.decode()
)
await future
except asyncio.CancelledError:
pass


class TC_00_QubesDaemonProtocol(qubes.tests.QubesTestCase):
def setUp(self):
super(TC_00_QubesDaemonProtocol, self).setUp()
self.app = unittest.mock.Mock()
self.app.log = self.log
self.sock_client, self.sock_server = socket.socketpair()
self.reader, self.writer = self.loop.run_until_complete(
asyncio.open_connection(sock=self.sock_client))
asyncio.open_connection(sock=self.sock_client)
)

connect_coro = self.loop.create_connection(
lambda: qubes.api.QubesDaemonProtocol(
TestMgmt, app=self.app),
sock=self.sock_server)
lambda: qubes.api.QubesDaemonProtocol(TestMgmt, app=self.app),
sock=self.sock_server,
)
self.transport, self.protocol = self.loop.run_until_complete(
connect_coro)
connect_coro
)

def tearDown(self):
self.writer.close()
Expand All @@ -114,68 +119,82 @@ def tearDown(self):
super(TC_00_QubesDaemonProtocol, self).tearDown()

def test_000_message_ok(self):
self.writer.write(b'mgmt.success+arg src name dest\0payload')
self.writer.write(b"mgmt.success+arg src name dest\0payload")
self.writer.write_eof()
with self.assertNotRaises(asyncio.TimeoutError):
response = self.loop.run_until_complete(
asyncio.wait_for(self.reader.read(), 1))
self.assertEqual(response,
b"0\0src: b'src', dest: b'dest', arg: b'arg', payload: b'payload'")
asyncio.wait_for(self.reader.read(), 1)
)
self.assertEqual(
response,
b"0\0src: b'src', dest: b'dest', arg: b'arg', payload: b'payload'",
)

def test_001_message_ok_in_parts(self):
self.writer.write(b'mgmt.success+arg')
self.writer.write(b"mgmt.success+arg")
self.loop.run_until_complete(self.writer.drain())
self.writer.write(b' dom0 name dom0\0payload')
self.writer.write(b" dom0 name dom0\0payload")
self.writer.write_eof()
with self.assertNotRaises(asyncio.TimeoutError):
response = self.loop.run_until_complete(
asyncio.wait_for(self.reader.read(), 1))
self.assertEqual(response,
b"0\0src: b'dom0', dest: b'dom0', arg: b'arg', payload: b'payload'")
asyncio.wait_for(self.reader.read(), 1)
)
self.assertEqual(
response,
b"0\0src: b'dom0', dest: b'dom0', arg: b'arg', payload: b'payload'",
)

def test_002_message_ok_empty(self):
self.writer.write(b'mgmt.success_none+arg dom0 name dom0\0payload')
self.writer.write(b"mgmt.success_none+arg dom0 name dom0\0payload")
self.writer.write_eof()
with self.assertNotRaises(asyncio.TimeoutError):
response = self.loop.run_until_complete(
asyncio.wait_for(self.reader.read(), 1))
asyncio.wait_for(self.reader.read(), 1)
)
self.assertEqual(response, b"0\0")

def test_003_exception_qubes(self):
self.writer.write(b'mgmt.qubesexception+arg dom0 name dom0\0payload')
self.writer.write(b"mgmt.qubesexception+arg dom0 name dom0\0payload")
self.writer.write_eof()
with self.assertNotRaises(asyncio.TimeoutError):
response = self.loop.run_until_complete(
asyncio.wait_for(self.reader.read(), 1))
asyncio.wait_for(self.reader.read(), 1)
)
self.assertEqual(response, b"2\0QubesException\0\0qubes-exception\0")

def test_004_exception_generic(self):
self.writer.write(b'mgmt.exception+arg dom0 name dom0\0payload')
self.writer.write(b"mgmt.exception+arg dom0 name dom0\0payload")
self.writer.write_eof()
with self.assertNotRaises(asyncio.TimeoutError):
response = self.loop.run_until_complete(
asyncio.wait_for(self.reader.read(), 1))
asyncio.wait_for(self.reader.read(), 1)
)
self.assertEqual(response, b"")

def test_005_event(self):
self.writer.write(b'mgmt.event+arg dom0 name dom0\0payload')
self.writer.write(b"mgmt.event+arg dom0 name dom0\0payload")
self.writer.write_eof()
with self.assertNotRaises(asyncio.TimeoutError):
response = self.loop.run_until_complete(
asyncio.wait_for(self.reader.readuntil(b'\0\0'), 1))
asyncio.wait_for(self.reader.readuntil(b"\0\0"), 1)
)
self.assertEqual(response, b"1\0subject\0event\0payload\0payload\0\0")
# this will trigger connection_lost, but only when next event is sent
self.sock_client.shutdown(socket.SHUT_RD)
# check if event-producing method is interrupted
with self.assertNotRaises(asyncio.TimeoutError):
self.loop.run_until_complete(
asyncio.wait_for(self.protocol.mgmt.task, 1))
asyncio.wait_for(self.protocol.mgmt.task, 1)
)

def test_006_target_adminvm(self):
self.writer.write(b'mgmt.success+arg src keyword adminvm\0payload')
self.writer.write(b"mgmt.success+arg src keyword adminvm\0payload")
self.writer.write_eof()
with self.assertNotRaises(asyncio.TimeoutError):
response = self.loop.run_until_complete(
asyncio.wait_for(self.reader.read(), 1))
self.assertEqual(response,
b"0\0src: b'src', dest: b'dom0', arg: b'arg', payload: b'payload'")
asyncio.wait_for(self.reader.read(), 1)
)
self.assertEqual(
response,
b"0\0src: b'src', dest: b'dom0', arg: b'arg', payload: b'payload'",
)
4,512 changes: 2,665 additions & 1,847 deletions qubes/tests/api_admin.py

Large diffs are not rendered by default.

197 changes: 109 additions & 88 deletions qubes/tests/api_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,32 @@
import json
import uuid


def mock_coro(f):
async def coro_f(*args, **kwargs):
return f(*args, **kwargs)

return coro_f


TEST_UUID = uuid.UUID("50c7dad4-5f1e-4586-9f6a-bf10a86ba6f0")


class TC_00_API_Misc(qubes.tests.QubesTestCase):
def setUp(self):
super().setUp()
self.app = mock.NonCallableMock()
self.dom0 = mock.NonCallableMock(spec=qubes.vm.adminvm.AdminVM)
self.dom0.name = 'dom0'
self.dom0.name = "dom0"
self.domains = {
'dom0': self.dom0,
"dom0": self.dom0,
}
self.app.domains = mock.MagicMock(**{
'__iter__.side_effect': lambda: iter(self.domains.values()),
'__getitem__.side_effect': self.domains.get,
})
self.app.domains = mock.MagicMock(
**{
"__iter__.side_effect": lambda: iter(self.domains.values()),
"__getitem__.side_effect": self.domains.get,
}
)

def tearDown(self):
self.domains.clear()
Expand All @@ -58,137 +63,153 @@ def create_mockvm(self, features=None):
vm = mock.Mock()
vm.features.check_with_template.side_effect = features.get
vm.run_service.return_value.wait = mock_coro(
vm.run_service.return_value.wait)
vm.run_service.return_value.wait
)
vm.run_service = mock_coro(vm.run_service)
vm.suspend = mock_coro(vm.suspend)
vm.resume = mock_coro(vm.resume)
return vm

def call_mgmt_func(self, method, arg=b'', payload=b''):
mgmt_obj = qubes.api.internal.QubesInternalAPI(self.app,
b'dom0', method, b'dom0', arg)
def call_mgmt_func(self, method, arg=b"", payload=b""):
mgmt_obj = qubes.api.internal.QubesInternalAPI(
self.app, b"dom0", method, b"dom0", arg
)

loop = asyncio.get_event_loop()
response = loop.run_until_complete(
mgmt_obj.execute(untrusted_payload=payload))
mgmt_obj.execute(untrusted_payload=payload)
)
return response

def test_000_suspend_pre(self):
running_vm = self.create_mockvm(features={'qrexec': True})
running_vm = self.create_mockvm(features={"qrexec": True})
running_vm.is_running.return_value = True

not_running_vm = self.create_mockvm(features={'qrexec': True})
not_running_vm = self.create_mockvm(features={"qrexec": True})
not_running_vm.is_running.return_value = False

no_qrexec_vm = self.create_mockvm()
no_qrexec_vm.is_running.return_value = True

self.domains.update({
'running': running_vm,
'not-running': not_running_vm,
'no-qrexec': no_qrexec_vm,
})
self.domains.update(
{
"running": running_vm,
"not-running": not_running_vm,
"no-qrexec": no_qrexec_vm,
}
)

ret = self.call_mgmt_func(b'internal.SuspendPre')
ret = self.call_mgmt_func(b"internal.SuspendPre")
self.assertIsNone(ret)
self.assertFalse(self.dom0.called)

self.assertNotIn(('run_service', ('qubes.SuspendPreAll',), mock.ANY),
not_running_vm.mock_calls)
self.assertNotIn(('suspend', (), {}),
not_running_vm.mock_calls)
self.assertNotIn(
("run_service", ("qubes.SuspendPreAll",), mock.ANY),
not_running_vm.mock_calls,
)
self.assertNotIn(("suspend", (), {}), not_running_vm.mock_calls)

self.assertIn(('run_service', ('qubes.SuspendPreAll',), mock.ANY),
running_vm.mock_calls)
self.assertIn(('suspend', (), {}),
running_vm.mock_calls)
self.assertIn(
("run_service", ("qubes.SuspendPreAll",), mock.ANY),
running_vm.mock_calls,
)
self.assertIn(("suspend", (), {}), running_vm.mock_calls)

self.assertNotIn(('run_service', ('qubes.SuspendPreAll',), mock.ANY),
no_qrexec_vm.mock_calls)
self.assertIn(('suspend', (), {}),
no_qrexec_vm.mock_calls)
self.assertNotIn(
("run_service", ("qubes.SuspendPreAll",), mock.ANY),
no_qrexec_vm.mock_calls,
)
self.assertIn(("suspend", (), {}), no_qrexec_vm.mock_calls)

def test_001_suspend_post(self):
running_vm = self.create_mockvm(features={'qrexec': True})
running_vm = self.create_mockvm(features={"qrexec": True})
running_vm.is_running.return_value = True
running_vm.get_power_state.return_value = 'Suspended'
running_vm.get_power_state.return_value = "Suspended"

not_running_vm = self.create_mockvm(features={'qrexec': True})
not_running_vm = self.create_mockvm(features={"qrexec": True})
not_running_vm.is_running.return_value = False
not_running_vm.get_power_state.return_value = 'Halted'
not_running_vm.get_power_state.return_value = "Halted"

no_qrexec_vm = self.create_mockvm()
no_qrexec_vm.is_running.return_value = True
no_qrexec_vm.get_power_state.return_value = 'Suspended'
no_qrexec_vm.get_power_state.return_value = "Suspended"

self.domains.update({
'running': running_vm,
'not-running': not_running_vm,
'no-qrexec': no_qrexec_vm,
})
self.domains.update(
{
"running": running_vm,
"not-running": not_running_vm,
"no-qrexec": no_qrexec_vm,
}
)

ret = self.call_mgmt_func(b'internal.SuspendPost')
ret = self.call_mgmt_func(b"internal.SuspendPost")
self.assertIsNone(ret)
self.assertFalse(self.dom0.called)

self.assertNotIn(('run_service', ('qubes.SuspendPostAll',), mock.ANY),
not_running_vm.mock_calls)
self.assertNotIn(('resume', (), {}),
not_running_vm.mock_calls)
self.assertNotIn(
("run_service", ("qubes.SuspendPostAll",), mock.ANY),
not_running_vm.mock_calls,
)
self.assertNotIn(("resume", (), {}), not_running_vm.mock_calls)

self.assertIn(('run_service', ('qubes.SuspendPostAll',), mock.ANY),
running_vm.mock_calls)
self.assertIn(('resume', (), {}),
running_vm.mock_calls)
self.assertIn(
("run_service", ("qubes.SuspendPostAll",), mock.ANY),
running_vm.mock_calls,
)
self.assertIn(("resume", (), {}), running_vm.mock_calls)

self.assertNotIn(('run_service', ('qubes.SuspendPostAll',), mock.ANY),
no_qrexec_vm.mock_calls)
self.assertIn(('resume', (), {}),
no_qrexec_vm.mock_calls)
self.assertNotIn(
("run_service", ("qubes.SuspendPostAll",), mock.ANY),
no_qrexec_vm.mock_calls,
)
self.assertIn(("resume", (), {}), no_qrexec_vm.mock_calls)

def test_010_get_system_info(self):
self.dom0.name = 'dom0'
self.dom0.tags = ['tag1', 'tag2']
self.dom0.name = "dom0"
self.dom0.tags = ["tag1", "tag2"]
self.dom0.default_dispvm = None
self.dom0.template_for_dispvms = False
self.dom0.label.icon = 'icon-dom0'
self.dom0.get_power_state.return_value = 'Running'
self.dom0.label.icon = "icon-dom0"
self.dom0.get_power_state.return_value = "Running"
self.dom0.uuid = uuid.UUID("00000000-0000-0000-0000-000000000000")
del self.dom0.guivm

vm = mock.NonCallableMock(spec=qubes.vm.qubesvm.QubesVM)
vm.name = 'vm'
vm.tags = ['tag3', 'tag4']
vm.name = "vm"
vm.tags = ["tag3", "tag4"]
vm.default_dispvm = vm
vm.template_for_dispvms = True
vm.label.icon = 'icon-vm'
vm.label.icon = "icon-vm"
vm.guivm = vm
vm.get_power_state.return_value = 'Halted'
vm.get_power_state.return_value = "Halted"
vm.uuid = TEST_UUID
self.domains['vm'] = vm

ret = json.loads(self.call_mgmt_func(b'internal.GetSystemInfo'))
self.assertEqual(ret, {
'domains': {
'dom0': {
'tags': ['tag1', 'tag2'],
'type': 'AdminVM',
'default_dispvm': None,
'template_for_dispvms': False,
'icon': 'icon-dom0',
'guivm': None,
'power_state': 'Running',
'uuid': "00000000-0000-0000-0000-000000000000",
},
'vm': {
'tags': ['tag3', 'tag4'],
'type': 'QubesVM',
'default_dispvm': 'vm',
'template_for_dispvms': True,
'icon': 'icon-vm',
'guivm': 'vm',
'power_state': 'Halted',
"uuid": str(TEST_UUID),
self.domains["vm"] = vm

ret = json.loads(self.call_mgmt_func(b"internal.GetSystemInfo"))
self.assertEqual(
ret,
{
"domains": {
"dom0": {
"tags": ["tag1", "tag2"],
"type": "AdminVM",
"default_dispvm": None,
"template_for_dispvms": False,
"icon": "icon-dom0",
"guivm": None,
"power_state": "Running",
"uuid": "00000000-0000-0000-0000-000000000000",
},
"vm": {
"tags": ["tag3", "tag4"],
"type": "QubesVM",
"default_dispvm": "vm",
"template_for_dispvms": True,
"icon": "icon-vm",
"guivm": "vm",
"power_state": "Halted",
"uuid": str(TEST_UUID),
},
}
}
})
},
)
439 changes: 258 additions & 181 deletions qubes/tests/api_misc.py

Large diffs are not rendered by default.

841 changes: 512 additions & 329 deletions qubes/tests/app.py

Large diffs are not rendered by default.

475 changes: 272 additions & 203 deletions qubes/tests/devices.py

Large diffs are not rendered by default.

1,213 changes: 688 additions & 525 deletions qubes/tests/devices_block.py

Large diffs are not rendered by default.

62 changes: 37 additions & 25 deletions qubes/tests/devices_pci.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


class TestVM(object):
def __init__(self, running=True, name='dom0', qid=0):
def __init__(self, running=True, name="dom0", qid=0):
self.name = name
self.qid = qid
self.is_running = lambda: running
Expand Down Expand Up @@ -122,36 +122,48 @@ def setUp(self):
super().setUp()
self.ext = qubes.ext.pci.PCIDeviceExtension()

@mock.patch('builtins.open', new=mock_file_open)
@mock.patch("builtins.open", new=mock_file_open)
def test_000_unsupported_device(self):
vm = TestVM()
vm.app.configure_mock(**{
'vmm.offline_mode': False,
'vmm.libvirt_conn.nodeDeviceLookupByName.return_value':
mock.Mock(**{"XMLDesc.return_value":
PCI_XML.format(*["0000"] * 3)
}),
'vmm.libvirt_conn.listAllDevices.return_value':
[mock.Mock(**{"XMLDesc.return_value":
PCI_XML.format(*["0000"] * 3),
"listCaps.return_value": ["pci"]
}),
mock.Mock(**{"XMLDesc.return_value":
PCI_XML.format(*["1000"] * 3),
"listCaps.return_value": ["pci"]
}),
]
})
devices = list(self.ext.on_device_list_pci(vm, 'device-list:pci'))
vm.app.configure_mock(
**{
"vmm.offline_mode": False,
"vmm.libvirt_conn.nodeDeviceLookupByName.return_value": mock.Mock(
**{"XMLDesc.return_value": PCI_XML.format(*["0000"] * 3)}
),
"vmm.libvirt_conn.listAllDevices.return_value": [
mock.Mock(
**{
"XMLDesc.return_value": PCI_XML.format(
*["0000"] * 3
),
"listCaps.return_value": ["pci"],
}
),
mock.Mock(
**{
"XMLDesc.return_value": PCI_XML.format(
*["1000"] * 3
),
"listCaps.return_value": ["pci"],
}
),
],
}
)
devices = list(self.ext.on_device_list_pci(vm, "device-list:pci"))
self.assertEqual(len(devices), 1)
self.assertEqual(devices[0].port_id, "00_14.0")
self.assertEqual(devices[0].vendor, "Intel Corporation")
self.assertEqual(devices[0].product,
"9 Series Chipset Family USB xHCI Controller")
self.assertEqual(
devices[0].product, "9 Series Chipset Family USB xHCI Controller"
)
self.assertEqual(devices[0].interfaces, [DeviceInterface("p0c0330")])
self.assertEqual(devices[0].parent_device, None)
self.assertEqual(devices[0].libvirt_name, "pci_0000_00_14_0")
self.assertEqual(devices[0].description,
"USB controller: Intel Corporation 9 Series "
"Chipset Family USB xHCI Controller")
self.assertEqual(
devices[0].description,
"USB controller: Intel Corporation 9 Series "
"Chipset Family USB xHCI Controller",
)
self.assertEqual(devices[0].device_id, "0x8086:0x8cb1::p0c0330")
110 changes: 54 additions & 56 deletions qubes/tests/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,62 +22,63 @@
import qubes.events
import qubes.tests


class TC_00_Emitter(qubes.tests.QubesTestCase):
def test_000_add_handler(self):
# need something mutable
testevent_fired = [False]

def on_testevent(subject, event):
# pylint: disable=unused-argument
if event == 'testevent':
if event == "testevent":
testevent_fired[0] = True

emitter = qubes.events.Emitter()
emitter.add_handler('testevent', on_testevent)
emitter.add_handler("testevent", on_testevent)
emitter.events_enabled = True
emitter.fire_event('testevent')
emitter.fire_event("testevent")
self.assertTrue(testevent_fired[0])


def test_001_decorator(self):
class TestEmitter(qubes.events.Emitter):
def __init__(self):
# pylint: disable=bad-super-call
super(TestEmitter, self).__init__()
self.testevent_fired = False

@qubes.events.handler('testevent')
@qubes.events.handler("testevent")
def on_testevent(self, event):
if event == 'testevent':
if event == "testevent":
self.testevent_fired = True

emitter = TestEmitter()
emitter.events_enabled = True
emitter.fire_event('testevent')
emitter.fire_event("testevent")
self.assertTrue(emitter.testevent_fired)

def test_002_fire_for_effect(self):
class TestEmitter(qubes.events.Emitter):
@qubes.events.handler('testevent')
@qubes.events.handler("testevent")
def on_testevent_1(self, event):
pass

@qubes.events.handler('testevent')
@qubes.events.handler("testevent")
def on_testevent_2(self, event):
yield 'testvalue1'
yield 'testvalue2'
yield "testvalue1"
yield "testvalue2"

@qubes.events.handler('testevent')
@qubes.events.handler("testevent")
def on_testevent_3(self, event):
return ('testvalue3', 'testvalue4')
return ("testvalue3", "testvalue4")

emitter = TestEmitter()
emitter.events_enabled = True

effect = emitter.fire_event('testevent')
effect = emitter.fire_event("testevent")

self.assertCountEqual(effect,
('testvalue1', 'testvalue2', 'testvalue3', 'testvalue4'))
self.assertCountEqual(
effect, ("testvalue1", "testvalue2", "testvalue3", "testvalue4")
)

def test_004_catch_all(self):
# need something mutable
Expand All @@ -92,76 +93,73 @@ def on_foo(subject, event, *args, **kwargs):
testevent_fired[0] += 1

emitter = qubes.events.Emitter()
emitter.add_handler('*', on_all)
emitter.add_handler('foo', on_foo)
emitter.add_handler("*", on_all)
emitter.add_handler("foo", on_foo)
emitter.events_enabled = True
emitter.fire_event('testevent')
emitter.fire_event("testevent")
self.assertEqual(testevent_fired[0], 1)
emitter.fire_event('foo')
emitter.fire_event("foo")
# now catch-all and foo should be executed
self.assertEqual(testevent_fired[0], 3)
emitter.fire_event('bar')
emitter.fire_event("bar")
self.assertEqual(testevent_fired[0], 4)

def test_005_instance_handlers(self):
class TestEmitter(qubes.events.Emitter):
@qubes.events.handler('testevent')
@qubes.events.handler("testevent")
def on_testevent_1(self, event):
yield 'testevent_1'
yield "testevent_1"

def on_testevent_2(subject, event):
yield 'testevent_2'
yield "testevent_2"

emitter = TestEmitter()
emitter.add_handler('testevent', on_testevent_2)
emitter.add_handler("testevent", on_testevent_2)
emitter.events_enabled = True

emitter2 = TestEmitter()
emitter2.events_enabled = True

with self.subTest('fire_event'):
effect = emitter.fire_event('testevent')
effect2 = emitter2.fire_event('testevent')
self.assertEqual(list(effect),
['testevent_1', 'testevent_2'])
self.assertEqual(list(effect2),
['testevent_1'])

with self.subTest('fire_event_pre'):
effect = emitter.fire_event('testevent', pre_event=True)
effect2 = emitter2.fire_event('testevent', pre_event=True)
self.assertEqual(list(effect),
['testevent_2', 'testevent_1'])
self.assertEqual(list(effect2),
['testevent_1'])
with self.subTest("fire_event"):
effect = emitter.fire_event("testevent")
effect2 = emitter2.fire_event("testevent")
self.assertEqual(list(effect), ["testevent_1", "testevent_2"])
self.assertEqual(list(effect2), ["testevent_1"])

with self.subTest("fire_event_pre"):
effect = emitter.fire_event("testevent", pre_event=True)
effect2 = emitter2.fire_event("testevent", pre_event=True)
self.assertEqual(list(effect), ["testevent_2", "testevent_1"])
self.assertEqual(list(effect2), ["testevent_1"])

def test_005_fire_for_effect_async(self):
class TestEmitter(qubes.events.Emitter):
@qubes.events.handler('testevent')
@qubes.events.handler("testevent")
async def on_testevent_1(self, event):
pass

@qubes.events.handler('testevent')
@qubes.events.handler("testevent")
async def on_testevent_2(self, event):
await asyncio.sleep(0.01)
return ['testvalue1']
return ["testvalue1"]

@qubes.events.handler('testevent')
@qubes.events.handler("testevent")
async def on_testevent_3(self, event):
return ('testvalue2', 'testvalue3')
return ("testvalue2", "testvalue3")

@qubes.events.handler('testevent')
@qubes.events.handler("testevent")
def on_testevent_4(self, event):
return ('testvalue4',)
return ("testvalue4",)

loop = asyncio.get_event_loop()
emitter = TestEmitter()
emitter.events_enabled = True

effect = loop.run_until_complete(emitter.fire_event_async('testevent'))
effect = loop.run_until_complete(emitter.fire_event_async("testevent"))

self.assertCountEqual(effect,
('testvalue1', 'testvalue2', 'testvalue3', 'testvalue4'))
self.assertCountEqual(
effect, ("testvalue1", "testvalue2", "testvalue3", "testvalue4")
)

def test_006_wildcard(self):
# need something mutable
Expand All @@ -176,15 +174,15 @@ def on_foo(subject, event, *args, **kwargs):
testevent_fired[0] += 1

emitter = qubes.events.Emitter()
emitter.add_handler('foo:*', on_foo)
emitter.add_handler('foo:bar', on_foobar)
emitter.add_handler("foo:*", on_foo)
emitter.add_handler("foo:bar", on_foobar)
emitter.events_enabled = True
emitter.fire_event('foo:testevent')
emitter.fire_event("foo:testevent")
self.assertEqual(testevent_fired[0], 1)
emitter.fire_event('foo:bar')
emitter.fire_event("foo:bar")
# now foo:bar and foo:* should be executed
self.assertEqual(testevent_fired[0], 3)
emitter.fire_event('foo:')
emitter.fire_event("foo:")
self.assertEqual(testevent_fired[0], 4)
emitter.fire_event('testevent')
emitter.fire_event("testevent")
self.assertEqual(testevent_fired[0], 4)
1,248 changes: 775 additions & 473 deletions qubes/tests/ext.py

Large diffs are not rendered by default.

189 changes: 128 additions & 61 deletions qubes/tests/extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import qubes.tests
import qubes.vm.appvm


class ProcessWrapper(object):
def __init__(self, proc, loop=None):
self._proc = proc
Expand All @@ -37,20 +38,22 @@ def __getattr__(self, item):
return getattr(self._proc, item)

def __setattr__(self, key, value):
if key.startswith('_'):
if key.startswith("_"):
return super(ProcessWrapper, self).__setattr__(key, value)
return setattr(self._proc, key, value)

def communicate(self, input=None):
if self._proc.stdin is not None and input is None:
input = b''
input = b""
return self._loop.run_until_complete(self._proc.communicate(input))

def wait(self):
return self._loop.run_until_complete(self._proc.wait())


class VMWrapper(object):
'''Wrap VM object to provide stable API for basic operations'''
"""Wrap VM object to provide stable API for basic operations"""

def __init__(self, vm, loop=None):
self._vm = vm
self._loop = loop or asyncio.get_event_loop()
Expand All @@ -59,7 +62,7 @@ def __getattr__(self, item):
return getattr(self._vm, item)

def __setattr__(self, key, value):
if key.startswith('_'):
if key.startswith("_"):
return super(VMWrapper, self).__setattr__(key, value)
return setattr(self._vm, key, value)

Expand All @@ -74,56 +77,87 @@ def __hash__(self):

def start(self, start_guid=True):
return self._loop.run_until_complete(
self._vm.start(start_guid=start_guid))
self._vm.start(start_guid=start_guid)
)

def shutdown(self):
return self._loop.run_until_complete(self._vm.shutdown())

def run(self, command, wait=False, user=None, passio_popen=False,
passio_stderr=False, gui=False, **kwargs):
def run(
self,
command,
wait=False,
user=None,
passio_popen=False,
passio_stderr=False,
gui=False,
**kwargs,
):
if gui:
try:
self._loop.run_until_complete(
self._vm.run_service_for_stdio('qubes.WaitForSession',
user=user))
self._vm.run_service_for_stdio(
"qubes.WaitForSession", user=user
)
)
except subprocess.CalledProcessError as err:
return err.returncode
if wait:
try:
self._loop.run_until_complete(
self._vm.run_for_stdio(command, user=user))
self._vm.run_for_stdio(command, user=user)
)
except subprocess.CalledProcessError as err:
return err.returncode
return 0
elif passio_popen:
p = self._loop.run_until_complete(self._vm.run(command, user=user,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE if passio_stderr else None))
p = self._loop.run_until_complete(
self._vm.run(
command,
user=user,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE if passio_stderr else None,
)
)
return ProcessWrapper(p, self._loop)
else:
asyncio.ensure_future(self._vm.run_for_stdio(command, user=user),
loop=self._loop)

def run_service(self, service, wait=True, input=None, user=None,
passio_popen=False,
passio_stderr=False, **kwargs):
asyncio.ensure_future(
self._vm.run_for_stdio(command, user=user), loop=self._loop
)

def run_service(
self,
service,
wait=True,
input=None,
user=None,
passio_popen=False,
passio_stderr=False,
**kwargs,
):
if wait:
try:
if isinstance(input, str):
input = input.encode()
self._loop.run_until_complete(
self._vm.run_service_for_stdio(service,
input=input, user=user))
self._vm.run_service_for_stdio(
service, input=input, user=user
)
)
except subprocess.CalledProcessError as err:
return err.returncode
return 0
elif passio_popen:
p = self._loop.run_until_complete(self._vm.run_service(service,
user=user,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE if passio_stderr else None))
p = self._loop.run_until_complete(
self._vm.run_service(
service,
user=user,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE if passio_stderr else None,
)
)
return ProcessWrapper(p, self._loop)


Expand All @@ -137,9 +171,13 @@ def setUp(self):
self.init_default_template(self.template)
if self.template is not None:
# also use this template for DispVMs
dispvm_base = self.app.add_new_vm('AppVM',
name=self.make_vm_name('dvm'),
template=self.template, label='red', template_for_dispvms=True)
dispvm_base = self.app.add_new_vm(
"AppVM",
name=self.make_vm_name("dvm"),
template=self.template,
label="red",
template_for_dispvms=True,
)
self.loop.run_until_complete(dispvm_base.create_on_disk())
self.app.default_dispvm = dispvm_base

Expand All @@ -160,18 +198,23 @@ def create_vms(self, names):
else:
template = self.app.default_template
for vmname in names:
vm = self.app.add_new_vm(qubes.vm.appvm.AppVM,
name=self.make_vm_name(vmname),
template=template,
label='red')
vm = self.app.add_new_vm(
qubes.vm.appvm.AppVM,
name=self.make_vm_name(vmname),
template=template,
label="red",
)
self.loop.run_until_complete(vm.create_on_disk())
self.app.save()

# get objects after reload
vms = []
for vmname in names:
vms.append(VMWrapper(self.app.domains[self.make_vm_name(vmname)],
loop=self.loop))
vms.append(
VMWrapper(
self.app.domains[self.make_vm_name(vmname)], loop=self.loop
)
)
return vms

def enable_network(self):
Expand All @@ -180,8 +223,9 @@ def enable_network(self):
"""
self.init_networking()

def qrexec_policy(self, service, source, destination, allow=True,
target=None):
def qrexec_policy(
self, service, source, destination, allow=True, target=None
):
"""
Allow qrexec calls for duration of the test
:param service: service name
Expand All @@ -196,62 +240,85 @@ def qrexec_policy(self, service, source, destination, allow=True,
# to the same file
# abort if policy exists before the test starts
if not self.test_policy_created:
open_mode = 'x'
open_mode = "x"
else:
open_mode = 'a'
with open('/etc/qubes/policy.d/10-test.policy', open_mode) as policy:
rule = f"{service} * {source} {destination} " \
f"{'allow' if allow else 'deny'}" \
f"{' target=' + target if target else ''}\n"
open_mode = "a"
with open("/etc/qubes/policy.d/10-test.policy", open_mode) as policy:
rule = (
f"{service} * {source} {destination} "
f"{'allow' if allow else 'deny'}"
f"{' target=' + target if target else ''}\n"
)
policy.write(rule)
if not self.test_policy_created:
self.test_policy_created = True
self.addCleanup(os.unlink, '/etc/qubes/policy.d/10-test.policy')
self.addCleanup(os.unlink, "/etc/qubes/policy.d/10-test.policy")


def load_tests(loader, tests, pattern):
include_list = None
if 'QUBES_TEST_EXTRA_INCLUDE' in os.environ:
include_list = os.environ['QUBES_TEST_EXTRA_INCLUDE'].split()
if "QUBES_TEST_EXTRA_INCLUDE" in os.environ:
include_list = os.environ["QUBES_TEST_EXTRA_INCLUDE"].split()
exclude_list = []
if 'QUBES_TEST_EXTRA_EXCLUDE' in os.environ:
exclude_list = os.environ['QUBES_TEST_EXTRA_EXCLUDE'].split()
if "QUBES_TEST_EXTRA_EXCLUDE" in os.environ:
exclude_list = os.environ["QUBES_TEST_EXTRA_EXCLUDE"].split()

for entry in importlib.metadata.entry_points(group='qubes.tests.extra'):
for entry in importlib.metadata.entry_points(group="qubes.tests.extra"):
if include_list is not None and entry.name not in include_list:
continue
if entry.name in exclude_list:
continue
try:
for test_case in entry.load()():
tests.addTests(loader.loadTestsFromNames([
'{}.{}'.format(test_case.__module__, test_case.__name__)]))
tests.addTests(
loader.loadTestsFromNames(
[
"{}.{}".format(
test_case.__module__, test_case.__name__
)
]
)
)
except Exception as err: # pylint: disable=broad-except

def runTest(self, err=err):
raise err
ExtraLoadFailure = type('ExtraLoadFailure',

ExtraLoadFailure = type(
"ExtraLoadFailure",
(qubes.tests.QubesTestCase,),
{entry.name: runTest})
{entry.name: runTest},
)
tests.addTest(ExtraLoadFailure(entry.name))

for entry in importlib.metadata.entry_points(
group='qubes.tests.extra.for_template'):
group="qubes.tests.extra.for_template"
):
if include_list is not None and entry.name not in include_list:
continue
if entry.name in exclude_list:
continue
try:
for test_case in entry.load()():
tests.addTests(loader.loadTestsFromNames(
qubes.tests.create_testcases_for_templates(
test_case.__name__, test_case,
module=sys.modules[test_case.__module__])))
tests.addTests(
loader.loadTestsFromNames(
qubes.tests.create_testcases_for_templates(
test_case.__name__,
test_case,
module=sys.modules[test_case.__module__],
)
)
)
except Exception as err: # pylint: disable=broad-except

def runTest(self, err=err):
raise err
ExtraForTemplateLoadFailure = type('ExtraForTemplateLoadFailure',

ExtraForTemplateLoadFailure = type(
"ExtraForTemplateLoadFailure",
(qubes.tests.QubesTestCase,),
{entry.name: runTest})
{entry.name: runTest},
)
tests.addTest(ExtraForTemplateLoadFailure(entry.name))

return tests
Loading