Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions src/aleph/vm/network/firewall.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,3 +762,127 @@ def check_nftables_redirections(port: int) -> bool:
except Exception as e:
logger.warning(f"Error checking NAT redirections: {e}")
return False


def get_all_aleph_chains() -> list[str]:
"""Query nftables ruleset and return all chains created by aleph software.

This function scans the entire nftables ruleset and identifies all chains
whose names start with the configured NFTABLES_CHAIN_PREFIX. This includes
both supervisor chains (e.g., aleph-supervisor-nat, aleph-supervisor-filter,
aleph-supervisor-prerouting) and VM-specific chains (e.g., aleph-vm-nat-123,
aleph-vm-filter-123).

Returns:
A list of chain names that belong to aleph software

Raises:
Exception: If the nftables query fails
"""
logger.debug("Querying nftables for all aleph-related chains")
nft_ruleset = get_existing_nftables_ruleset()
aleph_chains = []

for entry in nft_ruleset:
if isinstance(entry, dict) and "chain" in entry:
chain_name = entry["chain"].get("name", "")
# Find all chains created by aleph software
if chain_name.startswith(settings.NFTABLES_CHAIN_PREFIX):
aleph_chains.append(chain_name)
logger.debug(f"Found aleph chain: {chain_name}")

logger.info(f"Found {len(aleph_chains)} aleph-related chains")
return aleph_chains


def remove_all_aleph_chains() -> tuple[list[str], list[tuple[str, str]]]:
"""Remove all chains created by aleph software from the nftables ruleset.

This function queries the nftables ruleset to find all chains that start with
the configured NFTABLES_CHAIN_PREFIX, then attempts to remove each one. This
ensures a clean slate by removing both tracked and untracked chains that may
have been left behind due to software crashes or inconsistent state.

The function uses the remove_chain() helper which handles:
- Removing all rules that jump to the chain
- Removing the chain itself

Returns:
A tuple containing:
- List of successfully removed chain names
- List of tuples (chain_name, error_message) for failed removals

Example:
removed, failed = remove_all_aleph_chains()
if failed:
logger.warning(f"Failed to remove {len(failed)} chains")
"""
logger.info("Removing all aleph-related chains from nftables")
aleph_chains = get_all_aleph_chains()

removed_chains = []
failed_chains = []

for chain_name in aleph_chains:
try:
remove_chain(chain_name)
removed_chains.append(chain_name)
logger.debug(f"Successfully removed chain: {chain_name}")
except Exception as e:
error_msg = str(e)
failed_chains.append((chain_name, error_msg))
logger.warning(f"Failed to remove chain {chain_name}: {error_msg}")

logger.info(f"Chain removal complete. Removed: {len(removed_chains)}, Failed: {len(failed_chains)}")
return removed_chains, failed_chains


def recreate_network_for_vms(vm_configurations: list[dict]) -> tuple[list[str], list[dict]]:
"""Recreate network rules for a list of VMs.

This function sets up nftables chains and rules for each VM in the provided list.
For each VM, it creates:
- NAT chain and masquerading rules for outbound traffic
- Filter chain and forwarding rules for traffic control
- Port forwarding rules if the VM is an instance (handled by caller)

Args:
vm_configurations: List of dictionaries, each containing:
- vm_id: Integer ID of the VM
- tap_interface: TapInterface object for the VM
- vm_hash: ItemHash of the VM (for logging)

Returns:
A tuple containing:
- List of successfully recreated VM hashes (as strings)
- List of dictionaries with failed VMs:
[{"vm_hash": str, "error": str}, ...]

Example:
vms = [
{"vm_id": 1, "tap_interface": tap1, "vm_hash": hash1},
{"vm_id": 2, "tap_interface": tap2, "vm_hash": hash2},
]
recreated, failed = recreate_network_for_vms(vms)
"""
logger.info(f"Recreating network rules for {len(vm_configurations)} VMs")
recreated_vms = []
failed_vms = []

for vm_config in vm_configurations:
vm_id = vm_config["vm_id"]
tap_interface = vm_config["tap_interface"]
vm_hash = vm_config["vm_hash"]

try:
# Recreate the basic VM network chains and rules
setup_nftables_for_vm(vm_id, tap_interface)
recreated_vms.append(str(vm_hash))
logger.debug(f"Recreated nftables for VM {vm_hash} (vm_id={vm_id})")
except Exception as e:
error_msg = str(e)
failed_vms.append({"vm_hash": str(vm_hash), "error": error_msg})
logger.error(f"Failed to recreate network for VM {vm_hash}: {error_msg}")

logger.info(f"VM network recreation complete. Success: {len(recreated_vms)}, Failed: {len(failed_vms)}")
return recreated_vms, failed_vms
2 changes: 2 additions & 0 deletions src/aleph/vm/orchestrator/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
notify_allocation,
operate_reserve_resources,
operate_update,
recreate_network,
run_code_from_hostname,
run_code_from_path,
status_check_fastapi,
Expand Down Expand Up @@ -164,6 +165,7 @@ def setup_webapp(pool: VmPool | None):
other_routes = [
# /control APIs are used to control the VMs and access their logs
web.post("/control/allocations", update_allocations),
web.post("/control/network/recreate", recreate_network),
# Raise an HTTP Error 404 if attempting to access an unknown URL within these paths.
web.get("/about/{suffix:.*}", http_not_found),
web.get("/control/{suffix:.*}", http_not_found),
Expand Down
130 changes: 130 additions & 0 deletions src/aleph/vm/orchestrator/views/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
from aleph.vm.controllers.firecracker.program import FileTooLargeError
from aleph.vm.hypervisors.firecracker.microvm import MicroVMFailedInitError
from aleph.vm.models import VmExecution
from aleph.vm.network.firewall import (
initialize_nftables,
recreate_network_for_vms,
remove_all_aleph_chains,
)
from aleph.vm.orchestrator import payment, status
from aleph.vm.orchestrator.chain import STREAM_CHAINS
from aleph.vm.orchestrator.custom_logs import set_vm_for_logging
Expand Down Expand Up @@ -429,6 +434,7 @@ def authenticate_api_request(request: web.Request) -> bool:


allocation_lock = None
network_recreation_lock = None


async def update_allocations(request: web.Request):
Expand Down Expand Up @@ -547,6 +553,130 @@ async def update_allocations(request: web.Request):
)


async def recreate_network(request: web.Request):
"""Recreate network settings for the CRN and all running VMs.

This endpoint performs a complete network reconfiguration by:
1. Querying the nftables ruleset to find all aleph-related chains
2. Removing ALL chains created by aleph software (both tracked and untracked)
including VM-specific chains and supervisor chains
3. Re-initializing the base network setup with nftables (creating fresh
supervisor chains: aleph-supervisor-nat, aleph-supervisor-filter,
aleph-supervisor-prerouting)
4. Recreating VM-specific chains and rules for each currently running VM
5. Restoring port forwarding rules for all running instances

This method is designed to handle cases where:
- Network rules have become duplicated or inconsistent
- Chains exist on the host that are no longer tracked by the software
- The firewall state needs to be reset to match the current VM pool

The operation is atomic and uses a lock to prevent concurrent modifications.

Returns:
JSON response with:
- success: Boolean indicating if all VMs were successfully recreated
- removed_chains_count: Number of chains that were removed
- removed_chains: List of chain names that were removed
- recreated_count: Number of VMs that were successfully recreated
- failed_count: Number of VMs that failed to recreate
- recreated_vms: List of VM hashes that were recreated
- failed_vms: List of VM hashes and errors for failed recreations
"""
if not authenticate_api_request(request):
return web.HTTPUnauthorized(text="Authentication token received is invalid")

global network_recreation_lock
if network_recreation_lock is None:
network_recreation_lock = asyncio.Lock()

pool: VmPool = request.app["vm_pool"]

async with network_recreation_lock:
logger.info("Starting network recreation process")

# Step 1: Collect all running VMs and their network configuration
running_vms = []
for vm_hash, execution in pool.executions.items():
if execution.is_running and execution.vm and execution.vm.tap_interface:
running_vms.append(
{
"vm_hash": vm_hash,
"vm_id": execution.vm.vm_id,
"tap_interface": execution.vm.tap_interface,
"execution": execution,
}
)
logger.debug(f"Found running VM {vm_hash} with vm_id={execution.vm.vm_id}")

logger.info(f"Found {len(running_vms)} running VMs to recreate network rules for")

# Step 2: Remove all aleph-related chains (VM-specific and supervisor chains)
try:
removed_chains, failed_removals = remove_all_aleph_chains()
if failed_removals:
logger.warning(f"Failed to remove {len(failed_removals)} chains")
for chain_name, error in failed_removals:
logger.warning(f" - {chain_name}: {error}")
except Exception as e:
logger.error(f"Error removing aleph chains: {e}")
return web.json_response(
{"success": False, "error": f"Failed to remove existing chains: {str(e)}"},
status=500,
)

# Step 3: Re-initialize the base network setup
logger.info("Re-initializing nftables")
try:
initialize_nftables()
except Exception as e:
logger.error(f"Error initializing nftables: {e}")
return web.json_response(
{"success": False, "error": f"Failed to initialize network: {str(e)}"},
status=500,
)

# Step 4: Recreate VM-specific chains and rules
try:
recreated_vms, failed_vms = recreate_network_for_vms(running_vms)
except Exception as e:
logger.error(f"Error recreating VM networks: {e}")
return web.json_response(
{"success": False, "error": f"Failed to recreate VM networks: {str(e)}"},
status=500,
)

# Step 5: Recreate port forwarding rules for instances
logger.info("Recreating port forwarding rules for instances")
for vm_info in running_vms:
execution = vm_info["execution"]
if execution.is_instance and str(vm_info["vm_hash"]) in recreated_vms:
try:
await execution.fetch_port_redirect_config_and_setup()
logger.debug(f"Recreated port redirects for instance {vm_info['vm_hash']}")
except Exception as e:
logger.error(f"Error recreating port redirects for VM {vm_info['vm_hash']}: {e}")
# Don't add to failed_vms as the VM network itself was created successfully

logger.info(
f"Network recreation complete. Removed chains: {len(removed_chains)}, "
f"Recreated VMs: {len(recreated_vms)}, Failed: {len(failed_vms)}"
)

return web.json_response(
{
"success": len(failed_vms) == 0,
"removed_chains_count": len(removed_chains),
"removed_chains": removed_chains,
"recreated_count": len(recreated_vms),
"failed_count": len(failed_vms),
"recreated_vms": recreated_vms,
"failed_vms": failed_vms,
},
status=200 if len(failed_vms) == 0 else 207,
)


@cors_allow_all
async def notify_allocation(request: web.Request):
"""Notify instance allocation, only used for Pay as you Go feature"""
Expand Down
Loading