In [6]:
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import pickle
import random

# Define the NetformerDatasetDownstream class
class NetformerDatasetDownstream(Dataset):
    def __init__(self, input_sequences, input_labels, input_segments, seq_len=2000):
        self.seq_len = seq_len
        self.session_flows = len(input_sequences)
        self.sessions = input_sequences
        self.segments = input_segments
        self.labels = input_labels
        self.special_token_dict = {'PAD': 0, 'MASK': 1028}
        self.mask_ratio = 0

    def __len__(self):
        return self.session_flows

    def __getitem__(self, item):
        s1, seg1, seq_label = self.get_session_flow(item)
        s1_random, s1_label, s1_idx = self.random_word(s1)
        segment_label = seg1
        netformer_input = s1_random
        netformer_label = s1_label
        netformer_idx = s1_idx

        output = {
            "netformer_input": netformer_input,
            "netformer_label": netformer_label,
            "netformer_idx": netformer_idx,
            "segment_label": segment_label,
            "sequence_label": seq_label
        }

        return {key: torch.tensor(value, dtype=torch.float32) for key, value in output.items()}

    def random_word(self, sentence):
        output_label = []
        output = []
        output_idx = []

        for i, token in enumerate(sentence):
            prob = random.random()
            if prob < self.mask_ratio:
                prob /= self.mask_ratio
                if prob < 0.8:
                    output.append(self.special_token_dict['MASK'])
                elif prob < 0.9:
                    output.append(self.random_selection(self.sessions))
                else:
                    output.append(token)
                output_label.append(token)
                output_idx.append(1)
            else:
                output.append(token)
                output_label.append(0)
                output_idx.append(0)

        assert len(output) == len(output_label)
        return output, output_label, output_idx

    def random_selection(self, input_sequences):
        rand_session = random.randrange(len(input_sequences))
        rand_flow = random.randrange(len(input_sequences[rand_session]))
        return input_sequences[rand_session][rand_flow]

    def get_session_flow(self, item):
        return self.sessions[item], self.segments[item], self.labels[item]

# Load the PKL dataset
with open('CIC2018-dataset-all-new.pkl', 'rb') as f:
    netformer_dataset = pickle.load(f)

In [8]:
# Function to sample the dataset
def sample_dataset(dataset, target_labels, samples_per_label):
    label_indices = {label: [] for label in target_labels}
    
    # Collect indices for each target label
    for idx in range(len(dataset)):
        label = int(dataset.labels[idx])
        if label in target_labels:
            label_indices[label].append(idx)
    
    sampled_indices = []
    for label in target_labels:
        if len(label_indices[label]) < samples_per_label:
            raise ValueError(f"Not enough samples for label {label}")
        sampled_indices.extend(random.sample(label_indices[label], samples_per_label))
    
    # Create a subset of the dataset
    return Subset(dataset, sampled_indices)

# Specify the target labels and samples per label
target_labels = [0, 1, 2, 4]
samples_per_label = 200

# Sample the dataset
sampled_dataset = sample_dataset(netformer_dataset, target_labels, samples_per_label)
# Relabel the dataset after sampling
def relabel_dataset(dataset, label_mapping):
    for idx in dataset.indices:
        item = dataset.dataset[idx]
        original_label = int(item['sequence_label'].item())
        new_label = label_mapping[original_label]
        item['sequence_label'] = torch.tensor(new_label, dtype=torch.float32)

label_mapping = {0: 0, 1: 1, 2: 2, 4: 3}
relabel_dataset(sampled_dataset, label_mapping)

# Function to convert subset to required format and save
def save_dataset_as_netformer_format(dataset, filename):
    output = {
        'netformer_input': [],
        'netformer_label': [],
        'netformer_idx': [],
        'segment_label': [],
        'sequence_label': []
    }
    
    for idx in dataset.indices:
        item = dataset.dataset[idx]
        output['netformer_input'].append(item['netformer_input'].tolist())
        output['netformer_label'].append(item['netformer_label'].tolist())
        output['netformer_idx'].append(item['netformer_idx'].tolist())
        output['segment_label'].append(item['segment_label'].tolist())
        output['sequence_label'].append(int(item['sequence_label'].item()))
    
    with open(filename, 'wb') as f:
        pickle.dump(output, f)

# Save the sampled and relabeled dataset into a PKL file
save_dataset_as_netformer_format(sampled_dataset, 'CIC2018-dataset-sampled-200.pkl')