In [1]:
import pytest
from similarity import similar
from similarity import draw_tree_multiple_command
from similarity import create_dict_nodes_labels
from similarity import program_name_win
from similarity import token_with_tags
from similarity import normalize_digits
from similarity import normalize_uuid


@pytest.mark.parametrize(
    "text1, text2, is_similar",
    [("test", "test", 1.0), ("one", "two", 0.333), ("12345", "45678", 0.4)],
)
def test_similar(text1, text2, is_similar):
    assert is_similar == similar(text1, text2)


@pytest.mark.parametrize(
    "tags, tokens, graph_nodes, graph_edges",
    [
        # empty
        (
            [],
            [],
            [],
            [],
        ),
        # only program
        (
            ["COMM"],
            ["del"],
            [("COMM", 0, "del")],
            [(("COMM", 0, "del"), ("COMM", 0, "del"))],
        ),
        # normal command
        (
            ["COMM", "SUBCOMM", "FLAG", "FLAG_VALUE"],
            ["PIP", "install", "-upgrade", "pip"],
            [("COMM", 0, "PIP"), ("SUBCOMM", 1, "install"), ("FLAG", 2, "-upgrade"), ("FLAG_VALUE", 3, "pip")],
            [
                (("COMM", 0, "PIP"), ("COMM", 0, "PIP")),
                (("COMM", 0, "PIP"), ("SUBCOMM", 1, "install")),
                (("COMM", 0, "PIP"), ("FLAG", 2, "-upgrade")),
                (("FLAG", 2, "-upgrade"), ("FLAG_VALUE", 3, "pip")),
            ],
        ),
        # TODO repetition : now failing, node name collision
        (
            ["COMM", "SUBCOMM", "FLAG", "FLAG_VALUE"],
            ["pip", "install", "-upgrade", "pip"],
            [("COMM", 0, "pip"), ("SUBCOMM", 1, "install"), ("FLAG", 2, "-upgrade"), ("FLAG_VALUE", 3, "pip")],
            [
                (("COMM", 0, "pip"), ("COMM", 0, "pip")),
                (("COMM", 0, "pip"), ("SUBCOMM", 1, "install")),
                (("COMM", 0, "pip"), ("FLAG", 2, "-upgrade")),
                (("FLAG", 2, "-upgrade"), ("FLAG_VALUE", 3, "pip")),
            ],
        ),
        # combination of flags and flag values
        (
            ["COMM", "FLAG", "FLAG_VALUE", "FLAG", "FLAG", "FLAG_VALUE"],
            ["C:\\WINDOWS\\system32\\svchost.exe", "-k", "LocalService", "-p", "-s", "WebClient"],
            [
                ("COMM", 0, "C:\\WINDOWS\\system32\\svchost.exe"),
                ("FLAG", 1, "-k"),
                ("FLAG_VALUE", 2, "LocalService"),
                ("FLAG", 3, "-p"),
                ("FLAG", 4, "-s"),
                ("FLAG_VALUE", 5, "WebClient"),
            ],
            [
                (("COMM", 0, "C:\\WINDOWS\\system32\\svchost.exe"), ("COMM", 0, "C:\\WINDOWS\\system32\\svchost.exe")),
                (("COMM", 0, "C:\\WINDOWS\\system32\\svchost.exe"), ("FLAG", 1, "-k")),
                (("COMM", 0, "C:\\WINDOWS\\system32\\svchost.exe"), ("FLAG", 3, "-p")),
                (("COMM", 0, "C:\\WINDOWS\\system32\\svchost.exe"), ("FLAG", 4, "-s")),
                (("FLAG", 1, "-k"), ("FLAG_VALUE", 2, "LocalService")),
                (("FLAG", 4, "-s"), ("FLAG_VALUE", 5, "WebClient")),
            ],
        ),
        # many commands
        (
            ["COMM", "PARAM", "COMM", "COMM", "PARAM"],
            ["rundll32.exe", "param1", "MSIDE24.tmp", "process3", "param3"],
            [
                ("COMM", 0, "rundll32.exe"),
                ("PARAM", 1, "param1"),
                ("COMM", 2, "MSIDE24.tmp"),
                ("COMM", 3, "process3"),
                ("PARAM", 4, "param3"),
            ],
            [
                (("COMM", 0, "rundll32.exe"), ("COMM", 0, "rundll32.exe")),
                (("COMM", 0, "rundll32.exe"), ("COMM", 2, "MSIDE24.tmp")),
                (("COMM", 2, "MSIDE24.tmp"), ("COMM", 3, "process3")),
                (("COMM", 0, "rundll32.exe"), ("PARAM", 1, "param1")),
                (("COMM", 3, "process3"), ("PARAM", 4, "param3")),
            ],
        ),
    ],
)
def test_draw_tree_multiple_command(tags, tokens, graph_nodes, graph_edges):
    graph = draw_tree_multiple_command(tags, token_with_tags(tags, tokens), tokens)
    assert len(graph_nodes) == len(graph.nodes())
    assert len(graph_edges) == len(graph.edges())
    assert set(graph_nodes) == set(graph.nodes())
    assert set(graph_edges) == set(graph.edges())


@pytest.mark.parametrize(
    "labels, nodes, result",
    [
        # empty
        (
            [],
            [],
            {},
        ),
        # only program
        (
            ["COMM"],
            ["del"],
            {"COMM": ["del"]},
        ),
        # normal command
        (
            ["COMM", "SUBCOMM", "FLAG", "FLAG_VALUE"],
            ["PIP", "install", "-upgrade", "pip"],
            {"COMM": ["PIP"], "SUBCOMM": ["install"], "FLAG": ["-upgrade"], "FLAG_VALUE": ["pip"]},
        ),
        # combination of flags and flag values
        (
            ["COMM", "FLAG", "FLAG_VALUE", "FLAG", "FLAG", "FLAG_VALUE"],
            ["C:\\WINDOWS\\system32\\svchost.exe", "-k", "LocalService", "-p", "-s", "WebClient"],
            {
                "COMM": ["C:\\WINDOWS\\system32\\svchost.exe"],
                "FLAG": ["-k", "-p", "-s"],
                "FLAG_VALUE": ["LocalService", "WebClient"],
            },
        ),
        # many commands
        (
            ["COMM", "PARAM", "COMM", "COMM", "PARAM"],
            ["rundll32.exe", "param1", "MSIDE24.tmp", "process3", "param3"],
            {
                "COMM": ["rundll32.exe", "MSIDE24.tmp", "process3"],
                "PARAM": ["param1", "param3"],
            },
        ),
    ],
)
def test_create_dict_nodes_labels(labels, nodes, result):
    assert result == create_dict_nodes_labels(labels, nodes)


@pytest.mark.parametrize(
    "command, name",
    [
        ("C:\\WINDOWS\\system32\\svchost.exe", "svchost"),
        ("rundll32.exe", "rundll32"),
    ],
)
def test_program_name(command, name):
    assert name == program_name_win(command)


@pytest.mark.parametrize(
    "text, expected",
    [
        (["123ad"], ["000ad"]),
        (["abc", "1a"], ["abc", "0a"]),
        (["987"], ["000"]),
    ],
)
def test_normalize_digits(text, expected):
    assert expected == normalize_digits(text)


@pytest.mark.parametrize(
    "text, expected",
    [
        (["7916 0cd76741-59fa-4639-8064-5548b55b5cae 612"], ["7916 <<UUID>> 612"]),
        (["{135E5B8E-5CFA-4C44-B3A1-1486D11778B0}"], ["{<<UUID>>}"]),
        (["abc", "1a"], ["abc", "1a"]),
    ],
)
def test_normalize_uuid(text, expected):
    assert expected == normalize_uuid(text)


ModuleNotFoundError: No module named 'strsim'