Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions src/aleph/vm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,16 @@ async def fetch_port_redirect_config_and_setup(self):
try:
port_forwarding_settings = await get_user_settings(message.address, "port-forwarding")
vm_port_forwarding = port_forwarding_settings.get(self.vm_hash, {}) or {}
ports_requests = vm_port_forwarding.get("ports", {})
fetched_ports_requests = vm_port_forwarding.get("ports", {})
# Force port always to be int and save it as int
ports_requests = {int(key): value for key, value in fetched_ports_requests.items()}
# Always forward port 22
if not ports_requests.get(22, None):
ports_requests[22] = {"tcp": True, "udp": False}

except Exception:
logger.info("Could not fetch the port redirect settings for user %s", message.address, exc_info=True)

# Always forward port 22
ports_requests[22] = {"tcp": True, "udp": False}

await self.update_port_redirects(ports_requests)

async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool]]):
Expand All @@ -130,40 +133,39 @@ async def update_port_redirects(self, requested_ports: dict[int, dict[str, bool]
current = self.mapped_ports[vm_port]
for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT:
if current[protocol]:
host_port = current["host"]
host_port = int(current["host"])
remove_port_redirect_rule(interface, host_port, vm_port, protocol)
del self.mapped_ports[vm_port]
del self.mapped_ports[int(vm_port)]
for vm_port in redirect_to_add:
target = requested_ports[vm_port]
host_port = fast_get_available_host_port()

for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT:
if target[protocol]:
add_port_redirect_rule(self.vm.vm_id, interface, host_port, vm_port, protocol)
self.mapped_ports[vm_port] = {"host": host_port, **target}
self.mapped_ports[int(vm_port)] = {"host": host_port, **target}

for vm_port in redirect_to_check:
current = self.mapped_ports[vm_port]
target = requested_ports[vm_port]
host_port = current["host"]
host_port = int(current["host"])
for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT:
if current[protocol] != target[protocol]:
if target[protocol]:
add_port_redirect_rule(self.vm.vm_id, interface, host_port, vm_port, protocol)
else:
remove_port_redirect_rule(interface, host_port, vm_port, protocol)
self.mapped_ports[vm_port] = {"host": host_port, **target}
self.mapped_ports[int(vm_port)] = {"host": host_port, **target}

# Save to DB
if self.record:
self.record.mapped_ports = self.mapped_ports
await save_record(self.record)
await self.save()

async def removed_all_ports_redirection(self):
if not self.vm:
return
interface = self.vm.tap_interface
# copy in a list since we modify dict during iteration
self.mapped_ports = {int(key): value for key, value in self.mapped_ports.items()}
for vm_port, map_detail in list(self.mapped_ports.items()):
host_port = map_detail["host"]
for protocol in SUPPORTED_PROTOCOL_FOR_REDIRECT:
Expand Down Expand Up @@ -543,8 +545,10 @@ async def stop(self) -> None:
self.times.stopping_at = datetime.now(tz=timezone.utc)
await self.all_runs_complete()
await self.record_usage()
await self.vm.teardown()
# First remove existing redirect rules for that VM
await self.removed_all_ports_redirection()
# After do the teardown
await self.vm.teardown()

self.times.stopped_at = datetime.now(tz=timezone.utc)
self.cancel_expiration()
Expand Down Expand Up @@ -608,6 +612,7 @@ async def save(self):
original_message=self.original.model_dump_json(),
persistent=self.persistent,
gpus=json.dumps(self.gpus, default=pydantic_encoder),
mapped_ports=self.mapped_ports,
)
pid_info = self.vm.to_dict() if self.vm else None
# Handle cases when the process cannot be accessed
Expand Down
137 changes: 131 additions & 6 deletions src/aleph/vm/network/firewall.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,11 +621,11 @@ def add_port_redirect_rule(
"match": {
"op": "==",
"left": {"payload": {"protocol": protocol, "field": "dport"}},
"right": host_port,
"right": int(host_port),
}
},
{
"dnat": {"addr": str(interface.guest_ip.ip), "port": vm_port},
"dnat": {"addr": str(interface.guest_ip.ip), "port": int(vm_port)},
},
],
}
Expand All @@ -648,7 +648,7 @@ def add_port_redirect_rule(
"match": {
"op": "==",
"left": {"payload": {"protocol": protocol, "field": "dport"}},
"right": vm_port,
"right": int(vm_port),
}
},
{"accept": None},
Expand Down Expand Up @@ -696,19 +696,20 @@ def remove_port_redirect_rule(interface: TapInterface, host_port: int, vm_port:
and "match" in expr[1]
and expr[1]["match"]["left"].get("payload", {}).get("protocol") == protocol
and expr[1]["match"]["left"]["payload"].get("field") == "dport"
and expr[1]["match"]["right"] == host_port
and int(expr[1]["match"]["right"]) == int(host_port)
and "dnat" in expr[2]
and expr[2]["dnat"].get("addr") == str(interface.guest_ip.ip)
and expr[2]["dnat"].get("port") == vm_port
and int(expr[2]["dnat"].get("port")) == int(vm_port)
):
rule_handle = entry["rule"]["handle"]
commands.append(
{
"delete": {
"rule": {
"family": "ip",
"table": prerouting_table,
"chain": chain_name,
"handle": entry["rule"]["handle"],
"handle": rule_handle,
}
}
}
Expand Down Expand Up @@ -761,3 +762,127 @@ def check_nftables_redirections(port: int) -> bool:
except Exception as e:
logger.warning(f"Error checking NAT redirections: {e}")
return False


def get_all_aleph_chains() -> list[str]:
"""Query nftables ruleset and return all chains created by aleph software.

This function scans the entire nftables ruleset and identifies all chains
whose names start with the configured NFTABLES_CHAIN_PREFIX. This includes
both supervisor chains (e.g., aleph-supervisor-nat, aleph-supervisor-filter,
aleph-supervisor-prerouting) and VM-specific chains (e.g., aleph-vm-nat-123,
aleph-vm-filter-123).

Returns:
A list of chain names that belong to aleph software

Raises:
Exception: If the nftables query fails
"""
logger.debug("Querying nftables for all aleph-related chains")
nft_ruleset = get_existing_nftables_ruleset()
aleph_chains = []

for entry in nft_ruleset:
if isinstance(entry, dict) and "chain" in entry:
chain_name = entry["chain"].get("name", "")
# Find all chains created by aleph software
if chain_name.startswith(settings.NFTABLES_CHAIN_PREFIX):
aleph_chains.append(chain_name)
logger.debug(f"Found aleph chain: {chain_name}")

logger.info(f"Found {len(aleph_chains)} aleph-related chains")
return aleph_chains


def remove_all_aleph_chains() -> tuple[list[str], list[tuple[str, str]]]:
"""Remove all chains created by aleph software from the nftables ruleset.

This function queries the nftables ruleset to find all chains that start with
the configured NFTABLES_CHAIN_PREFIX, then attempts to remove each one. This
ensures a clean slate by removing both tracked and untracked chains that may
have been left behind due to software crashes or inconsistent state.

The function uses the remove_chain() helper which handles:
- Removing all rules that jump to the chain
- Removing the chain itself

Returns:
A tuple containing:
- List of successfully removed chain names
- List of tuples (chain_name, error_message) for failed removals

Example:
removed, failed = remove_all_aleph_chains()
if failed:
logger.warning(f"Failed to remove {len(failed)} chains")
"""
logger.info("Removing all aleph-related chains from nftables")
aleph_chains = get_all_aleph_chains()

removed_chains = []
failed_chains = []

for chain_name in aleph_chains:
try:
remove_chain(chain_name)
removed_chains.append(chain_name)
logger.debug(f"Successfully removed chain: {chain_name}")
except Exception as e:
error_msg = str(e)
failed_chains.append((chain_name, error_msg))
logger.warning(f"Failed to remove chain {chain_name}: {error_msg}")

logger.info(f"Chain removal complete. Removed: {len(removed_chains)}, Failed: {len(failed_chains)}")
return removed_chains, failed_chains


def recreate_network_for_vms(vm_configurations: list[dict]) -> tuple[list[str], list[dict]]:
"""Recreate network rules for a list of VMs.

This function sets up nftables chains and rules for each VM in the provided list.
For each VM, it creates:
- NAT chain and masquerading rules for outbound traffic
- Filter chain and forwarding rules for traffic control
- Port forwarding rules if the VM is an instance (handled by caller)

Args:
vm_configurations: List of dictionaries, each containing:
- vm_id: Integer ID of the VM
- tap_interface: TapInterface object for the VM
- vm_hash: ItemHash of the VM (for logging)

Returns:
A tuple containing:
- List of successfully recreated VM hashes (as strings)
- List of dictionaries with failed VMs:
[{"vm_hash": str, "error": str}, ...]

Example:
vms = [
{"vm_id": 1, "tap_interface": tap1, "vm_hash": hash1},
{"vm_id": 2, "tap_interface": tap2, "vm_hash": hash2},
]
recreated, failed = recreate_network_for_vms(vms)
"""
logger.info(f"Recreating network rules for {len(vm_configurations)} VMs")
recreated_vms = []
failed_vms = []

for vm_config in vm_configurations:
vm_id = vm_config["vm_id"]
tap_interface = vm_config["tap_interface"]
vm_hash = vm_config["vm_hash"]

try:
# Recreate the basic VM network chains and rules
setup_nftables_for_vm(vm_id, tap_interface)
recreated_vms.append(str(vm_hash))
logger.debug(f"Recreated nftables for VM {vm_hash} (vm_id={vm_id})")
except Exception as e:
error_msg = str(e)
failed_vms.append({"vm_hash": str(vm_hash), "error": error_msg})
logger.error(f"Failed to recreate network for VM {vm_hash}: {error_msg}")

logger.info(f"VM network recreation complete. Success: {len(recreated_vms)}, Failed: {len(failed_vms)}")
return recreated_vms, failed_vms
2 changes: 1 addition & 1 deletion src/aleph/vm/network/port_availability_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ def fast_get_available_host_port() -> int:
LAST_ASSIGNED_HOST_PORT = host_port
if LAST_ASSIGNED_HOST_PORT > MAX_PORT:
LAST_ASSIGNED_HOST_PORT = MIN_DYNAMIC_PORT
return host_port
return int(host_port)
2 changes: 2 additions & 0 deletions src/aleph/vm/orchestrator/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
notify_allocation,
operate_reserve_resources,
operate_update,
recreate_network,
run_code_from_hostname,
run_code_from_path,
status_check_fastapi,
Expand Down Expand Up @@ -164,6 +165,7 @@ def setup_webapp(pool: VmPool | None):
other_routes = [
# /control APIs are used to control the VMs and access their logs
web.post("/control/allocations", update_allocations),
web.post("/control/network/recreate", recreate_network),
# Raise an HTTP Error 404 if attempting to access an unknown URL within these paths.
web.get("/about/{suffix:.*}", http_not_found),
web.get("/control/{suffix:.*}", http_not_found),
Expand Down
Loading
Loading