diff --git a/CHANGELOG.md b/CHANGELOG.md index 66aba94056f6..8a6293dc7b9a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/ ### Task Library -- None +- Add task for creating new branches in a GitHub repository - [#1011](https://github.com/PrefectHQ/prefect/pull/1011) ### Fixes diff --git a/docs/guide/task_library/github.md b/docs/guide/task_library/github.md index 27f54a6c0bca..1a72b6e48c6a 100644 --- a/docs/guide/task_library/github.md +++ b/docs/guide/task_library/github.md @@ -22,3 +22,9 @@ Task for opening / creating new GitHub issues using the v3 version of the GitHub Task for retrieving GitHub repository information using the v3 version of the GitHub REST API. [API Reference](/api/unreleased/tasks/github.html#prefect-tasks-github-prs-getrepoinfo) + +## CreateBranch + +Task for creating new branches in a given GitHub repository using the v3 version of the GitHub REST API. + +[API Reference](/api/unreleased/tasks/github.html#prefect-tasks-github-prs-createbranch) diff --git a/docs/outline.toml b/docs/outline.toml index 2038b02f55cd..0c7587c25da3 100644 --- a/docs/outline.toml +++ b/docs/outline.toml @@ -161,7 +161,7 @@ classes = ["S3Download", "S3Upload"] [pages.tasks.github] title = "GitHub Tasks" module = "prefect.tasks.github" -classes = ["CreateGitHubPR", "OpenGitHubIssue", "GetRepoInfo"] +classes = ["CreateGitHubPR", "OpenGitHubIssue", "GetRepoInfo", "CreateBranch"] [pages.tasks.kubernetes] title = "Kubernetes Tasks" diff --git a/src/prefect/tasks/github/__init__.py b/src/prefect/tasks/github/__init__.py index 57d438f64db0..5b1e2c4a8541 100644 --- a/src/prefect/tasks/github/__init__.py +++ b/src/prefect/tasks/github/__init__.py @@ -7,4 +7,4 @@ from .issues import OpenGitHubIssue from .prs import CreateGitHubPR -from .repos import GetRepoInfo +from .repos import GetRepoInfo, CreateBranch diff --git a/src/prefect/tasks/github/repos.py b/src/prefect/tasks/github/repos.py index 63c9e29e6f4d..b7b7d888efe0 100644 --- a/src/prefect/tasks/github/repos.py +++ b/src/prefect/tasks/github/repos.py @@ -70,3 +70,88 @@ def run(self, repo: str = None, info_keys: List[str] = None) -> None: data = resp.json() return {key: data[key] for key in info_keys} + + +class CreateBranch(Task): + """ + Task for creating new branches using the v3 version of the GitHub REST API. + + Args: + - repo (str, optional): the name of the repository to create the branch in; must be provided in the + form `organization/repo_name` or `user/repo_name`; can also be provided to the `run` method + - base (str, optional): the name of the branch you want to branch off; can also + be provided to the `run` method. Defaults to "master". + - branch_name (str, optional): the name of the new branch; can also be provided to the `run` method + - token_secret (str, optional): the name of the Prefect Secret containing your GitHub Access Token; + defaults to "GITHUB_ACCESS_TOKEN" + - **kwargs (Any, optional): additional keyword arguments to pass to the standard Task init method + """ + + def __init__( + self, + repo: str = None, + base: str = "master", + branch_name: str = None, + token_secret: str = "GITHUB_ACCESS_TOKEN", + **kwargs: Any + ): + self.repo = repo + self.base = base + self.branch_name = branch_name + self.token_secret = token_secret + super().__init__(**kwargs) + + @defaults_from_attrs("repo", "base", "branch_name") + def run(self, repo: str = None, base: str = None, branch_name: str = None) -> dict: + """ + Run method for this Task. Invoked by calling this Task after initialization within a Flow context, + or by using `Task.bind`. + + Args: + - repo (str, optional): the name of the repository to open the issue in; must be provided in the + form `organization/repo_name`; defaults to the one provided at initialization + - base (str, optional): the name of the branch you want to branch off; if not provided here, + defaults to the one set at initialization + - branch_name (str, optional): the name of the new branch; if not provided here, defaults to + the one set at initialization + + Raises: + - ValueError: if a `repo` or `branch_name` was never provided, or if the base branch wasn't found + - HTTPError: if the GET request returns a non-200 status code + + Returns: + - dict: dictionary of the response (includes commit hash, etc.) + """ + if branch_name is None: + raise ValueError("A branch name must be provided.") + + if repo is None: + raise ValueError("A GitHub repository must be provided.") + + ## prepare the request + token = Secret(self.token_secret).get() + url = "https://api.github.com/repos/{}/git/refs".format(repo) + headers = { + "AUTHORIZATION": "token {}".format(token), + "Accept": "application/vnd.github.v3+json", + } + + ## gather branch information + resp = requests.get(url + "/heads", headers=headers) + resp.raise_for_status() + branch_data = resp.json() + + commit_sha = None + for branch in branch_data: + if branch.get("ref") == "refs/heads/{}".format(base): + commit_sha = branch.get("object", {}).get("sha") + break + + if commit_sha is None: + raise ValueError("Base branch {} not found.".format(base)) + + ## create new branch + new_branch = {"ref": "refs/heads/{}".format(branch_name), "sha": commit_sha} + resp = requests.post(url, headers=headers, json=new_branch) + resp.raise_for_status() + return resp.json() diff --git a/tests/tasks/github/test_repos.py b/tests/tasks/github/test_repos.py index a09972292c7c..a7425d6e7655 100644 --- a/tests/tasks/github/test_repos.py +++ b/tests/tasks/github/test_repos.py @@ -3,7 +3,7 @@ import pytest import prefect -from prefect.tasks.github import GetRepoInfo +from prefect.tasks.github import GetRepoInfo, CreateBranch from prefect.utilities.configuration import set_temporary_config @@ -32,8 +32,40 @@ def test_repo_is_required_eventually(self): assert "repo" in str(exc.value) -class TestCredentialsandProjects: - def test_creds_are_pulled_from_secret_at_runtime(self, monkeypatch): +class TestCreateBranchInitialization: + def test_initializes_with_nothing_and_sets_defaults(self): + task = CreateBranch() + assert task.repo is None + assert task.base == "master" + assert task.branch_name is None + assert task.token_secret == "GITHUB_ACCESS_TOKEN" + + def test_additional_kwargs_passed_upstream(self): + task = CreateBranch(name="test-task", checkpoint=True, tags=["bob"]) + assert task.name == "test-task" + assert task.checkpoint is True + assert task.tags == {"bob"} + + @pytest.mark.parametrize("attr", ["repo", "base", "branch_name", "token_secret"]) + def test_initializes_attr_from_kwargs(self, attr): + task = CreateBranch(**{attr: "my-value"}) + assert getattr(task, attr) == "my-value" + + def test_repo_is_required_eventually(self): + task = CreateBranch(branch_name="bob") + with pytest.raises(ValueError) as exc: + task.run() + assert "repo" in str(exc.value) + + def test_branch_name_is_required_eventually(self): + task = CreateBranch(repo="org/bob") + with pytest.raises(ValueError) as exc: + task.run() + assert "branch name" in str(exc.value) + + +class TestCredentials: + def test_creds_are_pulled_from_secret_at_runtime_repo_info(self, monkeypatch): task = GetRepoInfo() req = MagicMock() @@ -45,7 +77,21 @@ def test_creds_are_pulled_from_secret_at_runtime(self, monkeypatch): assert req.get.call_args[1]["headers"]["AUTHORIZATION"] == "token {'key': 42}" - def test_creds_secret_can_be_overwritten(self, monkeypatch): + def test_creds_are_pulled_from_secret_at_runtime_create_branch(self, monkeypatch): + task = CreateBranch() + + req = MagicMock() + monkeypatch.setattr("prefect.tasks.github.repos.requests", req) + + with set_temporary_config({"cloud.use_local_secrets": True}): + with prefect.context(secrets=dict(GITHUB_ACCESS_TOKEN={"key": 42})): + with pytest.raises(ValueError) as exc: + task.run(repo="org/repo", branch_name="new") + + assert req.get.call_args[1]["headers"]["AUTHORIZATION"] == "token {'key': 42}" + assert "not found" in str(exc.value) + + def test_creds_secret_can_be_overwritten_repo_info(self, monkeypatch): task = GetRepoInfo(token_secret="MY_SECRET") req = MagicMock() @@ -56,3 +102,35 @@ def test_creds_secret_can_be_overwritten(self, monkeypatch): task.run(repo="org/repo") assert req.get.call_args[1]["headers"]["AUTHORIZATION"] == "token {'key': 42}" + + def test_creds_secret_can_be_overwritten_create_branch(self, monkeypatch): + task = CreateBranch(token_secret="MY_SECRET") + + req = MagicMock() + monkeypatch.setattr("prefect.tasks.github.repos.requests", req) + + with set_temporary_config({"cloud.use_local_secrets": True}): + with prefect.context(secrets=dict(MY_SECRET={"key": 42})): + with pytest.raises(ValueError): + task.run(repo="org/repo", branch_name="new") + + assert req.get.call_args[1]["headers"]["AUTHORIZATION"] == "token {'key': 42}" + + +def test_base_name_is_filtered_for(monkeypatch): + task = CreateBranch(base="BOB", branch_name="NEWBRANCH") + + payload = [{"ref": "refs/heads/BOB", "object": {"sha": "salty"}}] + req = MagicMock( + get=MagicMock(return_value=MagicMock(json=MagicMock(return_value=payload))) + ) + monkeypatch.setattr("prefect.tasks.github.repos.requests", req) + + with set_temporary_config({"cloud.use_local_secrets": True}): + with prefect.context(secrets=dict(MY_SECRET={"key": 42})): + task.run(repo="org/repo") + + assert req.post.call_args[1]["json"] == { + "ref": "refs/heads/NEWBRANCH", + "sha": "salty", + }