#### Helper stuff

In [85]:
import numpy as np
from pathlib import Path
from collections import Counter
from scipy.stats import entropy
from scipy.spatial.distance import jensenshannon
from nlp_cyber_ner.dataset import read_iob2_file
from nlp_cyber_ner.dataset import read_aptner, read_cyner, remove_leakage
from nlp_cyber_ner.config import PROCESSED_DATA_DIR, RAW_DATA_DIR, INTERIM_DATA_DIR, TOKENPROCESSED_DATA_DIR
import spacy
from spacy.tokens import Doc



#### We will derive 3 matrices:
- dev to dev 
- train to train
- train (deduplicated) to dev

Interpretations:
- dev to dev: divergence between data that the model was evaluated - one from the same population as the training dataset, and another from a foreign one.
- train to train: divergence between the data that the model saw during training, vs the training data it would've (probably ideally - reason, 
I'm saying probably is because we may have a scenario where one model ends up performing better on a foreign dev set, than the model training
on the corresponding training set) liked to see during training. Relevant for comparison with model on the diagonal (a model that has the same train and dev).
- train (deduplicated) to dev - difference between the data the model was trained on, and the data the model was evaluated on.

#### Divergences for labels - token based counts

Potential problem here: The function is simply stripping prefixes and would count an entity of 3 different tokens with B-, I-, I- as 3 instances of that label - probably not good - should be fixed. I will create a version WITH and WITHOUT this fix.

In [86]:
def extract_entity_label(label):
    """
    If the label is in IOB2 format (e.g., "B-malware", "I-malware"),
    strip off the "B-" or "I-" prefix. Otherwise, return the label as is.
    """
    for prefix in ["B-", "I-"]:
        if label.startswith(prefix):
            return label[len(prefix):]
    return label

def get_label_distribution_ignoringprefix(data, label_set):
    """
    Given data (a list of sentences, where each sentence is a tuple of lists for words and labels, respectively),
    compute the distribution of the entity types specified in label_set.
    """
    counter = Counter()
    for sent in data:
        for label in sent[1]:
            # extract the entity label (removing IOB2 prefix, if any)
            ent_label = extract_entity_label(label)
            if ent_label in label_set:
                counter.update([ent_label])

    distribution = np.array([counter[l] for l in label_set], dtype=np.float64)
    # Add a tiny smoothing constant to avoid division by zero - probably not needed
    print(distribution)
    distribution += 1e-10
    distribution /= distribution.sum()
    return distribution

#### Divergence for labels - span based counts
A given span of labels for an entity only counts towards one occurance of that underlying entity category - i.e. concecutive B-MALWARE, I-MALWARE, I-MALWARE, would count as one MALWARE.

In [87]:
# This entire function can probably just be removed:

def extract_entity_labelandprefix(label):
    """
    If the label is in IOB2 format (e.g., "B-Malware", "I-Malware"),
    strip off the "B-" or "I-" prefix. Otherwise, return the label as is with "O"
    """
    for prefix in ["B-", "I-"]:
        if label.startswith(prefix):
            relevant_prefix = prefix
            return label[len(prefix):], relevant_prefix
    return label, "O"

def get_label_span_distribution(data, label_set):
    """
    Given data (a list of sentences, where each sentence is a tuple of lists for words and labels, respectively),
    compute the distribution of the entity types specified in label_set.
    """
    counter = Counter()
    for sent in data:
        previous_prefix = "O" #assuming entities don't span sentence boundaries
        for label in sent[1]:
            # extract the entity label (removing IOB2 prefix, if any)
            ent_label, prefix = extract_entity_labelandprefix(label)
            if (ent_label in label_set and prefix == "B-") or (ent_label in label_set and previous_prefix == "O"): 
                #the second case is if there happens to be an error in the annotation where the first prefix of the labeled span used prefix
                #I rather than B - feel like we caught some of that in our manual error analysis, but maybe I'm misremembering...
                counter.update([ent_label])
            previous_prefix = prefix

    distribution = np.array([counter[l] for l in label_set], dtype=np.float64)
    # Add a tiny smoothing constant to avoid division by zero - probably not needed
    distribution += 1e-10
    distribution /= distribution.sum()
    return distribution


In [None]:
"""
With this statement:  or (ent_label in label_set and previous_prefix == "O")
[1221. 2450. 4422.  828.] - DNRTI 
[ 548.  822. 1390.  197.] - ATTACKNER 
[3119.  205. 4747.  484.] - APTNER 
[705. 838. 287.  48.] - CYNER

Without this statement: or (ent_label in label_set and previous_prefix == "O")
[1221. 2449. 4416.  828.] - DNRTI 
[ 548.  822. 1390.  197.] - ATTACKNER 
[3109.  205. 4742.  483.] - APTNER 
[703. 837. 284.  48.] - CYNER

Indicates that some labeled entities in the ground truth data start their span with I- rather than B- which 
"or (ent_label in label_set and previous_prefix == "O")" catches, since "(ent_label in label_set and prefix == "B-")" would catch everything
that starts with B-.
Very few cases of this though and not necessarily a problem - seems most prevelant in APTNER, likely because we remapped some prefixes.
"""

#### Divergence for label distributions w. prefixes
Probably worth doing this since if datasets have vastly different lengths of their annotated entities, then this could highlight that? As an example, a malware entity that consists of 5 tokens would contribute a count of +1 to B-Malware and +4 to I-Malware

In [89]:
def get_label_distribution_wprefix(data):
    """
    Given data (a list of sentences, where each sentence is a tuple of lists for words and labels, respectively),
    compute the distribution of the entity types specified in label_set.
    """
    counter = Counter()
    for sent in data:
        for label in sent[1]:
            if label.startswith("B-") or label.startswith("I-"):
                counter.update([label])
    temp = np.sort(list(counter.keys()))
    distribution = np.array([counter[i] for i in temp], dtype=np.float64)
    # Add a tiny smoothing constant to avoid division by zero - probably not needed
    print(distribution)
    distribution += 1e-10
    distribution /= distribution.sum()
    return distribution

#### Divergence for word distributions

Currently the common vocabulary is just being constructed by considering all the datasets as one, and finding the 10000 most common words - there may be a better way of doing this - need to check the paper provided by Rob over email.

In [90]:


# Function to construct a common vocabulary across multiple datasets.
def get_global_vocab(datasets, max_vocab_size=None):
    """
    Given a list of datasets (each is a list of sentences, where each sentence 
    is a tuple (words, labels)), count the word frequencies and return the list
    of words from the most common to the less common.
    """
    global_counter = Counter()
    for data in datasets:
        for sent in data:
            # sent[0] contains the words
            global_counter.update(sent[0])
    if max_vocab_size is not None:
        vocab = [word for word, _ in global_counter.most_common(max_vocab_size)]
    else:
        vocab = list(global_counter.keys())
    return vocab

# Function to compute a word frequency distribution given a dataset and a fixed vocabulary.
def get_word_distribution(data, vocab):
    """
    Given data (a list of sentences, each sentence is a tuple (words, labels))
    and a fixed common vocabulary, compute a normalized frequency distribution
    of the words in 'vocab' for this dataset.
    """
    counter = Counter()
    for sent in data:
        counter.update(sent[0])
    # Create a vector in the same order as the vocabulary.
    distribution = np.array([counter[word] for word in vocab], dtype=np.float64)
    # Smoothing to avoid zero entries
    distribution += 1e-10
    # Normalize to form a probability distribution
    distribution /= distribution.sum()
    return distribution

#### Divergence for POS distribution
Using some off-the-shelf model for this to make predictions for the labels. Obviously we have no way of knowing the accuracy on the datasets so we're just gonna assume that the model is equally shit / good on each dataset - should be fine given we also have other metrics.

In [None]:
#https://huggingface.co/spacy/en_core_web_sm/tree/main
#https://spacy.io/api/attributeruler, https://github.com/explosion/spaCy/issues/5637
!python -m spacy download en_core_web_sm
!python -m spacy validate

In [92]:
def get_pos_distributions(data):
    """
    If target data is provided, it will deduplicate data, such that it removes sentences from data that are in targetdata, before computing
    the distribution.
    """
    nlp = spacy.load("en_core_web_sm")
    #https://stackoverflow.com/a/71491525/24251578
    #I can find the actual mapping from tags to more coarse POS tags that we're using here. 
    #Under the model's label scheme, you can only find the more fine-grained pos tags: https://spacy.io/models/en#en_core_web_sm.
    #There is a list of more coarse POS tags here, but they don't include the full output of this model: https://spacy.io/usage/linguistic-features#pos-tagging

    #Seems you can find it here but I ain't reading all that.
        #attr_ruler = nlp.get_pipe("attribute_ruler")
        #pprint.pprint(attr_ruler.patterns)
        #We'll just assume that each dataset ends up with the same set of labels used - which seems to be the case. The only reason this is 
        #important is because the distribution for JS needs to be the same dimension.
    counter = Counter()
    for sent in data:
        doc = Doc(nlp.vocab, sent[0])
        for token in nlp(doc):
            counter.update([token.pos_])
    temp = np.sort(list(counter.keys()))
    distribution = np.array([counter[i] for i in temp], dtype=np.float64)
    # Add a tiny smoothing constant to avoid division by zero - probably not needed
    print(distribution)
    distribution += 1e-10
    distribution /= distribution.sum()
    return distribution

#### Create one for span length distributions

In [93]:
def get_spanlength_distribution(data):
    """
    Given data (a list of sentences, where each sentence is a tuple of lists for words and labels, respectively),
    compute the distribution of the labeled span lengths for entities in the dataset.
    """
    prefixes = ["B-", "I-"]
    counter = Counter({"1":0, "2":0, "3":0, "4":0, "5":0, "length>5":0})
    for sent in data:
        previous_prefix = "O"
        length = 0 #this is assuming that spans don't go across sentence boundaries - which must be true?
        for label in sent[1]:
            if (label.startswith("B-")) or (label.startswith("I-") and previous_prefix == "O"):
                for prefix in prefixes:
                    if label.startswith(prefix):
                        previous_prefix = prefix
                        break
                if length > 5:
                    counter["length>5"] += 1
                    length = 1
                    continue
                elif length >= 1:
                    counter[str(length)] += 1
                    length = 1
                    continue
            elif (label == "O" and previous_prefix in prefixes):
                if length > 5:
                    counter["length>5"] += 1
                    length = 0
                    continue
                elif length >= 1:
                    counter[str(length)] += 1
                    length = 0
                    continue              
            elif label.startswith("I-") and previous_prefix != "O":
                previous_prefix = "I-"
                length += 1   
    temp = np.sort(list(counter.keys()))
    distribution = np.array([counter[i] for i in temp], dtype=np.float64)
    # Add a tiny smoothing constant to avoid division by zero - probably not needed
    distribution += 1e-10
    distribution /= distribution.sum()
    return distribution


#### Compute JS and KL divergence

In [94]:
def compute_divergence_matrix(distributions: dict):
    dataset_names = list(distributions.keys())
    n_datasets = len(dataset_names)

    #matrices for kuhl and jensen
    js_matrix = np.zeros((n_datasets, n_datasets))
    kl_matrix = np.zeros((n_datasets, n_datasets))


    # For KL divergence, we compute the symmetric version: 0.5*(KL(P||Q) + KL(Q||P)). If not, the distance between two pairs of datasets may not
    # be the same depending on the 'perspective'
    for i in range(n_datasets):
        for j in range(n_datasets):
            P = distributions[dataset_names[i]]
            Q = distributions[dataset_names[j]]
            # Jensen-Shannon distance using SciPy (this is already symmetric)
            js_matrix[i, j] = jensenshannon(P, Q)
            # Symmetric KL divergence
            kl_sym = 0.5 * (entropy(P, Q) + entropy(Q, P))
            kl_matrix[i, j] = kl_sym
    return js_matrix, kl_matrix


#### Load data

In [95]:
dnrti_path = PROCESSED_DATA_DIR / "dnrti"
dnrti_train_path = dnrti_path / "train.unified"
dnrti_train_data = read_iob2_file(dnrti_train_path, word_index=0, tag_index=1)
dnrti_dev_path = dnrti_path / "valid.unified"
dnrti_dev_data = read_iob2_file(dnrti_dev_path)

attackner_path = PROCESSED_DATA_DIR / "attackner"
attackner_train_path  = attackner_path / "train.unified"
attackner_train_data = read_iob2_file(attackner_train_path, word_index=0, tag_index=1)
attackner_dev_path = attackner_path / "dev.unified"
attackner_dev_data = read_iob2_file(attackner_dev_path)

aptner_path = PROCESSED_DATA_DIR / "APTNer"
aptner_train_path = aptner_path / "APTNERtrain.unified"
aptner_train_data = read_iob2_file(aptner_train_path)
aptner_dev_path = aptner_path / "APTNERdev.unified"
aptner_dev_data = read_iob2_file(aptner_dev_path)

cyner_path = PROCESSED_DATA_DIR / "cyner"
cyner_train_path = cyner_path / "train.unified"
cyner_train_data = read_iob2_file(cyner_train_path)
cyner_dev_path = cyner_path / "valid.unified"
cyner_dev_data = read_iob2_file(cyner_dev_path)

#mismatched distribution dimensions?
#some of these functions don't have a predefined set of categories, but if it does happen that some datasets miss certain values, it'll cause
#errors during the JS calculations anyway - It can't be different dimensions, and it doens't seem to be the case that any of the datasets are 
#completely missing a certain category - can be fixed by just defining a fixed category set, but cba, since it all works.

#### Getting the numbers:

In [96]:
train_datasets = [
    ("DNRTI", dnrti_train_data),
    ("AttackNER", attackner_train_data),
    ("APTNER", aptner_train_data),
    ("CyNER", cyner_train_data)
]

dev_datasets = [
    ("DNRTI", dnrti_dev_data),
    ("AttackNER", attackner_dev_data),
    ("APTNER", aptner_dev_data),
    ("CyNER", cyner_dev_data)
]

LABEL_SET = ["Malware", "System", "Organization", "Vulnerability"]

all_datasets = [dnrti_train_data, attackner_train_data, aptner_train_data, cyner_train_data, dnrti_dev_data, attackner_dev_data, aptner_dev_data, cyner_dev_data]
common_vocab = get_global_vocab(all_datasets, max_vocab_size=10000)

distribution_methods = {
    "IgnoringPrefix":  lambda data: get_label_distribution_ignoringprefix(data, LABEL_SET),
    "SpanBased":       lambda data: get_label_span_distribution(data, LABEL_SET),
    "WithPrefix":      lambda data: get_label_distribution_wprefix(data),
    "WordDist":        lambda data: get_word_distribution(data, common_vocab),
    "POSDist":         lambda data: get_pos_distributions(data),
    "SpanLengthDist":  lambda data: get_spanlength_distribution(data)
}

for method_name, dist_func in distribution_methods.items():
    print("\n\n===========================================")
    print(f"=== DISTRIBUTION METHOD: {method_name} ===")
    print("===========================================")

    # -------------------------------------------
    # A) TRAIN-TO-TRAIN
    # -------------------------------------------
    train_distributions = {}
    for (name, data) in train_datasets:
        train_distributions[name] = dist_func(data)

    js_train_train, kl_train_train = compute_divergence_matrix(train_distributions)
    print("\n--- Train-to-Train ---")
    print("Jensen-Shannon Distance Matrix:")
    print(js_train_train)
    print("\nSymmetric KL Divergence Matrix:")
    print(kl_train_train)

    # -------------------------------------------
    # B) DEV-TO-DEV
    # -------------------------------------------
    dev_distributions = {}
    for (name, data) in dev_datasets:
        dev_distributions[name] = dist_func(data)

    js_dev_dev, kl_dev_dev = compute_divergence_matrix(dev_distributions)
    print("\n--- Dev-to-Dev ---")
    print("Jensen-Shannon Distance Matrix:")
    print(js_dev_dev)
    print("\nSymmetric KL Divergence Matrix:")
    print(kl_dev_dev)

    # -------------------------------------------
    # C) TRAIN(deduplicated)-TO-DEV
    # -------------------------------------------
    n = len(train_datasets)
    js_train_dev = np.zeros((n, n))
    kl_train_dev = np.zeros((n, n))

    for i in range(n):
        train_name, train_data = train_datasets[i]
        for j in range(n):
            dev_name, dev_data = dev_datasets[j]

            clean_data, _ = remove_leakage(train_data, dev_data)
            P = dist_func(clean_data)
            Q = dist_func(dev_data)

            js_train_dev[i, j] = jensenshannon(P, Q)
            kl_train_dev[i, j] = 0.5 * (entropy(P, Q) + entropy(Q, P))

    print("\n--- Train(deduplicated)-to-Dev ---")
    print("Jensen-Shannon Distance Matrix:")
    print(js_train_dev)
    print("\nSymmetric KL Divergence Matrix:")
    print(kl_train_dev)
    print("===========================================================")



=== DISTRIBUTION METHOD: IgnoringPrefix ===
[1777. 3835. 5866. 1613.]
[ 901. 1886. 2608.  914.]
[4252.  232. 6135.  516.]
[ 898. 1321.  400.   90.]

--- Train-to-Train ---
Jensen-Shannon Distance Matrix:
[[0.         0.02911637 0.33555624 0.2947223 ]
 [0.02911637 0.         0.34469139 0.28697856]
 [0.33555624 0.34469139 0.         0.44287795]
 [0.2947223  0.28697856 0.44287795 0.        ]]

Symmetric KL Divergence Matrix:
[[0.         0.00339337 0.53514412 0.3627232 ]
 [0.00339337 0.         0.56385104 0.34459174]
 [0.53514412 0.56385104 0.         1.00717131]
 [0.3627232  0.34459174 1.00717131 0.        ]]
[280. 473. 810. 216.]
[133. 301. 331. 151.]
[444.  81. 600.  19.]
[265. 273. 162.  14.]

--- Dev-to-Dev ---
Jensen-Shannon Distance Matrix:
[[0.         0.07837468 0.28101456 0.2596976 ]
 [0.07837468 0.         0.34131368 0.25890805]
 [0.28101456 0.34131368 0.         0.29883119]
 [0.2596976  0.25890805 0.29883119 0.        ]]

Symmetric KL Divergence Matrix:
[[0.         0.024632