In [None]:
#| default_exp nostr

# nostr

> key adjustments to python-nostr to make sure things work properly

This repository relies heavily on [python-nostr](https://github.com/jeffthibault/python-nostr), which is still in active development. There are a few features that are not yet integrated into the python-nostr library that I will add here. The classes listed in here will be used in the rebroadcastr client instead of using the classes directly from python nostr

In [None]:
#| export
from nostr import key
from nostr import bech32
from fastcore.utils import *

In [None]:
#| hide
from nbdev.showdoc import *

## PrivateKey class
adding `from_hex` method

In [None]:
#| export

class PrivateKey(key.PrivateKey):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @classmethod
    def from_hex(cls, hex: str) -> 'PrivateKey':
        return cls(bytes.fromhex(hex))

### tests
make sure we can generate private keys in various ways

In [None]:
private_key = PrivateKey()

assert private_key.hex() == PrivateKey.from_hex(private_key.hex()).hex()
assert private_key.bech32() == PrivateKey.from_nsec(private_key.bech32()).bech32()

## PublicKey class
adding `from_hex` and `from_npub` methods

In [None]:
#| export
from nostr import bech32


class PublicKey(key.PublicKey):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    @classmethod
    def from_npub(cls, npub: str):
        """ Load a PublicKey from its bech32/nsec form """
        hrp, data, spec = bech32.bech32_decode(npub)
        raw_bytes = bech32.convertbits(data, 5, 8)[:-1]
        return cls(bytes(raw_bytes))

    @classmethod
    def from_hex(cls, hex: str) -> 'PrivateKey':
        return cls(bytes.fromhex(hex))

### tests
make sure we can generate public keys in various ways

In [None]:
public_key = PublicKey(private_key.raw_secret)

assert public_key.hex() == PublicKey.from_hex(public_key.hex()).hex()
assert public_key.bech32() == PublicKey.from_npub(public_key.bech32()).bech32()

## relay class changes

In [None]:
#| export
import json
import time
import threading
from threading import Lock
from typing import Union
from queue import Queue
from nostr import message_pool
from nostr import relay, relay_manager
from nostr.relay import RelayPolicy
from nostr.message_pool import EventMessage, NoticeMessage, EndOfStoredEventsMessage
from nostr.message_type import RelayMessageType
from nostr.event import Event

We have to make a critical change to the `MessagePool` class that allows us to keep multiple of an event if it came from a separate relay. We now tell the message_pool to store any unique event/url combination instead of just a unique id

This is now set with a `first_reponse_only` arg that defaults to True - in the case of `rebroadcastr` we will primarily set to False so that we can check for a message on multiple relays

In [None]:
#| export

class MessagePool(relay_manager.MessagePool):
    def __init__(self, first_response_only: bool = True):
        self.first_response_only = first_response_only
        self.events: Queue[EventMessage] = Queue()
        self.notices: Queue[NoticeMessage] = Queue()
        self.eose_notices: Queue[EndOfStoredEventsMessage] = Queue()
        self._unique_objects: set = set()
        self.lock: Lock = Lock()

    def __init__(self, first_response_only: bool = True) -> None:
        self.first_response_only = first_response_only

    def _process_message(self, message: str, url: str):
        message_json = json.loads(message)
        message_type = message_json[0]
        if message_type == RelayMessageType.EVENT:
            subscription_id = message_json[1]
            e = message_json[2]
            event = Event(e['pubkey'], e['content'], e['created_at'], e['kind'], e['tags'], e['id'], e['sig'])
            with self.lock:
                if self.first_response_only:
                    object_id = event.id
                else:
                    object_id = f'{event.id}:{url}'
                if object_id not in self._unique_objects:
                    self.events.put(EventMessage(event, subscription_id, url))
                    self._unique_objects.add(event.id)


We add a context manager for connections for both the `Relay` and `RelayManager` classes. The `Relay` class also gets a new method to make sure that we can open a single relay connection - the connect method has to be called with threading to allow the python script to continue even when we are only connecting to a single relay. An `is_connected` property also lets us easily check the connection status

In [None]:
#| export

class Connection:
    def __init__(self, relay_or_manager: Union[relay.Relay, relay_manager.RelayManager],
                 *args, **kwargs):
        self.relay_manager = relay_or_manager
        self.conn = self.relay_manager.open_connections(*args, **kwargs)
    def __enter__(self):
        return self.conn
    def __exit__(self, type, value, traceback):
        self.relay_manager.close_connections()


class Relay(relay.Relay):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __repr__(self):
        return json.dumps(self.to_json_object(), indent=2)

    @property
    def is_connected(self) -> bool:
        return False if self.ws.sock is None else self.ws.sock.connected
    
    def open_connections(self, ssl_options: dict={}):
        threading.Thread(
                target=self.connect,
                args=(ssl_options,),
                name=f"{self.url}-thread"
        ).start()
    
    def close_connections(self):
        self.close()
    
    def connection(self, *args, **kwargs):
        return Connection(self, *args, **kwargs)

### tests

In [None]:
import ssl

In [None]:
a_relay = Relay(url='wss://relay.nostr.ch',
                policy=RelayPolicy(),
                message_pool=MessagePool()
                )

assert not a_relay.is_connected
with a_relay.connection():
    time.sleep(1)
    assert a_relay.is_connected
assert not a_relay.is_connected

## RelayManager class changes

In [None]:
#| export

from nostr import relay_manager
from nostr.relay import RelayPolicy

The `RelayManager` class gets a connection status attribute, an `__iter__` method, and now references our updated `Relay` class

In [None]:
#| export

class RelayManager(relay_manager.RelayManager):
    def __init__(self, first_response_only: bool = True,  *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.relays: dict[str, Relay] = {}
        self.message_pool = MessagePool(first_response_only=first_response_only)
        self._is_connected = False

    def __iter__(self):
        return iter(self.relays.values())
    
    def connection(self, *args, **kwargs):
        return Connection(self, *args, **kwargs)
    
    def open_connections(self, ssl_options: dict=None):
        for relay in self.relays.values():
            threading.Thread(
                target=relay.connect,
                args=(ssl_options,),
                name=f"{relay.url}-thread"
            ).start()
        time.sleep(1)
        self.remove_closed_relays()
        assert all(self.connection_statuses.values())
        self._is_connected = True
    
    def close_connections(self):
        for relay in self.relays.values():
            if relay.is_connected:
                relay.close()
        assert not any(self.connection_statuses.values())
        self._is_connected = False

    def remove_closed_relays(self):
        for url, connected in self.connection_statuses.items():
            if not connected:
                warnings.warn(
                    f'{url} is not connected... removing relay.'
                )
                self.remove_relay(url=url)

    def add_relay(self, url: str, read: bool=True, write: bool=True, subscriptions={}):
        policy = RelayPolicy(read, write)
        relay = Relay(url, policy, self.message_pool, subscriptions)
        self.relays[url] = relay
    
    def remove_relay(self, url: str):
        if self.relays[url].is_connected:
            self.relays[url].close
        self.relays.pop(url)

    @property
    def connection_statuses(self) -> dict:
        """gets the url and connection statuses of relays

        Returns:
            dict: bool of connection statuses
        """
        statuses = [relay.is_connected for relay in self]
        return dict(zip(self.relays.keys(), statuses))

In [None]:
manager = RelayManager()
urls=['wss://relay.nostr.ch', 'wss://relay.damus.io']
for url in urls:
    manager.add_relay(url=url)

with manager.connection():
    print('opening')
    time.sleep(1)
    print(manager.connection_statuses)
    assert all(manager.connection_statuses.values())
    print('closing')
print(manager.connection_statuses)
assert not any(manager.connection_statuses.values())


opening
{'wss://relay.nostr.ch': True, 'wss://relay.damus.io': True}
closing
{'wss://relay.nostr.ch': False, 'wss://relay.damus.io': False}


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()