Skip to content

Commit

Permalink
write nested requirements to output file unless --no-recursive is pre…
Browse files Browse the repository at this point in the history
…sent
  • Loading branch information
alanhamlett committed May 28, 2018
1 parent 03834b9 commit 4ccecad
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 58 deletions.
123 changes: 79 additions & 44 deletions pur/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
from .exceptions import StopUpdating


PUR_GLOBAL_UPDATED = 0


@click.command()
@click.option('-r', '--requirement', type=click.Path(),
help='The requirements.txt file to update; Defaults to using ' +
Expand Down Expand Up @@ -88,16 +91,12 @@ def pur(**options):

options['echo'] = True

global UPDATED
UPDATED = 0

# patch pip for handling nested requirements files
patch_pip(options)
global PUR_GLOBAL_UPDATED
PUR_GLOBAL_UPDATED = 0

output_file = options['output'] or options['requirement']
update_requirements(
input_file=options['requirement'],
output_file=output_file,
output_file=options['output'],
force=options['force'],
interactive=options['interactive'],
skip=options['skip'],
Expand All @@ -111,14 +110,11 @@ def pur(**options):
_echo('All requirements up-to-date.')

if options['nonzero_exit_code']:
if UPDATED > 0:
if PUR_GLOBAL_UPDATED > 0:
raise ExitCodeException(11)
raise ExitCodeException(10)


UPDATED = 0


def update_requirements(input_file=None, output_file=None, force=False,
interactive=False, skip=[], only=[], dry_run=False,
no_recursive=False, echo=False):
Expand All @@ -139,11 +135,42 @@ def update_requirements(input_file=None, output_file=None, force=False,
will be updated.
"""

global UPDATED
obuffer = StringIO()
updates = defaultdict(list)

# patch pip for handling nested requirements files
patch_pip(obuffer, updates, input_file=input_file, output_file=output_file,
force=force, interactive=interactive, skip=skip, only=only,
dry_run=dry_run, no_recursive=no_recursive, echo=echo)

_internal_update_requirements(obuffer, updates,
input_file=input_file,
output_file=output_file,
force=force,
interactive=interactive, skip=skip,
only=only, dry_run=dry_run,
no_recursive=no_recursive,
echo=echo)

if not dry_run:
if not output_file:
output_file = input_file
with open(output_file, 'w') as output:
output.write(obuffer.getvalue())

obuffer.close()

return updates


def _internal_update_requirements(obuffer, updates, input_file=None,
output_file=None, force=False,
interactive=False, skip=[], only=[],
dry_run=False, no_recursive=False,
echo=False):
global PUR_GLOBAL_UPDATED

updated = 0
buf = StringIO()
updates = defaultdict(list)

try:
requirements = get_requirements_and_latest(input_file, force=force)
Expand All @@ -155,15 +182,16 @@ def update_requirements(input_file=None, output_file=None, force=False,

try:
if should_update(req, spec_ver, latest_ver,
force=force,
interactive=interactive):
force=force,
interactive=interactive):

if not spec_ver[0]:
new_line = '{0}=={1}'.format(line, latest_ver)
else:
new_line = update_requirement(req, line, spec_ver,
latest_ver)
buf.write(new_line)
new_line = update_requirement_line(req, line,
spec_ver,
latest_ver)
obuffer.write(new_line)

if new_line != line:
msg = 'Updated {package}: {old} -> {new}'.format(
Expand Down Expand Up @@ -192,52 +220,48 @@ def update_requirements(input_file=None, output_file=None, force=False,
_echo(msg)

else:
buf.write(line)
obuffer.write(line)
except StopUpdating:
stop = True
buf.write(line)
obuffer.write(line)

else:
buf.write(line)
elif not output_file or not requirements_line(line, req):
obuffer.write(line)

buf.write("\n")
if not output_file or not requirements_line(line, req):
obuffer.write('\n')

except InstallationError as e:
raise click.ClickException(str(e))

if not dry_run:
with open(output_file, 'w') as output:
output.write(buf.getvalue())
elif echo:
_echo('==> ' + output_file + ' <==')
_echo(buf.getvalue())
if dry_run and echo:
_echo('==> ' + (output_file or input_file) + ' <==')
_echo(obuffer.getvalue())

buf.close()
PUR_GLOBAL_UPDATED += updated

UPDATED += updated

return updates


def patch_pip(options):
def patch_pip(obuffer, updates, **options):
"""Patch pip to also update nested requirements files.
:param obuffer: Output buffer for new requirements file.
:param updates: Dict for saving information about updated packages.
:param options: Dict containing original command line arguments.
"""

global UPDATED
seen = []

def patched_parse_requirements(*args, **kwargs):
global UPDATED
if not options['no_recursive']:
filename = args[0]
if not options['output'] and filename not in seen:
if filename not in seen:
if os.path.isfile(filename):
seen.append(filename)
update_requirements(
buf = StringIO()
_internal_update_requirements(
buf, updates,
input_file=filename,
output_file=filename,
output_file=options['output_file'],
force=options['force'],
interactive=options['interactive'],
skip=options['skip'],
Expand All @@ -246,6 +270,13 @@ def patched_parse_requirements(*args, **kwargs):
no_recursive=options['no_recursive'],
echo=options['echo'],
)
if not options['dry_run']:
if options['output_file']:
obuffer.write(buf.getvalue())
else:
with open(filename, 'w') as output:
output.write(buf.getvalue())
buf.close()
return []
req_file.parse_requirements = patched_parse_requirements

Expand Down Expand Up @@ -391,7 +422,7 @@ def join_lines(lines_enum):
new_line.append(line)
orig_lines.append(orig_line)
yield (primary_line_number, ''.join(new_line),
"\n".join(orig_lines))
'\n'.join(orig_lines))
new_line = []
orig_lines = []
else:
Expand All @@ -404,7 +435,7 @@ def join_lines(lines_enum):

# last line contains \
if new_line:
yield primary_line_number, ''.join(new_line), "\n".join(orig_lines)
yield primary_line_number, ''.join(new_line), '\n'.join(orig_lines)


def latest_version(req, session, finder, include_prereleases=False):
Expand Down Expand Up @@ -516,7 +547,7 @@ def ask_to_update(req, spec_ver, latest_ver):
raise StopUpdating()


def update_requirement(req, line, spec_ver, latest_ver):
def update_requirement_line(req, line, spec_ver, latest_ver):
"""Updates the version of a requirement line.
Returns a new requirement line with the package version updated.
Expand Down Expand Up @@ -548,6 +579,10 @@ def update_requirement(req, line, spec_ver, latest_ver):
return new_line


def requirements_line(line, req):
return not req and line and line.strip().startswith('-r ')


class ExitCodeException(click.ClickException):
def __init__(self, exit_code):
self.exit_code = exit_code
Expand Down
34 changes: 20 additions & 14 deletions tests/test_pur.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
import os
import shutil
import tempfile
from click.testing import CliRunner
from pip._internal.index import InstallationCandidate, Link, PackageFinder

from pur import pur, __version__

from click.testing import CliRunner
from pip._internal.index import InstallationCandidate, Link, PackageFinder

from . import utils
from .utils import u


class BaseTestCase(utils.TestCase):
class PurTestCase(utils.TestCase):

def setUp(self):
self.runner = CliRunner()
self.maxDiff = None

def test_help_contents(self):
args = ['--help']
Expand Down Expand Up @@ -105,8 +107,10 @@ def test_requirements_long_option_accepted(self):
def test_updates_package_to_output_file(self):
tempdir = tempfile.mkdtemp()
output = os.path.join(tempdir, 'output.txt')
requirements = open('tests/samples/requirements.txt').read()
args = ['-r', 'tests/samples/requirements.txt', '--output', output]
previous = open('tests/samples/requirements.txt').read()
requirements = os.path.join(tempdir, 'requirements.txt')
shutil.copy('tests/samples/requirements.txt', requirements)
args = ['-r', requirements, '--output', output]

with utils.mock.patch('pip._internal.index.PackageFinder.find_all_candidates') as mock_find_all_candidates:
project = 'flask'
Expand All @@ -120,9 +124,12 @@ def test_updates_package_to_output_file(self):
expected_output = "Updated flask: 0.9 -> 0.10.1\nAll requirements up-to-date.\n"
self.assertEquals(u(result.output), u(expected_output))
self.assertEquals(result.exit_code, 0)
self.assertEquals(open('tests/samples/requirements.txt').read(), requirements)
self.assertEquals(open('tests/samples/requirements.txt').read(), previous)
expected_requirements = open('tests/samples/results/test_updates_package').read()
self.assertEquals(open(output).read(), expected_requirements)

def test_does_not_update_nested_requirements_to_output_file(self):
def test_updates_nested_requirements_to_output_file(self):
tempdir = tempfile.mkdtemp()
tempdir = tempfile.mkdtemp()
output = os.path.join(tempdir, 'output.txt')
requirements = os.path.join(tempdir, 'requirements-with-nested-reqfile.txt')
Expand All @@ -131,6 +138,10 @@ def test_does_not_update_nested_requirements_to_output_file(self):
shutil.copy('tests/samples/requirements-nested.txt', requirements_nested)
args = ['-r', requirements, '--output', output]

expected_output = "Updated readtime: 0.9 -> 0.10.1\nAll requirements up-to-date.\n"
expected_requirements = open('tests/samples/results/test_updates_package_in_nested_requirements').read()
expected_requirements = expected_requirements.replace('-r requirements-nested.txt\n', open('tests/samples/results/test_updates_package_in_nested_requirements_nested').read())

with utils.mock.patch('pip._internal.index.PackageFinder.find_all_candidates') as mock_find_all_candidates:
project = 'readtime'
version = '0.10.1'
Expand All @@ -140,15 +151,10 @@ def test_does_not_update_nested_requirements_to_output_file(self):

result = self.runner.invoke(pur, args)
self.assertIsNone(result.exception)
expected_output = "All requirements up-to-date.\n"
self.assertEquals(u(result.output), u(expected_output))
self.assertEquals(result.exit_code, 0)

expected_requirements = open('tests/samples/requirements-with-nested-reqfile.txt').read()
self.assertEquals(open(requirements).read(), expected_requirements)
expected_requirements = open('tests/samples/requirements-nested.txt').read()
self.assertEquals(open(requirements_nested).read(), expected_requirements)
expected_requirements = open('tests/samples/results/test_updates_package_in_nested_requirements').read()
self.assertEquals(open(requirements_nested).read(), open('tests/samples/requirements-nested.txt').read())
self.assertEquals(open(requirements).read(), open('tests/samples/requirements-with-nested-reqfile.txt').read())
self.assertEquals(open(output).read(), expected_requirements)

def test_exit_code_from_no_updates(self):
Expand Down

0 comments on commit 4ccecad

Please sign in to comment.