In [None]:
%pip install pyarrow jsonlines pandas -q

In [1]:
from pyarrow import parquet as pq
import jsonlines
import pandas as pd
from tqdm import tqdm
import json
import os
import subprocess
import requests

In [None]:
files = [
    ("https://huggingface.co/datasets/HuggingFaceTB/cosmopedia-100k/resolve/main/data/train-00000-of-00002.parquet?download=true", "train-00000-of-00002.parquet"),
    ("https://huggingface.co/datasets/HuggingFaceTB/cosmopedia-100k/resolve/main/data/train-00001-of-00002.parquet?download=true", "train-00001-of-00002.parquet")
]

try:
    os.mkdir("./data")
except FileExistsError:
    pass

for url, file in files:
    fp = "./data/" + file
    if os.path.exists(fp):
        continue
    
    content = requests.get(url).content
    with open(fp, "wb") as f:
        f.write(content)
    
    print(f"Downloaded {file}")

In [None]:
pq_file_01 = pq.read_table("./data/train-00000-of-00002.parquet").to_pandas()
pq_file_02 = pq.read_table("./data/train-00001-of-00002.parquet").to_pandas()

In [None]:
json_conversations = []

for pq_file in [pq_file_01, pq_file_02]:
    for row in tqdm(pq_file.iterrows(), total=len(pq_file)):
        json_conversations.append(row[1].to_dict())


print(json_conversations[0].keys())

In [None]:
with jsonlines.open("./data/train.jsonl", "w") as writer:
    for conv in tqdm(json_conversations):
        writer.write(conv)

In [None]:
try:
    json_conversations.append(json_conversations.pop(0))
except:
    json_conversations = []
    with open("./data/train.jsonl", "r") as reader:
        for line in reader:
            json_conversations.append(json.loads(line))

print("\n".join([str(val) for val in json_conversations[0].items()]))

In [None]:
word_freq = {}
for conv in tqdm(json_conversations):
    for word in conv["text"].split():
        if word in word_freq:
            word_freq[word] += 1
        else:
            word_freq[word] = 1
    
    for word in conv["prompt"].split():
        if word in word_freq:
            word_freq[word] += 1
        else:
            word_freq[word] = 1

with open("./data/word_freq.json", "w") as writer:
    json.dump(word_freq, writer, indent=4)

In [2]:
try:
    n = word_freq.get("the")
    print(n)
except:
    with open("./data/word_freq.json", "r") as reader:
        word_freq = json.load(reader)

In [65]:
class Tokenizer:
    __token_dict: dict[str, int]
    __reverse_token_dict: dict[int, str]

    vocab_size: int

    def __init__(self, word_freq: dict[str, int]):
        self.vocab_size = len(word_freq)

        # Sort the words by frequency
        sorted_words = sorted(word_freq, key=word_freq.get, reverse=True)

        self.__token_dict = {}
        self.__reverse_token_dict = {}

        for i, word in enumerate(sorted_words):
            self.__token_dict[word] = i
            self.__reverse_token_dict[i] = word
    
    def reduce_vocab_size(self, new_vocab_size: int):
        # cut out the least frequent words
        words_to_cut = list(self.__token_dict.keys())[new_vocab_size:]
        for word in words_to_cut:
            del self.__reverse_token_dict[self.__token_dict[word]]
            del self.__token_dict[word]
        
        self.vocab_size = len(self.__token_dict)
    
    def __get_token(self, word: str) -> int:
        if len(word) > 1:
            punctuations = ",.!?"
            if word[-1] in punctuations:
                word = word[:-1]
            if word[0] in punctuations:
                word = word[1:]
        
        if word in self.__token_dict:
            return self.__token_dict[word]
        else:
            return -1
        
        
    def encode(self, text: str) -> list[int]:
        return [self.__get_token(word) for word in text.split()]
    
    def encode_one_hot(self, text: str) -> list[int]:
        tokens = self.encode(text)
        one_hot_tokens = []
        for tok in tokens:
            one_hot = [0] * self.vocab_size
            one_hot[tok] = 1
            one_hot_tokens.append(one_hot)
        return one_hot_tokens

    def __get_word(self, token: int) -> str:
        if token in self.__reverse_token_dict:
            return self.__reverse_token_dict[token]
        else:
            return "N/A"
    
    def decode(self, tokens: list[int]) -> str:
        return " ".join([self.__get_word(tok) for tok in tokens])
    
    def decode_one_hot_tokens(self, one_hot_tokens: list[int]) -> str:
        tokens = [self.__reverse_token_dict[one_hot.index(max(one_hot))] for one_hot in one_hot_tokens]
        return " ".join(tokens)
    
    def vocab_size(self) -> int:
        return self.vocab_size

def pad_or_truncate(tokens: list[int], length: int) -> list[int]:
    if len(tokens) < length:
        return tokens + [0] * (length - len(tokens))
    else:
        return tokens[:length]

In [66]:
tokenizer = Tokenizer(word_freq)
tokenizer.reduce_vocab_size(64_000)
print(tokenizer.vocab_size)

64000


In [63]:
txt = "the quick brown fox jumps over the lazy dog"

tokens = tokenizer.encode(txt)

txt_out = tokenizer.decode(tokens)
print(txt_out)

the quick brown fox jumps over the lazy dog


In [6]:
import torch
import torch.nn as nn

print(torch.__version__)
print(torch.cuda.is_available())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

2.2.0+cpu
False
cpu


In [20]:
def pad_or_truncate_tensor(tensor: torch.Tensor, length: int) -> torch.Tensor:

    if tensor.size(0) < length:
        if len(tensor.size()) == 1:
            return torch.cat([tensor, torch.zeros(length - tensor.size(0))], dim=0)
        return torch.cat([tensor, torch.zeros(length - tensor.size(0), tensor.size(1))], dim=0)
    else:
        return tensor[:length]

In [52]:
class Network(nn.Module):
    def __init__(self, hidden_layer_size: int, n_hidden_layers: int, context_window_size: int):
        super(Network, self).__init__()

        self.context_window_size = context_window_size

        input_shape = (context_window_size, hidden_layer_size)
        hidden_layer_shape = (hidden_layer_size, hidden_layer_size)
        output_shape = (hidden_layer_size, context_window_size)

        self.input_layer = nn.Linear(*input_shape)
        self.hidden_layers = nn.ModuleList([nn.Linear(*hidden_layer_shape) for _ in range(n_hidden_layers)])
        self.output_layer = nn.Linear(*output_shape)
    
    def forward(self, x):
        x = self.input_layer(x)
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.output_layer(x)
        return x
    
    def random_init(self):
        for layer in [self.input_layer, *self.hidden_layers, self.output_layer]:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)
    
    def generate_text(self, input_text: str, tokenizer: Tokenizer, n_words: int):
        tokens = tokenizer.encode(input_text)
        input_tensor = torch.tensor(tokens, dtype=torch.float32).to(device)
        input_tensor = pad_or_truncate_tensor(input_tensor, self.context_window_size).to(device)

        n_inferences = n_words // self.context_window_size + 1 # number of inferences needed to generate n_words

        output_text = []
        for _ in range(n_inferences):
            output = self(input_tensor)
            output_py_arr = output.detach().cpu().numpy()

            for elem in output_py_arr:
                token_idx = int(abs(elem))
                output_text.append(tokenizer.decode([token_idx]))
                input_tensor = torch.cat([input_tensor[1:], torch.tensor([token_idx], dtype=torch.float32).to(device)], dim=0)
        
        return " ".join(output_text)


In [40]:
with jsonlines.open("./data/train.jsonl", "r") as reader:
    data = list(reader)

In [42]:
print(data[0].keys())

dict_keys(['prompt', 'text_token_length', 'text', 'seed_data', 'format', 'audience'])


In [43]:
X = []
y = []

for conv in tqdm(data):
    y.append(torch.tensor(tokenizer.encode(conv["text"])))
    X.append(torch.tensor(tokenizer.encode(conv["prompt"])))

100%|██████████| 100000/100000 [00:54<00:00, 1839.04it/s]


In [46]:
max_len_x = max([len(x) for x in X])
max_len_y = max([len(y) for y in y])

print(max_len_x, max_len_y) # 415, 1805

context_window_size = 2048

415 1805


In [47]:
X = [pad_or_truncate_tensor(x, context_window_size) for x in X]
X = torch.stack(X).to(device)

y = [pad_or_truncate_tensor(y_, context_window_size) for y_ in y]
y = torch.stack(y).to(device)

In [70]:
hidden_layer_size = 4096
n_hidden_layers = 16

network = Network(hidden_layer_size, n_hidden_layers, context_window_size).to(device)
network.random_init()

In [56]:
import matplotlib.pyplot as plt

In [71]:
# Training

lr = 0.001
n_epochs = 10
batch_size = 32

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(network.parameters(), lr=lr)


loss_data = []

for epoch in range(n_epochs):
    for i in tqdm(range(0, len(X), batch_size)):
        X_batch = X[i:i+batch_size]
        y_batch = y[i:i+batch_size]

        optimizer.zero_grad()
        output = network(X_batch)
        loss = criterion(output, y_batch)
        loss.backward()
        optimizer.step()

        loss_data.append(loss.item())
    
    # plot the loss
    # and save it to ./data/loss_{epoch}.png
    plt.plot(loss_data)
    plt.savefig(f"./data/loss_{epoch}.png")
    plt.close()


100%|██████████| 3125/3125 [1:38:58<00:00,  1.90s/it]
100%|██████████| 3125/3125 [1:38:30<00:00,  1.89s/it]
100%|██████████| 3125/3125 [1:41:45<00:00,  1.95s/it]
100%|██████████| 3125/3125 [1:41:59<00:00,  1.96s/it]
100%|██████████| 3125/3125 [1:42:43<00:00,  1.97s/it]
 14%|█▍        | 438/3125 [14:23<1:41:05,  2.26s/it]

In [69]:
output_toks = network(X[0].unsqueeze(0))
print(output_toks)

tensor([[75224272.0000, 68507648.0000, 75482024.0000,  ...,
          4921658.5000,  4691474.0000,  4265536.0000]],
       grad_fn=<AddmmBackward0>)
