In [29]:
import torch
import torch.nn as nn
import numpy as np  

In [30]:
if torch.cuda.is_available():
    device = "cuda"    
    print(torch.cuda.device_count())
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
device

'cpu'

In [31]:
from datasets import load_dataset

imdb_dataset = load_dataset("imdb")
split = imdb_dataset["train"].train_test_split(train_size = 0.8)
imdb_train, imdb_valid = split["train"], split["test"]
imdb_test = imdb_dataset["test"]

In [33]:
imdb_train[:1]["text"]

["This film has been receiving a lot of play lately during the day on either HBO or Cinemax. The reason is that they are assuming people would be interested in comparing it to the Leonardo DiCaprio/Tom Hanks caper of the same name. The only reason to see it is for the attractive Matt Lattanzi. Yum! Although I must say Matt was more than a little long in the tooth to be playing a high schooler. If he were a woman, they'd have had him playing the MOTHER of a high schooler! (Is is just me, or is his daughter starting to look like Shelley Duvall?) Oh yeah, the plot--who cares? Typical teen highjinx played by adults."]

In [34]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [44]:
word_freq = {}

for text in imdb_train["text"]:
    word_with_offset = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
    new_word = [word for word, offset in word_with_offset]
    for word in new_word:
        word_freq[word] = word_freq.get(word, 0) + 1

In [42]:
alphabet = []
for word in word_freq.keys():
    for letter in word:
        if letter not in alphabet:
            alphabet.append(letter)
alphabet.sort()
print(alphabet)

['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '¡', '¢', '£', '¤', '¥', '¦', '§', '¨', '©', 'ª', '«', '¬', '®', '¯', '°', '±', '²', '³', '´', '¶', '·', '¸', '¹', 'º', '»', '¼', '½', '¾', '¿', 'Â', 'Ã', 'â', 'ï', 'Ĉ', 'ĉ', 'Đ', 'Ġ', 'Ģ', 'Ĥ', 'ĥ', 'Ħ', 'ħ', 'Ī', 'ī', 'Ĭ', 'į', 'İ', 'ĳ', 'ĵ', 'ķ', 'ĸ', 'Ĺ', 'ĺ', 'Ļ', 'ľ', 'Ŀ', 'ŀ', 'Ł', 'ł', 'Ń']


In [45]:
vocab = ["<|endoftext|>"] + alphabet.copy()
print(vocab)

['<|endoftext|>', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '¡', '¢', '£', '¤', '¥', '¦', '§', '¨', '©', 'ª', '«', '¬', '®', '¯', '°', '±', '²', '³', '´', '¶', '·', '¸', '¹', 'º', '»', '¼', '½', '¾', '¿', 'Â', 'Ã', 'â', 'ï', 'Ĉ', 'ĉ', 'Đ', 'Ġ', 'Ģ', 'Ĥ', 'ĥ', 'Ħ', 'ħ', 'Ī', 'ī', 'Ĭ', 'į', 'İ', 'ĳ', 'ĵ', 'ķ', 'ĸ', 'Ĺ', 'ĺ', 'Ļ', 'ľ', 'Ŀ', 'ŀ', 'Ł', 'ł', 'Ń']


In [52]:
splits = {word:[c for c in word] for word in word_freq.keys()}

def compute_pair_freq(splits):
    pair_freq = {}
    for word, freq in word_freq.items():
        split = splits[word]
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i+1])
            pair_freq[pair] = pair_freq.get(pair, 0) + freq
    return pair_freq
pair_freq = compute_pair_freq(splits)
for i, (pair, freq) in enumerate(pair_freq.items()):
    print(pair,freq)
    if i > 5:
        break

('T', 'h') 62926
('h', 'i') 166820
('i', 's') 235957
('Ġ', 'f') 188078
('f', 'i') 71670
('i', 'l') 104507
('l', 'm') 43829


In [53]:
most_freq = ""
max_freq = None

for pair, freq in pair_freq.items():
    if max_freq is None or max_freq < freq:
        most_freq = pair
        max_freq = freq
print(most_freq, max_freq)

('Ġ', 't') 628411


In [55]:
def merge_pair(a,b,splits):
    for word in word_freq.keys():
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == 1 and split[i+1] == b:
                split = split[:i] + [a + b] + aplit[i+2:]
            else:
                i += 1
        splits[word] = split
    return splits
splits = merge_pair("Ġ", "t", splits)


In [58]:
merges = {}
vocab_size = 2000

while vocab_size > len(vocab):
    pair_freq = compute_pair_freq(splits)
    best_pair = ""
    max_freq = None
    for pair, freq in pair_freq.items():
        if max_freq is None or max_freq < freq:
            best_pair = pair
            max_freq = freq
    splits = merge_pair(*best_pair,splits)
    merges[best_pair] = best_pair[0] + best_pair[1]
    vocab.append(best_pair[0] + best_pair[1])
    

In [59]:
def tokenize(text):    
    word_with_offset = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
    new_words = [word for word, offset in word_with_offset]
    splits = [[l for l in word] for word in new_words]
    for pair, merge in merge.items():
        for idx, split in enumerate(splits):
            i = 0 
            while i < len(split) - 1:
                if split[i] == pair[0] and split[i+1] == pair[1]:
                    split = split[:i] + [merge] + split[i+2:]
                else:
                    i += 1
            splits[idx] = split
    return sum(splits, [])
    


In [None]:
train_tokenize = tokenize(imdb_train["text"])

def token