# Next Type - A Mobile Typing Assistant
[CS 505 - NLP] [Final Project]
Completed by Muhammad Aseef Imran
---

### Problem Statement

Compared to typing on a keyboard, typing on our phones considerably slower. Luckily most phones come built in with a feature to predict your next word to make up for this. However, after much experimentation, it seems most of these prediction algorithms seem to use simple N-grams with a window of around 3-4 words of left context. Many times, predicting the next word based on simple N-gram based probabilities work "fine" but often, this strategy produces poor results due to ignoring the context of the sentence. As NLP technology leaps forward exponentially we can do much better than simplistic N-grams. For this reason **Next Type** aims to bring the next generation of typing experiance to its users in order to raise productivity and typing speed for users leaving more time for the important things in life!

<div style="display: flex; justify-content: space-between;">
    <img src="assets/bad-phone-example-1.jpg" alt="N-gram example 1" style="width: 33%; padding-right: 10px;">
    <img src="assets/bad-phone-example-2.jpg" alt="N-gram example 2" style="width: 33%; padding-right: 10px;">
    <img src="assets/bad-phone-example-3.jpg" alt="N-gram example 3" style="width: 33%;">
</div>

Presented above are three examples where unfortunately, the standard N-gram model employed on "modern" devices fails miserably.
___

### Project Goals

N-gram models perform very poorly with shorter context window, and any large context windows require an immense amount of data to produce reasonable results. Moreover, at least in this case, our N-gram models do not consider the right context when suggesting words. However, despite having their weaknesses, the N-gram approach also has some nice benefits. Particularly, the nature of the N-gram model allows it to be easily updated with new data and adapt to the typing habits of its users with little processing power.

Keeping these things in mind, we can summarize our goals for the model we have set out to develop as:
1. The model should consider both the left and right context before suggesting a "natural" word that fits the context.
2. The model should be able to adapt to or learn from the user's word choices in various contexts.
3. The model should be able to learn the user's word choice as—outlined above—quick enough to be useful.
4. The model should be reasonably sized allowing it to be run and be updated on most modern-phone hardware with in reasonable time.

Additional things that may be considered (depending on time) in our model may be that:

5. The model should validate the input data for grammatical correctness to avoid learning incorrect patterns.
6. The model should (optionally) avoid suggesting profane language.

### Key Questions & Challenges

Let us better define the problem and explain exactly what we are seeking to achieve and possible challanges.

1. People evolve and change. So can I get a better accuracy by consider ALL known history? Or only "recent" history? Should the more recent history be weighted more"? If so, should history be judged more by time or the volume typed? If I'm wrong about this, then more data = better. OR the time people change in is just longer.
2. My problem is my data is ever evolving. So how do I prevent the model from overfitting? Overfitting will make newer data harder to generalize. (i.e. How do I prevent it from forgetting the stuff it knew in the base model?). Also suppose if I don't want my model to consider the "old" history, how can I make the model "forgot" the old stuff (some kind of vanishing gradient is my best choice)?

Shortcomings of this research: data may be biased. Only considering reddit users. Particularly those who type somewhat frequently on reddit.

### Obtaining the Fine-Tuning Data

It is important for my task that the fine-tunning data only come from a single person since my goal is to train the model to adapt to the typing behavior of a specific user. In order to achieve this we can use scraped messages from reddit for specific users and examine how the model adapts to their specific word-choice and typing habits. Particularly, we will be using data from a pre-scrapped dataset [Reddit comments/submissions 2005-06](https://academictorrents.com/details/89d24ff9d5fbc1efcdaf9d7689d72b7548f699fc). Further, we also want to make sure that sample data from any user we use:
* Provides a reasonably large enough dataset to train on
* Posts regularly (as opposed to posting a lot occasionally)

With that defined, our focus will be between and including messages posted between [1/2011 - 6/2012]. Why this particular time range? No particular reason besides that data with in this time range was reasonably enough sized to be processed quickly yet still leave us with enough data to work with.

We can then process the reddit dump creating a dictionary consisting of reddit users and the messages they sent. We further filter this data as follows:
* "Deleted" users aren't included
* First, we remove all authors with less than 546 (i.e. they must average more than 1 post a day) - although future filteration steps would've already ensured this requirement, we start of with this since this is a quick and dirty elimination step allowing for quicker processing in the next steps (which are a bit more complicated implementation wise)
* Second, we filter to only authors that made at least 18 post per month in the time range without missing any month.
* Third, we filter to only authors that made at least 2 post every week without missing a week.
* Fourth, we filter down the remaining users to those that posted on at least 80% of the days.
* In our final step, we post process the post messages by running it through a sequence-to-sequence model already developed by someone correcting silly grammatical errors. Otherwise, we may end up having messages in our data set that contain non-existent vocab.
TODO: grammer corrrection may be needed! https://huggingface.co/pszemraj/grammar-synthesis-small/tree/main also trained on t5
or https://huggingface.co/flexudy/t5-small-wav2vec2-grammar-fixer

### Establishing a Baseline

As previously mentioned, we want to develop a model the beats the naive n-gram models (with a window of 3) used by most keyboard apps. Therefore, we will define our baseline as 3-gram model trained on the entire reddit corpus. This model uses simple probabilities based on the frequency of the n-gram in the corpus to make predictions.

Although modern keyboard apps adapt and update their n-gram model by using data from its user, as a 'baseline' model, we neglect this detail in the interest of time.

### Definition a "Successful" Prediction

We define a successful prediction as a prediction that meets the following criteria:
```The predicted next word is one of the top 3 predictions or a "close enough" synonym of one of the top 3 words.```

#### Justification
The goal of our "predict the next word" is to allow our users to type faster by suggesting their next word. Generally, most mobile devices have enough screen space to suggest at least 3 words for the "next word". Moreover, users may not care enough to type to exact word they were thinking of if the suggested word has a "close enough" meaning to the word they intended to write.

#### Defining a "close enough" synonym
We will base our definition of a "close enough" synonym on the assumption that "most words have a synonym". Then, we define a synonym to be close enough as follows:
1. For each word in the word embedding space, find the word closest to that word using some distance metric.
2. Calculate the mean and standard deviation of the distances between the closest words.
3. Now we define a word to be "close enough" as being with in (the mean minus 1 standard deviation) unit distance from the predicted word.

### Evaluation Stategy

In light of our above outlined goals, we have two major evaluation goals:


1. How well does the model predict the user's next token after having seen x tokens of examples from the user. In other words, not only how well the model predicts the user's next token but also how fast the model improves its prediction as a function of the data it has already seen?

> We can evaluate "how well" the model predicts the user's next token by measuring the loss between what the user actually types vs what the model suggests. Then, we can further measure this loss as function of the number of tokens of examples the model has seen during its Fine-Tuning. For example, how does the loss change after the model has seen 1000 tokens of examples from the user?

2. How much computation is needed is to both run the model and update the model on new data?

> Measuring how long various parts of the model such prediction and training takes is trivial. (We can simply calculate the time between the target area of code). We may then analyze the run-time in context of the hardware the code is run on and comparing this information with current state of computational power of modern mobile devices. This information can be used to make an informed decision on the sequence lengths to input to the model to ensure our model can suggest new words to users in real-time on standard mobile hardware.

### Project Plan and Exploring Potential Solutions

Once again, our goal is to accuractely and effectively predict the user's next token in real-time while adapting to the user's behavior and writing style over time.

In order to reasonably meet these goals, we will fine-tune one or more combination of existing models such as T5, Bert, and GPT2, and/or their "Distilled" counterparts. We will use the Hugging Face transformers library simply due to its vast popularity and easy of use. I intend to use the SCC for rapid prototyping and experimentation as I already have significant experiance using the SCC at this stage.

Finally I should note that in my initial research, I have identified potential pitfalls with each of these models and their strengths and weaknesses for my task. However, further experimentation will be needed to make a final decision on which model (or model combinations) to use. Detailed experimentation with each model will be required to evaluate its pros/cons.

#### Bert
Having been trained on a mask-fill task, Bert naturally lends itself to the kind of project I am trying to do. Being a relatively small model, and still quite versatile for the task, Bert may be a great choice. However, one downside to Bert is that Bert seems to perform poorly when attempting to Mask-Fill multiple words in the middle of a sentence. (See the bottom of this notebook for a demonstration).

#### GPT2
GPT2 was essentially trained on predicting the next word. Indeed, this is the task we want to achieve ourselves. However, in some cases, we may need to predict the middle word (if someone is editing the middle of a sentence they wrote). This is not a task GPT2 was designed for although this may still be possible due to the surprising generality of the model. Further research and experimentation will be needed.

#### T5
T5 is an extremely general purpose model than can adapt to many NLP tasks. Unlike bert Being a substantially larger model than both GPT2 and Bert, T5 is slower to retrain. Yet at the same time, T5 seems to do a much better job mask-filling between sentences. Yet, in a realistic scenario how often does one write in the middle of the sentence? Is the increased computing cost really worth it? These are the questions I hope to answer with the first stages of my research.

#### The "Distilled Version"
Models like Distilled-Bert and Distilled-GPT2 are indeed smaller. However, one *major* pitfall to this may be that the model may struggle to generalize to new tasks and during retraining. Retraining forms an essential component to this project and depending on the severity of this effected, the "Distilled" models may prove ineffective. Further experimentation is needed to make any conclusions, however.

### Limitations
* Compared to the N-gram approach, this new model cannot easily learn new words?
* You may talk different with friends vs family vs boss. This training and results was done specifically for reddit. It may be the case that the model will not generalize as well to a broader domain in an actual key board app. (Still probability at least better than the ngram stuff right?)

### Resources and Publications

For this project, I am considering/planning to use the following resources for research:

1. Tunstall, L., Werra, L., Wolf, T., &amp; Géron, A. (2022). Natural language processing with transformers: Building language applications with hugging face. O’Reilly.

> This book has been repeatedly recommended by both Professor Snyder and other students in CS505. Upon a coarse inspection, I expect to particular find the sections on "Fine Tuning" various models helpful (since this is very much a Fine-Tuning project). The book also contains extensive details about various transformer architectures which will inevitably prove useful.

2. "Fine-Tune a Pretrained Model." Hugging Face, https://huggingface.co/docs/transformers/training. Accessed 1 Dec. 2023.

> This blog post contains examples Fine-Tuning bert using three different methods along with model evaluation. Although unlike source (#1), it goes into less details on the theory, the sample code is more rich, and easier to work with.

3. Raffel, Colin, et al. “Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer.” Arxiv, 23 Oct. 2019, Accessed 1 Dec. 2023.

> This is the original paper that introduced T5 to the world. As covered in lecture, similar to Bert, T5 was (partially) trained on Mask-Fill task where random spans of texts were removed which naturally make it a great candidate for what we want to achieve here. However, this Mask-Fill process was only part of the process. T5 is a very versitile (and large) model, and has huge applications. There is no better place to learn what it is, how it works, and how to use it than the original paper!

4. https://212digital.medium.com/fine-tuning-the-gpt-2-large-language-model-unlocking-its-full-potential-66e3a082ab9c

> This great blog post contains all the coded needed to get started on fine-tuning GPT2 and the warnings about the various pitfalls.

### Comparison and Exploratory Between Various Models

We begin with an exploratory between different promising transformer models that have been historically very successful with a variety of tasks. Namely, we will do a comparisons between Bert, GPT2, and T5 on various tasks.

1. We compare between the following 5 sentences to see how well the models do. These results are just to get a basic idea of each model's capabilities. We will "eyeball this result". Note that the "|" token represents the current "cursor" location.

    a. `After forcefully breaking into the bank, they|`

    b. `Every |, my family and I visit Hawaii.`

    c. `In my family, there is my |`

    d. `My favorite | is apple.`

    e. `It has been 2 months since I graduated. However, unfortunately I still haven't found a |. At this rate I won't be able to pay rent!`

2. After that, we will test each model to see how well the model does in predicting this "next" word token. This task is the most important task for our proposed application to do well.

3. Next, we will see how well each model does at predicting a randomly removed "middle" token.

4. Then, we do a similar challenge by now comparing how each model does at predicting when multiple "middle" tokens have been deleted.

5. Finally, we will choose the most promising model to develop a fine-tuning method for this continuous learning project.

#### Setup: Imports

In [1]:
# all imports here
from typing import Union
import os
import pickle
import gzip
import torch
from tqdm import tqdm
from typing import Dict, Tuple, List, Set, Any, Union
import statistics
import random
import re
from nltk.tokenize import sent_tokenize
from nltk import word_tokenize
import nltk
import time

from transformers import GPT2LMHeadModel, GPT2Tokenizer, PreTrainedTokenizerBase, PreTrainedModel, GPT2Config
from transformers import pipeline
from transformers import DistilBertForMaskedLM, DistilBertTokenizer, DistilBertConfig, AdamW
from torch.utils.data import DataLoader, TensorDataset
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration
from transformers import DistilBertTokenizer, DistilBertForMaskedLM

nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     /usr4/cs505ws/aseef/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
test_sentences = [
    "After forcefully breaking into the bank, they |", # predict the next word using context
    "In my family, there is my |", # predict the next word using context
    "Every |, my family and I visit Hawaii.",  # simple mask fill with one missing word
    "My favorite | is apple.", # simple mask fill with one missing word
    "It has been 2 months since I graduated. However, unfortunately I still haven't found a | rate I won't be able to pay rent!", # need multiple words here before sentence makes sense
    "Following the American Civil War, | assassinated.",  # need multiple words here before sentence makes sense
]

In [3]:
data_dir = "/projectnb/cs505ws/projects/NextType/data"

def does_var_exists_gz(var_name: str) -> bool:
    return os.path.isfile(F'{data_dir}/{var_name}.pkl.gz')

def dump_var_gz(var_name: str, obj) -> None:
    os.makedirs(f"{data_dir}", exist_ok=True)
    with gzip.open(F'{data_dir}/{var_name}.pkl.gz', 'wb', compresslevel=1) as file:
        pickle.dump(obj, file)


def load_var_gz(var_name: str) -> Union[None, object]:
    if not does_var_exists_gz(var_name):
        return None

    file_path = F'{data_dir}/{var_name}.pkl.gz'  # Updated file extension
    with gzip.open(file_path, 'rb', compresslevel=1) as file:
        return pickle.load(file)

In [4]:
# Check if GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device={device}")

Using device=cuda


In [5]:
author_to_posts_dict: Dict[str, Tuple[int, str]] = load_var_gz("author_to_lines")

In [6]:
random_user_sample = random.sample(list(author_to_posts_dict.keys()), 25)
random_user_sample

['analogkid01',
 'rednightmare',
 'alle0441',
 'ChingShih',
 'zulubanshee',
 'whoami9',
 'lakerswiz',
 'jared555',
 'Flammy',
 'Peritract',
 'from_the_sidelines',
 'Wyrmshadow',
 'wartornhero',
 'gndn',
 'PalermoJohn',
 'powercow',
 'mariesoleil',
 'browwiw',
 'peewinkle',
 'pillage',
 'cylinderhead',
 'chicofaraby',
 'Wiebelhaus',
 'Naly_D',
 'Maxion']

In [7]:
# sample up to 50 posts from each user
random_post_samples = []
for rand_user in random_user_sample:
    all_posts = author_to_posts_dict[rand_user]
    sampled_posts = random.sample(list(all_posts), 50)
    random_post_samples += sampled_posts

#### Data Cleaning

In [8]:
# this was the only small-ish grammar correction model I found.
# had I more time, I would create my own model. But I don't so I will focus
# on my primary task.
# the downsides of this models is that besides correcting spellings, it often alters the
# structure of the sentence which could fundamentally undermine our purpose.
# so question: does benefits of correcting grammar using this outweighs the harms?
# after all, if the model doesnt recgonize a word, it'll just ignore it and wont learn from it!
grammar_corrector = pipeline(
               'text2text-generation',
               'pszemraj/grammar-synthesis-small',
                 device=device
               )

In [9]:
import html
def normalize_text(post_text: str):
    # get rid of new lines
    post_text = re.sub("\n", " ", post_text)
    # remove html characters
    post_text = html.unescape(post_text)
    # Remove bold and italic formatting
    post_text = re.sub(r'(\*\*|__)(.*?)\1|(\*|_)(.*?)\3', r'\2\4', post_text)
    # Remove headers
    post_text = re.sub(r'^#{1,6}\s', '', post_text)
    # Remove hyperlinks
    post_text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'\1', post_text)
    # Remove inline code
    post_text = re.sub(r'`([^`]+)`', r'\1', post_text)
    # Remove block code
    post_text = re.sub(r'```(?:[^`]+|`(?!``))*```', '', post_text)
    # Remove lists (unordered and ordered)
    post_text = re.sub(r'^\s*([\*\-\+]\s|(\d+\.)\s)', '', post_text)
    post_text = re.sub("(\*\*|__)(.*?)\1|(\*|_)(.*?)\3", "", post_text)
    # remove double spaces
    post_text = re.sub(" {2,}", " ", post_text)
    # replace urls
    post_text = re.sub(r"https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)", "{URL}", post_text)
    # replace references to a specific subreddit but just the token "[subreddit]"
    post_text = re.sub(r"(\W)(r/[a-z0-9A-Z_]{2,10})(\W)", r"\1{SUB REDDIT}\3", post_text)
    post_text = re.sub(r"(\W)(/[a-z0-9A-Z_]{2,10})(\W)", r"\1{SUB REDDIT}\3", post_text)
    post_text = post_text.strip()
    return post_text

In [15]:
user_posts_corrected = list(user_posts)
for i in tqdm(range(len(user_posts_corrected))):
    creation, message = user_posts_corrected[i]
    message = normalize_text(message)
    sentences = sent_tokenize(message)
    for j in range(len(sentences)):
        sent = sentences[j]
        updated_message = grammar_corrector(sent)[0]['generated_text']
        sentences[j] = updated_message
    updated_message = ' '.join(sentences)
    user_posts_corrected[i] = (creation, updated_message)

  0%|                                                                                                                                                                                 | 0/3794 [00:00<?, ?it/s]


NameError: name 'normalize_text' is not defined

In [80]:
dump_var_gz(f"{random_key}-corrected-posts", user_posts_corrected)

In [11]:
user_posts_corrected = load_var_gz('The_Jackal-corrected-posts')

In [19]:
counter = 0
index = 0
for old_post, new_post in zip(user_posts, user_posts_corrected):
    normalized_post = normalize_text(old_post[1])
    if old_post[1].strip() != new_post[1].strip():
        print("-=+=--=+=--=+=--=+=--=+=--=+=-")
        print(f'Old Post {index}:', normalized_post)
        print("~~+~~~~+~~~~+~~~~+~~~~+~~~~+~~")
        print(f'New Post {index}:', new_post[1])
        print("-=+=--=+=--=+=--=+=--=+=--=+=-")
        counter += 1
    index += 1
    if counter > 9:
        break

#### T5-Small

In [11]:
# load in the t5 model
T5_path = 't5-small'
t5_config = T5Config.from_pretrained(T5_path)
t5_tokenizer = T5Tokenizer.from_pretrained(T5_path, legacy=False)
t5_model = T5ForConditionalGeneration.from_pretrained(T5_path, config=t5_config).to(device)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


##### T5-Small On the Sample Sentences

In [11]:
for sentence in test_sentences:

    input_text = sentence.replace("|", "<extra_id_0>")
    input_ids = t5_tokenizer(input_text, return_tensors="pt").to(device).input_ids

    # Generate predictions for the next token
    num_samples = 5
    with torch.no_grad():
        output = t5_model.generate(
            input_ids,
            num_beams=num_samples,
            min_new_tokens=2,
            max_new_tokens=2,
            num_return_sequences=num_samples,  # Generate multiple suggestions
            return_dict_in_generate=True,
            output_scores=True)

    probabilities = torch.nn.functional.softmax(output.sequences_scores, dim=-1)

    # Decode and print the predicted token
    suggestions = []
    for sample_output, prob in zip(output.sequences, probabilities):
        decoded_output = t5_tokenizer.decode(sample_output, skip_special_tokens=True)
        suggestions += [(decoded_output, prob.item())]
    print('Sentence:', sentence)
    print('Suggestions:', suggestions)
    print("-=+=--=+=--=+=--=+=--=+=-")

Sentence: After forcefully breaking into the bank, they |
Suggestions: [('broke', 0.22146426141262054), ('break', 0.21711385250091553), ('are', 0.21013765037059784), ('were', 0.17779791355133057), ('will', 0.1734863519668579)]
-=+=--=+=--=+=--=+=--=+=-
Sentence: In my family, there is my |
Suggestions: [('family', 0.27247941493988037), ('own', 0.25976160168647766), ('daughter', 0.15796756744384766), ('mother', 0.1555204540491104), ('home', 0.1542709767818451)]
-=+=--=+=--=+=--=+=--=+=-
Sentence: Every |, my family and I visit Hawaii.
Suggestions: [('year', 0.3100126087665558), ('day', 0.2699778378009796), ('month', 0.1532614529132843), ('week', 0.14059576392173767), ('time', 0.12615235149860382)]
-=+=--=+=--=+=--=+=--=+=-
Sentence: My favorite | is apple.
Suggestions: [('apple', 0.469388872385025), ('fruit', 0.16375479102134705), ('thing', 0.15048746764659882), ('pie', 0.11327831447124481), ('recipe', 0.10309050232172012)]
-=+=--=+=--=+=--=+=--=+=-
Sentence: It has been 2 months since 

##### T5-Small on "Predict the Next Token"

In [15]:
t5_correct_guesses = 0
t5_total_guesses = 0
t5_inference_times = []

for creation, post in tqdm(random_post_samples, smoothing=0):
    sent_tokenized = sent_tokenize(normalize_text(post))
    tokenized_words = [word_tokenize(sentence) for sentence in sent_tokenized]
    tokenized_words = [word for s in tokenized_words for word in s]
    for i in range(1, len(tokenized_words) - 1):
        start_time = time.time()
        current_prompt = ' '.join(tokenized_words[:i]) + " <extra_id_0>"
        current_prompt = re.sub(r" ([!.?,;:\"')\]}]{1,9})", r"\1", current_prompt)
        current_prompt = re.sub(r"([\[({]{1,9}) ", r"\1", current_prompt)
        actual_next_word = tokenized_words[i]
        input_ids = t5_tokenizer(current_prompt, return_tensors="pt").to(device).input_ids
        # t5 can only accept up to 512 tokens so if our tensor is bigger than this
        # we trim it before passing into the model
        if input_ids[0].shape[0] > 512:
            input_ids = input_ids[:, -512:]

        with torch.no_grad():
            # generate one word at a time
            num_samples = 7
            output = t5_model.generate(
                input_ids,
                num_beams=num_samples,
                min_new_tokens=2,
                max_new_tokens=2,
                num_return_sequences=num_samples,  # Generate multiple suggestions
                return_dict_in_generate=True,
                output_scores=True
            )

        probabilities = torch.nn.functional.softmax(output.sequences_scores, dim=-1)

        # Decode and print the predicted token
        suggestions = set()
        for sample_output, prob in zip(output.sequences, probabilities):
            if len(suggestions) >= 5:
                break
            decoded_output = t5_tokenizer.decode(sample_output, skip_special_tokens=True)
            if prob.item() < 0.05:
                # avoid bizarre suggestions by simply filtering out low prob
                # terms. We don't HAVE TO suggest exactly 5 words
                break
            if decoded_output.strip() == '':
                continue
            suggestions.add(decoded_output)

        if actual_next_word in suggestions:
            t5_correct_guesses += 1
        t5_total_guesses += 1
        t5_inference_times += [time.time() - start_time]


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [19:05<00:00,  1.09it/s]


In [16]:
t5_accuracy = t5_correct_guesses / t5_total_guesses
print(t5_accuracy)

0.3592931017170399


In [17]:
print('Avg Inference Time:', statistics.mean(t5_inference_times))
print('Inference Std:', statistics.stdev(t5_inference_times))

Avg Inference Time: 0.022976395781948235
Inference Std: 0.00535418975672332


##### T5-Small on "Predict the Middle Token"
OK, predicting the next token is fun and all. But what if a user wants to add something to the middle of their sentence? Wouldn't it be nice to be able to both use left and right context?

---
#### DistilBert

In [12]:
distbert_path = 'distilbert-base-cased'
distbert_model = DistilBertForMaskedLM.from_pretrained(distbert_path).to(device=device)
distbert_config = DistilBertConfig.from_pretrained(distbert_path)
distbert_tokenizer = DistilBertTokenizer.from_pretrained(distbert_path, config=distbert_config)

##### DistilBert On the Sample Sentences

In [20]:
for sentence in test_sentences:

    input_text = sentence.replace("|", "[MASK]")
    input_ids = distbert_tokenizer(input_text, return_tensors="pt").to(device).input_ids

    # Get the position of the masked token
    mask_token_index = torch.where(input_ids == distbert_tokenizer.mask_token_id)[1].tolist()[0]

    # Generate predictions for the next token
    with torch.no_grad():
        output = distbert_model(input_ids)
        predictions = output.logits

    # Get the top-k predicted tokens and their probabilities
    top_k = 5  # Adjust as needed
    probs, indices = torch.topk(predictions[0, mask_token_index], k=top_k, dim=-1)

    # Convert indices back to tokens
    predicted_tokens = distbert_tokenizer.convert_ids_to_tokens(indices.tolist())

    # Decode and print the predicted token
    suggestions = []
    for sample_output, prob in zip(predicted_tokens, probs.tolist()):
        suggestions += [(sample_output, prob)]
    print('Sentence:', sentence)
    print('Suggestions:', suggestions)
    print("-=+=--=+=--=+=--=+=--=+=-")

Sentence: After forcefully breaking into the bank, they |
Suggestions: [('!', 7.7542314529418945), ('.', 7.455650806427002), ('escape', 6.470279693603516), (':', 6.367412567138672), ('find', 6.3458967208862305)]
-=+=--=+=--=+=--=+=--=+=-
Sentence: In my family, there is my |
Suggestions: [('heart', 6.815902233123779), ('destiny', 6.189598560333252), ('.', 6.164964199066162), ('love', 6.162736415863037), ('family', 6.145949363708496)]
-=+=--=+=--=+=--=+=--=+=-
Sentence: Every |, my family and I visit Hawaii.
Suggestions: [('##day', 6.442167282104492), ('day', 5.875978946685791), ('morning', 5.826064586639404), ('##night', 5.271539688110352), ('night', 5.13131046295166)]
-=+=--=+=--=+=--=+=--=+=-
Sentence: My favorite | is apple.
Suggestions: [('fruit', 11.054349899291992), ('apple', 10.834712028503418), ('tree', 10.448899269104004), ('grape', 8.978232383728027), ('vegetable', 8.545177459716797)]
-=+=--=+=--=+=--=+=--=+=-
Sentence: It has been 2 months since I graduated. However, unfortu

##### DistilBert on "Predict the Next Token"

In [36]:
distbert_correct_guesses = 0
distbert_total_guesses = 0
distbert_inference_times = []

for creation, post in tqdm(random_post_samples, smoothing=0):

    sent_tokenized = sent_tokenize(normalize_text(post))
    tokenized_words = [word_tokenize(sentence) for sentence in sent_tokenized]
    tokenized_words = [word for s in tokenized_words for word in s]
    for i in range(1, len(tokenized_words) - 1):
        start_time = time.time()
        current_prompt = ' '.join(tokenized_words[:i]) + " [MASK]"
        current_prompt = re.sub(r" ([!.?,;:\"')\]}]{1,9})", r"\1", current_prompt)
        current_prompt = re.sub(r"([\[({]{1,9}) ", r"\1", current_prompt)
        actual_next_word = tokenized_words[i]
        input_ids: torch.Tensor = distbert_tokenizer(current_prompt, return_tensors="pt").to(device).input_ids
        # bert can only accept up to 512 tokens so if our tensor is bigger than this
        # we trim it before passing into the model
        if input_ids[0].shape[0] > 512:
            input_ids = input_ids[:, -512:]
        # Get the position of the masked token
        mask_token_index = torch.where(input_ids == distbert_tokenizer.mask_token_id)[1].tolist()[0]

        with torch.no_grad():
            # generate one word at a time
            output = distbert_model(input_ids)
            predictions = output.logits

        # Get the top-k predicted tokens and their probabilities
        top_k = 8  # Adjust as needed
        probs, indices = torch.topk(predictions[0, mask_token_index], k=top_k, dim=-1)

        # Convert indices back to tokens
        predicted_tokens = distbert_tokenizer.convert_ids_to_tokens(indices.tolist())

        # Decode and print the predicted token
        suggestions = set()
        for decoded_output, prob in zip(predicted_tokens, probs.tolist()):
            if len(suggestions) >= 5:
                break
            if prob < 0.04:
                # avoid bizarre suggestions by simply filtering out low prob
                # terms. We don't HAVE TO suggest exactly 5 words
                break
            # bert also suggests "sub-words". Ehh... we'll just ignore those.
            # otherwise stuff will get too complicated
            if decoded_output.startswith("##"):
                continue
            if decoded_output.strip() == '':
                continue
            suggestions.add(decoded_output)

        if actual_next_word in suggestions:
            distbert_correct_guesses += 1
        # since the way the bert tokenizer works, it can suggest "sub-words" - example: characteristically = characteristic + ##ally,
        # so we will give bert half a point for suggest the same start of the word
        for s in suggestions:
            if actual_next_word.startswith(s):
                distbert_correct_guesses += 0.5
                break
        distbert_total_guesses += 1
        distbert_inference_times += [time.time() - start_time]



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [04:14<00:00,  4.90it/s]


In [37]:
bert_accuracy = distbert_correct_guesses / distbert_total_guesses
print(bert_accuracy)

0.18298021889747967


In [38]:
print('Avg Inference Time:', statistics.mean(distbert_inference_times))
print('Inference Std:', statistics.stdev(distbert_inference_times))

Avg Inference Time: 0.005107757104114663
Inference Std: 0.001992808943326104


##### DistilBert on "Predict the Middle Token"

#### Distil-GPT2

In [10]:
# Choose the GPT-2 model variant
distgpt2_model_name = "distilgpt2"

# Load pre-trained GPT-2 model and tokenizer
distgpt2_model: GPT2LMHeadModel = GPT2LMHeadModel.from_pretrained(distgpt2_model_name)
distgpt2_tokenizer: PreTrainedTokenizerBase = GPT2Tokenizer.from_pretrained(distgpt2_model_name)
# Move the model to the GPU (if available)
distgpt2_model = distgpt2_model.to(device)

In [11]:
for sentence in test_sentences:
    # delete everything including and after |
    # GPT2 is physically unable to consider right context
    input_text = sentence[:sentence.index("|")].strip()
    # Tokenize the prompt
    input_ids = distgpt2_tokenizer.encode(input_text, return_tensors="pt").to(device)
    # Generate probabilities for the next words
    with torch.no_grad():
        outputs = distgpt2_model(input_ids)
        logits = outputs.logits
    # we are only interested in either the top 5 words or words with a probability of > 1%
    # after all, we don't want to suggest too many words!
    top_k = 5
    # Get the probability distribution for the next word
    next_word_probs = torch.nn.functional.softmax(logits[:, -1, :], dim=-1)
    top_k_values, top_k_indices = torch.topk(next_word_probs, k=top_k, dim=-1)
    # Normalize probabilities
    top_k_probs_normalized = top_k_values / top_k_values.sum()

    # Convert the probabilities to a list
    top_k_probs_list: List = top_k_probs_normalized.tolist()[0]
    top_k_indices_list: List = top_k_indices.tolist()[0]

    # Decode and print the predicted token
    suggestions = []
    for sample_output, prob in zip(top_k_indices_list, top_k_probs_list):
        decoded_output = distgpt2_tokenizer.decode(sample_output)
        suggestions += [(decoded_output, prob)]
    print('Sentence:', sentence)
    print('Suggestions:', suggestions)

Sentence: After forcefully breaking into the bank, they |
Suggestions: [(' were', 0.6859785914421082), (' found', 0.10938050597906113), (' had', 0.08820205926895142), (' took', 0.059164125472307205), (' began', 0.05727475509047508)]
Sentence: In my family, there is my |
Suggestions: [(' family', 0.27839019894599915), (' brother', 0.2245674729347229), (' mother', 0.18426108360290527), (' sister', 0.1591966301202774), (' wife', 0.1535845696926117)]
Sentence: Every |, my family and I visit Hawaii.
Suggestions: [(' The', 0.40960025787353516), ('.', 0.16524524986743927), (' A', 0.15191836655139923), ('\n', 0.13858996331691742), ('The', 0.13464611768722534)]
Sentence: My favorite | is apple.
Suggestions: [(' favorite', 0.44314318895339966), (' part', 0.18813173472881317), (' thing', 0.16900213062763214), (' of', 0.11418430507183075), ('.', 0.08553868532180786)]
Sentence: It has been 2 months since I graduated. However, unfortunately I still haven't found a | rate I won't be able to pay rent!

In [15]:
distilgpt2_correct_predictions = 0
distilgpt2_total_predictions = 0
distilgpt2_inference_times = []

#distgpt2_tokenizer.pad_token = distgpt2_tokenizer.eos_token
distgpt2_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
for creation, post in tqdm(random_post_samples, smoothing=0):
    sent_tokenized = sent_tokenize(normalize_text(post))
    tokenized_words = [word_tokenize(sentence) for sentence in sent_tokenized]
    tokenized_words = [word for s in tokenized_words for word in s]

    current_batch_prompt = []
    current_batch_answer = []
    for i in range(1, len(tokenized_words) - 1):
        start_time = time.time()
        current_prompt = ' '.join(tokenized_words[:i]).strip()
        current_prompt = re.sub(r" ([!.?,;:\"')\]}]{1,9})", r"\1", current_prompt)
        current_prompt = re.sub(r"([\[({]{1,9}) ", r"\1", current_prompt)
        actual_next_word = tokenized_words[i]
        current_batch_prompt += [current_prompt]
        current_batch_answer += [actual_next_word]

        assert len(current_batch_prompt) == len(current_batch_answer)
        if len(current_batch_prompt) < 8:
            continue

        # Tokenize the prompt
        input_ids = distgpt2_tokenizer.batch_encode_plus(current_batch_prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024, truncation_strategy="longest_first").input_ids.to(device)

        # Generate probabilities for the next words
        with torch.no_grad():
            outputs = distgpt2_model(input_ids)
            logits = outputs.logits
        # we are only interested in either the top 5 words or words with a probability of > 1%
        # after all, we don't want to suggest too many words!
        top_k = 5
        top_p = 0.04  # don't suggest 5 words just for the sake of suggesting 5 words
        # Get the probability distribution for the next word
        next_word_probs = torch.nn.functional.softmax(logits[:, -1, :], dim=-1)
        top_k_values, top_k_indices = torch.topk(next_word_probs, k=top_k, dim=-1)
        # Normalize probabilities
        top_k_probs_normalized = top_k_values / top_k_values.sum()
        # Convert the probabilities to a list
        top_k_probs_list: List = top_k_probs_normalized.tolist()
        top_k_indices_list: List = top_k_indices.tolist()

        # for each element in the batch
        for current_prompt, top_k_probs, top_k_indices, actual_next_word in zip(current_batch_prompt, top_k_probs_list, top_k_indices_list, current_batch_answer):

            guesses = set()
            for token_id, prob in zip(top_k_indices, top_k_probs):
                # ignore bizarre suggestions
                if prob < top_p:
                    continue
                token = distgpt2_tokenizer.decode([token_id])
                guesses.add(token)

            # we consider a prediction correct if the actual word was
            # one of the (up to) 5 words suggested
            if (' ' + actual_next_word) in guesses:
                distilgpt2_correct_predictions += 1
            distilgpt2_total_predictions += 1

            print('prompt:', current_prompt)
            print('suggestions:', guesses)
            print('answer:', actual_next_word)

        distilgpt2_inference_times += [time.time() - start_time]
        # clear the batch
        current_batch_prompt = []
        current_batch_answer = []


  0%|                                                                                                                                 | 0/1250 [00:00<?, ?it/s]../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [80,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [80,0,0], thread: [65,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [80,0,0], thread: [66,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [80,0,0], thread: [67,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [80,0,0], thread: [68,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [80,

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [41]:
distilgpt2_correct_predictions = 0
distilgpt2_total_predictions = 0
distilgpt2_inference_times = []

for creation, post in tqdm(random_post_samples, smoothing=0):
    sent_tokenized = sent_tokenize(normalize_text(post))
    tokenized_words = [word_tokenize(sentence) for sentence in sent_tokenized]
    tokenized_words = [word for s in tokenized_words for word in s]
    for i in range(1, len(tokenized_words) - 1):
        start_time = time.time()
        current_prompt = ' '.join(tokenized_words[:i])
        current_prompt = re.sub(r" ([!.?,;:\"')\]}]{1,9})", r"\1", current_prompt)
        current_prompt = re.sub(r"([\[({]{1,9}) ", r"\1", current_prompt)
        actual_next_word = tokenized_words[i]
        # Tokenize the prompt
        input_ids = distgpt2_tokenizer.encode(current_prompt, return_tensors="pt").to(device)
        if input_ids.shape[1] == 0:
            continue
        # gpt2 can only accept up to 1024 tokens so if our tensor is bigger than this
        # we trim it before passing into the model
        if input_ids[0].shape[0] > 1024:
            input_ids = input_ids[:, -1024:]
        # Generate probabilities for the next words
        with torch.no_grad():
            outputs = distgpt2_model(input_ids)
            logits = outputs.logits
        # we are only interested in either the top 5 words or words with a probability of > 1%
        # after all, we don't want to suggest too many words!
        top_k = 5
        top_p = 0.04  # don't suggest 5 words just for the sake of suggesting 5 words
        # Get the probability distribution for the next word
        next_word_probs = torch.nn.functional.softmax(logits[:, -1, :], dim=-1)
        top_k_values, top_k_indices = torch.topk(next_word_probs, k=top_k, dim=-1)
        # Normalize probabilities
        top_k_probs_normalized = top_k_values / top_k_values.sum()

        # Convert the probabilities to a list
        top_k_probs_list: List = top_k_probs_normalized.tolist()[0]
        top_k_indices_list: List = top_k_indices.tolist()[0]

        # filter to only words with a probability greater than p%
        removal_marked = []
        for j in range(len(top_k_probs_list)):
            if top_k_probs_list[j] < top_p:
                removal_marked += [(top_k_probs_list[j], top_k_indices_list[j])]

        for to_remove in removal_marked:
            top_k_probs_list.remove(to_remove[0])
            top_k_indices_list.remove(to_remove[1])

        guesses = set()
        for token_id, prob in zip(top_k_indices_list, top_k_probs_list):
            token = distgpt2_tokenizer.decode([token_id])
            guesses.add(token)

        # we consider a prediction correct if the actual word was
        # one of the (up to) 5 words suggested
        if ' ' + actual_next_word in guesses:
            distilgpt2_correct_predictions += 1
        distilgpt2_total_predictions += 1
        distilgpt2_inference_times += [time.time() - start_time]


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [04:02<00:00,  5.16it/s]


In [42]:
gpt2_accuracy = distilgpt2_correct_predictions / distilgpt2_total_predictions
print(gpt2_accuracy)

0.3662472919523781


In [14]:
print('Avg Inference Time:', statistics.mean(distilgpt2_inference_times))
print('Inference Std:', statistics.stdev(distilgpt2_inference_times))

Avg Inference Time: 0.00606661000856735
Inference Std: 0.0015065450347984649


### Summary of the results from Model Comparison

* T5 was the slowest model
GPU
* 24:39 - T5-Small
* 5:10 - DistilBert
* 5:32 - DistilGPT2
CPU
* 3:19:43: T5-Small
* 1:41:39 - DistilBert
* 2:29:33 - DistilGPT2
Accuracy:
* 37.27% - T5-Small
* 17.73% - DistilBert
* 35.86% - DistilGPT2

In [29]:
total_words = 0
for creation, post in random_post_samples:
    sent_tokenized = sent_tokenize(normalize_text(post))
    tokenized_words = [word_tokenize(sentence) for sentence in sent_tokenized]
    tokenized_words = [word for s in tokenized_words for word in s]
    total_words += len(tokenized_words)
print(total_words)

46875


#### The Problem of Catastrophic Forgetting
So far, in the examples we have tested, we have trained our model on the entire dataset at ones (as opposed to training the model on the data sequentially). That is, in the real world, in the type of application we want to develop, our model would be exposed to snippets of data over a long period of time. However, this introduces the problem of "catastrophic forgetting" where the model, while optimizing the weights on the new data forgets about the old tasks it learned to do.

To overcome this problem, we turn to DeepMind research paper titled ["Overcoming catastrophic forgetting in neural networks"](https://arxiv.org/abs/1612.00796). This paper arrives upon 3 key insights to help develop an algorithm for the aforementioned problem:
1. Many optimal configurations of the set of weights and biases θ exist.
2. It is likely that one those optimal configurations is close to the previous learned tasks.
3. It is therefore possible to constraint parameters to stay in a region of low error for previous learned tasks during training.

Inspired by biological mechanisms found in real life, the paper refers to this algorithm as elastic weight consolidation (EWC). EWC is essentially a new loss function. This is exactly what we are looking for because:
1. Ofcourse now we can learn new tasks with out forgetting the others.
2. EWC allows us to define the importance of each tasks relative to each other. In theory, this means we could classify more recent examples as more important! This gives us a lot of power.
3. EWC can be trained for an arbitrary number of new tasks.

https://github.com/moskomule/ewc.pytorch/blob/master/demo.ipynb
https://github.com/kuc2477/pytorch-ewc
https://github.com/ContinualAI/colab/blob/master/notebooks/intro_to_continual_learning.ipynb

In [22]:
# randomly select a user to train with
random_user = random.choice(list(author_to_posts_dict.keys()))
random_user

'VulturE'

In [23]:
training_data = author_to_posts_dict[random_user]
len(training_data)

2676

In [31]:
# Split the data into training and validation sets (e.g., 90% train, 10% validation)
split_index = int(0.9 * len(training_data))
train_data = [x[1] for x in training_data[:split_index]]
valid_data = [x[1] for x in training_data[split_index:]]

In [32]:
valid_data

["I'm sad that samsung no longer makes its own drives :(",
 'I remember the GLORIOUS DAY when Subway [actually had a letter go out to store managers asking that all employees start tessellating the cheese](http://gawker.com/5551263/subway-finally-agrees-to-tessellate-cheese). I have to remind them to do it still :(',
 'Does it need any?',
 'disco ball',
 "I'd second ES File Explorer. Being able to copy files from my Windows Shares easily is what made this app golden. Others I like:\n\n* [Parcels](https://play.google.com/store/apps/details?id=eu.zomtec.android.delivery) for tracking my many packages I send or receive.\n* [Aix Weather](https://play.google.com/store/apps/details?id=net.veierland.aix) all that you need for a weather widget\n* [AlarmDroid](https://play.google.com/store/apps/details?id=com.splunchy.android.alarmclock)\n* [Contact Simple Widget](https://play.google.com/store/apps/details?id=lcc.simplewidgets) simply because it uses nickname values from contacts properly\n* [G

In [33]:
# Convert the list of strings to a single string
train_text_data = []
for train in train_data:
    train_text_data += [normalize_text(train)]

valid_text_data = []
for valid in valid_data:
    valid_text_data += [normalize_text(valid)]

# Save the text data to a file
with open(f"{data_dir}/training.txt", "w", encoding="utf-8") as file:
    file.writelines(train_text_data)

# Save the text data to a file
with open(f"{data_dir}/validation.txt", "w", encoding="utf-8") as file:
    file.writelines(valid_text_data)



In [38]:
from transformers import TextDataset, DataCollatorForLanguageModeling, TrainingArguments, Trainer

model_name = "distilgpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

print(model == distgpt2_model)

# Load and preprocess the training data
train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path=f"{data_dir}/training.txt",
    block_size=128  # adjust block_size as needed
)

# Load and preprocess the validation data
valid_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path=f"{data_dir}/validation.txt",
    block_size=128  # adjust block_size as needed
)

# Data collator for language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # we're not using masked language modeling here
)

# Training arguments
training_args = TrainingArguments(
    output_dir=f"{data_dir}/gpt2-finetuned",
    overwrite_output_dir=True,
    num_train_epochs=5,  # adjust as needed
    per_device_train_batch_size=2,
    save_steps=10_000,  # adjust as needed
    save_total_limit=2,
    prediction_loss_only=True,
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
)

# Fine-tune the GPT-2 model
trainer.train()

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


False


Step,Training Loss


TrainOutput(global_step=2025, training_loss=3.973758179464458, metrics={'train_runtime': 54.6245, 'train_samples_per_second': 74.142, 'train_steps_per_second': 37.071, 'total_flos': 132281479987200.0, 'train_loss': 3.973758179464458, 'epoch': 5.0})

In [39]:
distilgpt2_correct_predictions = 0
distilgpt2_total_predictions = 0
distilgpt2_inference_times = []

for creation, post in tqdm(random_post_samples, smoothing=0):
    sent_tokenized = sent_tokenize(normalize_text(post))
    tokenized_words = [word_tokenize(sentence) for sentence in sent_tokenized]
    tokenized_words = [word for s in tokenized_words for word in s]
    for i in range(1, len(tokenized_words) - 1):
        start_time = time.time()
        current_prompt = ' '.join(tokenized_words[:i])
        current_prompt = re.sub(r" ([!.?,;:\"')\]}]{1,9})", r"\1", current_prompt)
        current_prompt = re.sub(r"([\[({]{1,9}) ", r"\1", current_prompt)
        actual_next_word = tokenized_words[i]
        # Tokenize the prompt
        input_ids = distgpt2_tokenizer.encode(current_prompt, return_tensors="pt").to(device)
        if input_ids.shape[1] == 0:
            continue
        # gpt2 can only accept up to 1024 tokens so if our tensor is bigger than this
        # we trim it before passing into the model
        if input_ids[0].shape[0] > 1024:
            input_ids = input_ids[:, -1024:]
        # Generate probabilities for the next words
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits
        # we are only interested in either the top 5 words or words with a probability of > 1%
        # after all, we don't want to suggest too many words!
        top_k = 5
        top_p = 0.04  # don't suggest 5 words just for the sake of suggesting 5 words
        # Get the probability distribution for the next word
        next_word_probs = torch.nn.functional.softmax(logits[:, -1, :], dim=-1)
        top_k_values, top_k_indices = torch.topk(next_word_probs, k=top_k, dim=-1)
        # Normalize probabilities
        top_k_probs_normalized = top_k_values / top_k_values.sum()

        # Convert the probabilities to a list
        top_k_probs_list: List = top_k_probs_normalized.tolist()[0]
        top_k_indices_list: List = top_k_indices.tolist()[0]

        # filter to only words with a probability greater than p%
        removal_marked = []
        for j in range(len(top_k_probs_list)):
            if top_k_probs_list[j] < top_p:
                removal_marked += [(top_k_probs_list[j], top_k_indices_list[j])]

        for to_remove in removal_marked:
            top_k_probs_list.remove(to_remove[0])
            top_k_indices_list.remove(to_remove[1])

        guesses = set()
        for token_id, prob in zip(top_k_indices_list, top_k_probs_list):
            token = distgpt2_tokenizer.decode([token_id])
            guesses.add(token)

        # we consider a prediction correct if the actual word was
        # one of the (up to) 5 words suggested
        if ' ' + actual_next_word in guesses:
            distilgpt2_correct_predictions += 1
        distilgpt2_total_predictions += 1
        distilgpt2_inference_times += [time.time() - start_time]


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [04:17<00:00,  4.86it/s]


In [40]:
gpt2_accuracy = distilgpt2_correct_predictions / distilgpt2_total_predictions
print(gpt2_accuracy)

0.3338699739629917


Vocab updated??
https://github.com/huggingface/tokenizers/issues/1160