In [1]:
import datasets
from datasets import Dataset, load_dataset, load_from_disk, concatenate_datasets, disable_caching
import itertools
import copy
import pandas as pd
import numpy as np
import math
import os
import sys

try:
    from tqdm.notebook import tqdm
except:
    from tqdm import tqdm

from utils import ProjectManager
PM = ProjectManager.ProjectManager()  # Instantiate the ProjectManager

# train ngram models
import nltk
from nltk.tag import map_tag
from nltk import ngrams
from nltk.probability import FreqDist, LaplaceProbDist

# Configure datasets and disable caching to manage memory usage
datasets.config.IN_MEMORY_MAX_SIZE = 12 * (1000**3)  # 12GB
disable_caching()

# Define keys for word and sentence level features
word_level_feature_keys = ['words', 'pos_tags', 'predicate_lemmas', 'predicate_framenet_ids', 'word_senses', 'named_entities']
sentence_level_feature_keys = ['part_id', 'parse_tree', 'speaker']
non_feature_keys = ['srl_frames', 'coref_spans']

## Sentence level features

In [2]:
def process_sentence_features(sentence_level_feature_keys=['part_id', 'parse_tree', 'speaker']):
    """
    Process and save sentence-level features from the original dataset.

    Args:
        sentence_level_feature_keys (list): List of features to extract at the sentence level.
    """
    print("Running process_sentence_features")

    # Define the file name for storing processed features
    fn_base_sent_features_hf = PM.directories["hf_files"] / PM.fn_base_sent_features.replace(".csv", ".hf")

    # Load the OntoNotes dataset
    original_dataset = load_dataset('conll2012_ontonotesv5', 'english_v12', keep_in_memory=True)
    original_dataset = concatenate_datasets([original_dataset[x] for x in list(original_dataset)])  # Concatenate the splits

    # Extract document IDs and sentences
    document_labels = original_dataset["document_id"]
    sentences = original_dataset["sentences"]

    # Prepare index mappings for sentences
    document_inds, part_inds, _document_ids, sentence_inds_2d = [], [], [], []
    for i in range(len(document_labels)):
        for j in range(len(sentences[i])):
            sentence_inds_2d.append((i, j))
            part_inds.append(j)
            document_inds.append(i)
            _document_ids.append(document_labels[i])

    sentence_inds = list(range(len(document_inds)))

    # Create a dictionary to hold the sentence-level features
    sentence_dict = {
        "sentence_idx": sentence_inds,
        "part_idx": part_inds,
        "document_idx": document_inds,
        "document_id": _document_ids,
    }

    # Add selected sentence-level features
    for k in sentence_level_feature_keys:
        sentence_dict[k] = [sentences[i][j][k] for i, j in sentence_inds_2d]

    sentence_dict["sentences"] = [sentences[i][j]["words"] for i, j in sentence_inds_2d]

    # Save the processed sentence-level features
    print("Saving datasets")
    sentence_level_dataset = Dataset.from_dict(sentence_dict)
    sentence_level_dataset.save_to_disk(fn_base_sent_features_hf)
    PM.save_dataset("sentence_features", PM.fn_base_sent_features, data=sentence_dict)

    print("Sentence level features processed!")

process_sentence_features()
sentence_features = PM.load_dataset("sentence_features")
sentence_features



Running process_sentence_features


Found cached dataset conll2012_ontonotesv5 (/home/ben/.cache/huggingface/datasets/conll2012_ontonotesv5/english_v12/1.0.0/c541e760a5983b07e403e77ccf1f10864a6ae3e3dc0b994112eff9f217198c65)


  0%|          | 0/3 [00:00<?, ?it/s]

Saving datasets


Saving the dataset (0/1 shards):   0%|          | 0/143709 [00:00<?, ? examples/s]

Sentence level features processed!


Unnamed: 0,sentence_idx,part_idx,document_idx,document_id,part_id,parse_tree,speaker,sentences
0,0,0,0,bc/cctv/00/cctv_0001,0,(TOP(SBARQ(WHNP(WHNP (WP What) (NN kind) )(PP...,Speaker#1,"['What', 'kind', 'of', 'memory', '?']"
1,1,1,0,bc/cctv/00/cctv_0001,0,(TOP(S(NP (PRP We) )(ADVP (RB respectfully) )(...,Speaker#1,"['We', 'respectfully', 'invite', 'you', 'to', ..."
2,2,2,0,bc/cctv/00/cctv_0001,0,(TOP(NP(NP(NP (NNP WW) (NNP II) (NNPS Landma...,Speaker#1,"['WW', 'II', 'Landmarks', 'on', 'the', 'Great'..."
3,3,3,0,bc/cctv/00/cctv_0001,0,(TOP(SINV(VP (VBG Standing) (S(ADJP (JJ tall) ...,Speaker#1,"['Standing', 'tall', 'on', 'Taihang', 'Mountai..."
4,4,4,0,bc/cctv/00/cctv_0001,0,(TOP(S(NP (PRP It) )(VP (VBZ is) (VP (VBN comp...,Speaker#1,"['It', 'is', 'composed', 'of', 'a', 'primary',..."
...,...,...,...,...,...,...,...,...
143704,143704,0,13104,wb/sel/97/sel_9789,-1,,,"['The', 'rain', 'slacked', 'down', '.']"
143705,143705,0,13105,wb/sel/97/sel_9799,-1,,,"['Encouraging', 'but', 'I', 'wo', ""n't"", 'let'..."
143706,143706,0,13106,wb/sel/98/sel_9809,-1,,,"['It', ""'s"", 'not', 'the', 'judge', ""'s"", 'job..."
143707,143707,0,13107,wb/sel/98/sel_9829,-1,,,"['I', 'wo', ""n't"", 'interject', 'personal', 't..."


## Word level features

In [3]:

def process_word_features(remove_keys=sentence_level_feature_keys + non_feature_keys):
    """
    Process and save word-level features from the original dataset.

    Args:
        remove_keys (list): List of feature columns to remove from the dataset.
    """
    print("Running process_word_features")

    # Load the dataset and concatenate the splits
    fn_base_sent_features_hf = PM.directories["hf_files"] / PM.fn_base_sent_features.replace(".csv", ".hf")
    original_dataset = load_dataset('conll2012_ontonotesv5', 'english_v12', keep_in_memory=True)
    original_dataset = concatenate_datasets([original_dataset[x] for x in list(original_dataset)])  
    original_dataset = original_dataset.remove_columns('document_id')

    # Extract sentences
    all_data = []
    for i in tqdm(range(original_dataset.num_rows)):
        selected = original_dataset.select([i])
        all_data += list(itertools.chain.from_iterable(selected["sentences"]))
    
    full_dataset = Dataset.from_list(all_data)

    # Remove non-relevant keys and process word-level data
    word_level_dataset = full_dataset.remove_columns(remove_keys)
    word_level_dict = {x: list(itertools.chain.from_iterable(word_level_dataset[x])) for x in word_level_dataset.column_names}
    PM.save_dataset("word_features", PM.fn_base_word_features, data=word_level_dict)

    # Create word indexes: word order, sentence index, word index
    print("Creating word indexes")
    sentences = load_from_disk(fn_base_sent_features_hf, keep_in_memory=True)["sentences"]
    word_level_dict["word_order"] = [i2 for i1, x in enumerate(sentences) for i2 in range(len(x))]
    word_level_dict["sentence_idx"] = [i1 for i1, x in enumerate(sentences) for i2 in range(len(x))]

    word_idx = []
    i = 0
    for _, x in enumerate(sentences):
        for _ in range(len(x)):
            word_idx.append(i)
            i += 1
    word_level_dict["word_idx"] = word_idx

    # Save the word indexes
    word_indexes = {x: word_level_dict[x] for x in ["word_order", "word_idx", "sentence_idx"]}
    PM.save_dataset("word_features", PM.fn_word_indexes, data=word_indexes)
    
    print("Word level features processed!")

process_word_features()
word_features = PM.load_dataset("word_features")
word_features




Running process_word_features


Found cached dataset conll2012_ontonotesv5 (/home/ben/.cache/huggingface/datasets/conll2012_ontonotesv5/english_v12/1.0.0/c541e760a5983b07e403e77ccf1f10864a6ae3e3dc0b994112eff9f217198c65)


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/13109 [00:00<?, ?it/s]

Creating word indexes
Word level features processed!


Unnamed: 0,word_order,word_idx,sentence_idx,words,pos_tags,predicate_lemmas,predicate_framenet_ids,word_senses,named_entities,pos_names,...,POS_7,POS_51_id,POS_12_id,POS_7_id,is_in_POS_6,function,tree_depth,unigram_probs,bigram_probs,trigram_probs
0,0,0,0,What,48,,,,0,WP,...,Noun,48,2,0,True,0,5,7.784044,14.373151,14.373151
1,1,1,0,kind,25,,,,0,NN,...,Noun,25,1,0,True,0,5,8.019463,11.305098,14.635171
2,2,2,0,of,18,,,,0,IN,...,Adposition,18,5,2,True,1,5,3.770704,8.346078,11.567118
3,3,3,0,memory,25,memory,,1.0,0,NN,...,Noun,25,1,0,True,0,6,10.404441,12.986857,14.635171
4,4,4,0,?,8,,,,0,.,...,X,8,11,6,False,2,3,6.262053,14.373151,14.635171
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2735679,19,2735679,143708,San,0,,,,0,XX,...,X,0,10,6,False,0,0,8.878384,13.680004,14.635171
2735680,20,2735680,143708,Luis,0,,,,0,XX,...,X,0,10,6,False,0,0,12.013879,14.373151,14.635171
2735681,21,2735681,143708,Obispo,0,,,,0,XX,...,X,0,10,6,False,0,0,14.153945,14.373151,14.635171
2735682,22,2735682,143708,journey,0,,,,0,XX,...,X,0,10,6,False,0,0,10.975891,14.373151,14.635171


In [4]:
def process_pos_features():
    """
    Process and save POS (part of speech) features.
    """
    word_features = PM.load_dataset("word_features")
    
    # POS tags
    POS_51_all_tags = PM.POS_51_all_tags
    POS_12_all_tags = PM.POS_12_all_tags
    POS_7_all_tags = PM.POS_7_all_tags
    POS_6_all_tags = PM.POS_6_all_tags
    
    # POS tag mappings
    POS_12_to_POS_7 = PM.POS_12_to_POS_7
    POS_51_tag_to_id = PM.POS_51_tag_to_id
    POS_12_tag_to_id = PM.POS_12_tag_to_id
    POS_7_tag_to_id = PM.POS_7_tag_to_id
    POS_6_tag_to_id = PM.POS_6_tag_to_id

    pos_names_dict = {i: x for i, x in enumerate(PM.POS_51_all_tags)}
    pos_names = [pos_names_dict[x] for x in word_features["pos_tags"]]

    # Convert POS tags to different granularities
    POS_51 = pos_names
    POS_12 = [map_tag('en-ptb', 'universal', x) for x in POS_51]
    POS_7 = [POS_12_to_POS_7[x] for x in POS_12]
    POS_51_id = [POS_51_tag_to_id[x] for x in POS_51]
    POS_12_id = [POS_12_tag_to_id[x] for x in POS_12]
    POS_7_id = [POS_7_tag_to_id[x] for x in POS_7]
    is_in_POS_6 = [x != "X" for x in POS_7]

    # Ensure lengths match
    assert len(POS_51) == len(POS_51_id) == len(POS_12) == len(POS_12_id) == len(POS_7) == len(POS_7_id) == len(is_in_POS_6)

    # Create a dictionary to store POS data
    POS_dict = {
        "pos_names": pos_names,
        "POS_51": POS_51,
        "POS_12": POS_12,
        "POS_7": POS_7,
        "POS_51_id": POS_51_id,
        "POS_12_id": POS_12_id,
        "POS_7_id": POS_7_id,
        "is_in_POS_6": is_in_POS_6
    }

    PM.save_dataset("word_features", PM.fn_pos_features, data=POS_dict)
process_pos_features()
word_features = PM.load_dataset("word_features")
word_features


Unnamed: 0,word_order,word_idx,sentence_idx,words,pos_tags,predicate_lemmas,predicate_framenet_ids,word_senses,named_entities,pos_names,...,POS_7,POS_51_id,POS_12_id,POS_7_id,is_in_POS_6,function,tree_depth,unigram_probs,bigram_probs,trigram_probs
0,0,0,0,What,48,,,,0,WP,...,Noun,48,2,0,True,0,5,7.784044,14.373151,14.373151
1,1,1,0,kind,25,,,,0,NN,...,Noun,25,1,0,True,0,5,8.019463,11.305098,14.635171
2,2,2,0,of,18,,,,0,IN,...,Adposition,18,5,2,True,1,5,3.770704,8.346078,11.567118
3,3,3,0,memory,25,memory,,1.0,0,NN,...,Noun,25,1,0,True,0,6,10.404441,12.986857,14.635171
4,4,4,0,?,8,,,,0,.,...,X,8,11,6,False,2,3,6.262053,14.373151,14.635171
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2735679,19,2735679,143708,San,0,,,,0,XX,...,X,0,10,6,False,0,0,8.878384,13.680004,14.635171
2735680,20,2735680,143708,Luis,0,,,,0,XX,...,X,0,10,6,False,0,0,12.013879,14.373151,14.635171
2735681,21,2735681,143708,Obispo,0,,,,0,XX,...,X,0,10,6,False,0,0,14.153945,14.373151,14.635171
2735682,22,2735682,143708,journey,0,,,,0,XX,...,X,0,10,6,False,0,0,10.975891,14.373151,14.635171


## Ngrams: add_nltk_ngram_frequencies

In [5]:
def find_ngram_word_frequencies(words):
    """
    Train n-gram models and calculate smoothed log probabilities for unigrams, bigrams, and trigrams.

    Args:
        words (list): List of words to calculate n-grams.

    Returns:
        dict: Dictionary containing log probabilities for unigrams, bigrams, and trigrams.
    """

    def train_ngram_model(words, n):
        """
        Train an n-gram model with Laplace smoothing and calculate log probabilities.

        Args:
            words (list): List of words.
            n (int): The order of the n-gram.

        Returns:
            dict: Log probabilities for the n-grams.
        """
        ngrams_list = list(ngrams(words, n))
        freq_dist = FreqDist(ngrams_list)
        laplace_prob_dist = LaplaceProbDist(freq_dist)
        return {ngram: -math.log(laplace_prob_dist.prob(ngram)) for ngram in ngrams_list}

    # Add start and end tokens to the word list
    words_for_ngram = ['<s>'] + words + ['</s>']

    # Train models for unigrams, bigrams, and trigrams
    unigram_log_probs = train_ngram_model(words_for_ngram, 1)
    bigram_log_probs = train_ngram_model(words_for_ngram, 2)
    trigram_log_probs = train_ngram_model(words_for_ngram, 3)

    # Generate n-grams
    unigrams = [tuple(words_for_ngram[i:i + 1]) for i in range(0, len(words_for_ngram) - 1)]
    bigrams = [tuple(words_for_ngram[i:i + 2]) for i in range(0, len(words_for_ngram) - 2)]
    trigrams = [tuple(words_for_ngram[i:i + 3]) for i in range(0, len(words_for_ngram) - 3)]

    def get_ngram_probs(list_of_ngrams, unigram_log_probs, bigram_log_probs, trigram_log_probs):
        """
        Retrieve log probabilities for a list of n-grams.

        Args:
            list_of_ngrams (list): List of n-grams.
            unigram_log_probs (dict): Log probabilities for unigrams.
            bigram_log_probs (dict): Log probabilities for bigrams.
            trigram_log_probs (dict): Log probabilities for trigrams.

        Returns:
            list: List of log probabilities for the input n-grams.
        """
        res = []
        for ngram in list_of_ngrams:
            if len(ngram) == 1:
                res.append(unigram_log_probs[ngram])
            elif len(ngram) == 2:
                res.append(bigram_log_probs[ngram])
            elif len(ngram) == 3:
                res.append(trigram_log_probs[ngram])
            else:
                print("Error: n-gram length greater than 3")
        return res

    # Get log probabilities for each n-gram type
    unigram_probs = get_ngram_probs(unigrams, unigram_log_probs, bigram_log_probs, trigram_log_probs)
    bigram_probs = get_ngram_probs(bigrams, unigram_log_probs, bigram_log_probs, trigram_log_probs)
    trigram_probs = get_ngram_probs(trigrams, unigram_log_probs, bigram_log_probs, trigram_log_probs)

    # Adjust to align n-gram probabilities
    assert len(unigram_probs) - 1 == len(bigram_probs) == len(trigram_probs) + 1
    unigram_probs = unigram_probs[1:]  # Remove the unigram for start symbol '<s>'
    trigram_probs = bigram_probs[:1] + trigram_probs  # Start with bigram for the first word

    print("First 3 unigrams:", unigrams[:3])
    print("First 3 bigrams:", bigrams[:3])
    print("First 3 trigrams:", trigrams[:3])

    print("Length of all_words:     ", len(words))
    print("Length of unigram_probs: ", len(unigram_probs))
    print("Length of bigram_probs:  ", len(bigram_probs))
    print("Length of trigram_probs: ", len(trigram_probs))

    # Ensure all lists have matching lengths
    assert len(words) == len(unigram_probs) == len(bigram_probs) == len(trigram_probs)

    return {
        "unigram_probs": unigram_probs,
        "bigram_probs": bigram_probs,
        "trigram_probs": trigram_probs
    }

def add_nltk_ngram_frequencies():
    """
    Calculate and save n-gram frequencies for word features.
    """
    print("Calculating n-gram frequencies")

    # Load word features dataset
    word_features = PM.load_dataset("word_features", PM.fn_base_word_features)

    # Find n-gram word frequencies
    word_frequencies = find_ngram_word_frequencies(list(word_features["words"]))

    # Save word-level n-gram frequencies
    PM.save_dataset("word_features", PM.fn_word_frequencies, data=word_frequencies)

add_nltk_ngram_frequencies()



Calculating n-gram frequencies


  data = pd.read_csv(file_path)


First 3 unigrams: [('<s>',), ('What',), ('kind',)]
First 3 bigrams: [('<s>', 'What'), ('What', 'kind'), ('kind', 'of')]
First 3 trigrams: [('<s>', 'What', 'kind'), ('What', 'kind', 'of'), ('kind', 'of', 'memory')]
Length of all_words:      2735684
Length of unigram_probs:  2735684
Length of bigram_probs:   2735684
Length of trigram_probs:  2735684


In [6]:
def add_function_vs_content():
    """
    Add a 'function' column to word features, marking function and content words.
    """
    # Load word features
    word_features = PM.load_dataset("word_features")

    # Use NLTK stopwords to define function words
    from nltk.corpus import stopwords
    stops = stopwords.words('english')
    all_stopwords = stops + ["'s", "'t", "'ve", "'ll", "'re", "'d", "'m", "n't"]

    # Mark words as function or content
    word_features["function"] = [1 if x in all_stopwords else 0 for x in word_features["words"]]

    # Mark punctuation as function words (e.g., ".")
    word_features.loc[word_features["POS_12"] == ".", "function"] = 2

    # Save the function/content word labels
    function = {"function": list(word_features["function"])}
    PM.save_dataset("word_features", PM.fn_function, data=function)
    
add_function_vs_content()


In [7]:
def add_tree_depth():
    """
    Calculate and add tree depth feature to word features (based on parse trees).
    """
    print("Calculating tree depth")

    # Load sentence and word features
    sentence_features = PM.load_dataset("sentence_features")
    word_features = PM.load_dataset("word_features")

    from nltk.tree import Tree

    tree_depth = []

    # Calculate tree depth for the first 5000 sentences
    for i, tree_str in tqdm(enumerate(sentence_features["parse_tree"])):
        word_feats = word_features[word_features["sentence_idx"] == i]

        if not isinstance(tree_str, str):
            # If no valid parse tree, assign depth of 0
            tree_depth.extend([0] * len(word_feats))
            continue

        _tree_depth = []
        tree = Tree.fromstring(tree_str)

        # Calculate depth for each word in the tree
        for j in range(len(tree.leaves())):
            depth = len(tree.leaf_treeposition(j))
            tree_depth.append(depth)
            _tree_depth.append(depth)

        # Break after processing 5000 sentences
        if i == 5000:
            break

        assert len(_tree_depth) == len(word_feats), f"Error in tree depth calculation for sentence {i}"

    # If tree depth list is shorter than the word features, fill the remaining with zeros
    diff = len(word_features) - len(tree_depth)
    tree_depth.extend([0] * diff)

    assert len(word_features) == len(tree_depth), "Tree depth list size mismatch"

    # Save tree depth data
    tree_depth = {"tree_depth": tree_depth}
    PM.save_dataset("word_features", PM.fn_tree_depth, data=tree_depth)

add_tree_depth()

Calculating tree depth


0it [00:00, ?it/s]

## preprocess_for_gpt_input

In [8]:
# ================================================
# Preprocessing for GPT Input
# ================================================

def preprocess_for_gpt_input(protected_tokens=["'s", "'t", "'ve", "'ll", "'re", "'d", "'m"]):
    """
    Preprocess data for GPT input by modifying words and spaces.

    Args:
        protected_tokens (list): List of tokens that need to be protected from modifications.
    """

    def add_model_words_and_append_spaces(df, protected_tokens=protected_tokens):
        """
        Add model words and adjust spacing based on punctuation and special tokens.

        Args:
            df (pd.DataFrame): DataFrame of word features.
            protected_tokens (list): List of protected tokens to preserve during preprocessing.
        """

        # Modify special token "n't" and shift "n" to preceding word
        nt_indexes = df.index[df['words'] == "n't"]
        model_words = list(df['words'])
        for i in nt_indexes:
            model_words[i] = "'t"
            model_words[i - 1] = model_words[i - 1] + "n"
        df['model_words'] = model_words

        # Convert parentheses and brackets into plaintext equivalents
        CONVERT_PARENTHESES = [
            ("(", "-LRB-"), (")", "-RRB-"),
            ("[", "-LSB-"), ("]", "-RSB-"),
            ("{", "-LCB-"), ("}", "-RCB-")
        ]
        for x, y in CONVERT_PARENTHESES:
            df['model_words'] = df['model_words'].str.replace(y, x)

        # Define indexes where space before/after should be removed
        remove_space_before_indexes = []
        remove_space_after_indexes = []

        # Remove spaces before specific punctuation and tokens
        for i, x in [(3, "''"), (5, ","), (7, "-RRB-"), (8, "."), (9, ":"), (30, "POS")]:
            remove_space_before_indexes += list(df.index[(df["pos_tags"] == i) & (df["pos_names"] == x)])

        # Remove spaces after specific tokens
        for i, x in [(1, "``"), (6, "-LRB-")]:
            remove_space_after_indexes += list(df.index[(df["pos_tags"] == i) & (df["pos_names"] == x)])

        # Remove spaces on both sides of hyphens
        for i, x in [(17, "HYPH")]:
            remove_space_before_indexes += list(df.index[(df["pos_tags"] == i) & (df["pos_names"] == x)])
            remove_space_after_indexes += list(df.index[(df["pos_tags"] == i) & (df["pos_names"] == x)])

        # Remove spaces at the start of each sentence
        remove_space_before_indexes += list(df.index[df["word_idx"] == 0])

        # Align and append indexes
        for i in remove_space_after_indexes:
            remove_space_before_indexes.append(i + 1)

        # Adjust space append logic
        df["append_space"] = True
        for i in list(set(remove_space_before_indexes)):
            df.at[i, "append_space"] = False

    # Load word features dataset
    df = PM.load_dataset("word_features")

    # Process words and spacing
    add_model_words_and_append_spaces(df)

    # Save processed GPT input data
    gpt_input = df[["model_words", "append_space", "sentence_idx", "word_idx", "word_order"]]
    PM.save_dataset("gpt_input", data=gpt_input)
    
preprocess_for_gpt_input()
gpt_input = PM.load_dataset("gpt_input")
gpt_input


Unnamed: 0,model_words,append_space,sentence_idx,word_idx,word_order
0,What,False,0,0,0
1,kind,True,0,1,1
2,of,True,0,2,2
3,memory,True,0,3,3
4,?,False,0,4,4
...,...,...,...,...,...
2735679,San,True,143708,2735679,19
2735680,Luis,True,143708,2735680,20
2735681,Obispo,True,143708,2735681,21
2735682,journey,True,143708,2735682,22
