Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,12 @@ repos:
types_or: [python, jupyter]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.0
rev: v1.18.2
hooks:
- id: mypy
entry: python3 -m mypy --config-file pyproject.toml
language: system
types: [python]
exclude: "tests"
- id: mypy
name: mypy
entry: ./run_mypy.sh
language: system

- repo: local
hooks:
Expand Down
38 changes: 38 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
[mypy]
mypy_path=src
follow_imports = normal
ignore_missing_imports = False
install_types = True
pretty = True
non_interactive = True
disallow_untyped_defs = True
no_implicit_optional = True
check_untyped_defs = True
allow_untyped_decorators = False
allow_incomplete_defs = False
warn_redundant_casts = True
warn_unused_ignores = True
implicit_reexport = False
strict_equality = True
extra_checks = True
warn_unused_configs = True
allow_subclassing_any = False
exclude = (venv|examples/tutorial/*|tests)

[mypy-sklearn.*]
ignore_missing_imports = True

[mypy-syntheval.*]
ignore_missing_imports = True

[mypy-opacus.*]
ignore_missing_imports = True

[mypy-nltk.*]
ignore_missing_imports = True

[mypy-scipy.*]
ignore_missing_imports = True

[mypy-category_encoders.*]
ignore_missing_imports = True
26 changes: 0 additions & 26 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,6 @@ docs = [
[tool.uv]
default-groups = ["dev", "docs"]

[tool.mypy]
follow_imports = "normal"
ignore_missing_imports = false
install_types = true
pretty = true
non_interactive = true
disallow_untyped_defs = false
no_implicit_optional = true
check_untyped_defs = true
namespace_packages = true
explicit_package_bases = true
warn_unused_configs = true
allow_subclassing_any = false
allow_untyped_calls = true
allow_incomplete_defs = false
allow_untyped_decorators = false
warn_redundant_casts = true
warn_unused_ignores = true
implicit_reexport = false
strict_equality = true
extra_checks = true
mypy_path = "src"
files = ["src", "examples"]
exclude = [
"examples/tutorial/.*"
]

[tool.ruff]
include = ["*.py", "pyproject.toml"]
Expand Down
3 changes: 3 additions & 0 deletions run_mypy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/sh

mypy --config-file ./mypy.ini .
4 changes: 2 additions & 2 deletions src/midst_toolkit/attacks/ensemble/process_split_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def generate_train_test_challenge_splits(
# Shuffle data
df_val = df_val.sample(frac=1, random_state=random_seed).reset_index(drop=True)

y_val = df_val["is_train"].values
y_val = df_val["is_train"].to_numpy()
df_val = df_val.drop(columns=["is_train"])

# Test set
Expand Down Expand Up @@ -139,7 +139,7 @@ def generate_train_test_challenge_splits(

df_test = df_test.sample(frac=1, random_state=random_seed).reset_index(drop=True)

y_test = df_test["is_train"].values
y_test = df_test["is_train"].to_numpy()
df_test = df_test.drop(columns=["is_train"])

return df_val, y_val, df_test, y_test
Expand Down
29 changes: 24 additions & 5 deletions src/midst_toolkit/core/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import pandas as pd


def load_multi_table(data_dir, verbose=True):
def load_multi_table(
data_dir: str, verbose: bool = True
) -> tuple[dict[str, Any], list[tuple[str, str]], dict[str, Any]]:
dataset_meta = json.load(open(os.path.join(data_dir, "dataset_meta.json"), "r"))

relation_order = dataset_meta["relation_order"]
Expand Down Expand Up @@ -71,7 +73,7 @@ def pipeline_process_data(
ratio: float = 0.9,
save: bool = False,
verbose: bool = True,
) -> tuple[dict[str, Any], dict[str, Any]]:
) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]:
# ruff: noqa: D103
num_data = data_df.shape[0]

Expand All @@ -93,15 +95,18 @@ def pipeline_process_data(
num_train = int(num_data * ratio)
num_test = num_data - num_train

test_df: pd.DataFrame | None = None

if ratio < 1:
train_df, test_df, seed = train_val_test_split(data_df, cat_columns, num_train, num_test)
else:
train_df = data_df.copy()

train_df.columns = range(len(train_df.columns))
train_df.columns = list(range(len(train_df.columns)))

if ratio < 1:
test_df.columns = range(len(test_df.columns))
assert test_df is not None
test_df.columns = list(range(len(test_df.columns)))

col_info: dict[Any, Any] = {}

Expand Down Expand Up @@ -131,6 +136,7 @@ def pipeline_process_data(

train_df.rename(columns=idx_name_mapping, inplace=True)
if ratio < 1:
assert test_df is not None
test_df.rename(columns=idx_name_mapping, inplace=True)

for col in num_columns:
Expand All @@ -139,6 +145,7 @@ def pipeline_process_data(
train_df.loc[train_df[col] == "?", col] = "nan"

if ratio < 1:
assert test_df is not None
for col in num_columns:
test_df.loc[test_df[col] == "?", col] = np.nan
for col in cat_columns:
Expand All @@ -148,7 +155,12 @@ def pipeline_process_data(
X_cat_train = train_df[cat_columns].to_numpy()
y_train = train_df[target_columns].to_numpy()

X_num_test: np.ndarray | None = None
X_cat_test: np.ndarray | None = None
y_test: np.ndarray | None = None

if ratio < 1:
assert test_df is not None
X_num_test = test_df[num_columns].to_numpy().astype(np.float32)
X_cat_test = test_df[cat_columns].to_numpy()
y_test = test_df[target_columns].to_numpy()
Expand All @@ -160,19 +172,22 @@ def pipeline_process_data(
np.save(f"{save_dir}/y_train.npy", y_train)

if ratio < 1:
assert X_num_test is not None and X_cat_test is not None and y_test is not None
np.save(f"{save_dir}/X_num_test.npy", X_num_test)
np.save(f"{save_dir}/X_cat_test.npy", X_cat_test)
np.save(f"{save_dir}/y_test.npy", y_test)

train_df[num_columns] = train_df[num_columns].astype(np.float32)

if ratio < 1:
assert test_df is not None
test_df[num_columns] = test_df[num_columns].astype(np.float32)

if save:
train_df.to_csv(f"{save_dir}/train.csv", index=False)

if ratio < 1:
assert test_df is not None
test_df.to_csv(f"{save_dir}/test.csv", index=False)

if not os.path.exists(f"synthetic/{name}"):
Expand All @@ -181,12 +196,14 @@ def pipeline_process_data(
train_df.to_csv(f"synthetic/{name}/real.csv", index=False)

if ratio < 1:
assert test_df is not None
test_df.to_csv(f"synthetic/{name}/test.csv", index=False)

info["column_names"] = column_names
info["train_num"] = train_df.shape[0]

if ratio < 1:
assert test_df is not None
info["test_num"] = test_df.shape[0]

info["idx_mapping"] = idx_mapping
Expand Down Expand Up @@ -227,6 +244,7 @@ def pipeline_process_data(

if verbose:
if ratio < 1:
assert test_df is not None
str_shape = "Train dataframe shape: {}, Test dataframe shape: {}, Total dataframe shape: {}".format(
train_df.shape, test_df.shape, data_df.shape
)
Expand All @@ -251,7 +269,7 @@ def pipeline_process_data(
# print('Num', num)
# print('Cat', cat)

data = {
data: dict[str, dict[str, Any]] = {
"df": {"train": train_df},
"numpy": {
"X_num_train": X_num_train,
Expand All @@ -261,6 +279,7 @@ def pipeline_process_data(
}

if ratio < 1:
assert test_df is not None and X_num_test is not None and X_cat_test is not None and y_test is not None
data["df"]["test"] = test_df
data["numpy"]["X_num_test"] = X_num_test
data["numpy"]["X_cat_test"] = X_cat_test
Expand Down
18 changes: 9 additions & 9 deletions src/midst_toolkit/core/logger.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: I am being lazy in this file and adding a few typing ignores. Under normal circumstances, please yell at me for this. However, I know that @lotif will be removing this file and migrating it's useful contents in this PR. So I'm not trying very hard here 😂

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import time
import warnings
from collections import defaultdict
from collections.abc import Generator, Iterable
from collections.abc import Callable, Generator, Iterable
from contextlib import contextmanager
from typing import IO, Any

Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self, filename_or_file: str | IO[str]):
self.file = filename_or_file # type: ignore[assignment]
self.own_file = False

def writekvs(self, kvs):
def writekvs(self, kvs: dict[str, Any]) -> None:
# Create strings for printing
key2str = {}
for key, val in sorted(kvs.items()):
Expand Down Expand Up @@ -84,7 +84,7 @@ def _truncate(self, s: str) -> str:
maxlen = 30
return s[: maxlen - 3] + "..." if len(s) > maxlen else s

def writeseq(self, seq):
def writeseq(self, seq: Iterable[str]) -> None:
seq = list(seq)
for i, elem in enumerate(seq):
self.file.write(elem)
Expand All @@ -103,7 +103,7 @@ def __init__(self, filename: str):
self.file = open(filename, "wt")
# ruff: noqa: SIM115

def writekvs(self, kvs):
def writekvs(self, kvs: dict[str, Any]) -> None:
for k, v in sorted(kvs.items()):
if hasattr(v, "dtype"):
kvs[k] = float(v)
Expand All @@ -121,7 +121,7 @@ def __init__(self, filename: str):
self.keys: list[str] = []
self.sep = ","

def writekvs(self, kvs):
def writekvs(self, kvs: dict[str, Any]) -> None:
# Add our current row to the history
extra_keys = list(kvs.keys() - self.keys)
extra_keys.sort()
Expand Down Expand Up @@ -297,16 +297,16 @@ def profile_kv(scopename: str) -> Generator[None, None, None]:
get_current().name2val[logkey] += time.time() - tstart


def profile(n):
def profile(n: str) -> Callable:
"""
Usage.

@profile("my_func")
def my_func(): code
"""

def decorator_with_name(func):
def func_wrapper(*args, **kwargs):
def decorator_with_name(func): # type: ignore
def func_wrapper(*args, **kwargs): # type: ignore
with profile_kv(n):
return func(*args, **kwargs)

Expand Down Expand Up @@ -483,7 +483,7 @@ def reset() -> None:


@contextmanager
def scoped_configure(dir=None, format_strs=None, comm=None):
def scoped_configure(dir=None, format_strs=None, comm=None): # type: ignore
# ruff: noqa: D103
prevlogger = Logger.CURRENT
configure(dir=dir, format_strs=format_strs, comm=comm)
Expand Down
4 changes: 2 additions & 2 deletions src/midst_toolkit/data_processing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
synthetic_data: pd.DataFrame,
categorical_columns: list[str] | None,
numerical_columns: list[str] | None,
holdout_data: pd.DataFrame = None,
holdout_data: pd.DataFrame | None = None,
) -> None:
"""
A class responsible for fitting encoders and scalers for categorical and numerical columns of dataframes,
Expand Down Expand Up @@ -214,6 +214,6 @@ def is_column_type_numerical(dataframe: pd.DataFrame, column_name: str) -> bool:
Returns:
True if the column contains numerical values. False otherwise.
"""
column_dtype = dataframe[column_name].dtype
column_dtype = dataframe[column_name].to_numpy().dtype

return np.issubdtype(column_dtype, np.integer) or np.issubdtype(column_dtype, np.floating)
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def preprocess(

if real_data_test is None:
return (processed_synthetic_data, processed_real_data_train)

assert num_real_data_test_np is not None
assert cat_real_data_test_oh is not None
return (
processed_synthetic_data,
processed_real_data_train,
Expand Down
Loading