In [1]:
!pip install transformers accelerate optimum nvidia-ml-py torchmetrics



In [2]:
from transformers.utils import is_flash_attn_2_available
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import numpy as np
import torch.nn.functional as F
import torch
from datetime import timedelta
import time
from collections import namedtuple
import json
from sklearn.cluster import KMeans
from torchmetrics.functional import pairwise_cosine_similarity

torch.random.manual_seed(0)

<torch._C.Generator at 0x7fe315056250>

In [3]:


from pynvml import *

def check_gpu(step):
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"{step}: GPU memory used: {info.used // 1024**2} MB.")
    
def D(obj, label=None, c=True):
    print()
    if label:
        print(label)
        
    if isinstance(obj, tuple):
        print('tuple size', len(obj), ':')
        if c: # Contents
            display(obj)
    elif isinstance(obj, torch.Tensor) or isinstance(obj, np.ndarray):
        print(obj.shape)
        if c: # Contents
            display(obj)
    else:
        if c: # Contents
            display(obj)
            
def DS(obj, label=None):
    D(obj, label, c=False)
    


In [21]:
class InferenceTensor:
    def __init__(self):
        print('Initializing model...')
        self.model = AutoModelForCausalLM.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct",
            torch_dtype=torch.bfloat16,
            device_map='auto',
            trust_remote_code=True,
            use_cache=True,
            # attn_implementation='flash_attention_2',
        )
        print('Initializing tokenizer...')
        self.tokenizer = AutoTokenizer.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.eos_token_id = 32007

        self.batch_size = 8
        
    def candidates_generator(self, top_p: float, top_p_decay: float, top_k: float, max_beams: int, max_new_tokens: int, gather_algo: str, prompt: str):
        candidates, candidate_logprobs = self._init_candidates(prompt)
        all_finished = []
        all_finished_logprobs = []
        
        for level_idx in range(max_new_tokens):
            start = time.perf_counter()
            logits, embeddings = self._infer(candidates, candidate_logprobs)
            
            if candidates.shape[0] > max_beams:
                if gather_algo == 'k_means':
                    candidates, candidate_parents, candidate_aunts, candidate_logprobs, logits = self._k_means(logits, embeddings, candidates, candidate_logprobs, max_beams)
                    inference_duration = time.perf_counter() - start
                    print('K MEANS PRIOR {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))
                    yield self._format_gather(level_idx, 'k', candidates, candidate_parents, candidate_aunts, candidate_logprobs, inference_duration)
                    print('K MEANS AFTER {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))

                elif gather_algo == 'farthest_neighbors':
                    candidates, candidate_parents, candidate_aunts, candidate_logprobs, logits = self._farthest_neighbors(logits, embeddings, candidates, candidate_logprobs, max_beams)
                    inference_duration = time.perf_counter() - start
                    print('F NEIGHBORS PRIOR {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))
                    yield self._format_gather(level_idx, 'f', candidates, candidate_parents, candidate_aunts, candidate_logprobs, inference_duration)
                    print('F NEIGHBORS AFTER {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))

            logits, candidates, candidate_logprobs, max_beams, finished, finished_parents, finished_logprobs = self._select_finished(logits, candidates, candidate_logprobs, max_beams)
            if finished.shape[0] > 0:
                all_finished.extend(finished)
                all_finished_logprobs.extend(finished_logprobs)
            
            candidates, candidate_parents, candidate_logprobs = self._top_p(logits, candidates, candidate_logprobs, top_p, top_k)
            inference_duration = time.perf_counter() - start
            print('TOP P PRIOR {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))
            yield self._format_top_p(level_idx, candidates, candidate_parents, candidate_logprobs, inference_duration, finished, finished_parents, finished_logprobs)
            print('TOP P AFTER {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))
            top_p *= top_p_decay
            
            if candidates.shape[0] == 0:
                break

        yield f"event: message\nid: END\ndata: []\n\n"
        return all_finished, all_finished_logprobs
    

    def _format_gather(self, suffix, level_idx, candidates, candidate_parents, candidate_aunts, candidate_logprobs, duration):
        candidate_texts = self.tokenizer.convert_ids_to_tokens(candidates[:, -1], skip_special_tokens=True)
        candidate_dicts = []
        idx = f"{level_idx}-{suffix}"
        for i in range(len(candidate_texts)):
            candidate_dicts.append({'content': candidate_texts[i], 'parent': candidate_parents[i], 'aunts': candidate_aunts[i], 'prob': candidate_logprobs[i].item()})
        data = json.dumps({'id': idx, 'level_type': 'gather', 'duration': duration, 'nodes': candidate_dicts})
        return f"event: message\nid: {idx}\"\ndata: {data}\n\n"

    def _format_top_p(self, level_idx, candidates, candidate_parents, candidate_logprobs, duration, finished, finished_parents, finished_logprobs):
        candidate_texts = self.tokenizer.convert_ids_to_tokens(candidates[:, -1])
        candidate_dicts = []
        idx = f"{level_idx}-p"
        for i in range(len(candidate_texts)):
            candidate_dicts.append({'content': candidate_texts[i], 'parent': candidate_parents[i], 'prob': candidate_logprobs[i].item()})
        
        finished_texts = self.tokenizer.batch_decode(finished, skip_special_tokens=True)
        finished_parents = finished_parents.tolist()
        finished_dicts = []
        for i in range(len(finished_texts)):
            finished_dicts.append({'content': finished_texts[i], 'parent': finished_parents[i], 'prob': finished_logprobs[i].item()})
        
        data = json.dumps({'id': idx, 'level_type': 'sample', 'duration': duration, 'nodes': candidate_dicts, 'finished': finished_dicts})
        return f"event: message\nid: {idx}\ndata: {data}\n\n"


    def _init_candidates(self, text: str):
        prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
        inputs = self.tokenizer(prompt, return_tensors='pt')
        D(inputs.input_ids, 'input_ids')

        candidates = inputs.input_ids.to(self.device)
        candidate_logprobs = torch.zeros((1), dtype=torch.float32, device=self.device)

        return candidates, candidate_logprobs

    def _k_means(self, logits, embeddings, candidates, candidate_logprobs, max_beams):
        D(candidates, 'candidates')
        D(candidate_logprobs, 'candidate_logprobs')
        # === CPU ===
        embeddings_np = embeddings.float().numpy(force=True)
        D(embeddings_np, 'embeddings_np')
        k_means = KMeans(n_clusters=min(max_beams, embeddings_np.shape[0]), random_state=0, n_init="auto")
        k_mean_space = k_means.fit_transform(embeddings_np)
        D(k_mean_space, 'k_mean_space')
        k_mean_clusters = k_means.predict(embeddings_np)
        D(k_mean_clusters, 'k_mean_clusters')
        k_mean_logprob_mass = np.log(np.bincount(k_mean_clusters, weights=candidate_logprobs.cpu().exp()))
        D(k_mean_logprob_mass, 'k_mean_logprob_mass')
        closest = np.argmin(k_mean_space, axis=0)
        D(closest, 'closest')
        # === END CPU ===
        
        closest_indices = torch.from_numpy(closest).to(self.device)
        new_candidates = candidates.index_select(0, closest_indices)
        D(new_candidates, 'new_candidates')
        new_candidate_parents = closest_indices.tolist()
        D(new_candidate_parents, 'new_candidate_parents')
        new_candidate_aunts = [torch.nonzero(torch.from_numpy(k_mean_clusters).to(self.device) == i).squeeze(-1).tolist() for i in range(new_candidates.shape[0])]
        D(new_candidate_aunts, 'new_candidate_aunts')
        new_candidate_logprobs = torch.from_numpy(k_mean_logprob_mass).to(self.device)
        D(new_candidate_logprobs, 'new_candidate_logprobs')
        new_candidate_logits = logits.index_select(0, closest_indices)
        
        return new_candidates, new_candidate_parents, new_candidate_aunts, new_candidate_logprobs, new_candidate_logits
        
    def _farthest_neighbors(self, logits, embeddings, candidates, candidate_logprobs, max_beams):
        D(candidates, 'candidates')
        D(candidate_logprobs, 'candidate_logprobs')
        D(embeddings, 'embeddings')
        
        selected = torch.zeros((candidates.shape[0],), dtype=torch.bool).to(self.device)
        max_prob_idx = candidate_logprobs.argmax()
        selected[max_prob_idx] = 1
        
        D(selected, 'selected')
        
        for idx in range(min(max_beams - 1, candidates.shape[0])):
            selected_embeddings = embeddings[selected]
            D(selected_embeddings, 'selected_embeddings')
            # Add 2 because bfloat16 on cuda can have imprecision and we need 0 to be lower than every
            # cosine distance
            distances = torch.add(2, pairwise_cosine_similarity(embeddings, selected_embeddings), alpha=-1)
            D(distances, 'distances')
            min_distances = torch.min(distances, dim=1).values
            D(min_distances, 'min_distances')
            min_remaining_distances = min_distances * ~selected
            D(min_remaining_distances, 'min_remaining_distances')
            next_selected = min_remaining_distances.argmax(dim=0)
            selected[next_selected] = 1
            D(selected, 'selected (end of loop)')
            
        # We have all the candidates that are selected to move forward. Figure out which probability mass
        # to assign where.
        selected_embeddings = embeddings[selected]
        D(selected_embeddings, 'selected_embeddings')
        # Add 2 because bfloat16 on cuda can have imprecision and we need 0 to be lower than every
        # cosine distance
        distances = torch.add(2, pairwise_cosine_similarity(embeddings, selected_embeddings), alpha=-1)
        D(distances, 'distances')
        
        closest_per_candidate = distances.argmin(dim=1)
        D(closest_per_candidate, 'closest_per_candidate')
        
        new_candidates = candidates[selected]
        D(new_candidates, 'new_candidates')
        new_candidate_parents = torch.arange(candidates.shape[0]).to(self.device)[selected].tolist()
        D(new_candidate_parents, 'new_candidate_parents')
        new_candidate_aunts = [list(filter(lambda x: x != i, torch.nonzero(closest_per_candidate == i).squeeze(-1).tolist())) \
                       for i in range(new_candidates.shape[0])]
        D(new_candidate_aunts, 'new_candidate_aunts')
        new_candidate_logprobs = torch.zeros((new_candidates.shape[0],)).to(self.device)
        new_candidate_logprobs.index_add_(0, closest_per_candidate, candidate_logprobs)
        D(new_candidate_logprobs, 'new_candidate_logprobs')
        new_candidate_logits = logits[selected]
        
        return new_candidates, new_candidate_parents, new_candidate_aunts, new_candidate_logprobs, new_candidate_logits

    
    def _select_finished(self, logits, candidates, candidate_logprobs, max_beams):
        finished_mask = candidates[:,-1] == self.eos_token_id
        unfinished_mask = ~finished_mask
        D(finished_mask, 'finished_mask')
        
        new_logits = logits[unfinished_mask]
        new_candidates = candidates[unfinished_mask]
        new_candidate_logprobs = candidate_logprobs[unfinished_mask]
        new_max_beams = max_beams - finished_mask.sum()
        
        finished = candidates[finished_mask][:,:-1] # Remove the EOS token
        D(finished, 'finished')
        finished_parents = torch.arange(candidates.shape[0], device=self.device)[finished_mask]
        D(finished_parents, 'finished_parents')
        finished_logprobs = candidate_logprobs[finished_mask]
        D(finished_logprobs, 'finished_logprobs')
        
        return new_logits, new_candidates, new_candidate_logprobs, new_max_beams, finished, finished_parents, finished_logprobs
    

    def _top_p(self, logits, candidates, candidate_logprobs, top_p, top_k):
        D(candidates, 'candidates')
        D(candidate_logprobs, 'candidate_logprobs')
        
        last_tok_logits = logits[:, -1, :]
        D(last_tok_logits, 'last_tok_logits')

        sorted_logits, sorted_indices = torch.sort(last_tok_logits, descending=True, dim=-1)
        DS(sorted_logits, 'sorted_logits')
        DS(sorted_indices, 'sorted_indices')
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        D(sorted_probs, 'sorted_probs')
        display(sorted_probs.sum(dim=1))
        cum_probs = torch.cumsum(sorted_probs, dim=-1)
        D(cum_probs, 'cum_probs')

        # Create tensor of bools indicating which indices are cumulatively less than top_p
        keep_indices = cum_probs < top_p

        # Keep the last element that went over top_p
        keep_indices[:, 1:] = keep_indices[:, :-1].clone() # Is this inefficient?
        keep_indices[:, 0] = 1  # Always keep the first element
        D(keep_indices, 'keep_indices')

        # Don't keep any indices that are greater than top_k
        keep_indices[:, top_k:] = 0
        D(keep_indices, 'keep_indices after top_k')

        new_candidate_parents = keep_indices.nonzero()[:, 0]
        D(new_candidate_parents, 'new_candidate_parents')

        # OPTIM: Potential optimization -- have a fixed tensor of size (max_candidates, max_tokens) and copy this into that (batch-aware).
        # OPTIM: consider which of these operations can be done in-place to prevent new allocations?
        carryover_candidates = candidates.index_select(0, new_candidate_parents)
        D(carryover_candidates, 'carryover_candidates')

        # Similar code could be used to trace entire origin of sequence. For now since server just traces parent of the preceding generation, not needed
        # carryover_candidate_parents = candidate_parents.index_select(0, carryover_candidate_indices)  # Not strictly necessary since 1d
        # D(carryover_candidate_parents, 'carryover_candidate_parents')

        carryover_candidate_logprobs = candidate_logprobs.index_select(0, new_candidate_parents)  # Not strictly necessary since 1d
        D(carryover_candidate_logprobs, 'carryover_candidate_logprobs')

        new_candidate_toks = sorted_indices[keep_indices].unsqueeze(1)
        D(new_candidate_toks, 'new_candidate_toks')
        new_candidate_tok_logprobs = sorted_probs[keep_indices].log()
        D(new_candidate_tok_logprobs, 'new_candidate_tok_logprobs')

        new_candidates = torch.cat([carryover_candidates, new_candidate_toks], dim=1)
        D(new_candidates, 'new_candidates')
        new_candidate_logprobs = carryover_candidate_logprobs.add_(new_candidate_tok_logprobs)
        D(new_candidate_logprobs, 'new_candidate_logprobs')

        return new_candidates, new_candidate_parents.tolist(), new_candidate_logprobs


    def _infer(self, candidates, candidate_logprobs):
        with torch.inference_mode():
            num_batches = (candidates.shape[0] + self.batch_size - 1) // self.batch_size  # Round up to nearest whole number of batches
            D(num_batches, 'num_batches')

            check_gpu('infer start')
            output_logits_list = []
            output_embeddings_list = []
            for i in range(0, num_batches, 1):
                batch_candidates = candidates[i * self.batch_size:(i + 1) * self.batch_size]
                DS(batch_candidates, 'batch_candidates')
                batch_candidate_logprobs = candidate_logprobs[i * self.batch_size:(i + 1) * self.batch_size]
                DS(batch_candidate_logprobs, 'batch_candidate_logprobs')

                batch_outputs = self.model(input_ids=batch_candidates, output_hidden_states=True)
                DS(batch_outputs.logits, 'batch_logits')
                DS(batch_outputs.hidden_states[-1], 'hidden_states[-1]')

                output_logits_list.append(batch_outputs.logits)
                output_embeddings_list.append(batch_outputs.hidden_states[-1][:,-1,:])
                check_gpu('infer - after batch run')

            output_logits = torch.cat(output_logits_list, dim=0)
            output_embeddings = torch.cat(output_embeddings_list, dim=0)
            
            return output_logits, output_embeddings
        
it = InferenceTensor()

for x in it.candidates_generator(top_p=0.9, top_p_decay=1.0, top_k=5, max_beams=3, max_new_tokens=50, gather_algo='farthest_neighbors', prompt='What is the closest star to the Earth? Answer in 5 words or less.'):
    print(x)
    print()
    print('====================================')
    print()

Initializing model...




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Initializing tokenizer...

input_ids
torch.Size([1, 22])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001]])


num_batches


1

infer start: GPU memory used: 14564 MB.

batch_candidates
torch.Size([1, 22])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 22, 32064])

hidden_states[-1]
torch.Size([1, 22, 3072])
infer - after batch run: GPU memory used: 7476 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 21])


tensor([], device='cuda:0', size=(0, 21), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 22])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([0.], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[11.4375,  9.4375,  7.8125,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[8.3999e-01, 1.4597e-01, 1.3577e-02,  ..., 4.0964e-23, 4.0964e-23,
         1.2370e-24]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.8400, 0.9860, 0.9995,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 0], device='cuda:0')


carryover_candidates
torch.Size([2, 22])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([0., 0.], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[ 450],
        [1019]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.1744, -1.9244], device='cuda:0')


new_candidates
torch.Size([2, 23])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-0.1744, -1.9244], device='cuda:0')

TOP P PRIOR 0: (15580.075339598) 2 candidates, 2.1145544099999825 inference time, 2.114557982999031 total time
event: message
id: 0-p
data: {"id": "0-p", "level_type": "sample", "duration": 2.1145544099999825, "nodes": [{"content": "\u2581The", "parent": 0, "prob": -0.1743605136871338}, {"content": "\u2581Pro", "parent": 0, "prob": -1.9243606328964233}], "finished": []}




TOP P AFTER 0: (15580.075701606) 2 candidates, 2.1145544099999825 inference time, 2.11491906899937 total time

num_batches


1

infer start: GPU memory used: 7476 MB.

batch_candidates
torch.Size([2, 23])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 23, 32064])

hidden_states[-1]
torch.Size([2, 23, 3072])
infer - after batch run: GPU memory used: 7500 MB.

finished_mask
torch.Size([2])


tensor([False, False], device='cuda:0')


finished
torch.Size([0, 22])


tensor([], device='cuda:0', size=(0, 22), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([2, 23])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-0.1744, -1.9244], device='cuda:0')


last_tok_logits
torch.Size([2, 32064])


tensor([[4.8750, 6.0938, 2.9844,  ..., 0.0000, 0.0000, 0.0000],
        [1.0000, 0.5977, 1.7891,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[9.8566e-01, 1.4060e-02, 1.2164e-04,  ..., 1.8638e-24, 9.9760e-25,
         1.7336e-25],
        [1.0000e+00, 3.6535e-08, 3.8507e-09,  ..., 3.9244e-26, 1.1244e-26,
         9.9224e-27]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[0.9857, 0.9997, 0.9998,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 1], device='cuda:0')


carryover_candidates
torch.Size([2, 23])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-0.1744, -1.9244], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[21438],
        [ 2657]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.0144,  0.0000], device='cuda:0')


new_candidates
torch.Size([2, 24])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-0.1888, -1.9244], device='cuda:0')

TOP P PRIOR 1: (15581.810057068) 2 candidates, 1.7343410199991922 inference time, 1.7343461380005465 total time
event: message
id: 1-p
data: {"id": "1-p", "level_type": "sample", "duration": 1.7343410199991922, "nodes": [{"content": "\u2581closest", "parent": 0, "prob": -0.1888073980808258}, {"content": "xim", "parent": 1, "prob": -1.9243606328964233}], "finished": []}




TOP P AFTER 1: (15581.810611016) 2 candidates, 1.7343410199991922 inference time, 1.734899459999724 total time

num_batches


1

infer start: GPU memory used: 7500 MB.

batch_candidates
torch.Size([2, 24])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 24, 32064])

hidden_states[-1]
torch.Size([2, 24, 3072])
infer - after batch run: GPU memory used: 7688 MB.

finished_mask
torch.Size([2])


tensor([False, False], device='cuda:0')


finished
torch.Size([0, 23])


tensor([], device='cuda:0', size=(0, 23), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([2, 24])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-0.1888, -1.9244], device='cuda:0')


last_tok_logits
torch.Size([2, 32064])


tensor([[ 4.0625, -3.8125, -1.6406,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.3750, -1.3984, -2.5938,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[9.9960e-01, 3.7998e-04, 6.1418e-06,  ..., 1.5515e-25, 1.0663e-25,
         4.4452e-26],
        [9.9999e-01, 6.1442e-06, 1.8554e-07,  ..., 9.1107e-23, 4.8766e-23,
         1.0121e-24]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[0.9996, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 1], device='cuda:0')


carryover_candidates
torch.Size([2, 24])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-0.1888, -1.9244], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[ 5810],
        [29874]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-3.9544e-04, -6.4373e-06], device='cuda:0')


new_candidates
torch.Size([2, 25])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-0.1892, -1.9244], device='cuda:0')

TOP P PRIOR 2: (15583.485639932) 2 candidates, 1.6749949519999063 inference time, 1.6749995930003934 total time
event: message
id: 2-p
data: {"id": "2-p", "level_type": "sample", "duration": 1.6749949519999063, "nodes": [{"content": "\u2581star", "parent": 0, "prob": -0.18920283019542694}, {"content": "a", "parent": 1, "prob": -1.924367070198059}], "finished": []}




TOP P AFTER 2: (15583.486027284) 2 candidates, 1.6749949519999063 inference time, 1.6753851800003758 total time

num_batches


1

infer start: GPU memory used: 7688 MB.

batch_candidates
torch.Size([2, 25])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 25, 32064])

hidden_states[-1]
torch.Size([2, 25, 3072])
infer - after batch run: GPU memory used: 7508 MB.

finished_mask
torch.Size([2])


tensor([False, False], device='cuda:0')


finished
torch.Size([0, 24])


tensor([], device='cuda:0', size=(0, 24), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([2, 25])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-0.1892, -1.9244], device='cuda:0')


last_tok_logits
torch.Size([2, 32064])


tensor([[ 5.0312, -2.1250, -1.3906,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.9375, -3.5312, -0.2754,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[8.4862e-01, 1.4747e-01, 3.0606e-03,  ..., 1.1249e-22, 9.9275e-23,
         8.1490e-24],
        [9.9999e-01, 3.2887e-06, 1.5535e-06,  ..., 3.2858e-25, 1.7588e-25,
         5.7099e-26]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[0.8486, 0.9961, 0.9991,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([2, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 0, 1], device='cuda:0')


carryover_candidates
torch.Size([3, 25])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([-0.1892, -0.1892, -1.9244], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[ 304],
        [ 338],
        [2895]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-1.6415e-01, -1.9141e+00, -5.4836e-06], device='cuda:0')


new_candidates
torch.Size([3, 26])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895]], device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([-0.3533, -2.1033, -1.9244], device='cuda:0')

TOP P PRIOR 3: (15585.218911978) 3 candidates, 1.7328713280003285 inference time, 1.7328744160004135 total time
event: message
id: 3-p
data: {"id": "3-p", "level_type": "sample", "duration": 1.7328713280003285, "nodes": [{"content": "\u2581to", "parent": 0, "prob": -0.35334891080856323}, {"content": "\u2581is", "parent": 0, "prob": -2.103348970413208}, {"content": "\u2581Cent", "parent": 1, "prob": -1.9243725538253784}], "finished": []}




TOP P AFTER 3: (15585.219612177) 3 candidates, 1.7328713280003285 inference time, 1.7335743469993758 total time

num_batches


1

infer start: GPU memory used: 7508 MB.

batch_candidates
torch.Size([3, 26])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 26, 32064])

hidden_states[-1]
torch.Size([3, 26, 3072])
infer - after batch run: GPU memory used: 7676 MB.

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 25])


tensor([], device='cuda:0', size=(0, 25), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 26])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-0.3533, -2.1033, -1.9244], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 8.6250,  4.4062,  5.5000,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.3438,  2.0469,  1.5781,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.5000, -0.5469,  2.4531,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[9.9532e-01, 4.6092e-03, 5.8022e-05,  ..., 1.6941e-22, 1.5915e-22,
         6.2324e-23],
        [4.8121e-01, 4.8121e-01, 2.7148e-02,  ..., 1.2813e-21, 1.7340e-22,
         9.2814e-23],
        [9.9994e-01, 4.5397e-05, 7.8888e-06,  ..., 1.3971e-23, 4.0027e-24,
         1.7587e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[0.9953, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.4812, 0.9624, 0.9896,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([4])


tensor([0, 1, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([4, 26])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([4])


tensor([-0.3533, -2.1033, -2.1033, -1.9244], device='cuda:0')


new_candidate_toks
torch.Size([4, 1])


tensor([[11563],
        [  278],
        [ 1019],
        [29874]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([4])


tensor([-4.6923e-03, -7.3145e-01, -7.3145e-01, -6.1633e-05], device='cuda:0')


new_candidates
torch.Size([4, 27])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874]], device='cuda:0')


new_candidate_logprobs
torch.Size([4])


tensor([-0.3580, -2.8348, -2.8348, -1.9244], device='cuda:0')

TOP P PRIOR 4: (15586.987227756) 4 candidates, 1.7676017569992837 inference time, 1.7676067009997496 total time
event: message
id: 4-p
data: {"id": "4-p", "level_type": "sample", "duration": 1.7676017569992837, "nodes": [{"content": "\u2581Earth", "parent": 0, "prob": -0.3580411970615387}, {"content": "\u2581the", "parent": 1, "prob": -2.834794282913208}, {"content": "\u2581Pro", "parent": 1, "prob": -2.834794282913208}, {"content": "a", "parent": 2, "prob": -1.9244341850280762}], "finished": []}




TOP P AFTER 4: (15586.987813126) 4 candidates, 1.7676017569992837 inference time, 1.7681917049994809 total time

num_batches


1

infer start: GPU memory used: 7676 MB.

batch_candidates
torch.Size([4, 27])

batch_candidate_logprobs
torch.Size([4])

batch_logits
torch.Size([4, 27, 32064])

hidden_states[-1]
torch.Size([4, 27, 3072])
infer - after batch run: GPU memory used: 7686 MB.

candidates
torch.Size([4, 27])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874]], device='cuda:0')


candidate_logprobs
torch.Size([4])


tensor([-0.3580, -2.8348, -2.8348, -1.9244], device='cuda:0')


embeddings
torch.Size([4, 3072])


tensor([[ 0.6406, -0.9219,  1.1172,  ...,  1.5547, -1.2734,  0.4902],
        [ 1.4297,  0.7031,  1.1250,  ...,  0.9375, -0.9375,  0.4629],
        [-2.2188,  0.7305, -0.4473,  ..., -1.6953, -0.0540,  0.2793],
        [ 0.7383, -1.8281,  1.8594,  ..., -2.5469,  0.9141, -0.3320]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([4])


tensor([ True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.6406, -0.9219,  1.1172,  ...,  1.5547, -1.2734,  0.4902]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 1])


tensor([[1.0000],
        [1.7266],
        [1.8984],
        [1.9297]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([1.0000, 1.7266, 1.8984, 1.9297], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([0.0000, 1.7266, 1.8984, 1.9297], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([ True, False, False,  True], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[ 0.6406, -0.9219,  1.1172,  ...,  1.5547, -1.2734,  0.4902],
        [ 0.7383, -1.8281,  1.8594,  ..., -2.5469,  0.9141, -0.3320]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 2])


tensor([[1.0000, 1.9297],
        [1.7266, 1.9297],
        [1.8984, 1.9531],
        [1.9297, 0.9922]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([1.0000, 1.7266, 1.8984, 0.9922], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([0.0000, 1.7266, 1.8984, 0.0000], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([ True, False,  True,  True], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[ 0.6406, -0.9219,  1.1172,  ...,  1.5547, -1.2734,  0.4902],
        [-2.2188,  0.7305, -0.4473,  ..., -1.6953, -0.0540,  0.2793],
        [ 0.7383, -1.8281,  1.8594,  ..., -2.5469,  0.9141, -0.3320]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 3])


tensor([[1.0000, 1.8984, 1.9297],
        [1.7266, 1.8594, 1.9297],
        [1.8984, 1.0000, 1.9531],
        [1.9297, 1.9531, 0.9922]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([4])


tensor([0, 0, 1, 2], device='cuda:0')


new_candidates
torch.Size([3, 27])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874]], device='cuda:0')


new_candidate_parents


[0, 2, 3]


new_candidate_aunts


[[1], [2], [3]]


new_candidate_logprobs
torch.Size([3])


tensor([-3.1928, -2.8348, -1.9244], device='cuda:0')

F NEIGHBORS PRIOR 5: (15588.744817004) 3 candidates, 1.756981631000599 inference time, 1.7569874880009593 total time
event: message
id: f-5"
data: {"id": "f-5", "level_type": "gather", "duration": 1.756981631000599, "nodes": [{"content": "\u2581Earth", "parent": 0, "aunts": [1], "prob": -3.192835569381714}, {"content": "\u2581Pro", "parent": 2, "aunts": [2], "prob": -2.834794282913208}, {"content": "a", "parent": 3, "aunts": [3], "prob": -1.9244341850280762}]}




F NEIGHBORS AFTER 5: (15588.745405089) 3 candidates, 1.756981631000599 inference time, 1.7575752150005428 total time

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 26])


tensor([], device='cuda:0', size=(0, 26), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 27])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-3.1928, -2.8348, -1.9244], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 4.8438,  1.5391, -1.1562,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.2422,  1.3594,  0.7656,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.5625, -3.9688,  0.3242,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[9.9806e-01, 1.7003e-03, 1.3957e-04,  ..., 6.5870e-24, 6.1880e-24,
         4.5272e-24],
        [1.0000e+00, 1.2752e-07, 4.5991e-10,  ..., 9.9224e-27, 6.0183e-27,
         4.3596e-28],
        [9.9994e-01, 5.1442e-05, 7.8888e-06,  ..., 2.0592e-28, 1.2490e-28,
         8.5840e-29]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[0.9981, 0.9998, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([3, 27])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([-3.1928, -2.8348, -1.9244], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[ 338],
        [2657],
        [5338]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-1.9372e-03, -1.1921e-07, -6.2229e-05], device='cuda:0')


new_candidates
torch.Size([3, 28])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([-3.1948, -2.8348, -1.9245], device='cuda:0')

TOP P PRIOR 5: (15588.788267388) 3 candidates, 1.8004334410015872 inference time, 1.8004367370012915 total time
event: message
id: 5-p
data: {"id": "5-p", "level_type": "sample", "duration": 1.8004334410015872, "nodes": [{"content": "\u2581is", "parent": 0, "prob": -3.194772720336914}, {"content": "xim", "parent": 1, "prob": -2.834794521331787}, {"content": "uri", "parent": 2, "prob": -1.9244964122772217}], "finished": []}




TOP P AFTER 5: (15588.788617703) 3 candidates, 1.8004334410015872 inference time, 1.800786332001735 total time

num_batches


1

infer start: GPU memory used: 7686 MB.

batch_candidates
torch.Size([3, 28])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 28, 32064])

hidden_states[-1]
torch.Size([3, 28, 3072])
infer - after batch run: GPU memory used: 7676 MB.

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 27])


tensor([], device='cuda:0', size=(0, 27), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 28])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338]],
       device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-3.1948, -2.8348, -1.9245], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 6.0000,  2.2031,  1.1953,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.7656,  0.0420, -1.2969,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.6250, -0.1963,  8.1250,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[5.2175e-01, 4.6044e-01, 1.3904e-02,  ..., 3.0997e-22, 3.0997e-22,
         1.0063e-22],
        [1.0000e+00, 1.1861e-08, 1.6052e-09,  ..., 1.9930e-25, 4.4469e-26,
         3.0563e-26],
        [9.0407e-01, 9.5289e-02, 2.3620e-04,  ..., 3.6190e-24, 2.1950e-24,
         1.7095e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[0.5217, 0.9822, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9041, 0.9994, 0.9996,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([4])


tensor([0, 0, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([4, 28])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([4])


tensor([-3.1948, -3.1948, -2.8348, -1.9245], device='cuda:0')


new_candidate_toks
torch.Size([4, 1])


tensor([[  278],
        [ 1019],
        [29874],
        [  338]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([4])


tensor([-0.6506, -0.7756,  0.0000, -0.1008], device='cuda:0')


new_candidates
torch.Size([4, 29])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([4])


tensor([-3.8453, -3.9703, -2.8348, -2.0253], device='cuda:0')

TOP P PRIOR 6: (15590.562910444) 4 candidates, 1.7742789870007982 inference time, 1.7742825080003968 total time
event: message
id: 6-p
data: {"id": "6-p", "level_type": "sample", "duration": 1.7742789870007982, "nodes": [{"content": "\u2581the", "parent": 0, "prob": -3.8453469276428223}, {"content": "\u2581Pro", "parent": 0, "prob": -3.9703469276428223}, {"content": "a", "parent": 1, "prob": -2.834794521331787}, {"content": "\u2581is", "parent": 2, "prob": -2.025339365005493}], "finished": []}




TOP P AFTER 6: (15590.563433243) 4 candidates, 1.7742789870007982 inference time, 1.774805135000861 total time

num_batches


1

infer start: GPU memory used: 7676 MB.

batch_candidates
torch.Size([4, 29])

batch_candidate_logprobs
torch.Size([4])

batch_logits
torch.Size([4, 29, 32064])

hidden_states[-1]
torch.Size([4, 29, 3072])
infer - after batch run: GPU memory used: 7708 MB.

candidates
torch.Size([4, 29])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338]],
       device='cuda:0')


candidate_logprobs
torch.Size([4])


tensor([-3.8453, -3.9703, -2.8348, -2.0253], device='cuda:0')


embeddings
torch.Size([4, 3072])


tensor([[ 1.3750,  0.6758,  1.4531,  ...,  1.0156, -0.7930,  0.9414],
        [-2.0781,  0.7969, -0.2637,  ..., -1.6641,  0.0510,  0.3867],
        [-0.8594,  0.9336, -0.1465,  ..., -1.2500, -1.3047, -1.0391],
        [-0.6719,  0.5938,  0.6641,  ...,  0.8320, -0.2178, -1.0312]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([4])


tensor([False, False, False,  True], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.6719,  0.5938,  0.6641,  ...,  0.8320, -0.2178, -1.0312]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 1])


tensor([[1.6797],
        [1.8516],
        [1.7812],
        [1.0000]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([1.6797, 1.8516, 1.7812, 1.0000], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([1.6797, 1.8516, 1.7812, 0.0000], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([False,  True, False,  True], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[-2.0781,  0.7969, -0.2637,  ..., -1.6641,  0.0510,  0.3867],
        [-0.6719,  0.5938,  0.6641,  ...,  0.8320, -0.2178, -1.0312]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 2])


tensor([[1.8516, 1.6797],
        [1.0000, 1.8516],
        [1.8203, 1.7812],
        [1.8516, 1.0000]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([1.6797, 1.0000, 1.7812, 1.0000], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([1.6797, 0.0000, 1.7812, 0.0000], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([False,  True,  True,  True], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[-2.0781,  0.7969, -0.2637,  ..., -1.6641,  0.0510,  0.3867],
        [-0.8594,  0.9336, -0.1465,  ..., -1.2500, -1.3047, -1.0391],
        [-0.6719,  0.5938,  0.6641,  ...,  0.8320, -0.2178, -1.0312]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 3])


tensor([[1.8516, 1.7969, 1.6797],
        [1.0000, 1.8203, 1.8516],
        [1.8203, 1.0000, 1.7812],
        [1.8516, 1.7812, 1.0000]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([4])


tensor([2, 0, 1, 2], device='cuda:0')


new_candidates
torch.Size([3, 29])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338]],
       device='cuda:0')


new_candidate_parents


[1, 2, 3]


new_candidate_aunts


[[1], [2], [0, 3]]


new_candidate_logprobs
torch.Size([3])


tensor([-3.9703, -2.8348, -5.8707], device='cuda:0')

F NEIGHBORS PRIOR 7: (15592.349331655) 3 candidates, 1.7858830530003615 inference time, 1.78588647000106 total time
event: message
id: f-7"
data: {"id": "f-7", "level_type": "gather", "duration": 1.7858830530003615, "nodes": [{"content": "\u2581Pro", "parent": 1, "aunts": [1], "prob": -3.9703469276428223}, {"content": "a", "parent": 2, "aunts": [2], "prob": -2.834794521331787}, {"content": "\u2581is", "parent": 3, "aunts": [0, 3], "prob": -5.8706865310668945}]}




F NEIGHBORS AFTER 7: (15592.349815211) 3 candidates, 1.7858830530003615 inference time, 1.7863694190000388 total time

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 28])


tensor([], device='cuda:0', size=(0, 28), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 29])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338]],
       device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-3.9703, -2.8348, -5.8707], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 1.7734,  1.9688,  1.1250,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.2812, -1.7500,  0.4062,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.1875,  0.4727,  1.8750,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[1.0000e+00, 1.4450e-07, 4.5991e-10,  ..., 8.7565e-27, 4.1363e-27,
         7.1878e-28],
        [1.0000e+00, 1.7603e-06, 4.4508e-07,  ..., 4.4469e-26, 4.4469e-26,
         2.6972e-26],
        [4.9547e-01, 4.9547e-01, 7.0675e-03,  ..., 8.5175e-22, 5.8540e-22,
         2.0231e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.4955, 0.9909, 0.9980,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([4])


tensor([0, 1, 2, 2], device='cuda:0')


carryover_candidates
torch.Size([4, 29])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([4])


tensor([-3.9703, -2.8348, -5.8707, -5.8707], device='cuda:0')


new_candidate_toks
torch.Size([4, 1])


tensor([[ 2657],
        [ 2895],
        [  278],
        [21438]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([4])


tensor([-1.1921e-07, -2.3842e-06, -7.0225e-01, -7.0225e-01], device='cuda:0')


new_candidates
torch.Size([4, 30])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338, 21438]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([4])


tensor([-3.9703, -2.8348, -6.5729, -6.5729], device='cuda:0')

TOP P PRIOR 7: (15592.396236506) 4 candidates, 1.8327876920011477 inference time, 1.8327914540004713 total time
event: message
id: 7-p
data: {"id": "7-p", "level_type": "sample", "duration": 1.8327876920011477, "nodes": [{"content": "xim", "parent": 0, "prob": -3.9703471660614014}, {"content": "\u2581Cent", "parent": 1, "prob": -2.834796905517578}, {"content": "\u2581the", "parent": 2, "prob": -6.572933197021484}, {"content": "\u2581closest", "parent": 2, "prob": -6.572933197021484}], "finished": []}




TOP P AFTER 7: (15592.396660968) 4 candidates, 1.8327876920011477 inference time, 1.8332150250007544 total time

num_batches


1

infer start: GPU memory used: 7708 MB.

batch_candidates
torch.Size([4, 30])

batch_candidate_logprobs
torch.Size([4])

batch_logits
torch.Size([4, 30, 32064])

hidden_states[-1]
torch.Size([4, 30, 3072])
infer - after batch run: GPU memory used: 7708 MB.

candidates
torch.Size([4, 30])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338, 21438]],
       device='cuda:0')


candidate_logprobs
torch.Size([4])


tensor([-3.9703, -2.8348, -6.5729, -6.5729], device='cuda:0')


embeddings
torch.Size([4, 3072])


tensor([[-1.7344,  0.0605,  1.5000,  ..., -3.1875, -1.2266, -0.0317],
        [-0.1494, -0.2168,  1.4688,  ..., -1.5000,  0.6289,  0.2715],
        [-1.1875,  1.0938,  0.6406,  ..., -0.6055,  0.1279,  0.6992],
        [-0.0466, -0.2246,  2.1406,  ..., -0.5820,  0.0253,  0.4961]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([4])


tensor([False,  True, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.1494, -0.2168,  1.4688,  ..., -1.5000,  0.6289,  0.2715]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 1])


tensor([[1.7422],
        [1.0000],
        [1.9141],
        [1.8516]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([1.7422, 1.0000, 1.9141, 1.8516], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([1.7422, 0.0000, 1.9141, 1.8516], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([False,  True,  True, False], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[-0.1494, -0.2168,  1.4688,  ..., -1.5000,  0.6289,  0.2715],
        [-1.1875,  1.0938,  0.6406,  ..., -0.6055,  0.1279,  0.6992]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 2])


tensor([[1.7422, 1.8594],
        [1.0000, 1.9141],
        [1.9141, 0.9922],
        [1.8516, 1.5391]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([1.7422, 1.0000, 0.9922, 1.5391], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([1.7422, 0.0000, 0.0000, 1.5391], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([ True,  True,  True, False], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[-1.7344,  0.0605,  1.5000,  ..., -3.1875, -1.2266, -0.0317],
        [-0.1494, -0.2168,  1.4688,  ..., -1.5000,  0.6289,  0.2715],
        [-1.1875,  1.0938,  0.6406,  ..., -0.6055,  0.1279,  0.6992]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 3])


tensor([[1.0000, 1.7422, 1.8594],
        [1.7422, 1.0000, 1.9141],
        [1.8594, 1.9141, 0.9922],
        [1.8047, 1.8516, 1.5391]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([4])


tensor([0, 1, 2, 2], device='cuda:0')


new_candidates
torch.Size([3, 30])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278]],
       device='cuda:0')


new_candidate_parents


[0, 1, 2]


new_candidate_aunts


[[], [], [3]]


new_candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8348, -13.1459], device='cuda:0')

F NEIGHBORS PRIOR 8: (15594.159166329) 3 candidates, 1.7624914389998594 inference time, 1.7624969930002408 total time
event: message
id: f-8"
data: {"id": "f-8", "level_type": "gather", "duration": 1.7624914389998594, "nodes": [{"content": "xim", "parent": 0, "aunts": [], "prob": -3.9703471660614014}, {"content": "\u2581Cent", "parent": 1, "aunts": [], "prob": -2.834796905517578}, {"content": "\u2581the", "parent": 2, "aunts": [3], "prob": -13.145866394042969}]}




F NEIGHBORS AFTER 8: (15594.159671412) 3 candidates, 1.7624914389998594 inference time, 1.7630004940001527 total time

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 29])


tensor([], device='cuda:0', size=(0, 29), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 30])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278]],
       device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8348, -13.1459], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 2.8594,  1.0391, -1.3203,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.0625, -0.7695,  2.0469,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.8438,  3.7500,  1.8359,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[1.0000e+00, 9.2374e-09, 1.4166e-09,  ..., 3.2859e-25, 8.3079e-26,
         3.0563e-26],
        [9.9990e-01, 7.4844e-05, 1.4738e-05,  ..., 2.4277e-24, 7.8815e-25,
         4.7804e-25],
        [9.9478e-01, 5.2201e-03, 2.2485e-06,  ..., 1.1409e-24, 1.0068e-24,
         7.3661e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9948, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([3, 30])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8348, -13.1459], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[29874],
        [29874],
        [21438]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([ 0.0000, -0.0001, -0.0052], device='cuda:0')


new_candidates
torch.Size([3, 31])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438]], device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8349, -13.1511], device='cuda:0')

TOP P PRIOR 8: (15594.206307384) 3 candidates, 1.8096334080000815 inference time, 1.8096370570001454 total time
event: message
id: 8-p
data: {"id": "8-p", "level_type": "sample", "duration": 1.8096334080000815, "nodes": [{"content": "a", "parent": 0, "prob": -3.9703471660614014}, {"content": "a", "parent": 1, "prob": -2.8349013328552246}, {"content": "\u2581closest", "parent": 2, "prob": -13.151103973388672}], "finished": []}




TOP P AFTER 8: (15594.206673252) 3 candidates, 1.8096334080000815 inference time, 1.810001933001331 total time

num_batches


1

infer start: GPU memory used: 7708 MB.

batch_candidates
torch.Size([3, 31])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 31, 32064])

hidden_states[-1]
torch.Size([3, 31, 3072])
infer - after batch run: GPU memory used: 7686 MB.

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 30])


tensor([], device='cuda:0', size=(0, 30), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 31])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8349, -13.1511], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 7.0938, -1.5312,  0.2598,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.1250, -5.2188,  0.1982,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.6094, -2.3594, -0.8086,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[1.0000e+00, 1.3710e-06, 2.3824e-07,  ..., 3.4633e-26, 3.0563e-26,
         1.4437e-26],
        [9.9997e-01, 2.7536e-05, 3.7265e-06,  ..., 4.9399e-28, 3.8472e-28,
         2.6441e-28],
        [9.1932e-01, 7.5463e-02, 4.8242e-03,  ..., 4.1700e-24, 2.8660e-24,
         2.5293e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9193, 0.9948, 0.9996,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([3, 31])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8349, -13.1511], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[2895],
        [5338],
        [5810]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-1.7881e-06, -3.4214e-05, -8.4119e-02], device='cuda:0')


new_candidates
torch.Size([3, 32])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438,  5810]], device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8349, -13.2352], device='cuda:0')

TOP P PRIOR 9: (15595.950170498) 3 candidates, 1.7434830860001966 inference time, 1.7434882509987801 total time
event: message
id: 9-p
data: {"id": "9-p", "level_type": "sample", "duration": 1.7434830860001966, "nodes": [{"content": "\u2581Cent", "parent": 0, "prob": -3.970349073410034}, {"content": "uri", "parent": 1, "prob": -2.8349356651306152}, {"content": "\u2581star", "parent": 2, "prob": -13.235222816467285}], "finished": []}




TOP P AFTER 9: (15595.950594394) 3 candidates, 1.7434830860001966 inference time, 1.7439103319993592 total time

num_batches


1

infer start: GPU memory used: 7686 MB.

batch_candidates
torch.Size([3, 32])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 32, 32064])

hidden_states[-1]
torch.Size([3, 32, 3072])
infer - after batch run: GPU memory used: 7686 MB.

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 31])


tensor([], device='cuda:0', size=(0, 31), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 32])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438,  5810]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8349, -13.2352], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 5.2812, -0.8867,  1.7891,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.2500,  1.0938,  5.8750,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.5312, -2.4062,  2.2344,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[9.9986e-01, 1.0889e-04, 2.1442e-05,  ..., 3.5321e-24, 1.2994e-24,
         1.1467e-24],
        [6.7824e-01, 3.2038e-01, 5.4580e-04,  ..., 2.0061e-23, 2.0061e-23,
         5.0723e-24],
        [7.9918e-01, 1.7832e-01, 1.8795e-02,  ..., 1.8593e-22, 1.7467e-22,
         1.0594e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.6782, 0.9986, 0.9992,  ..., 1.0000, 1.0000, 1.0000],
        [0.7992, 0.9775, 0.9963,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([5])


tensor([0, 1, 1, 2, 2], device='cuda:0')


carryover_candidates
torch.Size([5, 32])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  


carryover_candidate_logprobs
torch.Size([5])


tensor([ -3.9703,  -2.8349,  -2.8349, -13.2352, -13.2352], device='cuda:0')


new_candidate_toks
torch.Size([5, 1])


tensor([[29874],
        [29889],
        [29892],
        [  304],
        [29889]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([5])


tensor([-1.4402e-04, -3.8825e-01, -1.1383e+00, -2.2417e-01, -1.7242e+00],
       device='cuda:0')


new_candidates
torch.Size([5, 33])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 2


new_candidate_logprobs
torch.Size([5])


tensor([ -3.9705,  -3.2232,  -3.9732, -13.4594, -14.9594], device='cuda:0')

TOP P PRIOR 10: (15597.707892949) 5 candidates, 1.757284855000762 inference time, 1.757289495000805 total time
event: message
id: 10-p
data: {"id": "10-p", "level_type": "sample", "duration": 1.757284855000762, "nodes": [{"content": "a", "parent": 0, "prob": -3.9704930782318115}, {"content": ".", "parent": 1, "prob": -3.223188877105713}, {"content": ",", "parent": 1, "prob": -3.973188877105713}, {"content": "\u2581to", "parent": 2, "prob": -13.459388732910156}, {"content": ".", "parent": 2, "prob": -14.959388732910156}], "finished": []}




TOP P AFTER 10: (15597.70840233) 5 candidates, 1.757284855000762 inference time, 1.7577971369992156 total time

num_batches


1

infer start: GPU memory used: 7686 MB.

batch_candidates
torch.Size([5, 33])

batch_candidate_logprobs
torch.Size([5])

batch_logits
torch.Size([5, 33, 32064])

hidden_states[-1]
torch.Size([5, 33, 3072])
infer - after batch run: GPU memory used: 7710 MB.

candidates
torch.Size([5, 33])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 2


candidate_logprobs
torch.Size([5])


tensor([ -3.9705,  -3.2232,  -3.9732, -13.4594, -14.9594], device='cuda:0')


embeddings
torch.Size([5, 3072])


tensor([[ 0.6641, -1.5938,  2.0781,  ..., -2.4062,  1.1016, -0.1787],
        [-0.5273, -0.2432,  0.3477,  ..., -0.8438, -0.4629, -2.0000],
        [-0.3105, -0.4727,  0.3379,  ..., -0.9062,  1.7656, -0.8398],
        [ 1.0469,  1.0547,  1.4219,  ..., -1.5078, -0.0215, -0.1582],
        [-0.4785, -0.3008,  0.7031,  ..., -2.0312, -0.3926, -1.9688]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([5])


tensor([False,  True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.5273, -0.2432,  0.3477,  ..., -0.8438, -0.4629, -2.0000]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([5, 1])


tensor([[1.9531],
        [1.0000],
        [1.6406],
        [1.7422],
        [1.0781]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([5])


tensor([1.9531, 1.0000, 1.6406, 1.7422, 1.0781], device='cuda:0',
       dtype=torch.bfloat16)


min_remaining_distances
torch.Size([5])


tensor([1.9531, 0.0000, 1.6406, 1.7422, 1.0781], device='cuda:0',
       dtype=torch.bfloat16)


selected (end of loop)
torch.Size([5])


tensor([ True,  True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[ 0.6641, -1.5938,  2.0781,  ..., -2.4062,  1.1016, -0.1787],
        [-0.5273, -0.2432,  0.3477,  ..., -0.8438, -0.4629, -2.0000]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([5, 2])


tensor([[1.0000, 1.9531],
        [1.9531, 1.0000],
        [1.8828, 1.6406],
        [1.9141, 1.7422],
        [1.9531, 1.0781]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([5])


tensor([1.0000, 1.0000, 1.6406, 1.7422, 1.0781], device='cuda:0',
       dtype=torch.bfloat16)


min_remaining_distances
torch.Size([5])


tensor([0.0000, 0.0000, 1.6406, 1.7422, 1.0781], device='cuda:0',
       dtype=torch.bfloat16)


selected (end of loop)
torch.Size([5])


tensor([ True,  True, False,  True, False], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[ 0.6641, -1.5938,  2.0781,  ..., -2.4062,  1.1016, -0.1787],
        [-0.5273, -0.2432,  0.3477,  ..., -0.8438, -0.4629, -2.0000],
        [ 1.0469,  1.0547,  1.4219,  ..., -1.5078, -0.0215, -0.1582]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([5, 3])


tensor([[1.0000, 1.9531, 1.9141],
        [1.9531, 1.0000, 1.7422],
        [1.8828, 1.6406, 1.5781],
        [1.9141, 1.7422, 1.0000],
        [1.9531, 1.0781, 1.7656]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([5])


tensor([0, 1, 2, 2, 1], device='cuda:0')


new_candidates
torch.Size([3, 33])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438,  5810,   304]], device='cuda:0')


new_candidate_parents


[0, 1, 3]


new_candidate_aunts


[[], [4], [3]]


new_candidate_logprobs
torch.Size([3])


tensor([ -3.9705, -18.1826, -17.4326], device='cuda:0')

F NEIGHBORS PRIOR 11: (15599.565358785) 3 candidates, 1.856943021000916 inference time, 1.856946390000303 total time
event: message
id: f-11"
data: {"id": "f-11", "level_type": "gather", "duration": 1.856943021000916, "nodes": [{"content": "a", "parent": 0, "aunts": [], "prob": -3.9704930782318115}, {"content": ".", "parent": 1, "aunts": [4], "prob": -18.18257713317871}, {"content": "\u2581to", "parent": 3, "aunts": [3], "prob": -17.43257713317871}]}




F NEIGHBORS AFTER 11: (15599.565777447) 3 candidates, 1.856943021000916 inference time, 1.8573642380015372 total time

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 32])


tensor([], device='cuda:0', size=(0, 32), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 33])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438,  5810,   304]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([ -3.9705, -18.1826, -17.4326], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 6.1250e+00, -5.6562e+00,  6.5002e-03,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 6.9688e+00,  4.2812e+00,  1.3562e+01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 8.0000e+00,  4.5000e+00,  5.2812e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[9.9998e-01, 1.8925e-05, 3.2887e-06,  ..., 2.6442e-28, 2.6442e-28,
         2.0593e-28],
        [8.1523e-01, 1.8190e-01, 1.3888e-03,  ..., 8.2263e-19, 4.6872e-19,
         3.4292e-19],
        [9.7638e-01, 2.2962e-02, 5.4002e-04,  ..., 2.1552e-21, 2.0246e-21,
         5.8006e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8152, 0.9971, 0.9985,  ..., 1.0000, 1.0000, 1.0000],
        [0.9764, 0.9993, 0.9999,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([4])


tensor([0, 1, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([4, 33])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 2


carryover_candidate_logprobs
torch.Size([4])


tensor([ -3.9705, -18.1826, -18.1826, -17.4326], device='cuda:0')


new_candidate_toks
torch.Size([4, 1])


tensor([[ 5338],
        [   13],
        [32007],
        [11563]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([4])


tensor([-2.4557e-05, -2.0428e-01, -1.7043e+00, -2.3903e-02], device='cuda:0')


new_candidates
torch.Size([4, 34])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889, 32007],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 3


new_candidate_logprobs
torch.Size([4])


tensor([ -3.9705, -18.3869, -19.8869, -17.4565], device='cuda:0')

TOP P PRIOR 11: (15599.607423129) 4 candidates, 1.8990063890014426 inference time, 1.8990121010010625 total time
event: message
id: 11-p
data: {"id": "11-p", "level_type": "sample", "duration": 1.8990063890014426, "nodes": [{"content": "uri", "parent": 0, "prob": -3.970517635345459}, {"content": "<0x0A>", "parent": 1, "prob": -18.386856079101562}, {"content": "<|end|>", "parent": 1, "prob": -19.886856079101562}, {"content": "\u2581Earth", "parent": 2, "prob": -17.456480026245117}], "finished": []}




TOP P AFTER 11: (15599.60795931) 4 candidates, 1.8990063890014426 inference time, 1.8995477740008937 total time

num_batches


1

infer start: GPU memory used: 7710 MB.

batch_candidates
torch.Size([4, 34])

batch_candidate_logprobs
torch.Size([4])

batch_logits
torch.Size([4, 34, 32064])

hidden_states[-1]
torch.Size([4, 34, 3072])
infer - after batch run: GPU memory used: 7708 MB.

candidates
torch.Size([4, 34])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889, 32007],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 3


candidate_logprobs
torch.Size([4])


tensor([ -3.9705, -18.3869, -19.8869, -17.4565], device='cuda:0')


embeddings
torch.Size([4, 3072])


tensor([[ 1.0469, -0.2402, -0.2207,  ..., -0.0081,  0.8711,  1.2578],
        [ 1.2812, -0.2334, -3.3906,  ...,  0.3086, -0.3672, -0.1650],
        [-0.5469, -0.1416,  1.0156,  ..., -1.1797,  0.2246,  1.2578],
        [ 1.0234,  0.7031,  0.9375,  ..., -0.0227,  0.1191, -0.3457]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([4])


tensor([ True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 1.0469, -0.2402, -0.2207,  ..., -0.0081,  0.8711,  1.2578]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 1])


tensor([[0.9922],
        [1.9141],
        [2.0312],
        [1.2812]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([0.9922, 1.9141, 2.0312, 1.2812], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([0.0000, 1.9141, 2.0312, 1.2812], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([ True, False,  True, False], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[ 1.0469, -0.2402, -0.2207,  ..., -0.0081,  0.8711,  1.2578],
        [-0.5469, -0.1416,  1.0156,  ..., -1.1797,  0.2246,  1.2578]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 2])


tensor([[0.9922, 2.0312],
        [1.9141, 1.7031],
        [2.0312, 1.0000],
        [1.2812, 2.0469]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([0.9922, 1.7031, 1.0000, 1.2812], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([0.0000, 1.7031, 0.0000, 1.2812], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([ True,  True,  True, False], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[ 1.0469, -0.2402, -0.2207,  ..., -0.0081,  0.8711,  1.2578],
        [ 1.2812, -0.2334, -3.3906,  ...,  0.3086, -0.3672, -0.1650],
        [-0.5469, -0.1416,  1.0156,  ..., -1.1797,  0.2246,  1.2578]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 3])


tensor([[0.9922, 1.9141, 2.0312],
        [1.9141, 1.0000, 1.7031],
        [2.0312, 1.7031, 1.0000],
        [1.2812, 1.9375, 2.0469]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([4])


tensor([0, 1, 2, 0], device='cuda:0')


new_candidates
torch.Size([3, 34])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889, 32007]], device='cuda:0')


new_candidate_parents


[0, 1, 2]


new_candidate_aunts


[[3], [], []]


new_candidate_logprobs
torch.Size([3])


tensor([-21.4270, -18.3869, -19.8869], device='cuda:0')

F NEIGHBORS PRIOR 12: (15601.499876418) 3 candidates, 1.8918919160005316 inference time, 1.891897355000765 total time
event: message
id: f-12"
data: {"id": "f-12", "level_type": "gather", "duration": 1.8918919160005316, "nodes": [{"content": "uri", "parent": 0, "aunts": [3], "prob": -21.426998138427734}, {"content": "<0x0A>", "parent": 1, "aunts": [], "prob": -18.386856079101562}, {"content": "<|end|>", "parent": 2, "aunts": [], "prob": -19.886856079101562}]}




F NEIGHBORS AFTER 12: (15601.500383505) 3 candidates, 1.8918919160005316 inference time, 1.8924029730005714 total time

finished_mask
torch.Size([3])


tensor([False, False,  True], device='cuda:0')


finished
torch.Size([1, 33])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889]], device='cuda:0')


finished_parents
torch.Size([1])


tensor([2], device='cuda:0')


finished_logprobs
torch.Size([1])


tensor([-19.8869], device='cuda:0')


candidates
torch.Size([2, 34])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-21.4270, -18.3869], device='cuda:0')


last_tok_logits
torch.Size([2, 32064])


tensor([[ 8.1250,  1.5156,  5.6562,  ...,  0.0000,  0.0000,  0.0000],
        [-4.1250,  1.9844,  6.5625,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[9.3942e-01, 6.0055e-02, 1.9114e-04,  ..., 2.4521e-23, 1.4873e-23,
         1.3125e-23],
        [9.9899e-01, 2.6099e-04, 2.3033e-04,  ..., 8.8083e-20, 6.0538e-20,
         4.7147e-20]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[0.9394, 0.9995, 0.9997,  ..., 1.0000, 1.0000, 1.0000],
        [0.9990, 0.9993, 0.9995,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 1], device='cuda:0')


carryover_candidates
torch.Size([2, 34])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-21.4270, -18.3869], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[29889],
        [   13]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.0625, -0.0010], device='cuda:0')


new_candidates
torch.Size([2, 35])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-21.4895, -18.3879], device='cuda:0')

TOP P PRIOR 12: (15601.542900308) 2 candidates, 1.934916811000221 inference time, 1.9349204840000311 total time
event: message
id: 12-p
data: {"id": "12-p", "level_type": "sample", "duration": 1.934916811000221, "nodes": [{"content": ".", "parent": 0, "prob": -21.489492416381836}, {"content": "<0x0A>", "parent": 1, "prob": -18.38786506652832}], "finished": [{"content": "What is the closest star to the Earth? Answer in 5 words or less.  The closest star is Proxima Centauri.", "parent": 2, "prob": -19.886856079101562}]}




TOP P AFTER 12: (15601.543484503) 2 candidates, 1.934916811000221 inference time, 1.935504455001137 total time

num_batches


1

infer start: GPU memory used: 7708 MB.

batch_candidates
torch.Size([2, 35])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 35, 32064])

hidden_states[-1]
torch.Size([2, 35, 3072])
infer - after batch run: GPU memory used: 7674 MB.

finished_mask
torch.Size([2])


tensor([False, False], device='cuda:0')


finished
torch.Size([0, 34])


tensor([], device='cuda:0', size=(0, 34), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([2, 35])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-21.4895, -18.3879], device='cuda:0')


last_tok_logits
torch.Size([2, 32064])


tensor([[ 7.0312,  4.0312, 13.3750,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.6562,  5.5312,  3.9531,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[5.8736e-01, 4.0369e-01, 6.5250e-03,  ..., 1.3357e-18, 1.1787e-18,
         9.7719e-19],
        [4.6826e-01, 3.2183e-01, 8.1371e-02,  ..., 3.1118e-17, 2.4235e-17,
         2.4235e-17]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[0.5874, 0.9911, 0.9976,  ..., 1.0000, 1.0000, 1.0000],
        [0.4683, 0.7901, 0.8715,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([2, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([6])


tensor([0, 0, 1, 1, 1, 1], device='cuda:0')


carryover_candidates
torch.Size([6, 35])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 298


carryover_candidate_logprobs
torch.Size([6])


tensor([-21.4895, -21.4895, -18.3879, -18.3879, -18.3879, -18.3879],
       device='cuda:0')


new_candidate_toks
torch.Size([6, 1])


tensor([[32007],
        [   13],
        [17245],
        [29898],
        [ 9842],
        [   13]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([6])


tensor([-0.5321, -0.9071, -0.7587, -1.1337, -2.5087, -3.5087], device='cuda:0')


new_candidates
torch.Size([6, 36])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889, 32007],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889,    13],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   4


new_candidate_logprobs
torch.Size([6])


tensor([-22.0216, -22.3966, -19.1466, -19.5216, -20.8966, -21.8966],
       device='cuda:0')

TOP P PRIOR 13: (15603.325859732) 6 candidates, 1.7823591209998995 inference time, 1.7823631909996038 total time
event: message
id: 13-p
data: {"id": "13-p", "level_type": "sample", "duration": 1.7823591209998995, "nodes": [{"content": "<|end|>", "parent": 0, "prob": -22.021602630615234}, {"content": "<0x0A>", "parent": 0, "prob": -22.396602630615234}, {"content": "However", "parent": 1, "prob": -19.146604537963867}, {"content": "(", "parent": 1, "prob": -19.521604537963867}, {"content": "Note", "parent": 1, "prob": -20.896604537963867}, {"content": "<0x0A>", "parent": 1, "prob": -21.896604537963867}], "finished": []}




TOP P AFTER 13: (15603.326373881) 6 candidates, 1.7823591209998995 inference time, 1.7828764710011455 total time

num_batches


1

infer start: GPU memory used: 7674 MB.

batch_candidates
torch.Size([6, 36])

batch_candidate_logprobs
torch.Size([6])

batch_logits
torch.Size([6, 36, 32064])

hidden_states[-1]
torch.Size([6, 36, 3072])
infer - after batch run: GPU memory used: 7724 MB.

candidates
torch.Size([6, 36])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889, 32007],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889,    13],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   4


candidate_logprobs
torch.Size([6])


tensor([-22.0216, -22.3966, -19.1466, -19.5216, -20.8966, -21.8966],
       device='cuda:0')


embeddings
torch.Size([6, 3072])


tensor([[-0.1602, -0.4160,  0.7227,  ..., -2.5156,  0.5859,  1.1953],
        [ 1.3984, -0.3574, -2.6094,  ...,  0.4004, -0.6328,  0.2100],
        [ 0.7734, -0.3867,  0.4473,  ..., -0.0537,  0.4902,  0.4688],
        [ 1.9766, -0.0654,  0.3828,  ...,  0.1328,  0.6016, -0.8438],
        [ 1.3281,  0.4570,  0.6094,  ...,  0.9766,  0.4531, -1.8516],
        [-0.1104,  0.1338,  0.8242,  ...,  0.4766,  0.2578, -0.8750]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([6])


tensor([False, False,  True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.7734, -0.3867,  0.4473,  ..., -0.0537,  0.4902,  0.4688]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([6, 1])


tensor([[2.0469],
        [1.9297],
        [1.0000],
        [1.6484],
        [1.5625],
        [1.6875]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([6])


tensor([2.0469, 1.9297, 1.0000, 1.6484, 1.5625, 1.6875], device='cuda:0',
       dtype=torch.bfloat16)


min_remaining_distances
torch.Size([6])


tensor([2.0469, 1.9297, 0.0000, 1.6484, 1.5625, 1.6875], device='cuda:0',
       dtype=torch.bfloat16)


selected (end of loop)
torch.Size([6])


tensor([ True, False,  True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[-0.1602, -0.4160,  0.7227,  ..., -2.5156,  0.5859,  1.1953],
        [ 0.7734, -0.3867,  0.4473,  ..., -0.0537,  0.4902,  0.4688]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([6, 2])


tensor([[1.0000, 2.0469],
        [1.7500, 1.9297],
        [2.0469, 1.0000],
        [2.0156, 1.6484],
        [2.0000, 1.5625],
        [1.9766, 1.6875]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([6])


tensor([0, 0, 1, 1, 1, 1], device='cuda:0')


new_candidates
torch.Size([2, 36])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889, 32007],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245]], device='cuda:0')


new_candidate_parents


[0, 2]


new_candidate_aunts


[[1], [2, 3, 4, 5]]


new_candidate_logprobs
torch.Size([2])


tensor([-44.4182, -81.4614], device='cuda:0')

F NEIGHBORS PRIOR 14: (15605.362858911) 2 candidates, 2.0364689559992257 inference time, 2.03647463499874 total time
event: message
id: f-14"
data: {"id": "f-14", "level_type": "gather", "duration": 2.0364689559992257, "nodes": [{"content": "<|end|>", "parent": 0, "aunts": [1], "prob": -44.41820526123047}, {"content": "However", "parent": 2, "aunts": [2, 3, 4, 5], "prob": -81.46141815185547}]}




F NEIGHBORS AFTER 14: (15605.363453154) 2 candidates, 2.0364689559992257 inference time, 2.037068150999403 total time

finished_mask
torch.Size([2])


tensor([ True, False], device='cuda:0')


finished
torch.Size([1, 35])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889]], device='cuda:0')


finished_parents
torch.Size([1])


tensor([0], device='cuda:0')


finished_logprobs
torch.Size([1])


tensor([-44.4182], device='cuda:0')


candidates
torch.Size([1, 36])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-81.4614], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[3.6094, 5.0625, 0.8008,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9999e-01, 2.5613e-06, 2.2603e-06,  ..., 3.4528e-20, 2.6891e-20,
         2.2293e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 36])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-81.4614], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[29892]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-1.3232e-05], device='cuda:0')


new_candidates
torch.Size([1, 37])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-81.4614], device='cuda:0')

TOP P PRIOR 14: (15605.420440399) 1 candidates, 2.0940503559995705 inference time, 2.0940563539988943 total time
event: message
id: 14-p
data: {"id": "14-p", "level_type": "sample", "duration": 2.0940503559995705, "nodes": [{"content": ",", "parent": 0, "prob": -81.46143341064453}], "finished": [{"content": "What is the closest star to the Earth? Answer in 5 words or less.  The closest star to Earth is Proxima Centauri.", "parent": 0, "prob": -44.41820526123047}]}




TOP P AFTER 14: (15605.421273132) 1 candidates, 2.0940503559995705 inference time, 2.094888858999184 total time

num_batches


1

infer start: GPU memory used: 7724 MB.

batch_candidates
torch.Size([1, 37])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 37, 32064])

hidden_states[-1]
torch.Size([1, 37, 3072])
infer - after batch run: GPU memory used: 7510 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 36])


tensor([], device='cuda:0', size=(0, 36), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 37])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-81.4614], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[5.6250, 7.6562, 5.9375,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[3.7173e-01, 2.8950e-01, 5.7007e-02,  ..., 1.1198e-18, 6.3807e-19,
         6.1844e-19]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.3717, 0.6612, 0.7182,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([5, 37])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 298


carryover_candidate_logprobs
torch.Size([5])


tensor([-81.4614, -81.4614, -81.4614, -81.4614, -81.4614], device='cuda:0')


new_candidate_toks
torch.Size([5, 1])


tensor([[ 565],
        [ 372],
        [ 304],
        [ 278],
        [1951]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([5])


tensor([-0.9896, -1.2396, -2.8646, -3.1146, -3.1146], device='cuda:0')


new_candidates
torch.Size([5, 38])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   372],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         299


new_candidate_logprobs
torch.Size([5])


tensor([-82.4510, -82.7010, -84.3260, -84.5760, -84.5760], device='cuda:0')

TOP P PRIOR 15: (15607.164709243) 5 candidates, 1.743413137999596 inference time, 1.7434183709992794 total time
event: message
id: 15-p
data: {"id": "15-p", "level_type": "sample", "duration": 1.743413137999596, "nodes": [{"content": "\u2581if", "parent": 0, "prob": -82.45101928710938}, {"content": "\u2581it", "parent": 0, "prob": -82.70101928710938}, {"content": "\u2581to", "parent": 0, "prob": -84.32601928710938}, {"content": "\u2581the", "parent": 0, "prob": -84.57601928710938}, {"content": "\u2581since", "parent": 0, "prob": -84.57601928710938}], "finished": []}




TOP P AFTER 15: (15607.165234215) 5 candidates, 1.743413137999596 inference time, 1.7439427029985382 total time

num_batches


1

infer start: GPU memory used: 7510 MB.

batch_candidates
torch.Size([5, 38])

batch_candidate_logprobs
torch.Size([5])

batch_logits
torch.Size([5, 38, 32064])

hidden_states[-1]
torch.Size([5, 38, 3072])
infer - after batch run: GPU memory used: 7678 MB.

candidates
torch.Size([5, 38])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   372],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         299


candidate_logprobs
torch.Size([5])


tensor([-82.4510, -82.7010, -84.3260, -84.5760, -84.5760], device='cuda:0')


embeddings
torch.Size([5, 3072])


tensor([[ 0.5547,  0.3086,  0.2773,  ...,  0.6523, -1.6406, -0.6641],
        [ 0.2285,  1.6875,  0.0615,  ...,  0.0505, -1.1641,  0.6055],
        [ 0.2695,  0.9570,  1.3516,  ...,  0.7734, -0.6055,  0.1553],
        [ 0.2637,  0.5703,  0.1875,  ...,  0.0513, -1.0234,  0.1836],
        [-0.5039,  1.5938,  1.5312,  ...,  0.3496, -1.3984, -0.9141]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([5])


tensor([ True, False, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.5547,  0.3086,  0.2773,  ...,  0.6523, -1.6406, -0.6641]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([5, 1])


tensor([[1.0000],
        [1.5781],
        [1.4688],
        [1.4375],
        [1.2812]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 38])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565]],
       device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2, 3, 4]]


new_candidate_logprobs
torch.Size([1])


tensor([-418.6301], device='cuda:0')

F NEIGHBORS PRIOR 16: (15609.065788245) 1 candidates, 1.9005322649991285 inference time, 1.9005360649989598 total time
event: message
id: f-16"
data: {"id": "f-16", "level_type": "gather", "duration": 1.9005322649991285, "nodes": [{"content": "\u2581if", "parent": 0, "aunts": [1, 2, 3, 4], "prob": -418.6300964355469}]}




F NEIGHBORS AFTER 16: (15609.066157562) 1 candidates, 1.9005322649991285 inference time, 1.9009043639998708 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 37])


tensor([], device='cuda:0', size=(0, 37), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 38])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-418.6301], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[6.2188, 4.7500, 7.4688,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[3.5380e-01, 2.7554e-01, 2.7554e-01,  ..., 2.1440e-20, 1.7775e-20,
         1.6698e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.3538, 0.6293, 0.9049,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([3, 38])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([-418.6301, -418.6301, -418.6301], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[366],
        [278],
        [591]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-1.0390, -1.2890, -1.2890], device='cuda:0')


new_candidates
torch.Size([3, 39])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   591]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([-419.6691, -419.9191, -419.9191], device='cuda:0')

TOP P PRIOR 16: (15609.109791577) 3 candidates, 1.9445358529992518 inference time, 1.9445386649986176 total time
event: message
id: 16-p
data: {"id": "16-p", "level_type": "sample", "duration": 1.9445358529992518, "nodes": [{"content": "\u2581you", "parent": 0, "prob": -419.6690979003906}, {"content": "\u2581the", "parent": 0, "prob": -419.9190979003906}, {"content": "\u2581we", "parent": 0, "prob": -419.9190979003906}], "finished": []}




TOP P AFTER 16: (15609.110234782) 3 candidates, 1.9445358529992518 inference time, 1.944982125000024 total time

num_batches


1

infer start: GPU memory used: 7678 MB.

batch_candidates
torch.Size([3, 39])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 39, 32064])

hidden_states[-1]
torch.Size([3, 39, 3072])
infer - after batch run: GPU memory used: 7600 MB.

candidates
torch.Size([3, 39])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   591]],
       device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-419.6691, -419.9191, -419.9191], device='cuda:0')


embeddings
torch.Size([3, 3072])


tensor([[-0.8203,  0.5352, -0.2139,  ...,  0.6367, -1.0547,  1.1875],
        [-0.0498,  0.6758,  0.4648,  ...,  0.7617, -2.0469,  0.4004],
        [ 0.2559,  0.1260,  1.2734,  ...,  1.6016, -0.9844,  0.8320]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([3])


tensor([ True, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.8203,  0.5352, -0.2139,  ...,  0.6367, -1.0547,  1.1875]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 1])


tensor([[1.0000],
        [1.4688],
        [1.1641]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 39])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366]],
       device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2]]


new_candidate_logprobs
torch.Size([1])


tensor([-1259.5073], device='cuda:0')

F NEIGHBORS PRIOR 17: (15610.924739611) 1 candidates, 1.8144888229999196 inference time, 1.8144934629999625 total time
event: message
id: f-17"
data: {"id": "f-17", "level_type": "gather", "duration": 1.8144888229999196, "nodes": [{"content": "\u2581you", "parent": 0, "aunts": [1, 2], "prob": -1259.50732421875}]}




F NEIGHBORS AFTER 17: (15610.925136262) 1 candidates, 1.8144888229999196 inference time, 1.8148900269989099 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 38])


tensor([], device='cuda:0', size=(0, 38), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 39])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-1259.5073], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[5.0312, 3.5781, 3.7344,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[4.5353e-01, 2.7508e-01, 1.0120e-01,  ..., 6.6587e-19, 5.1858e-19,
         5.8183e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.4535, 0.7286, 0.8298,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([5, 39])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278


carryover_candidate_logprobs
torch.Size([5])


tensor([-1259.5073, -1259.5073, -1259.5073, -1259.5073, -1259.5073],
       device='cuda:0')


new_candidate_toks
torch.Size([5, 1])


tensor([[29915],
        [  526],
        [ 6839],
        [18719],
        [ 2099]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([5])


tensor([-0.7907, -1.2907, -2.2907, -2.6657, -2.9157], device='cuda:0')


new_candidates
torch.Size([5, 40])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366,   526],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366,  6839],
        [    1, 32010,  1724,   338,   278, 21438


new_candidate_logprobs
torch.Size([5])


tensor([-1260.2980, -1260.7980, -1261.7980, -1262.1730, -1262.4230],
       device='cuda:0')

TOP P PRIOR 17: (15610.979297249) 5 candidates, 1.8690467950000311 inference time, 1.8690507779992913 total time
event: message
id: 17-p
data: {"id": "17-p", "level_type": "sample", "duration": 1.8690467950000311, "nodes": [{"content": "'", "parent": 0, "prob": -1260.2979736328125}, {"content": "\u2581are", "parent": 0, "prob": -1260.7979736328125}, {"content": "\u2581meant", "parent": 0, "prob": -1261.7979736328125}, {"content": "\u2581strictly", "parent": 0, "prob": -1262.1729736328125}, {"content": "\u2581mean", "parent": 0, "prob": -1262.4229736328125}], "finished": []}




TOP P AFTER 17: (15610.979814127) 5 candidates, 1.8690467950000311 inference time, 1.869566864999797 total time

num_batches


1

infer start: GPU memory used: 7600 MB.

batch_candidates
torch.Size([5, 40])

batch_candidate_logprobs
torch.Size([5])

batch_logits
torch.Size([5, 40, 32064])

hidden_states[-1]
torch.Size([5, 40, 3072])
infer - after batch run: GPU memory used: 7698 MB.

candidates
torch.Size([5, 40])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366,   526],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366,  6839],
        [    1, 32010,  1724,   338,   278, 21438


candidate_logprobs
torch.Size([5])


tensor([-1260.2980, -1260.7980, -1261.7980, -1262.1730, -1262.4230],
       device='cuda:0')


embeddings
torch.Size([5, 3072])


tensor([[ 0.4453,  0.6250, -0.4414,  ..., -0.0718, -0.2676,  0.0317],
        [ 0.3691, -0.0608,  0.4570,  ..., -1.1094, -0.0099,  0.0879],
        [-0.1465,  1.0547,  2.3906,  ...,  0.2734, -0.8711, -2.1406],
        [-1.0000,  0.8438,  0.5508,  ...,  1.3594, -1.0078,  0.0728],
        [ 0.0369,  0.9688,  2.5156,  ..., -0.0175, -0.5859, -1.9219]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([5])


tensor([ True, False, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.4453,  0.6250, -0.4414,  ..., -0.0718, -0.2676,  0.0317]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([5, 1])


tensor([[1.0000],
        [1.7812],
        [1.7500],
        [1.7422],
        [1.7422]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 40])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915]],
       device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2, 3, 4]]


new_candidate_logprobs
torch.Size([1])


tensor([-6307.4897], device='cuda:0')

F NEIGHBORS PRIOR 18: (15613.033702542) 1 candidates, 2.053873181999734 inference time, 2.053877594000369 total time
event: message
id: f-18"
data: {"id": "f-18", "level_type": "gather", "duration": 2.053873181999734, "nodes": [{"content": "'", "parent": 0, "aunts": [1, 2, 3, 4], "prob": -6307.48974609375}]}




F NEIGHBORS AFTER 18: (15613.03409103) 1 candidates, 2.053873181999734 inference time, 2.0542659790007747 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 39])


tensor([], device='cuda:0', size=(0, 39), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 40])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-6307.4897], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 1.2344, -1.2266, -0.1245,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9792e-01, 1.9264e-03, 1.2315e-04,  ..., 7.5147e-17, 5.4979e-17,
         5.1648e-17]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9979, 0.9998, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 40])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-6307.4897], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[276]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0021], device='cuda:0')


new_candidates
torch.Size([1, 41])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-6307.4917], device='cuda:0')

TOP P PRIOR 18: (15613.075967945) 1 candidates, 2.0961386910003057 inference time, 2.0961429130002216 total time
event: message
id: 18-p
data: {"id": "18-p", "level_type": "sample", "duration": 2.0961386910003057, "nodes": [{"content": "re", "parent": 0, "prob": -6307.49169921875}], "finished": []}




TOP P AFTER 18: (15613.076419427) 1 candidates, 2.0961386910003057 inference time, 2.0965940189998946 total time

num_batches


1

infer start: GPU memory used: 7698 MB.

batch_candidates
torch.Size([1, 41])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 41, 32064])

hidden_states[-1]
torch.Size([1, 41, 3072])
infer - after batch run: GPU memory used: 7514 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 40])


tensor([], device='cuda:0', size=(0, 40), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 41])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-6307.4917], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[3.9219, 4.3750, 5.1562,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[5.1326e-01, 3.9973e-01, 2.8956e-02,  ..., 9.0000e-20, 9.0000e-20,
         2.5785e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.5133, 0.9130, 0.9419,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 0], device='cuda:0')


carryover_candidates
torch.Size([2, 41])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-6307.4917, -6307.4917], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[16811],
        [ 6721]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.6670, -0.9170], device='cuda:0')


new_candidates
torch.Size([2, 42])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276,  6721]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-6308.1587, -6308.4087], device='cuda:0')

TOP P PRIOR 19: (15614.836294385) 2 candidates, 1.7598537750000105 inference time, 1.7598592009999265 total time
event: message
id: 19-p
data: {"id": "19-p", "level_type": "sample", "duration": 1.7598537750000105, "nodes": [{"content": "\u2581referring", "parent": 0, "prob": -6308.15869140625}, {"content": "\u2581asking", "parent": 0, "prob": -6308.40869140625}], "finished": []}




TOP P AFTER 19: (15614.836864853) 2 candidates, 1.7598537750000105 inference time, 1.7604296600002272 total time

num_batches


1

infer start: GPU memory used: 7514 MB.

batch_candidates
torch.Size([2, 42])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 42, 32064])

hidden_states[-1]
torch.Size([2, 42, 3072])
infer - after batch run: GPU memory used: 7548 MB.

candidates
torch.Size([2, 42])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276,  6721]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-6308.1587, -6308.4087], device='cuda:0')


embeddings
torch.Size([2, 3072])


tensor([[ 0.0312, -0.3145,  1.7656,  ...,  0.4570,  0.1367, -0.7812],
        [ 0.0098, -1.1797,  1.3438,  ..., -0.1582, -1.2656,  1.7656]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([2])


tensor([ True, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.0312, -0.3145,  1.7656,  ...,  0.4570,  0.1367, -0.7812]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([2, 1])


tensor([[1.0000],
        [1.2812]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([2])


tensor([0, 0], device='cuda:0')


new_candidates
torch.Size([1, 42])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811]], device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1]]


new_candidate_logprobs
torch.Size([1])


tensor([-12616.5674], device='cuda:0')

F NEIGHBORS PRIOR 20: (15616.668992118) 1 candidates, 1.832102002999818 inference time, 1.8321064049996494 total time
event: message
id: f-20"
data: {"id": "f-20", "level_type": "gather", "duration": 1.832102002999818, "nodes": [{"content": "\u2581referring", "parent": 0, "aunts": [1], "prob": -12616.5673828125}]}




F NEIGHBORS AFTER 20: (15616.669396556) 1 candidates, 1.832102002999818 inference time, 1.8325108389999514 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 41])


tensor([], device='cuda:0', size=(0, 41), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 42])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-12616.5674], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[6.5625, 5.4688, 8.6875,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.1653e-01, 5.8592e-02, 1.6787e-02,  ..., 1.1046e-19, 1.0377e-19,
         5.9124e-20]], device='cuda:0')

tensor([1.], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9165, 0.9751, 0.9919,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 42])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-12616.5674], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[304]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0872], device='cuda:0')


new_candidates
torch.Size([1, 43])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-12616.6543], device='cuda:0')

TOP P PRIOR 20: (15616.708701391) 1 candidates, 1.8718112189999374 inference time, 1.8718161790002341 total time
event: message
id: 20-p
data: {"id": "20-p", "level_type": "sample", "duration": 1.8718112189999374, "nodes": [{"content": "\u2581to", "parent": 0, "prob": -12616.654296875}], "finished": []}




TOP P AFTER 20: (15616.709143048) 1 candidates, 1.8718112189999374 inference time, 1.8722559870002442 total time

num_batches


1

infer start: GPU memory used: 7548 MB.

batch_candidates
torch.Size([1, 43])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 43, 32064])

hidden_states[-1]
torch.Size([1, 43, 3072])
infer - after batch run: GPU memory used: 7500 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 42])


tensor([], device='cuda:0', size=(0, 42), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 43])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-12616.6543], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[8.0000, 8.4375, 4.9688,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[4.4095e-01, 4.4095e-01, 1.9374e-02,  ..., 2.1227e-18, 1.6532e-18,
         1.4445e-19]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.4409, 0.8819, 0.9013,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([3, 43])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   36


carryover_candidate_logprobs
torch.Size([3])


tensor([-12616.6543, -12616.6543, -12616.6543], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[  278],
        [10819],
        [  263]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-0.8188, -0.8188, -3.9438], device='cuda:0')


new_candidates
torch.Size([3, 44])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304, 10819],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 2989


new_candidate_logprobs
torch.Size([3])


tensor([-12617.4727, -12617.4727, -12620.5977], device='cuda:0')

TOP P PRIOR 21: (15618.43546264) 3 candidates, 1.7263055419989541 inference time, 1.726309431000118 total time
event: message
id: 21-p
data: {"id": "21-p", "level_type": "sample", "duration": 1.7263055419989541, "nodes": [{"content": "\u2581the", "parent": 0, "prob": -12617.47265625}, {"content": "\u2581stars", "parent": 0, "prob": -12617.47265625}, {"content": "\u2581a", "parent": 0, "prob": -12620.59765625}], "finished": []}




TOP P AFTER 21: (15618.435976978) 3 candidates, 1.7263055419989541 inference time, 1.7268235129995446 total time

num_batches


1

infer start: GPU memory used: 7500 MB.

batch_candidates
torch.Size([3, 44])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 44, 32064])

hidden_states[-1]
torch.Size([3, 44, 3072])
infer - after batch run: GPU memory used: 7708 MB.

candidates
torch.Size([3, 44])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304, 10819],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 2989


candidate_logprobs
torch.Size([3])


tensor([-12617.4727, -12617.4727, -12620.5977], device='cuda:0')


embeddings
torch.Size([3, 3072])


tensor([[-0.4941,  1.0469,  0.3203,  ..., -0.8203,  1.7734, -0.1719],
        [ 0.0240,  1.3828, -1.2656,  ..., -0.4824,  0.4355,  1.4141],
        [-0.6289,  1.1172, -0.1670,  ..., -0.4180, -0.3770,  0.1934]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([3])


tensor([ True, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.4941,  1.0469,  0.3203,  ..., -0.8203,  1.7734, -0.1719]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 1])


tensor([[0.9922],
        [1.4219],
        [1.2578]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 44])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278]], device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2]]


new_candidate_logprobs
torch.Size([1])


tensor([-37855.5430], device='cuda:0')

F NEIGHBORS PRIOR 22: (15620.292719991) 1 candidates, 1.8567284390010173 inference time, 1.8567319860012503 total time
event: message
id: f-22"
data: {"id": "f-22", "level_type": "gather", "duration": 1.8567284390010173, "nodes": [{"content": "\u2581the", "parent": 0, "aunts": [1, 2], "prob": -37855.54296875}]}




F NEIGHBORS AFTER 22: (15620.293051119) 1 candidates, 1.8567284390010173 inference time, 1.8570621870003379 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 43])


tensor([], device='cuda:0', size=(0, 43), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 44])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-37855.5430], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[7.2500, 7.7188, 3.9688,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.2576e-01, 1.6956e-02, 1.4964e-02,  ..., 3.9330e-18, 2.8774e-18,
         1.8578e-18]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9258, 0.9427, 0.9577,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 44])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-37855.5430], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[21438]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0771], device='cuda:0')


new_candidates
torch.Size([1, 45])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-37855.6211], device='cuda:0')

TOP P PRIOR 22: (15620.328708056) 1 candidates, 1.8927166100002069 inference time, 1.8927202049999323 total time
event: message
id: 22-p
data: {"id": "22-p", "level_type": "sample", "duration": 1.8927166100002069, "nodes": [{"content": "\u2581closest", "parent": 0, "prob": -37855.62109375}], "finished": []}




TOP P AFTER 22: (15620.329031596) 1 candidates, 1.8927166100002069 inference time, 1.893042757999865 total time

num_batches


1

infer start: GPU memory used: 7708 MB.

batch_candidates
torch.Size([1, 45])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 45, 32064])

hidden_states[-1]
torch.Size([1, 45, 3072])
infer - after batch run: GPU memory used: 7688 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 44])


tensor([], device='cuda:0', size=(0, 44), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 45])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-37855.6211], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 7.3750,  3.6406, -1.2656,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[7.2492e-01, 1.8329e-01, 5.9505e-02,  ..., 3.8385e-21, 1.1707e-21,
         5.5299e-22]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.7249, 0.9082, 0.9677,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 0], device='cuda:0')


carryover_candidates
torch.Size([2, 45])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-37855.6211, -37855.6211], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[5810],
        [6432]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.3217, -1.6967], device='cuda:0')


new_candidates
torch.Size([2, 46])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  6432]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-37855.9414, -37857.3164], device='cuda:0')

TOP P PRIOR 23: (15621.973137192) 2 candidates, 1.6440922800011322 inference time, 1.6440958390012383 total time
event: message
id: 23-p
data: {"id": "23-p", "level_type": "sample", "duration": 1.6440922800011322, "nodes": [{"content": "\u2581star", "parent": 0, "prob": -37855.94140625}, {"content": "\u2581cel", "parent": 0, "prob": -37857.31640625}], "finished": []}




TOP P AFTER 23: (15621.973507982) 2 candidates, 1.6440922800011322 inference time, 1.644465665000098 total time

num_batches


1

infer start: GPU memory used: 7688 MB.

batch_candidates
torch.Size([2, 46])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 46, 32064])

hidden_states[-1]
torch.Size([2, 46, 3072])
infer - after batch run: GPU memory used: 7686 MB.

candidates
torch.Size([2, 46])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  6432]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-37855.9414, -37857.3164], device='cuda:0')


embeddings
torch.Size([2, 3072])


tensor([[ 0.5234,  0.7070,  0.4707,  ...,  0.6172,  1.3438, -0.5703],
        [-0.0684,  0.9961,  0.4414,  ..., -1.4688,  1.7031,  0.3398]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([2])


tensor([ True, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.5234,  0.7070,  0.4707,  ...,  0.6172,  1.3438, -0.5703]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([2, 1])


tensor([[1.0000],
        [1.8594]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([2])


tensor([0, 0], device='cuda:0')


new_candidates
torch.Size([1, 46])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810]], device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1]]


new_candidate_logprobs
torch.Size([1])


tensor([-75713.2578], device='cuda:0')

F NEIGHBORS PRIOR 24: (15623.700970211) 1 candidates, 1.7274483250002959 inference time, 1.727453104000233 total time
event: message
id: f-24"
data: {"id": "f-24", "level_type": "gather", "duration": 1.7274483250002959, "nodes": [{"content": "\u2581star", "parent": 0, "aunts": [1], "prob": -75713.2578125}]}




F NEIGHBORS AFTER 24: (15623.70135885) 1 candidates, 1.7274483250002959 inference time, 1.7278412989999197 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 45])


tensor([], device='cuda:0', size=(0, 45), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 46])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-75713.2578], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[5.6562, 3.8438, 0.8359,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[5.3988e-01, 1.5468e-01, 1.0631e-01,  ..., 1.5608e-19, 1.0727e-19,
         1.8641e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.5399, 0.6946, 0.8009,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([5, 46])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 2988


carryover_candidate_logprobs
torch.Size([5])


tensor([-75713.2578, -75713.2578, -75713.2578, -75713.2578, -75713.2578],
       device='cuda:0')


new_candidate_toks
torch.Size([5, 1])


tensor([[1788],
        [ 304],
        [ 297],
        [2629],
        [ 313]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([5])


tensor([-0.6164, -1.8664, -2.2414, -2.8664, -3.6164], device='cuda:0')


new_candidates
torch.Size([5, 47])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         2987


new_candidate_logprobs
torch.Size([5])


tensor([-75713.8750, -75715.1250, -75715.5000, -75716.1250, -75716.8750],
       device='cuda:0')

TOP P PRIOR 24: (15623.755635234) 5 candidates, 1.782113360999574 inference time, 1.782118740000442 total time
event: message
id: 24-p
data: {"id": "24-p", "level_type": "sample", "duration": 1.782113360999574, "nodes": [{"content": "\u2581system", "parent": 0, "prob": -75713.875}, {"content": "\u2581to", "parent": 0, "prob": -75715.125}, {"content": "\u2581in", "parent": 0, "prob": -75715.5}, {"content": "\u2581within", "parent": 0, "prob": -75716.125}, {"content": "\u2581(", "parent": 0, "prob": -75716.875}], "finished": []}




TOP P AFTER 24: (15623.756353775) 5 candidates, 1.782113360999574 inference time, 1.7828371890009294 total time

num_batches


1

infer start: GPU memory used: 7686 MB.

batch_candidates
torch.Size([5, 47])

batch_candidate_logprobs
torch.Size([5])

batch_logits
torch.Size([5, 47, 32064])

hidden_states[-1]
torch.Size([5, 47, 3072])
infer - after batch run: GPU memory used: 7700 MB.

candidates
torch.Size([5, 47])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         2987


candidate_logprobs
torch.Size([5])


tensor([-75713.8750, -75715.1250, -75715.5000, -75716.1250, -75716.8750],
       device='cuda:0')


embeddings
torch.Size([5, 3072])


tensor([[-0.6523,  0.1475,  0.9141,  ..., -1.2578,  1.1875, -0.7617],
        [ 0.7188,  1.4453,  1.1094,  ..., -1.1016,  0.7812, -1.1875],
        [ 0.2109, -0.4258,  0.0064,  ..., -0.4590,  2.1562,  0.3613],
        [-0.6641,  0.9258,  1.6172,  ..., -2.0312,  3.1875, -0.3145],
        [ 1.5391, -0.6484,  1.0781,  ...,  0.7656, -0.6406, -1.5312]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([5])


tensor([ True, False, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.6523,  0.1475,  0.9141,  ..., -1.2578,  1.1875, -0.7617]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([5, 1])


tensor([[0.9922],
        [1.5000],
        [1.5000],
        [1.5703],
        [1.5625]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 47])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788]], device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2, 3, 4]]


new_candidate_logprobs
torch.Size([1])


tensor([-378577.5000], device='cuda:0')

F NEIGHBORS PRIOR 25: (15625.752695713) 1 candidates, 1.9963165569988632 inference time, 1.9963197699999 total time
event: message
id: f-25"
data: {"id": "f-25", "level_type": "gather", "duration": 1.9963165569988632, "nodes": [{"content": "\u2581system", "parent": 0, "aunts": [1, 2, 3, 4], "prob": -378577.5}]}




F NEIGHBORS AFTER 25: (15625.753025121) 1 candidates, 1.9963165569988632 inference time, 1.996648500999072 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 46])


tensor([], device='cuda:0', size=(0, 46), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 47])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-378577.5000], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[6.6250, 4.1562, 3.0625,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[4.1303e-01, 2.2108e-01, 1.9510e-01,  ..., 6.0043e-20, 3.2139e-20,
         1.5181e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.4130, 0.6341, 0.8292,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([4])


tensor([0, 0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([4, 47])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         2987


carryover_candidate_logprobs
torch.Size([4])


tensor([-378577.5000, -378577.5000, -378577.5000, -378577.5000],
       device='cuda:0')


new_candidate_toks
torch.Size([4, 1])


tensor([[  304],
        [  470],
        [29892],
        [  313]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([4])


tensor([-0.8842, -1.5092, -1.6342, -2.6342], device='cuda:0')


new_candidates
torch.Size([4, 48])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   470],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,


new_candidate_logprobs
torch.Size([4])


tensor([-378578.3750, -378579.0000, -378579.1250, -378580.1250],
       device='cuda:0')

TOP P PRIOR 25: (15625.796372166) 4 candidates, 2.03999273899899 inference time, 2.0399961809998786 total time
event: message
id: 25-p
data: {"id": "25-p", "level_type": "sample", "duration": 2.03999273899899, "nodes": [{"content": "\u2581to", "parent": 0, "prob": -378578.375}, {"content": "\u2581or", "parent": 0, "prob": -378579.0}, {"content": ",", "parent": 0, "prob": -378579.125}, {"content": "\u2581(", "parent": 0, "prob": -378580.125}], "finished": []}




TOP P AFTER 25: (15625.796796762) 4 candidates, 2.03999273899899 inference time, 2.0404203119996964 total time

num_batches


1

infer start: GPU memory used: 7700 MB.

batch_candidates
torch.Size([4, 48])

batch_candidate_logprobs
torch.Size([4])

batch_logits
torch.Size([4, 48, 32064])

hidden_states[-1]
torch.Size([4, 48, 3072])
infer - after batch run: GPU memory used: 7652 MB.

candidates
torch.Size([4, 48])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   470],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,


candidate_logprobs
torch.Size([4])


tensor([-378578.3750, -378579.0000, -378579.1250, -378580.1250],
       device='cuda:0')


embeddings
torch.Size([4, 3072])


tensor([[ 0.5352,  1.3203,  0.1592,  ..., -1.5547,  0.7969, -0.9180],
        [ 0.1099,  0.7930,  0.4980,  ..., -1.2109,  0.7266,  1.0938],
        [-1.4141, -0.2012, -0.7734,  ..., -0.0294,  0.2490, -0.7188],
        [-0.1074, -1.4609,  0.0134,  ..., -0.4805, -0.4316, -0.3613]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([4])


tensor([ True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.5352,  1.3203,  0.1592,  ..., -1.5547,  0.7969, -0.9180]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 1])


tensor([[1.0000],
        [1.5000],
        [1.4766],
        [1.6719]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([4])


tensor([0, 0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 48])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304]],
       device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2, 3]]


new_candidate_logprobs
torch.Size([1])


tensor([-1514316.6250], device='cuda:0')

F NEIGHBORS PRIOR 26: (15627.710874559) 1 candidates, 1.9140636240008462 inference time, 1.9140681130011217 total time
event: message
id: f-26"
data: {"id": "f-26", "level_type": "gather", "duration": 1.9140636240008462, "nodes": [{"content": "\u2581to", "parent": 0, "aunts": [1, 2, 3], "prob": -1514316.625}]}




F NEIGHBORS AFTER 26: (15627.711281222) 1 candidates, 1.9140636240008462 inference time, 1.9144743680008105 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 47])


tensor([], device='cuda:0', size=(0, 47), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 48])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-1514316.6250], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[9.2500, 7.7500, 5.7500,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[8.4600e-01, 1.1449e-01, 3.2803e-02,  ..., 8.9977e-20, 6.5829e-20,
         4.5243e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.8460, 0.9605, 0.9933,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 0], device='cuda:0')


carryover_candidates
torch.Size([2, 48])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-1514316.6250, -1514316.6250], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[11563],
        [  278]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.1672, -2.1672], device='cuda:0')


new_candidates
torch.Size([2, 49])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304,   278]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-1514316.7500, -1514318.7500], device='cuda:0')

TOP P PRIOR 26: (15627.756868077) 2 candidates, 1.9600575690001278 inference time, 1.9600606540006993 total time
event: message
id: 26-p
data: {"id": "26-p", "level_type": "sample", "duration": 1.9600575690001278, "nodes": [{"content": "\u2581Earth", "parent": 0, "prob": -1514316.75}, {"content": "\u2581the", "parent": 0, "prob": -1514318.75}], "finished": []}




TOP P AFTER 26: (15627.757287564) 2 candidates, 1.9600575690001278 inference time, 1.9604799349999666 total time

num_batches


1

infer start: GPU memory used: 7652 MB.

batch_candidates
torch.Size([2, 49])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 49, 32064])

hidden_states[-1]
torch.Size([2, 49, 3072])
infer - after batch run: GPU memory used: 7570 MB.

candidates
torch.Size([2, 49])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304,   278]],
       device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-1514316.7500, -1514318.7500], device='cuda:0')


embeddings
torch.Size([2, 3072])


tensor([[-0.0776,  0.2393,  0.1191,  ..., -1.0547,  0.1021,  0.1572],
        [ 0.7812,  1.2188, -0.4707,  ..., -2.6875,  1.7500,  0.7852]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([2])


tensor([ True, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.0776,  0.2393,  0.1191,  ..., -1.0547,  0.1021,  0.1572]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([2, 1])


tensor([[1.0000],
        [1.5703]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([2])


tensor([0, 0], device='cuda:0')


new_candidates
torch.Size([1, 49])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563]],
       device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1]]


new_candidate_logprobs
torch.Size([1])


tensor([-3028635.5000], device='cuda:0')

F NEIGHBORS PRIOR 27: (15629.584275788) 1 candidates, 1.8269734580007935 inference time, 1.8269780139999057 total time
event: message
id: f-27"
data: {"id": "f-27", "level_type": "gather", "duration": 1.8269734580007935, "nodes": [{"content": "\u2581Earth", "parent": 0, "aunts": [1], "prob": -3028635.5}]}




F NEIGHBORS AFTER 27: (15629.584697442) 1 candidates, 1.8269734580007935 inference time, 1.8273998109998502 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 48])


tensor([], device='cuda:0', size=(0, 48), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 49])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-3028635.5000], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[5.7188, 8.5000, 2.5938,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[7.9090e-01, 8.3360e-02, 6.4921e-02,  ..., 2.1480e-19, 1.3868e-19,
         1.1497e-19]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.7909, 0.8743, 0.9392,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([3, 49])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657,


carryover_candidate_logprobs
torch.Size([3])


tensor([-3028635.5000, -3028635.5000, -3028635.5000], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[29892],
        [  313],
        [ 3265]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-0.2346, -2.4846, -2.7346], device='cuda:0')


new_candidates
torch.Size([3, 50])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563,   313],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,


new_candidate_logprobs
torch.Size([3])


tensor([-3028635.7500, -3028638.0000, -3028638.2500], device='cuda:0')

TOP P PRIOR 27: (15629.627396677) 3 candidates, 1.870094871999754 inference time, 1.8700984940005583 total time
event: message
id: 27-p
data: {"id": "27-p", "level_type": "sample", "duration": 1.870094871999754, "nodes": [{"content": ",", "parent": 0, "prob": -3028635.75}, {"content": "\u2581(", "parent": 0, "prob": -3028638.0}, {"content": "\u2581rather", "parent": 0, "prob": -3028638.25}], "finished": []}




TOP P AFTER 27: (15629.627784875) 3 candidates, 1.870094871999754 inference time, 1.8704858439996315 total time

num_batches


1

infer start: GPU memory used: 7570 MB.

batch_candidates
torch.Size([3, 50])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 50, 32064])

hidden_states[-1]
torch.Size([3, 50, 3072])
infer - after batch run: GPU memory used: 7634 MB.

candidates
torch.Size([3, 50])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563,   313],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,


candidate_logprobs
torch.Size([3])


tensor([-3028635.7500, -3028638.0000, -3028638.2500], device='cuda:0')


embeddings
torch.Size([3, 3072])


tensor([[-0.7109,  0.0796, -0.6914,  ..., -0.3320, -0.3320, -0.8320],
        [ 0.5234, -1.1797,  0.2471,  ..., -1.1172, -0.6602, -1.0781],
        [-1.0156,  0.2617,  0.3926,  ..., -2.4219, -0.1396, -1.5938]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([3])


tensor([ True, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.7109,  0.0796, -0.6914,  ..., -0.3320, -0.3320, -0.8320]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 1])


tensor([[1.0000],
        [1.5000],
        [1.4844]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 50])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892]],
       device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2]]


new_candidate_logprobs
torch.Size([1])


tensor([-9085912.], device='cuda:0')

F NEIGHBORS PRIOR 28: (15631.539682298) 1 candidates, 1.9118757359992742 inference time, 1.9118793939996976 total time
event: message
id: f-28"
data: {"id": "f-28", "level_type": "gather", "duration": 1.9118757359992742, "nodes": [{"content": ",", "parent": 0, "aunts": [1, 2], "prob": -9085912.0}]}




F NEIGHBORS AFTER 28: (15631.540066902) 1 candidates, 1.9118757359992742 inference time, 1.9122635620005894 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 49])


tensor([], device='cuda:0', size=(0, 49), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 50])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-9085912.], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[6.2188, 4.9688, 4.1562,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[2.9895e-01, 2.9895e-01, 1.4122e-01,  ..., 4.8206e-19, 1.9477e-19,
         1.4250e-19]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.2990, 0.5979, 0.7391,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([5, 50])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,


carryover_candidate_logprobs
torch.Size([5])


tensor([-9085912., -9085912., -9085912., -9085912., -9085912.],
       device='cuda:0')


new_candidate_toks
torch.Size([5, 1])


tensor([[372],
        [769],
        [278],
        [393],
        [607]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([5])


tensor([-1.2075, -1.2075, -1.9575, -1.9575, -3.4575], device='cuda:0')


new_candidates
torch.Size([5, 51])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           769],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32


new_candidate_logprobs
torch.Size([5])


tensor([-9085913., -9085913., -9085914., -9085914., -9085915.],
       device='cuda:0')

TOP P PRIOR 28: (15631.586252602) 5 candidates, 1.9584458819990687 inference time, 1.9584490269990056 total time
event: message
id: 28-p
data: {"id": "28-p", "level_type": "sample", "duration": 1.9584458819990687, "nodes": [{"content": "\u2581it", "parent": 0, "prob": -9085913.0}, {"content": "\u2581then", "parent": 0, "prob": -9085913.0}, {"content": "\u2581the", "parent": 0, "prob": -9085914.0}, {"content": "\u2581that", "parent": 0, "prob": -9085914.0}, {"content": "\u2581which", "parent": 0, "prob": -9085915.0}], "finished": []}




TOP P AFTER 28: (15631.586795142) 5 candidates, 1.9584458819990687 inference time, 1.9589913380004873 total time

num_batches


1

infer start: GPU memory used: 7634 MB.

batch_candidates
torch.Size([5, 51])

batch_candidate_logprobs
torch.Size([5])

batch_logits
torch.Size([5, 51, 32064])

hidden_states[-1]
torch.Size([5, 51, 3072])
infer - after batch run: GPU memory used: 7744 MB.

candidates
torch.Size([5, 51])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           769],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32


candidate_logprobs
torch.Size([5])


tensor([-9085913., -9085913., -9085914., -9085914., -9085915.],
       device='cuda:0')


embeddings
torch.Size([5, 3072])


tensor([[-0.3691,  0.0586, -0.1777,  ...,  0.4102, -0.9062,  0.5156],
        [-1.0703,  0.8398,  0.1885,  ...,  2.1562, -1.1328, -0.8203],
        [-1.7188,  1.2500,  0.2832,  ..., -0.8594,  0.2891, -0.3418],
        [ 0.1338, -0.0747,  0.8008,  ...,  1.3906, -0.4590, -0.6484],
        [ 0.3984,  0.1270,  0.1250,  ..., -0.3672, -0.4863,  0.7109]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([5])


tensor([ True, False, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.3691,  0.0586, -0.1777,  ...,  0.4102, -0.9062,  0.5156]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([5, 1])


tensor([[1.0000],
        [1.4297],
        [1.5156],
        [1.2344],
        [1.3125]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 51])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372]], device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2, 3, 4]]


new_candidate_logprobs
torch.Size([1])


tensor([-45429572.], device='cuda:0')

F NEIGHBORS PRIOR 29: (15633.65211168) 1 candidates, 2.0653007889995934 inference time, 2.0653063139998267 total time
event: message
id: f-29"
data: {"id": "f-29", "level_type": "gather", "duration": 2.0653007889995934, "nodes": [{"content": "\u2581it", "parent": 0, "aunts": [1, 2, 3, 4], "prob": -45429572.0}]}




F NEIGHBORS AFTER 29: (15633.65253456) 1 candidates, 2.0653007889995934 inference time, 2.0657289310001943 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 50])


tensor([], device='cuda:0', size=(0, 50), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 51])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-45429572.], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 4.0312, -0.0762,  3.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[4.9189e-01, 2.9835e-01, 1.8096e-01,  ..., 9.2732e-19, 4.1149e-19,
         2.3446e-19]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.4919, 0.7902, 0.9712,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([3, 51])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32


carryover_candidate_logprobs
torch.Size([3])


tensor([-45429572., -45429572., -45429572.], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[  723],
        [29915],
        [  338]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-0.7095, -1.2095, -1.7095], device='cuda:0')


new_candidates
torch.Size([3, 52])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372, 29915],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
    


new_candidate_logprobs
torch.Size([3])


tensor([-45429572., -45429572., -45429572.], device='cuda:0')

TOP P PRIOR 29: (15633.701073366) 3 candidates, 2.114263307999863 inference time, 2.114267351998933 total time
event: message
id: 29-p
data: {"id": "29-p", "level_type": "sample", "duration": 2.114263307999863, "nodes": [{"content": "\u2581would", "parent": 0, "prob": -45429572.0}, {"content": "'", "parent": 0, "prob": -45429572.0}, {"content": "\u2581is", "parent": 0, "prob": -45429572.0}], "finished": []}




TOP P AFTER 29: (15633.701500059) 3 candidates, 2.114263307999863 inference time, 2.1146929179994913 total time

num_batches


1

infer start: GPU memory used: 7744 MB.

batch_candidates
torch.Size([3, 52])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 52, 32064])

hidden_states[-1]
torch.Size([3, 52, 3072])
infer - after batch run: GPU memory used: 7634 MB.

candidates
torch.Size([3, 52])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372, 29915],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
    


candidate_logprobs
torch.Size([3])


tensor([-45429572., -45429572., -45429572.], device='cuda:0')


embeddings
torch.Size([3, 3072])


tensor([[ 0.4473,  0.9336,  0.6211,  ...,  0.9258, -0.3633,  0.5352],
        [ 0.9297, -0.6250,  0.8555,  ...,  0.4238, -1.0312,  0.0508],
        [-0.2354,  0.1436,  0.7461,  ..., -1.5781, -0.1953, -1.0234]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([3])


tensor([ True, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.4473,  0.9336,  0.6211,  ...,  0.9258, -0.3633,  0.5352]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 1])


tensor([[1.0000],
        [1.8516],
        [1.5859]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 52])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723]], device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2]]


new_candidate_logprobs
torch.Size([1])


tensor([-1.3629e+08], device='cuda:0')

F NEIGHBORS PRIOR 30: (15635.618682049) 1 candidates, 1.9171685010005604 inference time, 1.9171718859997782 total time
event: message
id: f-30"
data: {"id": "f-30", "level_type": "gather", "duration": 1.9171685010005604, "nodes": [{"content": "\u2581would", "parent": 0, "aunts": [1, 2], "prob": -136288720.0}]}




F NEIGHBORS AFTER 30: (15635.619019665) 1 candidates, 1.9171685010005604 inference time, 1.9175086630002625 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 51])


tensor([], device='cuda:0', size=(0, 51), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 52])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-1.3629e+08], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[4.6250, 4.6875, 2.0156,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[8.4017e-01, 6.8966e-02, 5.3710e-02,  ..., 4.0047e-19, 2.5856e-19,
         4.1793e-21]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.8402, 0.9091, 0.9628,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 0], device='cuda:0')


carryover_candidates
torch.Size([2, 52])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-1.3629e+08, -1.3629e+08], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[ 367],
        [1603]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.1741, -2.6741], device='cuda:0')


new_candidates
torch.Size([2, 53])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,  1603]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-1.3629e+08, -1.3629e+08], device='cuda:0')

TOP P PRIOR 30: (15635.655750535) 2 candidates, 1.9542370849994768 inference time, 1.9542405089996464 total time
event: message
id: 30-p
data: {"id": "30-p", "level_type": "sample", "duration": 1.9542370849994768, "nodes": [{"content": "\u2581be", "parent": 0, "prob": -136288720.0}, {"content": "\u2581still", "parent": 0, "prob": -136288720.0}], "finished": []}




TOP P AFTER 30: (15635.65611476) 2 candidates, 1.9542370849994768 inference time, 1.9546037039999646 total time

num_batches


1

infer start: GPU memory used: 7634 MB.

batch_candidates
torch.Size([2, 53])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 53, 32064])

hidden_states[-1]
torch.Size([2, 53, 3072])
infer - after batch run: GPU memory used: 7570 MB.

candidates
torch.Size([2, 53])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,  1603]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-1.3629e+08, -1.3629e+08], device='cuda:0')


embeddings
torch.Size([2, 3072])


tensor([[-0.5820,  0.7109,  0.7344,  ..., -1.3750, -1.7266, -1.1328],
        [-0.5742,  2.0312,  1.0859,  ...,  0.6133, -0.3555,  1.0234]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([2])


tensor([ True, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.5820,  0.7109,  0.7344,  ..., -1.3750, -1.7266, -1.1328]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([2, 1])


tensor([[0.9922],
        [1.6328]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([2])


tensor([0, 0], device='cuda:0')


new_candidates
torch.Size([1, 53])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367]], device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1]]


new_candidate_logprobs
torch.Size([1])


tensor([-2.7258e+08], device='cuda:0')

F NEIGHBORS PRIOR 31: (15637.456191404) 1 candidates, 1.8000631730010355 inference time, 1.8000671600002534 total time
event: message
id: f-31"
data: {"id": "f-31", "level_type": "gather", "duration": 1.8000631730010355, "nodes": [{"content": "\u2581be", "parent": 0, "aunts": [1], "prob": -272577440.0}]}




F NEIGHBORS AFTER 31: (15637.456544107) 1 candidates, 1.8000631730010355 inference time, 1.8004190450010356 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 52])


tensor([], device='cuda:0', size=(0, 52), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 53])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-2.7258e+08], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[7.6875, 5.5938, 5.5000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[6.3588e-01, 2.6508e-01, 5.2196e-02,  ..., 5.9683e-20, 3.1946e-20,
         2.0626e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.6359, 0.9010, 0.9532,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 0], device='cuda:0')


carryover_candidates
torch.Size([2, 53])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-2.7258e+08, -2.7258e+08], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[278],
        [838]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.4527, -1.3277], device='cuda:0')


new_candidates
torch.Size([2, 54])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   838]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-2.7258e+08, -2.7258e+08], device='cuda:0')

TOP P PRIOR 31: (15637.497827281) 2 candidates, 1.841698443000496 inference time, 1.8417041989996505 total time
event: message
id: 31-p
data: {"id": "31-p", "level_type": "sample", "duration": 1.841698443000496, "nodes": [{"content": "\u2581the", "parent": 0, "prob": -272577440.0}, {"content": "\u2581Al", "parent": 0, "prob": -272577440.0}], "finished": []}




TOP P AFTER 31: (15637.498231689) 2 candidates, 1.841698443000496 inference time, 1.8421071480006503 total time

num_batches


1

infer start: GPU memory used: 7570 MB.

batch_candidates
torch.Size([2, 54])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 54, 32064])

hidden_states[-1]
torch.Size([2, 54, 3072])
infer - after batch run: GPU memory used: 7570 MB.

candidates
torch.Size([2, 54])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   838]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-2.7258e+08, -2.7258e+08], device='cuda:0')


embeddings
torch.Size([2, 3072])


tensor([[-0.4160,  0.4395, -0.8594,  ..., -0.9141, -0.2344,  0.8750],
        [ 1.0781,  0.3789, -0.8594,  ..., -1.1172,  2.4219, -2.3281]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([2])


tensor([ True, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.4160,  0.4395, -0.8594,  ..., -0.9141, -0.2344,  0.8750]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([2, 1])


tensor([[1.0000],
        [1.7578]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([2])


tensor([0, 0], device='cuda:0')


new_candidates
torch.Size([1, 54])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278]], device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1]]


new_candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')

F NEIGHBORS PRIOR 32: (15639.304296886) 1 candidates, 1.806044300999929 inference time, 1.8060482580003736 total time
event: message
id: f-32"
data: {"id": "f-32", "level_type": "gather", "duration": 1.806044300999929, "nodes": [{"content": "\u2581the", "parent": 0, "aunts": [1], "prob": -545154880.0}]}




F NEIGHBORS AFTER 32: (15639.304678861) 1 candidates, 1.806044300999929 inference time, 1.8064294989999325 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 53])


tensor([], device='cuda:0', size=(0, 53), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 54])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 5.7500, -2.4844,  3.0156,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.8432e-01, 3.5500e-03, 3.5500e-03,  ..., 1.4786e-22, 8.9679e-23,
         7.3613e-24]], device='cuda:0')

tensor([1.], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9843, 0.9879, 0.9914,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 54])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[838]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0158], device='cuda:0')


new_candidates
torch.Size([1, 55])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')

TOP P PRIOR 32: (15639.344657086) 1 candidates, 1.8464044970005489 inference time, 1.8464083630005916 total time
event: message
id: 32-p
data: {"id": "32-p", "level_type": "sample", "duration": 1.8464044970005489, "nodes": [{"content": "\u2581Al", "parent": 0, "prob": -545154880.0}], "finished": []}




TOP P AFTER 32: (15639.345012002) 1 candidates, 1.8464044970005489 inference time, 1.8467625770008453 total time

num_batches


1

infer start: GPU memory used: 7570 MB.

batch_candidates
torch.Size([1, 55])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 55, 32064])

hidden_states[-1]
torch.Size([1, 55, 3072])
infer - after batch run: GPU memory used: 7520 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 54])


tensor([], device='cuda:0', size=(0, 54), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 55])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[-0.7500, -2.8906,  1.6094,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9998e-01, 1.3007e-05, 4.7850e-06,  ..., 8.6439e-22, 6.7319e-22,
         4.0831e-22]], device='cuda:0')

tensor([1.], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 55])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[2026]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-2.0385e-05], device='cuda:0')


new_candidates
torch.Size([1, 56])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')

TOP P PRIOR 33: (15641.097014836) 1 candidates, 1.7519875519992638 inference time, 1.75199166099992 total time
event: message
id: 33-p
data: {"id": "33-p", "level_type": "sample", "duration": 1.7519875519992638, "nodes": [{"content": "pha", "parent": 0, "prob": -545154880.0}], "finished": []}




TOP P AFTER 33: (15641.097387191) 1 candidates, 1.7519875519992638 inference time, 1.7523632630000066 total time

num_batches


1

infer start: GPU memory used: 7520 MB.

batch_candidates
torch.Size([1, 56])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 56, 32064])

hidden_states[-1]
torch.Size([1, 56, 3072])
infer - after batch run: GPU memory used: 7528 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 55])


tensor([], device='cuda:0', size=(0, 55), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 56])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 5.4688,  0.7539, -1.9844,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[1.0000e+00, 3.2887e-06, 3.9278e-07,  ..., 6.5998e-24, 5.1399e-24,
         3.1175e-24]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 56])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[2895]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-4.6492e-06], device='cuda:0')


new_candidates
torch.Size([1, 57])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')

TOP P PRIOR 34: (15642.892655259) 1 candidates, 1.7952528779987915 inference time, 1.7952573989987286 total time
event: message
id: 34-p
data: {"id": "34-p", "level_type": "sample", "duration": 1.7952528779987915, "nodes": [{"content": "\u2581Cent", "parent": 0, "prob": -545154880.0}], "finished": []}




TOP P AFTER 34: (15642.892951331) 1 candidates, 1.7952528779987915 inference time, 1.7955524680000963 total time

num_batches


1

infer start: GPU memory used: 7528 MB.

batch_candidates
torch.Size([1, 57])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 57, 32064])

hidden_states[-1]
torch.Size([1, 57, 3072])
infer - after batch run: GPU memory used: 7520 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 56])


tensor([], device='cuda:0', size=(0, 56), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 57])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[4.4062, 1.8906, 5.6250,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9997e-01, 2.4300e-05, 7.8891e-06,  ..., 1.3697e-25, 3.9243e-26,
         2.8427e-27]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 57])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[29874]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-3.3856e-05], device='cuda:0')


new_candidates
torch.Size([1, 58])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')

TOP P PRIOR 35: (15644.65198525) 1 candidates, 1.7590189830007148 inference time, 1.7590245770006732 total time
event: message
id: 35-p
data: {"id": "35-p", "level_type": "sample", "duration": 1.7590189830007148, "nodes": [{"content": "a", "parent": 0, "prob": -545154880.0}], "finished": []}




TOP P AFTER 35: (15644.653016354) 1 candidates, 1.7590189830007148 inference time, 1.7600556410015997 total time

num_batches


1

infer start: GPU memory used: 7520 MB.

batch_candidates
torch.Size([1, 58])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 58, 32064])

hidden_states[-1]
torch.Size([1, 58, 3072])
infer - after batch run: GPU memory used: 7520 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 57])


tensor([], device='cuda:0', size=(0, 57), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 58])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 7.1250, -2.4688, -0.0121,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[1.0000e+00, 7.7344e-08, 2.8453e-08,  ..., 5.5978e-28, 2.6442e-28,
         7.5759e-29]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 58])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[5338]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-1.1921e-07], device='cuda:0')


new_candidates
torch.Size([1, 59])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')

TOP P PRIOR 36: (15646.411160797) 1 candidates, 1.7581212479999522 inference time, 1.7581254449996777 total time
event: message
id: 36-p
data: {"id": "36-p", "level_type": "sample", "duration": 1.7581212479999522, "nodes": [{"content": "uri", "parent": 0, "prob": -545154880.0}], "finished": []}




TOP P AFTER 36: (15646.411596185) 1 candidates, 1.7581212479999522 inference time, 1.7585598659989046 total time

num_batches


1

infer start: GPU memory used: 7520 MB.

batch_candidates
torch.Size([1, 59])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 59, 32064])

hidden_states[-1]
torch.Size([1, 59, 3072])
infer - after batch run: GPU memory used: 7520 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 58])


tensor([], device='cuda:0', size=(0, 58), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 59])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-5.4515e+08], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 6.8125, -4.8750,  1.6016,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[6.4960e-01, 3.4771e-01, 1.2540e-03,  ..., 1.9214e-23, 4.8581e-24,
         2.0252e-24]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.6496, 0.9973, 0.9986,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 0], device='cuda:0')


carryover_candidates
torch.Size([2, 59])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-5.4515e+08, -5.4515e+08], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[1788],
        [5810]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.4314, -1.0564], device='cuda:0')


new_candidates
torch.Size([2, 60])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  5810]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-5.4515e+08, -5.4515e+08], device='cuda:0')

TOP P PRIOR 37: (15648.173821024) 2 candidates, 1.7622087459985778 inference time, 1.7622141779993399 total time
event: message
id: 37-p
data: {"id": "37-p", "level_type": "sample", "duration": 1.7622087459985778, "nodes": [{"content": "\u2581system", "parent": 0, "prob": -545154880.0}, {"content": "\u2581star", "parent": 0, "prob": -545154880.0}], "finished": []}




TOP P AFTER 37: (15648.174249952) 2 candidates, 1.7622087459985778 inference time, 1.7626420889992005 total time

num_batches


1

infer start: GPU memory used: 7520 MB.

batch_candidates
torch.Size([2, 60])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 60, 32064])

hidden_states[-1]
torch.Size([2, 60, 3072])
infer - after batch run: GPU memory used: 7600 MB.

candidates
torch.Size([2, 60])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  5810]],
       device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-5.4515e+08, -5.4515e+08], device='cuda:0')


embeddings
torch.Size([2, 3072])


tensor([[-0.2891,  0.6016,  0.0581,  ..., -0.0737,  0.7930,  0.9492],
        [-0.0771,  0.6484,  0.5352,  ...,  0.1631,  2.9531,  0.8828]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([2])


tensor([ True, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.2891,  0.6016,  0.0581,  ..., -0.0737,  0.7930,  0.9492]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([2, 1])


tensor([[1.0000],
        [1.5859]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([2])


tensor([0, 0], device='cuda:0')


new_candidates
torch.Size([1, 60])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788]],
       device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1]]


new_candidate_logprobs
torch.Size([1])


tensor([-1.0903e+09], device='cuda:0')

F NEIGHBORS PRIOR 38: (15649.981292453) 1 candidates, 1.8070210209989455 inference time, 1.8070257239996863 total time
event: message
id: f-38"
data: {"id": "f-38", "level_type": "gather", "duration": 1.8070210209989455, "nodes": [{"content": "\u2581system", "parent": 0, "aunts": [1], "prob": -1090309760.0}]}




F NEIGHBORS AFTER 38: (15649.981688651) 1 candidates, 1.8070210209989455 inference time, 1.8074210729992046 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 59])


tensor([], device='cuda:0', size=(0, 59), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 60])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-1.0903e+09], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[8.1250, 4.2500, 5.8438,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[5.9380e-01, 3.6016e-01, 2.0319e-02,  ..., 2.6326e-20, 2.3233e-20,
         5.5183e-21]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.5938, 0.9540, 0.9743,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 0], device='cuda:0')


carryover_candidates
torch.Size([2, 60])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-1.0903e+09, -1.0903e+09], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[29892],
        [29889]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.5212, -1.0212], device='cuda:0')


new_candidates
torch.Size([2, 61])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29889]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-1.0903e+09, -1.0903e+09], device='cuda:0')

TOP P PRIOR 38: (15650.02714712) 2 candidates, 1.8528757069998392 inference time, 1.8528800919993955 total time
event: message
id: 38-p
data: {"id": "38-p", "level_type": "sample", "duration": 1.8528757069998392, "nodes": [{"content": ",", "parent": 0, "prob": -1090309760.0}, {"content": ".", "parent": 0, "prob": -1090309760.0}], "finished": []}




TOP P AFTER 38: (15650.027554957) 2 candidates, 1.8528757069998392 inference time, 1.853286883999317 total time

num_batches


1

infer start: GPU memory used: 7600 MB.

batch_candidates
torch.Size([2, 61])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 61, 32064])

hidden_states[-1]
torch.Size([2, 61, 3072])
infer - after batch run: GPU memory used: 7600 MB.

candidates
torch.Size([2, 61])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29889]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-1.0903e+09, -1.0903e+09], device='cuda:0')


embeddings
torch.Size([2, 3072])


tensor([[-0.2988,  1.6328, -0.8945,  ..., -0.6406,  1.6328,  2.5938],
        [-0.4590,  1.1562,  0.5820,  ..., -0.9062, -0.6758, -0.0505]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([2])


tensor([ True, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.2988,  1.6328, -0.8945,  ..., -0.6406,  1.6328,  2.5938]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([2, 1])


tensor([[1.0000],
        [1.5078]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([2])


tensor([0, 0], device='cuda:0')


new_candidates
torch.Size([1, 61])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892]], device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1]]


new_candidate_logprobs
torch.Size([1])


tensor([-2.1806e+09], device='cuda:0')

F NEIGHBORS PRIOR 39: (15651.859587433) 1 candidates, 1.832017670998539 inference time, 1.8320215809999354 total time
event: message
id: f-39"
data: {"id": "f-39", "level_type": "gather", "duration": 1.832017670998539, "nodes": [{"content": ",", "parent": 0, "aunts": [1], "prob": -2180619520.0}]}




F NEIGHBORS AFTER 39: (15651.859919187) 1 candidates, 1.832017670998539 inference time, 1.8323521649999748 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 60])


tensor([], device='cuda:0', size=(0, 60), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 61])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-2.1806e+09], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[5.9062, 5.3125, 5.7812,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[3.7544e-01, 2.9239e-01, 2.9239e-01,  ..., 9.5787e-20, 7.4599e-20,
         7.4599e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.3754, 0.6678, 0.9602,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([3, 61])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892],
        [    1, 32


carryover_candidate_logprobs
torch.Size([3])


tensor([-2.1806e+09, -2.1806e+09, -2.1806e+09], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[607],
        [310],
        [411]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-0.9797, -1.2297, -1.2297], device='cuda:0')


new_candidates
torch.Size([3, 62])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   310],
    


new_candidate_logprobs
torch.Size([3])


tensor([-2.1806e+09, -2.1806e+09, -2.1806e+09], device='cuda:0')

TOP P PRIOR 39: (15651.906684662) 3 candidates, 1.8791140209996229 inference time, 1.8791197159989679 total time
event: message
id: 39-p
data: {"id": "39-p", "level_type": "sample", "duration": 1.8791140209996229, "nodes": [{"content": "\u2581which", "parent": 0, "prob": -2180619520.0}, {"content": "\u2581of", "parent": 0, "prob": -2180619520.0}, {"content": "\u2581with", "parent": 0, "prob": -2180619520.0}], "finished": []}




TOP P AFTER 39: (15651.907144647) 3 candidates, 1.8791140209996229 inference time, 1.8795784300000378 total time

num_batches


1

infer start: GPU memory used: 7600 MB.

batch_candidates
torch.Size([3, 62])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 62, 32064])

hidden_states[-1]
torch.Size([3, 62, 3072])
infer - after batch run: GPU memory used: 7656 MB.

candidates
torch.Size([3, 62])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   310],
    


candidate_logprobs
torch.Size([3])


tensor([-2.1806e+09, -2.1806e+09, -2.1806e+09], device='cuda:0')


embeddings
torch.Size([3, 3072])


tensor([[ 1.5625e-01, -1.5030e-03, -1.3047e+00,  ..., -1.3125e+00,
         -2.6367e-01,  9.4922e-01],
        [-1.5859e+00,  2.3594e+00,  1.2695e-02,  ...,  4.8633e-01,
          2.3750e+00,  1.5527e-01],
        [-4.1602e-01,  7.8125e-01,  1.5078e+00,  ...,  2.7734e-01,
         -7.2656e-01, -3.9453e-01]], device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([3])


tensor([ True, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.1562, -0.0015, -1.3047,  ..., -1.3125, -0.2637,  0.9492]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 1])


tensor([[1.0000],
        [1.5547],
        [1.5156]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 62])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607]], device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2]]


new_candidate_logprobs
torch.Size([1])


tensor([-6.5419e+09], device='cuda:0')

F NEIGHBORS PRIOR 40: (15653.817013388) 1 candidates, 1.9098477650004497 inference time, 1.9098520549996465 total time
event: message
id: f-40"
data: {"id": "f-40", "level_type": "gather", "duration": 1.9098477650004497, "nodes": [{"content": "\u2581which", "parent": 0, "aunts": [1, 2], "prob": -6541858816.0}]}




F NEIGHBORS AFTER 40: (15653.817415794) 1 candidates, 1.9098477650004497 inference time, 1.9102542590007943 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 61])


tensor([], device='cuda:0', size=(0, 61), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 62])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-6.5419e+09], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[7.7188, 0.9141, 6.1250,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[7.1263e-01, 1.5901e-01, 5.1623e-02,  ..., 1.5073e-19, 9.1422e-20,
         3.5802e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.7126, 0.8716, 0.9233,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([3, 62])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607],
    


carryover_candidate_logprobs
torch.Size([3])


tensor([-6.5419e+09, -6.5419e+09, -6.5419e+09], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[ 7805],
        [  338],
        [11624]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-0.3388, -1.8388, -2.9638], device='cuda:0')


new_candidates
torch.Size([3, 63])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607


new_candidate_logprobs
torch.Size([3])


tensor([-6.5419e+09, -6.5419e+09, -6.5419e+09], device='cuda:0')

TOP P PRIOR 40: (15653.863818291) 3 candidates, 1.9566526719991089 inference time, 1.9566568640002515 total time
event: message
id: 40-p
data: {"id": "40-p", "level_type": "sample", "duration": 1.9566526719991089, "nodes": [{"content": "\u2581includes", "parent": 0, "prob": -6541858816.0}, {"content": "\u2581is", "parent": 0, "prob": -6541858816.0}, {"content": "\u2581consists", "parent": 0, "prob": -6541858816.0}], "finished": []}




TOP P AFTER 40: (15653.86426135) 3 candidates, 1.9566526719991089 inference time, 1.9570990699994582 total time

num_batches


1

infer start: GPU memory used: 7656 MB.

batch_candidates
torch.Size([3, 63])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 63, 32064])

hidden_states[-1]
torch.Size([3, 63, 3072])
infer - after batch run: GPU memory used: 7656 MB.

candidates
torch.Size([3, 63])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607


candidate_logprobs
torch.Size([3])


tensor([-6.5419e+09, -6.5419e+09, -6.5419e+09], device='cuda:0')


embeddings
torch.Size([3, 3072])


tensor([[ 0.7578,  1.1641,  1.9219,  ...,  1.0781, -1.7734,  0.4375],
        [-1.3047,  1.9453, -1.8281,  ..., -0.9961,  0.7852,  1.0078],
        [ 0.1416,  0.8438, -0.0593,  ..., -0.4531,  2.3906, -0.2520]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([3])


tensor([ True, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.7578,  1.1641,  1.9219,  ...,  1.0781, -1.7734,  0.4375]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 1])


tensor([[1.0000],
        [1.4844],
        [1.6719]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 63])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805]], device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2]]


new_candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')

F NEIGHBORS PRIOR 41: (15655.773954934) 1 candidates, 1.9096787779999431 inference time, 1.9096820610011491 total time
event: message
id: f-41"
data: {"id": "f-41", "level_type": "gather", "duration": 1.9096787779999431, "nodes": [{"content": "\u2581includes", "parent": 0, "aunts": [1, 2], "prob": -19625576448.0}]}




F NEIGHBORS AFTER 41: (15655.774282297) 1 candidates, 1.9096787779999431 inference time, 1.9100085940008285 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 62])


tensor([], device='cuda:0', size=(0, 62), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 63])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 6.0938, -2.1094, -0.2227,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.5439e-01, 4.1933e-02, 2.6807e-03,  ..., 1.2651e-22, 3.1988e-23,
         8.0878e-24]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9544, 0.9963, 0.9990,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 63])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[1019]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0467], device='cuda:0')


new_candidates
torch.Size([1, 64])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')

TOP P PRIOR 41: (15655.813840944) 1 candidates, 1.949564543001543 inference time, 1.949568617001205 total time
event: message
id: 41-p
data: {"id": "41-p", "level_type": "sample", "duration": 1.949564543001543, "nodes": [{"content": "\u2581Pro", "parent": 0, "prob": -19625576448.0}], "finished": []}




TOP P AFTER 41: (15655.81419794) 1 candidates, 1.949564543001543 inference time, 1.949924208000084 total time

num_batches


1

infer start: GPU memory used: 7656 MB.

batch_candidates
torch.Size([1, 64])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 64, 32064])

hidden_states[-1]
torch.Size([1, 64, 3072])
infer - after batch run: GPU memory used: 7532 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 63])


tensor([], device='cuda:0', size=(0, 63), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 64])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 3.9531, -1.0312,  2.2812,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[1.0000e+00, 1.4450e-07, 2.5110e-08,  ..., 1.7588e-25, 1.0668e-25,
         1.0668e-25]], device='cuda:0')

tensor([1.], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 64])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[2657]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-1.1921e-07], device='cuda:0')


new_candidates
torch.Size([1, 65])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')

TOP P PRIOR 42: (15657.582861798) 1 candidates, 1.7686505680012488 inference time, 1.768654374000107 total time
event: message
id: 42-p
data: {"id": "42-p", "level_type": "sample", "duration": 1.7686505680012488, "nodes": [{"content": "xim", "parent": 0, "prob": -19625576448.0}], "finished": []}




TOP P AFTER 42: (15657.5831759) 1 candidates, 1.7686505680012488 inference time, 1.7689673770000809 total time

num_batches


1

infer start: GPU memory used: 7532 MB.

batch_candidates
torch.Size([1, 65])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 65, 32064])

hidden_states[-1]
torch.Size([1, 65, 3072])
infer - after batch run: GPU memory used: 7532 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 64])


tensor([], device='cuda:0', size=(0, 64), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 65])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 2.9844,  0.7969, -2.4844,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[1.0000e+00, 2.2159e-08, 1.0467e-08,  ..., 8.3079e-26, 3.0563e-26,
         2.6972e-26]], device='cuda:0')

tensor([1.], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 65])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[29874]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([0.], device='cuda:0')


new_candidates
torch.Size([1, 66])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')

TOP P PRIOR 43: (15659.398585268) 1 candidates, 1.8153959860010218 inference time, 1.8153994500007684 total time
event: message
id: 43-p
data: {"id": "43-p", "level_type": "sample", "duration": 1.8153959860010218, "nodes": [{"content": "a", "parent": 0, "prob": -19625576448.0}], "finished": []}




TOP P AFTER 43: (15659.398901387) 1 candidates, 1.8153959860010218 inference time, 1.8157145340010175 total time

num_batches


1

infer start: GPU memory used: 7532 MB.

batch_candidates
torch.Size([1, 66])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 66, 32064])

hidden_states[-1]
torch.Size([1, 66, 3072])
infer - after batch run: GPU memory used: 7532 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 65])


tensor([], device='cuda:0', size=(0, 65), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 66])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 9.1250, -4.4688, -4.8750,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9995e-01, 3.5356e-05, 1.4738e-05,  ..., 2.1005e-26, 1.8537e-26,
         9.9219e-27]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 66])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[2895]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-5.3646e-05], device='cuda:0')


new_candidates
torch.Size([1, 67])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')

TOP P PRIOR 44: (15661.215577658) 1 candidates, 1.8166628550006862 inference time, 1.8166664929995022 total time
event: message
id: 44-p
data: {"id": "44-p", "level_type": "sample", "duration": 1.8166628550006862, "nodes": [{"content": "\u2581Cent", "parent": 0, "prob": -19625576448.0}], "finished": []}




TOP P AFTER 44: (15661.215917839) 1 candidates, 1.8166628550006862 inference time, 1.8170056079998176 total time

num_batches


1

infer start: GPU memory used: 7532 MB.

batch_candidates
torch.Size([1, 67])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 67, 32064])

hidden_states[-1]
torch.Size([1, 67, 3072])
infer - after batch run: GPU memory used: 7532 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 66])


tensor([], device='cuda:0', size=(0, 66), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 67])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 5.9375, -2.8438, -0.5977,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[1.0000e+00, 1.9947e-06, 1.1861e-08,  ..., 5.0390e-26, 2.3803e-26,
         2.1006e-26]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 67])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[29874]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-2.0266e-06], device='cuda:0')


new_candidates
torch.Size([1, 68])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')

TOP P PRIOR 45: (15663.037207863) 1 candidates, 1.8212758249992476 inference time, 1.821280870000919 total time
event: message
id: 45-p
data: {"id": "45-p", "level_type": "sample", "duration": 1.8212758249992476, "nodes": [{"content": "a", "parent": 0, "prob": -19625576448.0}], "finished": []}




TOP P AFTER 45: (15663.037575949) 1 candidates, 1.8212758249992476 inference time, 1.8216482709995034 total time

num_batches


1

infer start: GPU memory used: 7532 MB.

batch_candidates
torch.Size([1, 68])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 68, 32064])

hidden_states[-1]
torch.Size([1, 68, 3072])
infer - after batch run: GPU memory used: 7532 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 67])


tensor([], device='cuda:0', size=(0, 67), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 68])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 4.0938, -5.4062, -3.4375,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9994e-01, 5.8291e-05, 1.0676e-06,  ..., 2.2139e-27, 7.1873e-28,
         2.0592e-28]], device='cuda:0')

tensor([1.], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 68])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[5338]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-6.1514e-05], device='cuda:0')


new_candidates
torch.Size([1, 69])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')

TOP P PRIOR 46: (15664.870494843) 1 candidates, 1.8328966310000396 inference time, 1.8329028539992578 total time
event: message
id: 46-p
data: {"id": "46-p", "level_type": "sample", "duration": 1.8328966310000396, "nodes": [{"content": "uri", "parent": 0, "prob": -19625576448.0}], "finished": []}




TOP P AFTER 46: (15664.870958514) 1 candidates, 1.8328966310000396 inference time, 1.8333650120002858 total time

num_batches


1

infer start: GPU memory used: 7532 MB.

batch_candidates
torch.Size([1, 69])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 69, 32064])

hidden_states[-1]
torch.Size([1, 69, 3072])
infer - after batch run: GPU memory used: 7538 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 68])


tensor([], device='cuda:0', size=(0, 68), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 69])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-1.9626e+10], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[7.6875, 3.3750, 1.9375,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[6.1252e-01, 1.7549e-01, 1.3667e-01,  ..., 7.3819e-20, 6.9346e-20,
         3.2757e-20]], device='cuda:0')

tensor([1.], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.6125, 0.7880, 0.9247,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([3, 69])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 


carryover_candidate_logprobs
torch.Size([3])


tensor([-1.9626e+10, -1.9626e+10, -1.9626e+10], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[29889],
        [  408],
        [29892]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-0.4902, -1.7402, -1.9902], device='cuda:0')


new_candidates
torch.Size([3, 70])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026, 


new_candidate_logprobs
torch.Size([3])


tensor([-1.9626e+10, -1.9626e+10, -1.9626e+10], device='cuda:0')

TOP P PRIOR 47: (15666.707381134) 3 candidates, 1.836402378999992 inference time, 1.8364063600001828 total time
event: message
id: 47-p
data: {"id": "47-p", "level_type": "sample", "duration": 1.836402378999992, "nodes": [{"content": ".", "parent": 0, "prob": -19625576448.0}, {"content": "\u2581as", "parent": 0, "prob": -19625576448.0}, {"content": ",", "parent": 0, "prob": -19625576448.0}], "finished": []}




TOP P AFTER 47: (15666.70781059) 3 candidates, 1.836402378999992 inference time, 1.836834882000403 total time

num_batches


1

infer start: GPU memory used: 7538 MB.

batch_candidates
torch.Size([3, 70])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 70, 32064])

hidden_states[-1]
torch.Size([3, 70, 3072])
infer - after batch run: GPU memory used: 7676 MB.

candidates
torch.Size([3, 70])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026, 


candidate_logprobs
torch.Size([3])


tensor([-1.9626e+10, -1.9626e+10, -1.9626e+10], device='cuda:0')


embeddings
torch.Size([3, 3072])


tensor([[-0.2363,  1.2031,  0.3945,  ..., -1.0000, -0.8047,  0.2002],
        [-0.6445, -0.0210,  1.4297,  ...,  0.1235,  1.5703, -0.5391],
        [-1.2188, -1.0938,  0.5703,  ...,  0.2109,  0.0259,  1.1094]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([3])


tensor([ True, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.2363,  1.2031,  0.3945,  ..., -1.0000, -0.8047,  0.2002]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 1])


tensor([[1.0000],
        [1.6875],
        [1.7500]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 70])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338, 29889]],
       device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2]]


new_candidate_logprobs
torch.Size([1])


tensor([-5.8877e+10], device='cuda:0')

F NEIGHBORS PRIOR 48: (15668.754414919) 1 candidates, 2.0465892470001563 inference time, 2.0465943600011087 total time
event: message
id: f-48"
data: {"id": "f-48", "level_type": "gather", "duration": 2.0465892470001563, "nodes": [{"content": ".", "parent": 0, "aunts": [1, 2], "prob": -58876731392.0}]}




F NEIGHBORS AFTER 48: (15668.75486212) 1 candidates, 2.0465892470001563 inference time, 2.0470410020006966 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 69])


tensor([], device='cuda:0', size=(0, 69), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 70])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338, 29889]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-5.8877e+10], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[5.6875, 2.4531, 7.6250,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[2.6942e-01, 2.3776e-01, 8.7466e-02,  ..., 3.1423e-17, 3.5255e-18,
         2.7457e-18]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.2694, 0.5072, 0.5946,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([5, 70])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026, 


carryover_candidate_logprobs
torch.Size([5])


tensor([-5.8877e+10, -5.8877e+10, -5.8877e+10, -5.8877e+10, -5.8877e+10],
       device='cuda:0')


new_candidate_toks
torch.Size([5, 1])


tensor([[1205],
        [4001],
        [ 960],
        [1152],
        [ 450]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([5])


tensor([-1.3115, -1.4365, -2.4365, -2.5615, -2.9365], device='cuda:0')


new_candidates
torch.Size([5, 71])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338, 29889,
          1205],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278


new_candidate_logprobs
torch.Size([5])


tensor([-5.8877e+10, -5.8877e+10, -5.8877e+10, -5.8877e+10, -5.8877e+10],
       device='cuda:0')

TOP P PRIOR 48: (15668.811103468) 5 candidates, 2.1032783999999083 inference time, 2.1032825210004376 total time
event: message
id: 48-p
data: {"id": "48-p", "level_type": "sample", "duration": 2.1032783999999083, "nodes": [{"content": "\u2581But", "parent": 0, "prob": -58876731392.0}, {"content": "\u2581Since", "parent": 0, "prob": -58876731392.0}, {"content": "\u2581If", "parent": 0, "prob": -58876731392.0}, {"content": "\u2581For", "parent": 0, "prob": -58876731392.0}, {"content": "\u2581The", "parent": 0, "prob": -58876731392.0}], "finished": []}




TOP P AFTER 48: (15668.811547595) 5 candidates, 2.1032783999999083 inference time, 2.103725185001167 total time

num_batches


1

infer start: GPU memory used: 7676 MB.

batch_candidates
torch.Size([5, 71])

batch_candidate_logprobs
torch.Size([5])

batch_logits
torch.Size([5, 71, 32064])

hidden_states[-1]
torch.Size([5, 71, 3072])
infer - after batch run: GPU memory used: 7852 MB.

candidates
torch.Size([5, 71])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338, 29889,
          1205],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278


candidate_logprobs
torch.Size([5])


tensor([-5.8877e+10, -5.8877e+10, -5.8877e+10, -5.8877e+10, -5.8877e+10],
       device='cuda:0')


embeddings
torch.Size([5, 3072])


tensor([[ 0.7461,  1.5469,  0.3438,  ..., -0.2119, -0.3438,  0.4316],
        [-0.1475,  1.2500,  0.1235,  ..., -0.2119, -0.8828, -0.6836],
        [ 0.0141,  0.4199,  0.1592,  ..., -0.1377, -0.8086, -0.6758],
        [ 0.6133, -0.0244,  2.1719,  ...,  0.4180, -0.3164,  1.2109],
        [-0.6758,  1.4844, -0.2080,  ..., -0.1943,  0.0605,  0.0304]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([5])


tensor([ True, False, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.7461,  1.5469,  0.3438,  ..., -0.2119, -0.3438,  0.4316]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([5, 1])


tensor([[1.0000],
        [1.2109],
        [1.2500],
        [1.3594],
        [1.4219]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 71])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338, 29889,
          1205]], device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2, 3, 4]]


new_candidate_logprobs
torch.Size([1])


tensor([-2.9438e+11], device='cuda:0')

F NEIGHBORS PRIOR 49: (15671.107068304) 1 candidates, 2.295506821001254 inference time, 2.295510722000472 total time
event: message
id: f-49"
data: {"id": "f-49", "level_type": "gather", "duration": 2.295506821001254, "nodes": [{"content": "\u2581But", "parent": 0, "aunts": [1, 2, 3, 4], "prob": -294383648768.0}]}




F NEIGHBORS AFTER 49: (15671.107448485) 1 candidates, 2.295506821001254 inference time, 2.295890186000179 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 70])


tensor([], device='cuda:0', size=(0, 70), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 71])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338, 29889,
          1205]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-2.9438e+11], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[5.4375, 4.5000, 4.7500,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[4.5209e-01, 1.4677e-01, 1.1431e-01,  ..., 2.0649e-17, 2.0013e-17,
         6.2975e-18]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.4521, 0.5989, 0.7132,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([5, 71])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338, 29889,
          1205],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278


carryover_candidate_logprobs
torch.Size([5])


tensor([-2.9438e+11, -2.9438e+11, -2.9438e+11, -2.9438e+11, -2.9438e+11],
       device='cuda:0')


new_candidate_toks
torch.Size([5, 1])


tensor([[ 1951],
        [  565],
        [18719],
        [  363],
        [  278]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([5])


tensor([-0.7939, -1.9189, -2.1689, -2.9189, -3.4189], device='cuda:0')


new_candidates
torch.Size([5, 72])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367,   278,   838,  2026,  2895, 29874,  5338,  1788,
         29892,   607,  7805,  1019,  2657, 29874,  2895, 29874,  5338, 29889,
          1205,  1951],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811,   304,   278, 21438,  5810,  1788,   304, 11563, 29892,
           372,   723,   367


new_candidate_logprobs
torch.Size([5])


tensor([-2.9438e+11, -2.9438e+11, -2.9438e+11, -2.9438e+11, -2.9438e+11],
       device='cuda:0')

TOP P PRIOR 49: (15671.160215533) 5 candidates, 2.3486538409997593 inference time, 2.3486588880005 total time
event: message
id: 49-p
data: {"id": "49-p", "level_type": "sample", "duration": 2.3486538409997593, "nodes": [{"content": "\u2581since", "parent": 0, "prob": -294383648768.0}, {"content": "\u2581if", "parent": 0, "prob": -294383648768.0}, {"content": "\u2581strictly", "parent": 0, "prob": -294383648768.0}, {"content": "\u2581for", "parent": 0, "prob": -294383648768.0}, {"content": "\u2581the", "parent": 0, "prob": -294383648768.0}], "finished": []}




TOP P AFTER 49: (15671.160744312) 5 candidates, 2.3486538409997593 inference time, 2.3491857839999284 total time
event: message
id: END
data: []




