Skip to content

Commit

Permalink
signed connections (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
earonesty committed Oct 23, 2023
1 parent d9d0e98 commit 4eaa02f
Show file tree
Hide file tree
Showing 5 changed files with 368 additions and 12 deletions.
133 changes: 133 additions & 0 deletions ai_worker/key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import base64
import os
from hashlib import sha256
from typing import Optional, Union

import coincurve as secp256k1

"""Minimalist pub/priv key classes for signing and verification based on coincurve"""


class PublicKey:
def __init__(self,
raw_bytes: Union[bytes, "PrivateKey", secp256k1.keys.PublicKey, secp256k1.keys.PublicKeyXOnly, str]):
"""
:param raw_bytes: The formatted public key.
:type data: bytes, private key to copy, or b64 str
"""
if isinstance(raw_bytes, PrivateKey):
self.raw_bytes = raw_bytes.public_key.raw_bytes
elif isinstance(raw_bytes, secp256k1.keys.PublicKey):
self.raw_bytes = raw_bytes.format(compressed=True)[2:]
elif isinstance(raw_bytes, secp256k1.keys.PublicKeyXOnly):
self.raw_bytes = raw_bytes.format()
elif isinstance(raw_bytes, str):
self.raw_bytes = base64.urlsafe_b64decode(raw_bytes)
else:
self.raw_bytes = raw_bytes

def to_b64(self) -> str:
return base64.urlsafe_b64encode(self.raw_bytes).decode()

def verify(self, sig: str, message: bytes) -> bool:
pk = secp256k1.PublicKeyXOnly(self.raw_bytes)
return pk.verify(base64.urlsafe_b64decode(sig), message)

@classmethod
def from_b64(cls, b64: str, /) -> 'PublicKey':
return cls(base64.urlsafe_b64decode(b64))

def __repr__(self):
pubkey = self.to_b64()
return f'PublicKey({pubkey[:10]}...{pubkey[-10:]})'

def __eq__(self, other):
return isinstance(other, PublicKey) and self.raw_bytes == other.raw_bytes

def __hash__(self):
return hash(self.raw_bytes)

def __str__(self):
"""Return public key in b64 form
:return: string
:rtype: str
"""
return self.to_b64()

def __bytes__(self):
"""Return raw bytes
:return: Raw bytes
:rtype: bytes
"""
return self.raw_bytes


class PrivateKey:
def __init__(self, raw_secret: Optional[bytes] = None) -> None:
"""
:param raw_secret: The secret used to initialize the private key.
If not provided or `None`, a new key will be generated.
:type raw_secret: bytes
"""
if raw_secret is not None:
self.raw_secret = raw_secret
else:
self.raw_secret = os.urandom(32)

sk = secp256k1.PrivateKey(self.raw_secret)
self.public_key = PublicKey(sk.public_key_xonly)

@classmethod
def from_b64(cls, b64: str, /):
"""Load a PrivateKey from its b64 form."""
return cls(base64.urlsafe_b64decode(b64))

def __hash__(self):
return hash(self.raw_secret)

def __eq__(self, other):
return isinstance(other, PrivateKey) and self.raw_secret == other.raw_secret

def to_b64(self) -> str:
return base64.b64encode(self.raw_secret).decode()

def sign(self, message: bytes, aux_randomness: bytes = b'') -> str:
sk = secp256k1.PrivateKey(self.raw_secret)
return base64.urlsafe_b64encode(sk.sign_schnorr(message, aux_randomness)).decode()

def __repr__(self):
pubkey = base64.urlsafe_b64encode(public_key).decode()
return f'PrivateKey({pubkey[:10]}...{pubkey[-10:]})'

def __str__(self):
"""Return private key in b64 form
:return: b64 string
:rtype: str
"""
return self.to_b64()

def __bytes__(self):
"""Return raw bytes
:return: Raw bytes
:rtype: bytes
"""
return self.raw_secret


def test_cp():
pk = PrivateKey()
pk2 = PrivateKey(pk.raw_secret)
assert pk == pk2


def test_fromb64():
pk = PrivateKey()
pk2 = PrivateKey.from_b64(pk.to_b64())
assert pk == pk2


def test_sig():
pk = PrivateKey()
pub = pk.public_key
sig = pk.sign(b'1' * 32)
assert pub.verify(sig, b'1' * 32)
60 changes: 53 additions & 7 deletions ai_worker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import platform
import sys
import time
from hashlib import sha256, md5
from pprint import pprint
from typing import Optional, List

from base64 import urlsafe_b64encode as b64encode, urlsafe_b64decode as b64decode
import psutil
import websockets
from httpx import Response, AsyncClient
Expand All @@ -27,6 +28,7 @@
from gguf_loader.main import get_size

from .gguf_reader import GGUFReader
from .key import PrivateKey
from .version import VERSION

APP_NAME = "gputopia"
Expand Down Expand Up @@ -54,10 +56,12 @@ class GpuInfo(BaseModel):

class ConnectMessage(BaseModel):
worker_version: str
worker_id: str
ln_url: str # sent for back compat. will drop this eventually
pubkey: str
slug: str = ""
sig: str = ""
ln_url: str # sent for back compat. will drop this eventually
ln_address: str
auth_key: str
auth_key: str # user private auth token for queenbee
cpu_count: int
disk_space: int
vram: int
Expand All @@ -74,7 +78,6 @@ class Config(BaseSettings):
env_prefix=ENV_PREFIX, case_sensitive=False)
auth_key: str = Field('', description="authentication key for a user account")
queen_url: str = Field(DEFAULT_COORDINATOR, description="websocket url of the coordinator")
worker_id: str = Field('', description="unique private worker id. autogenerated by default.")
ln_address: str = Field('DONT_PAY_ME', description="a lightning address")
loops: int = Field(0, description="quit after getting this number of jobs")
debug: bool = Field(False, description="verbose debugging info")
Expand All @@ -85,7 +88,8 @@ class Config(BaseSettings):
tensor_split: str = Field("", description="comma-delimited list of ratio numbers, one for each gpu")
force_layers: int = Field(0, description="force layers to load in the gpu")
layer_offset: int = Field(2, description="reduce the layer guess by this")

config: str = Field(os.path.expanduser("~/.config/gputopia"), description="config file location")
privkey: str = Field("", description=argparse.SUPPRESS, exclude=True)

def get_free_space_mb(dirname):
"""Return folder/drive free space (in megabytes)."""
Expand All @@ -103,11 +107,40 @@ class WorkerMain:
def __init__(self, conf: Config):
self.__connect_info: Optional[ConnectMessage] = None
self.conf = conf
self._gen_or_load_priv()
self.__sk = PrivateKey(b64decode(self.conf.privkey))
self.pubkey = self.__sk.public_key.to_b64()
if self.conf.main_gpu or self.conf.tensor_split:
self.slug = b64encode(md5((str(self.conf.main_gpu) + self.conf.tensor_split).encode()).digest()).decode()
else:
self.slug = ""
self.stopped = False
self.llama = None
self.llama_model = None
self.llama_cli: Optional[AsyncClient] = None

def _gen_or_load_priv(self) -> None:
if not self.conf.privkey:
cfg = self.conf.config
if os.path.exists(cfg):
with open(cfg, encoding="utf8") as fh:
js = json.load(fh)
else:
js = {}
if not js.get("privkey"):
js["privkey"] = b64encode(os.urandom(32)).decode()
with open(cfg, "w", encoding="utf8") as fh:
json.dump(js, fh, indent=4)
self.conf.privkey = js["privkey"]

def sign(self, msg: ConnectMessage):
js = msg.model_dump(mode="json")
js.pop("sig", None)
# this is needed for a consistent dump!
dump = json.dumps(js, separators=(",", ":"), sort_keys=True, ensure_ascii=False)
h32 = sha256(dump.encode()).digest()
msg.sig = self.__sk.sign(h32)

async def test_model(self):
pprint(self.connect_info().model_dump())
start = time.monotonic()
Expand Down Expand Up @@ -207,7 +240,8 @@ def _get_connect_info(self) -> ConnectMessage:

connect_msg = ConnectMessage(
worker_version=VERSION,
worker_id=self.conf.worker_id,
pubkey=self.pubkey,
slug=self.slug,
ln_url=self.conf.ln_address, # todo: remove eventually
ln_address=self.conf.ln_address,
auth_key=self.conf.auth_key,
Expand All @@ -216,6 +250,8 @@ def _get_connect_info(self) -> ConnectMessage:
vram=psutil.virtual_memory().available,
)

self.sign(connect_msg)

try:
nv = nvidia_smi.getInstance()
dq = nv.DeviceQuery()
Expand Down Expand Up @@ -353,6 +389,8 @@ def main(argv=None):
help=description,
action="store_true" if field.annotation is bool else "store",
)
if field.default:
args["default"] = field.default
if field.annotation is bool:
args.pop("type")
arg_names.append(name)
Expand All @@ -365,6 +403,13 @@ def main(argv=None):

args = parser.parse_args(args=argv)

if os.path.exists(args.config):
with open(args.config, "r", encoding="utf8") as fh:
for k, v in json.load(fh).items():
cv = getattr(args, k)
if cv is None or cv == Config.model_fields[k].default:
setattr(args, k, v)

if args.debug:
log.setLevel(logging.DEBUG)

Expand All @@ -380,6 +425,7 @@ def main(argv=None):

conf = Config(**{k: getattr(args, k) for k in arg_names if getattr(args, k) is not None})

log.debug("config: %s", conf)
wm = WorkerMain(conf)

asyncio.run(wm.run())
Loading

0 comments on commit 4eaa02f

Please sign in to comment.