diff --git a/.github/workflows/static-checking.yml b/.github/workflows/static-checking.yml index bc33d9327..56f978a50 100644 --- a/.github/workflows/static-checking.yml +++ b/.github/workflows/static-checking.yml @@ -27,10 +27,11 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install -r requirements-dev.txt + pip install -r requirements-torch.txt - name: CloudFormation Lint run: cfn-lint -t testing/cloudformation.yaml - name: Documentation Lint - run: pydocstyle awswrangler/ --add-ignore=D204 + run: pydocstyle awswrangler/ --add-ignore=D204,D403 - name: mypy check run: mypy awswrangler - name: Flake8 Lint diff --git a/.pylintrc b/.pylintrc index 132ce213a..4f41cb3fb 100644 --- a/.pylintrc +++ b/.pylintrc @@ -141,7 +141,8 @@ disable=print-statement, comprehension-escape, C0330, C0103, - W1202 + W1202, + too-few-public-methods # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/README.md b/README.md index d4a8a3cad..624ebc12c 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,7 @@ df = wr.db.read_sql_query("SELECT * FROM external_schema.my_table", con=engine) - [11 - CSV Datasets](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/11%20-%20CSV%20Datasets.ipynb) - [12 - CSV Crawler](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/12%20-%20CSV%20Crawler.ipynb) - [13 - Merging Datasets on S3](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/13%20-%20Merging%20Datasets%20on%20S3.ipynb) + - [14 - PyTorch](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/14%20-%20PyTorch.ipynb) - [**API Reference**](https://aws-data-wrangler.readthedocs.io/en/latest/api.html) - [Amazon S3](https://aws-data-wrangler.readthedocs.io/en/latest/api.html#amazon-s3) - [AWS Glue Catalog](https://aws-data-wrangler.readthedocs.io/en/latest/api.html#aws-glue-catalog) diff --git a/awswrangler/__init__.py b/awswrangler/__init__.py index ce11c7ad5..b7f931a3d 100644 --- a/awswrangler/__init__.py +++ b/awswrangler/__init__.py @@ -5,9 +5,18 @@ """ +import importlib import logging from awswrangler import athena, catalog, cloudwatch, db, emr, exceptions, s3 # noqa from awswrangler.__metadata__ import __description__, __license__, __title__, __version__ # noqa +if ( + importlib.util.find_spec("torch") + and importlib.util.find_spec("torchvision") + and importlib.util.find_spec("torchaudio") + and importlib.util.find_spec("PIL") +): # type: ignore + from awswrangler import torch # noqa + logging.getLogger("awswrangler").addHandler(logging.NullHandler()) diff --git a/awswrangler/db.py b/awswrangler/db.py index 491fe7784..5d16301ad 100644 --- a/awswrangler/db.py +++ b/awswrangler/db.py @@ -155,29 +155,15 @@ def read_sql_query( ... ) """ - if not isinstance(con, sqlalchemy.engine.Engine): # pragma: no cover - raise exceptions.InvalidConnection( - "Invalid 'con' argument, please pass a " - "SQLAlchemy Engine. Use wr.db.get_engine(), " - "wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()" - ) + _validate_engine(con=con) with con.connect() as _con: args = _convert_params(sql, params) cursor = _con.execute(*args) if chunksize is None: return _records2df(records=cursor.fetchall(), cols_names=cursor.keys(), index=index_col, dtype=dtype) - return _iterate_cursor(cursor=cursor, chunksize=chunksize, index=index_col, dtype=dtype) - - -def _iterate_cursor( - cursor, chunksize: int, index: Optional[Union[str, List[str]]], dtype: Optional[Dict[str, pa.DataType]] = None -) -> Iterator[pd.DataFrame]: - while True: - records = cursor.fetchmany(chunksize) - if not records: - break - df: pd.DataFrame = _records2df(records=records, cols_names=cursor.keys(), index=index, dtype=dtype) - yield df + return _iterate_cursor( + cursor=cursor, chunksize=chunksize, cols_names=cursor.keys(), index=index_col, dtype=dtype + ) def _records2df( @@ -207,6 +193,20 @@ def _records2df( return df +def _iterate_cursor( + cursor: Any, + chunksize: int, + cols_names: List[str], + index: Optional[Union[str, List[str]]], + dtype: Optional[Dict[str, pa.DataType]] = None, +) -> Iterator[pd.DataFrame]: + while True: + records = cursor.fetchmany(chunksize) + if not records: + break + yield _records2df(records=records, cols_names=cols_names, index=index, dtype=dtype) + + def _convert_params(sql: str, params: Optional[Union[List, Tuple, Dict]]) -> List[Any]: args: List[Any] = [sql] if params is not None: @@ -1087,3 +1087,12 @@ def unload_redshift_to_files( paths = [x[0].replace(" ", "") for x in _con.execute(sql).fetchall()] _logger.debug(f"paths: {paths}") return paths + + +def _validate_engine(con: sqlalchemy.engine.Engine) -> None: # pragma: no cover + if not isinstance(con, sqlalchemy.engine.Engine): + raise exceptions.InvalidConnection( + "Invalid 'con' argument, please pass a " + "SQLAlchemy Engine. Use wr.db.get_engine(), " + "wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()" + ) diff --git a/awswrangler/s3.py b/awswrangler/s3.py index f728937db..c083d52c5 100644 --- a/awswrangler/s3.py +++ b/awswrangler/s3.py @@ -111,7 +111,7 @@ def does_object_exist(path: str, boto3_session: Optional[boto3.Session] = None) raise ex # pragma: no cover -def list_objects(path: str, boto3_session: Optional[boto3.Session] = None) -> List[str]: +def list_objects(path: str, suffix: Optional[str] = None, boto3_session: Optional[boto3.Session] = None) -> List[str]: """List Amazon S3 objects from a prefix. Parameters @@ -120,6 +120,8 @@ def list_objects(path: str, boto3_session: Optional[boto3.Session] = None) -> Li S3 path (e.g. s3://bucket/prefix). boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. + suffix: str, optional + Suffix for filtering S3 keys. Returns ------- @@ -155,15 +157,16 @@ def list_objects(path: str, boto3_session: Optional[boto3.Session] = None) -> Li for content in contents: if (content is not None) and ("Key" in content): key: str = content["Key"] - paths.append(f"s3://{bucket}/{key}") + if (suffix is None) or key.endswith(suffix): + paths.append(f"s3://{bucket}/{key}") return paths -def _path2list(path: Union[str, List[str]], boto3_session: Optional[boto3.Session]) -> List[str]: +def _path2list(path: object, boto3_session: boto3.Session, suffix: str = None) -> List[str]: if isinstance(path, str): # prefix paths: List[str] = list_objects(path=path, boto3_session=boto3_session) elif isinstance(path, list): - paths = path + paths = path if suffix is None else [x for x in path if x.endswith(suffix)] else: raise exceptions.InvalidArgumentType(f"{type(path)} is not a valid path type. Please, use str or List[str].") return paths diff --git a/awswrangler/torch.py b/awswrangler/torch.py new file mode 100644 index 000000000..e7cd4518f --- /dev/null +++ b/awswrangler/torch.py @@ -0,0 +1,475 @@ +"""PyTorch Module.""" +import io +import logging +import os +import pathlib +import re +import tarfile +from collections.abc import Iterable +from io import BytesIO +from typing import Any, Callable, Iterator, List, Optional, Tuple, Union + +import boto3 # type: ignore +import numpy as np # type: ignore +import sqlalchemy # type: ignore +import torch # type: ignore +import torchaudio # type: ignore +from PIL import Image # type: ignore +from torch.utils.data.dataset import Dataset, IterableDataset # type: ignore +from torchvision.transforms.functional import to_tensor # type: ignore + +from awswrangler import _utils, db, s3 + +_logger: logging.Logger = logging.getLogger(__name__) + + +class _BaseS3Dataset: + """PyTorch Amazon S3 Map-Style Dataset.""" + + def __init__( + self, path: Union[str, List[str]], suffix: Optional[str] = None, boto3_session: Optional[boto3.Session] = None + ): + """PyTorch Map-Style S3 Dataset. + + Parameters + ---------- + path : Union[str, List[str]] + S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]). + suffix: str, optional + S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png). + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + torch.utils.data.Dataset + + """ + super().__init__() + self._session = _utils.ensure_session(session=boto3_session) + self._paths: List[str] = s3._path2list( # pylint: disable=protected-access + path=path, suffix=suffix, boto3_session=self._session + ) + + def _fetch_data(self, path: str) -> Any: + """Add parquet and csv support.""" + bucket, key = _utils.parse_path(path=path) + buff = BytesIO() + client_s3: boto3.client = _utils.client(service_name="s3", session=self._session) + client_s3.download_fileobj(Bucket=bucket, Key=key, Fileobj=buff) + buff.seek(0) + return buff + + @staticmethod + def _load_data(data: io.BytesIO, path: str) -> Any: + if path.endswith(".pt"): + data = torch.load(data) + elif path.endswith(".tar.gz") or path.endswith(".tgz"): + tarfile.open(fileobj=data) + raise NotImplementedError("Tar loader not implemented!") + # tar = tarfile.open(fileobj=data) + # for member in tar.getmembers(): + else: + raise NotImplementedError() + + return data + + +class _ListS3Dataset(_BaseS3Dataset, Dataset): + """PyTorch Amazon S3 Map-Style List Dataset.""" + + def __getitem__(self, index): + path = self._paths[index] + data = self._fetch_data(path) + return [self._data_fn(data), self._label_fn(path)] + + def __len__(self): + return len(self._paths) + + def _data_fn(self, data) -> Any: + raise NotImplementedError() + + def _label_fn(self, path: str) -> Any: + raise NotImplementedError() + + +class _S3PartitionedDataset(_ListS3Dataset): + """PyTorch Amazon S3 Map-Style Partitioned Dataset.""" + + def _label_fn(self, path: str) -> torch.Tensor: + label = int(re.findall(r"/(.*?)=(.*?)/", path)[-1][1]) + return torch.tensor([label]) # pylint: disable=not-callable + + def _data_fn(self, data) -> Any: + raise NotImplementedError() + + +# class S3FilesDataset(_BaseS3Dataset, Dataset): +# """PyTorch Amazon S3 Files Map-Style Dataset.""" +# +# def __init__( +# self, path: Union[str, List[str]], suffix: Optional[str] = None, boto3_session: Optional[boto3.Session] = None +# ): +# """PyTorch S3 Files Map-Style Dataset. +# +# Each file under Amazon S3 path would be handled as a tensor or batch of tensors. +# +# Note +# ---- +# All files will be loaded to memory since random access is needed. +# +# Parameters +# ---------- +# path : Union[str, List[str]] +# S3 prefix (e.g. s3://bucket/prefix) or +# list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]). +# boto3_session : boto3.Session(), optional +# Boto3 Session. The default boto3 session will be used if boto3_session receive None. +# +# Returns +# ------- +# torch.utils.data.Dataset +# +# """ +# super(S3FilesDataset, self).__init__(path, suffix, boto3_session) +# self._download_files() +# +# def _download_files(self) -> None: +# self._data = [] +# for path in self._paths: +# data = self._fetch_data(path) +# data = self._load_data(data, path) +# self._data.append(data) +# +# self.data = torch.cat(self._data, dim=0) +# +# def __getitem__(self, index): +# return self._data[index] +# +# def __len__(self): +# return len(self._data) + + +class LambdaS3Dataset(_ListS3Dataset): + """PyTorch Amazon S3 Lambda Map-Style Dataset.""" + + def __init__( + self, + path: Union[str, List[str]], + data_fn: Callable, + label_fn: Callable, + suffix: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, + ): + """PyTorch Amazon S3 Lambda Dataset. + + Parameters + ---------- + path : Union[str, List[str]] + S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]). + data_fn: Callable + Function that receives a io.BytesIO object and returns a torch.Tensor + label_fn: Callable + Function that receives object path (str) and return a torch.Tensor + suffix: str, optional + S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png). + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + torch.utils.data.Dataset + + Examples + -------- + >>> import re + >>> import torch + >>> import awswrangler as wr + >>> ds = wr.torch.LambdaS3Dataset( + >>> 's3://bucket/path', + >>> data_fn=lambda x: torch.load(x), + >>> label_fn=lambda x: torch.Tensor(int(re.findall(r"/class=(.*?)/", x)[-1])), + >>> ) + + """ + super(LambdaS3Dataset, self).__init__(path, suffix, boto3_session) + self._data_func = data_fn + self._label_func = label_fn + + def _label_fn(self, path: str) -> torch.Tensor: + return self._label_func(path) + + def _data_fn(self, data) -> torch.Tensor: + return self._data_func(data) + + +class AudioS3Dataset(_S3PartitionedDataset): + """PyTorch S3 Audio Dataset.""" + + def __init__( + self, + path: Union[str, List[str]], + cache_dir: str = "/tmp/", + suffix: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, + ): + """PyTorch Amazon S3 Audio Dataset. + + Read individual WAV audio files stores in Amazon S3 and return + them as torch tensors. + + Note + ---- + This dataset assumes audio files are stored with the following structure: + + + :: + + bucket + ├── class=0 + │ ├── audio0.wav + │ └── audio1.wav + └── class=1 + ├── audio2.wav + └── audio3.wav + + Parameters + ---------- + path : Union[str, List[str]] + S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]). + suffix: str, optional + S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png). + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + torch.utils.data.Dataset + + Examples + -------- + Create a Audio S3 Dataset + + >>> import awswrangler as wr + >>> ds = wr.torch.AudioS3Dataset('s3://bucket/path') + + + Training a Model + + >>> criterion = CrossEntropyLoss().to(device) + >>> opt = SGD(model.parameters(), 0.025) + >>> loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + >>> + >>> for epoch in range(epochs): + >>> + >>> correct = 0 + >>> model.train() + >>> for i, (inputs, labels) in enumerate(loader): + >>> + >>> # Forward Pass + >>> outputs = model(inputs) + >>> + >>> # Backward Pass + >>> loss = criterion(outputs, labels) + >>> loss.backward() + >>> opt.step() + >>> opt.zero_grad() + >>> + >>> # Accuracy + >>> _, predicted = torch.max(outputs.data, 1) + >>> correct += (predicted == labels).sum().item() + >>> accuracy = 100 * correct / ((i+1) * batch_size) + >>> print(f'batch: {i} loss: {loss.mean().item():.4f} acc: {accuracy:.2f}') + + """ + super(AudioS3Dataset, self).__init__(path, suffix, boto3_session) + self._cache_dir: str = cache_dir[:-1] if cache_dir.endswith("/") else cache_dir + + def _data_fn(self, filename: str) -> Tuple[Any, Any]: # pylint: disable=arguments-differ + waveform, sample_rate = torchaudio.load(filename) + os.remove(path=filename) + return waveform, sample_rate + + def _fetch_data(self, path: str) -> str: + bucket, key = _utils.parse_path(path=path) + filename: str = f"{self._cache_dir}/{bucket}/{key}" + pathlib.Path(filename).parent.mkdir(parents=True, exist_ok=True) + client_s3 = _utils.client(service_name="s3", session=self._session) + client_s3.download_file(Bucket=bucket, Key=key, Filename=filename) + return filename + + +class ImageS3Dataset(_S3PartitionedDataset): + """PyTorch Amazon S3 Image Dataset.""" + + def __init__(self, path: Union[str, List[str]], suffix: str, boto3_session: boto3.Session): + """PyTorch Amazon S3 Image Dataset. + + ImageS3Dataset assumes images are patitioned (within class= folders) in Amazon S3. + Each lisited object will be loaded by default Pillow library. + + Note + ---- + Assumes Images are stored with the following structure: + + + :: + + bucket + ├── class=0 + │ ├── img0.jpeg + │ └── img1.jpeg + └── class=1 + ├── img2.jpeg + └── img3.jpeg + + Parameters + ---------- + path : Union[str, List[str]] + S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]). + suffix: str, optional + S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png). + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + torch.utils.data.Dataset + + Examples + -------- + >>> import awswrangler as wr + >>> ds = wr.torch.ImageS3Dataset('s3://bucket/path') + + """ + super(ImageS3Dataset, self).__init__(path, suffix, boto3_session) + + def _data_fn(self, data: io.BytesIO) -> Any: + image = Image.open(data) + tensor = to_tensor(image) + return tensor + + +class S3IterableDataset(IterableDataset, _BaseS3Dataset): # pylint: disable=abstract-method + """PyTorch Amazon S3 Iterable Dataset. + + Parameters + ---------- + path : Union[str, List[str]] + S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]). + suffix: str, optional + S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png). + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + torch.utils.data.Dataset + + Examples + -------- + >>> import awswrangler as wr + >>> ds = wr.torch.S3IterableDataset('s3://bucket/path') + + """ + + def __iter__(self) -> Union[Iterator[torch.Tensor], Iterator[Tuple[torch.Tensor, torch.Tensor]]]: + """Iterate over data returning tensors or expanding Iterables.""" + for path in self._paths: + data = self._fetch_data(path) + data = self._load_data(data, path) + + if isinstance(data, torch.Tensor): + pass + elif isinstance(data, Iterable) and all([isinstance(d, torch.Tensor) for d in data]): + data = zip(*data) + else: + raise NotImplementedError(f"ERROR: Type: {type(data)} has not been implemented!") + + for d in data: + yield d + + +class SQLDataset(IterableDataset): # pylint: disable=too-few-public-methods,abstract-method + """Pytorch Iterable SQL Dataset.""" + + def __init__( + self, + sql: str, + con: sqlalchemy.engine.Engine, + label_col: Optional[Union[int, str]] = None, + chunksize: Optional[int] = None, + ): + """Pytorch Iterable SQL Dataset. + + Support for **Redshift**, **PostgreSQL** and **MySQL**. + + Parameters + ---------- + sql : str + Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html + con : sqlalchemy.engine.Engine + SQLAlchemy Engine. Please use, + wr.db.get_engine(), wr.db.get_redshift_temp_engine() or wr.catalog.get_engine() + label_col : int, optional + Label column number. + chunksize : int, optional + The chunksize determines que number of rows to be retrived from the database at each time. + + Returns + ------- + torch.utils.data.dataset.IterableDataset + + Examples + -------- + >>> import awswrangler as wr + >>> con = wr.catalog.get_engine("aws-data-wrangler-postgresql") + >>> ds = wr.torch.SQLDataset('select * from public.tutorial', con=con) + + """ + super().__init__() + self._sql = sql + self._con = con + self._label_col = label_col + self._chunksize = chunksize + + def __iter__(self) -> Union[Iterator[torch.Tensor], Iterator[Tuple[torch.Tensor, torch.Tensor]]]: + """Iterate over the Dataset.""" + if torch.utils.data.get_worker_info() is not None: # type: ignore + raise NotImplementedError() + db._validate_engine(con=self._con) # pylint: disable=protected-access + with self._con.connect() as con: + cursor: Any = con.execute(self._sql) + if (self._label_col is not None) and isinstance(self._label_col, str): + label_col: Optional[int] = list(cursor.keys()).index(self._label_col) + else: + label_col = self._label_col + _logger.debug("label_col: %s", label_col) + if self._chunksize is None: + return SQLDataset._records2tensor(records=cursor.fetchall(), label_col=label_col) + return self._iterate_cursor(cursor=cursor, chunksize=self._chunksize, label_col=label_col) + + @staticmethod + def _iterate_cursor( + cursor: Any, chunksize: int, label_col: Optional[int] = None + ) -> Union[Iterator[torch.Tensor], Iterator[Tuple[torch.Tensor, torch.Tensor]]]: + while True: + records = cursor.fetchmany(chunksize) + if not records: + break + yield from SQLDataset._records2tensor(records=records, label_col=label_col) + + @staticmethod + def _records2tensor( + records: List[Tuple[Any]], label_col: Optional[int] = None + ) -> Union[Iterator[torch.Tensor], Iterator[Tuple[torch.Tensor, torch.Tensor]]]: # pylint: disable=unused-argument + for row in records: + if label_col is None: + arr_data: np.ndarray = np.array(row, dtype=np.float) + yield torch.as_tensor(arr_data, dtype=torch.float) # pylint: disable=no-member + else: + arr_data = np.array(row[:label_col] + row[label_col + 1 :], dtype=np.float) # noqa: E203 + arr_label: np.ndarray = np.array(row[label_col], dtype=np.long) + ts_data: torch.Tensor = torch.as_tensor(arr_data, dtype=torch.float) # pylint: disable=no-member + ts_label: torch.Tensor = torch.as_tensor(arr_label, dtype=torch.long) # pylint: disable=no-member + yield ts_data, ts_label diff --git a/building/build-docs.sh b/building/build-docs.sh index c32c20aa0..8c807b485 100755 --- a/building/build-docs.sh +++ b/building/build-docs.sh @@ -4,4 +4,4 @@ set -ex pushd .. rm -rf docs/build docs/source/stubs make -C docs/ html -doc8 --ignore D005 docs/source +doc8 --ignore D005,D002 docs/source diff --git a/docs/source/api.rst b/docs/source/api.rst index 897fc7a3e..aea8bbed6 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -3,6 +3,19 @@ API Reference ============= +PyTorch +------- + +.. currentmodule:: awswrangler.torch + +.. autosummary:: + :toctree: stubs + + AudioS3Dataset + ImageS3Dataset + S3IterableDataset + SQLDataset + Amazon S3 --------- diff --git a/requirements-torch.txt b/requirements-torch.txt new file mode 100644 index 000000000..73b8aae36 --- /dev/null +++ b/requirements-torch.txt @@ -0,0 +1,4 @@ +torch~=1.4.0 +torchvision~=0.5.0 +torchaudio~=0.4.0 +Pillow~=7.1.1 diff --git a/setup-dev-env.sh b/setup-dev-env.sh index 692724ee0..c9c2e9902 100755 --- a/setup-dev-env.sh +++ b/setup-dev-env.sh @@ -3,5 +3,4 @@ set -ex pip install --upgrade pip pip install -r requirements-dev.txt -pip install -r requirements.txt -pip install -e . +pip install -e ".[torch]" diff --git a/setup.py b/setup.py index b363e6e58..f9c861a60 100644 --- a/setup.py +++ b/setup.py @@ -23,4 +23,7 @@ packages=find_packages(include=["awswrangler", "awswrangler.*"], exclude=["tests"]), python_requires=">=3.6, <3.9", install_requires=[open("requirements.txt").read().strip().split("\n")], + extras_require={ + "torch": open("requirements-torch.txt").read().strip().split("\n") + } ) diff --git a/testing/run-validations.sh b/testing/run-validations.sh index 966038ec9..d32fc7808 100755 --- a/testing/run-validations.sh +++ b/testing/run-validations.sh @@ -9,7 +9,7 @@ mv temp.yaml cloudformation.yaml pushd .. black --line-length 120 --target-version py36 awswrangler testing/test_awswrangler isort -rc --line-width 120 awswrangler testing/test_awswrangler -pydocstyle awswrangler/ --add-ignore=D204 +pydocstyle awswrangler/ --add-ignore=D204,D403 mypy awswrangler flake8 setup.py awswrangler testing/test_awswrangler pylint -j 0 awswrangler diff --git a/testing/test_awswrangler/test_torch.py b/testing/test_awswrangler/test_torch.py new file mode 100644 index 000000000..19a300400 --- /dev/null +++ b/testing/test_awswrangler/test_torch.py @@ -0,0 +1,273 @@ +import io +import logging +import re + +import boto3 +import numpy as np +import pandas as pd +import pytest +import torch +import torchaudio +from PIL import Image +from torch.utils.data import DataLoader +from torchvision.transforms.functional import to_tensor + +import awswrangler as wr + +logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s") +logging.getLogger("awswrangler").setLevel(logging.DEBUG) +logging.getLogger("botocore.credentials").setLevel(logging.CRITICAL) + + +@pytest.fixture(scope="module") +def cloudformation_outputs(): + response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test") + outputs = {} + for output in response.get("Stacks")[0].get("Outputs"): + outputs[output.get("OutputKey")] = output.get("OutputValue") + yield outputs + + +@pytest.fixture(scope="module") +def bucket(cloudformation_outputs): + if "BucketName" in cloudformation_outputs: + bucket = cloudformation_outputs["BucketName"] + else: + raise Exception("You must deploy/update the test infrastructure (CloudFormation)") + yield bucket + + +@pytest.fixture(scope="module") +def parameters(cloudformation_outputs): + parameters = dict(postgresql={}, mysql={}, redshift={}) + parameters["postgresql"]["host"] = cloudformation_outputs["PostgresqlAddress"] + parameters["postgresql"]["port"] = 3306 + parameters["postgresql"]["schema"] = "public" + parameters["postgresql"]["database"] = "postgres" + parameters["mysql"]["host"] = cloudformation_outputs["MysqlAddress"] + parameters["mysql"]["port"] = 3306 + parameters["mysql"]["schema"] = "test" + parameters["mysql"]["database"] = "test" + parameters["redshift"]["host"] = cloudformation_outputs["RedshiftAddress"] + parameters["redshift"]["port"] = cloudformation_outputs["RedshiftPort"] + parameters["redshift"]["identifier"] = cloudformation_outputs["RedshiftIdentifier"] + parameters["redshift"]["schema"] = "public" + parameters["redshift"]["database"] = "test" + parameters["redshift"]["role"] = cloudformation_outputs["RedshiftRole"] + parameters["password"] = cloudformation_outputs["DatabasesPassword"] + parameters["user"] = "test" + yield parameters + + +@pytest.mark.parametrize("chunksize", [None, 1, 10]) +@pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"]) +def test_torch_sql(parameters, db_type, chunksize): + schema = parameters[db_type]["schema"] + table = f"test_torch_sql_{db_type}_{str(chunksize).lower()}" + engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-{db_type}") + wr.db.to_sql( + df=pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0]}), + con=engine, + name=table, + schema=schema, + if_exists="replace", + index=False, + index_label=None, + chunksize=None, + method=None, + ) + ds = list(wr.torch.SQLDataset(f"SELECT * FROM {schema}.{table}", con=engine, chunksize=chunksize)) + assert torch.all(ds[0].eq(torch.tensor([1.0, 4.0]))) + assert torch.all(ds[1].eq(torch.tensor([2.0, 5.0]))) + assert torch.all(ds[2].eq(torch.tensor([3.0, 6.0]))) + + +@pytest.mark.parametrize("chunksize", [None, 1, 10]) +@pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"]) +def test_torch_sql_label(parameters, db_type, chunksize): + schema = parameters[db_type]["schema"] + table = f"test_torch_sql_label_{db_type}_{str(chunksize).lower()}" + engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-{db_type}") + wr.db.to_sql( + df=pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0], "c": [7, 8, 9]}), + con=engine, + name=table, + schema=schema, + if_exists="replace", + index=False, + index_label=None, + chunksize=None, + method=None, + ) + ts = list(wr.torch.SQLDataset(f"SELECT * FROM {schema}.{table}", con=engine, chunksize=chunksize, label_col=2)) + assert torch.all(ts[0][0].eq(torch.tensor([1.0, 4.0]))) + assert torch.all(ts[0][1].eq(torch.tensor([7], dtype=torch.long))) + assert torch.all(ts[1][0].eq(torch.tensor([2.0, 5.0]))) + assert torch.all(ts[1][1].eq(torch.tensor([8], dtype=torch.long))) + assert torch.all(ts[2][0].eq(torch.tensor([3.0, 6.0]))) + assert torch.all(ts[2][1].eq(torch.tensor([9], dtype=torch.long))) + + +def test_torch_image_s3(bucket): + folder = "test_torch_image_s3" + path = f"s3://{bucket}/{folder}/" + wr.s3.delete_objects(path=path, boto3_session=boto3.Session()) + s3 = boto3.client("s3") + ref_label = 0 + s3.put_object( + Body=open("docs/source/_static/logo.png", "rb").read(), + Bucket=bucket, + Key=f"{folder}/class={ref_label}/logo.png", + ContentType="image/png", + ) + ds = wr.torch.ImageS3Dataset(path=path, suffix="png", boto3_session=boto3.Session()) + image, label = ds[0] + assert image.shape == torch.Size([4, 494, 1636]) + assert label == torch.tensor(ref_label, dtype=torch.int) + wr.s3.delete_objects(path=path) + + +@pytest.mark.parametrize("drop_last", [True, False]) +def test_torch_image_s3_loader(bucket, drop_last): + folder = f"test_torch_image_s3_loader_{str(drop_last).lower()}" + path = f"s3://{bucket}/{folder}/" + wr.s3.delete_objects(path=path) + client_s3 = boto3.client("s3") + labels = np.random.randint(0, 4, size=(8,)) + for i, label in enumerate(labels): + client_s3.put_object( + Body=open("./docs/source/_static/logo.png", "rb").read(), + Bucket=bucket, + Key=f"{folder}/class={label}/logo{i}.png", + ContentType="image/png", + ) + ds = wr.torch.ImageS3Dataset(path=path, suffix="png", boto3_session=boto3.Session()) + batch_size = 2 + num_train = len(ds) + indices = list(range(num_train)) + loader = DataLoader( + ds, + batch_size=batch_size, + num_workers=4, + sampler=torch.utils.data.sampler.RandomSampler(indices), + drop_last=drop_last, + ) + for i, (image, label) in enumerate(loader): + assert image.shape == torch.Size([batch_size, 4, 494, 1636]) + assert label.dtype == torch.int64 + wr.s3.delete_objects(path=path) + + +def test_torch_lambda_s3(bucket): + path = f"s3://{bucket}/test_torch_lambda_s3/" + wr.s3.delete_objects(path=path) + s3 = boto3.client("s3") + ref_label = 0 + s3.put_object( + Body=open("./docs/source/_static/logo.png", "rb").read(), + Bucket=bucket, + Key=f"test_torch_lambda_s3/class={ref_label}/logo.png", + ContentType="image/png", + ) + ds = wr.torch.LambdaS3Dataset( + path=path, + suffix="png", + boto3_session=boto3.Session(), + data_fn=lambda x: to_tensor(Image.open(x)), + label_fn=lambda x: int(re.findall(r"/class=(.*?)/", x)[-1]), + ) + image, label = ds[0] + assert image.shape == torch.Size([4, 494, 1636]) + assert label == torch.tensor(ref_label, dtype=torch.int) + wr.s3.delete_objects(path=path) + + +def test_torch_audio_s3(bucket): + size = (1, 8_000 * 5) + audio = torch.randint(low=-25, high=25, size=size) / 100.0 + audio_file = "/tmp/amazing_sound.wav" + torchaudio.save(audio_file, audio, 8_000) + folder = "test_torch_audio_s3" + path = f"s3://{bucket}/{folder}/" + wr.s3.delete_objects(path=path) + s3 = boto3.client("s3") + ref_label = 0 + s3.put_object( + Body=open(audio_file, "rb").read(), + Bucket=bucket, + Key=f"{folder}/class={ref_label}/amazing_sound.wav", + ContentType="audio/wav", + ) + s3_audio_file = f"{bucket}/test_torch_audio_s3/class={ref_label}/amazing_sound.wav" + ds = wr.torch.AudioS3Dataset(path=s3_audio_file, suffix="wav") + loader = DataLoader(ds, batch_size=1) + for (audio, rate), label in loader: + assert audio.shape == torch.Size((1, *size)) + wr.s3.delete_objects(path=path) + + +# def test_torch_s3_file_dataset(bucket): +# cifar10 = "s3://fast-ai-imageclas/cifar10.tgz" +# batch_size = 64 +# for image, label in DataLoader( +# wr.torch.S3FilesDataset(cifar10), +# batch_size=batch_size, +# ): +# assert image.shape == torch.Size([batch_size, 3, 32, 32]) +# assert label.dtype == torch.int64 +# break + + +@pytest.mark.parametrize("drop_last", [True, False]) +def test_torch_s3_iterable(bucket, drop_last): + folder = f"test_torch_s3_iterable_{str(drop_last).lower()}" + path = f"s3://{bucket}/{folder}/" + wr.s3.delete_objects(path=path) + batch_size = 32 + client_s3 = boto3.client("s3") + for i in range(3): + batch = torch.randn(100, 3, 32, 32) + buff = io.BytesIO() + torch.save(batch, buff) + buff.seek(0) + client_s3.put_object(Body=buff.read(), Bucket=bucket, Key=f"{folder}/file{i}.pt") + + for image in DataLoader( + wr.torch.S3IterableDataset(path=f"s3://{bucket}/{folder}/file"), batch_size=batch_size, drop_last=drop_last + ): + if drop_last: + assert image.shape == torch.Size([batch_size, 3, 32, 32]) + else: + assert image[0].shape == torch.Size([3, 32, 32]) + + wr.s3.delete_objects(path=path) + + +@pytest.mark.parametrize("drop_last", [True, False]) +def test_torch_s3_iterable_with_labels(bucket, drop_last): + folder = f"test_torch_s3_iterable_with_labels_{str(drop_last).lower()}" + path = f"s3://{bucket}/{folder}/" + wr.s3.delete_objects(path=path) + batch_size = 32 + client_s3 = boto3.client("s3") + for i in range(3): + batch = (torch.randn(100, 3, 32, 32), torch.randint(2, size=(100,))) + buff = io.BytesIO() + torch.save(batch, buff) + buff.seek(0) + client_s3.put_object(Body=buff.read(), Bucket=bucket, Key=f"{folder}/file{i}.pt") + + for images, labels in DataLoader( + wr.torch.S3IterableDataset(path=f"s3://{bucket}/{folder}/file"), batch_size=batch_size, drop_last=drop_last + ): + if drop_last: + assert images.shape == torch.Size([batch_size, 3, 32, 32]) + assert labels.dtype == torch.int64 + assert labels.shape == torch.Size([batch_size]) + + else: + assert images[0].shape == torch.Size([3, 32, 32]) + assert labels[0].dtype == torch.int64 + assert labels[0].shape == torch.Size([]) + + wr.s3.delete_objects(path=path) diff --git a/tutorials/14 - PyTorch.ipynb b/tutorials/14 - PyTorch.ipynb new file mode 100644 index 000000000..b7af04627 --- /dev/null +++ b/tutorials/14 - PyTorch.ipynb @@ -0,0 +1,330 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![AWS Data Wrangler](_static/logo.png \"AWS Data Wrangler\")](https://github.com/awslabs/aws-data-wrangler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PyTorch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Table of Contents\n", + "* [1.Defining Training Function](#1.-Defining-Training-Function)\n", + "* [2.Training From Amazon S3](#2.-Traoning-From-Amazon-S3)\n", + "\t* [2.1 Writing PyTorch Dataset to S3](#2.1-Writing-PyTorch-Dataset-to-S3)\n", + "\t* [2.2 Training Network](#2.2-Training-Network)\n", + "* [3. Training From SQL Query](#3.-Training-From-SQL-Query)\n", + "\t* [3.1 Writing Data to SQL Database](#3.1-Writing-Data-to-SQL-Database)\n", + "\t* [3.3 Training Network From SQL](#3.3-Reading-single-JSON-file)\n", + "* [4. Creating Custom S3 Dataset](#4.-Creating-Custom-S3-Dataset)\n", + "\t* [4.1 Creating Custom PyTorch Dataset](#4.1-Creating-Custom-PyTorch-Dataset)\n", + "\t* [4.2 Writing Data to S3](#4.2-Writing-Data-to-S3)\n", + "\t* [4.3 Training Network](#4.4-Training-Network)\n", + "* [5. Delete objects](#5.-Delete-objects)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import io\n", + "\n", + "import boto3\n", + "import torch\n", + "import torchvision\n", + "import pandas as pd\n", + "import awswrangler as wr\n", + "\n", + "from torch.optim import SGD\n", + "from torch.nn import CrossEntropyLoss\n", + "from torch.utils.data import DataLoader" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "········\n" + ] + } + ], + "source": [ + "import getpass\n", + "bucket = getpass.getpass()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 1. Defining Training Function" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def train(model, dataset, batch_size=64, epochs=2, device='cpu', num_workers=1):\n", + "\n", + " criterion = CrossEntropyLoss().to(device)\n", + " opt = SGD(model.parameters(), 0.025)\n", + " loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n", + "\n", + " for epoch in range(epochs):\n", + "\n", + " correct = 0 \n", + " model.train()\n", + " for i, (inputs, labels) in enumerate(loader):\n", + "\n", + " # Forward Pass\n", + " outputs = model(inputs)\n", + " \n", + " # Backward Pass\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " opt.step()\n", + " opt.zero_grad()\n", + " \n", + " # Accuracy\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " correct += (predicted == labels).sum().item()\n", + " accuracy = 100 * correct / ((i+1) * batch_size)\n", + "\n", + " print(f'batch: {i} loss: {loss.mean().item():.4f} acc: {accuracy:.2f}') " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. Training From Amazon S3" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.1 Writing PyTorch Dataset to S3" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "client_s3 = boto3.client(\"s3\")\n", + "folder = \"tutorial_torch_dataset\"\n", + "\n", + "wr.s3.delete_objects(f\"s3://{bucket}/{folder}\")\n", + "for i in range(3):\n", + " batch = (\n", + " torch.randn(100, 3, 32, 32),\n", + " torch.randint(2, size=(100,)),\n", + " )\n", + " buff = io.BytesIO()\n", + " torch.save(batch, buff)\n", + " buff.seek(0)\n", + " client_s3.put_object(\n", + " Body=buff.read(),\n", + " Bucket=bucket,\n", + " Key=f\"{folder}/file{i}.pt\",\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.2 Training Network" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "batch: 0 loss: 7.0132 acc: 0.00\n", + "batch: 1 loss: 2.8764 acc: 21.09\n", + "batch: 2 loss: 0.9600 acc: 32.29\n", + "batch: 3 loss: 0.8676 acc: 36.33\n", + "batch: 4 loss: 1.1386 acc: 36.88\n", + "batch: 0 loss: 1.0754 acc: 51.56\n", + "batch: 1 loss: 1.4241 acc: 51.56\n", + "batch: 2 loss: 1.3019 acc: 51.04\n", + "batch: 3 loss: 0.8631 acc: 53.52\n", + "batch: 4 loss: 0.4252 acc: 54.38\n" + ] + } + ], + "source": [ + "train(\n", + " torchvision.models.resnet18(),\n", + " wr.torch.S3IterableDataset(path=f\"{bucket}/{folder}\")\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. Training Directly From SQL Query" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.1 Writing Data to SQL Database" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "eng = wr.catalog.get_engine(\"aws-data-wrangler-redshift\")\n", + "df = pd.DataFrame({\n", + " \"height\": [2, 1.4, 1.7, 1.8, 1.9, 2.2],\n", + " \"weight\": [100.0, 50.0, 70.0, 80.0, 90.0, 160.0],\n", + " \"target\": [1, 0, 0, 1, 1, 1]\n", + "})\n", + "\n", + "wr.db.to_sql(\n", + " df,\n", + " eng,\n", + " schema=\"public\",\n", + " name=\"torch\",\n", + " if_exists=\"replace\",\n", + " index=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.2 Training Network From SQL" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "batch: 0 loss: 8.8708 acc: 50.00\n", + "batch: 1 loss: 88.7789 acc: 50.00\n", + "batch: 2 loss: 0.8655 acc: 33.33\n", + "batch: 0 loss: 0.7036 acc: 50.00\n", + "batch: 1 loss: 0.7034 acc: 50.00\n", + "batch: 2 loss: 0.8447 acc: 33.33\n", + "batch: 0 loss: 0.7012 acc: 50.00\n", + "batch: 1 loss: 0.7010 acc: 50.00\n", + "batch: 2 loss: 0.8250 acc: 33.33\n", + "batch: 0 loss: 0.6992 acc: 50.00\n", + "batch: 1 loss: 0.6991 acc: 50.00\n", + "batch: 2 loss: 0.8063 acc: 33.33\n", + "batch: 0 loss: 0.6975 acc: 50.00\n", + "batch: 1 loss: 0.6974 acc: 50.00\n", + "batch: 2 loss: 0.7886 acc: 33.33\n" + ] + } + ], + "source": [ + "train(\n", + " torch.nn.Sequential(\n", + " torch.nn.Linear(2, 10),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(10, 2), \n", + " ),\n", + " wr.torch.SQLDataset(\n", + " sql=\"SELECT * FROM public.torch\",\n", + " con=eng,\n", + " label_col=\"target\",\n", + " chunksize=2\n", + " ),\n", + " num_workers=0,\n", + " batch_size=2,\n", + " epochs=5\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3. Delete Objects" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "wr.s3.delete_objects(f\"s3://{bucket}/{folder}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file