Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow usage of unsigned S3 client #195

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## Unreleased

### New features
* Update S3ClientConfig to pass in the configuration for allowing unsigned requests, under boolean flag `unsigned`.


## v1.2.2 (March 22, 2024)

### New features
Expand Down
6 changes: 6 additions & 0 deletions doc/DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ Using S3ClientConfig you can set up the following parameters for the underlying
(max number of parts per upload is 10,000, minimum upload part size is 5 MiB).
Part size must have **values between 5MiB and 5GiB.** Is set by default to **8MiB** (may change in future).

* `unsigned(bool)`: Allows the usage of unsigned clients when accessing public datasets or when other mechanisms are
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
* `unsigned(bool)`: Allows the usage of unsigned clients when accessing public datasets or when other mechanisms are
* `unsigned(bool)`: Set to true to disable signing S3 requests.

(we don't want to suggest "other mechanisms" unless you're really, really sure)

in place to grant access.

For example this can be passed in like:
```py
from s3torchconnector import S3MapDataset, S3ClientConfig
Expand All @@ -165,6 +168,9 @@ s3_map_dataset = S3MapDataset.from_prefix(DATASET_URI, region=REGION, s3client_c
s3_checkpoint = S3Checkpoint(region=REGION, s3client_config=config)
# Works similarly for Lightning checkpoints.
s3_lightning_checkpoint = S3LightningCheckpoint(region=REGION, s3client_config=config)

# Use an unsigned S3 client
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Use an unsigned S3 client
# Disable signing to make requests without AWS credentials

s3_client = S3Client(region=REGION, s3client_config=S3ClientConfig(unsigned=True))
```

**When modifying the default values for these flags, we strongly recommend to run benchmarking to ensure you are not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
throughput_target_gbps=self.s3client_config.throughput_target_gbps,
part_size=self.s3client_config.part_size,
user_agent_prefix=self.user_agent_prefix,
unsigned=self.s3client_config.unsigned,
)

def add_object(self, key: str, data: bytes) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _client_builder(self) -> MountpointS3Client:
user_agent_prefix=self._user_agent_prefix,
throughput_target_gbps=self._s3client_config.throughput_target_gbps,
part_size=self._s3client_config.part_size,
unsigned=self._s3client_config.unsigned,
)

def get_object(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ class S3ClientConfig:

throughput_target_gbps: float = 10.0
part_size: int = 8 * 1024 * 1024
unsigned: bool = False
17 changes: 16 additions & 1 deletion s3torchconnector/tst/e2e/test_e2e_s3datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data.datapipes.datapipe import MapDataPipe
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe

from s3torchconnector import S3IterableDataset, S3MapDataset
from s3torchconnector import S3IterableDataset, S3MapDataset, S3ClientConfig


def test_s3iterable_dataset_images_10_from_prefix(image_directory):
Expand Down Expand Up @@ -100,6 +100,21 @@ def test_dataset_unpickled_iterates(image_directory):
assert expected == actual


def test_unsigned_client():
s3_uri = "s3://s3torchconnector-demo/geonet/images/"
region = "us-east-1"
s3_dataset = S3MapDataset.from_prefix(
s3_uri=s3_uri,
region=region,
transform=lambda obj: obj.read(),
s3client_config=S3ClientConfig(unsigned=True),
)
s3_dataloader = _pytorch_dataloader(s3_dataset)
assert s3_dataloader is not None
assert isinstance(s3_dataloader.dataset, S3MapDataset)
assert len(s3_dataloader) >= 1296


def _compare_dataloaders(
local_dataloader: DataLoader, s3_dataloader: DataLoader, expected_batch_count: int
):
Expand Down
9 changes: 9 additions & 0 deletions s3torchconnector/tst/unit/test_s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def test_s3_client_custom_config(part_size: int, throughput_target_gbps: float):
)
assert s3_client._client.part_size == part_size
assert s3_client._client.throughput_target_gbps == throughput_target_gbps
assert s3_client._client.unsigned is False


@pytest.mark.parametrize(
Expand All @@ -130,3 +131,11 @@ def test_s3_client_invalid_part_size_config(part_size: int):
)
# The client is lazily initialized
assert s3_client._client.part_size == part_size


def test_unsigned_s3_client():
s3_client = S3Client(
region=TEST_REGION,
s3client_config=S3ClientConfig(unsigned=True),
)
assert s3_client._client.unsigned is True
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class MountpointS3Client:
region: str
part_size: int
profile: Optional[str]
no_sign_request: bool
unsigned: Optional[bool]
user_agent_prefix: str
endpoint: str

Expand All @@ -21,7 +21,7 @@ class MountpointS3Client:
throughput_target_gbps: float = 10.0,
part_size: int = 8 * 1024 * 1024,
profile: Optional[str] = None,
no_sign_request: bool = False,
unsigned: Optional[bool] = False,
endpoint: Optional[str] = None,
): ...
def get_object(self, bucket: str, key: str) -> GetObjectStream: ...
Expand All @@ -39,6 +39,7 @@ class MockMountpointS3Client:
region: str
part_size: int
user_agent_prefix: str
unsigned: bool

def __init__(
self,
Expand All @@ -48,6 +49,7 @@ class MockMountpointS3Client:
throughput_target_gbps: float = 10.0,
part_size: int = 8 * 1024 * 1024,
user_agent_prefix: str = "mock_client",
unsigned: bool = False,
): ...
def create_mocked_client(self) -> MountpointS3Client: ...
def add_object(self, key: str, data: bytes) -> None: ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def test_put_object_with_storage_class():
# TODO: Add hypothesis setup after aligning on limits
def test_mountpoint_client_pickles():
expected_profile = None
expected_no_sign_request = False
expected_unsigned = False
expected_region = REGION
expected_part_size = 5 * 2**20
expected_throughput_target_gbps = 3.5
Expand All @@ -254,7 +254,7 @@ def test_mountpoint_client_pickles():
part_size=expected_part_size,
throughput_target_gbps=expected_throughput_target_gbps,
profile=expected_profile,
no_sign_request=expected_no_sign_request,
unsigned=expected_unsigned,
)
dumped = pickle.dumps(client)
loaded = pickle.loads(dumped)
Expand All @@ -271,7 +271,7 @@ def test_mountpoint_client_pickles():
== expected_throughput_target_gbps
)
assert client.profile == loaded.profile == expected_profile
assert client.no_sign_request == loaded.no_sign_request == expected_no_sign_request
assert client.unsigned == loaded.unsigned == expected_unsigned


@pytest.mark.parametrize(
Expand Down
10 changes: 7 additions & 3 deletions s3torchconnectorclient/rust/src/mock_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,21 @@ pub struct PyMockClient {
pub(crate) part_size: usize,
#[pyo3(get)]
pub(crate) user_agent_prefix: String,
#[pyo3(get)]
pub(crate) unsigned: bool,
}

#[pymethods]
impl PyMockClient {
#[new]
#[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024, user_agent_prefix="mock_client".to_string()))]
#[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024, user_agent_prefix="mock_client".to_string(), unsigned=false))]
pub fn new(
region: String,
bucket: String,
throughput_target_gbps: f64,
part_size: usize,
user_agent_prefix: String,
unsigned: bool,
) -> PyMockClient {
let unordered_list_seed: Option<u64> = None;
let config = MockClientConfig { bucket, part_size, unordered_list_seed };
Expand All @@ -48,7 +51,8 @@ impl PyMockClient {
region,
throughput_target_gbps,
part_size,
user_agent_prefix
user_agent_prefix,
unsigned
}
}

Expand All @@ -59,7 +63,7 @@ impl PyMockClient {
self.throughput_target_gbps,
self.part_size,
None,
false,
self.unsigned,
self.mock_client.clone(),
None,
)
Expand Down
21 changes: 11 additions & 10 deletions s3torchconnectorclient/rust/src/mountpoint_s3_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct MountpointS3Client {
#[pyo3(get)]
profile: Option<String>,
#[pyo3(get)]
no_sign_request: bool,
unsigned: bool,
#[pyo3(get)]
user_agent_prefix: String,
#[pyo3(get)]
Expand All @@ -53,14 +53,14 @@ pub struct MountpointS3Client {
#[pymethods]
impl MountpointS3Client {
#[new]
#[pyo3(signature = (region, user_agent_prefix="".to_string(), throughput_target_gbps=10.0, part_size=8*1024*1024, profile=None, no_sign_request=false, endpoint=None))]
#[pyo3(signature = (region, user_agent_prefix="".to_string(), throughput_target_gbps=10.0, part_size=8*1024*1024, profile=None, unsigned=false, endpoint=None))]
pub fn new_s3_client(
region: String,
user_agent_prefix: String,
throughput_target_gbps: f64,
part_size: usize,
profile: Option<String>,
no_sign_request: bool,
unsigned: bool,
endpoint: Option<String>,
) -> PyResult<Self> {
// TODO: Mountpoint has logic for guessing based on instance type. It may be worth having
Expand All @@ -72,7 +72,7 @@ impl MountpointS3Client {
} else {
EndpointConfig::new(&region).endpoint(Uri::new_from_str(&Allocator::default(), endpoint_str).unwrap())
};
let auth_config = auth_config(profile.as_deref(), no_sign_request);
let auth_config = auth_config(profile.as_deref(), unsigned);

let user_agent_suffix =
&format!("{}/{}", build_info::PACKAGE_NAME, build_info::FULL_VERSION);
Expand All @@ -96,7 +96,7 @@ impl MountpointS3Client {
throughput_target_gbps,
part_size,
profile,
no_sign_request,
unsigned,
crt_client,
endpoint,
))
Expand Down Expand Up @@ -154,7 +154,7 @@ impl MountpointS3Client {
slf.throughput_target_gbps.to_object(py),
slf.part_size.to_object(py),
slf.profile.to_object(py),
slf.no_sign_request.to_object(py),
slf.unsigned.to_object(py),
slf.endpoint.to_object(py),
];
Ok(PyTuple::new(py, state))
Expand All @@ -169,7 +169,8 @@ impl MountpointS3Client {
throughput_target_gbps: f64,
part_size: usize,
profile: Option<String>,
no_sign_request: bool,
// no_sign_request on mountpoint-s3-client
unsigned: bool,
client: Arc<Client>,
endpoint: Option<String>,
) -> Self
Expand All @@ -183,7 +184,7 @@ impl MountpointS3Client {
part_size,
region,
profile,
no_sign_request,
unsigned,
client: Arc::new(MountpointS3ClientInnerImpl::new(client)),
user_agent_prefix,
endpoint,
Expand All @@ -192,8 +193,8 @@ impl MountpointS3Client {
}
}

fn auth_config(profile: Option<&str>, no_sign_request: bool) -> S3ClientAuthConfig {
if no_sign_request {
fn auth_config(profile: Option<&str>, unsigned: bool) -> S3ClientAuthConfig {
if unsigned {
S3ClientAuthConfig::NoSigning
} else if let Some(profile_name) = profile {
S3ClientAuthConfig::Profile(profile_name.to_string())
Expand Down
Loading