Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.pyc
.idea/
.vscode
.mypy_cache
2 changes: 1 addition & 1 deletion patches/kaggle_gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class KaggleKernelCredentials(credentials.Credentials):
def refresh(self, request):
try:
client = UserSecretsClient()
self.token = client.get_bigquery_access_token()
self.token, self.expiry = client.get_bigquery_access_token()
except Exception as e:
raise RefreshError('Unable to refresh access token.') from e

Expand Down
22 changes: 18 additions & 4 deletions patches/kaggle_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import json
import os
import urllib.request
from urllib.error import HTTPError
from typing import Tuple, Optional
from datetime import datetime, timedelta

_KAGGLE_DEFAULT_URL_BASE = "https://www.kaggle.com"
_KAGGLE_URL_BASE_ENV_VAR_NAME = "KAGGLE_URL_BASE"
Expand Down Expand Up @@ -37,7 +40,7 @@ def __init__(self):
f'but none found in environment variable {_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME}')
self.headers = {'Content-type': 'application/json'}

def _make_post_request(self, data):
def _make_post_request(self, data: dict) -> dict:
url = f'{self.url_base}{self.GET_USER_SECRET_ENDPOINT}'
request_body = dict(data)
request_body['JWE'] = self.jwt_token
Expand All @@ -50,18 +53,29 @@ def _make_post_request(self, data):
raise BackendError(
'Unexpected response from the service.')
return response_json['result']
except urllib.error.HTTPError as e:
except HTTPError as e:
if e.code == 401 or e.code == 403:
raise CredentialError(f'Service responded with error code {e.code}.'
' Please ensure you have access to the resource.') from e
raise BackendError('Unexpected response from the service.') from e

def get_bigquery_access_token(self):
def get_bigquery_access_token(self) -> Tuple[str, Optional[datetime]]:
"""Retrieves BigQuery access token information from the UserSecrets service.

This returns the token for the current kernel as well as its expiry (abs time) if it
is present.
Example usage:
client = UserSecretsClient()
token, expiry = client.get_bigquery_access_token()
"""
request_body = {
'Target': self.BIGQUERY_TARGET_VALUE
}
response_json = self._make_post_request(request_body)
if 'secret' not in response_json:
raise BackendError(
'Unexpected response from the service.')
return response_json['secret']
# Optionally return expiry if it is set.
expiresInSeconds = response_json.get('expiresInSeconds')
expiry = datetime.utcnow() + timedelta(seconds=expiresInSeconds) if expiresInSeconds else None
return response_json['secret'], expiry
11 changes: 8 additions & 3 deletions tests/test_user_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from http.server import BaseHTTPRequestHandler, HTTPServer
from test.support import EnvironmentVarGuard
from urllib.parse import urlparse
from datetime import datetime, timedelta
import mock

from google.auth.exceptions import DefaultCredentialsError
from google.cloud import bigquery
Expand Down Expand Up @@ -50,7 +52,7 @@ def set_request(self):

def get_response(self):
if success:
return {'result': {'secret': secret, 'secretType': 'refreshToken', 'secretProvider': 'google'}, 'wasSuccessful': "true"}
return {'result': {'secret': secret, 'secretType': 'refreshToken', 'secretProvider': 'google', 'expiresInSeconds': 3600}, 'wasSuccessful': "true"}
else:
return {'wasSuccessful': "false"}

Expand Down Expand Up @@ -82,13 +84,16 @@ def test_no_token_fails(self):
with self.assertRaises(CredentialError):
client = UserSecretsClient()

def test_get_access_token_succeeds(self):
@mock.patch('kaggle_secrets.datetime')
def test_get_access_token_succeeds(self, mock_dt):
secret = '12345'
now = datetime(1993, 4, 24)
mock_dt.utcnow = mock.Mock(return_value=now)

def call_get_access_token():
client = UserSecretsClient()
secret_response = client.get_bigquery_access_token()
self.assertEqual(secret_response, secret)
self.assertEqual(secret_response, (secret, now + timedelta(seconds=3600)))
self._test_client(call_get_access_token,
'/requests/GetUserSecretRequest', {'Target': 1, 'JWE': _TEST_JWT}, secret=secret)

Expand Down