In [6]:
import pysmile
import pysmile_license
from pathlib import Path
from pysmile.learning import DataSet, EM
import itertools
import numpy as np

In [7]:
fileName = "DBNfromAG.xdsl"
outFileName = "DBNfromAG_learned.xdsl"
tracesFileName = "dbnLogs100.csv"
numSlices = 300
dt = 1 # time step in seconds

outcomes = ["N", "C"]
priorNodes = ["workStation_compromise"]
# <node_id>: mean TTC

analyticAccuracy = 0.95
analyticNodes = []
analyticsDict = {}
for node in analyticNodes:
    analyticsDict[node] = analyticAccuracy

orNodes = ["historianServer_NodeOR1"]
andNodes = ["historianServer_remoteShellAND", "MMSclient1_AND52", "MMSserver1_NodeAND23"]

In [8]:
net = pysmile.Network()
net.read_file(fileName)
net.set_slice_count(numSlices)

In [9]:
def flattenExtended(cpt: list):
    """Flatten a list of lists into a single list."""
    listCpt = []
    for subitem in cpt:
        if isinstance(subitem, list):
            for item in subitem:
                listCpt.append(item)
        else:
            listCpt.append(subitem)
    return listCpt

def plotDefinitions(net: pysmile.Network):
    nodeHandles = net.get_all_nodes()
    nodeIds = net.get_all_node_ids()
    for nodeHandle, nodeId in zip(nodeHandles, nodeIds):
        nodeDef = net.get_node_definition(nodeHandle)
        nodeOutcomes = net.get_outcome_ids(nodeHandle)
        print(f"Node ID: {nodeId}, Definition: {nodeDef}, Outcomes: {nodeOutcomes}")

def learnParams(net: pysmile.Network, fileName: str, randomize: bool = False, uniformize: bool = False, relevance: bool = True):
    ds = DataSet()
    ds.read_file(fileName)
    matching = ds.match_network(net)
    em = EM()
    em.set_seed(98)
    em.set_relevance(relevance)
    em.set_randomize_parameters(randomize)
    em.set_uniformize_parameters(uniformize)
    em.learn(ds, net, matching)

def findNodeHandle(net: pysmile.Network, nodeId: str):
    nodeIds = net.get_all_node_ids()
    nodeHandles = net.get_all_nodes()
    for nodeHandle, id in zip(nodeHandles, nodeIds):
        if id == nodeId:
            return nodeHandle
    return None

def fixDiscrParams(net: pysmile.Network, tacticsDict: dict, analyticsDict: dict, orNodes: list, andNodes: list):
    for nodeId, accuracy in analyticsDict.items():
        nodeHandle = findNodeHandle(net, nodeId)
        if nodeHandle is not None:
            net.set_node_definition(nodeHandle, [accuracy, 1-accuracy, 1-accuracy, accuracy])
            print("Set parameters for analytic node:", nodeId)
        else:
            print(f"Node {nodeId} not found in the network.")

    for nodeId in itertools.chain(orNodes, andNodes):
        nodeHandle = findNodeHandle(net, nodeId)
        if nodeHandle is not None:
            parents = net.get_parents(nodeHandle)
            # Compute all possible combinations of parent outcomes
            numParents = len(parents)
            if numParents > 0:
                parentOutcomes = [net.get_outcome_ids(parent) for parent in parents]
                combinations = [list(comb) for comb in itertools.product(*parentOutcomes)]
                # Set the definition for the OR node based on parent outcomes
                if nodeId in orNodes:
                    nodeDefinition = [[0,1] if outcomes[1] in comb else [1,0] for comb in combinations]
                else:
                    nodeDefinition = [[0,1] if outcomes[0] not in comb else [1,0] for comb in combinations]
                nodeDefinition = [item for sublist in nodeDefinition for item in sublist]
                net.set_node_definition(nodeHandle, nodeDefinition)
        else:
            print(f"Node {nodeId} not found in the network.")

In [10]:
learnParams(net, tracesFileName, randomize=False, uniformize=True, relevance=False)

In [11]:
fixDiscrParams(net, None, analyticsDict, orNodes, andNodes)

In [12]:
net.write_file(outFileName)

In [13]:
plotDefinitions(net)

Node ID: DMZ_scanIP, Definition: [0.995049504950495, 0.0049504950495049506], Outcomes: ['N', 'C']
Node ID: workStation_compromise, Definition: [0.002567282878116235, 0.9974327171218837], Outcomes: ['N', 'C']
Node ID: historianServer_addSSHkey, Definition: [0.994033893993792, 0.005966106006207975], Outcomes: ['N', 'C']
Node ID: historianServer_remoteShellAND, Definition: [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0], Outcomes: ['N', 'C']
Node ID: historian_scanVuln, Definition: [0.9948976735448269, 0.00510232645517306], Outcomes: ['N', 'C']
Node ID: historian_remoteSrvc, Definition: [0.9949344172481706, 0.005065582751829554], Outcomes: ['N', 'C']
Node ID: historianServer_shell, Definition: [0.995044299765697, 0.004955700234303086], Outcomes: ['N', 'C']
Node ID: historianServer_remoteShell, Definition: [0.6762922923800851, 0.32370770761991485], Outcomes: ['N', 'C']
Node ID: historianServer_NodeOR1, Definition: [1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], Outcomes: ['N', 'C']
Node ID: MMSclient1_