Skip to content

Commit

Permalink
add tests from ask
Browse files Browse the repository at this point in the history
  • Loading branch information
ravinkohli committed Mar 3, 2022
1 parent 7870aff commit 9588325
Showing 1 changed file with 113 additions and 1 deletion.
114 changes: 113 additions & 1 deletion test/test_data/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
from typing import Mapping
import warnings
from typing import List, Mapping

import numpy as np

import pandas as pd

import pytest

from sklearn.datasets import fetch_openml

from scipy.sparse import csr_matrix, spmatrix

from autoPyTorch.constants import (
BINARY,
CLASSIFICATION_TASKS,
CONTINUOUS,
MULTICLASS,
MULTICLASSMULTIOUTPUT,
CONTINUOUSMULTIOUTPUT,
TABULAR_REGRESSION,
TABULAR_CLASSIFICATION,
)
from autoPyTorch.data.utils import (
default_dataset_compression_arg,
get_dataset_compression_mapping,
megabytes,
reduce_dataset_size_if_too_large,
reduce_precision,
subsample,
validate_dataset_compression_arg
)
from autoPyTorch.utils.common import subsampler
Expand All @@ -37,6 +53,102 @@ def test_reduce_dataset_if_too_large(openmlid, as_frame, n_samples):
assert megabytes(X_converted) < megabytes(X)


@pytest.mark.parametrize("X", [np.asarray([[1, 1, 1]] * 30)])
@pytest.mark.parametrize("x_type", [list, np.ndarray, csr_matrix, pd.DataFrame])
@pytest.mark.parametrize(
"y, task, output",
[
(np.asarray([0] * 15 + [1] * 15), TABULAR_CLASSIFICATION, BINARY),
(np.asarray([0] * 10 + [1] * 10 + [2] * 10), TABULAR_CLASSIFICATION, MULTICLASS),
(np.asarray([[1, 0, 1]] * 30), TABULAR_CLASSIFICATION, MULTICLASSMULTIOUTPUT),
(np.asarray([1.0] * 30), TABULAR_REGRESSION, CONTINUOUS),
(np.asarray([[1.0, 1.0, 1.0]] * 30), TABULAR_REGRESSION, CONTINUOUSMULTIOUTPUT),
],
)
@pytest.mark.parametrize("y_type", [list, np.ndarray, pd.DataFrame, pd.Series])
@pytest.mark.parametrize("random_state", [0])
@pytest.mark.parametrize("sample_size", [0.25, 0.5, 5, 10])
def test_subsample_validity(X, x_type, y, y_type, random_state, sample_size, task, output):
"""Asserts the validity of the function with all valid types
We want to make sure that `subsample` works correctly with all the types listed
as x_type and y_type.
We also want to make sure it works with all kinds of target types.
The output should maintain the types, and subsample the correct amount.
(test adapted from autosklearn)
"""
assert len(X) == len(y) # Make sure our test data is correct

if y_type == pd.Series and output in [
MULTICLASSMULTIOUTPUT,
CONTINUOUSMULTIOUTPUT,
]:
# We can't have a pd.Series with multiple values as it's 1 dimensional
pytest.skip("Can't have pd.Series as y when task is n-dimensional")

# Convert our data to its given x_type or y_type
def convert(arr, objtype):
if objtype == np.ndarray:
return arr
elif objtype == list:
return arr.tolist()
else:
return objtype(arr)

X = convert(X, x_type)
y = convert(y, y_type)

# Subsample the data, ignoring any warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
X_sampled, y_sampled = subsample(
X,
y=y,
random_state=random_state,
sample_size=sample_size,
is_classification=task in CLASSIFICATION_TASKS,
)

# Function to get the type of an obj
def dtype(obj):
if isinstance(obj, List):
if isinstance(obj[0], List):
return type(obj[0][0])
else:
return type(obj[0])

elif isinstance(obj, pd.DataFrame):
return obj.dtypes

else:
return obj.dtype

# Check that the types of X remain the same after subsampling
if isinstance(X, pd.DataFrame):
# Dataframe can have multiple types, one per column
assert list(dtype(X_sampled)) == list(dtype(X))
else:
assert dtype(X_sampled) == dtype(X)

# Check that the types of y remain the same after subsampling
if isinstance(y, pd.DataFrame):
assert list(dtype(y_sampled)) == list(dtype(y))
else:
assert dtype(y_sampled) == dtype(y)

# Function to get the size of an object
def size(obj):
if isinstance(obj, spmatrix): # spmatrix doesn't support __len__
return obj.shape[0] if obj.shape[0] > 1 else obj.shape[1]
else:
return len(obj)

# check the right amount of samples were taken
if sample_size < 1:
assert size(X_sampled) == int(sample_size * size(X))
else:
assert size(X_sampled) == sample_size


def test_validate_dataset_compression_arg():

data_compression_args = validate_dataset_compression_arg({}, 10)
Expand Down

0 comments on commit 9588325

Please sign in to comment.