Importing Libraries

In [None]:
import numpy as np
from PIL import Image
import binascii
from array import *
from scapy.all import rdpcap
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import cv2
import tensorflow as tf
import json
import math
import os
import subprocess
from bs4 import BeautifulSoup

Loading Model and Setting Paths

In [None]:
model = load_model("path/to/trained/model")
pcap_file_path = "path/to/original/full/pcap/file/of/session/"
session_image_file_path = "path/to/greyscale/image/of/session/"
heatmap_file_path = "path/to/heatmap/image/generated/with/gradcam"
metadata_file_path = "path/to/metadata/json/file/of/session/"
destination_path = "path/to/store/results/of/reverse/lookup/"

'''Default Path of tshark. Change accordingly.'''
tshark_path = "C:/Program Files/Wireshark/tshark.exe"

''' Set to True in version W_Eth and Z_Eth. Set to False in version No_Eth.'''
has_ethernet = True

''' Set to False in case of Multiclass Classification.'''
binary_classification = True

image_size =  (32,32)

Defining Classes

In [None]:
class packet_layer_data:
    def __init__(self, name, length):
        self.name = name
        self.length = length

class packet_layer_data_real:
    def __init__(self, name, real_length):
        self.name = name
        self.real_length = real_length
        
class packet_layer_start:
    def __init__(self, name, start):
        self.name = name
        self.start = start
    
class yellow_regions_data:
    def __init__(self, position_image, packet_number, position_packet, layer, byte):
        self.position_image = position_image
        self.packet_number = packet_number
        self.position_packet = position_packet
        self.layer = layer
        self.byte = byte

Defining Functions

In [None]:
def binary_mask_generation(heatmap_path, image_path, session_name, dir_name, image_size):
    heatmap = cv2.imread(heatmap_path, cv2.IMREAD_COLOR)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    colored_heatmap = cv2.resize(heatmap, image_size)
    
    lower_yellow_threshold = np.array([150, 165, 0])
    upper_yellow_threshold = np.array([255, 255, 200])
    
    binary_mask = cv2.inRange(colored_heatmap, lower_yellow_threshold, upper_yellow_threshold)
    
    plt.imshow(colored_heatmap)
    plt.title(session_name)
    plt.axis('off')
    plt.savefig(f"{dir_name}heatmap_{session_name}.png", bbox_inches='tight', pad_inches=0.2)
    
    plt.imshow(binary_mask, cmap="gray")
    plt.title(f'Binary Mask of {session_name}')
    plt.axis('off')
    plt.savefig(f"{dir_name}binary_mask_{session_name}.png", bbox='tight', pad_inches=0.2)
    
    session_image = Image.open(image_path)
    
    input_image = np.array(session_image) / 255.0
    input_image = np.expand_dims(input_image, axis=0)
    
    prediction = model.predict(input_image)
    
    score = prediction[0]
    score = tf.nn.softmax(prediction[0])
    confidence = 100 * np.max(score)
    
    if binary_classification:
        
        if prediction > 0.5:
            prediction = 1
        else:
            prediction = 0
            
    else:
        
        prediction = np.argmax(prediction, axis=1)[0]
        
    print(f"This packet most likely belongs to {prediction} with a confidence score of {confidence}.")
        
    return binary_mask, confidence, prediction
        
def map_to_pcap(width, binary_mask):
    yellow_regions = []
    
    count = 0
    
    for i in binary_mask:
        for x in range(width):
            if i[x] == 255:
                position = (int(count)*int(width)) + int(x)
                yellow_regions.append(position)
                
        count += 1
        
    print(yellow_regions)
    return yellow_regions
        
def get_layer_for_position(relative_position, packet):
    layer_names = []
    
    counter = 0
    
    while True:
        layer = packet.getlayer(counter)
        if layer is None:
            break
        layer_names.append(layer.name)
        counter += 1
        
    summary = packet.summary()
    parts = summary.split(' / ')
    
    layer_names = [part.strip().split(" ")[0] for part in parts if part.strip()]
    
    if not has_ethernet:
        layer_names.remove('Ether')
    
    layers = []
    layers_real = []
    layers_start = []
    
    for layer_name in layer_names:
        layer = packet.getlayer(layer_name)
        if layer is not None:
            p_layer = packet_layer_data(layer_name, len(layer))
            layers.append(p_layer)
    
    for i, la in enumerate(layers):
        if i < (len(layers)-1):
            real_length = la.length - layers[i+1].length
            p_layer_real = packet_layer_data_real(la.name, real_length)
            layers_real.append(p_layer_real)
        else:
            real_length = la.length
            p_layer_real = packet_layer_data_real(la.name, real_length)
            layers_real.append(p_layer_real)
        
    for n, lay in enumerate(layers_real):
        if n==0:
            start = 0
            p_layer_start = packet_layer_start(lay.name, start)
            layers_start.append(p_layer_start)
        else:
            start = layers_real[n-1].real_length + layers_start[n-1].start
            p_layer_start = packet_layer_start(lay.name, start)
            layers_start.append(p_layer_start)
            
    found_layer = None
    
    for m, laye in enumerate(layers_start):
        if relative_position >= laye.start:
            found_layer = laye.name
        else:
            break
        
    layer = None
    
    if found_layer is not None:
        layer = found_layer
        found_layer = None
    
    return layer

def getPcapData_from_Positions(positions, meta_data, filename):
    packets = rdpcap(filename)
    yellow_regions = []
    
    for position in positions:
        found_packet = None
        for i, start in enumerate(meta_data)
            if not has_ethernet:
                if start == 0:
                    if position >= start:
                        found_packet = i+1
                    else:
                        break
                else:
                    if position >= (int(start)+14):
                        found_packet = i+1
                    else:
                        break
            else:
                if position >= start:
                    found_packet = i+1
                else:
                    break
        
        if found_packet is not None:
            packet = packets[found_packet-1]
            start_byte = meta_data[found_packet-1]
            
            if not has_ethernet:
                if start_byte != 0:
                    start_byte = start_byte+14
                    
            relative_position = position - start_byte
            
            hex_data = binascii.hexlify(bytes(packet))
            
            if not has_ethernet:
                hex_data = hex_data[28:]
                
            hex_data = hex_data.decode()
            
            if relative_position < len(hex_data) // 2:
                start = relative_position * 2
                end = start + 2
                byte_data = hex_data[start:end]
                
                layer = get_layer_for_position(relative_position, packet)
                new_position_data = yellow_regions_data(position, found_packet, relative_position, layer, byte_data)
                yellow_regions.append(new_position_data)
    
    return yellow_regions_data
        
def get_text_packets(pcap_path):
    pdml_file_path = './pdml_dump.pdml'
    data = None
    try:
        subprocess.run([tshark_path, '-r', pcap_path, '-T', 'pdml', '>', pdml_file_path], shell=True, check=True)
        with open(pdml_file_path, 'r', encoding='utf-8') as file:
            data = file.read()
        
        pdml_data = BeautifulSoup(data, 'xml')
        
        packets_xml = pdml_data.find_all('packet')
        
        text_packets = []
        
        for index, packet in enumerate(packets_xml):
            packet_fields = []
            for protocol in packet.find_all('proto'):
                layer_name = protocol.get('name')
                if layer_name != "frame" and layer_name != "getinfo":
                    layer_value = ""
                    layer_show = ""
                    layer_length = protocol.get('size')
                    layer_start_byte = protocol.get('pos')
                    
                    layer_field = {
                        "layer_name": layer_name,
                        "field_abbr": layer_name,
                        "field_name": f"Layer {layer_name.upper()}:",
                        "field_value": layer_value,
                        "field_show": layer_show,
                        "field_length": layer_length,
                        "field_start_byte": layer_start_byte
                    }
                    
                    packet_fields.append(layer_field)
                    
                    for index, field in enumerate(protocol.find_all('field')):
                        field_name = field.get('showname', field.get('show'))
                        field_abbr = field.get('name')
                        field_length = field.get('size')
                        field_start_byte = field.get('pos')
                        field_value = field.get('value')
                        field_show = field.get('show')
                        
                        field_obj = {
                            "layer_name": layer_name,
                            "field_abbr": field_abbr,
                            "field_name": f"    {field_name}",
                            "field_value": field_value,
                            "field_show": field_show,
                            "field_length": field_length,
                            "field_start_byte": field_start_byte
                        }
                        
                        packet_fields.append(field_obj)
                        
            text_packets.append(packet_fields)
            
        if text_packets != []:
            print(len(text_packets))
        else:
            print('Error in text_packets in get_text_packets function!')
        
        return text_packets
        
    except Exception as e:
        print(e)
        
def get_indices_of_fields(text_packets, highlighted_data):
    highlighted_layers = []
    
    for i, packet in enumerate(text_packets):
        current_highlighted_layer = []
        
        highlighted_bytes_data = [hexbyte for hexbyte in highlighted_data if hexbyte.packet_number == (i+1)]
        
        content_list = []
        
        if not highlighted_bytes_data:
            current_highlighted_layer = {
                "is_highlighted_packet": False,
                "line_indices": content_list,
            }
        else:
            for hexbyte in highlighted_bytes_data:
                if not has_ethernet:
                    byte_position = hexbyte.position_packet+14
                else:
                    byte_position = hexbyte.position_packet
                    
                fields_content = []
                hex_byte_data = []
                
                for field in packet:
                    if byte_position >= int(field['field_start_byte']) and byte_position <= (int(field['field_length']) + int(field['field_start_byte'])):
                        hex_byte_data.append(field['field_name'])
                        
                    
                imp_data = {
                    'hex_byte_pos': hexbyte.position_packet,
                    'hex_byte_val': hexbyte.byte,
                    'hex_byte_data': hex_byte_data
                }
                
                fields_content.append(imp_data)
                
                for ind in fields_content:
                    content_list.append(ind)
                    
            current_highlighted_layer = {
                'is_highlighted_packet': True,
                'line_content': content_list
            }
        
        highlighted_layers.append(current_highlighted_layer)
    
    return highlighted_layers
        

Doing Reverse Lookup for Session

In [None]:
''' adjust the index of the session name in the path accordingly. '''
index_of_name_in_path = 3
session_name = session_image_file_path.split('/')[index_of_name_in_path].split('.')[0]

session_results_folder = f"{destination_path}{session_name}/"

os.makedirs(session_results_folder)

meta_data = None

with open(metadata_file_path, 'r') as f:
    meta_data = json.load(f)

with open(pcap_file_path, 'rb') as f:
    content = f.read()
    
width = image_size[0]

binary_mask, confidence, prediction = binary_mask_generation(heatmap_file_path, session_image_file_path, session_results_folder, image_size)

yellow_regions = map_to_pcap(width, binary_mask)

Yellow_Data = getPcapData_from_Positions(yellow_regions, meta_data, pcap_file_path)

plt.figure(figsize=(width/2, width/2), dpi=60)

text_packets = get_text_packets(pcap_file_path)

imp_fields_content = get_indices_of_fields(text_packets, Yellow_Data)

with open(f"{destination_path}/layers_data.txt", 'w') as text_file:
    for i, pack in enumerate(imp_fields_content):
        if pack.get('is_highlighted_packet'):
            print("*"*50, file=text_file)
            print(f"Packet No: {i+1}", file=text_file)
            print(f"Highlighted Content:", file=text_file)
            print("_"*20, file=text_file)
            for content in pack.get('line_content'):
                print(f"Hex Byte Position: {content.get('hex_byte_pos')}", file=text_file)
                print(f"Hex Byte Value: {content.get('hex_byte_val')}", file=text_file)
                print(f"Hex Byte Content:", file=text_file)
                print(" ", file=text_file)
                for line in content.get('hex_byte_data'):
                    print(line, file=text_file)
                print("_"*20, file=text_file)
    
for i in Yellow_Data:
    if i.layer is not None:
        layer = str(i.packet_number)+"."+i.layer
    if i.layer == "Ether":
        layer = str(i.packet_number)+".Eth"
    data = str(i.position_packet) + "_" + i.byte
    y = int(i.position_image/width)
    x = int(i.position_image) - int(y*width)
    plt.text(x-0.4, y-0.1, layer)
    plt.text(x-0.4, y+0.3, data, fontdict={'size':9})
    square = plt.Rectangle((x-0.5, y-0.5), 1, 1, fc="none", ec="red")
    plt.gca().add_patch(square)
    
plt.imshow(binary_mask)
plt.axis('off')
plt.title(f"Prediction: {prediction}, Confidence Score: {confidence}%")
plt.show()
plt.clf()