In [1]:
from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset
import torch

In [2]:
data_files = './data_ezsocket/merged_api_para_pcapdata_dataset100k.csv'
# ; is the tab character in Python
ezsocket_dataset = load_dataset("csv", data_files=data_files, delimiter=";")

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

  return pd.read_csv(xopen(filepath_or_buffer, "rb", download_config=download_config), **kwargs)


In [3]:
import re

####以下函数用于讲10进制数token化#####
def format_decimal_as_hexadecimal(decimal_str):
    # Convert the decimal string to an integer
    decimal_number = int(decimal_str)
    
    # Convert the integer to a hexadecimal string
    hex_str = hex(decimal_number)[2:]  # Strip the '0x' prefix
    
    # Ensure the length of the hex string is even
    if len(hex_str) % 2 != 0:
        hex_str = '0' + hex_str
    
    # Split the hex string into pairs of characters
    hex_pairs = [hex_str[i:i+2] for i in range(0, len(hex_str), 2)]
    
    # Join the pairs with commas
    formatted_hex = ','.join(hex_pairs)
    
    return formatted_hex


def convert_number(num_str):
    num = float(num_str)
    if num.is_integer():
        num = int(num)
        sign = "-" if num < 0 else "+"
        num_str = format_decimal_as_hexadecimal(str(num).lstrip('-'))
        return f"num,{sign},{num_str},num"
    else:
        sign = "-" if num < 0 else "+"
        num_str = num_str.lstrip('-')
        integer_part, fractional_part = num_str.split('.')
        combined_num = format_decimal_as_hexadecimal(integer_part + fractional_part.rstrip('0'))
        pos_num = format_decimal_as_hexadecimal(len(fractional_part.rstrip('0')))
        return f"num,{sign},{combined_num},pos,{pos_num},num"

def process_segment(segment):
    parts = segment.split(',')
    for i, part in enumerate(parts):
        if re.match(r'^-?\d+(\.\d+)?$', part):  # Match integers and floating-point numbers
            parts[i] = convert_number(part)
    result = ','.join(parts)
    result = result.replace(",", " ")
    return result
####以上函数用于将10进制数token化：process_segment(segment)#####

####以下函数用于将payload按两位分开，用','隔开#####
def split_payload_into_pairs(text):
    # 将文本按每两个字符分割
    pairs = [text[i:i+2] for i in range(0, len(text), 2)]
    # 用逗号连接分割后的文本
    result = ' '.join(pairs)
    return result
####以上函数用于将payload按两位分开，用','隔开#####

print(process_segment('GetMGNPot3,-3.5888855,97,'))
print(split_payload_into_pairs('47494f50010001004400000000000000e037000001002c0004000000010000000d0000006d6f63686147657444617461003c2377000000003700000016b0010000000000000000000000000002000000'))


GetMGNPot3 num - 02 23 9e d7 pos 07 num num + 61 num 
47 49 4f 50 01 00 01 00 44 00 00 00 00 00 00 00 e0 37 00 00 01 00 2c 00 04 00 00 00 01 00 00 00 0d 00 00 00 6d 6f 63 68 61 47 65 74 44 61 74 61 00 3c 23 77 00 00 00 00 37 00 00 00 16 b0 01 00 00 00 00 00 00 00 00 00 00 00 00 00 02 00 00 00


In [4]:
#使用map+lambda清洗数据
# clear_ezsocket_dataset = ezsocket_dataset.map(lambda x: {"Function and Parameters": x["Function and Parameters"].split(',', 1)[1]})
clear_ezsocket_dataset = ezsocket_dataset.map(lambda x: {"Function and Parameters": [o.split(',', 1)[1] for o in x["Function and Parameters"]]}, batched=True) #可加速处理，删除前面的时间戳
clear_ezsocket_dataset = clear_ezsocket_dataset.map(lambda x: {"Function and Parameters": [process_segment(o) for o in x["Function and Parameters"]]}, batched=True) #可加速处理，10进制参数token化
clear_ezsocket_dataset = clear_ezsocket_dataset.map(lambda x: {"Data Segment": [split_payload_into_pairs(o) for o in x["Data Segment"]]}, batched=True) #可加速处理，10进制参数token化

Map:   0%|          | 0/113871 [00:00<?, ? examples/s]

Map:   0%|          | 0/113871 [00:00<?, ? examples/s]

Map:   0%|          | 0/113871 [00:00<?, ? examples/s]

In [5]:
#划分训练集测试集和验证集
ezsocket_dataset_tt = clear_ezsocket_dataset["train"].train_test_split(train_size=0.8, seed=42)
ezsocket_dataset_tvt = ezsocket_dataset_tt["train"].train_test_split(train_size=0.9, seed=42)
ezsocket_dataset_tvt["validation"] = ezsocket_dataset_tvt.pop("test")
ezsocket_dataset_tvt["test"] = ezsocket_dataset_tt["test"]
ezsocket_dataset_tvt
#保存数据集使用：Arrow:	Dataset.save_to_disk()  CSV:	Dataset.to_csv()    JSON:	Dataset.to_json()

DatasetDict({
    train: Dataset({
        features: ['Function and Parameters', 'Data Segment'],
        num_rows: 81986
    })
    validation: Dataset({
        features: ['Function and Parameters', 'Data Segment'],
        num_rows: 9110
    })
    test: Dataset({
        features: ['Function and Parameters', 'Data Segment'],
        num_rows: 22775
    })
})

In [6]:
tokenizer=BertTokenizer(vocab_file='./vocab.txt')
src_tokenizer = BertTokenizer(vocab_file='./vocab.txt')
tgt_tokenizer = BertTokenizer(vocab_file='./vocab.txt')  # Or your target language tokenizer

In [7]:

# pretrained_model = CustomEncoderDecoderModel.from_encoder_decoder_pretrained("./bert_pretrained_100k/checkpoint-6400", "./bert_pretrained_100k/checkpoint-6400")
# Load your pre-trained model
pretrained_model = EncoderDecoderModel.from_encoder_decoder_pretrained("./gpt2_pretrained-500k/checkpoint-16000-500k", "./gpt2_pretrained-500k/checkpoint-16000-500k")

# Set up the decoder to match the target language vocabulary
pretrained_model.config.decoder_start_token_id = tgt_tokenizer.cls_token_id
pretrained_model.config.bos_token_id = tgt_tokenizer.cls_token_id
pretrained_model.config.eos_token_id = tgt_tokenizer.sep_token_id
pretrained_model.config.pad_token_id = tgt_tokenizer.pad_token_id

pretrained_model.config.vocab_size = tgt_tokenizer.vocab_size

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at ./gpt2_pretrained-500k/checkpoint-16000-500k and are newly initialized: ['transformer.h.0.crossattention.c_attn.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.0.crossattention.q_attn.bias', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.crossattention.c_attn.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.1.crossattention.q_attn.bias', 'transformer.h.1.crossattention.q_attn.weight', 'transformer.h.1.ln_cross_attn.bias', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.10.crossattention.c_attn.bias', 'transformer.h.10.crossattention.c_attn.weight', 'transformer.h.10.crossattent

In [8]:
max_input_length = 50
max_target_length = 210

def preprocess_function(examples):
    inputs = [ex for ex in examples["Function and Parameters"]]
    targets = [ex for ex in examples["Data Segment"]]
    model_inputs = src_tokenizer(inputs, max_length=max_input_length, truncation=True, padding="max_length")
    
    # Tokenize targets with the target tokenizer
    labels = tgt_tokenizer(targets, max_length=max_target_length, truncation=True, padding="max_length")

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = ezsocket_dataset_tvt.map(preprocess_function, batched=True)


Map:   0%|          | 0/81986 [00:00<?, ? examples/s]

Map:   0%|          | 0/9110 [00:00<?, ? examples/s]

Map:   0%|          | 0/22775 [00:00<?, ? examples/s]

In [9]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
)

In [10]:
import numpy as np
from datasets import load_metric

metric = load_metric("sacrebleu")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tgt_tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    labels = np.where(labels != -100, labels, tgt_tokenizer.pad_token_id)
    decoded_labels = tgt_tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]
    print(decoded_preds)
    print(decoded_labels)
    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": result["score"]}

  metric = load_metric("sacrebleu")


In [11]:
trainer = Seq2SeqTrainer(
    model=pretrained_model,
    args=training_args,
    train_dataset=tokenized_datasets["test"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tgt_tokenizer,
    # compute_metrics=compute_metrics
)

trainer.train()

  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)


Epoch,Training Loss,Validation Loss
1,No log,0.026758
2,0.060900,0.025523
3,0.025500,0.024776


  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)


TrainOutput(global_step=1068, training_loss=0.0419939813319217, metrics={'train_runtime': 1391.3749, 'train_samples_per_second': 49.106, 'train_steps_per_second': 0.768, 'total_flos': 4068320186880000.0, 'train_loss': 0.0419939813319217, 'epoch': 3.0})