In [None]:
from unittest.mock import Mock, patch
from urllib.parse import parse_qs, urlparse

import pytest
from pytest_django.asserts import assertTemplateUsed

from allauth.account.models import EmailAddress
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.internal import statekit
from allauth.socialaccount.models import SocialAccount
from allauth.socialaccount.providers.saml.utils import build_saml_config
from django.urls import reverse, reverse_lazy
from django.utils.http import urlencode

In [None]:
@pytest.mark.parametrize(
    "is_connect,state_kwargs,relay_state, expected_url",
    [
        (False, None, None, "/accounts/profile/"),
        (False, None, "/foo", "/foo"),
        (True, {"process": "connect"}, None, reverse_lazy("socialaccount_connections")),
        (True, {"process": "connect", "next_url": "/conn"}, None, "/conn"),
    ],
)
def test_acs(
    request,
    is_connect,
    db,
    saml_settings,
    acs_saml_response,
    mocked_signature_validation,
    expected_url,
    relay_state,
    state_kwargs,
    sociallogin_setup_state,
):
    if is_connect:
        client = request.getfixturevalue("auth_client")
        user = request.getfixturevalue("user")
    else:
        client = request.getfixturevalue("client")
        user = None

    if state_kwargs:
        assert not relay_state
        state_id = sociallogin_setup_state(client, **state_kwargs)
        relay_state = urlencode({"state": state_id})

    data = {"SAMLResponse": acs_saml_response}
    if relay_state is not None:
        data["RelayState"] = relay_state
    resp = client.post(
        reverse("saml_acs", kwargs={"organization_slug": "org"}), data=data
    )
    finish_url = reverse("saml_finish_acs", kwargs={"organization_slug": "org"})
    assert resp.status_code == 302
    assert resp["location"] == finish_url
    resp = client.get(finish_url)
    assert resp["location"] == expected_url
    account = SocialAccount.objects.get(
        provider="urn:dev-123.us.auth0.com", uid="dummysamluid"
    )
    assert account.extra_data["Role"] == ["view-profile", "manage-account-links"]
    email = EmailAddress.objects.get(user=account.user)
    assert email.email == (user.email if is_connect else "john.doe@email.org")

In [None]:
def test_acs_error(client, db, saml_settings):
    data = {"SAMLResponse": "bad-response"}
    resp = client.post(
        reverse("saml_acs", kwargs={"organization_slug": "org"}), data=data
    )
    assert resp.status_code == 200
    assert "socialaccount/authentication_error.html" in (t.name for t in resp.templates)

In [None]:
def test_acs_get(client, db, saml_settings):
    """ACS expects POST"""
    resp = client.get(reverse("saml_acs", kwargs={"organization_slug": "org"}))
    assert resp.status_code == 200
    assert "socialaccount/authentication_error.html" in (t.name for t in resp.templates)

In [None]:
def test_sls_get(client, db, saml_settings):
    """SLS expects POST"""
    resp = client.get(reverse("saml_sls", kwargs={"organization_slug": "org"}))
    assert resp.status_code == 400

In [None]:
def test_login_on_get(client, db, saml_settings):
    resp = client.get(reverse("saml_login", kwargs={"organization_slug": "org"}))
    assert resp.status_code == 200
    assertTemplateUsed(resp, "socialaccount/login.html")

In [None]:
def test_login(client, db, saml_settings):
    resp = client.post(
        reverse("saml_login", kwargs={"organization_slug": "org"})
        + "?process=connect&next=/foo"
    )
    assert resp.status_code == 302
    location = resp["location"]
    assert location.startswith("https://dev-123.us.auth0.com/samlp/456?SAMLRequest=")
    resp_query = parse_qs(urlparse(location).query)
    relay_state = resp_query.get("RelayState")[0]
    state_id = parse_qs(relay_state)["state"][0]
    state = client.session[statekit.STATES_SESSION_KEY][state_id][0]
    assert state == {"process": "connect", "data": None, "next": "/foo"}

In [None]:
def test_metadata(
    client,
    db,
    saml_settings,
):
    resp = client.get(reverse("saml_metadata", kwargs={"organization_slug": "org"}))
    assert resp.status_code == 200
    assert resp.content.startswith(
        b'<?xml version="1.0"?>\n<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata'
    )

In [None]:
def test_sls(auth_client, db, saml_settings, user_factory, sls_saml_request):
    with patch("allauth.account.adapter.DefaultAccountAdapter.logout") as logout_mock:
        resp = auth_client.get(
            reverse("saml_sls", kwargs={"organization_slug": "org"})
            + "?"
            + urlencode({"SAMLRequest": sls_saml_request})
        )
        assert logout_mock.call_count == 1
    assert resp.status_code == 302
    assert resp["location"].startswith(
        "https://dev-123.us.auth0.com/samlp/456?SAMLResponse="
    )

In [None]:
@pytest.mark.parametrize(
    "provider_config",
    [
        {
            "idp": {
                "entity_id": "dummy",
                "sso_url": "https://idp.org/sso/",
                "slo_url": "https://idp.saml.org/slo/",
                "x509cert": "cert",
            }
        },
    ],
)
def test_build_saml_config_without_metadata_url(rf, provider_config):
    request = rf.get("/")
    config = build_saml_config(request, provider_config, "org")
    assert config["idp"]["entityId"] == "dummy"
    assert config["idp"]["x509cert"] == "cert"
    assert config["idp"]["singleSignOnService"] == {"url": "https://idp.org/sso/"}
    assert config["idp"]["singleLogoutService"] == {"url": "https://idp.saml.org/slo/"}

In [None]:
@pytest.mark.parametrize(
    "provider_config",
    [
        {
            "idp": {
                "entity_id": "dummy",
                "metadata_url": "https://idp.org/sso/",
            }
        },
        {
            "idp": {
                "entity_id": "dummy",
                "metadata_url": "https://idp.org/sso/",
            },
            "sp": {"entity_id": "dummy-sp-entity-id"},
        },
    ],
)
def test_build_saml_config(rf, provider_config):
    request = rf.get("/")
    with patch(
        "onelogin.saml2.idp_metadata_parser.OneLogin_Saml2_IdPMetadataParser.parse_remote"
    ) as parse_mock:
        parse_mock.return_value = {
            "idp": {
                "entityId": "dummy",
                "singleSignOnService": {"url": "https://idp.org/sso/"},
                "singleLogoutService": {"url": "https://idp.saml.org/slo/"},
                "x509cert": "cert",
            }
        }
        config = build_saml_config(request, provider_config, "org")

    assert config["idp"]["entityId"] == "dummy"
    assert config["idp"]["x509cert"] == "cert"
    assert config["idp"]["singleSignOnService"] == {"url": "https://idp.org/sso/"}
    assert config["idp"]["singleLogoutService"] == {"url": "https://idp.saml.org/slo/"}
    metadata_url = request.build_absolute_uri(reverse("saml_metadata", args=["org"]))
    sp_entity_id = provider_config.get("sp", {}).get("entity_id")
    if sp_entity_id:
        assert config["sp"]["entityId"] == sp_entity_id
    else:
        assert config["sp"]["entityId"] == metadata_url

In [None]:
@pytest.mark.parametrize(
    "data, result, uid",
    [
        (
            {"urn:oasis:names:tc:SAML:attribute:subject-id": ["123"]},
            {"uid": "123", "email": "nameid@saml.org"},
            "123",
        ),
        ({}, {"email": "nameid@saml.org"}, "nameid@saml.org"),
    ],
)
def test_extract_attributes(db, data, result, uid, settings):
    settings.SOCIALACCOUNT_PROVIDERS = {
        "saml": {
            "APPS": [
                {
                    "client_id": "org",
                    "provider_id": "urn:dev-123.us.auth0.com",
                }
            ]
        }
    }
    provider = get_adapter().get_provider(request=None, provider="saml")
    onelogin_data = Mock()
    onelogin_data.get_attributes.return_value = data
    onelogin_data.get_nameid.return_value = "nameid@saml.org"
    onelogin_data.get_nameid_format.return_value = (
        "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
    )
    assert provider._extract(onelogin_data) == result
    assert provider.extract_uid(onelogin_data) == uid