In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import yaml
import torch
from utils.tokenizer import load_tokenizer
from train import model_from_config

device = "cuda" if torch.cuda.is_available() else "cpu"

tok = load_tokenizer("tokenizer_bytes")

with open("../data/small_train.txt", "r") as f:
    f.readline() # the first line is just "text"
    data = f.read()
    
with open("../configs/llama_medium.yaml", "r") as f:
    config = yaml.safe_load(f)
    
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = model_from_config(config).to(DEVICE)
    

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import numpy as np
def formatted_ngram_probs(token_probs):
    '''
    probs: list(token, list(rule, prob)) 
    returns  {rule: [(token, prob)]}
    
    So the tokens are sorted by their probability given a rule, with NaN values at the bottom
    '''
    
    rule_to_probs = {}
    
    for token, rule_probs in token_probs:
        for rule, prob in rule_probs:
            if rule not in rule_to_probs:
                rule_to_probs[rule] = []
            # Replace NaN with 0
            prob = 0 if np.isnan(prob) else prob
            rule_to_probs[rule].append((token, prob))
    
    # Sort tokens by probability within each rule
    for rule in rule_to_probs:
        rule_to_probs[rule].sort(key=lambda x: x[1], reverse=True)
    
    return rule_to_probs

    

In [4]:
def apply_rule(context, rule_context):
    '''
    context: list of tokens
    rule: string - "++"
    '''
    for rule in rule_context:
        if rule == "+":
            continue
        elif rule == "-":
            context = context[:-1]
        
    rule_context = rule_context[1:]
    
    return context


In [5]:
def model_logits(input_ids, model, rule):
    
    ## TODO: add marginalization
    
    context = apply_rule(input_ids, rule)
    
    input_ids_model = torch.tensor(context).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(input_ids_model)
        next_token_logits = outputs.logits[:, -1, :]
    
    return next_token_logits
        


In [6]:
class NgramStats:
    def __init__(self, tokenized_data):
        self.tokenized_data = tokenized_data
        
        from ngram_trie import PySmoothedTrie

        self.trie = PySmoothedTrie(n_gram_max_length=7, root_capacity=None)

        self.trie.fit(self.tokenized_data, n_gram_max_length=7, root_capacity=None, max_tokens=None)

    def top1_accuracy(self, text, rules, model):
        '''return {rule: {n_gram: prob, transformer: prob}}'''
        
        self.trie.set_rule_set(rules)

        self.trie.fit_smoothing()
        
        bos_token = tok.token_to_id("<bos>")
        input_ids = [bos_token] + tok.encode(text, add_special_tokens=False).ids
        
        n_gram_probs = self.trie.get_prediction_probabilities(input_ids)
        
        rule_to_ngram_probs = formatted_ngram_probs(n_gram_probs)
        
        top1_match = 0
        for rule in rules:
            import torch.nn.functional as F
            logits = model_logits(input_ids, model, rule)
            p = F.softmax(logits, dim=-1)
            top_probs, top_indices = torch.topk(p, k=1)
                        
            # if predicting the same token as the ngram model, increment top1_match
            top1_match += rule_to_ngram_probs[rule][0][0] == top_indices[0].item()
            
        return top1_match / len(rules)
            
            

In [7]:
from utils.dataset import MemmapDataset
from utils.tokenizer import load_tokenizer
from torch.utils.data import random_split, DataLoader

tok = load_tokenizer("tokenizer_bytes")
tok.pad_token = tok.token_to_id("<pad>")
full_ds = MemmapDataset(dataset_file="small_train", tokenizer_name='tokenizer_bytes', num_tokens=2048 - 1)

train_size = int(0.8 * len(full_ds))
val_size = int(0.1 * len(full_ds))
test_size = len(full_ds) - train_size - val_size
train_ds, val_ds, _ = random_split(full_ds, [train_size, val_size, test_size])

tokenized_data = []

for batch in train_ds:
    tokenized_data.extend(batch.tolist())



In [8]:
ns = NgramStats(tokenized_data)

----- Calculating d values -----
Number of nodes: 0
----- Trie fitting -----
Expected time for 180136 tokens: 0.00636129632 min
Expected ram usage for 180136 tokens: 63.477342066026786 MB
Time taken to fit trie: 119.170875ms
----- Shrinking to fit -----
Time taken to shrink to fit: 83.411ms
----- Calculating size in RAM -----
Time taken to calculate size in RAM: 18.663667ms
Size in RAM: 34.89246368408203 MB


[?25l[1G  0%|                                                | 0/180130 [00:00<?, ?it/s][?25h[?25l[1G  0%|                                                | 1/180130 [00:00<?, ?it/s][?25h[?25l[1G 45%|██████████████▏                | 82278/180130 [00:00<00:00, 1974632.20it/s][?25h[?25l[1G 82%|████████████████████████▋     | 147790/180130 [00:00<00:00, 1853925.17it/s][?25h[1G100%|██████████████████████████████| 180130/180130 [00:00<00:00, 1708310.01it/s]


In [9]:
from itertools import product

#symbols = ['-', '+', '*'] # TODO: add support for marginalization
symbols = ['-', '+']
rules = []

for length in range(1, 7):
    rules.extend([''.join(p) for p in product(symbols, repeat=length)])

In [10]:
prompt = "time. in a big"
top1 = ns.top1_accuracy(prompt, rules, model)
print(top1)

----- Calculating d values -----
Number of nodes: 762236
Time taken: 133.687834ms
Smoothing calculated, d1: 0.8988142808639557, d2: 1.136216810277788, d3: 1.2239528302386875, uniform: 0.00020491803278688525
----- Getting prediction probabilities -----
Time taken to get prediction probabilities: 153.829542ms


RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)