Skip to content

Commit 1a95a15

Browse files
wrap api client to add defaults
Signed-off-by: Kevin <kpostlet@redhat.com>
1 parent d14902d commit 1a95a15

File tree

7 files changed

+67
-69
lines changed

7 files changed

+67
-69
lines changed

src/codeflare_sdk/cluster/auth.py

+31-30
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,7 @@ def __init__(
9393
self.token = token
9494
self.server = server
9595
self.skip_tls = skip_tls
96-
self.ca_cert_path = self._gen_ca_cert_path(ca_cert_path)
97-
98-
def _gen_ca_cert_path(self, ca_cert_path: str):
99-
if ca_cert_path is not None:
100-
return ca_cert_path
101-
elif "CF_SDK_CA_CERT_PATH" in os.environ:
102-
return os.environ.get("CF_SDK_CA_CERT_PATH")
103-
elif os.path.exists(WORKBENCH_CA_CERT_PATH):
104-
return WORKBENCH_CA_CERT_PATH
105-
else:
106-
return None
96+
self.ca_cert_path = _gen_ca_cert_path(ca_cert_path)
10797

10898
def login(self) -> str:
10999
"""
@@ -119,25 +109,14 @@ def login(self) -> str:
119109
configuration.host = self.server
120110
configuration.api_key["authorization"] = self.token
121111

112+
api_client = client.ApiClient(configuration)
122113
if not self.skip_tls:
123-
if self.ca_cert_path is None:
124-
configuration.ssl_ca_cert = None
125-
elif os.path.isfile(self.ca_cert_path):
126-
print(
127-
f"Authenticated with certificate located at {self.ca_cert_path}"
128-
)
129-
configuration.ssl_ca_cert = self.ca_cert_path
130-
else:
131-
raise FileNotFoundError(
132-
f"Certificate file not found at {self.ca_cert_path}"
133-
)
134-
configuration.verify_ssl = True
114+
_client_with_cert(api_client, self.ca_cert_path)
135115
else:
136116
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
137117
print("Insecure request warnings have been disabled")
138118
configuration.verify_ssl = False
139119

140-
api_client = client.ApiClient(configuration)
141120
client.AuthenticationApi(api_client).get_api_group()
142121
config_path = None
143122
return "Logged into %s" % self.server
@@ -211,11 +190,33 @@ def config_check() -> str:
211190
return config_path
212191

213192

214-
def api_config_handler() -> Optional[client.ApiClient]:
215-
"""
216-
This function is used to load the api client if the user has logged in
217-
"""
218-
if api_client != None and config_path == None:
219-
return api_client
193+
def _client_with_cert(client: client.ApiClient, ca_cert_path: Optional[str] = None):
194+
client.configuration.verify_ssl = True
195+
cert_path = _gen_ca_cert_path(ca_cert_path)
196+
if cert_path is None:
197+
client.configuration.ssl_ca_cert = None
198+
elif os.path.isfile(cert_path):
199+
client.configuration.ssl_ca_cert = cert_path
200+
else:
201+
raise FileNotFoundError(f"Certificate file not found at {cert_path}")
202+
203+
204+
def _gen_ca_cert_path(ca_cert_path: Optional[str]):
205+
"""Gets the path to the default CA certificate file either through env config or default path"""
206+
if ca_cert_path is not None:
207+
return ca_cert_path
208+
elif "CF_SDK_CA_CERT_PATH" in os.environ:
209+
return os.environ.get("CF_SDK_CA_CERT_PATH")
210+
elif os.path.exists(WORKBENCH_CA_CERT_PATH):
211+
return WORKBENCH_CA_CERT_PATH
220212
else:
221213
return None
214+
215+
216+
def get_api_client() -> client.ApiClient:
217+
"This function should load the api client with defaults"
218+
if api_client != None:
219+
return api_client
220+
to_return = client.ApiClient()
221+
_client_with_cert(to_return)
222+
return to_return

src/codeflare_sdk/cluster/awload.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from kubernetes import client, config
2626
from ..utils.kube_api_helpers import _kube_api_error_handling
27-
from .auth import config_check, api_config_handler
27+
from .auth import config_check, get_api_client
2828

2929

3030
class AWManager:
@@ -59,7 +59,7 @@ def submit(self) -> None:
5959
"""
6060
try:
6161
config_check()
62-
api_instance = client.CustomObjectsApi(api_config_handler())
62+
api_instance = client.CustomObjectsApi(get_api_client())
6363
api_instance.create_namespaced_custom_object(
6464
group="workload.codeflare.dev",
6565
version="v1beta2",
@@ -84,7 +84,7 @@ def remove(self) -> None:
8484

8585
try:
8686
config_check()
87-
api_instance = client.CustomObjectsApi(api_config_handler())
87+
api_instance = client.CustomObjectsApi(get_api_client())
8888
api_instance.delete_namespaced_custom_object(
8989
group="workload.codeflare.dev",
9090
version="v1beta2",

src/codeflare_sdk/cluster/cluster.py

+18-21
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from kubernetes import config
2727
from ray.job_submission import JobSubmissionClient
2828

29-
from .auth import config_check, api_config_handler
29+
from .auth import config_check, get_api_client
3030
from ..utils import pretty_print
3131
from ..utils.generate_yaml import (
3232
generate_appwrapper,
@@ -81,7 +81,7 @@ def __init__(self, config: ClusterConfiguration):
8181

8282
@property
8383
def _client_headers(self):
84-
k8_client = api_config_handler() or client.ApiClient()
84+
k8_client = get_api_client()
8585
return {
8686
"Authorization": k8_client.configuration.get_api_key_with_prefix(
8787
"authorization"
@@ -96,7 +96,7 @@ def _client_verify_tls(self):
9696

9797
@property
9898
def job_client(self):
99-
k8client = api_config_handler() or client.ApiClient()
99+
k8client = get_api_client()
100100
if self._job_submission_client:
101101
return self._job_submission_client
102102
if is_openshift_cluster():
@@ -142,7 +142,7 @@ def up(self):
142142

143143
try:
144144
config_check()
145-
api_instance = client.CustomObjectsApi(api_config_handler())
145+
api_instance = client.CustomObjectsApi(get_api_client())
146146
if self.config.appwrapper:
147147
if self.config.write_to_file:
148148
with open(self.app_wrapper_yaml) as f:
@@ -173,7 +173,7 @@ def up(self):
173173
return _kube_api_error_handling(e)
174174

175175
def _throw_for_no_raycluster(self):
176-
api_instance = client.CustomObjectsApi(api_config_handler())
176+
api_instance = client.CustomObjectsApi(get_api_client())
177177
try:
178178
api_instance.list_namespaced_custom_object(
179179
group="ray.io",
@@ -200,7 +200,7 @@ def down(self):
200200
self._throw_for_no_raycluster()
201201
try:
202202
config_check()
203-
api_instance = client.CustomObjectsApi(api_config_handler())
203+
api_instance = client.CustomObjectsApi(get_api_client())
204204
if self.config.appwrapper:
205205
api_instance.delete_namespaced_custom_object(
206206
group="workload.codeflare.dev",
@@ -359,7 +359,7 @@ def cluster_dashboard_uri(self) -> str:
359359
config_check()
360360
if is_openshift_cluster():
361361
try:
362-
api_instance = client.CustomObjectsApi(api_config_handler())
362+
api_instance = client.CustomObjectsApi(get_api_client())
363363
routes = api_instance.list_namespaced_custom_object(
364364
group="route.openshift.io",
365365
version="v1",
@@ -381,7 +381,7 @@ def cluster_dashboard_uri(self) -> str:
381381
return f"{protocol}://{route['spec']['host']}"
382382
else:
383383
try:
384-
api_instance = client.NetworkingV1Api(api_config_handler())
384+
api_instance = client.NetworkingV1Api(get_api_client())
385385
ingresses = api_instance.list_namespaced_ingress(self.config.namespace)
386386
except Exception as e: # pragma no cover
387387
return _kube_api_error_handling(e)
@@ -580,9 +580,6 @@ def get_current_namespace(): # pragma: no cover
580580
return active_context
581581
except Exception as e:
582582
print("Unable to find current namespace")
583-
584-
if api_config_handler() != None:
585-
return None
586583
print("trying to gather from current context")
587584
try:
588585
_, active_context = config.list_kube_config_contexts(config_check())
@@ -602,7 +599,7 @@ def get_cluster(
602599
):
603600
try:
604601
config_check()
605-
api_instance = client.CustomObjectsApi(api_config_handler())
602+
api_instance = client.CustomObjectsApi(get_api_client())
606603
rcs = api_instance.list_namespaced_custom_object(
607604
group="ray.io",
608605
version="v1",
@@ -657,7 +654,7 @@ def _create_resources(yamls, namespace: str, api_instance: client.CustomObjectsA
657654
def _check_aw_exists(name: str, namespace: str) -> bool:
658655
try:
659656
config_check()
660-
api_instance = client.CustomObjectsApi(api_config_handler())
657+
api_instance = client.CustomObjectsApi(get_api_client())
661658
aws = api_instance.list_namespaced_custom_object(
662659
group="workload.codeflare.dev",
663660
version="v1beta2",
@@ -684,7 +681,7 @@ def _get_ingress_domain(self): # pragma: no cover
684681

685682
if is_openshift_cluster():
686683
try:
687-
api_instance = client.CustomObjectsApi(api_config_handler())
684+
api_instance = client.CustomObjectsApi(get_api_client())
688685

689686
routes = api_instance.list_namespaced_custom_object(
690687
group="route.openshift.io",
@@ -703,7 +700,7 @@ def _get_ingress_domain(self): # pragma: no cover
703700
domain = route["spec"]["host"]
704701
else:
705702
try:
706-
api_client = client.NetworkingV1Api(api_config_handler())
703+
api_client = client.NetworkingV1Api(get_api_client())
707704
ingresses = api_client.list_namespaced_ingress(namespace)
708705
except Exception as e: # pragma: no cover
709706
return _kube_api_error_handling(e)
@@ -717,7 +714,7 @@ def _get_ingress_domain(self): # pragma: no cover
717714
def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]:
718715
try:
719716
config_check()
720-
api_instance = client.CustomObjectsApi(api_config_handler())
717+
api_instance = client.CustomObjectsApi(get_api_client())
721718
aws = api_instance.list_namespaced_custom_object(
722719
group="workload.codeflare.dev",
723720
version="v1beta2",
@@ -736,7 +733,7 @@ def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]:
736733
def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]:
737734
try:
738735
config_check()
739-
api_instance = client.CustomObjectsApi(api_config_handler())
736+
api_instance = client.CustomObjectsApi(get_api_client())
740737
rcs = api_instance.list_namespaced_custom_object(
741738
group="ray.io",
742739
version="v1",
@@ -758,7 +755,7 @@ def _get_ray_clusters(
758755
list_of_clusters = []
759756
try:
760757
config_check()
761-
api_instance = client.CustomObjectsApi(api_config_handler())
758+
api_instance = client.CustomObjectsApi(get_api_client())
762759
rcs = api_instance.list_namespaced_custom_object(
763760
group="ray.io",
764761
version="v1",
@@ -787,7 +784,7 @@ def _get_app_wrappers(
787784

788785
try:
789786
config_check()
790-
api_instance = client.CustomObjectsApi(api_config_handler())
787+
api_instance = client.CustomObjectsApi(get_api_client())
791788
aws = api_instance.list_namespaced_custom_object(
792789
group="workload.codeflare.dev",
793790
version="v1beta2",
@@ -816,7 +813,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
816813
dashboard_url = None
817814
if is_openshift_cluster():
818815
try:
819-
api_instance = client.CustomObjectsApi(api_config_handler())
816+
api_instance = client.CustomObjectsApi(get_api_client())
820817
routes = api_instance.list_namespaced_custom_object(
821818
group="route.openshift.io",
822819
version="v1",
@@ -835,7 +832,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
835832
dashboard_url = f"{protocol}://{route['spec']['host']}"
836833
else:
837834
try:
838-
api_instance = client.NetworkingV1Api(api_config_handler())
835+
api_instance = client.NetworkingV1Api(get_api_client())
839836
ingresses = api_instance.list_namespaced_ingress(
840837
rc["metadata"]["namespace"]
841838
)

src/codeflare_sdk/cluster/widgets.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .config import ClusterConfiguration
3030
from .model import RayClusterStatus
3131
from ..utils.kube_api_helpers import _kube_api_error_handling
32-
from .auth import config_check, api_config_handler
32+
from .auth import config_check, get_api_client
3333

3434

3535
def cluster_up_down_buttons(cluster: "codeflare_sdk.cluster.Cluster") -> widgets.Button:
@@ -343,7 +343,7 @@ def _delete_cluster(
343343

344344
try:
345345
config_check()
346-
api_instance = client.CustomObjectsApi(api_config_handler())
346+
api_instance = client.CustomObjectsApi(get_api_client())
347347

348348
if _check_aw_exists(cluster_name, namespace):
349349
api_instance.delete_namespaced_custom_object(

src/codeflare_sdk/utils/generate_cert.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from cryptography import x509
2020
from cryptography.x509.oid import NameOID
2121
import datetime
22-
from ..cluster.auth import config_check, api_config_handler
22+
from ..cluster.auth import config_check, get_api_client
2323
from kubernetes import client, config
2424
from .kube_api_helpers import _kube_api_error_handling
2525

@@ -103,7 +103,7 @@ def generate_tls_cert(cluster_name, namespace, days=30):
103103
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.key"}}'
104104
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt
105105
config_check()
106-
v1 = client.CoreV1Api(api_config_handler())
106+
v1 = client.CoreV1Api(get_api_client())
107107

108108
# Secrets have a suffix appended to the end so we must list them and gather the secret that includes cluster_name-ca-secret-
109109
secret_name = get_secret_name(cluster_name, namespace, v1)

src/codeflare_sdk/utils/generate_yaml.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import uuid
2828
from kubernetes import client, config
2929
from .kube_api_helpers import _kube_api_error_handling
30-
from ..cluster.auth import api_config_handler, config_check
30+
from ..cluster.auth import get_api_client, config_check
3131
from os import urandom
3232
from base64 import b64encode
3333
from urllib3.util import parse_url
@@ -57,7 +57,7 @@ def gen_names(name):
5757
def is_openshift_cluster():
5858
try:
5959
config_check()
60-
for api in client.ApisApi(api_config_handler()).get_api_versions().groups:
60+
for api in client.ApisApi(get_api_client()).get_api_versions().groups:
6161
for v in api.versions:
6262
if "route.openshift.io/v1" in v.group_version:
6363
return True
@@ -235,7 +235,7 @@ def get_default_kueue_name(namespace: str):
235235
# If the local queue is set, use it. Otherwise, try to use the default queue.
236236
try:
237237
config_check()
238-
api_instance = client.CustomObjectsApi(api_config_handler())
238+
api_instance = client.CustomObjectsApi(get_api_client())
239239
local_queues = api_instance.list_namespaced_custom_object(
240240
group="kueue.x-k8s.io",
241241
version="v1beta1",
@@ -261,7 +261,7 @@ def local_queue_exists(namespace: str, local_queue_name: str):
261261
# get all local queues in the namespace
262262
try:
263263
config_check()
264-
api_instance = client.CustomObjectsApi(api_config_handler())
264+
api_instance = client.CustomObjectsApi(get_api_client())
265265
local_queues = api_instance.list_namespaced_custom_object(
266266
group="kueue.x-k8s.io",
267267
version="v1beta1",

0 commit comments

Comments
 (0)