Skip to content

Commit

Permalink
Fix flake E721 issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jussiviinikka committed Sep 27, 2023
1 parent fe00a3a commit e750263
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 30 deletions.
2 changes: 1 addition & 1 deletion sumu/beeps.py
Expand Up @@ -8,7 +8,7 @@
class Beeps:
def __init__(self, *, dags, data):
self.dags = dags
if not type(self.dags[0]) == np.ndarray:
if type(self.dags[0]) is not np.ndarray:
self.dags = [family_sequence_to_adj_mat(d) for d in self.dags]
self.data = Data(data)
n = self.data.n
Expand Down
4 changes: 2 additions & 2 deletions sumu/bnet.py
Expand Up @@ -235,7 +235,7 @@ def from_dag(cls, dag, *, data=None, arity=2, ess=0.5, params="MP"):
data = np.array([], dtype=np.int32).reshape(0, len(nodes))

for i, node in enumerate(nodes):
if type(arity) == list:
if type(arity) is list:
node.arity = arity[i]
else:
node.arity = arity
Expand Down Expand Up @@ -373,7 +373,7 @@ def sample(self, N=1):
def __getitem__(self, node_name_or_index):

# TODO: bn["node1", "node2"].sample()
is_name = type(node_name_or_index) == str
is_name = type(node_name_or_index) is str

try:
if is_name:
Expand Down
6 changes: 3 additions & 3 deletions sumu/data.py
Expand Up @@ -17,14 +17,14 @@ class Data:
def __init__(self, data_or_path):

# Copying existing Data object
if type(data_or_path) == Data:
if type(data_or_path) is Data:
self.data = data_or_path.data
self.discrete = data_or_path.discrete
self.data_path = data_or_path.data_path
return

# Initializing from np.array
if type(data_or_path) == np.ndarray:
if type(data_or_path) is np.ndarray:
# TODO: Should cast all int types to np.int32 as that is what bdeu
# scorer expects. Also make sure float is np.float64, not
# np.float32 or so?
Expand All @@ -34,7 +34,7 @@ def __init__(self, data_or_path):
return

# Initializing from path
if type(data_or_path) == str:
if type(data_or_path) is str:
self.data_path = data_or_path
with open(data_or_path) as f:
# . is assumed to be a decimal separator
Expand Down
12 changes: 6 additions & 6 deletions sumu/gadget.py
Expand Up @@ -448,23 +448,23 @@ def _populate_default_parameters(self):

def _complete_user_given_parameters(self):
def complete(default, p):
if all(type(p[k]) != dict for k in p):
if all(type(p[k]) is not dict for k in p):
return dict(default, **p)
for k in p:
if type(p[k]) == dict:
if type(p[k]) is dict:
p[k] = complete(default[k], p[k])
return dict(default, **p)

for k in self.p:
if k not in self.default:
continue
if (
type(self.p[k]) == dict
type(self.p[k]) is dict
and "name" in self.p[k]
and self.p[k]["name"] != self.default[k]["name"]
):
continue
if type(self.p[k]) == dict:
if type(self.p[k]) is dict:
self.p[k] = complete(self.default[k], self.p[k])
elif self.p[k] is None:
self.p[k] = self.default[k]
Expand Down Expand Up @@ -782,7 +782,7 @@ def __call__(self, string):
print(string, file=f, flush=True)

def unlink(self):
if type(self._logfile) == pathlib.PosixPath:
if type(self._logfile) is pathlib.PosixPath:
self._logfile.unlink()

def silent(self):
Expand All @@ -791,7 +791,7 @@ def silent(self):
def dict(self, data):
def pretty_dict(d, n=0, string=""):
for k in d:
if type(d[k]) in (dict, Stats):
if type(d[k]) is dict: # in (dict, Stats):
string += f"{' '*n}{k}\n"
else:
string += f"{' '*n}{k}: {d[k]}\n"
Expand Down
36 changes: 18 additions & 18 deletions sumu/validate.py
Expand Up @@ -44,11 +44,11 @@


def is_int(val):
return type(val) == int or np.issubdtype(type(val), np.integer)
return type(val) is int or np.issubdtype(type(val), np.integer)


def is_float(val):
return type(val) == float or np.issubdtype(type(val), np.floating)
return type(val) is float or np.issubdtype(type(val), np.floating)


def is_num(val):
Expand All @@ -72,11 +72,11 @@ def is_pos_num(val):


def is_boolean(val):
return type(val) == bool
return type(val) is bool


def is_string(val):
return type(val) == str
return type(val) is str


def in_range(val, min_val, max_val, min_incl=True, max_incl=True):
Expand Down Expand Up @@ -163,9 +163,9 @@ def max_n_truthy(n, items):
):
lambda p:
not nested_in_dict(p, "params", "t_share") or
type(p["params"]["t_share"]) == dict and
type(p["params"]["t_share"]) is dict and
set(p["params"]["t_share"]).issubset({"C", "K", "d"}) and
all(type(v) == float for v in p["params"]["t_share"].values()) and
all(type(v) is float for v in p["params"]["t_share"].values()) and
all((v > 0 and v < 1) for v in p["params"]["t_share"].values()) and
sum(p["params"]["t_share"].values()) < 1

Expand Down Expand Up @@ -212,7 +212,7 @@ def max_n_truthy(n, items):
"move_weights" not in p or
all(
[
type(p["move_weights"]) == dict,
type(p["move_weights"]) is dict,
set(p["move_weights"]) == {
"R_split_merge",
"R_swap_node_pair",
Expand Down Expand Up @@ -542,17 +542,17 @@ def max_n_truthy(n, items):
"silent should be a boolean":
lambda p:
"silent" not in p or
type(p["silent"]) == bool,
type(p["silent"]) is bool,

"verbose_prefix should be a string":
lambda p:
"verbose_prefix" not in p or
type(p["verbose_prefix"]) == str,
type(p["verbose_prefix"]) is str,

"overwrite should be a boolean":
lambda p:
"overwrite" not in p or
type(p["overwrite"]) == bool,
type(p["overwrite"]) is bool,

"period should be a positive number":
lambda p:
Expand All @@ -567,8 +567,8 @@ def max_n_truthy(n, items):
"should be a list partitioning integers 0..n to sets":
lambda R:
all([
type(R) == list,
all([type(R_i) == set for R_i in R]),
type(R) is list,
all([type(R_i) is set for R_i in R]),
all([is_int(u) for R_i in R for u in R_i]),
sorted([u for R_i in R for u in R_i])
== list(range(max([max(R_i) for R_i in R]) + 1))
Expand All @@ -585,11 +585,11 @@ def max_n_truthy(n, items):
lambda dag:
all(
[
type(dag) == list,
all([type(f) == tuple for f in dag]),
type(dag) is list,
all([type(f) is tuple for f in dag]),
all([len(f) == 2 for f in dag]),
all([isinstance(f[0], (np.integer, int))] for f in dag),
all([type(f[1]) == set for f in dag]),
all([type(f[1]) is set for f in dag]),
all([isinstance(p, (np.integer, int)) for f in dag for p in f[1]]),
]
)
Expand All @@ -603,8 +603,8 @@ def max_n_truthy(n, items):
lambda C:
all(
[
type(C) == dict,
all(type(v) == tuple for v in C.values()),
type(C) is dict,
all(type(v) is tuple for v in C.values()),
all(is_int(vi)
for v in C.values() for vi in v),
]
Expand Down Expand Up @@ -642,7 +642,7 @@ class ValidationError(Exception):
def _make_validator(validator, validator_name, only_check_is_valid=False):
def validate(item):
for f in validator:
if type(f) == str and f[0] == "_":
if type(f) is str and f[0] == "_":
continue
try:
if not validator[f](item):
Expand Down

0 comments on commit e750263

Please sign in to comment.