Skip to content

Commit

Permalink
added checks for src_indices out of bounds of source vec and for nega…
Browse files Browse the repository at this point in the history
…tive index values.
  • Loading branch information
naylor-b committed Dec 2, 2015
1 parent ccfc52b commit 6e6e852
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 218 deletions.
2 changes: 1 addition & 1 deletion benchmarks/10Kmultipoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, adders, scalars):
for i,(a,s) in enumerate(zip(adders, scalars)):
c_name = 'p%d'%i
self.add(c_name, Point(a,s))
# self.connect('X', c_name+'.x', src_indices=[i,])
# self.connect('X', c_name+'.x', src_indices=[i])
self.connect(c_name+'.f2','aggregate.y%d'%i)

self.add('aggregate', Summer(size))
Expand Down
267 changes: 86 additions & 181 deletions openmdao/core/checks.py
Original file line number Diff line number Diff line change
@@ -1,148 +1,6 @@
""" Set of utilities for detecting and reporting connection errors."""

from six.moves import zip
from six import iterkeys, itervalues

class ConnectError(Exception):
""" Custom error that is raised when a connection is invalid."""

@classmethod
def _type_mismatch_error(cls, src, target):
msg = "Type {src[type]} of source '{src[promoted_name]}' must be the same as type {target[type]} of target '{target[promoted_name]}'"
msg = msg.format(src=src, target=target)

return cls(msg)

@classmethod
def _shape_mismatch_error(cls, src, target):
msg = "Shape {src[shape]} of source '{src[pathname]}' must be the same as shape {target[shape]} of target '{target[pathname]}'"
msg = msg.format(src=src, target=target)

return cls(msg)

@classmethod
def _size_mismatch_error(cls, src, target):
msg = "Size {isize} of the indexed sub-part of source '{src[promoted_name]}' must be the same as size {target[size]} of target '{target[promoted_name]}'"
msg = msg.format(src=src, target=target, isize=len(target['src_indices']))

return cls(msg)

@classmethod
def _indices_too_large(cls, src, target):
msg = "Size {isize} of target indices is larger than size {src[size]} of source '{src[promoted_name]}'"
msg = msg.format(src=src, target=target, isize=len(target['src_indices']))

return cls(msg)

@classmethod
def _val_and_shape_mismatch_error(cls, src, target):
msg = "Shape of the initial value {src[val].shape} of source '{src[promoted_name]}' must be the same as shape {target[shape]} of target '{target[promoted_name]}'"
msg = msg.format(src=src, target=target)

return cls(msg)

@classmethod
def nonexistent_src_error(cls, src, target):
""" Formats an error message for non-existant source in a connection.
Args
----
src : str
Name of source
target : str
Name of target
Returns
-------
str : error message
"""
msg = ("Source '{src}' cannot be connected to target '{target}': "
"'{src}' does not exist.")

msg = msg.format(src=src, target=target)

return cls(msg)

@classmethod
def nonexistent_target_error(cls, src, target):
""" Formats an error message for non-existant target in a connection.
Args
----
src : str
Name of source
target : str
Name of target
Returns
-------
str : error message
"""
msg = ("Source '{src}' cannot be connected to target '{target}': "
"'{target}' does not exist.")

msg = msg.format(src=src, target=target)

return cls(msg)

@classmethod
def invalid_target_error(cls, src, target):
""" Formats an error message for invalid target in a connection.
Args
----
src : str
Name of source
target : str
Name of target
Returns
-------
str : error message
"""
msg = ("Source '{src}' cannot be connected to target '{target}': "
"Target must be a parameter but '{target}' is an unknown.")

msg = msg.format(src=src, target=target)

return cls(msg)


def __make_metadata(metadata, to_prom_name):
'''
Add type field to metadata dict.
Returns a modified copy of `metadata`.
'''
metadata = dict(metadata)
metadata['type'] = type(metadata['val'])
metadata['promoted_name'] = to_prom_name[metadata['pathname']]

return metadata


def __get_metadata(paths, metadata_dict, to_prom_name):
metadata = []

for path in paths:
var_metadata = metadata_dict[path]
metadata.append(__make_metadata(var_metadata, to_prom_name))

return metadata


def _check_types_match(src, tgt):
if src['type'] == tgt['type']:
return

src_indices = tgt.get('src_indices')
if src_indices and len(src_indices) == 1 and tgt['type'] == float:
return

raise ConnectError._type_mismatch_error(src, tgt)

from six import iteritems

def check_connections(connections, params_dict, unknowns_dict, to_prom_name):
"""Checks the specified connections to make sure they are valid in
Expand All @@ -163,55 +21,102 @@ def check_connections(connections, params_dict, unknowns_dict, to_prom_name):
Raises
------
ConnectError
TypeError, or ValueError
Any invalidity in the connection raises an error.
"""
for tgt, (src, idxs) in iteritems(connections):
tmeta = params_dict[tgt]
smeta = unknowns_dict[src]
_check_types_match(smeta, tmeta, to_prom_name)
_check_shapes_match(smeta, tmeta, to_prom_name)

# Get metadata for all sources
srcs = (src for src, idxs in itervalues(connections))
sources = __get_metadata(srcs, unknowns_dict, to_prom_name)
def _check_types_match(src, tgt, to_prom_name):
stype = type(src['val'])
ttype = type(tgt['val'])

#Get metadata for all targets
targets = __get_metadata(iterkeys(connections), params_dict, to_prom_name)
if stype == ttype:
return

for source, target in zip(sources, targets):
_check_types_match(source, target)
_check_shapes_match(source, target)
src_indices = tgt.get('src_indices')
if src_indices and len(src_indices) == 1 and ttype == float:
return

raise TypeError("Type %s of source %s must be the same as type %s of "
"target %s." % (type(src['val']),
_both_names(src, to_prom_name),
type(tgt['val']), _both_names(tgt, to_prom_name)))

def _check_shapes_match(source, target):
# Use the type of the shape of source and target to determine the
# correct function to use for shape checking
check_shape_function = __shape_checks.get((type(source.get('shape')),
type(target.get('shape'))),
lambda x, y: None)
check_shape_function(source, target)
def _check_shapes_match(source, target, to_prom_name):
sshape = source.get('shape')
tshape = target.get('shape')
if sshape is not None and tshape is not None:
__check_shapes_match(source, target, to_prom_name)
elif sshape is None and tshape is not None:
__check_val_and_shape_match(source, target, to_prom_name)

def __check_shapes_match(src, target, to_prom_name):
src_idxs = target.get('src_indices')

def __check_shapes_match(src, target):
if src['shape'] != target['shape']:
if 'src_indices' in target:
if len(target['src_indices']) != target['size']:
raise ConnectError._size_mismatch_error(src, target)
elif len(target['src_indices']) > src['size']:
raise ConnectError._indices_too_large(src, target)
if src_idxs is not None:
if len(src_idxs) != target['size']:
raise ValueError("Size %d of the indexed sub-part of source "
"%s must be the same as size %d of target "
"%s." %
(len(target['src_indices']),
_both_names(src, to_prom_name),
target['size'],
_both_names(target, to_prom_name)))
if len(src_idxs) > src['size']:
raise ValueError("Size %d of target indices is larger than size"
" %d of source %s." %
(len(src_idxs), src['size'],
_both_names(src, to_prom_name)))
elif 'src_indices' in src:
if target['size'] != src['distrib_size']:
msg = ("Total size {src[distrib_size]} of distributed source "
"'{src[pathname]}' must be the same as size "
"{target[size]} of target '{target[pathname]}'")
msg = msg.format(src=src, target=target)
raise RuntimeError(msg)
raise ValueError("Total size %d of distributed source "
"%s must be the same as size "
"%d of target %s." %
(src['distrib_size'], _both_names(src, to_prom_name),
target['size'], _both_names(target, to_prom_name)))
else:
raise ConnectError._shape_mismatch_error(src, target)


def __check_val_and_shape_match(src, target):
raise ValueError("Shape %s of source %s must be the same as shape "
"%s of target %s." % (src['shape'],
_both_names(src, to_prom_name), target['shape'],
_both_names(target, to_prom_name)))

if src_idxs is not None:
if 'src_indices' in src:
ssize = src['distrib_size']
else:
ssize = src['size']
max_idx = max(src_idxs)
if max_idx >= ssize:
raise ValueError("%s src_indices contains an index (%d) that "
"exceeds the bounds of source variable %s of "
"size %d." %
(_both_names(target, to_prom_name),
max_idx, _both_names(src, to_prom_name), ssize))
min_idx = min(src_idxs)
if min_idx < 0:
raise ValueError("%s src_indices contains a negative index "
"(%d)." %
(_both_names(target, to_prom_name), min_idx))

def __check_val_and_shape_match(src, target, to_prom_name):
if src['val'].shape != target['shape']:
raise ConnectError._val_and_shape_mismatch_error(src, target)


__shape_checks = {
(tuple, tuple) : __check_shapes_match,
(type(None), tuple) : __check_val_and_shape_match
}
raise ValueError("Shape of the initial value %s of source "
"%s must be the same as shape %s of target %s." %
(src[val].shape, _both_names(src, to_prom_name),
target['shape'], _both_names(target, to_prom_name)))

def _both_names(meta, to_prom_name):
"""If the pathname differs from the promoted name, return a string with
both names. Otherwise, just return the pathname.
"""
name = meta['pathname']
pname = to_prom_name[name]
if name == pname:
return "'%s'" % name
else:
return "'%s' (%s)" % (name, pname)
15 changes: 10 additions & 5 deletions openmdao/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from openmdao.core.system import System
from openmdao.util.string_util import nearest_child, name_relative_to
from openmdao.util.graph import collapse_nodes
from openmdao.core.checks import ConnectError

#from openmdao.devtools.debug import diff_mem, mem_usage

Expand Down Expand Up @@ -583,8 +582,9 @@ def _get_explicit_connections(self):
try:
src_pathnames = to_abs_pnames[src]
except KeyError as error:
raise ConnectError.nonexistent_src_error(src, tgt)

raise NameError("Source '%s' cannot be connected to "
"target '%s': '%s' does not exist." %
(src, tgt, src))
try:
for tgt_pathname in to_abs_pnames[tgt]:
for src_pathname in src_pathnames:
Expand All @@ -595,9 +595,14 @@ def _get_explicit_connections(self):
try:
to_abs_uname[tgt]
except KeyError as error:
raise ConnectError.nonexistent_target_error(src, tgt)
raise NameError("Source '%s' cannot be connected to "
"target '%s': '%s' does not exist." %
(src, tgt, tgt))
else:
raise ConnectError.invalid_target_error(src, tgt)
raise NameError("Source '%s' cannot be connected to "
"target '%s': Target must be a "
"parameter but '%s' is an unknown." %
(src, tgt, tgt))

return connections

Expand Down
1 change: 0 additions & 1 deletion openmdao/core/test/test_pbo_desvar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

import unittest
from openmdao.api import Component, Problem, Group, IndepVarComp, ExecComp, Driver
from openmdao.core.checks import ConnectError
from openmdao.util.record_util import create_local_meta, update_local_meta


Expand Down
Loading

0 comments on commit 6e6e852

Please sign in to comment.