In [1]:
!pip install torch

Collecting torch
  Downloading torch-2.4.0-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nv

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from embedding_new import PacketEmbedding


class PacketLevelEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_len, num_heads, num_layers, dropout):
        super(PacketLevelEncoder, self).__init__()

        # initialise the embedding layer
        self.embedding = PacketEmbedding(
            vocab_size, max_len, embed_dim, dropout)

        # initialsise the encoder from PyTorch
        self.encoder_layer = nn.TransformerEncoderLayer(
            embed_dim, num_heads, embed_dim * 4, dropout)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers)

        # initialise the mlm and sfbo predictor
        self.mlm_predictor = nn.Linear(embed_dim, vocab_size)
        self.sfbo_predictor = nn.Sequential(
            nn.Linear(embed_dim * 3, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, vocab_size)
        )

    def forward(self, packet_sequences, field_pos, header_pos):

        # splits the seq into masked seqs
        masked_packets, span_masks = apply_mlm_sfbo_masking(packet_sequences)
        packet_embeddings = self.embedding(masked_packets, field_pos, header_pos)
        encoded_packets = self.encoder(packet_embeddings)

        mlm_loss, mlm_logits = self.compute_mlm_loss(encoded_packets, masked_packets)
        # sfbo_loss = self.compute_sfbo_loss(
        #     encoded_packets, span_masks, packet_sequences)
        sfbo_loss = 0

        return mlm_loss, sfbo_loss, encoded_packets, masked_packets, mlm_logits 

    def compute_mlm_loss(self, encoded_packets, masked_packets):
        mlm_logits = self.mlm_predictor(encoded_packets)
        mlm_loss = F.cross_entropy(
            mlm_logits.view(-1, mlm_logits.size(-1)), masked_packets.view(-1))
        return mlm_loss, mlm_logits

    def compute_sfbo_loss(self, encoded_packets, span_masks, packet_sequences):
        sfbo_loss = 0
        total_spans = 0

        for i, (start_tokens, end_tokens, span_tokens) in enumerate(span_masks):
            if len(start_tokens) == 0:  # Skip if no spans for this sequence
                continue

            start_embeddings = encoded_packets[i, start_tokens].view(-1, encoded_packets.size(-1))
            end_embeddings = encoded_packets[i, end_tokens].view(-1, encoded_packets.size(-1))
            span_embeddings = encoded_packets[i, span_tokens].view(-1, encoded_packets.size(-1))

            # Ensure the correct shape
            print("start_embeddings shape:", start_embeddings.shape)
            print("end_embeddings shape:", end_embeddings.shape)
            print("span_embeddings shape:", span_embeddings.shape)

            # Concatenate the embeddings along the feature dimension (dim=-1)
            span_representations = torch.cat([start_embeddings, end_embeddings, span_embeddings], dim=-1)

            # Predict the entire span sequence
            sfbo_logits = self.sfbo_predictor(span_representations)

            # Flatten logits and targets
            flat_logits = sfbo_logits.view(-1, sfbo_logits.size(-1))
            flat_targets = packet_sequences[i, span_tokens].view(-1)

            # Calculate the loss for each token in the span
            sfbo_loss += F.cross_entropy(flat_logits, flat_targets)
            total_spans += len(span_tokens)

        return sfbo_loss / total_spans if total_spans > 0 else torch.tensor(0.0).to(packet_sequences.device)




def apply_sfbo_masking(packet_seq, sfbo_prob, max_span_length):
    num_tokens = packet_seq.size(0)
    sfbo_mask = torch.zeros(num_tokens, dtype=torch.bool)

    start_indices = []
    end_indices = []
    span_indices = []

    num_spans = max(1, int(sfbo_prob * num_tokens))
    print("num spans: ", num_spans)
    while num_spans > 0:
        start = random.randint(0, num_tokens - 1)
        span_length = random.randint(1, max_span_length)
        end = min(start + span_length, num_tokens) - 1
        if not sfbo_mask[start:end + 1].any():
            start_indices.append(start)
            end_indices.append(end)
            span_indices.extend(range(start, end + 1))
            sfbo_mask[start:end + 1] = True
            num_spans -= 1
    
    start_indices_tensor = torch.tensor(start_indices, dtype=torch.long)
    end_indices_tensor = torch.tensor(end_indices, dtype=torch.long)
    span_indices_tensor = torch.tensor(span_indices, dtype=torch.long)
    print("Number of spans selected:", len(start_indices_tensor))

    return start_indices_tensor, end_indices_tensor, span_indices_tensor

def apply_mlm_sfbo_masking(packet_sequences, mlm_prob=0.15, sfbo_prob=0.15, max_span_length=6):
    batch_size, seq_length, _ = packet_sequences.size()
    masked_sequences = packet_sequences.clone()

    # Apply MLM masking
    mlm_mask = torch.rand(batch_size, seq_length) < mlm_prob
    masked_sequences[mlm_mask] = torch.tensor(4).to(packet_sequences.device)

    # Apply SFBO masking
    span_masks = []
    for packet_seq in masked_sequences:
        span_mask = apply_sfbo_masking(packet_seq, sfbo_prob, max_span_length)
        span_masks.append(span_mask)

    return masked_sequences, span_masks


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from Input_Tokenizer import Tokenizer
from embedding_new import PacketEmbedding, FlowEmbedding
# from packet_encoder import PacketLevelEncoder, apply_mlm_sfbo_masking
# from flow_encoder import FlowLevelEncoder
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
from field_header_pos_encoding import field_pos, header_pos

In [3]:
# Hyperparameters
vocab_size = 30000
embed_dim = 768
num_heads = 12
num_layers = 6
dropout = 0.1
max_flow_length = 510
mask_prob = 0.15
num_epochs = 10
max_len=512

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
device

device(type='cuda')

In [6]:
# loading vocabulary
vocab = {}
custom_vocab_path = r"/home/satvik/spark/spark/vocab_1.txt"
with open(custom_vocab_path, 'r', encoding='utf-8') as f:
    for line in f:
        token, token_id = line.strip().split('\t')
        vocab[token] = int(token_id)

In [7]:
# initialise the tokenizer
tokenizer = Tokenizer(vocab_file=custom_vocab_path)

In [26]:
# Initialize packet embedding and encoding modules
packet_embedding = PacketEmbedding(
    vocab_size, max_len=512, embed_dim=embed_dim, dropout=dropout).to(device)
packet_encoder = PacketLevelEncoder(   
    vocab_size, embed_dim, max_len, num_heads, num_layers, dropout).to(device)



In [9]:
optimizer = optim.Adam(list(packet_embedding.parameters(
)) + list(packet_encoder.parameters()) , lr=0.001)
criterion = nn.CrossEntropyLoss()

In [10]:
# data loader


class PacketSequenceDataset(Dataset):
    def __init__(self, packet_seq_dir, field_pos_dir, header_pos_dir, tokenizer):
        self.packet_seq_dir = packet_seq_dir
        self.field_pos_dir = field_pos_dir
        self.header_pos_dir = header_pos_dir
        self.tokenizer = tokenizer
        self.packet_sequences = []
        self.field_pos_sequences = []
        self.header_pos_sequences = []

        # Preprocess data
        packet_seq_files = [os.path.join(packet_seq_dir, file) for file in os.listdir(
            packet_seq_dir) if file.endswith('.txt')]
        field_pos_files = [os.path.join(field_pos_dir, file) for file in os.listdir(
            field_pos_dir) if file.endswith('.txt')]
        header_pos_files = [os.path.join(header_pos_dir, file) for file in os.listdir(
            header_pos_dir) if file.endswith('.txt')]

        for packet_seq_file, field_pos_file, header_pos_file in zip(packet_seq_files, field_pos_files, header_pos_files):
            with open(packet_seq_file, 'r', encoding='utf-8') as f:
                hex_dumps = f.readlines()
                padded_all_tokens, token_ids, mask, max_length = self.tokenizer.encode_packet(
                    hex_dumps)
                self.packet_sequences.append(token_ids.to(device))
                print("td :", token_ids.shape)

            '''
            with open(packet_seq_file, 'r', encoding='utf-8') as f:
                for line in f:
                    tokens = self.tokenizer.encode_packet(
                        line.strip(), add_special_tokens=True, truncation=True, padding='max_length')
                    self.packet_sequences.append(tokens)
                '''
            field_posn = field_pos(field_pos_file).to(device).long()
            print("fp: ", field_posn.shape) 
            self.field_pos_sequences.append(field_posn)

            header_posn = header_pos(header_pos_file).to(device).long()
            print("hp :", header_posn.shape)
            self.header_pos_sequences.append(header_posn)

    def __len__(self):
        return len(self.packet_sequences)


    def __getitem__(self, idx):
        return (
        self.packet_sequences[idx],
        self.field_pos_sequences[idx],
        self.header_pos_sequences[idx]
        )

In [11]:
import importlib
importlib.invalidate_caches()

In [12]:
# Create the dataset and data loader
dataset = PacketSequenceDataset(
    packet_seq_dir=r'/home/satvik/spark/spark/packets', field_pos_dir='/home/satvik/spark/spark/fields', header_pos_dir='/home/satvik/spark/spark/headers', tokenizer=tokenizer)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

td : torch.Size([4, 469])
field: torch.Size([4, 469])
fp:  torch.Size([4, 469])
header: torch.Size([4, 469])
hp : torch.Size([4, 469])


In [27]:
# Training loop
for epoch in range(num_epochs):
    print("epoch 1")
    for packet_sequences, field_pos, header_pos in train_loader:
        
        # Forward pass
        print("1")
        field_pos = field_pos.squeeze(0)
        header_pos = header_pos.squeeze(0)

        # for i in range(4):
        #     print(field_pos[i][:])
        print(header_pos.shape)
        # print("packet seq :", packet_sequences.shape)
        # print(packet_sequences)
        # packet_sequences = packet_sequences.squeeze(0)
        print("packet seq shape: ", packet_sequences.shape)

        mlm, sfbo, enc, mask, mlm_log = packet_encoder(packet_sequences, field_pos, header_pos)

        print('mlm:', mlm)
        print(sfbo)
        print(mask.shape)
        print(enc.shape)
        print(mlm_log.shape)
        break
    break
        # packet_seq_cpu = packet_sequences.cpu()
        # packet_seq_np = packet_seq_cpu.numpy()

        # # Open a file to write the data
        # with open('/home/satvik/spark/spark/packet_seq.txt', 'w') as file:
        #     for row in packet_seq_np:
        #         # Convert each row to a space-separated string and write it to the file
        #         row_str = ' '.join(map(str, row))
        #         file.write(row_str + '\n')

        # print("Tensor data has been written to 'packet_seq.txt'.")

        # Generate masked tokens
        # masked_packets, span_masks = apply_mlm_sfbo_masking(packet_sequences)
        # print("masked packets shape: ", masked_packets.shape)
        # print("span masks: ", span_masks)

        # # moving to same device
        # packet_sequences = packet_sequences.to(device)
        # fiels_pos = field_pos.to(device)
        # header_pos = header_pos.to(device)
        # masked_packets = masked_packets.to(device)
        
        # # compute packet embeddings
        # packet_embeddings = packet_embedding(
        #     token_ids=masked_packets, field_pos=field_pos, header_pos=header_pos)
        # print("pe shape: ", packet_embeddings.shape)
        
        # # calucate loss
        # mlm_loss = packet_encoder.compute_mlm_loss(packet_embeddings, masked_packets)
        # sfbo_loss = packet_encoder.compute_sfbo_loss(packet_embeddings, span_masks, packet_sequences)
        # # flow_encodings, mpm_losses = flow_encoder(flow_sequences)
        # print("mlm loss: ", mlm_loss.item())
        # print("sfbo_loss: ", sfbo_loss.item())
        # total_loss = mlm_loss  + sfbo_loss
        # print("total loss: ", total_loss.item())
        
        # # Compute total loss
        # # total_loss = mlm_loss + sfbo_lossD
        # ''' + sum(mpm_losses)'''
        
        # optimizer.zero_grad()
        # total_loss.backward()
        # optimizer.step()

epoch 1
1
torch.Size([4, 469])
packet seq shape:  torch.Size([1, 4, 469])
num spans:  1
Number of spans selected: 1
token ids : torch.Size([4, 469])
num packets:  4
seq len:  469
torch.Size([4, 469, 768])
torch.Size([4, 469, 768])
torch.Size([4, 469, 768])
torch.Size([4, 469, 768])
mlm: tensor(10.4134, device='cuda:0', grad_fn=<NllLossBackward0>)
0
torch.Size([1, 4, 469])
torch.Size([4, 469, 768])
torch.Size([4, 469, 30000])
