diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index 6b40c8aef..c042c530b 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -108,13 +108,16 @@ async def fetch_port_redirect_config_and_setup(self): try: port_forwarding_settings = await get_user_settings(message.address, "port-forwarding") vm_port_forwarding = port_forwarding_settings.get(self.vm_hash, {}) or {} - ports_requests = vm_port_forwarding.get("ports", {}) + fetched_ports_requests = vm_port_forwarding.get("ports", {}) + # Force port always to be int and save it as int + ports_requests = {int(key): value for key, value in fetched_ports_requests.items()} + # Always forward port 22 + if not ports_requests.get(22, None): + ports_requests[22] = {"tcp": True, "udp": False} + except Exception: logger.info("Could not fetch the port redirect settings for user %s", message.address, exc_info=True) - # Always forward port 22 - ports_requests[22] = {"tcp": True, "udp": False} - await self.update_port_redirects(ports_requests) async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool]]): @@ -130,9 +133,9 @@ async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool] current = self.mapped_ports[vm_port] for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: if current[protocol]: - host_port = current["host"] + host_port = int(current["host"]) remove_port_redirect_rule(interface, host_port, vm_port, protocol) - del self.mapped_ports[vm_port] + del self.mapped_ports[int(vm_port)] for vm_port in redirect_to_add: target = requested_ports[vm_port] host_port = fast_get_available_host_port() @@ -140,30 +143,29 @@ async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool] for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: if target[protocol]: add_port_redirect_rule(self.vm.vm_id, interface, host_port, vm_port, protocol) - self.mapped_ports[vm_port] = {"host": host_port, **target} + self.mapped_ports[int(vm_port)] = {"host": host_port, **target} for vm_port in redirect_to_check: current = self.mapped_ports[vm_port] target = requested_ports[vm_port] - host_port = current["host"] + host_port = int(current["host"]) for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: if current[protocol] != target[protocol]: if target[protocol]: add_port_redirect_rule(self.vm.vm_id, interface, host_port, vm_port, protocol) else: remove_port_redirect_rule(interface, host_port, vm_port, protocol) - self.mapped_ports[vm_port] = {"host": host_port, **target} + self.mapped_ports[int(vm_port)] = {"host": host_port, **target} # Save to DB - if self.record: - self.record.mapped_ports = self.mapped_ports - await save_record(self.record) + await self.save() async def removed_all_ports_redirection(self): if not self.vm: return interface = self.vm.tap_interface # copy in a list since we modify dict during iteration + self.mapped_ports = {int(key): value for key, value in self.mapped_ports.items()} for vm_port, map_detail in list(self.mapped_ports.items()): host_port = map_detail["host"] for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT: @@ -543,8 +545,10 @@ async def stop(self) -> None: self.times.stopping_at = datetime.now(tz=timezone.utc) await self.all_runs_complete() await self.record_usage() - await self.vm.teardown() + # First remove existing redirect rules for that VM await self.removed_all_ports_redirection() + # After do the teardown + await self.vm.teardown() self.times.stopped_at = datetime.now(tz=timezone.utc) self.cancel_expiration() @@ -608,6 +612,7 @@ async def save(self): original_message=self.original.model_dump_json(), persistent=self.persistent, gpus=json.dumps(self.gpus, default=pydantic_encoder), + mapped_ports=self.mapped_ports, ) pid_info = self.vm.to_dict() if self.vm else None # Handle cases when the process cannot be accessed diff --git a/src/aleph/vm/network/firewall.py b/src/aleph/vm/network/firewall.py index fc8babcda..7f9b99156 100644 --- a/src/aleph/vm/network/firewall.py +++ b/src/aleph/vm/network/firewall.py @@ -621,11 +621,11 @@ def add_port_redirect_rule( "match": { "op": "==", "left": {"payload": {"protocol": protocol, "field": "dport"}}, - "right": host_port, + "right": int(host_port), } }, { - "dnat": {"addr": str(interface.guest_ip.ip), "port": vm_port}, + "dnat": {"addr": str(interface.guest_ip.ip), "port": int(vm_port)}, }, ], } @@ -648,7 +648,7 @@ def add_port_redirect_rule( "match": { "op": "==", "left": {"payload": {"protocol": protocol, "field": "dport"}}, - "right": vm_port, + "right": int(vm_port), } }, {"accept": None}, @@ -696,11 +696,12 @@ def remove_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: and "match" in expr[1] and expr[1]["match"]["left"].get("payload", {}).get("protocol") == protocol and expr[1]["match"]["left"]["payload"].get("field") == "dport" - and expr[1]["match"]["right"] == host_port + and int(expr[1]["match"]["right"]) == int(host_port) and "dnat" in expr[2] and expr[2]["dnat"].get("addr") == str(interface.guest_ip.ip) - and expr[2]["dnat"].get("port") == vm_port + and int(expr[2]["dnat"].get("port")) == int(vm_port) ): + rule_handle = entry["rule"]["handle"] commands.append( { "delete": { @@ -708,7 +709,7 @@ def remove_port_redirect_rule(interface: TapInterface, host_port: int, vm_port: "family": "ip", "table": prerouting_table, "chain": chain_name, - "handle": entry["rule"]["handle"], + "handle": rule_handle, } } } @@ -761,3 +762,127 @@ def check_nftables_redirections(port: int) -> bool: except Exception as e: logger.warning(f"Error checking NAT redirections: {e}") return False + + +def get_all_aleph_chains() -> list[str]: + """Query nftables ruleset and return all chains created by aleph software. + + This function scans the entire nftables ruleset and identifies all chains + whose names start with the configured NFTABLES_CHAIN_PREFIX. This includes + both supervisor chains (e.g., aleph-supervisor-nat, aleph-supervisor-filter, + aleph-supervisor-prerouting) and VM-specific chains (e.g., aleph-vm-nat-123, + aleph-vm-filter-123). + + Returns: + A list of chain names that belong to aleph software + + Raises: + Exception: If the nftables query fails + """ + logger.debug("Querying nftables for all aleph-related chains") + nft_ruleset = get_existing_nftables_ruleset() + aleph_chains = [] + + for entry in nft_ruleset: + if isinstance(entry, dict) and "chain" in entry: + chain_name = entry["chain"].get("name", "") + # Find all chains created by aleph software + if chain_name.startswith(settings.NFTABLES_CHAIN_PREFIX): + aleph_chains.append(chain_name) + logger.debug(f"Found aleph chain: {chain_name}") + + logger.info(f"Found {len(aleph_chains)} aleph-related chains") + return aleph_chains + + +def remove_all_aleph_chains() -> tuple[list[str], list[tuple[str, str]]]: + """Remove all chains created by aleph software from the nftables ruleset. + + This function queries the nftables ruleset to find all chains that start with + the configured NFTABLES_CHAIN_PREFIX, then attempts to remove each one. This + ensures a clean slate by removing both tracked and untracked chains that may + have been left behind due to software crashes or inconsistent state. + + The function uses the remove_chain() helper which handles: + - Removing all rules that jump to the chain + - Removing the chain itself + + Returns: + A tuple containing: + - List of successfully removed chain names + - List of tuples (chain_name, error_message) for failed removals + + Example: + removed, failed = remove_all_aleph_chains() + if failed: + logger.warning(f"Failed to remove {len(failed)} chains") + """ + logger.info("Removing all aleph-related chains from nftables") + aleph_chains = get_all_aleph_chains() + + removed_chains = [] + failed_chains = [] + + for chain_name in aleph_chains: + try: + remove_chain(chain_name) + removed_chains.append(chain_name) + logger.debug(f"Successfully removed chain: {chain_name}") + except Exception as e: + error_msg = str(e) + failed_chains.append((chain_name, error_msg)) + logger.warning(f"Failed to remove chain {chain_name}: {error_msg}") + + logger.info(f"Chain removal complete. Removed: {len(removed_chains)}, Failed: {len(failed_chains)}") + return removed_chains, failed_chains + + +def recreate_network_for_vms(vm_configurations: list[dict]) -> tuple[list[str], list[dict]]: + """Recreate network rules for a list of VMs. + + This function sets up nftables chains and rules for each VM in the provided list. + For each VM, it creates: + - NAT chain and masquerading rules for outbound traffic + - Filter chain and forwarding rules for traffic control + - Port forwarding rules if the VM is an instance (handled by caller) + + Args: + vm_configurations: List of dictionaries, each containing: + - vm_id: Integer ID of the VM + - tap_interface: TapInterface object for the VM + - vm_hash: ItemHash of the VM (for logging) + + Returns: + A tuple containing: + - List of successfully recreated VM hashes (as strings) + - List of dictionaries with failed VMs: + [{"vm_hash": str, "error": str}, ...] + + Example: + vms = [ + {"vm_id": 1, "tap_interface": tap1, "vm_hash": hash1}, + {"vm_id": 2, "tap_interface": tap2, "vm_hash": hash2}, + ] + recreated, failed = recreate_network_for_vms(vms) + """ + logger.info(f"Recreating network rules for {len(vm_configurations)} VMs") + recreated_vms = [] + failed_vms = [] + + for vm_config in vm_configurations: + vm_id = vm_config["vm_id"] + tap_interface = vm_config["tap_interface"] + vm_hash = vm_config["vm_hash"] + + try: + # Recreate the basic VM network chains and rules + setup_nftables_for_vm(vm_id, tap_interface) + recreated_vms.append(str(vm_hash)) + logger.debug(f"Recreated nftables for VM {vm_hash} (vm_id={vm_id})") + except Exception as e: + error_msg = str(e) + failed_vms.append({"vm_hash": str(vm_hash), "error": error_msg}) + logger.error(f"Failed to recreate network for VM {vm_hash}: {error_msg}") + + logger.info(f"VM network recreation complete. Success: {len(recreated_vms)}, Failed: {len(failed_vms)}") + return recreated_vms, failed_vms diff --git a/src/aleph/vm/network/port_availability_checker.py b/src/aleph/vm/network/port_availability_checker.py index 693d261d5..0b85c483a 100644 --- a/src/aleph/vm/network/port_availability_checker.py +++ b/src/aleph/vm/network/port_availability_checker.py @@ -63,4 +63,4 @@ def fast_get_available_host_port() -> int: LAST_ASSIGNED_HOST_PORT = host_port if LAST_ASSIGNED_HOST_PORT > MAX_PORT: LAST_ASSIGNED_HOST_PORT = MIN_DYNAMIC_PORT - return host_port + return int(host_port) diff --git a/src/aleph/vm/orchestrator/supervisor.py b/src/aleph/vm/orchestrator/supervisor.py index b8f3061f3..5937f91ab 100644 --- a/src/aleph/vm/orchestrator/supervisor.py +++ b/src/aleph/vm/orchestrator/supervisor.py @@ -38,6 +38,7 @@ notify_allocation, operate_reserve_resources, operate_update, + recreate_network, run_code_from_hostname, run_code_from_path, status_check_fastapi, @@ -164,6 +165,7 @@ def setup_webapp(pool: VmPool | None): other_routes = [ # /control APIs are used to control the VMs and access their logs web.post("/control/allocations", update_allocations), + web.post("/control/network/recreate", recreate_network), # Raise an HTTP Error 404 if attempting to access an unknown URL within these paths. web.get("/about/{suffix:.*}", http_not_found), web.get("/control/{suffix:.*}", http_not_found), diff --git a/src/aleph/vm/orchestrator/views/__init__.py b/src/aleph/vm/orchestrator/views/__init__.py index 978b57029..b880f8ada 100644 --- a/src/aleph/vm/orchestrator/views/__init__.py +++ b/src/aleph/vm/orchestrator/views/__init__.py @@ -27,6 +27,11 @@ from aleph.vm.controllers.firecracker.program import FileTooLargeError from aleph.vm.hypervisors.firecracker.microvm import MicroVMFailedInitError from aleph.vm.models import VmExecution +from aleph.vm.network.firewall import ( + initialize_nftables, + recreate_network_for_vms, + remove_all_aleph_chains, +) from aleph.vm.orchestrator import payment, status from aleph.vm.orchestrator.chain import STREAM_CHAINS from aleph.vm.orchestrator.custom_logs import set_vm_for_logging @@ -429,6 +434,7 @@ def authenticate_api_request(request: web.Request) -> bool: allocation_lock = None +network_recreation_lock = None async def update_allocations(request: web.Request): @@ -547,6 +553,130 @@ async def update_allocations(request: web.Request): ) +async def recreate_network(request: web.Request): + """Recreate network settings for the CRN and all running VMs. + + This endpoint performs a complete network reconfiguration by: + 1. Querying the nftables ruleset to find all aleph-related chains + 2. Removing ALL chains created by aleph software (both tracked and untracked) + including VM-specific chains and supervisor chains + 3. Re-initializing the base network setup with nftables (creating fresh + supervisor chains: aleph-supervisor-nat, aleph-supervisor-filter, + aleph-supervisor-prerouting) + 4. Recreating VM-specific chains and rules for each currently running VM + 5. Restoring port forwarding rules for all running instances + + This method is designed to handle cases where: + - Network rules have become duplicated or inconsistent + - Chains exist on the host that are no longer tracked by the software + - The firewall state needs to be reset to match the current VM pool + + The operation is atomic and uses a lock to prevent concurrent modifications. + + Returns: + JSON response with: + - success: Boolean indicating if all VMs were successfully recreated + - removed_chains_count: Number of chains that were removed + - removed_chains: List of chain names that were removed + - recreated_count: Number of VMs that were successfully recreated + - failed_count: Number of VMs that failed to recreate + - recreated_vms: List of VM hashes that were recreated + - failed_vms: List of VM hashes and errors for failed recreations + """ + if not authenticate_api_request(request): + return web.HTTPUnauthorized(text="Authentication token received is invalid") + + global network_recreation_lock + if network_recreation_lock is None: + network_recreation_lock = asyncio.Lock() + + pool: VmPool = request.app["vm_pool"] + + async with network_recreation_lock: + logger.info("Starting network recreation process") + + # Step 1: Collect all running VMs and their network configuration + running_vms = [] + for vm_hash, execution in pool.executions.items(): + if execution.is_running and execution.vm and execution.vm.tap_interface: + running_vms.append( + { + "vm_hash": vm_hash, + "vm_id": execution.vm.vm_id, + "tap_interface": execution.vm.tap_interface, + "execution": execution, + } + ) + logger.debug(f"Found running VM {vm_hash} with vm_id={execution.vm.vm_id}") + + logger.info(f"Found {len(running_vms)} running VMs to recreate network rules for") + + # Step 2: Remove all aleph-related chains (VM-specific and supervisor chains) + try: + removed_chains, failed_removals = remove_all_aleph_chains() + if failed_removals: + logger.warning(f"Failed to remove {len(failed_removals)} chains") + for chain_name, error in failed_removals: + logger.warning(f" - {chain_name}: {error}") + except Exception as e: + logger.error(f"Error removing aleph chains: {e}") + return web.json_response( + {"success": False, "error": f"Failed to remove existing chains: {str(e)}"}, + status=500, + ) + + # Step 3: Re-initialize the base network setup + logger.info("Re-initializing nftables") + try: + initialize_nftables() + except Exception as e: + logger.error(f"Error initializing nftables: {e}") + return web.json_response( + {"success": False, "error": f"Failed to initialize network: {str(e)}"}, + status=500, + ) + + # Step 4: Recreate VM-specific chains and rules + try: + recreated_vms, failed_vms = recreate_network_for_vms(running_vms) + except Exception as e: + logger.error(f"Error recreating VM networks: {e}") + return web.json_response( + {"success": False, "error": f"Failed to recreate VM networks: {str(e)}"}, + status=500, + ) + + # Step 5: Recreate port forwarding rules for instances + logger.info("Recreating port forwarding rules for instances") + for vm_info in running_vms: + execution = vm_info["execution"] + if execution.is_instance and str(vm_info["vm_hash"]) in recreated_vms: + try: + await execution.fetch_port_redirect_config_and_setup() + logger.debug(f"Recreated port redirects for instance {vm_info['vm_hash']}") + except Exception as e: + logger.error(f"Error recreating port redirects for VM {vm_info['vm_hash']}: {e}") + # Don't add to failed_vms as the VM network itself was created successfully + + logger.info( + f"Network recreation complete. Removed chains: {len(removed_chains)}, " + f"Recreated VMs: {len(recreated_vms)}, Failed: {len(failed_vms)}" + ) + + return web.json_response( + { + "success": len(failed_vms) == 0, + "removed_chains_count": len(removed_chains), + "removed_chains": removed_chains, + "recreated_count": len(recreated_vms), + "failed_count": len(failed_vms), + "recreated_vms": recreated_vms, + "failed_vms": failed_vms, + }, + status=200 if len(failed_vms) == 0 else 207, + ) + + @cors_allow_all async def notify_allocation(request: web.Request): """Notify instance allocation, only used for Pay as you Go feature""" diff --git a/src/aleph/vm/pool.py b/src/aleph/vm/pool.py index 0383e6703..adf367417 100644 --- a/src/aleph/vm/pool.py +++ b/src/aleph/vm/pool.py @@ -184,8 +184,9 @@ async def create_a_vm( if resource in self.reservations: del self.reservations[resource] except Exception: - # ensure the VM is removed from the pool on creation error - await execution.removed_all_ports_redirection() + if execution.is_instance: + # ensure the VM is removed from the pool on creation error + await execution.removed_all_ports_redirection() self.forget_vm(vm_hash) raise @@ -290,9 +291,8 @@ async def load_persistent_executions(self): ) mapped_ports = saved_execution.mapped_ports if saved_execution.mapped_ports else {} - # Ensure the key are int and not string. They get converted when serialized in the db - for k, v in mapped_ports.items(): - execution.mapped_ports[int(k)] = v + execution.mapped_ports = {int(key): value for key, value in mapped_ports.items()} + logger.info("Loading existing mapped_ports %s", execution.mapped_ports) # Load and instantiate the rest of resources and already assigned GPUs await execution.prepare() @@ -329,6 +329,8 @@ async def load_persistent_executions(self): # Start the snapshot manager for the VM if vm.support_snapshot and self.snapshot_manager: await self.snapshot_manager.start_for(vm=execution.vm) + + # Refresh port redirection changes await execution.fetch_port_redirect_config_and_setup() self.executions[vm_hash] = execution