In [4]:
%pip install aioboto3 boto3 pandas plotly requests pillow tabulate kaleido nbformat ipython python-dotenv

Collecting python-dotenv
  Downloading python_dotenv-1.0.0-py3-none-any.whl (19 kB)
Installing collected packages: python-dotenv
Successfully installed python-dotenv-1.0.0
Note: you may need to restart the kernel to use updated packages.


In [15]:
from utils import queue_jobs
import requests
import os
import dotenv
dotenv.load_dotenv(".env", override=True)

salad_api_key = os.getenv("SALAD_API_KEY")
salad_org_id = os.getenv("SALAD_ORG")
salad_project_name = os.getenv("SALAD_PROJECT_NAME")
reporting_api_key = os.getenv("REPORTING_API_KEY")
reporting_url = os.getenv("REPORTING_URL")
queue_service_url = os.getenv("QUEUE_SERVICE_URL")

salad_headers = {
  "accept": "application/json",
  "Salad-Api-Key": salad_api_key,
}

salad_api_base_url = "https://api.salad.com/api/public"


In [16]:
vcpu = 2
memory = 1024 * 12

replica_count_per_group = 10

create_container_group_payload = {
  "name": "replace-this",
  "replicas": replica_count_per_group,
  "autostart_policy": True,
  "container": {
    "image": "replaceme:latest",
    "resources": {
      "cpu": vcpu,
      "memory": memory,
      "gpu_classes": []
    },
    "environment_variables": {
        "REPORTING_API_KEY": reporting_api_key,
        "REPORTING_URL": reporting_url,
        "QUEUE_SERVICE_URL": queue_service_url
    }
  }
}


def get_gpu_classes():
    url = f"{salad_api_base_url}/organizations/{salad_org_id}/gpu-classes"
    response = requests.get(url, headers=salad_headers)
    return [gpu for gpu in response.json()["items"] if gpu["name"] != "Stable Diffusion Compatible"]


def create_container_group(name, image, gpu, env={}):
    payload = create_container_group_payload.copy()
    payload["name"] = name
    payload["container"]["image"] = image
    payload["container"]["resources"]["gpu_classes"] = [gpu]
    payload["container"]["environment_variables"].update(env)
    url = f"{salad_api_base_url}/organizations/{salad_org_id}/projects/{salad_project_name}/containers"
    response = requests.post(url, headers=salad_headers, json=payload)
    return response.json()


gpu_classes = get_gpu_classes()
