In [25]:
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
import yaml
import torch
from utils.tokenizer import load_tokenizer
from train import model_from_config

device = "cuda" if torch.cuda.is_available() else "cpu"

tok = load_tokenizer("tokenizer_bytes")

with open("../data/small_train.txt", "r") as f:
    f.readline() # the first line is just "text"
    data = f.read()
    
with open("../configs/llama_medium.yaml", "r") as f:
    config = yaml.safe_load(f)
    


In [None]:
from typing import Iterator
from utils.sample import sample_with_temp, nucleus_sample
import torch

def generate_text(model, tokenizer, prompt, max_length=50, temperature=0.7)->Iterator[str]:
    model.eval()
    device = next(model.parameters()).device
    
    bos_token = tokenizer.token_to_id("<bos>")
    input_ids = torch.tensor([bos_token] + tokenizer.encode(prompt, add_special_tokens=False).ids).unsqueeze(0).to(device)
    
    generated = []
    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids)
            next_token_logits = outputs.logits[:, -1, :]
            # next_token = nucleus_sample(next_token_logits, 0.5)
            next_token = sample_with_temp(next_token_logits, temperature)
        
        generated.append(next_token.item())
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
        
        # Check for end-of-sequence token
        if next_token.item() == tokenizer.token_to_id("<eos>"):
            break
    
        
        yield tokenizer.decode(next_token.tolist())

# Example usage
prompt = ""
print(prompt, end="")
for token in generate_text(model, tok, prompt, max_length=1024, temperature=0.5):
    print(token, end="")

In [27]:
with open("../data/small_train.txt", "r") as f:
    f.readline() # the first line is just "text"
    data = f.read()
    
# first 10000 words
data = " ".join(data.split()[:10000])
    
data = data.lower()
data = data.replace(".", "")
data = data.replace(",", "")
data = data.replace("!", "")
data = data.replace("?", "")
data = data.replace("'", "")
data = data.replace("\"", "")

word_to_ids = {word: i for i, word in enumerate(set(data.split()))}
ids_to_word = {i: word for word, i in word_to_ids.items()}

In [28]:
tokenized_data = [word_to_ids[word] for word in data.split()]

In [29]:
history = [word_to_ids[word] for word in "how are you doing today man".split()] # tokens of the history, n-1 tokens


In [30]:
history

[548, 707, 997, 431, 404, 6]

In [31]:
from ngram_trie import PySmoothedTrie

trie = PySmoothedTrie(n_gram_max_length=7, root_capacity=None)

trie.fit(tokenized_data, n_gram_max_length=7, root_capacity=None, max_tokens=None)

trie.set_rule_set(["+++++++", "++++++", "+++++", "++++", "+++", "++", "+"])

trie.fit_smoothing()

trie.get_prediction_probabilities(history)

----- Calculating d values -----
Number of nodes: 0
----- Trie fitting -----
Expected time for 10000 tokens: -0.0015988 min
Expected ram usage for 10000 tokens: 5.397831143814615 MB
Time taken to fit trie: 60.363792ms
----- Shrinking to fit -----
Time taken to shrink to fit: 4.781333ms
----- Calculating size in RAM -----
Time taken to calculate size in RAM: 421.959µs
Size in RAM: 2.4482803344726563 MB
----- Calculating d values -----
Number of nodes: 53482
Time taken: 11.200916ms
Smoothing calculated, d1: 0.9308956725930896, d2: 1.1926212289483884, d3: 0.7352836621989014, uniform: 0.0008688097306689834
----- Getting prediction probabilities -----
Time taken to get prediction probabilities: 1.53675ms


[?25l[1G  0%|                                                  | 0/9994 [00:00<?, ?it/s][?25h[?25l[1G  0%|                                                  | 1/9994 [00:00<?, ?it/s][?25h[1G100%|███████████████████████████████████| 9994/9994 [00:00<00:00, 492557.18it/s]


[(0,
  [('++++++', nan),
   ('+++++', nan),
   ('++++', nan),
   ('+++', nan),
   ('++', nan),
   ('+', 0.000571997711674877)]),
 (1,
  [('++++++', nan),
   ('+++++', nan),
   ('++++', nan),
   ('+++', nan),
   ('++', nan),
   ('+', 0.0003857068026108205)]),
 (2,
  [('++++++', nan),
   ('+++++', nan),
   ('++++', nan),
   ('+++', nan),
   ('++', nan),
   ('+', 0.010352270437537834)]),
 (3,
  [('++++++', nan),
   ('+++++', nan),
   ('++++', nan),
   ('+++', nan),
   ('++', nan),
   ('+', 0.00015681697801731148)]),
 (4,
  [('++++++', nan),
   ('+++++', nan),
   ('++++', nan),
   ('+++', nan),
   ('++', nan),
   ('+', 8.805006939464557e-05)]),
 (5,
  [('++++++', nan),
   ('+++++', nan),
   ('++++', nan),
   ('+++', nan),
   ('++', nan),
   ('+', 0.0009445795298029898)]),
 (6,
  [('++++++', nan),
   ('+++++', nan),
   ('++++', nan),
   ('+++', nan),
   ('++', nan),
   ('+', 0.0003857068026108205)]),
 (7,
  [('++++++', nan),
   ('+++++', nan),
   ('++++', nan),
   ('+++', nan),
   ('++', na