Skip to content

Commit

Permalink
Transform many package related facts to use a nested function
Browse files Browse the repository at this point in the history
Each fact that is deduced from package rules, and start with
a bare package atom, is transformed into a "facts" atom containing
a nested function.

For instance we transformed

  version_declared(Package, ...) -> facts(Package, version_declared(...))

This allows us to clearly mark facts that represent a rule on the package,
and will be of help later when we'll have to distinguish the cases where
the atom "Package" is being used referred to package rules and not to a
node in the DAG.
  • Loading branch information
alalazo committed Aug 7, 2023
1 parent 8cd9497 commit b72e7aa
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 112 deletions.
83 changes: 54 additions & 29 deletions lib/spack/spack/solver/asp.py
Expand Up @@ -302,6 +302,8 @@ def argify(arg):
return clingo.String(str(arg))
elif isinstance(arg, int):
return clingo.Number(arg)
elif isinstance(arg, AspFunction):
return clingo.Function(arg.name, [argify(x) for x in arg.args], positive=positive)
else:
return clingo.String(str(arg))

Expand Down Expand Up @@ -918,16 +920,20 @@ def key_fn(version):
)

for weight, declared_version in enumerate(most_to_least_preferred):
# TODO: self.package_fact(pkg.name).version_declared(declared_version, weight=weight)
self.gen.fact(
fn.version_declared(
pkg.name, declared_version.version, weight, str(declared_version.origin)
fn.facts(
pkg.name,
fn.version_declared(
declared_version.version, weight, str(declared_version.origin)
),
)
)

# Declare deprecated versions for this package, if any
deprecated = self.deprecated_versions[pkg.name]
for v in sorted(deprecated):
self.gen.fact(fn.deprecated_version(pkg.name, v))
self.gen.fact(fn.facts(pkg.name, fn.deprecated_version(v)))

def spec_versions(self, spec):
"""Return list of clauses expressing spec's version constraints."""
Expand Down Expand Up @@ -970,7 +976,9 @@ def conflict_rules(self, pkg):
conflict_msg = default_msg.format(pkg.name, trigger, constraint)
constraint_msg = "conflict constraint %s" % str(constraint)
constraint_id = self.condition(constraint, name=pkg.name, msg=constraint_msg)
self.gen.fact(fn.conflict(pkg.name, trigger_id, constraint_id, conflict_msg))
self.gen.fact(
fn.facts(pkg.name, fn.conflict(trigger_id, constraint_id, conflict_msg))
)
self.gen.newline()

def compiler_facts(self):
Expand Down Expand Up @@ -1023,8 +1031,11 @@ def package_compiler_defaults(self, pkg):

for i, compiler in enumerate(reversed(matches)):
self.gen.fact(
fn.node_compiler_preference(
pkg.name, compiler.spec.name, compiler.spec.version, -i * 100
fn.facts(
pkg.name,
fn.node_compiler_preference(
compiler.spec.name, compiler.spec.version, -i * 100
),
)
)

Expand Down Expand Up @@ -1119,7 +1130,7 @@ def pkg_rules(self, pkg, tests):

if spack.spec.Spec() in when:
# unconditional variant
self.gen.fact(fn.variant(pkg.name, name))
self.gen.fact(fn.facts(pkg.name, fn.variant(name)))
else:
# conditional variant
for w in when:
Expand All @@ -1128,19 +1139,23 @@ def pkg_rules(self, pkg, tests):
msg += " when %s" % w

cond_id = self.condition(w, name=pkg.name, msg=msg)
self.gen.fact(fn.variant_condition(cond_id, pkg.name, name))
self.gen.fact(fn.facts(pkg.name, fn.conditional_variant(cond_id, name)))

single_value = not variant.multi
if single_value:
self.gen.fact(fn.variant_single_value(pkg.name, name))
self.gen.fact(fn.facts(pkg.name, fn.variant_single_value(name)))
self.gen.fact(
fn.variant_default_value_from_package_py(pkg.name, name, variant.default)
fn.facts(
pkg.name, fn.variant_default_value_from_package_py(name, variant.default)
)
)
else:
spec_variant = variant.make_default()
defaults = spec_variant.value
for val in sorted(defaults):
self.gen.fact(fn.variant_default_value_from_package_py(pkg.name, name, val))
self.gen.fact(
fn.facts(pkg.name, fn.variant_default_value_from_package_py(name, val))
)

values = variant.values
if values is None:
Expand All @@ -1151,7 +1166,9 @@ def pkg_rules(self, pkg, tests):
for sid, s in enumerate(values.sets):
for value in s:
self.gen.fact(
fn.variant_value_from_disjoint_sets(pkg.name, name, value, sid)
fn.facts(
pkg.name, fn.variant_value_from_disjoint_sets(name, value, sid)
)
)
union.update(s)
values = union
Expand All @@ -1178,7 +1195,9 @@ def pkg_rules(self, pkg, tests):
msg="empty (total) conflict constraint",
)
msg = "variant {0}={1} is conditionally disabled".format(name, value)
self.gen.fact(fn.conflict(pkg.name, trigger_id, constraint_id, msg))
self.gen.fact(
fn.facts(pkg.name, fn.conflict(trigger_id, constraint_id, msg))
)
else:
imposed = spack.spec.Spec(value.when)
imposed.name = pkg.name
Expand All @@ -1189,10 +1208,10 @@ def pkg_rules(self, pkg, tests):
name=pkg.name,
msg="%s variant %s value %s when %s" % (pkg.name, name, value, when),
)
self.gen.fact(fn.variant_possible_value(pkg.name, name, value))
self.gen.fact(fn.facts(pkg.name, fn.variant_possible_value(name, value)))

if variant.sticky:
self.gen.fact(fn.variant_sticky(pkg.name, name))
self.gen.fact(fn.facts(pkg.name, fn.variant_sticky(name)))

self.gen.newline()

Expand All @@ -1210,7 +1229,8 @@ def pkg_rules(self, pkg, tests):

# virtual preferences
self.virtual_preferences(
pkg.name, lambda v, p, i: self.gen.fact(fn.pkg_provider_preference(pkg.name, v, p, i))
pkg.name,
lambda v, p, i: self.gen.fact(fn.facts(pkg.name, fn.provider_preference(v, p, i))),
)

self.package_requirement_rules(pkg)
Expand All @@ -1232,15 +1252,16 @@ def condition(self, required_spec, imposed_spec=None, name=None, msg=None, node=
"""
named_cond = required_spec.copy()
named_cond.name = named_cond.name or name
assert named_cond.name, "must provide name for anonymous condtions!"
assert named_cond.name, "must provide name for anonymous conditions!"

# Check if we can emit the requirements before updating the condition ID counter.
# In this way, if a condition can't be emitted but the exception is handled in the caller,
# we won't emit partial facts.
requirements = self.spec_clauses(named_cond, body=True, required_from=name)

condition_id = next(self._condition_id_counter)
self.gen.fact(fn.condition(condition_id, msg))
self.gen.fact(fn.facts(named_cond.name, fn.condition(condition_id)))
self.gen.fact(fn.condition_reason(condition_id, msg))
for pred in requirements:
self.gen.fact(fn.condition_requirement(condition_id, *pred.args))

Expand All @@ -1259,13 +1280,15 @@ def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):

def package_provider_rules(self, pkg):
for provider_name in sorted(set(s.name for s in pkg.provided.keys())):
self.gen.fact(fn.possible_provider(pkg.name, provider_name))
self.gen.fact(fn.facts(pkg.name, fn.possible_provider(provider_name)))

for provided, whens in pkg.provided.items():
for when in whens:
msg = "%s provides %s when %s" % (pkg.name, provided, when)
condition_id = self.condition(when, provided, pkg.name, msg)
self.gen.fact(fn.provider_condition(condition_id, when.name, provided.name))
self.gen.fact(
fn.facts(when.name, fn.provider_condition(condition_id, provided.name))
)
self.gen.newline()

def package_dependencies_rules(self, pkg):
Expand All @@ -1291,7 +1314,9 @@ def package_dependencies_rules(self, pkg):
msg += " when %s" % cond

condition_id = self.condition(cond, dep.spec, pkg.name, msg)
self.gen.fact(fn.dependency_condition(condition_id, pkg.name, dep.spec.name))
self.gen.fact(
fn.facts(pkg.name, fn.dependency_condition(condition_id, dep.spec.name))
)

for t in sorted(deptypes):
# there is a declared dependency of type t
Expand Down Expand Up @@ -1449,7 +1474,7 @@ def external_packages(self):
for local_idx, spec in enumerate(external_specs):
msg = "%s available as external when satisfying %s" % (spec.name, spec)
condition_id = self.condition(spec, msg=msg)
self.gen.fact(fn.possible_external(condition_id, pkg_name, local_idx))
self.gen.fact(fn.facts(pkg_name, fn.possible_external(condition_id, local_idx)))
self.possible_versions[spec.name].add(spec.version)
self.gen.newline()

Expand Down Expand Up @@ -1495,7 +1520,9 @@ def target_preferences(self, pkg_name):
if str(preferred.architecture.target) == best_default and i != 0:
offset = 100
self.gen.fact(
fn.target_weight(pkg_name, str(preferred.architecture.target), i + offset)
fn.facts(
pkg_name, fn.target_weight(str(preferred.architecture.target), i + offset)
)
)

def spec_clauses(self, *args, **kwargs):
Expand Down Expand Up @@ -2041,11 +2068,11 @@ def define_version_constraints(self):
# generate facts for each package constraint and the version
# that satisfies it
for v in sorted(v for v in self.possible_versions[pkg_name] if v.satisfies(versions)):
self.gen.fact(fn.version_satisfies(pkg_name, versions, v))
self.gen.fact(fn.facts(pkg_name, fn.version_satisfies(versions, v)))

self.gen.newline()

def define_virtual_constraints(self):
def collect_virtual_constraints(self):
"""Define versions for constraints on virtuals.
Must be called before define_version_constraints().
Expand Down Expand Up @@ -2131,7 +2158,7 @@ def define_variant_values(self):
# spec_clauses(). We might want to order these facts by pkg and name
# if we are debugging.
for pkg, variant, value in self.variant_values_from_specs:
self.gen.fact(fn.variant_possible_value(pkg, variant, value))
self.gen.fact(fn.facts(pkg, fn.variant_possible_value(variant, value)))

def _facts_from_concrete_spec(self, spec, possible):
# tell the solver about any installed packages that could
Expand Down Expand Up @@ -2280,10 +2307,8 @@ def setup(self, driver, specs, reuse=None):
self.gen.h1("Variant Values defined in specs")
self.define_variant_values()

self.gen.h1("Virtual Constraints")
self.define_virtual_constraints()

self.gen.h1("Version Constraints")
self.collect_virtual_constraints()
self.define_version_constraints()

self.gen.h1("Compiler Version Constraints")
Expand Down

0 comments on commit b72e7aa

Please sign in to comment.