Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.DS_Store
metrics.jsonl
service_account.json
play_integrity_service_account.json
.ruff_cache
# Since we are running it as a library, better not to commit the lock file
uv.lock
Expand Down
117 changes: 117 additions & 0 deletions scripts/play_integrity_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
E2E Play Integrity flow against a running MLPA server.
Requires a real Play Integrity token and configured service account file on the server.
"""

import argparse
import json
import os
from typing import Optional

import httpx

from mlpa.core.config import env

DEFAULT_BASE_URL = f"http://0.0.0.0:{env.PORT or 8080}"
DEFAULT_SERVICE_TYPE = "ai"


def _print_json(payload: dict) -> None:
print(json.dumps(payload, indent=2))


def _require_value(value: Optional[str], name: str) -> str:
if value:
return value
raise SystemExit(f"Missing required value for {name}.")


def run(args: argparse.Namespace) -> None:
integrity_token = _require_value(
args.integrity_token or os.getenv("MLPA_PLAY_INTEGRITY_TOKEN"),
"integrity_token",
)
user_id = _require_value(
args.user_id or os.getenv("MLPA_PLAY_USER_ID"),
"user_id",
)

verify_response = httpx.post(
f"{args.base_url}/verify/play",
json={"integrity_token": integrity_token, "user_id": user_id},
timeout=args.timeout_s,
)
verify_response.raise_for_status()
access_token = verify_response.json().get("access_token")
if not access_token:
raise SystemExit("No access_token returned from /verify/play.")

headers = {
"authorization": f"Bearer {access_token}",
"use-play-integrity": "true",
"service-type": args.service_type,
}
payload = {
"model": args.model or env.MODEL_NAME,
"messages": [{"role": "user", "content": args.message}],
"stream": args.stream,
}

if args.stream:
with httpx.stream(
"POST",
f"{args.base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=args.timeout_s,
) as response:
response.raise_for_status()
for line in response.iter_lines():
if line:
print(line)
else:
response = httpx.post(
f"{args.base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=args.timeout_s,
)
response.raise_for_status()
_print_json(response.json())


def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="E2E Play Integrity verification + chat completion."
)
subparsers = parser.add_subparsers(dest="command")

run_parser = subparsers.add_parser("run", help="Verify and request a completion.")
run_parser.add_argument("--integrity-token", dest="integrity_token")
run_parser.add_argument("--user-id", dest="user_id")
run_parser.add_argument(
"--base-url", dest="base_url", default="http://localhost:8080"
)
run_parser.add_argument("--timeout-s", dest="timeout_s", type=int, default=30)
run_parser.add_argument(
"--service-type", dest="service_type", default=DEFAULT_SERVICE_TYPE
)
run_parser.add_argument("--model", dest="model")
run_parser.add_argument("--message", dest="message", default="What is 2+2?")
run_parser.add_argument("--stream", dest="stream", action="store_true")
run_parser.set_defaults(func=run)

return parser


def main() -> None:
parser = build_parser()
args = parser.parse_args()
if not getattr(args, "command", None):
parser.print_help()
raise SystemExit(2)
args.func(args)


if __name__ == "__main__":
main()
45 changes: 25 additions & 20 deletions src/mlpa/core/auth/authorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mlpa.core.config import env
from mlpa.core.routers.appattest import app_attest_auth
from mlpa.core.routers.fxa import fxa_auth
from mlpa.core.utils import parse_app_attest_jwt
from mlpa.core.utils import extract_user_from_play_integrity_jwt, parse_app_attest_jwt


async def authorize_request(
Expand All @@ -15,30 +15,35 @@ async def authorize_request(
service_type: Annotated[ServiceType, Header()],
use_app_attest: Annotated[bool | None, Header()] = None,
use_qa_certificates: Annotated[bool | None, Header()] = None,
use_play_integrity: Annotated[bool | None, Header()] = None,
) -> AuthorizedChatRequest:
if not authorization:
raise HTTPException(status_code=401, detail="Missing authorization header")
if use_app_attest:
# Apple App Attest
assertionAuth = parse_app_attest_jwt(authorization, "assert")
data = await app_attest_auth(assertionAuth, chat_request, use_qa_certificates)
if data:
if data.get("error"):
raise HTTPException(status_code=401, detail=data["error"])
return AuthorizedChatRequest(
user=f"{assertionAuth.key_id_b64}:{service_type.value}", # "user" is key_id_b64 from app attest
service_type=service_type.value,
**chat_request.model_dump(exclude_unset=True),
)
if not data or data.get("error"):
raise HTTPException(status_code=401, detail=data["error"])
return AuthorizedChatRequest(
user=f"{assertionAuth.key_id_b64}:{service_type.value}", # "user" is key_id_b64 from app attest
service_type=service_type.value,
**chat_request.model_dump(exclude_unset=True),
)
elif use_play_integrity:
# Google Play integrity
play_user_id = extract_user_from_play_integrity_jwt(authorization)
return AuthorizedChatRequest(
user=f"{play_user_id}:{service_type.value}",
service_type=service_type.value,
**chat_request.model_dump(exclude_unset=True),
)
else:
fxa_user_id = await fxa_auth(authorization)
if fxa_user_id:
if fxa_user_id.get("error"):
raise HTTPException(status_code=401, detail=fxa_user_id["error"])
return AuthorizedChatRequest(
user=f"{fxa_user_id['user']}:{service_type.value}",
service_type=service_type.value,
**chat_request.model_dump(exclude_unset=True),
)
raise HTTPException(
status_code=401, detail="Please authenticate with App Attest or FxA."
)
if not fxa_user_id or fxa_user_id.get("error"):
raise HTTPException(status_code=401, detail=fxa_user_id["error"])
return AuthorizedChatRequest(
user=f"{fxa_user_id['user']}:{service_type.value}",
service_type=service_type.value,
**chat_request.model_dump(exclude_unset=True),
)
7 changes: 7 additions & 0 deletions src/mlpa/core/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class UserUpdatePayload(BaseModel):
blocked: bool | None = None


# iOS App Attest
class AttestationAuth(BaseModel):
key_id_b64: str
challenge_b64: str
Expand All @@ -37,6 +38,12 @@ class AssertionAuth(BaseModel):
assertion_obj_b64: str


# Google Play Integrity
class PlayIntegrityRequest(BaseModel):
integrity_token: str
user_id: str
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this allows empty value, but if it is empty, lower on https://github.com/Firefox-AI/MLPA/pull/79/changes#diff-b6cf33ca99c89749f7b57bcb0e80bf6fdc324b9f25577c2cb7af154baa0a62d9R37 may bypass the validation

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? play_user_id is only non null if extract_user_from_play_integrity_jwt succeeds

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@noahpodgurski so I was wondering if it's possible for play_user_id to be empty, 'cause if not, and it can't be invalid after extract_user_from_play_integrity_jwt succeeds, then maybe we don't need if play_user_id: on line 37 of authorize.py

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@noahpodgurski went over this with more analysis, think it's fine for now to handle the "" user_id situation:

extract_user_from_play_integrity_jwt has two outcomes:

  1. Returns payload["sub"] (a string)
  2. Raises HTTPException(401)

  It never returns None or empty string silently. If sub is missing, jwtoxide rejects it (it's in
  required_spec_claims). If decoding fails for any reason, the except block raises 401.

  The only way play_user_id could be empty string is the scenario from before -- someone issues a JWT with sub: ""
  -- which would require them to know MLPA_ACCESS_TOKEN_SECRET. At that point you have bigger problems.



class AuthorizedChatRequest(ChatRequest):
user: str
service_type: str
Expand Down
7 changes: 7 additions & 0 deletions src/mlpa/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def valid_service_types(self) -> list[str]:
APP_ATTEST_QA_BUCKET_PREFIX: str | None = None
APP_ATTEST_QA_GCP_PROJECT_ID: str | None = None

# Play Integrity
PLAY_INTEGRITY_PACKAGE_NAME: str = "com.example.app"
PLAY_INTEGRITY_SERVICE_ACCOUNT_FILE: str = "play_integrity_service_account.json"
PLAY_INTEGRITY_REQUEST_TIMEOUT_SECONDS: int = 30
MLPA_ACCESS_TOKEN_SECRET: str = "mlpa-dev-secret"
MLPA_ACCESS_TOKEN_TTL_SECONDS: int = 6000

# FxA
CLIENT_ID: str = "default-client-id"
CLIENT_SECRET: str = "default-client-secret"
Expand Down
3 changes: 3 additions & 0 deletions src/mlpa/core/routers/play/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from mlpa.core.routers.play.play import router as play_router

__all__ = ["play_router"]
108 changes: 108 additions & 0 deletions src/mlpa/core/routers/play/play.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import hashlib
from functools import lru_cache

import httpx
from fastapi import APIRouter, HTTPException
from fastapi.concurrency import run_in_threadpool
from google.auth.transport.requests import Request
from google.oauth2 import service_account
from pydantic import BaseModel

from mlpa.core.classes import PlayIntegrityRequest
from mlpa.core.config import env
from mlpa.core.http_client import get_http_client
from mlpa.core.utils import issue_mlpa_access_token, raise_and_log

router = APIRouter()

PLAY_INTEGRITY_SCOPE = "https://www.googleapis.com/auth/playintegrity"
ALLOWED_DEVICE_VERDICTS = {
"MEETS_DEVICE_INTEGRITY",
"MEETS_BASIC_INTEGRITY",
"MEETS_STRONG_INTEGRITY",
}


@lru_cache(maxsize=1)
def _get_service_account_credentials():
return service_account.Credentials.from_service_account_file(
env.PLAY_INTEGRITY_SERVICE_ACCOUNT_FILE,
scopes=[PLAY_INTEGRITY_SCOPE],
)


def _get_play_integrity_access_token() -> str:
credentials = _get_service_account_credentials()
if not credentials.valid:
credentials.refresh(Request())
if not credentials.token:
raise HTTPException(status_code=500, detail="Failed to fetch access token")
return credentials.token


async def _decode_integrity_token(integrity_token: str) -> dict:
access_token = await run_in_threadpool(_get_play_integrity_access_token)
client = get_http_client()
try:
response = await client.post(
f"https://playintegrity.googleapis.com/v1/{env.PLAY_INTEGRITY_PACKAGE_NAME}:decodeIntegrityToken",
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
json={"integrity_token": integrity_token},
timeout=env.PLAY_INTEGRITY_REQUEST_TIMEOUT_SECONDS,
)
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise_and_log(e, False, 401)
except Exception as e:
raise_and_log(e, False, 502, "Play Integrity validation service unavailable")
return response.json()


def _validate_integrity_payload(payload: dict, expected_hash: str) -> None:
request_details = payload.get("requestDetails", {})
Comment thread
ti3x marked this conversation as resolved.
package_name = request_details.get("requestPackageName")
if package_name != env.PLAY_INTEGRITY_PACKAGE_NAME:
raise HTTPException(status_code=401, detail="Invalid package name")

token_request_hash = request_details.get("requestHash")
if token_request_hash != expected_hash:
raise HTTPException(status_code=401, detail="Invalid request hash")

app_integrity = payload.get("appIntegrity", {})
acceptable_recognition_verdicts = [
"PLAY_RECOGNIZED",
]
if env.MLPA_DEBUG:
acceptable_recognition_verdicts.append("UNRECOGNIZED_VERSION")
if (
app_integrity.get("appRecognitionVerdict")
not in acceptable_recognition_verdicts
):
raise HTTPException(status_code=401, detail="App not recognized by Play")

device_integrity = payload.get("deviceIntegrity", {})
device_verdicts = set(device_integrity.get("deviceRecognitionVerdict", []))
if not device_verdicts.intersection(ALLOWED_DEVICE_VERDICTS):
raise HTTPException(status_code=401, detail="Device integrity check failed")


@router.post("/play", tags=["Play Integrity"])
async def verify_play_integrity(payload: PlayIntegrityRequest):
decoded = await _decode_integrity_token(payload.integrity_token)
token_payload = decoded.get("tokenPayloadExternal") or decoded.get("tokenPayload")
if not token_payload:
raise HTTPException(status_code=401, detail="Invalid Play Integrity token")

expected_hash = hashlib.sha256(payload.user_id.encode("utf-8")).hexdigest()

_validate_integrity_payload(token_payload, expected_hash)

access_token = issue_mlpa_access_token(payload.user_id)
return {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": env.MLPA_ACCESS_TOKEN_TTL_SECONDS,
}
38 changes: 37 additions & 1 deletion src/mlpa/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import ast
import base64
import json
import time

from fastapi import HTTPException
from fxa.oauth import Client
from jwtoxide import DecodingKey, ValidationOptions, decode
from jwtoxide import DecodingKey, ValidationOptions, decode, encode

from mlpa.core.classes import AssertionAuth, AttestationAuth
from mlpa.core.config import LITELLM_MASTER_AUTH_HEADERS, env
Expand Down Expand Up @@ -160,3 +161,38 @@ def raise_and_log(
else response_text_prefix or GENERIC_UPSTREAM_ERROR
},
)


def extract_user_from_play_integrity_jwt(authorization: str):
token = authorization.removeprefix("Bearer ").split()[0]
try:
payload = decode(
token,
env.MLPA_ACCESS_TOKEN_SECRET,
ValidationOptions(
required_spec_claims={"exp", "iat", "sub"},
iss={"mlpa"},
aud=None,
validate_aud=False,
validate_exp=True,
validate_nbf=False,
verify_signature=True,
algorithms=["HS256"],
),
)
return payload["sub"]
except Exception as e:
logger.error(f"Play Integrity JWT decode error: {e}")
raise HTTPException(status_code=401, detail="Invalid MLPA access token")


def issue_mlpa_access_token(user_id: str) -> str:
now = int(time.time())
payload = {
"sub": user_id,
"iat": now,
"exp": now + env.MLPA_ACCESS_TOKEN_TTL_SECONDS,
"iss": "mlpa",
"typ": "mlpa_access",
}
return encode(payload, env.MLPA_ACCESS_TOKEN_SECRET, algorithm="HS256")
Loading