In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import yaml
from primaite.simulator.sim_container import Simulation
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.hardware.nodes.router import Router

from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database_service import DatabaseService
from primaite.simulator.system.services.dns_client import DNSClient
from primaite.simulator.system.services.dns_server import DNSServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot


from primaite.simulator.network.hardware.nodes.router import ACLAction
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port

from ipaddress import IPv4Address


In [17]:
# import yaml


from typing import Dict
from primaite.simulator.network.hardware.base import NIC, Link, Node
from primaite.simulator.system.services.service import Service


class PrimaiteSession:

    def __init__(self):
        self.simulation: Simulation
        self.agents = []

    @classmethod
    def from_config(cls, cfg_path):
        ref_map_nodes: Dict[str,Node] = {}
        ref_map_services: Dict[str, Service] = {}
        ref_map_links: Dict[str, Link] = {}
        # ref_map_agents: Dict[str, AgentInterface] = {}


        game = cls()
        with open(cfg_path, 'r') as file:
            conf = yaml.safe_load(file)
        
        #1. create nodes 
        sim = Simulation()
        net = sim.network
        nodes_cfg = conf['simulation']['network']['nodes']
        links_cfg = conf['simulation']['network']['links']
        for node_cfg in nodes_cfg:
            node_ref = node_cfg['ref']
            n_type = node_cfg['type']
            if n_type == 'computer':
                new_node = Computer(hostname = node_cfg['hostname'], 
                                    ip_address = node_cfg['ip_address'], 
                                    subnet_mask = node_cfg['subnet_mask'], 
                                    default_gateway = node_cfg['default_gateway'],
                                    dns_server = node_cfg['dns_server'])
            elif n_type == 'server':
                new_node = Server(hostname = node_cfg['hostname'], 
                                    ip_address = node_cfg['ip_address'], 
                                    subnet_mask = node_cfg['subnet_mask'], 
                                    default_gateway = node_cfg['default_gateway'],
                                    dns_server = node_cfg.get('dns_server'))
            elif n_type == 'switch':
                new_node = Switch(hostname = node_cfg['hostname'],
                                  num_ports = node_cfg.get('num_ports'))
            elif n_type == 'router':
                new_node = Router(hostname=node_cfg['hostname'],
                                  num_ports = node_cfg.get('num_ports'))
                if 'ports' in node_cfg:
                    for port_num, port_cfg in node_cfg['ports'].items():
                        new_node.configure_port(port=port_num, 
                                                ip_address=port_cfg['ip_address'],
                                                subnet_mask=port_cfg['subnet_mask'])
                if 'acl' in node_cfg:
                    for r_num, r_cfg in node_cfg['acl'].items():
                        # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, so that we can do
                        # both of these things once: check if a key is defined, access and convert it to a 
                        # Port/IPProtocol. TODO Refactor
                        new_node.acl.add_rule(
                            action = ACLAction[r_cfg['action']],
                            src_port = None if not (p:=r_cfg.get('src_port')) else Port[p],
                            dst_port = None if not (p:=r_cfg.get('dst_port')) else Port[p],
                            protocol = None if not (p:=r_cfg.get('protocol')) else IPProtocol[p],
                            src_ip_address = r_cfg.get('ip_address'),
                            dst_ip_address = r_cfg.get('ip_address'),
                            position = r_num
                        )
            else:
                print('invalid node type')
            if 'services' in node_cfg:
                for service_cfg in node_cfg['services']:
                    service_ref = service_cfg['ref']
                    service_type = service_cfg['type']
                    service_types_mapping = {
                        'DNSClient': DNSClient, # key is equal to the 'name' attr of the service class itself.
                        'DNSServer' : DNSServer,
                        'DatabaseClient': DatabaseClient,
                        'DatabaseService': DatabaseService,
                        # 'database_backup': ,
                        'DataManipulationBot': DataManipulationBot,
                        # 'web_browser'
                    }
                    if service_type in service_types_mapping:
                        new_node.software_manager.install(service_types_mapping[service_type])
                        new_service = new_node.software_manager.software[service_type]
                        ref_map_services[service_ref] = new_service
                    else:
                        print(f"service type not found {service_type}")
                    # service-dependent options
                    if service_type == 'DatabaseClient':
                        if 'options' in service_cfg:
                            opt = service_cfg['options']
                            if 'db_server_ip' in opt:
                                new_service.configure(server_ip_address=IPv4Address(opt['db_server_ip']))
                    if service_type == 'DNSServer':
                        if 'options' in service_cfg:
                            opt = service_cfg['options']
                            if 'domain_mapping' in opt:
                                for domain, ip in opt['domain_mapping'].items():
                                    new_service.dns_register(domain, ip)
            if 'nics' in node_cfg:
                for nic_num, nic_cfg in node_cfg['nics'].items():
                    new_node.connect_nic(NIC(ip_address=nic_cfg['ip_address'], subnet_mask=nic_cfg['subnet_mask']))

            net.add_node(new_node)
            ref_map_nodes[node_ref] = new_node.uuid

        #2. create links between nodes
        for link_cfg in links_cfg:
            node_a = net.nodes[ref_map_nodes[link_cfg['endpoint_a_ref']]]
            node_b = net.nodes[ref_map_nodes[link_cfg['endpoint_b_ref']]]
            if isinstance(node_a, Switch):
                endpoint_a = node_a.switch_ports[link_cfg['endpoint_a_port']]
            else:
                endpoint_a = node_a.ethernet_port[link_cfg['endpoint_a_port']]
            if isinstance(node_b, Switch):
                endpoint_b = node_b.switch_ports[link_cfg['endpoint_b_port']]
            else:
                endpoint_b = node_b.ethernet_port[link_cfg['endpoint_b_port']]
            new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
            ref_map_links[link_cfg['ref']] = new_link.uuid

        #2. start/setup simulation objects
        #3. create agents
        #4. set up agents' actions and observation spaces.
        game.simulation = sim
        return game

s = PrimaiteSession.from_config('example_config.yaml')
# print(s.simulation.describe_state())

2023-09-26 11:47:11,032: Added node bc149bf5-ccc4-4dcd-b419-629ec44b2c9a to Network 2c22989f-8f91-4c61-8be9-1afd733b3e1c
2023-09-26 11:47:11,035: Added node 9cacbaee-33cc-4423-a6c8-fe3dd75b1f87 to Network 2c22989f-8f91-4c61-8be9-1afd733b3e1c
2023-09-26 11:47:11,042: Added node d4444d66-7cc3-4cd4-acbd-202cb9fe37ff to Network 2c22989f-8f91-4c61-8be9-1afd733b3e1c
2023-09-26 11:47:11,045: Added node af170371-e99b-42b7-9525-65ca64522539 to Network 2c22989f-8f91-4c61-8be9-1afd733b3e1c
2023-09-26 11:47:11,049: Added node d6218f34-a104-469d-a08b-97329ad84c19 to Network 2c22989f-8f91-4c61-8be9-1afd733b3e1c
2023-09-26 11:47:11,052: Added node 831a3803-ae65-4cee-a17e-9c1220035bc9 to Network 2c22989f-8f91-4c61-8be9-1afd733b3e1c
2023-09-26 11:47:11,055: Added node 1b935654-065d-4cb9-82d9-d67fe3d3304e to Network 2c22989f-8f91-4c61-8be9-1afd733b3e1c
2023-09-26 11:47:11,059: Added node dd181916-076b-4d8a-ab97-a32052624b09 to Network 2c22989f-8f91-4c61-8be9-1afd733b3e1c
2023-09-26 11:47:11,064: Added n

service type not found DatabaseBackup
service type not found WebBrowser


In [None]:
print(s.simulation.describe_state())