diff --git a/cartoframes/data/observatory/category.py b/cartoframes/data/observatory/category.py index 72d4d6f26..311e1916e 100644 --- a/cartoframes/data/observatory/category.py +++ b/cartoframes/data/observatory/category.py @@ -1,5 +1,6 @@ import pandas as pd +from cartoframes.exceptions import DiscoveryException from .repository.category_repo import get_category_repo from .repository.dataset_repo import get_dataset_repo @@ -22,7 +23,15 @@ def get_by_id(category_id): return get_category_repo().get_by_id(category_id) def datasets(self): - return get_dataset_repo().get_by_category(self[_CATEGORY_ID_FIELD]) + return get_dataset_repo().get_by_category(self._get_id()) + + def _get_id(self): + try: + return self[_CATEGORY_ID_FIELD] + except KeyError: + raise DiscoveryException('Unsupported function: this instance actually represents a subset of Categories ' + 'class. You should use `Categories.get_by_id("category_id")` to obtain a valid ' + 'instance of the Category class and then attempt this function on it.') def __eq__(self, other): return self.equals(other) @@ -41,6 +50,10 @@ def _constructor(self): def _constructor_sliced(self): return Category + def __init__(self, data): + super(Categories, self).__init__(data) + self.set_index(_CATEGORY_ID_FIELD, inplace=True, drop=False) + @staticmethod def get_all(): return get_category_repo().get_all() diff --git a/cartoframes/data/observatory/country.py b/cartoframes/data/observatory/country.py index 164672d17..e691ba918 100644 --- a/cartoframes/data/observatory/country.py +++ b/cartoframes/data/observatory/country.py @@ -1,5 +1,6 @@ import pandas as pd +from cartoframes.exceptions import DiscoveryException from .repository.geography_repo import get_geography_repo from .repository.country_repo import get_country_repo from .repository.dataset_repo import get_dataset_repo @@ -22,10 +23,18 @@ def get_by_id(iso_code3): return get_country_repo().get_by_id(iso_code3) def datasets(self): - return get_dataset_repo().get_by_country(self[_COUNTRY_ID_FIELD]) + return get_dataset_repo().get_by_country(self._get_id()) def geographies(self): - return get_geography_repo().get_by_country(self[_COUNTRY_ID_FIELD]) + return get_geography_repo().get_by_country(self._get_id()) + + def _get_id(self): + try: + return self[_COUNTRY_ID_FIELD] + except KeyError: + raise DiscoveryException('Unsupported function: this instance actually represents a subset of Countries ' + 'class. You should use `Countries.get_by_id("country_id")` to obtain a valid ' + 'instance of the Country class and then attempt this function on it.') def __eq__(self, other): return self.equals(other) @@ -44,6 +53,10 @@ def _constructor(self): def _constructor_sliced(self): return Country + def __init__(self, data): + super(Countries, self).__init__(data) + self.set_index(_COUNTRY_ID_FIELD, inplace=True, drop=False) + @staticmethod def get_all(): return get_country_repo().get_all() diff --git a/cartoframes/data/observatory/dataset.py b/cartoframes/data/observatory/dataset.py index d92748d76..d9757e466 100644 --- a/cartoframes/data/observatory/dataset.py +++ b/cartoframes/data/observatory/dataset.py @@ -1,5 +1,6 @@ import pandas as pd +from cartoframes.exceptions import DiscoveryException from .repository.dataset_repo import get_dataset_repo from .repository.variable_repo import get_variable_repo @@ -21,7 +22,15 @@ def get_by_id(dataset_id): return get_dataset_repo().get_by_id(dataset_id) def variables(self): - return get_variable_repo().get_by_dataset(self[_DATASET_ID_FIELD]) + return get_variable_repo().get_by_dataset(self._get_id()) + + def _get_id(self): + try: + return self[_DATASET_ID_FIELD] + except KeyError: + raise DiscoveryException('Unsupported function: this instance actually represents a subset of Datasets ' + 'class. You should use `Datasets.get_by_id("dataset_id")` to obtain a valid ' + 'instance of the Dataset class and then attempt this function on it.') def __eq__(self, other): return self.equals(other) @@ -40,6 +49,10 @@ def _constructor(self): def _constructor_sliced(self): return Dataset + def __init__(self, data): + super(Datasets, self).__init__(data) + self.set_index(_DATASET_ID_FIELD, inplace=True, drop=False) + @staticmethod def get_all(): return get_dataset_repo().get_all() diff --git a/cartoframes/data/observatory/geography.py b/cartoframes/data/observatory/geography.py index 1e54a2062..d0963736a 100644 --- a/cartoframes/data/observatory/geography.py +++ b/cartoframes/data/observatory/geography.py @@ -1,9 +1,10 @@ import pandas as pd +from cartoframes.exceptions import DiscoveryException from .repository.dataset_repo import get_dataset_repo from .repository.geography_repo import get_geography_repo -_GEOGRAPHY_FIELD_ID = 'id' +_GEOGRAPHY_ID_FIELD = 'id' class Geography(pd.Series): @@ -21,7 +22,15 @@ def get_by_id(geography_id): return get_geography_repo().get_by_id(geography_id) def datasets(self): - return get_dataset_repo().get_by_geography(self[_GEOGRAPHY_FIELD_ID]) + return get_dataset_repo().get_by_geography(self._get_id()) + + def _get_id(self): + try: + return self[_GEOGRAPHY_ID_FIELD] + except KeyError: + raise DiscoveryException('Unsupported function: this instance actually represents a subset of Geographies ' + 'class. You should use `Geographies.get_by_id("geography_id")` to obtain a valid ' + 'instance of the Geography class and then attempt this function on it.') def __eq__(self, other): return self.equals(other) @@ -40,6 +49,10 @@ def _constructor(self): def _constructor_sliced(self): return Geography + def __init__(self, data): + super(Geographies, self).__init__(data) + self.set_index(_GEOGRAPHY_ID_FIELD, inplace=True, drop=False) + @staticmethod def get_all(): return get_geography_repo().get_all() diff --git a/cartoframes/data/observatory/provider.py b/cartoframes/data/observatory/provider.py index cfcf4abf4..6c4aa6332 100644 --- a/cartoframes/data/observatory/provider.py +++ b/cartoframes/data/observatory/provider.py @@ -1,5 +1,6 @@ import pandas as pd +from cartoframes.exceptions import DiscoveryException from .repository.provider_repo import get_provider_repo from .repository.dataset_repo import get_dataset_repo @@ -22,7 +23,15 @@ def get_by_id(provider_id): return get_provider_repo().get_by_id(provider_id) def datasets(self): - return get_dataset_repo().get_by_provider(self[_PROVIDER_ID_FIELD]) + return get_dataset_repo().get_by_provider(self._get_id()) + + def _get_id(self): + try: + return self[_PROVIDER_ID_FIELD] + except KeyError: + raise DiscoveryException('Unsupported function: this instance actually represents a subset of Providers ' + 'class. You should use `Providers.get_by_id("category_id")` to obtain a valid ' + 'instance of the Provider class and then attempt this function on it.') def __eq__(self, other): return self.equals(other) @@ -41,6 +50,10 @@ def _constructor(self): def _constructor_sliced(self): return Provider + def __init__(self, data): + super(Providers, self).__init__(data) + self.set_index(_PROVIDER_ID_FIELD, inplace=True, drop=False) + @staticmethod def get_all(): return get_provider_repo().get_all() diff --git a/cartoframes/data/observatory/repository/category_repo.py b/cartoframes/data/observatory/repository/category_repo.py index 63abf6512..3282f31ea 100644 --- a/cartoframes/data/observatory/repository/category_repo.py +++ b/cartoframes/data/observatory/repository/category_repo.py @@ -31,6 +31,9 @@ def _to_category(result): @staticmethod def _to_categories(results): + if len(results) == 0: + return None + from cartoframes.data.observatory.category import Categories return Categories([CategoryRepository._to_category(result) for result in results]) diff --git a/cartoframes/data/observatory/repository/country_repo.py b/cartoframes/data/observatory/repository/country_repo.py index 8bef5b61a..f16634a50 100644 --- a/cartoframes/data/observatory/repository/country_repo.py +++ b/cartoframes/data/observatory/repository/country_repo.py @@ -31,6 +31,9 @@ def _to_country(result): @staticmethod def _to_countries(results): + if len(results) == 0: + return None + from cartoframes.data.observatory.country import Countries return Countries([CountryRepository._to_country(result) for result in results]) diff --git a/cartoframes/data/observatory/repository/dataset_repo.py b/cartoframes/data/observatory/repository/dataset_repo.py index 181c70fc1..520432428 100644 --- a/cartoframes/data/observatory/repository/dataset_repo.py +++ b/cartoframes/data/observatory/repository/dataset_repo.py @@ -46,6 +46,9 @@ def _to_dataset(result): @staticmethod def _to_datasets(results): + if len(results) == 0: + return None + from cartoframes.data.observatory.dataset import Datasets return Datasets(DatasetRepository._to_dataset(result) for result in results) diff --git a/cartoframes/data/observatory/repository/geography_repo.py b/cartoframes/data/observatory/repository/geography_repo.py index 17905fdb2..3aa2237a7 100644 --- a/cartoframes/data/observatory/repository/geography_repo.py +++ b/cartoframes/data/observatory/repository/geography_repo.py @@ -34,6 +34,9 @@ def _to_geography(result): @staticmethod def _to_geographies(results): + if len(results) == 0: + return None + from cartoframes.data.observatory.geography import Geographies return Geographies(GeographyRepository._to_geography(result) for result in results) diff --git a/cartoframes/data/observatory/repository/provider_repo.py b/cartoframes/data/observatory/repository/provider_repo.py index ac364c88d..bddee2fbd 100644 --- a/cartoframes/data/observatory/repository/provider_repo.py +++ b/cartoframes/data/observatory/repository/provider_repo.py @@ -34,6 +34,9 @@ def _to_provider(result): @staticmethod def _to_providers(results): + if len(results) == 0: + return None + from cartoframes.data.observatory.provider import Providers return Providers([ProviderRepository._to_provider(result) for result in results]) diff --git a/cartoframes/data/observatory/repository/variable_repo.py b/cartoframes/data/observatory/repository/variable_repo.py index 14e9fe348..3856e57a0 100644 --- a/cartoframes/data/observatory/repository/variable_repo.py +++ b/cartoframes/data/observatory/repository/variable_repo.py @@ -34,6 +34,9 @@ def _to_variable(result): @staticmethod def _to_variables(results): + if len(results) == 0: + return None + from cartoframes.data.observatory.variable import Variables return Variables([VariableRepository._to_variable(result) for result in results]) diff --git a/cartoframes/data/observatory/variable.py b/cartoframes/data/observatory/variable.py index e84f8d65d..782f66621 100644 --- a/cartoframes/data/observatory/variable.py +++ b/cartoframes/data/observatory/variable.py @@ -1,9 +1,10 @@ import pandas as pd +from cartoframes.exceptions import DiscoveryException from .repository.dataset_repo import get_dataset_repo from .repository.variable_repo import get_variable_repo -_VARIABLE_FIELD_ID = 'id' +_VARIABLE_ID_FIELD = 'id' class Variable(pd.Series): @@ -21,7 +22,15 @@ def get_by_id(variable_id): return get_variable_repo().get_by_id(variable_id) def datasets(self): - return get_dataset_repo().get_by_variable(self[_VARIABLE_FIELD_ID]) + return get_dataset_repo().get_by_variable(self._get_id()) + + def _get_id(self): + try: + return self[_VARIABLE_ID_FIELD] + except KeyError: + raise DiscoveryException('Unsupported function: this instance actually represents a subset of Variables ' + 'class. You should use `Variables.get_by_id("variable_id")` to obtain a valid ' + 'instance of the Variable class and then attempt this function on it.') def __eq__(self, other): return self.equals(other) @@ -40,6 +49,10 @@ def _constructor(self): def _constructor_sliced(self): return Variable + def __init__(self, data): + super(Variables, self).__init__(data) + self.set_index(_VARIABLE_ID_FIELD, inplace=True, drop=False) + @staticmethod def get_all(): return get_variable_repo().get_all() diff --git a/examples/07_catalog/discovery.ipynb b/examples/07_catalog/discovery.ipynb index e79f200d6..fcc2b9889 100644 --- a/examples/07_catalog/discovery.ipynb +++ b/examples/07_catalog/discovery.ipynb @@ -96,6 +96,22 @@ "isinstance(filtered_country, pd.Series)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also use the id to access with loc, since the id corresponds to the DataFrame's index:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "countries.loc['spain']" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/test/data/observatory/repository/test_category_repo.py b/test/data/observatory/repository/test_category_repo.py index 5dc00e4ed..fcbc5ce14 100644 --- a/test/data/observatory/repository/test_category_repo.py +++ b/test/data/observatory/repository/test_category_repo.py @@ -1,7 +1,6 @@ import unittest from cartoframes.exceptions import DiscoveryException -from cartoframes.data.observatory.category import Categories from cartoframes.data.observatory.repository.category_repo import CategoryRepository from cartoframes.data.observatory.repository.repo_client import RepoClient @@ -28,6 +27,9 @@ def test_get_all(self, mocked_repo): mocked_repo.assert_called_once_with() assert categories == test_categories + id1 = db_category1['id'] + assert categories.loc[id1] == test_category1 + @patch.object(RepoClient, 'get_categories') def test_get_all_when_empty(self, mocked_repo): # Given @@ -39,7 +41,7 @@ def test_get_all_when_empty(self, mocked_repo): # Then mocked_repo.assert_called_once_with() - assert categories == Categories([]) + assert categories is None @patch.object(RepoClient, 'get_categories') def test_get_by_id(self, mocked_repo): diff --git a/test/data/observatory/repository/test_country_repo.py b/test/data/observatory/repository/test_country_repo.py index d319c2a8f..f65f670a8 100644 --- a/test/data/observatory/repository/test_country_repo.py +++ b/test/data/observatory/repository/test_country_repo.py @@ -1,7 +1,6 @@ import unittest from cartoframes.exceptions import DiscoveryException -from cartoframes.data.observatory.country import Countries from cartoframes.data.observatory.repository.country_repo import CountryRepository from cartoframes.data.observatory.repository.repo_client import RepoClient @@ -37,7 +36,7 @@ def test_get_all_when_empty(self, mocked_repo): countries = repo.get_all() # Then - assert countries == Countries([]) + assert countries is None @patch.object(RepoClient, 'get_countries') def test_get_by_id(self, mocked_repo): diff --git a/test/data/observatory/repository/test_dataset_repo.py b/test/data/observatory/repository/test_dataset_repo.py index 7fcfa23a5..b6e280e51 100644 --- a/test/data/observatory/repository/test_dataset_repo.py +++ b/test/data/observatory/repository/test_dataset_repo.py @@ -1,7 +1,6 @@ import unittest from cartoframes.exceptions import DiscoveryException -from cartoframes.data.observatory.dataset import Datasets from cartoframes.data.observatory.repository.dataset_repo import DatasetRepository from cartoframes.data.observatory.repository.repo_client import RepoClient @@ -39,7 +38,7 @@ def test_get_all_when_empty(self, mocked_repo): # Then mocked_repo.assert_called_once_with() - assert datasets == Datasets([]) + assert datasets is None @patch.object(RepoClient, 'get_datasets') def test_get_by_id(self, mocked_repo): diff --git a/test/data/observatory/repository/test_geography_repo.py b/test/data/observatory/repository/test_geography_repo.py index 0d968fd0f..4171c8431 100644 --- a/test/data/observatory/repository/test_geography_repo.py +++ b/test/data/observatory/repository/test_geography_repo.py @@ -1,7 +1,6 @@ import unittest from cartoframes.exceptions import DiscoveryException -from cartoframes.data.observatory.geography import Geographies from cartoframes.data.observatory.repository.geography_repo import GeographyRepository from cartoframes.data.observatory.repository.repo_client import RepoClient @@ -39,7 +38,7 @@ def test_get_all_when_empty(self, mocked_repo): # Then mocked_repo.assert_called_once_with() - assert geographies == Geographies([]) + assert geographies is None @patch.object(RepoClient, 'get_geographies') def test_get_by_id(self, mocked_repo): diff --git a/test/data/observatory/repository/test_provider_repo.py b/test/data/observatory/repository/test_provider_repo.py index e0de4ec97..78bd1b250 100644 --- a/test/data/observatory/repository/test_provider_repo.py +++ b/test/data/observatory/repository/test_provider_repo.py @@ -1,7 +1,6 @@ import unittest from cartoframes.exceptions import DiscoveryException -from cartoframes.data.observatory.provider import Providers from cartoframes.data.observatory.repository.provider_repo import ProviderRepository from cartoframes.data.observatory.repository.repo_client import RepoClient @@ -39,7 +38,7 @@ def test_get_all_when_empty(self, mocked_repo): # Then mocked_repo.assert_called_once_with() - assert providers == Providers([]) + assert providers is None @patch.object(RepoClient, 'get_providers') def test_get_by_id(self, mocked_repo): diff --git a/test/data/observatory/repository/test_variable_repo.py b/test/data/observatory/repository/test_variable_repo.py index 2760683eb..857cb523d 100644 --- a/test/data/observatory/repository/test_variable_repo.py +++ b/test/data/observatory/repository/test_variable_repo.py @@ -1,7 +1,6 @@ import unittest from cartoframes.exceptions import DiscoveryException -from cartoframes.data.observatory.variable import Variables from cartoframes.data.observatory.repository.variable_repo import VariableRepository from cartoframes.data.observatory.repository.repo_client import RepoClient @@ -39,7 +38,7 @@ def test_get_all_when_empty(self, mocked_repo): # Then mocked_repo.assert_called_once_with() - assert variables == Variables([]) + assert variables is None @patch.object(RepoClient, 'get_variables') def test_get_by_id(self, mocked_repo): diff --git a/test/data/observatory/test_category.py b/test/data/observatory/test_category.py index bae95bb93..748cd0f57 100644 --- a/test/data/observatory/test_category.py +++ b/test/data/observatory/test_category.py @@ -6,8 +6,9 @@ from cartoframes.data.observatory.repository.category_repo import CategoryRepository from cartoframes.data.observatory.repository.dataset_repo import DatasetRepository +from cartoframes.exceptions import DiscoveryException -from .examples import test_category1, test_datasets, test_categories +from .examples import test_category1, test_datasets, test_categories, db_category1 try: from unittest.mock import Mock, patch @@ -18,7 +19,7 @@ class TestCategory(unittest.TestCase): @patch.object(CategoryRepository, 'get_by_id') - def test_get_by_id(self, mocked_repo): + def test_get_category_by_id(self, mocked_repo): # Given mocked_repo.return_value = test_category1 @@ -31,7 +32,7 @@ def test_get_by_id(self, mocked_repo): assert category == test_category1 @patch.object(DatasetRepository, 'get_by_category') - def test_get_datasets(self, mocked_repo): + def test_get_datasets_by_category(self, mocked_repo): # Given mocked_repo.return_value = test_datasets @@ -43,6 +44,14 @@ def test_get_datasets(self, mocked_repo): assert isinstance(datasets, Datasets) assert datasets == test_datasets + def test_get_datasets_by_category_fails_if_column_Series(self): + # Given + category = test_categories.id + + # Then + with self.assertRaises(DiscoveryException): + category.datasets() + class TestCategories(unittest.TestCase): @@ -60,7 +69,7 @@ def test_get_all(self, mocked_repo): assert categories == test_categories @patch.object(CategoryRepository, 'get_by_id') - def test_get_by_id(self, mocked_repo): + def test_get_category_by_id(self, mocked_repo): # Given mocked_repo.return_value = test_category1 @@ -71,3 +80,29 @@ def test_get_by_id(self, mocked_repo): assert isinstance(category, pd.Series) assert isinstance(category, Category) assert category == test_category1 + + @patch.object(CategoryRepository, 'get_all') + def test_categories_are_indexed_with_id(self, mocked_repo): + # Given + mocked_repo.return_value = test_categories + category_id = db_category1['id'] + + # When + categories = Categories.get_all() + category = categories.loc[category_id] + + # Then + assert category == test_category1 + + @patch.object(CategoryRepository, 'get_all') + def test_categories_slice_is_category_and_series(self, mocked_repo): + # Given + mocked_repo.return_value = test_categories + + # When + categories = Categories.get_all() + category = categories.iloc[0] + + # Then + assert isinstance(category, Category) + assert isinstance(category, pd.Series) diff --git a/test/data/observatory/test_country.py b/test/data/observatory/test_country.py index 6d2037319..3cd5ea1cc 100644 --- a/test/data/observatory/test_country.py +++ b/test/data/observatory/test_country.py @@ -7,8 +7,9 @@ from cartoframes.data.observatory.repository.geography_repo import GeographyRepository from cartoframes.data.observatory.repository.dataset_repo import DatasetRepository from cartoframes.data.observatory.repository.country_repo import CountryRepository +from cartoframes.exceptions import DiscoveryException -from .examples import test_country1, test_datasets, test_countries, test_geographies +from .examples import test_country1, test_datasets, test_countries, test_geographies, db_country1 try: from unittest.mock import Mock, patch @@ -19,7 +20,7 @@ class TestCountry(unittest.TestCase): @patch.object(CountryRepository, 'get_by_id') - def test_get_by_id(self, mocked_repo): + def test_get_country_by_id(self, mocked_repo): # Given mocked_repo.return_value = test_country1 @@ -32,7 +33,7 @@ def test_get_by_id(self, mocked_repo): assert country == test_country1 @patch.object(DatasetRepository, 'get_by_country') - def test_get_datasets(self, mocked_repo): + def test_get_datasets_by_country(self, mocked_repo): # Given mocked_repo.return_value = test_datasets @@ -44,8 +45,16 @@ def test_get_datasets(self, mocked_repo): assert isinstance(datasets, Datasets) assert datasets == test_datasets + def test_get_datasets_by_country_fails_if_column_Series(self): + # Given + country = test_countries.country_iso_code3 + + # Then + with self.assertRaises(DiscoveryException): + country.datasets() + @patch.object(GeographyRepository, 'get_by_country') - def test_get_geographies(self, mocked_repo): + def test_get_geographies_by_country(self, mocked_repo): # Given mocked_repo.return_value = test_geographies @@ -57,11 +66,19 @@ def test_get_geographies(self, mocked_repo): assert isinstance(geographies, Geographies) assert geographies == test_geographies + def test_get_geographies_by_country_fails_if_column_Series(self): + # Given + country = test_countries.country_iso_code3 + + # Then + with self.assertRaises(DiscoveryException): + country.geographies() + class TestCountries(unittest.TestCase): @patch.object(CountryRepository, 'get_all') - def test_get_all(self, mocked_repo): + def test_get_all_countries(self, mocked_repo): # Given mocked_repo.return_value = test_countries @@ -74,7 +91,7 @@ def test_get_all(self, mocked_repo): assert countries == test_countries @patch.object(CountryRepository, 'get_by_id') - def test_get_by_id(self, mocked_repo): + def test_get_country_by_id(self, mocked_repo): # Given mocked_repo.return_value = test_country1 @@ -85,3 +102,29 @@ def test_get_by_id(self, mocked_repo): assert isinstance(country, pd.Series) assert isinstance(country, Country) assert country == test_country1 + + @patch.object(CountryRepository, 'get_all') + def test_countries_are_indexed_with_id(self, mocked_repo): + # Given + mocked_repo.return_value = test_countries + country_id = db_country1['country_iso_code3'] + + # When + countries = Countries.get_all() + country = countries.loc[country_id] + + # Then + assert country == test_country1 + + @patch.object(CountryRepository, 'get_all') + def test_countries_slice_is_country_and_series(self, mocked_repo): + # Given + mocked_repo.return_value = test_countries + + # When + countries = Countries.get_all() + country = countries.iloc[0] + + # Then + assert isinstance(country, Country) + assert isinstance(country, pd.Series) diff --git a/test/data/observatory/test_dataset.py b/test/data/observatory/test_dataset.py index 946a5c621..69cdef0b2 100644 --- a/test/data/observatory/test_dataset.py +++ b/test/data/observatory/test_dataset.py @@ -5,8 +5,9 @@ from cartoframes.data.observatory.dataset import Datasets, Dataset from cartoframes.data.observatory.repository.variable_repo import VariableRepository from cartoframes.data.observatory.repository.dataset_repo import DatasetRepository +from cartoframes.exceptions import DiscoveryException -from .examples import test_dataset1, test_datasets, test_variables +from .examples import test_dataset1, test_datasets, test_variables, db_dataset1 try: from unittest.mock import Mock, patch @@ -17,7 +18,7 @@ class TestDataset(unittest.TestCase): @patch.object(DatasetRepository, 'get_by_id') - def test_get_by_id(self, mocked_repo): + def test_get_dataset_by_id(self, mocked_repo): # Given mocked_repo.return_value = test_dataset1 @@ -30,7 +31,7 @@ def test_get_by_id(self, mocked_repo): assert dataset == test_dataset1 @patch.object(VariableRepository, 'get_by_dataset') - def test_get_variables(self, mocked_repo): + def test_get_variables_by_dataset(self, mocked_repo): # Given mocked_repo.return_value = test_variables @@ -42,11 +43,19 @@ def test_get_variables(self, mocked_repo): assert isinstance(variables, Variables) assert variables == test_variables + def test_get_variables_by_dataset_fails_if_column_Series(self): + # Given + dataset = test_datasets.id + + # Then + with self.assertRaises(DiscoveryException): + dataset.variables() + class TestDatasets(unittest.TestCase): @patch.object(DatasetRepository, 'get_all') - def test_get_all(self, mocked_repo): + def test_get_all_datasets(self, mocked_repo): # Given mocked_repo.return_value = test_datasets @@ -58,7 +67,7 @@ def test_get_all(self, mocked_repo): assert isinstance(datasets, Datasets) @patch.object(DatasetRepository, 'get_by_id') - def test_get_by_id(self, mocked_repo): + def test_get_dataset_by_id(self, mocked_repo): # Given mocked_repo.return_value = test_dataset1 @@ -69,3 +78,29 @@ def test_get_by_id(self, mocked_repo): assert isinstance(dataset, pd.Series) assert isinstance(dataset, Dataset) assert dataset == test_dataset1 + + @patch.object(DatasetRepository, 'get_all') + def test_datasets_are_indexed_with_id(self, mocked_repo): + # Given + mocked_repo.return_value = test_datasets + dataset_id = db_dataset1['id'] + + # When + datasets = Datasets.get_all() + dataset = datasets.loc[dataset_id] + + # Then + assert dataset == test_dataset1 + + @patch.object(DatasetRepository, 'get_all') + def test_datasets_slice_is_dataset_and_series(self, mocked_repo): + # Given + mocked_repo.return_value = test_datasets + + # When + datasets = Datasets.get_all() + dataset = datasets.iloc[0] + + # Then + assert isinstance(dataset, Dataset) + assert isinstance(dataset, pd.Series) diff --git a/test/data/observatory/test_geography.py b/test/data/observatory/test_geography.py index 3f6fdd685..e769dc424 100644 --- a/test/data/observatory/test_geography.py +++ b/test/data/observatory/test_geography.py @@ -1,13 +1,13 @@ import unittest import pandas as pd -from cartoframes.data.observatory.geography import Geography, Geographies +from cartoframes.data.observatory.geography import Geography, Geographies from cartoframes.data.observatory.repository.geography_repo import GeographyRepository - from cartoframes.data.observatory.dataset import Datasets from cartoframes.data.observatory.repository.dataset_repo import DatasetRepository +from cartoframes.exceptions import DiscoveryException -from .examples import test_geography1, test_geographies, test_datasets +from .examples import test_geography1, test_geographies, test_datasets, db_geography1 try: from unittest.mock import Mock, patch @@ -18,7 +18,7 @@ class TestGeography(unittest.TestCase): @patch.object(GeographyRepository, 'get_by_id') - def test_get_by_id(self, mocked_repo): + def test_get_geography_by_id(self, mocked_repo): # Given mocked_repo.return_value = test_geography1 @@ -31,7 +31,7 @@ def test_get_by_id(self, mocked_repo): assert geography == test_geography1 @patch.object(DatasetRepository, 'get_by_geography') - def test_get_datasets(self, mocked_repo): + def test_get_datasets_by_geography(self, mocked_repo): # Given mocked_repo.return_value = test_datasets @@ -43,11 +43,19 @@ def test_get_datasets(self, mocked_repo): assert isinstance(datasets, Datasets) assert datasets == test_datasets + def test_get_datasets_by_geography_fails_if_column_Series(self): + # Given + geography = test_geographies.id + + # Then + with self.assertRaises(DiscoveryException): + geography.datasets() + class TestGeographies(unittest.TestCase): @patch.object(GeographyRepository, 'get_all') - def test_get_all(self, mocked_repo): + def test_get_all_geographies(self, mocked_repo): # Given mocked_repo.return_value = test_geographies @@ -60,7 +68,7 @@ def test_get_all(self, mocked_repo): assert geographies == test_geographies @patch.object(GeographyRepository, 'get_by_id') - def test_get_by_id(self, mocked_repo): + def test_get_geography_by_id(self, mocked_repo): # Given mocked_repo.return_value = test_geography1 @@ -71,3 +79,29 @@ def test_get_by_id(self, mocked_repo): assert isinstance(geography, pd.Series) assert isinstance(geography, Geography) assert geography == test_geography1 + + @patch.object(GeographyRepository, 'get_all') + def test_geographies_are_indexed_with_id(self, mocked_repo): + # Given + mocked_repo.return_value = test_geographies + geography_id = db_geography1['id'] + + # When + geographies = Geographies.get_all() + geography = geographies.loc[geography_id] + + # Then + assert geography == test_geography1 + + @patch.object(GeographyRepository, 'get_all') + def test_geographies_slice_is_geography_and_series(self, mocked_repo): + # Given + mocked_repo.return_value = test_geographies + + # When + geographies = Geographies.get_all() + geography = geographies.iloc[0] + + # Then + assert isinstance(geography, Geography) + assert isinstance(geography, pd.Series) diff --git a/test/data/observatory/test_provider.py b/test/data/observatory/test_provider.py index 21d02c3ea..e54c8caaa 100644 --- a/test/data/observatory/test_provider.py +++ b/test/data/observatory/test_provider.py @@ -6,8 +6,9 @@ from cartoframes.data.observatory.repository.provider_repo import ProviderRepository from cartoframes.data.observatory.repository.dataset_repo import DatasetRepository +from cartoframes.exceptions import DiscoveryException -from .examples import test_datasets, test_provider1, test_providers +from .examples import test_datasets, test_provider1, test_providers, db_provider1 try: from unittest.mock import Mock, patch @@ -18,7 +19,7 @@ class TestProvider(unittest.TestCase): @patch.object(ProviderRepository, 'get_by_id') - def test_get_by_id(self, mocked_repo): + def test_get_provider_by_id(self, mocked_repo): # Given mocked_repo.return_value = test_provider1 @@ -31,7 +32,7 @@ def test_get_by_id(self, mocked_repo): assert provider == test_provider1 @patch.object(DatasetRepository, 'get_by_provider') - def test_get_datasets(self, mocked_repo): + def test_get_datasets_by_provider(self, mocked_repo): # Given mocked_repo.return_value = test_datasets @@ -43,11 +44,19 @@ def test_get_datasets(self, mocked_repo): assert isinstance(datasets, Datasets) assert datasets == test_datasets + def test_get_datasets_by_provider_fails_if_column_Series(self): + # Given + provider = test_providers.id + + # Then + with self.assertRaises(DiscoveryException): + provider.datasets() + class TestProviders(unittest.TestCase): @patch.object(ProviderRepository, 'get_all') - def test_get_all(self, mocked_repo): + def test_get_all_providers(self, mocked_repo): # Given mocked_repo.return_value = test_providers @@ -60,7 +69,7 @@ def test_get_all(self, mocked_repo): assert providers == test_providers @patch.object(ProviderRepository, 'get_by_id') - def test_get_by_id(self, mocked_repo): + def test_get_provider_by_id(self, mocked_repo): # Given mocked_repo.return_value = test_provider1 @@ -71,3 +80,29 @@ def test_get_by_id(self, mocked_repo): assert isinstance(provider, pd.Series) assert isinstance(provider, Provider) assert provider == test_provider1 + + @patch.object(ProviderRepository, 'get_all') + def test_providers_are_indexed_with_id(self, mocked_repo): + # Given + mocked_repo.return_value = test_providers + provider_id = db_provider1['id'] + + # When + providers = Providers.get_all() + provider = providers.loc[provider_id] + + # Then + assert provider == test_provider1 + + @patch.object(DatasetRepository, 'get_all') + def test_providers_slice_is_provider_and_series(self, mocked_repo): + # Given + mocked_repo.return_value = test_providers + + # When + providers = Datasets.get_all() + provider = providers.iloc[0] + + # Then + assert isinstance(provider, Provider) + assert isinstance(provider, pd.Series) diff --git a/test/data/observatory/test_variable.py b/test/data/observatory/test_variable.py index 5ed7c7b49..5d7c04160 100644 --- a/test/data/observatory/test_variable.py +++ b/test/data/observatory/test_variable.py @@ -5,8 +5,9 @@ from cartoframes.data.observatory.dataset import Datasets from cartoframes.data.observatory.repository.variable_repo import VariableRepository from cartoframes.data.observatory.repository.dataset_repo import DatasetRepository +from cartoframes.exceptions import DiscoveryException -from .examples import test_datasets, test_variable1, test_variables +from .examples import test_datasets, test_variable1, test_variables, db_variable1 try: from unittest.mock import Mock, patch @@ -17,7 +18,7 @@ class TestVariable(unittest.TestCase): @patch.object(VariableRepository, 'get_by_id') - def test_get_by_id(self, mocked_repo): + def test_get_variable_by_id(self, mocked_repo): # Given mocked_repo.return_value = test_variable1 @@ -30,7 +31,7 @@ def test_get_by_id(self, mocked_repo): assert variable == test_variable1 @patch.object(DatasetRepository, 'get_by_variable') - def test_get_datasets(self, mocked_repo): + def test_get_datasets_by_variable(self, mocked_repo): # Given mocked_repo.return_value = test_datasets @@ -42,11 +43,19 @@ def test_get_datasets(self, mocked_repo): assert isinstance(datasets, Datasets) assert datasets == test_datasets + def test_get_datasets_by_variable_fails_if_column_Series(self): + # Given + variable = test_variables.id + + # Then + with self.assertRaises(DiscoveryException): + variable.datasets() + class TestVariables(unittest.TestCase): @patch.object(VariableRepository, 'get_all') - def test_get_all(self, mocked_repo): + def test_get_all_variables(self, mocked_repo): # Given mocked_repo.return_value = test_variables @@ -59,7 +68,7 @@ def test_get_all(self, mocked_repo): assert countries == test_variables @patch.object(VariableRepository, 'get_by_id') - def test_get_by_id(self, mocked_repo): + def test_get_variable_by_id(self, mocked_repo): # Given mocked_repo.return_value = test_variable1 @@ -70,3 +79,29 @@ def test_get_by_id(self, mocked_repo): assert isinstance(variable, pd.Series) assert isinstance(variable, Variable) assert variable == test_variable1 + + @patch.object(VariableRepository, 'get_all') + def test_variables_are_indexed_with_id(self, mocked_repo): + # Given + mocked_repo.return_value = test_variables + variable_id = db_variable1['id'] + + # When + variables = Variables.get_all() + variable = variables.loc[variable_id] + + # Then + assert variable == test_variable1 + + @patch.object(VariableRepository, 'get_all') + def test_variables_slice_is_variable_and_series(self, mocked_repo): + # Given + mocked_repo.return_value = test_variables + + # When + variables = Variables.get_all() + variable = variables.iloc[0] + + # Then + assert isinstance(variable, Variable) + assert isinstance(variable, pd.Series)