Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/reinstate strict optional mypy #541

Merged
merged 45 commits into from Dec 2, 2021

Conversation

jklaise
Copy link
Member

@jklaise jklaise commented Nov 24, 2021

This PR reinstates strict Optional type-checking with mypy because it's both a prerequisite of #511 and also a good idea (there were multiple genuine bugs caught by re-enabling strict behaviour).

Specifically, setup.cfg now contains the following:

[mypy]
ignore_missing_imports = True
no_implicit_optional = True

Note that strict_optional = True is the default and disallows None as a valid type for functions that don't declare it explicitly as valid. Furthermore, no_implicit_optional now requires writing Optional[some_type] = None instead of some_type = None which is recommended behaviour.

Most of the code is just type changes, but where genuine bugs were revealed some (minor) logic changes were needed also.

I will post some comments on the PR shortly to ease reviewing.

I have additionally taken the change to run isort on the affected modules for more pleasant import ordering.

One particularly pernicious pattern I have decided to refactor is to do with initializing instance attributes as None and later setting them to their proper values. So instead of this (which would fail type-checking now):

class Foo:
    def __init__(self):
        # to be set later
        self.attr = None # type: int

we now declare the type of the non-initialized attribute in the class instead:

class Foo:
    attr: int

Now you may ask, why not just stick with the first variant and annotate with # type: Optional[int]? The answer is that None is (almost always) not a valid value as in the self.attr = None pattern is used as a placeholder to be initialized later on. But this can lead to a variety of bugs if we don't explicitly check for None every time we want to use the (hopefully initialized) attribute. Type checkers like mypy are also unlikely to know that when accessing self.attr after initialization there is no risk of it being None because of potentially complex runtime logic that determines self.attr can't be None which the type-checker cannot follow. Finally, we may fail to consider all cases of setting self.attr to it's not-None value and have a genuine bug where we believe self.attr isn't None anymore when it still is... Thus, it is better to declare but not instantiate to a generally invalid None value. In this scenario we can also check that if the value needs instantiating by calling hasattr(self, 'attr') instead of comparing to None as in the previous scenario.

A final minor curiosity, because our numpy version is older and doesn't include numpy.typing advancements, anywhere arr: np.ndarray = None is declared doesn't actually result in a mypy error. However, we're good citizens and wrap it in Optional as well, this will future-proof when later versionf of numpy can be used.

@@ -218,7 +219,7 @@ def _sample_knn(x: np.ndarray, X_train: np.ndarray, nb_samples: int = 10) -> np.
return np.asarray(X_sampled) # shape=(nb_instances, nb_samples, nb_features)


def _sample_grid(x: np.ndarray, feature_range: np.ndarray = None, epsilon: float = 0.04,
def _sample_grid(x: np.ndarray, feature_range: np.ndarray, epsilon: float = 0.04,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no None handling in the function so feature_range must be required.

@@ -341,10 +349,16 @@ def _infer_feature_range(X_train: np.ndarray) -> np.ndarray:
return np.vstack((X_train.min(axis=0), X_train.max(axis=0))).T


class LinearityMeasure(object):
class LinearityMeasure:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the old-style ineheritance from object as it's not necessary in Python 3.

@@ -50,13 +51,13 @@ def load_cats(target_size: tuple = (299, 299), return_X_y: bool = False) -> Unio
(data, target)
Tuple if ``return_X_y`` is true
"""
tar = tarfile.open(fileobj=BytesIO(pkgutil.get_data(__name__, "data/cats.tar.gz")), mode='r:gz')
tar = tarfile.open(fileobj=BytesIO(pkgutil.get_data(__name__, "data/cats.tar.gz")), mode='r:gz') # type: ignore
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The three type: ignore in this module are because the various objects actually return Optional, presumably because the data is not guaranteed to exist or yield anything. I don't think we need to check for this explicitly as we are processing data that we ship with the library.

@@ -130,7 +130,7 @@ def explain(self, X: np.ndarray, features: List[int] = None, min_bin_points: int
feature_names = self.feature_names[features] # type: ignore
else:
feature_names = self.feature_names
features = range(n_features) # type: ignore
features = list(range(n_features))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expect a list, therefore a list you shall have.

@@ -24,7 +25,7 @@ def __init__(self, samplers: List[Callable], **kwargs) -> None:
"""

self.sample_fcn = samplers[0]
self.samplers = None # type: List[Callable]
self.samplers = None # type: Optional[List[Callable]]
Copy link
Member Author

@jklaise jklaise Nov 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an unfortunate exception in the tirade of the original comment because the class AnchorBaseBeam is inherited by DistributedAnchorBaseBeam which uses a self.samplers attribute even though it is never set for this base class. Sorting this out would require more refactoring than necessary for this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not ideal, but agree this is probably the most elegant fix without refactoring.

@@ -835,7 +836,7 @@ def __init__(self, samplers: List[Callable], **kwargs) -> None:
self.pool = ActorPool(samplers)
self.samplers = samplers

def _get_coverage_samples(self, coverage_samples: int, samplers: List[Callable] = None) -> np.ndarray:
def _get_coverage_samples(self, coverage_samples: int, samplers: List[Callable]) -> np.ndarray: # type: ignore
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We ignore the type here because this method is an override of the one in the base class, except here samplers really is required.

if isinstance(segmentation_fn, str) and not segmentation_kwargs:
# TODO: this logic needs improvement. We should check against a fixed set of strings
# for built-ins instead of any `str`.
if isinstance(segmentation_fn, str) and segmentation_kwargs is None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and not was a bug as it would also match the empty dictionary. Now we explicitly deal with the None case which conveniently doesn't require the type: ignore later on.

@@ -128,7 +133,7 @@ def set_instance_label(self, X: np.ndarray) -> None:
Instance to be explained.
"""

label = self.predictor(X.reshape(1, -1))[0]
label = self.predictor(X.reshape(1, -1))[0] # type: int
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to declare the type here otherwise we will get warnings later on when passing label to functions.

@@ -698,7 +705,7 @@ def __init__(self,
self.ohe = ohe
self.feature_names = feature_names

if ohe:
if ohe and categorical_names:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never validate the combinations that are consistent, this check is the easiest one to add for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does ohe=True with categorical_names=None (or {}) cause an error in .fit as self.cat_vars_ohe is not set? Maybe we should raise an error in init if this is the case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think any errors are raised currently, we should definitely validate and do it. Perhaps a new issue.

@@ -1003,19 +1009,19 @@ def predictor(self) -> Callable:
return self._ohe_predictor if self.ohe else self._predictor

@predictor.setter
def predictor(self, predictor: Optional[Callable]) -> None:
def predictor(self, predictor: Callable) -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predictor should never be None.

@@ -285,7 +285,7 @@ def generate_condition(X_ohe: np.ndarray,

def sample_numerical(X_hat_num_split: List[np.ndarray],
X_ohe_num_split: List[np.ndarray],
C_num_split: List[np.ndarray],
C_num_split: Optional[List[np.ndarray]],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None case is dealt with in the body so it can clearly be Optional (as inferred by the type-checker elsewhere in the code).

@@ -335,7 +335,7 @@ def sample_numerical(X_hat_num_split: List[np.ndarray],


def sample_categorical(X_hat_cat_split: List[np.ndarray],
C_cat_split: List[np.ndarray]) -> List[np.ndarray]:
C_cat_split: Optional[List[np.ndarray]]) -> List[np.ndarray]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None case is dealt with in the body so it can clearly be Optional (as inferred by the type-checker elsewhere in the code).

@@ -506,7 +506,7 @@ def get_inv_preprocessor(X_ohe: np.ndarray):

np_X_inv = np_X_inv[:, inv_perm].astype(object)
for i, fn in enumerate(feature_names):
type = feature_types[fn] if fn in feature_types else float
type = feature_types[fn] if fn in feature_types else float # type: ignore # TODO: closure resets type info?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it does not. The casting is in the following line:

np_X_inv[:, i] = np_X_inv[:, i].astype(type)

Should be good as it is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue here is with the inferred type for feature_types, mypy can't see that it can't be None due to it being inside an inner function which is a limitation referred to in the link above. I will see if I can reduce the scope of the ignore with an error code though.

if self.is_cat:
# compute dimensionality after conversion from OHE to ordinal encoding
shape = ohe_to_ord_shape(shape, cat_vars=cat_vars, is_ohe=self.ohe)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this inside the if self.is_cat block because it really doesn't make sense to do if there is no cat_vars (again, a case of not validating allowed combinations of kwargs).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it might be best to open a new issue about raising a warning for incorrect combinations of kwargs?


# define placeholder for mapping which can be fed after the fit step
max_key = max(cat_vars, key=cat_vars.get)
max_key = max(cat_vars)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max is by default over the keys. This also avoids a type error because dict.get actually returns an Optional type.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#610
Default is not over the keys. Both functions generate different outputs.

cat_vars
{0: 2, 2: 3, 5: 8, 13: 2, 15: 2, 17: 2, 19: 3}
max(cat_vars)
19
max(cat_vars,key=cat_vars.get)
5

@@ -701,7 +710,7 @@ def fit(self, train_data: np.ndarray, trustscore_kwargs: dict = None, d_type: st
else:
preds = np.argmax(self.predict(train_data), axis=1)

self.cat_vars_ord = None
self.cat_vars_ord = dict() # type: dict
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoids checking for None later on when passing to functions that expect a dict.

Comment on lines 746 to 748
if w is None:
msg = f"Must specify a value for `w` if using d_type='abdm-mvdm'"
raise ValueError(msg)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validation was missing which the type-checker caught as trying to do arithmetic with w=None and integers later in the function.

@@ -826,7 +839,7 @@ def loss_fn(self, pred_proba: np.ndarray, Y: np.ndarray) -> np.ndarray:
return loss_attack

def get_gradients(self, X: np.ndarray, Y: np.ndarray, grads_shape: tuple,
cat_vars_ord: dict = None) -> np.ndarray:
cat_vars_ord: dict) -> np.ndarray:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now cat_vars_ord is always a dict even if it's the empty one.

@@ -1045,15 +1058,16 @@ def compare(x: Union[float, int, np.ndarray], y: int) -> bool:
self.class_proto[c] = self.X_by_class[c][idx_c[0][-1]].reshape(1, -1)

if self.enc_or_kdtree:
self.id_proto = min(dist_proto, key=dist_proto.get)
self.id_proto = min(dist_proto)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment to max above.

Comment on lines 1068 to 1070
if self.is_cat:
# set shape for perturbed instance and gradients
pert_shape = ohe_to_ord_shape(self.shape, cat_vars=self.cat_vars, is_ohe=self.ohe)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only makes sense to do if cat_vars was passed.

Comment on lines 183 to 190
X = self.X[rand_idx] # input array
Y_m = self.Y_m[rand_idx] # model's prediction
Y_t = self.Y_t[rand_idx] # counterfactual target
Z = self.Z[rand_idx] # input embedding
Z_cf_tilde = self.Z_cf_tilde[rand_idx] # noised counterfactual embedding
C = self.C[rand_idx] if (self.C is not None) else None # conditional array if exists
R_tilde = self.R_tilde[rand_idx] # noised counterfactual reward
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can re-align the comments if desired, my IDE collapses everything when I format...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also one for aligning nicely, but all IDE's seem to disagree with us...

@@ -491,7 +494,7 @@ def _validate_kwargs(self,
predictor: Callable,
encoder: 'Union[tensorflow.keras.Model, torch.nn.Module]',
decoder: 'Union[tensorflow.keras.Model, torch.nn.Module]',
latent_dim: float,
latent_dim: Optional[int],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug in the type, it should be an int but may be left unspecified.

@@ -780,7 +783,7 @@ def _is_classification(pred: np.ndarray) -> bool:

def explain(self,
X: np.ndarray,
Y_t: np.ndarray = None, # TODO: remove default value (mypy error. explanation in the validation step)
Y_t: np.ndarray = None, # TODO: remove default value (mypy error. explanation in the validation step)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: missing Optional.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still to fix then?

@@ -462,7 +465,7 @@ def _diversity(self,
X_cf, Y_m_cf, Y_t = results["X_cf"], results["Y_m_cf"], results["Y_t"]

# Select only counterfactuals where prediction matches the target.
X_cf = X_cf[Y_t == Y_m_cf]
X_cf = X_cf[Y_t == Y_m_cf] # type: ignore # TODO: fix me
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alibi/explainers/cfrl_tabular.py:468: error: Value of type "Optional[Any]" is not indexable - felt like would need more refactoring

@@ -833,7 +838,7 @@ def explain(self,
baselines,
self.model,
self.layer,
self.orig_call,
self.orig_call, # type: ignore # TODO: fix me
Copy link
Member Author

@jklaise jklaise Nov 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the type-checker point of view self.orig_call can still bet None even though we narrowed it down via else clause (following if self.layer is None). Same comment for the occurences below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmn this one seems annoying...

group_names = ['group_{}'.format(i) for i in range(len(groups))]
# disable grouping or data weights if inputs are not correct
if self.ignore_weights:
weights = None
if not self.use_groups:
group_names, groups = None, None
else:
self.feature_names = group_names
self.feature_names = group_names # type:ignore # TODO: fix me
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group_names could be None as hinted by the method.

@@ -730,20 +732,20 @@ def fit(self, # type: ignore

# check user inputs to provide warnings if input is incorrect
self._check_inputs(background_data, group_names, groups, weights)
if self.create_group_names:
if self.create_group_names and groups:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't check if groups is None beforehand, so type-checker complains that len(groups) may fail.


# perform grouping if requested by the user
self.background_data = self._get_data(background_data, group_names, groups, weights, **kwargs)
explainer_args = (self.predictor, self.background_data)
explainer_kwargs = {'link': self.link} # type: Dict[str, Union[str, int]]
explainer_kwargs = {'link': self.link} # type: Dict[str, Union[str, int, None]]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allow None because seed can be None.

@@ -49,7 +49,7 @@ def compile(self,

if isinstance(loss, list):
# check if the number of weights is the same as the number of partial losses
if len(loss_weights) != len(loss):
if len(self.loss_weights) != len(loss):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.loss_weights is already validated and can't be None.

alibi/saving.py Outdated
@@ -165,7 +165,7 @@ def _save_AnchorImage(explainer: 'AnchorImage', path: Union[str, os.PathLike]) -
dill.dump(segmentation_fn, f, recurse=True)

predictor = explainer.predictor
explainer.predictor = None
explainer.predictor = None # type: ignore
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignores in this file are purely because we temporarily set some attributes to None so that they are not saved to disk. The attributes are reinstated to their original values afterwards.

@jklaise jklaise force-pushed the dev/reinstate-strict-optional-mypy branch from 785d218 to 745bd55 Compare December 2, 2021 12:16
@jklaise
Copy link
Member Author

jklaise commented Dec 2, 2021

Enabling strict optional and testing against the latest numpy has flagged quite a few more issues, I will attempt to address them as part of this PR.

Comment on lines 66 to 68
self.num_classes: int
self.min_m: float
self.max_m: float
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I avoided initializing these to None as those throw mypy errors when passed into the numpy.random.randint function.

@@ -839,7 +832,7 @@ def explain(self,

# Perform prediction in mini-batches.
n_minibatch = int(np.ceil(X.shape[0] / batch_size))
all_results: Dict[str, np.ndarray] = {}
all_results: Dict[str, Optional[np.ndarray]] = {}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this to Optional because C can be None. However, this complicates typing further down as all other entries cannot be None, but the type checker doesn't know that. A TypedDict might be a better fit here in the future.

Copy link
Collaborator

@RobertSamoilescu RobertSamoilescu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@jklaise
Copy link
Member Author

jklaise commented Dec 2, 2021

Reverted removing model_type attribute from MockPredictor as it is actually being used to emulate the model wrapper class behaviour from the shap library. Added also a comment for future reference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants