diff --git a/pyproject.toml b/pyproject.toml index e92a16ef..0fe9a48b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "protobuf==6.31.1", "dijkstar==2.6.0", "huggingface-hub", - "lattica==1.0.0", + "lattica==1.0.1", ] [project.scripts] diff --git a/src/backend/server/scheduler_manage.py b/src/backend/server/scheduler_manage.py index da3f3d60..b98f7700 100644 --- a/src/backend/server/scheduler_manage.py +++ b/src/backend/server/scheduler_manage.py @@ -133,28 +133,39 @@ def _start_lattica(self): self.lattica = Lattica.builder().with_listen_addrs(self.host_maddrs).with_key_path(".") if len(self.relay_servers) > 0: - print(f"Using relay servers: {self.relay_servers}") - self.lattica.with_relay_servers(self.relay_servers).with_dcutr(True) + logger.info(f"Using relay servers: {self.relay_servers}") + self.lattica.with_relay_servers(self.relay_servers).with_dcutr(True).with_protocol("") if len(self.announce_maddrs) > 0: - print(f"Using announce maddrs: {self.announce_maddrs}") + logger.info(f"Using announce maddrs: {self.announce_maddrs}") self.lattica.with_external_addrs(self.announce_maddrs) if len(self.initial_peers) > 0: - print(f"Using initial peers: {self.initial_peers}") + logger.info(f"Using initial peers: {self.initial_peers}") self.lattica.with_bootstraps(self.initial_peers) self.lattica.build() logger.debug("Lattica node built") - if self.lattica.store( - "scheduler_peer_id", - self.lattica.peer_id(), - expiration_time=time.time() + 365 * 24 * 60 * 60, - ): - logger.info(f"Stored scheduler peer id: {self.lattica.peer_id()}") - else: - logger.error("Failed to store scheduler peer id") + store_success = False + for _ in range(10): + try: + if self.lattica.store( + "scheduler_peer_id", + self.lattica.peer_id(), + expiration_time=time.time() + 365 * 24 * 60 * 60, + ): + logger.info(f"Stored scheduler peer id: {self.lattica.peer_id()}") + store_success = True + break + logger.warning("Failed to store scheduler peer id, waiting for 10 seconds") + time.sleep(10) + except Exception as e: + logger.error(f"Failed to store scheduler peer id: {e}, waiting for 10 seconds") + time.sleep(10) + + if not store_success: + logger.error("Failed to store scheduler peer id, after 10 times") exit(1) self.connection_handler = RPCConnectionHandler( diff --git a/src/parallax/p2p/server.py b/src/parallax/p2p/server.py index 508730ce..9102d572 100644 --- a/src/parallax/p2p/server.py +++ b/src/parallax/p2p/server.py @@ -180,9 +180,6 @@ def __init__( max_batch_size: Optional[int] = None, max_sequence_length: Optional[int] = None, ): - assert not ( - scheduler_addr is not None and len(initial_peers) > 0 - ), "scheduler_addr and initial_peers are not allowed at the same time" self.recv_from_peer_addr = recv_from_peer_addr self.send_to_peer_addr = send_to_peer_addr self.initial_peers = initial_peers @@ -220,9 +217,18 @@ def __init__( def build_lattica(self): self.lattica = Lattica.builder().with_listen_addrs(self.host_maddrs) + if self.scheduler_addr is not None and self.scheduler_addr != "auto": + if self.scheduler_addr.startswith("/"): + logger.info(f"Using scheduler addr: {self.scheduler_addr}") + self.lattica.with_bootstraps([self.scheduler_addr]) + self.scheduler_peer_id = self.scheduler_addr.split("/")[-1] + if len(self.relay_servers) > 0: logger.info(f"Using relay servers: {self.relay_servers}") self.lattica.with_relay_servers(self.relay_servers).with_dcutr(True) + if self.scheduler_peer_id is not None: + logger.info(f"Using protocol: /{self.scheduler_peer_id}") + self.lattica.with_protocol("/" + self.scheduler_peer_id) if len(self.announce_maddrs) > 0: logger.info(f"Using announce maddrs: {self.announce_maddrs}") @@ -232,11 +238,6 @@ def build_lattica(self): logger.info(f"Using initial peers: {self.initial_peers}") self.lattica.with_bootstraps(self.initial_peers) - if self.scheduler_addr is not None and self.scheduler_addr != "auto": - logger.info(f"Using scheduler addr: {self.scheduler_addr}") - self.lattica.with_bootstraps([self.scheduler_addr]) - self.scheduler_peer_id = self.scheduler_addr.split("/")[-1] - self.lattica.build() if self.scheduler_addr == "auto": @@ -272,7 +273,14 @@ def run(self): self.scheduler_stub = RPCConnectionHandler(self.lattica, None).get_stub( self.scheduler_peer_id ) - response = self.scheduler_stub.node_join(self.get_node_info()) + node_info = self.get_node_info() + if node_info == {}: + logger.error("Failed to get node info, try again after 10 seconds") + del self.lattica + self.lattica = None + time.sleep(10) + return self.run() + response = self.scheduler_stub.node_join(node_info) response = response.result(timeout=300) if response == {}: logger.error("Failed to join scheduler") @@ -508,7 +516,7 @@ def _announcer_thread(): while not self.stop_event.is_set(): # Announce the range ID try: - if self.scheduler_addr is not None: + if self.scheduler_peer_id is not None: self.scheduler_stub.node_update(self.get_node_info(is_update=True)) else: self.lattica.store( @@ -540,13 +548,18 @@ def get_node_info(self, is_update: bool = False): all_peers = [] for _ in range(1 if is_update else 30): all_peers = self.lattica.get_all_peers() - if len(all_peers) > 0: + if len(all_peers) > 0 and self.scheduler_peer_id in all_peers: break - logger.warning("No peers found, waiting for 1 second.") + logger.warning( + "No peers found or scheduler peer id not found, waiting for 1 second." + ) time.sleep(1) - if len(all_peers) == 0: - logger.warning("No peers found, send empty rtt_to_nodes.") + if len(all_peers) == 0 or self.scheduler_peer_id not in all_peers: + logger.warning( + "No peers found or scheduler peer id not found, return empty node info." + ) + return {} for peer_id in all_peers: rtt = None