In [1]:
!pip install transformers accelerate optimum nvidia-ml-py



In [2]:
from transformers.utils import is_flash_attn_2_available
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import numpy as np
import torch.nn.functional as F
import torch
from datetime import timedelta
import time
from collections import namedtuple
import json

torch.random.manual_seed(0)

<torch._C.Generator at 0x7f8ab46121f0>

In [3]:
from pynvml import *

def check_gpu(step):
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"{step}: GPU memory used: {info.used // 1024**2} MB.")

In [16]:
def D(obj, label=None, c=True):
    print()
    if label:
        print(label)
        
    if isinstance(obj, tuple):
        print(len(obj))
    elif isinstance(obj, torch.Tensor) or isinstance(obj, np.ndarray):
        print(obj.shape)
        if c: # Contents
            display(obj)
    else:
        if c: # Contents
            display(obj)
            
def DS(obj, label=None):
    D(obj, label, c=False)

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    torch_dtype=torch.bfloat16,
    device_map='auto',
    trust_remote_code=True,
    use_cache=True,
    # attn_implementation='flash_attention_2',
)
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct")
device = "cuda" if torch.cuda.is_available() else "cpu"
print('device', device)

check_gpu('model init')

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


device cuda
model init: GPU memory used: 7813 MB.


In [6]:
max_candidates = 16
max_new_tokens = 3
batch_size = 8
p_falloff = 0.5 # UNIMPLEMENTED
prune_similar_sequences = True # UNIMPLEMENTED
prune_similar_branches = True # UNIMPLEMENTED
prune_similar_embeddings = True # UNIMPLEMENTED

In [7]:
def init_candidates(text: str):
    prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
    inputs = tokenizer(prompt, return_tensors='pt')

    max_total_tokens = inputs.input_ids.shape[1] + max_new_tokens

    # (max_candidates, max_total_tokens)
    candidates = torch.zeros((max_candidates, max_total_tokens), dtype=torch.long, device=device)
    # (max_candidates, max_total_tokens)
    candidate_masks = torch.zeros((max_candidates, max_total_tokens), dtype=torch.bool, device=device)
    # (max_candidates)
    candidate_parents = torch.zeros((max_candidates), dtype=torch.long, device=device)
    # (max_candidates)
    candidate_logprobs = torch.zeros((max_candidates), dtype=torch.float32, device=device)

    candidates[0, :inputs.input_ids.shape[1]] = inputs.input_ids
    candidate_masks[0, :inputs.input_ids.shape[1]] = inputs.attention_mask
    candidate_parents[0] = 0
    candidate_logprobs[0] = 0.0

    return candidates, candidate_masks, candidate_parents, candidate_logprobs

candidates, candidate_masks, candidate_parents, candidate_logprobs = init_candidates('What is the most popular breed of dog?')
D(candidates)
D(candidate_masks)
D(candidate_parents)
D(candidate_logprobs)

check_gpu('candidates init')

torch.Size([16, 18])


tensor([[    1, 32010,  1724,   338,   278,  1556,  5972,  2078,   287,   310,
         11203, 29973, 29871, 32007, 32001,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0, 

torch.Size([16, 18])


tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, 

torch.Size([16])


tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')

torch.Size([16])


tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')

candidates init: GPU memory used: 7843 MB.


In [8]:
# For testing batch inputs
inputs2 = tokenizer("<|user|>\n{} <|end|>\n<|assistant|>".format('What is the most popular breed of cat?'), return_tensors='pt')
candidates[11, :inputs2.input_ids.shape[1]] = inputs2.input_ids
candidate_masks[11, :inputs2.input_ids.shape[1]] = inputs2.attention_mask
candidate_parents[11] = 0
candidate_logprobs[11] = -1.3

check_gpu('test addl inputs added')

test addl inputs added: GPU memory used: 7843 MB.


In [9]:
test_be = tokenizer(["String A", "String B which is longer", "String C which is even longer"], return_tensors="pt", padding=True)
test_be

{'input_ids': tensor([[32000, 32000, 32000, 32000,     1,  1714,   319],
        [32000,     1,  1714,   350,   607,   338,  5520],
        [    1,  1714,   315,   607,   338,  1584,  5520]]), 'attention_mask': tensor([[0, 0, 0, 0, 1, 1, 1],
        [0, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1]])}

In [10]:
tokenizer.batch_decode(test_be.input_ids)

['<|endoftext|><|endoftext|><|endoftext|><|endoftext|><s> String A',
 '<|endoftext|><s> String B which is longer',
 '<s> String C which is even longer']

In [29]:
def infer(candidates, candidate_masks, candidate_parents, candidate_logprobs):
    with torch.inference_mode():
        batches = (max_candidates + batch_size - 1) // batch_size  # Round up to nearest whole number of batches

        check_gpu('infer start')
        for i in range(0, batches, 1):
            batch_candidates = candidates[i * batch_size:(i + 1) * batch_size]
            D(batch_candidates)
            batch_candidate_masks = candidate_masks[i * batch_size:(i + 1) * batch_size]
            D(batch_candidate_masks)

            check_gpu('batch views made')

            batch_outputs = model(input_ids=batch_candidates, attention_mask=batch_candidate_masks)
            D(batch_outputs.logits)

            # Possibly turn off caching to save memory here?
            check_gpu('batch forward run')

            # del batch_outputs

            check_gpu('batch outputs deleted')
            break
            
        return batch_outputs

logits = infer(candidates, candidate_masks, candidate_parents, candidate_logprobs).logits
D(logits)

check_gpu('all batches run')

infer start: GPU memory used: 8165 MB.
torch.Size([8, 18])


tensor([[    1, 32010,  1724,   338,   278,  1556,  5972,  2078,   287,   310,
         11203, 29973, 29871, 32007, 32001,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0, 


torch.Size([8, 18])


tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, 


batch views made: GPU memory used: 8165 MB.
torch.Size([8, 18, 32064])


tensor([[[ 1.8125,  1.3438, -0.4473,  ...,  0.0000,  0.0000,  0.0000],
         [ 4.2812,  9.6875, 10.1250,  ...,  0.0000,  0.0000,  0.0000],
         [ 6.0625,  2.9844,  3.8281,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 1.5859, -1.2031, -3.5625,  ...,  0.0000,  0.0000,  0.0000],
         [ 1.0703,  0.7109, -5.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 3.5938, -3.8750, -3.3281,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0


batch forward run: GPU memory used: 8165 MB.
batch outputs deleted: GPU memory used: 8165 MB.
torch.Size([8, 18, 32064])


tensor([[[ 1.8125,  1.3438, -0.4473,  ...,  0.0000,  0.0000,  0.0000],
         [ 4.2812,  9.6875, 10.1250,  ...,  0.0000,  0.0000,  0.0000],
         [ 6.0625,  2.9844,  3.8281,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 1.5859, -1.2031, -3.5625,  ...,  0.0000,  0.0000,  0.0000],
         [ 1.0703,  0.7109, -5.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 3.5938, -3.8750, -3.3281,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0


all batches run: GPU memory used: 8165 MB.


In [20]:
# Actually, no attention mask is needed -- all candidates will always be the same number of tokens (having started from the same
# base and with the same number of generations), so all we have to do is feed a view of the candidates tensor with just valid tokens
# into the model). Separately keep track of length of candidate sequences.

# def top_p_tokens(logits, top_p=0.9):
#     """logits of shape (batch_size, curr_seq_len, vocab_size)"""
#     with torch.inference_mode():
last_tok_logits = logits[:, -1, :]

sorted_logits, sorted_indices = torch.sort(last_tok_logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
D(sorted_probs)
cum_probs = torch.cumsum(sorted_probs, dim=-1)
D(cum_probs)

torch.Size([8, 32064])


tensor([[9.9975e-01, 9.6088e-05, 8.4797e-05,  ..., 9.9887e-20, 5.6914e-20,
         3.2429e-20],
        [3.3339e-02, 1.7845e-02, 8.1702e-03,  ..., 4.3864e-09, 3.6364e-09,
         3.3110e-09],
        [3.3339e-02, 1.7845e-02, 8.1702e-03,  ..., 4.3864e-09, 3.6364e-09,
         3.3110e-09],
        ...,
        [3.3339e-02, 1.7845e-02, 8.1702e-03,  ..., 4.3864e-09, 3.6364e-09,
         3.3110e-09],
        [3.3339e-02, 1.7845e-02, 8.1702e-03,  ..., 4.3864e-09, 3.6364e-09,
         3.3110e-09],
        [3.3339e-02, 1.7845e-02, 8.1702e-03,  ..., 4.3864e-09, 3.6364e-09,
         3.3110e-09]], device='cuda:0')

torch.Size([8, 32064])


tensor([[0.9998, 0.9998, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [0.0333, 0.0512, 0.0594,  ..., 1.0000, 1.0000, 1.0000],
        [0.0333, 0.0512, 0.0594,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.0333, 0.0512, 0.0594,  ..., 1.0000, 1.0000, 1.0000],
        [0.0333, 0.0512, 0.0594,  ..., 1.0000, 1.0000, 1.0000],
        [0.0333, 0.0512, 0.0594,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')

In [22]:
# Create tensor of bools indicating which indices are cumulatively less than top_p
keep_indices = cum_probs < 0.9

# Keep the last element that went over top_p
keep_indices[:, 1:] = keep_indices[:, :-1].clone() # Is this inefficient?
keep_indices[:, 0] = 1  # Always keep the first element

D(keep_indices)

torch.Size([8, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')

In [25]:
keep_toks = sorted_indices[keep_indices]
keep_probs = sorted_probs[keep_indices]

D(keep_toks)
D(keep_probs)

# top_p_tokens(logits)

torch.Size([90805])


tensor([29871, 24278, 26785,  ..., 15108, 15268, 16121], device='cuda:0')

torch.Size([90805])


tensor([9.9975e-01, 3.3339e-02, 1.7845e-02,  ..., 1.4587e-05, 1.4587e-05,
        1.4587e-05], device='cuda:0')

In [56]:
D(candidates.index_select(0, keep_indices.nonzero()[:, 0])) # COMPONENT A

torch.Size([90805, 18])


tensor([[    1, 32010,  1724,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0],
        ...,
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0]], device='cuda:0')




In [60]:
D(sorted_indices[keep_indices]) # I think this is COMPONENT B

torch.Size([90805])


tensor([29871, 24278, 26785,  ..., 15108, 15268, 16121], device='cuda:0')




In [61]:
D(sorted_probs[keep_indices]) # I think this is COMPONENT C, still needs to be ln

torch.Size([90805])


tensor([9.9975e-01, 3.3339e-02, 1.7845e-02,  ..., 1.4587e-05, 1.4587e-05,
        1.4587e-05], device='cuda:0')




In [27]:
from sklearn.cluster import KMeans

from pynvml import *

def check_gpu(step):
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"{step}: GPU memory used: {info.used // 1024**2} MB.")
    
def D(obj, label=None, c=True):
    print()
    if label:
        print(label)
        
    if isinstance(obj, tuple):
        print(len(obj))
    elif isinstance(obj, torch.Tensor) or isinstance(obj, np.ndarray):
        print(obj.shape)
        if c: # Contents
            display(obj)
    else:
        if c: # Contents
            display(obj)
            
def DS(obj, label=None):
    D(obj, label, c=False)
    
    
    

class InferenceTensor:
    def __init__(self):
        self.model = AutoModelForCausalLM.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct",
            torch_dtype=torch.bfloat16,
            device_map='auto',
            trust_remote_code=True,
            use_cache=True,
            # attn_implementation='flash_attention_2',
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.max_candidates = 20
        self.max_new_tokens = 100
        self.batch_size = 8
        self.p_falloff = 0.5 # UNIMPLEMENTED
        self.prune_similar_sequences = True # UNIMPLEMENTED
        self.prune_similar_branches = True # UNIMPLEMENTED
        self.prune_similar_embeddings = True # UNIMPLEMENTED
        
    def candidates_generator(self, top_p: float, max_beams: int, prompt: str):
        print(prompt)
        candidates, candidate_logprobs = self._init_candidates(prompt)
        for level_idx in range(self.max_new_tokens):
            logits, embeddings = self._infer(candidates[:self.max_candidates, ...], candidate_logprobs[:self.max_candidates, ...])
        
            if candidates.shape[0] > max_beams:
                candidates, candidate_parents, candidate_logprobs = self._k_means(output_embeddings, candidates, candidate_logprobs, max_beams)
                self._send_candidates('k_means', f"{level_idx}-k", candidates, candidate_parents, candidate_logprobs)

            candidates, candidate_parents, candidate_logprobs = self._top_p(output_logits, candidates, candidate_logprobs, top_p)
            self._send_candidates('top_p', f"{level_idx}-p", candidates, candidate_parents, candidate_logprobs)
            
            

        yield f"event: level\nid: END\ndata: []\n\n"

    def _send_candidates(self, event: str, idx: int, candidates, candidate_parents, candidate_logprobs):
        candidate_texts = self.tokenizer.batch_decode(candidates[:, -1])
        candidate_probs = candidate_logprobs.exp()
        candidate_dicts = []
        for i in range(len(candidate_texts)):
            candidate_dicts.append({'content': candidate_texts[i], 'parents': list(candidate_parents[i]), 'prob': candidate_probs[i].item()})
        data = json.dumps(candidate_dicts)
        yield f"event: {event}\nid: {idx}\ndata: {data}\n\n"
        
    def _init_candidates(self, text: str):
        prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
        inputs = tokenizer(prompt, return_tensors='pt')
        D(inputs.input_ids, 'input_ids')
        print(tokenizer.batch_decode(inputs.input_ids))

        candidates = inputs.input_ids.to(device)
        candidate_logprobs = torch.zeros((1), dtype=torch.float32, device=device)

        return candidates, candidate_logprobs

    def _k_means(self, candidates, candidate_logprobs, max_beams):
        # === CPU ===
        embeddings_np = output_embeddings.float().numpy(force=True)
        D(embeddings_np, 'embeddings_np')
        k_means = KMeans(n_clusters=min(2, embeddings_np.shape[0]), random_state=0, n_init="auto")
        k_mean_space = k_means.fit_transform(embeddings_np)
        D(k_mean_space, 'k_mean_space')
        k_mean_clusters = k_means.predict(embeddings_np)
        D(k_mean_clusters, 'k_mean_clusters')
        k_mean_logprob_mass = np.bincount(k_mean_clusters, weights=candidate_logprobs.cpu())
        D(k_mean_logprob_mass, 'k_mean_logprob_mass')
        closest = np.argmin(k_mean_space, axis=0)
        D(closest, 'closest')
        # === END CPU ===
        
        new_candidates = candidates.index_select(0, closest)
        D(new_candidates, 'new_candidates')
        new_candidate_logprobs = k_mean_logprob_mass
        D(new_candidate_logprobs, 'new_candidate_logprobs')
        
        return new_candidates, [], new_candidate_logprobs
        
    
    def _top_p(self, logits, candidates, candidate_logprobs, top_p):
        last_tok_logits = logits[:, -1, :]
        D(last_tok_logits, 'last_tok_logits')

        sorted_logits, sorted_indices = torch.sort(last_tok_logits, descending=True, dim=-1)
        DS(sorted_logits, 'sorted_logits')
        DS(sorted_indices, 'sorted_indices')
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        D(sorted_probs, 'sorted_probs')
        display(sorted_probs.sum(dim=1))
        cum_probs = torch.cumsum(sorted_probs, dim=-1)
        D(cum_probs, 'cum_probs')

        # Create tensor of bools indicating which indices are cumulatively less than top_p
        keep_indices = cum_probs < top_p

        # Keep the last element that went over top_p
        keep_indices[:, 1:] = keep_indices[:, :-1].clone() # Is this inefficient?
        keep_indices[:, 0] = 1  # Always keep the first element
        D(keep_indices, 'keep_indices')

        new_candidate_parents = keep_indices.nonzero()[:, 0]
        D(new_candidate_parents, 'new_candidate_parents')

        # OPTIM: Potential optimization -- have a fixed tensor of size (max_candidates, max_tokens) and copy this into that (batch-aware).
        # OPTIM: consider which of these operations can be done in-place to prevent new allocations?
        carryover_candidates = candidates.index_select(0, new_candidate_parents)
        D(carryover_candidates, 'carryover_candidates')

        # Similar code could be used to trace entire origin of sequence. For now since server just traces parent of the preceding generation, not needed
        # carryover_candidate_parents = candidate_parents.index_select(0, carryover_candidate_indices)  # Not strictly necessary since 1d
        # D(carryover_candidate_parents, 'carryover_candidate_parents')

        carryover_candidate_logprobs = candidate_logprobs.index_select(0, new_candidate_parents)  # Not strictly necessary since 1d
        D(carryover_candidate_logprobs, 'carryover_candidate_logprobs')

        new_candidate_toks = sorted_indices[keep_indices].unsqueeze(1)
        D(new_candidate_toks, 'new_candidate_toks')
        new_candidate_tok_logprobs = sorted_probs[keep_indices].log()
        D(new_candidate_tok_logprobs, 'new_candidate_tok_logprobs')

        new_candidates = torch.cat([carryover_candidates, new_candidate_toks], dim=1)
        D(new_candidates, 'new_candidates')
        new_candidate_logprobs = carryover_candidate_logprobs.add_(new_candidate_tok_logprobs)
        D(new_candidate_logprobs, 'new_candidate_logprobs')

        return new_candidates, new_candidate_parents, new_candidate_logprobs


    def _infer(self, candidates, candidate_logprobs):
        with torch.inference_mode():
            num_batches = (candidates.shape[0] + batch_size - 1) // batch_size  # Round up to nearest whole number of batches
            print('\nnum_batches', num_batches)

            new_candidates_list = []
            new_candidate_parents_list = []
            new_candidate_logprobs_list = []

            check_gpu('infer start')
            output_logits_list = []
            output_embeddings_list = []
            for i in range(0, num_batches, 1):
                batch_candidates = candidates[i * batch_size:(i + 1) * batch_size]
                DS(batch_candidates, 'batch_candidates')
                batch_candidate_logprobs = candidate_logprobs[i * batch_size:(i + 1) * batch_size]
                DS(batch_candidate_logprobs, 'batch_candidate_logprobs')

                batch_outputs = model(input_ids=batch_candidates, output_hidden_states=True)
                DS(batch_outputs.logits, 'batch_logits')
                DS(batch_outputs.hidden_states[-1], 'hidden_states[-1]')

                output_logits_list.append(batch_outputs.logits)
                output_embeddings_list.append(batch_outputs.hidden_states[-1][:,-1,:])
                check_gpu('infer - after batch run')

            output_logits = torch.cat(output_logits_list, dim=0)
            output_embeddings = torch.cat(output_embeddings_list, dim=0)
            
            return output_logits, output_embeddings

            

it = InferenceTensor()

for x in it.candidates_generator(0.9, 5, 'What is the highest mountain?'):
    print(x)
    print()
    print('====================================')
    print()



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


What is the highest mountain?

input_ids
torch.Size([1, 11])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001]])

['<s><|user|> What is the highest mountain? <|end|><|assistant|>']

num_batches 1
infer start: GPU memory used: 15393 MB.

batch_candidates
torch.Size([1, 11])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 11, 32064])

hidden_states[-1]
torch.Size([1, 11, 3072])
infer - after batch run: GPU memory used: 15395 MB.

embeddings_np
(1, 3072)


array([[-0.73828125,  1.390625  ,  1.9765625 , ...,  2.109375  ,
        -0.765625  , -0.19335938]], dtype=float32)


k_mean_space
(1, 1)


array([[0.]], dtype=float32)


k_mean_clusters
(1,)


array([0], dtype=int32)


k_mean_logprob_mass
(1,)


array([0.])


closest
(1,)


array([0])


last_tok_logits
torch.Size([1, 32064])


tensor([[ 5.7812, -1.4688, -3.2969,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.2405e-01, 7.5850e-02, 8.8812e-05,  ..., 1.1622e-21, 7.0490e-22,
         3.7730e-22]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9240, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 11])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([0.], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[450]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0790], device='cuda:0')


new_candidates
torch.Size([1, 12])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-0.0790], device='cuda:0')

infer end: GPU memory used: 15395 MB.
event: level
id: 0
data: [{"content": "The", "parent": 0, "prob": -0.0789911225438118}]





num_batches 1
infer start: GPU memory used: 15395 MB.

batch_candidates
torch.Size([1, 12])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 12, 32064])

hidden_states[-1]
torch.Size([1, 12, 3072])
infer - after batch run: GPU memory used: 15395 MB.

embeddings_np
(1, 3072)


array([[-0.5546875 ,  0.81640625,  1.546875  , ...,  0.96484375,
        -1.71875   ,  0.1953125 ]], dtype=float32)


k_mean_space
(1, 1)


array([[0.]], dtype=float32)


k_mean_clusters
(1,)


array([0], dtype=int32)


k_mean_logprob_mass
(1,)


array([-0.07899112])


closest
(1,)


array([0])


last_tok_logits
torch.Size([1, 32064])


tensor([[-1.2969,  1.0312, -5.0938,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9909e-01, 9.1105e-04, 1.2088e-06,  ..., 6.5938e-24, 3.9993e-24,
         1.4713e-24]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9991, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 12])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-0.0790], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[9939]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0009], device='cuda:0')


new_candidates
torch.Size([1, 13])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-0.0799], device='cuda:0')

infer end: GPU memory used: 15395 MB.
event: level
id: 1
data: [{"content": "highest", "parent": 0, "prob": -0.07990551739931107}]





num_batches 1
infer start: GPU memory used: 15395 MB.

batch_candidates
torch.Size([1, 13])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 13, 32064])

hidden_states[-1]
torch.Size([1, 13, 3072])
infer - after batch run: GPU memory used: 15397 MB.

embeddings_np
(1, 3072)


array([[-0.640625  , -0.61328125,  3.0625    , ..., -0.5703125 ,
        -1.875     ,  0.80859375]], dtype=float32)


k_mean_space
(1, 1)


array([[0.]], dtype=float32)


k_mean_clusters
(1,)


array([0], dtype=int32)


k_mean_logprob_mass
(1,)


array([-0.07990552])


closest
(1,)


array([0])


last_tok_logits
torch.Size([1, 32064])


tensor([[ 2.8906,  3.1875, -7.0938,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9984e-01, 1.2339e-04, 3.5352e-05,  ..., 5.1391e-24, 4.5352e-24,
         3.5320e-24]], device='cuda:0')

tensor([1.], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9998, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 13])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-0.0799], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[14378]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0002], device='cuda:0')


new_candidates
torch.Size([1, 14])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-0.0801], device='cuda:0')

infer end: GPU memory used: 15397 MB.
event: level
id: 2
data: [{"content": "mountain", "parent": 0, "prob": -0.08006944507360458}]





num_batches 1
infer start: GPU memory used: 15397 MB.

batch_candidates
torch.Size([1, 14])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 14, 32064])

hidden_states[-1]
torch.Size([1, 14, 3072])
infer - after batch run: GPU memory used: 15397 MB.

embeddings_np
(1, 3072)


array([[-2.109375  , -1.609375  ,  2.3125    , ..., -2.21875   ,
        -1.5078125 , -0.14453125]], dtype=float32)


k_mean_space
(1, 1)


array([[0.]], dtype=float32)


k_mean_clusters
(1,)


array([0], dtype=int32)


k_mean_logprob_mass
(1,)


array([-0.08006945])


closest
(1,)


array([0])


last_tok_logits
torch.Size([1, 32064])


tensor([[ 4.8125,  1.2734, -5.8750,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[8.1600e-01, 9.7458e-02, 8.6006e-02,  ..., 5.2119e-21, 4.0591e-21,
         1.6921e-21]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.8160, 0.9135, 0.9995,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 0], device='cuda:0')


carryover_candidates
torch.Size([2, 14])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-0.0801, -0.0801], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[373],
        [297]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.2033, -2.3283], device='cuda:0')


new_candidates
torch.Size([2, 15])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-0.2834, -2.4084], device='cuda:0')

infer end: GPU memory used: 15397 MB.
event: level
id: 3
data: [{"content": "on", "parent": 0, "prob": -0.2834072411060333}, {"content": "in", "parent": 0, "prob": -2.40840744972229}]





num_batches 1
infer start: GPU memory used: 15397 MB.

batch_candidates
torch.Size([2, 15])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 15, 32064])

hidden_states[-1]
torch.Size([2, 15, 3072])
infer - after batch run: GPU memory used: 15415 MB.

embeddings_np
(2, 3072)


array([[-1.890625  , -1.40625   ,  3.4375    , ..., -0.01074219,
        -1.453125  , -2.46875   ],
       [-1.265625  , -1.328125  ,  3.375     , ..., -0.02392578,
        -1.3515625 , -2.40625   ]], dtype=float32)


k_mean_space
(2, 2)


array([[85.34896,  0.     ],
       [ 0.     , 85.34896]], dtype=float32)


k_mean_clusters
(2,)


array([1, 0], dtype=int32)


k_mean_logprob_mass
(2,)


array([-2.40840745, -0.28340724])


closest
(2,)


array([1, 0])


last_tok_logits
torch.Size([2, 32064])


tensor([[ 1.6797,  0.8750, -4.7812,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.3438, -5.4688, -4.6875,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[9.9996e-01, 3.5356e-05, 1.7603e-06,  ..., 3.5325e-24, 2.4278e-24,
         3.7232e-25],
        [8.1683e-01, 1.8226e-01, 5.1193e-04,  ..., 3.4021e-20, 3.4021e-20,
         7.5910e-21]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8168, 0.9991, 0.9996,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 1], device='cuda:0')


carryover_candidates
torch.Size([3, 15])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([-0.2834, -2.4084, -2.4084], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[11563],
        [ 4958],
        [  278]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-3.7432e-05, -2.0232e-01, -1.7023e+00], device='cuda:0')


new_candidates
torch.Size([3, 16])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278]], device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([-0.2834, -2.6107, -4.1107], device='cuda:0')

infer end: GPU memory used: 15415 MB.
event: level
id: 4
data: [{"content": "Earth", "parent": 0, "prob": -0.28344467282295227}, {"content": "terms", "parent": 1, "prob": -2.6107285022735596}, {"content": "the", "parent": 1, "prob": -4.1107282638549805}]





num_batches 1
infer start: GPU memory used: 15415 MB.

batch_candidates
torch.Size([3, 16])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 16, 32064])

hidden_states[-1]
torch.Size([3, 16, 3072])
infer - after batch run: GPU memory used: 15433 MB.

embeddings_np
(3, 3072)


array([[-1.4765625 , -0.0112915 ,  1.4921875 , ..., -0.94921875,
        -0.27734375, -1.515625  ],
       [ 0.640625  , -1.109375  ,  2.71875   , ..., -0.9140625 ,
         2.5625    , -2.1875    ],
       [ 0.09326172, -0.6796875 ,  3.390625  , ..., -0.69921875,
        -3.28125   , -1.0859375 ]], dtype=float32)


k_mean_space
(3, 2)


array([[105.234764,  50.63994 ],
       [  0.      ,  93.1218  ],
       [106.760445,  50.63994 ]], dtype=float32)


k_mean_clusters
(3,)


array([1, 0, 1], dtype=int32)


k_mean_logprob_mass
(2,)


array([-2.6107285 , -4.39417294])


closest
(2,)


array([1, 0])


last_tok_logits
torch.Size([3, 32064])


tensor([[ -0.9766,  -6.3438, -10.3125,  ...,   0.0000,   0.0000,   0.0000],
        [  1.5703,  -2.2031,  -5.8750,  ...,   0.0000,   0.0000,   0.0000],
        [ -0.4199,   2.5781,  -0.6797,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[9.8127e-01, 1.5861e-02, 2.1465e-03,  ..., 8.4822e-22, 5.8297e-22,
         1.6702e-22],
        [1.0000e+00, 4.0587e-10, 1.3177e-10,  ..., 1.3697e-25, 1.2088e-25,
         1.0668e-25],
        [9.9997e-01, 2.4300e-05, 9.4222e-07,  ..., 4.6267e-22, 3.1799e-22,
         3.1799e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[0.9813, 0.9971, 0.9993,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([3, 16])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([-0.2834, -2.6107, -4.1107], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[29892],
        [  310],
        [ 3186]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-1.8907e-02,  0.0000e+00, -2.7895e-05], device='cuda:0')


new_candidates
torch.Size([3, 17])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186]], device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([-0.3024, -2.6107, -4.1108], device='cuda:0')

infer end: GPU memory used: 15453 MB.
event: level
id: 5
data: [{"content": ",", "parent": 0, "prob": -0.3023514449596405}, {"content": "of", "parent": 1, "prob": -2.6107285022735596}, {"content": "world", "parent": 2, "prob": -4.1107563972473145}]





num_batches 1
infer start: GPU memory used: 15453 MB.

batch_candidates
torch.Size([3, 17])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 17, 32064])

hidden_states[-1]
torch.Size([3, 17, 3072])
infer - after batch run: GPU memory used: 15459 MB.

embeddings_np
(3, 3072)


array([[-1.0859375 ,  2.140625  ,  2.5       , ..., -2.109375  ,
        -0.71484375, -1.5078125 ],
       [-2.953125  , -2.59375   ,  1.4453125 , ...,  1.1484375 ,
        -1.6171875 , -0.31835938],
       [-2.03125   , -1.4140625 ,  1.0546875 , ..., -1.390625  ,
         0.11914062, -0.3828125 ]], dtype=float32)


k_mean_space
(3, 2)


array([[94.75493 , 39.14829 ],
       [ 0.      , 87.11659 ],
       [96.256325, 39.14829 ]], dtype=float32)


k_mean_clusters
(3,)


array([1, 0, 1], dtype=int32)


k_mean_logprob_mass
(2,)


array([-2.6107285 , -4.41310784])


closest
(2,)


array([1, 0])


last_tok_logits
torch.Size([3, 32064])


tensor([[ -3.5156,  -2.1094, -10.3125,  ...,   0.0000,   0.0000,   0.0000],
        [  0.3340,  -0.2949,  -6.8750,  ...,   0.0000,   0.0000,   0.0000],
        [ -0.1992,  -3.4062,  -8.2500,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[4.2495e-01, 2.9206e-01, 2.2746e-01,  ..., 1.3078e-19, 1.0185e-19,
         7.9321e-20],
        [4.5078e-01, 1.6583e-01, 1.6583e-01,  ..., 1.0251e-18, 6.2173e-19,
         4.8420e-19],
        [7.6299e-01, 1.1701e-01, 9.1126e-02,  ..., 3.7954e-21, 2.9558e-21,
         6.5953e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[0.4249, 0.7170, 0.9445,  ..., 1.0000, 1.0000, 1.0000],
        [0.4508, 0.6166, 0.7824,  ..., 1.0000, 1.0000, 1.0000],
        [0.7630, 0.8800, 0.9711,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([13])


tensor([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2], device='cuda:0')


carryover_candidates
torch.Size([13, 17])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   33


carryover_candidate_logprobs
torch.Size([13])


tensor([-0.3024, -0.3024, -0.3024, -2.6107, -2.6107, -2.6107, -2.6107, -2.6107,
        -2.6107, -2.6107, -4.1108, -4.1108, -4.1108], device='cuda:0')


new_candidate_toks
torch.Size([13, 1])


tensor([[  408],
        [ 2729],
        [  297],
        [11858],
        [ 3171],
        [19224],
        [ 2038],
        [  967],
        [ 2533],
        [11563],
        [29892],
        [  338],
        [  746]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([13])


tensor([-0.8558, -1.2308, -1.4808, -0.7968, -1.7968, -1.7968, -3.0468, -3.2968,
        -3.4218, -3.7968, -0.2705, -2.1455, -2.3955], device='cuda:0')


new_candidates
torch.Size([13, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,  2729],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   297],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958, 


new_candidate_logprobs
torch.Size([13])


tensor([-1.1581, -1.5331, -1.7831, -3.4075, -4.4075, -4.4075, -5.6575, -5.9075,
        -6.0325, -6.4075, -4.3813, -6.2563, -6.5063], device='cuda:0')

infer end: GPU memory used: 15461 MB.
event: level
id: 6
data: [{"content": "as", "parent": 0, "prob": -1.1581370830535889}, {"content": "based", "parent": 0, "prob": -1.5331370830535889}, {"content": "in", "parent": 0, "prob": -1.7831370830535889}, {"content": "elev", "parent": 1, "prob": -3.4075088500976562}, {"content": "height", "parent": 1, "prob": -4.407508850097656}, {"content": "peak", "parent": 1, "prob": -4.407508850097656}, {"content": "above", "parent": 1, "prob": -5.657508850097656}, {"content": "its", "parent": 1, "prob": -5.907508850097656}, {"content": "sum", "parent": 1, "prob": -6.032508850097656}, {"content": "Earth", "parent": 1, "prob": -6.407508850097656}, {"content": ",", "parent": 2, "prob": -4.381266117095947}, {"content": "is", "parent": 2, "prob": -6.256266117095947}, {"content": "when", "parent": 2, "prob": -6.506266117095947}]





num_batches 2
infer start: GPU memory used: 15461 MB.

batch_candidates
torch.Size([8, 18])

batch_candidate_logprobs
torch.Siz

array([[-1.1328125 ,  0.54296875,  0.23242188, ..., -1.6953125 ,
         1.109375  , -5.3125    ],
       [ 0.58984375, -1.3671875 ,  0.359375  , ...,  0.41015625,
         0.5390625 , -3.890625  ],
       [-2.453125  , -0.46875   ,  2.453125  , ...,  0.3046875 ,
        -2.703125  , -4.375     ],
       ...,
       [-1.3671875 ,  1.515625  ,  0.9296875 , ..., -0.84765625,
         0.04980469, -0.82421875],
       [ 1.6875    ,  1.        ,  0.35546875, ...,  0.85546875,
        -1.0546875 , -0.25195312],
       [-1.171875  ,  1.15625   , -0.16113281, ..., -1.765625  ,
         0.34375   , -1.6796875 ]], dtype=float32)


k_mean_space
(13, 2)


array([[ 64.968605,  97.289665],
       [ 69.84189 ,  97.70087 ],
       [ 62.93975 ,  94.86857 ],
       [ 85.346535,  51.75946 ],
       [ 63.661728,  87.620575],
       [ 61.88258 ,  85.99392 ],
       [ 73.69413 ,  92.78236 ],
       [ 63.86057 ,  89.13428 ],
       [ 83.0659  ,  51.75946 ],
       [ 66.70392 ,  90.96561 ],
       [ 59.258583,  96.899796],
       [ 73.730484, 100.27267 ],
       [ 63.0741  ,  94.941605]], dtype=float32)


k_mean_clusters
(13,)


array([0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], dtype=int32)


k_mean_logprob_mass
(2,)


array([-48.40575385,  -9.4400177 ])


closest
(2,)


array([10,  3])


last_tok_logits
torch.Size([13, 32064])


tensor([[ -3.6250,  -6.8750,  -6.4062,  ...,   0.0000,   0.0000,   0.0000],
        [ -4.1250,  -1.3672,  -4.3125,  ...,   0.0000,   0.0000,   0.0000],
        [ -0.4863,  -6.3750,  -7.7188,  ...,   0.0000,   0.0000,   0.0000],
        ...,
        [ -3.3438,  -2.0469, -10.8750,  ...,   0.0000,   0.0000,   0.0000],
        [  2.2969,  -0.4414,  -6.3125,  ...,   0.0000,   0.0000,   0.0000],
        [ -0.5352,  -6.8750, -11.3125,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([13, 32064])

sorted_indices
torch.Size([13, 32064])

sorted_probs
torch.Size([13, 32064])


tensor([[8.5700e-01, 9.0327e-02, 2.9325e-02,  ..., 8.3943e-22, 5.7693e-22,
         8.8475e-23],
        [9.9999e-01, 8.9396e-06, 2.5613e-06,  ..., 8.9318e-25, 2.8997e-25,
         1.5521e-25],
        [9.9985e-01, 9.6097e-05, 2.4297e-05,  ..., 4.3030e-23, 3.3512e-23,
         2.9574e-23],
        ...,
        [7.1745e-01, 1.2467e-01, 6.6733e-02,  ..., 2.3272e-20, 1.4115e-20,
         7.0274e-22],
        [9.9976e-01, 8.4798e-05, 5.1432e-05,  ..., 2.2288e-20, 1.9669e-20,
         9.2910e-21],
        [9.0999e-01, 6.5920e-02, 1.4709e-02,  ..., 1.2063e-22, 1.0646e-22,
         5.6981e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([13, 32064])


tensor([[0.8570, 0.9473, 0.9766,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9998, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.7174, 0.8421, 0.9089,  ..., 1.0000, 1.0000, 1.0000],
        [0.9998, 0.9998, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [0.9100, 0.9759, 0.9906,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([13, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([23])


tensor([ 0,  0,  1,  2,  3,  4,  5,  5,  5,  6,  7,  7,  7,  7,  7,  7,  8,  9,
        10, 10, 10, 11, 12], device='cuda:0')


carryover_candidates
torch.Size([23, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,  2729],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   297],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958, 


carryover_candidate_logprobs
torch.Size([23])


tensor([-1.1581, -1.1581, -1.5331, -1.7831, -3.4075, -4.4075, -4.4075, -4.4075,
        -4.4075, -5.6575, -5.9075, -5.9075, -5.9075, -5.9075, -5.9075, -5.9075,
        -6.0325, -6.4075, -4.3813, -4.3813, -4.3813, -6.2563, -6.5063],
       device='cuda:0')


new_candidate_toks
torch.Size([23, 1])


tensor([[17005],
        [  310],
        [  373],
        [ 4958],
        [  362],
        [ 2038],
        [11858],
        [ 3171],
        [ 2038],
        [ 7205],
        [ 2533],
        [19224],
        [ 3171],
        [11563],
        [ 5534],
        [11858],
        [ 2415],
        [29915],
        [  746],
        [17005],
        [ 2729],
        [ 8040],
        [17005]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([23])


tensor([-1.5432e-01, -2.4043e+00, -1.4901e-05, -1.5343e-04, -1.0610e-05,
        -7.1930e-03, -3.5909e-01, -1.6091e+00, -2.3591e+00, -1.7558e-03,
        -1.0258e+00, -1.5258e+00, -1.9008e+00, -2.0258e+00, -3.6508e+00,
        -3.7758e+00, -2.6226e-06, -3.3480e-04, -3.3205e-01, -2.0821e+00,
        -2.7071e+00, -2.4256e-04, -9.4318e-02], device='cuda:0')


new_candidates
torch.Size([23, 19])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,  2729,   373],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   297,  4958],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858,   362],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         


new_candidate_logprobs
torch.Size([23])


tensor([-1.3125, -3.5625, -1.5332, -1.7833, -3.4075, -4.4147, -4.7666, -6.0166,
        -6.7666, -5.6593, -6.9333, -7.4333, -7.8083, -7.9333, -9.5583, -9.6833,
        -6.0325, -6.4078, -4.7133, -6.4633, -7.0883, -6.2565, -6.6006],
       device='cuda:0')

infer end: GPU memory used: 15693 MB.
event: level
id: 7
data: [{"content": "measured", "parent": 0, "prob": -1.3124570846557617}, {"content": "of", "parent": 0, "prob": -3.5624570846557617}, {"content": "on", "parent": 1, "prob": -1.5331519842147827}, {"content": "terms", "parent": 2, "prob": -1.7832905054092407}, {"content": "ation", "parent": 3, "prob": -3.407519578933716}, {"content": "above", "parent": 4, "prob": -4.41470193862915}, {"content": "elev", "parent": 5, "prob": -4.766594409942627}, {"content": "height", "parent": 5, "prob": -6.016594409942627}, {"content": "above", "parent": 5, "prob": -6.766594886779785}, {"content": "sea", "parent": 6, "prob": -5.65926456451416}, {"content": "sum", "parent": 7, "prob": -6.933293342590332}, {"content": "peak", "parent": 7, "prob": -7.433293342590332}, {"content": "height", "parent": 7, "prob": -7.808293342590332}, {"content": "Earth", "parent": 7, "prob": -7.933293342590332}, {"content": "global", "parent": 7, "prob": -9.5582933425903

array([[-1.46875   ,  0.26171875,  0.54296875, ...,  0.90625   ,
        -0.578125  , -2.484375  ],
       [-1.2421875 , -1.09375   , -0.9296875 , ..., -0.6015625 ,
        -1.859375  , -1.984375  ],
       [-2.25      , -1.4609375 ,  0.9296875 , ...,  0.9296875 ,
        -1.453125  , -1.3046875 ],
       ...,
       [ 0.11425781, -0.7734375 ,  1.96875   , ...,  1.9453125 ,
        -1.1796875 ,  0.10009766],
       [-0.37890625,  0.76953125, -0.23730469, ..., -1.6640625 ,
         0.42382812, -1.828125  ],
       [-0.84375   , -0.296875  ,  0.30859375, ...,  1.109375  ,
        -0.22167969, -2.25      ]], dtype=float32)


k_mean_space
(20, 2)


array([[108.090355,  68.845276],
       [112.92927 ,  83.16716 ],
       [100.42233 ,  62.03671 ],
       [110.15066 ,  77.7317  ],
       [103.00117 ,  60.813286],
       [109.12773 ,  70.14271 ],
       [105.55327 ,  74.92014 ],
       [103.51917 ,  59.99682 ],
       [107.0962  ,  68.711716],
       [106.75428 ,  72.73436 ],
       [  0.      ,  78.689095],
       [ 95.984436,  59.706863],
       [101.82982 ,  61.92126 ],
       [105.34387 ,  71.42927 ],
       [100.11471 ,  66.62799 ],
       [101.1882  ,  73.36465 ],
       [ 94.55702 ,  63.723984],
       [104.92426 ,  72.03608 ],
       [107.20744 ,  69.53959 ],
       [108.92861 ,  68.5642  ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([  -6.93329334, -105.25608993])


closest
(2,)


array([10, 11])


last_tok_logits
torch.Size([20, 32064])


tensor([[  1.5078,  -4.2500,  -7.2812,  ...,   0.0000,   0.0000,   0.0000],
        [ -2.9219, -11.1250,  -3.2969,  ...,   0.0000,   0.0000,   0.0000],
        [ -0.3633,  -1.3984,  -6.5625,  ...,   0.0000,   0.0000,   0.0000],
        ...,
        [ -2.5156,  -2.7188,  -1.5703,  ...,   0.0000,   0.0000,   0.0000],
        [ -1.5625,  -6.7812, -12.0000,  ...,   0.0000,   0.0000,   0.0000],
        [  1.8359,   0.2910,  -8.8125,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.5255e-01, 4.7425e-02, 7.5150e-06,  ..., 4.4073e-22, 1.6214e-22,
         6.7588e-23],
        [7.1638e-01, 1.5985e-01, 5.8804e-02,  ..., 4.9193e-20, 2.6331e-20,
         2.0507e-20],
        [7.1501e-01, 2.0485e-01, 4.0338e-02,  ..., 3.6279e-19, 2.4934e-19,
         7.1438e-20],
        ...,
        [1.0000e+00, 3.9279e-07, 6.0236e-08,  ..., 1.0324e-22, 2.6103e-23,
         1.3972e-23],
        [9.4171e-01, 4.6885e-02, 9.2322e-03,  ..., 1.6029e-22, 1.4145e-22,
         1.6894e-23],
        [8.1752e-01, 1.8241e-01, 2.8905e-05,  ..., 1.3202e-21, 1.3915e-22,
         1.2280e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9526, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.7164, 0.8762, 0.9350,  ..., 1.0000, 1.0000, 1.0000],
        [0.7150, 0.9199, 0.9602,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9417, 0.9886, 0.9978,  ..., 1.0000, 1.0000, 1.0000],
        [0.8175, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([31])


tensor([ 0,  1,  1,  1,  2,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 11, 11, 12,
        13, 14, 14, 14, 14, 15, 16, 16, 16, 17, 18, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([31, 19])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,  2729,   373],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,  2729,   373],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         


carryover_candidate_logprobs
torch.Size([31])


tensor([-1.3125, -3.5625, -3.5625, -3.5625, -1.5332, -1.5332, -1.7833, -3.4075,
        -4.4147, -4.7666, -6.0166, -6.7666, -5.6593, -6.9333, -7.4333, -7.4333,
        -7.4333, -7.8083, -7.9333, -9.5583, -9.5583, -9.5583, -9.5583, -9.6833,
        -6.0325, -6.0325, -6.0325, -6.4078, -4.7133, -6.4633, -6.4633],
       device='cuda:0')


new_candidate_toks
torch.Size([31, 1])


tensor([[  491],
        [  590],
        [ 1857],
        [ 1286],
        [  967],
        [  278],
        [  310],
        [ 2038],
        [ 7205],
        [  362],
        [ 2038],
        [ 7205],
        [ 3233],
        [ 2415],
        [ 2038],
        [11858],
        [29915],
        [ 2038],
        [29915],
        [19224],
        [ 2533],
        [ 3171],
        [11858],
        [  362],
        [ 3171],
        [ 2038],
        [11858],
        [29879],
        [17005],
        [  491],
        [  515]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([31])


tensor([-4.8612e-02, -3.3354e-01, -1.8335e+00, -2.8335e+00, -3.3545e-01,
        -1.5855e+00,  0.0000e+00, -8.0603e-03, -5.9605e-06, -1.2387e-04,
        -3.7585e-03, -2.7418e-06, -2.5153e-05, -7.1526e-07, -3.2222e-01,
        -1.9472e+00, -2.5722e+00, -2.7378e-03, -7.9454e-02, -8.2963e-01,
        -1.4546e+00, -1.9546e+00, -2.2046e+00, -2.6822e-05, -8.2659e-01,
        -9.5159e-01, -1.8266e+00, -5.9605e-07, -6.0061e-02, -2.0149e-01,
        -1.7015e+00], device='cuda:0')


new_candidates
torch.Size([31, 20])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408,   310,   590],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408,   310,  1857],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408,   310,  1286],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,  2729,   373,   967],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,  2729,   373,   278],
        [    1, 32010,  1724,   338,   278,  9


new_candidate_logprobs
torch.Size([31])


tensor([ -1.3611,  -3.8960,  -5.3960,  -6.3960,  -1.8686,  -3.1186,  -1.7833,
         -3.4156,  -4.4147,  -4.7667,  -6.0204,  -6.7666,  -5.6593,  -6.9333,
         -7.7555,  -9.3805, -10.0055,  -7.8110,  -8.0127, -10.3879, -11.0129,
        -11.5129, -11.7629,  -9.6833,  -6.8591,  -6.9841,  -7.8591,  -6.4078,
         -4.7734,  -6.6648,  -8.1648], device='cuda:0')

infer end: GPU memory used: 15803 MB.
event: level
id: 8
data: [{"content": "by", "parent": 0, "prob": -1.3610693216323853}, {"content": "my", "parent": 1, "prob": -3.8960018157958984}, {"content": "current", "parent": 1, "prob": -5.396001815795898}, {"content": "now", "parent": 1, "prob": -6.396001815795898}, {"content": "its", "parent": 2, "prob": -1.868605375289917}, {"content": "the", "parent": 2, "prob": -3.118605136871338}, {"content": "of", "parent": 3, "prob": -1.7832905054092407}, {"content": "above", "parent": 4, "prob": -3.4155797958374023}, {"content": "sea", "parent": 5, "prob": -4.414708137512207}, {"content": "ation", "parent": 6, "prob": -4.76671838760376}, {"content": "above", "parent": 7, "prob": -6.020352840423584}, {"content": "sea", "parent": 8, "prob": -6.766597747802734}, {"content": "level", "parent": 9, "prob": -5.659289836883545}, {"content": "mit", "parent": 10, "prob": -6.933294296264648}, {"content": "above", "parent": 11, "prob": -7.7555131912231445}, {"co

array([[-2.546875  , -1.453125  ,  1.859375  , ...,  1.375     ,
        -1.3203125 , -1.828125  ],
       [-0.04077148, -2.109375  , -0.78515625, ...,  1.4921875 ,
        -2.453125  ,  0.22167969],
       [-0.78125   , -0.35546875,  1.328125  , ..., -2.1875    ,
        -1.8125    , -1.3515625 ],
       ...,
       [-2.140625  , -1.0546875 ,  1.28125   , ...,  0.54296875,
        -0.59765625, -0.78125   ],
       [ 0.25195312, -1.1171875 ,  2.515625  , ...,  1.9453125 ,
        -0.8828125 ,  0.11230469],
       [-2.140625  ,  0.48632812,  1.671875  , ..., -0.71484375,
         0.765625  ,  0.82421875]], dtype=float32)


k_mean_space
(20, 2)


array([[82.02928 , 57.83898 ],
       [86.31077 , 88.55866 ],
       [85.77922 , 66.65087 ],
       [89.64893 , 73.17649 ],
       [83.02188 , 55.618355],
       [83.949165, 58.62445 ],
       [81.674576, 55.881218],
       [43.21181 , 81.42808 ],
       [66.71655 , 83.85886 ],
       [86.25273 , 60.581882],
       [41.89849 , 79.461914],
       [66.35204 , 83.29448 ],
       [79.35331 , 54.730457],
       [86.8606  , 60.046146],
       [44.742443, 78.33824 ],
       [86.83229 , 74.70681 ],
       [90.86739 , 69.88292 ],
       [43.382965, 80.06935 ],
       [93.36756 , 74.79514 ],
       [84.965416, 54.670177]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-40.07978487, -75.06957483])


closest
(2,)


array([10, 19])


last_tok_logits
torch.Size([20, 32064])


tensor([[ -1.6172,   1.2969,  -7.1875,  ...,   0.0000,   0.0000,   0.0000],
        [  0.6797,  -8.6250,  -4.2500,  ...,   0.0000,   0.0000,   0.0000],
        [  0.1162,  -6.5312,  -2.5625,  ...,   0.0000,   0.0000,   0.0000],
        ...,
        [  2.2812,  -3.3438, -11.8125,  ...,   0.0000,   0.0000,   0.0000],
        [ -2.7188,  -3.0000,  -0.2246,  ...,   0.0000,   0.0000,   0.0000],
        [  2.5781,  -2.2188,  -5.0938,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[7.4603e-01, 1.4690e-01, 8.9100e-02,  ..., 1.6631e-20, 7.8561e-21,
         6.9330e-21],
        [5.5856e-01, 4.3501e-01, 6.2051e-03,  ..., 2.1214e-23, 1.8721e-23,
         1.6521e-23],
        [9.5682e-01, 1.7525e-02, 1.0629e-02,  ..., 3.9068e-22, 3.0426e-22,
         1.8455e-22],
        ...,
        [9.9999e-01, 4.2228e-06, 2.9023e-06,  ..., 8.4743e-24, 6.5998e-24,
         2.4279e-24],
        [1.0000e+00, 1.1254e-07, 2.5110e-08,  ..., 1.0881e-23, 1.1469e-24,
         2.2583e-25],
        [7.1712e-01, 1.4121e-01, 1.2462e-01,  ..., 4.2604e-22, 2.9281e-22,
         5.0883e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.7460, 0.8929, 0.9820,  ..., 1.0000, 1.0000, 1.0000],
        [0.5586, 0.9936, 0.9998,  ..., 1.0000, 1.0000, 1.0000],
        [0.9568, 0.9743, 0.9850,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.7171, 0.8583, 0.9829,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([36])


tensor([ 0,  0,  0,  1,  1,  2,  3,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  6,
         7,  8,  9, 10, 11, 12, 12, 13, 13, 13, 14, 15, 16, 17, 18, 19, 19, 19],
       device='cuda:0')


carryover_candidates
torch.Size([36, 20])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408,   310,   590],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408,   310,   590],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408,   310,  1857],
        [    1, 32010,  1724,   338,   278,  9


carryover_candidate_logprobs
torch.Size([36])


tensor([ -1.3611,  -1.3611,  -1.3611,  -3.8960,  -3.8960,  -5.3960,  -6.3960,
         -1.8686,  -1.8686,  -3.1186,  -3.1186,  -3.1186,  -3.1186,  -1.7833,
         -1.7833,  -1.7833,  -1.7833,  -1.7833,  -3.4156,  -4.4147,  -4.7667,
         -6.0204,  -6.7666,  -5.6593,  -5.6593,  -6.9333,  -6.9333,  -6.9333,
         -7.7555,  -9.3805, -10.0055,  -7.8110,  -8.0127, -10.3879, -10.3879,
        -10.3879], device='cuda:0')


new_candidate_toks
torch.Size([36, 1])


tensor([[  967],
        [  278],
        [19224],
        [ 7134],
        [ 1833],
        [20398],
        [29892],
        [19224],
        [ 2533],
        [ 2533],
        [19224],
        [ 3171],
        [ 3001],
        [11858],
        [ 3171],
        [19224],
        [  967],
        [ 5272],
        [ 7205],
        [ 3233],
        [ 2038],
        [ 7205],
        [ 3233],
        [11858],
        [  338],
        [ 2038],
        [29915],
        [11858],
        [ 7205],
        [  362],
        [29879],
        [ 7205],
        [29879],
        [ 2038],
        [  338],
        [11858]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([36])


tensor([-2.9299e-01, -1.9180e+00, -2.4180e+00, -5.8239e-01, -8.3239e-01,
        -4.4145e-02, -1.9551e-05, -1.2063e-01, -2.7456e+00, -8.0460e-01,
        -1.0546e+00, -2.3046e+00, -3.8046e+00, -9.4504e-01, -1.3200e+00,
        -1.8200e+00, -3.0700e+00, -3.0700e+00, -7.1526e-06, -1.4305e-06,
        -5.9583e-03, -1.5020e-05, -5.9605e-07, -3.4680e-01, -1.3468e+00,
        -8.6266e-01, -9.8766e-01, -2.1127e+00, -3.4571e-06, -6.6998e-05,
        -7.5102e-06, -9.0599e-06, -1.1921e-07, -3.3252e-01, -1.9575e+00,
        -2.0825e+00], device='cuda:0')


new_candidates
torch.Size([36, 21])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           278],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
         19224],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408,   310,   590,
          7134],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408,   310,   590,
          1833],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11


new_candidate_logprobs
torch.Size([36])


tensor([ -1.6541,  -3.2791,  -3.7791,  -4.4784,  -4.7284,  -5.4401,  -6.3960,
         -1.9892,  -4.6142,  -3.9232,  -4.1732,  -5.4232,  -6.9232,  -2.7283,
         -3.1033,  -3.6033,  -4.8533,  -4.8533,  -3.4156,  -4.4147,  -4.7727,
         -6.0204,  -6.7666,  -6.0061,  -7.0061,  -7.7960,  -7.9210,  -9.0460,
         -7.7555,  -9.3806, -10.0055,  -7.8110,  -8.0127, -10.7204, -12.3454,
        -12.4704], device='cuda:0')

infer end: GPU memory used: 15853 MB.
event: level
id: 9
data: [{"content": "its", "parent": 0, "prob": -1.6540625095367432}, {"content": "the", "parent": 0, "prob": -3.2790627479553223}, {"content": "peak", "parent": 0, "prob": -3.7790627479553223}, {"content": "knowledge", "parent": 1, "prob": -4.47838830947876}, {"content": "last", "parent": 1, "prob": -4.72838830947876}, {"content": "measurements", "parent": 2, "prob": -5.4401469230651855}, {"content": ",", "parent": 3, "prob": -6.396021366119385}, {"content": "peak", "parent": 4, "prob": -1.9892395734786987}, {"content": "sum", "parent": 4, "prob": -4.614239692687988}, {"content": "sum", "parent": 5, "prob": -3.9232091903686523}, {"content": "peak", "parent": 5, "prob": -4.173209190368652}, {"content": "height", "parent": 5, "prob": -5.423209190368652}, {"content": "total", "parent": 5, "prob": -6.923209190368652}, {"content": "elev", "parent": 6, "prob": -2.7283291816711426}, {"content": "height", "parent": 6, "prob": -3.10332918

array([[-3.578125  , -1.8984375 ,  0.37304688, ...,  0.69921875,
        -0.375     , -1.59375   ],
       [-1.734375  , -0.68359375,  0.06225586, ...,  1.2578125 ,
         0.29492188, -1.0703125 ],
       [-0.37304688, -1.4609375 ,  2.890625  , ..., -0.48242188,
         0.6796875 , -1.046875  ],
       ...,
       [-0.5234375 , -2.828125  ,  1.859375  , ..., -2.21875   ,
        -0.41015625, -0.53125   ],
       [ 1.8359375 , -0.71484375,  0.06152344, ...,  0.50390625,
        -1.21875   ,  2.0625    ],
       [-1.9765625 ,  1.140625  , -1.296875  , ..., -1.78125   ,
         0.0045166 ,  1.4453125 ]], dtype=float32)


k_mean_space
(20, 2)


array([[54.889275, 78.07944 ],
       [60.380497, 78.97818 ],
       [51.087166, 73.564995],
       [90.23411 , 77.70281 ],
       [84.14997 , 84.881096],
       [80.79718 , 63.606934],
       [81.298676, 63.996265],
       [52.3949  , 71.988045],
       [69.22221 , 83.256584],
       [67.64776 , 82.19654 ],
       [55.464928, 72.7298  ],
       [72.7466  , 60.54184 ],
       [65.45474 , 82.047485],
       [85.38649 , 70.93598 ],
       [72.50519 , 58.409164],
       [51.31735 , 72.83523 ],
       [56.361176, 77.84605 ],
       [83.93292 , 70.81448 ],
       [80.798004, 82.83262 ],
       [78.99681 , 61.898376]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-46.93592846, -36.8374629 ])


closest
(2,)


array([ 2, 14])


last_tok_logits
torch.Size([20, 32064])


tensor([[-1.3203e+00,  2.2500e+00, -6.5312e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-1.6875e+00,  3.9219e+00, -4.3125e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 2.0410e-01, -2.4062e+00, -5.3750e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [-1.6797e+00,  7.8735e-03, -1.5781e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.3203e+00, -5.1875e+00, -1.0375e+01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 2.3633e-01, -5.3750e+00, -6.4375e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[8.8566e-01, 4.9966e-02, 4.4095e-02,  ..., 1.1976e-20, 1.0568e-20,
         8.2307e-21],
        [4.6710e-01, 2.2064e-01, 1.9472e-01,  ..., 8.7189e-20, 5.9924e-20,
         3.2075e-20],
        [8.7683e-01, 8.1557e-02, 2.0621e-02,  ..., 2.7883e-22, 9.0522e-23,
         5.4904e-23],
        ...,
        [9.9999e-01, 3.2887e-06, 2.2603e-06,  ..., 3.5326e-24, 1.8909e-24,
         6.1387e-25],
        [1.0000e+00, 1.3710e-06, 1.6374e-07,  ..., 2.6442e-28, 1.1023e-28,
         3.5786e-29],
        [9.9994e-01, 5.8291e-05, 7.3378e-07,  ..., 3.1173e-24, 2.1425e-24,
         2.5589e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.8857, 0.9356, 0.9797,  ..., 1.0000, 1.0000, 1.0000],
        [0.4671, 0.6877, 0.8825,  ..., 1.0000, 1.0000, 1.0000],
        [0.8768, 0.9584, 0.9790,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([33])


tensor([ 0,  0,  1,  1,  1,  1,  2,  2,  3,  4,  5,  6,  7,  7,  7,  8,  9, 10,
        10, 10, 11, 12, 13, 14, 15, 15, 15, 16, 16, 16, 17, 18, 19],
       device='cuda:0')


carryover_candidates
torch.Size([33, 21])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           278],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           278],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           278],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11


carryover_candidate_logprobs
torch.Size([33])


tensor([-1.6541, -1.6541, -3.2791, -3.2791, -3.2791, -3.2791, -3.7791, -3.7791,
        -4.4784, -4.7284, -5.4401, -6.3960, -1.9892, -1.9892, -1.9892, -4.6142,
        -3.9232, -4.1732, -4.1732, -4.1732, -5.4232, -6.9232, -2.7283, -3.1033,
        -3.6033, -3.6033, -3.6033, -4.8533, -4.8533, -4.8533, -4.8533, -3.4156,
        -4.4147], device='cuda:0')


new_candidate_toks
torch.Size([33, 1])


tensor([[19224],
        [ 2533],
        [ 3171],
        [ 2533],
        [19224],
        [ 5272],
        [11858],
        [ 2038],
        [ 5700],
        [ 2767],
        [29892],
        [  338],
        [29915],
        [11858],
        [ 2038],
        [ 2415],
        [ 2415],
        [29915],
        [ 2038],
        [20888],
        [ 2038],
        [ 3171],
        [  362],
        [ 2038],
        [11858],
        [ 2038],
        [ 3171],
        [19224],
        [ 3171],
        [11858],
        [ 4279],
        [ 3233],
        [  338]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([33])


tensor([-1.2142e-01, -2.9964e+00, -7.6121e-01, -1.5112e+00, -1.6362e+00,
        -2.8862e+00, -1.3145e-01, -2.5064e+00, -4.8506e-03, -8.6415e-03,
        -1.9719e-03, -1.2755e-05, -6.4153e-01, -1.3915e+00, -1.8915e+00,
        -4.7684e-07, -7.1086e-04, -5.9533e-01, -1.5953e+00, -1.7203e+00,
        -6.4690e-02, -4.0630e-02, -2.2769e-05, -4.4883e-03, -2.3338e-01,
        -2.3584e+00, -2.3584e+00, -6.1641e-01, -1.1164e+00, -2.9914e+00,
        -7.0334e-06, -1.5497e-06, -6.0560e-05], device='cuda:0')


new_candidates
torch.Size([33, 22])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967,  2533],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           278,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           278,  2533],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           278, 19224],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32


new_candidate_logprobs
torch.Size([33])


tensor([-1.7755, -4.6505, -4.0403, -4.7903, -4.9153, -6.1653, -3.9105, -6.2855,
        -4.4832, -4.7370, -5.4421, -6.3960, -2.6308, -3.3808, -3.8808, -4.6142,
        -3.9239, -4.7685, -5.7685, -5.8935, -5.4879, -6.9638, -2.7284, -3.1078,
        -3.8367, -5.9617, -5.9617, -5.4697, -5.9697, -7.8447, -4.8533, -3.4156,
        -4.4148], device='cuda:0')

infer end: GPU memory used: 15909 MB.
event: level
id: 10
data: [{"content": "peak", "parent": 0, "prob": -1.7754803895950317}, {"content": "sum", "parent": 0, "prob": -4.650480270385742}, {"content": "height", "parent": 1, "prob": -4.0402703285217285}, {"content": "sum", "parent": 1, "prob": -4.7902703285217285}, {"content": "peak", "parent": 1, "prob": -4.9152703285217285}, {"content": "alt", "parent": 1, "prob": -6.1652703285217285}, {"content": "elev", "parent": 2, "prob": -3.9105100631713867}, {"content": "above", "parent": 2, "prob": -6.285510063171387}, {"content": "cut", "parent": 3, "prob": -4.48323917388916}, {"content": "update", "parent": 4, "prob": -4.737029552459717}, {"content": ",", "parent": 5, "prob": -5.4421186447143555}, {"content": "is", "parent": 6, "prob": -6.396034240722656}, {"content": "'", "parent": 7, "prob": -2.6307716369628906}, {"content": "elev", "parent": 7, "prob": -3.3807716369628906}, {"content": "above", "parent": 7, "prob": -3.8807716369628906}, {"

array([[-2.390625  , -0.578125  ,  2.640625  , ...,  1.2890625 ,
         0.13671875,  0.06689453],
       [ 0.61328125, -1.640625  ,  2.03125   , ..., -2.0625    ,
         0.49609375,  0.04785156],
       [-0.09619141, -1.40625   ,  2.46875   , ...,  1.9765625 ,
         0.41015625, -0.55078125],
       ...,
       [-0.11035156,  0.3125    ,  1.1328125 , ..., -0.42382812,
        -1.234375  , -3.328125  ],
       [-2.5625    , -1.5       ,  1.921875  , ...,  2.328125  ,
         0.39648438, -1.8828125 ],
       [-1.328125  ,  1.1640625 ,  2.3125    , ..., -1.1484375 ,
         0.6328125 ,  0.03149414]], dtype=float32)


k_mean_space
(20, 2)


array([[52.456715, 74.04165 ],
       [82.232346, 69.7017  ],
       [77.23338 , 69.04499 ],
       [83.272125, 68.939995],
       [53.590088, 74.42818 ],
       [88.41295 , 70.31234 ],
       [88.818665, 69.75629 ],
       [83.19874 , 60.407867],
       [84.866585, 77.32373 ],
       [79.6626  , 87.32031 ],
       [69.37577 , 83.5088  ],
       [72.305664, 85.298004],
       [70.317474, 81.21001 ],
       [84.02613 , 65.01364 ],
       [81.955635, 60.037983],
       [54.56602 , 76.077126],
       [54.675316, 75.360374],
       [69.64279 , 84.45496 ],
       [81.789215, 60.48908 ],
       [72.237175, 70.9672  ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-39.20340884, -53.24918079])


closest
(2,)


array([ 0, 14])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.3359, -5.2188, -5.6875,  ...,  0.0000,  0.0000,  0.0000],
        [-1.2344, -2.1719, -3.8906,  ...,  0.0000,  0.0000,  0.0000],
        [-0.9414, -0.7539, -7.2500,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.2832, -0.2363, -4.1250,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.4766,  0.6328, -8.8125,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.7812,  0.4512, -0.3848,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[5.8104e-01, 1.8864e-01, 1.6647e-01,  ..., 4.2053e-21, 2.8903e-21,
         1.3653e-21],
        [1.0000e+00, 2.5613e-06, 3.9278e-07,  ..., 6.2617e-23, 4.3036e-23,
         1.3972e-23],
        [9.6075e-01, 3.7252e-02, 1.8547e-03,  ..., 1.0664e-21, 3.9229e-22,
         5.3091e-23],
        ...,
        [1.0000e+00, 7.1941e-09, 6.6916e-10,  ..., 2.6972e-26, 1.2741e-26,
         1.1244e-26],
        [9.9987e-01, 9.6099e-05, 1.4737e-05,  ..., 6.2609e-23, 3.7974e-23,
         4.5354e-24],
        [4.8236e-01, 4.2568e-01, 3.4942e-02,  ..., 1.2843e-21, 8.8268e-22,
         7.7897e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.5810, 0.7697, 0.9362,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9608, 0.9980, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.4824, 0.9080, 0.9430,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([24])


tensor([ 0,  0,  0,  1,  2,  3,  4,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,
        15, 16, 17, 18, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([24, 22])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967,  2533],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           278,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32


carryover_candidate_logprobs
torch.Size([24])


tensor([-1.7755, -1.7755, -1.7755, -4.6505, -4.0403, -4.7903, -4.9153, -4.9153,
        -6.1653, -3.9105, -6.2855, -4.4832, -4.7370, -5.4421, -6.3960, -2.6308,
        -3.3808, -3.8808, -4.6142, -3.9239, -4.7685, -5.7685, -5.8935, -5.8935],
       device='cuda:0')


new_candidate_toks
torch.Size([24, 1])


tensor([[29915],
        [11858],
        [ 2038],
        [ 2415],
        [ 2038],
        [ 2415],
        [29915],
        [ 2038],
        [ 4279],
        [  362],
        [ 7205],
        [ 2696],
        [  297],
        [  338],
        [ 8040],
        [29879],
        [  362],
        [ 7205],
        [29915],
        [29915],
        [29879],
        [ 7205],
        [  278],
        [ 2038]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([24])


tensor([-5.4293e-01, -1.6679e+00, -1.7929e+00, -3.8147e-06, -4.0036e-02,
        -2.9751e-02, -3.7244e-01, -1.4974e+00, -7.7486e-06, -2.3243e-04,
        -1.2279e-05, -1.1803e-03, -6.1969e-02,  0.0000e+00, -8.7027e-05,
        -1.6570e-05, -2.3088e-04, -1.6928e-05, -4.6181e-02, -2.0844e-02,
         0.0000e+00, -1.2840e-04, -7.2907e-01, -8.5407e-01], device='cuda:0')


new_candidates
torch.Size([24, 23])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967,  2533,  2415],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           278,  3171,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 1437


new_candidate_logprobs
torch.Size([24])


tensor([-2.3184, -3.4434, -3.5684, -4.6505, -4.0803, -4.8200, -5.2877, -6.4127,
        -6.1653, -3.9107, -6.2855, -4.4844, -4.7990, -5.4421, -6.3961, -2.6308,
        -3.3810, -3.8808, -4.6604, -3.9448, -4.7685, -5.7687, -6.6226, -6.7476],
       device='cuda:0')

infer end: GPU memory used: 15963 MB.
event: level
id: 11
data: [{"content": "'", "parent": 0, "prob": -2.3184120655059814}, {"content": "elev", "parent": 0, "prob": -3.4434120655059814}, {"content": "above", "parent": 0, "prob": -3.5684120655059814}, {"content": "mit", "parent": 1, "prob": -4.650484085083008}, {"content": "above", "parent": 2, "prob": -4.080306529998779}, {"content": "mit", "parent": 3, "prob": -4.820021152496338}, {"content": "'", "parent": 4, "prob": -5.287709712982178}, {"content": "above", "parent": 4, "prob": -6.412710189819336}, {"content": "itude", "parent": 5, "prob": -6.16527795791626}, {"content": "ation", "parent": 6, "prob": -3.9107425212860107}, {"content": "sea", "parent": 7, "prob": -6.2855224609375}, {"content": "off", "parent": 8, "prob": -4.484419345855713}, {"content": "in", "parent": 9, "prob": -4.798998832702637}, {"content": "is", "parent": 10, "prob": -5.4421186447143555}, {"content": "Mount", "parent": 11, "prob": -6.396121501922607}, {"content

array([[ 1.8359375e+00, -4.1992188e-01,  2.4687500e+00, ...,
         6.8359375e-01, -1.7456055e-02, -5.0781250e-02],
       [ 2.3750000e+00, -1.4140625e+00,  3.2968750e+00, ...,
         8.4375000e-01, -5.2734375e-01, -1.8125000e+00],
       [-2.1093750e+00, -2.2343750e+00,  1.4062500e+00, ...,
         1.1093750e+00, -5.0292969e-02, -1.8828125e+00],
       ...,
       [ 1.6171875e+00, -1.3515625e+00,  3.6718750e-01, ...,
         1.7890625e+00, -1.6328125e+00,  2.1718750e+00],
       [ 8.2031250e-01,  2.4218750e-01,  1.7421875e+00, ...,
         1.4843750e-01, -1.3203125e+00, -3.6250000e+00],
       [ 9.3359375e-01,  1.8652344e-01,  1.3671875e+00, ...,
        -1.7700195e-03, -1.4140625e+00, -3.5468750e+00]], dtype=float32)


k_mean_space
(20, 2)


array([[ 72.23879 ,  89.73633 ],
       [ 75.872925, 100.29541 ],
       [ 65.14652 ,  95.09122 ],
       [ 64.582   ,  95.19549 ],
       [ 68.453735,  94.84407 ],
       [ 65.80657 ,  95.745255],
       [ 69.68479 ,  87.529686],
       [ 65.94327 ,  95.9941  ],
       [ 67.01432 ,  97.575   ],
       [ 64.51316 ,  98.56519 ],
       [ 72.0294  , 103.86715 ],
       [ 80.26332 , 104.383675],
       [ 82.406525, 102.63307 ],
       [ 83.687645,  49.518433],
       [ 86.797424,  57.908524],
       [ 66.124245,  93.57507 ],
       [ 62.33776 ,  95.99142 ],
       [ 71.811066, 104.22248 ],
       [ 79.85231 ,  41.11964 ],
       [ 80.93922 ,  40.10299 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-70.1190083 , -20.44342494])


closest
(2,)


array([16, 19])


last_tok_logits
torch.Size([20, 32064])


tensor([[  0.1021,  -1.2578,  -3.6562,  ...,   0.0000,   0.0000,   0.0000],
        [  2.5156,  -4.5000,   0.0398,  ...,   0.0000,   0.0000,   0.0000],
        [  2.7188,  -2.7500, -10.1875,  ...,   0.0000,   0.0000,   0.0000],
        ...,
        [  3.0312,  -5.0938,  -8.0000,  ...,   0.0000,   0.0000,   0.0000],
        [ -2.2969,   0.6875,  -5.0000,  ...,   0.0000,   0.0000,   0.0000],
        [ -0.6641,   1.1797,  -4.2812,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 3.0590e-07, 1.6374e-07,  ..., 3.7234e-25, 3.2859e-25,
         1.5521e-25],
        [9.9982e-01, 1.7953e-04, 1.8551e-07,  ..., 1.3969e-23, 9.6010e-24,
         6.5986e-24],
        [9.9999e-01, 5.4222e-06, 1.2099e-06,  ..., 1.6687e-24, 8.9318e-25,
         3.7233e-25],
        ...,
        [1.0000e+00, 3.9279e-07, 1.4450e-07,  ..., 7.5759e-29, 5.2068e-29,
         2.7870e-29],
        [1.0000e+00, 1.4166e-09, 8.5922e-10,  ..., 1.8538e-26, 8.7565e-27,
         6.8196e-27],
        [1.0000e+00, 5.9053e-10, 4.5991e-10,  ..., 3.0563e-26, 1.8538e-26,
         2.5088e-27]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9998, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([22])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  8,  9, 10, 11, 12, 13, 14, 15, 15,
        16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([22, 23])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967,  2533,  2415],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           278,  3171,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 1437


carryover_candidate_logprobs
torch.Size([22])


tensor([-2.3184, -3.4434, -3.5684, -4.6505, -4.0803, -4.8200, -5.2877, -6.4127,
        -6.1653, -6.1653, -3.9107, -6.2855, -4.4844, -4.7990, -5.4421, -6.3961,
        -2.6308, -2.6308, -3.3810, -3.8808, -4.6604, -3.9448], device='cuda:0')


new_candidate_toks
torch.Size([22, 1])


tensor([[29879],
        [  362],
        [ 7205],
        [29915],
        [ 7205],
        [29915],
        [29879],
        [ 7205],
        [  310],
        [ 2038],
        [ 2038],
        [ 3233],
        [  297],
        [29871],
        [ 8040],
        [18274],
        [ 3171],
        [11858],
        [ 2038],
        [ 3233],
        [29879],
        [29879]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([22])


tensor([-5.9605e-07, -1.8008e-04, -7.2718e-06, -5.5956e-03, -9.0484e-05,
        -1.3082e-03, -8.0827e-05, -3.6479e-05, -3.5567e-01, -1.2307e+00,
        -4.1341e-02, -8.3447e-07, -5.2367e-03, -2.2918e-03, -8.3092e-05,
        -2.3842e-06, -3.3624e-01, -1.3362e+00, -3.5316e-02, -4.7684e-07,
         0.0000e+00,  0.0000e+00], device='cuda:0')


new_candidates
torch.Size([22, 24])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967,  2533,  2415, 29915],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           278,  3171,  2038,  7205],
        [    1, 3201


new_candidate_logprobs
torch.Size([22])


tensor([-2.3184, -3.4436, -3.5684, -4.6561, -4.0804, -4.8213, -5.2878, -6.4127,
        -6.5210, -7.3960, -3.9521, -6.2855, -4.4897, -4.8013, -5.4422, -6.3961,
        -2.9670, -3.9670, -3.4163, -3.8808, -4.6604, -3.9448], device='cuda:0')

infer end: GPU memory used: 16021 MB.
event: level
id: 12
data: [{"content": "s", "parent": 0, "prob": -2.3184127807617188}, {"content": "ation", "parent": 1, "prob": -3.443592071533203}, {"content": "sea", "parent": 2, "prob": -3.5684194564819336}, {"content": "'", "parent": 3, "prob": -4.6560797691345215}, {"content": "sea", "parent": 4, "prob": -4.080397129058838}, {"content": "'", "parent": 5, "prob": -4.821329116821289}, {"content": "s", "parent": 6, "prob": -5.287790775299072}, {"content": "sea", "parent": 7, "prob": -6.412746906280518}, {"content": "of", "parent": 8, "prob": -6.5209503173828125}, {"content": "above", "parent": 8, "prob": -7.3959503173828125}, {"content": "above", "parent": 9, "prob": -3.952083110809326}, {"content": "level", "parent": 10, "prob": -6.285523414611816}, {"content": "in", "parent": 11, "prob": -4.4896559715271}, {"content": "", "parent": 12, "prob": -4.801290512084961}, {"content": "Mount", "parent": 13, "prob": -5.442201614379883}, {"content": "Eve

array([[-2.171875  , -1.078125  ,  0.7109375 , ...,  1.59375   ,
        -1.7265625 , -0.9375    ],
       [-1.765625  , -0.86328125,  3.4375    , ...,  0.6640625 ,
        -0.28320312, -0.53515625],
       [ 1.8671875 , -1.15625   ,  0.7734375 , ...,  1.71875   ,
        -1.6640625 ,  2.296875  ],
       ...,
       [ 1.953125  , -1.4765625 ,  2.296875  , ...,  1.3046875 ,
        -0.3984375 , -1.265625  ],
       [-2.15625   , -2.296875  ,  1.3203125 , ...,  0.91796875,
        -0.9609375 , -1.8125    ],
       [-1.03125   , -0.64453125,  2.421875  , ..., -0.01409912,
        -0.21777344,  0.49609375]], dtype=float32)


k_mean_space
(20, 2)


array([[64.33764 , 90.532936],
       [66.66351 , 93.335464],
       [66.30068 , 99.19117 ],
       [83.19065 , 44.52203 ],
       [67.57986 , 99.54421 ],
       [83.31626 , 43.936   ],
       [65.188805, 89.89933 ],
       [66.52929 , 99.05127 ],
       [72.82356 , 89.345795],
       [65.82837 , 92.90611 ],
       [64.264946, 93.41304 ],
       [63.074104, 90.660805],
       [82.92864 , 97.509865],
       [84.79663 , 97.58915 ],
       [87.535126, 63.116627],
       [90.867455, 74.52381 ],
       [67.52167 , 92.527596],
       [77.972534, 96.14137 ],
       [62.988106, 92.93969 ],
       [63.534187, 91.76606 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-72.7879765 , -21.31573439])


closest
(2,)


array([18,  5])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.1826,  0.8203, -2.8594,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0781, -3.5469, -2.9219,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.2812, -6.4688, -7.2188,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.9180, -3.3125, -0.3145,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.6094, -2.0312, -9.3750,  ...,  0.0000,  0.0000,  0.0000],
        [-0.5938, -4.6562, -6.7188,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[6.2059e-01, 3.3218e-01, 4.4955e-02,  ..., 3.4980e-21, 2.7243e-21,
         1.6524e-21],
        [9.7530e-01, 1.5764e-02, 7.4465e-03,  ..., 5.1134e-22, 4.5126e-22,
         2.1316e-22],
        [1.0000e+00, 2.6996e-07, 6.0236e-08,  ..., 1.4154e-28, 1.1023e-28,
         2.4595e-29],
        ...,
        [9.9999e-01, 1.3007e-05, 1.1253e-07,  ..., 1.2330e-23, 1.0881e-23,
         6.5997e-24],
        [9.9994e-01, 5.1442e-05, 1.9946e-06,  ..., 8.4738e-24, 3.1174e-24,
         3.1174e-24],
        [9.9993e-01, 4.0063e-05, 7.8888e-06,  ..., 1.8908e-24, 4.7806e-25,
         4.7806e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.6206, 0.9528, 0.9977,  ..., 1.0000, 1.0000, 1.0000],
        [0.9753, 0.9911, 0.9985,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([22])


tensor([ 0,  0,  1,  2,  3,  4,  5,  6,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
        16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([22, 24])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967,  2533,  2415, 29915],
        [    1, 3201


carryover_candidate_logprobs
torch.Size([22])


tensor([-2.3184, -2.3184, -3.4436, -3.5684, -4.6561, -4.0804, -4.8213, -5.2878,
        -5.2878, -6.4127, -6.5210, -7.3960, -3.9521, -6.2855, -4.4897, -4.8013,
        -5.4422, -6.3961, -2.9670, -3.9670, -3.4163, -3.8808], device='cuda:0')


new_candidate_toks
torch.Size([22, 1])


tensor([[ 3171],
        [11858],
        [ 2038],
        [ 3233],
        [29879],
        [ 3233],
        [29879],
        [ 3171],
        [11858],
        [ 3233],
        [  967],
        [ 7205],
        [ 7205],
        [29892],
        [29871],
        [29906],
        [18274],
        [  342],
        [ 2038],
        [  362],
        [ 7205],
        [29892]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([22])


tensor([-4.7709e-01, -1.1021e+00, -2.5009e-02, -3.5763e-07,  0.0000e+00,
        -2.0266e-06,  0.0000e+00, -5.4489e-01, -1.0449e+00, -1.5497e-06,
        -2.5988e-05, -7.9158e-05, -7.4866e-05, -2.9564e-05, -2.6584e-05,
         0.0000e+00, -5.7221e-06, -1.0133e-05, -4.4743e-04, -1.3352e-05,
        -5.8295e-05, -6.6759e-05], device='cuda:0')


new_candidates
torch.Size([22, 25])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967,  2533,  2415, 


new_candidate_logprobs
torch.Size([22])


tensor([-2.7955, -3.4205, -3.4686, -3.5684, -4.6561, -4.0804, -4.8213, -5.8327,
        -6.3327, -6.4127, -6.5210, -7.3960, -3.9522, -6.2856, -4.4897, -4.8013,
        -5.4422, -6.3961, -2.9675, -3.9670, -3.4164, -3.8809], device='cuda:0')

infer end: GPU memory used: 16081 MB.
event: level
id: 13
data: [{"content": "height", "parent": 0, "prob": -2.795501232147217}, {"content": "elev", "parent": 0, "prob": -3.4205009937286377}, {"content": "above", "parent": 1, "prob": -3.4686009883880615}, {"content": "level", "parent": 2, "prob": -3.568419933319092}, {"content": "s", "parent": 3, "prob": -4.6560797691345215}, {"content": "level", "parent": 4, "prob": -4.080399036407471}, {"content": "s", "parent": 5, "prob": -4.821329116821289}, {"content": "height", "parent": 6, "prob": -5.832678318023682}, {"content": "elev", "parent": 6, "prob": -6.332678318023682}, {"content": "level", "parent": 7, "prob": -6.412748336791992}, {"content": "its", "parent": 8, "prob": -6.520976543426514}, {"content": "sea", "parent": 9, "prob": -7.396029472351074}, {"content": "sea", "parent": 10, "prob": -3.952157974243164}, {"content": ",", "parent": 11, "prob": -6.285552978515625}, {"content": "", "parent": 12, "prob": -4.489682674407959}, {"conte

array([[-0.55078125, -0.875     ,  1.8828125 , ...,  1.5546875 ,
        -1.15625   , -1.078125  ],
       [ 2.28125   , -1.46875   ,  2.03125   , ...,  1.2578125 ,
        -0.34179688, -1.1640625 ],
       [-2.359375  , -2.234375  ,  0.87109375, ...,  0.765625  ,
        -1.3984375 , -2.015625  ],
       ...,
       [ 0.65625   ,  1.0859375 ,  0.25390625, ...,  2.265625  ,
         2.8125    ,  2.484375  ],
       [-2.359375  , -1.5703125 ,  0.59375   , ...,  0.20605469,
        -1.6171875 , -2.453125  ],
       [-1.0625    , -1.5234375 ,  2.28125   , ...,  0.71484375,
        -0.625     ,  0.15332031]], dtype=float32)


k_mean_space
(20, 2)


array([[60.773357, 81.3287  ],
       [70.42866 , 91.990486],
       [68.90856 , 84.35555 ],
       [70.56264 , 47.265945],
       [63.19098 , 84.68183 ],
       [69.50292 , 50.048866],
       [62.295574, 84.9526  ],
       [60.79238 , 82.21347 ],
       [70.26658 , 92.31546 ],
       [70.69095 , 48.994755],
       [68.72219 , 88.33718 ],
       [83.715   , 60.44298 ],
       [83.84033 , 60.808823],
       [74.74814 , 75.9135  ],
       [89.35571 , 83.6034  ],
       [79.68253 , 93.375305],
       [84.64626 , 96.857704],
       [73.84807 , 84.98984 ],
       [67.82317 , 83.62964 ],
       [62.659344, 76.034454]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-67.70804715, -29.89943743])


closest
(2,)


array([0, 3])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 0.9297, -0.1836, -4.7188,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.1016, -3.9844,  0.3457,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.5469, -2.5781, -9.5000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-1.9297, -5.7188, -6.9375,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.4062, -4.4688, -9.3125,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.5781, -0.3496, -1.9219,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9444e-01, 5.2184e-03, 1.3906e-04,  ..., 1.6059e-21, 1.4172e-21,
         9.7406e-22],
        [9.9999e-01, 1.1479e-05, 1.2752e-07,  ..., 1.5832e-23, 7.4785e-24,
         6.5997e-24],
        [9.9988e-01, 1.0890e-04, 2.2601e-06,  ..., 2.0327e-23, 2.0327e-23,
         1.3970e-23],
        ...,
        [9.9968e-01, 2.6117e-04, 4.5385e-05,  ..., 4.8751e-23, 4.3022e-23,
         9.5996e-24],
        [9.9994e-01, 5.1442e-05, 8.9392e-06,  ..., 2.6101e-23, 2.3034e-23,
         1.2329e-23],
        [9.9422e-01, 3.1644e-03, 2.1748e-03,  ..., 1.7836e-23, 7.4353e-24,
         7.8368e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9944, 0.9997, 0.9998,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9997, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9942, 0.9974, 0.9996,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([25])


tensor([ 0,  1,  2,  3,  4,  4,  4,  5,  6,  6,  6,  7,  8,  9, 10, 10, 11, 12,
        13, 14, 15, 16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([25, 25])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967,  2533,  2415, 


carryover_candidate_logprobs
torch.Size([25])


tensor([-2.7955, -3.4205, -3.4686, -3.5684, -4.6561, -4.6561, -4.6561, -4.0804,
        -4.8213, -4.8213, -4.8213, -5.8327, -6.3327, -6.4127, -6.5210, -6.5210,
        -7.3960, -3.9522, -6.2856, -4.4897, -4.8013, -5.4422, -6.3961, -2.9675,
        -3.9670], device='cuda:0')


new_candidate_toks
torch.Size([25, 1])


tensor([[ 2038],
        [  362],
        [ 7205],
        [29892],
        [11858],
        [ 3171],
        [ 5272],
        [29892],
        [11858],
        [ 3171],
        [ 5272],
        [ 2038],
        [  362],
        [29892],
        [19224],
        [ 2533],
        [ 3233],
        [ 3233],
        [  338],
        [29906],
        [29900],
        [  342],
        [29889],
        [ 7205],
        [ 2038]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([25])


tensor([-5.5727e-03, -1.1802e-05, -1.1624e-04, -6.0797e-06, -4.9857e-01,
        -1.2486e+00, -2.3736e+00, -9.1795e-05, -5.7206e-01, -1.4471e+00,
        -1.6971e+00, -5.5188e-03, -2.6226e-05, -9.0246e-05, -5.3333e-01,
        -1.1583e+00, -2.1458e-06, -2.8610e-06,  0.0000e+00,  0.0000e+00,
         0.0000e+00, -7.3910e-06, -3.2454e-04, -6.4494e-05, -5.7993e-03],
       device='cuda:0')


new_candidates
torch.Size([25, 26])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
  


new_candidate_logprobs
torch.Size([25])


tensor([-2.8011, -3.4205, -3.4687, -3.5684, -5.1546, -5.9046, -7.0296, -4.0805,
        -5.3934, -6.2684, -6.5184, -5.8382, -6.3327, -6.4128, -7.0543, -7.6793,
        -7.3960, -3.9522, -6.2856, -4.4897, -4.8013, -5.4422, -6.3965, -2.9675,
        -3.9728], device='cuda:0')

infer end: GPU memory used: 16143 MB.
event: level
id: 14
data: [{"content": "above", "parent": 0, "prob": -2.8010740280151367}, {"content": "ation", "parent": 1, "prob": -3.4205129146575928}, {"content": "sea", "parent": 2, "prob": -3.468717336654663}, {"content": ",", "parent": 3, "prob": -3.5684261322021484}, {"content": "elev", "parent": 4, "prob": -5.154644966125488}, {"content": "height", "parent": 4, "prob": -5.904644966125488}, {"content": "alt", "parent": 4, "prob": -7.029644966125488}, {"content": ",", "parent": 5, "prob": -4.080491065979004}, {"content": "elev", "parent": 6, "prob": -5.393393516540527}, {"content": "height", "parent": 6, "prob": -6.268393516540527}, {"content": "alt", "parent": 6, "prob": -6.518393516540527}, {"content": "above", "parent": 7, "prob": -5.838197231292725}, {"content": "ation", "parent": 8, "prob": -6.332704544067383}, {"content": ",", "parent": 9, "prob": -6.412838459014893}, {"content": "peak", "parent": 10, "prob": -7.0543060302734375}, {"co

array([[-2.703125  , -1.296875  ,  0.41796875, ...,  0.01855469,
        -1.515625  , -2.28125   ],
       [-1.0234375 , -2.046875  ,  1.6484375 , ...,  1.4296875 ,
        -0.67578125, -0.28125   ],
       [ 2.09375   , -1.34375   ,  0.27148438, ...,  1.140625  ,
        -1.34375   ,  2.25      ],
       ...,
       [-1.3359375 , -2.078125  ,  1.390625  , ..., -0.34765625,
         0.55859375,  0.5       ],
       [ 0.84375   ,  0.40820312,  0.05371094, ...,  1.328125  ,
        -1.4375    , -0.7890625 ],
       [ 0.9765625 ,  0.        ,  2.        , ..., -1.3515625 ,
        -1.3046875 , -4.46875   ]], dtype=float32)


k_mean_space
(20, 2)


array([[87.55755 , 69.356674],
       [70.541374, 61.88573 ],
       [91.97991 , 78.919556],
       [90.74477 , 57.76753 ],
       [54.359833, 85.89207 ],
       [62.866653, 70.02063 ],
       [54.27445 , 86.828735],
       [90.74997 , 57.821526],
       [54.26709 , 85.80427 ],
       [61.98991 , 69.926315],
       [54.025368, 86.70563 ],
       [87.68707 , 69.731926],
       [69.556946, 63.379185],
       [90.88506 , 57.82705 ],
       [78.20322 , 63.928196],
       [78.47575 , 87.00762 ],
       [83.29239 , 54.697857],
       [83.9315  , 53.83229 ],
       [93.483315, 75.8771  ],
       [92.76551 , 80.4699  ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-43.94842148, -65.10069609])


closest
(2,)


array([10, 17])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 2.2031, -4.8750, -9.7500,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.6367, -1.6484, -1.4453,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.1875, -6.4688, -6.0625,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.7305, -3.2969, -4.0625,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.6875, -2.1562, -4.7188,  ...,  0.0000,  0.0000,  0.0000],
        [-2.4844, -1.3125, -4.7500,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9992e-01, 7.4846e-05, 5.4218e-06,  ..., 6.2612e-23, 6.2612e-23,
         2.9576e-23],
        [9.9308e-01, 5.2112e-03, 7.9917e-04,  ..., 1.3164e-22, 1.3164e-22,
         4.8429e-23],
        [1.0000e+00, 1.9947e-06, 2.8453e-08,  ..., 9.7276e-29, 7.5759e-29,
         5.9001e-29],
        ...,
        [9.9999e-01, 1.0130e-05, 7.3381e-07,  ..., 6.1999e-24, 3.5326e-24,
         3.1175e-24],
        [9.9993e-01, 5.8291e-05, 2.5611e-06,  ..., 3.0169e-21, 1.6148e-21,
         9.7944e-22],
        [1.0000e+00, 1.7832e-11, 3.9790e-12,  ..., 2.3803e-26, 1.8538e-26,
         9.2293e-28]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9931, 0.9983, 0.9991,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([20])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19], device='cuda:0')


carryover_candidates
torch.Size([20, 26])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
  


carryover_candidate_logprobs
torch.Size([20])


tensor([-2.8011, -3.4205, -3.4687, -3.5684, -5.1546, -5.9046, -7.0296, -4.0805,
        -5.3934, -6.2684, -6.5184, -5.8382, -6.3327, -6.4128, -7.0543, -7.6793,
        -7.3960, -3.9522, -6.2856, -4.4897], device='cuda:0')


new_candidate_toks
torch.Size([20, 1])


tensor([[ 7205],
        [ 2038],
        [ 3233],
        [  338],
        [  362],
        [ 2038],
        [ 4279],
        [  338],
        [  362],
        [ 2038],
        [ 4279],
        [ 7205],
        [ 2038],
        [  338],
        [ 2038],
        [ 2415],
        [29892],
        [29892],
        [ 8040],
        [29900]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([20])


tensor([-8.2973e-05, -6.9413e-03, -2.0266e-06,  0.0000e+00, -2.0385e-05,
        -3.1343e-02, -1.0729e-06,  0.0000e+00, -3.1233e-05, -1.2201e-02,
        -1.3113e-06, -4.0651e-05, -2.8050e-02,  0.0000e+00, -4.6015e-02,
         0.0000e+00, -8.1542e-05, -1.1802e-05, -6.5448e-05,  0.0000e+00],
       device='cuda:0')


new_candidates
torch.Size([20, 27])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205,  3233],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892,   338],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29


new_candidate_logprobs
torch.Size([20])


tensor([-2.8012, -3.4275, -3.4687, -3.5684, -5.1547, -5.9360, -7.0296, -4.0805,
        -5.3934, -6.2806, -6.5184, -5.8382, -6.3608, -6.4128, -7.1003, -7.6793,
        -7.3961, -3.9522, -6.2856, -4.4897], device='cuda:0')

infer end: GPU memory used: 16207 MB.
event: level
id: 15
data: [{"content": "sea", "parent": 0, "prob": -2.801156997680664}, {"content": "above", "parent": 1, "prob": -3.4274542331695557}, {"content": "level", "parent": 2, "prob": -3.468719482421875}, {"content": "is", "parent": 3, "prob": -3.5684261322021484}, {"content": "ation", "parent": 4, "prob": -5.154665470123291}, {"content": "above", "parent": 5, "prob": -5.93598747253418}, {"content": "itude", "parent": 6, "prob": -7.029645919799805}, {"content": "is", "parent": 7, "prob": -4.080491065979004}, {"content": "ation", "parent": 8, "prob": -5.393424987792969}, {"content": "above", "parent": 9, "prob": -6.280594825744629}, {"content": "itude", "parent": 10, "prob": -6.518394947052002}, {"content": "sea", "parent": 11, "prob": -5.838237762451172}, {"content": "above", "parent": 12, "prob": -6.360754013061523}, {"content": "is", "parent": 13, "prob": -6.412838459014893}, {"content": "above", "parent": 14, "prob": -7.100320816040039

array([[ 2.34375   , -0.65234375,  1.328125  , ...,  1.5546875 ,
        -1.53125   ,  2.453125  ],
       [-2.921875  , -1.9609375 ,  0.51171875, ...,  0.36132812,
        -1.3984375 , -1.984375  ],
       [-1.21875   , -1.6875    ,  2.015625  , ..., -0.54296875,
         0.625     ,  0.28515625],
       ...,
       [-1.1171875 ,  0.76953125, -1.4453125 , ..., -0.00927734,
        -0.59375   ,  0.8984375 ],
       [ 0.36523438,  1.796875  ,  1.4140625 , ...,  1.828125  ,
        -0.28320312, -2.71875   ],
       [ 0.3203125 ,  2.046875  ,  1.        , ..., -0.59375   ,
         1.546875  , -0.296875  ]], dtype=float32)


k_mean_space
(20, 2)


array([[ 73.84493 , 101.21751 ],
       [ 57.546036,  95.36855 ],
       [ 61.70305 ,  94.15863 ],
       [ 87.0548  ,  31.047789],
       [ 59.246727,  98.12012 ],
       [ 57.4699  ,  94.58    ],
       [ 59.366417,  97.751686],
       [ 86.68634 ,  31.339916],
       [ 59.61545 ,  98.44553 ],
       [ 57.769245,  94.43616 ],
       [ 59.644627,  98.04832 ],
       [ 73.66572 , 101.050156],
       [ 57.88315 ,  95.58581 ],
       [ 87.00747 ,  30.786615],
       [ 59.1855  ,  96.13044 ],
       [ 66.088615,  95.45535 ],
       [ 71.24657 ,  91.532265],
       [ 71.261345,  92.13603 ],
       [ 91.00093 ,  56.36619 ],
       [ 95.51597 ,  87.606766]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-84.33694911, -24.83705664])


closest
(2,)


array([ 5, 13])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 3.1875, -8.3750, -5.6562,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.7344, -5.1250, -9.4375,  ...,  0.0000,  0.0000,  0.0000],
        [-0.3926, -4.2188, -4.9688,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 2.2344, -3.7188, -4.8750,  ...,  0.0000,  0.0000,  0.0000],
        [ 8.1875, -0.1260, -0.7109,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.8594, -2.9531, -0.5664,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 2.5613e-06, 7.7344e-08,  ..., 2.6442e-28, 1.6038e-28,
         1.1023e-28],
        [9.9991e-01, 8.4811e-05, 2.5611e-06,  ..., 1.0323e-22, 6.2612e-23,
         5.5255e-23],
        [9.9999e-01, 8.9397e-06, 3.0590e-07,  ..., 2.4279e-24, 2.1426e-24,
         1.0121e-24],
        ...,
        [1.0000e+00, 4.9445e-09, 2.6466e-09,  ..., 1.5521e-25, 8.3079e-26,
         7.3317e-26],
        [9.9999e-01, 4.2228e-06, 7.3382e-07,  ..., 6.7320e-22, 5.9409e-22,
         4.0831e-22],
        [1.0000e+00, 3.7751e-11, 3.7751e-11,  ..., 7.7276e-27, 1.0458e-27,
         9.2293e-28]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([22])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 15, 15,
        16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([22, 27])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205,  3233],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892,   338],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29


carryover_candidate_logprobs
torch.Size([22])


tensor([-2.8012, -3.4275, -3.4687, -3.5684, -5.1547, -5.9360, -7.0296, -4.0805,
        -5.3934, -6.2806, -6.5184, -5.8382, -6.3608, -6.4128, -7.1003, -7.6793,
        -7.6793, -7.6793, -7.3961, -3.9522, -6.2856, -4.4897], device='cuda:0')


new_candidate_toks
torch.Size([22, 1])


tensor([[ 3233],
        [ 7205],
        [29892],
        [ 8040],
        [ 2038],
        [ 7205],
        [ 2038],
        [ 8040],
        [ 2038],
        [ 7205],
        [ 2038],
        [ 3233],
        [ 7205],
        [ 8040],
        [ 7205],
        [ 2038],
        [29892],
        [  515],
        [  338],
        [  338],
        [18274],
        [29906]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([22])


tensor([-2.7418e-06, -9.0722e-05, -9.8944e-06, -6.6163e-05, -2.0787e-02,
        -1.8580e-04, -2.7305e-02, -7.2005e-05, -5.9636e-02, -1.0801e-04,
        -2.7153e-02, -1.3113e-06, -1.6989e-04, -5.7222e-05, -8.7623e-05,
        -8.0659e-01, -8.0659e-01, -2.3066e+00,  0.0000e+00,  0.0000e+00,
        -5.8413e-06,  0.0000e+00], device='cuda:0')


new_candidates
torch.Size([22, 28])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362,  2038,  7205],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205,  3233, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892,   338,  8040],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9


new_candidate_logprobs
torch.Size([22])


tensor([-2.8012, -3.4275, -3.4687, -3.5685, -5.1755, -5.9362, -7.0570, -4.0806,
        -5.4531, -6.2807, -6.5455, -5.8382, -6.3609, -6.4129, -7.1004, -8.4859,
        -8.4859, -9.9859, -7.3961, -3.9522, -6.2856, -4.4897], device='cuda:0')

infer end: GPU memory used: 16275 MB.
event: level
id: 16
data: [{"content": "level", "parent": 0, "prob": -2.8011598587036133}, {"content": "sea", "parent": 1, "prob": -3.4275450706481934}, {"content": ",", "parent": 2, "prob": -3.4687294960021973}, {"content": "Mount", "parent": 3, "prob": -3.5684924125671387}, {"content": "above", "parent": 4, "prob": -5.175452709197998}, {"content": "sea", "parent": 5, "prob": -5.936173439025879}, {"content": "above", "parent": 6, "prob": -7.056951522827148}, {"content": "Mount", "parent": 7, "prob": -4.080563068389893}, {"content": "above", "parent": 8, "prob": -5.453061580657959}, {"content": "sea", "parent": 9, "prob": -6.280703067779541}, {"content": "above", "parent": 10, "prob": -6.545548439025879}, {"content": "level", "parent": 11, "prob": -5.8382391929626465}, {"content": "sea", "parent": 12, "prob": -6.360923767089844}, {"content": "Mount", "parent": 13, "prob": -6.412895679473877}, {"content": "sea", "parent": 14, "prob": -7.100408554077

array([[-0.84375   , -1.5390625 ,  2.34375   , ..., -0.11865234,
         0.56640625,  0.703125  ],
       [ 2.078125  , -1.2734375 ,  0.96484375, ...,  1.3046875 ,
        -1.4140625 ,  2.1875    ],
       [-1.109375  ,  0.71875   , -1.4609375 , ...,  0.14550781,
        -0.52734375,  0.80859375],
       ...,
       [-1.7109375 , -1.59375   ,  2.078125  , ..., -1.375     ,
         0.4765625 , -0.3828125 ],
       [ 0.84375   ,  0.38085938, -0.12451172, ...,  0.8203125 ,
        -1.640625  , -0.5859375 ],
       [ 1.0078125 ,  0.47265625,  0.10009766, ...,  1.109375  ,
        -1.3984375 , -0.63671875]], dtype=float32)


k_mean_space
(20, 2)


array([[ 76.86075 ,  68.51397 ],
       [ 89.27525 ,  30.119253],
       [ 76.68603 ,  89.60724 ],
       [ 65.64901 , 104.69057 ],
       [ 53.04538 ,  90.70579 ],
       [ 88.66759 ,  29.643015],
       [ 53.04711 ,  91.47436 ],
       [ 65.46593 , 104.71868 ],
       [ 53.178432,  91.27894 ],
       [ 88.37614 ,  30.024246],
       [ 52.76295 ,  91.82493 ],
       [ 76.473   ,  68.09911 ],
       [ 89.10831 ,  30.354155],
       [ 65.713104, 104.716095],
       [ 89.51196 ,  33.41116 ],
       [ 54.58307 ,  91.9938  ],
       [ 75.03917 ,  89.22654 ],
       [ 64.202126,  93.51482 ],
       [ 64.3659  , 102.174385],
       [ 64.364044, 102.09687 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-80.06767225, -37.74515295])


closest
(2,)


array([10,  5])


last_tok_logits
torch.Size([20, 32064])


tensor([[-1.2734, -5.2188, -6.3750,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.4375, -6.4688, -5.7500,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.7188, -4.5625, -6.2812,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0776, -7.4062, -8.9375,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.8750, -1.6641, -4.9375,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.3125, -2.1250, -4.5312,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9995e-01, 5.1442e-05, 1.1253e-07,  ..., 2.7511e-24, 1.8908e-24,
         1.6686e-24],
        [1.0000e+00, 1.7603e-06, 4.6912e-08,  ..., 1.2490e-28, 1.1023e-28,
         6.6857e-29],
        [1.0000e+00, 2.9990e-09, 7.5826e-10,  ..., 1.0668e-25, 3.4633e-26,
         3.0563e-26],
        ...,
        [9.9971e-01, 1.5842e-04, 8.4794e-05,  ..., 2.8055e-22, 2.4759e-22,
         8.4719e-24],
        [9.9986e-01, 1.2339e-04, 6.1434e-06,  ..., 2.6622e-21, 1.4250e-21,
         9.7936e-22],
        [9.9991e-01, 8.4810e-05, 3.7263e-06,  ..., 3.4185e-21, 1.8298e-21,
         1.8298e-21]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9997, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([20])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19], device='cuda:0')


carryover_candidates
torch.Size([20, 28])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362,  2038,  7205],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205,  3233, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892,   338,  8040],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9


carryover_candidate_logprobs
torch.Size([20])


tensor([-2.8012, -3.4275, -3.4687, -3.5685, -5.1755, -5.9362, -7.0570, -4.0806,
        -5.4531, -6.2807, -6.5455, -5.8382, -6.3609, -6.4129, -7.1004, -8.4859,
        -8.4859, -9.9859, -7.3961, -3.9522], device='cuda:0')


new_candidate_toks
torch.Size([20, 1])


tensor([[29892],
        [ 3233],
        [  338],
        [18274],
        [ 7205],
        [ 3233],
        [ 7205],
        [18274],
        [ 7205],
        [ 3233],
        [ 7205],
        [29892],
        [ 3233],
        [18274],
        [ 3233],
        [ 7205],
        [  338],
        [ 7205],
        [ 8040],
        [ 8040]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([20])


tensor([-5.1857e-05, -1.7881e-06,  0.0000e+00, -8.2255e-06, -1.3245e-04,
        -3.2187e-06, -3.8417e-04, -3.3379e-06, -1.5415e-04, -1.6689e-06,
        -1.6357e-04, -9.6803e-05, -1.7881e-06, -5.9605e-06, -4.7684e-07,
        -1.6941e-04, -6.3004e-04, -2.8644e-04, -1.3853e-04, -9.4180e-05],
       device='cuda:0')


new_candidates
torch.Size([20, 29])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362,  2038,  7205,  3233],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205,  3233, 29892,   338],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892,   338,  8040, 18274],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 3200


new_candidate_logprobs
torch.Size([20])


tensor([-2.8012, -3.4275, -3.4687, -3.5685, -5.1756, -5.9362, -7.0573, -4.0806,
        -5.4532, -6.2807, -6.5457, -5.8383, -6.3609, -6.4129, -7.1004, -8.4861,
        -8.4865, -9.9862, -7.3963, -3.9523], device='cuda:0')

infer end: GPU memory used: 16345 MB.
event: level
id: 17
data: [{"content": ",", "parent": 0, "prob": -2.8012118339538574}, {"content": "level", "parent": 1, "prob": -3.427546977996826}, {"content": "is", "parent": 2, "prob": -3.4687294960021973}, {"content": "Ever", "parent": 3, "prob": -3.5685007572174072}, {"content": "sea", "parent": 4, "prob": -5.1755852699279785}, {"content": "level", "parent": 5, "prob": -5.936176776885986}, {"content": "sea", "parent": 6, "prob": -7.05733585357666}, {"content": "Ever", "parent": 7, "prob": -4.08056640625}, {"content": "sea", "parent": 8, "prob": -5.453215599060059}, {"content": "level", "parent": 9, "prob": -6.280704975128174}, {"content": "sea", "parent": 10, "prob": -6.545711994171143}, {"content": ",", "parent": 11, "prob": -5.838335990905762}, {"content": "level", "parent": 12, "prob": -6.360925674438477}, {"content": "Ever", "parent": 13, "prob": -6.412901878356934}, {"content": "level", "parent": 14, "prob": -7.100409030914307}, {"conten

array([[-1.0078125 ,  0.93359375, -1.359375  , ...,  0.3671875 ,
        -0.36328125,  0.9140625 ],
       [-0.8203125 , -1.921875  ,  2.125     , ..., -0.14550781,
         0.875     ,  0.45703125],
       [ 1.        ,  0.27148438,  0.12158203, ...,  0.92578125,
        -1.375     , -0.63671875],
       ...,
       [ 0.60546875, -0.6796875 ,  2.328125  , ...,  0.35546875,
        -2.109375  ,  3.28125   ],
       [ 0.3125    ,  1.8203125 ,  1.375     , ...,  1.4921875 ,
        -0.65234375, -2.859375  ],
       [ 0.31054688,  1.9375    ,  1.46875   , ...,  1.8515625 ,
        -0.33789062, -2.84375   ]], dtype=float32)


k_mean_space
(20, 2)


array([[106.35795 ,  62.797592],
       [ 95.87268 ,  53.782043],
       [111.92684 ,  67.80514 ],
       [114.133156,  74.64795 ],
       [ 15.420188,  86.97797 ],
       [ 95.93352 ,  53.370987],
       [ 13.988643,  86.622986],
       [114.06879 ,  74.67621 ],
       [ 15.305088,  87.00738 ],
       [ 95.4317  ,  52.983444],
       [ 13.644617,  86.719666],
       [106.33786 ,  62.662918],
       [ 95.31118 ,  53.367344],
       [114.13881 ,  74.745445],
       [ 94.6712  ,  54.998135],
       [ 16.682106,  87.723946],
       [112.09856 ,  68.5229  ],
       [ 45.371502,  89.49208 ],
       [113.62226 ,  71.63664 ],
       [113.68181 ,  71.69305 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-42.7040987 , -75.11105633])


closest
(2,)


array([10,  9])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 1.2969, -4.0625, -6.5000,  ...,  0.0000,  0.0000,  0.0000],
        [-1.6094, -5.1875, -5.1875,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.2812, -1.9609, -4.7812,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 3.5469, -6.1250, -2.5312,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.8125,  0.0479, -0.5781,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.7188,  0.2637, -0.5352,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 4.3635e-09, 1.1033e-09,  ..., 1.9930e-25, 1.2088e-25,
         4.4469e-26],
        [9.9997e-01, 2.7536e-05, 1.8553e-07,  ..., 3.5325e-24, 2.4279e-24,
         1.6686e-24],
        [9.9990e-01, 8.4810e-05, 3.7263e-06,  ..., 2.0734e-21, 1.8298e-21,
         1.6148e-21],
        ...,
        [1.0000e+00, 1.6374e-07, 1.6374e-07,  ..., 2.3803e-26, 1.8538e-26,
         1.1244e-26],
        [1.0000e+00, 1.2099e-06, 9.4224e-07,  ..., 4.6268e-22, 2.1856e-22,
         1.9287e-22],
        [1.0000e+00, 2.5613e-06, 9.4224e-07,  ..., 7.6283e-22, 6.7320e-22,
         6.7320e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([20])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19], device='cuda:0')


carryover_candidates
torch.Size([20, 29])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362,  2038,  7205,  3233],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205,  3233, 29892,   338],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892,   338,  8040, 18274],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 3200


carryover_candidate_logprobs
torch.Size([20])


tensor([-2.8012, -3.4275, -3.4687, -3.5685, -5.1756, -5.9362, -7.0573, -4.0806,
        -5.4532, -6.2807, -6.5457, -5.8383, -6.3609, -6.4129, -7.1004, -8.4861,
        -8.4865, -9.9862, -7.3963, -3.9523], device='cuda:0')


new_candidate_toks
torch.Size([20, 1])


tensor([[  338],
        [29892],
        [ 8040],
        [  342],
        [ 3233],
        [29892],
        [ 3233],
        [  342],
        [ 3233],
        [29892],
        [ 3233],
        [  338],
        [29892],
        [  342],
        [29892],
        [ 3233],
        [ 8040],
        [ 3233],
        [18274],
        [18274]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([20])


tensor([ 0.0000e+00, -2.8015e-05, -9.5014e-05, -7.2718e-06, -2.3842e-06,
        -8.5000e-05, -2.5034e-06, -7.2718e-06, -1.9074e-06, -8.5119e-05,
        -1.5497e-06,  0.0000e+00, -4.6254e-05, -6.4373e-06, -3.0160e-05,
        -8.3447e-07, -1.4545e-04, -2.3842e-07, -3.0994e-06, -4.6492e-06],
       device='cuda:0')


new_candidates
torch.Size([20, 30])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362,  2038,  7205,  3233, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205,  3233, 29892,   338,  8040],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892,   338,  8040, 18274,   342],
        [    1, 32010,  1724,   338,   278,  993


new_candidate_logprobs
torch.Size([20])


tensor([-2.8012, -3.4276, -3.4688, -3.5685, -5.1756, -5.9363, -7.0573, -4.0806,
        -5.4532, -6.2808, -6.5457, -5.8383, -6.3610, -6.4129, -7.1004, -8.4861,
        -8.4867, -9.9862, -7.3963, -3.9523], device='cuda:0')

infer end: GPU memory used: 16417 MB.
event: level
id: 18
data: [{"content": "is", "parent": 0, "prob": -2.8012118339538574}, {"content": ",", "parent": 1, "prob": -3.42757511138916}, {"content": "Mount", "parent": 2, "prob": -3.468824625015259}, {"content": "est", "parent": 3, "prob": -3.5685081481933594}, {"content": "level", "parent": 4, "prob": -5.1755876541137695}, {"content": ",", "parent": 5, "prob": -5.9362616539001465}, {"content": "level", "parent": 6, "prob": -7.057338237762451}, {"content": "est", "parent": 7, "prob": -4.080573558807373}, {"content": "level", "parent": 8, "prob": -5.453217506408691}, {"content": ",", "parent": 9, "prob": -6.280790328979492}, {"content": "level", "parent": 10, "prob": -6.545713424682617}, {"content": "is", "parent": 11, "prob": -5.838335990905762}, {"content": ",", "parent": 12, "prob": -6.360971927642822}, {"content": "est", "parent": 13, "prob": -6.412908554077148}, {"content": ",", "parent": 14, "prob": -7.100439071655273}, {"content": "l

array([[ 1.0703125 ,  0.34375   ,  0.02490234, ...,  0.953125  ,
        -1.4921875 , -0.6953125 ],
       [-1.140625  ,  0.73828125, -1.4765625 , ...,  0.35742188,
        -0.40234375,  0.73046875],
       [ 0.20507812,  1.796875  ,  1.4375    , ...,  1.7578125 ,
        -0.41015625, -2.78125   ],
       ...,
       [-0.40820312, -2.640625  ,  1.828125  , ..., -0.4921875 ,
         0.43164062,  2.34375   ],
       [-1.2890625 ,  0.8046875 ,  2.40625   , ..., -0.89453125,
        -0.33984375,  2.078125  ],
       [-1.296875  ,  0.80859375,  2.25      , ..., -0.83203125,
        -0.33203125,  2.25      ]], dtype=float32)


k_mean_space
(20, 2)


array([[ 88.66526 ,  50.75498 ],
       [ 45.919212,  92.360664],
       [ 95.142586,  49.982918],
       [ 68.991234,  87.14514 ],
       [ 42.8192  ,  94.48282 ],
       [ 46.89272 ,  89.095505],
       [ 42.293377,  94.37488 ],
       [ 68.96273 ,  86.91394 ],
       [ 42.441353,  94.081406],
       [ 47.21174 ,  89.06212 ],
       [ 42.02693 ,  93.96336 ],
       [ 88.95346 ,  50.597115],
       [ 46.045826,  92.3973  ],
       [ 69.06783 ,  86.87033 ],
       [ 46.50631 ,  92.3601  ],
       [ 42.476284,  93.876465],
       [ 95.41326 ,  50.32144 ],
       [ 51.338486,  97.19077 ],
       [100.65658 ,  67.7804  ],
       [100.53791 ,  67.88595 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-85.87213612, -31.943573  ])


closest
(2,)


array([10,  2])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 4.4062, -1.7109, -4.7188,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.2266, -4.3750, -6.3438,  ...,  0.0000,  0.0000,  0.0000],
        [ 8.0625,  0.1543, -0.2080,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.7891, -4.1562, -7.5938,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.8125, -3.6562, -4.7812,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.7734, -3.5000, -4.3438,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9988e-01, 1.0890e-04, 3.7262e-06,  ..., 3.4184e-21, 2.0733e-21,
         1.4250e-21],
        [1.0000e+00, 3.8507e-09, 1.2502e-09,  ..., 2.2583e-25, 1.2088e-25,
         3.0563e-26],
        [9.9999e-01, 4.2228e-06, 1.0677e-06,  ..., 1.1099e-21, 7.6283e-22,
         7.6283e-22],
        ...,
        [9.9994e-01, 5.1442e-05, 6.9619e-06,  ..., 2.0328e-23, 3.1173e-24,
         7.8819e-25],
        [9.9999e-01, 6.9622e-06, 1.3709e-06,  ..., 1.0881e-23, 6.5998e-24,
         4.0030e-24],
        [9.9999e-01, 6.1442e-06, 7.3382e-07,  ..., 1.0881e-23, 6.5998e-24,
         3.5326e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([20])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19], device='cuda:0')


carryover_candidates
torch.Size([20, 30])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362,  2038,  7205,  3233, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205,  3233, 29892,   338,  8040],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892,   338,  8040, 18274,   342],
        [    1, 32010,  1724,   338,   278,  993


carryover_candidate_logprobs
torch.Size([20])


tensor([-2.8012, -3.4276, -3.4688, -3.5685, -5.1756, -5.9363, -7.0573, -4.0806,
        -5.4532, -6.2808, -6.5457, -5.8383, -6.3610, -6.4129, -7.1004, -8.4861,
        -8.4867, -9.9862, -7.3963, -3.9523], device='cuda:0')


new_candidate_toks
torch.Size([20, 1])


tensor([[ 8040],
        [  338],
        [18274],
        [29889],
        [29892],
        [  338],
        [29892],
        [29889],
        [29892],
        [  338],
        [29892],
        [ 8040],
        [  338],
        [29889],
        [  338],
        [29892],
        [18274],
        [29892],
        [  342],
        [  342]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([20])


tensor([-1.2172e-04,  0.0000e+00, -6.7950e-06, -2.5153e-05, -1.2994e-05,
         0.0000e+00, -2.7895e-05, -3.4810e-05, -1.9074e-05,  0.0000e+00,
        -4.5658e-05, -1.2196e-04,  0.0000e+00, -2.7776e-05,  0.0000e+00,
        -1.5974e-05, -4.1723e-06, -6.4017e-05, -8.5831e-06, -7.1526e-06],
       device='cuda:0')


new_candidates
torch.Size([20, 31])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362,  2038,  7205,  3233, 29892,
           338],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205,  3233, 29892,   338,  8040,
         18274],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892,   338,  8040, 18274,   342,
 


new_candidate_logprobs
torch.Size([20])


tensor([-2.8013, -3.4276, -3.4688, -3.5685, -5.1756, -5.9363, -7.0574, -4.0806,
        -5.4532, -6.2808, -6.5458, -5.8385, -6.3610, -6.4129, -7.1004, -8.4861,
        -8.4867, -9.9862, -7.3963, -3.9523], device='cuda:0')

infer end: GPU memory used: 16491 MB.
event: level
id: 19
data: [{"content": "Mount", "parent": 0, "prob": -2.8013336658477783}, {"content": "is", "parent": 1, "prob": -3.42757511138916}, {"content": "Ever", "parent": 2, "prob": -3.4688315391540527}, {"content": ".", "parent": 3, "prob": -3.568533420562744}, {"content": ",", "parent": 4, "prob": -5.175600528717041}, {"content": "is", "parent": 5, "prob": -5.9362616539001465}, {"content": ",", "parent": 6, "prob": -7.057366371154785}, {"content": ".", "parent": 7, "prob": -4.080608367919922}, {"content": ",", "parent": 8, "prob": -5.4532365798950195}, {"content": "is", "parent": 9, "prob": -6.280790328979492}, {"content": ",", "parent": 10, "prob": -6.545759201049805}, {"content": "Mount", "parent": 11, "prob": -5.838458061218262}, {"content": "is", "parent": 12, "prob": -6.360971927642822}, {"content": ".", "parent": 13, "prob": -6.412936210632324}, {"content": "is", "parent": 14, "prob": -7.100439071655273}, {"content": ",", "parent":

array([[ 0.25      ,  1.78125   ,  1.3359375 , ...,  1.828125  ,
        -0.47070312, -2.6875    ],
       [ 1.0390625 ,  0.2734375 ,  0.0213623 , ...,  1.0625    ,
        -1.375     , -0.69921875],
       [-1.3203125 ,  0.8046875 ,  2.25      , ..., -1.        ,
        -0.34960938,  2.234375  ],
       ...,
       [-1.0703125 ,  0.68359375, -1.734375  , ...,  0.24707031,
         0.17285156,  1.2421875 ],
       [ 0.515625  ,  0.88671875, -0.29492188, ...,  2.46875   ,
         2.53125   ,  2.890625  ],
       [ 0.6171875 ,  0.98046875, -0.1875    , ...,  2.203125  ,
         2.296875  ,  3.03125   ]], dtype=float32)


k_mean_space
(20, 2)


array([[91.25027 , 54.47929 ],
       [82.901146, 34.23847 ],
       [99.47821 , 80.1759  ],
       [62.55421 , 89.10691 ],
       [40.067482, 88.23181 ],
       [82.41197 , 34.319633],
       [40.222836, 88.12033 ],
       [62.69413 , 89.17519 ],
       [39.78277 , 88.15085 ],
       [82.5073  , 34.06991 ],
       [40.061943, 87.94708 ],
       [91.39897 , 54.77702 ],
       [83.06154 , 34.30362 ],
       [62.909885, 89.24661 ],
       [82.81    , 34.473583],
       [42.326523, 91.752716],
       [99.65027 , 80.18983 ],
       [43.39331 , 91.69321 ],
       [66.98922 , 87.61697 ],
       [66.947044, 87.75988 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-68.11491466, -49.70133853])


closest
(2,)


array([8, 9])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 8.2500, -0.0162, -0.2637,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.1250, -2.1719, -4.6875,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.8203, -3.4375, -4.3438,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 2.6250, -2.0156, -5.2188,  ...,  0.0000,  0.0000,  0.0000],
        [-1.6875, -6.8438, -7.7812,  ...,  0.0000,  0.0000,  0.0000],
        [-1.1641, -6.5625, -7.6562,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9999e-01, 5.4222e-06, 8.3152e-07,  ..., 6.7319e-22, 5.2428e-22,
         5.2428e-22],
        [9.9986e-01, 1.2339e-04, 5.4215e-06,  ..., 3.0167e-21, 2.6622e-21,
         1.4250e-21],
        [9.9999e-01, 6.1442e-06, 6.4759e-07,  ..., 9.6026e-24, 8.4743e-24,
         3.5326e-24],
        ...,
        [1.0000e+00, 1.7258e-08, 8.1520e-09,  ..., 1.8722e-25, 1.7588e-25,
         1.3697e-25],
        [9.9996e-01, 2.1445e-05, 1.1478e-05,  ..., 2.3035e-23, 2.3035e-23,
         4.0028e-24],
        [9.9996e-01, 1.6701e-05, 1.6701e-05,  ..., 3.7978e-23, 3.3515e-23,
         7.4783e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([24])


tensor([ 0,  1,  2,  3,  3,  3,  4,  5,  6,  7,  7,  8,  9, 10, 11, 12, 13, 13,
        14, 15, 16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([24, 31])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362,  2038,  7205,  3233, 29892,
           338],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205,  3233, 29892,   338,  8040,
         18274],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892,   338,  8040, 18274,   342,
 


carryover_candidate_logprobs
torch.Size([24])


tensor([-2.8013, -3.4276, -3.4688, -3.5685, -3.5685, -3.5685, -5.1756, -5.9363,
        -7.0574, -4.0806, -4.0806, -5.4532, -6.2808, -6.5458, -5.8385, -6.3610,
        -6.4129, -6.4129, -7.1004, -8.4861, -8.4867, -9.9862, -7.3963, -3.9523],
       device='cuda:0')


new_candidate_toks
torch.Size([24, 1])


tensor([[18274],
        [ 8040],
        [  342],
        [  739],
        [ 8011],
        [ 5976],
        [  338],
        [ 8040],
        [  338],
        [  739],
        [ 8011],
        [  338],
        [ 8040],
        [  338],
        [18274],
        [ 8040],
        [ 8011],
        [  739],
        [ 8040],
        [  338],
        [  342],
        [  338],
        [29889],
        [29889]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([24])


tensor([-7.3910e-06, -1.3805e-04, -7.0334e-06, -7.6691e-01, -1.2669e+00,
        -1.5169e+00,  0.0000e+00, -1.3579e-04,  0.0000e+00, -6.5812e-01,
        -9.0812e-01,  0.0000e+00, -1.0622e-04,  0.0000e+00, -4.6492e-06,
        -1.2410e-04, -5.4141e-01, -1.0414e+00, -1.2482e-04,  0.0000e+00,
        -8.3447e-06,  0.0000e+00, -3.7432e-05, -3.9936e-05], device='cuda:0')


new_candidates
torch.Size([24, 32])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362,  2038,  7205,  3233, 29892,
           338,  8040],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205,  3233, 29892,   338,  8040,
         18274,   342],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892,   338,  


new_candidate_logprobs
torch.Size([24])


tensor([-2.8013, -3.4277, -3.4688, -4.3354, -4.8354, -5.0854, -5.1756, -5.9364,
        -7.0574, -4.7387, -4.9887, -5.4532, -6.2809, -6.5458, -5.8385, -6.3611,
        -6.9543, -7.4543, -7.1006, -8.4861, -8.4867, -9.9862, -7.3963, -3.9523],
       device='cuda:0')

infer end: GPU memory used: 16567 MB.
event: level
id: 20
data: [{"content": "Ever", "parent": 0, "prob": -2.8013410568237305}, {"content": "Mount", "parent": 1, "prob": -3.42771315574646}, {"content": "est", "parent": 2, "prob": -3.468838691711426}, {"content": "It", "parent": 3, "prob": -4.335441589355469}, {"content": "Its", "parent": 3, "prob": -4.835441589355469}, {"content": "Loc", "parent": 3, "prob": -5.085441589355469}, {"content": "is", "parent": 4, "prob": -5.175600528717041}, {"content": "Mount", "parent": 5, "prob": -5.936397552490234}, {"content": "is", "parent": 6, "prob": -7.057366371154785}, {"content": "It", "parent": 7, "prob": -4.7387261390686035}, {"content": "Its", "parent": 7, "prob": -4.988725662231445}, {"content": "is", "parent": 8, "prob": -5.4532365798950195}, {"content": "Mount", "parent": 9, "prob": -6.2808966636657715}, {"content": "is", "parent": 10, "prob": -6.545759201049805}, {"content": "Ever", "parent": 11, "prob": -5.838462829589844}, {"content": "

array([[-1.4609375 ,  0.80859375,  2.21875   , ..., -0.91796875,
        -0.3828125 ,  2.296875  ],
       [ 0.12353516,  1.8515625 ,  1.359375  , ...,  1.8359375 ,
        -0.41015625, -2.828125  ],
       [ 0.71875   ,  1.109375  , -0.12304688, ...,  2.03125   ,
         2.109375  ,  2.890625  ],
       ...,
       [-1.4765625 , -1.0078125 , -0.43945312, ..., -0.984375  ,
        -0.13671875,  1.3828125 ],
       [ 0.19628906,  1.9453125 ,  1.4296875 , ...,  1.640625  ,
        -0.52734375, -2.953125  ],
       [ 0.921875  ,  0.30078125,  0.05004883, ...,  0.7109375 ,
        -1.4296875 , -0.82421875]], dtype=float32)


k_mean_space
(20, 2)


array([[ 95.77378 ,  84.749596],
       [ 89.9187  ,  41.182384],
       [ 67.612686,  88.57681 ],
       [ 50.097122,  94.959885],
       [ 53.83165 ,  90.474174],
       [ 84.37982 , 100.91382 ],
       [ 83.72956 ,  41.68895 ],
       [ 89.80456 ,  41.169594],
       [ 83.59567 ,  41.957153],
       [ 49.9677  ,  94.70631 ],
       [ 54.731148,  90.4161  ],
       [ 83.702774,  41.52716 ],
       [ 89.816086,  41.291534],
       [ 83.574524,  41.79064 ],
       [ 95.827255,  84.74792 ],
       [ 89.955986,  41.541405],
       [ 53.9348  ,  89.93346 ],
       [ 49.915257,  94.87122 ],
       [ 89.87333 ,  40.785843],
       [ 83.89089 ,  41.71164 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-41.86129904, -70.46451783])


closest
(2,)


array([17, 18])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 1.7266, -3.4375, -4.0938,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.9688,  0.0737, -0.2471,  ...,  0.0000,  0.0000,  0.0000],
        [-1.3828, -7.0625, -7.9062,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 3.3750, -5.5938, -9.4375,  ...,  0.0000,  0.0000,  0.0000],
        [ 8.0000,  0.0413, -0.6758,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.5938, -1.9922, -4.7188,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9999e-01, 8.9397e-06, 7.3381e-07,  ..., 7.4785e-24, 5.1399e-24,
         2.7512e-24],
        [9.9999e-01, 6.1442e-06, 1.2099e-06,  ..., 7.6283e-22, 6.7319e-22,
         6.7319e-22],
        [9.9998e-01, 8.9396e-06, 3.7266e-06,  ..., 1.7940e-23, 1.5832e-23,
         4.5359e-24],
        ...,
        [5.1559e-01, 2.7598e-01, 1.8968e-01,  ..., 5.7227e-22, 3.0631e-22,
         3.2285e-23],
        [1.0000e+00, 2.5613e-06, 9.4224e-07,  ..., 6.7320e-22, 4.6268e-22,
         4.0831e-22],
        [9.9988e-01, 1.0890e-04, 4.7845e-06,  ..., 2.6622e-21, 8.6430e-22,
         7.6274e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.5156, 0.7916, 0.9812,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([31])


tensor([ 0,  1,  2,  3,  3,  3,  4,  4,  4,  4,  5,  6,  7,  8,  9,  9,  9, 10,
        11, 12, 13, 14, 15, 16, 16, 16, 17, 17, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([31, 32])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879, 11858,   362,  2038,  7205,  3233, 29892,
           338,  8040],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 11858,   362,  2038,  7205,  3233, 29892,   338,  8040,
         18274,   342],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224,  2038,  7205,  3233, 29892,   338,  


carryover_candidate_logprobs
torch.Size([31])


tensor([-2.8013, -3.4277, -3.4688, -4.3354, -4.3354, -4.3354, -4.8354, -4.8354,
        -4.8354, -4.8354, -5.0854, -5.1756, -5.9364, -7.0574, -4.7387, -4.7387,
        -4.7387, -4.9887, -5.4532, -6.2809, -6.5458, -5.8385, -6.3611, -6.9543,
        -6.9543, -6.9543, -7.4543, -7.4543, -7.4543, -7.1006, -8.4861],
       device='cuda:0')


new_candidate_toks
torch.Size([31, 1])


tensor([[  342],
        [18274],
        [29889],
        [  338],
        [22170],
        [15028],
        [11858],
        [ 2533],
        [19224],
        [ 3171],
        [  630],
        [ 8040],
        [18274],
        [ 8040],
        [22170],
        [  338],
        [15028],
        [19224],
        [ 8040],
        [18274],
        [ 8040],
        [  342],
        [18274],
        [19224],
        [11858],
        [ 2533],
        [22170],
        [  338],
        [15028],
        [18274],
        [ 8040]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([31])


tensor([-9.8944e-06, -8.8215e-06, -1.6809e-05, -1.0024e+00, -1.0024e+00,
        -1.3774e+00, -5.9839e-01, -1.7234e+00, -1.8484e+00, -2.3484e+00,
        -3.5763e-07, -1.2244e-04, -3.5763e-06, -1.5713e-04, -7.1767e-01,
        -9.6767e-01, -2.0927e+00, -3.6830e-02, -1.2089e-04, -4.4108e-06,
        -1.2327e-04, -7.2718e-06, -7.0334e-06, -5.7079e-01, -1.4458e+00,
        -1.9458e+00, -6.6244e-01, -1.2874e+00, -1.6624e+00, -4.6492e-06,
        -1.2363e-04], device='cuda:0')


new_candidates
torch.Size([31, 33])


tensor([[    1, 32010,  1724,  ...,  8040, 18274,   342],
        [    1, 32010,  1724,  ...,   338,  8040, 18274],
        [    1, 32010,  1724,  ..., 18274,   342, 29889],
        ...,
        [    1, 32010,  1724,  ..., 29889,   739, 15028],
        [    1, 32010,  1724,  ...,   338,  8040, 18274],
        [    1, 32010,  1724,  ..., 29892,   338,  8040]], device='cuda:0')


new_candidate_logprobs
torch.Size([31])


tensor([-2.8014, -3.4277, -3.4689, -5.3379, -5.3379, -5.7129, -5.4338, -6.5588,
        -6.6838, -7.1838, -5.0854, -5.1757, -5.9364, -7.0575, -5.4564, -5.7064,
        -6.8314, -5.0256, -5.4534, -6.2809, -6.5459, -5.8385, -6.3611, -7.5251,
        -8.4001, -8.9001, -8.1168, -8.7418, -9.1168, -7.1006, -8.4862],
       device='cuda:0')

infer end: GPU memory used: 16647 MB.
event: level
id: 21
data: [{"content": "est", "parent": 0, "prob": -2.8013510704040527}, {"content": "Ever", "parent": 1, "prob": -3.4277219772338867}, {"content": ".", "parent": 2, "prob": -3.468855619430542}, {"content": "is", "parent": 3, "prob": -5.337873935699463}, {"content": "reaches", "parent": 3, "prob": -5.337873935699463}, {"content": "stands", "parent": 3, "prob": -5.712873935699463}, {"content": "elev", "parent": 4, "prob": -5.433828830718994}, {"content": "sum", "parent": 4, "prob": -6.558828830718994}, {"content": "peak", "parent": 4, "prob": -6.683828830718994}, {"content": "height", "parent": 4, "prob": -7.183829307556152}, {"content": "ated", "parent": 5, "prob": -5.085442066192627}, {"content": "Mount", "parent": 6, "prob": -5.175723075866699}, {"content": "Ever", "parent": 7, "prob": -5.9364013671875}, {"content": "Mount", "parent": 8, "prob": -7.057523727416992}, {"content": "reaches", "parent": 9, "prob": -5.4563984870910645},

array([[ 6.0937500e-01,  1.3203125e+00, -2.3040771e-03, ...,
         1.9687500e+00,  2.0781250e+00,  3.0468750e+00],
       [-1.5156250e+00,  8.8671875e-01,  2.1250000e+00, ...,
        -9.2968750e-01, -3.8671875e-01,  2.1718750e+00],
       [-3.2343750e+00, -5.3125000e-01, -1.2597656e-01, ...,
         1.1640625e+00, -6.5234375e-01,  3.4218750e+00],
       ...,
       [-6.7578125e-01, -1.0234375e+00, -3.8085938e-01, ...,
        -4.2773438e-01,  2.7343750e-02,  1.9687500e+00],
       [ 2.4707031e-01,  1.9140625e+00,  1.4375000e+00, ...,
         1.7187500e+00, -4.8632812e-01, -2.9687500e+00],
       [-1.3671875e+00,  1.0468750e+00,  2.2343750e+00, ...,
        -1.0937500e+00, -3.9257812e-01,  2.1406250e+00]], dtype=float32)


k_mean_space
(20, 2)


array([[64.79196 , 90.929276],
       [92.202644, 54.21237 ],
       [67.9245  , 94.13288 ],
       [58.121735, 96.651276],
       [62.362576, 97.44144 ],
       [57.604137, 95.323395],
       [80.16796 , 98.6681  ],
       [79.43078 , 97.162834],
       [60.459625, 97.951996],
       [65.86938 , 99.67846 ],
       [73.902985, 97.74773 ],
       [86.34197 , 53.938488],
       [92.11281 , 53.85186 ],
       [86.34392 , 53.88055 ],
       [61.885685, 97.30748 ],
       [57.64414 , 96.48313 ],
       [57.709324, 95.04974 ],
       [60.65973 , 97.820274],
       [86.48695 , 54.056847],
       [92.26301 , 53.95682 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-76.62433839, -33.3316288 ])


closest
(2,)


array([ 5, 12])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.7305, -6.5625, -7.3438,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.6172, -3.3750, -3.9688,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.0312, -2.2812, -4.8750,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 4.8750, -3.1406, -7.7188,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.6250, -0.1367, -0.6914,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.7109, -3.2500, -4.3125,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9998e-01, 1.1479e-05, 6.1441e-06,  ..., 2.9578e-23, 2.6102e-23,
         8.4741e-24],
        [9.9999e-01, 6.1442e-06, 5.7150e-07,  ..., 5.1399e-24, 3.5326e-24,
         2.1426e-24],
        [3.7507e-01, 2.9210e-01, 2.9210e-01,  ..., 3.4171e-23, 1.4245e-23,
         1.2571e-23],
        ...,
        [3.5143e-01, 2.4153e-01, 2.1315e-01,  ..., 7.6807e-23, 3.2018e-23,
         2.4936e-23],
        [9.9999e-01, 3.2887e-06, 1.0677e-06,  ..., 5.2429e-22, 4.6268e-22,
         4.6268e-22],
        [9.9999e-01, 4.7851e-06, 3.9278e-07,  ..., 8.4743e-24, 5.1399e-24,
         1.8909e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.3751, 0.6672, 0.9593,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.3514, 0.5930, 0.8061,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([31])


tensor([ 0,  1,  2,  2,  2,  3,  3,  4,  5,  5,  6,  7,  8,  8,  8,  9, 10, 11,
        12, 13, 14, 15, 15, 16, 16, 17, 17, 17, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([31, 33])


tensor([[    1, 32010,  1724,  ...,  8040, 18274,   342],
        [    1, 32010,  1724,  ...,   338,  8040, 18274],
        [    1, 32010,  1724,  ..., 18274,   342, 29889],
        ...,
        [    1, 32010,  1724,  ..., 29889,  8011, 19224],
        [    1, 32010,  1724,  ..., 29892,   338,  8040],
        [    1, 32010,  1724,  ...,   338,  8040, 18274]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([31])


tensor([-2.8014, -3.4277, -3.4689, -3.4689, -3.4689, -5.3379, -5.3379, -5.3379,
        -5.7129, -5.7129, -5.4338, -6.5588, -6.6838, -6.6838, -6.6838, -7.1838,
        -5.0854, -5.1757, -5.9364, -7.0575, -5.4564, -5.7064, -5.7064, -6.8314,
        -6.8314, -5.0256, -5.0256, -5.0256, -5.0256, -5.4534, -6.2809],
       device='cuda:0')


new_candidate_toks
torch.Size([31, 1])


tensor([[29889],
        [  342],
        [  739],
        [ 5976],
        [ 8011],
        [  760],
        [ 5982],
        [  385],
        [  472],
        [14235],
        [  362],
        [ 2415],
        [15028],
        [  338],
        [22170],
        [  338],
        [  297],
        [18274],
        [  342],
        [18274],
        [  385],
        [  760],
        [ 5982],
        [  472],
        [14235],
        [  364],
        [  338],
        [22170],
        [15028],
        [18274],
        [  342]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([31])


tensor([-2.3604e-05, -7.0334e-06, -9.8065e-01, -1.2307e+00, -1.2307e+00,
        -2.2897e-01, -1.6040e+00, -2.7516e-03, -1.0681e-01, -2.7318e+00,
        -2.7418e-06, -2.6226e-06, -7.4062e-01, -9.9062e-01, -2.4906e+00,
        -5.5350e-03, -1.4139e-04, -6.7950e-06, -5.4836e-06, -5.8413e-06,
        -9.9378e-03, -1.8390e-01, -1.8089e+00, -1.6333e-01, -2.2883e+00,
        -1.0457e+00, -1.4207e+00, -1.5457e+00, -1.6707e+00, -5.2452e-06,
        -5.3644e-06], device='cuda:0')


new_candidates
torch.Size([31, 34])


tensor([[    1, 32010,  1724,  ..., 18274,   342, 29889],
        [    1, 32010,  1724,  ...,  8040, 18274,   342],
        [    1, 32010,  1724,  ...,   342, 29889,   739],
        ...,
        [    1, 32010,  1724,  ...,  8011, 19224, 15028],
        [    1, 32010,  1724,  ...,   338,  8040, 18274],
        [    1, 32010,  1724,  ...,  8040, 18274,   342]], device='cuda:0')


new_candidate_logprobs
torch.Size([31])


tensor([-2.8014, -3.4277, -4.4495, -4.6995, -4.6995, -5.5668, -6.9418, -5.3406,
        -5.8197, -8.4447, -5.4338, -6.5588, -7.4245, -7.6745, -9.1745, -7.1894,
        -5.0856, -5.1757, -5.9364, -7.0575, -5.4663, -5.8903, -7.5153, -6.9947,
        -9.1197, -6.0713, -6.4463, -6.5713, -6.6963, -5.4534, -6.2809],
       device='cuda:0')

infer end: GPU memory used: 16729 MB.
event: level
id: 22
data: [{"content": ".", "parent": 0, "prob": -2.801374673843384}, {"content": "est", "parent": 1, "prob": -3.4277291297912598}, {"content": "It", "parent": 2, "prob": -4.449508190155029}, {"content": "Loc", "parent": 2, "prob": -4.6995086669921875}, {"content": "Its", "parent": 2, "prob": -4.6995086669921875}, {"content": "part", "parent": 3, "prob": -5.566843032836914}, {"content": "located", "parent": 3, "prob": -6.941843032836914}, {"content": "an", "parent": 4, "prob": -5.340625762939453}, {"content": "at", "parent": 5, "prob": -5.819679260253906}, {"content": "approximately", "parent": 5, "prob": -8.444679260253906}, {"content": "ation", "parent": 6, "prob": -5.433831691741943}, {"content": "mit", "parent": 7, "prob": -6.558831691741943}, {"content": "stands", "parent": 8, "prob": -7.424450874328613}, {"content": "is", "parent": 8, "prob": -7.674450874328613}, {"content": "reaches", "parent": 8, "prob": -9.174450874328613},

array([[-3.046875  , -0.45507812, -0.10400391, ...,  0.9609375 ,
        -0.73828125,  3.578125  ],
       [ 0.56640625,  1.21875   , -0.0859375 , ...,  2.078125  ,
         2.        ,  2.953125  ],
       [-1.375     , -0.99609375, -0.51953125, ..., -1.        ,
         0.03222656,  1.4765625 ],
       ...,
       [-1.3984375 ,  0.97265625,  2.171875  , ..., -1.0859375 ,
        -0.44335938,  2.109375  ],
       [ 0.58984375,  1.1953125 ,  0.09033203, ...,  2.109375  ,
         2.140625  ,  2.890625  ],
       [-1.40625   ,  0.96875   ,  2.234375  , ..., -1.015625  ,
        -0.39257812,  2.078125  ]], dtype=float32)


k_mean_space
(20, 2)


array([[78.02924 , 69.18174 ],
       [76.14844 , 63.844078],
       [48.57608 , 70.90965 ],
       [94.03716 , 81.64813 ],
       [82.53905 , 68.351746],
       [71.27621 , 81.96005 ],
       [84.21786 , 72.40042 ],
       [83.79999 , 66.85806 ],
       [84.20218 , 61.58619 ],
       [88.31221 , 66.27467 ],
       [54.593082, 74.042435],
       [47.027225, 73.417854],
       [79.107056, 60.60243 ],
       [75.256645, 57.964096],
       [82.688835, 62.413345],
       [83.79395 , 64.66053 ],
       [93.19177 , 78.286575],
       [98.77252 , 80.44645 ],
       [75.70946 , 63.474804],
       [98.76393 , 80.45846 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-22.00901461, -96.89291549])


closest
(2,)


array([11, 13])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 6.0000, -1.6484, -4.9375,  ...,  0.0000,  0.0000,  0.0000],
        [-0.9609, -6.6250, -7.4375,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.5156, -5.4688, -9.3125,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 1.8594, -3.1875, -3.9375,  ...,  0.0000,  0.0000,  0.0000],
        [-1.1641, -6.9375, -8.3125,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.7344, -3.0312, -4.1250,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[3.7631e-01, 2.9307e-01, 2.9307e-01,  ..., 2.3563e-23, 2.0795e-23,
         7.6499e-24],
        [9.9997e-01, 1.1479e-05, 7.8891e-06,  ..., 4.8765e-23, 2.0328e-23,
         7.4784e-24],
        [4.2169e-01, 3.7214e-01, 1.9919e-01,  ..., 3.6451e-22, 2.8388e-22,
         7.1776e-23],
        ...,
        [9.9999e-01, 6.1442e-06, 5.0434e-07,  ..., 8.4743e-24, 4.5360e-24,
         2.7512e-24],
        [9.9996e-01, 1.6701e-05, 1.3007e-05,  ..., 2.6102e-23, 1.3971e-23,
         7.4783e-24],
        [9.9999e-01, 5.4222e-06, 5.7150e-07,  ..., 9.6026e-24, 5.8243e-24,
         3.1175e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.3763, 0.6694, 0.9624,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.4217, 0.7938, 0.9930,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([36])


tensor([ 0,  0,  0,  1,  2,  2,  2,  3,  4,  4,  4,  4,  5,  6,  7,  7,  7,  8,
         8,  9, 10, 11, 11, 11, 12, 13, 13, 14, 15, 15, 15, 15, 16, 17, 18, 19],
       device='cuda:0')


carryover_candidates
torch.Size([36, 34])


tensor([[    1, 32010,  1724,  ..., 18274,   342, 29889],
        [    1, 32010,  1724,  ..., 18274,   342, 29889],
        [    1, 32010,  1724,  ..., 18274,   342, 29889],
        ...,
        [    1, 32010,  1724,  ...,   338,  8040, 18274],
        [    1, 32010,  1724,  ...,  8040, 18274,   342],
        [    1, 32010,  1724,  ...,   338,  8040, 18274]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([36])


tensor([-2.8014, -2.8014, -2.8014, -3.4277, -4.4495, -4.4495, -4.4495, -4.6995,
        -4.6995, -4.6995, -4.6995, -4.6995, -5.5668, -6.9418, -5.3406, -5.3406,
        -5.3406, -5.8197, -5.8197, -8.4447, -5.4338, -6.5588, -6.5588, -6.5588,
        -7.4245, -7.6745, -7.6745, -9.1745, -7.1894, -7.1894, -7.1894, -7.1894,
        -5.0856, -5.1757, -5.9364, -7.0575], device='cuda:0')


new_candidate_toks
torch.Size([36, 1])


tensor([[  739],
        [ 5976],
        [ 8011],
        [29889],
        [  338],
        [22170],
        [15028],
        [  630],
        [ 3171],
        [19224],
        [ 2533],
        [11858],
        [  310],
        [  297],
        [11858],
        [21210],
        [ 5272],
        [14235],
        [  385],
        [29871],
        [  338],
        [22170],
        [  338],
        [  364],
        [  472],
        [14235],
        [29871],
        [  385],
        [14235],
        [ 5279],
        [22444],
        [17644],
        [  278],
        [  342],
        [29889],
        [  342]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([36])


tensor([-9.7735e-01, -1.2273e+00, -1.2273e+00, -2.7776e-05, -8.6350e-01,
        -9.8850e-01, -1.6135e+00, -3.5763e-07, -8.1152e-01, -1.1865e+00,
        -1.9365e+00, -2.4365e+00,  0.0000e+00, -6.5378e-04, -3.0690e-01,
        -2.1819e+00, -2.5569e+00, -4.3139e-01, -1.0564e+00, -1.0365e-03,
        -4.2316e-02, -1.6608e-01, -3.0411e+00, -3.0411e+00, -3.7485e-02,
        -5.1838e-01, -1.0184e+00, -6.2637e-02, -8.9753e-01, -1.6475e+00,
        -1.7725e+00, -1.8975e+00, -2.0266e-06, -6.9142e-06, -3.6598e-05,
        -6.1989e-06], device='cuda:0')


new_candidates
torch.Size([36, 35])


tensor([[    1, 32010,  1724,  ...,   342, 29889,   739],
        [    1, 32010,  1724,  ...,   342, 29889,  5976],
        [    1, 32010,  1724,  ...,   342, 29889,  8011],
        ...,
        [    1, 32010,  1724,  ...,  8040, 18274,   342],
        [    1, 32010,  1724,  ..., 18274,   342, 29889],
        [    1, 32010,  1724,  ...,  8040, 18274,   342]], device='cuda:0')


new_candidate_logprobs
torch.Size([36])


tensor([-3.7787, -4.0287, -4.0287, -3.4278, -5.3130, -5.4380, -6.0630, -4.6995,
        -5.5110, -5.8860, -6.6360, -7.1360, -5.5668, -6.9425, -5.6475, -7.5225,
        -7.8975, -6.2511, -6.8761, -8.4457, -5.4761, -6.7249, -9.5999, -9.5999,
        -7.4619, -8.1928, -8.6928, -9.2371, -8.0869, -8.8369, -8.9619, -9.0869,
        -5.0856, -5.1757, -5.9364, -7.0575], device='cuda:0')

infer end: GPU memory used: 16813 MB.
event: level
id: 23
data: [{"content": "It", "parent": 0, "prob": -3.778719902038574}, {"content": "Loc", "parent": 0, "prob": -4.028719902038574}, {"content": "Its", "parent": 0, "prob": -4.028719902038574}, {"content": ".", "parent": 1, "prob": -3.4277570247650146}, {"content": "is", "parent": 2, "prob": -5.3130035400390625}, {"content": "reaches", "parent": 2, "prob": -5.4380035400390625}, {"content": "stands", "parent": 2, "prob": -6.0630035400390625}, {"content": "ated", "parent": 3, "prob": -4.699509143829346}, {"content": "height", "parent": 4, "prob": -5.51102352142334}, {"content": "peak", "parent": 4, "prob": -5.88602352142334}, {"content": "sum", "parent": 4, "prob": -6.63602352142334}, {"content": "elev", "parent": 4, "prob": -7.13602352142334}, {"content": "of", "parent": 5, "prob": -5.566843032836914}, {"content": "in", "parent": 6, "prob": -6.9424967765808105}, {"content": "elev", "parent": 7, "prob": -5.647525787353516}, {"content":

array([[-1.3984375 , -0.8515625 , -0.3984375 , ..., -1.109375  ,
        -0.08105469,  1.5078125 ],
       [-1.9375    ,  1.2265625 ,  1.3515625 , ..., -1.7421875 ,
        -0.97265625,  1.28125   ],
       [-2.765625  , -0.875     ,  2.953125  , ...,  0.94140625,
        -2.953125  ,  3.53125   ],
       ...,
       [-0.48828125,  0.5390625 ,  0.53515625, ...,  0.8828125 ,
         1.390625  ,  2.421875  ],
       [-1.4375    , -1.9375    , -0.53125   , ...,  1.1953125 ,
        -1.1328125 ,  2.921875  ],
       [-0.546875  ,  0.9375    ,  2.125     , ...,  2.375     ,
         1.859375  ,  0.65234375]], dtype=float32)


k_mean_space
(20, 2)


array([[65.99383 , 76.91128 ],
       [79.70474 , 90.25883 ],
       [67.32238 , 77.66152 ],
       [69.65406 , 79.18935 ],
       [72.77636 , 61.85229 ],
       [68.56029 , 75.081055],
       [71.28752 , 63.198627],
       [82.37024 , 64.62455 ],
       [66.32462 , 82.83236 ],
       [63.636497, 80.76198 ],
       [72.77289 , 89.37121 ],
       [67.93249 , 90.62088 ],
       [86.24428 , 63.417778],
       [85.57189 , 62.792236],
       [68.06934 , 90.26772 ],
       [79.854935, 93.98855 ],
       [71.26544 , 90.04172 ],
       [75.77947 , 65.63219 ],
       [66.824196, 75.86202 ],
       [82.325806, 68.624596]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-73.81466556, -43.28164625])


closest
(2,)


array([9, 4])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 3.6094, -5.2188, -8.9375,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.1797, -7.4062, -1.0156,  ...,  0.0000,  0.0000,  0.0000],
        [-0.5391, -2.7656, -6.2500,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.6992, -2.9531, -7.5000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0144, -2.2031, -5.5625,  ...,  0.0000,  0.0000,  0.0000],
        [-0.6797,  0.7227, -3.6875,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[4.6998e-01, 3.2302e-01, 1.9592e-01,  ..., 5.2164e-22, 3.5852e-22,
         4.8520e-23],
        [1.0000e+00, 1.8554e-07, 1.1254e-07,  ..., 1.4437e-26, 7.7276e-27,
         6.0183e-27],
        [4.8726e-01, 4.3001e-01, 6.5944e-02,  ..., 6.1283e-22, 4.7728e-22,
         2.8948e-22],
        ...,
        [1.0000e+00, 1.6374e-07, 7.7344e-08,  ..., 4.0030e-24, 1.8909e-24,
         4.7809e-25],
        [5.8532e-01, 3.1330e-01, 4.8046e-02,  ..., 3.0688e-22, 2.3900e-22,
         1.1289e-22],
        [7.5491e-01, 2.4508e-01, 7.0430e-08,  ..., 7.4681e-21, 7.4681e-21,
         3.3139e-21]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.4700, 0.7930, 0.9889,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.4873, 0.9173, 0.9832,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.5853, 0.8986, 0.9467,  ..., 1.0000, 1.0000, 1.0000],
        [0.7549, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([32])


tensor([ 0,  0,  0,  1,  2,  2,  3,  3,  3,  4,  4,  5,  6,  6,  7,  8,  9,  9,
         9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 18, 18, 19, 19],
       device='cuda:0')


carryover_candidates
torch.Size([32, 35])


tensor([[    1, 32010,  1724,  ...,   342, 29889,   739],
        [    1, 32010,  1724,  ...,   342, 29889,   739],
        [    1, 32010,  1724,  ...,   342, 29889,   739],
        ...,
        [    1, 32010,  1724,  ..., 15028,   472,   385],
        [    1, 32010,  1724,  ..., 15028, 14235, 29871],
        [    1, 32010,  1724,  ..., 15028, 14235, 29871]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([32])


tensor([-3.7787, -3.7787, -3.7787, -4.0287, -4.0287, -4.0287, -3.4278, -3.4278,
        -3.4278, -5.3130, -5.3130, -5.4380, -6.0630, -6.0630, -4.6995, -5.5110,
        -5.8860, -5.8860, -5.8860, -6.6360, -7.1360, -5.5668, -6.9425, -5.6475,
        -7.5225, -7.8975, -6.2511, -6.8761, -6.8761, -6.8761, -8.4457, -8.4457],
       device='cuda:0')


new_candidate_toks
torch.Size([32, 1])


tensor([[22170],
        [  338],
        [15028],
        [  630],
        [19224],
        [11858],
        [  739],
        [ 5976],
        [ 8011],
        [  760],
        [ 5982],
        [  385],
        [  472],
        [14235],
        [  297],
        [  338],
        [15028],
        [  338],
        [  364],
        [ 2415],
        [  362],
        [  278],
        [  278],
        [  362],
        [  573],
        [ 4279],
        [29871],
        [21210],
        [11858],
        [24293],
        [29947],
        [29906]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([32])


tensor([-7.5506e-01, -1.1301e+00, -1.6301e+00, -3.5763e-07, -7.1895e-01,
        -8.4395e-01, -9.1569e-01, -1.1657e+00, -1.4157e+00, -2.2819e-01,
        -1.6032e+00, -7.1802e-03, -1.7715e-01, -2.3021e+00, -1.6738e-04,
        -4.3298e-03, -7.7427e-01, -8.9927e-01, -2.5243e+00, -1.6689e-06,
        -2.0266e-06, -2.0981e-05, -2.5034e-06, -4.5300e-06,  0.0000e+00,
        -3.0994e-06, -3.5763e-07, -5.3559e-01, -1.1606e+00, -3.0356e+00,
        -2.8115e-01, -1.4062e+00], device='cuda:0')


new_candidates
torch.Size([32, 36])


tensor([[    1, 32010,  1724,  ..., 29889,   739, 22170],
        [    1, 32010,  1724,  ..., 29889,   739,   338],
        [    1, 32010,  1724,  ..., 29889,   739, 15028],
        ...,
        [    1, 32010,  1724,  ...,   472,   385, 24293],
        [    1, 32010,  1724,  ..., 14235, 29871, 29947],
        [    1, 32010,  1724,  ..., 14235, 29871, 29906]], device='cuda:0')


new_candidate_logprobs
torch.Size([32])


tensor([-4.5338, -4.9088, -5.4088, -4.0287, -4.7477, -4.8727, -4.3434, -4.5934,
        -4.8434, -5.5412, -6.9162, -5.4452, -6.2402, -8.3652, -4.6997, -5.5154,
        -6.6603, -6.7853, -8.4103, -6.6360, -7.1360, -5.5669, -6.9425, -5.6475,
        -7.5225, -7.8975, -6.2511, -7.4117, -8.0367, -9.9117, -8.7269, -9.8519],
       device='cuda:0')

infer end: GPU memory used: 16899 MB.
event: level
id: 24
data: [{"content": "reaches", "parent": 0, "prob": -4.53377628326416}, {"content": "is", "parent": 0, "prob": -4.90877628326416}, {"content": "stands", "parent": 0, "prob": -5.40877628326416}, {"content": "ated", "parent": 1, "prob": -4.028720378875732}, {"content": "peak", "parent": 2, "prob": -4.747668743133545}, {"content": "elev", "parent": 2, "prob": -4.872668743133545}, {"content": "It", "parent": 3, "prob": -4.343447208404541}, {"content": "Loc", "parent": 3, "prob": -4.593447208404541}, {"content": "Its", "parent": 3, "prob": -4.843447208404541}, {"content": "part", "parent": 4, "prob": -5.541192531585693}, {"content": "located", "parent": 4, "prob": -6.916192531585693}, {"content": "an", "parent": 5, "prob": -5.445183753967285}, {"content": "at", "parent": 6, "prob": -6.240150451660156}, {"content": "approximately", "parent": 6, "prob": -8.365150451660156}, {"content": "in", "parent": 7, "prob": -4.699676513671875}, {"c

array([[-1.265625  ,  2.28125   ,  0.40429688, ..., -0.6171875 ,
         0.53125   ,  0.32226562],
       [-0.53515625,  1.5       , -0.41210938, ...,  0.49414062,
         0.56640625,  1.796875  ],
       [-0.6015625 ,  2.09375   ,  0.36132812, ...,  2.078125  ,
         2.359375  ,  1.3671875 ],
       ...,
       [-0.4140625 ,  0.6953125 ,  1.3828125 , ..., -0.65234375,
         1.40625   ,  2.15625   ],
       [-0.67578125,  1.59375   , -0.34570312, ..., -1.3671875 ,
        -0.84765625,  0.92578125],
       [ 0.55078125, -2.09375   ,  0.05297852, ..., -1.0625    ,
        -0.64453125,  1.875     ]], dtype=float32)


k_mean_space
(20, 2)


array([[84.704216, 60.846184],
       [62.692173, 66.17322 ],
       [77.414696, 56.023403],
       [50.789986, 82.46612 ],
       [83.7075  , 60.22111 ],
       [92.62242 , 82.50975 ],
       [77.10047 , 63.650818],
       [75.35546 , 89.554245],
       [84.67168 , 67.34648 ],
       [68.9627  , 86.65834 ],
       [48.851093, 81.49922 ],
       [85.90068 , 64.09535 ],
       [82.6357  , 58.650528],
       [82.75021 , 65.71864 ],
       [69.04335 , 88.15451 ],
       [82.41789 , 65.19605 ],
       [80.71386 , 53.892216],
       [75.93704 , 53.974247],
       [90.5592  , 76.59544 ],
       [83.47874 , 61.458572]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-30.68800545, -82.80752563])


closest
(2,)


array([10, 16])


last_tok_logits
torch.Size([20, 32064])


tensor([[-1.4141, -8.6250, -4.7812,  ...,  0.0000,  0.0000,  0.0000],
        [-0.4316, -3.7188, -5.7812,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.3125, -0.7695, -4.5625,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.6133, -6.1875, -5.7188,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.6250,  0.7461, -1.3281,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.3125, -1.0859, -6.1250,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9074e-01, 6.6756e-03, 2.4558e-03,  ..., 3.1505e-22, 2.4536e-22,
         1.3133e-22],
        [7.7334e-01, 2.2157e-01, 4.0581e-03,  ..., 1.1616e-22, 1.7814e-23,
         1.0805e-23],
        [7.1825e-01, 2.3318e-01, 4.5916e-02,  ..., 3.5728e-21, 7.9720e-22,
         2.4073e-23],
        ...,
        [6.3463e-01, 2.9978e-01, 4.0570e-02,  ..., 1.2486e-20, 7.5728e-21,
         1.6897e-21],
        [1.0000e+00, 2.3824e-07, 1.5230e-08,  ..., 6.1388e-25, 5.4175e-25,
         1.7588e-25],
        [7.3179e-01, 9.9037e-02, 7.7130e-02,  ..., 7.1679e-22, 7.1679e-22,
         5.5823e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9907, 0.9974, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [0.7733, 0.9949, 0.9990,  ..., 1.0000, 1.0000, 1.0000],
        [0.7183, 0.9514, 0.9974,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.6346, 0.9344, 0.9750,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.7318, 0.8308, 0.9080,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([39])


tensor([ 0,  1,  1,  2,  2,  3,  4,  4,  4,  5,  6,  6,  6,  7,  8,  8,  8,  9,
        10, 11, 11, 11, 11, 12, 12, 13, 14, 15, 15, 15, 15, 15, 16, 17, 17, 18,
        19, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([39, 36])


tensor([[    1, 32010,  1724,  ..., 29889,   739, 22170],
        [    1, 32010,  1724,  ..., 29889,   739,   338],
        [    1, 32010,  1724,  ..., 29889,   739,   338],
        ...,
        [    1, 32010,  1724,  ...,  8011,  2533,  2415],
        [    1, 32010,  1724,  ...,  8011,  2533,  2415],
        [    1, 32010,  1724,  ...,  8011,  2533,  2415]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([39])


tensor([-4.5338, -4.9088, -4.9088, -5.4088, -5.4088, -4.0287, -4.7477, -4.7477,
        -4.7477, -4.8727, -4.3434, -4.3434, -4.3434, -4.5934, -4.8434, -4.8434,
        -4.8434, -5.5412, -6.9162, -5.4452, -5.4452, -5.4452, -5.4452, -6.2402,
        -6.2402, -8.3652, -4.6997, -5.5154, -5.5154, -5.5154, -5.5154, -5.5154,
        -6.6603, -6.7853, -6.7853, -8.4103, -6.6360, -6.6360, -6.6360],
       device='cuda:0')


new_candidate_toks
torch.Size([39, 1])


tensor([[  385],
        [  760],
        [ 5982],
        [  472],
        [14235],
        [  297],
        [15028],
        [  338],
        [22170],
        [  362],
        [  338],
        [22170],
        [15028],
        [  630],
        [19224],
        [ 3171],
        [ 2533],
        [  310],
        [  297],
        [11858],
        [21210],
        [24293],
        [ 5272],
        [14235],
        [  385],
        [29871],
        [  278],
        [14235],
        [22444],
        [ 5279],
        [17644],
        [15687],
        [  472],
        [14235],
        [29871],
        [ 4637],
        [22170],
        [  338],
        [  364]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([39])


tensor([-9.3022e-03, -2.5704e-01, -1.5070e+00, -3.3093e-01, -1.4559e+00,
        -1.5248e-04, -6.6369e-01, -1.1637e+00, -2.4137e+00, -1.7881e-06,
        -8.4685e-01, -1.0969e+00, -1.4719e+00, -3.5763e-07, -6.5085e-01,
        -1.2759e+00, -2.1509e+00,  0.0000e+00, -5.8680e-04, -9.1832e-01,
        -1.1683e+00, -1.6683e+00, -2.6683e+00, -2.8546e-01, -1.4105e+00,
        -9.1499e-04, -1.6689e-06, -1.2689e+00, -1.3939e+00, -1.7689e+00,
        -1.7689e+00, -2.2689e+00, -7.4305e-02, -4.5472e-01, -1.2047e+00,
        -2.3842e-07, -3.1226e-01, -2.3123e+00, -2.5623e+00], device='cuda:0')


new_candidates
torch.Size([39, 37])


tensor([[    1, 32010,  1724,  ...,   739, 22170,   385],
        [    1, 32010,  1724,  ...,   739,   338,   760],
        [    1, 32010,  1724,  ...,   739,   338,  5982],
        ...,
        [    1, 32010,  1724,  ...,  2533,  2415, 22170],
        [    1, 32010,  1724,  ...,  2533,  2415,   338],
        [    1, 32010,  1724,  ...,  2533,  2415,   364]], device='cuda:0')


new_candidate_logprobs
torch.Size([39])


tensor([-4.5431, -5.1658, -6.4158, -5.7397, -6.8647, -4.0289, -5.4114, -5.9114,
        -7.1614, -4.8727, -5.1903, -5.4403, -5.8153, -4.5934, -5.4943, -6.1193,
        -6.9943, -5.5412, -6.9168, -6.3635, -6.6135, -7.1135, -8.1135, -6.5256,
        -7.6506, -8.3661, -4.6997, -6.7843, -6.9093, -7.2843, -7.2843, -7.7843,
        -6.7346, -7.2400, -7.9900, -8.4103, -6.9483, -8.9483, -9.1983],
       device='cuda:0')

infer end: GPU memory used: 16989 MB.
event: level
id: 25
data: [{"content": "an", "parent": 0, "prob": -4.543078422546387}, {"content": "part", "parent": 1, "prob": -5.165813446044922}, {"content": "located", "parent": 1, "prob": -6.415813446044922}, {"content": "at", "parent": 2, "prob": -5.739710807800293}, {"content": "approximately", "parent": 2, "prob": -6.864710807800293}, {"content": "in", "parent": 3, "prob": -4.028872966766357}, {"content": "stands", "parent": 4, "prob": -5.4113545417785645}, {"content": "is", "parent": 4, "prob": -5.911354064941406}, {"content": "reaches", "parent": 4, "prob": -7.1613545417785645}, {"content": "ation", "parent": 5, "prob": -4.872670650482178}, {"content": "is", "parent": 6, "prob": -5.190298080444336}, {"content": "reaches", "parent": 6, "prob": -5.440298080444336}, {"content": "stands", "parent": 6, "prob": -5.815298080444336}, {"content": "ated", "parent": 7, "prob": -4.593447685241699}, {"content": "peak", "parent": 8, "prob": -5.49429893

array([[ 0.28125   , -1.9453125 ,  0.07568359, ..., -0.38671875,
        -1.296875  ,  1.9609375 ],
       [ 2.        , -0.4921875 ,  0.89453125, ...,  0.59765625,
         1.375     ,  0.16210938],
       [ 1.265625  ,  0.6953125 , -1.765625  , ..., -1.8984375 ,
        -1.8359375 ,  1.8359375 ],
       ...,
       [-0.828125  , -0.5625    ,  1.015625  , ...,  3.234375  ,
        -0.984375  , -0.46875   ],
       [-1.7421875 , -1.84375   ,  1.8046875 , ...,  1.8046875 ,
        -0.88671875, -1.0390625 ],
       [ 2.234375  , -2.28125   ,  1.0390625 , ...,  2.984375  ,
        -1.078125  ,  0.3828125 ]], dtype=float32)


k_mean_space
(20, 2)


array([[69.03571 , 79.52779 ],
       [78.07524 , 90.333824],
       [65.98451 , 88.66746 ],
       [59.07278 , 83.3934  ],
       [65.36866 , 86.40913 ],
       [70.321396, 94.35501 ],
       [58.500706, 77.49097 ],
       [58.82649 , 76.541626],
       [59.957626, 80.454346],
       [77.85963 , 47.29264 ],
       [61.915066, 82.400375],
       [61.679764, 81.98321 ],
       [57.878326, 80.719185],
       [67.17727 , 89.846176],
       [76.14735 , 56.10939 ],
       [78.58862 , 50.296516],
       [87.86792 , 70.74103 ],
       [71.2702  , 94.46863 ],
       [68.95688 , 93.78552 ],
       [87.51543 , 71.65382 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-84.73937702, -29.84407043])


closest
(2,)


array([12,  9])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.7500, -2.5781, -4.9688,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.2188, -0.2891, -4.3750,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.0156,  1.2422, -4.8438,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.5547, -4.4688, -5.7500,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.6875, -3.4375, -2.3750,  ...,  0.0000,  0.0000,  0.0000],
        [-0.1011, -4.5000,  3.1875,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[6.8422e-01, 1.9603e-01, 7.2116e-02,  ..., 4.8549e-23, 1.2275e-23,
         3.9851e-24],
        [1.0000e+00, 2.9990e-09, 5.9053e-10,  ..., 1.8909e-24, 1.4726e-24,
         4.2191e-25],
        [9.9948e-01, 4.8784e-04, 1.6693e-05,  ..., 1.1093e-21, 7.6244e-22,
         1.3965e-23],
        ...,
        [9.9997e-01, 1.3007e-05, 1.1479e-05,  ..., 7.6282e-22, 6.7318e-22,
         3.3516e-23],
        [1.0000e+00, 1.7603e-06, 2.3824e-07,  ..., 9.1107e-23, 8.0402e-23,
         4.3036e-23],
        [9.9999e-01, 4.7851e-06, 4.4508e-07,  ..., 2.9578e-23, 2.0329e-23,
         1.7940e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.6842, 0.8803, 0.9524,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9995, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([28])


tensor([ 0,  0,  0,  1,  2,  3,  3,  4,  5,  6,  7,  7,  8,  9, 10, 10, 11, 12,
        12, 13, 14, 14, 14, 15, 16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([28, 37])


tensor([[    1, 32010,  1724,  ...,   739, 22170,   385],
        [    1, 32010,  1724,  ...,   739, 22170,   385],
        [    1, 32010,  1724,  ...,   739, 22170,   385],
        ...,
        [    1, 32010,  1724,  ...,   338,   760,   310],
        [    1, 32010,  1724,  ...,   338,  5982,   297],
        [    1, 32010,  1724,  ..., 22170,   385, 11858]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([28])


tensor([-4.5431, -4.5431, -4.5431, -5.1658, -6.4158, -5.7397, -5.7397, -6.8647,
        -4.0289, -5.4114, -5.9114, -5.9114, -7.1614, -4.8727, -5.1903, -5.1903,
        -5.4403, -5.8153, -5.8153, -4.5934, -5.4943, -5.4943, -5.4943, -6.1193,
        -6.9943, -5.5412, -6.9168, -6.3635], device='cuda:0')


new_candidate_toks
torch.Size([28, 1])


tensor([[11858],
        [21210],
        [24293],
        [  310],
        [  297],
        [14235],
        [  385],
        [29871],
        [  278],
        [  472],
        [14235],
        [29871],
        [  385],
        [  338],
        [  760],
        [ 5982],
        [  385],
        [  472],
        [14235],
        [  297],
        [15028],
        [  338],
        [  364],
        [  338],
        [ 2415],
        [  278],
        [  278],
        [  362]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([28])


tensor([-3.7948e-01, -1.6295e+00, -2.6295e+00,  0.0000e+00, -5.2233e-04,
        -1.6273e-01, -1.9127e+00, -8.0898e-04, -2.1458e-06, -9.0614e-02,
        -4.2878e-01, -1.1788e+00, -9.0411e-02, -5.1049e-02, -2.8609e-01,
        -1.4111e+00, -8.2203e-03, -2.3427e-01, -1.8593e+00, -1.3877e-04,
        -8.3857e-01, -9.6357e-01, -2.2136e+00, -7.1033e-03, -2.1458e-06,
        -2.5153e-05, -2.2650e-06, -5.8413e-06], device='cuda:0')


new_candidates
torch.Size([28, 38])


tensor([[    1, 32010,  1724,  ..., 22170,   385, 11858],
        [    1, 32010,  1724,  ..., 22170,   385, 21210],
        [    1, 32010,  1724,  ..., 22170,   385, 24293],
        ...,
        [    1, 32010,  1724,  ...,   760,   310,   278],
        [    1, 32010,  1724,  ...,  5982,   297,   278],
        [    1, 32010,  1724,  ...,   385, 11858,   362]], device='cuda:0')


new_candidate_logprobs
torch.Size([28])


tensor([-4.9226, -6.1726, -7.1726, -5.1658, -6.4163, -5.9024, -7.6524, -6.8655,
        -4.0289, -5.5020, -6.3401, -7.0901, -7.2518, -4.9237, -5.4764, -6.6014,
        -5.4485, -6.0496, -7.6746, -4.5936, -6.3329, -6.4579, -7.7079, -6.1264,
        -6.9943, -5.5412, -6.9168, -6.3635], device='cuda:0')

infer end: GPU memory used: 17081 MB.
event: level
id: 26
data: [{"content": "elev", "parent": 0, "prob": -4.922554016113281}, {"content": "impress", "parent": 0, "prob": -6.172554016113281}, {"content": "aston", "parent": 0, "prob": -7.172554016113281}, {"content": "of", "parent": 1, "prob": -5.165813446044922}, {"content": "in", "parent": 2, "prob": -6.416335582733154}, {"content": "approximately", "parent": 3, "prob": -5.9024434089660645}, {"content": "an", "parent": 3, "prob": -7.6524434089660645}, {"content": "", "parent": 4, "prob": -6.865520000457764}, {"content": "the", "parent": 5, "prob": -4.028875350952148}, {"content": "at", "parent": 6, "prob": -5.501968860626221}, {"content": "approximately", "parent": 7, "prob": -6.340133190155029}, {"content": "", "parent": 7, "prob": -7.090133190155029}, {"content": "an", "parent": 8, "prob": -7.251765727996826}, {"content": "is", "parent": 9, "prob": -4.92371940612793}, {"content": "part", "parent": 10, "prob": -5.4763875007629395}, {

array([[ 2.265625  , -2.125     ,  1.2578125 , ...,  2.984375  ,
        -1.3203125 ,  0.44335938],
       [ 1.03125   , -0.74609375, -2.390625  , ...,  2.703125  ,
         1.09375   , -0.4921875 ],
       [-0.734375  , -2.578125  , -1.59375   , ...,  1.0625    ,
         2.734375  , -0.3828125 ],
       ...,
       [-3.390625  ,  1.9140625 ,  1.859375  , ...,  1.96875   ,
         1.28125   ,  0.9375    ],
       [-0.91796875, -0.51953125,  1.2578125 , ..., -0.43945312,
         0.83203125,  2.015625  ],
       [-1.5234375 , -1.484375  ,  2.        , ...,  2.140625  ,
        -0.49804688, -0.6953125 ]], dtype=float32)


k_mean_space
(20, 2)


array([[ 80.4217  ,  97.80879 ],
       [ 79.532326, 101.53924 ],
       [ 81.27289 , 100.57022 ],
       [ 84.66197 ,  53.657482],
       [ 84.27821 ,  45.739346],
       [ 56.531975,  85.4599  ],
       [ 55.684998,  88.64757 ],
       [ 78.55846 ,  63.928295],
       [ 86.702385,  60.69852 ],
       [ 54.412186,  84.103294],
       [ 58.876812,  87.86296 ],
       [ 78.05056 ,  64.48614 ],
       [ 57.685345,  90.719086],
       [ 62.82198 ,  87.7587  ],
       [ 81.44677 ,  92.290665],
       [ 78.31633 ,  83.02599 ],
       [ 57.782024,  90.60984 ],
       [ 55.47032 ,  84.05182 ],
       [ 60.187683,  87.49336 ],
       [ 85.24978 ,  46.096092]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-87.0905571 , -34.16026402])


closest
(2,)


array([9, 4])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.0903, -3.6562,  3.5781,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.7812, -4.2188,  1.6953,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.0781,  1.5234,  0.2539,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.8984, -4.2812, -4.3125,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.9883, -1.8594, -5.0625,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.4688, -3.5938, -2.7188,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9999e-01, 5.4222e-06, 2.1024e-07,  ..., 1.2330e-23, 9.6026e-24,
         4.5360e-24],
        [1.0000e+00, 3.8507e-09, 1.0262e-10,  ..., 2.5088e-27, 1.9538e-27,
         2.6442e-28],
        [9.9999e-01, 6.9623e-06, 2.5110e-08,  ..., 8.7565e-27, 8.7565e-27,
         6.0182e-27],
        ...,
        [8.1411e-01, 1.8165e-01, 2.9361e-03,  ..., 6.2103e-22, 2.5888e-22,
         1.7793e-22],
        [9.9929e-01, 7.0967e-04, 1.5524e-06,  ..., 3.5301e-24, 2.1411e-24,
         2.8977e-25],
        [1.0000e+00, 1.5535e-06, 4.4508e-07,  ..., 2.4766e-22, 1.9287e-22,
         8.0402e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.8141, 0.9958, 0.9987,  ..., 1.0000, 1.0000, 1.0000],
        [0.9993, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([29])


tensor([ 0,  1,  2,  3,  4,  5,  6,  6,  7,  7,  8,  9, 10, 11, 11, 12, 12, 13,
        13, 14, 15, 16, 16, 16, 16, 17, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([29, 38])


tensor([[    1, 32010,  1724,  ..., 22170,   385, 11858],
        [    1, 32010,  1724,  ..., 22170,   385, 21210],
        [    1, 32010,  1724,  ..., 22170,   385, 24293],
        ...,
        [    1, 32010,  1724,  ...,   739, 15028,   472],
        [    1, 32010,  1724,  ...,   739, 15028, 14235],
        [    1, 32010,  1724,  ...,  5976,   630,   297]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([29])


tensor([-4.9226, -6.1726, -7.1726, -5.1658, -6.4163, -5.9024, -7.6524, -7.6524,
        -6.8655, -6.8655, -4.0289, -5.5020, -6.3401, -7.0901, -7.0901, -7.2518,
        -7.2518, -4.9237, -4.9237, -5.4764, -6.6014, -5.4485, -5.4485, -5.4485,
        -5.4485, -6.0496, -6.0496, -7.6746, -4.5936], device='cuda:0')


new_candidate_toks
torch.Size([29, 1])


tensor([[  362],
        [  573],
        [14424],
        [  278],
        [  278],
        [29871],
        [21210],
        [11858],
        [29947],
        [29906],
        [  379],
        [14235],
        [29871],
        [29947],
        [29906],
        [11858],
        [ 5272],
        [14235],
        [22444],
        [  310],
        [  297],
        [21210],
        [11858],
        [24293],
        [ 5272],
        [14235],
        [  385],
        [29871],
        [  278]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([29])


tensor([-5.8413e-06,  0.0000e+00, -6.9142e-06, -2.2292e-05, -2.5034e-06,
        -7.1526e-07, -2.7811e-01, -1.7781e+00, -4.2870e-01, -1.0537e+00,
        -1.0050e-01, -1.6355e-02, -9.5367e-07, -4.2871e-01, -1.0537e+00,
        -1.1028e-01, -2.8603e+00, -2.4144e-01, -2.1164e+00,  0.0000e+00,
        -4.1070e-04, -1.0235e+00, -1.3985e+00, -1.3985e+00, -2.2735e+00,
        -2.0566e-01, -1.7057e+00, -7.1330e-04, -2.2650e-06], device='cuda:0')


new_candidates
torch.Size([29, 39])


tensor([[    1, 32010,  1724,  ...,   385, 11858,   362],
        [    1, 32010,  1724,  ...,   385, 21210,   573],
        [    1, 32010,  1724,  ...,   385, 24293, 14424],
        ...,
        [    1, 32010,  1724,  ..., 15028,   472,   385],
        [    1, 32010,  1724,  ..., 15028, 14235, 29871],
        [    1, 32010,  1724,  ...,   630,   297,   278]], device='cuda:0')


new_candidate_logprobs
torch.Size([29])


tensor([ -4.9226,  -6.1726,  -7.1726,  -5.1658,  -6.4163,  -5.9024,  -7.9306,
         -9.4306,  -7.2942,  -7.9192,  -4.1294,  -5.5183,  -6.3401,  -7.5188,
         -8.1438,  -7.3620, -10.1120,  -5.1652,  -7.0402,  -5.4764,  -6.6018,
         -6.4720,  -6.8470,  -6.8470,  -7.7220,  -6.2552,  -7.7552,  -7.6753,
         -4.5936], device='cuda:0')

infer end: GPU memory used: 17175 MB.
event: level
id: 27
data: [{"content": "ation", "parent": 0, "prob": -4.92255973815918}, {"content": "ive", "parent": 1, "prob": -6.172554016113281}, {"content": "ishing", "parent": 2, "prob": -7.172561168670654}, {"content": "the", "parent": 3, "prob": -5.165835857391357}, {"content": "the", "parent": 4, "prob": -6.416337966918945}, {"content": "", "parent": 5, "prob": -5.902444362640381}, {"content": "impress", "parent": 6, "prob": -7.930551528930664}, {"content": "elev", "parent": 6, "prob": -9.430551528930664}, {"content": "8", "parent": 7, "prob": -7.294220924377441}, {"content": "2", "parent": 7, "prob": -7.919220924377441}, {"content": "H", "parent": 8, "prob": -4.129379749298096}, {"content": "approximately", "parent": 9, "prob": -5.51832389831543}, {"content": "", "parent": 10, "prob": -6.340134143829346}, {"content": "8", "parent": 11, "prob": -7.5188469886779785}, {"content": "2", "parent": 11, "prob": -8.143847465515137}, {"content": "e

array([[ 0.30273438,  0.14453125,  1.1953125 , ...,  0.984375  ,
        -0.30273438,  3.375     ],
       [ 1.2578125 , -1.3046875 ,  0.75      , ...,  1.140625  ,
        -0.1796875 ,  2.1875    ],
       [ 1.7109375 , -1.265625  ,  0.26757812, ...,  1.0546875 ,
        -0.07763672,  2.6875    ],
       ...,
       [-1.1796875 ,  0.80859375, -0.5859375 , ..., -0.5703125 ,
        -0.12792969,  0.8515625 ],
       [ 1.3515625 ,  1.4375    ,  2.515625  , ..., -0.19726562,
         2.        ,  0.94140625],
       [-0.796875  , -0.58984375,  0.99609375, ...,  3.15625   ,
        -1.015625  , -0.40429688]], dtype=float32)


k_mean_space
(20, 2)


array([[ 97.778076,  68.88869 ],
       [ 92.69312 ,  61.930874],
       [ 93.03903 ,  62.18209 ],
       [ 43.149075,  84.85436 ],
       [ 45.283047,  85.71428 ],
       [ 91.2011  ,  65.42462 ],
       [101.04499 ,  81.133545],
       [100.83461 ,  72.27648 ],
       [ 95.156555,  71.56538 ],
       [ 96.927864,  72.83278 ],
       [ 71.26971 ,  89.682396],
       [ 95.17271 ,  63.8084  ],
       [ 91.7714  ,  65.38654 ],
       [ 95.08813 ,  71.10929 ],
       [ 97.09904 ,  72.711494],
       [100.77658 ,  72.4921  ],
       [100.480896,  74.777534],
       [ 96.09758 ,  64.80738 ],
       [ 97.15695 ,  72.6688  ],
       [ 54.51204 ,  84.17888 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -21.18794107, -113.94522285])


closest
(2,)


array([3, 1])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 2.8594, -3.6250, -1.7344,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1045,  0.3281, -0.4727,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.2207, -1.5625, -0.1309,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 1.0078, -2.7500, -6.5625,  ...,  0.0000,  0.0000,  0.0000],
        [-0.5430, -2.4062,  5.9062,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.7852, -4.3125, -5.2500,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 1.7603e-06, 3.4663e-07,  ..., 2.9578e-23, 1.7940e-23,
         3.5326e-24],
        [8.3082e-01, 1.2741e-01, 3.6504e-02,  ..., 4.0516e-23, 4.0516e-23,
         3.5755e-23],
        [8.5241e-01, 6.9970e-02, 5.4493e-02,  ..., 3.0716e-22, 1.2804e-22,
         7.7661e-23],
        ...,
        [1.0000e+00, 1.2752e-07, 4.6912e-08,  ..., 1.8909e-24, 6.1388e-25,
         3.2859e-25],
        [6.8760e-01, 2.2323e-01, 4.9809e-02,  ..., 1.5329e-20, 1.0535e-20,
         8.6479e-22],
        [9.9998e-01, 1.3007e-05, 1.0130e-05,  ..., 5.2428e-22, 5.2428e-22,
         3.3516e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8308, 0.9582, 0.9947,  ..., 1.0000, 1.0000, 1.0000],
        [0.8524, 0.9224, 0.9769,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.6876, 0.9108, 0.9606,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([25])


tensor([ 0,  1,  1,  2,  2,  3,  4,  5,  5,  6,  7,  8,  9, 10, 11, 12, 12, 13,
        14, 15, 16, 17, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([25, 39])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 11858,   362],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 21210,   573],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 21210,   573],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871


carryover_candidate_logprobs
torch.Size([25])


tensor([ -4.9226,  -6.1726,  -6.1726,  -7.1726,  -7.1726,  -5.1658,  -6.4163,
         -5.9024,  -5.9024,  -7.9306,  -9.4306,  -7.2942,  -7.9192,  -4.1294,
         -5.5183,  -6.3401,  -6.3401,  -7.5188,  -8.1438,  -7.3620, -10.1120,
         -5.1652,  -7.0402,  -7.0402,  -5.4764], device='cuda:0')


new_candidate_toks
torch.Size([25, 1])


tensor([[  310],
        [ 3171],
        [11858],
        [ 3171],
        [11858],
        [  379],
        [  379],
        [29906],
        [29947],
        [  573],
        [  362],
        [29892],
        [29929],
        [ 3039],
        [29871],
        [29947],
        [29906],
        [29892],
        [29929],
        [  362],
        [ 4279],
        [29871],
        [14831],
        [10478],
        [  278]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([25])


tensor([-2.6226e-06, -1.8535e-01, -2.0603e+00, -1.5969e-01, -2.6597e+00,
        -4.1285e-04, -7.9111e-02, -4.7408e-01, -9.7408e-01,  0.0000e+00,
        -4.0770e-05, -2.0266e-06, -1.1921e-07, -2.1458e-06, -1.1921e-07,
        -2.8115e-01, -1.4062e+00, -1.1921e-07, -1.1921e-07, -2.5034e-06,
        -9.5367e-07, -2.3842e-07, -3.7455e-01, -1.4996e+00, -2.3842e-05],
       device='cuda:0')


new_candidates
torch.Size([25, 40])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 11858,   362,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 21210,   573,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 21210,   573, 11858],
        [    1, 32010,  1724,   338,   278,  9939


new_candidate_logprobs
torch.Size([25])


tensor([ -4.9226,  -6.3579,  -8.2329,  -7.3322,  -9.8322,  -5.1662,  -6.4954,
         -6.3765,  -6.8765,  -7.9306,  -9.4306,  -7.2942,  -7.9192,  -4.1294,
         -5.5183,  -6.6213,  -7.7463,  -7.5188,  -8.1438,  -7.3620, -10.1120,
         -5.1652,  -7.4147,  -8.5397,  -5.4764], device='cuda:0')

infer end: GPU memory used: 17271 MB.
event: level
id: 28
data: [{"content": "of", "parent": 0, "prob": -4.922562599182129}, {"content": "height", "parent": 1, "prob": -6.357900142669678}, {"content": "elev", "parent": 1, "prob": -8.232900619506836}, {"content": "height", "parent": 2, "prob": -7.332248687744141}, {"content": "elev", "parent": 2, "prob": -9.83224868774414}, {"content": "H", "parent": 3, "prob": -5.166248798370361}, {"content": "H", "parent": 4, "prob": -6.495448589324951}, {"content": "2", "parent": 5, "prob": -6.376521110534668}, {"content": "8", "parent": 5, "prob": -6.876521110534668}, {"content": "ive", "parent": 6, "prob": -7.930551528930664}, {"content": "ation", "parent": 7, "prob": -9.43059253692627}, {"content": ",", "parent": 8, "prob": -7.294222831726074}, {"content": "9", "parent": 9, "prob": -7.919220924377441}, {"content": "imal", "parent": 10, "prob": -4.129382133483887}, {"content": "", "parent": 11, "prob": -5.51832389831543}, {"content": "8", "parent":

array([[ 0.11767578,  1.484375  ,  2.046875  , ..., -1.2890625 ,
         1.828125  ,  1.359375  ],
       [ 1.7109375 , -0.86328125,  0.25585938, ..., -0.43359375,
        -0.5625    ,  3.234375  ],
       [ 2.        , -2.890625  ,  2.1875    , ...,  2.515625  ,
        -1.078125  ,  0.40429688],
       ...,
       [ 0.89453125, -2.28125   ,  2.25      , ...,  2.546875  ,
         2.65625   ,  2.015625  ],
       [-1.6640625 , -2.953125  , -0.16015625, ...,  1.625     ,
         0.51171875,  2.125     ],
       [ 0.71875   , -0.08007812,  1.1796875 , ...,  0.15136719,
        -0.12890625,  2.96875   ]], dtype=float32)


k_mean_space
(20, 2)


array([[65.199234, 85.374115],
       [49.67933 , 90.667244],
       [66.58165 , 92.822655],
       [49.81128 , 90.46457 ],
       [66.356804, 92.615555],
       [97.09263 , 77.02731 ],
       [97.09672 , 76.97668 ],
       [92.083405, 64.25718 ],
       [90.02295 , 51.41108 ],
       [61.128407, 87.84298 ],
       [49.53425 , 90.187355],
       [87.84377 , 59.67683 ],
       [93.92931 , 61.45652 ],
       [97.16072 , 85.11417 ],
       [77.47883 , 78.20372 ],
       [90.458496, 51.427315],
       [92.60038 , 64.405014],
       [87.244736, 60.646393],
       [93.70486 , 61.46208 ],
       [48.301632, 90.835846]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-66.91937351, -74.28782892])


closest
(2,)


array([19,  8])


last_tok_logits
torch.Size([20, 32064])


tensor([[-1.2656, -4.5000, -4.4688,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.2812, -2.4844, -3.7656,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2070, -2.8125,  5.5000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 3.5781,  2.0781, -1.6016,  ...,  0.0000,  0.0000,  0.0000],
        [-1.9922, -1.0938, -8.6875,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.3750, -3.8906, -2.2812,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.8337e-01, 1.5895e-02, 6.1630e-04,  ..., 1.4015e-21, 7.5015e-22,
         3.5435e-22],
        [1.0000e+00, 3.9279e-07, 7.7344e-08,  ..., 6.1388e-25, 5.4175e-25,
         2.5590e-25],
        [1.0000e+00, 4.1399e-08, 2.8453e-08,  ..., 5.7100e-26, 3.0563e-26,
         1.1244e-26],
        ...,
        [9.9994e-01, 5.1442e-05, 6.1439e-06,  ..., 1.1099e-21, 3.6032e-22,
         8.3075e-26],
        [1.0000e+00, 3.2242e-08, 1.4166e-09,  ..., 1.1469e-24, 7.8824e-25,
         5.4175e-25],
        [1.0000e+00, 1.5535e-06, 1.2752e-07,  ..., 7.4786e-24, 5.8243e-24,
         4.7809e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9834, 0.9993, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([23])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  9,  9, 10, 11, 12, 13, 14, 14,
        15, 16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([23, 40])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 11858,   362,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 21210,   573,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 21210,   573, 11858],
        [    1, 32010,  1724,   338,   278,  9939


carryover_candidate_logprobs
torch.Size([23])


tensor([-4.9226, -6.3579, -8.2329, -7.3322, -9.8322, -5.1662, -6.4954, -6.3765,
        -6.8765, -7.9306, -7.9306, -7.9306, -9.4306, -7.2942, -7.9192, -4.1294,
        -5.5183, -5.5183, -6.6213, -7.7463, -7.5188, -8.1438, -7.3620],
       device='cuda:0')


new_candidate_toks
torch.Size([23, 1])


tensor([[14235],
        [  310],
        [  362],
        [  310],
        [  362],
        [ 3039],
        [ 3039],
        [29929],
        [29892],
        [29871],
        [ 3171],
        [11858],
        [  310],
        [29947],
        [29892],
        [  388],
        [29947],
        [29906],
        [29892],
        [29929],
        [29947],
        [29892],
        [  310]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([23])


tensor([-1.6771e-02, -4.7684e-07,  0.0000e+00, -4.7684e-07, -1.1921e-07,
        -9.5367e-07, -1.3113e-06,  0.0000e+00, -3.5763e-07, -5.9171e-01,
        -1.2167e+00, -2.2167e+00, -1.4305e-06, -4.7684e-06,  0.0000e+00,
        -5.2341e-03, -4.7408e-01, -9.7408e-01, -7.1526e-07, -1.1921e-07,
        -5.7818e-05,  0.0000e+00, -1.9074e-06], device='cuda:0')


new_candidates
torch.Size([23, 41])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 11858,   362,   310,
         14235],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 21210,   573,  3171,
           310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 21210,   573, 11858,
           362],
 


new_candidate_logprobs
torch.Size([23])


tensor([ -4.9393,  -6.3579,  -8.2329,  -7.3322,  -9.8322,  -5.1662,  -6.4955,
         -6.3765,  -6.8765,  -8.5223,  -9.1473, -10.1473,  -9.4306,  -7.2942,
         -7.9192,  -4.1346,  -5.9924,  -6.4924,  -6.6213,  -7.7463,  -7.5189,
         -8.1438,  -7.3620], device='cuda:0')

infer end: GPU memory used: 17369 MB.
event: level
id: 29
data: [{"content": "approximately", "parent": 0, "prob": -4.939333915710449}, {"content": "of", "parent": 1, "prob": -6.357900619506836}, {"content": "ation", "parent": 2, "prob": -8.232900619506836}, {"content": "of", "parent": 3, "prob": -7.332249164581299}, {"content": "ation", "parent": 4, "prob": -9.83224868774414}, {"content": "imal", "parent": 5, "prob": -5.166249752044678}, {"content": "imal", "parent": 6, "prob": -6.495450019836426}, {"content": "9", "parent": 7, "prob": -6.376521110534668}, {"content": ",", "parent": 8, "prob": -6.876521587371826}, {"content": "", "parent": 9, "prob": -8.522257804870605}, {"content": "height", "parent": 9, "prob": -9.147257804870605}, {"content": "elev", "parent": 9, "prob": -10.147258758544922}, {"content": "of", "parent": 10, "prob": -9.430594444274902}, {"content": "8", "parent": 11, "prob": -7.294227600097656}, {"content": ",", "parent": 12, "prob": -7.919220924377441}, {"content":

array([[-0.73828125,  0.3359375 ,  0.36523438, ..., -0.17578125,
         0.97265625,  1.4453125 ],
       [ 0.65234375,  1.3671875 ,  1.6875    , ..., -0.85546875,
         1.8515625 ,  1.625     ],
       [ 0.42382812, -0.95703125,  0.69140625, ...,  0.26367188,
        -0.78125   ,  3.453125  ],
       ...,
       [-0.65625   , -0.640625  ,  2.3125    , ...,  0.19921875,
         0.89453125,  0.11767578],
       [ 0.84765625, -2.515625  ,  2.140625  , ...,  2.46875   ,
         2.03125   ,  1.859375  ],
       [-1.5859375 , -3.421875  , -0.42382812, ...,  1.6015625 ,
         0.56640625,  2.015625  ]], dtype=float32)


k_mean_space
(20, 2)


array([[61.257538, 84.30941 ],
       [49.323547, 82.17404 ],
       [51.43592 , 88.60801 ],
       [49.137306, 82.17531 ],
       [51.896137, 88.709496],
       [94.33668 , 77.67006 ],
       [93.79246 , 77.9339  ],
       [93.59517 , 60.988064],
       [87.228645, 57.79818 ],
       [79.34617 , 73.86726 ],
       [56.89569 , 87.637596],
       [77.91035 , 90.68908 ],
       [50.376293, 82.980804],
       [90.86692 , 66.233444],
       [94.50982 , 71.4735  ],
       [86.34668 , 88.31911 ],
       [89.18586 , 56.7171  ],
       [91.643555, 68.48902 ],
       [87.16658 , 57.867218],
       [93.70446 , 61.11695 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-69.55436039, -75.50282049])


closest
(2,)


array([ 3, 16])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 1.0625, -3.4062, -9.0625,  ...,  0.0000,  0.0000,  0.0000],
        [-0.4199, -3.5312, -3.5781,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.8750, -6.4375, -1.3359,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-2.5312, -0.5312, -5.7500,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.6094,  1.3438, -2.2344,  ...,  0.0000,  0.0000,  0.0000],
        [-2.3750, -1.9844, -8.8125,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 3.3983e-09, 3.3983e-09,  ..., 1.5521e-25, 1.0668e-25,
         8.3079e-26],
        [9.8144e-01, 1.7976e-02, 4.7904e-04,  ..., 4.5410e-22, 3.5365e-22,
         1.8930e-22],
        [1.0000e+00, 1.2752e-07, 2.5110e-08,  ..., 2.2583e-25, 1.3697e-25,
         1.2088e-25],
        ...,
        [1.0000e+00, 6.3488e-09, 4.5991e-10,  ..., 1.6359e-26, 5.3111e-27,
         2.5088e-27],
        [9.9998e-01, 1.8925e-05, 2.5612e-06,  ..., 2.4765e-22, 2.1855e-22,
         6.4701e-26],
        [1.0000e+00, 2.5110e-08, 1.8190e-09,  ..., 6.1388e-25, 6.1388e-25,
         2.2583e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9814, 0.9994, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([22])


tensor([ 0,  1,  2,  3,  4,  5,  5,  6,  7,  8,  9,  9, 10, 11, 12, 13, 14, 15,
        16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([22, 41])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 11858,   362,   310,
         14235],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 21210,   573,  3171,
           310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 21210,   573, 11858,
           362],
 


carryover_candidate_logprobs
torch.Size([22])


tensor([ -4.9393,  -6.3579,  -8.2329,  -7.3322,  -9.8322,  -5.1662,  -5.1662,
         -6.4955,  -6.3765,  -6.8765,  -8.5223,  -8.5223,  -9.1473, -10.1473,
         -9.4306,  -7.2942,  -7.9192,  -4.1346,  -5.9924,  -6.4924,  -6.6213,
         -7.7463], device='cuda:0')


new_candidate_toks
torch.Size([22, 1])


tensor([[29871],
        [14235],
        [  310],
        [14235],
        [  310],
        [ 9010],
        [  388],
        [  388],
        [29892],
        [29947],
        [29906],
        [29947],
        [  310],
        [  362],
        [14235],
        [29946],
        [29900],
        [  294],
        [29892],
        [29929],
        [29947],
        [29892]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([22])


tensor([ 0.0000e+00, -1.8735e-02, -1.1921e-07, -3.0551e-02, -2.3842e-07,
        -1.6023e-01, -1.9102e+00, -1.2510e-02,  0.0000e+00, -9.2984e-06,
        -2.5193e-01, -1.5019e+00, -3.5763e-07, -2.3842e-07, -1.0348e-02,
        -1.1921e-07,  0.0000e+00, -2.6128e-04, -1.1921e-07,  0.0000e+00,
        -2.1577e-05,  0.0000e+00], device='cuda:0')


new_candidates
torch.Size([22, 42])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 11858,   362,   310,
         14235, 29871],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 21210,   573,  3171,
           310, 14235],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 21210,   573, 11858,
    


new_candidate_logprobs
torch.Size([22])


tensor([ -4.9393,  -6.3766,  -8.2329,  -7.3628,  -9.8322,  -5.3265,  -7.0765,
         -6.5080,  -6.3765,  -6.8765,  -8.7742, -10.0242,  -9.1473, -10.1473,
         -9.4409,  -7.2942,  -7.9192,  -4.1349,  -5.9924,  -6.4924,  -6.6213,
         -7.7463], device='cuda:0')

infer end: GPU memory used: 17471 MB.
event: level
id: 30
data: [{"content": "", "parent": 0, "prob": -4.939333915710449}, {"content": "approximately", "parent": 1, "prob": -6.3766350746154785}, {"content": "of", "parent": 2, "prob": -8.232900619506836}, {"content": "approximately", "parent": 3, "prob": -7.362800121307373}, {"content": "of", "parent": 4, "prob": -9.83224868774414}, {"content": "aya", "parent": 5, "prob": -5.326475143432617}, {"content": "ay", "parent": 5, "prob": -7.076475143432617}, {"content": "ay", "parent": 6, "prob": -6.507959842681885}, {"content": ",", "parent": 7, "prob": -6.376521110534668}, {"content": "8", "parent": 8, "prob": -6.87653112411499}, {"content": "2", "parent": 9, "prob": -8.774191856384277}, {"content": "8", "parent": 9, "prob": -10.024191856384277}, {"content": "of", "parent": 10, "prob": -9.147257804870605}, {"content": "ation", "parent": 11, "prob": -10.147258758544922}, {"content": "approximately", "parent": 12, "prob": -9.44094181060791}, {

array([[-0.0703125 ,  0.78125   ,  2.609375  , ...,  2.375     ,
         2.59375   ,  1.1875    ],
       [-0.3984375 ,  0.34179688,  0.83203125, ...,  0.23535156,
         0.8515625 ,  1.296875  ],
       [ 0.5390625 ,  1.1171875 ,  1.953125  , ..., -1.046875  ,
         2.234375  ,  1.4921875 ],
       ...,
       [ 0.83203125, -0.296875  , -1.40625   , ..., -1.71875   ,
        -0.79296875,  1.0546875 ],
       [ 0.90625   , -2.765625  ,  2.171875  , ...,  2.71875   ,
         2.171875  ,  1.765625  ],
       [-1.6171875 , -3.40625   , -0.24707031, ...,  1.8984375 ,
         0.70703125,  2.078125  ]], dtype=float32)


k_mean_space
(20, 2)


array([[77.39951 , 70.78623 ],
       [85.78515 , 52.004368],
       [85.09017 , 49.801872],
       [85.88702 , 52.456463],
       [85.06984 , 49.76431 ],
       [91.883995, 76.31758 ],
       [94.056   , 78.49978 ],
       [92.95781 , 76.39675 ],
       [69.25408 , 91.759224],
       [62.396065, 88.55771 ],
       [65.0766  , 88.421555],
       [57.418827, 84.71892 ],
       [84.81037 , 49.96564 ],
       [89.225296, 70.54723 ],
       [86.412315, 52.49837 ],
       [64.18688 , 85.83003 ],
       [65.01243 , 82.47455 ],
       [92.28941 , 78.53021 ],
       [60.103302, 81.78563 ],
       [62.19979 , 91.39416 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-59.74968624, -88.5251646 ])


closest
(2,)


array([11,  4])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.3652,  0.0601, -4.2500,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.9961, -4.5938, -8.9375,  ...,  0.0000,  0.0000,  0.0000],
        [-0.9727, -4.4375, -3.9688,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-1.8750, -5.1875, -3.9688,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.0781,  1.6797, -2.6719,  ...,  0.0000,  0.0000,  0.0000],
        [-3.5312, -2.3906, -8.6875,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[7.0579e-01, 2.9421e-01, 5.4050e-09,  ..., 4.7513e-22, 4.7513e-22,
         2.8818e-22],
        [1.0000e+00, 7.1941e-09, 3.8507e-09,  ..., 2.2583e-25, 1.3697e-25,
         8.3079e-26],
        [8.6480e-01, 1.3262e-01, 2.4291e-03,  ..., 2.7501e-22, 2.4269e-22,
         6.1362e-23],
        ...,
        [8.3915e-01, 1.4582e-01, 6.4070e-03,  ..., 1.0214e-18, 1.0214e-18,
         3.0844e-20],
        [9.9999e-01, 7.8892e-06, 2.2603e-06,  ..., 3.1799e-22, 1.3256e-22,
         4.4469e-26],
        [1.0000e+00, 4.3635e-09, 4.5991e-10,  ..., 4.2191e-25, 3.7234e-25,
         1.5521e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.7058, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8648, 0.9974, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.8392, 0.9850, 0.9914,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([25])


tensor([ 0,  0,  1,  2,  2,  3,  4,  4,  5,  6,  6,  7,  8,  9, 10, 11, 12, 13,
        14, 15, 16, 17, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([25, 42])


tensor([[    1, 32010,  1724,  ...,   310, 14235, 29871],
        [    1, 32010,  1724,  ...,   310, 14235, 29871],
        [    1, 32010,  1724,  ...,  3171,   310, 14235],
        ...,
        [    1, 32010,  1724,  ...,  3039,   388,   294],
        [    1, 32010,  1724,  ..., 29871, 29947, 29892],
        [    1, 32010,  1724,  ..., 29871, 29906, 29929]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([25])


tensor([ -4.9393,  -4.9393,  -6.3766,  -8.2329,  -8.2329,  -7.3628,  -9.8322,
         -9.8322,  -5.3265,  -7.0765,  -7.0765,  -6.5080,  -6.3765,  -6.8765,
         -8.7742, -10.0242,  -9.1473, -10.1473,  -9.4409,  -7.2942,  -7.9192,
         -4.1349,  -4.1349,  -5.9924,  -6.4924], device='cuda:0')


new_candidate_toks
torch.Size([25, 1])


tensor([[29906],
        [29947],
        [29871],
        [14235],
        [29871],
        [29871],
        [14235],
        [29871],
        [ 3464],
        [  294],
        [  273],
        [  294],
        [29900],
        [29946],
        [29929],
        [29892],
        [14235],
        [  310],
        [29871],
        [29947],
        [29906],
        [  373],
        [29892],
        [29947],
        [29892]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([25])


tensor([-3.4844e-01, -1.2234e+00,  0.0000e+00, -1.4525e-01, -2.0203e+00,
         0.0000e+00, -1.6275e-01, -1.9127e+00, -6.9055e-03, -3.1326e-01,
        -1.3133e+00, -3.3539e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00, -2.6984e-02, -1.1921e-07,  0.0000e+00, -7.0990e-04,
        -1.0021e-01, -1.7536e-01, -1.9254e+00, -1.0252e-05,  0.0000e+00],
       device='cuda:0')


new_candidates
torch.Size([25, 43])


tensor([[    1, 32010,  1724,  ..., 14235, 29871, 29906],
        [    1, 32010,  1724,  ..., 14235, 29871, 29947],
        [    1, 32010,  1724,  ...,   310, 14235, 29871],
        ...,
        [    1, 32010,  1724,  ...,   388,   294, 29892],
        [    1, 32010,  1724,  ..., 29947, 29892, 29947],
        [    1, 32010,  1724,  ..., 29906, 29929, 29892]], device='cuda:0')


new_candidate_logprobs
torch.Size([25])


tensor([ -5.2878,  -6.1628,  -6.3766,  -8.3782, -10.2532,  -7.3628,  -9.9950,
        -11.7450,  -5.3334,  -7.3897,  -8.3897,  -6.5083,  -6.3765,  -6.8765,
         -8.7742, -10.0242,  -9.1742, -10.1473,  -9.4409,  -7.2949,  -8.0194,
         -4.3102,  -6.0602,  -5.9924,  -6.4924], device='cuda:0')

infer end: GPU memory used: 17575 MB.
event: level
id: 31
data: [{"content": "2", "parent": 0, "prob": -5.287778377532959}, {"content": "8", "parent": 0, "prob": -6.162778377532959}, {"content": "", "parent": 1, "prob": -6.3766350746154785}, {"content": "approximately", "parent": 2, "prob": -8.378152847290039}, {"content": "", "parent": 2, "prob": -10.253152847290039}, {"content": "", "parent": 3, "prob": -7.362800121307373}, {"content": "approximately", "parent": 4, "prob": -9.994996070861816}, {"content": "", "parent": 4, "prob": -11.744997024536133}, {"content": "range", "parent": 5, "prob": -5.333380699157715}, {"content": "as", "parent": 6, "prob": -7.389737129211426}, {"content": "an", "parent": 6, "prob": -8.389737129211426}, {"content": "as", "parent": 7, "prob": -6.508295059204102}, {"content": "0", "parent": 8, "prob": -6.376521110534668}, {"content": "4", "parent": 9, "prob": -6.87653112411499}, {"content": "9", "parent": 10, "prob": -8.774191856384277}, {"content": ",", "pa

array([[-0.66796875, -0.625     ,  2.140625  , ...,  0.41210938,
         0.7265625 ,  0.08349609],
       [-0.31640625, -3.640625  ,  1.625     , ...,  2.078125  ,
         2.90625   ,  0.8203125 ],
       [ 0.01660156,  0.61328125,  2.390625  , ...,  2.515625  ,
         2.53125   ,  0.9921875 ],
       ...,
       [-0.34179688,  1.09375   ,  2.078125  , ..., -0.58203125,
         2.171875  ,  1.765625  ],
       [-0.23339844,  0.890625  ,  2.59375   , ...,  2.3125    ,
         2.703125  ,  1.2109375 ],
       [-1.6953125 ,  0.02319336, -0.171875  , ..., -2.640625  ,
         1.75      ,  1.3515625 ]], dtype=float32)


k_mean_space
(20, 2)


array([[101.77995 ,  72.347015],
       [100.95052 ,  67.51747 ],
       [101.01085 ,  48.125725],
       [ 99.35344 ,  61.134945],
       [100.941185,  49.31276 ],
       [100.97314 ,  48.103878],
       [ 99.29146 ,  61.574966],
       [100.919235,  49.281494],
       [ 52.550846,  91.9714  ],
       [ 39.740692,  92.95759 ],
       [ 60.94179 ,  92.16094 ],
       [ 45.080822,  92.41913 ],
       [ 94.88489 ,  72.271545],
       [101.56849 ,  72.68459 ],
       [101.733154,  79.84658 ],
       [100.58074 ,  66.01972 ],
       [ 98.92479 ,  60.68345 ],
       [ 96.80264 ,  63.20805 ],
       [101.04279 ,  48.539764],
       [102.40139 ,  80.084755]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -27.62115002, -133.67010689])


closest
(2,)


array([9, 5])


last_tok_logits
torch.Size([20, 32064])


tensor([[-2.2031, -0.3828, -5.7500,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1240, -2.7812, -6.9062,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2480, -0.1523, -4.0625,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-1.0781, -4.2500, -4.8750,  ...,  0.0000,  0.0000,  0.0000],
        [-0.4004, -0.1426, -4.3438,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.4004, -3.2500, -1.3125,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 2.3356e-09, 4.5991e-10,  ..., 9.9224e-27, 7.7276e-27,
         2.2140e-27],
        [1.0000e+00, 4.1399e-08, 1.4166e-09,  ..., 3.4633e-26, 3.4633e-26,
         1.2741e-26],
        [7.3106e-01, 2.6894e-01, 7.6523e-09,  ..., 4.3432e-22, 3.3825e-22,
         2.0516e-22],
        ...,
        [9.1259e-01, 8.4884e-02, 2.2621e-03,  ..., 2.2601e-22, 2.2601e-22,
         5.0429e-23],
        [7.3106e-01, 2.6894e-01, 4.6413e-09,  ..., 6.3193e-22, 3.3825e-22,
         1.4100e-22],
        [9.7069e-01, 2.9312e-02, 2.1941e-06,  ..., 2.1921e-25, 1.3296e-25,
         2.6181e-26]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.7311, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9126, 0.9975, 0.9997,  ..., 1.0000, 1.0000, 1.0000],
        [0.7311, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9707, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([29])


tensor([ 0,  1,  2,  2,  3,  4,  4,  5,  5,  6,  7,  7,  8,  9, 10, 10, 11, 11,
        11, 12, 12, 13, 14, 15, 16, 17, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([29, 43])


tensor([[    1, 32010,  1724,  ..., 14235, 29871, 29906],
        [    1, 32010,  1724,  ..., 14235, 29871, 29947],
        [    1, 32010,  1724,  ...,   310, 14235, 29871],
        ...,
        [    1, 32010,  1724,  ...,   310, 14235, 29871],
        [    1, 32010,  1724,  ...,   310, 14235, 29871],
        [    1, 32010,  1724,  ..., 29947, 29946, 29947]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([29])


tensor([ -5.2878,  -6.1628,  -6.3766,  -6.3766,  -8.3782, -10.2532, -10.2532,
         -7.3628,  -7.3628,  -9.9950, -11.7450, -11.7450,  -5.3334,  -7.3897,
         -8.3897,  -8.3897,  -6.5083,  -6.5083,  -6.5083,  -6.3765,  -6.3765,
         -6.8765,  -8.7742, -10.0242,  -9.1742, -10.1473,  -9.4409,  -9.4409,
         -7.2949], device='cuda:0')


new_candidate_toks
torch.Size([29, 1])


tensor([[29929],
        [29892],
        [29906],
        [29947],
        [29871],
        [29906],
        [29947],
        [29906],
        [29947],
        [29871],
        [29906],
        [29947],
        [  297],
        [  322],
        [14378],
        [ 3464],
        [  373],
        [  322],
        [29892],
        [29906],
        [29941],
        [29947],
        [29892],
        [29947],
        [29871],
        [14235],
        [29906],
        [29947],
        [29889]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([29])


tensor([ 0.0000e+00,  0.0000e+00, -3.1326e-01, -1.3133e+00,  0.0000e+00,
        -1.4268e-01, -2.0177e+00, -4.2870e-01, -1.0537e+00,  0.0000e+00,
        -1.7975e-01, -1.8047e+00, -1.6527e-02, -3.0549e-02, -3.1953e-01,
        -1.4445e+00, -6.6969e-01, -1.2947e+00, -1.6697e+00, -1.2693e-01,
        -2.1269e+00, -4.3062e-04,  0.0000e+00, -2.7418e-06,  0.0000e+00,
        -9.1471e-02, -3.1326e-01, -1.3133e+00, -2.9753e-02], device='cuda:0')


new_candidates
torch.Size([29, 44])


tensor([[    1, 32010,  1724,  ..., 29871, 29906, 29929],
        [    1, 32010,  1724,  ..., 29871, 29947, 29892],
        [    1, 32010,  1724,  ..., 14235, 29871, 29906],
        ...,
        [    1, 32010,  1724,  ..., 14235, 29871, 29906],
        [    1, 32010,  1724,  ..., 14235, 29871, 29947],
        [    1, 32010,  1724,  ..., 29946, 29947, 29889]], device='cuda:0')


new_candidate_logprobs
torch.Size([29])


tensor([ -5.2878,  -6.1628,  -6.6899,  -7.6899,  -8.3782, -10.3958, -12.2708,
         -7.7915,  -8.4165,  -9.9950, -11.9247, -13.5497,  -5.3499,  -7.4203,
         -8.7093,  -9.8343,  -7.1780,  -7.8030,  -8.1780,  -6.5034,  -8.5034,
         -6.8770,  -8.7742, -10.0242,  -9.1742, -10.2387,  -9.7542, -10.7542,
         -7.3247], device='cuda:0')

infer end: GPU memory used: 17681 MB.
event: level
id: 32
data: [{"content": "9", "parent": 0, "prob": -5.287778377532959}, {"content": ",", "parent": 1, "prob": -6.162778377532959}, {"content": "2", "parent": 2, "prob": -6.689896583557129}, {"content": "8", "parent": 2, "prob": -7.689896583557129}, {"content": "", "parent": 3, "prob": -8.378152847290039}, {"content": "2", "parent": 4, "prob": -10.395828247070312}, {"content": "8", "parent": 4, "prob": -12.270828247070312}, {"content": "2", "parent": 5, "prob": -7.791500568389893}, {"content": "8", "parent": 5, "prob": -8.41650104522705}, {"content": "", "parent": 6, "prob": -9.994996070861816}, {"content": "2", "parent": 7, "prob": -11.924741744995117}, {"content": "8", "parent": 7, "prob": -13.549741744995117}, {"content": "in", "parent": 8, "prob": -5.349907875061035}, {"content": "and", "parent": 9, "prob": -7.420285701751709}, {"content": "mountain", "parent": 10, "prob": -8.709269523620605}, {"content": "range", "parent": 10, "pr

array([[-1.6875000e+00, -3.4218750e+00, -2.6953125e-01, ...,
         1.7578125e+00,  7.8125000e-01,  2.1406250e+00],
       [ 1.0234375e+00, -2.7343750e+00,  2.0781250e+00, ...,
         2.6093750e+00,  2.2187500e+00,  1.6250000e+00],
       [-5.8203125e-01, -6.2109375e-01,  1.9687500e+00, ...,
         2.0214844e-01,  8.5937500e-01,  1.0937500e-01],
       ...,
       [-8.5937500e-01, -6.1718750e-01, -2.2812500e+00, ...,
        -1.4218750e+00,  4.4921875e-01,  2.3559570e-02],
       [-1.3750000e+00,  2.6757812e-01, -3.3906250e+00, ...,
        -3.1093750e+00, -3.8671875e-01,  2.2216797e-02],
       [ 1.0234375e+00, -3.4062500e+00, -2.1209717e-03, ...,
        -2.5390625e-02, -1.0625000e+00,  3.3203125e-01]], dtype=float32)


k_mean_space
(20, 2)


array([[68.57523 , 93.72384 ],
       [63.28459 , 92.69749 ],
       [51.78747 , 92.928406],
       [48.516747, 92.728264],
       [68.82899 , 90.49719 ],
       [52.8757  , 92.89696 ],
       [48.69975 , 92.54777 ],
       [52.048492, 92.98158 ],
       [48.57298 , 92.70879 ],
       [68.80169 , 90.45074 ],
       [53.10762 , 92.88735 ],
       [48.612698, 92.53363 ],
       [94.52488 , 69.76408 ],
       [95.38847 , 56.874302],
       [97.412865, 72.718704],
       [96.37502 , 62.482048],
       [96.71058 , 71.33169 ],
       [95.219185, 54.91192 ],
       [96.5922  , 55.502552],
       [79.12891 , 96.43001 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-115.0560894,  -54.4726882])


closest
(2,)


array([ 3, 17])


last_tok_logits
torch.Size([20, 32064])


tensor([[-3.8750, -2.9219, -9.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.2969,  1.8672, -2.5312,  ...,  0.0000,  0.0000,  0.0000],
        [-2.4844, -0.4980, -6.0938,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-1.2344, -8.5000, -6.6875,  ...,  0.0000,  0.0000,  0.0000],
        [-2.7031, -7.2188, -5.3438,  ...,  0.0000,  0.0000,  0.0000],
        [-2.2344, -4.7812, -3.3281,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 3.3983e-09, 5.2114e-10,  ..., 3.7234e-25, 1.3697e-25,
         7.3317e-26],
        [9.9999e-01, 6.9622e-06, 1.3709e-06,  ..., 3.1799e-22, 1.5021e-22,
         5.7099e-26],
        [1.0000e+00, 1.8190e-09, 2.7895e-10,  ..., 6.8196e-27, 4.1363e-27,
         1.9538e-27],
        ...,
        [4.2808e-01, 2.2913e-01, 1.0824e-01,  ..., 3.9782e-21, 3.5108e-21,
         2.7342e-21],
        [6.1575e-01, 1.5569e-01, 9.4428e-02,  ..., 9.4345e-21, 3.4707e-21,
         2.3854e-21],
        [1.0000e+00, 6.8256e-08, 5.6028e-09,  ..., 8.9319e-25, 8.9319e-25,
         6.9562e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.4281, 0.6572, 0.7654,  ..., 1.0000, 1.0000, 1.0000],
        [0.6157, 0.7714, 0.8659,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([31])


tensor([ 0,  1,  2,  3,  4,  4,  5,  6,  7,  8,  9,  9, 10, 11, 12, 13, 14, 15,
        15, 16, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([31, 44])


tensor([[    1, 32010,  1724,  ..., 29871, 29906, 29929],
        [    1, 32010,  1724,  ..., 29871, 29947, 29892],
        [    1, 32010,  1724,  ..., 14235, 29871, 29906],
        ...,
        [    1, 32010,  1724,  ...,   388,   294, 29892],
        [    1, 32010,  1724,  ...,   388,   294, 29892],
        [    1, 32010,  1724,  ..., 29892, 29900, 29906]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([31])


tensor([ -5.2878,  -6.1628,  -6.6899,  -7.6899,  -8.3782,  -8.3782, -10.3958,
        -12.2708,  -7.7915,  -8.4165,  -9.9950,  -9.9950, -11.9247, -13.5497,
         -5.3499,  -7.4203,  -8.7093,  -9.8343,  -9.8343,  -7.1780,  -7.8030,
         -7.8030,  -7.8030,  -7.8030,  -7.8030,  -7.8030,  -8.1780,  -8.1780,
         -8.1780,  -8.1780,  -6.5034], device='cuda:0')


new_candidate_toks
torch.Size([31, 1])


tensor([[29892],
        [29947],
        [29929],
        [29892],
        [29906],
        [29947],
        [29929],
        [29892],
        [29929],
        [29892],
        [29906],
        [29947],
        [29929],
        [29892],
        [14325],
        [  338],
        [ 3464],
        [  297],
        [  322],
        [  278],
        [22170],
        [  756],
        [  967],
        [15028],
        [12185],
        [  269],
        [  373],
        [13407],
        [  851],
        [ 1546],
        [29929]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([31])


tensor([ 0.0000e+00, -8.4639e-06,  0.0000e+00, -1.1921e-07, -2.8115e-01,
        -1.4062e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.1921e-07,
        -3.4844e-01, -1.2234e+00,  0.0000e+00,  0.0000e+00, -2.9639e-02,
        -7.5117e-02, -1.7405e-05, -1.1505e-01, -2.2401e+00, -1.6689e-06,
        -8.4845e-01, -1.4734e+00, -2.2234e+00, -2.7234e+00, -2.8484e+00,
        -2.9734e+00, -4.8492e-01, -1.8599e+00, -2.3599e+00, -2.9849e+00,
        -1.1921e-07], device='cuda:0')


new_candidates
torch.Size([31, 45])


tensor([[    1, 32010,  1724,  ..., 29906, 29929, 29892],
        [    1, 32010,  1724,  ..., 29947, 29892, 29947],
        [    1, 32010,  1724,  ..., 29871, 29906, 29929],
        ...,
        [    1, 32010,  1724,  ...,   294, 29892,   851],
        [    1, 32010,  1724,  ...,   294, 29892,  1546],
        [    1, 32010,  1724,  ..., 29900, 29906, 29929]], device='cuda:0')


new_candidate_logprobs
torch.Size([31])


tensor([ -5.2878,  -6.1628,  -6.6899,  -7.6899,  -8.6593,  -9.7843, -10.3958,
        -12.2708,  -7.7915,  -8.4165, -10.3434, -11.2184, -11.9247, -13.5497,
         -5.3795,  -7.4954,  -8.7093,  -9.9493, -12.0743,  -7.1780,  -8.6514,
         -9.2764, -10.0264, -10.5264, -10.6514, -10.7764,  -8.6629, -10.0379,
        -10.5379, -11.1629,  -6.5034], device='cuda:0')

infer end: GPU memory used: 17789 MB.
event: level
id: 33
data: [{"content": ",", "parent": 0, "prob": -5.287778377532959}, {"content": "8", "parent": 1, "prob": -6.162786960601807}, {"content": "9", "parent": 2, "prob": -6.689896583557129}, {"content": ",", "parent": 3, "prob": -7.689896583557129}, {"content": "2", "parent": 4, "prob": -8.659302711486816}, {"content": "8", "parent": 4, "prob": -9.784302711486816}, {"content": "9", "parent": 5, "prob": -10.395828247070312}, {"content": ",", "parent": 6, "prob": -12.270828247070312}, {"content": "9", "parent": 7, "prob": -7.791500568389893}, {"content": ",", "parent": 8, "prob": -8.41650104522705}, {"content": "2", "parent": 9, "prob": -10.343441009521484}, {"content": "8", "parent": 9, "prob": -11.218441009521484}, {"content": "9", "parent": 10, "prob": -11.924741744995117}, {"content": ",", "parent": 11, "prob": -13.549741744995117}, {"content": "Asia", "parent": 12, "prob": -5.379547119140625}, {"content": "is", "parent": 13, "prob":

array([[-2.1875    , -1.4765625 , -0.296875  , ...,  1.546875  ,
         0.3203125 ,  1.546875  ],
       [ 0.15039062, -2.09375   ,  0.05639648, ...,  1.59375   ,
        -0.33984375,  0.9296875 ],
       [-1.65625   , -3.359375  , -0.23144531, ...,  1.625     ,
         0.73046875,  2.140625  ],
       ...,
       [-1.3125    ,  0.27929688,  0.87890625, ..., -0.35351562,
         1.296875  ,  0.89453125],
       [-0.40234375, -0.16601562, -0.5390625 , ..., -1.140625  ,
         0.18554688,  0.16210938],
       [-1.78125   ,  3.140625  ,  1.6484375 , ...,  1.7421875 ,
         2.046875  ,  2.90625   ]], dtype=float32)


k_mean_space
(20, 2)


array([[74.4481  , 93.98    ],
       [70.5942  , 93.89801 ],
       [53.826454, 93.92943 ],
       [54.21012 , 92.85838 ],
       [68.08883 , 92.76134 ],
       [47.436253, 92.21164 ],
       [53.91446 , 94.002594],
       [55.22065 , 92.8229  ],
       [53.952915, 93.89214 ],
       [54.246006, 92.79006 ],
       [68.176254, 92.77921 ],
       [47.4886  , 92.22805 ],
       [53.97142 , 93.96721 ],
       [55.470493, 92.81512 ],
       [94.72799 , 60.22534 ],
       [96.016106, 59.72695 ],
       [95.43145 , 56.67249 ],
       [95.16564 , 67.980286],
       [95.99397 , 60.177197],
       [97.77292 , 70.0088  ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-130.18498755,  -50.78586912])


closest
(2,)


array([ 5, 16])


last_tok_logits
torch.Size([20, 32064])


tensor([[-1.4453, -0.8516, -4.5000,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.3750, -0.3359, -0.5625,  ...,  0.0000,  0.0000,  0.0000],
        [-3.6875, -2.9688, -9.3750,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 5.2500, -3.5000, -3.1875,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.2969, -6.7188, -6.8750,  ...,  0.0000,  0.0000,  0.0000],
        [-4.5312, -5.3438, -4.5312,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 4.2778e-11, 2.9401e-11,  ..., 3.0563e-26, 2.3803e-26,
         2.1006e-26],
        [1.0000e+00, 9.9312e-08, 4.8474e-11,  ..., 4.9401e-28, 1.8174e-28,
         5.2068e-29],
        [1.0000e+00, 4.3635e-09, 1.4166e-09,  ..., 2.5590e-25, 1.3697e-25,
         5.0390e-26],
        ...,
        [9.8107e-01, 1.2350e-02, 2.1461e-03,  ..., 2.2309e-18, 1.1941e-18,
         1.0538e-18],
        [7.6809e-01, 9.1736e-02, 8.0956e-02,  ..., 2.0451e-21, 1.4056e-21,
         5.8593e-22],
        [9.9987e-01, 4.5394e-05, 4.0060e-05,  ..., 1.4250e-21, 5.9402e-22,
         4.0827e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9811, 0.9934, 0.9956,  ..., 1.0000, 1.0000, 1.0000],
        [0.7681, 0.8598, 0.9408,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([22])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([22, 45])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 11858,   362,   310,
         14235, 29871, 29906, 29929, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 11858,   362,   310,
         14235, 29871, 29947, 29892, 29947],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   73


carryover_candidate_logprobs
torch.Size([22])


tensor([ -5.2878,  -6.1628,  -6.6899,  -7.6899,  -8.6593,  -9.7843, -10.3958,
        -12.2708,  -7.7915,  -8.4165, -10.3434, -11.2184, -11.9247, -13.5497,
         -5.3795,  -7.4954,  -8.7093,  -9.9493, -12.0743, -12.0743, -12.0743,
         -7.1780], device='cuda:0')


new_candidate_toks
torch.Size([22, 1])


tensor([[29900],
        [29946],
        [29892],
        [29947],
        [29929],
        [29892],
        [29892],
        [29947],
        [29892],
        [29947],
        [29929],
        [29892],
        [29892],
        [29947],
        [  322],
        [ 5982],
        [  322],
        [14325],
        [  338],
        [  967],
        [22170],
        [ 5139]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([22])


tensor([ 0.0000e+00, -1.1921e-07,  0.0000e+00, -1.8239e-05,  0.0000e+00,
         0.0000e+00,  0.0000e+00, -2.5630e-05,  0.0000e+00, -1.4067e-05,
         0.0000e+00, -1.1921e-07,  0.0000e+00, -2.5034e-05, -7.8896e-02,
        -1.7636e-02, -1.0946e-02, -1.9112e-02, -2.6384e-01, -2.3888e+00,
        -2.5138e+00, -1.2506e-04], device='cuda:0')


new_candidates
torch.Size([22, 46])


tensor([[    1, 32010,  1724,  ..., 29929, 29892, 29900],
        [    1, 32010,  1724,  ..., 29892, 29947, 29946],
        [    1, 32010,  1724,  ..., 29906, 29929, 29892],
        ...,
        [    1, 32010,  1724,  ...,  3464,   322,   967],
        [    1, 32010,  1724,  ...,  3464,   322, 22170],
        [    1, 32010,  1724,  ...,   373,   278,  5139]], device='cuda:0')


new_candidate_logprobs
torch.Size([22])


tensor([ -5.2878,  -6.1628,  -6.6899,  -7.6899,  -8.6593,  -9.7843, -10.3958,
        -12.2709,  -7.7915,  -8.4165, -10.3434, -11.2184, -11.9247, -13.5498,
         -5.4584,  -7.5130,  -8.7202,  -9.9684, -12.3382, -14.4632, -14.5882,
         -7.1781], device='cuda:0')

infer end: GPU memory used: 17901 MB.
event: level
id: 34
data: [{"content": "0", "parent": 0, "prob": -5.287778377532959}, {"content": "4", "parent": 1, "prob": -6.162786960601807}, {"content": ",", "parent": 2, "prob": -6.689896583557129}, {"content": "8", "parent": 3, "prob": -7.689914703369141}, {"content": "9", "parent": 4, "prob": -8.659302711486816}, {"content": ",", "parent": 5, "prob": -9.784302711486816}, {"content": ",", "parent": 6, "prob": -10.395828247070312}, {"content": "8", "parent": 7, "prob": -12.270853996276855}, {"content": ",", "parent": 8, "prob": -7.791500568389893}, {"content": "8", "parent": 9, "prob": -8.416515350341797}, {"content": "9", "parent": 10, "prob": -10.343441009521484}, {"content": ",", "parent": 11, "prob": -11.218441009521484}, {"content": ",", "parent": 12, "prob": -11.924741744995117}, {"content": "8", "parent": 13, "prob": -13.549766540527344}, {"content": "and", "parent": 14, "prob": -5.458443641662598}, {"content": "located", "parent": 15, 

array([[ 1.7656250e+00,  1.1015625e+00,  1.5937500e+00, ...,
         1.2187500e+00,  1.3671875e+00, -3.3416748e-03],
       [-1.6640625e+00, -2.8281250e+00,  9.1406250e-01, ...,
        -7.1093750e-01, -1.7578125e-01, -6.0546875e-01],
       [-2.2656250e+00, -1.5234375e+00, -2.6171875e-01, ...,
         1.4375000e+00,  3.3007812e-01,  1.5703125e+00],
       ...,
       [-3.6523438e-01, -5.9326172e-02,  1.8554688e-02, ...,
        -6.4843750e-01, -1.0234375e+00, -2.4218750e-01],
       [-4.7070312e-01,  1.6796875e+00, -8.4375000e-01, ...,
         6.7187500e-01,  8.2812500e-01,  8.5156250e-01],
       [-2.2656250e+00, -7.1093750e-01,  3.4843750e+00, ...,
         1.1640625e+00, -2.9375000e+00,  3.3281250e+00]], dtype=float32)


k_mean_space
(20, 2)


array([[71.74477 , 89.163635],
       [73.789154, 96.01921 ],
       [57.342957, 96.8481  ],
       [56.35747 , 96.294426],
       [63.520348, 96.39402 ],
       [65.357414, 94.55193 ],
       [57.065662, 96.57618 ],
       [56.269695, 95.84833 ],
       [57.17808 , 96.77075 ],
       [56.16857 , 96.18569 ],
       [63.51341 , 96.2971  ],
       [65.10734 , 94.47589 ],
       [57.087486, 96.46628 ],
       [56.22254 , 95.77005 ],
       [94.229515, 46.751297],
       [95.05763 , 63.41826 ],
       [94.43369 , 48.885162],
       [93.86058 , 65.94241 ],
       [93.793915, 53.06645 ],
       [92.12369 , 70.994965]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-130.18507051,  -58.46148443])


closest
(2,)


array([ 9, 14])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 5.5312, -1.4766,  2.1094,  ...,  0.0000,  0.0000,  0.0000],
        [-1.0391,  2.7500, -2.2188,  ...,  0.0000,  0.0000,  0.0000],
        [-1.4453, -1.1172, -4.4375,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.8320, -3.0312, -9.6250,  ...,  0.0000,  0.0000,  0.0000],
        [-0.5312, -5.3438, -7.2812,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.7344, -0.9102, -4.9062,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.2414e-01, 7.5858e-02, 5.9257e-08,  ..., 3.3300e-22, 2.2887e-22,
         1.2251e-22],
        [9.9974e-01, 2.6119e-04, 1.4927e-10,  ..., 2.3796e-26, 9.9198e-27,
         1.3425e-27],
        [1.0000e+00, 7.0529e-11, 2.9401e-11,  ..., 3.4633e-26, 3.4633e-26,
         3.0563e-26],
        ...,
        [9.0464e-01, 9.5348e-02, 1.1767e-05,  ..., 1.4322e-23, 1.4322e-23,
         2.4889e-24],
        [9.6964e-01, 2.2804e-02, 6.5334e-03,  ..., 7.7961e-23, 6.8801e-23,
         4.7286e-23],
        [4.4202e-01, 4.4202e-01, 3.6283e-02,  ..., 2.1987e-21, 9.1657e-22,
         3.3719e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9241, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9997, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9046, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9696, 0.9924, 0.9990,  ..., 1.0000, 1.0000, 1.0000],
        [0.4420, 0.8840, 0.9203,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([27])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 14, 14, 14,
        14, 15, 16, 16, 17, 18, 19, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([27, 46])


tensor([[    1, 32010,  1724,  ..., 29929, 29892, 29900],
        [    1, 32010,  1724,  ..., 29892, 29947, 29946],
        [    1, 32010,  1724,  ..., 29906, 29929, 29892],
        ...,
        [    1, 32010,  1724,  ...,  3464,   322,   967],
        [    1, 32010,  1724,  ...,  3464,   322,   967],
        [    1, 32010,  1724,  ...,  3464,   322,   967]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([27])


tensor([ -5.2878,  -6.1628,  -6.6899,  -7.6899,  -8.6593,  -9.7843, -10.3958,
        -12.2709,  -7.7915,  -8.4165, -10.3434, -11.2184, -11.9247, -13.5498,
         -5.4584,  -5.4584,  -5.4584,  -5.4584,  -5.4584,  -7.5130,  -8.7202,
         -8.7202,  -9.9684, -12.3382, -14.4632, -14.4632, -14.4632],
       device='cuda:0')


new_candidate_toks
torch.Size([27, 1])


tensor([[29906],
        [29947],
        [29900],
        [29946],
        [29892],
        [29947],
        [29900],
        [29946],
        [29900],
        [29946],
        [29892],
        [29947],
        [29900],
        [29946],
        [22170],
        [  269],
        [  338],
        [  756],
        [15028],
        [  373],
        [  338],
        [  967],
        [  322],
        [ 5982],
        [11858],
        [19224],
        [ 2533]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([27])


tensor([-7.8890e-02, -2.6128e-04,  0.0000e+00, -1.1921e-07,  0.0000e+00,
        -2.2054e-05,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.1921e-07,
         0.0000e+00, -1.9074e-05,  0.0000e+00,  0.0000e+00, -7.3447e-01,
        -1.6095e+00, -2.6095e+00, -2.6095e+00, -2.6095e+00, -4.9591e-02,
        -1.6947e-01, -2.5445e+00, -1.0022e-01, -3.0831e-02, -8.1641e-01,
        -8.1641e-01, -3.3164e+00], device='cuda:0')


new_candidates
torch.Size([27, 47])


tensor([[    1, 32010,  1724,  ..., 29892, 29900, 29906],
        [    1, 32010,  1724,  ..., 29947, 29946, 29947],
        [    1, 32010,  1724,  ..., 29929, 29892, 29900],
        ...,
        [    1, 32010,  1724,  ...,   322,   967, 11858],
        [    1, 32010,  1724,  ...,   322,   967, 19224],
        [    1, 32010,  1724,  ...,   322,   967,  2533]], device='cuda:0')


new_candidate_logprobs
torch.Size([27])


tensor([ -5.3667,  -6.1630,  -6.6899,  -7.6899,  -8.6593,  -9.7843, -10.3958,
        -12.2709,  -7.7915,  -8.4165, -10.3434, -11.2185, -11.9247, -13.5498,
         -6.1929,  -7.0679,  -8.0679,  -8.0679,  -8.0679,  -7.5626,  -8.8897,
        -11.2647, -10.0687, -12.3690, -15.2796, -15.2796, -17.7796],
       device='cuda:0')

infer end: GPU memory used: 18015 MB.
event: level
id: 35
data: [{"content": "2", "parent": 0, "prob": -5.366668224334717}, {"content": "8", "parent": 1, "prob": -6.163048267364502}, {"content": "0", "parent": 2, "prob": -6.689896583557129}, {"content": "4", "parent": 3, "prob": -7.689914703369141}, {"content": ",", "parent": 4, "prob": -8.659302711486816}, {"content": "8", "parent": 5, "prob": -9.784324645996094}, {"content": "0", "parent": 6, "prob": -10.395828247070312}, {"content": "4", "parent": 7, "prob": -12.270853996276855}, {"content": "0", "parent": 8, "prob": -7.791500568389893}, {"content": "4", "parent": 9, "prob": -8.416515350341797}, {"content": ",", "parent": 10, "prob": -10.343441009521484}, {"content": "8", "parent": 11, "prob": -11.218460083007812}, {"content": "0", "parent": 12, "prob": -11.924741744995117}, {"content": "4", "parent": 13, "prob": -13.549766540527344}, {"content": "reaches", "parent": 14, "prob": -6.192916393280029}, {"content": "s", "parent": 14, "p

array([[ 0.90234375, -3.359375  ,  0.09863281, ..., -0.19335938,
        -1.2109375 ,  0.36328125],
       [-1.953125  , -0.37890625, -0.1328125 , ..., -2.421875  ,
         1.5390625 ,  1.2734375 ],
       [ 1.890625  ,  1.109375  ,  1.6484375 , ...,  1.28125   ,
         1.5078125 , -0.05004883],
       ...,
       [-1.359375  ,  0.38476562,  1.2265625 , ...,  2.421875  ,
        -1.53125   , -0.16210938],
       [-1.109375  ,  2.1875    ,  0.9609375 , ...,  1.9765625 ,
         2.828125  ,  1.2265625 ],
       [-2.4375    ,  0.15039062,  1.3203125 , ...,  2.28125   ,
         0.5390625 ,  1.0234375 ]], dtype=float32)


k_mean_space
(20, 2)


array([[90.85442 , 69.78236 ],
       [93.94159 , 75.80857 ],
       [79.86469 , 55.424717],
       [90.54539 , 57.14542 ],
       [48.50851 , 82.68598 ],
       [87.92536 , 68.985855],
       [82.14858 , 56.94947 ],
       [90.28612 , 56.531586],
       [79.71657 , 55.26185 ],
       [90.293945, 56.821884],
       [48.33104 , 82.482605],
       [87.841286, 68.90491 ],
       [81.84246 , 56.660736],
       [90.11753 , 56.25722 ],
       [90.98103 , 79.59023 ],
       [74.66222 , 87.49599 ],
       [86.91258 , 80.19817 ],
       [89.45147 , 78.532295],
       [89.17013 , 77.567276],
       [73.29888 , 86.55531 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -33.63328981, -141.65818596])


closest
(2,)


array([10,  8])


last_tok_logits
torch.Size([20, 32064])


tensor([[-2.2812, -5.0625, -3.5312,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0261, -3.0625, -0.8164,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.2188, -1.6172,  1.8516,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-1.9609, -1.1172, -5.5625,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.2500, -1.0078, -5.2188,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.6406, -5.0312,  5.1562,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 4.6912e-08, 4.9445e-09,  ..., 1.1469e-24, 6.9562e-25,
         6.1388e-25],
        [8.1756e-01, 1.8242e-01, 1.2050e-05,  ..., 1.0625e-24, 3.4494e-25,
         2.3707e-25],
        [9.4660e-01, 5.3403e-02, 5.7019e-08,  ..., 3.8651e-22, 3.4109e-22,
         1.1074e-22],
        ...,
        [9.5782e-01, 4.2084e-02, 5.5836e-05,  ..., 1.6303e-22, 1.2697e-22,
         5.2929e-23],
        [9.0750e-01, 5.1198e-02, 3.9873e-02,  ..., 5.3914e-22, 3.2701e-22,
         4.4255e-23],
        [1.0000e+00, 1.0677e-06, 6.4759e-07,  ..., 1.5021e-22, 1.5021e-22,
         9.6026e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8176, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9466, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9578, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9075, 0.9587, 0.9986,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([22])


tensor([ 0,  1,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([22, 47])


tensor([[    1, 32010,  1724,  ..., 29892, 29900, 29906],
        [    1, 32010,  1724,  ..., 29947, 29946, 29947],
        [    1, 32010,  1724,  ..., 29947, 29946, 29947],
        ...,
        [    1, 32010,  1724,  ..., 14325,   322,   756],
        [    1, 32010,  1724,  ..., 14325,   322, 15028],
        [    1, 32010,  1724,  ...,   338,  5982,   373]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([22])


tensor([ -5.3667,  -6.1630,  -6.1630,  -6.6899,  -7.6899,  -8.6593,  -9.7843,
        -10.3958, -12.2709,  -7.7915,  -8.4165, -10.3434, -11.2185, -11.9247,
        -13.5498,  -6.1929,  -7.0679,  -8.0679,  -8.0679,  -8.0679,  -8.0679,
         -7.5626], device='cuda:0')


new_candidate_toks
torch.Size([22, 1])


tensor([[29929],
        [29889],
        [27881],
        [29906],
        [29947],
        [29900],
        [29946],
        [29906],
        [29947],
        [29906],
        [29947],
        [29900],
        [29946],
        [29906],
        [29947],
        [  385],
        [ 1169],
        [ 5982],
        [ 7258],
        [  385],
        [  472],
        [  278]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([22])


tensor([ 0.0000e+00, -2.0143e-01, -1.7014e+00, -5.4882e-02, -5.5299e-04,
         0.0000e+00, -1.1921e-07, -3.8041e-02, -1.0896e-04, -3.8042e-02,
        -7.0990e-04,  0.0000e+00, -1.1921e-07, -2.9750e-02, -1.5844e-04,
        -4.2458e-03,  0.0000e+00, -1.1816e-01, -2.9932e+00, -4.3095e-02,
        -9.7064e-02, -4.8876e-06], device='cuda:0')


new_candidates
torch.Size([22, 48])


tensor([[    1, 32010,  1724,  ..., 29900, 29906, 29929],
        [    1, 32010,  1724,  ..., 29946, 29947, 29889],
        [    1, 32010,  1724,  ..., 29946, 29947, 27881],
        ...,
        [    1, 32010,  1724,  ...,   322,   756,   385],
        [    1, 32010,  1724,  ...,   322, 15028,   472],
        [    1, 32010,  1724,  ...,  5982,   373,   278]], device='cuda:0')


new_candidate_logprobs
torch.Size([22])


tensor([ -5.3667,  -6.3645,  -7.8645,  -6.7448,  -7.6905,  -8.6593,  -9.7843,
        -10.4339, -12.2710,  -7.8295,  -8.4172, -10.3434, -11.2185, -11.9545,
        -13.5499,  -6.1972,  -7.0679,  -8.1861, -11.0611,  -8.1110,  -8.1650,
         -7.5626], device='cuda:0')

infer end: GPU memory used: 18131 MB.
event: level
id: 36
data: [{"content": "9", "parent": 0, "prob": -5.366668224334717}, {"content": ".", "parent": 1, "prob": -6.364477634429932}, {"content": "meters", "parent": 1, "prob": -7.86447811126709}, {"content": "2", "parent": 2, "prob": -6.744779109954834}, {"content": "8", "parent": 3, "prob": -7.690467834472656}, {"content": "0", "parent": 4, "prob": -8.659302711486816}, {"content": "4", "parent": 5, "prob": -9.784324645996094}, {"content": "2", "parent": 6, "prob": -10.433869361877441}, {"content": "8", "parent": 7, "prob": -12.270962715148926}, {"content": "2", "parent": 8, "prob": -7.82954216003418}, {"content": "8", "parent": 9, "prob": -8.417224884033203}, {"content": "0", "parent": 10, "prob": -10.343441009521484}, {"content": "4", "parent": 11, "prob": -11.218460083007812}, {"content": "2", "parent": 12, "prob": -11.954492568969727}, {"content": "8", "parent": 13, "prob": -13.549924850463867}, {"content": "an", "parent": 14, "prob

array([[-1.5390625e+00, -5.8203125e-01,  8.5937500e-01, ...,
        -1.8281250e+00,  4.6289062e-01,  2.1718750e+00],
       [-4.3359375e-01, -1.1328125e+00,  1.6171875e+00, ...,
        -1.9836426e-03,  1.6328125e+00, -1.2451172e-02],
       [-7.8125000e-01, -1.9375000e+00,  1.0546875e+00, ...,
        -5.0781250e-01,  1.0253906e-01,  2.7734375e-01],
       ...,
       [ 1.3750000e+00,  5.0390625e-01, -1.8281250e+00, ...,
        -1.0312500e+00, -6.5625000e-01,  2.2656250e-01],
       [-4.0000000e+00,  2.6171875e-01,  7.4609375e-01, ...,
         1.0078125e+00, -2.3125000e+00, -9.3750000e-01],
       [-5.3125000e-01, -1.1953125e+00,  9.1796875e-01, ...,
         1.0078125e+00, -2.6250000e+00,  4.0625000e+00]], dtype=float32)


k_mean_space
(20, 2)


array([[62.857456, 89.67278 ],
       [85.91454 , 72.372826],
       [72.050545, 90.634674],
       [86.73931 , 42.0552  ],
       [51.74243 , 86.51598 ],
       [80.67085 , 61.05726 ],
       [83.31906 , 57.4002  ],
       [86.36425 , 41.867336],
       [52.47704 , 87.426575],
       [86.5737  , 41.88604 ],
       [51.703297, 86.5331  ],
       [80.49314 , 60.834038],
       [83.00199 , 57.12631 ],
       [86.353775, 41.721085],
       [52.58706 , 87.36773 ],
       [75.13521 , 98.47288 ],
       [73.4977  , 97.84301 ],
       [75.384415, 98.61621 ],
       [78.037315, 99.376526],
       [74.487   , 96.52552 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-95.78297663, -83.33268929])


closest
(2,)


array([10, 13])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 1.2188, -3.1250, -1.3672,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.8320, -0.4297, -1.9688,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.5938,  0.7734, -4.3750,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 2.9844,  3.6406, -0.9570,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.5625, -2.6719, -4.7812,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3516, -0.4766, -6.6250,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 1.5535e-06, 9.4224e-07,  ..., 1.6687e-24, 8.9318e-25,
         6.1388e-25],
        [1.0000e+00, 5.7150e-07, 1.8190e-09,  ..., 1.8538e-26, 7.7276e-27,
         1.7243e-27],
        [9.9825e-01, 1.3245e-03, 4.2999e-04,  ..., 1.1449e-24, 1.0103e-24,
         1.0103e-24],
        ...,
        [9.8770e-01, 4.5740e-03, 3.5622e-03,  ..., 3.6303e-20, 2.8273e-20,
         1.5134e-20],
        [9.3200e-01, 6.7514e-02, 1.4769e-04,  ..., 8.6613e-21, 5.9528e-21,
         4.0913e-21],
        [8.7432e-01, 1.1833e-01, 4.0489e-03,  ..., 3.3206e-23, 1.5685e-23,
         9.5137e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9982, 0.9996, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9877, 0.9923, 0.9958,  ..., 1.0000, 1.0000, 1.0000],
        [0.9320, 0.9995, 0.9997,  ..., 1.0000, 1.0000, 1.0000],
        [0.8743, 0.9926, 0.9967,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([25])


tensor([ 0,  1,  2,  3,  4,  4,  5,  6,  7,  8,  9, 10, 10, 11, 12, 13, 14, 15,
        15, 16, 16, 17, 18, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([25, 48])


tensor([[    1, 32010,  1724,  ..., 29900, 29906, 29929],
        [    1, 32010,  1724,  ..., 29946, 29947, 29889],
        [    1, 32010,  1724,  ..., 29946, 29947, 27881],
        ...,
        [    1, 32010,  1724,  ...,   322,   338,  7258],
        [    1, 32010,  1724,  ...,   322,   756,   385],
        [    1, 32010,  1724,  ...,   322,   756,   385]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([25])


tensor([ -5.3667,  -6.3645,  -7.8645,  -6.7448,  -7.6905,  -7.6905,  -8.6593,
         -9.7843, -10.4339, -12.2710,  -7.8295,  -8.4172,  -8.4172, -10.3434,
        -11.2185, -11.9545, -13.5499,  -6.1972,  -6.1972,  -7.0679,  -7.0679,
         -8.1861, -11.0611,  -8.1110,  -8.1110], device='cuda:0')


new_candidate_toks
torch.Size([25, 1])


tensor([[ 6900],
        [29947],
        [  313],
        [29929],
        [29889],
        [27881],
        [29906],
        [29947],
        [29929],
        [29889],
        [29929],
        [29889],
        [27881],
        [29906],
        [29947],
        [29929],
        [29889],
        [11858],
        [21210],
        [  373],
        [19434],
        [  373],
        [  491],
        [ 6221],
        [11858]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([25])


tensor([-4.0531e-06, -5.9605e-07, -1.7563e-03, -1.1921e-07, -2.0144e-01,
        -1.7014e+00, -5.4882e-02, -7.0990e-04,  0.0000e+00, -1.3259e-03,
        -1.1921e-07, -1.6026e-01, -1.9103e+00, -4.2999e-02, -7.0990e-04,
         0.0000e+00, -1.5023e-03, -1.6369e-01, -2.7887e+00, -1.5946e-01,
        -2.7845e+00, -1.2375e-02, -7.0425e-02, -1.3431e-01, -2.1343e+00],
       device='cuda:0')


new_candidates
torch.Size([25, 49])


tensor([[    1, 32010,  1724,  ..., 29906, 29929,  6900],
        [    1, 32010,  1724,  ..., 29947, 29889, 29947],
        [    1, 32010,  1724,  ..., 29947, 27881,   313],
        ...,
        [    1, 32010,  1724,  ...,   338,  7258,   491],
        [    1, 32010,  1724,  ...,   756,   385,  6221],
        [    1, 32010,  1724,  ...,   756,   385, 11858]], device='cuda:0')


new_candidate_logprobs
torch.Size([25])


tensor([ -5.3667,  -6.3645,  -7.8662,  -6.7448,  -7.8919,  -9.3919,  -8.7142,
         -9.7850, -10.4339, -12.2723,  -7.8295,  -8.5775, -10.3275, -10.3864,
        -11.2192, -11.9545, -13.5514,  -6.3609,  -8.9859,  -7.2274,  -9.8524,
         -8.1985, -11.1315,  -8.2453, -10.2453], device='cuda:0')

infer end: GPU memory used: 18249 MB.
event: level
id: 37
data: [{"content": "feet", "parent": 0, "prob": -5.366672515869141}, {"content": "8", "parent": 1, "prob": -6.36447811126709}, {"content": "(", "parent": 2, "prob": -7.866234302520752}, {"content": "9", "parent": 3, "prob": -6.744779109954834}, {"content": ".", "parent": 4, "prob": -7.891907215118408}, {"content": "meters", "parent": 4, "prob": -9.39190673828125}, {"content": "2", "parent": 5, "prob": -8.714184761047363}, {"content": "8", "parent": 6, "prob": -9.7850341796875}, {"content": "9", "parent": 7, "prob": -10.433869361877441}, {"content": ".", "parent": 8, "prob": -12.27228832244873}, {"content": "9", "parent": 9, "prob": -7.82954216003418}, {"content": ".", "parent": 10, "prob": -8.57748794555664}, {"content": "meters", "parent": 10, "prob": -10.32748794555664}, {"content": "2", "parent": 11, "prob": -10.38644027709961}, {"content": "8", "parent": 12, "prob": -11.219169616699219}, {"content": "9", "parent": 13, "prob"

array([[ 0.828125  , -1.9609375 ,  2.015625  , ...,  2.046875  ,
         0.62890625,  0.66015625],
       [-0.29882812, -1.0625    ,  1.0390625 , ..., -3.0625    ,
         0.19140625, -0.18457031],
       [-0.10107422,  2.265625  ,  0.71875   , ...,  1.046875  ,
         1.375     ,  1.3125    ],
       ...,
       [ 1.9609375 , -1.4140625 ,  0.83203125, ...,  3.109375  ,
        -1.5390625 ,  0.69140625],
       [ 1.421875  , -0.6875    , -2.53125   , ...,  2.953125  ,
         0.8046875 , -0.36914062],
       [-1.90625   ,  0.45703125,  1.578125  , ...,  2.        ,
         1.3125    ,  0.88671875]], dtype=float32)


k_mean_space
(20, 2)


array([[ 66.150955,  94.40137 ],
       [ 68.04461 ,  81.38147 ],
       [ 86.1845  ,  81.53655 ],
       [ 48.27027 ,  91.44191 ],
       [ 85.38873 ,  38.518528],
       [ 64.54642 ,  90.53494 ],
       [ 86.231285,  66.80202 ],
       [ 52.83819 ,  86.73454 ],
       [ 48.61215 ,  92.0104  ],
       [ 85.76631 ,  38.953835],
       [ 48.135406,  91.34095 ],
       [ 85.23058 ,  38.20132 ],
       [ 64.51202 ,  90.56622 ],
       [ 86.03888 ,  66.692444],
       [ 53.017574,  86.74403 ],
       [ 48.34037 ,  91.82411 ],
       [ 85.568504,  38.394054],
       [ 85.81543 ,  99.483536],
       [ 88.23655 , 101.50904 ],
       [ 85.965805,  98.82294 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-111.99151945,  -69.25996971])


closest
(2,)


array([10, 11])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 1.9688,  0.3691, -5.3125,  ...,  0.0000,  0.0000,  0.0000],
        [-1.9766, -2.4062, -2.1406,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.6719,  0.5312, -4.8750,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.0496, -4.4062,  3.7031,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.3438, -3.5938,  1.6875,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.5859, -3.2812,  4.6562,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9923e-01, 4.3041e-04, 3.3521e-04,  ..., 1.6674e-24, 1.1460e-24,
         6.9508e-25],
        [1.0000e+00, 1.7603e-06, 1.4450e-07,  ..., 2.8428e-27, 1.5217e-27,
         1.1851e-27],
        [9.9743e-01, 2.4724e-03, 7.4660e-05,  ..., 4.8641e-23, 4.2926e-23,
         2.2976e-23],
        ...,
        [9.9999e-01, 7.8893e-06, 3.0590e-07,  ..., 3.1175e-24, 2.4279e-24,
         3.2858e-25],
        [1.0000e+00, 1.9556e-08, 4.5991e-10,  ..., 1.4437e-26, 1.2741e-26,
         1.9538e-27],
        [9.9991e-01, 5.8289e-05, 2.1443e-05,  ..., 5.2424e-22, 2.4763e-22,
         1.0323e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9992, 0.9997, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9974, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([22])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  7,  8,  9, 10, 11, 12, 13, 14, 14, 15,
        16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([22, 49])


tensor([[    1, 32010,  1724,  ..., 29906, 29929,  6900],
        [    1, 32010,  1724,  ..., 29947, 29889, 29947],
        [    1, 32010,  1724,  ..., 29947, 27881,   313],
        ...,
        [    1, 32010,  1724,  ..., 22170,   385, 11858],
        [    1, 32010,  1724,  ..., 22170,   385, 21210],
        [    1, 32010,  1724,  ...,   269,  1169,   373]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([22])


tensor([ -5.3667,  -6.3645,  -7.8662,  -6.7448,  -7.8919,  -9.3919,  -8.7142,
         -9.7850,  -9.7850, -10.4339, -12.2723,  -7.8295,  -8.5775, -10.3275,
        -10.3864, -11.2192, -11.2192, -11.9545, -13.5514,  -6.3609,  -8.9859,
         -7.2274], device='cuda:0')


new_candidate_toks
torch.Size([22, 1])


tensor([[  313],
        [29953],
        [29906],
        [ 6900],
        [29947],
        [  313],
        [29929],
        [29889],
        [27881],
        [ 6900],
        [29947],
        [ 6900],
        [29947],
        [  313],
        [29929],
        [29889],
        [27881],
        [ 6900],
        [29947],
        [  362],
        [  573],
        [  278]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([22])


tensor([-7.6597e-04, -1.9074e-06, -2.5708e-03, -3.9339e-06, -5.9605e-07,
        -1.5319e-02, -1.1921e-07, -3.1329e-01, -1.3133e+00, -9.2153e-05,
        -1.1921e-07, -5.6029e-06, -4.7684e-07, -1.9627e-02,  0.0000e+00,
        -1.6025e-01, -1.9103e+00, -1.9189e-04, -1.1921e-07, -8.3447e-06,
         0.0000e+00, -9.4657e-05], device='cuda:0')


new_candidates
torch.Size([22, 50])


tensor([[    1, 32010,  1724,  ..., 29929,  6900,   313],
        [    1, 32010,  1724,  ..., 29889, 29947, 29953],
        [    1, 32010,  1724,  ..., 27881,   313, 29906],
        ...,
        [    1, 32010,  1724,  ...,   385, 11858,   362],
        [    1, 32010,  1724,  ...,   385, 21210,   573],
        [    1, 32010,  1724,  ...,  1169,   373,   278]], device='cuda:0')


new_candidate_logprobs
torch.Size([22])


tensor([ -5.3674,  -6.3645,  -7.8688,  -6.7448,  -7.8919,  -9.4072,  -8.7142,
        -10.0983, -11.0983, -10.4340, -12.2723,  -7.8295,  -8.5775, -10.3471,
        -10.3864, -11.3794, -13.1294, -11.9547, -13.5514,  -6.3609,  -8.9859,
         -7.2275], device='cuda:0')

infer end: GPU memory used: 18369 MB.
event: level
id: 38
data: [{"content": "(", "parent": 0, "prob": -5.367438316345215}, {"content": "6", "parent": 1, "prob": -6.364480018615723}, {"content": "2", "parent": 2, "prob": -7.868804931640625}, {"content": "feet", "parent": 3, "prob": -6.7447829246521}, {"content": "8", "parent": 4, "prob": -7.891907691955566}, {"content": "(", "parent": 5, "prob": -9.407225608825684}, {"content": "9", "parent": 6, "prob": -8.714184761047363}, {"content": ".", "parent": 7, "prob": -10.098323822021484}, {"content": "meters", "parent": 7, "prob": -11.098323822021484}, {"content": "feet", "parent": 8, "prob": -10.433961868286133}, {"content": "8", "parent": 9, "prob": -12.27228832244873}, {"content": "feet", "parent": 10, "prob": -7.829547882080078}, {"content": "8", "parent": 11, "prob": -8.577488899230957}, {"content": "(", "parent": 12, "prob": -10.347114562988281}, {"content": "9", "parent": 13, "prob": -10.38644027709961}, {"content": ".", "parent": 14,

array([[ 0.25390625,  0.06445312, -0.07958984, ...,  0.42382812,
        -0.1484375 ,  0.6328125 ],
       [ 0.484375  ,  0.11376953, -1.5078125 , ..., -1.546875  ,
         0.625     ,  2.515625  ],
       [-0.36132812, -3.        ,  0.24023438, ..., -1.0234375 ,
         0.05053711,  1.3359375 ],
       ...,
       [ 1.15625   , -1.7265625 ,  1.2890625 , ...,  2.109375  ,
         0.24609375,  1.484375  ],
       [-0.38476562, -0.6796875 ,  1.484375  , ..., -2.75      ,
         0.43554688,  0.08056641],
       [-0.39648438,  0.703125  ,  1.546875  , ...,  1.2109375 ,
        -0.34960938,  3.09375   ]], dtype=float32)


k_mean_space
(20, 2)


array([[88.95135 , 73.04549 ],
       [58.491215, 85.555504],
       [83.54263 , 91.31858 ],
       [86.802574, 39.640842],
       [41.295544, 90.04743 ],
       [88.86996 , 69.69746 ],
       [62.52221 , 84.964806],
       [67.74807 , 91.06437 ],
       [83.968155, 43.926132],
       [86.28345 , 39.437527],
       [41.62991 , 90.03608 ],
       [86.76772 , 39.71321 ],
       [41.209255, 89.93619 ],
       [88.89813 , 69.68932 ],
       [62.492107, 85.09029 ],
       [67.4075  , 90.80111 ],
       [84.056694, 43.816357],
       [86.29057 , 39.685375],
       [41.680553, 90.02525 ],
       [95.05013 , 82.45187 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-97.10476589, -92.67336369])


closest
(2,)


array([12,  9])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 4.3125,  2.2188,  1.0781,  ...,  0.0000,  0.0000,  0.0000],
        [-2.4062, -4.4375,  0.2422,  ...,  0.0000,  0.0000,  0.0000],
        [-1.4297, -4.4688, -6.0938,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 1.4453,  0.3340, -5.3125,  ...,  0.0000,  0.0000,  0.0000],
        [-0.8906, -1.3047, -1.8047,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.3438, -1.6953, -1.5312,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9532e-01, 4.6092e-03, 5.8022e-05,  ..., 1.7856e-23, 1.0830e-23,
         1.1415e-24],
        [1.0000e+00, 5.0435e-07, 2.3824e-07,  ..., 2.3803e-26, 9.9224e-27,
         7.7276e-27],
        [1.0000e+00, 8.1520e-09, 1.2502e-09,  ..., 5.4175e-25, 1.0668e-25,
         5.0390e-26],
        ...,
        [9.8929e-01, 6.6658e-03, 4.0430e-03,  ..., 1.2198e-23, 8.3836e-24,
         2.7217e-24],
        [1.0000e+00, 9.4224e-07, 3.9279e-07,  ..., 9.2293e-28, 4.9401e-28,
         4.9401e-28],
        [1.0000e+00, 1.9947e-06, 9.4224e-07,  ..., 4.8766e-23, 1.5832e-23,
         1.2330e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9953, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9893, 0.9960, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([20])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19], device='cuda:0')


carryover_candidates
torch.Size([20, 50])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 11858,   362,   310,
         14235, 29871, 29906, 29929, 29892, 29900, 29906, 29929,  6900,   313],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,  7205,  3233, 29892,   338,
          8040, 18274,   342, 29889,   739, 22170,   385, 11858,   362,   310,
         14235, 29871, 29947, 29892, 29947, 29946, 29947, 29889, 29947, 29953],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408, 17005,   491,
           967, 19224, 29915, 29879,  3171,  2038,


carryover_candidate_logprobs
torch.Size([20])


tensor([ -5.3674,  -6.3645,  -7.8688,  -6.7448,  -7.8919,  -9.4072,  -8.7142,
        -10.0983, -11.0983, -10.4340, -12.2723,  -7.8295,  -8.5775, -10.3471,
        -10.3864, -11.3794, -13.1294, -11.9547, -13.5514,  -6.3609],
       device='cuda:0')


new_candidate_toks
torch.Size([20, 1])


tensor([[29947],
        [27881],
        [29929],
        [  313],
        [29953],
        [29906],
        [ 6900],
        [29947],
        [  313],
        [  313],
        [29953],
        [  313],
        [29953],
        [29906],
        [ 6900],
        [29947],
        [  313],
        [  313],
        [29953],
        [  310]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([20])


tensor([-4.6947e-03, -1.0729e-06,  0.0000e+00, -3.1209e-02, -2.9802e-06,
        -4.2074e-03, -3.2187e-06, -3.5763e-07, -5.7749e-03, -9.8313e-03,
        -1.5497e-06, -2.4714e-02, -2.3842e-06, -6.0462e-03, -5.0068e-06,
        -3.5763e-07, -5.5232e-03, -1.0768e-02, -1.4305e-06, -4.5300e-06],
       device='cuda:0')


new_candidates
torch.Size([20, 51])


tensor([[    1, 32010,  1724,  ...,  6900,   313, 29947],
        [    1, 32010,  1724,  ..., 29947, 29953, 27881],
        [    1, 32010,  1724,  ...,   313, 29906, 29929],
        ...,
        [    1, 32010,  1724,  ..., 29929,  6900,   313],
        [    1, 32010,  1724,  ..., 29889, 29947, 29953],
        [    1, 32010,  1724,  ..., 11858,   362,   310]], device='cuda:0')


new_candidate_logprobs
torch.Size([20])


tensor([ -5.3721,  -6.3645,  -7.8688,  -6.7760,  -7.8919,  -9.4114,  -8.7142,
        -10.0983, -11.1041, -10.4438, -12.2723,  -7.8543,  -8.5775, -10.3532,
        -10.3864, -11.3794, -13.1349, -11.9655, -13.5514,  -6.3609],
       device='cuda:0')

infer end: GPU memory used: 18493 MB.
event: level
id: 39
data: [{"content": "8", "parent": 0, "prob": -5.372133255004883}, {"content": "meters", "parent": 1, "prob": -6.364480972290039}, {"content": "9", "parent": 2, "prob": -7.868804931640625}, {"content": "(", "parent": 3, "prob": -6.775992393493652}, {"content": "6", "parent": 4, "prob": -7.891910552978516}, {"content": "2", "parent": 5, "prob": -9.411433219909668}, {"content": "feet", "parent": 6, "prob": -8.714187622070312}, {"content": "8", "parent": 7, "prob": -10.098323822021484}, {"content": "(", "parent": 8, "prob": -11.104098320007324}, {"content": "(", "parent": 9, "prob": -10.443793296813965}, {"content": "6", "parent": 10, "prob": -12.272290229797363}, {"content": "(", "parent": 11, "prob": -7.854261875152588}, {"content": "6", "parent": 12, "prob": -8.577491760253906}, {"content": "2", "parent": 13, "prob": -10.353160858154297}, {"content": "feet", "parent": 14, "prob": -10.386445045471191}, {"content": "8", "parent": 1

array([[ 0.60546875, -1.4765625 ,  0.16992188, ...,  0.94140625,
         2.046875  , -1.4921875 ],
       [-0.10351562, -1.0234375 ,  1.4296875 , ..., -0.30859375,
        -0.22167969, -0.59765625],
       [-0.17089844, -1.890625  , -0.16503906, ..., -0.421875  ,
         1.5859375 ,  2.3125    ],
       ...,
       [ 0.31640625,  0.25195312, -0.21777344, ...,  0.37695312,
         0.09667969,  0.50390625],
       [ 0.26171875,  0.20410156, -1.2109375 , ..., -1.265625  ,
         1.2578125 ,  2.171875  ],
       [-0.484375  ,  1.9140625 ,  2.28125   , ..., -1.1171875 ,
         1.828125  ,  1.453125  ]], dtype=float32)


k_mean_space
(20, 2)


array([[90.10572 , 79.53927 ],
       [62.57332 , 86.20877 ],
       [82.908676, 87.48329 ],
       [90.493774, 43.012665],
       [41.867027, 92.02343 ],
       [95.302414, 69.179184],
       [65.17532 , 85.92303 ],
       [56.815075, 90.992516],
       [87.95672 , 55.64458 ],
       [90.63702 , 43.83751 ],
       [43.709282, 92.482765],
       [90.38806 , 42.93334 ],
       [41.836857, 92.05574 ],
       [95.22843 , 69.11169 ],
       [65.21199 , 86.00862 ],
       [57.285934, 90.7896  ],
       [87.985886, 55.769226],
       [90.564354, 43.603973],
       [43.68908 , 92.52885 ],
       [91.25613 , 76.920204]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-97.10478401, -92.77613735])


closest
(2,)


array([12, 11])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.9648, -2.5156, -6.7812,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3086, -1.5781, -3.7969,  ...,  0.0000,  0.0000,  0.0000],
        [-1.4844, -2.5469, -6.1250,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 3.9531,  1.4531,  0.2891,  ...,  0.0000,  0.0000,  0.0000],
        [-1.4609, -4.7188,  0.7695,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3965, -3.6719, -4.8125,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 9.7362e-10, 2.1724e-10,  ..., 8.3079e-26, 6.4702e-26,
         2.1006e-26],
        [9.9991e-01, 4.5396e-05, 4.0062e-05,  ..., 1.4725e-24, 1.2995e-24,
         1.7586e-25],
        [1.0000e+00, 9.2374e-09, 4.0587e-10,  ..., 7.8824e-25, 1.5521e-25,
         1.2088e-25],
        ...,
        [9.5845e-01, 2.5542e-02, 1.5492e-02,  ..., 7.7062e-23, 5.2964e-23,
         1.3391e-23],
        [1.0000e+00, 1.9947e-06, 4.4508e-07,  ..., 5.7099e-26, 1.6359e-26,
         1.1244e-26],
        [9.4466e-01, 5.3294e-02, 1.8236e-03,  ..., 1.8585e-20, 1.6401e-20,
         4.1469e-21]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9585, 0.9840, 0.9995,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9447, 0.9980, 0.9998,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([20])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19], device='cuda:0')


carryover_candidates
torch.Size([20, 51])


tensor([[    1, 32010,  1724,  ...,  6900,   313, 29947],
        [    1, 32010,  1724,  ..., 29947, 29953, 27881],
        [    1, 32010,  1724,  ...,   313, 29906, 29929],
        ...,
        [    1, 32010,  1724,  ..., 29929,  6900,   313],
        [    1, 32010,  1724,  ..., 29889, 29947, 29953],
        [    1, 32010,  1724,  ..., 11858,   362,   310]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([20])


tensor([ -5.3721,  -6.3645,  -7.8688,  -6.7760,  -7.8919,  -9.4114,  -8.7142,
        -10.0983, -11.1041, -10.4438, -12.2723,  -7.8543,  -8.5775, -10.3532,
        -10.3864, -11.3794, -13.1349, -11.9655, -13.5514,  -6.3609],
       device='cuda:0')


new_candidate_toks
torch.Size([20, 1])


tensor([[29892],
        [  313],
        [29892],
        [29947],
        [27881],
        [29929],
        [  313],
        [29953],
        [29906],
        [29947],
        [27881],
        [29947],
        [27881],
        [29929],
        [  313],
        [29953],
        [29906],
        [29947],
        [27881],
        [14235]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([20])


tensor([ 0.0000e+00, -8.5715e-05,  0.0000e+00, -9.8538e-03, -9.5367e-07,
         0.0000e+00, -6.5533e-03, -1.1921e-06, -2.9654e-03, -3.0850e-02,
        -3.8147e-06, -1.2616e-02, -9.5367e-07,  0.0000e+00, -8.2081e-03,
        -8.3447e-07, -3.3431e-03, -4.2433e-02, -3.2187e-06, -5.6931e-02],
       device='cuda:0')


new_candidates
torch.Size([20, 52])


tensor([[    1, 32010,  1724,  ...,   313, 29947, 29892],
        [    1, 32010,  1724,  ..., 29953, 27881,   313],
        [    1, 32010,  1724,  ..., 29906, 29929, 29892],
        ...,
        [    1, 32010,  1724,  ...,  6900,   313, 29947],
        [    1, 32010,  1724,  ..., 29947, 29953, 27881],
        [    1, 32010,  1724,  ...,   362,   310, 14235]], device='cuda:0')


new_candidate_logprobs
torch.Size([20])


tensor([ -5.3721,  -6.3646,  -7.8688,  -6.7858,  -7.8919,  -9.4114,  -8.7207,
        -10.0983, -11.1071, -10.4746, -12.2723,  -7.8669,  -8.5775, -10.3532,
        -10.3947, -11.3794, -13.1383, -12.0079, -13.5514,  -6.4178],
       device='cuda:0')

infer end: GPU memory used: 18619 MB.
event: level
id: 40
data: [{"content": ",", "parent": 0, "prob": -5.372133255004883}, {"content": "(", "parent": 1, "prob": -6.364566802978516}, {"content": ",", "parent": 2, "prob": -7.868804931640625}, {"content": "8", "parent": 3, "prob": -6.78584623336792}, {"content": "meters", "parent": 4, "prob": -7.891911506652832}, {"content": "9", "parent": 5, "prob": -9.411433219909668}, {"content": "(", "parent": 6, "prob": -8.720741271972656}, {"content": "6", "parent": 7, "prob": -10.0983247756958}, {"content": "2", "parent": 8, "prob": -11.107063293457031}, {"content": "8", "parent": 9, "prob": -10.47464370727539}, {"content": "meters", "parent": 10, "prob": -12.272294044494629}, {"content": "8", "parent": 11, "prob": -7.866878032684326}, {"content": "meters", "parent": 12, "prob": -8.577492713928223}, {"content": "9", "parent": 13, "prob": -10.353160858154297}, {"content": "(", "parent": 14, "prob": -10.3946533203125}, {"content": "6", "parent": 15,

array([[ 1.046875  , -0.12158203,  0.51953125, ...,  1.1171875 ,
         1.7421875 , -1.5       ],
       [ 0.3984375 ,  3.21875   , -0.05126953, ...,  1.2109375 ,
         0.29492188,  1.4921875 ],
       [-0.62109375, -0.01348877,  2.25      , ..., -0.70703125,
         1.296875  ,  2.21875   ],
       ...,
       [ 0.75      , -1.421875  ,  0.22363281, ...,  1.        ,
         1.9921875 , -1.328125  ],
       [ 0.1875    , -0.73046875,  1.234375  , ..., -0.19726562,
        -0.15234375, -0.24316406],
       [-0.5390625 ,  0.38476562,  0.35742188, ..., -0.00793457,
         1.1875    ,  1.59375   ]], dtype=float32)


k_mean_space
(20, 2)


array([[91.71794 , 67.53204 ],
       [70.46751 , 89.19425 ],
       [93.55508 , 74.71912 ],
       [88.59834 , 40.71595 ],
       [47.547073, 92.14337 ],
       [88.28661 , 57.889122],
       [66.59206 , 87.291016],
       [66.664764, 95.01041 ],
       [90.2484  , 70.38594 ],
       [89.24724 , 41.189205],
       [47.534   , 92.17664 ],
       [88.56321 , 40.73658 ],
       [47.685818, 92.25211 ],
       [88.35365 , 57.933567],
       [66.58131 , 87.180176],
       [66.603165, 95.05487 ],
       [90.31193 , 70.53094 ],
       [89.1407  , 41.12273 ],
       [47.750507, 92.40645 ],
       [72.64149 , 94.74296 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-95.66863728, -94.38613415])


closest
(2,)


array([10,  3])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.3730,  0.7500, -3.7656,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1631,  0.5586, -4.4375,  ...,  0.0000,  0.0000,  0.0000],
        [-4.9062, -0.6172, -4.0312,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.7461, -2.6875, -7.0312,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0942, -2.7188, -4.7500,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.7734, -2.9219, -8.3125,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 2.7895e-10, 5.4928e-11,  ..., 1.4437e-26, 3.6503e-27,
         1.6904e-29],
        [9.9959e-01, 3.7997e-04, 1.0126e-05,  ..., 1.5015e-22, 1.3251e-22,
         5.5237e-23],
        [1.0000e+00, 7.5826e-10, 4.0587e-10,  ..., 8.3079e-26, 5.7100e-26,
         5.7100e-26],
        ...,
        [1.0000e+00, 2.3356e-09, 4.0587e-10,  ..., 1.9930e-25, 1.9930e-25,
         9.4141e-26],
        [9.9984e-01, 1.3982e-04, 1.8922e-05,  ..., 1.2328e-23, 2.7508e-24,
         1.0120e-24],
        [1.0000e+00, 6.3488e-09, 6.3488e-09,  ..., 1.7588e-25, 1.5521e-25,
         2.6972e-26]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9996, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9998, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([20])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19], device='cuda:0')


carryover_candidates
torch.Size([20, 52])


tensor([[    1, 32010,  1724,  ...,   313, 29947, 29892],
        [    1, 32010,  1724,  ..., 29953, 27881,   313],
        [    1, 32010,  1724,  ..., 29906, 29929, 29892],
        ...,
        [    1, 32010,  1724,  ...,  6900,   313, 29947],
        [    1, 32010,  1724,  ..., 29947, 29953, 27881],
        [    1, 32010,  1724,  ...,   362,   310, 14235]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([20])


tensor([ -5.3721,  -6.3646,  -7.8688,  -6.7858,  -7.8919,  -9.4114,  -8.7207,
        -10.0983, -11.1071, -10.4746, -12.2723,  -7.8669,  -8.5775, -10.3532,
        -10.3947, -11.3794, -13.1383, -12.0079, -13.5514,  -6.4178],
       device='cuda:0')


new_candidate_toks
torch.Size([20, 1])


tensor([[29947],
        [29906],
        [29900],
        [29892],
        [  313],
        [29892],
        [29947],
        [27881],
        [29929],
        [29892],
        [  313],
        [29892],
        [  313],
        [29892],
        [29947],
        [27881],
        [29929],
        [29892],
        [  313],
        [29871]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([20])


tensor([ 0.0000e+00, -4.1154e-04,  0.0000e+00,  0.0000e+00, -7.1879e-04,
         0.0000e+00, -6.8047e-03, -8.3447e-07,  0.0000e+00,  0.0000e+00,
        -1.5928e-04,  0.0000e+00, -6.6141e-04,  0.0000e+00, -7.7465e-03,
        -8.3447e-07,  0.0000e+00,  0.0000e+00, -1.5928e-04,  0.0000e+00],
       device='cuda:0')


new_candidates
torch.Size([20, 53])


tensor([[    1, 32010,  1724,  ..., 29947, 29892, 29947],
        [    1, 32010,  1724,  ..., 27881,   313, 29906],
        [    1, 32010,  1724,  ..., 29929, 29892, 29900],
        ...,
        [    1, 32010,  1724,  ...,   313, 29947, 29892],
        [    1, 32010,  1724,  ..., 29953, 27881,   313],
        [    1, 32010,  1724,  ...,   310, 14235, 29871]], device='cuda:0')


new_candidate_logprobs
torch.Size([20])


tensor([ -5.3721,  -6.3650,  -7.8688,  -6.7858,  -7.8926,  -9.4114,  -8.7275,
        -10.0983, -11.1071, -10.4746, -12.2725,  -7.8669,  -8.5782, -10.3532,
        -10.4024, -11.3794, -13.1383, -12.0079, -13.5516,  -6.4178],
       device='cuda:0')

infer end: GPU memory used: 18747 MB.
event: level
id: 41
data: [{"content": "8", "parent": 0, "prob": -5.372133255004883}, {"content": "2", "parent": 1, "prob": -6.364978313446045}, {"content": "0", "parent": 2, "prob": -7.868804931640625}, {"content": ",", "parent": 3, "prob": -6.78584623336792}, {"content": "(", "parent": 4, "prob": -7.892630100250244}, {"content": ",", "parent": 5, "prob": -9.411433219909668}, {"content": "8", "parent": 6, "prob": -8.727545738220215}, {"content": "meters", "parent": 7, "prob": -10.098325729370117}, {"content": "9", "parent": 8, "prob": -11.107063293457031}, {"content": ",", "parent": 9, "prob": -10.47464370727539}, {"content": "(", "parent": 10, "prob": -12.272453308105469}, {"content": ",", "parent": 11, "prob": -7.866878032684326}, {"content": "(", "parent": 12, "prob": -8.578154563903809}, {"content": ",", "parent": 13, "prob": -10.353160858154297}, {"content": "8", "parent": 14, "prob": -10.402400016784668}, {"content": "meters", "parent": 15, 

array([[ 1.546875  , -1.1171875 ,  2.078125  , ...,  0.96484375,
        -1.1796875 ,  0.70703125],
       [-0.2109375 , -2.875     ,  0.43359375, ..., -0.734375  ,
        -0.21582031,  1.6015625 ],
       [ 1.9609375 ,  1.7734375 , -0.296875  , ..., -0.25195312,
         1.8125    , -1.09375   ],
       ...,
       [ 1.3203125 , -0.01098633,  0.78515625, ...,  1.2890625 ,
         1.9375    , -1.15625   ],
       [ 0.7265625 ,  3.46875   , -0.53125   , ...,  0.8828125 ,
         0.22167969,  1.7109375 ],
       [-0.06201172,  0.96875   ,  2.671875  , ...,  2.40625   ,
         2.953125  ,  1.2265625 ]], dtype=float32)


k_mean_space
(20, 2)


array([[86.82638 , 74.966995],
       [75.684746, 85.22536 ],
       [77.70503 , 91.35616 ],
       [87.088104, 39.749313],
       [49.550865, 92.95918 ],
       [79.45781 , 74.1526  ],
       [79.81419 , 53.03799 ],
       [70.243195, 95.96803 ],
       [71.00947 , 80.336426],
       [87.22532 , 40.099552],
       [49.566666, 93.43704 ],
       [86.98138 , 39.707714],
       [49.555725, 92.90884 ],
       [79.32788 , 74.04409 ],
       [79.773056, 52.76487 ],
       [70.32718 , 96.072014],
       [71.16589 , 80.440796],
       [87.11679 , 39.787674],
       [49.607685, 93.45726 ],
       [74.8745  , 88.68003 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-108.66950893,  -81.40192604])


closest
(2,)


array([ 4, 11])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 4.0938,  1.1172,  2.0156,  ...,  0.0000,  0.0000,  0.0000],
        [-1.6797, -4.2188, -6.3125,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.9062,  4.0000,  5.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.2363,  0.2393, -3.5156,  ...,  0.0000,  0.0000,  0.0000],
        [-1.0547, -0.2969, -6.2812,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.4922,  0.7305, -3.0781,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9988e-01, 1.2339e-04, 2.1722e-10,  ..., 1.1242e-26, 1.1242e-26,
         1.3427e-27],
        [1.0000e+00, 2.6466e-09, 6.6916e-10,  ..., 4.2191e-25, 3.0563e-26,
         2.3803e-26],
        [1.0000e+00, 3.2242e-08, 1.3440e-08,  ..., 5.8243e-24, 5.8243e-24,
         5.1399e-24],
        ...,
        [1.0000e+00, 1.4931e-10, 7.0529e-11,  ..., 2.1006e-26, 6.8196e-27,
         2.4595e-29],
        [9.9796e-01, 1.1685e-03, 7.0872e-04,  ..., 4.0748e-22, 4.0748e-22,
         3.5960e-22],
        [5.0000e-01, 5.0000e-01, 1.4227e-08,  ..., 1.5085e-21, 1.1748e-21,
         8.0746e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9980, 0.9991, 0.9998,  ..., 1.0000, 1.0000, 1.0000],
        [0.5000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([21])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([21, 53])


tensor([[    1, 32010,  1724,  ..., 29947, 29892, 29947],
        [    1, 32010,  1724,  ..., 27881,   313, 29906],
        [    1, 32010,  1724,  ..., 29929, 29892, 29900],
        ...,
        [    1, 32010,  1724,  ..., 29953, 27881,   313],
        [    1, 32010,  1724,  ...,   310, 14235, 29871],
        [    1, 32010,  1724,  ...,   310, 14235, 29871]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([21])


tensor([ -5.3721,  -6.3650,  -7.8688,  -6.7858,  -7.8926,  -9.4114,  -8.7275,
        -10.0983, -11.1071, -10.4746, -12.2725,  -7.8669,  -8.5782, -10.3532,
        -10.4024, -11.3794, -13.1383, -12.0079, -13.5516,  -6.4178,  -6.4178],
       device='cuda:0')


new_candidate_toks
torch.Size([21, 1])


tensor([[29946],
        [29929],
        [29906],
        [29947],
        [29906],
        [29900],
        [29892],
        [  313],
        [29892],
        [29947],
        [29906],
        [29947],
        [29906],
        [29900],
        [29892],
        [  313],
        [29892],
        [29947],
        [29906],
        [29906],
        [29947]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([21])


tensor([-1.2339e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00, -7.7266e-04,
         0.0000e+00,  0.0000e+00, -2.3624e-04,  0.0000e+00,  0.0000e+00,
        -2.0216e-03,  0.0000e+00, -8.6792e-04,  0.0000e+00,  0.0000e+00,
        -2.4363e-04,  0.0000e+00,  0.0000e+00, -2.0432e-03, -6.9315e-01,
        -6.9315e-01], device='cuda:0')


new_candidates
torch.Size([21, 54])


tensor([[    1, 32010,  1724,  ..., 29892, 29947, 29946],
        [    1, 32010,  1724,  ...,   313, 29906, 29929],
        [    1, 32010,  1724,  ..., 29892, 29900, 29906],
        ...,
        [    1, 32010,  1724,  ..., 27881,   313, 29906],
        [    1, 32010,  1724,  ..., 14235, 29871, 29906],
        [    1, 32010,  1724,  ..., 14235, 29871, 29947]], device='cuda:0')


new_candidate_logprobs
torch.Size([21])


tensor([ -5.3723,  -6.3650,  -7.8688,  -6.7858,  -7.8934,  -9.4114,  -8.7275,
        -10.0986, -11.1071, -10.4746, -12.2745,  -7.8669,  -8.5790, -10.3532,
        -10.4024, -11.3797, -13.1383, -12.0079, -13.5536,  -7.1109,  -7.1109],
       device='cuda:0')

infer end: GPU memory used: 18877 MB.
event: level
id: 42
data: [{"content": "4", "parent": 0, "prob": -5.372256755828857}, {"content": "9", "parent": 1, "prob": -6.364978313446045}, {"content": "2", "parent": 2, "prob": -7.868804931640625}, {"content": "8", "parent": 3, "prob": -6.78584623336792}, {"content": "2", "parent": 4, "prob": -7.893402576446533}, {"content": "0", "parent": 5, "prob": -9.411433219909668}, {"content": ",", "parent": 6, "prob": -8.727545738220215}, {"content": "(", "parent": 7, "prob": -10.098562240600586}, {"content": ",", "parent": 8, "prob": -11.107063293457031}, {"content": "8", "parent": 9, "prob": -10.47464370727539}, {"content": "2", "parent": 10, "prob": -12.27447509765625}, {"content": "8", "parent": 11, "prob": -7.866878032684326}, {"content": "2", "parent": 12, "prob": -8.579022407531738}, {"content": "0", "parent": 13, "prob": -10.353160858154297}, {"content": ",", "parent": 14, "prob": -10.402400016784668}, {"content": "(", "parent": 15, "prob": -11

array([[ 0.1640625 , -3.234375  ,  0.6640625 , ..., -1.3828125 ,
         0.859375  , -0.18847656],
       [-0.23925781, -1.4921875 ,  0.30273438, ...,  0.00946045,
         0.99609375,  2.25      ],
       [ 0.25      , -2.59375   ,  0.6171875 , ..., -0.24316406,
        -0.84765625,  1.375     ],
       ...,
       [ 1.6171875 , -0.76953125,  2.140625  , ...,  1.109375  ,
        -0.81640625,  0.8203125 ],
       [-0.11962891, -2.640625  ,  0.47070312, ..., -0.5546875 ,
        -0.07763672,  1.25      ],
       [-0.53125   , -0.515625  ,  2.375     , ...,  0.45898438,
         1.078125  , -0.03808594]], dtype=float32)


k_mean_space
(20, 2)


array([[75.619156, 68.912735],
       [72.92066 , 92.411156],
       [70.90035 , 88.20484 ],
       [82.93246 , 36.187756],
       [50.502113, 88.90461 ],
       [70.57056 , 93.21538 ],
       [85.22314 , 63.259403],
       [70.411064, 97.705376],
       [65.614624, 92.508675],
       [83.28619 , 36.233757],
       [50.307358, 88.938225],
       [82.83455 , 36.273067],
       [50.49674 , 88.896614],
       [70.593025, 93.157265],
       [85.11813 , 63.037006],
       [70.35143 , 97.67417 ],
       [65.70874 , 92.53437 ],
       [83.06122 , 36.00881 ],
       [50.38509 , 89.01384 ],
       [67.49241 , 86.31219 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-129.13343477,  -61.63745546])


closest
(2,)


array([10, 17])


last_tok_logits
torch.Size([20, 32064])


tensor([[-1.5391,  1.3203,  1.1250,  ...,  0.0000,  0.0000,  0.0000],
        [-1.8906, -2.5938, -7.9688,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.7227, -2.3750, -1.5469,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 4.3750,  1.6172,  1.7969,  ...,  0.0000,  0.0000,  0.0000],
        [-1.6641, -4.8438, -6.4375,  ...,  0.0000,  0.0000,  0.0000],
        [-1.6172, -0.0579, -5.5625,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9683e-01, 3.1727e-03, 2.1655e-10,  ..., 5.6918e-26, 4.4328e-26,
         4.3458e-28],
        [1.0000e+00, 1.7258e-08, 1.6919e-10,  ..., 2.8998e-25, 2.5590e-25,
         2.2583e-25],
        [1.0000e+00, 1.1033e-09, 7.5826e-10,  ..., 3.6503e-27, 2.8428e-27,
         2.8428e-27],
        ...,
        [9.9996e-01, 3.5356e-05, 7.5823e-10,  ..., 9.9221e-27, 4.6869e-27,
         9.2290e-28],
        [1.0000e+00, 2.6466e-09, 1.1033e-09,  ..., 3.7234e-25, 3.4633e-26,
         2.3803e-26],
        [1.0000e+00, 1.1861e-08, 8.5922e-10,  ..., 1.2741e-26, 1.1244e-26,
         1.7243e-27]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9968, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([20])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19], device='cuda:0')


carryover_candidates
torch.Size([20, 54])


tensor([[    1, 32010,  1724,  ..., 29892, 29947, 29946],
        [    1, 32010,  1724,  ...,   313, 29906, 29929],
        [    1, 32010,  1724,  ..., 29892, 29900, 29906],
        ...,
        [    1, 32010,  1724,  ..., 29947, 29892, 29947],
        [    1, 32010,  1724,  ..., 27881,   313, 29906],
        [    1, 32010,  1724,  ..., 14235, 29871, 29906]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([20])


tensor([ -5.3723,  -6.3650,  -7.8688,  -6.7858,  -7.8934,  -9.4114,  -8.7275,
        -10.0986, -11.1071, -10.4746, -12.2745,  -7.8669,  -8.5790, -10.3532,
        -10.4024, -11.3797, -13.1383, -12.0079, -13.5536,  -7.1109],
       device='cuda:0')


new_candidate_toks
torch.Size([20, 1])


tensor([[29947],
        [29892],
        [29929],
        [29946],
        [29929],
        [29906],
        [29947],
        [29906],
        [29900],
        [29946],
        [29929],
        [29946],
        [29929],
        [29906],
        [29947],
        [29906],
        [29900],
        [29946],
        [29929],
        [29929]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([20])


tensor([-3.1777e-03,  0.0000e+00,  0.0000e+00, -4.0055e-05,  0.0000e+00,
        -1.1921e-07,  0.0000e+00, -8.8510e-04,  0.0000e+00, -5.1500e-05,
         0.0000e+00, -3.1233e-05,  0.0000e+00, -3.5763e-07,  0.0000e+00,
        -6.8843e-04,  0.0000e+00, -3.5406e-05,  0.0000e+00,  0.0000e+00],
       device='cuda:0')


new_candidates
torch.Size([20, 55])


tensor([[    1, 32010,  1724,  ..., 29947, 29946, 29947],
        [    1, 32010,  1724,  ..., 29906, 29929, 29892],
        [    1, 32010,  1724,  ..., 29900, 29906, 29929],
        ...,
        [    1, 32010,  1724,  ..., 29892, 29947, 29946],
        [    1, 32010,  1724,  ...,   313, 29906, 29929],
        [    1, 32010,  1724,  ..., 29871, 29906, 29929]], device='cuda:0')


new_candidate_logprobs
torch.Size([20])


tensor([ -5.3754,  -6.3650,  -7.8688,  -6.7859,  -7.8934,  -9.4114,  -8.7275,
        -10.0994, -11.1071, -10.4747, -12.2745,  -7.8669,  -8.5790, -10.3532,
        -10.4024, -11.3804, -13.1383, -12.0079, -13.5536,  -7.1109],
       device='cuda:0')

infer end: GPU memory used: 19011 MB.
event: level
id: 43
data: [{"content": "8", "parent": 0, "prob": -5.375434398651123}, {"content": ",", "parent": 1, "prob": -6.364978313446045}, {"content": "9", "parent": 2, "prob": -7.868804931640625}, {"content": "4", "parent": 3, "prob": -6.785886287689209}, {"content": "9", "parent": 4, "prob": -7.893402576446533}, {"content": "2", "parent": 5, "prob": -9.411433219909668}, {"content": "8", "parent": 6, "prob": -8.727545738220215}, {"content": "2", "parent": 7, "prob": -10.099447250366211}, {"content": "0", "parent": 8, "prob": -11.107063293457031}, {"content": "4", "parent": 9, "prob": -10.474695205688477}, {"content": "9", "parent": 10, "prob": -12.27447509765625}, {"content": "4", "parent": 11, "prob": -7.866909503936768}, {"content": "9", "parent": 12, "prob": -8.579022407531738}, {"content": "2", "parent": 13, "prob": -10.353160858154297}, {"content": "8", "parent": 14, "prob": -10.402400016784668}, {"content": "2", "parent": 15, "prob": -

array([[-0.796875  , -1.3671875 , -0.2265625 , ..., -3.328125  ,
         1.90625   ,  2.015625  ],
       [-0.58984375, -0.15429688,  2.796875  , ..., -0.3203125 ,
         0.78515625,  2.21875   ],
       [ 0.49414062, -1.25      ,  1.2109375 , ..., -0.765625  ,
         1.7265625 ,  2.25      ],
       ...,
       [ 0.        , -3.078125  ,  0.78515625, ..., -1.3671875 ,
         1.0546875 , -0.23242188],
       [-0.22851562, -1.171875  ,  0.33789062, ..., -0.11621094,
         0.96875   ,  2.203125  ],
       [-1.6953125 , -3.21875   ,  0.07519531, ...,  1.53125   ,
         1.140625  ,  1.984375  ]], dtype=float32)


k_mean_space
(20, 2)


array([[91.181435, 76.43267 ],
       [63.153713, 79.65682 ],
       [81.18521 , 85.77864 ],
       [92.26162 , 46.650463],
       [25.08246 , 84.38062 ],
       [96.60411 , 59.931442],
       [93.36351 , 67.680115],
       [91.68924 , 66.160736],
       [94.97557 , 72.03208 ],
       [92.4422  , 46.91199 ],
       [25.606361, 84.725395],
       [92.15982 , 46.44732 ],
       [25.153805, 84.25779 ],
       [96.51769 , 59.919262],
       [93.28956 , 67.8408  ],
       [91.57241 , 66.34794 ],
       [94.84195 , 71.92896 ],
       [91.96089 , 46.245193],
       [25.430086, 84.49545 ],
       [49.961395, 86.12221 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -63.64526415, -127.1305356 ])


closest
(2,)


array([ 4, 17])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 0.4277, -4.0938, -1.5859,  ...,  0.0000,  0.0000,  0.0000],
        [-4.7812,  0.6641, -4.2812,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.9609, -4.2500, -5.1875,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-1.2734,  1.5547,  1.4688,  ...,  0.0000,  0.0000,  0.0000],
        [-1.5156, -2.5938, -7.8750,  ...,  0.0000,  0.0000,  0.0000],
        [-2.5000, -1.7344, -7.3750,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[6.2246e-01, 3.7754e-01, 3.3751e-06,  ..., 2.4428e-26, 1.6789e-26,
         1.3075e-26],
        [1.0000e+00, 8.5922e-10, 7.5826e-10,  ..., 1.0121e-24, 6.1388e-25,
         4.2191e-25],
        [9.9995e-01, 4.0063e-05, 2.5612e-06,  ..., 1.8908e-24, 1.4726e-24,
         1.2995e-24],
        ...,
        [9.9753e-01, 2.4726e-03, 4.0486e-10,  ..., 2.0954e-26, 1.8492e-26,
         2.3278e-28],
        [1.0000e+00, 4.1399e-08, 4.0587e-10,  ..., 8.9319e-25, 8.9319e-25,
         7.8824e-25],
        [1.0000e+00, 1.7258e-08, 3.8507e-09,  ..., 6.9562e-25, 4.7809e-25,
         4.7809e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.6225, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9975, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([21])


tensor([ 0,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([21, 55])


tensor([[    1, 32010,  1724,  ..., 29947, 29946, 29947],
        [    1, 32010,  1724,  ..., 29947, 29946, 29947],
        [    1, 32010,  1724,  ..., 29906, 29929, 29892],
        ...,
        [    1, 32010,  1724,  ..., 29892, 29947, 29946],
        [    1, 32010,  1724,  ...,   313, 29906, 29929],
        [    1, 32010,  1724,  ..., 29871, 29906, 29929]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([21])


tensor([ -5.3754,  -5.3754,  -6.3650,  -7.8688,  -6.7859,  -7.8934,  -9.4114,
         -8.7275, -10.0994, -11.1071, -10.4747, -12.2745,  -7.8669,  -8.5790,
        -10.3532, -10.4024, -11.3804, -13.1383, -12.0079, -13.5536,  -7.1109],
       device='cuda:0')


new_candidate_toks
torch.Size([21, 1])


tensor([[27881],
        [29889],
        [29900],
        [ 6900],
        [29947],
        [29892],
        [29929],
        [29946],
        [29929],
        [29906],
        [29947],
        [29892],
        [29947],
        [29892],
        [29929],
        [29946],
        [29929],
        [29906],
        [29947],
        [29892],
        [29892]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([21])


tensor([-4.7408e-01, -9.7408e-01,  0.0000e+00, -4.5658e-05, -5.2337e-03,
         0.0000e+00,  0.0000e+00, -2.4319e-05,  0.0000e+00, -1.1921e-07,
        -2.4756e-03,  0.0000e+00, -5.2337e-03,  0.0000e+00,  0.0000e+00,
        -3.5406e-05,  0.0000e+00, -2.3842e-07, -2.4756e-03,  0.0000e+00,
         0.0000e+00], device='cuda:0')


new_candidates
torch.Size([21, 56])


tensor([[    1, 32010,  1724,  ..., 29946, 29947, 27881],
        [    1, 32010,  1724,  ..., 29946, 29947, 29889],
        [    1, 32010,  1724,  ..., 29929, 29892, 29900],
        ...,
        [    1, 32010,  1724,  ..., 29947, 29946, 29947],
        [    1, 32010,  1724,  ..., 29906, 29929, 29892],
        [    1, 32010,  1724,  ..., 29906, 29929, 29892]], device='cuda:0')


new_candidate_logprobs
torch.Size([21])


tensor([ -5.8495,  -6.3495,  -6.3650,  -7.8689,  -6.7911,  -7.8934,  -9.4114,
         -8.7276, -10.0994, -11.1071, -10.4772, -12.2745,  -7.8721,  -8.5790,
        -10.3532, -10.4024, -11.3804, -13.1383, -12.0104, -13.5536,  -7.1109],
       device='cuda:0')

infer end: GPU memory used: 19147 MB.
event: level
id: 44
data: [{"content": "meters", "parent": 0, "prob": -5.849515914916992}, {"content": ".", "parent": 0, "prob": -6.349515914916992}, {"content": "0", "parent": 1, "prob": -6.364978313446045}, {"content": "feet", "parent": 2, "prob": -7.8688507080078125}, {"content": "8", "parent": 3, "prob": -6.7911200523376465}, {"content": ",", "parent": 4, "prob": -7.893402576446533}, {"content": "9", "parent": 5, "prob": -9.411433219909668}, {"content": "4", "parent": 6, "prob": -8.727570533752441}, {"content": "9", "parent": 7, "prob": -10.099447250366211}, {"content": "2", "parent": 8, "prob": -11.107063293457031}, {"content": "8", "parent": 9, "prob": -10.477170944213867}, {"content": ",", "parent": 10, "prob": -12.27447509765625}, {"content": "8", "parent": 11, "prob": -7.872143268585205}, {"content": ",", "parent": 12, "prob": -8.579022407531738}, {"content": "9", "parent": 13, "prob": -10.353160858154297}, {"content": "4", "parent": 14, "

array([[-0.90625   , -0.86328125,  0.03271484, ..., -1.40625   ,
        -0.29296875,  2.1875    ],
       [-0.8671875 , -2.203125  ,  1.8515625 , ..., -1.0390625 ,
         1.34375   , -0.1953125 ],
       [ 1.640625  ,  1.        , -2.078125  , ..., -1.2578125 ,
         1.2421875 , -1.171875  ],
       ...,
       [ 0.29101562, -2.59375   ,  0.6015625 , ..., -0.36914062,
        -0.87109375,  1.3828125 ],
       [-0.81640625, -1.15625   , -0.02038574, ..., -3.109375  ,
         1.828125  ,  2.046875  ],
       [-0.5859375 , -0.06445312,  2.65625   , ..., -0.28515625,
         0.83203125,  2.46875   ]], dtype=float32)


k_mean_space
(20, 2)


array([[72.794075, 93.10971 ],
       [71.28965 , 86.456955],
       [88.31009 , 73.4167  ],
       [73.22502 , 92.13543 ],
       [45.70586 , 90.87055 ],
       [86.91654 , 38.986736],
       [62.262634, 86.50365 ],
       [64.31976 , 80.6424  ],
       [84.922905, 58.432934],
       [81.50757 , 72.00031 ],
       [45.939705, 90.811935],
       [87.51563 , 39.488293],
       [45.87483 , 90.903206],
       [86.79898 , 38.89863 ],
       [62.102726, 86.39974 ],
       [64.47877 , 80.486145],
       [84.75596 , 58.35909 ],
       [81.35806 , 71.90918 ],
       [45.97597 , 90.82187 ],
       [87.48291 , 39.53477 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-96.11331272, -94.39066219])


closest
(2,)


array([ 4, 13])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 0.0547, -3.7344, -7.6875,  ...,  0.0000,  0.0000,  0.0000],
        [-0.3535, -1.2578, -1.0859,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.0000,  1.0000,  3.2969,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.7266, -2.6406, -1.4219,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.9141, -3.9219, -1.8906,  ...,  0.0000,  0.0000,  0.0000],
        [-4.2812,  0.6289, -4.1562,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[6.1912e-01, 3.7552e-01, 5.3565e-03,  ..., 5.9044e-31, 3.5812e-31,
         2.7890e-31],
        [1.0000e+00, 6.4759e-07, 1.0467e-08,  ..., 6.3432e-28, 6.3432e-28,
         5.5978e-28],
        [1.0000e+00, 1.5230e-08, 1.3308e-09,  ..., 1.7588e-25, 8.3079e-26,
         2.1006e-26],
        ...,
        [1.0000e+00, 1.4166e-09, 8.5922e-10,  ..., 2.5088e-27, 1.7243e-27,
         1.7243e-27],
        [6.7918e-01, 3.2082e-01, 4.1730e-06,  ..., 2.6654e-26, 1.8319e-26,
         8.6531e-27],
        [1.0000e+00, 1.1033e-09, 7.5826e-10,  ..., 1.4726e-24, 1.2996e-24,
         6.1388e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.6191, 0.9946, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.6792, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([26])


tensor([ 0,  0,  1,  2,  3,  3,  4,  4,  5,  6,  7,  8,  9, 10, 10, 11, 12, 12,
        13, 14, 15, 16, 17, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([26, 56])


tensor([[    1, 32010,  1724,  ..., 29946, 29947, 27881],
        [    1, 32010,  1724,  ..., 29946, 29947, 27881],
        [    1, 32010,  1724,  ..., 29946, 29947, 29889],
        ...,
        [    1, 32010,  1724,  ..., 29947, 29946, 29947],
        [    1, 32010,  1724,  ..., 29947, 29946, 29947],
        [    1, 32010,  1724,  ..., 29906, 29929, 29892]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([26])


tensor([ -5.8495,  -5.8495,  -6.3495,  -6.3650,  -7.8689,  -7.8689,  -6.7911,
         -6.7911,  -7.8934,  -9.4114,  -8.7276, -10.0994, -11.1071, -10.4772,
        -10.4772, -12.2745,  -7.8721,  -7.8721,  -8.5790, -10.3532, -10.4024,
        -11.3804, -13.1383, -12.0104, -12.0104, -13.5536], device='cuda:0')


new_candidate_toks
torch.Size([26, 1])


tensor([[  467],
        [29897],
        [29947],
        [29941],
        [29897],
        [  467],
        [29889],
        [27881],
        [29900],
        [ 6900],
        [29947],
        [29892],
        [29929],
        [27881],
        [29889],
        [29900],
        [29889],
        [27881],
        [29900],
        [ 6900],
        [29947],
        [29892],
        [29929],
        [27881],
        [29889],
        [29900]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([26])


tensor([-4.7945e-01, -9.7945e-01, -5.9605e-07,  0.0000e+00, -2.1301e-01,
        -1.7130e+00, -4.7408e-01, -9.7408e-01,  0.0000e+00, -4.4347e-05,
        -4.0784e-03,  0.0000e+00,  0.0000e+00, -2.5193e-01, -1.5019e+00,
         0.0000e+00, -4.7408e-01, -9.7408e-01,  0.0000e+00, -3.8982e-05,
        -6.7153e-03,  0.0000e+00,  0.0000e+00, -3.8688e-01, -1.1369e+00,
         0.0000e+00], device='cuda:0')


new_candidates
torch.Size([26, 57])


tensor([[    1, 32010,  1724,  ..., 29947, 27881,   467],
        [    1, 32010,  1724,  ..., 29947, 27881, 29897],
        [    1, 32010,  1724,  ..., 29947, 29889, 29947],
        ...,
        [    1, 32010,  1724,  ..., 29946, 29947, 27881],
        [    1, 32010,  1724,  ..., 29946, 29947, 29889],
        [    1, 32010,  1724,  ..., 29929, 29892, 29900]], device='cuda:0')


new_candidate_logprobs
torch.Size([26])


tensor([ -6.3290,  -6.8290,  -6.3495,  -6.3650,  -8.0819,  -9.5819,  -7.2652,
         -7.7652,  -7.8934,  -9.4115,  -8.7316, -10.0994, -11.1071, -10.7291,
        -11.9791, -12.2745,  -8.3462,  -8.8462,  -8.5790, -10.3532, -10.4092,
        -11.3804, -13.1383, -12.3973, -13.1473, -13.5536], device='cuda:0')

infer end: GPU memory used: 19285 MB.
event: level
id: 45
data: [{"content": ").", "parent": 0, "prob": -6.3289642333984375}, {"content": ")", "parent": 0, "prob": -6.8289642333984375}, {"content": "8", "parent": 1, "prob": -6.34951639175415}, {"content": "3", "parent": 2, "prob": -6.364978313446045}, {"content": ")", "parent": 3, "prob": -8.08185863494873}, {"content": ").", "parent": 3, "prob": -9.58185863494873}, {"content": ".", "parent": 4, "prob": -7.265199661254883}, {"content": "meters", "parent": 4, "prob": -7.765199661254883}, {"content": "0", "parent": 5, "prob": -7.893402576446533}, {"content": "feet", "parent": 6, "prob": -9.411478042602539}, {"content": "8", "parent": 7, "prob": -8.731649398803711}, {"content": ",", "parent": 8, "prob": -10.099447250366211}, {"content": "9", "parent": 9, "prob": -11.107063293457031}, {"content": "meters", "parent": 10, "prob": -10.729104995727539}, {"content": ".", "parent": 10, "prob": -11.979104995727539}, {"content": "0", "parent": 11,

array([[-1.1796875 , -1.015625  , -0.25      , ...,  1.2265625 ,
        -0.5546875 ,  0.8125    ],
       [-1.1875    , -0.2734375 ,  1.3203125 , ..., -0.43554688,
        -2.375     ,  1.578125  ],
       [ 0.36914062, -0.921875  ,  0.19140625, ..., -2.65625   ,
         0.83984375, -0.35546875],
       ...,
       [-0.734375  , -0.671875  , -0.07519531, ..., -1.3046875 ,
        -0.453125  ,  1.9609375 ],
       [ 1.65625   ,  1.09375   , -2.09375   , ..., -1.46875   ,
         1.265625  , -1.1328125 ],
       [-0.890625  ,  0.58203125,  0.87109375, ..., -0.20605469,
         0.10107422,  0.9921875 ]], dtype=float32)


k_mean_space
(20, 2)


array([[65.13752 , 84.0963  ],
       [57.27679 , 93.38889 ],
       [79.679115, 81.2581  ],
       [90.46406 , 70.860695],
       [57.437214, 93.85298 ],
       [64.51011 , 83.74046 ],
       [88.42451 , 56.911167],
       [43.410088, 93.45321 ],
       [91.031265, 53.22197 ],
       [44.951042, 92.479065],
       [74.49931 , 85.90769 ],
       [89.40372 , 70.46377 ],
       [74.64916 , 88.097565],
       [43.596386, 92.937965],
       [88.54364 , 57.43518 ],
       [91.24886 , 53.778637],
       [88.27912 , 56.694294],
       [44.031624, 93.78657 ],
       [91.05029 , 53.063976],
       [45.217503, 92.44717 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-104.1150794 ,  -72.80185318])


closest
(2,)


array([ 7, 18])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 6.6562, -2.3125,  5.2188,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.0938,  0.5664, -7.2500,  ...,  0.0000,  0.0000,  0.0000],
        [-1.2422, -1.7188, -2.6094,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.6523, -3.4688, -7.4062,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.1406,  0.4648,  3.1094,  ...,  0.0000,  0.0000,  0.0000],
        [-0.3340, -6.7188, -9.8750,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[5.9720e-01, 2.4895e-01, 1.3325e-01,  ..., 2.1950e-20, 5.5499e-21,
         4.8978e-21],
        [8.7926e-01, 7.2174e-02, 3.0086e-02,  ..., 2.0254e-23, 1.2285e-23,
         1.0841e-23],
        [9.9999e-01, 1.3007e-05, 5.0434e-07,  ..., 1.3428e-27, 6.3431e-28,
         6.3431e-28],
        ...,
        [7.7414e-01, 2.2179e-01, 4.0623e-03,  ..., 1.7710e-30, 1.0742e-30,
         4.4778e-31],
        [1.0000e+00, 1.3440e-08, 2.0612e-09,  ..., 2.8998e-25, 1.0668e-25,
         3.9244e-26],
        [5.4687e-01, 4.2590e-01, 2.7227e-02,  ..., 3.0012e-30, 3.0012e-30,
         3.0012e-30]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.5972, 0.8461, 0.9794,  ..., 1.0000, 1.0000, 1.0000],
        [0.8793, 0.9514, 0.9815,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.7741, 0.9959, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.5469, 0.9728, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([32])


tensor([ 0,  0,  0,  1,  1,  2,  3,  4,  4,  5,  5,  5,  6,  7,  7,  8,  9,  9,
        10, 10, 11, 12, 13, 13, 14, 15, 16, 17, 17, 18, 19, 19],
       device='cuda:0')


carryover_candidates
torch.Size([32, 57])


tensor([[    1, 32010,  1724,  ..., 29947, 27881,   467],
        [    1, 32010,  1724,  ..., 29947, 27881,   467],
        [    1, 32010,  1724,  ..., 29947, 27881,   467],
        ...,
        [    1, 32010,  1724,  ..., 29929, 29892, 29900],
        [    1, 32010,  1724,  ..., 29906, 29929,  6900],
        [    1, 32010,  1724,  ..., 29906, 29929,  6900]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([32])


tensor([ -6.3290,  -6.3290,  -6.3290,  -6.8290,  -6.8290,  -6.3495,  -6.3650,
         -8.0819,  -8.0819,  -9.5819,  -9.5819,  -9.5819,  -7.2652,  -7.7652,
         -7.7652,  -7.8934,  -9.4115,  -9.4115,  -8.7316,  -8.7316, -10.0994,
        -11.1071, -10.7291, -10.7291, -11.9791, -12.2745,  -8.3462,  -8.8462,
         -8.8462,  -8.5790, -10.3532, -10.3532], device='cuda:0')


new_candidate_toks
torch.Size([32, 1])


tensor([[ 2398],
        [ 5976],
        [ 8040],
        [  322],
        [ 2038],
        [29953],
        [29896],
        [ 2038],
        [  322],
        [ 5976],
        [ 2398],
        [ 8040],
        [29947],
        [  467],
        [29897],
        [29941],
        [  467],
        [29897],
        [27881],
        [29889],
        [29900],
        [ 6900],
        [29897],
        [  467],
        [29947],
        [29941],
        [29947],
        [  467],
        [29897],
        [29941],
        [  467],
        [29897]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([32])


tensor([-5.1551e-01, -1.3905e+00, -2.0155e+00, -1.2868e-01, -2.6287e+00,
        -1.3471e-05,  0.0000e+00, -5.4978e-01, -9.2478e-01, -7.1761e-01,
        -1.3426e+00, -1.4676e+00, -1.0729e-06, -1.6370e-01, -1.9137e+00,
         0.0000e+00, -4.9793e-01, -9.9793e-01, -5.7594e-01, -8.2594e-01,
         0.0000e+00, -3.2545e-05, -4.7734e-01, -9.7734e-01, -7.1526e-07,
         0.0000e+00, -1.4305e-06, -2.5601e-01, -1.5060e+00,  0.0000e+00,
        -6.0354e-01, -8.5354e-01], device='cuda:0')


new_candidates
torch.Size([32, 58])


tensor([[    1, 32010,  1724,  ..., 27881,   467,  2398],
        [    1, 32010,  1724,  ..., 27881,   467,  5976],
        [    1, 32010,  1724,  ..., 27881,   467,  8040],
        ...,
        [    1, 32010,  1724,  ..., 29892, 29900, 29941],
        [    1, 32010,  1724,  ..., 29929,  6900,   467],
        [    1, 32010,  1724,  ..., 29929,  6900, 29897]], device='cuda:0')


new_candidate_logprobs
torch.Size([32])


tensor([ -6.8445,  -7.7195,  -8.3445,  -6.9576,  -9.4576,  -6.3495,  -6.3650,
         -8.6316,  -9.0066, -10.2995, -10.9245, -11.0495,  -7.2652,  -7.9289,
         -9.6789,  -7.8934,  -9.9094, -10.4094,  -9.3076,  -9.5576, -10.0994,
        -11.1071, -11.2064, -11.7064, -11.9791, -12.2745,  -8.3462,  -9.1022,
        -10.3522,  -8.5790, -10.9567, -11.2067], device='cuda:0')

infer end: GPU memory used: 19425 MB.
event: level
id: 46
data: [{"content": "However", "parent": 0, "prob": -6.844473838806152}, {"content": "Loc", "parent": 0, "prob": -7.719473838806152}, {"content": "Mount", "parent": 0, "prob": -8.344473838806152}, {"content": "and", "parent": 1, "prob": -6.957643985748291}, {"content": "above", "parent": 1, "prob": -9.45764446258545}, {"content": "6", "parent": 2, "prob": -6.34952974319458}, {"content": "1", "parent": 3, "prob": -6.364978313446045}, {"content": "above", "parent": 4, "prob": -8.631635665893555}, {"content": "and", "parent": 4, "prob": -9.006635665893555}, {"content": "Loc", "parent": 5, "prob": -10.299467086791992}, {"content": "However", "parent": 5, "prob": -10.924467086791992}, {"content": "Mount", "parent": 5, "prob": -11.049467086791992}, {"content": "8", "parent": 6, "prob": -7.265200614929199}, {"content": ").", "parent": 7, "prob": -7.928903102874756}, {"content": ")", "parent": 7, "prob": -9.678903579711914}, {"content": 

array([[ 0.06176758, -1.5234375 ,  1.1640625 , ...,  1.2109375 ,
        -1.1953125 , -2.28125   ],
       [-1.890625  ,  0.83203125,  1.921875  , ..., -0.859375  ,
        -0.71484375,  0.47070312],
       [-0.5390625 , -0.00357056,  2.21875   , ...,  1.328125  ,
        -2.265625  , -1.5390625 ],
       ...,
       [-1.3671875 ,  0.10351562,  1.7109375 , ..., -0.69140625,
        -1.7890625 ,  1.53125   ],
       [-0.734375  , -0.8984375 , -0.05810547, ..., -1.1953125 ,
        -0.515625  ,  2.0625    ],
       [-0.70703125, -2.1875    ,  1.78125   , ..., -0.984375  ,
         1.5625    , -0.01226807]], dtype=float32)


k_mean_space
(20, 2)


array([[63.651997, 83.40016 ],
       [66.06156 , 85.388855],
       [68.03119 , 88.111595],
       [79.717804, 66.881096],
       [74.812614, 59.889366],
       [85.60282 , 70.92146 ],
       [79.51736 , 85.169525],
       [76.573944, 61.28129 ],
       [79.719696, 66.80575 ],
       [65.37161 , 84.64147 ],
       [62.537598, 82.35388 ],
       [67.66622 , 87.79963 ],
       [87.8566  , 72.45157 ],
       [52.482597, 68.35231 ],
       [83.67804 , 59.45998 ],
       [79.03721 , 85.78974 ],
       [53.750782, 68.77935 ],
       [84.157265, 59.985054],
       [80.94048 , 64.516266],
       [87.06319 , 76.93047 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-87.27851057, -86.6217823 ])


closest
(2,)


array([13, 14])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.8203, -7.9062, -7.5000,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.4844, -7.0000, -0.7188,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.4375,  2.1562,  0.4824,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 1.7734,  0.8633, -7.0938,  ...,  0.0000,  0.0000,  0.0000],
        [-0.3281, -3.9062, -7.8750,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3613, -1.2500, -0.9180,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 1.8554e-07, 6.0236e-08,  ..., 2.2073e-21, 1.3388e-21,
         3.1800e-22],
        [1.0000e+00, 1.6374e-07, 4.6912e-08,  ..., 2.1006e-26, 1.1244e-26,
         3.6503e-27],
        [1.0000e+00, 1.0677e-06, 9.9312e-08,  ..., 1.5832e-23, 5.8243e-24,
         3.7234e-25],
        ...,
        [7.0097e-01, 2.5787e-01, 2.7179e-02,  ..., 7.2367e-23, 4.9737e-23,
         4.3893e-23],
        [8.1545e-01, 1.8195e-01, 2.5954e-03,  ..., 8.8121e-31, 6.8629e-31,
         6.8629e-31],
        [1.0000e+00, 5.0435e-07, 1.9556e-08,  ..., 2.5088e-27, 2.2140e-27,
         2.2140e-27]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.7010, 0.9588, 0.9860,  ..., 1.0000, 1.0000, 1.0000],
        [0.8155, 0.9974, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([27])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 13, 13, 14, 14,
        14, 15, 16, 16, 17, 17, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([27, 58])


tensor([[    1, 32010,  1724,  ..., 27881,   467,  2398],
        [    1, 32010,  1724,  ..., 27881,   467,  5976],
        [    1, 32010,  1724,  ..., 27881,   467,  8040],
        ...,
        [    1, 32010,  1724,  ..., 29946, 29947, 27881],
        [    1, 32010,  1724,  ..., 29946, 29947, 27881],
        [    1, 32010,  1724,  ..., 29946, 29947, 29889]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([27])


tensor([ -6.8445,  -7.7195,  -8.3445,  -6.9576,  -9.4576,  -6.3495,  -6.3650,
         -8.6316,  -9.0066, -10.2995, -10.9245, -11.0495,  -7.2652,  -7.9289,
         -7.9289,  -7.9289,  -9.6789,  -9.6789,  -9.6789,  -7.8934,  -9.9094,
         -9.9094, -10.4094, -10.4094,  -9.3076,  -9.3076,  -9.5576],
       device='cuda:0')


new_candidate_toks
torch.Size([27, 1])


tensor([[29892],
        [  630],
        [18274],
        [  338],
        [ 7205],
        [27881],
        [29889],
        [ 7205],
        [  338],
        [  630],
        [29892],
        [18274],
        [29953],
        [ 5976],
        [ 2398],
        [ 8040],
        [  322],
        [ 2038],
        [  297],
        [29896],
        [ 5976],
        [ 8040],
        [ 2038],
        [  322],
        [  467],
        [29897],
        [29947]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([27])


tensor([-3.5763e-07, -1.1921e-07, -1.1921e-06, -1.3793e-04, -3.2728e-04,
        -1.1921e-06, -3.6955e-06, -3.2514e-04, -1.7841e-04, -3.5763e-07,
        -3.5763e-07, -7.1526e-07, -1.3471e-05, -6.5349e-01, -1.4035e+00,
        -1.5285e+00, -3.7223e-01, -1.6222e+00, -2.8722e+00,  0.0000e+00,
        -3.6721e-01, -1.3672e+00, -3.5530e-01, -1.3553e+00, -2.0401e-01,
        -1.7040e+00, -4.7684e-07], device='cuda:0')


new_candidates
torch.Size([27, 59])


tensor([[    1, 32010,  1724,  ...,   467,  2398, 29892],
        [    1, 32010,  1724,  ...,   467,  5976,   630],
        [    1, 32010,  1724,  ...,   467,  8040, 18274],
        ...,
        [    1, 32010,  1724,  ..., 29947, 27881,   467],
        [    1, 32010,  1724,  ..., 29947, 27881, 29897],
        [    1, 32010,  1724,  ..., 29947, 29889, 29947]], device='cuda:0')


new_candidate_logprobs
torch.Size([27])


tensor([ -6.8445,  -7.7195,  -8.3445,  -6.9578,  -9.4580,  -6.3495,  -6.3650,
         -8.6320,  -9.0068, -10.2995, -10.9245, -11.0495,  -7.2652,  -8.5824,
         -9.3324,  -9.4574, -10.0511, -11.3011, -12.5511,  -7.8934, -10.2766,
        -11.2766, -10.7647, -11.7647,  -9.5116, -11.0116,  -9.5576],
       device='cuda:0')

infer end: GPU memory used: 19567 MB.
event: level
id: 47
data: [{"content": ",", "parent": 0, "prob": -6.8444743156433105}, {"content": "ated", "parent": 1, "prob": -7.719473838806152}, {"content": "Ever", "parent": 2, "prob": -8.344474792480469}, {"content": "is", "parent": 3, "prob": -6.957781791687012}, {"content": "sea", "parent": 4, "prob": -9.457971572875977}, {"content": "meters", "parent": 5, "prob": -6.349531173706055}, {"content": ".", "parent": 6, "prob": -6.3649821281433105}, {"content": "sea", "parent": 7, "prob": -8.63196086883545}, {"content": "is", "parent": 8, "prob": -9.006814002990723}, {"content": "ated", "parent": 9, "prob": -10.299467086791992}, {"content": ",", "parent": 10, "prob": -10.924467086791992}, {"content": "Ever", "parent": 11, "prob": -11.049468040466309}, {"content": "6", "parent": 12, "prob": -7.265213966369629}, {"content": "Loc", "parent": 13, "prob": -8.58239459991455}, {"content": "However", "parent": 13, "prob": -9.33239459991455}, {"content": 

array([[ 1.9765625 ,  0.3515625 ,  1.2734375 , ..., -2.03125   ,
        -0.47460938, -0.6953125 ],
       [ 0.2109375 ,  0.80859375, -2.171875  , ..., -1.5234375 ,
        -2.15625   ,  0.8984375 ],
       [-0.859375  ,  1.2265625 ,  1.5078125 , ..., -0.28515625,
        -0.4140625 ,  0.5078125 ],
       ...,
       [-2.4375    , -2.921875  ,  1.328125  , ...,  0.35351562,
        -0.7265625 , -0.91796875],
       [-1.7109375 , -0.25585938,  2.        , ...,  1.578125  ,
        -1.609375  , -1.7734375 ],
       [-1.015625  , -1.0390625 ,  0.74609375, ..., -0.640625  ,
         2.671875  ,  1.515625  ]], dtype=float32)


k_mean_space
(20, 2)


array([[65.79338 , 85.95291 ],
       [64.35824 , 85.88543 ],
       [89.92133 , 70.71758 ],
       [59.90501 , 84.49674 ],
       [76.86652 , 58.97367 ],
       [81.24384 , 73.07968 ],
       [92.143654, 77.452995],
       [79.35915 , 61.707325],
       [59.825268, 84.49692 ],
       [64.2677  , 85.82468 ],
       [65.78882 , 85.99066 ],
       [89.891846, 70.72542 ],
       [88.54752 , 72.02973 ],
       [82.74354 , 74.672485],
       [73.71366 , 80.48304 ],
       [76.35861 , 84.6883  ],
       [66.22111 , 83.110405],
       [67.817024, 67.94495 ],
       [70.7775  , 83.408394],
       [90.904106, 72.6305  ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-104.44566679,  -73.93939972])


closest
(2,)


array([8, 4])


last_tok_logits
torch.Size([20, 32064])


tensor([[-1.0938, -6.1250, -7.7812,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3066,  1.8203, -3.9375,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.9062, -1.1484, -4.9062,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 4.6875, -2.8906, -7.5000,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.1406, -5.6562, -5.3438,  ...,  0.0000,  0.0000,  0.0000],
        [-2.4531, -4.6562, -8.0625,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[8.6332e-01, 1.1684e-01, 1.7918e-02,  ..., 2.7453e-22, 2.4228e-22,
         1.0100e-22],
        [9.9974e-01, 2.0341e-04, 4.0055e-05,  ..., 7.6264e-22, 4.6256e-22,
         2.4759e-22],
        [1.0000e+00, 3.4663e-07, 3.2242e-08,  ..., 3.4633e-26, 1.2741e-26,
         9.9224e-27],
        ...,
        [9.9988e-01, 1.0890e-04, 6.1435e-06,  ..., 1.9477e-21, 7.1653e-22,
         1.0323e-22],
        [9.9998e-01, 8.9396e-06, 2.2603e-06,  ..., 1.9287e-22, 2.3035e-23,
         1.5832e-23],
        [9.9999e-01, 6.1442e-06, 1.7258e-08,  ..., 1.0121e-24, 8.9318e-25,
         7.8823e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.8633, 0.9802, 0.9981,  ..., 1.0000, 1.0000, 1.0000],
        [0.9997, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([25])


tensor([ 0,  0,  1,  2,  3,  3,  4,  5,  5,  6,  7,  8,  8,  9, 10, 10, 11, 12,
        13, 14, 15, 16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([25, 59])


tensor([[    1, 32010,  1724,  ...,   467,  2398, 29892],
        [    1, 32010,  1724,  ...,   467,  2398, 29892],
        [    1, 32010,  1724,  ...,   467,  5976,   630],
        ...,
        [    1, 32010,  1724,  ..., 27881, 29897,  2038],
        [    1, 32010,  1724,  ..., 27881, 29897,   297],
        [    1, 32010,  1724,  ..., 29900, 29941, 29896]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([25])


tensor([ -6.8445,  -6.8445,  -7.7195,  -8.3445,  -6.9578,  -6.9578,  -9.4580,
         -6.3495,  -6.3495,  -6.3650,  -8.6320,  -9.0068,  -9.0068, -10.2995,
        -10.9245, -10.9245, -11.0495,  -7.2652,  -8.5824,  -9.3324,  -9.4574,
        -10.0511, -11.3011, -12.5511,  -7.8934], device='cuda:0')


new_candidate_toks
torch.Size([25, 1])


tensor([[  565],
        [  372],
        [  297],
        [  342],
        [ 5982],
        [  760],
        [ 3233],
        [  467],
        [29897],
        [29955],
        [ 3233],
        [ 5982],
        [  760],
        [  297],
        [  565],
        [  372],
        [  342],
        [27881],
        [  630],
        [29892],
        [18274],
        [  338],
        [ 7205],
        [  278],
        [29889]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([25])


tensor([-1.4697e-01, -2.1470e+00, -2.6247e-04, -3.5763e-07, -2.8725e-01,
        -1.4123e+00, -3.5763e-06, -1.2389e-01, -2.6239e+00,  0.0000e+00,
        -1.7881e-06, -3.2010e-01, -1.3201e+00, -2.4530e-04, -1.4943e-01,
        -2.1494e+00, -4.7684e-07, -1.0729e-06, -2.3842e-07, -2.3842e-07,
        -2.3842e-07, -2.7941e-04, -1.2077e-04, -2.1339e-05, -6.1989e-06],
       device='cuda:0')


new_candidates
torch.Size([25, 60])


tensor([[    1, 32010,  1724,  ...,  2398, 29892,   565],
        [    1, 32010,  1724,  ...,  2398, 29892,   372],
        [    1, 32010,  1724,  ...,  5976,   630,   297],
        ...,
        [    1, 32010,  1724,  ..., 29897,  2038,  7205],
        [    1, 32010,  1724,  ..., 29897,   297,   278],
        [    1, 32010,  1724,  ..., 29941, 29896, 29889]], device='cuda:0')


new_candidate_logprobs
torch.Size([25])


tensor([ -6.9914,  -8.9914,  -7.7197,  -8.3445,  -7.2450,  -8.3700,  -9.4580,
         -6.4734,  -8.9734,  -6.3650,  -8.6320,  -9.3269, -10.3269, -10.2997,
        -11.0739, -13.0739, -11.0495,  -7.2652,  -8.5824,  -9.3324,  -9.4574,
        -10.0514, -11.3013, -12.5512,  -7.8934], device='cuda:0')

infer end: GPU memory used: 19713 MB.
event: level
id: 48
data: [{"content": "if", "parent": 0, "prob": -6.9914398193359375}, {"content": "it", "parent": 0, "prob": -8.991439819335938}, {"content": "in", "parent": 1, "prob": -7.719736099243164}, {"content": "est", "parent": 2, "prob": -8.344474792480469}, {"content": "located", "parent": 3, "prob": -7.24503231048584}, {"content": "part", "parent": 3, "prob": -8.37003231048584}, {"content": "level", "parent": 4, "prob": -9.457975387573242}, {"content": ").", "parent": 5, "prob": -6.473419189453125}, {"content": ")", "parent": 5, "prob": -8.973419189453125}, {"content": "7", "parent": 6, "prob": -6.3649821281433105}, {"content": "level", "parent": 7, "prob": -8.631962776184082}, {"content": "located", "parent": 8, "prob": -9.326910018920898}, {"content": "part", "parent": 8, "prob": -10.326910018920898}, {"content": "in", "parent": 9, "prob": -10.299712181091309}, {"content": "if", "parent": 10, "prob": -11.073897361755371}, {"content": 

array([[ 1.765625  ,  1.125     ,  1.828125  , ..., -0.71875   ,
        -0.46289062, -2.65625   ],
       [-0.9296875 , -0.33203125,  0.77734375, ...,  0.08789062,
         0.2890625 , -0.10400391],
       [-2.046875  , -1.4140625 ,  1.921875  , ...,  1.875     ,
        -0.48632812, -0.53515625],
       ...,
       [ 0.07666016, -0.70703125,  0.13867188, ..., -1.6875    ,
        -1.6796875 ,  1.6015625 ],
       [ 0.27929688,  0.9375    , -2.390625  , ..., -1.6484375 ,
        -2.046875  ,  0.96875   ],
       [ 2.015625  ,  0.62890625,  1.234375  , ..., -2.203125  ,
        -0.59765625, -0.75390625]], dtype=float32)


k_mean_space
(20, 2)


array([[ 84.423515,  27.42764 ],
       [ 71.83086 ,  93.27856 ],
       [ 71.72406 , 102.85621 ],
       [ 66.34459 , 103.03731 ],
       [ 62.30621 , 103.78716 ],
       [ 69.09827 , 102.55245 ],
       [ 60.65613 ,  98.143974],
       [ 62.42602 ,  88.36611 ],
       [ 68.24346 ,  99.04252 ],
       [ 81.87162 , 104.59578 ],
       [ 61.210487,  98.41263 ],
       [ 62.19334 , 103.65983 ],
       [ 68.97147 , 102.55776 ],
       [ 71.7618  , 103.00573 ],
       [ 84.203224,  27.061867],
       [ 71.50991 ,  93.05094 ],
       [ 66.42849 , 103.11823 ],
       [ 72.819695, 101.68816 ],
       [ 63.535213, 104.10643 ],
       [ 81.118324,  54.32143 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-150.4969821 ,  -27.39773178])


closest
(2,)


array([ 6, 14])


last_tok_logits
torch.Size([20, 32064])


tensor([[-2.7500, -8.0625, -9.3750,  ...,  0.0000,  0.0000,  0.0000],
        [-1.9453, -6.5000, -6.0625,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.0312, -3.5625, -1.2656,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.7656, -3.7500, -6.2812,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0510,  1.2578, -4.3125,  ...,  0.0000,  0.0000,  0.0000],
        [-1.4609, -6.9375, -8.3125,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[4.3255e-01, 4.3255e-01, 1.2393e-01,  ..., 1.9368e-19, 1.7093e-19,
         1.8015e-20],
        [9.9440e-01, 5.2181e-03, 3.3358e-04,  ..., 3.7767e-23, 3.7767e-23,
         2.0215e-23],
        [1.0000e+00, 9.4224e-07, 5.0435e-07,  ..., 2.1856e-22, 1.7021e-22,
         1.1698e-22],
        ...,
        [8.5509e-01, 9.0126e-02, 5.4664e-02,  ..., 2.1031e-29, 1.1257e-29,
         9.9344e-30],
        [9.9957e-01, 3.3532e-04, 6.6028e-05,  ..., 5.9384e-22, 3.1786e-22,
         1.7014e-22],
        [7.7734e-01, 1.9654e-01, 2.3473e-02,  ..., 3.5966e-22, 3.5966e-22,
         9.0936e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.4325, 0.8651, 0.9890,  ..., 1.0000, 1.0000, 1.0000],
        [0.9944, 0.9996, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.8551, 0.9452, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [0.9996, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.7773, 0.9739, 0.9974,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([29])


tensor([ 0,  0,  0,  1,  2,  3,  4,  5,  6,  7,  7,  7,  8,  8,  9, 10, 11, 12,
        13, 14, 14, 14, 15, 16, 17, 17, 18, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([29, 60])


tensor([[    1, 32010,  1724,  ...,  2398, 29892,   565],
        [    1, 32010,  1724,  ...,  2398, 29892,   565],
        [    1, 32010,  1724,  ...,  2398, 29892,   565],
        ...,
        [    1, 32010,  1724,  ...,   467,  5976,   630],
        [    1, 32010,  1724,  ...,   467,  2398, 29892],
        [    1, 32010,  1724,  ...,   467,  2398, 29892]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([29])


tensor([ -6.9914,  -6.9914,  -6.9914,  -8.9914,  -7.7197,  -8.3445,  -7.2450,
         -8.3700,  -9.4580,  -6.4734,  -6.4734,  -6.4734,  -8.9734,  -8.9734,
         -6.3650,  -8.6320,  -9.3269, -10.3269, -10.2997, -11.0739, -11.0739,
        -11.0739, -13.0739, -11.0495,  -7.2652,  -7.2652,  -8.5824,  -9.3324,
         -9.3324], device='cuda:0')


new_candidate_toks
torch.Size([29, 1])


tensor([[  366],
        [  591],
        [13858],
        [29915],
        [  278],
        [  338],
        [  297],
        [  310],
        [29889],
        [ 2398],
        [ 5976],
        [ 8040],
        [  322],
        [ 2038],
        [ 6900],
        [29889],
        [  297],
        [  310],
        [  278],
        [  591],
        [  366],
        [13858],
        [29915],
        [  338],
        [  467],
        [29897],
        [  297],
        [  565],
        [  372]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([29])


tensor([-8.3806e-01, -8.3806e-01, -2.0881e+00, -5.6135e-03, -1.7881e-06,
        -4.4108e-06, -4.6121e-04,  0.0000e+00, -5.5298e-02, -7.9922e-01,
        -1.0492e+00, -1.6742e+00, -1.8732e-01, -2.0623e+00, -2.8847e-04,
        -5.8010e-02, -4.1750e-04,  0.0000e+00, -1.5497e-06, -6.7434e-01,
        -1.0493e+00, -2.0493e+00, -5.7653e-03, -3.8147e-06, -1.5655e-01,
        -2.4065e+00, -4.2537e-04, -2.5188e-01, -1.6269e+00], device='cuda:0')


new_candidates
torch.Size([29, 61])


tensor([[    1, 32010,  1724,  ..., 29892,   565,   366],
        [    1, 32010,  1724,  ..., 29892,   565,   591],
        [    1, 32010,  1724,  ..., 29892,   565, 13858],
        ...,
        [    1, 32010,  1724,  ...,  5976,   630,   297],
        [    1, 32010,  1724,  ...,  2398, 29892,   565],
        [    1, 32010,  1724,  ...,  2398, 29892,   372]], device='cuda:0')


new_candidate_logprobs
torch.Size([29])


tensor([ -7.8295,  -7.8295,  -9.0795,  -8.9971,  -7.7197,  -8.3445,  -7.2455,
         -8.3700,  -9.5133,  -7.2726,  -7.5226,  -8.1476,  -9.1607, -11.0357,
         -6.3653,  -8.6900,  -9.3273, -10.3269, -10.2997, -11.7482, -12.1232,
        -13.1232, -13.0797, -11.0495,  -7.4218,  -9.6718,  -8.5828,  -9.5843,
        -10.9593], device='cuda:0')

infer end: GPU memory used: 19861 MB.
event: level
id: 49
data: [{"content": "you", "parent": 0, "prob": -7.8294997215271}, {"content": "we", "parent": 0, "prob": -7.8294997215271}, {"content": "considering", "parent": 0, "prob": -9.079500198364258}, {"content": "'", "parent": 1, "prob": -8.997053146362305}, {"content": "the", "parent": 2, "prob": -7.719738006591797}, {"content": "is", "parent": 3, "prob": -8.34447956085205}, {"content": "in", "parent": 4, "prob": -7.245493412017822}, {"content": "of", "parent": 5, "prob": -8.37003231048584}, {"content": ".", "parent": 6, "prob": -9.513273239135742}, {"content": "However", "parent": 7, "prob": -7.272637367248535}, {"content": "Loc", "parent": 7, "prob": -7.522637367248535}, {"content": "Mount", "parent": 7, "prob": -8.147637367248535}, {"content": "and", "parent": 8, "prob": -9.16073989868164}, {"content": "above", "parent": 8, "prob": -11.03573989868164}, {"content": "feet", "parent": 9, "prob": -6.365270614624023}, {"content": ".", "

array([[ 0.07177734, -0.31835938, -0.53515625, ...,  1.71875   ,
        -2.109375  ,  1.4140625 ],
       [ 0.5078125 ,  0.06005859,  1.3046875 , ...,  2.15625   ,
        -2.109375  ,  0.41015625],
       [ 0.50390625, -0.17382812,  2.984375  , ..., -1.109375  ,
        -1.15625   , -0.02429199],
       ...,
       [-0.87890625, -0.21972656,  1.203125  , ...,  3.53125   ,
        -1.0234375 , -0.38867188],
       [-4.03125   , -1.        ,  0.90625   , ...,  0.640625  ,
        -0.17675781,  1.2421875 ],
       [ 0.51953125,  0.07519531,  1.3125    , ...,  2.15625   ,
        -2.109375  ,  0.36523438]], dtype=float32)


k_mean_space
(20, 2)


array([[48.346436, 84.08612 ],
       [38.112522, 79.15063 ],
       [83.95073 , 77.696526],
       [68.5026  , 78.67719 ],
       [95.141136, 63.580395],
       [90.03499 , 70.86494 ],
       [93.671684, 58.14258 ],
       [94.32963 , 58.631187],
       [78.87758 , 59.866703],
       [81.8815  , 75.98581 ],
       [76.00404 , 83.24003 ],
       [92.06165 , 75.76356 ],
       [90.24066 , 72.50946 ],
       [83.78916 , 73.641235],
       [90.994354, 81.68891 ],
       [79.37706 , 59.934563],
       [93.62766 , 58.216503],
       [94.35469 , 58.846695],
       [95.07875 , 63.659176],
       [37.413746, 78.485245]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -43.92692566, -130.89846659])


closest
(2,)


array([19,  6])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 0.7539, -2.7656, -4.0312,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.3125, -1.4688, -1.6484,  ...,  0.0000,  0.0000,  0.0000],
        [-0.1709, -6.7812, -7.3125,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 2.2500, -4.1562, -6.3125,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.7812,  1.6328, -0.3105,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.4609, -1.2656, -1.6406,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[7.4920e-01, 1.6717e-01, 7.8965e-02,  ..., 4.6913e-23, 3.6536e-23,
         2.5111e-23],
        [9.7360e-01, 5.7893e-03, 5.1090e-03,  ..., 1.3875e-21, 1.3875e-21,
         4.2318e-22],
        [9.2246e-01, 2.7856e-02, 1.0248e-02,  ..., 6.3343e-20, 6.3343e-20,
         2.7831e-21],
        ...,
        [9.9997e-01, 2.1445e-05, 1.0130e-05,  ..., 5.2427e-22, 4.0830e-22,
         4.3035e-23],
        [8.3521e-01, 1.6446e-01, 1.9257e-04,  ..., 1.3758e-19, 6.4989e-20,
         5.0613e-20],
        [9.7845e-01, 4.5311e-03, 3.9987e-03,  ..., 1.3100e-21, 1.3100e-21,
         4.2528e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.7492, 0.9164, 0.9953,  ..., 1.0000, 1.0000, 1.0000],
        [0.9736, 0.9794, 0.9845,  ..., 1.0000, 1.0000, 1.0000],
        [0.9225, 0.9503, 0.9606,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8352, 0.9997, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [0.9784, 0.9830, 0.9870,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([29])


tensor([ 0,  0,  1,  2,  3,  4,  4,  5,  5,  6,  7,  8,  8,  8,  9, 10, 11, 12,
        13, 14, 14, 15, 15, 15, 16, 17, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([29, 61])


tensor([[    1, 32010,  1724,  ..., 29892,   565,   366],
        [    1, 32010,  1724,  ..., 29892,   565,   366],
        [    1, 32010,  1724,  ..., 29892,   565,   591],
        ...,
        [    1, 32010,  1724,  ...,   630,   297,   278],
        [    1, 32010,  1724,  ...,   630,   297,   278],
        [    1, 32010,  1724,  ..., 29892,   565,   591]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([29])


tensor([ -7.8295,  -7.8295,  -7.8295,  -9.0795,  -8.9971,  -7.7197,  -7.7197,
         -8.3445,  -8.3445,  -7.2455,  -8.3700,  -9.5133,  -9.5133,  -9.5133,
         -7.2726,  -7.5226,  -8.1476,  -9.1607, -11.0357,  -6.3653,  -6.3653,
         -8.6900,  -8.6900,  -8.6900,  -9.3273, -10.3269, -10.2997, -10.2997,
        -11.7482], device='cuda:0')


new_candidate_toks
torch.Size([29, 1])


tensor([[29915],
        [  526],
        [ 2050],
        [  278],
        [29879],
        [  379],
        [10082],
        [  760],
        [ 5982],
        [  278],
        [  278],
        [ 2398],
        [ 5976],
        [ 8040],
        [29892],
        [  630],
        [18274],
        [  338],
        [ 7205],
        [  511],
        [29897],
        [ 5976],
        [ 8040],
        [ 2398],
        [  278],
        [  278],
        [  379],
        [10082],
        [ 2050]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([29])


tensor([-2.8874e-01, -1.7887e+00, -2.6750e-02, -8.0716e-02,  0.0000e+00,
        -1.8008e-01, -1.8051e+00, -3.8756e-01, -1.1376e+00, -1.0729e-06,
        -3.2425e-05, -5.7184e-01, -1.5718e+00, -1.5718e+00, -1.1921e-07,
        -3.5763e-07, -3.8147e-06, -2.4614e-04, -5.8414e-05, -1.8006e-01,
        -1.9301e+00, -9.3489e-01, -1.0599e+00, -1.4349e+00, -1.4305e-06,
        -3.2545e-05, -1.8007e-01, -1.8051e+00, -2.1786e-02], device='cuda:0')


new_candidates
torch.Size([29, 62])


tensor([[    1, 32010,  1724,  ...,   565,   366, 29915],
        [    1, 32010,  1724,  ...,   565,   366,   526],
        [    1, 32010,  1724,  ...,   565,   591,  2050],
        ...,
        [    1, 32010,  1724,  ...,   297,   278,   379],
        [    1, 32010,  1724,  ...,   297,   278, 10082],
        [    1, 32010,  1724,  ...,   565,   591,  2050]], device='cuda:0')


new_candidate_logprobs
torch.Size([29])


tensor([ -8.1182,  -9.6182,  -7.8562,  -9.1602,  -8.9971,  -7.8998,  -9.5248,
         -8.7320,  -9.4820,  -7.2455,  -8.3701, -10.0851, -11.0851, -11.0851,
         -7.2726,  -7.5226,  -8.1476,  -9.1610, -11.0358,  -6.5453,  -8.2953,
         -9.6249,  -9.7499, -10.1249,  -9.3273, -10.3269, -10.4798, -12.1048,
        -11.7700], device='cuda:0')

infer end: GPU memory used: 20011 MB.
event: level
id: 50
data: [{"content": "'", "parent": 0, "prob": -8.118244171142578}, {"content": "are", "parent": 0, "prob": -9.618244171142578}, {"content": "consider", "parent": 1, "prob": -7.8562493324279785}, {"content": "the", "parent": 2, "prob": -9.160216331481934}, {"content": "s", "parent": 3, "prob": -8.997053146362305}, {"content": "H", "parent": 4, "prob": -7.899820327758789}, {"content": "Mah", "parent": 4, "prob": -9.524820327758789}, {"content": "part", "parent": 5, "prob": -8.732037544250488}, {"content": "located", "parent": 5, "prob": -9.482037544250488}, {"content": "the", "parent": 6, "prob": -7.245494365692139}, {"content": "the", "parent": 7, "prob": -8.370064735412598}, {"content": "However", "parent": 8, "prob": -10.085112571716309}, {"content": "Loc", "parent": 8, "prob": -11.085112571716309}, {"content": "Mount", "parent": 8, "prob": -11.085112571716309}, {"content": ",", "parent": 9, "prob": -7.272637367248535}, {"conten

array([[ 0.33203125,  0.90234375, -0.41015625, ...,  0.1875    ,
         0.07568359, -0.14648438],
       [ 0.31640625, -1.2109375 , -0.12695312, ..., -0.546875  ,
         0.00714111,  0.85546875],
       [ 0.921875  , -0.41015625,  2.9375    , ..., -1.4453125 ,
        -1.7890625 , -0.41601562],
       ...,
       [-1.4140625 ,  0.34179688, -1.9296875 , ...,  0.34570312,
         0.31640625,  0.671875  ],
       [ 2.15625   , -0.70703125,  0.203125  , ...,  1.765625  ,
        -0.18359375,  1.140625  ],
       [-1.3125    ,  1.09375   ,  0.0859375 , ..., -2.390625  ,
        -0.4375    ,  1.28125   ]], dtype=float32)


k_mean_space
(20, 2)


array([[66.83693 , 54.124043],
       [83.56892 , 66.66783 ],
       [86.27535 , 63.94856 ],
       [85.88606 , 67.0226  ],
       [71.88771 , 55.946377],
       [77.49296 , 89.09219 ],
       [80.558784, 90.06868 ],
       [72.69913 , 83.867386],
       [60.390484, 84.37273 ],
       [64.725655, 85.96369 ],
       [64.75158 , 84.97183 ],
       [83.9703  , 69.08742 ],
       [75.579704, 84.58314 ],
       [83.260056, 75.66896 ],
       [85.96409 , 60.99408 ],
       [61.057175, 84.47718 ],
       [81.38006 , 88.41803 ],
       [68.43795 , 79.30258 ],
       [82.59847 , 70.07519 ],
       [80.583176, 72.51966 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-87.17065239, -89.77400112])


closest
(2,)


array([8, 0])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.8242, -4.9688, -1.7031,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0253, -4.4688, -0.3926,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.9648, -5.9688, -5.6562,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.7539, -2.2500, -5.1250,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.0312, -3.0781, -6.1250,  ...,  0.0000,  0.0000,  0.0000],
        [-1.6172, -3.9219, -7.3438,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 2.2603e-06, 1.3710e-06,  ..., 7.0954e-23, 2.0329e-23,
         2.4279e-24],
        [6.4019e-01, 3.0240e-01, 5.2550e-02,  ..., 6.5440e-24, 4.2251e-24,
         1.9958e-24],
        [9.8057e-01, 8.4836e-03, 2.4306e-03,  ..., 3.1806e-20, 1.0326e-20,
         6.6012e-22],
        ...,
        [7.7308e-01, 2.2149e-01, 4.5969e-03,  ..., 2.0180e-23, 1.5716e-23,
         1.5716e-23],
        [1.0000e+00, 1.5535e-06, 4.4508e-07,  ..., 1.8537e-26, 1.6359e-26,
         5.5978e-28],
        [4.6145e-01, 2.4699e-01, 9.0864e-02,  ..., 2.6010e-21, 2.2954e-21,
         8.4442e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.6402, 0.9426, 0.9951,  ..., 1.0000, 1.0000, 1.0000],
        [0.9806, 0.9891, 0.9915,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.7731, 0.9946, 0.9992,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.4614, 0.7084, 0.7993,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([30])


tensor([ 0,  1,  1,  2,  3,  3,  3,  4,  5,  6,  7,  8,  9,  9, 10, 11, 12, 13,
        14, 14, 15, 16, 17, 17, 18, 19, 19, 19, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([30, 62])


tensor([[    1, 32010,  1724,  ...,   565,   366, 29915],
        [    1, 32010,  1724,  ...,   565,   366,   526],
        [    1, 32010,  1724,  ...,   565,   366,   526],
        ...,
        [    1, 32010,  1724,  ..., 29955,  6900,   511],
        [    1, 32010,  1724,  ..., 29955,  6900,   511],
        [    1, 32010,  1724,  ..., 29955,  6900,   511]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([30])


tensor([ -8.1182,  -9.6182,  -9.6182,  -7.8562,  -9.1602,  -9.1602,  -9.1602,
         -8.9971,  -7.8998,  -9.5248,  -8.7320,  -9.4820,  -7.2455,  -7.2455,
         -8.3701, -10.0851, -11.0851, -11.0851,  -7.2726,  -7.2726,  -7.5226,
         -8.1476,  -9.1610,  -9.1610, -11.0358,  -6.5453,  -6.5453,  -6.5453,
         -6.5453,  -6.5453], device='cuda:0')


new_candidate_toks
torch.Size([30, 1])


tensor([[  276],
        [ 6721],
        [16811],
        [  278],
        [ 9939],
        [ 3171],
        [ 2533],
        [ 7088],
        [ 3039],
        [  284],
        [  310],
        [  297],
        [  379],
        [10082],
        [  379],
        [29892],
        [  630],
        [18274],
        [  565],
        [  372],
        [  297],
        [  342],
        [ 5982],
        [  760],
        [ 3233],
        [ 5034],
        [  408],
        [ 2466],
        [ 5982],
        [ 5998]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([30])


tensor([-3.8147e-06, -4.4599e-01, -1.1960e+00, -1.9620e-02, -4.5597e-01,
        -1.5810e+00, -2.7060e+00, -1.0307e-01, -2.0266e-06, -1.3972e-04,
         0.0000e+00, -8.8100e-05, -1.1325e-01, -2.2382e+00, -5.7052e-04,
         0.0000e+00, -1.1921e-07, -3.3379e-06, -1.2864e-01, -2.2536e+00,
        -2.4411e-04, -3.5763e-07, -2.5737e-01, -1.5074e+00, -2.0266e-06,
        -7.7339e-01, -1.3984e+00, -2.3984e+00, -2.3984e+00, -2.6484e+00],
       device='cuda:0')


new_candidates
torch.Size([30, 63])


tensor([[    1, 32010,  1724,  ...,   366, 29915,   276],
        [    1, 32010,  1724,  ...,   366,   526,  6721],
        [    1, 32010,  1724,  ...,   366,   526, 16811],
        ...,
        [    1, 32010,  1724,  ...,  6900,   511,  2466],
        [    1, 32010,  1724,  ...,  6900,   511,  5982],
        [    1, 32010,  1724,  ...,  6900,   511,  5998]], device='cuda:0')


new_candidate_logprobs
torch.Size([30])


tensor([ -8.1182, -10.0642, -10.8142,  -7.8759,  -9.6162, -10.7412, -11.8662,
         -9.1001,  -7.8998,  -9.5250,  -8.7320,  -9.4821,  -7.3587,  -9.4837,
         -8.3706, -10.0851, -11.0851, -11.0851,  -7.4013,  -9.5263,  -7.5229,
         -8.1476,  -9.4184, -10.6684, -11.0358,  -7.3187,  -7.9437,  -8.9437,
         -8.9437,  -9.1937], device='cuda:0')

infer end: GPU memory used: 20163 MB.
event: level
id: 51
data: [{"content": "re", "parent": 0, "prob": -8.118247985839844}, {"content": "asking", "parent": 1, "prob": -10.064233779907227}, {"content": "referring", "parent": 1, "prob": -10.814233779907227}, {"content": "the", "parent": 2, "prob": -7.8758697509765625}, {"content": "highest", "parent": 3, "prob": -9.616189002990723}, {"content": "height", "parent": 3, "prob": -10.741189002990723}, {"content": "sum", "parent": 3, "prob": -11.866189002990723}, {"content": "worth", "parent": 4, "prob": -9.100122451782227}, {"content": "imal", "parent": 5, "prob": -7.899822235107422}, {"content": "al", "parent": 6, "prob": -9.5249605178833}, {"content": "of", "parent": 7, "prob": -8.732037544250488}, {"content": "in", "parent": 8, "prob": -9.482125282287598}, {"content": "H", "parent": 9, "prob": -7.358743190765381}, {"content": "Mah", "parent": 9, "prob": -9.483743667602539}, {"content": "H", "parent": 10, "prob": -8.370635032653809}, {"con

array([[ 0.7109375 , -1.3125    , -0.17480469, ..., -0.59375   ,
        -0.22949219,  0.88671875],
       [-0.6640625 , -1.4609375 ,  1.3828125 , ...,  0.75390625,
        -0.7265625 ,  2.3125    ],
       [ 0.21679688, -1.484375  ,  1.109375  , ...,  1.59375   ,
         1.40625   ,  0.3515625 ],
       ...,
       [-0.796875  ,  1.265625  ,  1.640625  , ..., -0.20019531,
        -0.4296875 ,  0.17480469],
       [ 1.9609375 ,  1.25      ,  1.6640625 , ..., -0.6640625 ,
        -0.71484375, -2.90625   ],
       [-0.9453125 , -0.19238281,  0.84375   , ...,  0.13183594,
         0.5       , -0.16503906]], dtype=float32)


k_mean_space
(20, 2)


array([[ 64.47277 ,  97.378235],
       [ 66.85828 , 100.26478 ],
       [ 65.04819 ,  97.30726 ],
       [ 71.25875 , 102.996124],
       [ 73.4861  , 103.63292 ],
       [ 70.39234 , 103.54687 ],
       [ 77.54523 , 104.82875 ],
       [ 64.87268 ,  93.464264],
       [ 82.14398 ,  98.751236],
       [ 82.09417 ,  97.358154],
       [ 74.85702 ,  93.508286],
       [ 73.86413 ,  95.35897 ],
       [ 85.662704,  38.17246 ],
       [ 86.356926,  71.58515 ],
       [ 85.28105 ,  38.587505],
       [ 67.808205, 102.75696 ],
       [ 76.97523 ,  99.763794],
       [ 82.31462 , 103.32129 ],
       [ 69.535164, 103.47981 ],
       [ 71.472015,  99.30564 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-163.01811695,  -25.21312189])


closest
(2,)


array([ 0, 12])


last_tok_logits
torch.Size([20, 32064])


tensor([[  0.4258,  -2.8281,   0.0496,  ...,   0.0000,   0.0000,   0.0000],
        [  5.0938,   0.1035,   1.0234,  ...,   0.0000,   0.0000,   0.0000],
        [  3.9219,  -3.3281,   1.8984,  ...,   0.0000,   0.0000,   0.0000],
        ...,
        [  2.7969,  -1.5859,  -5.5625,  ...,   0.0000,   0.0000,   0.0000],
        [ -2.7812,  -8.3750, -10.3750,  ...,   0.0000,   0.0000,   0.0000],
        [ -2.1562,  -6.5938,  -6.2500,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[6.6923e-01, 2.4620e-01, 7.0536e-02,  ..., 3.6981e-23, 2.5417e-23,
         1.4482e-23],
        [9.5494e-01, 2.8837e-02, 3.4440e-03,  ..., 2.3650e-22, 1.0494e-22,
         9.2613e-23],
        [9.9998e-01, 1.4739e-05, 1.5535e-06,  ..., 1.4251e-21, 9.7948e-22,
         2.6362e-22],
        ...,
        [1.0000e+00, 5.7150e-07, 6.8256e-08,  ..., 2.2583e-25, 5.0390e-26,
         4.4469e-26],
        [4.8192e-01, 4.2530e-01, 8.3746e-02,  ..., 1.4831e-19, 1.1551e-19,
         8.3672e-21],
        [9.9567e-01, 4.0691e-03, 2.2956e-04,  ..., 2.5990e-23, 2.2936e-23,
         8.4376e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.6692, 0.9154, 0.9860,  ..., 1.0000, 1.0000, 1.0000],
        [0.9549, 0.9838, 0.9872,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.4819, 0.9072, 0.9910,  ..., 1.0000, 1.0000, 1.0000],
        [0.9957, 0.9997, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([29])


tensor([ 0,  0,  1,  2,  3,  3,  3,  4,  5,  5,  5,  6,  6,  7,  7,  8,  9, 10,
        11, 12, 13, 14, 15, 15, 16, 17, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([29, 63])


tensor([[    1, 32010,  1724,  ...,   366, 29915,   276],
        [    1, 32010,  1724,  ...,   366, 29915,   276],
        [    1, 32010,  1724,  ...,   366,   526,  6721],
        ...,
        [    1, 32010,  1724,  ...,  2398, 29892,   565],
        [    1, 32010,  1724,  ...,  2398, 29892,   565],
        [    1, 32010,  1724,  ...,  2398, 29892,   372]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([29])


tensor([ -8.1182,  -8.1182, -10.0642, -10.8142,  -7.8759,  -7.8759,  -7.8759,
         -9.6162, -10.7412, -10.7412, -10.7412, -11.8662, -11.8662,  -9.1001,
         -9.1001,  -7.8998,  -9.5250,  -8.7320,  -9.4821,  -7.3587,  -9.4837,
         -8.3706, -10.0851, -10.0851, -11.0851, -11.0851,  -7.4013,  -7.4013,
         -9.5263], device='cuda:0')


new_candidate_toks
torch.Size([29, 1])


tensor([[ 6721],
        [16811],
        [ 1048],
        [  304],
        [ 9939],
        [ 3171],
        [11563],
        [14378],
        [  310],
        [29879],
        [  515],
        [ 2415],
        [  310],
        [  451],
        [ 3585],
        [  388],
        [  574],
        [  278],
        [  278],
        [ 3039],
        [  284],
        [ 3039],
        [  565],
        [  372],
        [  297],
        [  342],
        [  366],
        [  591],
        [29915]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([29])


tensor([-4.0163e-01, -1.4016e+00, -4.6109e-02, -2.2054e-05, -2.7492e-01,
        -2.0249e+00, -3.6499e+00, -7.6479e-02, -2.5094e-01, -2.1259e+00,
        -2.6259e+00, -2.0143e-01, -1.7014e+00, -1.6051e-01, -1.9105e+00,
        -2.1856e-03, -1.2414e-03, -1.5378e-05, -1.0729e-06, -1.4305e-06,
        -1.5785e-04, -1.3113e-06, -1.6433e-01, -2.0393e+00, -5.7929e-04,
        -7.1526e-07, -7.2997e-01, -8.5497e-01, -4.3396e-03], device='cuda:0')


new_candidates
torch.Size([29, 64])


tensor([[    1, 32010,  1724,  ..., 29915,   276,  6721],
        [    1, 32010,  1724,  ..., 29915,   276, 16811],
        [    1, 32010,  1724,  ...,   526,  6721,  1048],
        ...,
        [    1, 32010,  1724,  ..., 29892,   565,   366],
        [    1, 32010,  1724,  ..., 29892,   565,   591],
        [    1, 32010,  1724,  ..., 29892,   372, 29915]], device='cuda:0')


new_candidate_logprobs
torch.Size([29])


tensor([ -8.5199,  -9.5199, -10.1103, -10.8143,  -8.1508,  -9.9008, -11.5258,
         -9.6927, -10.9921, -12.8671, -13.3671, -12.0676, -13.5676,  -9.2606,
        -11.0106,  -7.9020,  -9.5262,  -8.7321,  -9.4821,  -7.3587,  -9.4839,
         -8.3706, -10.2494, -12.1244, -11.0857, -11.0851,  -8.1312,  -8.2562,
         -9.5306], device='cuda:0')

infer end: GPU memory used: 20319 MB.
event: level
id: 52
data: [{"content": "asking", "parent": 0, "prob": -8.519875526428223}, {"content": "referring", "parent": 0, "prob": -9.519875526428223}, {"content": "about", "parent": 1, "prob": -10.110342979431152}, {"content": "to", "parent": 2, "prob": -10.814255714416504}, {"content": "highest", "parent": 3, "prob": -8.150792121887207}, {"content": "height", "parent": 3, "prob": -9.900793075561523}, {"content": "Earth", "parent": 3, "prob": -11.525792121887207}, {"content": "mountain", "parent": 4, "prob": -9.692667961120605}, {"content": "of", "parent": 5, "prob": -10.992124557495117}, {"content": "s", "parent": 5, "prob": -12.867124557495117}, {"content": "from", "parent": 5, "prob": -13.367124557495117}, {"content": "mit", "parent": 6, "prob": -12.067620277404785}, {"content": "of", "parent": 6, "prob": -13.567620277404785}, {"content": "not", "parent": 7, "prob": -9.260635375976562}, {"content": "mention", "parent": 7, "prob": -11.0106

array([[-0.67578125, -1.546875  ,  1.265625  , ...,  0.25976562,
        -0.7578125 ,  2.421875  ],
       [ 0.2109375 , -1.3984375 ,  1.3203125 , ...,  1.0859375 ,
         1.53125   , -0.0703125 ],
       [-0.67578125, -0.53125   ,  3.765625  , ..., -1.1796875 ,
        -1.5       , -1.3125    ],
       ...,
       [-2.953125  , -0.3125    ,  0.41796875, ...,  1.578125  ,
        -0.24902344,  0.71484375],
       [-3.859375  , -0.46289062,  0.9765625 , ...,  0.17382812,
        -0.00836182,  1.8828125 ],
       [-0.41015625,  0.31054688,  0.49414062, ..., -0.8671875 ,
        -3.171875  ,  3.046875  ]], dtype=float32)


k_mean_space
(20, 2)


array([[67.40446 , 88.833855],
       [65.81376 , 85.25972 ],
       [56.606018, 91.177956],
       [56.982986, 91.334045],
       [64.79164 , 91.93775 ],
       [61.728775, 92.3309  ],
       [71.41765 , 93.40288 ],
       [65.119026, 91.90292 ],
       [60.28819 , 91.14017 ],
       [55.041233, 88.58988 ],
       [72.36656 , 93.78688 ],
       [69.86999 , 93.073654],
       [69.59473 , 91.61901 ],
       [77.453804, 93.43973 ],
       [71.138794, 83.79623 ],
       [90.98329 , 67.30357 ],
       [92.13793 , 75.3865  ],
       [89.03525 , 59.226982],
       [89.59408 , 59.04065 ],
       [93.63096 , 69.48271 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-161.36728001,  -43.00113392])


closest
(2,)


array([ 9, 18])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 4.7188,  0.1650,  0.6562,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.0312, -3.0469,  2.0781,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.2344, -0.0630, -4.5625,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 1.6328, -0.1455, -4.0938,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.4844,  2.4219,  0.2578,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.5469, -4.9375,  2.2656,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.6300e-01, 2.9080e-02, 2.3870e-03,  ..., 2.8768e-22, 1.7449e-22,
         1.4465e-22],
        [9.9997e-01, 1.3007e-05, 1.9947e-06,  ..., 5.6365e-21, 4.1237e-21,
         9.2013e-22],
        [8.6398e-01, 6.2586e-02, 3.3500e-02,  ..., 3.5629e-21, 2.6067e-21,
         4.5297e-22],
        ...,
        [9.9929e-01, 2.0332e-04, 1.7943e-04,  ..., 1.9660e-20, 1.9660e-20,
         1.6138e-21],
        [8.6681e-01, 1.3293e-01, 1.7637e-04,  ..., 7.6428e-20, 2.8116e-20,
         2.1897e-20],
        [9.9142e-01, 8.5775e-03, 1.1157e-07,  ..., 6.0861e-25, 3.6914e-25,
         1.9759e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9630, 0.9921, 0.9945,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8640, 0.9266, 0.9601,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9993, 0.9995, 0.9997,  ..., 1.0000, 1.0000, 1.0000],
        [0.8668, 0.9997, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [0.9914, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([34])


tensor([ 0,  1,  2,  2,  3,  4,  4,  5,  5,  6,  7,  7,  7,  8,  8,  8,  8,  8,
         8,  9, 10, 11, 11, 11, 12, 12, 13, 14, 15, 16, 17, 18, 18, 19],
       device='cuda:0')


carryover_candidates
torch.Size([34, 64])


tensor([[    1, 32010,  1724,  ..., 29915,   276,  6721],
        [    1, 32010,  1724,  ..., 29915,   276, 16811],
        [    1, 32010,  1724,  ...,   526,  6721,  1048],
        ...,
        [    1, 32010,  1724,  ...,  5982,   297,   278],
        [    1, 32010,  1724,  ...,  5982,   297,   278],
        [    1, 32010,  1724,  ...,   278,   379,  3039]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([34])


tensor([ -8.5199,  -9.5199, -10.1103, -10.1103, -10.8143,  -8.1508,  -8.1508,
         -9.9008,  -9.9008, -11.5258,  -9.6927,  -9.6927,  -9.6927, -10.9921,
        -10.9921, -10.9921, -10.9921, -10.9921, -10.9921, -12.8671, -13.3671,
        -12.0676, -12.0676, -12.0676, -13.5676, -13.5676,  -9.2606, -11.0106,
         -7.9020,  -9.5262,  -8.7321,  -9.4821,  -9.4821,  -7.3587],
       device='cuda:0')


new_candidate_toks
torch.Size([34, 1])


tensor([[ 1048],
        [  304],
        [  278],
        [19223],
        [  278],
        [14378],
        [19223],
        [  310],
        [29879],
        [29915],
        [  373],
        [ 2038],
        [  746],
        [  263],
        [  278],
        [19223],
        [14378],
        [ 1090],
        [  599],
        [  310],
        [  278],
        [ 3171],
        [  310],
        [ 2038],
        [  599],
        [11563],
        [  292],
        [  292],
        [  294],
        [  332],
        [  379],
        [  379],
        [10082],
        [  388]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([34])


tensor([-3.7700e-02, -2.5630e-05, -1.4621e-01, -2.7712e+00, -4.1432e-02,
        -1.7236e-01, -1.9224e+00, -1.7106e-01, -2.5461e+00, -2.5276e-03,
        -2.1926e-01, -2.7193e+00, -3.2193e+00, -1.3330e+00, -1.4580e+00,
        -1.4580e+00, -2.3330e+00, -2.8330e+00, -2.9580e+00, -4.8347e-02,
        -7.4378e-02, -4.0249e-01, -1.5275e+00, -2.7775e+00, -1.6278e-01,
        -2.9128e+00, -1.3113e-06, -4.7684e-07, -3.3575e-04, -5.1971e-04,
        -7.0979e-04, -1.4293e-01, -2.0179e+00, -8.6148e-03], device='cuda:0')


new_candidates
torch.Size([34, 65])


tensor([[    1, 32010,  1724,  ...,   276,  6721,  1048],
        [    1, 32010,  1724,  ...,   276, 16811,   304],
        [    1, 32010,  1724,  ...,  6721,  1048,   278],
        ...,
        [    1, 32010,  1724,  ...,   297,   278,   379],
        [    1, 32010,  1724,  ...,   297,   278, 10082],
        [    1, 32010,  1724,  ...,   379,  3039,   388]], device='cuda:0')


new_candidate_logprobs
torch.Size([34])


tensor([ -8.5576,  -9.5199, -10.2565, -12.8815, -10.8557,  -8.3231, -10.0731,
        -10.0719, -12.4469, -11.5283,  -9.9119, -12.4119, -12.9119, -12.3251,
        -12.4501, -12.4501, -13.3251, -13.8251, -13.9501, -12.9155, -13.4415,
        -12.4701, -13.5951, -14.8451, -13.7304, -16.4804,  -9.2606, -11.0106,
         -7.9023,  -9.5267,  -8.7328,  -9.6251, -11.5001,  -7.3674],
       device='cuda:0')

infer end: GPU memory used: 20477 MB.
event: level
id: 53
data: [{"content": "about", "parent": 0, "prob": -8.557576179504395}, {"content": "to", "parent": 1, "prob": -9.519901275634766}, {"content": "the", "parent": 2, "prob": -10.256549835205078}, {"content": "mountains", "parent": 2, "prob": -12.881549835205078}, {"content": "the", "parent": 3, "prob": -10.855687141418457}, {"content": "mountain", "parent": 4, "prob": -8.32314682006836}, {"content": "mountains", "parent": 4, "prob": -10.07314682006836}, {"content": "of", "parent": 5, "prob": -10.071856498718262}, {"content": "s", "parent": 5, "prob": -12.446856498718262}, {"content": "'", "parent": 6, "prob": -11.528319358825684}, {"content": "on", "parent": 7, "prob": -9.911931037902832}, {"content": "above", "parent": 7, "prob": -12.411931037902832}, {"content": "when", "parent": 7, "prob": -12.911931037902832}, {"content": "a", "parent": 8, "prob": -12.325116157531738}, {"content": "the", "parent": 8, "prob": -12.450116157531738}

array([[-1.1796875 , -0.27929688,  3.765625  , ..., -1.578125  ,
        -1.3671875 , -1.1875    ],
       [-1.046875  , -0.58203125,  4.6875    , ..., -1.8359375 ,
        -1.109375  , -1.15625   ],
       [-2.078125  , -0.06201172,  3.390625  , ..., -0.9375    ,
         0.12011719,  0.17382812],
       ...,
       [ 1.34375   ,  0.37109375,  2.609375  , ...,  2.        ,
         0.375     , -0.49023438],
       [-0.65234375, -2.359375  ,  2.421875  , ...,  0.296875  ,
         0.8046875 , -2.1875    ],
       [ 0.13769531, -1.390625  ,  2.609375  , ...,  0.5234375 ,
         0.04589844, -1.109375  ]], dtype=float32)


k_mean_space
(20, 2)


array([[78.43347 , 48.916756],
       [78.40378 , 49.44536 ],
       [87.708015, 57.705147],
       [87.06135 , 55.371918],
       [88.01002 , 58.748875],
       [86.306786, 61.551517],
       [86.55435 , 57.361248],
       [79.244995, 51.961414],
       [87.50915 , 58.107388],
       [90.545975, 72.50492 ],
       [48.82683 , 74.93038 ],
       [48.826836, 81.67753 ],
       [87.12727 , 67.208275],
       [84.1781  , 57.969383],
       [84.795975, 57.55641 ],
       [87.886345, 58.23756 ],
       [89.72801 , 61.97622 ],
       [92.60566 , 82.71032 ],
       [85.21981 , 61.678528],
       [80.00131 , 49.502087]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -22.32386208, -208.66769028])


closest
(2,)


array([10,  0])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 4.1562,  0.0344, -4.1562,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.2969, -1.3281, -4.2812,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.8281,  1.8203, -1.3516,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.7305, -4.2500, -4.8125,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.9219,  1.1562, -6.7188,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.5234,  0.3789, -7.0312,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[8.6643e-01, 6.2764e-02, 4.3137e-02,  ..., 4.3099e-21, 3.3565e-21,
         1.9125e-21],
        [9.4133e-01, 1.7241e-02, 1.7241e-02,  ..., 1.7398e-20, 1.1233e-20,
         7.7201e-21],
        [9.6215e-01, 2.9054e-02, 4.4556e-03,  ..., 2.2164e-23, 1.9559e-23,
         1.9559e-23],
        ...,
        [5.6102e-01, 4.3693e-01, 2.0234e-03,  ..., 1.7840e-22, 5.1114e-23,
         3.5130e-23],
        [9.6814e-01, 1.3810e-02, 1.0755e-02,  ..., 2.6034e-20, 2.1583e-20,
         4.8158e-21],
        [3.7727e-01, 1.3879e-01, 1.2248e-01,  ..., 1.1611e-19, 3.7694e-20,
         3.0941e-21]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.8664, 0.9292, 0.9723,  ..., 1.0000, 1.0000, 1.0000],
        [0.9413, 0.9586, 0.9758,  ..., 1.0000, 1.0000, 1.0000],
        [0.9621, 0.9912, 0.9957,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.5610, 0.9979, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9681, 0.9820, 0.9927,  ..., 1.0000, 1.0000, 1.0000],
        [0.3773, 0.5161, 0.6385,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([44])


tensor([ 0,  0,  1,  2,  3,  4,  5,  5,  5,  6,  6,  7,  7,  7,  7,  7,  8,  9,
        10, 10, 10, 10, 10, 10, 11, 12, 13, 14, 14, 14, 15, 15, 16, 17, 17, 18,
        19, 19, 19, 19, 19, 19, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([44, 65])


tensor([[    1, 32010,  1724,  ...,   276,  6721,  1048],
        [    1, 32010,  1724,  ...,   276,  6721,  1048],
        [    1, 32010,  1724,  ...,   276, 16811,   304],
        ...,
        [    1, 32010,  1724,  ...,  3171, 29879,   310],
        [    1, 32010,  1724,  ...,  3171, 29879,   310],
        [    1, 32010,  1724,  ...,  3171, 29879,   310]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([44])


tensor([ -8.5576,  -8.5576,  -9.5199, -10.2565, -12.8815, -10.8557,  -8.3231,
         -8.3231,  -8.3231, -10.0731, -10.0731, -10.0719, -10.0719, -10.0719,
        -10.0719, -10.0719, -12.4469, -11.5283,  -9.9119,  -9.9119,  -9.9119,
         -9.9119,  -9.9119,  -9.9119, -12.4119, -12.9119, -12.3251, -12.4501,
        -12.4501, -12.4501, -12.4501, -12.4501, -13.3251, -13.8251, -13.8251,
        -13.9501, -12.9155, -12.9155, -12.9155, -12.9155, -12.9155, -12.9155,
        -12.9155, -12.9155], device='cuda:0')


new_candidate_toks
torch.Size([44, 1])


tensor([[  278],
        [19223],
        [  278],
        [ 9939],
        [  373],
        [ 9939],
        [  373],
        [  746],
        [ 2038],
        [  373],
        [ 2038],
        [19223],
        [  278],
        [  263],
        [14378],
        [ 1090],
        [  310],
        [29879],
        [  738],
        [  916],
        [11563],
        [  263],
        [ 1269],
        [ 1422],
        [ 7205],
        [17005],
        [14378],
        [ 9939],
        [14378],
        [11563],
        [  373],
        [ 2038],
        [ 1236],
        [13405],
        [  344],
        [19223],
        [19223],
        [  278],
        [  599],
        [ 1090],
        [14378],
        [ 1716],
        [11563],
        [ 4655]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([44])


tensor([-0.1434, -2.7684, -0.0605, -0.0386, -0.0582, -0.0409, -0.2013, -2.7013,
        -2.9513, -0.1090, -2.6090, -1.1115, -1.3615, -1.8615, -2.2365, -2.8615,
        -0.0296,  0.0000, -0.8730, -1.6230, -2.1230, -2.6230, -2.6230, -3.3730,
        -0.0130, -0.0205, -0.0104, -0.7198, -0.9698, -3.0948, -0.2065, -1.8315,
        -0.1033, -0.5780, -0.8280, -0.0324, -0.9748, -1.9748, -2.0998, -2.0998,
        -2.2248, -4.0998, -4.3498, -4.5998], device='cuda:0')


new_candidates
torch.Size([44, 66])


tensor([[    1, 32010,  1724,  ...,  6721,  1048,   278],
        [    1, 32010,  1724,  ...,  6721,  1048, 19223],
        [    1, 32010,  1724,  ..., 16811,   304,   278],
        ...,
        [    1, 32010,  1724,  ..., 29879,   310,  1716],
        [    1, 32010,  1724,  ..., 29879,   310, 11563],
        [    1, 32010,  1724,  ..., 29879,   310,  4655]], device='cuda:0')


new_candidate_logprobs
torch.Size([44])


tensor([ -8.7010, -11.3260,  -9.5804, -10.2951, -12.9397, -10.8965,  -8.5244,
        -11.0244, -11.2744, -10.1821, -12.6821, -11.1833, -11.4333, -11.9333,
        -12.3083, -12.9333, -12.4765, -11.5283, -10.7849, -11.5349, -12.0349,
        -12.5349, -12.5349, -13.2849, -12.4249, -12.9324, -12.3355, -13.1699,
        -13.4199, -15.5449, -12.6566, -14.2816, -13.4284, -14.4031, -14.6531,
        -13.9825, -13.8903, -14.8903, -15.0153, -15.0153, -15.1403, -17.0153,
        -17.2653, -17.5153], device='cuda:0')

infer end: GPU memory used: 20637 MB.
event: level
id: 54
data: [{"content": "the", "parent": 0, "prob": -8.700953483581543}, {"content": "mountains", "parent": 0, "prob": -11.325953483581543}, {"content": "the", "parent": 1, "prob": -9.580357551574707}, {"content": "highest", "parent": 2, "prob": -10.295137405395508}, {"content": "on", "parent": 3, "prob": -12.939702987670898}, {"content": "highest", "parent": 4, "prob": -10.896538734436035}, {"content": "on", "parent": 5, "prob": -8.524429321289062}, {"content": "when", "parent": 5, "prob": -11.024429321289062}, {"content": "above", "parent": 5, "prob": -11.274429321289062}, {"content": "on", "parent": 6, "prob": -10.182137489318848}, {"content": "above", "parent": 6, "prob": -12.682136535644531}, {"content": "mountains", "parent": 7, "prob": -11.183318138122559}, {"content": "the", "parent": 7, "prob": -11.433318138122559}, {"content": "a", "parent": 7, "prob": -11.933318138122559}, {"content": "mountain", "parent": 7, "prob": -12.3

array([[-2.4218750e+00,  2.0751953e-02,  3.2500000e+00, ...,
        -9.0625000e-01,  1.5136719e-01, -3.3447266e-02],
       [-8.9355469e-02, -7.2656250e-01,  2.9062500e+00, ...,
        -1.3125000e+00,  7.5390625e-01,  1.0703125e+00],
       [-1.7890625e+00, -4.0893555e-03,  4.1562500e+00, ...,
        -7.6562500e-01,  2.0703125e-01,  2.1582031e-01],
       ...,
       [-1.7187500e+00, -1.1562500e+00,  1.7578125e+00, ...,
        -3.4218750e+00, -2.1191406e-01, -5.7421875e-01],
       [-3.8671875e-01, -2.1250000e+00,  3.5781250e+00, ...,
        -1.7109375e+00,  6.7871094e-02, -7.8906250e-01],
       [-1.3515625e+00, -7.5000000e-01,  2.8906250e+00, ...,
        -4.1015625e-01, -1.0595703e-01, -2.5468750e+00]], dtype=float32)


k_mean_space
(20, 2)


array([[92.33421 , 56.753643],
       [89.61084 , 61.35803 ],
       [92.667915, 57.226032],
       [88.74047 , 56.92788 ],
       [89.959946, 60.027527],
       [89.110466, 57.63178 ],
       [87.38272 , 61.977684],
       [89.73344 , 69.62505 ],
       [37.45103 , 81.16897 ],
       [87.411674, 58.2745  ],
       [36.80694 , 80.0221  ],
       [89.612915, 64.40186 ],
       [87.9445  , 57.027534],
       [88.08416 , 57.856705],
       [91.008965, 66.45528 ],
       [68.04355 , 86.30516 ],
       [84.87985 , 53.23344 ],
       [91.28643 , 63.412903],
       [95.05578 , 67.49861 ],
       [95.932785, 70.06264 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -36.889884  , -186.65254879])


closest
(2,)


array([10, 16])


last_tok_logits
torch.Size([20, 32064])


tensor([[  4.0625,   0.5312,  -0.6719,  ...,   0.0000,   0.0000,   0.0000],
        [  3.4375,   4.2500,  -6.1562,  ...,   0.0000,   0.0000,   0.0000],
        [  3.1406,   0.5117,  -0.7305,  ...,   0.0000,   0.0000,   0.0000],
        ...,
        [  1.8828,   1.7500,  -3.6250,  ...,   0.0000,   0.0000,   0.0000],
        [  1.8828,  -4.6875, -10.1875,  ...,   0.0000,   0.0000,   0.0000],
        [  1.6094,   0.4395,  -5.3750,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.6498e-01, 2.9140e-02, 2.3919e-03,  ..., 3.2343e-23, 2.5189e-23,
         1.7312e-23],
        [9.1108e-01, 4.0030e-02, 1.0121e-02,  ..., 1.3822e-21, 1.2984e-21,
         8.3834e-22],
        [9.7317e-01, 1.3882e-02, 5.7867e-03,  ..., 3.9736e-22, 1.8770e-22,
         1.2900e-22],
        ...,
        [2.4128e-01, 2.1293e-01, 1.4634e-01,  ..., 1.0804e-19, 7.4254e-20,
         2.1274e-20],
        [8.6037e-01, 4.8539e-02, 3.3360e-02,  ..., 5.4953e-21, 2.9414e-21,
         2.9414e-21],
        [8.1586e-01, 1.8204e-01, 1.3899e-03,  ..., 8.4228e-23, 6.5597e-23,
         5.1087e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9650, 0.9941, 0.9965,  ..., 1.0000, 1.0000, 1.0000],
        [0.9111, 0.9511, 0.9612,  ..., 1.0000, 1.0000, 1.0000],
        [0.9732, 0.9871, 0.9928,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.2413, 0.4542, 0.6006,  ..., 1.0000, 1.0000, 1.0000],
        [0.8604, 0.9089, 0.9423,  ..., 1.0000, 1.0000, 1.0000],
        [0.8159, 0.9979, 0.9993,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([53])


tensor([ 0,  1,  2,  3,  4,  5,  5,  6,  6,  6,  6,  6,  7,  8,  9,  9,  9, 10,
        11, 11, 12, 12, 12, 13, 14, 15, 15, 16, 16, 16, 16, 16, 16, 16, 17, 17,
        17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 19, 19],
       device='cuda:0')


carryover_candidates
torch.Size([53, 66])


tensor([[    1, 32010,  1724,  ...,  6721,  1048,   278],
        [    1, 32010,  1724,  ...,  6721,  1048, 19223],
        [    1, 32010,  1724,  ..., 16811,   304,   278],
        ...,
        [    1, 32010,  1724,  ..., 14378,   373,   738],
        [    1, 32010,  1724,  ..., 14378,   373,   916],
        [    1, 32010,  1724,  ..., 14378,   373,   916]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([53])


tensor([ -8.7010, -11.3260,  -9.5804, -10.2951, -12.9397, -10.8965, -10.8965,
         -8.5244,  -8.5244,  -8.5244,  -8.5244,  -8.5244, -11.0244, -11.2744,
        -10.1821, -10.1821, -10.1821, -12.6821, -11.1833, -11.1833, -11.4333,
        -11.4333, -11.4333, -11.9333, -12.3083, -12.9333, -12.9333, -12.4765,
        -12.4765, -12.4765, -12.4765, -12.4765, -12.4765, -12.4765, -11.5283,
        -11.5283, -11.5283, -11.5283, -11.5283, -11.5283, -11.5283, -11.5283,
        -11.5283, -11.5283, -11.5283, -11.5283, -11.5283, -11.5283, -11.5283,
        -10.7849, -10.7849, -11.5349, -11.5349], device='cuda:0')


new_candidate_toks
torch.Size([53, 1])


tensor([[ 9939],
        [  373],
        [ 9939],
        [14378],
        [  916],
        [14378],
        [ 2998],
        [  738],
        [  916],
        [11563],
        [ 1269],
        [  263],
        [17005],
        [ 7205],
        [  916],
        [ 1422],
        [ 3814],
        [ 7205],
        [  373],
        [ 2038],
        [ 9939],
        [14378],
        [11563],
        [14378],
        [ 1236],
        [13405],
        [  344],
        [19223],
        [  278],
        [  599],
        [ 1090],
        [14378],
        [ 5164],
        [ 1700],
        [24235],
        [11855],
        [ 9939],
        [ 1737],
        [ 4818],
        [ 7101],
        [ 2246],
        [11420],
        [ 7136],
        [10150],
        [ 1090],
        [ 7773],
        [12463],
        [13290],
        [13694],
        [15754],
        [ 6432],
        [ 3814],
        [ 6432]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([53])


tensor([-0.0357, -0.0931, -0.0272, -0.0595, -0.0080, -0.1270, -2.2520, -0.8045,
        -1.5545, -2.3045, -2.4295, -2.8045, -0.0175, -0.0064, -0.2147, -2.4647,
        -2.5897, -0.0052, -0.1620, -2.1620, -0.4554, -1.3304, -3.3304, -0.0088,
        -0.0980, -0.3894, -1.1394, -0.8105, -1.9355, -2.1855, -2.1855, -2.6855,
        -4.0605, -4.5605, -1.4218, -1.5468, -1.9218, -3.0468, -3.0468, -3.0468,
        -3.4218, -3.4218, -3.7968, -3.7968, -4.2968, -4.5468, -4.5468, -4.6718,
        -4.6718, -0.1504, -3.0254, -0.2035, -1.7035], device='cuda:0')


new_candidates
torch.Size([53, 67])


tensor([[    1, 32010,  1724,  ...,  1048,   278,  9939],
        [    1, 32010,  1724,  ...,  1048, 19223,   373],
        [    1, 32010,  1724,  ...,   304,   278,  9939],
        ...,
        [    1, 32010,  1724,  ...,   373,   738,  6432],
        [    1, 32010,  1724,  ...,   373,   916,  3814],
        [    1, 32010,  1724,  ...,   373,   916,  6432]], device='cuda:0')


new_candidate_logprobs
torch.Size([53])


tensor([ -8.7366, -11.4191,  -9.6075, -10.3546, -12.9477, -11.0235, -13.1485,
         -9.3289, -10.0789, -10.8289, -10.9539, -11.3289, -11.0420, -11.2808,
        -10.3968, -12.6468, -12.7718, -12.6874, -11.3453, -13.3453, -11.8887,
        -12.7637, -14.7637, -11.9421, -12.4063, -13.3227, -14.0727, -13.2870,
        -14.4120, -14.6620, -14.6620, -15.1620, -16.5370, -17.0370, -12.9501,
        -13.0751, -13.4501, -14.5751, -14.5751, -14.5751, -14.9501, -14.9501,
        -15.3251, -15.3251, -15.8251, -16.0751, -16.0751, -16.2001, -16.2001,
        -10.9353, -13.8103, -11.7384, -13.2384], device='cuda:0')

infer end: GPU memory used: 20799 MB.
event: level
id: 55
data: [{"content": "highest", "parent": 0, "prob": -8.736605644226074}, {"content": "on", "parent": 1, "prob": -11.41907787322998}, {"content": "highest", "parent": 2, "prob": -9.607549667358398}, {"content": "mountain", "parent": 3, "prob": -10.354646682739258}, {"content": "other", "parent": 4, "prob": -12.947659492492676}, {"content": "mountain", "parent": 5, "prob": -11.023494720458984}, {"content": "known", "parent": 5, "prob": -13.148494720458984}, {"content": "any", "parent": 6, "prob": -9.328913688659668}, {"content": "other", "parent": 6, "prob": -10.078912734985352}, {"content": "Earth", "parent": 6, "prob": -10.828912734985352}, {"content": "each", "parent": 6, "prob": -10.953913688659668}, {"content": "a", "parent": 6, "prob": -11.328912734985352}, {"content": "measured", "parent": 7, "prob": -11.041959762573242}, {"content": "sea", "parent": 8, "prob": -11.280831336975098}, {"content": "other", "parent": 9, "prob": 

array([[-1.6875    , -1.6796875 ,  5.5625    , ...,  0.265625  ,
        -0.8984375 ,  0.8828125 ],
       [-0.69140625,  0.44921875,  3.609375  , ..., -0.94140625,
         0.875     , -3.375     ],
       [-1.6328125 , -2.03125   ,  5.6875    , ...,  0.45117188,
        -0.39257812,  0.82421875],
       ...,
       [ 1.2578125 , -1.359375  , -1.234375  , ...,  2.640625  ,
        -1.1640625 ,  0.32617188],
       [ 0.47070312, -0.07519531,  3.65625   , ..., -1.0390625 ,
         0.44335938, -2.953125  ],
       [-1.28125   , -2.6875    ,  0.74609375, ...,  1.28125   ,
        -0.265625  ,  0.15917969]], dtype=float32)


k_mean_space
(20, 2)


array([[85.833755, 60.674984],
       [88.373436, 56.138107],
       [86.126526, 61.02873 ],
       [84.45035 , 65.540985],
       [92.09924 , 54.039375],
       [85.010796, 66.14728 ],
       [86.738106, 64.89748 ],
       [92.348495, 56.75853 ],
       [93.03602 , 54.02605 ],
       [84.949646, 72.4237  ],
       [91.865944, 65.36847 ],
       [87.22724 , 55.218174],
       [70.91554 , 82.70152 ],
       [47.664867, 91.020744],
       [92.28215 , 51.443447],
       [88.69988 , 52.47318 ],
       [94.2167  , 79.70719 ],
       [46.256435, 89.16398 ],
       [87.10993 , 57.158825],
       [64.70159 , 84.228325]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -48.35548973, -176.91784286])


closest
(2,)


array([17, 14])


last_tok_logits
torch.Size([20, 32064])


tensor([[  3.2344,   3.9844, -10.2500,  ...,   0.0000,   0.0000,   0.0000],
        [  3.0000,  -1.2266,  -7.0625,  ...,   0.0000,   0.0000,   0.0000],
        [  2.6719,   3.5156, -10.1875,  ...,   0.0000,   0.0000,   0.0000],
        ...,
        [  0.9805,  -2.6719,  -5.1250,  ...,   0.0000,   0.0000,   0.0000],
        [  1.1094,  -3.4844,  -6.2500,  ...,   0.0000,   0.0000,   0.0000],
        [ -0.2832,  -2.1094,  -9.2500,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.1072e-01, 5.8220e-02, 2.4270e-02,  ..., 4.5302e-21, 3.5281e-21,
         1.6666e-21],
        [9.9208e-01, 4.5943e-03, 1.0251e-03,  ..., 4.8380e-23, 9.5267e-24,
         1.0041e-24],
        [8.4813e-01, 1.3006e-01, 1.2098e-02,  ..., 7.8819e-21, 6.1384e-21,
         2.8996e-21],
        ...,
        [9.9997e-01, 2.7536e-05, 9.9309e-08,  ..., 5.8999e-29, 2.4595e-29,
         1.3165e-29],
        [9.9521e-01, 1.1653e-03, 1.0283e-03,  ..., 3.5861e-22, 3.1647e-22,
         1.0274e-22],
        [8.8883e-01, 7.2960e-02, 2.6840e-02,  ..., 6.4330e-21, 4.4213e-21,
         3.0387e-21]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9107, 0.9689, 0.9932,  ..., 1.0000, 1.0000, 1.0000],
        [0.9921, 0.9967, 0.9977,  ..., 1.0000, 1.0000, 1.0000],
        [0.8481, 0.9782, 0.9903,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9952, 0.9964, 0.9974,  ..., 1.0000, 1.0000, 1.0000],
        [0.8888, 0.9618, 0.9886,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([39])


tensor([ 0,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  7,  8,  8,  9,
         9,  9,  9,  9,  9,  9,  9,  9, 10, 11, 12, 13, 14, 14, 15, 15, 16, 17,
        18, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([39, 67])


tensor([[    1, 32010,  1724,  ...,  1048,   278,  9939],
        [    1, 32010,  1724,  ...,  1048, 19223,   373],
        [    1, 32010,  1724,  ...,   304,   278,  9939],
        ...,
        [    1, 32010,  1724,  ...,   310, 19223,   373],
        [    1, 32010,  1724,  ...,   310, 19223,  2038],
        [    1, 32010,  1724,  ...,   310, 19223,  2038]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([39])


tensor([ -8.7366, -11.4191,  -9.6075,  -9.6075, -10.3546, -10.3546, -12.9477,
        -12.9477, -11.0235, -11.0235, -13.1485, -13.1485,  -9.3289,  -9.3289,
         -9.3289, -10.0789, -10.0789, -10.8289, -10.8289, -10.8289, -10.8289,
        -10.8289, -10.8289, -10.8289, -10.8289, -10.8289, -10.9539, -11.3289,
        -11.0420, -11.2808, -10.3968, -10.3968, -12.6468, -12.6468, -12.7718,
        -12.6874, -11.3453, -13.3453, -13.3453], device='cuda:0')


new_candidate_toks
torch.Size([39, 1])


tensor([[14378],
        [  916],
        [14378],
        [ 2998],
        [  373],
        [  746],
        [ 6432],
        [ 3814],
        [  373],
        [  746],
        [14378],
        [ 1298],
        [15754],
        [  916],
        [ 6432],
        [ 3814],
        [ 6432],
        [  491],
        [ 2729],
        [29915],
        [  429],
        [  746],
        [29892],
        [  297],
        [  393],
        [13858],
        [25523],
        [15754],
        [  515],
        [ 3233],
        [ 3814],
        [ 6432],
        [ 3814],
        [ 6432],
        [ 1691],
        [ 3233],
        [  916],
        [ 7205],
        [  278]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([39])


tensor([-9.3524e-02, -7.9466e-03, -1.6472e-01, -2.0397e+00, -1.8638e-01,
        -2.4364e+00, -4.7538e-01, -9.7538e-01, -1.1972e-01, -2.8697e+00,
        -3.7728e-01, -1.2523e+00, -2.1722e-01, -2.5922e+00, -2.7172e+00,
        -2.2780e-01, -1.6028e+00, -8.2192e-01, -2.0719e+00, -2.0719e+00,
        -2.5719e+00, -3.1969e+00, -3.3219e+00, -3.8219e+00, -3.8219e+00,
        -3.8219e+00, -9.8616e-02, -3.1781e-02, -1.8445e-02, -8.1063e-06,
        -1.8384e-01, -1.8088e+00, -1.3252e-01, -2.2575e+00, -1.2279e-05,
        -2.7776e-05, -4.8012e-03, -1.1785e-01, -2.6178e+00], device='cuda:0')


new_candidates
torch.Size([39, 68])


tensor([[    1, 32010,  1724,  ...,   278,  9939, 14378],
        [    1, 32010,  1724,  ..., 19223,   373,   916],
        [    1, 32010,  1724,  ...,   278,  9939, 14378],
        ...,
        [    1, 32010,  1724,  ..., 19223,   373,   916],
        [    1, 32010,  1724,  ..., 19223,  2038,  7205],
        [    1, 32010,  1724,  ..., 19223,  2038,   278]], device='cuda:0')


new_candidate_logprobs
torch.Size([39])


tensor([ -8.8301, -11.4270,  -9.7723, -11.6473, -10.5410, -12.7910, -13.4230,
        -13.9230, -11.1432, -13.8932, -13.5258, -14.4008,  -9.5461, -11.9211,
        -12.0461, -10.3067, -11.6817, -11.6508, -12.9008, -12.9008, -13.4008,
        -14.0258, -14.1508, -14.6508, -14.6508, -14.6508, -11.0525, -11.3607,
        -11.0604, -11.2808, -10.5807, -12.2057, -12.7793, -14.9043, -12.7718,
        -12.6874, -11.3501, -13.4632, -15.9632], device='cuda:0')

infer end: GPU memory used: 20963 MB.
event: level
id: 56
data: [{"content": "mountain", "parent": 0, "prob": -8.830129623413086}, {"content": "other", "parent": 1, "prob": -11.427024841308594}, {"content": "mountain", "parent": 2, "prob": -9.772273063659668}, {"content": "known", "parent": 2, "prob": -11.647273063659668}, {"content": "on", "parent": 3, "prob": -10.541027069091797}, {"content": "when", "parent": 3, "prob": -12.791027069091797}, {"content": "cel", "parent": 4, "prob": -13.423038482666016}, {"content": "plan", "parent": 4, "prob": -13.923038482666016}, {"content": "on", "parent": 5, "prob": -11.143211364746094}, {"content": "when", "parent": 5, "prob": -13.893211364746094}, {"content": "mountain", "parent": 6, "prob": -13.525774955749512}, {"content": "point", "parent": 6, "prob": -14.400774955749512}, {"content": "planet", "parent": 7, "prob": -9.546137809753418}, {"content": "other", "parent": 7, "prob": -11.921138763427734}, {"content": "cel", "parent": 7, "prob": -12

array([[-2.625     , -0.3515625 ,  3.390625  , ..., -1.28125   ,
        -0.08886719, -0.55078125],
       [-1.2109375 , -0.13476562,  2.359375  , ..., -2.        ,
         0.48632812, -2.921875  ],
       [-2.953125  , -0.828125  ,  3.9375    , ..., -1.359375  ,
         0.16308594, -0.5390625 ],
       ...,
       [-1.2265625 , -2.0625    ,  3.015625  , ..., -1.8984375 ,
         1.46875   , -1.2890625 ],
       [-0.46679688, -0.3203125 ,  1.8515625 , ...,  0.46289062,
         1.75      , -0.9609375 ],
       [ 1.6953125 , -0.38671875,  2.375     , ..., -1.359375  ,
        -1.234375  , -1.2734375 ]], dtype=float32)


k_mean_space
(20, 2)


array([[ 54.18247 ,  98.03536 ],
       [ 69.65913 ,  91.61021 ],
       [ 54.794025,  98.24817 ],
       [ 64.487144,  99.15272 ],
       [ 62.682514,  96.265594],
       [ 60.08402 ,  98.392204],
       [ 91.20106 ,  43.721764],
       [ 87.57409 ,  62.317654],
       [ 62.820396,  96.349976],
       [ 59.21685 ,  97.25133 ],
       [ 55.477573,  97.735825],
       [ 65.73158 ,  98.29838 ],
       [ 64.51429 ,  88.09326 ],
       [ 67.79462 ,  93.736626],
       [ 90.01204 ,  45.203438],
       [ 87.51955 ,  62.9588  ],
       [ 90.345406,  42.18647 ],
       [ 68.19369 , 100.52603 ],
       [ 68.95328 ,  97.051155],
       [ 76.43719 ,  99.34706 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-176.89149475,  -61.38063812])


closest
(2,)


array([ 0, 16])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 3.7031,  3.5781, -5.8125,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.1875,  1.7500, -4.9375,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.7812,  2.9375, -6.0625,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-1.5156,  0.6016, -6.2188,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0266, -0.2793, -3.4219,  ...,  0.0000,  0.0000,  0.0000],
        [-2.2500, -0.1416, -4.0938,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[8.0558e-01, 1.0902e-01, 2.4326e-02,  ..., 7.1029e-20, 5.5318e-20,
         3.3552e-20],
        [4.9885e-01, 4.9885e-01, 1.0912e-03,  ..., 1.7623e-24, 9.4327e-25,
         8.3243e-25],
        [8.5348e-01, 7.9386e-02, 1.7714e-02,  ..., 1.4059e-19, 5.1721e-20,
         2.7684e-20],
        ...,
        [5.1310e-01, 3.1121e-01, 4.7726e-02,  ..., 9.5775e-20, 8.4521e-20,
         7.4589e-20],
        [9.9982e-01, 5.8284e-05, 5.1436e-05,  ..., 3.1170e-24, 1.2993e-24,
         1.1467e-24],
        [1.0000e+00, 6.4759e-07, 1.1861e-08,  ..., 6.2617e-23, 6.2617e-23,
         4.3036e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.8056, 0.9146, 0.9389,  ..., 1.0000, 1.0000, 1.0000],
        [0.4989, 0.9977, 0.9988,  ..., 1.0000, 1.0000, 1.0000],
        [0.8535, 0.9329, 0.9506,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.5131, 0.8243, 0.8720,  ..., 1.0000, 1.0000, 1.0000],
        [0.9998, 0.9999, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([39])


tensor([ 0,  0,  1,  1,  2,  2,  3,  3,  4,  4,  4,  5,  5,  6,  7,  8,  8,  8,
         9,  9, 10, 10, 10, 11, 12, 12, 12, 13, 13, 14, 15, 16, 17, 17, 17, 17,
        17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([39, 68])


tensor([[    1, 32010,  1724,  ...,   278,  9939, 14378],
        [    1, 32010,  1724,  ...,   278,  9939, 14378],
        [    1, 32010,  1724,  ..., 19223,   373,   916],
        ...,
        [    1, 32010,  1724,  ...,   373, 11563,   491],
        [    1, 32010,  1724,  ...,   373, 11563,  2729],
        [    1, 32010,  1724,  ...,   373, 11563, 29915]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([39])


tensor([ -8.8301,  -8.8301, -11.4270, -11.4270,  -9.7723,  -9.7723, -11.6473,
        -11.6473, -10.5410, -10.5410, -10.5410, -12.7910, -12.7910, -13.4230,
        -13.9230, -11.1432, -11.1432, -11.1432, -13.8932, -13.8932, -13.5258,
        -13.5258, -13.5258, -14.4008,  -9.5461,  -9.5461,  -9.5461, -11.9211,
        -11.9211, -12.0461, -10.3067, -11.6817, -11.6508, -11.6508, -11.6508,
        -11.6508, -11.6508, -12.9008, -12.9008], device='cuda:0')


new_candidate_toks
torch.Size([39, 1])


tensor([[  373],
        [  746],
        [ 3814],
        [ 6432],
        [  373],
        [  746],
        [14378],
        [ 1298],
        [  738],
        [  263],
        [  916],
        [17005],
        [13858],
        [  342],
        [ 1691],
        [  738],
        [  263],
        [  916],
        [17005],
        [13858],
        [  373],
        [  297],
        [  746],
        [  373],
        [  297],
        [29892],
        [  916],
        [15754],
        [ 6432],
        [  342],
        [ 1691],
        [  342],
        [ 2967],
        [ 7977],
        [ 2246],
        [  278],
        [ 4158],
        [  373],
        [29879]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([39])


tensor([-2.1620e-01, -2.2162e+00, -6.9545e-01, -6.9545e-01, -1.5843e-01,
        -2.5334e+00, -5.1503e-01, -1.0150e+00, -5.8149e-01, -1.4565e+00,
        -2.2065e+00, -5.8446e-01, -8.3446e-01, -2.3842e-06, -3.5763e-07,
        -3.6308e-01, -1.7381e+00, -2.4881e+00, -5.8924e-01, -8.3924e-01,
        -2.1845e-01, -2.5934e+00, -2.8434e+00, -1.1628e-02, -6.3569e-01,
        -1.1357e+00, -2.2607e+00, -3.2085e-01, -1.3209e+00, -2.2650e-06,
        -8.1063e-06, -4.1723e-06, -6.6729e-01, -1.1673e+00, -3.0423e+00,
        -3.9173e+00, -3.9173e+00, -1.8080e-04, -5.9605e-07], device='cuda:0')


new_candidates
torch.Size([39, 69])


tensor([[    1, 32010,  1724,  ...,  9939, 14378,   373],
        [    1, 32010,  1724,  ...,  9939, 14378,   746],
        [    1, 32010,  1724,  ...,   373,   916,  3814],
        ...,
        [    1, 32010,  1724,  ..., 11563,   491,  4158],
        [    1, 32010,  1724,  ..., 11563,  2729,   373],
        [    1, 32010,  1724,  ..., 11563, 29915, 29879]], device='cuda:0')


new_candidate_logprobs
torch.Size([39])


tensor([ -9.0463, -11.0463, -12.1225, -12.1225,  -9.9307, -12.3057, -12.1623,
        -12.6623, -11.1225, -11.9975, -12.7475, -13.3755, -13.6255, -13.4230,
        -13.9230, -11.5063, -12.8813, -13.6313, -14.4824, -14.7324, -13.7442,
        -16.1192, -16.3692, -14.4124, -10.1818, -10.6818, -11.8068, -12.2420,
        -13.2420, -12.0461, -10.3067, -11.6817, -12.3181, -12.8181, -14.6931,
        -15.5681, -15.5681, -12.9010, -12.9008], device='cuda:0')

infer end: GPU memory used: 21131 MB.
event: level
id: 57
data: [{"content": "on", "parent": 0, "prob": -9.04632568359375}, {"content": "when", "parent": 0, "prob": -11.04632568359375}, {"content": "plan", "parent": 1, "prob": -12.122471809387207}, {"content": "cel", "parent": 1, "prob": -12.122471809387207}, {"content": "on", "parent": 2, "prob": -9.930700302124023}, {"content": "when", "parent": 2, "prob": -12.305700302124023}, {"content": "mountain", "parent": 3, "prob": -12.162300109863281}, {"content": "point", "parent": 3, "prob": -12.662300109863281}, {"content": "any", "parent": 4, "prob": -11.12252140045166}, {"content": "a", "parent": 4, "prob": -11.99752140045166}, {"content": "other", "parent": 4, "prob": -12.74752140045166}, {"content": "measured", "parent": 5, "prob": -13.375483512878418}, {"content": "considering", "parent": 5, "prob": -13.625483512878418}, {"content": "est", "parent": 6, "prob": -13.423041343688965}, {"content": "ets", "parent": 7, "prob": -13.923038482

array([[ 0.48828125, -0.33398438,  3.8125    , ..., -1.328125  ,
         0.2890625 , -3.71875   ],
       [-1.265625  , -0.23925781,  2.734375  , ..., -1.8828125 ,
         0.9921875 , -0.47070312],
       [-1.4140625 ,  1.5625    ,  0.7109375 , ..., -1.578125  ,
         0.25976562,  0.78515625],
       ...,
       [-0.921875  ,  0.03417969,  2.890625  , ..., -1.1328125 ,
         1.0234375 , -2.546875  ],
       [-1.640625  , -0.9375    ,  2.4375    , ...,  0.72265625,
         0.9296875 , -0.55078125],
       [-0.41601562, -1.2109375 ,  3.65625   , ..., -1.78125   ,
        -0.265625  , -0.35351562]], dtype=float32)


k_mean_space
(20, 2)


array([[54.01814 , 79.48579 ],
       [81.21404 , 47.809723],
       [80.283966, 94.677765],
       [83.80721 , 95.447716],
       [55.145557, 79.629234],
       [80.554344, 47.581936],
       [81.850334, 61.355473],
       [84.3637  , 65.49132 ],
       [47.487514, 81.906876],
       [50.11306 , 77.21988 ],
       [50.787342, 85.156456],
       [87.51943 , 51.750034],
       [74.714516, 50.492428],
       [87.117546, 97.27469 ],
       [78.04128 , 67.44442 ],
       [48.603523, 84.16489 ],
       [49.2834  , 77.83594 ],
       [51.773647, 86.11455 ],
       [87.946144, 53.572407],
       [74.56175 , 49.986866]], dtype=float32)


k_mean_clusters
(20,)


array([0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-130.53145504, -118.31553078])


closest
(2,)


array([8, 5])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 4.3438,  1.2969, -4.0312,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.4219, -3.2656, -0.8867,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.1250, -4.3750, -1.9141,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 2.8750,  1.7812, -4.3125,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.2031, -3.3594, -2.5469,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.9375, -2.3750, -3.3281,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[5.2146e-01, 2.7912e-01, 1.0268e-01,  ..., 1.3173e-20, 1.0259e-20,
         6.2224e-21],
        [5.5634e-01, 4.3328e-01, 2.9194e-03,  ..., 1.3072e-21, 3.3052e-22,
         8.3568e-23],
        [1.0000e+00, 1.2752e-07, 9.9312e-08,  ..., 5.8243e-24, 1.4726e-24,
         1.2996e-24],
        ...,
        [6.7817e-01, 3.2034e-01, 6.1841e-04,  ..., 1.7702e-23, 1.2166e-23,
         1.0737e-23],
        [9.0407e-01, 9.5289e-02, 2.6765e-04,  ..., 3.2577e-22, 2.8749e-22,
         1.9759e-22],
        [5.8924e-01, 1.1603e-01, 5.4808e-02,  ..., 2.1658e-20, 1.9113e-20,
         9.5158e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.5215, 0.8006, 0.9033,  ..., 1.0000, 1.0000, 1.0000],
        [0.5563, 0.9896, 0.9925,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.6782, 0.9985, 0.9991,  ..., 1.0000, 1.0000, 1.0000],
        [0.9041, 0.9994, 0.9996,  ..., 1.0000, 1.0000, 1.0000],
        [0.5892, 0.7053, 0.7601,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([55])


tensor([ 0,  0,  0,  1,  1,  2,  3,  4,  4,  4,  5,  5,  6,  6,  6,  7,  8,  8,
         8,  9,  9, 10, 10, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 14, 14,
        14, 15, 15, 15, 16, 16, 16, 17, 17, 18, 19, 19, 19, 19, 19, 19, 19, 19,
        19], device='cuda:0')


carryover_candidates
torch.Size([55, 69])


tensor([[    1, 32010,  1724,  ...,  9939, 14378,   373],
        [    1, 32010,  1724,  ...,  9939, 14378,   373],
        [    1, 32010,  1724,  ...,  9939, 14378,   373],
        ...,
        [    1, 32010,  1724,  ..., 14378,   746, 13858],
        [    1, 32010,  1724,  ..., 14378,   746, 13858],
        [    1, 32010,  1724,  ..., 14378,   746, 13858]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([55])


tensor([ -9.0463,  -9.0463,  -9.0463, -11.0463, -11.0463, -12.1225, -12.1225,
         -9.9307,  -9.9307,  -9.9307, -12.3057, -12.3057, -12.1623, -12.1623,
        -12.1623, -12.6623, -11.1225, -11.1225, -11.1225, -11.9975, -11.9975,
        -12.7475, -12.7475, -13.3755, -13.6255, -13.6255, -13.6255, -13.6255,
        -13.6255, -13.6255, -13.6255, -13.6255, -13.6255, -13.4230, -13.9230,
        -13.9230, -13.9230, -11.5063, -11.5063, -11.5063, -12.8813, -12.8813,
        -12.8813, -13.6313, -13.6313, -14.4824, -14.7324, -14.7324, -14.7324,
        -14.7324, -14.7324, -14.7324, -14.7324, -14.7324, -14.7324],
       device='cuda:0')


new_candidate_toks
torch.Size([55, 1])


tensor([[  738],
        [  263],
        [  916],
        [13858],
        [17005],
        [ 1691],
        [  342],
        [  738],
        [  263],
        [  916],
        [13858],
        [17005],
        [  373],
        [  297],
        [  746],
        [  373],
        [  916],
        [15754],
        [ 6432],
        [15754],
        [ 1422],
        [ 3814],
        [ 6432],
        [  515],
        [  278],
        [11563],
        [  916],
        [14378],
        [  263],
        [15754],
        [19223],
        [  599],
        [ 1090],
        [  616],
        [29892],
        [  297],
        [  470],
        [15754],
        [  916],
        [ 6432],
        [15754],
        [ 1422],
        [ 6432],
        [ 3814],
        [ 6432],
        [  515],
        [  278],
        [11563],
        [  263],
        [14378],
        [15754],
        [ 4158],
        [  916],
        [  967],
        [ 1749]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([55])


tensor([-6.5113e-01, -1.2761e+00, -2.2761e+00, -5.8638e-01, -8.3638e-01,
        -2.3842e-07, -2.6226e-06, -3.9348e-01, -1.6435e+00, -2.6435e+00,
        -5.8801e-01, -8.3801e-01, -2.6270e-01, -2.6377e+00, -2.6377e+00,
        -1.6553e-02, -4.7573e-01, -1.4757e+00, -1.9757e+00, -2.3650e-01,
        -1.9865e+00, -3.8810e-01, -1.1381e+00, -7.9710e-02, -4.6288e-01,
        -2.2129e+00, -3.2129e+00, -3.2129e+00, -3.7129e+00, -3.8379e+00,
        -3.9629e+00, -4.3379e+00, -4.3379e+00, -7.1526e-07, -4.3454e-01,
        -1.6845e+00, -2.1845e+00, -8.2979e-01, -9.5479e-01, -1.8298e+00,
        -4.3981e-01, -1.4398e+00, -2.6898e+00, -3.8836e-01, -1.1384e+00,
        -1.0084e-01, -5.2892e-01, -2.1539e+00, -2.9039e+00, -3.0289e+00,
        -3.0289e+00, -4.1539e+00, -4.2789e+00, -4.4039e+00, -4.6539e+00],
       device='cuda:0')


new_candidates
torch.Size([55, 70])


tensor([[    1, 32010,  1724,  ..., 14378,   373,   738],
        [    1, 32010,  1724,  ..., 14378,   373,   263],
        [    1, 32010,  1724,  ..., 14378,   373,   916],
        ...,
        [    1, 32010,  1724,  ...,   746, 13858,   916],
        [    1, 32010,  1724,  ...,   746, 13858,   967],
        [    1, 32010,  1724,  ...,   746, 13858,  1749]], device='cuda:0')


new_candidate_logprobs
torch.Size([55])


tensor([ -9.6975, -10.3225, -11.3225, -11.6327, -11.8827, -12.1225, -12.1225,
        -10.3242, -11.5742, -12.5742, -12.8937, -13.1437, -12.4250, -14.8000,
        -14.8000, -12.6789, -11.5982, -12.5982, -13.0982, -12.2340, -13.9840,
        -13.1356, -13.8856, -13.4552, -14.0884, -15.8384, -16.8384, -16.8384,
        -17.3384, -17.4634, -17.5884, -17.9634, -17.9634, -13.4230, -14.3576,
        -15.6076, -16.1076, -12.3361, -12.4611, -13.3361, -13.3211, -14.3211,
        -15.5711, -14.0197, -14.7697, -14.5833, -15.2614, -16.8864, -17.6364,
        -17.7614, -17.7614, -18.8864, -19.0114, -19.1364, -19.3864],
       device='cuda:0')

infer end: GPU memory used: 21301 MB.
event: level
id: 58
data: [{"content": "any", "parent": 0, "prob": -9.697453498840332}, {"content": "a", "parent": 0, "prob": -10.322453498840332}, {"content": "other", "parent": 0, "prob": -11.322453498840332}, {"content": "considering", "parent": 1, "prob": -11.632707595825195}, {"content": "measured", "parent": 1, "prob": -11.882707595825195}, {"content": "ets", "parent": 2, "prob": -12.122471809387207}, {"content": "est", "parent": 3, "prob": -12.122474670410156}, {"content": "any", "parent": 4, "prob": -10.324176788330078}, {"content": "a", "parent": 4, "prob": -11.574176788330078}, {"content": "other", "parent": 4, "prob": -12.574176788330078}, {"content": "considering", "parent": 5, "prob": -12.893712997436523}, {"content": "measured", "parent": 5, "prob": -13.143712997436523}, {"content": "on", "parent": 6, "prob": -12.42500114440918}, {"content": "in", "parent": 6, "prob": -14.80000114440918}, {"content": "when", "parent": 6, "prob": -14.8

array([[-0.96875   , -0.00860596,  3.703125  , ..., -1.8828125 ,
         0.23535156, -0.390625  ],
       [-1.3203125 ,  0.13476562,  4.28125   , ..., -0.09570312,
         0.43164062, -0.984375  ],
       [-0.98828125,  0.1640625 ,  2.703125  , ..., -1.3515625 ,
         0.83203125, -2.671875  ],
       ...,
       [-0.01165771,  0.4609375 ,  0.578125  , ..., -0.8125    ,
         0.984375  , -0.08544922],
       [-0.3125    ,  2.78125   ,  0.51171875, ..., -1.59375   ,
         0.20703125, -0.49609375],
       [-0.6953125 ,  0.69921875,  1.234375  , ..., -0.765625  ,
         2.21875   ,  0.40429688]], dtype=float32)


k_mean_space
(20, 2)


array([[75.87932 , 41.617405],
       [72.37803 , 47.78649 ],
       [78.857994, 46.826942],
       [52.618813, 74.815155],
       [56.311615, 89.7061  ],
       [59.790855, 82.487785],
       [92.60613 , 91.33107 ],
       [78.007576, 42.860153],
       [72.68726 , 46.837223],
       [79.613556, 47.659637],
       [52.063007, 74.55534 ],
       [57.480156, 90.18555 ],
       [75.91296 , 58.550953],
       [67.35506 , 78.68883 ],
       [57.24728 , 83.07555 ],
       [75.80606 , 61.31499 ],
       [72.289505, 39.873283],
       [57.776947, 77.38706 ],
       [83.793686, 93.542206],
       [58.36542 , 75.7596  ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-129.20583248, -114.63946915])


closest
(2,)


array([10, 16])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 2.8906, -1.0000, -9.8750,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.7969, -1.6094, -8.2500,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.2344,  2.8281, -3.8125,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.6406, -2.5625, -7.9375,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.6797, -5.4062,  0.3477,  ...,  0.0000,  0.0000,  0.0000],
        [-0.4004, -1.9219, -4.8750,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[5.8256e-01, 2.7518e-01, 1.2999e-01,  ..., 1.2732e-22, 9.9158e-23,
         1.7231e-23],
        [8.7767e-01, 8.1636e-02, 1.6075e-02,  ..., 4.6015e-22, 2.4630e-22,
         1.6928e-22],
        [6.7795e-01, 3.2024e-01, 1.0193e-03,  ..., 7.3769e-24, 3.9486e-24,
         1.6460e-24],
        ...,
        [4.5342e-01, 3.5312e-01, 1.2991e-01,  ..., 2.3772e-22, 1.8514e-22,
         4.1310e-23],
        [1.0000e+00, 1.6374e-07, 9.9312e-08,  ..., 1.3256e-22, 3.7979e-23,
         1.0881e-23],
        [9.2190e-01, 1.6885e-02, 1.6885e-02,  ..., 9.0301e-22, 9.0301e-22,
         2.9316e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.5826, 0.8577, 0.9877,  ..., 1.0000, 1.0000, 1.0000],
        [0.8777, 0.9593, 0.9754,  ..., 1.0000, 1.0000, 1.0000],
        [0.6779, 0.9982, 0.9992,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.4534, 0.8065, 0.9365,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9219, 0.9388, 0.9557,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([46])


tensor([ 0,  0,  0,  1,  1,  2,  2,  3,  3,  3,  3,  3,  4,  5,  5,  5,  6,  7,
         7,  7,  8,  8,  9,  9, 10, 10, 10, 10, 10, 11, 12, 12, 13, 13, 14, 14,
        15, 15, 15, 16, 16, 17, 17, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([46, 70])


tensor([[    1, 32010,  1724,  ..., 14378,   373,   738],
        [    1, 32010,  1724,  ..., 14378,   373,   738],
        [    1, 32010,  1724,  ..., 14378,   373,   738],
        ...,
        [    1, 32010,  1724,  ...,   373,   738, 15754],
        [    1, 32010,  1724,  ...,   373,   738,  6432],
        [    1, 32010,  1724,  ...,   373,   263, 15754]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([46])


tensor([ -9.6975,  -9.6975,  -9.6975, -10.3225, -10.3225, -11.3225, -11.3225,
        -11.6327, -11.6327, -11.6327, -11.6327, -11.6327, -11.8827, -12.1225,
        -12.1225, -12.1225, -12.1225, -10.3242, -10.3242, -10.3242, -11.5742,
        -11.5742, -12.5742, -12.5742, -12.8937, -12.8937, -12.8937, -12.8937,
        -12.8937, -13.1437, -12.4250, -12.4250, -14.8000, -14.8000, -14.8000,
        -14.8000, -12.6789, -12.6789, -12.6789, -11.5982, -11.5982, -12.5982,
        -12.5982, -12.5982, -13.0982, -12.2340], device='cuda:0')


new_candidate_toks
torch.Size([46, 1])


tensor([[  916],
        [15754],
        [ 6432],
        [15754],
        [ 1422],
        [ 3814],
        [ 6432],
        [  278],
        [11563],
        [  916],
        [14378],
        [15754],
        [  515],
        [29892],
        [  297],
        [  470],
        [  616],
        [15754],
        [  916],
        [ 6432],
        [15754],
        [ 1422],
        [ 3814],
        [ 6432],
        [  278],
        [11563],
        [15754],
        [  263],
        [14378],
        [  515],
        [  738],
        [  263],
        [ 4958],
        [  278],
        [13858],
        [17005],
        [  738],
        [  263],
        [11563],
        [15754],
        [ 6432],
        [  297],
        [  916],
        [29892],
        [  342],
        [  916]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([46])


tensor([-5.4032e-01, -1.2903e+00, -2.0403e+00, -1.3048e-01, -2.5055e+00,
        -3.8869e-01, -1.1387e+00, -2.7439e-01, -2.6494e+00, -3.3994e+00,
        -3.7744e+00, -4.1494e+00, -7.1061e-02, -4.1342e-01, -1.6634e+00,
        -2.1634e+00, -1.1921e-06, -6.7865e-01, -1.0537e+00, -2.0537e+00,
        -2.2467e-01, -1.9747e+00, -4.3092e-01, -1.0559e+00, -2.7174e-01,
        -2.5217e+00, -3.6467e+00, -3.7717e+00, -3.7717e+00, -5.5615e-02,
        -1.5438e-01, -2.0294e+00, -4.4848e-01, -1.0735e+00, -1.9322e-01,
        -1.9432e+00, -5.1989e-01, -1.3949e+00, -2.1449e+00, -3.9274e-01,
        -1.1427e+00, -7.9093e-01, -1.0409e+00, -2.0409e+00, -3.5763e-07,
        -8.1314e-02], device='cuda:0')


new_candidates
torch.Size([46, 71])


tensor([[    1, 32010,  1724,  ...,   373,   738,   916],
        [    1, 32010,  1724,  ...,   373,   738, 15754],
        [    1, 32010,  1724,  ...,   373,   738,  6432],
        ...,
        [    1, 32010,  1724,  ...,   738, 15754, 29892],
        [    1, 32010,  1724,  ...,   738,  6432,   342],
        [    1, 32010,  1724,  ...,   263, 15754,   916]], device='cuda:0')


new_candidate_logprobs
torch.Size([46])


tensor([-10.2378, -10.9878, -11.7378, -10.4529, -12.8279, -11.7111, -12.4611,
        -11.9071, -14.2821, -15.0321, -15.4071, -15.7821, -11.9538, -12.5359,
        -13.7859, -14.2859, -12.1225, -11.0028, -11.3778, -12.3778, -11.7988,
        -13.5488, -13.0051, -13.6301, -13.1655, -15.4155, -16.5405, -16.6655,
        -16.6655, -13.1993, -12.5794, -14.4544, -15.2485, -15.8735, -14.9932,
        -16.7432, -13.1987, -14.0737, -14.8237, -11.9910, -12.7410, -13.3892,
        -13.6392, -14.6392, -13.0982, -12.3153], device='cuda:0')

infer end: GPU memory used: 21473 MB.
event: level
id: 59
data: [{"content": "other", "parent": 0, "prob": -10.237777709960938}, {"content": "planet", "parent": 0, "prob": -10.987777709960938}, {"content": "cel", "parent": 0, "prob": -11.737777709960938}, {"content": "planet", "parent": 1, "prob": -10.452933311462402}, {"content": "different", "parent": 1, "prob": -12.827933311462402}, {"content": "plan", "parent": 2, "prob": -11.711139678955078}, {"content": "cel", "parent": 2, "prob": -12.461139678955078}, {"content": "the", "parent": 3, "prob": -11.90709400177002}, {"content": "Earth", "parent": 3, "prob": -14.282094955444336}, {"content": "other", "parent": 3, "prob": -15.03209400177002}, {"content": "mountain", "parent": 3, "prob": -15.40709400177002}, {"content": "planet", "parent": 3, "prob": -15.782094955444336}, {"content": "from", "parent": 4, "prob": -11.953768730163574}, {"content": ",", "parent": 5, "prob": -12.53589153289795}, {"content": "in", "parent": 5, "prob": -13.78

array([[-1.71875   , -0.43164062,  2.8125    , ..., -1.8828125 ,
        -0.16992188, -0.98046875],
       [-0.16015625,  0.546875  ,  0.5234375 , ..., -0.84375   ,
         0.97265625, -0.03735352],
       [-0.4453125 ,  2.890625  ,  0.58984375, ..., -1.671875  ,
         0.24804688, -0.48632812],
       ...,
       [-0.2421875 ,  0.421875  ,  0.85546875, ..., -0.6171875 ,
         1.4140625 ,  0.04223633],
       [-1.5625    , -0.38671875,  2.90625   , ..., -1.8984375 ,
         0.1953125 , -0.88671875],
       [-0.37304688,  3.078125  ,  0.45703125, ..., -1.5546875 ,
         0.29296875, -0.41015625]], dtype=float32)


k_mean_space
(20, 2)


array([[77.06771 , 52.844227],
       [72.79372 , 55.99341 ],
       [91.874084, 70.103035],
       [70.62335 , 57.212223],
       [76.04249 , 55.220047],
       [88.42264 , 73.35366 ],
       [91.705635, 70.153046],
       [55.543232, 77.94179 ],
       [60.15891 , 78.26783 ],
       [67.14698 , 54.55634 ],
       [59.290092, 81.022736],
       [56.7386  , 61.578976],
       [62.805035, 84.616806],
       [80.923805, 76.204216],
       [84.636986, 70.69305 ],
       [72.63646 , 58.22167 ],
       [75.85296 , 56.67506 ],
       [73.73068 , 57.091297],
       [79.04699 , 54.88538 ],
       [91.68856 , 69.968094]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -69.33214664, -182.93720341])


closest
(2,)


array([7, 0])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 2.8125,  1.1094, -6.3125,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.0781, -1.6641, -7.1250,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.8516, -5.0938,  0.8633,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.3496, -2.3281, -7.8125,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.5312, -0.2227, -6.7500,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.6484, -5.4688,  0.4473,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[7.0191e-01, 2.9260e-01, 3.2505e-03,  ..., 1.1947e-22, 9.3046e-23,
         4.9804e-23],
        [4.7774e-01, 2.8977e-01, 1.7575e-01,  ..., 4.1296e-22, 3.0213e-22,
         7.1762e-23],
        [1.0000e+00, 1.1254e-07, 8.7642e-08,  ..., 1.0324e-22, 2.9578e-23,
         5.8243e-24],
        ...,
        [3.5794e-01, 3.5794e-01, 2.1710e-01,  ..., 7.9010e-22, 3.5060e-22,
         5.3767e-23],
        [6.7703e-01, 3.1981e-01, 1.6782e-03,  ..., 4.8039e-23, 4.2394e-23,
         2.2692e-23],
        [1.0000e+00, 1.3574e-07, 1.1979e-07,  ..., 2.4766e-22, 7.0955e-23,
         2.0329e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.7019, 0.9945, 0.9978,  ..., 1.0000, 1.0000, 1.0000],
        [0.4777, 0.7675, 0.9433,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.3579, 0.7159, 0.9330,  ..., 1.0000, 1.0000, 1.0000],
        [0.6770, 0.9968, 0.9985,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([48])


tensor([ 0,  0,  1,  1,  1,  2,  3,  4,  4,  5,  6,  7,  7,  7,  7,  7,  7,  7,
         7,  7,  7,  7,  7,  7,  7,  7,  7,  8,  9,  9, 10, 10, 10, 11, 12, 12,
        13, 13, 14, 15, 15, 16, 17, 17, 17, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([48, 71])


tensor([[    1, 32010,  1724,  ...,   373,   738,   916],
        [    1, 32010,  1724,  ...,   373,   738,   916],
        [    1, 32010,  1724,  ...,   373,   738, 15754],
        ...,
        [    1, 32010,  1724,  ...,   373,   738,   916],
        [    1, 32010,  1724,  ...,   373,   738,   916],
        [    1, 32010,  1724,  ...,   373,   738,  6432]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([48])


tensor([-10.2378, -10.2378, -10.9878, -10.9878, -10.9878, -11.7378, -10.4529,
        -12.8279, -12.8279, -11.7111, -12.4611, -11.9071, -11.9071, -11.9071,
        -11.9071, -11.9071, -11.9071, -11.9071, -11.9071, -11.9071, -11.9071,
        -11.9071, -11.9071, -11.9071, -11.9071, -11.9071, -11.9071, -14.2821,
        -15.0321, -15.0321, -15.4071, -15.4071, -15.4071, -15.7821, -11.9538,
        -11.9538, -12.5359, -12.5359, -13.7859, -14.2859, -14.2859, -12.1225,
        -11.0028, -11.0028, -11.0028, -11.3778, -11.3778, -12.3778],
       device='cuda:0')


new_candidate_toks
torch.Size([48, 1])


tensor([[15754],
        [ 6432],
        [  297],
        [  916],
        [29892],
        [  342],
        [  916],
        [15754],
        [ 6432],
        [ 1691],
        [  342],
        [ 2533],
        [ 3171],
        [ 9939],
        [24235],
        [11563],
        [ 5418],
        [ 2967],
        [ 7977],
        [ 2159],
        [15754],
        [11855],
        [ 1090],
        [14378],
        [19604],
        [ 2246],
        [ 4818],
        [29915],
        [ 3814],
        [ 6432],
        [20238],
        [ 4158],
        [ 3171],
        [  653],
        [  278],
        [  967],
        [  769],
        [ 6167],
        [ 1749],
        [ 2730],
        [ 6432],
        [17873],
        [  297],
        [  916],
        [29892],
        [15754],
        [ 6432],
        [  342]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([48])


tensor([-3.5395e-01, -1.2289e+00, -7.3868e-01, -1.2387e+00, -1.7387e+00,
        -3.5763e-07, -7.0962e-02, -2.0491e-01, -1.7049e+00, -1.0729e-06,
        -1.4305e-06, -1.2051e+00, -2.4551e+00, -2.4551e+00, -2.4551e+00,
        -2.5801e+00, -2.9551e+00, -3.0801e+00, -3.2051e+00, -3.4551e+00,
        -3.9551e+00, -4.0801e+00, -4.2051e+00, -4.2051e+00, -4.2051e+00,
        -4.4551e+00, -4.8301e+00, -1.2713e-02, -4.4162e-01, -1.1916e+00,
        -5.0888e-01, -1.3839e+00, -2.8839e+00, -6.0267e-03, -4.7333e-01,
        -1.0983e+00, -6.9832e-01, -8.2332e-01, -3.1813e-03, -4.1791e-01,
        -1.4179e+00, -9.2418e-04, -1.0274e+00, -1.0274e+00, -1.5274e+00,
        -3.9004e-01, -1.1400e+00, -4.7684e-07], device='cuda:0')


new_candidates
torch.Size([48, 72])


tensor([[    1, 32010,  1724,  ...,   738,   916, 15754],
        [    1, 32010,  1724,  ...,   738,   916,  6432],
        [    1, 32010,  1724,  ...,   738, 15754,   297],
        ...,
        [    1, 32010,  1724,  ...,   738,   916, 15754],
        [    1, 32010,  1724,  ...,   738,   916,  6432],
        [    1, 32010,  1724,  ...,   738,  6432,   342]], device='cuda:0')


new_candidate_logprobs
torch.Size([48])


tensor([-10.5917, -11.4667, -11.7265, -12.2265, -12.7265, -11.7378, -10.5239,
        -13.0328, -14.5328, -11.7111, -12.4611, -13.1122, -14.3622, -14.3622,
        -14.3622, -14.4872, -14.8622, -14.9872, -15.1122, -15.3622, -15.8622,
        -15.9872, -16.1122, -16.1122, -16.1122, -16.3622, -16.7372, -14.2948,
        -15.4737, -16.2237, -15.9160, -16.7910, -18.2910, -15.7881, -12.4271,
        -13.0521, -13.2342, -13.3592, -13.7891, -14.7038, -15.7038, -12.1234,
        -12.0302, -12.0302, -12.5302, -11.7679, -12.5179, -12.3778],
       device='cuda:0')

infer end: GPU memory used: 21647 MB.
event: level
id: 60
data: [{"content": "planet", "parent": 0, "prob": -10.59172534942627}, {"content": "cel", "parent": 0, "prob": -11.46672534942627}, {"content": "in", "parent": 1, "prob": -11.726457595825195}, {"content": "other", "parent": 1, "prob": -12.226457595825195}, {"content": ",", "parent": 1, "prob": -12.726457595825195}, {"content": "est", "parent": 2, "prob": -11.737777709960938}, {"content": "other", "parent": 3, "prob": -10.523895263671875}, {"content": "planet", "parent": 4, "prob": -13.032840728759766}, {"content": "cel", "parent": 4, "prob": -14.532840728759766}, {"content": "ets", "parent": 5, "prob": -11.711140632629395}, {"content": "est", "parent": 6, "prob": -12.461141586303711}, {"content": "sum", "parent": 7, "prob": -13.112217903137207}, {"content": "height", "parent": 7, "prob": -14.362217903137207}, {"content": "highest", "parent": 7, "prob": -14.362217903137207}, {"content": "diameter", "parent": 7, "prob": -14.362217

array([[-0.59375   ,  0.55859375,  0.15820312, ..., -0.24707031,
        -0.40429688,  0.01745605],
       [-0.421875  ,  2.84375   ,  0.6015625 , ..., -1.234375  ,
        -0.02160645, -0.53125   ],
       [ 0.3046875 ,  0.97265625,  0.53515625, ...,  0.69140625,
         2.09375   , -0.07421875],
       ...,
       [-0.01373291, -1.484375  ,  1.5078125 , ...,  0.7109375 ,
        -0.9921875 ,  0.6328125 ],
       [-1.84375   , -1.1796875 ,  2.078125  , ..., -1.09375   ,
         0.13183594,  1.6171875 ],
       [-0.80859375, -1.2421875 ,  2.375     , ..., -0.859375  ,
         2.71875   ,  1.078125  ]], dtype=float32)


k_mean_space
(20, 2)


array([[ 91.37025 ,  54.445873],
       [ 51.53405 ,  87.210266],
       [ 98.874725,  74.5799  ],
       [ 88.78372 ,  62.750988],
       [100.437035,  69.10931 ],
       [ 52.62346 ,  89.52557 ],
       [ 88.11512 ,  62.24352 ],
       [ 89.970276,  51.03741 ],
       [ 52.6124  ,  87.987175],
       [ 91.67438 ,  52.563526],
       [ 52.03531 ,  88.77374 ],
       [101.0376  ,  75.831955],
       [ 97.55759 ,  59.479164],
       [100.41086 ,  71.86343 ],
       [ 92.21925 ,  52.896603],
       [100.54913 ,  70.36215 ],
       [ 98.725464,  63.150787],
       [ 98.85391 ,  71.626114],
       [ 97.86475 ,  65.69737 ],
       [ 88.83502 ,  47.766636]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -50.19848537, -213.54893684])


closest
(2,)


array([ 1, 19])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 1.6875, -0.2480, -4.1250,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.6562, -4.3438,  1.3438,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.7188,  2.6875, -3.7656,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 6.3750, -2.2969, -7.5000,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.3594, -2.0938,  0.7617,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.7500,  1.7656,  4.1562,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[6.0320e-01, 3.6586e-01, 1.6075e-02,  ..., 2.3138e-22, 9.6452e-23,
         9.0608e-23],
        [1.0000e+00, 1.2752e-07, 9.9312e-08,  ..., 9.1108e-23, 1.2330e-23,
         5.1399e-24],
        [9.5252e-01, 4.7423e-02, 4.3244e-05,  ..., 8.0719e-24, 5.5478e-24,
         4.3206e-24],
        ...,
        [7.2851e-01, 9.8593e-02, 5.2773e-02,  ..., 7.1357e-22, 2.0444e-22,
         1.8042e-22],
        [8.8860e-01, 3.9042e-02, 3.0406e-02,  ..., 3.0379e-21, 9.8627e-22,
         5.2791e-22],
        [9.6675e-01, 1.5626e-02, 8.3640e-03,  ..., 1.7867e-20, 1.2280e-20,
         7.9286e-21]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.6032, 0.9691, 0.9851,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9525, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.7285, 0.8271, 0.8799,  ..., 1.0000, 1.0000, 1.0000],
        [0.8886, 0.9276, 0.9580,  ..., 1.0000, 1.0000, 1.0000],
        [0.9668, 0.9824, 0.9907,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([38])


tensor([ 0,  0,  1,  2,  3,  4,  4,  5,  6,  7,  7,  8,  9,  9,  9, 10, 11, 12,
        12, 12, 13, 13, 13, 14, 14, 14, 14, 14, 14, 15, 16, 17, 17, 17, 17, 18,
        18, 19], device='cuda:0')


carryover_candidates
torch.Size([38, 72])


tensor([[    1, 32010,  1724,  ...,   738,   916, 15754],
        [    1, 32010,  1724,  ...,   738,   916, 15754],
        [    1, 32010,  1724,  ...,   738,   916,  6432],
        ...,
        [    1, 32010,  1724,  ..., 13858,   278,  7977],
        [    1, 32010,  1724,  ..., 13858,   278,  7977],
        [    1, 32010,  1724,  ..., 13858,   278,  2159]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([38])


tensor([-10.5917, -10.5917, -11.4667, -11.7265, -12.2265, -12.7265, -12.7265,
        -11.7378, -10.5239, -13.0328, -13.0328, -14.5328, -11.7111, -11.7111,
        -11.7111, -12.4611, -13.1122, -14.3622, -14.3622, -14.3622, -14.3622,
        -14.3622, -14.3622, -14.3622, -14.3622, -14.3622, -14.3622, -14.3622,
        -14.3622, -14.4872, -14.8622, -14.9872, -14.9872, -14.9872, -14.9872,
        -15.1122, -15.1122, -15.3622], device='cuda:0')


new_candidate_toks
torch.Size([38, 1])


tensor([[  297],
        [29892],
        [  342],
        [ 1749],
        [ 1135],
        [  393],
        [  769],
        [  616],
        [ 1135],
        [29892],
        [  297],
        [  342],
        [  297],
        [29892],
        [  470],
        [  616],
        [ 2415],
        [  310],
        [  515],
        [ 2038],
        [14378],
        [19224],
        [ 1298],
        [  310],
        [  470],
        [29892],
        [  313],
        [  515],
        [  322],
        [29915],
        [  515],
        [29899],
        [  470],
        [  304],
        [  322],
        [  470],
        [  310],
        [  310]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([38])


tensor([-5.0550e-01, -1.0055e+00, -4.7684e-07, -4.8648e-02, -5.9605e-07,
        -7.6114e-01, -7.6114e-01, -3.5763e-07, -1.0729e-06, -2.4650e-01,
        -1.9965e+00, -1.1921e-07, -8.7074e-01, -1.1207e+00, -1.6207e+00,
        -2.1458e-06, -1.0022e-01, -4.7223e-01, -1.3472e+00, -2.4722e+00,
        -3.8216e-01, -1.8822e+00, -2.5072e+00, -1.0464e+00, -1.2964e+00,
        -2.0464e+00, -2.7964e+00, -2.7964e+00, -3.5464e+00, -9.5860e-04,
        -2.7895e-02, -3.1676e-01, -2.3168e+00, -2.9418e+00, -3.0668e+00,
        -1.1811e-01, -3.2431e+00, -3.3814e-02], device='cuda:0')


new_candidates
torch.Size([38, 73])


tensor([[    1, 32010,  1724,  ...,   916, 15754,   297],
        [    1, 32010,  1724,  ...,   916, 15754, 29892],
        [    1, 32010,  1724,  ...,   916,  6432,   342],
        ...,
        [    1, 32010,  1724,  ...,   278,  7977,   470],
        [    1, 32010,  1724,  ...,   278,  7977,   310],
        [    1, 32010,  1724,  ...,   278,  2159,   310]], device='cuda:0')


new_candidate_logprobs
torch.Size([38])


tensor([-11.0972, -11.5972, -11.4667, -11.7751, -12.2265, -13.4876, -13.4876,
        -11.7378, -10.5239, -13.2793, -15.0293, -14.5328, -12.5819, -12.8319,
        -13.3319, -12.4611, -13.2124, -14.8344, -15.7094, -16.8344, -14.7444,
        -16.2444, -16.8694, -15.4086, -15.6586, -16.4086, -17.1586, -17.1586,
        -17.9086, -14.4882, -14.8901, -15.3040, -17.3040, -17.9290, -18.0540,
        -15.2303, -18.3553, -15.3960], device='cuda:0')

infer end: GPU memory used: 21825 MB.
event: level
id: 61
data: [{"content": "in", "parent": 0, "prob": -11.097225189208984}, {"content": ",", "parent": 0, "prob": -11.597225189208984}, {"content": "est", "parent": 1, "prob": -11.466726303100586}, {"content": "our", "parent": 2, "prob": -11.775105476379395}, {"content": "than", "parent": 3, "prob": -12.226458549499512}, {"content": "that", "parent": 4, "prob": -13.487594604492188}, {"content": "then", "parent": 4, "prob": -13.487594604492188}, {"content": "ial", "parent": 5, "prob": -11.737777709960938}, {"content": "than", "parent": 6, "prob": -10.523896217346191}, {"content": ",", "parent": 7, "prob": -13.27934455871582}, {"content": "in", "parent": 7, "prob": -15.02934455871582}, {"content": "est", "parent": 8, "prob": -14.532840728759766}, {"content": "in", "parent": 9, "prob": -12.581881523132324}, {"content": ",", "parent": 9, "prob": -12.831881523132324}, {"content": "or", "parent": 9, "prob": -13.331880569458008}, {"content": "

array([[ 0.69921875,  0.828125  ,  0.37695312, ...,  0.7421875 ,
         1.6796875 , -0.48046875],
       [-0.421875  ,  0.29882812, -0.47070312, ...,  1.640625  ,
         0.26953125, -1.1640625 ],
       [ 2.015625  ,  0.16601562,  0.83203125, ...,  1.5234375 ,
        -0.11181641,  0.32617188],
       ...,
       [-1.71875   , -1.7109375 ,  2.5625    , ...,  0.6484375 ,
         0.30859375, -1.5078125 ],
       [-0.546875  , -1.5859375 ,  2.15625   , ...,  1.109375  ,
        -0.5234375 ,  1.1484375 ],
       [-1.8671875 , -1.75      ,  0.30273438, ..., -0.03564453,
        -0.7578125 ,  0.03344727]], dtype=float32)


k_mean_space
(20, 2)


array([[58.775852, 86.97051 ],
       [79.02252 , 36.797264],
       [76.36439 , 98.00191 ],
       [69.01895 , 92.95245 ],
       [63.309265, 81.99037 ],
       [76.42409 , 72.62202 ],
       [79.35873 , 42.1187  ],
       [61.033707, 84.11539 ],
       [64.78049 , 83.19201 ],
       [78.61862 , 36.688015],
       [56.344555, 82.75894 ],
       [76.31333 , 97.89849 ],
       [57.39292 , 84.58642 ],
       [79.34668 , 37.11544 ],
       [60.913532, 74.82506 ],
       [57.60797 , 80.8982  ],
       [83.934784, 77.07504 ],
       [75.31443 , 67.90724 ],
       [73.90525 , 84.93988 ],
       [73.791145, 86.055244]], dtype=float32)


k_mean_clusters
(20,)


array([0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-169.30817604,  -92.73052597])


closest
(2,)


array([10,  9])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 4.2500,  2.8281, -2.7031,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.7500, -0.8008,  0.9766,  ...,  0.0000,  0.0000,  0.0000],
        [-3.3125, -0.3633, -3.2500,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 4.7812,  5.1250, -4.2500,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.5664,  0.5117, -5.3750,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.6719, -0.5586, -7.6562,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.8901e-01, 1.0987e-02, 3.2526e-06,  ..., 4.2143e-24, 3.2821e-24,
         2.1191e-24],
        [5.1448e-01, 4.0067e-01, 3.2889e-02,  ..., 8.9324e-21, 6.9565e-21,
         8.3084e-22],
        [1.0000e+00, 6.8256e-08, 5.3158e-08,  ..., 8.5846e-29, 6.2186e-30,
         1.0806e-30],
        ...,
        [4.7745e-01, 3.2815e-01, 1.0653e-01,  ..., 5.6973e-21, 5.0278e-21,
         1.4405e-21],
        [6.0235e-01, 2.8453e-01, 1.0467e-01,  ..., 6.3431e-21, 6.3431e-21,
         1.8173e-21],
        [7.2896e-01, 2.0885e-01, 5.9836e-02,  ..., 2.8240e-21, 1.7128e-21,
         2.9764e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9890, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.5145, 0.9152, 0.9480,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.4774, 0.8056, 0.9121,  ..., 1.0000, 1.0000, 1.0000],
        [0.6023, 0.8869, 0.9916,  ..., 1.0000, 1.0000, 1.0000],
        [0.7290, 0.9378, 0.9976,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([46])


tensor([ 0,  1,  1,  2,  3,  4,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  7,
         8,  9,  9,  9,  9, 10, 11, 12, 13, 13, 13, 14, 14, 14, 15, 16, 16, 16,
        16, 16, 17, 17, 17, 18, 18, 18, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([46, 73])


tensor([[    1, 32010,  1724,  ...,   916, 15754,   297],
        [    1, 32010,  1724,  ...,   916, 15754, 29892],
        [    1, 32010,  1724,  ...,   916, 15754, 29892],
        ...,
        [    1, 32010,  1724,  ...,   278,  3171,   515],
        [    1, 32010,  1724,  ...,   278,  3171,  2038],
        [    1, 32010,  1724,  ...,   278,  3171,  2038]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([46])


tensor([-11.0972, -11.5972, -11.5972, -11.4667, -11.7751, -12.2265, -13.4876,
        -13.4876, -13.4876, -13.4876, -13.4876, -13.4876, -13.4876, -13.4876,
        -13.4876, -13.4876, -13.4876, -11.7378, -10.5239, -13.2793, -13.2793,
        -13.2793, -13.2793, -15.0293, -14.5328, -12.5819, -12.8319, -12.8319,
        -12.8319, -13.3319, -13.3319, -13.3319, -12.4611, -13.2124, -13.2124,
        -13.2124, -13.2124, -13.2124, -14.8344, -14.8344, -14.8344, -15.7094,
        -15.7094, -15.7094, -16.8344, -16.8344], device='cuda:0')


new_candidate_toks
torch.Size([46, 1])


tensor([[ 1749],
        [  393],
        [  769],
        [  616],
        [21635],
        [11563],
        [  723],
        [ 3620],
        [ 3611],
        [21578],
        [  722],
        [ 7111],
        [  393],
        [ 6167],
        [ 8040],
        [  366],
        [  372],
        [ 3573],
        [11563],
        [  769],
        [  393],
        [ 6167],
        [ 8040],
        [ 1749],
        [  616],
        [ 1749],
        [  769],
        [  393],
        [ 6167],
        [ 2730],
        [ 6432],
        [  297],
        [17873],
        [ 3171],
        [ 2038],
        [ 6198],
        [29915],
        [  310],
        [  263],
        [  967],
        [  278],
        [  278],
        [  967],
        [ 2967],
        [  278],
        [ 7205]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([46])


tensor([-1.1052e-02, -6.6461e-01, -9.1461e-01, -1.1921e-07, -6.2296e-02,
        -3.9339e-06, -6.5482e-01, -1.6548e+00, -2.4048e+00, -3.0298e+00,
        -3.1548e+00, -3.4048e+00, -1.2322e+00, -1.2322e+00, -1.7322e+00,
        -2.3572e+00, -2.6072e+00, -8.0707e-04, -5.1260e-06, -3.8410e-01,
        -1.7591e+00, -3.1341e+00, -3.5091e+00, -4.1186e-03,  0.0000e+00,
        -4.1034e-03, -6.9457e-01, -9.4457e-01, -2.8196e+00, -3.2037e-01,
        -2.0704e+00, -2.4454e+00, -5.1464e-04, -6.3166e-01, -1.7567e+00,
        -2.2567e+00, -2.5067e+00, -3.3817e+00, -7.3930e-01, -1.1143e+00,
        -2.2393e+00, -5.0692e-01, -1.2569e+00, -2.2569e+00, -3.1614e-01,
        -1.5661e+00], device='cuda:0')


new_candidates
torch.Size([46, 74])


tensor([[    1, 32010,  1724,  ..., 15754,   297,  1749],
        [    1, 32010,  1724,  ..., 15754, 29892,   393],
        [    1, 32010,  1724,  ..., 15754, 29892,   769],
        ...,
        [    1, 32010,  1724,  ...,  3171,   515,  2967],
        [    1, 32010,  1724,  ...,  3171,  2038,   278],
        [    1, 32010,  1724,  ...,  3171,  2038,  7205]], device='cuda:0')


new_candidate_logprobs
torch.Size([46])


tensor([-11.1083, -12.2618, -12.5118, -11.4667, -11.8374, -12.2265, -14.1424,
        -15.1424, -15.8924, -16.5174, -16.6424, -16.8924, -14.7198, -14.7198,
        -15.2198, -15.8448, -16.0948, -11.7386, -10.5239, -13.6634, -15.0384,
        -16.4134, -16.7884, -15.0335, -14.5328, -12.5860, -13.5265, -13.7764,
        -15.6514, -13.6522, -15.4022, -15.7772, -12.4617, -13.8441, -14.9691,
        -15.4691, -15.7191, -16.5941, -15.5737, -15.9487, -17.0737, -16.2164,
        -16.9664, -17.9664, -17.1506, -18.4006], device='cuda:0')

infer end: GPU memory used: 22005 MB.
event: level
id: 62
data: [{"content": "our", "parent": 0, "prob": -11.108277320861816}, {"content": "that", "parent": 1, "prob": -12.261832237243652}, {"content": "then", "parent": 1, "prob": -12.511832237243652}, {"content": "ial", "parent": 2, "prob": -11.466726303100586}, {"content": "solar", "parent": 3, "prob": -11.837401390075684}, {"content": "Earth", "parent": 4, "prob": -12.226462364196777}, {"content": "would", "parent": 5, "prob": -14.142416954040527}, {"content": "changes", "parent": 5, "prob": -15.142416954040527}, {"content": "title", "parent": 5, "prob": -15.892416954040527}, {"content": "distinction", "parent": 5, "prob": -16.51741600036621}, {"content": "var", "parent": 5, "prob": -16.64241600036621}, {"content": "depends", "parent": 5, "prob": -16.89241600036621}, {"content": "that", "parent": 6, "prob": -14.719832420349121}, {"content": "Olymp", "parent": 6, "prob": -14.719832420349121}, {"content": "Mount", "parent": 6, "prob":

array([[-1.4921875 ,  1.7578125 , -0.91796875, ..., -1.8828125 ,
         2.890625  , -0.63671875],
       [-0.15722656, -0.21386719, -0.94921875, ...,  1.7421875 ,
        -1.0234375 , -1.6796875 ],
       [-1.546875  , -0.49804688, -0.58203125, ...,  2.4375    ,
         0.00927734, -1.15625   ],
       ...,
       [ 0.12792969,  0.5625    ,  0.6796875 , ..., -1.203125  ,
        -0.36914062, -0.55078125],
       [-0.98046875,  0.19042969, -0.04296875, ...,  0.41210938,
        -1.5234375 ,  0.31054688],
       [-1.328125  , -0.70703125,  0.        , ...,  2.28125   ,
        -0.21289062, -1.765625  ]], dtype=float32)


k_mean_space
(20, 2)


array([[93.767815, 76.21784 ],
       [84.66808 , 52.781708],
       [84.080124, 72.53225 ],
       [80.97438 , 62.100155],
       [95.80922 , 79.62229 ],
       [86.36336 , 55.076324],
       [78.694145, 46.61085 ],
       [82.35752 , 50.99703 ],
       [83.10915 , 51.100956],
       [81.04549 , 48.89475 ],
       [61.681747, 83.23377 ],
       [84.81845 , 60.00627 ],
       [83.30619 , 50.026054],
       [63.27438 , 91.10842 ],
       [61.8736  , 86.80052 ],
       [89.97828 , 64.58607 ],
       [80.8161  , 48.147552],
       [89.58132 , 59.372536],
       [87.8069  , 56.96553 ],
       [84.3816  , 70.852325]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -46.58208084, -232.58504295])


closest
(2,)


array([10,  6])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 0.1553,  2.6406, -5.9688,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.4062, -7.5625,  1.2266,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.5000, -3.9219, -2.4688,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 3.1094, -1.0000, -5.8438,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.2969,  4.7812, -2.4688,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.5625, -4.3750, -1.2969,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.3984e-01, 6.0082e-02, 6.2083e-05,  ..., 1.5683e-24, 9.5123e-25,
         3.2549e-26],
        [6.7661e-01, 1.3323e-01, 6.2935e-02,  ..., 1.0470e-19, 4.9458e-20,
         1.5084e-20],
        [8.0762e-01, 8.5123e-02, 5.1630e-02,  ..., 2.3118e-20, 1.8005e-20,
         3.1287e-21],
        ...,
        [4.5704e-01, 2.7721e-01, 1.3095e-01,  ..., 4.4326e-23, 2.8619e-23,
         1.9669e-23],
        [9.8521e-01, 6.6383e-03, 6.6383e-03,  ..., 2.0028e-23, 1.2931e-23,
         1.8629e-24],
        [7.4031e-01, 1.0019e-01, 5.3628e-02,  ..., 2.5562e-20, 6.8798e-21,
         1.3547e-21]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9398, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.6766, 0.8098, 0.8728,  ..., 1.0000, 1.0000, 1.0000],
        [0.8076, 0.8927, 0.9444,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.4570, 0.7343, 0.8652,  ..., 1.0000, 1.0000, 1.0000],
        [0.9852, 0.9919, 0.9985,  ..., 1.0000, 1.0000, 1.0000],
        [0.7403, 0.8405, 0.8941,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([51])


tensor([ 0,  1,  1,  1,  1,  2,  2,  2,  3,  4,  5,  6,  6,  6,  7,  7,  7,  7,
         8,  8,  8,  8,  9,  9,  9, 10, 11, 12, 12, 12, 12, 12, 12, 13, 14, 15,
        15, 15, 15, 16, 16, 16, 17, 17, 17, 17, 18, 19, 19, 19, 19],
       device='cuda:0')


carryover_candidates
torch.Size([51, 74])


tensor([[    1, 32010,  1724,  ..., 15754,   297,  1749],
        [    1, 32010,  1724,  ..., 15754, 29892,   393],
        [    1, 32010,  1724,  ..., 15754, 29892,   393],
        ...,
        [    1, 32010,  1724,  ..., 15754, 29892,   769],
        [    1, 32010,  1724,  ..., 15754, 29892,   769],
        [    1, 32010,  1724,  ..., 15754, 29892,   769]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([51])


tensor([-11.1083, -12.2618, -12.2618, -12.2618, -12.2618, -12.5118, -12.5118,
        -12.5118, -11.4667, -11.8374, -12.2265, -14.1424, -14.1424, -14.1424,
        -15.1424, -15.1424, -15.1424, -15.1424, -15.8924, -15.8924, -15.8924,
        -15.8924, -16.5174, -16.5174, -16.5174, -16.6424, -16.8924, -14.7198,
        -14.7198, -14.7198, -14.7198, -14.7198, -14.7198, -14.7198, -15.2198,
        -15.8448, -15.8448, -15.8448, -15.8448, -16.0948, -16.0948, -16.0948,
        -11.7386, -11.7386, -11.7386, -11.7386, -10.5239, -13.6634, -13.6634,
        -13.6634, -13.6634], device='cuda:0')


new_candidate_toks
torch.Size([51, 1])


tensor([[21635],
        [  723],
        [ 3611],
        [ 3620],
        [21578],
        [ 6167],
        [  393],
        [  366],
        [ 3573],
        [ 1788],
        [29892],
        [  367],
        [ 1735],
        [ 8839],
        [29889],
        [29901],
        [ 8679],
        [ 2729],
        [ 5771],
        [14393],
        [ 3620],
        [  723],
        [ 3620],
        [ 5771],
        [14393],
        [  583],
        [  373],
        [  723],
        [ 3620],
        [ 7111],
        [ 3611],
        [  722],
        [21578],
        [  375],
        [ 6167],
        [29915],
        [ 1795],
        [  881],
        [  723],
        [  723],
        [29915],
        [ 7111],
        [29892],
        [  297],
        [  916],
        [ 2629],
        [29892],
        [ 6167],
        [  366],
        [ 8040],
        [  393]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([51])


tensor([-6.2044e-02, -3.9065e-01, -2.0157e+00, -2.7657e+00, -3.0157e+00,
        -2.1366e-01, -2.4637e+00, -2.9637e+00, -1.1436e-02, -4.7684e-07,
        -1.0106e-02, -5.9696e-01, -1.3470e+00, -2.0970e+00, -1.0268e+00,
        -1.0268e+00, -2.1518e+00, -2.4018e+00, -1.2051e+00, -1.2051e+00,
        -1.5801e+00, -1.9551e+00, -8.9807e-01, -1.3981e+00, -1.3981e+00,
        -1.1921e-07, -3.2607e-03, -9.1117e-01, -1.5362e+00, -2.1612e+00,
        -2.4112e+00, -2.9112e+00, -3.2862e+00, -7.5939e-05, -1.8247e-02,
        -9.5280e-01, -1.2028e+00, -1.7028e+00, -2.2028e+00, -6.9027e-01,
        -9.4027e-01, -2.8153e+00, -7.8298e-01, -1.2830e+00, -2.0330e+00,
        -2.0330e+00, -1.4896e-02, -3.0069e-01, -2.3007e+00, -2.9257e+00,
        -3.1757e+00], device='cuda:0')


new_candidates
torch.Size([51, 75])


tensor([[    1, 32010,  1724,  ...,   297,  1749, 21635],
        [    1, 32010,  1724,  ..., 29892,   393,   723],
        [    1, 32010,  1724,  ..., 29892,   393,  3611],
        ...,
        [    1, 32010,  1724,  ..., 29892,   769,   366],
        [    1, 32010,  1724,  ..., 29892,   769,  8040],
        [    1, 32010,  1724,  ..., 29892,   769,   393]], device='cuda:0')


new_candidate_logprobs
torch.Size([51])


tensor([-11.1703, -12.6525, -14.2775, -15.0275, -15.2775, -12.7255, -14.9755,
        -15.4755, -11.4782, -11.8374, -12.2366, -14.7394, -15.4894, -16.2394,
        -16.1692, -16.1692, -17.2942, -17.5442, -17.0975, -17.0975, -17.4725,
        -17.8475, -17.4155, -17.9155, -17.9155, -16.6424, -16.8957, -15.6310,
        -16.2560, -16.8810, -17.1310, -17.6310, -18.0060, -14.7199, -15.2381,
        -16.7976, -17.0476, -17.5476, -18.0476, -16.7851, -17.0351, -18.9101,
        -12.5216, -13.0216, -13.7716, -13.7716, -10.5388, -13.9641, -15.9641,
        -16.5891, -16.8391], device='cuda:0')

infer end: GPU memory used: 22187 MB.
event: level
id: 63
data: [{"content": "solar", "parent": 0, "prob": -11.170321464538574}, {"content": "would", "parent": 1, "prob": -12.652486801147461}, {"content": "title", "parent": 1, "prob": -14.277486801147461}, {"content": "changes", "parent": 1, "prob": -15.027486801147461}, {"content": "distinction", "parent": 1, "prob": -15.277486801147461}, {"content": "Olymp", "parent": 2, "prob": -12.72549057006836}, {"content": "that", "parent": 2, "prob": -14.97549057006836}, {"content": "you", "parent": 2, "prob": -15.47549057006836}, {"content": "body", "parent": 3, "prob": -11.478161811828613}, {"content": "system", "parent": 4, "prob": -11.83740234375}, {"content": ",", "parent": 5, "prob": -12.236568450927734}, {"content": "be", "parent": 6, "prob": -14.73937702178955}, {"content": "change", "parent": 6, "prob": -15.48937702178955}, {"content": "depend", "parent": 6, "prob": -16.239377975463867}, {"content": ".", "parent": 7, "prob": -16.169166

array([[-1.6171875 , -0.3984375 , -1.4453125 , ..., -1.359375  ,
        -0.2890625 , -1.0546875 ],
       [ 0.43359375, -0.08544922, -0.48828125, ...,  0.13574219,
        -0.11865234, -0.7734375 ],
       [-1.09375   ,  0.23339844, -0.75      , ...,  0.12695312,
         0.28320312, -0.5390625 ],
       ...,
       [ 0.08349609,  0.6796875 , -0.18652344, ...,  0.39257812,
        -0.20996094,  0.45898438],
       [-0.36132812,  0.01818848, -0.953125  , ...,  0.8359375 ,
        -0.2265625 ,  1.1640625 ],
       [-1.078125  , -0.00686646, -0.28320312, ...,  0.64453125,
         0.69140625, -0.20214844]], dtype=float32)


k_mean_space
(20, 2)


array([[94.66699 , 79.642624],
       [76.13791 , 47.4206  ],
       [81.23095 , 51.881462],
       [75.61276 , 47.802834],
       [78.75456 , 49.010483],
       [96.424255, 86.185646],
       [81.74955 , 56.22521 ],
       [89.22261 , 69.30659 ],
       [71.943436, 63.336067],
       [67.942184, 70.41019 ],
       [45.526707, 77.70972 ],
       [56.82128 , 85.08193 ],
       [74.041565, 47.820152],
       [84.38015 , 53.13748 ],
       [40.939503, 71.8608  ],
       [37.04832 , 75.40273 ],
       [84.12952 , 54.725227],
       [81.50347 , 51.222855],
       [79.26512 , 56.821453],
       [73.68036 , 52.344418]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -71.15168095, -223.82206249])


closest
(2,)


array([15,  1])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 6.1875,  0.0845, -0.6875,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.5938,  1.6797,  3.3438,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.8438, -1.7031,  3.5938,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.6406,  0.7930,  3.1562,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.6875,  4.0312,  2.5469,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.0000,  3.7656,  5.2188,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 1.6374e-07, 2.5110e-08,  ..., 2.9963e-28, 2.6442e-28,
         1.2490e-28],
        [6.9698e-01, 1.5552e-01, 1.0689e-01,  ..., 5.4233e-20, 4.7860e-20,
         2.5618e-20],
        [6.0563e-01, 2.8608e-01, 8.1964e-02,  ..., 3.0126e-21, 1.6125e-21,
         9.1879e-22],
        ...,
        [9.9999e-01, 6.9623e-06, 2.1024e-07,  ..., 9.6983e-23, 8.5587e-23,
         4.3036e-23],
        [9.9986e-01, 1.0889e-04, 1.4737e-05,  ..., 2.0733e-21, 1.8297e-21,
         1.6147e-21],
        [9.9955e-01, 4.3055e-04, 3.7250e-06,  ..., 4.9721e-21, 4.3878e-21,
         2.6614e-21]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.6970, 0.8525, 0.9594,  ..., 1.0000, 1.0000, 1.0000],
        [0.6056, 0.8917, 0.9737,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9996, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([43])


tensor([ 0,  1,  1,  1,  2,  2,  2,  3,  3,  3,  3,  4,  4,  4,  5,  6,  6,  6,
         6,  6,  7,  7,  8,  8,  9, 10, 10, 10, 11, 12, 12, 12, 12, 12, 13, 14,
        14, 15, 15, 16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([43, 75])


tensor([[    1, 32010,  1724,  ...,   297,  1749, 21635],
        [    1, 32010,  1724,  ..., 29892,   393,   723],
        [    1, 32010,  1724,  ..., 29892,   393,   723],
        ...,
        [    1, 32010,  1724,  ...,   393,  3620,  2729],
        [    1, 32010,  1724,  ...,   393,  3611,  5771],
        [    1, 32010,  1724,  ...,   393,  3611, 14393]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([43])


tensor([-11.1703, -12.6525, -12.6525, -12.6525, -14.2775, -14.2775, -14.2775,
        -15.0275, -15.0275, -15.0275, -15.0275, -15.2775, -15.2775, -15.2775,
        -12.7255, -14.9755, -14.9755, -14.9755, -14.9755, -14.9755, -15.4755,
        -15.4755, -11.4782, -11.4782, -11.8374, -12.2366, -12.2366, -12.2366,
        -14.7394, -15.4894, -15.4894, -15.4894, -15.4894, -15.4894, -16.2394,
        -16.1692, -16.1692, -16.1692, -16.1692, -17.2942, -17.5442, -17.0975,
        -17.0975], device='cuda:0')


new_candidate_toks
torch.Size([43, 1])


tensor([[ 1788],
        [  367],
        [ 1735],
        [ 8839],
        [ 5771],
        [14393],
        [  723],
        [29889],
        [  278],
        [29901],
        [ 8679],
        [ 5771],
        [14393],
        [  723],
        [  375],
        [  723],
        [ 3611],
        [ 3620],
        [21578],
        [29915],
        [ 1795],
        [29915],
        [29892],
        [  297],
        [29892],
        [  769],
        [  393],
        [  366],
        [ 6167],
        [29889],
        [  278],
        [29901],
        [ 2729],
        [ 8679],
        [  373],
        [ 1152],
        [ 1551],
        [   13],
        [ 6167],
        [  373],
        [  373],
        [  304],
        [  304]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([43])


tensor([-1.1921e-07, -3.6100e-01, -1.8610e+00, -2.2360e+00, -5.0148e-01,
        -1.2515e+00, -2.5015e+00, -7.7846e-01, -1.5285e+00, -1.5285e+00,
        -3.0285e+00, -6.0277e-01, -1.1028e+00, -2.8528e+00, -2.5034e-05,
        -5.0046e-01, -2.1255e+00, -2.5005e+00, -3.0005e+00, -3.1255e+00,
        -6.4458e-01, -8.9458e-01, -2.0250e-01, -1.9525e+00, -1.1693e-02,
        -7.4276e-01, -8.6776e-01, -2.4928e+00, -2.1467e-02, -1.0348e+00,
        -1.5348e+00, -1.7848e+00, -2.0348e+00, -2.2848e+00, -1.1375e-03,
        -1.8469e-01, -2.0597e+00, -2.0557e-01, -2.2056e+00, -5.3288e-05,
        -7.6294e-06, -1.4449e-04, -4.4958e-04], device='cuda:0')


new_candidates
torch.Size([43, 76])


tensor([[    1, 32010,  1724,  ...,  1749, 21635,  1788],
        [    1, 32010,  1724,  ...,   393,   723,   367],
        [    1, 32010,  1724,  ...,   393,   723,  1735],
        ...,
        [    1, 32010,  1724,  ...,  3620,  2729,   373],
        [    1, 32010,  1724,  ...,  3611,  5771,   304],
        [    1, 32010,  1724,  ...,  3611, 14393,   304]], device='cuda:0')


new_candidate_logprobs
torch.Size([43])


tensor([-11.1703, -13.0135, -14.5135, -14.8885, -14.7790, -15.5290, -16.7790,
        -15.8059, -16.5559, -16.5559, -18.0559, -15.8803, -16.3803, -18.1303,
        -12.7255, -15.4759, -17.1009, -17.4759, -17.9759, -18.1009, -16.1201,
        -16.3701, -11.6807, -13.4307, -11.8491, -12.9793, -13.1043, -14.7293,
        -14.7608, -16.5242, -17.0242, -17.2742, -17.5242, -17.7742, -16.2405,
        -16.3539, -18.2289, -16.3747, -18.3747, -17.2942, -17.5442, -17.0977,
        -17.0980], device='cuda:0')

infer end: GPU memory used: 22371 MB.
event: level
id: 64
data: [{"content": "system", "parent": 0, "prob": -11.170321464538574}, {"content": "be", "parent": 1, "prob": -13.0134859085083}, {"content": "change", "parent": 1, "prob": -14.5134859085083}, {"content": "depend", "parent": 1, "prob": -14.8884859085083}, {"content": "goes", "parent": 2, "prob": -14.778966903686523}, {"content": "belongs", "parent": 2, "prob": -15.528966903686523}, {"content": "would", "parent": 2, "prob": -16.778966903686523}, {"content": ".", "parent": 3, "prob": -15.805944442749023}, {"content": "the", "parent": 3, "prob": -16.555944442749023}, {"content": ":", "parent": 3, "prob": -16.555944442749023}, {"content": "depending", "parent": 3, "prob": -18.055944442749023}, {"content": "goes", "parent": 4, "prob": -15.880257606506348}, {"content": "belongs", "parent": 4, "prob": -16.38025665283203}, {"content": "would", "parent": 4, "prob": -18.13025665283203}, {"content": "us", "parent": 5, "prob": -12.72551536

array([[-0.11279297, -0.70703125,  1.34375   , ..., -0.234375  ,
         0.39648438, -0.74609375],
       [-1.546875  , -1.703125  ,  0.76171875, ...,  2.65625   ,
         0.2890625 ,  1.375     ],
       [ 1.546875  ,  2.171875  ,  0.23730469, ...,  1.3203125 ,
         0.97265625, -0.02966309],
       ...,
       [ 1.0703125 ,  1.3359375 ,  0.27734375, ...,  2.46875   ,
         0.98046875,  0.3515625 ],
       [-0.94140625, -0.83984375, -1.1484375 , ..., -1.1015625 ,
         0.06591797, -1.6640625 ],
       [-0.45117188, -0.9453125 ,  2.71875   , ...,  1.5234375 ,
        -1.3203125 ,  0.578125  ]], dtype=float32)


k_mean_space
(20, 2)


array([[76.77899 , 67.62502 ],
       [88.94279 , 72.34953 ],
       [43.457367, 60.438538],
       [43.12991 , 68.79462 ],
       [73.94971 , 51.97111 ],
       [69.26501 , 46.9382  ],
       [66.72624 , 46.121433],
       [78.05156 , 62.86014 ],
       [62.149563, 58.38898 ],
       [82.96086 , 66.11031 ],
       [45.567677, 69.99504 ],
       [72.56056 , 50.94698 ],
       [67.837   , 46.18592 ],
       [64.71644 , 45.866028],
       [78.15938 , 87.82926 ],
       [60.508156, 48.307945],
       [70.30605 , 50.558865],
       [40.841705, 58.489174],
       [67.835464, 48.259605],
       [90.27715 , 75.48401 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -77.65937805, -239.23309803])


closest
(2,)


array([17, 13])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 1.0156, -0.3320, -1.8750,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.5625, -2.2656, -0.1680,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.2500,  2.1094,  5.6562,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 3.9219,  2.2344,  5.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 8.0000,  0.4668,  3.7031,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.9531,  4.0625,  0.8867,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9884e-01, 2.3029e-04, 1.7935e-04,  ..., 2.3238e-22, 1.3241e-22,
         5.8755e-23],
        [9.9608e-01, 2.4690e-03, 6.2427e-04,  ..., 1.4480e-19, 1.9597e-20,
         1.7294e-20],
        [6.1219e-01, 2.8918e-01, 3.0479e-02,  ..., 1.2044e-20, 8.8116e-21,
         5.0207e-21],
        ...,
        [7.0596e-01, 2.0226e-01, 2.7373e-02,  ..., 2.9403e-20, 2.0208e-20,
         1.8984e-20],
        [5.6769e-01, 3.4432e-01, 4.6599e-02,  ..., 6.3637e-21, 5.6159e-21,
         8.0905e-22],
        [9.9976e-01, 2.0342e-04, 2.1440e-05,  ..., 1.9475e-21, 1.6145e-21,
         7.6265e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9988, 0.9991, 0.9993,  ..., 1.0000, 1.0000, 1.0000],
        [0.9961, 0.9986, 0.9992,  ..., 1.0000, 1.0000, 1.0000],
        [0.6122, 0.9014, 0.9318,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.7060, 0.9082, 0.9356,  ..., 1.0000, 1.0000, 1.0000],
        [0.5677, 0.9120, 0.9586,  ..., 1.0000, 1.0000, 1.0000],
        [0.9998, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([33])


tensor([ 0,  1,  2,  2,  3,  4,  5,  6,  6,  6,  7,  7,  8,  9,  9, 10, 11, 12,
        13, 13, 13, 13, 14, 15, 15, 15, 16, 16, 17, 17, 18, 18, 19],
       device='cuda:0')


carryover_candidates
torch.Size([33, 76])


tensor([[    1, 32010,  1724,  ...,  1749, 21635,  1788],
        [    1, 32010,  1724,  ...,   393,   723,   367],
        [    1, 32010,  1724,  ...,   393,   723,  1735],
        ...,
        [    1, 32010,  1724,  ...,   769,   393, 21578],
        [    1, 32010,  1724,  ...,   769,   393, 21578],
        [    1, 32010,  1724,  ...,   769,   393, 29915]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([33])


tensor([-11.1703, -13.0135, -14.5135, -14.5135, -14.8885, -14.7790, -15.5290,
        -16.7790, -16.7790, -16.7790, -15.8059, -15.8059, -16.5559, -16.5559,
        -16.5559, -18.0559, -15.8803, -16.3803, -18.1303, -18.1303, -18.1303,
        -18.1303, -12.7255, -15.4759, -15.4759, -15.4759, -17.1009, -17.1009,
        -17.4759, -17.4759, -17.9759, -17.9759, -18.1009], device='cuda:0')


new_candidate_toks
torch.Size([33, 1])


tensor([[29892],
        [ 6167],
        [  278],
        [29889],
        [  373],
        [  304],
        [  304],
        [  748],
        [ 6852],
        [ 1735],
        [ 1152],
        [ 1551],
        [ 1234],
        [ 6167],
        [   13],
        [  373],
        [  304],
        [  304],
        [ 1735],
        [  748],
        [ 6852],
        [ 6416],
        [  341],
        [  367],
        [ 8839],
        [ 1735],
        [ 5771],
        [14393],
        [29889],
        [29901],
        [ 5771],
        [14393],
        [29879]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([33])


tensor([-1.1584e-03, -3.9238e-03, -4.9071e-01, -1.2407e+00, -1.5854e-03,
        -7.1526e-06, -4.8162e-05, -6.1046e-01, -1.4855e+00, -1.7355e+00,
        -1.6613e-01, -2.1661e+00, -7.6450e-02, -4.2563e-01, -1.3006e+00,
        -1.1159e-04, -1.6451e-05, -5.6388e-05, -6.0173e-01, -1.4767e+00,
        -2.2267e+00, -3.2267e+00, -7.1526e-07, -2.4132e-01, -2.2413e+00,
        -2.7413e+00, -5.5328e-01, -1.0533e+00, -3.4820e-01, -1.5982e+00,
        -5.6618e-01, -1.0662e+00, -2.3982e-04], device='cuda:0')


new_candidates
torch.Size([33, 77])


tensor([[    1, 32010,  1724,  ..., 21635,  1788, 29892],
        [    1, 32010,  1724,  ...,   723,   367,  6167],
        [    1, 32010,  1724,  ...,   723,  1735,   278],
        ...,
        [    1, 32010,  1724,  ...,   393, 21578,  5771],
        [    1, 32010,  1724,  ...,   393, 21578, 14393],
        [    1, 32010,  1724,  ...,   393, 29915, 29879]], device='cuda:0')


new_candidate_logprobs
torch.Size([33])


tensor([-11.1715, -13.0174, -15.0042, -15.7542, -14.8901, -14.7790, -15.5290,
        -17.3894, -18.2644, -18.5144, -15.9721, -17.9721, -16.6324, -16.9816,
        -17.8566, -18.0561, -15.8803, -16.3803, -18.7320, -19.6070, -20.3570,
        -21.3570, -12.7255, -15.7173, -17.7173, -18.2173, -17.6542, -18.1542,
        -17.8241, -19.0741, -18.5421, -19.0421, -18.1012], device='cuda:0')

infer end: GPU memory used: 22557 MB.
event: level
id: 65
data: [{"content": ",", "parent": 0, "prob": -11.171480178833008}, {"content": "Olymp", "parent": 1, "prob": -13.017409324645996}, {"content": "the", "parent": 2, "prob": -15.004199981689453}, {"content": ".", "parent": 2, "prob": -15.754199981689453}, {"content": "on", "parent": 3, "prob": -14.890070915222168}, {"content": "to", "parent": 4, "prob": -14.778974533081055}, {"content": "to", "parent": 5, "prob": -15.52901554107666}, {"content": "go", "parent": 6, "prob": -17.389427185058594}, {"content": "belong", "parent": 6, "prob": -18.264427185058594}, {"content": "change", "parent": 6, "prob": -18.514427185058594}, {"content": "For", "parent": 7, "prob": -15.97206974029541}, {"content": "On", "parent": 7, "prob": -17.972068786621094}, {"content": "answer", "parent": 8, "prob": -16.632394790649414}, {"content": "Olymp", "parent": 9, "prob": -16.981569290161133}, {"content": "\n", "parent": 9, "prob": -17.856569290161133}, {"co

array([[-0.5546875 , -0.06298828, -1.40625   , ...,  1.578125  ,
         0.43164062, -1.6015625 ],
       [-0.34179688, -0.00830078, -0.0222168 , ..., -0.7734375 ,
         0.609375  , -0.58984375],
       [-0.26757812,  1.4609375 ,  0.00726318, ...,  0.42773438,
        -0.41015625, -0.03710938],
       ...,
       [-1.4765625 , -1.6875    ,  1.5546875 , ...,  2.484375  ,
        -0.19628906,  1.09375   ],
       [ 1.0703125 ,  1.4609375 ,  0.9765625 , ...,  1.59375   ,
         1.6796875 ,  0.0177002 ],
       [-0.06542969, -0.546875  , -0.5625    , ...,  0.5078125 ,
        -0.26953125,  0.40625   ]], dtype=float32)


k_mean_space
(20, 2)


array([[ 62.01418 , 111.34171 ],
       [ 88.59718 ,  13.449341],
       [ 57.315773, 101.3704  ],
       [ 59.750996, 103.53038 ],
       [ 53.868443, 103.965515],
       [ 61.482487, 107.76216 ],
       [ 61.98826 , 107.710526],
       [ 58.87951 , 110.028175],
       [ 53.82684 , 105.128525],
       [ 53.897892, 104.315895],
       [ 52.18819 , 100.8029  ],
       [ 68.147125, 109.26588 ],
       [ 56.32138 , 104.71552 ],
       [ 89.47088 ,  13.449341],
       [ 80.13937 , 112.01057 ],
       [ 55.69441 , 103.31309 ],
       [ 61.37565 , 107.735504],
       [ 61.64262 , 107.62368 ],
       [ 54.445114, 104.96367 ],
       [ 57.91926 , 109.23137 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-298.38495159,  -29.99897861])


closest
(2,)


array([10,  1])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 4.7188,  0.2539, -0.6992,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3965, -6.7500, -0.1689,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.0938, -0.5117,  3.9219,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 4.3438, -2.5000,  0.6914,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.4688,  0.6094,  1.8125,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.4688,  4.2500,  3.9688,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[8.4506e-01, 8.9069e-02, 2.2520e-02,  ..., 5.0705e-21, 4.2036e-21,
         3.9489e-21],
        [9.9992e-01, 4.5396e-05, 2.1444e-05,  ..., 6.3866e-21, 1.8298e-21,
         1.2576e-21],
        [9.8766e-01, 8.5449e-03, 1.3104e-03,  ..., 7.6093e-21, 4.6153e-21,
         3.8262e-21],
        ...,
        [9.8295e-01, 1.5888e-02, 3.7365e-04,  ..., 7.6485e-20, 5.2567e-20,
         2.8137e-20],
        [6.3425e-01, 1.4152e-01, 9.7266e-02,  ..., 1.3415e-19, 7.1807e-20,
         5.2535e-20],
        [9.9993e-01, 5.8291e-05, 4.2226e-06,  ..., 1.3387e-21, 9.7943e-22,
         6.3237e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.8451, 0.9341, 0.9567,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9877, 0.9962, 0.9975,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9830, 0.9988, 0.9992,  ..., 1.0000, 1.0000, 1.0000],
        [0.6343, 0.7758, 0.8730,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([30])


tensor([ 0,  0,  1,  2,  3,  4,  4,  5,  6,  7,  8,  9,  9,  9,  9, 10, 10, 11,
        11, 12, 13, 14, 15, 16, 17, 18, 18, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([30, 77])


tensor([[    1, 32010,  1724,  ..., 21635,  1788, 29892],
        [    1, 32010,  1724,  ..., 21635,  1788, 29892],
        [    1, 32010,  1724,  ...,   723,   367,  6167],
        ...,
        [    1, 32010,  1724,  ..., 21578,   723,  1735],
        [    1, 32010,  1724,  ..., 21578,   723,  1735],
        [    1, 32010,  1724,  ..., 21578,   723,   748]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([30])


tensor([-11.1715, -11.1715, -13.0174, -15.0042, -15.7542, -14.8901, -14.8901,
        -14.7790, -15.5290, -17.3894, -18.2644, -18.5144, -18.5144, -18.5144,
        -18.5144, -15.9721, -15.9721, -17.9721, -17.9721, -16.6324, -16.9816,
        -17.8566, -18.0561, -15.8803, -16.3803, -18.7320, -18.7320, -18.7320,
        -18.7320, -19.6070], device='cuda:0')


new_candidate_toks
torch.Size([30, 1])


tensor([[  393],
        [  769],
        [  375],
        [ 1234],
        [ 1152],
        [  278],
        [  607],
        [ 6167],
        [ 6167],
        [  304],
        [  304],
        [29889],
        [16205],
        [ 2729],
        [ 8679],
        [ 2777],
        [ 1342],
        [16852],
        [ 9548],
        [29889],
        [  375],
        [   13],
        [  278],
        [ 6167],
        [ 6167],
        [29889],
        [16205],
        [ 2729],
        [ 8679],
        [  304]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([30])


tensor([-1.6834e-01, -2.4183e+00, -8.1662e-05, -1.2416e-02, -6.0202e-02,
        -1.3950e-01, -2.1395e+00, -1.7015e-02, -1.5236e-02, -5.0069e-05,
        -2.4936e-04, -9.6662e-01, -1.2166e+00, -2.0916e+00, -2.2166e+00,
        -3.1934e-01, -1.3193e+00, -3.0409e-01, -1.4291e+00, -8.0211e-02,
        -4.1247e-05, -1.1921e-07, -8.0922e-02, -1.6967e-02, -1.7195e-02,
        -4.5531e-01, -1.9553e+00, -2.3303e+00, -3.2053e+00, -6.9501e-05],
       device='cuda:0')


new_candidates
torch.Size([30, 78])


tensor([[    1, 32010,  1724,  ...,  1788, 29892,   393],
        [    1, 32010,  1724,  ...,  1788, 29892,   769],
        [    1, 32010,  1724,  ...,   367,  6167,   375],
        ...,
        [    1, 32010,  1724,  ...,   723,  1735,  2729],
        [    1, 32010,  1724,  ...,   723,  1735,  8679],
        [    1, 32010,  1724,  ...,   723,   748,   304]], device='cuda:0')


new_candidate_logprobs
torch.Size([30])


tensor([-11.3398, -13.5898, -13.0175, -15.0166, -15.8144, -15.0296, -17.0296,
        -14.7960, -15.5443, -17.3895, -18.2647, -19.4811, -19.7311, -20.6061,
        -20.7311, -16.2914, -17.2914, -18.2762, -19.4012, -16.7126, -16.9816,
        -17.8566, -18.1370, -15.8972, -16.3975, -19.1873, -20.6873, -21.0623,
        -21.9373, -19.6071], device='cuda:0')

infer end: GPU memory used: 22747 MB.
event: level
id: 66
data: [{"content": "that", "parent": 0, "prob": -11.339821815490723}, {"content": "then", "parent": 0, "prob": -13.589821815490723}, {"content": "us", "parent": 1, "prob": -13.017491340637207}, {"content": "answer", "parent": 2, "prob": -15.016615867614746}, {"content": "For", "parent": 3, "prob": -15.814401626586914}, {"content": "the", "parent": 4, "prob": -15.029571533203125}, {"content": "which", "parent": 4, "prob": -17.029571533203125}, {"content": "Olymp", "parent": 5, "prob": -14.795989036560059}, {"content": "Olymp", "parent": 6, "prob": -15.544251441955566}, {"content": "to", "parent": 7, "prob": -17.389476776123047}, {"content": "to", "parent": 8, "prob": -18.264677047729492}, {"content": ".", "parent": 9, "prob": -19.481050491333008}, {"content": "accordingly", "parent": 9, "prob": -19.731050491333008}, {"content": "based", "parent": 9, "prob": -20.606050491333008}, {"content": "depending", "parent": 9, "prob": -20.7

array([[-0.23828125, -0.78125   , -0.9609375 , ...,  1.8046875 ,
        -1.0390625 , -1.203125  ],
       [-1.5859375 , -0.67578125, -0.54296875, ...,  2.890625  ,
        -0.21875   , -1.1015625 ],
       [-2.203125  ,  0.33203125, -0.25195312, ..., -2.421875  ,
         1.2890625 ,  0.5859375 ],
       ...,
       [-1.0078125 , -1.1171875 ,  0.07617188, ..., -0.08349609,
         1.7265625 ,  3.1875    ],
       [-1.5234375 ,  0.41992188, -0.05371094, ..., -2.1875    ,
        -1.4296875 , -0.44921875],
       [-1.953125  ,  1.2265625 ,  2.75      , ...,  1.8828125 ,
        -1.2421875 ,  0.55859375]], dtype=float32)


k_mean_space
(20, 2)


array([[88.65407 , 66.822586],
       [51.60335 , 77.63485 ],
       [75.41444 , 82.588165],
       [79.897194, 51.579044],
       [66.50965 , 51.926285],
       [76.32527 , 49.999496],
       [77.503586, 52.14829 ],
       [90.145096, 75.29437 ],
       [90.17212 , 75.269264],
       [47.351143, 81.93508 ],
       [47.68401 , 82.13655 ],
       [46.07719 , 65.55393 ],
       [79.408356, 50.20558 ],
       [84.30696 , 55.18987 ],
       [87.771   , 60.67359 ],
       [68.04816 , 52.656204],
       [68.2621  , 52.51483 ],
       [81.314674, 70.350746],
       [96.63047 , 79.875   ],
       [48.05273 , 64.07689 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -98.4551239 , -236.89852524])


closest
(2,)


array([11,  5])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 3.3281, -6.5938,  0.9336,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.3438, -2.0312, -2.4219,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.2812, -2.0156, -4.3125,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.1445,  2.7188, -1.8047,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.4062,  7.0000, -1.8672,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.9688,  3.1719,  4.6250,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[6.4502e-01, 2.3729e-01, 8.7294e-02,  ..., 3.4155e-21, 2.8315e-21,
         2.2052e-21],
        [8.1029e-01, 8.5403e-02, 4.0342e-02,  ..., 4.3333e-20, 3.8241e-20,
         2.9782e-20],
        [1.0000e+00, 1.9947e-06, 2.3824e-07,  ..., 8.4743e-24, 7.4786e-24,
         6.5998e-24],
        ...,
        [9.9805e-01, 1.9267e-03, 8.9224e-06,  ..., 1.0860e-23, 6.5870e-24,
         1.4698e-24],
        [1.0000e+00, 2.1024e-07, 8.7642e-08,  ..., 7.4786e-24, 7.4786e-24,
         7.8824e-25],
        [8.8203e-01, 1.0534e-01, 5.9430e-03,  ..., 1.2696e-20, 1.2696e-20,
         1.0525e-20]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.6450, 0.8823, 0.9696,  ..., 1.0000, 1.0000, 1.0000],
        [0.8103, 0.8957, 0.9360,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9981, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8820, 0.9874, 0.9933,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([28])


tensor([ 0,  0,  0,  1,  1,  1,  2,  3,  4,  4,  5,  5,  5,  6,  7,  8,  9, 10,
        11, 12, 13, 14, 15, 16, 17, 18, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([28, 78])


tensor([[    1, 32010,  1724,  ...,  1788, 29892,   393],
        [    1, 32010,  1724,  ...,  1788, 29892,   393],
        [    1, 32010,  1724,  ...,  1788, 29892,   393],
        ...,
        [    1, 32010,  1724,  ..., 29889,  1551,  9548],
        [    1, 32010,  1724,  ...,   278,  1234, 29889],
        [    1, 32010,  1724,  ...,   278,  1234, 29889]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([28])


tensor([-11.3398, -11.3398, -11.3398, -13.5898, -13.5898, -13.5898, -13.0175,
        -15.0166, -15.8144, -15.8144, -15.0296, -15.0296, -15.0296, -17.0296,
        -14.7960, -15.5443, -17.3895, -18.2647, -19.4811, -19.7311, -20.6061,
        -20.7311, -16.2914, -17.2914, -18.2762, -19.4012, -16.7126, -16.7126],
       device='cuda:0')


new_candidate_toks
torch.Size([28, 1])


tensor([[  723],
        [ 3611],
        [21578],
        [ 6167],
        [  393],
        [  366],
        [  341],
        [29889],
        [ 2777],
        [ 1342],
        [15754],
        [ 2702],
        [ 6432],
        [15754],
        [  375],
        [  375],
        [ 6167],
        [ 6167],
        [ 1152],
        [29889],
        [  373],
        [  373],
        [29892],
        [29892],
        [29892],
        [  375],
        [ 1152],
        [ 1551]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([28])


tensor([-4.3847e-01, -1.4385e+00, -2.4385e+00, -2.1037e-01, -2.4604e+00,
        -3.2104e+00, -2.9802e-06, -6.0118e-02, -3.9029e-01, -1.1403e+00,
        -3.9079e-01, -1.6408e+00, -2.3908e+00, -7.1977e-02, -5.2734e-04,
        -2.8310e-04, -1.5700e-02, -1.7707e-02, -8.3101e-02, -8.9117e-02,
        -6.0560e-05, -2.0393e-04, -6.9501e-05, -5.4957e-05, -1.9499e-03,
        -4.7684e-07, -1.2553e-01, -2.2505e+00], device='cuda:0')


new_candidates
torch.Size([28, 79])


tensor([[    1, 32010,  1724,  ..., 29892,   393,   723],
        [    1, 32010,  1724,  ..., 29892,   393,  3611],
        [    1, 32010,  1724,  ..., 29892,   393, 21578],
        ...,
        [    1, 32010,  1724,  ...,  1551,  9548,   375],
        [    1, 32010,  1724,  ...,  1234, 29889,  1152],
        [    1, 32010,  1724,  ...,  1234, 29889,  1551]], device='cuda:0')


new_candidate_logprobs
torch.Size([28])


tensor([-11.7783, -12.7783, -13.7783, -13.8002, -16.0502, -16.8002, -13.0175,
        -15.0767, -16.2047, -16.9547, -15.4204, -16.6704, -17.4204, -17.1015,
        -14.7965, -15.5445, -17.4052, -18.2824, -19.5642, -19.8202, -20.6061,
        -20.7313, -16.2915, -17.2915, -18.2781, -19.4012, -16.8381, -18.9631],
       device='cuda:0')

infer end: GPU memory used: 22939 MB.
event: level
id: 67
data: [{"content": "would", "parent": 0, "prob": -11.77829360961914}, {"content": "title", "parent": 0, "prob": -12.77829360961914}, {"content": "distinction", "parent": 0, "prob": -13.77829360961914}, {"content": "Olymp", "parent": 1, "prob": -13.800189971923828}, {"content": "that", "parent": 1, "prob": -16.050189971923828}, {"content": "you", "parent": 1, "prob": -16.800189971923828}, {"content": "M", "parent": 2, "prob": -13.017494201660156}, {"content": ".", "parent": 3, "prob": -15.076733589172363}, {"content": "instance", "parent": 4, "prob": -16.204687118530273}, {"content": "example", "parent": 4, "prob": -16.954687118530273}, {"content": "planet", "parent": 5, "prob": -15.420360565185547}, {"content": "specific", "parent": 5, "prob": -16.670360565185547}, {"content": "cel", "parent": 5, "prob": -17.420360565185547}, {"content": "planet", "parent": 6, "prob": -17.10154914855957}, {"content": "us", "parent": 7, "prob": -

array([[ 0.42773438, -0.5234375 , -0.19628906, ...,  0.265625  ,
        -0.37109375, -1.2421875 ],
       [-1.15625   ,  0.36328125, -0.62890625, ...,  0.23242188,
        -0.16015625, -0.40625   ],
       [-0.72265625, -0.80859375, -1.2109375 , ..., -1.234375  ,
        -0.06079102, -1.59375   ],
       ...,
       [-0.20117188, -0.4609375 , -0.09179688, ..., -0.640625  ,
         0.875     , -0.90625   ],
       [ 0.87890625,  0.76953125,  2.25      , ..., -0.08789062,
        -0.22167969, -0.23925781],
       [-1.9375    ,  1.15625   ,  2.703125  , ...,  1.3984375 ,
        -0.9921875 ,  0.52734375]], dtype=float32)


k_mean_space
(20, 2)


array([[53.1548  , 87.87933 ],
       [57.899815, 93.131966],
       [54.83648 , 90.54945 ],
       [89.2031  , 50.70235 ],
       [58.918747, 94.87343 ],
       [70.665565, 96.91118 ],
       [90.70107 , 80.64343 ],
       [58.550007, 84.03012 ],
       [51.764034, 84.19277 ],
       [52.160538, 84.22386 ],
       [58.41058 , 91.49962 ],
       [52.53759 , 82.49777 ],
       [83.48178 , 97.993645],
       [57.797073, 92.63282 ],
       [86.82689 , 59.04468 ],
       [86.92184 , 59.250576],
       [88.877754, 46.627773],
       [88.80904 , 46.52421 ],
       [52.90893 , 83.20485 ],
       [57.372208, 82.03272 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-225.41831875,  -92.84629631])


closest
(2,)


array([ 8, 17])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 4.2188,  0.2695,  0.0649,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.7812, -1.7578,  2.3594,  ...,  0.0000,  0.0000,  0.0000],
        [ 8.1250, -0.3848,  2.7812,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0771, -6.8125,  0.5586,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.0000,  1.7891,  1.7500,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.1250,  2.9531,  4.2500,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.7040e-01, 1.3842e-02, 1.2216e-02,  ..., 2.0122e-21, 6.5327e-22,
         5.4158e-22],
        [7.7382e-01, 1.7266e-01, 4.9469e-02,  ..., 4.8455e-23, 4.5519e-23,
         1.2251e-23],
        [6.7704e-01, 2.4907e-01, 5.5575e-02,  ..., 5.4978e-22, 4.2817e-22,
         2.2692e-23],
        ...,
        [9.9954e-01, 2.6114e-04, 1.7948e-04,  ..., 9.2890e-21, 8.1975e-21,
         1.2571e-21],
        [7.2848e-01, 2.6799e-01, 2.9771e-03,  ..., 1.0485e-20, 5.6125e-21,
         4.6529e-21],
        [9.4089e-01, 5.3082e-02, 2.6428e-03,  ..., 1.3543e-20, 1.1952e-20,
         1.1227e-20]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9704, 0.9842, 0.9965,  ..., 1.0000, 1.0000, 1.0000],
        [0.7738, 0.9465, 0.9960,  ..., 1.0000, 1.0000, 1.0000],
        [0.6770, 0.9261, 0.9817,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9995, 0.9998, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.7285, 0.9965, 0.9994,  ..., 1.0000, 1.0000, 1.0000],
        [0.9409, 0.9940, 0.9966,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([27])


tensor([ 0,  1,  1,  2,  2,  3,  4,  4,  4,  5,  5,  6,  7,  8,  9, 10, 11, 12,
        13, 13, 14, 15, 16, 17, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([27, 79])


tensor([[    1, 32010,  1724,  ..., 29892,   393,   723],
        [    1, 32010,  1724,  ..., 29892,   393,  3611],
        [    1, 32010,  1724,  ..., 29892,   393,  3611],
        ...,
        [    1, 32010,  1724,  ...,  1735, 29889,  1152],
        [    1, 32010,  1724,  ...,  1735, 29889,  1152],
        [    1, 32010,  1724,  ...,  1735, 16205, 29889]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([27])


tensor([-11.7783, -12.7783, -12.7783, -13.7783, -13.7783, -13.8002, -16.0502,
        -16.0502, -16.0502, -16.8002, -16.8002, -13.0175, -15.0767, -16.2047,
        -16.9547, -15.4204, -16.6704, -17.4204, -17.1015, -17.1015, -14.7965,
        -15.5445, -17.4052, -18.2824, -19.5642, -19.5642, -19.8202],
       device='cuda:0')


new_candidate_toks
torch.Size([27, 1])


tensor([[  367],
        [ 5771],
        [14393],
        [ 5771],
        [14393],
        [  375],
        [  723],
        [ 3611],
        [21578],
        [29915],
        [ 1795],
        [  787],
        [ 1152],
        [29892],
        [29892],
        [  297],
        [15754],
        [  342],
        [  366],
        [29915],
        [  341],
        [  341],
        [  375],
        [  375],
        [ 2777],
        [ 1342],
        [ 1152]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([27])


tensor([-3.0045e-02, -2.5641e-01, -1.7564e+00, -3.9003e-01, -1.3900e+00,
        -2.9326e-05, -4.1853e-01, -1.6685e+00, -2.1685e+00, -6.5138e-01,
        -9.0138e-01,  0.0000e+00, -4.4524e-02, -8.1063e-06, -5.6029e-06,
        -9.4610e-02, -7.5167e-02,  0.0000e+00, -1.0970e-01, -2.6097e+00,
        -1.5497e-06, -1.4305e-06, -4.9925e-04, -4.5662e-04, -3.1680e-01,
        -1.3168e+00, -6.0925e-02], device='cuda:0')


new_candidates
torch.Size([27, 80])


tensor([[    1, 32010,  1724,  ...,   393,   723,   367],
        [    1, 32010,  1724,  ...,   393,  3611,  5771],
        [    1, 32010,  1724,  ...,   393,  3611, 14393],
        ...,
        [    1, 32010,  1724,  ..., 29889,  1152,  2777],
        [    1, 32010,  1724,  ..., 29889,  1152,  1342],
        [    1, 32010,  1724,  ..., 16205, 29889,  1152]], device='cuda:0')


new_candidate_logprobs
torch.Size([27])


tensor([-11.8083, -13.0347, -14.5347, -14.1683, -15.1683, -13.8002, -16.4687,
        -17.7187, -18.2187, -17.4516, -17.7016, -13.0175, -15.1213, -16.2047,
        -16.9547, -15.5150, -16.7455, -17.4204, -17.2112, -19.7112, -14.7965,
        -15.5445, -17.4057, -18.2828, -19.8809, -20.8809, -19.8811],
       device='cuda:0')

infer end: GPU memory used: 16559 MB.
event: level
id: 68
data: [{"content": "be", "parent": 0, "prob": -11.80833911895752}, {"content": "goes", "parent": 1, "prob": -13.034708023071289}, {"content": "belongs", "parent": 1, "prob": -14.534708023071289}, {"content": "goes", "parent": 2, "prob": -14.168320655822754}, {"content": "belongs", "parent": 2, "prob": -15.168320655822754}, {"content": "us", "parent": 3, "prob": -13.800219535827637}, {"content": "would", "parent": 4, "prob": -16.46871566772461}, {"content": "title", "parent": 4, "prob": -17.71871566772461}, {"content": "distinction", "parent": 4, "prob": -18.21871566772461}, {"content": "'", "parent": 5, "prob": -17.451570510864258}, {"content": "might", "parent": 5, "prob": -17.701570510864258}, {"content": "ons", "parent": 6, "prob": -13.017494201660156}, {"content": "For", "parent": 7, "prob": -15.121257781982422}, {"content": ",", "parent": 8, "prob": -16.204694747924805}, {"content": ",", "parent": 9, "prob": -16.95469284057

array([[-1.671875  , -1.2578125 ,  1.296875  , ...,  2.484375  ,
         0.3203125 ,  1.3828125 ],
       [-0.12402344,  0.38085938, -0.89453125, ...,  1.03125   ,
        -1.109375  ,  0.97265625],
       [-1.390625  ,  0.43359375, -0.14746094, ...,  1.4296875 ,
        -0.18945312, -0.44921875],
       ...,
       [ 1.765625  ,  0.49609375, -0.33203125, ...,  1.5234375 ,
         0.04467773, -0.85546875],
       [-0.83203125,  0.6015625 ,  0.44726562, ...,  1.015625  ,
        -2.6875    ,  2.484375  ],
       [-0.14746094, -0.4921875 ,  2.171875  , ...,  1.6953125 ,
        -0.578125  , -0.09130859]], dtype=float32)


k_mean_space
(20, 2)


array([[86.91793 , 49.500763],
       [51.244244, 83.47274 ],
       [46.803444, 76.92171 ],
       [50.649815, 83.47774 ],
       [45.909904, 76.788994],
       [91.01521 , 74.93631 ],
       [49.218025, 76.26941 ],
       [53.25785 , 81.794334],
       [50.770077, 79.74806 ],
       [58.975414, 49.33152 ],
       [61.883945, 84.771645],
       [84.027054, 70.24468 ],
       [62.931812, 62.637764],
       [85.57137 , 45.022083],
       [85.6081  , 45.062553],
       [55.217533, 71.224594],
       [61.874386, 78.77958 ],
       [81.90001 , 91.741264],
       [67.99484 , 84.34144 ],
       [87.44221 , 55.269142]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-193.90588188, -124.06951714])


closest
(2,)


array([ 4, 13])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 3.0781, -2.4375, -0.9570,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.1562,  3.1406,  1.1328,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.4688,  3.2812,  3.5938,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-2.0469,  0.1738, -1.2969,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.8438, -0.8164, -2.0938,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.6094,  3.4219,  0.6094,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9847e-01, 9.1048e-04, 1.3963e-04,  ..., 5.3397e-20, 3.2387e-20,
         1.9644e-20],
        [1.0000e+00, 1.2099e-06, 9.9312e-08,  ..., 4.5812e-23, 2.4521e-23,
         1.2330e-23],
        [9.9997e-01, 2.1445e-05, 2.9022e-06,  ..., 1.6149e-21, 1.1815e-21,
         1.1099e-21],
        ...,
        [1.0000e+00, 1.4450e-07, 1.1254e-07,  ..., 1.0458e-27, 9.7276e-29,
         1.6904e-29],
        [9.8453e-01, 1.0937e-02, 4.0235e-03,  ..., 3.9803e-23, 3.5126e-23,
         4.7538e-24],
        [1.0000e+00, 2.9023e-06, 9.4224e-07,  ..., 7.7043e-21, 4.6729e-21,
         8.6440e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9985, 0.9994, 0.9995,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9845, 0.9955, 0.9995,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([30])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  7,  8,  8,  9, 10, 11, 12, 12, 13, 13,
        13, 13, 14, 14, 14, 14, 15, 16, 16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([30, 80])


tensor([[    1, 32010,  1724,  ...,   393,   723,   367],
        [    1, 32010,  1724,  ...,   393,  3611,  5771],
        [    1, 32010,  1724,  ...,   393,  3611, 14393],
        ...,
        [    1, 32010,  1724,  ...,   278,  6432,   342],
        [    1, 32010,  1724,  ...,   607, 15754,   366],
        [    1, 32010,  1724,  ...,   607, 15754, 29915]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([30])


tensor([-11.8083, -13.0347, -14.5347, -14.1683, -15.1683, -13.8002, -16.4687,
        -17.7187, -17.7187, -18.2187, -18.2187, -17.4516, -17.7016, -13.0175,
        -15.1213, -15.1213, -16.2047, -16.2047, -16.2047, -16.2047, -16.9547,
        -16.9547, -16.9547, -16.9547, -15.5150, -16.7455, -16.7455, -17.4204,
        -17.2112, -19.7112], device='cuda:0')


new_candidate_toks
torch.Size([30, 1])


tensor([[ 6167],
        [  304],
        [  304],
        [  304],
        [  304],
        [  341],
        [  367],
        [ 5771],
        [14393],
        [ 5771],
        [14393],
        [  276],
        [  367],
        [  373],
        [ 2777],
        [ 1342],
        [ 6167],
        [  373],
        [ 9548],
        [16852],
        [ 6167],
        [  373],
        [ 9548],
        [  278],
        [ 1139],
        [  297],
        [29889],
        [  616],
        [29915],
        [29879]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([30])


tensor([-1.5337e-03, -1.4305e-06, -2.7180e-05, -4.7684e-06, -2.9564e-05,
        -3.5763e-07, -1.5176e-02, -2.7761e-01, -1.5276e+00, -4.2773e-01,
        -1.1777e+00, -6.2202e-02, -1.1439e-03, -1.3056e-02, -3.1823e-01,
        -1.3182e+00, -5.4213e-01, -1.4171e+00, -2.6671e+00, -3.4171e+00,
        -5.5710e-01, -1.3071e+00, -2.9321e+00, -3.4321e+00, -4.8721e-04,
        -6.7359e-01, -9.2359e-01, -4.7684e-07, -1.5591e-02, -4.5300e-06],
       device='cuda:0')


new_candidates
torch.Size([30, 81])


tensor([[    1, 32010,  1724,  ...,   723,   367,  6167],
        [    1, 32010,  1724,  ...,  3611,  5771,   304],
        [    1, 32010,  1724,  ...,  3611, 14393,   304],
        ...,
        [    1, 32010,  1724,  ...,  6432,   342,   616],
        [    1, 32010,  1724,  ..., 15754,   366, 29915],
        [    1, 32010,  1724,  ..., 15754, 29915, 29879]], device='cuda:0')


new_candidate_logprobs
torch.Size([30])


tensor([-11.8099, -13.0347, -14.5347, -14.1683, -15.1684, -13.8002, -16.4839,
        -17.9963, -19.2463, -18.6464, -19.3964, -17.5138, -17.7027, -13.0306,
        -15.4395, -16.4395, -16.7468, -17.6218, -18.8718, -19.6218, -17.5118,
        -18.2618, -19.8868, -20.3868, -15.5155, -17.4191, -17.6691, -17.4204,
        -17.2268, -19.7113], device='cuda:0')

infer end: GPU memory used: 17095 MB.
event: level
id: 69
data: [{"content": "Olymp", "parent": 0, "prob": -11.8098726272583}, {"content": "to", "parent": 1, "prob": -13.034709930419922}, {"content": "to", "parent": 2, "prob": -14.534735679626465}, {"content": "to", "parent": 3, "prob": -14.168325424194336}, {"content": "to", "parent": 4, "prob": -15.168350219726562}, {"content": "M", "parent": 5, "prob": -13.800219535827637}, {"content": "be", "parent": 6, "prob": -16.4838924407959}, {"content": "goes", "parent": 7, "prob": -17.996326446533203}, {"content": "belongs", "parent": 7, "prob": -19.246326446533203}, {"content": "goes", "parent": 8, "prob": -18.64644432067871}, {"content": "belongs", "parent": 8, "prob": -19.39644432067871}, {"content": "re", "parent": 9, "prob": -17.51377296447754}, {"content": "be", "parent": 10, "prob": -17.702714920043945}, {"content": "on", "parent": 11, "prob": -13.030550003051758}, {"content": "instance", "parent": 12, "prob": -15.439489364624023}, {"

array([[-0.04345703,  0.1171875 , -0.12890625, ..., -0.30664062,
         0.6484375 , -0.53515625],
       [-1.84375   , -1.4453125 ,  1.4921875 , ...,  1.921875  ,
        -0.31640625,  1.328125  ],
       [-1.4921875 , -1.3984375 ,  1.6015625 , ...,  2.140625  ,
        -0.21289062,  1.4609375 ],
       ...,
       [ 0.73828125, -1.546875  ,  2.03125   , ...,  0.3671875 ,
         0.2890625 , -1.234375  ],
       [-0.6640625 ,  0.48632812, -0.7109375 , ..., -2.234375  ,
        -1.328125  , -1.0078125 ],
       [-0.29296875, -0.23339844, -0.40429688, ...,  0.40429688,
        -0.5859375 ,  3.        ]], dtype=float32)


k_mean_space
(20, 2)


array([[ 99.17431 ,  75.199905],
       [ 92.62412 ,  49.395584],
       [ 92.908165,  49.52098 ],
       [ 92.64023 ,  49.37547 ],
       [ 92.81446 ,  49.28722 ],
       [100.62189 ,  84.65215 ],
       [ 91.74976 ,  50.083984],
       [ 39.842716,  81.27502 ],
       [ 36.961063,  74.526276],
       [ 39.6752  ,  81.009995],
       [ 36.32027 ,  74.472595],
       [ 61.05286 ,  83.56231 ],
       [ 61.020645,  81.909225],
       [ 95.04823 ,  70.40561 ],
       [ 71.1771  ,  63.261818],
       [ 72.597466,  64.181694],
       [ 99.52915 ,  75.22681 ],
       [ 93.59718 ,  63.92911 ],
       [ 98.241646,  82.89084 ],
       [ 85.89197 ,  74.31235 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-110.50202942, -216.77193165])


closest
(2,)


array([10,  4])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 0.1670, -6.5625, -0.0225,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.4375, -3.2812,  0.2471,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.4688, -1.7109,  0.9648,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 2.6094, -1.4375, -0.0996,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.9297,  5.5938, -0.8945,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.5938,  6.1562, -0.2930,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9979e-01, 1.2338e-04, 4.0057e-05,  ..., 8.1996e-21, 8.1996e-21,
         2.0732e-21],
        [9.9346e-01, 5.9073e-03, 1.7839e-04,  ..., 4.6886e-20, 2.8438e-20,
         1.3433e-20],
        [9.9535e-01, 4.0678e-03, 1.2284e-04,  ..., 4.6976e-20, 3.6585e-20,
         2.8492e-20],
        ...,
        [8.9685e-01, 9.4528e-02, 6.8476e-03,  ..., 7.7525e-22, 6.8415e-22,
         9.2590e-23],
        [1.0000e+00, 1.2099e-06, 9.4224e-07,  ..., 5.5259e-23, 3.3517e-23,
         1.5832e-23],
        [7.2302e-01, 2.6598e-01, 4.8716e-03,  ..., 3.8285e-21, 3.3786e-21,
         3.1739e-21]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9998, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9935, 0.9994, 0.9995,  ..., 1.0000, 1.0000, 1.0000],
        [0.9954, 0.9994, 0.9995,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.8969, 0.9914, 0.9982,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.7230, 0.9890, 0.9939,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([25])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 11, 11, 12, 12, 13, 14,
        15, 16, 17, 17, 18, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([25, 81])


tensor([[    1, 32010,  1724,  ...,   723,   367,  6167],
        [    1, 32010,  1724,  ...,  3611,  5771,   304],
        [    1, 32010,  1724,  ...,  3611, 14393,   304],
        ...,
        [    1, 32010,  1724,  ...,  2777, 29892,  9548],
        [    1, 32010,  1724,  ...,  2777, 29892, 16852],
        [    1, 32010,  1724,  ...,  2777, 29892, 16852]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([25])


tensor([-11.8099, -13.0347, -14.5347, -14.1683, -15.1684, -13.8002, -16.4839,
        -17.9963, -19.2463, -18.6464, -19.3964, -17.5138, -17.5138, -17.5138,
        -17.7027, -17.7027, -13.0306, -15.4395, -16.4395, -16.7468, -17.6218,
        -17.6218, -18.8718, -19.6218, -19.6218], device='cuda:0')


new_candidate_toks
torch.Size([25, 1])


tensor([[  375],
        [ 6167],
        [ 6167],
        [ 6167],
        [ 6167],
        [  787],
        [ 6167],
        [  304],
        [  304],
        [  304],
        [  304],
        [ 5517],
        [16811],
        [ 3063],
        [16811],
        [ 8852],
        [16852],
        [29892],
        [29892],
        [  375],
        [16852],
        [ 9548],
        [  375],
        [  756],
        [29915]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([25])


tensor([-2.0679e-04, -6.5590e-03, -4.6593e-03, -6.5410e-03, -6.7282e-03,
         0.0000e+00, -1.2830e-03, -7.1526e-07, -1.6570e-05, -2.5034e-06,
        -1.7524e-05, -7.5192e-01, -1.0019e+00, -2.0019e+00, -1.6619e-01,
        -2.1662e+00, -9.3402e-04, -8.5831e-06, -7.9871e-06, -1.5427e-04,
        -1.0886e-01, -2.3589e+00, -3.6955e-06, -3.2432e-01, -1.3243e+00],
       device='cuda:0')


new_candidates
torch.Size([25, 82])


tensor([[    1, 32010,  1724,  ...,   367,  6167,   375],
        [    1, 32010,  1724,  ...,  5771,   304,  6167],
        [    1, 32010,  1724,  ..., 14393,   304,  6167],
        ...,
        [    1, 32010,  1724,  ..., 29892,  9548,   375],
        [    1, 32010,  1724,  ..., 29892, 16852,   756],
        [    1, 32010,  1724,  ..., 29892, 16852, 29915]], device='cuda:0')


new_candidate_logprobs
torch.Size([25])


tensor([-11.8101, -13.0413, -14.5394, -14.1749, -15.1751, -13.8002, -16.4852,
        -17.9963, -19.2463, -18.6464, -19.3965, -18.2657, -18.5157, -19.5157,
        -17.8689, -19.8689, -13.0315, -15.4395, -16.4395, -16.7470, -17.7307,
        -19.9807, -18.8718, -19.9461, -20.9461], device='cuda:0')

infer end: GPU memory used: 17295 MB.
event: level
id: 70
data: [{"content": "us", "parent": 0, "prob": -11.810079574584961}, {"content": "Olymp", "parent": 1, "prob": -13.041269302368164}, {"content": "Olymp", "parent": 2, "prob": -14.539395332336426}, {"content": "Olymp", "parent": 3, "prob": -14.174866676330566}, {"content": "Olymp", "parent": 4, "prob": -15.175078392028809}, {"content": "ons", "parent": 5, "prob": -13.800219535827637}, {"content": "Olymp", "parent": 6, "prob": -16.48517608642578}, {"content": "to", "parent": 7, "prob": -17.996326446533203}, {"content": "to", "parent": 8, "prob": -19.2463436126709}, {"content": "to", "parent": 9, "prob": -18.646446228027344}, {"content": "to", "parent": 10, "prob": -19.396461486816406}, {"content": "likely", "parent": 11, "prob": -18.26569366455078}, {"content": "referring", "parent": 11, "prob": -18.51569366455078}, {"content": "looking", "parent": 11, "prob": -19.51569366455078}, {"content": "referring", "parent": 12, "prob": -17.

array([[-2.359375  ,  0.51171875,  0.06176758, ..., -2.390625  ,
         0.8984375 ,  0.49023438],
       [-0.03686523, -0.2734375 , -0.296875  , ..., -0.359375  ,
         0.78125   , -1.0390625 ],
       [-0.1796875 , -0.1796875 , -0.23242188, ..., -0.46484375,
         0.9609375 , -0.85546875],
       ...,
       [-1.421875  , -1.40625   ,  1.8984375 , ...,  2.234375  ,
        -0.32617188,  0.46289062],
       [-1.3984375 , -1.3203125 ,  1.953125  , ...,  2.296875  ,
        -0.34375   ,  0.5625    ],
       [-2.21875   ,  0.30273438,  0.62890625, ..., -2.21875   ,
         0.31640625,  0.8125    ]], dtype=float32)


k_mean_space
(20, 2)


array([[ 99.943794,  70.68219 ],
       [103.765526,  56.697575],
       [103.81442 ,  56.54302 ],
       [103.922935,  56.555897],
       [104.108   ,  56.609062],
       [ 99.353424,  75.55601 ],
       [104.49919 ,  57.00364 ],
       [ 19.059593,  81.67839 ],
       [ 18.730984,  81.729805],
       [ 19.01748 ,  81.705025],
       [ 17.638416,  81.68027 ],
       [105.505844,  73.76627 ],
       [ 92.74562 ,  60.16428 ],
       [103.38122 ,  72.5303  ],
       [ 92.53663 ,  61.28595 ],
       [100.77237 ,  71.026535],
       [ 97.93475 ,  68.92442 ],
       [ 35.83639 ,  80.21609 ],
       [ 36.05326 ,  80.21791 ],
       [100.09252 ,  70.78084 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-107.16457272, -222.8394289 ])


closest
(2,)


array([10,  2])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 5.2500, -2.4375, -4.3438,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0752, -7.1250,  0.2236,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0586, -7.0625,  0.1191,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 2.5938, -0.9336, -1.0703,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.4062, -1.0312, -1.2422,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.1562, -2.4062, -4.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 7.3382e-07, 1.6374e-07,  ..., 2.7512e-24, 1.4726e-24,
         4.7809e-25],
        [9.9938e-01, 3.7989e-04, 2.0334e-04,  ..., 1.7351e-20, 1.3513e-20,
         1.4243e-21],
        [9.9956e-01, 2.3046e-04, 1.7948e-04,  ..., 9.2891e-21, 6.3843e-21,
         9.7907e-22],
        ...,
        [4.9487e-01, 3.4012e-01, 4.6030e-02,  ..., 1.7257e-19, 5.6026e-20,
         4.9443e-20],
        [4.7152e-01, 3.6722e-01, 4.3859e-02,  ..., 2.1114e-19, 5.3384e-20,
         5.3384e-20],
        [1.0000e+00, 4.4508e-07, 3.0590e-07,  ..., 8.4743e-24, 1.8909e-24,
         1.6687e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9994, 0.9998, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9996, 0.9998, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.4949, 0.8350, 0.8810,  ..., 1.0000, 1.0000, 1.0000],
        [0.4715, 0.8387, 0.8826,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([27])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 13, 14, 15, 16,
        17, 17, 17, 17, 18, 18, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([27, 82])


tensor([[    1, 32010,  1724,  ...,   367,  6167,   375],
        [    1, 32010,  1724,  ...,  5771,   304,  6167],
        [    1, 32010,  1724,  ..., 14393,   304,  6167],
        ...,
        [    1, 32010,  1724,  ...,  1152,  1342, 29892],
        [    1, 32010,  1724,  ...,  1152,  1342, 29892],
        [    1, 32010,  1724,  ..., 29892,  6167,   375]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([27])


tensor([-11.8101, -13.0413, -14.5394, -14.1749, -15.1751, -13.8002, -16.4852,
        -17.9963, -19.2463, -18.6464, -19.3965, -18.2657, -18.5157, -19.5157,
        -19.5157, -17.8689, -19.8689, -13.0315, -15.4395, -15.4395, -15.4395,
        -15.4395, -16.4395, -16.4395, -16.4395, -16.4395, -16.7470],
       device='cuda:0')


new_candidate_toks
torch.Size([27, 1])


tensor([[  341],
        [  375],
        [  375],
        [  375],
        [  375],
        [  373],
        [  375],
        [ 6167],
        [ 6167],
        [ 6167],
        [ 6167],
        [16811],
        [  304],
        [  472],
        [  363],
        [  304],
        [  297],
        [29892],
        [ 6167],
        [  373],
        [ 9548],
        [  278],
        [ 6167],
        [  373],
        [ 9548],
        [  278],
        [  341]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([27])


tensor([-1.1921e-06, -6.2396e-04, -4.4123e-04, -5.8597e-04, -3.9985e-04,
        -9.3382e-03, -2.6820e-04, -5.9297e-03, -4.8256e-03, -6.6158e-03,
        -6.0702e-03, -4.4876e-02, -4.7446e-05, -2.0222e-01, -1.7022e+00,
        -2.4319e-05, -6.7178e-03, -4.9188e-02, -7.0346e-01, -1.0785e+00,
        -3.0785e+00, -3.4535e+00, -7.5178e-01, -1.0018e+00, -3.1268e+00,
        -3.3768e+00, -1.1921e-06], device='cuda:0')


new_candidates
torch.Size([27, 83])


tensor([[    1, 32010,  1724,  ...,  6167,   375,   341],
        [    1, 32010,  1724,  ...,   304,  6167,   375],
        [    1, 32010,  1724,  ...,   304,  6167,   375],
        ...,
        [    1, 32010,  1724,  ...,  1342, 29892,  9548],
        [    1, 32010,  1724,  ...,  1342, 29892,   278],
        [    1, 32010,  1724,  ...,  6167,   375,   341]], device='cuda:0')


new_candidate_logprobs
torch.Size([27])


tensor([-11.8101, -13.0419, -14.5398, -14.1755, -15.1755, -13.8096, -16.4854,
        -18.0023, -19.2512, -18.6531, -19.4025, -18.3106, -18.5157, -19.7179,
        -21.2179, -17.8689, -19.8756, -13.0807, -16.1430, -16.5180, -18.5180,
        -18.8930, -17.1913, -17.4413, -19.5663, -19.8163, -16.7470],
       device='cuda:0')

infer end: GPU memory used: 17497 MB.
event: level
id: 71
data: [{"content": "M", "parent": 0, "prob": -11.810080528259277}, {"content": "us", "parent": 1, "prob": -13.041893005371094}, {"content": "us", "parent": 2, "prob": -14.539836883544922}, {"content": "us", "parent": 3, "prob": -14.17545223236084}, {"content": "us", "parent": 4, "prob": -15.175477981567383}, {"content": "on", "parent": 5, "prob": -13.809557914733887}, {"content": "us", "parent": 6, "prob": -16.485445022583008}, {"content": "Olymp", "parent": 7, "prob": -18.002256393432617}, {"content": "Olymp", "parent": 8, "prob": -19.251169204711914}, {"content": "Olymp", "parent": 9, "prob": -18.65306282043457}, {"content": "Olymp", "parent": 10, "prob": -19.40253257751465}, {"content": "referring", "parent": 11, "prob": -18.310569763183594}, {"content": "to", "parent": 12, "prob": -18.5157413482666}, {"content": "at", "parent": 13, "prob": -19.71790885925293}, {"content": "for", "parent": 13, "prob": -21.21790885925293}, {"c

array([[-0.01867676,  0.64453125, -1.78125   , ..., -2.609375  ,
         0.00439453,  1.109375  ],
       [-2.15625   ,  0.6640625 ,  0.27148438, ..., -2.75      ,
         0.92578125,  0.85546875],
       [-2.046875  ,  0.81640625,  0.12011719, ..., -2.796875  ,
         0.79296875,  0.77734375],
       ...,
       [-2.796875  ,  0.24804688,  0.953125  , ..., -0.65234375,
         0.234375  ,  0.7734375 ],
       [-0.08203125,  0.12597656,  0.13574219, ..., -1.3828125 ,
         0.56640625, -1.0234375 ],
       [ 0.78515625, -1.65625   ,  2.125     , ...,  0.27734375,
         0.33007812, -1.265625  ]], dtype=float32)


k_mean_space
(20, 2)


array([[87.17218 , 98.75724 ],
       [50.535217, 89.14092 ],
       [50.79031 , 89.827805],
       [50.636177, 89.35108 ],
       [50.79877 , 89.73145 ],
       [95.0828  , 71.896324],
       [51.03296 , 89.568085],
       [51.838528, 93.859886],
       [51.796757, 93.891335],
       [51.975563, 94.016174],
       [52.037308, 94.14981 ],
       [87.77951 , 68.417046],
       [90.39258 , 36.003204],
       [90.84194 , 36.415787],
       [90.78735 , 36.7398  ],
       [90.402016, 35.863533],
       [91.49264 , 39.66328 ],
       [92.25157 , 75.67038 ],
       [52.39474 , 94.64322 ],
       [94.34655 , 61.494324]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-176.68017006, -158.91486454])


closest
(2,)


array([ 1, 15])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.4707, -5.1250, -3.0938,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.9688, -1.5000, -3.8438,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.2500, -1.2656, -3.6562,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 3.9688, -4.4062, -2.4531,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.4375, -5.9688, -0.5547,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.1094, -1.6719,  0.3594,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 1.1628e-10, 6.2241e-11,  ..., 1.8538e-26, 1.5217e-27,
         9.2293e-28],
        [1.0000e+00, 5.0435e-07, 1.1254e-07,  ..., 1.6687e-24, 7.8824e-25,
         6.1388e-25],
        [1.0000e+00, 3.4663e-07, 1.1254e-07,  ..., 1.1469e-24, 6.9562e-25,
         6.9562e-25],
        ...,
        [9.6700e-01, 2.2742e-02, 7.3831e-03,  ..., 1.0627e-22, 7.7749e-23,
         5.3436e-23],
        [9.9986e-01, 7.4841e-05, 5.1438e-05,  ..., 7.2366e-21, 6.3863e-21,
         4.9736e-21],
        [9.0795e-01, 8.4453e-02, 6.1177e-03,  ..., 4.7603e-22, 4.7603e-22,
         7.3001e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9670, 0.9897, 0.9971,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9080, 0.9924, 0.9985,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([20])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19], device='cuda:0')


carryover_candidates
torch.Size([20, 83])


tensor([[    1, 32010,  1724,  ...,  6167,   375,   341],
        [    1, 32010,  1724,  ...,   304,  6167,   375],
        [    1, 32010,  1724,  ...,   304,  6167,   375],
        ...,
        [    1, 32010,  1724,  ...,   373, 16852, 29892],
        [    1, 32010,  1724,  ...,  2777, 29892,  6167],
        [    1, 32010,  1724,  ...,  2777, 29892,   373]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([20])


tensor([-11.8101, -13.0419, -14.5398, -14.1755, -15.1755, -13.8096, -16.4854,
        -18.0023, -19.2512, -18.6531, -19.4025, -18.3106, -18.5157, -19.7179,
        -21.2179, -17.8689, -19.8756, -13.0807, -16.1430, -16.5180],
       device='cuda:0')


new_candidate_toks
torch.Size([20, 1])


tensor([[  787],
        [  341],
        [  341],
        [  341],
        [  341],
        [16852],
        [  341],
        [  375],
        [  375],
        [  375],
        [  375],
        [  304],
        [ 6167],
        [ 6167],
        [ 6167],
        [ 6167],
        [ 6167],
        [  607],
        [  375],
        [16852]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([20])


tensor([ 0.0000e+00, -9.5367e-07, -7.1526e-07, -9.5367e-07, -8.3447e-07,
        -6.5032e-04, -1.0729e-06, -4.5232e-04, -2.7929e-04, -6.4209e-04,
        -3.4970e-04, -7.7370e-05, -2.0164e-02, -4.7812e-02, -1.8809e-02,
        -2.6199e-02, -2.6369e-02, -3.3557e-02, -1.3948e-04, -9.6565e-02],
       device='cuda:0')


new_candidates
torch.Size([20, 84])


tensor([[    1, 32010,  1724,  ...,   375,   341,   787],
        [    1, 32010,  1724,  ...,  6167,   375,   341],
        [    1, 32010,  1724,  ...,  6167,   375,   341],
        ...,
        [    1, 32010,  1724,  ..., 16852, 29892,   607],
        [    1, 32010,  1724,  ..., 29892,  6167,   375],
        [    1, 32010,  1724,  ..., 29892,   373, 16852]], device='cuda:0')


new_candidate_logprobs
torch.Size([20])


tensor([-11.8101, -13.0419, -14.5398, -14.1755, -15.1755, -13.8102, -16.4854,
        -18.0027, -19.2514, -18.6537, -19.4029, -18.3106, -18.5359, -19.7657,
        -21.2367, -17.8951, -19.9020, -13.1142, -16.1431, -16.6145],
       device='cuda:0')

infer end: GPU memory used: 17701 MB.
event: level
id: 72
data: [{"content": "ons", "parent": 0, "prob": -11.810080528259277}, {"content": "M", "parent": 1, "prob": -13.04189395904541}, {"content": "M", "parent": 2, "prob": -14.539837837219238}, {"content": "M", "parent": 3, "prob": -14.175453186035156}, {"content": "M", "parent": 4, "prob": -15.1754789352417}, {"content": "Mars", "parent": 5, "prob": -13.810208320617676}, {"content": "M", "parent": 6, "prob": -16.48544692993164}, {"content": "us", "parent": 7, "prob": -18.002708435058594}, {"content": "us", "parent": 8, "prob": -19.251447677612305}, {"content": "us", "parent": 9, "prob": -18.653705596923828}, {"content": "us", "parent": 10, "prob": -19.402881622314453}, {"content": "to", "parent": 11, "prob": -18.31064796447754}, {"content": "Olymp", "parent": 12, "prob": -18.535905838012695}, {"content": "Olymp", "parent": 13, "prob": -19.76572036743164}, {"content": "Olymp", "parent": 14, "prob": -21.236717224121094}, {"content": "O

array([[-1.9765625 , -1.515625  ,  2.46875   , ...,  1.90625   ,
         2.203125  ,  4.34375   ],
       [ 0.11425781,  0.875     , -1.78125   , ..., -2.65625   ,
        -0.00860596,  1.390625  ],
       [-0.06982422,  0.9453125 , -1.875     , ..., -2.484375  ,
         0.02600098,  1.3828125 ],
       ...,
       [-2.015625  , -3.        , -0.578125  , ..., -0.31445312,
        -0.7265625 , -0.20800781],
       [-2.21875   ,  0.20996094,  0.5859375 , ..., -2.375     ,
         0.42578125,  0.8125    ],
       [-0.44726562, -1.1640625 ,  1.2109375 , ...,  0.41796875,
         1.390625  ,  1.375     ]], dtype=float32)


k_mean_space
(20, 2)


array([[ 71.80375  , 109.23265  ],
       [ 60.842003 , 110.79015  ],
       [ 60.392063 , 110.647026 ],
       [ 60.73936  , 110.73181  ],
       [ 60.51748  , 110.67817  ],
       [ 76.23077  , 110.68842  ],
       [ 60.81355  , 110.92704  ],
       [ 55.60958  , 101.09985  ],
       [ 55.7762   , 101.19464  ],
       [ 55.640854 , 101.13864  ],
       [ 55.92829  , 101.405815 ],
       [ 79.09374  , 107.97988  ],
       [ 85.64003  ,   6.356444 ],
       [ 85.75028  ,   6.6188693],
       [ 85.81042  ,   7.70086  ],
       [ 85.512375 ,   7.064469 ],
       [ 85.84794  ,   7.9301114],
       [ 77.610374 , 110.45899  ],
       [ 56.41156  , 101.28131  ],
       [ 78.3531   , 110.51879  ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-238.53165054,  -97.33545494])


closest
(2,)


array([ 7, 12])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 0.1226, -4.4062, -5.3438,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0962, -4.4688, -3.0781,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1680, -4.2812, -2.8125,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 4.5000, -4.1875, -4.7500,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.2188, -2.1562, -3.9531,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2695,  1.4062, -1.4766,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9175e-01, 4.0530e-03, 2.4583e-03,  ..., 2.4079e-24, 1.2889e-24,
         1.9765e-25],
        [1.0000e+00, 2.1724e-10, 1.0262e-10,  ..., 1.1244e-26, 1.7243e-27,
         8.1448e-28],
        [1.0000e+00, 3.1609e-10, 1.4931e-10,  ..., 2.3803e-26, 6.8196e-27,
         2.8428e-27],
        ...,
        [8.5567e-01, 1.1580e-01, 1.2206e-02,  ..., 4.4862e-22, 1.8701e-22,
         6.8798e-23],
        [1.0000e+00, 3.0590e-07, 3.0590e-07,  ..., 9.6027e-24, 4.0030e-24,
         2.4279e-24],
        [9.9981e-01, 1.0889e-04, 3.5351e-05,  ..., 9.1091e-23, 7.5517e-23,
         5.5249e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9917, 0.9958, 0.9983,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.8557, 0.9715, 0.9837,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9998, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([22])


tensor([ 0,  1,  2,  3,  4,  5,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([22, 84])


tensor([[    1, 32010,  1724,  ...,   375,   341,   787],
        [    1, 32010,  1724,  ...,  6167,   375,   341],
        [    1, 32010,  1724,  ...,  6167,   375,   341],
        ...,
        [    1, 32010,  1724,  ..., 16852, 29892,   607],
        [    1, 32010,  1724,  ..., 29892,  6167,   375],
        [    1, 32010,  1724,  ..., 29892,   373, 16852]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([22])


tensor([-11.8101, -13.0419, -14.5398, -14.1755, -15.1755, -13.8102, -13.8102,
        -16.4854, -18.0027, -19.2514, -18.6537, -19.4029, -18.3106, -18.5359,
        -19.7657, -21.2367, -17.8951, -19.9020, -13.1142, -13.1142, -16.1431,
        -16.6145], device='cuda:0')


new_candidate_toks
torch.Size([22, 1])


tensor([[  373],
        [  787],
        [  787],
        [  787],
        [  787],
        [  338],
        [ 8640],
        [  787],
        [  341],
        [  341],
        [  341],
        [  341],
        [ 6167],
        [  375],
        [  375],
        [  375],
        [  375],
        [  375],
        [15028],
        [  338],
        [  341],
        [29892]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([22])


tensor([-8.2886e-03,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
        -4.3110e-01, -1.1811e+00,  0.0000e+00, -9.5367e-07, -7.1526e-07,
        -9.5367e-07, -1.0729e-06, -2.5857e-02, -8.6192e-05, -5.0069e-05,
        -8.0708e-05, -8.4165e-05, -3.3618e-05, -1.5587e-01, -2.1559e+00,
        -1.4305e-06, -1.8545e-04], device='cuda:0')


new_candidates
torch.Size([22, 85])


tensor([[    1, 32010,  1724,  ...,   341,   787,   373],
        [    1, 32010,  1724,  ...,   375,   341,   787],
        [    1, 32010,  1724,  ...,   375,   341,   787],
        ...,
        [    1, 32010,  1724,  ..., 29892,   607,   338],
        [    1, 32010,  1724,  ...,  6167,   375,   341],
        [    1, 32010,  1724,  ...,   373, 16852, 29892]], device='cuda:0')


new_candidate_logprobs
torch.Size([22])


tensor([-11.8184, -13.0419, -14.5398, -14.1755, -15.1755, -14.2413, -14.9913,
        -16.4854, -18.0027, -19.2514, -18.6537, -19.4029, -18.3365, -18.5360,
        -19.7658, -21.2368, -17.8952, -19.9020, -13.2701, -15.2701, -16.1431,
        -16.6147], device='cuda:0')

infer end: GPU memory used: 17907 MB.
event: level
id: 73
data: [{"content": "on", "parent": 0, "prob": -11.818368911743164}, {"content": "ons", "parent": 1, "prob": -13.04189395904541}, {"content": "ons", "parent": 2, "prob": -14.539837837219238}, {"content": "ons", "parent": 3, "prob": -14.175453186035156}, {"content": "ons", "parent": 4, "prob": -15.1754789352417}, {"content": "is", "parent": 5, "prob": -14.241308212280273}, {"content": "holds", "parent": 5, "prob": -14.991308212280273}, {"content": "ons", "parent": 6, "prob": -16.48544692993164}, {"content": "M", "parent": 7, "prob": -18.002710342407227}, {"content": "M", "parent": 8, "prob": -19.251447677612305}, {"content": "M", "parent": 9, "prob": -18.65370750427246}, {"content": "M", "parent": 10, "prob": -19.402883529663086}, {"content": "Olymp", "parent": 11, "prob": -18.336503982543945}, {"content": "us", "parent": 12, "prob": -18.535991668701172}, {"content": "us", "parent": 13, "prob": -19.765769958496094}, {"content": "u

array([[-0.17089844, -0.02587891,  2.890625  , ...,  0.7109375 ,
        -0.20117188, -0.70703125],
       [-2.015625  , -0.93359375,  2.640625  , ...,  2.03125   ,
         2.421875  ,  4.4375    ],
       [-1.859375  , -0.91015625,  2.65625   , ...,  2.        ,
         2.21875   ,  4.5625    ],
       ...,
       [-1.9375    ,  0.61328125,  0.09130859, ..., -2.21875   ,
         0.796875  ,  0.46484375],
       [-2.5625    ,  0.625     ,  1.140625  , ..., -1.015625  ,
         3.234375  ,  1.15625   ],
       [-0.28515625, -1.4375    ,  1.3671875 , ..., -0.26171875,
         0.9140625 ,  0.44726562]], dtype=float32)


k_mean_space
(20, 2)


array([[115.00567  ,  79.41984  ],
       [108.400505 ,  51.929726 ],
       [108.544655 ,  51.993458 ],
       [108.30622  ,  51.89346  ],
       [108.41137  ,  51.863796 ],
       [112.13412  ,  72.085526 ],
       [113.1897   ,  81.276855 ],
       [108.286316 ,  51.809082 ],
       [  5.266364 ,  89.10621  ],
       [  3.6882894,  88.727585 ],
       [  3.6683803,  89.10953  ],
       [  4.650941 ,  88.807816 ],
       [110.57213  ,  82.11638  ],
       [105.52327  ,  55.67348  ],
       [105.639206 ,  55.984394 ],
       [105.36247  ,  55.92411  ],
       [105.59615  ,  55.966076 ],
       [105.92304  ,  56.56492  ],
       [116.6697   ,  77.87009  ],
       [112.45203  ,  72.618996 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -75.31074905, -258.68158245])


closest
(2,)


array([10,  7])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 4.0938,  0.9219, -1.5156,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.1172, -3.9844, -4.3125,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.5000, -3.1562, -4.0312,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 4.8750, -1.8672, -4.7188,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.0625,  1.5000, -3.6250,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.7656, -4.7500, -2.9688,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9943e-01, 5.5277e-04, 1.3000e-05,  ..., 1.0115e-24, 1.0115e-24,
         6.9522e-25],
        [9.8273e-01, 1.0917e-02, 3.5443e-03,  ..., 1.8582e-24, 1.8582e-24,
         5.3239e-25],
        [9.7903e-01, 1.3965e-02, 4.5338e-03,  ..., 2.6935e-24, 1.8512e-24,
         3.6453e-25],
        ...,
        [1.0000e+00, 4.4508e-07, 3.9279e-07,  ..., 3.5326e-24, 1.4726e-24,
         1.2996e-24],
        [8.4400e-01, 1.2943e-01, 1.2039e-02,  ..., 8.7133e-23, 7.6895e-23,
         1.3362e-23],
        [9.4909e-01, 2.8660e-02, 4.9804e-03,  ..., 1.3526e-21, 1.1937e-21,
         5.6385e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9994, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9827, 0.9936, 0.9972,  ..., 1.0000, 1.0000, 1.0000],
        [0.9790, 0.9930, 0.9975,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8440, 0.9734, 0.9855,  ..., 1.0000, 1.0000, 1.0000],
        [0.9491, 0.9778, 0.9827,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([22])


tensor([ 0,  1,  2,  3,  4,  5,  6,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([22, 85])


tensor([[    1, 32010,  1724,  ...,   341,   787,   373],
        [    1, 32010,  1724,  ...,   375,   341,   787],
        [    1, 32010,  1724,  ...,   375,   341,   787],
        ...,
        [    1, 32010,  1724,  ..., 29892,   607, 15028],
        [    1, 32010,  1724,  ..., 29892,   607, 15028],
        [    1, 32010,  1724,  ..., 29892,   607,   338]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([22])


tensor([-11.8184, -13.0419, -14.5398, -14.1755, -15.1755, -14.2413, -14.9913,
        -14.9913, -16.4854, -18.0027, -19.2514, -18.6537, -19.4029, -18.3365,
        -18.5360, -19.7658, -21.2368, -17.8952, -19.9020, -13.2701, -13.2701,
        -15.2701], device='cuda:0')


new_candidate_toks
torch.Size([22, 1])


tensor([[16852],
        [  373],
        [  373],
        [  373],
        [  373],
        [  278],
        [  393],
        [  278],
        [  373],
        [  787],
        [  787],
        [  787],
        [  787],
        [  375],
        [  341],
        [  341],
        [  341],
        [  341],
        [  341],
        [ 1048],
        [  472],
        [ 1048]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([22])


tensor([-5.7076e-04, -1.7425e-02, -2.1195e-02, -1.1027e-02, -1.4839e-02,
        -1.7559e-02, -1.0629e-01, -2.3563e+00, -1.6335e-02,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00, -5.7103e-05, -1.4305e-06,
        -1.3113e-06, -1.1921e-06, -1.3113e-06, -1.9074e-06, -1.6960e-01,
        -2.0446e+00, -5.2247e-02], device='cuda:0')


new_candidates
torch.Size([22, 86])


tensor([[    1, 32010,  1724,  ...,   787,   373, 16852],
        [    1, 32010,  1724,  ...,   341,   787,   373],
        [    1, 32010,  1724,  ...,   341,   787,   373],
        ...,
        [    1, 32010,  1724,  ...,   607, 15028,  1048],
        [    1, 32010,  1724,  ...,   607, 15028,   472],
        [    1, 32010,  1724,  ...,   607,   338,  1048]], device='cuda:0')


new_candidate_logprobs
torch.Size([22])


tensor([-11.8189, -13.0593, -14.5610, -14.1865, -15.1903, -14.2589, -15.0976,
        -17.3476, -16.5018, -18.0027, -19.2514, -18.6537, -19.4029, -18.3366,
        -18.5360, -19.7658, -21.2368, -17.8952, -19.9020, -13.4397, -15.3147,
        -15.3223], device='cuda:0')

infer end: GPU memory used: 18115 MB.
event: level
id: 74
data: [{"content": "Mars", "parent": 0, "prob": -11.818939208984375}, {"content": "on", "parent": 1, "prob": -13.059319496154785}, {"content": "on", "parent": 2, "prob": -14.561033248901367}, {"content": "on", "parent": 3, "prob": -14.186480522155762}, {"content": "on", "parent": 4, "prob": -15.19031810760498}, {"content": "the", "parent": 5, "prob": -14.258867263793945}, {"content": "that", "parent": 6, "prob": -15.09759521484375}, {"content": "the", "parent": 6, "prob": -17.34759521484375}, {"content": "on", "parent": 7, "prob": -16.501781463623047}, {"content": "ons", "parent": 8, "prob": -18.002710342407227}, {"content": "ons", "parent": 9, "prob": -19.251447677612305}, {"content": "ons", "parent": 10, "prob": -18.65370750427246}, {"content": "ons", "parent": 11, "prob": -19.402883529663086}, {"content": "us", "parent": 12, "prob": -18.33656120300293}, {"content": "M", "parent": 13, "prob": -18.535993576049805}, {"content": 

array([[-1.3671875 , -2.046875  ,  1.546875  , ...,  1.2578125 ,
         2.375     ,  1.796875  ],
       [-0.12988281, -0.11865234,  3.25      , ...,  0.890625  ,
        -0.21582031, -0.578125  ],
       [-0.16015625, -0.17089844,  3.203125  , ...,  0.90625   ,
        -0.36132812, -0.62890625],
       ...,
       [ 0.06835938,  0.609375  , -2.015625  , ..., -2.59375   ,
         0.07177734,  1.2109375 ],
       [ 0.23046875,  0.625     , -1.7421875 , ..., -2.46875   ,
         0.02172852,  1.4765625 ],
       [-1.546875  ,  0.23144531,  1.6484375 , ...,  0.72265625,
        -0.06396484,  0.25390625]], dtype=float32)


k_mean_space
(20, 2)


array([[ 62.87408  , 107.61663  ],
       [ 56.02478  , 114.93263  ],
       [ 55.94858  , 114.93355  ],
       [ 56.00317  , 114.85869  ],
       [ 55.985554 , 114.845024 ],
       [ 70.73561  , 107.86343  ],
       [ 74.54296  , 111.02685  ],
       [ 67.9819   , 108.05742  ],
       [ 55.918564 , 114.87714  ],
       [ 56.732395 , 108.31699  ],
       [ 56.16227  , 108.27582  ],
       [ 56.824356 , 108.20125  ],
       [ 56.385662 , 108.3016   ],
       [ 77.76265  , 105.71346  ],
       [ 91.55608  ,   6.0862923],
       [ 91.371346 ,   5.6339417],
       [ 91.35349  ,   7.475703 ],
       [ 91.49324  ,   6.3059945],
       [ 91.40014  ,   7.228124 ],
       [ 79.111855 , 115.57674  ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-239.10893822,  -97.33579826])


closest
(2,)


array([ 8, 15])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 5.1250,  1.7969, -0.5703,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.0312,  1.3828, -1.0156,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.9062,  1.7578, -1.0234,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.1152, -5.1562, -3.1875,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.2734, -4.0000, -2.0625,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.5703,  0.5156, -2.5312,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[8.8054e-01, 1.1917e-01, 1.0867e-04,  ..., 3.3112e-24, 2.9222e-24,
         2.5788e-24],
        [9.9926e-01, 7.0965e-04, 3.1180e-05,  ..., 1.8895e-24, 1.8895e-24,
         1.4715e-24],
        [9.9926e-01, 7.0965e-04, 2.4283e-05,  ..., 1.8895e-24, 1.6675e-24,
         1.6675e-24],
        ...,
        [1.0000e+00, 7.9920e-11, 4.8474e-11,  ..., 1.8538e-26, 7.1878e-28,
         6.3432e-28],
        [1.0000e+00, 7.9920e-11, 6.2241e-11,  ..., 5.7100e-26, 1.7243e-27,
         1.5217e-27],
        [9.9984e-01, 5.8285e-05, 5.8285e-05,  ..., 7.0943e-23, 3.7973e-23,
         9.6011e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.8805, 0.9997, 0.9998,  ..., 1.0000, 1.0000, 1.0000],
        [0.9993, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9993, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9998, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([24])


tensor([ 0,  0,  1,  2,  3,  4,  5,  5,  6,  6,  7,  7,  8,  9, 10, 11, 12, 13,
        14, 15, 16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([24, 86])


tensor([[    1, 32010,  1724,  ...,   787,   373, 16852],
        [    1, 32010,  1724,  ...,   787,   373, 16852],
        [    1, 32010,  1724,  ...,   341,   787,   373],
        ...,
        [    1, 32010,  1724,  ...,  6167,   375,   341],
        [    1, 32010,  1724,  ...,  6167,   375,   341],
        [    1, 32010,  1724,  ...,   607, 15028,  1048]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([24])


tensor([-11.8189, -11.8189, -13.0593, -14.5610, -14.1865, -15.1903, -14.2589,
        -14.2589, -15.0976, -15.0976, -17.3476, -17.3476, -16.5018, -18.0027,
        -19.2514, -18.6537, -19.4029, -18.3366, -18.5360, -19.7658, -21.2368,
        -17.8952, -19.9020, -13.4397], device='cuda:0')


new_candidate_toks
torch.Size([24, 1])


tensor([[29892],
        [29889],
        [16852],
        [16852],
        [16852],
        [16852],
        [15655],
        [ 9939],
        [ 3611],
        [21578],
        [ 2407],
        [ 3611],
        [16852],
        [  373],
        [  373],
        [  373],
        [  373],
        [  341],
        [  787],
        [  787],
        [  787],
        [  787],
        [  787],
        [29871]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([24])


tensor([-1.2722e-01, -2.1272e+00, -7.4516e-04, -7.3800e-04, -7.5064e-04,
        -1.2123e-03, -1.5891e-01, -2.2839e+00, -6.1140e-01, -8.6140e-01,
        -7.0551e-01, -7.0551e-01, -6.4889e-04, -3.5001e-02, -3.3854e-02,
        -1.7594e-02, -1.9914e-02, -1.7881e-06,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00, -1.6476e-04], device='cuda:0')


new_candidates
torch.Size([24, 87])


tensor([[    1, 32010,  1724,  ...,   373, 16852, 29892],
        [    1, 32010,  1724,  ...,   373, 16852, 29889],
        [    1, 32010,  1724,  ...,   787,   373, 16852],
        ...,
        [    1, 32010,  1724,  ...,   375,   341,   787],
        [    1, 32010,  1724,  ...,   375,   341,   787],
        [    1, 32010,  1724,  ..., 15028,  1048, 29871]], device='cuda:0')


new_candidate_logprobs
torch.Size([24])


tensor([-11.9462, -13.9462, -13.0601, -14.5618, -14.1872, -15.1915, -14.4178,
        -16.5428, -15.7090, -15.9590, -18.0531, -18.0531, -16.5024, -18.0377,
        -19.2853, -18.6713, -19.4228, -18.3366, -18.5360, -19.7658, -21.2368,
        -17.8952, -19.9020, -13.4399], device='cuda:0')

infer end: GPU memory used: 18327 MB.
event: level
id: 75
data: [{"content": ",", "parent": 0, "prob": -11.946158409118652}, {"content": ".", "parent": 0, "prob": -13.946158409118652}, {"content": "Mars", "parent": 1, "prob": -13.060064315795898}, {"content": "Mars", "parent": 2, "prob": -14.561771392822266}, {"content": "Mars", "parent": 3, "prob": -14.187231063842773}, {"content": "Mars", "parent": 4, "prob": -15.191530227661133}, {"content": "tall", "parent": 5, "prob": -14.417773246765137}, {"content": "highest", "parent": 5, "prob": -16.54277229309082}, {"content": "title", "parent": 6, "prob": -15.708991050720215}, {"content": "distinction", "parent": 6, "prob": -15.958991050720215}, {"content": "record", "parent": 7, "prob": -18.05310821533203}, {"content": "title", "parent": 7, "prob": -18.05310821533203}, {"content": "Mars", "parent": 8, "prob": -16.502429962158203}, {"content": "on", "parent": 9, "prob": -18.03771209716797}, {"content": "on", "parent": 10, "prob": -19.2853012

array([[-2.75      ,  0.30078125,  1.359375  , ..., -0.5546875 ,
         0.03198242,  1.        ],
       [-1.1953125 , -1.1015625 , -0.09326172, ..., -0.75      ,
         0.4140625 ,  0.02441406],
       [-0.9765625 , -1.5       ,  1.9765625 , ...,  1.375     ,
         2.765625  ,  1.578125  ],
       ...,
       [ 0.02478027,  0.61328125, -1.90625   , ..., -2.5625    ,
         0.11230469,  1.1875    ],
       [-1.84375   , -0.70703125,  2.421875  , ...,  2.375     ,
         2.75      ,  4.34375   ],
       [-2.125     , -1.15625   ,  2.796875  , ...,  2.203125  ,
         2.578125  ,  4.        ]], dtype=float32)


k_mean_space
(20, 2)


array([[ 60.22588  , 102.97254  ],
       [ 57.278255 ,  92.3051   ],
       [ 37.642372 , 100.3496   ],
       [ 37.462566 , 100.01975  ],
       [ 37.80034  , 100.23373  ],
       [ 37.693523 ,  99.88144  ],
       [ 77.54628  , 110.230606 ],
       [ 68.49946  , 106.92214  ],
       [ 43.703674 , 103.21048  ],
       [ 44.872585 , 103.67157  ],
       [ 51.31373  , 107.01676  ],
       [ 52.038605 , 106.10554  ],
       [ 37.32727  , 100.231735 ],
       [ 87.2499   ,   3.238854 ],
       [ 87.18347  ,   3.225188 ],
       [ 87.29104  ,   3.0080411],
       [ 87.233246 ,   3.2162752],
       [ 89.84739  , 115.20125  ],
       [ 59.373302 , 101.58189  ],
       [ 60.258945 , 101.13976  ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-254.7684164 ,  -75.41711235])


closest
(2,)


array([12, 15])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 3.5469, -4.5938, -2.9531,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.4062, -1.9844,  3.2031,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.6875,  0.9219, -0.9531,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.0669, -5.0312, -3.1250,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.5234, -3.1719, -3.4531,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.8008, -2.3594, -4.1875,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.3884e-01, 4.6742e-02, 1.0430e-02,  ..., 2.1842e-22, 1.5012e-22,
         6.2579e-23],
        [6.7595e-01, 3.1929e-01, 2.7624e-03,  ..., 2.5342e-18, 2.3807e-18,
         5.6546e-19],
        [7.7685e-01, 2.2257e-01, 2.2998e-04,  ..., 3.7510e-24, 3.7510e-24,
         3.1097e-24],
        ...,
        [1.0000e+00, 9.0561e-11, 7.9920e-11,  ..., 2.6972e-26, 9.2293e-28,
         7.1878e-28],
        [9.3990e-01, 3.6444e-02, 2.2104e-02,  ..., 4.2634e-24, 2.9302e-24,
         1.3841e-24],
        [9.7178e-01, 1.3862e-02, 1.2233e-02,  ..., 8.2352e-24, 5.6600e-24,
         4.1001e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9388, 0.9856, 0.9960,  ..., 1.0000, 1.0000, 1.0000],
        [0.6759, 0.9952, 0.9980,  ..., 1.0000, 1.0000, 1.0000],
        [0.7768, 0.9994, 0.9996,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9399, 0.9763, 0.9984,  ..., 1.0000, 1.0000, 1.0000],
        [0.9718, 0.9856, 0.9979,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([33])


tensor([ 0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  7,  7,  8,  8,  9,  9,
        10, 10, 10, 11, 11, 11, 12, 12, 13, 14, 15, 16, 17, 18, 19],
       device='cuda:0')


carryover_candidates
torch.Size([33, 87])


tensor([[    1, 32010,  1724,  ...,   373, 16852, 29892],
        [    1, 32010,  1724,  ...,   373, 16852, 29889],
        [    1, 32010,  1724,  ...,   373, 16852, 29889],
        ...,
        [    1, 32010,  1724,  ...,  6167,   375,   341],
        [    1, 32010,  1724,  ...,   375,   341,   787],
        [    1, 32010,  1724,  ...,   375,   341,   787]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([33])


tensor([-11.9462, -13.9462, -13.9462, -13.0601, -13.0601, -14.5618, -14.5618,
        -14.1872, -14.1872, -15.1915, -15.1915, -14.4178, -16.5428, -16.5428,
        -15.7090, -15.7090, -15.9590, -15.9590, -18.0531, -18.0531, -18.0531,
        -18.0531, -18.0531, -18.0531, -16.5024, -16.5024, -18.0377, -19.2853,
        -18.6713, -19.4228, -18.3366, -18.5360, -19.7658], device='cuda:0')


new_candidate_toks
torch.Size([33, 1])


tensor([[  607],
        [  739],
        [ 6167],
        [29892],
        [29889],
        [29892],
        [29889],
        [29892],
        [29889],
        [29892],
        [29889],
        [  342],
        [29892],
        [ 2998],
        [29892],
        [29889],
        [29892],
        [  411],
        [29892],
        [  363],
        [  411],
        [29892],
        [  363],
        [29889],
        [29892],
        [29889],
        [16852],
        [16852],
        [16852],
        [16852],
        [  787],
        [  373],
        [  373]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([33])


tensor([-6.3105e-02, -3.9164e-01, -1.1416e+00, -2.5251e-01, -1.5025e+00,
        -3.8744e-01, -1.1374e+00, -3.1377e-01, -1.3138e+00, -3.8742e-01,
        -1.1374e+00,  0.0000e+00, -2.2273e-01, -2.2227e+00, -1.0989e-01,
        -2.8599e+00, -1.8511e-01, -2.1851e+00, -2.9781e-01, -2.0478e+00,
        -2.5478e+00, -3.3226e-01, -1.8323e+00, -2.5823e+00, -1.2724e-01,
        -2.1272e+00, -8.4364e-04, -9.4727e-04, -6.6750e-04, -1.5553e-03,
         0.0000e+00, -6.1981e-02, -2.8623e-02], device='cuda:0')


new_candidates
torch.Size([33, 88])


tensor([[    1, 32010,  1724,  ..., 16852, 29892,   607],
        [    1, 32010,  1724,  ..., 16852, 29889,   739],
        [    1, 32010,  1724,  ..., 16852, 29889,  6167],
        ...,
        [    1, 32010,  1724,  ...,   375,   341,   787],
        [    1, 32010,  1724,  ...,   341,   787,   373],
        [    1, 32010,  1724,  ...,   341,   787,   373]], device='cuda:0')


new_candidate_logprobs
torch.Size([33])


tensor([-12.0093, -14.3378, -15.0878, -13.3126, -14.5626, -14.9492, -15.6992,
        -14.5010, -15.5010, -15.5789, -16.3289, -14.4178, -16.7655, -18.7655,
        -15.8189, -18.5689, -16.1441, -18.1441, -18.3509, -20.1009, -20.6009,
        -18.3854, -19.8854, -20.6354, -16.6297, -18.6297, -18.0386, -19.2862,
        -18.6720, -19.4244, -18.3366, -18.5980, -19.7944], device='cuda:0')

infer end: GPU memory used: 18541 MB.
event: level
id: 76
data: [{"content": "which", "parent": 0, "prob": -12.009263038635254}, {"content": "It", "parent": 1, "prob": -14.337800979614258}, {"content": "Olymp", "parent": 1, "prob": -15.087800979614258}, {"content": ",", "parent": 2, "prob": -13.312576293945312}, {"content": ".", "parent": 2, "prob": -14.562576293945312}, {"content": ",", "parent": 3, "prob": -14.949214935302734}, {"content": ".", "parent": 3, "prob": -15.699214935302734}, {"content": ",", "parent": 4, "prob": -14.500997543334961}, {"content": ".", "parent": 4, "prob": -15.500997543334961}, {"content": ",", "parent": 5, "prob": -15.578948020935059}, {"content": ".", "parent": 5, "prob": -16.328947067260742}, {"content": "est", "parent": 6, "prob": -14.417773246765137}, {"content": ",", "parent": 7, "prob": -16.7655029296875}, {"content": "known", "parent": 7, "prob": -18.7655029296875}, {"content": ",", "parent": 8, "prob": -15.818882942199707}, {"content": ".", "parent

array([[-1.9765625 , -2.921875  , -0.421875  , ..., -0.22558594,
        -0.671875  , -0.2890625 ],
       [-2.984375  , -2.59375   ,  0.56640625, ...,  0.18261719,
        -1.546875  ,  1.4375    ],
       [ 0.02978516, -0.40429688, -2.40625   , ..., -2.421875  ,
        -0.88671875, -0.578125  ],
       ...,
       [-0.56640625, -2.484375  ,  2.625     , ...,  1.3046875 ,
         0.38085938,  3.359375  ],
       [-1.109375  , -0.01477051, -0.27929688, ..., -1.5546875 ,
         1.015625  ,  0.91796875],
       [-0.703125  , -2.40625   ,  2.65625   , ...,  0.890625  ,
        -0.39648438,  2.1875    ]], dtype=float32)


k_mean_space
(20, 2)


array([[ 72.49927 ,  57.200333],
       [ 73.162384,  59.01228 ],
       [ 79.62873 , 102.87701 ],
       [ 65.369026,  28.94716 ],
       [ 31.120148,  65.38941 ],
       [ 65.34992 ,  28.910528],
       [ 30.875412,  65.45199 ],
       [ 65.69838 ,  30.055958],
       [ 31.156126,  65.42857 ],
       [ 65.51233 ,  29.421888],
       [ 31.311075,  65.64144 ],
       [ 65.980385,  84.26596 ],
       [ 63.613117,  31.572834],
       [ 66.20675 ,  86.0249  ],
       [ 61.202385,  29.19848 ],
       [ 33.311794,  64.13887 ],
       [ 60.93945 ,  29.981419],
       [ 70.04112 ,  78.510056],
       [ 63.290768,  27.822838],
       [ 59.80049 ,  74.71976 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-167.17671299, -151.76820469])


closest
(2,)


array([ 6, 18])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 4.2812, -4.2188, -4.8750,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.5312, -4.2812, -5.5000,  ...,  0.0000,  0.0000,  0.0000],
        [-1.2734, -2.4375,  1.5938,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 3.5781, -3.4688, -1.3203,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.5000, -1.2031, -1.5000,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.7188,  1.8750,  0.2871,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[8.7805e-01, 8.1671e-02, 1.8223e-02,  ..., 7.5899e-22, 2.4641e-22,
         1.3189e-22],
        [9.2746e-01, 3.1736e-02, 1.4991e-02,  ..., 1.7888e-22, 1.5786e-22,
         2.1365e-23],
        [1.0000e+00, 3.9279e-07, 1.5230e-08,  ..., 5.2068e-29, 1.6904e-29,
         9.0481e-30],
        ...,
        [9.3896e-01, 4.6748e-02, 1.3394e-02,  ..., 7.1627e-22, 5.8795e-23,
         2.4509e-23],
        [4.4293e-01, 2.0923e-01, 1.6295e-01,  ..., 7.1530e-22, 1.9252e-22,
         4.0354e-23],
        [4.8992e-01, 3.8155e-01, 8.5135e-02,  ..., 1.1031e-19, 7.1220e-20,
         5.2106e-20]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.8780, 0.9597, 0.9779,  ..., 1.0000, 1.0000, 1.0000],
        [0.9275, 0.9592, 0.9742,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9390, 0.9857, 0.9991,  ..., 1.0000, 1.0000, 1.0000],
        [0.4429, 0.6522, 0.8151,  ..., 1.0000, 1.0000, 1.0000],
        [0.4899, 0.8715, 0.9566,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([41])


tensor([ 0,  0,  1,  2,  3,  4,  4,  5,  6,  6,  7,  8,  8,  9, 10, 10, 11, 11,
        12, 12, 13, 13, 13, 14, 14, 14, 14, 15, 15, 16, 16, 16, 16, 17, 18, 18,
        18, 18, 19, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([41, 88])


tensor([[    1, 32010,  1724,  ..., 16852, 29892,   607],
        [    1, 32010,  1724,  ..., 16852, 29892,   607],
        [    1, 32010,  1724,  ..., 16852, 29889,   739],
        ...,
        [    1, 32010,  1724,  ...,   278,  2407,   363],
        [    1, 32010,  1724,  ...,   278,  2407,   363],
        [    1, 32010,  1724,  ...,   278,  2407,   363]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([41])


tensor([-12.0093, -12.0093, -14.3378, -15.0878, -13.3126, -14.5626, -14.5626,
        -14.9492, -15.6992, -15.6992, -14.5010, -15.5010, -15.5010, -15.5789,
        -16.3289, -16.3289, -14.4178, -14.4178, -16.7655, -16.7655, -18.7655,
        -18.7655, -18.7655, -15.8189, -15.8189, -15.8189, -15.8189, -18.5689,
        -18.5689, -16.1441, -16.1441, -16.1441, -16.1441, -18.1441, -18.3509,
        -18.3509, -18.3509, -18.3509, -20.1009, -20.1009, -20.1009],
       device='cuda:0')


new_candidate_toks
torch.Size([41, 1])


tensor([[15028],
        [  338],
        [15028],
        [  375],
        [  607],
        [ 6167],
        [  739],
        [  607],
        [ 6167],
        [  739],
        [  607],
        [ 6167],
        [  739],
        [  607],
        [ 6167],
        [  739],
        [29892],
        [ 2998],
        [13407],
        [  411],
        [14378],
        [29892],
        [ 1700],
        [  411],
        [13407],
        [ 1641],
        [  408],
        [  739],
        [ 6167],
        [  411],
        [ 1641],
        [13407],
        [19372],
        [  263],
        [  411],
        [13407],
        [ 1641],
        [19372],
        [ 1641],
        [  278],
        [  393]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([41])


tensor([-1.3005e-01, -2.5051e+00, -7.5311e-02, -3.5763e-07, -6.1153e-02,
        -4.3479e-01, -1.0598e+00, -7.3359e-02, -2.5821e-01, -1.5082e+00,
        -4.6213e-02, -3.9238e-01, -1.1424e+00, -5.6281e-02, -2.3072e-01,
        -1.6057e+00, -1.2617e-01, -2.8762e+00, -3.4328e-01, -1.5933e+00,
        -3.7617e-01, -1.7512e+00, -2.1262e+00, -4.0992e-01, -2.1599e+00,
        -2.4099e+00, -3.0349e+00, -1.3393e-01, -2.1339e+00, -4.5517e-01,
        -1.9552e+00, -2.4552e+00, -2.9552e+00, -6.2982e-02, -8.1434e-01,
        -1.5643e+00, -1.8143e+00, -2.0643e+00, -7.1351e-01, -9.6351e-01,
        -2.4635e+00], device='cuda:0')


new_candidates
torch.Size([41, 89])


tensor([[    1, 32010,  1724,  ..., 29892,   607, 15028],
        [    1, 32010,  1724,  ..., 29892,   607,   338],
        [    1, 32010,  1724,  ..., 29889,   739, 15028],
        ...,
        [    1, 32010,  1724,  ...,  2407,   363,  1641],
        [    1, 32010,  1724,  ...,  2407,   363,   278],
        [    1, 32010,  1724,  ...,  2407,   363,   393]], device='cuda:0')


new_candidate_logprobs
torch.Size([41])


tensor([-12.1393, -14.5143, -14.4131, -15.0878, -13.3737, -14.9974, -15.6224,
        -15.0226, -15.9574, -17.2074, -14.5472, -15.8934, -16.6434, -15.6352,
        -16.5597, -17.9347, -14.5439, -17.2939, -17.1088, -18.3588, -19.1417,
        -20.5167, -20.8917, -16.2288, -17.9788, -18.2288, -18.8538, -18.7028,
        -20.7028, -16.5993, -18.0993, -18.5993, -19.0993, -18.2071, -19.1653,
        -19.9153, -20.1653, -20.4153, -20.8144, -21.0644, -22.5644],
       device='cuda:0')

infer end: GPU memory used: 18757 MB.
event: level
id: 77
data: [{"content": "stands", "parent": 0, "prob": -12.139314651489258}, {"content": "is", "parent": 0, "prob": -14.514314651489258}, {"content": "stands", "parent": 1, "prob": -14.413111686706543}, {"content": "us", "parent": 2, "prob": -15.087800979614258}, {"content": "which", "parent": 3, "prob": -13.373729705810547}, {"content": "Olymp", "parent": 4, "prob": -14.99736213684082}, {"content": "It", "parent": 4, "prob": -15.62236213684082}, {"content": "which", "parent": 5, "prob": -15.022573471069336}, {"content": "Olymp", "parent": 6, "prob": -15.957428932189941}, {"content": "It", "parent": 6, "prob": -17.207427978515625}, {"content": "which", "parent": 7, "prob": -14.547210693359375}, {"content": "Olymp", "parent": 8, "prob": -15.8933744430542}, {"content": "It", "parent": 8, "prob": -16.643375396728516}, {"content": "which", "parent": 9, "prob": -15.635229110717773}, {"content": "Olymp", "parent": 10, "prob": -16.559669494

array([[-2.53125   ,  0.609375  ,  1.203125  , ..., -1.109375  ,
         3.015625  ,  1.265625  ],
       [-0.53125   , -1.4296875 ,  1.46875   , ..., -0.5078125 ,
         0.92578125,  0.484375  ],
       [-2.3125    ,  0.73828125,  1.046875  , ..., -0.99609375,
         2.734375  ,  1.3828125 ],
       ...,
       [-0.04296875, -0.88671875,  2.703125  , ...,  0.89453125,
        -0.83203125,  3.53125   ],
       [-1.78125   , -0.18945312,  0.6875    , ..., -0.22363281,
         2.296875  ,  1.4453125 ],
       [-0.12988281, -2.703125  ,  3.09375   , ...,  0.69921875,
        -0.27734375,  2.0625    ]], dtype=float32)


k_mean_space
(20, 2)


array([[ 61.60269  , 110.83068  ],
       [ 65.12229  , 111.58962  ],
       [ 62.186684 , 110.58436  ],
       [ 79.55525  , 102.62348  ],
       [ 36.63399  , 111.73037  ],
       [ 96.87095  ,   4.3574033],
       [ 36.67003  , 111.03553  ],
       [ 36.56866  , 111.715904 ],
       [ 96.773705 ,   3.6847613],
       [ 36.586952 , 111.09353  ],
       [ 36.60419  , 111.704475 ],
       [ 96.80266  ,   3.36507  ],
       [ 36.761932 , 111.10339  ],
       [ 36.58422  , 111.668785 ],
       [ 96.70269  ,   4.3256164],
       [ 36.754395 , 111.205154 ],
       [ 56.272945 , 110.20443  ],
       [ 81.52392  , 113.531715 ],
       [ 62.494522 , 111.19438  ],
       [ 69.48529  , 113.87987  ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-249.44655514,  -63.40783501])


closest
(2,)


array([ 7, 11])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 1.5547,  1.0625, -4.2188,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.5469, -5.0938, -3.0469,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.0938,  2.6406, -3.3281,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 1.9531, -4.4375, -7.9688,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.0469,  1.7891, -4.4375,  ...,  0.0000,  0.0000,  0.0000],
        [-0.7734, -4.5312, -2.2969,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[8.1506e-01, 1.6049e-01, 1.0260e-02,  ..., 1.2243e-22, 6.5532e-23,
         1.8775e-23],
        [9.3703e-01, 3.2063e-02, 1.5146e-02,  ..., 1.1785e-21, 1.1785e-21,
         9.1782e-22],
        [7.1020e-01, 2.6127e-01, 1.1479e-02,  ..., 6.9564e-22, 3.7235e-22,
         1.0668e-22],
        ...,
        [9.1442e-01, 4.0177e-02, 2.4368e-02,  ..., 4.0141e-21, 2.1486e-21,
         6.9755e-22],
        [7.1648e-01, 2.6358e-01, 7.9594e-03,  ..., 2.2784e-22, 1.2195e-22,
         1.0762e-22],
        [9.8688e-01, 8.5382e-03, 1.1555e-03,  ..., 1.7134e-20, 2.3189e-21,
         2.7695e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.8151, 0.9755, 0.9858,  ..., 1.0000, 1.0000, 1.0000],
        [0.9370, 0.9691, 0.9842,  ..., 1.0000, 1.0000, 1.0000],
        [0.7102, 0.9715, 0.9829,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9144, 0.9546, 0.9790,  ..., 1.0000, 1.0000, 1.0000],
        [0.7165, 0.9801, 0.9880,  ..., 1.0000, 1.0000, 1.0000],
        [0.9869, 0.9954, 0.9966,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([30])


tensor([ 0,  0,  1,  2,  2,  3,  4,  4,  5,  6,  7,  7,  8,  9, 10, 10, 11, 12,
        12, 13, 13, 14, 15, 16, 16, 16, 17, 18, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([30, 89])


tensor([[    1, 32010,  1724,  ..., 29892,   607, 15028],
        [    1, 32010,  1724,  ..., 29892,   607, 15028],
        [    1, 32010,  1724,  ..., 29892,   607,   338],
        ...,
        [    1, 32010,  1724,  ...,  9939, 29892, 13407],
        [    1, 32010,  1724,  ...,  9939, 29892, 13407],
        [    1, 32010,  1724,  ...,  9939, 29892,   411]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([30])


tensor([-12.1393, -12.1393, -14.5143, -14.4131, -14.4131, -15.0878, -13.3737,
        -13.3737, -14.9974, -15.6224, -15.0226, -15.0226, -15.9574, -17.2074,
        -14.5472, -14.5472, -15.8934, -16.6434, -16.6434, -15.6352, -15.6352,
        -16.5597, -17.9347, -14.5439, -14.5439, -14.5439, -17.2939, -17.1088,
        -17.1088, -18.3588], device='cuda:0')


new_candidate_toks
torch.Size([30, 1])


tensor([[ 1048],
        [  472],
        [ 1048],
        [ 1048],
        [  472],
        [  341],
        [15028],
        [  338],
        [  375],
        [15028],
        [15028],
        [  338],
        [  375],
        [15028],
        [15028],
        [  338],
        [  375],
        [15028],
        [  338],
        [15028],
        [  338],
        [  375],
        [15028],
        [13407],
        [  411],
        [20888],
        [14378],
        [  472],
        [ 1048],
        [  263]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([30])


tensor([-2.0450e-01, -1.8295e+00, -6.5043e-02, -3.4221e-01, -1.3422e+00,
        -2.5034e-05, -1.8844e-01, -2.0634e+00, -8.3447e-07, -8.4656e-02,
        -1.9917e-01, -2.0742e+00, -3.5763e-07, -8.4093e-02, -2.2259e-01,
        -1.8476e+00, -8.3447e-07, -1.0649e-01, -2.6065e+00, -2.3326e-01,
        -1.9833e+00, -5.9605e-07, -1.0334e-01, -5.0821e-01, -1.2582e+00,
        -2.8832e+00, -8.9467e-02, -3.3341e-01, -1.3334e+00, -1.3203e-02],
       device='cuda:0')


new_candidates
torch.Size([30, 90])


tensor([[    1, 32010,  1724,  ...,   607, 15028,  1048],
        [    1, 32010,  1724,  ...,   607, 15028,   472],
        [    1, 32010,  1724,  ...,   607,   338,  1048],
        ...,
        [    1, 32010,  1724,  ..., 29892, 13407,   472],
        [    1, 32010,  1724,  ..., 29892, 13407,  1048],
        [    1, 32010,  1724,  ..., 29892,   411,   263]], device='cuda:0')


new_candidate_logprobs
torch.Size([30])


tensor([-12.3438, -13.9688, -14.5794, -14.7553, -15.7553, -15.0878, -13.5622,
        -15.4372, -14.9974, -15.7070, -15.2217, -17.0967, -15.9574, -17.2915,
        -14.7698, -16.3948, -15.8934, -16.7499, -19.2499, -15.8685, -17.6185,
        -16.5597, -18.0380, -15.0522, -15.8022, -17.4272, -17.3834, -17.4422,
        -18.4422, -18.3720], device='cuda:0')

infer end: GPU memory used: 18975 MB.
event: level
id: 78
data: [{"content": "about", "parent": 0, "prob": -12.3438138961792}, {"content": "at", "parent": 0, "prob": -13.9688138961792}, {"content": "about", "parent": 1, "prob": -14.579357147216797}, {"content": "about", "parent": 2, "prob": -14.755326271057129}, {"content": "at", "parent": 2, "prob": -15.755326271057129}, {"content": "M", "parent": 3, "prob": -15.087825775146484}, {"content": "stands", "parent": 4, "prob": -13.562172889709473}, {"content": "is", "parent": 4, "prob": -15.437172889709473}, {"content": "us", "parent": 5, "prob": -14.997363090515137}, {"content": "stands", "parent": 6, "prob": -15.707018852233887}, {"content": "stands", "parent": 7, "prob": -15.2217435836792}, {"content": "is", "parent": 7, "prob": -17.096742630004883}, {"content": "us", "parent": 8, "prob": -15.957428932189941}, {"content": "stands", "parent": 9, "prob": -17.291521072387695}, {"content": "stands", "parent": 10, "prob": -14.76979923248291}

array([[-1.359375  ,  0.484375  ,  1.6484375 , ...,  0.82421875,
        -0.11230469,  0.21777344],
       [-1.8125    , -1.78125   ,  0.27929688, ..., -1.3046875 ,
         1.828125  ,  1.6875    ],
       [-1.375     ,  0.18652344,  1.5       , ...,  0.7890625 ,
         0.34375   ,  0.515625  ],
       ...,
       [-2.203125  ,  0.83984375,  1.203125  , ..., -0.8125    ,
         2.875     ,  1.75      ],
       [-0.04345703, -0.73046875,  2.046875  , ..., -0.66015625,
         0.31640625,  0.31054688],
       [-2.21875   ,  0.81640625,  1.2265625 , ..., -1.1171875 ,
         3.03125   ,  1.34375   ]], dtype=float32)


k_mean_space
(20, 2)


array([[61.415783, 98.41484 ],
       [59.407837, 98.99452 ],
       [65.634926, 98.205414],
       [62.86031 , 98.9322  ],
       [60.341915, 99.35504 ],
       [97.834694, 75.50293 ],
       [39.186573, 97.757965],
       [49.7615  , 94.28938 ],
       [92.29124 , 25.491177],
       [38.93161 , 97.57762 ],
       [39.475327, 97.67085 ],
       [50.04192 , 94.08641 ],
       [92.79954 , 25.425976],
       [39.204636, 97.601616],
       [38.684246, 97.76068 ],
       [50.377785, 94.45401 ],
       [92.75682 , 25.546791],
       [39.129486, 97.66232 ],
       [55.78942 , 94.447105],
       [39.09196 , 97.69274 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-248.75183964,  -61.93599319])


closest
(2,)


array([14, 12])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 0.3965,  0.8242, -2.7656,  ...,  0.0000,  0.0000,  0.0000],
        [-1.1250, -2.1094, -5.3438,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.5312, -0.5234,  2.6406,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 2.0469,  2.4219, -3.4375,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.3750, -8.4375, -5.6250,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.7500,  0.9844, -4.1875,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9979e-01, 1.0889e-04, 4.5390e-05,  ..., 6.2604e-23, 4.3027e-23,
         8.4726e-24],
        [6.4159e-01, 1.8382e-01, 5.2665e-02,  ..., 4.0979e-21, 3.6164e-21,
         3.8117e-22],
        [9.9936e-01, 2.9586e-04, 1.0884e-04,  ..., 1.9275e-22, 1.7010e-22,
         1.0317e-22],
        ...,
        [6.8442e-01, 2.8531e-01, 1.2536e-02,  ..., 6.7039e-22, 3.1667e-22,
         9.0727e-23],
        [6.6630e-01, 1.3120e-01, 7.9579e-02,  ..., 9.0095e-21, 6.1921e-21,
         4.8224e-21],
        [8.0955e-01, 1.5941e-01, 1.9039e-02,  ..., 1.3779e-22, 7.3756e-23,
         1.6457e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9998, 0.9999, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [0.6416, 0.8254, 0.8781,  ..., 1.0000, 1.0000, 1.0000],
        [0.9994, 0.9997, 0.9998,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.6844, 0.9697, 0.9823,  ..., 1.0000, 1.0000, 1.0000],
        [0.6663, 0.7975, 0.8771,  ..., 1.0000, 1.0000, 1.0000],
        [0.8095, 0.9690, 0.9880,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([36])


tensor([ 0,  1,  1,  1,  1,  2,  3,  4,  4,  4,  4,  5,  6,  6,  7,  8,  9,  9,
        10, 10, 11, 12, 13, 13, 14, 14, 15, 16, 17, 17, 18, 18, 18, 18, 19, 19],
       device='cuda:0')


carryover_candidates
torch.Size([36, 90])


tensor([[    1, 32010,  1724,  ...,   607, 15028,  1048],
        [    1, 32010,  1724,  ...,   607, 15028,   472],
        [    1, 32010,  1724,  ...,   607, 15028,   472],
        ...,
        [    1, 32010,  1724,  ..., 29889,   739,   338],
        [    1, 32010,  1724,  ..., 29892,   607, 15028],
        [    1, 32010,  1724,  ..., 29892,   607, 15028]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([36])


tensor([-12.3438, -13.9688, -13.9688, -13.9688, -13.9688, -14.5794, -14.7553,
        -15.7553, -15.7553, -15.7553, -15.7553, -15.0878, -13.5622, -13.5622,
        -15.4372, -14.9974, -15.7070, -15.7070, -15.2217, -15.2217, -17.0967,
        -15.9574, -17.2915, -17.2915, -14.7698, -14.7698, -16.3948, -15.8934,
        -16.7499, -16.7499, -19.2499, -19.2499, -19.2499, -19.2499, -15.8685,
        -15.8685], device='cuda:0')


new_candidate_toks
torch.Size([36, 1])


tensor([[29871],
        [ 1048],
        [  263],
        [  385],
        [  901],
        [29871],
        [29871],
        [ 1048],
        [  263],
        [  385],
        [  901],
        [  787],
        [ 1048],
        [  472],
        [ 1048],
        [  341],
        [ 1048],
        [  472],
        [ 1048],
        [  472],
        [ 1048],
        [  341],
        [ 1048],
        [  472],
        [ 1048],
        [  472],
        [ 1048],
        [  341],
        [ 1048],
        [  472],
        [ 1048],
        [  263],
        [  278],
        [14235],
        [ 1048],
        [  472]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([36])


tensor([-2.0941e-04, -4.4381e-01, -1.6938e+00, -2.9438e+00, -3.1938e+00,
        -6.4042e-04, -8.1650e-04, -5.2817e-01, -1.6532e+00, -2.5282e+00,
        -2.6532e+00,  0.0000e+00, -1.8608e-01, -1.9361e+00, -8.5453e-02,
        -2.1458e-05, -3.4291e-01, -1.3429e+00, -1.8951e-01, -1.9395e+00,
        -6.2899e-02, -8.8215e-06, -3.4664e-01, -1.3466e+00, -1.6984e-01,
        -2.0448e+00, -8.2320e-02, -8.7023e-06, -3.7918e-01, -1.2542e+00,
        -4.0601e-01, -2.0310e+00, -2.5310e+00, -2.5310e+00, -2.1128e-01,
        -1.8363e+00], device='cuda:0')


new_candidates
torch.Size([36, 91])


tensor([[    1, 32010,  1724,  ..., 15028,  1048, 29871],
        [    1, 32010,  1724,  ..., 15028,   472,  1048],
        [    1, 32010,  1724,  ..., 15028,   472,   263],
        ...,
        [    1, 32010,  1724,  ...,   739,   338, 14235],
        [    1, 32010,  1724,  ...,   607, 15028,  1048],
        [    1, 32010,  1724,  ...,   607, 15028,   472]], device='cuda:0')


new_candidate_logprobs
torch.Size([36])


tensor([-12.3440, -14.4126, -15.6626, -16.9126, -17.1626, -14.5800, -14.7561,
        -16.2835, -17.4085, -18.2835, -18.4085, -15.0878, -13.7483, -15.4983,
        -15.5226, -14.9974, -16.0499, -17.0499, -15.4113, -17.1613, -17.1596,
        -15.9574, -17.6382, -18.6382, -14.9396, -16.8146, -16.4771, -15.8934,
        -17.1291, -18.0041, -19.6559, -21.2809, -21.7809, -21.7809, -16.0798,
        -17.7048], device='cuda:0')

infer end: GPU memory used: 19197 MB.
event: level
id: 79
data: [{"content": "", "parent": 0, "prob": -12.344023704528809}, {"content": "about", "parent": 1, "prob": -14.412622451782227}, {"content": "a", "parent": 1, "prob": -15.662622451782227}, {"content": "an", "parent": 1, "prob": -16.912622451782227}, {"content": "more", "parent": 1, "prob": -17.162622451782227}, {"content": "", "parent": 2, "prob": -14.579998016357422}, {"content": "", "parent": 3, "prob": -14.756142616271973}, {"content": "about", "parent": 4, "prob": -16.283496856689453}, {"content": "a", "parent": 4, "prob": -17.408496856689453}, {"content": "an", "parent": 4, "prob": -18.283496856689453}, {"content": "more", "parent": 4, "prob": -18.408496856689453}, {"content": "ons", "parent": 5, "prob": -15.087825775146484}, {"content": "about", "parent": 6, "prob": -13.748250007629395}, {"content": "at", "parent": 6, "prob": -15.498250007629395}, {"content": "about", "parent": 7, "prob": -15.522625923156738}, {"content":

array([[-0.69140625, -1.40625   ,  2.25      , ...,  2.984375  ,
         2.328125  , -0.14746094],
       [-1.671875  , -0.37890625,  2.3125    , ...,  0.08740234,
        -0.30664062,  0.703125  ],
       [-2.828125  , -2.59375   ,  0.62109375, ...,  0.33789062,
        -1.328125  ,  1.921875  ],
       ...,
       [-1.671875  , -1.5703125 ,  0.57421875, ..., -1.140625  ,
         1.515625  ,  2.09375   ],
       [-1.7578125 ,  0.49609375,  1.921875  , ...,  0.63671875,
        -0.08251953,  0.27734375],
       [-1.7109375 , -1.75      ,  0.671875  , ..., -1.4453125 ,
         1.8359375 ,  1.8359375 ]], dtype=float32)


k_mean_space
(20, 2)


array([[86.24021 , 63.34025 ],
       [79.96895 , 41.217983],
       [44.378906, 77.858536],
       [45.773884, 78.58823 ],
       [56.34113 , 82.29604 ],
       [86.82388 , 64.107155],
       [86.69083 , 63.732254],
       [79.813484, 40.64316 ],
       [44.263092, 78.46262 ],
       [45.672474, 79.01741 ],
       [55.449993, 82.28884 ],
       [72.35343 , 89.887314],
       [80.2178  , 39.396843],
       [68.58546 , 57.45613 ],
       [83.530075, 44.77998 ],
       [84.7557  , 98.04656 ],
       [81.204475, 40.473297],
       [69.05254 , 58.323296],
       [80.02373 , 39.070316],
       [68.65063 , 56.973934]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-133.92356873, -182.81777573])


closest
(2,)


array([ 8, 18])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 0.6953,  2.8594, -0.2041,  ...,  0.0000,  0.0000,  0.0000],
        [-1.4141,  0.3770, -2.2812,  ...,  0.0000,  0.0000,  0.0000],
        [-1.7188,  2.9531, -4.6562,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-1.0312, -1.5078, -4.6562,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.8477,  0.2324, -2.1094,  ...,  0.0000,  0.0000,  0.0000],
        [-1.1562, -2.1719, -5.3438,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9478e-01, 5.2201e-03, 5.6852e-07,  ..., 1.8385e-20, 1.8385e-20,
         1.8385e-20],
        [1.0000e+00, 8.3153e-07, 3.0590e-07,  ..., 3.1175e-24, 1.8909e-24,
         1.2996e-24],
        [8.9404e-01, 5.7154e-02, 3.0592e-02,  ..., 5.3114e-22, 3.6505e-22,
         1.3429e-22],
        ...,
        [6.5308e-01, 1.2860e-01, 1.0015e-01,  ..., 2.5300e-21, 2.2327e-21,
         2.0768e-22],
        [9.9938e-01, 2.0334e-04, 1.7945e-04,  ..., 4.8736e-23, 2.6086e-23,
         2.6086e-23],
        [7.1310e-01, 1.2392e-01, 6.6328e-02,  ..., 3.1304e-21, 2.1515e-21,
         2.9117e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9948, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8940, 0.9512, 0.9818,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.6531, 0.7817, 0.8818,  ..., 1.0000, 1.0000, 1.0000],
        [0.9994, 0.9996, 0.9998,  ..., 1.0000, 1.0000, 1.0000],
        [0.7131, 0.8370, 0.9033,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([32])


tensor([ 0,  1,  2,  2,  3,  4,  5,  6,  7,  8,  8,  9,  9, 10, 11, 11, 12, 13,
        13, 13, 13, 14, 15, 16, 17, 17, 17, 17, 18, 19, 19, 19],
       device='cuda:0')


carryover_candidates
torch.Size([32, 91])


tensor([[    1, 32010,  1724,  ..., 15028,  1048, 29871],
        [    1, 32010,  1724,  ..., 15028,   472,  1048],
        [    1, 32010,  1724,  ..., 15028,   472,   263],
        ...,
        [    1, 32010,  1724,  ...,   607, 15028,   472],
        [    1, 32010,  1724,  ...,   607, 15028,   472],
        [    1, 32010,  1724,  ...,   607, 15028,   472]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([32])


tensor([-12.3440, -14.4126, -15.6626, -15.6626, -16.9126, -17.1626, -14.5800,
        -14.7561, -16.2835, -17.4085, -17.4085, -18.2835, -18.2835, -18.4085,
        -15.0878, -15.0878, -13.7483, -15.4983, -15.4983, -15.4983, -15.4983,
        -15.5226, -14.9974, -16.0499, -17.0499, -17.0499, -17.0499, -17.0499,
        -15.4113, -17.1613, -17.1613, -17.1613], device='cuda:0')


new_candidate_toks
torch.Size([32, 1])


tensor([[29896],
        [29871],
        [ 3171],
        [19372],
        [15899],
        [ 1135],
        [29896],
        [29896],
        [29871],
        [ 3171],
        [19372],
        [15899],
        [21210],
        [ 1135],
        [15028],
        [  338],
        [29871],
        [ 1048],
        [  263],
        [  385],
        [  901],
        [29871],
        [  787],
        [29871],
        [ 1048],
        [  263],
        [  385],
        [  901],
        [29871],
        [ 1048],
        [  263],
        [  385]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([32])


tensor([-5.2361e-03, -2.2650e-06, -1.1201e-01, -2.8620e+00, -8.8874e-02,
        -1.1921e-07, -3.1805e-03, -5.9336e-03, -5.1260e-06, -1.5505e-01,
        -2.5301e+00, -2.2005e-01, -1.9700e+00,  0.0000e+00, -5.0067e-01,
        -1.0007e+00, -8.2760e-04, -4.7914e-01, -1.7291e+00, -2.3541e+00,
        -3.2291e+00, -3.0269e-03,  0.0000e+00, -2.4928e-03, -4.2605e-01,
        -2.0511e+00, -2.3011e+00, -2.9261e+00, -6.2503e-04, -3.3814e-01,
        -2.0881e+00, -2.7131e+00], device='cuda:0')


new_candidates
torch.Size([32, 92])


tensor([[    1, 32010,  1724,  ...,  1048, 29871, 29896],
        [    1, 32010,  1724,  ...,   472,  1048, 29871],
        [    1, 32010,  1724,  ...,   472,   263,  3171],
        ...,
        [    1, 32010,  1724,  ..., 15028,   472,  1048],
        [    1, 32010,  1724,  ..., 15028,   472,   263],
        [    1, 32010,  1724,  ..., 15028,   472,   385]], device='cuda:0')


new_candidate_logprobs
torch.Size([32])


tensor([-12.3493, -14.4126, -15.7746, -18.5246, -17.0015, -17.1626, -14.5832,
        -14.7621, -16.2835, -17.5636, -19.9386, -18.5035, -20.2535, -18.4085,
        -15.5885, -16.0885, -13.7491, -15.9774, -17.2274, -17.8524, -18.7274,
        -15.5257, -14.9974, -16.0524, -17.4760, -19.1010, -19.3510, -19.9760,
        -15.4119, -17.4994, -19.2494, -19.8744], device='cuda:0')

infer end: GPU memory used: 19421 MB.
event: level
id: 80
data: [{"content": "1", "parent": 0, "prob": -12.349259376525879}, {"content": "", "parent": 1, "prob": -14.41262435913086}, {"content": "height", "parent": 2, "prob": -15.77463150024414}, {"content": "tower", "parent": 2, "prob": -18.52463150024414}, {"content": "estimated", "parent": 3, "prob": -17.001497268676758}, {"content": "than", "parent": 4, "prob": -17.162622451782227}, {"content": "1", "parent": 5, "prob": -14.583178520202637}, {"content": "1", "parent": 6, "prob": -14.762076377868652}, {"content": "", "parent": 7, "prob": -16.28350257873535}, {"content": "height", "parent": 8, "prob": -17.56355094909668}, {"content": "tower", "parent": 8, "prob": -19.938552856445312}, {"content": "estimated", "parent": 9, "prob": -18.503541946411133}, {"content": "impress", "parent": 9, "prob": -20.253541946411133}, {"content": "than", "parent": 10, "prob": -18.408496856689453}, {"content": "stands", "parent": 11, "prob": -15.5884971

array([[-0.67578125, -0.8203125 , -1.109375  , ...,  2.921875  ,
         1.9453125 , -3.0625    ],
       [-0.640625  , -2.015625  ,  1.78125   , ...,  2.609375  ,
         2.484375  , -0.24609375],
       [-0.09667969, -0.26171875, -0.37695312, ...,  1.140625  ,
        -0.97265625,  2.375     ],
       ...,
       [-1.9609375 , -0.546875  ,  2.4375    , ..., -0.10986328,
        -0.37109375,  0.8203125 ],
       [-2.890625  , -2.625     ,  0.80078125, ...,  0.48828125,
        -1.34375   ,  1.984375  ],
       [-1.7109375 , -2.265625  , -0.21484375, ..., -1.2578125 ,
        -0.46875   ,  3.015625  ]], dtype=float32)


k_mean_space
(20, 2)


array([[62.21567 , 90.974014],
       [59.333332, 78.9889  ],
       [69.11567 , 75.16287 ],
       [62.113426, 75.74187 ],
       [75.08503 , 47.68305 ],
       [75.743675, 57.03536 ],
       [62.537937, 91.09282 ],
       [62.12885 , 91.08267 ],
       [59.353516, 79.13377 ],
       [69.163864, 75.26496 ],
       [61.75128 , 75.80187 ],
       [75.133316, 47.774723],
       [69.28949 , 80.53426 ],
       [75.706955, 57.374565],
       [78.1881  , 65.37992 ],
       [79.638306, 64.978386],
       [58.90582 , 78.61303 ],
       [75.711845, 56.559254],
       [76.1956  , 55.900665],
       [76.91461 , 57.457653]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-178.19462776, -153.81031799])


closest
(2,)


array([16,  4])


last_tok_logits
torch.Size([20, 32064])


tensor([[-3.2812, -2.3125, -3.3125,  ...,  0.0000,  0.0000,  0.0000],
        [-0.3125,  2.0312, -0.7305,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.2812,  0.6680,  1.0703,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-1.3281,  0.5820, -2.4531,  ...,  0.0000,  0.0000,  0.0000],
        [-1.1172,  3.4219, -4.2812,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.4375,  2.6406, -4.1250,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.8593e-01, 1.4064e-02, 3.2425e-06,  ..., 8.1911e-26, 3.4145e-26,
         2.6592e-26],
        [9.9409e-01, 5.9111e-03, 3.0409e-07,  ..., 2.0819e-20, 1.0468e-20,
         9.2383e-21],
        [9.9639e-01, 8.0183e-04, 5.5109e-04,  ..., 1.9218e-22, 1.9218e-22,
         5.1214e-24],
        ...,
        [1.0000e+00, 1.9947e-06, 1.2099e-06,  ..., 4.0030e-24, 4.0030e-24,
         1.8909e-24],
        [8.0528e-01, 9.6177e-02, 7.4902e-02,  ..., 6.1429e-22, 3.7259e-22,
         1.2096e-22],
        [8.5107e-01, 1.0165e-01, 1.5588e-02,  ..., 1.4486e-22, 7.7539e-23,
         6.0388e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9859, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9941, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9964, 0.9972, 0.9977,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8053, 0.9015, 0.9764,  ..., 1.0000, 1.0000, 1.0000],
        [0.8511, 0.9527, 0.9683,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([25])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 14, 15, 15,
        15, 16, 17, 18, 18, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([25, 92])


tensor([[    1, 32010,  1724,  ...,  1048, 29871, 29896],
        [    1, 32010,  1724,  ...,   472,  1048, 29871],
        [    1, 32010,  1724,  ...,   472,   263,  3171],
        ...,
        [    1, 32010,  1724,  ..., 15028,   472,   263],
        [    1, 32010,  1724,  ..., 15028,   472,   385],
        [    1, 32010,  1724,  ..., 15028,   472,   385]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([25])


tensor([-12.3493, -14.4126, -15.7746, -18.5246, -17.0015, -17.1626, -14.5832,
        -14.7621, -16.2835, -17.5636, -19.9386, -18.5035, -20.2535, -18.4085,
        -15.5885, -15.5885, -16.0885, -16.0885, -16.0885, -13.7491, -15.9774,
        -17.2274, -17.2274, -17.8524, -17.8524], device='cuda:0')


new_candidate_toks
torch.Size([25, 1])


tensor([[29941],
        [29896],
        [  310],
        [  292],
        [ 3171],
        [29871],
        [29941],
        [29941],
        [29896],
        [  310],
        [  292],
        [ 3171],
        [  573],
        [29871],
        [  472],
        [ 1048],
        [  263],
        [  278],
        [ 1048],
        [29896],
        [29871],
        [ 3171],
        [19372],
        [15899],
        [21210]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([25])


tensor([-1.4169e-02, -5.9301e-03, -3.6198e-03, -1.5497e-06, -1.8279e-02,
        -7.9871e-06, -3.8052e-02, -4.8618e-02, -6.7182e-03, -5.3144e-03,
        -1.5497e-06, -1.4312e-02, -7.8678e-06, -6.4373e-06, -6.6729e-01,
        -7.9229e-01, -4.2641e-01, -1.8014e+00, -1.9264e+00, -2.0547e-02,
        -4.6492e-06, -2.1657e-01, -2.3416e+00, -1.6126e-01, -2.2863e+00],
       device='cuda:0')


new_candidates
torch.Size([25, 93])


tensor([[    1, 32010,  1724,  ..., 29871, 29896, 29941],
        [    1, 32010,  1724,  ...,  1048, 29871, 29896],
        [    1, 32010,  1724,  ...,   263,  3171,   310],
        ...,
        [    1, 32010,  1724,  ...,   472,   263, 19372],
        [    1, 32010,  1724,  ...,   472,   385, 15899],
        [    1, 32010,  1724,  ...,   472,   385, 21210]], device='cuda:0')


new_candidate_logprobs
torch.Size([25])


tensor([-12.3634, -14.4186, -15.7783, -18.5246, -17.0198, -17.1626, -14.6212,
        -14.8107, -16.2902, -17.5689, -19.9386, -18.5179, -20.2535, -18.4085,
        -16.2558, -16.3808, -16.5149, -17.8899, -18.0149, -13.7696, -15.9774,
        -17.4440, -19.5690, -18.0136, -20.1386], device='cuda:0')

infer end: GPU memory used: 19647 MB.
event: level
id: 81
data: [{"content": "3", "parent": 0, "prob": -12.363428115844727}, {"content": "1", "parent": 1, "prob": -14.418554306030273}, {"content": "of", "parent": 2, "prob": -15.778251647949219}, {"content": "ing", "parent": 3, "prob": -18.524633407592773}, {"content": "height", "parent": 4, "prob": -17.019775390625}, {"content": "", "parent": 5, "prob": -17.162630081176758}, {"content": "3", "parent": 6, "prob": -14.621230125427246}, {"content": "3", "parent": 7, "prob": -14.810694694519043}, {"content": "1", "parent": 8, "prob": -16.290220260620117}, {"content": "of", "parent": 9, "prob": -17.568864822387695}, {"content": "ing", "parent": 10, "prob": -19.938554763793945}, {"content": "height", "parent": 11, "prob": -18.517854690551758}, {"content": "ive", "parent": 12, "prob": -20.253549575805664}, {"content": "", "parent": 13, "prob": -18.40850257873535}, {"content": "at", "parent": 14, "prob": -16.255783081054688}, {"content": "abou

array([[ 0.6875    , -0.00750732,  1.421875  , ..., -0.59765625,
         0.12109375, -1.4140625 ],
       [-0.69921875, -1.15625   , -1.015625  , ...,  2.984375  ,
         1.9921875 , -2.953125  ],
       [-0.96484375, -0.5546875 ,  1.        , ..., -0.63671875,
         1.984375  ,  0.49023438],
       ...,
       [-2.59375   , -1.34375   ,  0.70703125, ..., -0.78515625,
         0.2734375 ,  0.46875   ],
       [-1.4140625 ,  0.19824219,  1.484375  , ...,  0.22851562,
        -0.38476562,  0.30273438],
       [-0.81640625, -0.8515625 , -0.94921875, ...,  2.890625  ,
         2.        , -2.890625  ]], dtype=float32)


k_mean_space
(20, 2)


array([[ 71.02283 , 112.9935  ],
       [ 67.81673 , 112.80239 ],
       [ 59.578896,  94.53647 ],
       [ 60.391552,  85.40281 ],
       [ 79.40723 ,   5.606532],
       [ 66.89506 , 107.70757 ],
       [ 71.075424, 113.02918 ],
       [ 71.23748 , 113.02709 ],
       [ 67.910645, 112.847755],
       [ 59.90856 ,  95.06189 ],
       [ 61.549572,  86.070564],
       [ 79.693596,   5.606532],
       [ 60.254883,  88.34572 ],
       [ 66.95679 , 107.84006 ],
       [ 64.119354,  95.03413 ],
       [ 61.99292 ,  99.39916 ],
       [ 73.171036, 107.344734],
       [ 75.57029 , 105.92096 ],
       [ 63.914368, 103.12728 ],
       [ 67.19946 , 112.83091 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-298.96504116,  -35.53763008])


closest
(2,)


array([2, 4])


last_tok_logits
torch.Size([20, 32064])


tensor([[-3.4375, -1.9766, -2.3906,  ...,  0.0000,  0.0000,  0.0000],
        [-3.7344, -2.7500, -3.6719,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.9570, -1.5156, -5.8750,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 1.9297, -1.3203, -3.4062,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.3438,  0.8984,  3.1875,  ...,  0.0000,  0.0000,  0.0000],
        [-3.2812, -2.3594, -3.1719,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9995e-01, 1.4739e-05, 1.4739e-05,  ..., 2.0328e-23, 5.8241e-24,
         2.7511e-24],
        [9.7701e-01, 2.2977e-02, 4.1258e-06,  ..., 7.1632e-26, 6.3215e-26,
         1.4105e-26],
        [5.1731e-01, 2.7690e-01, 1.4821e-01,  ..., 5.7417e-22, 3.0733e-22,
         2.3935e-22],
        ...,
        [5.8344e-01, 4.0099e-01, 1.3721e-02,  ..., 5.1443e-20, 3.1202e-20,
         1.4739e-20],
        [9.9830e-01, 1.0315e-03, 3.3489e-04,  ..., 3.1746e-22, 1.9255e-22,
         1.0306e-22],
        [9.8900e-01, 1.0987e-02, 3.6857e-06,  ..., 1.5351e-25, 8.2166e-26,
         4.3980e-26]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9770, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.5173, 0.7942, 0.9424,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.5834, 0.9844, 0.9981,  ..., 1.0000, 1.0000, 1.0000],
        [0.9983, 0.9993, 0.9997,  ..., 1.0000, 1.0000, 1.0000],
        [0.9890, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([31])


tensor([ 0,  1,  2,  2,  2,  3,  3,  4,  5,  6,  7,  8,  9,  9,  9, 10, 10, 11,
        12, 13, 14, 14, 14, 15, 16, 16, 16, 17, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([31, 93])


tensor([[    1, 32010,  1724,  ..., 29871, 29896, 29941],
        [    1, 32010,  1724,  ...,  1048, 29871, 29896],
        [    1, 32010,  1724,  ...,   263,  3171,   310],
        ...,
        [    1, 32010,  1724,  ...,   787,   338,   278],
        [    1, 32010,  1724,  ...,   787,   338,  1048],
        [    1, 32010,  1724,  ...,  1048, 29871, 29896]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([31])


tensor([-12.3634, -14.4186, -15.7783, -15.7783, -15.7783, -18.5246, -18.5246,
        -17.0198, -17.1626, -14.6212, -14.8107, -16.2902, -17.5689, -17.5689,
        -17.5689, -19.9386, -19.9386, -18.5179, -20.2535, -18.4085, -16.2558,
        -16.2558, -16.2558, -16.3808, -16.5149, -16.5149, -16.5149, -17.8899,
        -17.8899, -18.0149, -13.7696], device='cuda:0')


new_candidate_toks
torch.Size([31, 1])


tensor([[29889],
        [29941],
        [ 1048],
        [ 8886],
        [14235],
        [ 3171],
        [29871],
        [  310],
        [29896],
        [29889],
        [29889],
        [29941],
        [ 1048],
        [ 8886],
        [14235],
        [ 3171],
        [29871],
        [  310],
        [ 3171],
        [29896],
        [ 1048],
        [  385],
        [  263],
        [29871],
        [28396],
        [20364],
        [ 1700],
        [10150],
        [15655],
        [29871],
        [29941]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([31])


tensor([-4.6016e-05, -2.3254e-02, -6.5912e-01, -1.2841e+00, -1.9091e+00,
        -2.0592e-01, -1.7059e+00, -1.8187e-03, -1.1727e-03, -1.4424e-05,
        -7.3316e-05, -2.9765e-02, -9.2680e-01, -9.2680e-01, -1.9268e+00,
        -1.6428e-01, -1.9143e+00, -9.2848e-04, -4.3376e-02, -2.1905e-03,
        -9.2504e-01, -1.1750e+00, -1.5500e+00, -1.9654e-04, -6.3954e-01,
        -1.0145e+00, -3.6395e+00, -5.3882e-01, -9.1382e-01, -1.7016e-03,
        -1.1056e-02], device='cuda:0')


new_candidates
torch.Size([31, 94])


tensor([[    1, 32010,  1724,  ..., 29896, 29941, 29889],
        [    1, 32010,  1724,  ..., 29871, 29896, 29941],
        [    1, 32010,  1724,  ...,  3171,   310,  1048],
        ...,
        [    1, 32010,  1724,  ...,   338,   278, 15655],
        [    1, 32010,  1724,  ...,   338,  1048, 29871],
        [    1, 32010,  1724,  ..., 29871, 29896, 29941]], device='cuda:0')


new_candidate_logprobs
torch.Size([31])


tensor([-12.3635, -14.4418, -16.4374, -17.0624, -17.6874, -18.7306, -20.2306,
        -17.0216, -17.1638, -14.6212, -14.8108, -16.3200, -18.4957, -18.4957,
        -19.4957, -20.1028, -21.8528, -18.5188, -20.2969, -18.4107, -17.1808,
        -17.4308, -17.8058, -16.3810, -17.1545, -17.5294, -20.1545, -18.4287,
        -18.8037, -18.0166, -13.7807], device='cuda:0')

infer end: GPU memory used: 19875 MB.
event: level
id: 82
data: [{"content": ".", "parent": 0, "prob": -12.363473892211914}, {"content": "3", "parent": 1, "prob": -14.441808700561523}, {"content": "about", "parent": 2, "prob": -16.437368392944336}, {"content": "nearly", "parent": 2, "prob": -17.062368392944336}, {"content": "approximately", "parent": 2, "prob": -17.687368392944336}, {"content": "height", "parent": 3, "prob": -18.73055076599121}, {"content": "", "parent": 3, "prob": -20.23055076599121}, {"content": "of", "parent": 4, "prob": -17.021595001220703}, {"content": "1", "parent": 5, "prob": -17.163803100585938}, {"content": ".", "parent": 6, "prob": -14.621244430541992}, {"content": ".", "parent": 7, "prob": -14.810768127441406}, {"content": "3", "parent": 8, "prob": -16.31998634338379}, {"content": "about", "parent": 9, "prob": -18.49566078186035}, {"content": "nearly", "parent": 9, "prob": -18.49566078186035}, {"content": "approximately", "parent": 9, "prob": -19.49566078186

array([[-0.26757812, -1.1796875 ,  0.41796875, ...,  1.0703125 ,
         0.2890625 , -0.9609375 ],
       [ 0.421875  , -0.31054688,  1.359375  , ..., -0.67578125,
         0.29296875, -1.421875  ],
       [-1.7578125 ,  0.03613281,  1.7578125 , ..., -0.84765625,
        -0.00842285,  0.06933594],
       ...,
       [-1.1484375 , -0.51953125,  0.6171875 , ..., -0.42773438,
         1.078125  ,  0.17382812],
       [-0.14746094, -0.14257812, -0.84765625, ...,  0.08740234,
        -0.33203125,  2.375     ],
       [-0.71484375, -0.9921875 , -1.        , ...,  3.25      ,
         1.515625  , -2.734375  ]], dtype=float32)


k_mean_space
(20, 2)


array([[47.485195, 96.168755],
       [55.70908 , 93.282005],
       [94.27857 , 43.370567],
       [93.176285, 49.256866],
       [94.6814  , 42.20635 ],
       [96.18571 , 65.074554],
       [87.53383 , 70.023346],
       [94.755   , 49.64985 ],
       [63.87197 , 89.963524],
       [47.834736, 96.24281 ],
       [48.042206, 96.30679 ],
       [55.386086, 93.43971 ],
       [94.475395, 43.071396],
       [93.18216 , 49.450115],
       [94.77573 , 41.747307],
       [97.06493 , 65.41747 ],
       [87.49232 , 70.175606],
       [94.98824 , 50.22051 ],
       [97.84044 , 66.63816 ],
       [63.52487 , 89.97686 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-108.13177681, -244.42817116])


closest
(2,)


array([ 0, 14])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 1.3672,  1.0156, -1.6953,  ...,  0.0000,  0.0000,  0.0000],
        [-4.2812, -2.5469, -3.1406,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.5508,  0.4375, -1.8672,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.9570,  1.0234, -4.5625,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.3125,  2.1875,  0.0635,  ...,  0.0000,  0.0000,  0.0000],
        [-2.5469, -3.0156, -3.7344,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 7.3382e-07, 2.3824e-07,  ..., 1.2996e-24, 1.0121e-24,
         4.2191e-25],
        [9.9989e-01, 1.0890e-04, 1.5533e-06,  ..., 3.1172e-24, 2.4277e-24,
         1.4725e-24],
        [1.0000e+00, 8.7642e-08, 6.8256e-08,  ..., 7.8824e-25, 3.2859e-25,
         1.9930e-25],
        ...,
        [5.7194e-01, 1.8568e-01, 9.9389e-02,  ..., 2.8450e-21, 2.8450e-21,
         2.5107e-21],
        [9.9748e-01, 1.4997e-03, 4.8687e-04,  ..., 9.5785e-24, 5.1270e-24,
         2.1372e-24],
        [9.8592e-01, 1.4063e-02, 1.2824e-05,  ..., 3.0133e-26, 2.6592e-26,
         8.6332e-27]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.5719, 0.7576, 0.8570,  ..., 1.0000, 1.0000, 1.0000],
        [0.9975, 0.9990, 0.9995,  ..., 1.0000, 1.0000, 1.0000],
        [0.9859, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([27])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  7,  7,  7,  7,  8,  9, 10, 11, 12, 13,
        14, 15, 16, 17, 17, 17, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([27, 94])


tensor([[    1, 32010,  1724,  ..., 29896, 29941, 29889],
        [    1, 32010,  1724,  ..., 29871, 29896, 29941],
        [    1, 32010,  1724,  ...,  3171,   310,  1048],
        ...,
        [    1, 32010,  1724,  ..., 15899,  3171,   310],
        [    1, 32010,  1724,  ..., 21210,   573,  3171],
        [    1, 32010,  1724,  ...,  1135, 29871, 29896]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([27])


tensor([-12.3635, -14.4418, -16.4374, -17.0624, -17.6874, -18.7306, -20.2306,
        -17.0216, -17.0216, -17.0216, -17.0216, -17.0216, -17.1638, -14.6212,
        -14.8108, -16.3200, -18.4957, -18.4957, -19.4957, -20.1028, -21.8528,
        -18.5188, -18.5188, -18.5188, -18.5188, -20.2969, -18.4107],
       device='cuda:0')


new_candidate_toks
torch.Size([27, 1])


tensor([[29953],
        [29889],
        [29871],
        [29871],
        [29871],
        [  310],
        [29896],
        [ 1048],
        [ 8886],
        [29871],
        [ 2820],
        [14235],
        [29941],
        [29953],
        [29953],
        [29889],
        [29871],
        [29871],
        [29871],
        [  310],
        [29896],
        [ 8886],
        [ 1048],
        [29871],
        [14235],
        [  310],
        [29941]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([27])


tensor([-9.5367e-07, -1.1242e-04, -4.7684e-07, -1.5497e-06, -4.7684e-07,
        -1.0274e-03, -1.9920e-03, -9.4841e-01, -1.3234e+00, -1.9484e+00,
        -2.3234e+00, -2.6984e+00, -8.6259e-03, -1.6689e-06, -4.5300e-06,
        -2.1287e-04, -5.9605e-07, -2.0266e-06, -9.5367e-07, -1.3390e-03,
        -1.7748e-03, -5.5871e-01, -1.6837e+00, -2.3087e+00, -2.6837e+00,
        -2.5234e-03, -1.4181e-02], device='cuda:0')


new_candidates
torch.Size([27, 95])


tensor([[    1, 32010,  1724,  ..., 29941, 29889, 29953],
        [    1, 32010,  1724,  ..., 29896, 29941, 29889],
        [    1, 32010,  1724,  ...,   310,  1048, 29871],
        ...,
        [    1, 32010,  1724,  ...,  3171,   310, 14235],
        [    1, 32010,  1724,  ...,   573,  3171,   310],
        [    1, 32010,  1724,  ..., 29871, 29896, 29941]], device='cuda:0')


new_candidate_logprobs
torch.Size([27])


tensor([-12.3635, -14.4419, -16.4374, -17.0624, -17.6874, -18.7316, -20.2325,
        -17.9700, -18.3450, -18.9700, -19.3450, -19.7200, -17.1724, -14.6212,
        -14.8108, -16.3202, -18.4957, -18.4957, -19.4957, -20.1042, -21.8546,
        -19.0775, -20.2025, -20.8275, -21.2025, -20.2994, -18.4249],
       device='cuda:0')

infer end: GPU memory used: 20105 MB.
event: level
id: 83
data: [{"content": "6", "parent": 0, "prob": -12.36347484588623}, {"content": ".", "parent": 1, "prob": -14.44192123413086}, {"content": "", "parent": 2, "prob": -16.437368392944336}, {"content": "", "parent": 3, "prob": -17.06237030029297}, {"content": "", "parent": 4, "prob": -17.687368392944336}, {"content": "of", "parent": 5, "prob": -18.731578826904297}, {"content": "1", "parent": 6, "prob": -20.232542037963867}, {"content": "about", "parent": 7, "prob": -17.97000503540039}, {"content": "nearly", "parent": 7, "prob": -18.34500503540039}, {"content": "", "parent": 7, "prob": -18.97000503540039}, {"content": "around", "parent": 7, "prob": -19.34500503540039}, {"content": "approximately", "parent": 7, "prob": -19.72000503540039}, {"content": "3", "parent": 8, "prob": -17.172428131103516}, {"content": "6", "parent": 9, "prob": -14.621246337890625}, {"content": "6", "parent": 10, "prob": -14.810772895812988}, {"content": ".", "p

array([[ 1.5390625 ,  1.4140625 ,  2.515625  , ..., -2.78125   ,
         0.609375  , -1.4921875 ],
       [-0.38476562, -1.328125  ,  0.515625  , ...,  0.96484375,
         0.23828125, -0.7734375 ],
       [-0.37304688, -2.125     ,  1.53125   , ...,  2.28125   ,
         2.3125    , -0.29101562],
       ...,
       [-0.68359375, -2.34375   ,  1.375     , ...,  1.578125  ,
         2.140625  , -1.015625  ],
       [-0.50390625, -2.234375  ,  1.421875  , ...,  2.03125   ,
         2.375     , -0.60546875],
       [-0.36523438, -0.0612793 ,  0.51953125, ..., -0.37695312,
         1.90625   ,  0.46289062]], dtype=float32)


k_mean_space
(20, 2)


array([[92.55666 , 48.150677],
       [93.925385, 61.57979 ],
       [40.273903, 84.991585],
       [40.614033, 85.05946 ],
       [39.972885, 85.09276 ],
       [62.082455, 91.06931 ],
       [88.37734 , 73.57691 ],
       [59.295624, 91.021065],
       [59.498627, 89.82364 ],
       [40.332634, 85.52446 ],
       [59.97928 , 92.00836 ],
       [58.93999 , 91.46229 ],
       [94.75836 , 62.95971 ],
       [92.629845, 48.404827],
       [92.5421  , 48.442017],
       [94.01548 , 61.56137 ],
       [40.432793, 84.93511 ],
       [41.01831 , 84.96349 ],
       [40.2498  , 85.095215],
       [62.520435, 91.52134 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-240.85987473, -109.96258545])


closest
(2,)


array([4, 0])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 1.9766,  0.1758, -0.4902,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.6094,  0.5664, -1.8438,  ...,  0.0000,  0.0000,  0.0000],
        [-0.6367,  1.5156, -0.1279,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.2871,  2.3281,  0.0854,  ...,  0.0000,  0.0000,  0.0000],
        [-0.4922,  1.3672, -0.1943,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.8945, -2.3438, -5.6250,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9715e-01, 2.1813e-03, 5.5151e-04,  ..., 1.4211e-21, 7.6066e-22,
         1.4978e-22],
        [1.0000e+00, 6.4759e-07, 1.6374e-07,  ..., 3.0868e-25, 2.2583e-25,
         1.5521e-25],
        [9.9753e-01, 2.4726e-03, 5.0310e-07,  ..., 7.6853e-21, 6.3713e-21,
         5.9853e-21],
        ...,
        [5.9265e-01, 4.0732e-01, 1.1216e-05,  ..., 1.6096e-19, 1.6096e-19,
         1.1062e-19],
        [9.9752e-01, 2.4726e-03, 8.8297e-07,  ..., 2.3672e-20, 2.3672e-20,
         2.2238e-20],
        [3.7046e-01, 3.7046e-01, 1.9829e-01,  ..., 1.0396e-22, 7.1452e-23,
         4.3338e-23]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9971, 0.9993, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9975, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.5927, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9975, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.3705, 0.7409, 0.9392,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([28])


tensor([ 0,  1,  2,  3,  3,  4,  5,  5,  5,  6,  7,  8,  9, 10, 11, 12, 12, 12,
        13, 14, 15, 16, 17, 17, 18, 19, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([28, 95])


tensor([[    1, 32010,  1724,  ..., 29941, 29889, 29953],
        [    1, 32010,  1724,  ..., 29896, 29941, 29889],
        [    1, 32010,  1724,  ...,   310,  1048, 29871],
        ...,
        [    1, 32010,  1724,  ...,   292,  3171,   310],
        [    1, 32010,  1724,  ...,   292,  3171,   310],
        [    1, 32010,  1724,  ...,   292,  3171,   310]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([28])


tensor([-12.3635, -14.4419, -16.4374, -17.0624, -17.0624, -17.6874, -18.7316,
        -18.7316, -18.7316, -20.2325, -17.9700, -18.3450, -18.9700, -19.3450,
        -19.7200, -17.1724, -17.1724, -17.1724, -14.6212, -14.8108, -16.3202,
        -18.4957, -18.4957, -18.4957, -19.4957, -20.1042, -20.1042, -20.1042],
       device='cuda:0')


new_candidate_toks
torch.Size([28, 1])


tensor([[ 7800],
        [29953],
        [29896],
        [29896],
        [29906],
        [29896],
        [ 1048],
        [ 8886],
        [14235],
        [29941],
        [29871],
        [29871],
        [29896],
        [29871],
        [29871],
        [29889],
        [ 7800],
        [20052],
        [ 7800],
        [ 7800],
        [29953],
        [29896],
        [29896],
        [29906],
        [29896],
        [ 1048],
        [ 8886],
        [14235]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([28])


tensor([-2.8565e-03, -8.3447e-07, -2.4773e-03, -5.7595e-01, -8.2595e-01,
        -3.6024e-03, -5.8930e-01, -1.5893e+00, -1.7143e+00, -2.3253e-02,
        -4.7684e-07, -1.3113e-06, -1.9319e-03, -4.7684e-07, -7.1526e-07,
        -6.7656e-01, -1.3016e+00, -1.6766e+00, -4.8640e-03, -2.0932e-03,
        -1.9074e-06, -1.9306e-03, -5.2315e-01, -8.9815e-01, -2.4786e-03,
        -9.9301e-01, -9.9301e-01, -1.6180e+00], device='cuda:0')


new_candidates
torch.Size([28, 96])


tensor([[    1, 32010,  1724,  ..., 29889, 29953,  7800],
        [    1, 32010,  1724,  ..., 29941, 29889, 29953],
        [    1, 32010,  1724,  ...,  1048, 29871, 29896],
        ...,
        [    1, 32010,  1724,  ...,  3171,   310,  1048],
        [    1, 32010,  1724,  ...,  3171,   310,  8886],
        [    1, 32010,  1724,  ...,  3171,   310, 14235]], device='cuda:0')


new_candidate_logprobs
torch.Size([28])


tensor([-12.3663, -14.4419, -16.4398, -17.6383, -17.8883, -17.6910, -19.3209,
        -20.3209, -20.4459, -20.2558, -17.9700, -18.3450, -18.9719, -19.3450,
        -19.7200, -17.8490, -18.4740, -18.8490, -14.6261, -14.8129, -16.3202,
        -18.4976, -19.0188, -19.3938, -19.4981, -21.0972, -21.0972, -21.7222],
       device='cuda:0')

infer end: GPU memory used: 20339 MB.
event: level
id: 84
data: [{"content": "miles", "parent": 0, "prob": -12.366331100463867}, {"content": "6", "parent": 1, "prob": -14.441922187805176}, {"content": "1", "parent": 2, "prob": -16.43984603881836}, {"content": "1", "parent": 3, "prob": -17.638324737548828}, {"content": "2", "parent": 3, "prob": -17.888324737548828}, {"content": "1", "parent": 4, "prob": -17.69097137451172}, {"content": "about", "parent": 5, "prob": -19.32087516784668}, {"content": "nearly", "parent": 5, "prob": -20.32087516784668}, {"content": "approximately", "parent": 5, "prob": -20.44587516784668}, {"content": "3", "parent": 6, "prob": -20.255794525146484}, {"content": "", "parent": 7, "prob": -17.97000503540039}, {"content": "", "parent": 8, "prob": -18.345006942749023}, {"content": "1", "parent": 9, "prob": -18.97193717956543}, {"content": "", "parent": 10, "prob": -19.34500503540039}, {"content": "", "parent": 11, "prob": -19.72000503540039}, {"content": ".", "par

array([[-1.0078125 , -1.203125  ,  0.02050781, ..., -0.59375   ,
        -0.36523438, -0.41210938],
       [ 1.265625  ,  1.3125    ,  2.53125   , ..., -3.015625  ,
         0.40429688, -1.1953125 ],
       [-0.54296875, -1.0234375 , -1.046875  , ...,  3.09375   ,
         2.09375   , -2.890625  ],
       ...,
       [-0.95703125, -0.03222656,  0.60546875, ..., -0.4296875 ,
         0.24804688, -1.4375    ],
       [-0.70703125, -1.2734375 ,  0.47070312, ..., -0.14746094,
        -0.10205078, -0.02990723],
       [-0.921875  , -1.203125  ,  0.08935547, ..., -0.76171875,
        -0.296875  , -0.36523438]], dtype=float32)


k_mean_space
(20, 2)


array([[ 87.710915,  24.293892],
       [ 76.04551 ,  99.974815],
       [ 56.96311 , 102.01398 ],
       [ 56.960857, 101.79279 ],
       [ 64.84844 , 103.613976],
       [ 56.921135, 101.87582 ],
       [ 67.51014 ,  97.18651 ],
       [ 65.23475 ,  97.14073 ],
       [ 67.40163 ,  96.16169 ],
       [ 76.08476 , 101.67526 ],
       [ 54.685024, 100.74429 ],
       [ 54.45547 , 100.84608 ],
       [ 57.3601  , 102.23979 ],
       [ 54.842854, 100.65475 ],
       [ 54.897835, 100.89259 ],
       [ 77.28856 , 103.95347 ],
       [ 87.55029 ,  28.51959 ],
       [ 94.29963 ,  78.18145 ],
       [ 87.424515,  27.946484],
       [ 88.22936 ,  25.906885]], dtype=float32)


k_mean_clusters
(20,)


array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-276.6437521 ,  -79.12827492])


closest
(2,)


array([11,  0])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.8906, -2.7500, -3.6875,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.6562,  0.4629, -0.3281,  ...,  0.0000,  0.0000,  0.0000],
        [-3.6875, -3.3906, -4.5625,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.5820,  0.5195,  0.1318,  ...,  0.0000,  0.0000,  0.0000],
        [-1.9766, -5.1250, -4.3750,  ...,  0.0000,  0.0000,  0.0000],
        [-1.0859, -2.3906, -3.8594,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9962e-01, 3.3533e-04, 3.5344e-05,  ..., 1.6681e-24, 6.1364e-25,
         4.7791e-25],
        [9.9642e-01, 2.4699e-03, 9.0861e-04,  ..., 8.6131e-22, 8.6131e-22,
         1.1657e-22],
        [9.9592e-01, 4.0701e-03, 3.7115e-06,  ..., 5.6867e-26, 1.2689e-26,
         7.6961e-27],
        ...,
        [1.0000e+00, 1.5230e-08, 1.3440e-08,  ..., 9.4141e-26, 8.3079e-26,
         6.0183e-27],
        [9.9963e-01, 3.3534e-04, 1.4734e-05,  ..., 1.7934e-23, 4.5343e-24,
         2.7502e-24],
        [9.9964e-01, 2.9594e-04, 5.1426e-05,  ..., 6.9537e-25, 5.4155e-25,
         2.8987e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9996, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9964, 0.9989, 0.9998,  ..., 1.0000, 1.0000, 1.0000],
        [0.9959, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9996, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9996, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([21])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 11, 12, 13, 14, 15, 16,
        17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([21, 96])


tensor([[    1, 32010,  1724,  ..., 29889, 29953,  7800],
        [    1, 32010,  1724,  ..., 29941, 29889, 29953],
        [    1, 32010,  1724,  ...,  1048, 29871, 29896],
        ...,
        [    1, 32010,  1724,  ..., 29896, 29941, 20052],
        [    1, 32010,  1724,  ..., 29889, 29953,  7800],
        [    1, 32010,  1724,  ..., 29889, 29953,  7800]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([21])


tensor([-12.3663, -14.4419, -16.4398, -17.6383, -17.8883, -17.6910, -19.3209,
        -20.3209, -20.4459, -20.2558, -17.9700, -18.3450, -18.3450, -18.9719,
        -19.3450, -19.7200, -17.8490, -18.4740, -18.8490, -14.6261, -14.8129],
       device='cuda:0')


new_candidate_toks
torch.Size([21, 1])


tensor([[  313],
        [ 7800],
        [29941],
        [29946],
        [29906],
        [29941],
        [29871],
        [29871],
        [29871],
        [29889],
        [29896],
        [29906],
        [29896],
        [29941],
        [29896],
        [29896],
        [29953],
        [  313],
        [ 2699],
        [  313],
        [  313]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([21])


tensor([-3.8333e-04, -3.5902e-03, -4.0866e-03, -1.0056e-01, -3.4028e-04,
        -6.7207e-03, -5.9605e-07, -1.5497e-06, -1.4305e-06, -2.1696e-05,
        -4.0802e-03, -4.2871e-01, -1.0537e+00, -2.4804e-03, -2.8065e-03,
        -4.0808e-03, -6.7236e-05, -8.2211e-04,  0.0000e+00, -3.6509e-04,
        -3.6055e-04], device='cuda:0')


new_candidates
torch.Size([21, 97])


tensor([[    1, 32010,  1724,  ..., 29953,  7800,   313],
        [    1, 32010,  1724,  ..., 29889, 29953,  7800],
        [    1, 32010,  1724,  ..., 29871, 29896, 29941],
        ...,
        [    1, 32010,  1724,  ..., 29941, 20052,  2699],
        [    1, 32010,  1724,  ..., 29953,  7800,   313],
        [    1, 32010,  1724,  ..., 29953,  7800,   313]], device='cuda:0')


new_candidate_logprobs
torch.Size([21])


tensor([-12.3667, -14.4455, -16.4439, -17.7389, -17.8887, -17.6977, -19.3209,
        -20.3209, -20.4459, -20.2558, -17.9741, -18.7737, -19.3987, -18.9744,
        -19.3478, -19.7241, -17.8491, -18.4748, -18.8490, -14.6265, -14.8132],
       device='cuda:0')

infer end: GPU memory used: 20575 MB.
event: level
id: 85
data: [{"content": "(", "parent": 0, "prob": -12.366714477539062}, {"content": "miles", "parent": 1, "prob": -14.445512771606445}, {"content": "3", "parent": 2, "prob": -16.443933486938477}, {"content": "4", "parent": 3, "prob": -17.73888397216797}, {"content": "2", "parent": 4, "prob": -17.88866424560547}, {"content": "3", "parent": 5, "prob": -17.69769287109375}, {"content": "", "parent": 6, "prob": -19.32087516784668}, {"content": "", "parent": 7, "prob": -20.320877075195312}, {"content": "", "parent": 8, "prob": -20.445877075195312}, {"content": ".", "parent": 9, "prob": -20.255815505981445}, {"content": "1", "parent": 10, "prob": -17.974084854125977}, {"content": "2", "parent": 11, "prob": -18.773719787597656}, {"content": "1", "parent": 11, "prob": -19.398719787597656}, {"content": "3", "parent": 12, "prob": -18.974416732788086}, {"content": "1", "parent": 13, "prob": -19.347810745239258}, {"content": "1", "parent": 14, "p

array([[-0.78515625, -0.71484375, -1.        , ...,  0.83203125,
         0.21777344, -0.671875  ],
       [-0.46484375, -1.390625  ,  0.22558594, ..., -0.48632812,
        -0.19921875,  0.05249023],
       [ 0.578125  , -0.23144531,  1.1875    , ..., -0.6015625 ,
         0.22167969, -1.6171875 ],
       ...,
       [ 0.51171875, -0.49023438, -1.671875  , ...,  0.3203125 ,
         0.65625   , -1.0703125 ],
       [-0.453125  , -1.609375  ,  1.640625  , ..., -0.22558594,
        -0.609375  , -0.11425781],
       [-0.6640625 , -0.875     , -1.2109375 , ...,  1.3046875 ,
         0.3359375 , -0.6328125 ]], dtype=float32)


k_mean_space
(20, 2)


array([[67.42813 , 92.09151 ],
       [82.78113 , 85.06081 ],
       [82.4513  , 40.75803 ],
       [81.345566, 52.525864],
       [80.07257 , 58.107235],
       [82.481674, 40.77084 ],
       [58.933456, 85.86513 ],
       [58.41585 , 86.062004],
       [59.17711 , 85.89816 ],
       [83.426315, 73.44367 ],
       [55.374268, 80.2887  ],
       [65.5633  , 80.94694 ],
       [55.836037, 80.07027 ],
       [82.91519 , 42.376614],
       [55.291805, 80.14856 ],
       [55.4341  , 80.23243 ],
       [82.43136 , 68.52565 ],
       [67.93434 , 91.05871 ],
       [84.55561 , 82.60662 ],
       [67.8639  , 92.44026 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-215.21955967, -145.6974411 ])


closest
(2,)


array([14,  2])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 0.4434,  0.0491, -3.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-1.8828, -3.5625, -4.1875,  ...,  0.0000,  0.0000,  0.0000],
        [-4.2812, -3.2188, -3.4375,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.7109, -2.8281, -2.1719,  ...,  0.0000,  0.0000,  0.0000],
        [-1.4062, -1.6562, -4.2812,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.9727, -0.1729, -3.0469,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9984e-01, 6.6046e-05, 5.8285e-05,  ..., 6.7309e-22, 3.6028e-22,
         2.1852e-22],
        [9.9972e-01, 2.6118e-04, 7.8871e-06,  ..., 7.4765e-24, 1.8903e-24,
         8.9294e-25],
        [9.9999e-01, 1.0130e-05, 1.3709e-06,  ..., 6.9561e-25, 6.1387e-25,
         4.2191e-25],
        ...,
        [9.9122e-01, 4.5902e-03, 3.5749e-03,  ..., 1.8139e-21, 1.6007e-21,
         1.1002e-21],
        [9.9925e-01, 5.5267e-04, 1.0883e-04,  ..., 5.1361e-24, 2.7492e-24,
         8.9252e-25],
        [9.9974e-01, 1.0888e-04, 8.4797e-05,  ..., 4.6256e-22, 4.6256e-22,
         3.1792e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9998, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9997, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9912, 0.9958, 0.9994,  ..., 1.0000, 1.0000, 1.0000],
        [0.9993, 0.9998, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [0.9997, 0.9999, 0.9999,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([24])


tensor([ 0,  1,  2,  3,  3,  3,  4,  4,  5,  6,  7,  7,  8,  9, 10, 11, 12, 13,
        14, 15, 16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([24, 97])


tensor([[    1, 32010,  1724,  ..., 29953,  7800,   313],
        [    1, 32010,  1724,  ..., 29889, 29953,  7800],
        [    1, 32010,  1724,  ..., 29871, 29896, 29941],
        ...,
        [    1, 32010,  1724,  ..., 29941,  7800,   313],
        [    1, 32010,  1724,  ..., 29941, 20052,  2699],
        [    1, 32010,  1724,  ..., 29953,  7800,   313]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([24])


tensor([-12.3667, -14.4455, -16.4439, -17.7389, -17.7389, -17.7389, -17.8887,
        -17.8887, -17.6977, -19.3209, -20.3209, -20.3209, -20.4459, -20.2558,
        -17.9741, -18.7737, -19.3987, -18.9744, -19.3478, -19.7241, -17.8491,
        -18.4748, -18.8490, -14.6265], device='cuda:0')


new_candidate_toks
torch.Size([24, 1])


tensor([[29906],
        [  313],
        [29889],
        [ 7800],
        [20052],
        [29892],
        [20052],
        [29892],
        [29889],
        [29896],
        [29896],
        [29906],
        [29896],
        [29953],
        [29941],
        [29906],
        [29946],
        [29889],
        [29941],
        [29941],
        [ 7800],
        [29906],
        [  313],
        [29906]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([24])


tensor([-1.5963e-04, -2.8203e-04, -1.2398e-05, -7.6320e-01, -1.3882e+00,
        -1.5132e+00, -1.8823e-01, -2.1882e+00, -1.4032e-04, -1.7033e-03,
        -2.8116e-01, -1.4062e+00, -1.7039e-03, -3.5763e-07, -1.1085e-02,
        -2.0906e-04, -4.9401e-02, -3.3379e-06, -2.3286e-02, -8.6348e-03,
        -1.7198e-03, -8.8214e-03, -7.4599e-04, -2.5532e-04], device='cuda:0')


new_candidates
torch.Size([24, 98])


tensor([[    1, 32010,  1724,  ...,  7800,   313, 29906],
        [    1, 32010,  1724,  ..., 29953,  7800,   313],
        [    1, 32010,  1724,  ..., 29896, 29941, 29889],
        ...,
        [    1, 32010,  1724,  ...,  7800,   313, 29906],
        [    1, 32010,  1724,  ..., 20052,  2699,   313],
        [    1, 32010,  1724,  ...,  7800,   313, 29906]], device='cuda:0')


new_candidate_logprobs
torch.Size([24])


tensor([-12.3669, -14.4458, -16.4439, -18.5021, -19.1271, -19.2521, -18.0769,
        -20.0769, -17.6978, -19.3226, -20.6020, -21.7270, -20.4476, -20.2558,
        -17.9852, -18.7739, -19.4481, -18.9744, -19.3711, -19.7327, -17.8508,
        -18.4836, -18.8497, -14.6267], device='cuda:0')

infer end: GPU memory used: 20813 MB.
event: level
id: 86
data: [{"content": "2", "parent": 0, "prob": -12.366873741149902}, {"content": "(", "parent": 1, "prob": -14.445795059204102}, {"content": ".", "parent": 2, "prob": -16.443946838378906}, {"content": "miles", "parent": 3, "prob": -18.5020809173584}, {"content": "kilom", "parent": 3, "prob": -19.1270809173584}, {"content": ",", "parent": 3, "prob": -19.2520809173584}, {"content": "kilom", "parent": 4, "prob": -18.07689094543457}, {"content": ",", "parent": 4, "prob": -20.076892852783203}, {"content": ".", "parent": 5, "prob": -17.697834014892578}, {"content": "1", "parent": 6, "prob": -19.32257843017578}, {"content": "1", "parent": 7, "prob": -20.602041244506836}, {"content": "2", "parent": 7, "prob": -21.727041244506836}, {"content": "1", "parent": 8, "prob": -20.447580337524414}, {"content": "6", "parent": 9, "prob": -20.255815505981445}, {"content": "3", "parent": 10, "prob": -17.985170364379883}, {"content": "2", "parent": 11,

array([[-0.26367188, -0.9453125 ,  1.46875   , ...,  3.4375    ,
        -1.6171875 , -3.578125  ],
       [-0.578125  , -1.0234375 , -1.0859375 , ...,  0.9296875 ,
         0.0402832 , -0.75      ],
       [-0.31445312, -0.921875  ,  0.43359375, ...,  0.89453125,
         0.40234375, -0.86328125],
       ...,
       [-0.5546875 , -0.9921875 ,  0.12597656, ...,  0.5625    ,
         0.24511719, -0.82421875],
       [ 0.29296875, -0.03613281,  1.0625    , ..., -0.78515625,
         0.2421875 , -1.2265625 ],
       [ 0.17578125, -0.234375  ,  1.0390625 , ..., -0.94921875,
         0.46875   , -1.4453125 ]], dtype=float32)


k_mean_space
(20, 2)


array([[70.19581 , 87.99509 ],
       [76.530754, 88.40404 ],
       [79.88891 , 60.842384],
       [90.174385, 79.49366 ],
       [94.27402 , 77.8144  ],
       [70.09259 , 77.896996],
       [93.51333 , 77.203995],
       [72.173164, 82.92667 ],
       [79.77452 , 60.877377],
       [50.85949 , 78.37223 ],
       [50.599236, 78.232925],
       [58.61902 , 80.37633 ],
       [50.726765, 78.151985],
       [80.915474, 71.52224 ],
       [80.046364, 51.61431 ],
       [75.29008 , 64.43858 ],
       [77.23118 , 60.19432 ],
       [80.23102 , 61.80301 ],
       [79.99181 , 51.45843 ],
       [79.97512 , 51.55168 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-148.24088383, -224.38911057])


closest
(2,)


array([10, 18])


last_tok_logits
torch.Size([20, 32064])


tensor([[-1.2266, -1.7578, -3.4375,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.2988,  0.2295, -3.2344,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.7656,  1.1328, -1.7500,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 2.3438,  1.7969, -0.4004,  ...,  0.0000,  0.0000,  0.0000],
        [-4.1562, -2.1875, -1.9297,  ...,  0.0000,  0.0000,  0.0000],
        [-4.4375, -2.5938, -2.8906,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 1.8554e-07, 4.9444e-09,  ..., 4.3596e-28, 1.6038e-28,
         8.5846e-29],
        [9.9975e-01, 1.2338e-04, 6.6040e-05,  ..., 9.7926e-22, 4.6257e-22,
         2.4760e-22],
        [1.0000e+00, 5.0435e-07, 1.1254e-07,  ..., 7.3317e-26, 2.6972e-26,
         1.6359e-26],
        ...,
        [1.0000e+00, 9.4224e-07, 7.7344e-08,  ..., 1.6359e-26, 9.9224e-27,
         9.9224e-27],
        [9.9997e-01, 2.1445e-05, 4.7850e-06,  ..., 1.8908e-24, 1.0121e-24,
         8.9316e-25],
        [9.9996e-01, 4.0064e-05, 9.4221e-07,  ..., 8.9315e-25, 4.7807e-25,
         4.2189e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9998, 0.9999, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([23])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 15, 16,
        16, 16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([23, 98])


tensor([[    1, 32010,  1724,  ...,  7800,   313, 29906],
        [    1, 32010,  1724,  ..., 29953,  7800,   313],
        [    1, 32010,  1724,  ..., 29896, 29941, 29889],
        ...,
        [    1, 32010,  1724,  ..., 29896, 29941, 29889],
        [    1, 32010,  1724,  ..., 29871, 29896, 29941],
        [    1, 32010,  1724,  ..., 29871, 29896, 29941]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([23])


tensor([-12.3669, -14.4458, -16.4439, -18.5021, -19.1271, -19.2521, -18.0769,
        -20.0769, -17.6978, -19.3226, -20.6020, -21.7270, -20.4476, -20.2558,
        -17.9852, -18.7739, -18.7739, -19.4481, -19.4481, -19.4481, -18.9744,
        -19.3711, -19.7327], device='cuda:0')


new_candidate_toks
torch.Size([23, 1])


tensor([[29906],
        [29906],
        [29953],
        [  313],
        [ 2699],
        [29900],
        [ 2699],
        [29900],
        [29953],
        [29941],
        [29946],
        [29906],
        [29941],
        [ 7800],
        [29889],
        [20052],
        [29892],
        [ 7800],
        [29892],
        [20052],
        [29953],
        [29889],
        [29889]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([23])


tensor([-2.3842e-07, -2.4530e-04, -5.9605e-07, -1.2875e-05, -4.7684e-07,
        -1.1274e-02,  0.0000e+00, -1.5497e-06, -9.5367e-07, -8.6335e-03,
        -6.2744e-02, -2.7070e-04, -3.8080e-02, -1.3871e-03, -1.3471e-05,
        -3.3644e-01, -1.4614e+00, -6.6655e-01, -1.4166e+00, -1.6666e+00,
        -1.0729e-06, -2.8134e-05, -4.1486e-05], device='cuda:0')


new_candidates
torch.Size([23, 99])


tensor([[    1, 32010,  1724,  ...,   313, 29906, 29906],
        [    1, 32010,  1724,  ...,  7800,   313, 29906],
        [    1, 32010,  1724,  ..., 29941, 29889, 29953],
        ...,
        [    1, 32010,  1724,  ..., 29941, 29889, 29953],
        [    1, 32010,  1724,  ..., 29896, 29941, 29889],
        [    1, 32010,  1724,  ..., 29896, 29941, 29889]], device='cuda:0')


new_candidate_logprobs
torch.Size([23])


tensor([-12.3669, -14.4460, -16.4439, -18.5021, -19.1271, -19.2634, -18.0769,
        -20.0769, -17.6978, -19.3312, -20.6648, -21.7273, -20.4857, -20.2572,
        -17.9852, -19.1104, -20.2354, -20.1147, -20.8647, -21.1147, -18.9744,
        -19.3711, -19.7328], device='cuda:0')

infer end: GPU memory used: 21053 MB.
event: level
id: 87
data: [{"content": "2", "parent": 0, "prob": -12.366873741149902}, {"content": "2", "parent": 1, "prob": -14.446040153503418}, {"content": "6", "parent": 2, "prob": -16.443946838378906}, {"content": "(", "parent": 3, "prob": -18.502094268798828}, {"content": "eters", "parent": 4, "prob": -19.1270809173584}, {"content": "0", "parent": 5, "prob": -19.263355255126953}, {"content": "eters", "parent": 6, "prob": -18.07689094543457}, {"content": "0", "parent": 7, "prob": -20.076894760131836}, {"content": "6", "parent": 8, "prob": -17.69783592224121}, {"content": "3", "parent": 9, "prob": -19.33121109008789}, {"content": "4", "parent": 10, "prob": -20.664785385131836}, {"content": "2", "parent": 11, "prob": -21.727312088012695}, {"content": "3", "parent": 12, "prob": -20.485660552978516}, {"content": "miles", "parent": 13, "prob": -20.2572021484375}, {"content": ".", "parent": 14, "prob": -17.985183715820312}, {"content": "kilom", "par

array([[ 0.8359375 , -1.078125  ,  2.0625    , ..., -2.15625   ,
         0.51953125, -0.49023438],
       [-0.05053711, -0.86328125,  1.46875   , ...,  3.515625  ,
        -1.7421875 , -3.5       ],
       [ 1.171875  ,  1.4765625 ,  2.46875   , ..., -3.09375   ,
         0.48046875, -1.25      ],
       ...,
       [-0.59375   , -2.34375   ,  1.765625  , ..., -0.59375   ,
         0.43554688,  0.11767578],
       [-2.359375  , -1.7109375 ,  2.        , ...,  0.10888672,
        -1.6328125 , -3.125     ],
       [-0.74609375, -1.1015625 ,  0.74609375, ..., -0.8125    ,
         0.08544922, -1.6875    ]], dtype=float32)


k_mean_space
(20, 2)


array([[62.52524 , 88.34422 ],
       [78.66881 , 95.94651 ],
       [61.605827, 92.09254 ],
       [86.87479 , 82.95901 ],
       [88.039856, 45.913548],
       [66.75895 , 96.2044  ],
       [88.402916, 46.444508],
       [71.6161  , 95.091606],
       [61.721783, 92.12282 ],
       [59.50988 , 93.04526 ],
       [57.405552, 88.904816],
       [57.600063, 90.294365],
       [59.640617, 93.04605 ],
       [88.96529 , 49.124325],
       [71.99561 , 96.98804 ],
       [89.28318 , 63.28686 ],
       [69.20704 , 95.91768 ],
       [87.65757 , 45.54295 ],
       [66.49971 , 95.040565],
       [90.333694, 63.49085 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-241.58914375, -136.30298615])


closest
(2,)


array([10, 17])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 0.6094, -0.5586, -1.2969,  ...,  0.0000,  0.0000,  0.0000],
        [-1.0547, -1.1562, -3.4688,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.9844, -0.0237, -0.6875,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-1.1250, -2.4062, -2.2344,  ...,  0.0000,  0.0000,  0.0000],
        [-1.3125,  1.2188, -1.5078,  ...,  0.0000,  0.0000,  0.0000],
        [-2.0312, -1.8906, -1.5391,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.9407e-01, 5.9109e-03, 1.1411e-05,  ..., 2.7349e-24, 1.8797e-24,
         7.8356e-25],
        [1.0000e+00, 1.2752e-07, 7.1941e-09,  ..., 1.0458e-27, 4.9401e-28,
         2.9963e-28],
        [9.9794e-01, 1.7001e-03, 2.9544e-04,  ..., 1.4990e-22, 1.0303e-22,
         2.0287e-23],
        ...,
        [9.9994e-01, 5.1442e-05, 3.2886e-06,  ..., 3.5324e-24, 1.0121e-24,
         8.9314e-25],
        [9.9258e-01, 6.6879e-03, 7.0490e-04,  ..., 6.2153e-23, 3.3268e-23,
         2.9359e-23],
        [1.0000e+00, 8.7642e-08, 6.0236e-08,  ..., 1.2088e-25, 2.6972e-26,
         5.3111e-27]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9941, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9979, 0.9996, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9926, 0.9993, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([23])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 10, 10, 11, 11, 12, 13, 14,
        15, 16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([23, 99])


tensor([[    1, 32010,  1724,  ...,   313, 29906, 29906],
        [    1, 32010,  1724,  ...,  7800,   313, 29906],
        [    1, 32010,  1724,  ..., 29941, 29889, 29953],
        ...,
        [    1, 32010,  1724,  ..., 29896, 29946,  7800],
        [    1, 32010,  1724,  ..., 29896, 29946, 29892],
        [    1, 32010,  1724,  ..., 29896, 29946, 20052]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([23])


tensor([-12.3669, -14.4460, -16.4439, -18.5021, -19.1271, -19.2634, -18.0769,
        -20.0769, -17.6978, -19.3312, -20.6648, -20.6648, -20.6648, -21.7273,
        -21.7273, -20.4857, -20.2572, -17.9852, -19.1104, -20.2354, -20.1147,
        -20.8647, -21.1147], device='cuda:0')


new_candidate_toks
torch.Size([23, 1])


tensor([[20052],
        [29906],
        [ 7800],
        [29906],
        [  313],
        [29900],
        [  313],
        [29900],
        [ 7800],
        [29889],
        [ 7800],
        [20052],
        [29892],
        [20052],
        [29892],
        [29889],
        [  313],
        [29953],
        [ 2699],
        [29900],
        [  313],
        [29900],
        [ 2699]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([23])


tensor([-5.9519e-03, -1.1921e-07, -2.0608e-03, -1.8086e-02, -3.2955e-04,
        -1.1921e-07, -1.1826e-04,  0.0000e+00, -1.4759e-03, -1.1563e-05,
        -4.5856e-01, -1.5836e+00, -2.2086e+00, -1.4118e-01, -2.5162e+00,
        -6.6521e-05, -1.8914e-04, -8.3447e-07,  0.0000e+00, -1.5497e-06,
        -5.6030e-05, -7.4503e-03, -3.5763e-07], device='cuda:0')


new_candidates
torch.Size([23, 100])


tensor([[    1, 32010,  1724,  ..., 29906, 29906, 20052],
        [    1, 32010,  1724,  ...,   313, 29906, 29906],
        [    1, 32010,  1724,  ..., 29889, 29953,  7800],
        ...,
        [    1, 32010,  1724,  ..., 29946,  7800,   313],
        [    1, 32010,  1724,  ..., 29946, 29892, 29900],
        [    1, 32010,  1724,  ..., 29946, 20052,  2699]], device='cuda:0')


new_candidate_logprobs
torch.Size([23])


tensor([-12.3728, -14.4460, -16.4460, -18.5202, -19.1274, -19.2634, -18.0770,
        -20.0769, -17.6993, -19.3312, -21.1233, -22.2483, -22.8733, -21.8685,
        -24.2435, -20.4857, -20.2574, -17.9852, -19.1104, -20.2354, -20.1147,
        -20.8721, -21.1147], device='cuda:0')

infer end: GPU memory used: 21297 MB.
event: level
id: 88
data: [{"content": "kilom", "parent": 0, "prob": -12.372825622558594}, {"content": "2", "parent": 1, "prob": -14.446040153503418}, {"content": "miles", "parent": 2, "prob": -16.446006774902344}, {"content": "2", "parent": 3, "prob": -18.520179748535156}, {"content": "(", "parent": 4, "prob": -19.127410888671875}, {"content": "0", "parent": 5, "prob": -19.263355255126953}, {"content": "(", "parent": 6, "prob": -18.077009201049805}, {"content": "0", "parent": 7, "prob": -20.076894760131836}, {"content": "miles", "parent": 8, "prob": -17.699312210083008}, {"content": ".", "parent": 9, "prob": -19.331222534179688}, {"content": "miles", "parent": 10, "prob": -21.123348236083984}, {"content": "kilom", "parent": 10, "prob": -22.248348236083984}, {"content": ",", "parent": 10, "prob": -22.873348236083984}, {"content": "kilom", "parent": 11, "prob": -21.868488311767578}, {"content": ",", "parent": 11, "prob": -24.243488311767578}, {"cont

array([[ 0.28125   , -1.7578125 ,  1.40625   , ..., -2.703125  ,
         1.15625   ,  0.72265625],
       [ 0.98046875, -1.3359375 ,  2.125     , ..., -2.265625  ,
         0.53515625, -0.60546875],
       [ 0.31640625, -1.9140625 ,  1.1015625 , ..., -0.06542969,
         0.38671875,  0.03881836],
       ...,
       [ 1.1328125 ,  1.484375  ,  2.609375  , ..., -3.140625  ,
         0.44335938, -1.2421875 ],
       [-0.14160156, -1.9375    ,  3.1875    , ..., -0.71484375,
         0.02294922, -0.70703125],
       [ 0.27148438, -1.3046875 ,  1.8125    , ..., -0.16015625,
        -0.94140625, -2.828125  ]], dtype=float32)


k_mean_space
(20, 2)


array([[72.62072 , 89.35963 ],
       [84.58275 , 68.226326],
       [52.461292, 87.94771 ],
       [90.04273 , 72.67655 ],
       [74.03196 , 87.69619 ],
       [84.21491 , 63.33756 ],
       [74.56532 , 88.08601 ],
       [85.37686 , 63.97557 ],
       [52.550564, 88.09005 ],
       [92.01649 , 64.98016 ],
       [53.101685, 87.12991 ],
       [63.7515  , 89.22083 ],
       [90.988235, 65.05233 ],
       [62.97823 , 88.22508 ],
       [92.36382 , 64.33003 ],
       [92.04116 , 65.14134 ],
       [75.55863 , 85.54284 ],
       [88.219894, 68.74729 ],
       [58.211166, 88.18416 ],
       [91.511894, 66.04975 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-188.33051109, -197.46081257])


closest
(2,)


array([2, 5])


last_tok_logits
torch.Size([20, 32064])


tensor([[-2.0312e+00, -4.5938e+00, -1.3672e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 9.6680e-02, -5.6250e-01, -1.5859e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-1.3672e+00, -3.2500e+00, -3.4219e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 1.5156e+00,  1.4038e-03, -2.9688e-01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-5.6152e-02, -2.9102e-01, -2.6562e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-2.7344e+00, -2.7188e+00, -1.7656e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 1.1254e-07, 1.3440e-08,  ..., 2.1426e-24, 1.4726e-24,
         1.1469e-24],
        [9.9327e-01, 6.6926e-03, 1.8798e-05,  ..., 3.5089e-24, 2.1282e-24,
         1.0053e-24],
        [9.9998e-01, 1.4739e-05, 4.2228e-06,  ..., 2.7512e-24, 2.4279e-24,
         6.1387e-25],
        ...,
        [9.9854e-01, 1.0318e-03, 3.7957e-04,  ..., 1.0309e-22, 7.0851e-23,
         2.0299e-23],
        [9.9977e-01, 2.0342e-04, 1.8921e-05,  ..., 1.2993e-24, 6.1374e-25,
         1.3694e-25],
        [1.0000e+00, 4.0587e-10, 3.8128e-10,  ..., 4.0030e-24, 2.7512e-24,
         6.1388e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9933, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9985, 0.9996, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [0.9998, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([22])


tensor([ 0,  1,  2,  3,  4,  4,  5,  6,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
        16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([22, 100])


tensor([[    1, 32010,  1724,  ..., 29906, 29906, 20052],
        [    1, 32010,  1724,  ...,   313, 29906, 29906],
        [    1, 32010,  1724,  ..., 29889, 29953,  7800],
        ...,
        [    1, 32010,  1724,  ..., 29941, 29889, 29953],
        [    1, 32010,  1724,  ..., 29906, 20052,  2699],
        [    1, 32010,  1724,  ..., 29906, 29892, 29900]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([22])


tensor([-12.3728, -14.4460, -16.4460, -18.5202, -19.1274, -19.1274, -19.2634,
        -18.0770, -18.0770, -20.0769, -17.6993, -19.3312, -21.1233, -22.2483,
        -22.8733, -21.8685, -24.2435, -20.4857, -20.2574, -17.9852, -19.1104,
        -20.2354], device='cuda:0')


new_candidate_toks
torch.Size([22, 1])


tensor([[ 2699],
        [20052],
        [  313],
        [29906],
        [12717],
        [29947],
        [29900],
        [29896],
        [12717],
        [29900],
        [  313],
        [29953],
        [  313],
        [ 2699],
        [29900],
        [ 2699],
        [29900],
        [29953],
        [29906],
        [ 7800],
        [  313],
        [29900]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([22])


tensor([-1.1921e-07, -6.7529e-03, -1.9193e-05, -3.4094e-05, -2.2410e-01,
        -1.9741e+00, -2.9803e-05, -4.3988e-01, -1.1899e+00, -1.3232e-05,
        -2.2531e-05, -1.9074e-06, -1.8835e-05, -1.1921e-07, -5.5571e-03,
         0.0000e+00, -9.5367e-07, -1.9074e-06, -2.3264e-03, -1.4648e-03,
        -2.2599e-04,  0.0000e+00], device='cuda:0')


new_candidates
torch.Size([22, 101])


tensor([[    1, 32010,  1724,  ..., 29906, 20052,  2699],
        [    1, 32010,  1724,  ..., 29906, 29906, 20052],
        [    1, 32010,  1724,  ..., 29953,  7800,   313],
        ...,
        [    1, 32010,  1724,  ..., 29889, 29953,  7800],
        [    1, 32010,  1724,  ..., 20052,  2699,   313],
        [    1, 32010,  1724,  ..., 29892, 29900, 29900]], device='cuda:0')


new_candidate_logprobs
torch.Size([22])


tensor([-12.3728, -14.4528, -16.4460, -18.5202, -19.3515, -21.1015, -19.2634,
        -18.5169, -19.2669, -20.0769, -17.6993, -19.3312, -21.1234, -22.2483,
        -22.8789, -21.8685, -24.2435, -20.4857, -20.2597, -17.9866, -19.1106,
        -20.2354], device='cuda:0')

infer end: GPU memory used: 21543 MB.
event: level
id: 89
data: [{"content": "eters", "parent": 0, "prob": -12.372825622558594}, {"content": "kilom", "parent": 1, "prob": -14.45279312133789}, {"content": "(", "parent": 2, "prob": -16.446025848388672}, {"content": "2", "parent": 3, "prob": -18.520214080810547}, {"content": "about", "parent": 4, "prob": -19.35150718688965}, {"content": "8", "parent": 4, "prob": -21.10150718688965}, {"content": "0", "parent": 5, "prob": -19.263385772705078}, {"content": "1", "parent": 6, "prob": -18.516889572143555}, {"content": "about", "parent": 6, "prob": -19.266889572143555}, {"content": "0", "parent": 7, "prob": -20.076908111572266}, {"content": "(", "parent": 8, "prob": -17.6993350982666}, {"content": "6", "parent": 9, "prob": -19.33122444152832}, {"content": "(", "parent": 10, "prob": -21.123367309570312}, {"content": "eters", "parent": 11, "prob": -22.248348236083984}, {"content": "0", "parent": 12, "prob": -22.87890625}, {"content": "eters", "par

array([[-0.76953125, -0.51171875,  0.890625  , ..., -2.203125  ,
        -0.07226562,  0.72265625],
       [ 0.03344727, -1.875     ,  1.5234375 , ..., -2.640625  ,
         0.97265625,  0.32421875],
       [-0.31054688, -0.7578125 , -1.578125  , ...,  0.94921875,
        -0.10595703, -0.78125   ],
       ...,
       [ 1.2578125 ,  1.4375    ,  2.5       , ..., -3.09375   ,
         0.49804688, -1.0546875 ],
       [-0.09521484, -0.85546875,  1.3671875 , ...,  3.515625  ,
        -1.53125   , -3.328125  ],
       [ 0.03613281, -1.8671875 ,  1.1640625 , ...,  0.13574219,
         0.19824219, -0.07275391]], dtype=float32)


k_mean_space
(20, 2)


array([[83.40513 , 63.396046],
       [89.332664, 75.93184 ],
       [57.970253, 87.3275  ],
       [78.75251 , 67.866295],
       [65.32636 , 80.04684 ],
       [83.03747 , 76.15888 ],
       [81.32477 , 63.659286],
       [68.84296 , 81.679565],
       [67.32776 , 81.25012 ],
       [82.48736 , 63.363598],
       [58.172596, 87.6201  ],
       [66.650505, 79.502556],
       [61.077217, 85.99125 ],
       [85.69486 , 57.625393],
       [75.5253  , 81.19869 ],
       [86.11591 , 57.728626],
       [82.96196 , 76.46285 ],
       [66.602005, 79.54593 ],
       [71.491264, 88.63738 ],
       [85.31686 , 62.021008]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-195.35959244, -192.13460922])


closest
(2,)


array([ 2, 13])


last_tok_logits
torch.Size([20, 32064])


tensor([[-1.5391, -4.0625, -4.3125,  ...,  0.0000,  0.0000,  0.0000],
        [-2.1719, -4.7188, -2.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0474, -0.1953, -3.7969,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 1.6875,  0.3242, -0.1699,  ...,  0.0000,  0.0000,  0.0000],
        [-0.8438, -0.7891, -3.5312,  ...,  0.0000,  0.0000,  0.0000],
        [-1.5781, -3.2656, -3.4219,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 1.9947e-06, 3.0590e-07,  ..., 2.2140e-27, 1.7243e-27,
         1.3429e-27],
        [1.0000e+00, 6.8256e-08, 7.1941e-09,  ..., 1.0121e-24, 5.4175e-25,
         4.7809e-25],
        [9.9895e-01, 4.3029e-04, 4.3029e-04,  ..., 4.9691e-21, 4.3852e-21,
         8.6349e-22],
        ...,
        [9.9861e-01, 1.0319e-03, 2.9563e-04,  ..., 1.3238e-22, 9.0981e-23,
         2.3004e-23],
        [1.0000e+00, 6.4759e-07, 2.2159e-08,  ..., 2.6972e-26, 1.6359e-26,
         8.7565e-27],
        [9.9996e-01, 3.5356e-05, 4.2227e-06,  ..., 2.7511e-24, 1.2995e-24,
         6.9559e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9989, 0.9994, 0.9998,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9986, 0.9996, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([20])


tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19], device='cuda:0')


carryover_candidates
torch.Size([20, 101])


tensor([[    1, 32010,  1724,  ..., 29906, 20052,  2699],
        [    1, 32010,  1724,  ..., 29906, 29906, 20052],
        [    1, 32010,  1724,  ..., 29953,  7800,   313],
        ...,
        [    1, 32010,  1724,  ..., 29941, 29889, 29953],
        [    1, 32010,  1724,  ...,  7800,   313, 29906],
        [    1, 32010,  1724,  ..., 29889, 29953,  7800]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([20])


tensor([-12.3728, -14.4528, -16.4460, -18.5202, -19.3515, -21.1015, -19.2634,
        -18.5169, -19.2669, -20.0769, -17.6993, -19.3312, -21.1234, -22.2483,
        -22.8789, -21.8685, -24.2435, -20.4857, -20.2597, -17.9866],
       device='cuda:0')


new_candidate_toks
torch.Size([20, 1])


tensor([[29897],
        [ 2699],
        [29906],
        [20052],
        [29871],
        [29889],
        [27881],
        [29941],
        [29871],
        [27881],
        [29906],
        [ 7800],
        [29906],
        [  313],
        [29900],
        [  313],
        [29900],
        [ 7800],
        [29906],
        [  313]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([20])


tensor([-2.9802e-06, -1.1921e-07, -1.0550e-03, -1.1286e-02, -1.1921e-07,
        -1.9431e-05, -1.5439e-03, -9.1666e-04,  0.0000e+00, -2.9364e-02,
        -1.8216e-03, -1.4857e-03, -1.6483e-02, -5.7964e-04,  0.0000e+00,
        -1.8104e-04,  0.0000e+00, -1.3917e-03, -5.9605e-07, -4.0413e-05],
       device='cuda:0')


new_candidates
torch.Size([20, 102])


tensor([[    1, 32010,  1724,  ..., 20052,  2699, 29897],
        [    1, 32010,  1724,  ..., 29906, 20052,  2699],
        [    1, 32010,  1724,  ...,  7800,   313, 29906],
        ...,
        [    1, 32010,  1724,  ..., 29889, 29953,  7800],
        [    1, 32010,  1724,  ...,   313, 29906, 29906],
        [    1, 32010,  1724,  ..., 29953,  7800,   313]], device='cuda:0')


new_candidate_logprobs
torch.Size([20])


tensor([-12.3728, -14.4528, -16.4471, -18.5315, -19.3515, -21.1015, -19.2649,
        -18.5178, -19.2669, -20.1063, -17.7012, -19.3327, -21.1399, -22.2489,
        -22.8789, -21.8687, -24.2435, -20.4871, -20.2597, -17.9867],
       device='cuda:0')

infer end: GPU memory used: 21791 MB.
event: level
id: 90
data: [{"content": ")", "parent": 0, "prob": -12.372828483581543}, {"content": "eters", "parent": 1, "prob": -14.45279312133789}, {"content": "2", "parent": 2, "prob": -16.447080612182617}, {"content": "kilom", "parent": 3, "prob": -18.5314998626709}, {"content": "", "parent": 4, "prob": -19.35150718688965}, {"content": ".", "parent": 5, "prob": -21.101526260375977}, {"content": "meters", "parent": 6, "prob": -19.264928817749023}, {"content": "3", "parent": 7, "prob": -18.517807006835938}, {"content": "", "parent": 8, "prob": -19.266889572143555}, {"content": "meters", "parent": 9, "prob": -20.106271743774414}, {"content": "2", "parent": 10, "prob": -17.701156616210938}, {"content": "miles", "parent": 11, "prob": -19.33271026611328}, {"content": "2", "parent": 12, "prob": -21.139850616455078}, {"content": "(", "parent": 13, "prob": -22.24892807006836}, {"content": "0", "parent": 14, "prob": -22.87890625}, {"content": "(", "paren

array([[-0.29882812, -0.83203125, -0.20800781, ..., -2.234375  ,
         0.51171875,  0.85546875],
       [-0.28125   , -0.703125  ,  0.6015625 , ..., -2.0625    ,
        -0.04467773,  0.70703125],
       [ 0.09667969, -0.78515625,  1.46875   , ...,  3.46875   ,
        -1.6640625 , -2.953125  ],
       ...,
       [ 0.10742188, -1.84375   ,  0.953125  , ..., -0.2265625 ,
         0.3515625 , -0.09130859],
       [ 0.890625  , -1.0703125 ,  2.125     , ..., -2.390625  ,
         0.60546875, -0.7734375 ],
       [-0.42382812, -0.88671875, -1.421875  , ...,  1.0234375 ,
        -0.11474609, -0.765625  ]], dtype=float32)


k_mean_space
(20, 2)


array([[92.28739 , 62.607265],
       [89.57332 , 59.246838],
       [61.676414, 86.868256],
       [95.07906 , 76.72927 ],
       [55.779762, 82.592964],
       [83.959564, 80.855835],
       [87.35109 , 54.500183],
       [89.80364 , 76.070915],
       [57.90948 , 80.62302 ],
       [87.98993 , 54.615604],
       [61.40742 , 86.66236 ],
       [91.26379 , 56.642696],
       [71.09598 , 84.464935],
       [55.747036, 84.21915 ],
       [83.953415, 70.47067 ],
       [59.50118 , 83.567924],
       [88.37857 , 70.26779 ],
       [91.38267 , 56.62191 ],
       [84.17071 , 73.33616 ],
       [59.223976, 85.208405]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-156.0107708 , -231.54960155])


closest
(2,)


array([13,  6])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 1.6328, -2.3750, -6.8438,  ...,  0.0000,  0.0000,  0.0000],
        [-1.4062, -4.3750, -4.3438,  ...,  0.0000,  0.0000,  0.0000],
        [-1.1875, -1.9453, -3.7031,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-1.0156, -3.0469, -3.5781,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0503,  0.1621, -1.0312,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2793, -0.3066, -3.5781,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[8.5181e-01, 1.4802e-01, 6.3760e-05,  ..., 9.9649e-23, 8.7940e-23,
         7.7607e-23],
        [9.9996e-01, 2.7535e-05, 1.1478e-05,  ..., 1.9538e-27, 7.1875e-28,
         7.1875e-28],
        [1.0000e+00, 1.8554e-07, 2.5110e-08,  ..., 2.6972e-26, 1.8538e-26,
         8.7565e-27],
        ...,
        [9.9997e-01, 3.1202e-05, 2.5612e-06,  ..., 1.8908e-24, 7.8821e-25,
         5.4173e-25],
        [9.9473e-01, 5.2199e-03, 2.1332e-05,  ..., 8.4297e-24, 8.4297e-24,
         1.2927e-24],
        [9.9952e-01, 1.7947e-04, 1.7947e-04,  ..., 3.0156e-21, 1.8291e-21,
         5.2404e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.8518, 0.9998, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9947, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9995, 0.9997, 0.9999,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([24])


tensor([ 0,  0,  1,  2,  3,  4,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 13, 14,
        15, 15, 16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([24, 102])


tensor([[    1, 32010,  1724,  ..., 20052,  2699, 29897],
        [    1, 32010,  1724,  ..., 20052,  2699, 29897],
        [    1, 32010,  1724,  ..., 29906, 20052,  2699],
        ...,
        [    1, 32010,  1724,  ..., 29889, 29953,  7800],
        [    1, 32010,  1724,  ...,   313, 29906, 29906],
        [    1, 32010,  1724,  ..., 29953,  7800,   313]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([24])


tensor([-12.3728, -12.3728, -14.4528, -16.4471, -18.5315, -19.3515, -19.3515,
        -21.1015, -19.2649, -18.5178, -19.2669, -20.1063, -17.7012, -19.3327,
        -21.1399, -22.2489, -22.2489, -22.8789, -21.8687, -21.8687, -24.2435,
        -20.4871, -20.2597, -17.9867], device='cuda:0')


new_candidate_toks
torch.Size([24, 1])


tensor([[ 1880],
        [15655],
        [29897],
        [29906],
        [ 2699],
        [29947],
        [29929],
        [29955],
        [  313],
        [29889],
        [29896],
        [  313],
        [29906],
        [  313],
        [29906],
        [12717],
        [29947],
        [29900],
        [29896],
        [12717],
        [29900],
        [  313],
        [20052],
        [29906]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([24])


tensor([-1.6039e-01, -1.9104e+00, -3.9459e-05, -2.3842e-07,  0.0000e+00,
        -3.2662e-01, -1.5766e+00, -1.1066e-02, -1.9239e-03, -1.7996e-04,
        -2.1887e-03, -1.1062e-03, -2.3842e-07, -2.9207e-05, -7.8562e-05,
        -1.9639e-01, -2.1964e+00, -6.6344e-04, -3.0531e-01, -1.5553e+00,
        -8.1063e-06, -3.4691e-05, -5.2825e-03, -4.8363e-04], device='cuda:0')


new_candidates
torch.Size([24, 103])


tensor([[    1, 32010,  1724,  ...,  2699, 29897,  1880],
        [    1, 32010,  1724,  ...,  2699, 29897, 15655],
        [    1, 32010,  1724,  ..., 20052,  2699, 29897],
        ...,
        [    1, 32010,  1724,  ..., 29953,  7800,   313],
        [    1, 32010,  1724,  ..., 29906, 29906, 20052],
        [    1, 32010,  1724,  ...,  7800,   313, 29906]], device='cuda:0')


new_candidate_logprobs
torch.Size([24])


tensor([-12.5332, -14.2832, -14.4528, -16.4471, -18.5315, -19.6781, -20.9281,
        -21.1126, -19.2669, -18.5180, -19.2691, -20.1074, -17.7012, -19.3327,
        -21.1399, -22.4453, -24.4453, -22.8796, -22.1740, -23.4240, -24.2435,
        -20.4872, -20.2650, -17.9872], device='cuda:0')

infer end: GPU memory used: 22041 MB.
event: level
id: 91
data: [{"content": "high", "parent": 0, "prob": -12.533217430114746}, {"content": "tall", "parent": 0, "prob": -14.283217430114746}, {"content": ")", "parent": 1, "prob": -14.452832221984863}, {"content": "2", "parent": 2, "prob": -16.447080612182617}, {"content": "eters", "parent": 3, "prob": -18.5314998626709}, {"content": "8", "parent": 4, "prob": -19.678131103515625}, {"content": "9", "parent": 4, "prob": -20.928131103515625}, {"content": "7", "parent": 5, "prob": -21.112592697143555}, {"content": "(", "parent": 6, "prob": -19.26685333251953}, {"content": ".", "parent": 7, "prob": -18.517986297607422}, {"content": "1", "parent": 8, "prob": -19.269079208374023}, {"content": "(", "parent": 9, "prob": -20.107378005981445}, {"content": "2", "parent": 10, "prob": -17.701156616210938}, {"content": "(", "parent": 11, "prob": -19.332738876342773}, {"content": "2", "parent": 12, "prob": -21.139928817749023}, {"content": "about", "par

array([[-0.75      , -1.5703125 ,  2.71875   , ..., -0.62109375,
         0.36132812, -0.04296875],
       [-0.0201416 , -1.890625  ,  2.140625  , ...,  0.3125    ,
         0.8046875 , -0.30273438],
       [-0.11523438, -0.7109375 ,  0.44140625, ..., -1.9296875 ,
         0.46484375,  1.0390625 ],
       ...,
       [-1.0078125 , -1.3515625 , -0.05273438, ..., -2.25      ,
        -0.97265625,  0.14941406],
       [-0.625     , -0.47265625,  1.3515625 , ...,  3.25      ,
        -0.36914062, -2.90625   ],
       [-1.734375  , -0.20703125,  0.23242188, ..., -0.90625   ,
         0.5703125 ,  1.265625  ]], dtype=float32)


k_mean_space
(20, 2)


array([[80.569016, 64.36727 ],
       [79.80171 , 63.791183],
       [78.30231 , 78.66023 ],
       [57.819942, 85.69732 ],
       [69.47014 , 71.93241 ],
       [62.287926, 88.38257 ],
       [60.610703, 85.49282 ],
       [64.32345 , 87.67045 ],
       [83.46856 , 59.772476],
       [77.03949 , 91.687996],
       [63.79626 , 78.54431 ],
       [82.75657 , 60.696804],
       [57.973343, 85.85482 ],
       [84.57064 , 66.83525 ],
       [57.691784, 86.97792 ],
       [78.501564, 55.38022 ],
       [61.2489  , 88.47769 ],
       [70.14357 , 88.30052 ],
       [66.821335, 81.12349 ],
       [79.95766 , 56.895042]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-257.27727795, -131.39269447])


closest
(2,)


array([14, 15])


last_tok_logits
torch.Size([20, 32064])


tensor([[  0.9219,  -0.0933, -10.3750,  ...,   0.0000,   0.0000,   0.0000],
        [  0.9922,   2.6406,  -9.1875,  ...,   0.0000,   0.0000,   0.0000],
        [  0.3535,  -3.2188,  -6.5938,  ...,   0.0000,   0.0000,   0.0000],
        ...,
        [ -0.1030,  -2.5156,   1.0859,  ...,   0.0000,   0.0000,   0.0000],
        [ -2.2500,  -3.2344,  -2.8281,  ...,   0.0000,   0.0000,   0.0000],
        [  2.2812, -10.1250,  -3.4062,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.1415e-01, 7.5038e-02, 3.7359e-03,  ..., 3.2940e-22, 2.5654e-22,
         1.6400e-23],
        [6.7606e-01, 3.1935e-01, 2.7629e-03,  ..., 1.2129e-23, 3.4749e-24,
         2.1076e-24],
        [9.2117e-01, 6.6729e-02, 1.1596e-02,  ..., 3.7613e-22, 2.9293e-22,
         9.5100e-23],
        ...,
        [9.9833e-01, 4.8728e-04, 4.8728e-04,  ..., 3.1747e-22, 2.8016e-22,
         7.0836e-23],
        [9.9908e-01, 9.1105e-04, 2.2583e-06,  ..., 3.6001e-22, 3.1771e-22,
         1.2442e-22],
        [1.0000e+00, 3.6535e-08, 9.2374e-09,  ..., 8.3079e-26, 6.4702e-26,
         6.4702e-26]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9142, 0.9892, 0.9929,  ..., 1.0000, 1.0000, 1.0000],
        [0.6761, 0.9954, 0.9982,  ..., 1.0000, 1.0000, 1.0000],
        [0.9212, 0.9879, 0.9995,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9983, 0.9988, 0.9993,  ..., 1.0000, 1.0000, 1.0000],
        [0.9991, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([25])


tensor([ 0,  1,  1,  2,  3,  4,  5,  6,  7,  8,  8,  9, 10, 11, 11, 11, 11, 12,
        13, 14, 15, 16, 17, 18, 19], device='cuda:0')


carryover_candidates
torch.Size([25, 103])


tensor([[    1, 32010,  1724,  ...,  2699, 29897,  1880],
        [    1, 32010,  1724,  ...,  2699, 29897, 15655],
        [    1, 32010,  1724,  ...,  2699, 29897, 15655],
        ...,
        [    1, 32010,  1724,  ..., 29900, 29900, 29900],
        [    1, 32010,  1724,  ...,  2699,   313, 29896],
        [    1, 32010,  1724,  ...,  2699,   313, 12717]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([25])


tensor([-12.5332, -14.2832, -14.2832, -14.4528, -16.4471, -18.5315, -19.6781,
        -20.9281, -21.1126, -19.2669, -19.2669, -18.5180, -19.2691, -20.1074,
        -20.1074, -20.1074, -20.1074, -17.7012, -19.3327, -21.1399, -22.4453,
        -24.4453, -22.8796, -22.1740, -23.4240], device='cuda:0')


new_candidate_toks
torch.Size([25, 1])


tensor([[29889],
        [29889],
        [29892],
        [ 1880],
        [20052],
        [29897],
        [29889],
        [ 7800],
        [ 7800],
        [29946],
        [12717],
        [29953],
        [29941],
        [29955],
        [12717],
        [ 9961],
        [11316],
        [20052],
        [29906],
        [20052],
        [29871],
        [29889],
        [27881],
        [29941],
        [29871]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([25])


tensor([-8.9759e-02, -3.9148e-01, -1.1415e+00, -8.2114e-02, -5.9787e-03,
        -5.5011e-02, -7.3910e-06, -1.6114e-02, -9.7009e-03, -5.0880e-01,
        -1.0088e+00, -6.5090e-05, -1.5491e-03, -8.2398e-01, -1.0740e+00,
        -2.1990e+00, -2.8240e+00, -6.7797e-03, -1.1062e-03, -1.4509e-02,
        -2.3842e-07, -6.6998e-05, -1.6682e-03, -9.1702e-04,  0.0000e+00],
       device='cuda:0')


new_candidates
torch.Size([25, 104])


tensor([[    1, 32010,  1724,  ..., 29897,  1880, 29889],
        [    1, 32010,  1724,  ..., 29897, 15655, 29889],
        [    1, 32010,  1724,  ..., 29897, 15655, 29892],
        ...,
        [    1, 32010,  1724,  ..., 29900, 29900, 27881],
        [    1, 32010,  1724,  ...,   313, 29896, 29941],
        [    1, 32010,  1724,  ...,   313, 12717, 29871]], device='cuda:0')


new_candidate_logprobs
torch.Size([25])


tensor([-12.6230, -14.6747, -15.4247, -14.5349, -16.4531, -18.5865, -19.6781,
        -20.9442, -21.1223, -19.7757, -20.2757, -18.5181, -19.2706, -20.9314,
        -21.1814, -22.3064, -22.9314, -17.7079, -19.3338, -21.1544, -22.4453,
        -24.4454, -22.8812, -22.1749, -23.4240], device='cuda:0')

infer end: GPU memory used: 22293 MB.
event: level
id: 92
data: [{"content": ".", "parent": 0, "prob": -12.622976303100586}, {"content": ".", "parent": 1, "prob": -14.674695014953613}, {"content": ",", "parent": 1, "prob": -15.424695014953613}, {"content": "high", "parent": 2, "prob": -14.53494644165039}, {"content": "kilom", "parent": 3, "prob": -16.453060150146484}, {"content": ")", "parent": 4, "prob": -18.586509704589844}, {"content": ".", "parent": 5, "prob": -19.678138732910156}, {"content": "miles", "parent": 6, "prob": -20.944244384765625}, {"content": "miles", "parent": 7, "prob": -21.12229347229004}, {"content": "4", "parent": 8, "prob": -19.775653839111328}, {"content": "about", "parent": 8, "prob": -20.275653839111328}, {"content": "6", "parent": 9, "prob": -18.518051147460938}, {"content": "3", "parent": 10, "prob": -19.270627975463867}, {"content": "7", "parent": 11, "prob": -20.93135643005371}, {"content": "about", "parent": 11, "prob": -21.18135643005371}, {"content": "

array([[-0.05126953,  0.51953125,  1.59375   , ...,  0.67578125,
        -2.09375   , -0.59765625],
       [-0.08154297,  0.51953125,  1.6015625 , ...,  0.671875  ,
        -2.15625   , -0.765625  ],
       [-1.59375   , -1.2578125 ,  0.73046875, ..., -2.1875    ,
        -0.01330566,  1.375     ],
       ...,
       [ 0.390625  , -1.609375  ,  1.7890625 , ..., -2.734375  ,
         1.1640625 ,  0.06201172],
       [ 0.07617188, -0.75      ,  1.34375   , ...,  3.375     ,
        -1.765625  , -3.015625  ],
       [ 0.19824219, -1.1328125 ,  1.28125   , ..., -2.5       ,
         1.3359375 , -0.5546875 ]], dtype=float32)


k_mean_space
(20, 2)


array([[65.38093 , 83.292496],
       [64.92045 , 82.789085],
       [66.908936, 87.28085 ],
       [67.70834 , 89.557365],
       [67.971695, 91.26024 ],
       [66.742615, 91.25531 ],
       [87.48237 , 65.88101 ],
       [59.77501 , 83.32359 ],
       [58.71761 , 82.00895 ],
       [85.81099 , 58.770805],
       [63.563   , 86.94976 ],
       [81.46994 , 65.53704 ],
       [84.22451 , 64.49955 ],
       [85.185425, 59.32375 ],
       [64.65942 , 88.08349 ],
       [78.429146, 90.04218 ],
       [62.679523, 84.316414],
       [68.06595 , 91.252655],
       [88.86556 , 68.0943  ],
       [68.51476 , 90.26077 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-259.92051506, -117.50767326])


closest
(2,)


array([8, 9])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 3.5781, -6.5000,  5.6250,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.3438, -6.3125,  5.5312,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.3125,  1.9062, -3.4531,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-1.6484, -4.9062, -2.7031,  ...,  0.0000,  0.0000,  0.0000],
        [-1.1641, -1.6953, -3.5469,  ...,  0.0000,  0.0000,  0.0000],
        [-2.9062, -5.4062, -2.4219,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[2.5452e-01, 1.9822e-01, 1.7493e-01,  ..., 2.5430e-20, 1.2012e-20,
         5.0074e-21],
        [2.4654e-01, 2.4654e-01, 1.6944e-01,  ..., 6.2901e-20, 2.4632e-20,
         1.0268e-20],
        [7.1672e-01, 7.5542e-02, 4.5819e-02,  ..., 1.3246e-20, 1.2444e-20,
         8.5525e-21],
        ...,
        [1.0000e+00, 5.3158e-08, 4.3635e-09,  ..., 4.7809e-25, 4.2191e-25,
         1.3697e-25],
        [1.0000e+00, 1.8554e-07, 5.3158e-08,  ..., 3.2534e-26, 2.1006e-26,
         9.9224e-27],
        [1.0000e+00, 2.2159e-08, 1.7258e-08,  ..., 3.5326e-24, 1.8909e-24,
         1.2996e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.2545, 0.4527, 0.6277,  ..., 1.0000, 1.0000, 1.0000],
        [0.2465, 0.4931, 0.6625,  ..., 1.0000, 1.0000, 1.0000],
        [0.7167, 0.7923, 0.8381,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([35])


tensor([ 0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  3,
         4,  5,  6,  7,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
       device='cuda:0')


carryover_candidates
torch.Size([35, 104])


tensor([[    1, 32010,  1724,  ..., 29897,  1880, 29889],
        [    1, 32010,  1724,  ..., 29897,  1880, 29889],
        [    1, 32010,  1724,  ..., 29897,  1880, 29889],
        ...,
        [    1, 32010,  1724,  ..., 29906, 29906, 20052],
        [    1, 32010,  1724,  ...,  7800,   313, 29906],
        [    1, 32010,  1724,  ..., 29906, 29906, 20052]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([35])


tensor([-12.6230, -12.6230, -12.6230, -12.6230, -12.6230, -12.6230, -14.6747,
        -14.6747, -14.6747, -14.6747, -14.6747, -14.6747, -15.4247, -15.4247,
        -15.4247, -15.4247, -15.4247, -14.5349, -16.4531, -18.5865, -19.6781,
        -20.9442, -20.9442, -21.1223, -19.7757, -20.2757, -18.5181, -19.2706,
        -20.9314, -21.1814, -22.3064, -22.9314, -17.7079, -19.3338, -21.1544],
       device='cuda:0')


new_candidate_toks
torch.Size([35, 1])


tensor([[32007],
        [ 1205],
        [19152],
        [   13],
        [  739],
        [ 3529],
        [ 1205],
        [32007],
        [19152],
        [  739],
        [   13],
        [ 3529],
        [ 3907],
        [16951],
        [ 2215],
        [ 5998],
        [  270],
        [29889],
        [ 2699],
        [ 2038],
        [29955],
        [29897],
        [  470],
        [29897],
        [29953],
        [29871],
        [ 7800],
        [29889],
        [29906],
        [29871],
        [ 2657],
        [29871],
        [ 2699],
        [29906],
        [ 2699]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([35])


tensor([-1.3684e+00, -1.6184e+00, -1.7434e+00, -2.1184e+00, -2.2434e+00,
        -2.4934e+00, -1.4002e+00, -1.4002e+00, -1.7752e+00, -2.1502e+00,
        -2.5252e+00, -2.6502e+00, -3.3307e-01, -2.5831e+00, -3.0831e+00,
        -3.3331e+00, -3.5831e+00, -3.3443e-02,  0.0000e+00, -9.5605e-02,
        -1.8178e-02, -2.2079e-01, -1.9708e+00, -4.5956e-02, -1.1016e-04,
        -2.2769e-05, -4.1941e-04, -1.1802e-05, -4.5300e-06, -3.0041e-05,
        -3.3379e-06, -1.3113e-06,  0.0000e+00, -2.3842e-07,  0.0000e+00],
       device='cuda:0')


new_candidates
torch.Size([35, 105])


tensor([[    1, 32010,  1724,  ...,  1880, 29889, 32007],
        [    1, 32010,  1724,  ...,  1880, 29889,  1205],
        [    1, 32010,  1724,  ...,  1880, 29889, 19152],
        ...,
        [    1, 32010,  1724,  ..., 29906, 20052,  2699],
        [    1, 32010,  1724,  ...,   313, 29906, 29906],
        [    1, 32010,  1724,  ..., 29906, 20052,  2699]], device='cuda:0')


new_candidate_logprobs
torch.Size([35])


tensor([-13.9914, -14.2414, -14.3664, -14.7414, -14.8664, -15.1164, -16.0749,
        -16.0749, -16.4499, -16.8249, -17.1999, -17.3249, -15.7578, -18.0078,
        -18.5078, -18.7578, -19.0078, -14.5684, -16.4531, -18.6821, -19.6963,
        -21.1650, -22.9150, -21.1682, -19.7758, -20.2757, -18.5185, -19.2706,
        -20.9314, -21.1814, -22.3064, -22.9314, -17.7079, -19.3338, -21.1544],
       device='cuda:0')

infer end: GPU memory used: 22549 MB.
event: level
id: 93
data: [{"content": "<|end|>", "parent": 0, "prob": -13.991351127624512}, {"content": "But", "parent": 0, "prob": -14.241351127624512}, {"content": "Keep", "parent": 0, "prob": -14.366351127624512}, {"content": "\n", "parent": 0, "prob": -14.741351127624512}, {"content": "It", "parent": 0, "prob": -14.866351127624512}, {"content": "Please", "parent": 0, "prob": -15.116351127624512}, {"content": "But", "parent": 1, "prob": -16.074928283691406}, {"content": "<|end|>", "parent": 1, "prob": -16.074928283691406}, {"content": "Keep", "parent": 1, "prob": -16.449928283691406}, {"content": "It", "parent": 1, "prob": -16.824928283691406}, {"content": "\n", "parent": 1, "prob": -17.199928283691406}, {"content": "Please", "parent": 1, "prob": -17.324928283691406}, {"content": "making", "parent": 2, "prob": -15.757761001586914}, {"content": "significantly", "parent": 2, "prob": -18.007761001586914}, {"content": "far", "parent": 2, "prob": -1

array([[-1.4062500e-01, -8.9843750e-02, -1.1328125e+00, ...,
        -1.7812500e+00, -1.0390625e+00,  7.7734375e-01],
       [-6.7968750e-01,  2.8710938e-01,  1.2421875e+00, ...,
         5.4296875e-01, -8.2778931e-04, -1.0625000e+00],
       [ 8.9453125e-01,  2.1582031e-01,  1.8920898e-02, ...,
        -2.7734375e-01, -2.0703125e-01, -3.6523438e-01],
       ...,
       [-2.0410156e-01,  6.0156250e-01,  1.6250000e+00, ...,
         7.1093750e-01, -2.0468750e+00, -6.6796875e-01],
       [ 9.2163086e-03, -1.1953125e+00,  1.6171875e+00, ...,
        -1.8750000e+00,  4.6142578e-02, -3.8476562e-01],
       [ 5.7812500e-01, -6.7578125e-01,  1.1640625e+00, ...,
         4.8046875e-01, -2.5976562e-01, -8.3203125e-01]], dtype=float32)


k_mean_space
(20, 2)


array([[ 50.992676,  93.04785 ],
       [ 97.35348 ,  52.896076],
       [102.550674,  67.038795],
       [ 51.016575,  90.608574],
       [100.56807 ,  59.765095],
       [100.49418 ,  71.538704],
       [ 97.057915,  52.62695 ],
       [ 50.99866 ,  93.008354],
       [102.54043 ,  66.89548 ],
       [100.455826,  59.734238],
       [ 51.03821 ,  90.454666],
       [100.453766,  71.384834],
       [ 97.614655,  59.47669 ],
       [ 91.691   ,  50.497715],
       [ 94.41499 ,  56.614822],
       [102.24989 ,  57.527977],
       [106.35573 ,  78.73318 ],
       [ 92.02298 ,  62.775494],
       [101.30459 ,  69.84164 ],
       [106.861244,  80.0892  ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -62.00755882, -265.00748825])


closest
(2,)


array([ 0, 13])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.6523, -2.5625,  3.9844,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.9141, -1.0156, -1.2109,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.1875,  3.0469, -0.5117,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 3.5781, -6.5938,  5.6875,  ...,  0.0000,  0.0000,  0.0000],
        [-1.0625, -4.0312, -5.2188,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.3438, -4.5938, -6.8125,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[8.9307e-01, 1.0666e-01, 1.1021e-04,  ..., 1.3684e-20, 1.2076e-20,
         5.0339e-21],
        [4.1095e-01, 2.1997e-01, 9.1696e-02,  ..., 1.5407e-18, 5.6680e-19,
         5.3246e-19],
        [9.9999e-01, 6.1442e-06, 1.3709e-06,  ..., 4.3465e-22, 3.6033e-22,
         4.5811e-23],
        ...,
        [3.0500e-01, 1.8499e-01, 1.6326e-01,  ..., 2.3733e-20, 6.7995e-21,
         4.1241e-21],
        [7.7695e-01, 1.7336e-01, 4.9669e-02,  ..., 8.2882e-26, 5.0270e-26,
         5.0270e-26],
        [8.7801e-01, 7.2071e-02, 4.9534e-02,  ..., 6.3547e-21, 5.6080e-21,
         1.8206e-21]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.8931, 0.9997, 0.9998,  ..., 1.0000, 1.0000, 1.0000],
        [0.4110, 0.6309, 0.7226,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.3050, 0.4900, 0.6533,  ..., 1.0000, 1.0000, 1.0000],
        [0.7769, 0.9503, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8780, 0.9501, 0.9996,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([49])


tensor([ 0,  0,  1,  1,  1,  1,  1,  1,  1,  2,  3,  4,  5,  5,  6,  6,  6,  6,
         6,  6,  6,  7,  7,  8,  9, 10, 11, 11, 12, 13, 13, 13, 14, 14, 14, 14,
        15, 15, 16, 17, 17, 17, 17, 17, 17, 18, 18, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([49, 105])


tensor([[    1, 32010,  1724,  ...,  1880, 29889, 32007],
        [    1, 32010,  1724,  ...,  1880, 29889, 32007],
        [    1, 32010,  1724,  ...,  1880, 29889,  1205],
        ...,
        [    1, 32010,  1724,  ..., 29906, 20052,  2699],
        [    1, 32010,  1724,  ...,  2699, 29897,  2038],
        [    1, 32010,  1724,  ...,  2699, 29897,  2038]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([49])


tensor([-13.9914, -13.9914, -14.2414, -14.2414, -14.2414, -14.2414, -14.2414,
        -14.2414, -14.2414, -14.3664, -14.7414, -14.8664, -15.1164, -15.1164,
        -16.0749, -16.0749, -16.0749, -16.0749, -16.0749, -16.0749, -16.0749,
        -16.0749, -16.0749, -16.4499, -16.8249, -17.1999, -17.3249, -17.3249,
        -15.7578, -18.0078, -18.0078, -18.0078, -18.5078, -18.5078, -18.5078,
        -18.5078, -18.7578, -18.7578, -19.0078, -14.5684, -14.5684, -14.5684,
        -14.5684, -14.5684, -14.5684, -16.4531, -16.4531, -18.6821, -18.6821],
       device='cuda:0')


new_candidate_toks
torch.Size([49, 1])


tensor([[32001],
        [32000],
        [  363],
        [  297],
        [  746],
        [ 1951],
        [  373],
        [ 4249],
        [ 6456],
        [  297],
        [   13],
        [29915],
        [ 4443],
        [ 1235],
        [  363],
        [  297],
        [  746],
        [ 1951],
        [  373],
        [ 4249],
        [ 6456],
        [32001],
        [32000],
        [  297],
        [29915],
        [   13],
        [ 4443],
        [ 1235],
        [  372],
        [  260],
        [ 6133],
        [ 7200],
        [13461],
        [ 1190],
        [ 2038],
        [  260],
        [  372],
        [  967],
        [ 4495],
        [32007],
        [19152],
        [ 1205],
        [   13],
        [  739],
        [ 3529],
        [29897],
        [  467],
        [  278],
        [16852]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([49])


tensor([-1.1309e-01, -2.2381e+00, -8.8928e-01, -1.5143e+00, -2.3893e+00,
        -2.6393e+00, -3.0143e+00, -3.1393e+00, -3.3893e+00, -9.7752e-06,
        -9.9112e-04, -2.8586e-02, -2.5400e-01, -1.5040e+00, -9.6689e-01,
        -1.4669e+00, -2.4669e+00, -2.5919e+00, -2.8419e+00, -2.9669e+00,
        -3.4669e+00, -1.1323e-01, -2.2382e+00, -1.1921e-05, -3.2272e-02,
        -5.8143e-04, -2.5364e-01, -1.5036e+00, -4.7684e-07, -6.2084e-01,
        -1.1208e+00, -3.1208e+00, -3.1991e-01, -2.4449e+00, -2.4449e+00,
        -2.8199e+00, -7.9391e-01, -7.9391e-01, -7.3910e-06, -1.1874e+00,
        -1.6874e+00, -1.8124e+00, -2.1874e+00, -2.4374e+00, -2.4374e+00,
        -2.5238e-01, -1.7524e+00, -1.3010e-01, -2.6301e+00], device='cuda:0')


new_candidates
torch.Size([49, 106])


tensor([[    1, 32010,  1724,  ..., 29889, 32007, 32001],
        [    1, 32010,  1724,  ..., 29889, 32007, 32000],
        [    1, 32010,  1724,  ..., 29889,  1205,   363],
        ...,
        [    1, 32010,  1724,  ..., 20052,  2699,   467],
        [    1, 32010,  1724,  ..., 29897,  2038,   278],
        [    1, 32010,  1724,  ..., 29897,  2038, 16852]], device='cuda:0')


new_candidate_logprobs
torch.Size([49])


tensor([-14.1044, -16.2294, -15.1306, -15.7556, -16.6306, -16.8806, -17.2556,
        -17.3806, -17.6306, -14.3664, -14.7423, -14.8949, -15.3703, -16.6203,
        -17.0418, -17.5418, -18.5418, -18.6668, -18.9168, -19.0418, -19.5418,
        -16.1882, -18.3132, -16.4499, -16.8572, -17.2005, -17.5786, -18.8286,
        -15.7578, -18.6286, -19.1286, -21.1286, -18.8277, -20.9527, -20.9527,
        -21.3277, -19.5517, -19.5517, -19.0078, -15.7558, -16.2558, -16.3808,
        -16.7558, -17.0058, -17.0058, -16.7054, -18.2054, -18.8122, -21.3122],
       device='cuda:0')

infer end: GPU memory used: 22807 MB.
event: level
id: 94
data: [{"content": "<|assistant|>", "parent": 0, "prob": -14.104438781738281}, {"content": "<|endoftext|>", "parent": 0, "prob": -16.22943878173828}, {"content": "for", "parent": 1, "prob": -15.130627632141113}, {"content": "in", "parent": 1, "prob": -15.755627632141113}, {"content": "when", "parent": 1, "prob": -16.630626678466797}, {"content": "since", "parent": 1, "prob": -16.880626678466797}, {"content": "on", "parent": 1, "prob": -17.255626678466797}, {"content": "among", "parent": 1, "prob": -17.380626678466797}, {"content": "remember", "parent": 1, "prob": -17.630626678466797}, {"content": "in", "parent": 2, "prob": -14.366360664367676}, {"content": "\n", "parent": 3, "prob": -14.742341995239258}, {"content": "'", "parent": 4, "prob": -14.894937515258789}, {"content": "note", "parent": 5, "prob": -15.37034797668457}, {"content": "let", "parent": 5, "prob": -16.62034797668457}, {"content": "for", "parent": 6, "prob": -17.0

array([[-2.375     ,  0.3125    , -0.8671875 , ...,  0.82421875,
        -3.        ,  1.875     ],
       [-0.51171875,  0.39648438,  0.5078125 , ..., -1.234375  ,
         1.4765625 , -2.        ],
       [ 0.11035156, -0.30078125,  1.6015625 , ..., -0.96484375,
        -1.078125  ,  0.10888672],
       ...,
       [-1.3359375 ,  0.484375  ,  0.9140625 , ..., -0.26953125,
         0.10498047, -0.4921875 ],
       [-0.62890625,  1.2734375 ,  2.5625    , ..., -0.8359375 ,
        -0.50390625, -1.03125   ],
       [-1.09375   ,  0.05688477,  3.5625    , ..., -1.375     ,
        -1.3203125 , -0.64453125]], dtype=float32)


k_mean_space
(20, 2)


array([[57.000744, 79.73839 ],
       [55.864563, 74.584   ],
       [74.81485 , 38.196358],
       [71.969246, 38.839   ],
       [71.67426 , 48.10482 ],
       [67.406715, 35.52003 ],
       [72.69454 , 45.528103],
       [70.48942 , 41.6053  ],
       [72.95837 , 54.31365 ],
       [83.296684, 70.52279 ],
       [51.482998, 75.24474 ],
       [66.61293 , 60.07294 ],
       [73.93167 , 60.457836],
       [64.79862 , 87.00657 ],
       [74.85761 , 38.02087 ],
       [72.039696, 38.827007],
       [71.615685, 48.04632 ],
       [67.61828 , 35.83678 ],
       [72.817566, 45.781273],
       [70.69934 , 41.953396]], dtype=float32)


k_mean_clusters
(20,)


array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -61.69656754, -271.04695797])


closest
(2,)


array([10,  5])


last_tok_logits
torch.Size([20, 32064])


tensor([[ 0.3125,  1.7656, -1.3828,  ...,  0.0000,  0.0000,  0.0000],
        [-0.3828, -4.9688,  1.3203,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.4688,  6.4062,  2.3125,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 6.4688,  1.3438, -0.6328,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.2812,  6.7500, -1.1797,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.6875,  4.5938, -0.9570,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[2.2336e-01, 2.2336e-01, 8.2171e-02,  ..., 2.2539e-19, 2.2539e-19,
         6.8740e-20],
        [9.3048e-01, 5.2494e-02, 2.1667e-03,  ..., 5.2811e-14, 3.0091e-14,
         1.9428e-14],
        [9.4474e-01, 4.1509e-02, 4.9575e-03,  ..., 2.2420e-20, 2.1061e-20,
         8.7797e-21],
        ...,
        [5.3674e-01, 2.2375e-01, 1.0569e-01,  ..., 5.7654e-19, 3.7224e-19,
         9.4117e-20],
        [9.9689e-01, 1.1672e-03, 1.0301e-03,  ..., 8.7030e-21, 6.3673e-21,
         3.8619e-21],
        [5.4543e-01, 3.3082e-01, 3.9511e-02,  ..., 2.4666e-18, 1.4961e-18,
         1.3203e-18]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.2234, 0.4467, 0.5289,  ..., 1.0000, 1.0000, 1.0000],
        [0.9305, 0.9830, 0.9851,  ..., 1.0000, 1.0000, 1.0000],
        [0.9447, 0.9862, 0.9912,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.5367, 0.7605, 0.8662,  ..., 1.0000, 1.0000, 1.0000],
        [0.9969, 0.9981, 0.9991,  ..., 1.0000, 1.0000, 1.0000],
        [0.5454, 0.8762, 0.9158,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([67])


tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         1,  2,  3,  3,  3,  4,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  6,
         7,  7,  7,  8,  9, 10, 10, 10, 10, 11, 12, 13, 14, 15, 15, 15, 16, 16,
        16, 16, 16, 16, 17, 17, 17, 17, 17, 18, 19, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([67, 106])


tensor([[    1, 32010,  1724,  ..., 29889, 32007, 32001],
        [    1, 32010,  1724,  ..., 29889, 32007, 32001],
        [    1, 32010,  1724,  ..., 29889, 32007, 32001],
        ...,
        [    1, 32010,  1724,  ..., 29889,  1205,  4249],
        [    1, 32010,  1724,  ..., 29889,  1205,  4249],
        [    1, 32010,  1724,  ..., 29889,  1205,  4249]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([67])


tensor([-14.1044, -14.1044, -14.1044, -14.1044, -14.1044, -14.1044, -14.1044,
        -14.1044, -14.1044, -14.1044, -14.1044, -14.1044, -14.1044, -14.1044,
        -14.1044, -14.1044, -14.1044, -14.1044, -16.2294, -15.1306, -15.7556,
        -15.7556, -15.7556, -16.6306, -16.6306, -16.6306, -16.6306, -16.6306,
        -16.6306, -16.8806, -16.8806, -16.8806, -16.8806, -16.8806, -16.8806,
        -17.2556, -17.3806, -17.3806, -17.3806, -17.6306, -14.3664, -14.7423,
        -14.7423, -14.7423, -14.7423, -14.8949, -15.3703, -16.6203, -17.0418,
        -17.5418, -17.5418, -17.5418, -18.5418, -18.5418, -18.5418, -18.5418,
        -18.5418, -18.5418, -18.6668, -18.6668, -18.6668, -18.6668, -18.6668,
        -18.9168, -19.0418, -19.0418, -19.0418], device='cuda:0')


new_candidate_toks
torch.Size([67, 1])


tensor([[  306],
        [ 2398],
        [  739],
        [ 3529],
        [ 4587],
        [  960],
        [ 5806],
        [ 1094],
        [  376],
        [  512],
        [ 3940],
        [ 1619],
        [  315],
        [ 1976],
        [ 7280],
        [ 8040],
        [19814],
        [  319],
        [32000],
        [11563],
        [ 4958],
        [  278],
        [ 2498],
        [16811],
        [ 5353],
        [  372],
        [ 2305],
        [29371],
        [13858],
        [ 8040],
        [11563],
        [  278],
        [ 1556],
        [  591],
        [18274],
        [11563],
        [11563],
        [  599],
        [  278],
        [29892],
        [ 3458],
        [ 9598],
        [12148],
        [ 3644],
        [ 3112],
        [29879],
        [  393],
        [  592],
        [11563],
        [ 4958],
        [  278],
        [ 2498],
        [16811],
        [ 5353],
        [  372],
        [ 2305],
        [29371],
        [13858],
        [ 8040


new_candidate_tok_logprobs
torch.Size([67])


tensor([-1.4990e+00, -1.4990e+00, -2.4990e+00, -2.7490e+00, -2.7490e+00,
        -2.8740e+00, -3.3740e+00, -3.6240e+00, -3.7490e+00, -3.8740e+00,
        -3.9990e+00, -4.4990e+00, -4.6240e+00, -4.6240e+00, -4.6240e+00,
        -4.7490e+00, -4.8740e+00, -4.9990e+00, -7.2057e-02, -5.6849e-02,
        -3.9939e-01, -1.6494e+00, -3.0244e+00, -3.6787e-01, -2.3679e+00,
        -3.1179e+00, -3.3679e+00, -3.3679e+00, -3.4929e+00, -7.8726e-01,
        -1.2873e+00, -2.1623e+00, -3.5373e+00, -3.7873e+00, -4.0373e+00,
        -3.4796e-03, -6.4710e-01, -1.0221e+00, -3.2721e+00, -8.1755e-03,
        -3.9339e-06, -7.9788e-01, -1.0479e+00, -2.4229e+00, -2.7979e+00,
        -1.1921e-07, -8.7835e-03, -6.6163e-05, -5.7879e-02, -4.0471e-01,
        -1.6547e+00, -3.0297e+00, -4.2827e-01, -2.1783e+00, -2.9283e+00,
        -3.3033e+00, -3.3033e+00, -3.4283e+00, -6.2225e-01, -1.4972e+00,
        -2.2472e+00, -3.6222e+00, -3.9972e+00, -3.1178e-03, -6.0619e-01,
        -1.1062e+00, -3.2312e+00], device='cuda:0')


new_candidates
torch.Size([67, 107])


tensor([[    1, 32010,  1724,  ..., 32007, 32001,   306],
        [    1, 32010,  1724,  ..., 32007, 32001,  2398],
        [    1, 32010,  1724,  ..., 32007, 32001,   739],
        ...,
        [    1, 32010,  1724,  ...,  1205,  4249, 11563],
        [    1, 32010,  1724,  ...,  1205,  4249,   599],
        [    1, 32010,  1724,  ...,  1205,  4249,   278]], device='cuda:0')


new_candidate_logprobs
torch.Size([67])


tensor([-15.6034, -15.6034, -16.6034, -16.8534, -16.8534, -16.9784, -17.4784,
        -17.7284, -17.8534, -17.9784, -18.1034, -18.6034, -18.7284, -18.7284,
        -18.7284, -18.8534, -18.9784, -19.1034, -16.3015, -15.1875, -16.1550,
        -17.4050, -18.7800, -16.9985, -18.9985, -19.7485, -19.9985, -19.9985,
        -20.1235, -17.6679, -18.1679, -19.0429, -20.4179, -20.6679, -20.9179,
        -17.2591, -18.0277, -18.4027, -20.6527, -17.6388, -14.3664, -15.5402,
        -15.7902, -17.1652, -17.5402, -14.8949, -15.3791, -16.6204, -17.0997,
        -17.9465, -19.1965, -20.5715, -18.9701, -20.7201, -21.4701, -21.8451,
        -21.8451, -21.9701, -19.2891, -20.1641, -20.9141, -22.2891, -22.6641,
        -18.9199, -19.6480, -20.1480, -22.2730], device='cuda:0')

infer end: GPU memory used: 16791 MB.
event: level
id: 95
data: [{"content": "I", "parent": 0, "prob": -15.603389739990234}, {"content": "However", "parent": 0, "prob": -15.603389739990234}, {"content": "It", "parent": 0, "prob": -16.603389739990234}, {"content": "Please", "parent": 0, "prob": -16.853389739990234}, {"content": "Of", "parent": 0, "prob": -16.853389739990234}, {"content": "If", "parent": 0, "prob": -16.978389739990234}, {"content": "While", "parent": 0, "prob": -17.478389739990234}, {"content": "As", "parent": 0, "prob": -17.728389739990234}, {"content": "\"", "parent": 0, "prob": -17.853389739990234}, {"content": "In", "parent": 0, "prob": -17.978389739990234}, {"content": "Note", "parent": 0, "prob": -18.103389739990234}, {"content": "My", "parent": 0, "prob": -18.603389739990234}, {"content": "C", "parent": 0, "prob": -18.728389739990234}, {"content": "Ab", "parent": 0, "prob": -18.728389739990234}, {"content": "Another", "parent": 0, "prob": -18.728389739990234}, {"c

array([[ 0.03540039,  1.9609375 , -2.        , ...,  1.1328125 ,
        -1.984375  , -1.2109375 ],
       [-0.71875   , -1.5703125 ,  0.99609375, ...,  0.5546875 ,
        -0.44140625, -0.4921875 ],
       [-1.078125  ,  0.34765625, -0.18847656, ..., -0.16210938,
        -1.09375   , -2.109375  ],
       ...,
       [-1.03125   ,  1.25      ,  1.25      , ...,  1.6640625 ,
        -1.0078125 , -0.703125  ],
       [-0.6953125 ,  0.73046875,  0.734375  , ..., -1.40625   ,
         2.140625  , -1.484375  ],
       [ 1.5625    ,  1.6171875 ,  0.46875   , ...,  1.046875  ,
         0.984375  , -0.34765625]], dtype=float32)


k_mean_space
(20, 2)


array([[86.15996 , 75.889565],
       [51.605556, 50.397877],
       [66.348495, 54.71866 ],
       [54.78463 , 78.03481 ],
       [71.78785 , 53.21588 ],
       [69.189705, 53.277664],
       [58.090504, 39.425713],
       [76.7526  , 59.874664],
       [66.546486, 50.33014 ],
       [70.00411 , 50.222866],
       [46.386955, 61.364777],
       [83.63052 , 69.4095  ],
       [88.919624, 76.3287  ],
       [92.561104, 81.70763 ],
       [65.24641 , 47.234142],
       [93.4617  , 79.96726 ],
       [42.3313  , 52.895184],
       [76.63705 , 56.815197],
       [69.28098 , 57.4657  ],
       [74.17829 , 60.66753 ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -53.93516922, -296.91481686])


closest
(2,)


array([16,  6])


last_tok_logits
torch.Size([20, 32064])


tensor([[-2.0312, -5.1875,  0.2617,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.4531,  3.7656, -0.4180,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.8281, -2.2812,  2.2344,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-2.2500, -0.6367, -1.6641,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.4219, -1.7891,  2.4375,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.8594,  3.0781, -3.2344,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[4.6927e-01, 1.7263e-01, 1.5235e-01,  ..., 8.3935e-18, 2.5599e-18,
         1.7594e-18],
        [9.9999e-01, 1.9947e-06, 5.7150e-07,  ..., 8.9050e-19, 7.8586e-19,
         7.8586e-19],
        [9.6090e-01, 2.9017e-02, 5.7137e-03,  ..., 1.2450e-18, 5.8810e-19,
         4.3026e-19],
        ...,
        [4.4201e-01, 1.4350e-01, 1.2664e-01,  ..., 7.9060e-18, 1.9989e-18,
         9.4424e-19],
        [7.0717e-01, 2.2732e-02, 1.6631e-02,  ..., 5.2569e-12, 3.6130e-12,
         2.9953e-12],
        [8.0047e-01, 9.5603e-02, 7.4456e-02,  ..., 4.8508e-20, 4.5569e-20,
         1.5748e-20]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.4693, 0.6419, 0.7943,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9609, 0.9899, 0.9956,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.4420, 0.5855, 0.7122,  ..., 1.0000, 1.0000, 1.0000],
        [0.7072, 0.7299, 0.7465,  ..., 1.0000, 1.0000, 1.0000],
        [0.8005, 0.8961, 0.9705,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([233])


tensor([ 0,  0,  0,  0,  1,  2,  3,  4,  5,  5,  6,  7,  7,  7,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 10, 10, 11, 11, 11,
        12, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15,
        16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18,
        18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
        18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
        18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
        18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
        18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
        18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
        18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
        18, 18, 18, 18, 18, 18, 18, 18, 


carryover_candidates
torch.Size([233, 107])


tensor([[    1, 32010,  1724,  ..., 32007, 32001,   306],
        [    1, 32010,  1724,  ..., 32007, 32001,   306],
        [    1, 32010,  1724,  ..., 32007, 32001,   306],
        ...,
        [    1, 32010,  1724,  ...,  1205,   363, 11563],
        [    1, 32010,  1724,  ...,  1205,   363, 11563],
        [    1, 32010,  1724,  ...,  1205,   363, 11563]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([233])


tensor([-15.6034, -15.6034, -15.6034, -15.6034, -15.6034, -16.6034, -16.8534,
        -16.8534, -16.9784, -16.9784, -17.4784, -17.7284, -17.7284, -17.7284,
        -17.8534, -17.8534, -17.8534, -17.8534, -17.8534, -17.8534, -17.8534,
        -17.8534, -17.8534, -17.8534, -17.8534, -17.8534, -17.8534, -17.8534,
        -17.8534, -17.8534, -17.8534, -17.8534, -17.8534, -17.8534, -17.8534,
        -17.8534, -17.8534, -17.8534, -17.9784, -17.9784, -17.9784, -17.9784,
        -17.9784, -17.9784, -17.9784, -17.9784, -17.9784, -17.9784, -17.9784,
        -18.1034, -18.1034, -18.6034, -18.6034, -18.6034, -18.7284, -18.7284,
        -18.7284, -18.7284, -18.7284, -18.7284, -18.7284, -18.7284, -18.7284,
        -18.7284, -18.7284, -18.7284, -18.7284, -18.7284, -18.7284, -18.7284,
        -18.7284, -18.8534, -18.9784, -19.1034, -19.1034, -19.1034, -19.1034,
        -19.1034, -19.1034, -19.1034, -19.1034, -19.1034, -19.1034, -19.1034,
        -16.3015, -16.3015, -16.3015, -16.3015, -16.3015, -16.30


new_candidate_toks
torch.Size([233, 1])


tensor([[29915],
        [  508],
        [ 2274],
        [27746],
        [29892],
        [29915],
        [ 4443],
        [ 3236],
        [  366],
        [  596],
        [ 8040],
        [  310],
        [  385],
        [  263],
        [27648],
        [ 1576],
        [ 5618],
        [12148],
        [29902],
        [ 3644],
        [16382],
        [29923],
        [29924],
        [ 2887],
        [29968],
        [ 6246],
        [13696],
        [  797],
        [ 3112],
        [ 5328],
        [ 7900],
        [17245],
        [29903],
        [ 1762],
        [10401],
        [15666],
        [29950],
        [29940],
        [  278],
        [ 6124],
        [  263],
        [ 1737],
        [ 4958],
        [ 2498],
        [  509],
        [13858],
        [  393],
        [ 1206],
        [ 1749],
        [  393],
        [29901],
        [ 3095],
        [ 6437],
        [ 7306],
        [13946],
        [ 2929],
        [ 8031],
        [ 1298],
        [11158


new_candidate_tok_logprobs
torch.Size([233])


tensor([-7.5658e-01, -1.7566e+00, -1.8816e+00, -2.0066e+00, -5.0068e-06,
        -3.9889e-02, -2.8915e-03, -3.2129e-03, -1.4248e-01, -2.8925e+00,
        -8.6677e-02, -3.1577e-01, -1.8158e+00, -3.5658e+00, -8.0935e-01,
        -1.6844e+00, -3.0594e+00, -3.1844e+00, -3.1844e+00, -3.8094e+00,
        -4.1844e+00, -4.1844e+00, -4.1844e+00, -4.5594e+00, -5.0594e+00,
        -5.1844e+00, -5.1844e+00, -5.3094e+00, -5.3094e+00, -5.4344e+00,
        -5.4344e+00, -5.4344e+00, -5.4344e+00, -5.5594e+00, -5.5594e+00,
        -5.5594e+00, -5.5594e+00, -5.6844e+00, -7.8371e-01, -1.7837e+00,
        -1.9087e+00, -3.4087e+00, -3.6587e+00, -4.1587e+00, -4.2837e+00,
        -4.2837e+00, -4.6587e+00, -4.6587e+00, -4.6587e+00, -5.9233e-01,
        -8.4233e-01, -1.8954e-01, -3.1895e+00, -3.1895e+00, -1.2398e-05,
        -7.5102e-06, -1.0834e+00, -2.2084e+00, -2.7084e+00, -2.7084e+00,
        -2.9584e+00, -3.0834e+00, -3.0834e+00, -3.2084e+00, -3.3334e+00,
        -3.8334e+00, -3.8334e+00, -3.9584e+00, -4.2


new_candidates
torch.Size([233, 108])


tensor([[    1, 32010,  1724,  ..., 32001,   306, 29915],
        [    1, 32010,  1724,  ..., 32001,   306,   508],
        [    1, 32010,  1724,  ..., 32001,   306,  2274],
        ...,
        [    1, 32010,  1724,  ...,   363, 11563, 29892],
        [    1, 32010,  1724,  ...,   363, 11563, 10816],
        [    1, 32010,  1724,  ...,   363, 11563, 29915]], device='cuda:0')


new_candidate_logprobs
torch.Size([233])


tensor([-16.3600, -17.3600, -17.4850, -17.6100, -15.6034, -16.6433, -16.8563,
        -16.8566, -17.1209, -19.8709, -17.5651, -18.0442, -19.5442, -21.2942,
        -18.6627, -19.5377, -20.9127, -21.0377, -21.0377, -21.6627, -22.0377,
        -22.0377, -22.0377, -22.4127, -22.9127, -23.0377, -23.0377, -23.1627,
        -23.1627, -23.2877, -23.2877, -23.2877, -23.2877, -23.4127, -23.4127,
        -23.4127, -23.4127, -23.5377, -18.7621, -19.7621, -19.8871, -21.3871,
        -21.6371, -22.1371, -22.2621, -22.2621, -22.6371, -22.6371, -22.6371,
        -18.6957, -18.9457, -18.7929, -21.7929, -21.7929, -18.7284, -18.7284,
        -19.8118, -20.9368, -21.4368, -21.4368, -21.6868, -21.8118, -21.8118,
        -21.9368, -22.0618, -22.5618, -22.5618, -22.6868, -22.9368, -23.0618,
        -23.0618, -18.8554, -18.9784, -19.9198, -21.0448, -21.1698, -21.5448,
        -22.5448, -22.9198, -23.4198, -23.5448, -23.6698, -23.7948, -24.1698,
        -16.6480, -20.0855, -20.3980, -20.5230, -20.6480, -20.71

infer end: GPU memory used: 17813 MB.
event: level
id: 96
data: [{"content": "'", "parent": 0, "prob": -16.359966278076172}, {"content": "can", "parent": 0, "prob": -17.359966278076172}, {"content": "understand", "parent": 0, "prob": -17.484966278076172}, {"content": "apolog", "parent": 0, "prob": -17.609966278076172}, {"content": ",", "parent": 1, "prob": -15.603394508361816}, {"content": "'", "parent": 2, "prob": -16.643278121948242}, {"content": "note", "parent": 3, "prob": -16.856281280517578}, {"content": "course", "parent": 4, "prob": -16.85660171508789}, {"content": "you", "parent": 5, "prob": -17.120864868164062}, {"content": "your", "parent": 5, "prob": -19.870864868164062}, {"content": "Mount", "parent": 6, "prob": -17.565065383911133}, {"content": "of", "parent": 7, "prob": -18.044157028198242}, {"content": "an", "parent": 7, "prob": -19.544157028198242}, {"content": "a", "parent": 7, "prob": -21.294157028198242}, {"content": "Mount", "parent": 8, "prob": -18.662742614746094

array([[ 0.45898438, -0.4453125 ,  0.80078125, ..., -0.4453125 ,
        -0.49023438, -0.66015625],
       [ 0.53125   ,  1.640625  , -0.30273438, ...,  0.984375  ,
        -1.0546875 , -2.046875  ],
       [ 0.5234375 , -0.75      , -0.2734375 , ..., -0.06494141,
        -0.66015625, -2.25      ],
       ...,
       [ 1.734375  ,  0.2890625 , -0.51953125, ...,  1.6171875 ,
        -1.0234375 ,  0.31640625],
       [-1.1640625 ,  0.43359375, -1.0546875 , ...,  0.953125  ,
        -2.234375  , -1.21875   ],
       [-0.39257812, -0.44921875,  2.046875  , ..., -0.8359375 ,
        -1.4765625 , -0.91796875]], dtype=float32)


k_mean_space
(20, 2)


array([[41.513557, 53.158504],
       [72.96286 , 61.898376],
       [79.21259 , 61.04334 ],
       [72.45761 , 88.01832 ],
       [69.1314  , 52.162945],
       [46.95698 , 57.624012],
       [70.14189 , 49.01015 ],
       [64.917404, 52.79335 ],
       [64.6057  , 48.606964],
       [63.47246 , 44.79915 ],
       [51.92991 , 75.21432 ],
       [83.33861 , 71.365295],
       [82.837776, 66.39171 ],
       [79.3187  , 60.49258 ],
       [45.941273, 66.83665 ],
       [58.023582, 44.97755 ],
       [70.29918 , 54.261982],
       [75.76352 , 61.366554],
       [74.30819 , 61.834667],
       [67.24625 , 47.48498 ]], dtype=float32)


k_mean_clusters
(20,)


array([0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([ -86.84101868, -284.22412395])


closest
(2,)


array([0, 9])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.1001, -1.3281,  1.5391,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.4062,  0.1934,  2.2344,  ...,  0.0000,  0.0000,  0.0000],
        [-0.1924, -5.5000, -2.8594,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.5391,  2.6406,  0.3027,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.9375, -1.3281,  3.6562,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.0000,  1.2734,  3.8281,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[9.3898e-01, 4.6749e-02, 1.0431e-02,  ..., 3.9891e-18, 3.3071e-18,
         3.3071e-18],
        [5.6578e-01, 3.4316e-01, 5.9633e-02,  ..., 1.3294e-21, 7.5748e-22,
         1.5878e-22],
        [3.0132e-01, 2.6591e-01, 2.0709e-01,  ..., 1.4362e-19, 9.8711e-20,
         3.3777e-21],
        ...,
        [3.9338e-01, 2.3860e-01, 1.1271e-01,  ..., 2.9041e-19, 2.5629e-19,
         1.3718e-19],
        [5.7484e-01, 1.1319e-01, 6.0588e-02,  ..., 4.6079e-17, 4.0665e-17,
         3.5533e-18],
        [9.5246e-01, 1.0581e-02, 6.4176e-03,  ..., 1.0661e-17, 1.0333e-17,
         7.1016e-18]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.9390, 0.9857, 0.9962,  ..., 1.0000, 1.0000, 1.0000],
        [0.5658, 0.9089, 0.9686,  ..., 1.0000, 1.0000, 1.0000],
        [0.3013, 0.5672, 0.7743,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.3934, 0.6320, 0.7447,  ..., 1.0000, 1.0000, 1.0000],
        [0.5748, 0.6880, 0.7486,  ..., 1.0000, 1.0000, 1.0000],
        [0.9525, 0.9630, 0.9695,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([73])


tensor([ 0,  1,  1,  2,  2,  2,  2,  2,  3,  4,  4,  4,  4,  4,  5,  6,  7,  7,
         8,  8,  9,  9, 10, 11, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
        13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 15,
        15, 16, 16, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18,
        19], device='cuda:0')


carryover_candidates
torch.Size([73, 108])


tensor([[    1, 32010,  1724,  ..., 32001,   306, 29915],
        [    1, 32010,  1724,  ..., 32001,   306,   508],
        [    1, 32010,  1724,  ..., 32001,   306,   508],
        ...,
        [    1, 32010,  1724,  ..., 32001,   376, 29902],
        [    1, 32010,  1724,  ..., 32001,   376, 29902],
        [    1, 32010,  1724,  ..., 32001,   376,  3644]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([73])


tensor([-16.3600, -17.3600, -17.3600, -17.4850, -17.4850, -17.4850, -17.4850,
        -17.4850, -17.6100, -15.6034, -15.6034, -15.6034, -15.6034, -15.6034,
        -16.6433, -16.8563, -16.8566, -16.8566, -17.1209, -17.1209, -19.8709,
        -19.8709, -17.5651, -18.0442, -19.5442, -21.2942, -21.2942, -21.2942,
        -21.2942, -21.2942, -21.2942, -21.2942, -21.2942, -21.2942, -21.2942,
        -21.2942, -21.2942, -21.2942, -21.2942, -21.2942, -21.2942, -21.2942,
        -21.2942, -21.2942, -21.2942, -21.2942, -21.2942, -21.2942, -21.2942,
        -21.2942, -21.2942, -21.2942, -18.6627, -19.5377, -19.5377, -20.9127,
        -20.9127, -21.0377, -21.0377, -21.0377, -21.0377, -21.0377, -21.0377,
        -21.0377, -21.0377, -21.0377, -21.0377, -21.0377, -21.0377, -21.0377,
        -21.0377, -21.0377, -21.6627], device='cuda:0')


new_candidate_toks
torch.Size([73, 1])


tensor([[29885],
        [ 3867],
        [  884],
        [ 1286],
        [  393],
        [29889],
        [  366],
        [  278],
        [  675],
        [  565],
        [  372],
        [  746],
        [  727],
        [  297],
        [29879],
        [  393],
        [29991],
        [29892],
        [29915],
        [  526],
        [ 1139],
        [ 4066],
        [18274],
        [  590],
        [  319],
        [15754],
        [ 1298],
        [ 1737],
        [ 4086],
        [  901],
        [ 1101],
        [14378],
        [  319],
        [ 7134],
        [ 2143],
        [ 2625],
        [ 2498],
        [ 2982],
        [ 3353],
        [ 5844],
        [  337],
        [ 5534],
        [  260],
        [ 1700],
        [ 3133],
        [ 3407],
        [  716],
        [15171],
        [  848],
        [ 7481],
        [10754],
        [ 4443],
        [18274],
        [ 9939],
        [15655],
        [  338],
        [29915],
        [ 3867],
        [ 4443


new_candidate_tok_logprobs
torch.Size([73])


tensor([-6.2960e-02, -5.6954e-01, -1.0695e+00, -1.1996e+00, -1.3246e+00,
        -1.5746e+00, -2.0746e+00, -3.4496e+00, -2.4319e-05, -3.9046e-01,
        -2.1405e+00, -2.6405e+00, -3.3905e+00, -3.6405e+00, -2.3842e-07,
        -3.1442e-02, -1.8089e-01, -1.8059e+00, -1.8839e-01, -2.0634e+00,
        -2.4540e-01, -1.7454e+00, -4.3694e-04, -1.3484e-02, -1.5957e-02,
        -1.4780e+00, -1.9780e+00, -2.1030e+00, -2.2280e+00, -2.9780e+00,
        -3.3530e+00, -3.4780e+00, -3.6030e+00, -3.7280e+00, -4.1030e+00,
        -4.3530e+00, -4.6030e+00, -4.6030e+00, -4.7280e+00, -4.7280e+00,
        -4.8530e+00, -4.8530e+00, -4.9780e+00, -4.9780e+00, -4.9780e+00,
        -4.9780e+00, -5.1030e+00, -5.3530e+00, -5.4780e+00, -5.4780e+00,
        -5.4780e+00, -5.6030e+00, -2.6420e-02, -1.4565e-01, -2.2706e+00,
        -2.8995e-01, -1.4149e+00, -9.3298e-01, -1.4330e+00, -2.1830e+00,
        -2.3080e+00, -3.5580e+00, -3.9330e+00, -4.1830e+00, -5.5367e-01,
        -2.1787e+00, -2.8037e+00, -3.1787e+00, -3.3


new_candidates
torch.Size([73, 109])


tensor([[    1, 32010,  1724,  ...,   306, 29915, 29885],
        [    1, 32010,  1724,  ...,   306,   508,  3867],
        [    1, 32010,  1724,  ...,   306,   508,   884],
        ...,
        [    1, 32010,  1724,  ...,   376, 29902,   471],
        [    1, 32010,  1724,  ...,   376, 29902,  6839],
        [    1, 32010,  1724,  ...,   376,  3644,   366]], device='cuda:0')


new_candidate_logprobs
torch.Size([73])


tensor([-16.4229, -17.9295, -18.4295, -18.6846, -18.8096, -19.0596, -19.5596,
        -20.9346, -17.6100, -15.9939, -17.7439, -18.2439, -18.9939, -19.2439,
        -16.6433, -16.8877, -17.0375, -18.6625, -17.3093, -19.1843, -20.1163,
        -21.6163, -17.5655, -18.0576, -19.5601, -22.7721, -23.2721, -23.3971,
        -23.5221, -24.2721, -24.6471, -24.7721, -24.8971, -25.0221, -25.3971,
        -25.6471, -25.8971, -25.8971, -26.0221, -26.0221, -26.1471, -26.1471,
        -26.2721, -26.2721, -26.2721, -26.2721, -26.3971, -26.6471, -26.7721,
        -26.7721, -26.7721, -26.8971, -18.6892, -19.6834, -21.8084, -21.2027,
        -22.3277, -21.9707, -22.4707, -23.2207, -23.3457, -24.5957, -24.9707,
        -25.2207, -21.5914, -23.2164, -23.8414, -24.2164, -24.3414, -24.3414,
        -24.4664, -25.2164, -21.7115], device='cuda:0')

infer end: GPU memory used: 18079 MB.
event: level
id: 97
data: [{"content": "m", "parent": 0, "prob": -16.42292594909668}, {"content": "provide", "parent": 1, "prob": -17.92951011657715}, {"content": "also", "parent": 1, "prob": -18.42951011657715}, {"content": "now", "parent": 2, "prob": -18.68455696105957}, {"content": "that", "parent": 2, "prob": -18.80955696105957}, {"content": ".", "parent": 2, "prob": -19.05955696105957}, {"content": "you", "parent": 2, "prob": -19.55955696105957}, {"content": "the", "parent": 2, "prob": -20.93455696105957}, {"content": "ize", "parent": 3, "prob": -17.6099910736084}, {"content": "if", "parent": 4, "prob": -15.993858337402344}, {"content": "it", "parent": 4, "prob": -17.743858337402344}, {"content": "when", "parent": 4, "prob": -18.243858337402344}, {"content": "there", "parent": 4, "prob": -18.993858337402344}, {"content": "in", "parent": 4, "prob": -19.243858337402344}, {"content": "s", "parent": 5, "prob": -16.643278121948242}, {"content": "th

array([[-0.15722656, -0.57421875,  0.90625   , ..., -0.57421875,
        -1.7265625 , -1.1015625 ],
       [ 0.13671875, -1.5703125 , -0.17480469, ...,  0.50390625,
         0.11914062, -0.13476562],
       [-0.3203125 ,  1.78125   , -0.51953125, ...,  0.45898438,
        -0.640625  , -1.3203125 ],
       ...,
       [ 0.01940918,  1.390625  ,  0.9375    , ...,  0.5078125 ,
        -1.8515625 , -0.484375  ],
       [ 0.17871094,  0.578125  , -0.22949219, ...,  0.07666016,
        -0.49609375, -0.23730469],
       [ 0.40234375,  0.31445312, -0.14746094, ..., -0.69140625,
        -0.42578125, -0.09179688]], dtype=float32)


k_mean_space
(20, 2)


array([[82.873795, 78.4365  ],
       [65.4279  , 54.90606 ],
       [63.470024, 56.815895],
       [70.87477 , 66.062775],
       [56.077568, 36.61482 ],
       [55.868805, 43.680843],
       [51.214325, 61.945114],
       [59.820007, 49.486626],
       [69.81906 , 86.030716],
       [67.19227 , 49.60161 ],
       [43.46497 , 57.366707],
       [57.4212  , 40.722446],
       [49.173996, 59.830563],
       [70.2457  , 57.091797],
       [43.79295 , 50.788467],
       [68.27855 , 51.189674],
       [52.627785, 44.718475],
       [60.709183, 42.36437 ],
       [41.17529 , 50.812046],
       [52.45521 , 45.6979  ]], dtype=float32)


k_mean_clusters
(20,)


array([1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-107.85979652, -255.52369881])


closest
(2,)


array([18,  4])


last_tok_logits
torch.Size([20, 32064])


tensor([[-0.2031, -1.1953,  3.0625,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.4219,  2.6406,  2.4375,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.4062,  3.2500,  3.9688,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 4.2812,  5.3438,  3.5938,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3770, -0.8828,  0.8438,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.3750,  1.8984,  3.1406,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[7.9636e-01, 1.7769e-01, 1.4586e-02,  ..., 5.0864e-21, 2.7226e-21,
         2.4027e-21],
        [8.8725e-01, 4.4174e-02, 2.0866e-02,  ..., 3.4372e-21, 1.5253e-21,
         1.5253e-21],
        [8.2492e-01, 4.1070e-02, 4.1070e-02,  ..., 1.8209e-21, 1.5096e-21,
         1.4181e-21],
        ...,
        [2.9211e-01, 1.7717e-01, 1.0746e-01,  ..., 1.4969e-18, 1.4062e-18,
         1.3080e-19],
        [9.9187e-01, 7.5730e-03, 5.4859e-04,  ..., 5.1846e-16, 4.2982e-16,
         3.7932e-16],
        [5.6988e-01, 1.2716e-01, 9.9031e-02,  ..., 6.5162e-19, 6.5162e-19,
         2.3972e-19]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[0.7964, 0.9740, 0.9886,  ..., 1.0000, 1.0000, 1.0000],
        [0.8873, 0.9314, 0.9523,  ..., 1.0000, 1.0000, 1.0000],
        [0.8249, 0.8660, 0.9071,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.2921, 0.4693, 0.5767,  ..., 1.0000, 1.0000, 1.0000],
        [0.9919, 0.9994, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.5699, 0.6970, 0.7961,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([115])


tensor([ 0,  0,  1,  1,  2,  2,  2,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  6,
         6,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
         8,  9,  9,  9, 10, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15,
        15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16,
        16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17,
        17, 18, 19, 19, 19, 19, 19], device='cuda:0')


carryover_candidates
torch.Size([115, 109])


tensor([[    1, 32010,  1724,  ...,   306, 29915, 29885],
        [    1, 32010,  1724,  ...,   306, 29915, 29885],
        [    1, 32010,  1724,  ...,   306,   508,  3867],
        ...,
        [    1, 32010,  1724,  ...,   960,   366,   526],
        [    1, 32010,  1724,  ...,   960,   366,   526],
        [    1, 32010,  1724,  ...,   960,   366,   526]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([115])


tensor([-16.4229, -16.4229, -17.9295, -17.9295, -18.4295, -18.4295, -18.4295,
        -18.6846, -18.6846, -18.6846, -18.6846, -18.8096, -18.8096, -18.8096,
        -18.8096, -18.8096, -18.8096, -19.0596, -19.0596, -19.0596, -19.0596,
        -19.0596, -19.0596, -19.0596, -19.0596, -19.0596, -19.0596, -19.0596,
        -19.0596, -19.0596, -19.0596, -19.0596, -19.0596, -19.0596, -19.0596,
        -19.5596, -19.5596, -20.9346, -20.9346, -20.9346, -20.9346, -20.9346,
        -20.9346, -20.9346, -20.9346, -20.9346, -20.9346, -20.9346, -20.9346,
        -20.9346, -20.9346, -20.9346, -20.9346, -20.9346, -17.6100, -15.9939,
        -15.9939, -15.9939, -17.7439, -18.2439, -18.2439, -18.9939, -18.9939,
        -18.9939, -19.2439, -19.2439, -19.2439, -16.6433, -16.6433, -16.6433,
        -16.8877, -16.8877, -16.8877, -16.8877, -16.8877, -16.8877, -16.8877,
        -16.8877, -16.8877, -16.8877, -16.8877, -16.8877, -16.8877, -17.0375,
        -17.0375, -17.0375, -17.0375, -17.0375, -17.0375, -17.03


new_candidate_toks
torch.Size([115, 1])


tensor([[ 1244],
        [  385],
        [ 2472],
        [  366],
        [ 3867],
        [ 1371],
        [ 6985],
        [29889],
        [29991],
        [  393],
        [29892],
        [  366],
        [ 8040],
        [  278],
        [  727],
        [ 1737],
        [19223],
        [  739],
        [  960],
        [  450],
        [19152],
        [  512],
        [ 1932],
        [ 1670],
        [22738],
        [ 3529],
        [11563],
        [ 5806],
        [ 1551],
        [ 8040],
        [28418],
        [ 2266],
        [  382],
        [22907],
        [  887],
        [29915],
        [ 1795],
        [ 6964],
        [ 1139],
        [14679],
        [21578],
        [ 2009],
        [ 2346],
        [27742],
        [ 2847],
        [ 3407],
        [ 4066],
        [ 3619],
        [ 4328],
        [13500],
        [ 2498],
        [ 5972],
        [  817],
        [ 7037],
        [  363],
        [  366],
        [  591],
        [  596],
        [29915


new_candidate_tok_logprobs
torch.Size([115])


tensor([-2.2771e-01, -1.7277e+00, -1.1963e-01, -3.1196e+00, -1.9247e-01,
        -3.1925e+00, -3.1925e+00, -5.3594e-01, -1.7859e+00, -2.5359e+00,
        -2.5359e+00, -4.2973e-01, -2.1797e+00, -2.6797e+00, -3.3047e+00,
        -3.8047e+00, -4.6797e+00, -1.8930e+00, -2.0180e+00, -2.3930e+00,
        -2.5180e+00, -2.6430e+00, -2.6430e+00, -3.1430e+00, -3.1430e+00,
        -3.3930e+00, -3.3930e+00, -3.5180e+00, -3.7680e+00, -3.7680e+00,
        -4.0180e+00, -4.1430e+00, -4.2680e+00, -4.2680e+00, -4.3930e+00,
        -2.6766e-01, -1.7677e+00, -1.3001e+00, -1.9251e+00, -2.0501e+00,
        -2.1751e+00, -3.4251e+00, -3.4251e+00, -3.5501e+00, -3.6751e+00,
        -3.8001e+00, -3.9251e+00, -4.1751e+00, -4.1751e+00, -4.3001e+00,
        -4.4251e+00, -4.4251e+00, -4.6751e+00, -4.8001e+00, -4.6379e-03,
        -5.8892e-01, -1.0889e+00, -3.0889e+00, -6.5784e-02, -2.4153e-01,
        -1.8665e+00, -9.1451e-01, -1.1645e+00, -1.4145e+00, -8.7189e-01,
        -1.1219e+00, -1.4969e+00, -3.8017e-01, -1.7


new_candidates
torch.Size([115, 110])


tensor([[    1, 32010,  1724,  ..., 29915, 29885,  1244],
        [    1, 32010,  1724,  ..., 29915, 29885,   385],
        [    1, 32010,  1724,  ...,   508,  3867,  2472],
        ...,
        [    1, 32010,  1724,  ...,   366,   526,  6721],
        [    1, 32010,  1724,  ...,   366,   526, 13858],
        [    1, 32010,  1724,  ...,   366,   526, 25501]], device='cuda:0')


new_candidate_logprobs
torch.Size([115])


tensor([-16.6506, -18.1506, -18.0491, -21.0491, -18.6220, -21.6220, -21.6220,
        -19.2205, -20.4705, -21.2205, -21.2205, -19.2393, -20.9893, -21.4893,
        -22.1143, -22.6143, -23.4893, -20.9525, -21.0775, -21.4525, -21.5775,
        -21.7025, -21.7025, -22.2025, -22.2025, -22.4525, -22.4525, -22.5775,
        -22.8275, -22.8275, -23.0775, -23.2025, -23.3275, -23.3275, -23.4525,
        -19.8272, -21.3272, -22.2346, -22.8596, -22.9846, -23.1096, -24.3596,
        -24.3596, -24.4846, -24.6096, -24.7346, -24.8596, -25.1096, -25.1096,
        -25.2346, -25.3596, -25.3596, -25.6096, -25.7346, -17.6146, -16.5828,
        -17.0828, -19.0828, -17.8096, -18.4854, -20.1104, -19.9084, -20.1584,
        -20.4084, -20.1158, -20.3658, -20.7408, -17.0234, -18.3984, -19.2734,
        -18.2010, -18.2010, -19.0760, -20.0760, -20.0760, -20.0760, -20.4510,
        -20.5760, -20.8260, -21.0760, -21.2010, -21.2010, -21.3260, -17.7406,
        -19.1156, -19.7406, -19.8656, -20.2406, -20.6156, -20.74

infer end: GPU memory used: 18347 MB.
event: level
id: 98
data: [{"content": "here", "parent": 0, "prob": -16.650636672973633}, {"content": "an", "parent": 0, "prob": -18.150636672973633}, {"content": "information", "parent": 1, "prob": -18.049137115478516}, {"content": "you", "parent": 1, "prob": -21.049137115478516}, {"content": "provide", "parent": 2, "prob": -18.62197494506836}, {"content": "help", "parent": 2, "prob": -21.62197494506836}, {"content": "assist", "parent": 2, "prob": -21.62197494506836}, {"content": ".", "parent": 3, "prob": -19.22049903869629}, {"content": "!", "parent": 3, "prob": -20.47049903869629}, {"content": "that", "parent": 3, "prob": -21.22049903869629}, {"content": ",", "parent": 3, "prob": -21.22049903869629}, {"content": "you", "parent": 4, "prob": -19.23929214477539}, {"content": "Mount", "parent": 4, "prob": -20.98929214477539}, {"content": "the", "parent": 4, "prob": -21.48929214477539}, {"content": "there", "parent": 4, "prob": -22.11429214477539}, {

array([[ 2.765625  ,  0.421875  ,  0.87109375, ...,  0.73828125,
        -2.078125  , -2.015625  ],
       [ 0.8671875 , -0.640625  ,  1.8125    , ...,  0.92578125,
         1.90625   , -0.08105469],
       [-0.34765625,  0.3359375 ,  0.58203125, ..., -0.49804688,
         2.265625  , -0.12890625],
       ...,
       [-0.66015625, -0.53515625, -0.796875  , ..., -0.17480469,
        -1.546875  , -1.5703125 ],
       [-0.10595703, -0.54296875,  1.78125   , ..., -0.65234375,
        -1.171875  , -1.2265625 ],
       [-2.265625  ,  1.265625  ,  0.89453125, ...,  0.01940918,
        -0.45703125, -0.9140625 ]], dtype=float32)


k_mean_space
(20, 2)


array([[76.77141 , 83.785545],
       [89.77027 , 80.79341 ],
       [53.13246 , 64.075226],
       [45.93683 , 57.015167],
       [45.298973, 56.846184],
       [53.225315, 67.793205],
       [49.55399 , 64.15691 ],
       [49.13791 , 57.49525 ],
       [51.32902 , 58.836365],
       [39.80336 , 47.637608],
       [44.991264, 57.59416 ],
       [57.303852, 60.27711 ],
       [81.73933 , 69.081635],
       [55.35996 , 45.81582 ],
       [54.73745 , 48.29133 ],
       [56.551315, 47.21079 ],
       [61.524807, 54.23164 ],
       [58.186214, 51.27538 ],
       [50.346077, 59.182114],
       [56.917183, 45.849865]], dtype=float32)


k_mean_clusters
(20,)


array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1],
      dtype=int32)


k_mean_logprob_mass
(2,)


array([-240.06367111, -171.25219154])


closest
(2,)


array([ 9, 13])


last_tok_logits
torch.Size([20, 32064])


tensor([[-3.5625, -1.8984,  1.2109,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.5000, -3.0938, -1.1641,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.4531,  3.1875,  3.0469,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 2.4375,  0.3203,  5.3750,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.0938, -1.0078,  4.3125,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.3750,  5.7500,  4.3750,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([20, 32064])

sorted_indices
torch.Size([20, 32064])

sorted_probs
torch.Size([20, 32064])


tensor([[1.0000e+00, 1.5535e-06, 7.3382e-07,  ..., 1.1469e-24, 1.0121e-24,
         2.2583e-25],
        [9.9974e-01, 2.0342e-04, 4.0055e-05,  ..., 3.3508e-23, 1.7936e-23,
         1.7936e-23],
        [5.9784e-01, 2.8240e-01, 4.9074e-02,  ..., 1.9201e-21, 1.9201e-21,
         1.0277e-21],
        ...,
        [9.7236e-01, 1.0802e-02, 9.5327e-03,  ..., 5.9512e-19, 1.5047e-19,
         2.1677e-20],
        [8.8436e-01, 4.9892e-02, 2.3567e-02,  ..., 1.3685e-19, 2.5315e-20,
         1.0553e-20],
        [4.3872e-01, 3.8717e-01, 6.7280e-02,  ..., 1.0076e-17, 3.4821e-18,
         2.6851e-19]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([20, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9997, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.5978, 0.8802, 0.9293,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9724, 0.9832, 0.9927,  ..., 1.0000, 1.0000, 1.0000],
        [0.8844, 0.9342, 0.9578,  ..., 1.0000, 1.0000, 1.0000],
        [0.4387, 0.8259, 0.8932,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([20, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([91])


tensor([ 0,  1,  2,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  7,  7,  7,
         7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  9,  9, 10, 10, 10, 11, 11, 12, 13, 13, 13, 13, 13, 14, 14, 14, 14,
        14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 17, 18, 18, 19, 19, 19,
        19], device='cuda:0')


carryover_candidates
torch.Size([91, 110])


tensor([[    1, 32010,  1724,  ..., 29915, 29885,  1244],
        [    1, 32010,  1724,  ..., 29915, 29885,   385],
        [    1, 32010,  1724,  ...,   508,  3867,  2472],
        ...,
        [    1, 32010,  1724,  ...,  2274, 29889,   450],
        [    1, 32010,  1724,  ...,  2274, 29889,   450],
        [    1, 32010,  1724,  ...,  2274, 29889,   450]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([91])


tensor([-16.6506, -18.1506, -18.0491, -18.0491, -18.0491, -21.0491, -21.0491,
        -18.6220, -18.6220, -21.6220, -21.6220, -21.6220, -21.6220, -19.2205,
        -19.2205, -19.2205, -19.2205, -19.2205, -19.2205, -19.2205, -19.2205,
        -19.2205, -19.2205, -19.2205, -19.2205, -19.2205, -19.2205, -19.2205,
        -19.2205, -19.2205, -19.2205, -19.2205, -19.2205, -19.2205, -19.2205,
        -19.2205, -20.4705, -20.4705, -20.4705, -20.4705, -20.4705, -20.4705,
        -20.4705, -20.4705, -20.4705, -20.4705, -20.4705, -20.4705, -20.4705,
        -20.4705, -20.4705, -20.4705, -20.4705, -20.4705, -20.4705, -21.2205,
        -21.2205, -21.2205, -21.2205, -21.2205, -19.2393, -19.2393, -20.9893,
        -21.4893, -21.4893, -21.4893, -21.4893, -21.4893, -22.1143, -22.1143,
        -22.1143, -22.1143, -22.1143, -22.6143, -22.6143, -22.6143, -22.6143,
        -23.4893, -23.4893, -23.4893, -23.4893, -23.4893, -23.4893, -23.4893,
        -20.9525, -21.0775, -21.0775, -21.4525, -21.4525, -21.45


new_candidate_toks
torch.Size([91, 1])


tensor([[  304],
        [  319],
        [  373],
        [ 1048],
        [  322],
        [  411],
        [ 2472],
        [ 2472],
        [  366],
        [  366],
        [  411],
        [  366],
        [  411],
        [  450],
        [  739],
        [ 3529],
        [  512],
        [ 1932],
        [ 3374],
        [  960],
        [  887],
        [22738],
        [ 1670],
        [19065],
        [19152],
        [ 5806],
        [28418],
        [ 1317],
        [ 1551],
        [ 8040],
        [ 5169],
        [ 3938],
        [22907],
        [  382],
        [ 7806],
        [11563],
        [  450],
        [  739],
        [ 3374],
        [  960],
        [  887],
        [19065],
        [  512],
        [ 1932],
        [ 3529],
        [ 8040],
        [22738],
        [19152],
        [28418],
        [ 1670],
        [22907],
        [ 1317],
        [ 5806],
        [ 1551],
        [ 2803],
        [  366],
        [  746],
        [  366],
        [ 6452


new_candidate_tok_logprobs
torch.Size([91])


tensor([-3.2187e-06, -2.5651e-04, -5.1443e-01, -1.2644e+00, -3.0144e+00,
        -1.0962e-01, -2.3596e+00, -1.0974e-01, -3.1097e+00, -1.2474e-01,
        -3.1247e+00, -1.7169e-01, -1.9217e+00, -1.7100e+00, -2.0850e+00,
        -2.5850e+00, -2.7100e+00, -2.8350e+00, -2.8350e+00, -2.9600e+00,
        -3.3350e+00, -3.4600e+00, -3.5850e+00, -3.5850e+00, -3.5850e+00,
        -3.7100e+00, -3.8350e+00, -4.2100e+00, -4.3350e+00, -4.3350e+00,
        -4.4600e+00, -4.7100e+00, -4.7100e+00, -4.8350e+00, -5.0850e+00,
        -5.2100e+00, -1.7323e+00, -1.9823e+00, -2.1073e+00, -2.4823e+00,
        -2.9823e+00, -2.9823e+00, -3.3573e+00, -3.3573e+00, -3.4823e+00,
        -3.4823e+00, -3.4823e+00, -3.6073e+00, -3.7323e+00, -3.9823e+00,
        -4.1073e+00, -4.6073e+00, -4.6073e+00, -4.7323e+00, -4.7323e+00,
        -1.0670e-01, -3.7317e+00, -4.8041e-01, -1.3554e+00, -2.8554e+00,
        -5.7843e-01, -1.0784e+00, -3.4624e-04, -8.5766e-01, -1.4827e+00,
        -1.8577e+00, -2.4827e+00, -4.3577e+00, -1.1


new_candidates
torch.Size([91, 111])


tensor([[    1, 32010,  1724,  ..., 29885,  1244,   304],
        [    1, 32010,  1724,  ..., 29885,   385,   319],
        [    1, 32010,  1724,  ...,  3867,  2472,   373],
        ...,
        [    1, 32010,  1724,  ..., 29889,   450,  9939],
        [    1, 32010,  1724,  ..., 29889,   450, 11563],
        [    1, 32010,  1724,  ..., 29889,   450,  1840]], device='cuda:0')


new_candidate_logprobs
torch.Size([91])


tensor([-16.6506, -18.1509, -18.5636, -19.3136, -21.0636, -21.1588, -23.4088,
        -18.7317, -21.7317, -21.7467, -24.7467, -21.7937, -23.5437, -20.9305,
        -21.3055, -21.8055, -21.9305, -22.0555, -22.0555, -22.1805, -22.5555,
        -22.6805, -22.8055, -22.8055, -22.8055, -22.9305, -23.0555, -23.4305,
        -23.5555, -23.5555, -23.6805, -23.9305, -23.9305, -24.0555, -24.3055,
        -24.4305, -22.2028, -22.4528, -22.5778, -22.9528, -23.4528, -23.4528,
        -23.8278, -23.8278, -23.9528, -23.9528, -23.9528, -24.0778, -24.2028,
        -24.4528, -24.5778, -25.0778, -25.0778, -25.2028, -25.2028, -21.3272,
        -24.9522, -21.7009, -22.5759, -24.0759, -19.8177, -20.3177, -20.9896,
        -22.3470, -22.9720, -23.3470, -23.9720, -25.8470, -23.2455, -23.2455,
        -24.1205, -24.3705, -24.9955, -23.4444, -24.0694, -24.3194, -25.0694,
        -24.1182, -25.1182, -26.3682, -26.3682, -26.8682, -27.2432, -28.3682,
        -20.9806, -21.2004, -24.0754, -22.2764, -22.4014, -24.15

infer end: GPU memory used: 18617 MB.
event: level
id: 99
data: [{"content": "to", "parent": 0, "prob": -16.6506404876709}, {"content": "A", "parent": 1, "prob": -18.15089225769043}, {"content": "on", "parent": 2, "prob": -18.563566207885742}, {"content": "about", "parent": 2, "prob": -19.313566207885742}, {"content": "and", "parent": 2, "prob": -21.063566207885742}, {"content": "with", "parent": 3, "prob": -21.158754348754883}, {"content": "information", "parent": 3, "prob": -23.408754348754883}, {"content": "information", "parent": 4, "prob": -18.73171615600586}, {"content": "you", "parent": 4, "prob": -21.73171615600586}, {"content": "you", "parent": 5, "prob": -21.746719360351562}, {"content": "with", "parent": 5, "prob": -24.746719360351562}, {"content": "you", "parent": 6, "prob": -21.793663024902344}, {"content": "with", "parent": 6, "prob": -23.543664932250977}, {"content": "The", "parent": 7, "prob": -20.930509567260742}, {"content": "It", "parent": 7, "prob": -21.305509567260

In [13]:
D(keep_indices.flatten())

torch.Size([256512])


tensor([ True,  True,  True,  ..., False, False, False], device='cuda:0')

In [14]:
keep_indices.flatten()[0:1000].sum()

tensor(1000, device='cuda:0')

In [31]:
D(sorted_indices)

torch.Size([8, 32064])


tensor([[29871,   259, 29892,  ..., 22715, 25923, 24336],
        [24278, 26785,  7066,  ..., 12645, 15084,  6941],
        [24278, 26785,  7066,  ..., 12645, 15084,  6941],
        ...,
        [24278, 26785,  7066,  ..., 12645, 15084,  6941],
        [24278, 26785,  7066,  ..., 12645, 15084,  6941],
        [24278, 26785,  7066,  ..., 12645, 15084,  6941]], device='cuda:0')




In [37]:
D(keep_indices)
print(keep_indices.sum())
keep_indices.sum(dim=1)

torch.Size([8, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


tensor(90805, device='cuda:0')


tensor([    1, 12972, 12972, 12972, 12972, 12972, 12972, 12972],
       device='cuda:0')

In [32]:
x = sorted_indices[keep_indices]
D(x)

torch.Size([90805])


tensor([29871, 24278, 26785,  ..., 15108, 15268, 16121], device='cuda:0')




In [41]:
a = torch.randn(8, 5)
D(a)
b = torch.ones(8, 5, dtype=torch.long)
b[0][1] = 0
D(b)
c = a[b]
D(c)

torch.Size([8, 5])


tensor([[-0.2994, -0.1878,  1.9159,  0.6902, -2.3217],
        [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
        [-1.3952,  0.4751, -0.8137,  0.9242, -0.2473],
        [-1.4154,  0.9874, -1.4878,  0.5867,  0.1583],
        [ 0.1102, -0.8188,  0.6328, -1.9169,  1.1711],
        [ 0.0975,  0.9634,  0.8403, -1.2537,  0.9868],
        [-0.4947, -1.2830,  0.9552,  1.2836, -0.6659],
        [ 0.5651,  0.2877, -0.0334, -1.0619, -0.1144]])


torch.Size([8, 5])


tensor([[1, 0, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]])


torch.Size([8, 5, 5])


tensor([[[-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-0.2994, -0.1878,  1.9159,  0.6902, -2.3217],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047]],

        [[-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047]],

        [[-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047]],

        [[-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.11




In [42]:
t = torch.tensor([[10, 30, 20], [60, 40, 50]])
sorted_idx = torch.argsort(t, dim=1)
D(sorted_idx)

torch.Size([2, 3])


tensor([[0, 2, 1],
        [1, 2, 0]])




In [43]:
x = np.arange(35).reshape(5, 7)

In [44]:
b = x > 20
D(b)

array([[False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False],
       [ True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True]])




In [45]:
b[:, 5]

array([False, False, False,  True,  True])

In [46]:
x[b]

array([21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34])

In [47]:
x[b[:, 5]]

array([[21, 22, 23, 24, 25, 26, 27],
       [28, 29, 30, 31, 32, 33, 34]])

In [50]:
b[3, :]

array([ True,  True,  True,  True,  True,  True,  True])

In [51]:
x[:, b[2, :]]

array([], shape=(5, 0), dtype=int64)

In [17]:
# def candidates_generator(text: str):
#     print(text)
#     candidates, candidate_masks, candidate_parents, candidate_logprobs = _init_candidates(text)

#     return candidates, candidate_masks, candidate_parents, candidate_logprobs
        


In [93]:
from transformers.utils import is_flash_attn_2_available
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import numpy as np
import torch.nn.functional as F
import torch
from datetime import timedelta
import time
from collections import namedtuple
import json

torch.random.manual_seed(0)

# Not directly comparable to legacy Inference yet --:
# - Remove p falloff from original
# - Are max candidates and max new tokens taken into account the same way?
class InferenceTensor:
    def __init__(self):
        self.model = AutoModelForCausalLM.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct",
            torch_dtype=torch.bfloat16,
            device_map='auto',
            trust_remote_code=True,
            use_cache=True,
            # attn_implementation='flash_attention_2',
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.max_candidates = 20
        self.max_new_tokens = 10
        self.batch_size = 8
        self.p_falloff = 0.5 # UNIMPLEMENTED
        self.prune_similar_sequences = True # UNIMPLEMENTED
        self.prune_similar_branches = True # UNIMPLEMENTED
        self.prune_similar_embeddings = True # UNIMPLEMENTED

    def candidates_generator(self, text: str):
        print(text)
        candidates, candidate_logprobs = self._init_candidates(text)
        for i in range(self.max_new_tokens):
            candidates, candidate_parents, candidate_logprobs = self._infer(candidates[:self.max_candidates, ...], candidate_logprobs[:self.max_candidates, ...])
            candidate_texts = self.tokenizer.batch_decode(candidates[:, -1])
            candidate_dicts = []
            for i in range(len(candidate_texts)):
                candidate_dicts.append({'content': candidate_texts[i], 'parent': candidate_parents[i].item(), 'prob': candidate_logprobs[i].item()})
            data = json.dumps(candidate_dicts)
            yield f"event: level\nid: {i}\ndata: {data}\n\n"

        yield f"event: level\nid: END\ndata: []\n\n"

    def _init_candidates(self, text: str):
        prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
        inputs = self.tokenizer(prompt, return_tensors='pt')
        print(self.tokenizer.batch_decode(inputs.input_ids))

        candidates = inputs.input_ids.to(self.device)
        candidate_logprobs = torch.zeros((1), dtype=torch.float32, device=self.device)

        return candidates, candidate_logprobs

    def _top_p_single_batch(self, logits, candidates, candidate_logprobs):
        last_tok_logits = logits[:, -1, :]
        
        sorted_logits, sorted_indices = torch.sort(last_tok_logits, descending=True, dim=-1)
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        cum_probs = torch.cumsum(sorted_probs, dim=-1)
        
        # Create tensor of bools indicating which indices are cumulatively less than top_p
        keep_indices = cum_probs < 0.96

        # Keep the last element that went over top_p
        keep_indices[:, 1:] = keep_indices[:, :-1].clone() # Is this inefficient?
        keep_indices[:, 0] = 1  # Always keep the first element
        
        new_candidate_parents = keep_indices.nonzero()[:, 0]
        
        # OPTIM: Potential optimization -- have a fixed tensor of size (max_candidates, max_tokens) and copy this into that (batch-aware).
        # OPTIM: consider which of these operations can be done in-place to prevent new allocations?
        carryover_candidates = candidates.index_select(0, new_candidate_parents)
        carryover_candidate_logprobs = candidate_logprobs.index_select(0, new_candidate_parents)  # Not strictly necessary since 1d
        
        new_candidate_toks = sorted_indices[keep_indices].unsqueeze(1)
        new_candidate_tok_logprobs = sorted_probs[keep_indices].log()
        
        new_candidates = torch.cat([carryover_candidates, new_candidate_toks], dim=1)
        new_candidate_logprobs = carryover_candidate_logprobs.add_(new_candidate_tok_logprobs)
        
        return new_candidates, new_candidate_parents, new_candidate_logprobs
        

    def _infer(self, candidates, candidate_logprobs):
        with torch.inference_mode():
            num_batches = (candidates.shape[0] + self.batch_size - 1) // self.batch_size  # Round up to nearest whole number of batches
            print('\nnum_batches', num_batches)
            new_candidates_list = []
            new_candidate_parents_list = []
            new_candidate_logprobs_list = []

            for i in range(0, num_batches, 1):
                batch_candidates = candidates[i * self.batch_size:(i + 1) * self.batch_size]
                batch_candidate_logprobs = candidate_logprobs[i * self.batch_size:(i + 1) * self.batch_size]

                batch_outputs = self.model(input_ids=batch_candidates)
                
                # TODO: Pruning step based on K-Means Clustering of embeddings here
                
                new_batch_candidates, new_batch_candidate_parents, new_batch_candidate_logprobs = self._top_p_single_batch(batch_outputs.logits, batch_candidates, batch_candidate_logprobs)
                new_candidates_list.append(new_batch_candidates)
                new_candidate_parents_list.append(new_batch_candidate_parents)
                new_candidate_logprobs_list.append(new_batch_candidate_logprobs)
                
            return torch.cat(new_candidates_list), torch.cat(new_candidate_parents_list), torch.cat(new_candidate_logprobs_list)
        




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


What is the highest mountain?
['<s><|user|> What is the highest mountain? <|end|><|assistant|>']

num_batches 1
event: level
id: 1
data: [{"content": "The", "parent": 0, "prob": -0.0789911225438118}, {"content": "As", "parent": 0, "prob": -2.578990936279297}]



num_batches 1
event: level
id: 1
data: [{"content": "highest", "parent": 0, "prob": -0.07979819923639297}, {"content": "of", "parent": 1, "prob": -2.578991651535034}]



num_batches 1
event: level
id: 2
data: [{"content": "mountain", "parent": 0, "prob": -0.079963319003582}, {"content": "my", "parent": 1, "prob": -2.6227312088012695}, {"content": "current", "parent": 1, "prob": -5.8727312088012695}]



num_batches 1
event: level
id: 8
data: [{"content": "on", "parent": 0, "prob": -0.3498345613479614}, {"content": "in", "parent": 0, "prob": -2.099834442138672}, {"content": "above", "parent": 0, "prob": -2.349834442138672}, {"content": "last", "parent": 1, "prob": -3.1989080905914307}, {"content": "knowledge", "parent": 1, "prob"

OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 0 has a total capacty of 22.17 GiB of which 18.38 MiB is free. Including non-PyTorch memory, this process has 22.14 GiB memory in use. Of the allocated memory 21.80 GiB is allocated by PyTorch, and 120.71 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF