In [None]:

from IPython.display import clear_output

from ipaddress import ip_address, ip_network, IPv4Address, IPv6Address
import panel as pn
import ipywidgets as widgets
import pandas as pd
import altair as alt
import pathlib
import datetime

pn.extension('vega')

SRC_AND_DST_IP = ["SOURCE IP ADDRESS", "DESTINATION IP ADDRESS"]

SRC_IP = SRC_AND_DST_IP[0]
DST_IP = SRC_AND_DST_IP[1]

file_path = pathlib.Path("Task 1\\20240701_16-17_100l.csv")
dir_path = pathlib.Path().resolve()

csv_file = pd.read_csv(dir_path / file_path)
data: pd.DataFrame = pd.DataFrame(csv_file)

def get_first_time_flow_string():
    return data["START TIME - FIRST SEEN"].min()

def get_last_time_flow_string():
    return data["START TIME - FIRST SEEN"].max()

def get_first_time_flow():
    return datetime.datetime.fromisoformat(get_first_time_flow_string())

def get_last_time_flow():
    return datetime.datetime.fromisoformat(get_last_time_flow_string())

def ipv4_in_range_cmp(ip, min, max):
    if is_ipv4(ip):
        return min <= try_convert_ip_to_int(ip) <= max
    return True

def ipv6_in_range_cmp(ip, min, max):
    if is_ipv6(ip):
        return min <= try_convert_ip_to_int(ip) <= max
    return True    

def try_convert_ip_to_int(ip):
    try:
        return int(ip_address(ip))
    except Exception:
        return 0

def is_ipv6(addr):
    try:
        return type(ip_address(addr)) is IPv6Address
    except Exception:
        return False

def is_ipv4(addr):
    try:
        return type(ip_address(addr)) is IPv4Address
    except Exception:
        return False

def to_datetime(x):
    return datetime.datetime.fromisoformat(x)

def time_to_seconds(time_str):
    time_parts = time_str.split()
    minutes = int(time_parts[0].strip('m'))
    seconds = int(time_parts[1].strip('s'))

    total_minutes = minutes * 60 + seconds
    return total_minutes

def is_in_muni_subdomain(input_str):
    return "muni.cz" in input_str

style = {'description_width': 'initial'}

apply_changes_widget = widgets.Button(
    description='Apply changes',
    tooltip='Apply changes',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    style = style,
)

top_n_widget = widgets.Dropdown(
    options=['10', '20', '30', '40'],
    value='10',
    description='Top N stats',
    disabled=False,
    style = style,
)

single_or_pair_widget = widgets.ToggleButtons(
    options=['Single', 'Pair'],
    description='Info about single IP adresss or pair:',
    disabled=False,
    button_style='info', # 'success', 'info', 'warning', 'danger' or ''
    style = style,
)

src_dst_widget = widgets.ToggleButtons(
    options=['Received', 'Sent'],
    description='Direction:',
    disabled=False,
    button_style='info', # 'success', 'info', 'warning', 'danger' or ''
    style=style,
)

data_type_widget = widgets.ToggleButtons(
    options=['Bytes', 'Packets', 'Flows', 'Duration with dst IP',
             'Most DNS requests', 'Subdomains of muni.cz', 
             'Unusual TLDs', "IP requesting unusual TLD"],
    description='Data type:',
    disabled=False,
    button_style='info', # 'success', 'info', 'warning', 'danger' or ''
    style = {'button_width':'200px'}, layout={'width': '520px'},
)

period_widget = widgets.IntRangeSlider(
    min=0,
    max=36000,
    value=[0, 36000],
    step=5,
    description="Time range",
    disabled=False,
    style = style,
)

ipv4_range_widget = widgets.Textarea(
    value='0.0.0.0/0,',
    description='List of IPv4 masks, separeted by \",\"',
    disabled=False,
    style = style,
    layout = widgets.Layout(width='50%')
)

ipv6_range_widget = widgets.Textarea(
    value='0:0:0:0:0:0:0:0/0,',
    description='List of IPv6 masks, separeted by \",\"',
    disabled=False,
    style = style,
    layout = widgets.Layout(width='80%')
)

protocol_widget = widgets.Dropdown(
    options=["ALL", "UDP", "SMTP",
             "DNS", "HTTP", "HTTPS",
             "SMB", "TELNET",
             "RDP", "FTP", "TFTFP"
            ],
    description='Protocol',
    disabled=False,
    style = style,
)

port_widget = widgets.Dropdown(
    options=["ALL",
             "8080, 8888, 591, 82",
             "8443, 9443, 4443",
             "5353, 5355",
             "2121, 8021",
             "2525, 26",
             "3390, 3391, 4000",
            ],
    description='Ports',
    disabled=False,
    style = style,
)

minimal_duration_widget = widgets.Text(
    value='disabled',
    description='Minimal duration of a flow in seconds',
    disabled=True,
    style = style,
    layout = widgets.Layout(width='50%')
)

ui_widgets = [
              top_n_widget, single_or_pair_widget, src_dst_widget, data_type_widget,
              minimal_duration_widget, period_widget, ipv4_range_widget, ipv6_range_widget,
              protocol_widget, port_widget, apply_changes_widget,
              ]


def display_widgets():
    for ui_widget in ui_widgets:
        display(ui_widget)

def disable_button(button):
    button.button_style = ''
    button.disabled = True

def enable_button(button):
    button.button_style = 'info'
    button.disabled = False

def process_flow_pair(curr_data, top_n_n):
    tmp_data = curr_data[[SRC_IP, DST_IP]].apply(lambda x: tuple(sorted(x)), axis=1).value_counts().reset_index()
    tmp_data.columns = ['pair', 'flows']
    tmp_data[[SRC_IP, DST_IP]] = pd.DataFrame(tmp_data['pair'].tolist(), index=tmp_data.index)
    processed_data = tmp_data.drop(columns=['pair']).nlargest(top_n_n, columns='flows')
    processed_data['src-dst'] = processed_data[SRC_IP] + ' - ' + processed_data[DST_IP]
    return processed_data[['src-dst', 'flows']]

def process_duration_single(curr_data, g_data_type, minimal_duration, top_n_n):
    processed_data = pd.DataFrame(curr_data)
    processed_data[g_data_type] = processed_data[g_data_type].apply(time_to_seconds)
    return processed_data[(processed_data[g_data_type] >= int(minimal_duration))].groupby(DST_IP).agg(
                    DURATION=(g_data_type, 'sum'),
                    COUNT=(DST_IP, 'count'),
                    ).reset_index().nlargest(top_n_n, columns=g_data_type)

def process_ip_requesting_unusual_tld(curr_data, g_data_type, top_n_n, unusual_tlds_regex):
    processed_data = curr_data[curr_data[g_data_type].str.contains(unusual_tlds_regex, regex=True, na=False)]
    return processed_data.groupby(SRC_IP)[SRC_IP].count().reset_index(name="IP REQUESTING UNUSUAL TLD").nlargest(top_n_n, columns="IP REQUESTING UNUSUAL TLD")
    
    
def process_unusual_tld(curr_data, g_data_type, top_n_n, unusual_tlds_regex):
    processed_data = curr_data[curr_data[g_data_type].str.contains(unusual_tlds_regex, regex=True, na=False)]
    processed_data = processed_data.groupby(g_data_type)[g_data_type].count().reset_index(name='COUNT').nlargest(top_n_n, columns='COUNT')
    return processed_data.rename(columns={g_data_type: "UNUSUAL TLD"})

def process_by_type(curr_data, g_src_dst, g_data_type, top_n_n, single_or_pair, minimal_duration):
    if curr_data.empty:
        return curr_data

    unusual_tlds_regex = "\.xyz$|\.top$|\.club$|\.site$|\.online$|\.loan$|\.trade$|\.accountants$|\.space$"
    if single_or_pair == "Pair":
        tmp_data = pd.DataFrame(curr_data)
        if g_data_type == "FLOWS":
            processed_data = process_flow_pair(curr_data, top_n_n)
        else:
            tmp_data['pair'] = curr_data[[SRC_IP, DST_IP]].apply(lambda x: tuple(sorted(x)), axis=1)
            processed_data = tmp_data.groupby('pair', as_index=False)[g_data_type].sum().nlargest(top_n_n, columns=g_data_type)

    else:
        if g_data_type == "FLOWS":
            processed_data = curr_data.groupby(g_src_dst)[g_src_dst].count().reset_index(name=g_data_type).nlargest(top_n_n, columns=g_data_type)
        elif g_data_type == "DURATION":
            processed_data = process_duration_single(curr_data, g_data_type, minimal_duration, top_n_n)
        elif g_data_type == "DNS QNAME":
            processed_data = curr_data[curr_data[g_data_type].str.contains('muni.cz', na=False)]
            processed_data = processed_data.groupby(g_data_type)[g_data_type].count().reset_index(name='COUNT').nlargest(top_n_n, columns='COUNT')
        elif g_data_type == "UNUSUAL TLD":
            processed_data = process_unusual_tld(curr_data, "DNS QNAME", top_n_n, unusual_tlds_regex)
        elif g_data_type == "IP REQUESTING UNUSUAL TLD":
            processed_data = process_ip_requesting_unusual_tld(curr_data, "DNS QNAME", top_n_n, unusual_tlds_regex)
        elif g_data_type == "MOST DNS REQUESTS":
            processed_data = data[data['DNS QNAME'].notna()].groupby(SRC_IP)[SRC_IP].count().reset_index(name=g_data_type).nlargest(top_n_n, columns=g_data_type)
        else:
            processed_data = curr_data[[g_src_dst, g_data_type]].groupby(g_src_dst).sum().reset_index().nlargest(top_n_n, columns=g_data_type)
    return processed_data

def process_by_ipvx_range_pair(curr_data, ip_ranges, cmp_func):
    tmp_data = pd.DataFrame()
    for ip_subnet in ip_ranges:
        if len(ip_subnet) <= 1:
            continue
        ip_range_int = ip_network(ip_subnet.strip())
        ip_start, ip_end = int(ip_range_int[0]), int(ip_range_int[-1])
    
        tmp_data = pd.concat([tmp_data, curr_data[(curr_data[SRC_IP].apply(cmp_func, args=(ip_start, ip_end))) &
                                                    (curr_data[DST_IP].apply(cmp_func, args=(ip_start, ip_end)))]])
    
    return tmp_data

#need to fix filtering ipv4 and ipv6 adresses separately
def process_by_ipvx_range_single(curr_data, g_src_dst, ip_ranges):
    tmp_data = pd.DataFrame()
    for ip_subnet in ip_ranges:
        if len(ip_subnet) <= 1:
            continue
        ip_range_int = ip_network(ip_subnet.strip())
        ip_start, ip_end = int(ip_range_int[0]), int(ip_range_int[-1])

        tmp_data = pd.concat([tmp_data, curr_data[(curr_data[g_src_dst].apply(try_convert_ip_to_int).between(ip_start, ip_end))]])

    return pd.DataFrame(tmp_data)

def process_by_ip_range(curr_data, g_src_dst, ipv4_ranges_str, ipv6_ranges_str, single_or_pair):
    if curr_data.empty:
        return curr_data

    if single_or_pair == 'Single':
        ipv4_addrs = curr_data[(curr_data[g_src_dst].apply(is_ipv4))]
        ipv6_addrs = curr_data[(curr_data[g_src_dst].apply(is_ipv6))]
        return pd.concat([process_by_ipvx_range_single(ipv4_addrs, g_src_dst, ipv4_ranges_str), 
                        process_by_ipvx_range_single(ipv6_addrs, g_src_dst, ipv6_ranges_str),
                        ])

    ipv4_addrs = curr_data[(curr_data[SRC_IP].apply(is_ipv4)) | (curr_data[DST_IP].apply(is_ipv4))]
    ipv6_addrs = curr_data[(curr_data[SRC_IP].apply(is_ipv6)) | (curr_data[DST_IP].apply(is_ipv6))]

    return pd.concat([process_by_ipvx_range_pair(ipv4_addrs, ipv4_ranges_str, ipv4_in_range_cmp), 
                      process_by_ipvx_range_pair(ipv6_addrs, ipv6_ranges_str, ipv6_in_range_cmp)])

def process_by_time_range(curr_data, start_offset, end_offset):
    if curr_data.empty:
        return curr_data

    start = get_first_time_flow() + datetime.timedelta(minutes=start_offset)
    end = get_last_time_flow() + datetime.timedelta(minutes=end_offset)

    return curr_data[(curr_data["START TIME - FIRST SEEN"].apply(to_datetime) >= start) &
                      ((curr_data["START TIME - FIRST SEEN"].apply(to_datetime)) <= end)]

def process_by_protocol(curr_data, protocol):
    if curr_data.empty:
        return curr_data

    if protocol == "ALL":
        return curr_data

    elif (protocol == "SMTP" or protocol == "DNS"):
        #ask how to detect dns/smtp server
        return curr_data[curr_data["PROTOCOL"] == protocol]

    else:
        return curr_data[curr_data["PROTOCOL"] == protocol]

def process_by_port(curr_data, ports, g_src_dst):
    if curr_data.empty:
        return curr_data

    if ports == "ALL":
        return curr_data

    port_dir = "SOURCE PORT" if g_src_dst == SRC_IP else "DESTINATION PORT"

    tmp_data = pd.DataFrame()
    ports_list = ports.split(', ')

    for port in ports_list:
        tmp_data = pd.concat([tmp_data, curr_data[curr_data[port_dir] == int(port)]])

    return tmp_data

def get_traffic_data(g_src_dst, g_data_type, top_n, time_period,
                     ipv4_range_str, ipv6_range_str, single_or_pair,
                     protocol, ports, minimal_duration):
    processed_data = pd.DataFrame(data)
    period_start, period_end = time_period

    processed_data: pd.DataFrame = process_by_time_range(processed_data, period_start, period_end)
    processed_data: pd.DataFrame = process_by_port(processed_data, ports, g_src_dst)
    processed_data: pd.DataFrame = process_by_protocol(processed_data, protocol)
    processed_data: pd.DataFrame = process_by_ip_range(processed_data, g_src_dst, ipv4_range_str.split(','), ipv6_range_str.split(','), single_or_pair)

    processed_data: pd.DataFrame = process_by_type(processed_data, g_src_dst, g_data_type, top_n, single_or_pair, minimal_duration)
    
    processed_data.sort_values(processed_data.columns[1])

    if g_data_type in ["DURATION", "DNS QNAME", "UNUSUAL TLD"]:
        tooltips = [g_data_type, "COUNT"]
        #selection = alt.selection_point(fields=[DST_IP], name='DST')
    else:
        tooltips = [g_data_type]
        #selection = None
    
    print(processed_data)
    
    point = alt.selection_point(name="select", on="click", fields=["Category"])

    graph = alt.Chart(processed_data).mark_bar().encode(
        x=processed_data.columns[1],
        y=alt.Y(processed_data.columns[0], sort='-x'),
        tooltip=tooltips,
        color=alt.Color(processed_data.columns[1],
                   scale=alt.Scale(range=['lightgreen', 'green']))
    ).properties(
        width=600,
        height=1000 if top_n >= 30 else 600,
        autosize=alt.AutoSizeParams(
            type='fit',
            contains='padding'
        ),
    ).add_params(point)
    
    return alt.JupyterChart(graph)


def on_change_singel_pair(v):
    if single_or_pair_widget.value == 'Pair':
        disable_button(src_dst_widget)
    else:
        enable_button(src_dst_widget)

def on_change_data_type(v):
    clear_output(wait=True)
    if data_type_widget.value == 'Duration with dst IP':
        single_or_pair_widget.value = 'Single'
        minimal_duration_widget.disabled = False
        minimal_duration_widget.value = "5"
        disable_button(src_dst_widget)
        disable_button(single_or_pair_widget)
    else: 
        minimal_duration_widget.disabled = True
        minimal_duration_widget.value = "disabled"
        if single_or_pair_widget.value != 'Pair' or single_or_pair_widget.disabled:
            enable_button(src_dst_widget)
            enable_button(single_or_pair_widget)

    if data_type_widget.value in ['IP requesting unusual TLD',
                                    'Subdomains of muni.cz',
                                    'Unusual TLDs', 
                                    'Most DNS requests']:

        single_or_pair_widget.value = 'Single'
        disable_button(src_dst_widget)
        disable_button(single_or_pair_widget)
        
    elif data_type_widget.value != 'Duration with dst IP':
        minimal_duration_widget.disabled = True
        minimal_duration_widget.value = "disabled"
        if single_or_pair_widget.value != 'Pair' or single_or_pair_widget.disabled:
            enable_button(src_dst_widget)
            enable_button(single_or_pair_widget)

# Function to print selected bar
def on_selection_change(change):
    clear_output(wait=True)
    selected = change.new
    if selected and "Category" in selected:
        print(f"Selected Category: {selected['Category']}")
    else:
        print("No selection")

# def update_selection(event):
#     selected = event.new.get(DST_IP, [])
#     if selected:
#         selected_value = selected[0][DST_IP] 
#         result_pane.object = f"**Selected:** {selected_value} → {data[selected_value]}"
#     else:
#         result_pane.object = "**No selection**"

def on_change(v):

    clear_output(wait=True)

    if src_dst_widget.value == 'Received':
        g_src_dst = SRC_IP
    else:
        g_src_dst = DST_IP

    if data_type_widget.value == "Bytes":
        g_data_type = "BYTES"
    elif data_type_widget.value == "Packets":
        g_data_type = "PACKETS"
    elif data_type_widget.value == "Flows":
        g_data_type = "FLOWS"
    elif data_type_widget.value == "Subdomains of muni.cz":
        g_data_type = "DNS QNAME"
    elif data_type_widget.value == "Unusual TLDs":
        g_data_type = "UNUSUAL TLD"
    elif data_type_widget.value == "IP requesting unusual TLD":
        g_data_type = "IP REQUESTING UNUSUAL TLD"
    elif data_type_widget.value == "Most DNS requests":
        g_data_type = "MOST DNS REQUESTS"
    else:
        g_data_type = "DURATION"


    graph: alt.JupyterChart = get_traffic_data(g_src_dst, g_data_type,
                            int(top_n_widget.value), period_widget.value,
                            ipv4_range_widget.value,ipv6_range_widget.value,
                            single_or_pair_widget.value, protocol_widget.value,
                            port_widget.value, minimal_duration_widget.value)
    
    display_widgets()
    display(graph)
    print(graph.selections.observe(on_selection_change, "select"))


single_or_pair_widget.on_trait_change(on_change_singel_pair)
data_type_widget.on_trait_change(on_change_data_type)
apply_changes_widget.on_click(on_change)
display_widgets()



     DESTINATION IP ADDRESS  DURATION  COUNT
4407        147.251.253.196       897      3
3438        147.251.164.122       889      4
3530          147.251.164.9       877      3
3433        147.251.164.114       845      3
97            10.16.108.124       617      4
3442        147.251.164.128       615      3
161           10.16.108.218       596      2
789             10.56.2.147       596      2
6101        157.157.126.240       596      2
7037         195.113.158.65       595      3


Dropdown(description='Top N stats', options=('10', '20', '30', '40'), style=DescriptionStyle(description_width…

ToggleButtons(description='Info about single IP adresss or pair:', disabled=True, options=('Single', 'Pair'), …

ToggleButtons(description='Direction:', disabled=True, options=('Received', 'Sent'), style=ToggleButtonsStyle(…

ToggleButtons(button_style='info', description='Data type:', index=3, layout=Layout(width='520px'), options=('…

Text(value='5', description='Minimal duration of a flow in seconds', layout=Layout(width='50%'), style=TextSty…

IntRangeSlider(value=(0, 36000), description='Time range', max=36000, step=5, style=SliderStyle(description_wi…

Textarea(value='0.0.0.0/0,', description='List of IPv4 masks, separeted by ","', layout=Layout(width='50%'), s…

Textarea(value='0:0:0:0:0:0:0:0/0,', description='List of IPv6 masks, separeted by ","', layout=Layout(width='…

Dropdown(description='Protocol', options=('ALL', 'UDP', 'SMTP', 'DNS', 'HTTP', 'HTTPS', 'SMB', 'TELNET', 'RDP'…

Dropdown(description='Ports', options=('ALL', '8080, 8888, 591, 82', '8443, 9443, 4443', '5353, 5355', '2121, …

Button(button_style='success', description='Apply changes', style=ButtonStyle(), tooltip='Apply changes')

JupyterChart(spec={'config': {'view': {'continuousWidth': 300, 'continuousHeight': 300}}, 'data': {'name': 'da…

None
