Skip to content

Commit

Permalink
Merge pull request #164 from MichaelAquilina/better_add_command
Browse files Browse the repository at this point in the history
Improve behaviour of the add command
  • Loading branch information
MichaelAquilina committed Oct 30, 2018
2 parents 6759c4f + 5aff416 commit eab3ad3
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 112 deletions.
74 changes: 47 additions & 27 deletions s4/commands/add_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,68 @@
class AddCommand(Command):
def run(self):
target = self.args.copy_target_credentials

all_targets = list(self.config["targets"].keys())
if target is not None and target not in all_targets:
self.logger.info('"%s" is an unknown target', target)
self.logger.info("Choices are: %s", all_targets)
return

if target is not None:
target_config = self.config["targets"][target]

aws_access_key_id = target_config["aws_access_key_id"]
aws_secret_access_key = target_config["aws_secret_access_key"]
else:
aws_access_key_id = None
aws_secret_access_key = None

self.logger.info("To add a new target, please enter the following\n")

entry = {}
local_folder = utils.get_input(
"local folder (leave blank for current folder): "
"local folder path to sync [leave blank for current folder]: "
)
if not local_folder:
local_folder = os.getcwd()

entry["local_folder"] = os.path.expanduser(local_folder)

entry["endpoint_url"] = utils.get_input("endpoint url (leave blank for AWS): ")
entry["s3_uri"] = utils.get_input("s3 uri: ")
entry["region_name"] = utils.get_input("region name: ")
endpoint_url = utils.get_input(
"endpoint url [leave blank for AWS]: ", blank=True
)
bucket = utils.get_input("S3 Bucket [required]: ", required=True)
path = utils.get_input("S3 Path: ")
region_name = utils.get_input("region name [leave blank if unknown]: ")

if target is not None:
entry["aws_access_key_id"] = self.config["targets"][target][
"aws_access_key_id"
]
entry["aws_secret_access_key"] = self.config["targets"][target][
"aws_secret_access_key"
]
else:
entry["aws_access_key_id"] = utils.get_input("AWS Access Key ID: ")
entry["aws_secret_access_key"] = utils.get_input(
"AWS Secret Access Key: ", secret=True
if aws_access_key_id is None:
aws_access_key_id = utils.get_input(
"AWS Access Key ID [required]: ", required=True
)
aws_secret_access_key = utils.get_input(
"AWS Secret Access Key [required]: ", secret=True, required=True
)

default_name = os.path.basename(entry["s3_uri"])
default_name = os.path.basename(path)
name = utils.get_input(
"Provide a name for this entry [{}]: ".format(default_name)
"Provide a name for this entry [leave blank to default to '{}']: ".format(
default_name
)
)
name = name or default_name
if not local_folder:
local_folder = os.getcwd()

if not name:
name = default_name

self.config["targets"][name] = entry
local_folder = os.path.expanduser(local_folder)
local_folder = os.path.abspath(local_folder)

self.config["targets"][name] = {
"local_folder": local_folder,
"endpoint_url": endpoint_url,
"s3_uri": "s3://{}/{}".format(bucket, path),
"region_name": region_name,
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
}
utils.set_config(self.config)

self.logger.info(
"\nTarget has been added. Start syncing with the 'sync' command"
)
self.logger.info(
"You can edit anything you have entered here using the 'edit' command"
)
27 changes: 20 additions & 7 deletions s4/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,26 @@ def to_timestamp(dt):
return (dt - epoch) / datetime.timedelta(seconds=1)


def get_input(*args, secret=False, **kwargs):
if secret:
return getpass.getpass(*args, **kwargs)
else:
value = input(*args, **kwargs)
# Normalise empty inputs to None
return value if value else None
def get_input(*args, secret=False, required=False, blank=False, **kwargs):
"""
secret: Don't show user input when they are typing.
required: Keep prompting if the user enters an empty value.
blank: turn all empty strings into None.
"""

while True:
if secret:
value = getpass.getpass(*args, **kwargs)
else:
value = input(*args, **kwargs)

if blank:
value = value if value else None

if not required or value:
break

return value


def get_config():
Expand Down
90 changes: 46 additions & 44 deletions tests/commands/test_add_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@
@mock.patch("s4.utils.get_input")
class TestAddCommand(object):
def test_correct_behaviour(self, get_input, config_file):
fake_stream = utils.FakeInputStream(
[
"/home/user/Documents",
None,
"s3://mybucket/Documents",
"eu-west-2",
"aaaaaaaaaaaaaaaaaaaaaaaa",
"bbbbbbbbbbbbbbbbbbbbbbbb",
"",
]
)
get_input.side_effect = fake_stream
get_input.side_effect = [
"/home/user/Documents",
None,
"mybucket",
"Documents",
"eu-west-2",
"aaaaaaaaaaaaaaaaaaaaaaaa",
"bbbbbbbbbbbbbbbbbbbbbbbb",
"",
]
args = argparse.Namespace(copy_target_credentials=None)

command = AddCommand(args, {"targets": {}}, utils.create_logger())
Expand All @@ -49,18 +47,16 @@ def test_correct_behaviour(self, get_input, config_file):
assert new_config == expected_config

def test_default_local_folder(self, get_input, config_file):
fake_stream = utils.FakeInputStream(
[
None,
None,
"s3://mybucket/Documents",
"eu-west-2",
"aaaaaaaaaaaaaaaaaaaaaaaa",
"bbbbbbbbbbbbbbbbbbbbbbbb",
"",
]
)
get_input.side_effect = fake_stream
get_input.side_effect = [
None,
None,
"mybucket",
"Documents",
"eu-west-2",
"aaaaaaaaaaaaaaaaaaaaaaaa",
"bbbbbbbbbbbbbbbbbbbbbbbb",
"",
]
args = argparse.Namespace(copy_target_credentials=None)

command = AddCommand(args, {"targets": {}}, utils.create_logger())
Expand All @@ -72,10 +68,14 @@ def test_default_local_folder(self, get_input, config_file):
assert config["targets"]["Documents"]["local_folder"] == os.getcwd()

def test_copy_target_credentials(self, get_input, config_file):
fake_stream = utils.FakeInputStream(
["/home/user/Animals", None, "s3://mybucket/Zoo", "us-west-2", "Beasts"]
)
get_input.side_effect = fake_stream
get_input.side_effect = [
"/home/user/Animals",
None,
"mybucket",
"Zoo",
"us-west-2",
"Beasts",
]
args = argparse.Namespace(copy_target_credentials="bar")

command = AddCommand(
Expand Down Expand Up @@ -114,10 +114,14 @@ def test_copy_target_credentials(self, get_input, config_file):
assert new_config == expected_config

def test_copy_target_credentials_bad_target(self, get_input, capsys):
fake_stream = utils.FakeInputStream(
["/home/user/Animals", "", "s3://mybucket/Zoo", "us-west-2", "Beasts"]
)
get_input.side_effect = fake_stream
get_input.side_effect = [
"/home/user/Animals",
"",
"mybucket",
"Zoo",
"us-west-2",
"Beasts",
]
args = argparse.Namespace(copy_target_credentials="Foo")

command = AddCommand(args, {"targets": {"bar": {}}}, utils.create_logger())
Expand All @@ -128,18 +132,16 @@ def test_copy_target_credentials_bad_target(self, get_input, capsys):
assert err == ('"Foo" is an unknown target\n' "Choices are: ['bar']\n")

def test_custom_target_name(self, get_input, config_file):
fake_stream = utils.FakeInputStream(
[
"/home/user/Music",
None,
"s3://mybucket/Musiccccc",
"us-west-1",
"1234567890",
"abcdefghij",
"Tunes",
]
)
get_input.side_effect = fake_stream
get_input.side_effect = [
"/home/user/Music",
None,
"mybucket",
"Musiccccc",
"us-west-1",
"1234567890",
"abcdefghij",
"Tunes",
]
args = argparse.Namespace(copy_target_credentials=None)

command = AddCommand(args, {"targets": {}}, utils.create_logger())
Expand Down
20 changes: 9 additions & 11 deletions tests/commands/test_edit_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_missing_target(self, get_input, capsys):
assert err == ('"idontexist" is an unknown target\n' "Choices are: ['foo']\n")

def test_no_changes(self, get_input, config_file):
fake_stream = utils.FakeInputStream(["", "", "", "", "", ""])
fake_stream = ["", "", "", "", "", ""]
get_input.side_effect = fake_stream

args = argparse.Namespace(target="foo")
Expand Down Expand Up @@ -68,16 +68,14 @@ def test_no_changes(self, get_input, config_file):
assert expected_config == config

def test_correct_output(self, get_input, config_file):
fake_stream = utils.FakeInputStream(
[
"/home/user/Documents",
"https://example.com",
"s3://buckets/mybackup222",
"9999999999",
"bbbbbbbbbbbbbbbbbbbbbbbb",
"eu-west-2",
]
)
fake_stream = [
"/home/user/Documents",
"https://example.com",
"s3://buckets/mybackup222",
"9999999999",
"bbbbbbbbbbbbbbbbbbbbbbbb",
"eu-west-2",
]
get_input.side_effect = fake_stream

args = argparse.Namespace(target="foo")
Expand Down
4 changes: 2 additions & 2 deletions tests/commands/test_sync_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from s4.resolution import Resolution
from s4.sync import SyncWorker

from tests.utils import FakeInputStream, create_logger, set_local_contents
from tests.utils import create_logger, set_local_contents


@mock.patch("s4.utils.get_input")
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_skip(self, get_input, s3_client, local_client):

@mock.patch("s4.commands.sync_command.show_diff")
def test_diff(self, show_diff, get_input, s3_client, local_client):
get_input.side_effect = FakeInputStream(["d", "X"])
get_input.side_effect = ["d", "X"]

action_1 = SyncState(SyncState.UPDATED, 1111, 2222)
action_2 = SyncState(SyncState.DELETED, 3333, 4444)
Expand Down
45 changes: 35 additions & 10 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,45 @@

@mock.patch("getpass.getpass")
@mock.patch("builtins.input")
def test_get_input(getpass, input_fn):
utils.get_input("give me some info", secret=False)
class TestGetInput:
def test_required(self, input_fn, getpass):
input_fn.side_effect = ["", "", "something"]

assert getpass.call_count == 1
assert input_fn.call_count == 0
result = utils.get_input("give me some info", required=True)

assert result == "something"
assert input_fn.call_count == 3
assert getpass.call_count == 0

@mock.patch("getpass.getpass")
@mock.patch("builtins.input")
def test_get_input_secret(getpass, input_fn):
utils.get_input("give me some secret info", secret=True)
def test_not_secret(self, input_fn, getpass):
input_fn.return_value = "foo"

result = utils.get_input("give me some info", secret=False)

assert result == "foo"

assert getpass.call_count == 0
assert input_fn.call_count == 1

def test_blank(self, input_fn, getpass):
input_fn.return_value = ""

result = utils.get_input("give me some info", blank=True)

assert result is None

assert getpass.call_count == 0
assert input_fn.call_count == 1

def test_secret(self, input_fn, getpass):
getpass.return_value = "bar"

result = utils.get_input("give me some secret info", secret=True)

assert result == "bar"

assert getpass.call_count == 0
assert input_fn.call_count == 1
assert getpass.call_count == 1
assert input_fn.call_count == 0


class TestGetConfigFile(object):
Expand Down
11 changes: 0 additions & 11 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,6 @@
from s4.utils import to_timestamp


class FakeInputStream(object):
def __init__(self, results):
self.results = results
self.index = 0

def __call__(self, *args, **kwargs):
output = self.results[self.index]
self.index += 1
return output


class InterruptedBytesIO(object):
"""
Test helper class that imitates a BytesIO stream. Will return a stream of 0s for
Expand Down

0 comments on commit eab3ad3

Please sign in to comment.