In [61]:
import torch
import torch.nn as nn
import numpy as np  

In [62]:
if torch.cuda.is_available():
    device = "cuda"    
    print(torch.cuda.device_count())
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
device

'cpu'

In [63]:
from datasets import load_dataset

imdb_dataset = load_dataset("imdb")
split = imdb_dataset["train"].train_test_split(train_size = 0.8)
imdb_train, imdb_valid = split["train"], split["test"]
imdb_test = imdb_dataset["test"]

In [64]:
imdb_train[:1]["text"]

["This British pot-boiler has one thing going for it: the young men are uniformly good looking. The older men are opinionated, right-wing Thatcherites whose behavior brings back all the acrimony of the Reagan/Thatcher years. Young or old, however, morals in this three-part mini-series are universally suspect and no one comes off particularly well.<br /><br />Nick is a handsome young gay man fresh out of Oxford. It is not pivotal to the story, but he has an extraordinarily beautiful head of hair which makes watching this drivel much easier. Nick comes to London with a friend, whose father Gerald is a rich conservative politician, and babysits his sister Cat while the family frolics in the south of France. They neglect to inform him that, when upset, Cat cuts herself with an assortment of knives and other kitchen implements. Nick mistakes their self-serving 'gratitude' for affection and moves in, finding out too late just how much they despise and patronize him. Inexplicably, Nick lives 

## Tokenization

In [65]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [66]:
word_freq = {}

for text in imdb_train["text"]:
    word_with_offset = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
    new_word = [word for word, offset in word_with_offset]
    for word in new_word:
        word_freq[word] = word_freq.get(word, 0) + 1

In [67]:
alphabet = []
for word in word_freq.keys():
    for letter in word:
        if letter not in alphabet:
            alphabet.append(letter)
alphabet.sort()
print(alphabet)

['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '¡', '¢', '£', '¤', '¥', '¦', '§', '¨', '©', 'ª', '«', '¬', '®', '¯', '°', '±', '²', '³', '´', '¶', '·', '¸', '¹', 'º', '»', '¼', '½', '¾', '¿', 'Â', 'Ã', 'Å', 'â', 'ï', 'Ĉ', 'ĉ', 'Ġ', 'Ģ', 'ģ', 'Ĥ', 'ĥ', 'Ħ', 'ħ', 'Ī', 'ī', 'Ĭ', 'į', 'İ', 'ĳ', 'ĵ', 'ķ', 'ĸ', 'Ĺ', 'ĺ', 'Ļ', 'ļ', 'ľ', 'Ŀ', 'ŀ', 'Ł', 'ł', 'Ń']


In [68]:
vocab = ["<|endoftext|>"] + alphabet.copy()
print(vocab)

['<|endoftext|>', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '¡', '¢', '£', '¤', '¥', '¦', '§', '¨', '©', 'ª', '«', '¬', '®', '¯', '°', '±', '²', '³', '´', '¶', '·', '¸', '¹', 'º', '»', '¼', '½', '¾', '¿', 'Â', 'Ã', 'Å', 'â', 'ï', 'Ĉ', 'ĉ', 'Ġ', 'Ģ', 'ģ', 'Ĥ', 'ĥ', 'Ħ', 'ħ', 'Ī', 'ī', 'Ĭ', 'į', 'İ', 'ĳ', 'ĵ', 'ķ', 'ĸ', 'Ĺ', 'ĺ', 'Ļ', 'ļ', 'ľ', 'Ŀ', 'ŀ', 'Ł', 'ł', 'Ń']


In [69]:
splits = {word:[c for c in word] for word in word_freq.keys()}

def compute_pair_freq(splits):
    pair_freq = {}
    for word, freq in word_freq.items():
        split = splits[word]
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i+1])
            pair_freq[pair] = pair_freq.get(pair, 0) + freq
    return pair_freq
pair_freq = compute_pair_freq(splits)
for i, (pair, freq) in enumerate(pair_freq.items()):
    print(pair,freq)
    if i > 5:
        break

('T', 'h') 63253
('h', 'i') 167671
('i', 's') 236952
('Ġ', 'B') 30353
('B', 'r') 5387
('r', 'i') 103700
('i', 't') 182617


In [70]:
most_freq = ""
max_freq = None

for pair, freq in pair_freq.items():
    if max_freq is None or max_freq < freq:
        most_freq = pair
        max_freq = freq
print(most_freq, max_freq)

('Ġ', 't') 630711


In [71]:
def merge_pair(a,b,splits):
    for word in word_freq.keys():
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i+1] == b:
                split = split[:i] + [a + b] + split[i+2:]
            else:
                i += 1
        splits[word] = split
    return splits
splits = merge_pair("Ġ", "t", splits)


In [72]:
merges = {}
vocab_size = 2000

while vocab_size > len(vocab):
    pair_freq = compute_pair_freq(splits)
    best_pair = ""
    max_freq = None
    for pair, freq in pair_freq.items():
        if max_freq is None or max_freq < freq:
            best_pair = pair
            max_freq = freq
    splits = merge_pair(*best_pair,splits)
    merges[best_pair] = best_pair[0] + best_pair[1]
    vocab.append(best_pair[0] + best_pair[1])
    

In [73]:
def tokenize(text):    
    word_with_offset = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
    new_words = [word for word, offset in word_with_offset]
    splits = [[l for l in word] for word in new_words]
    for pair, merge in merges.items():
        for idx, split in enumerate(splits):
            i = 0 
            while i < len(split) - 1:
                if split[i] == pair[0] and split[i+1] == pair[1]:
                    split = split[:i] + [merge] + split[i+2:]
                else:
                    i += 1
            splits[idx] = split
    return sum(splits, [])
    


In [76]:
print(vocab)

['<|endoftext|>', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '¡', '¢', '£', '¤', '¥', '¦', '§', '¨', '©', 'ª', '«', '¬', '®', '¯', '°', '±', '²', '³', '´', '¶', '·', '¸', '¹', 'º', '»', '¼', '½', '¾', '¿', 'Â', 'Ã', 'Å', 'â', 'ï', 'Ĉ', 'ĉ', 'Ġ', 'Ģ', 'ģ', 'Ĥ', 'ĥ', 'Ħ', 'ħ', 'Ī', 'ī', 'Ĭ', 'į', 'İ', 'ĳ', 'ĵ', 'ķ', 'ĸ', 'Ĺ', 'ĺ', 'Ļ', 'ļ', 'ľ', 'Ŀ', 'ŀ', 'Ł', 'ł', 'Ń', 'Ġa', 'he', 'in', 'Ġthe', 'Ġs', 're', 'Ġw', 'Ġo', 'is', 'er', 'nd', 'Ġf', 'Ġm', 'Ġb', 'it', 'at', 'ing', 'Ġc', 'en', 'es', 'or', 'ou', 'Ġth', 'on', 'Ġh', 'll', 'an', 'Ġto', 'Ġand', 'Ġof', 'Ġp', 'ar', '

In [82]:

token2id = {token:idx for idx, token in enumerate(vocab)}
id2token = {idx:token for token, idx in token2id.items()}

def tokenize_to_ids(text):
    tokens = tokenize(text)
    ids = []
    for t in tokens:
        ids.append(token2id[t])
    return ids

In [83]:
def decode(ids):
    text = "".join([id2token[i] for i in ids])
    text = text.replace("Ġ"," ")
    return text

In [90]:
tokenize_to_ids("I like this film")

[41, 355, 131, 340, 165, 246]

## Prepare Dataset for Training

In [91]:
X_train = tokenize_to_ids(imdb_train["text"])
y_train = imdb_train["label"]
X_valid = tokenize_to_ids(imdb_valid["text"])
y_valid = imdb_valid["label"]

TypeError: argument 's': 'list' object cannot be converted to 'PyString'

In [92]:
from torch.utils.data import TensorDataset, DataLoader

X_train_tensor = torch.tensor(X_train)
y_train_tensor = torch.tensor(y_train)
X_valid_tensor = torch.tensor(X_valid)
y_valid_tensor = torch.tensor(y_valid)

train_set = TensorDataset(X_train_tensor, y_train_tensor)
valid_set = TensorDataset(X_valid_tensor, y_valid_tensor)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size)


NameError: name 'X_train' is not defined