In [1]:
!pip install transformers accelerate optimum nvidia-ml-py torchmetrics



In [2]:
from transformers.utils import is_flash_attn_2_available
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import numpy as np
import torch.nn.functional as F
import torch
from datetime import timedelta
import time
from collections import namedtuple
import json
from sklearn.cluster import KMeans
from torchmetrics.functional import pairwise_cosine_similarity

torch.random.manual_seed(0)

<torch._C.Generator at 0x7fe315056250>

In [3]:


from pynvml import *

def check_gpu(step):
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"{step}: GPU memory used: {info.used // 1024**2} MB.")
    
def D(obj, label=None, c=True):
    print()
    if label:
        print(label)
        
    if isinstance(obj, tuple):
        print('tuple size', len(obj), ':')
        if c: # Contents
            display(obj)
    elif isinstance(obj, torch.Tensor) or isinstance(obj, np.ndarray):
        print(obj.shape)
        if c: # Contents
            display(obj)
    else:
        if c: # Contents
            display(obj)
            
def DS(obj, label=None):
    D(obj, label, c=False)
    


In [18]:
class InferenceTensor:
    def __init__(self):
        print('Initializing model...')
        self.model = AutoModelForCausalLM.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct",
            torch_dtype=torch.bfloat16,
            device_map='auto',
            trust_remote_code=True,
            use_cache=True,
            # attn_implementation='flash_attention_2',
        )
        print('Initializing tokenizer...')
        self.tokenizer = AutoTokenizer.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.eos_token_id = 32007

        self.batch_size = 8
        
    def candidates_generator(self, top_p: float, top_p_decay: float, top_k: float, max_beams: int, max_new_tokens: int, gather_algo: str, prompt: str):
        candidates, candidate_logprobs = self._init_candidates(prompt)
        for level_idx in range(max_new_tokens):
            start = time.perf_counter()
            logits, embeddings = self._infer(candidates, candidate_logprobs)
            
            if candidates.shape[0] > max_beams:
                if gather_algo == 'k_means':
                    candidates, candidate_parents, candidate_aunts, candidate_logprobs, logits = self._k_means(logits, embeddings, candidates, candidate_logprobs, max_beams)
                    inference_duration = time.perf_counter() - start
                    print('K MEANS PRIOR {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))
                    yield self._format_gather(level_idx, 'k', candidates, candidate_parents, candidate_aunts, candidate_logprobs, inference_duration)
                    print('K MEANS AFTER {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))

                elif gather_algo == 'farthest_neighbors':
                    candidates, candidate_parents, candidate_aunts, candidate_logprobs, logits = self._farthest_neighbors(logits, embeddings, candidates, candidate_logprobs, max_beams)
                    inference_duration = time.perf_counter() - start
                    print('F NEIGHBORS PRIOR {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))
                    yield self._format_gather(level_idx, 'f', candidates, candidate_parents, candidate_aunts, candidate_logprobs, inference_duration)
                    print('F NEIGHBORS AFTER {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))

            logits, candidates, candidate_logprobs, max_beams, finished, finished_parents, finished_logprobs = self._select_finished(logits, candidates, candidate_logprobs, max_beams)

            candidates, candidate_parents, candidate_logprobs = self._top_p(logits, candidates, candidate_logprobs, top_p, top_k)
            inference_duration = time.perf_counter() - start
            print('TOP P PRIOR {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))
            yield self._format_top_p(level_idx, candidates, candidate_parents, candidate_logprobs, inference_duration, finished, finished_parents, finished_logprobs)
            print('TOP P AFTER {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))
            top_p *= top_p_decay

        yield f"event: message\nid: END\ndata: []\n\n"

    def _format_gather(self, suffix, level_idx, candidates, candidate_parents, candidate_aunts, candidate_logprobs, duration):
        candidate_texts = self.tokenizer.convert_ids_to_tokens(candidates[:, -1], skip_special_tokens=True)
        candidate_dicts = []
        idx = f"{level_idx}-{suffix}"
        for i in range(len(candidate_texts)):
            candidate_dicts.append({'content': candidate_texts[i], 'parent': candidate_parents[i], 'aunts': candidate_aunts[i], 'prob': candidate_logprobs[i].item()})
        data = json.dumps({'id': idx, 'level_type': 'gather', 'duration': duration, 'nodes': candidate_dicts})
        return f"event: message\nid: {idx}\"\ndata: {data}\n\n"

    def _format_top_p(self, level_idx, candidates, candidate_parents, candidate_logprobs, duration, finished, finished_parents, finished_logprobs):
        candidate_texts = self.tokenizer.convert_ids_to_tokens(candidates[:, -1])
        candidate_dicts = []
        idx = f"{level_idx}-p"
        for i in range(len(candidate_texts)):
            candidate_dicts.append({'content': candidate_texts[i], 'parent': candidate_parents[i], 'prob': candidate_logprobs[i].item()})
        
        finished_texts = self.tokenizer.batch_decode(finished, skip_special_tokens=True)
        finished_parents = finished_parents.tolist()
        finished_dicts = []
        for i in range(len(finished_texts)):
            finished_dicts.append({'content': finished_texts[i], 'parent': finished_parents[i], 'prob': finished_logprobs[i].item()})
        
        data = json.dumps({'id': idx, 'level_type': 'sample', 'duration': duration, 'nodes': candidate_dicts, 'finished': finished_dicts})
        return f"event: message\nid: {idx}\ndata: {data}\n\n"


    def _init_candidates(self, text: str):
        prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
        inputs = self.tokenizer(prompt, return_tensors='pt')
        D(inputs.input_ids, 'input_ids')

        candidates = inputs.input_ids.to(self.device)
        candidate_logprobs = torch.zeros((1), dtype=torch.float32, device=self.device)

        return candidates, candidate_logprobs

    def _k_means(self, logits, embeddings, candidates, candidate_logprobs, max_beams):
        D(candidates, 'candidates')
        D(candidate_logprobs, 'candidate_logprobs')
        # === CPU ===
        embeddings_np = embeddings.float().numpy(force=True)
        D(embeddings_np, 'embeddings_np')
        k_means = KMeans(n_clusters=min(max_beams, embeddings_np.shape[0]), random_state=0, n_init="auto")
        k_mean_space = k_means.fit_transform(embeddings_np)
        D(k_mean_space, 'k_mean_space')
        k_mean_clusters = k_means.predict(embeddings_np)
        D(k_mean_clusters, 'k_mean_clusters')
        k_mean_logprob_mass = np.log(np.bincount(k_mean_clusters, weights=candidate_logprobs.cpu().exp()))
        D(k_mean_logprob_mass, 'k_mean_logprob_mass')
        closest = np.argmin(k_mean_space, axis=0)
        D(closest, 'closest')
        # === END CPU ===
        
        closest_indices = torch.from_numpy(closest).to(self.device)
        new_candidates = candidates.index_select(0, closest_indices)
        D(new_candidates, 'new_candidates')
        new_candidate_parents = closest_indices.tolist()
        D(new_candidate_parents, 'new_candidate_parents')
        new_candidate_aunts = [torch.nonzero(torch.from_numpy(k_mean_clusters).to(self.device) == i).squeeze(-1).tolist() for i in range(new_candidates.shape[0])]
        D(new_candidate_aunts, 'new_candidate_aunts')
        new_candidate_logprobs = torch.from_numpy(k_mean_logprob_mass).to(self.device)
        D(new_candidate_logprobs, 'new_candidate_logprobs')
        new_candidate_logits = logits.index_select(0, closest_indices)
        
        return new_candidates, new_candidate_parents, new_candidate_aunts, new_candidate_logprobs, new_candidate_logits
        
    def _farthest_neighbors(self, logits, embeddings, candidates, candidate_logprobs, max_beams):
        D(candidates, 'candidates')
        D(candidate_logprobs, 'candidate_logprobs')
        D(embeddings, 'embeddings')
        
        selected = torch.zeros((candidates.shape[0],), dtype=torch.bool).to(self.device)
        max_prob_idx = candidate_logprobs.argmax()
        selected[max_prob_idx] = 1
        
        D(selected, 'selected')
        
        for idx in range(min(max_beams - 1, candidates.shape[0])):
            selected_embeddings = embeddings[selected]
            D(selected_embeddings, 'selected_embeddings')
            # Add 2 because bfloat16 on cuda can have imprecision and we need 0 to be lower than every
            # cosine distance
            distances = torch.add(2, pairwise_cosine_similarity(embeddings, selected_embeddings), alpha=-1)
            D(distances, 'distances')
            min_distances = torch.min(distances, dim=1).values
            D(min_distances, 'min_distances')
            min_remaining_distances = min_distances * ~selected
            D(min_remaining_distances, 'min_remaining_distances')
            next_selected = min_remaining_distances.argmax(dim=0)
            selected[next_selected] = 1
            D(selected, 'selected (end of loop)')
            
        # We have all the candidates that are selected to move forward. Figure out which probability mass
        # to assign where.
        selected_embeddings = embeddings[selected]
        D(selected_embeddings, 'selected_embeddings')
        # Add 2 because bfloat16 on cuda can have imprecision and we need 0 to be lower than every
        # cosine distance
        distances = torch.add(2, pairwise_cosine_similarity(embeddings, selected_embeddings), alpha=-1)
        D(distances, 'distances')
        
        closest_per_candidate = distances.argmin(dim=1)
        D(closest_per_candidate, 'closest_per_candidate')
        
        new_candidates = candidates[selected]
        D(new_candidates, 'new_candidates')
        new_candidate_parents = torch.arange(candidates.shape[0]).to(self.device)[selected].tolist()
        D(new_candidate_parents, 'new_candidate_parents')
        new_candidate_aunts = [list(filter(lambda x: x != i, torch.nonzero(closest_per_candidate == i).squeeze(-1).tolist())) \
                       for i in range(new_candidates.shape[0])]
        D(new_candidate_aunts, 'new_candidate_aunts')
        new_candidate_logprobs = torch.zeros((new_candidates.shape[0],)).to(self.device)
        new_candidate_logprobs.index_add_(0, closest_per_candidate, candidate_logprobs)
        D(new_candidate_logprobs, 'new_candidate_logprobs')
        new_candidate_logits = logits[selected]
        
        return new_candidates, new_candidate_parents, new_candidate_aunts, new_candidate_logprobs, new_candidate_logits

    
    def _select_finished(self, logits, candidates, candidate_logprobs, max_beams):
        finished_mask = candidates[:,-1] == self.eos_token_id
        unfinished_mask = ~finished_mask
        D(finished_mask, 'finished_mask')
        
        new_logits = logits[unfinished_mask]
        new_candidates = candidates[unfinished_mask]
        new_candidate_logprobs = candidate_logprobs[unfinished_mask]
        new_max_beams = max_beams - finished_mask.sum()
        
        finished = candidates[finished_mask][:,:-1] # Remove the EOS token
        D(finished, 'finished')
        finished_parents = torch.arange(candidates.shape[0], device=self.device)[finished_mask]
        D(finished_parents, 'finished_parents')
        finished_logprobs = candidate_logprobs[finished_mask]
        D(finished_logprobs, 'finished_logprobs')
        
        return new_logits, new_candidates, new_candidate_logprobs, new_max_beams, finished, finished_parents, finished_logprobs
    

    def _top_p(self, logits, candidates, candidate_logprobs, top_p, top_k):
        D(candidates, 'candidates')
        D(candidate_logprobs, 'candidate_logprobs')
        
        last_tok_logits = logits[:, -1, :]
        D(last_tok_logits, 'last_tok_logits')

        sorted_logits, sorted_indices = torch.sort(last_tok_logits, descending=True, dim=-1)
        DS(sorted_logits, 'sorted_logits')
        DS(sorted_indices, 'sorted_indices')
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        D(sorted_probs, 'sorted_probs')
        display(sorted_probs.sum(dim=1))
        cum_probs = torch.cumsum(sorted_probs, dim=-1)
        D(cum_probs, 'cum_probs')

        # Create tensor of bools indicating which indices are cumulatively less than top_p
        keep_indices = cum_probs < top_p

        # Keep the last element that went over top_p
        keep_indices[:, 1:] = keep_indices[:, :-1].clone() # Is this inefficient?
        keep_indices[:, 0] = 1  # Always keep the first element
        D(keep_indices, 'keep_indices')

        # Don't keep any indices that are greater than top_k
        keep_indices[:, top_k:] = 0
        D(keep_indices, 'keep_indices after top_k')

        new_candidate_parents = keep_indices.nonzero()[:, 0]
        D(new_candidate_parents, 'new_candidate_parents')

        # OPTIM: Potential optimization -- have a fixed tensor of size (max_candidates, max_tokens) and copy this into that (batch-aware).
        # OPTIM: consider which of these operations can be done in-place to prevent new allocations?
        carryover_candidates = candidates.index_select(0, new_candidate_parents)
        D(carryover_candidates, 'carryover_candidates')

        # Similar code could be used to trace entire origin of sequence. For now since server just traces parent of the preceding generation, not needed
        # carryover_candidate_parents = candidate_parents.index_select(0, carryover_candidate_indices)  # Not strictly necessary since 1d
        # D(carryover_candidate_parents, 'carryover_candidate_parents')

        carryover_candidate_logprobs = candidate_logprobs.index_select(0, new_candidate_parents)  # Not strictly necessary since 1d
        D(carryover_candidate_logprobs, 'carryover_candidate_logprobs')

        new_candidate_toks = sorted_indices[keep_indices].unsqueeze(1)
        D(new_candidate_toks, 'new_candidate_toks')
        new_candidate_tok_logprobs = sorted_probs[keep_indices].log()
        D(new_candidate_tok_logprobs, 'new_candidate_tok_logprobs')

        new_candidates = torch.cat([carryover_candidates, new_candidate_toks], dim=1)
        D(new_candidates, 'new_candidates')
        new_candidate_logprobs = carryover_candidate_logprobs.add_(new_candidate_tok_logprobs)
        D(new_candidate_logprobs, 'new_candidate_logprobs')

        return new_candidates, new_candidate_parents.tolist(), new_candidate_logprobs


    def _infer(self, candidates, candidate_logprobs):
        with torch.inference_mode():
            num_batches = (candidates.shape[0] + self.batch_size - 1) // self.batch_size  # Round up to nearest whole number of batches
            D(num_batches, 'num_batches')

            check_gpu('infer start')
            output_logits_list = []
            output_embeddings_list = []
            for i in range(0, num_batches, 1):
                batch_candidates = candidates[i * self.batch_size:(i + 1) * self.batch_size]
                DS(batch_candidates, 'batch_candidates')
                batch_candidate_logprobs = candidate_logprobs[i * self.batch_size:(i + 1) * self.batch_size]
                DS(batch_candidate_logprobs, 'batch_candidate_logprobs')

                batch_outputs = self.model(input_ids=batch_candidates, output_hidden_states=True)
                DS(batch_outputs.logits, 'batch_logits')
                DS(batch_outputs.hidden_states[-1], 'hidden_states[-1]')

                output_logits_list.append(batch_outputs.logits)
                output_embeddings_list.append(batch_outputs.hidden_states[-1][:,-1,:])
                check_gpu('infer - after batch run')

            output_logits = torch.cat(output_logits_list, dim=0)
            output_embeddings = torch.cat(output_embeddings_list, dim=0)
            
            return output_logits, output_embeddings
        
it = InferenceTensor()

for x in it.candidates_generator(top_p=0.9, top_p_decay=1.0, top_k=5, max_beams=3, max_new_tokens=20, gather_algo='farthest_neighbors', prompt='What is the closest star to the Earth? Answer in 5 words or less.'):
    print(x)
    print()
    print('====================================')
    print()

Initializing model...




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Initializing tokenizer...

input_ids
torch.Size([1, 22])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001]])


num_batches


1

infer start: GPU memory used: 7914 MB.

batch_candidates
torch.Size([1, 22])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 22, 32064])

hidden_states[-1]
torch.Size([1, 22, 3072])
infer - after batch run: GPU memory used: 7934 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 21])


tensor([], device='cuda:0', size=(0, 21), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 22])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([0.], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[11.4375,  9.4375,  7.8125,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[8.3999e-01, 1.4597e-01, 1.3577e-02,  ..., 4.0964e-23, 4.0964e-23,
         1.2370e-24]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.8400, 0.9860, 0.9995,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 0], device='cuda:0')


carryover_candidates
torch.Size([2, 22])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([0., 0.], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[ 450],
        [1019]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.1744, -1.9244], device='cuda:0')


new_candidates
torch.Size([2, 23])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-0.1744, -1.9244], device='cuda:0')

TOP P PRIOR 0: (15099.106974232) 2 candidates, 0.2207208370000444 inference time, 0.22072459899936803 total time
event: message
id: 0-p
data: {"id": "0-p", "level_type": "sample", "duration": 0.2207208370000444, "nodes": [{"content": "\u2581The", "parent": 0, "prob": -0.1743605136871338}, {"content": "\u2581Pro", "parent": 0, "prob": -1.9243606328964233}], "finished": []}




TOP P AFTER 0: (15099.10732129) 2 candidates, 0.2207208370000444 inference time, 0.22107067899924004 total time

num_batches


1

infer start: GPU memory used: 7934 MB.

batch_candidates
torch.Size([2, 23])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 23, 32064])

hidden_states[-1]
torch.Size([2, 23, 3072])
infer - after batch run: GPU memory used: 7960 MB.

finished_mask
torch.Size([2])


tensor([False, False], device='cuda:0')


finished
torch.Size([0, 22])


tensor([], device='cuda:0', size=(0, 22), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([2, 23])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-0.1744, -1.9244], device='cuda:0')


last_tok_logits
torch.Size([2, 32064])


tensor([[4.8750, 6.0938, 2.9844,  ..., 0.0000, 0.0000, 0.0000],
        [1.0000, 0.5977, 1.7891,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[9.8566e-01, 1.4060e-02, 1.2164e-04,  ..., 1.8638e-24, 9.9760e-25,
         1.7336e-25],
        [1.0000e+00, 3.6535e-08, 3.8507e-09,  ..., 3.9244e-26, 1.1244e-26,
         9.9224e-27]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[0.9857, 0.9997, 0.9998,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 1], device='cuda:0')


carryover_candidates
torch.Size([2, 23])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-0.1744, -1.9244], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[21438],
        [ 2657]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.0144,  0.0000], device='cuda:0')


new_candidates
torch.Size([2, 24])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-0.1888, -1.9244], device='cuda:0')

TOP P PRIOR 1: (15099.335546255) 2 candidates, 0.22821116799968877 inference time, 0.22821650900004897 total time
event: message
id: 1-p
data: {"id": "1-p", "level_type": "sample", "duration": 0.22821116799968877, "nodes": [{"content": "\u2581closest", "parent": 0, "prob": -0.1888073980808258}, {"content": "xim", "parent": 1, "prob": -1.9243606328964233}], "finished": []}




TOP P AFTER 1: (15099.335954134) 2 candidates, 0.22821116799968877 inference time, 0.2286232700007531 total time

num_batches


1

infer start: GPU memory used: 7960 MB.

batch_candidates
torch.Size([2, 24])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 24, 32064])

hidden_states[-1]
torch.Size([2, 24, 3072])
infer - after batch run: GPU memory used: 7960 MB.

finished_mask
torch.Size([2])


tensor([False, False], device='cuda:0')


finished
torch.Size([0, 23])


tensor([], device='cuda:0', size=(0, 23), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([2, 24])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-0.1888, -1.9244], device='cuda:0')


last_tok_logits
torch.Size([2, 32064])


tensor([[ 4.0625, -3.8125, -1.6406,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.3750, -1.3984, -2.5938,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[9.9960e-01, 3.7998e-04, 6.1418e-06,  ..., 1.5515e-25, 1.0663e-25,
         4.4452e-26],
        [9.9999e-01, 6.1442e-06, 1.8554e-07,  ..., 9.1107e-23, 4.8766e-23,
         1.0121e-24]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[0.9996, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 1], device='cuda:0')


carryover_candidates
torch.Size([2, 24])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-0.1888, -1.9244], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[ 5810],
        [29874]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-3.9544e-04, -6.4373e-06], device='cuda:0')


new_candidates
torch.Size([2, 25])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-0.1892, -1.9244], device='cuda:0')

TOP P PRIOR 2: (15099.565083536) 2 candidates, 0.2291099750000285 inference time, 0.22911383499922522 total time
event: message
id: 2-p
data: {"id": "2-p", "level_type": "sample", "duration": 0.2291099750000285, "nodes": [{"content": "\u2581star", "parent": 0, "prob": -0.18920283019542694}, {"content": "a", "parent": 1, "prob": -1.924367070198059}], "finished": []}




TOP P AFTER 2: (15099.565450923) 2 candidates, 0.2291099750000285 inference time, 0.22948022300079174 total time

num_batches


1

infer start: GPU memory used: 7960 MB.

batch_candidates
torch.Size([2, 25])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 25, 32064])

hidden_states[-1]
torch.Size([2, 25, 3072])
infer - after batch run: GPU memory used: 7968 MB.

finished_mask
torch.Size([2])


tensor([False, False], device='cuda:0')


finished
torch.Size([0, 24])


tensor([], device='cuda:0', size=(0, 24), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([2, 25])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-0.1892, -1.9244], device='cuda:0')


last_tok_logits
torch.Size([2, 32064])


tensor([[ 5.0312, -2.1250, -1.3906,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.9375, -3.5312, -0.2754,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[8.4862e-01, 1.4747e-01, 3.0606e-03,  ..., 1.1249e-22, 9.9275e-23,
         8.1490e-24],
        [9.9999e-01, 3.2887e-06, 1.5535e-06,  ..., 3.2858e-25, 1.7588e-25,
         5.7099e-26]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[0.8486, 0.9961, 0.9991,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([2, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 0, 1], device='cuda:0')


carryover_candidates
torch.Size([3, 25])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([-0.1892, -0.1892, -1.9244], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[ 304],
        [ 338],
        [2895]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-1.6415e-01, -1.9141e+00, -5.4836e-06], device='cuda:0')


new_candidates
torch.Size([3, 26])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895]], device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([-0.3533, -2.1033, -1.9244], device='cuda:0')

TOP P PRIOR 3: (15099.79538959) 3 candidates, 0.2299259939991316 inference time, 0.22993001799841295 total time
event: message
id: 3-p
data: {"id": "3-p", "level_type": "sample", "duration": 0.2299259939991316, "nodes": [{"content": "\u2581to", "parent": 0, "prob": -0.35334891080856323}, {"content": "\u2581is", "parent": 0, "prob": -2.103348970413208}, {"content": "\u2581Cent", "parent": 1, "prob": -1.9243725538253784}], "finished": []}




TOP P AFTER 3: (15099.795836057) 3 candidates, 0.2299259939991316 inference time, 0.2303755529992486 total time

num_batches


1

infer start: GPU memory used: 7968 MB.

batch_candidates
torch.Size([3, 26])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 26, 32064])

hidden_states[-1]
torch.Size([3, 26, 3072])
infer - after batch run: GPU memory used: 8008 MB.

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 25])


tensor([], device='cuda:0', size=(0, 25), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 26])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-0.3533, -2.1033, -1.9244], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 8.6250,  4.4062,  5.5000,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.3438,  2.0469,  1.5781,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.5000, -0.5469,  2.4531,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[9.9532e-01, 4.6092e-03, 5.8022e-05,  ..., 1.6941e-22, 1.5915e-22,
         6.2324e-23],
        [4.8121e-01, 4.8121e-01, 2.7148e-02,  ..., 1.2813e-21, 1.7340e-22,
         9.2814e-23],
        [9.9994e-01, 4.5397e-05, 7.8888e-06,  ..., 1.3971e-23, 4.0027e-24,
         1.7587e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[0.9953, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.4812, 0.9624, 0.9896,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([4])


tensor([0, 1, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([4, 26])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([4])


tensor([-0.3533, -2.1033, -2.1033, -1.9244], device='cuda:0')


new_candidate_toks
torch.Size([4, 1])


tensor([[11563],
        [  278],
        [ 1019],
        [29874]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([4])


tensor([-4.6923e-03, -7.3145e-01, -7.3145e-01, -6.1633e-05], device='cuda:0')


new_candidates
torch.Size([4, 27])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874]], device='cuda:0')


new_candidate_logprobs
torch.Size([4])


tensor([-0.3580, -2.8348, -2.8348, -1.9244], device='cuda:0')

TOP P PRIOR 4: (15100.164542516) 4 candidates, 0.3686762090001139 inference time, 0.3686968040001375 total time
event: message
id: 4-p
data: {"id": "4-p", "level_type": "sample", "duration": 0.3686762090001139, "nodes": [{"content": "\u2581Earth", "parent": 0, "prob": -0.3580411970615387}, {"content": "\u2581the", "parent": 1, "prob": -2.834794282913208}, {"content": "\u2581Pro", "parent": 1, "prob": -2.834794282913208}, {"content": "a", "parent": 2, "prob": -1.9244341850280762}], "finished": []}




TOP P AFTER 4: (15100.164976086) 4 candidates, 0.3686762090001139 inference time, 0.369128788001035 total time

num_batches


1

infer start: GPU memory used: 8028 MB.

batch_candidates
torch.Size([4, 27])

batch_candidate_logprobs
torch.Size([4])

batch_logits
torch.Size([4, 27, 32064])

hidden_states[-1]
torch.Size([4, 27, 3072])
infer - after batch run: GPU memory used: 8042 MB.

candidates
torch.Size([4, 27])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874]], device='cuda:0')


candidate_logprobs
torch.Size([4])


tensor([-0.3580, -2.8348, -2.8348, -1.9244], device='cuda:0')


embeddings
torch.Size([4, 3072])


tensor([[ 0.6406, -0.9219,  1.1172,  ...,  1.5547, -1.2734,  0.4902],
        [ 1.4297,  0.7031,  1.1250,  ...,  0.9375, -0.9375,  0.4629],
        [-2.2188,  0.7305, -0.4473,  ..., -1.6953, -0.0540,  0.2793],
        [ 0.7383, -1.8281,  1.8594,  ..., -2.5469,  0.9141, -0.3320]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([4])


tensor([ True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.6406, -0.9219,  1.1172,  ...,  1.5547, -1.2734,  0.4902]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 1])


tensor([[1.0000],
        [1.7266],
        [1.8984],
        [1.9297]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([1.0000, 1.7266, 1.8984, 1.9297], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([0.0000, 1.7266, 1.8984, 1.9297], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([ True, False, False,  True], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[ 0.6406, -0.9219,  1.1172,  ...,  1.5547, -1.2734,  0.4902],
        [ 0.7383, -1.8281,  1.8594,  ..., -2.5469,  0.9141, -0.3320]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 2])


tensor([[1.0000, 1.9297],
        [1.7266, 1.9297],
        [1.8984, 1.9531],
        [1.9297, 0.9922]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([1.0000, 1.7266, 1.8984, 0.9922], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([0.0000, 1.7266, 1.8984, 0.0000], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([ True, False,  True,  True], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[ 0.6406, -0.9219,  1.1172,  ...,  1.5547, -1.2734,  0.4902],
        [-2.2188,  0.7305, -0.4473,  ..., -1.6953, -0.0540,  0.2793],
        [ 0.7383, -1.8281,  1.8594,  ..., -2.5469,  0.9141, -0.3320]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 3])


tensor([[1.0000, 1.8984, 1.9297],
        [1.7266, 1.8594, 1.9297],
        [1.8984, 1.0000, 1.9531],
        [1.9297, 1.9531, 0.9922]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([4])


tensor([0, 0, 1, 2], device='cuda:0')


new_candidates
torch.Size([3, 27])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874]], device='cuda:0')


new_candidate_parents


[0, 2, 3]


new_candidate_aunts


[[1], [2], [3]]


new_candidate_logprobs
torch.Size([3])


tensor([-3.1928, -2.8348, -1.9244], device='cuda:0')

F NEIGHBORS PRIOR 5: (15100.562300037) 3 candidates, 0.39731048899921007 inference time, 0.3973151969985338 total time
event: message
id: f-5"
data: {"id": "f-5", "level_type": "gather", "duration": 0.39731048899921007, "nodes": [{"content": "\u2581Earth", "parent": 0, "aunts": [1], "prob": -3.192835569381714}, {"content": "\u2581Pro", "parent": 2, "aunts": [2], "prob": -2.834794282913208}, {"content": "a", "parent": 3, "aunts": [3], "prob": -1.9244341850280762}]}




F NEIGHBORS AFTER 5: (15100.562806428) 3 candidates, 0.39731048899921007 inference time, 0.39782035699863627 total time

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 26])


tensor([], device='cuda:0', size=(0, 26), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 27])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-3.1928, -2.8348, -1.9244], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 4.8438,  1.5391, -1.1562,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.2422,  1.3594,  0.7656,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.5625, -3.9688,  0.3242,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[9.9806e-01, 1.7003e-03, 1.3957e-04,  ..., 6.5870e-24, 6.1880e-24,
         4.5272e-24],
        [1.0000e+00, 1.2752e-07, 4.5991e-10,  ..., 9.9224e-27, 6.0183e-27,
         4.3596e-28],
        [9.9994e-01, 5.1442e-05, 7.8888e-06,  ..., 2.0592e-28, 1.2490e-28,
         8.5840e-29]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[0.9981, 0.9998, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([3, 27])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([-3.1928, -2.8348, -1.9244], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[ 338],
        [2657],
        [5338]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-1.9372e-03, -1.1921e-07, -6.2229e-05], device='cuda:0')


new_candidates
torch.Size([3, 28])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([-3.1948, -2.8348, -1.9245], device='cuda:0')

TOP P PRIOR 5: (15100.609668954) 3 candidates, 0.4446796529991843 inference time, 0.4446839379997982 total time
event: message
id: 5-p
data: {"id": "5-p", "level_type": "sample", "duration": 0.4446796529991843, "nodes": [{"content": "\u2581is", "parent": 0, "prob": -3.194772720336914}, {"content": "xim", "parent": 1, "prob": -2.834794521331787}, {"content": "uri", "parent": 2, "prob": -1.9244964122772217}], "finished": []}




TOP P AFTER 5: (15100.61009946) 3 candidates, 0.4446796529991843 inference time, 0.44511345899991284 total time

num_batches


1

infer start: GPU memory used: 8042 MB.

batch_candidates
torch.Size([3, 28])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 28, 32064])

hidden_states[-1]
torch.Size([3, 28, 3072])
infer - after batch run: GPU memory used: 8042 MB.

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 27])


tensor([], device='cuda:0', size=(0, 27), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 28])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338]],
       device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-3.1948, -2.8348, -1.9245], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 6.0000,  2.2031,  1.1953,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.7656,  0.0420, -1.2969,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.6250, -0.1963,  8.1250,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[5.2175e-01, 4.6044e-01, 1.3904e-02,  ..., 3.0997e-22, 3.0997e-22,
         1.0063e-22],
        [1.0000e+00, 1.1861e-08, 1.6052e-09,  ..., 1.9930e-25, 4.4469e-26,
         3.0563e-26],
        [9.0407e-01, 9.5289e-02, 2.3620e-04,  ..., 3.6190e-24, 2.1950e-24,
         1.7095e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[0.5217, 0.9822, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9041, 0.9994, 0.9996,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([4])


tensor([0, 0, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([4, 28])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([4])


tensor([-3.1948, -3.1948, -2.8348, -1.9245], device='cuda:0')


new_candidate_toks
torch.Size([4, 1])


tensor([[  278],
        [ 1019],
        [29874],
        [  338]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([4])


tensor([-0.6506, -0.7756,  0.0000, -0.1008], device='cuda:0')


new_candidates
torch.Size([4, 29])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([4])


tensor([-3.8453, -3.9703, -2.8348, -2.0253], device='cuda:0')

TOP P PRIOR 6: (15100.988347777) 4 candidates, 0.37823333400046977 inference time, 0.3782384930000262 total time
event: message
id: 6-p
data: {"id": "6-p", "level_type": "sample", "duration": 0.37823333400046977, "nodes": [{"content": "\u2581the", "parent": 0, "prob": -3.8453469276428223}, {"content": "\u2581Pro", "parent": 0, "prob": -3.9703469276428223}, {"content": "a", "parent": 1, "prob": -2.834794521331787}, {"content": "\u2581is", "parent": 2, "prob": -2.025339365005493}], "finished": []}




TOP P AFTER 6: (15100.988777299) 4 candidates, 0.37823333400046977 inference time, 0.3786663390001195 total time

num_batches


1

infer start: GPU memory used: 8042 MB.

batch_candidates
torch.Size([4, 29])

batch_candidate_logprobs
torch.Size([4])

batch_logits
torch.Size([4, 29, 32064])

hidden_states[-1]
torch.Size([4, 29, 3072])
infer - after batch run: GPU memory used: 8068 MB.

candidates
torch.Size([4, 29])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338]],
       device='cuda:0')


candidate_logprobs
torch.Size([4])


tensor([-3.8453, -3.9703, -2.8348, -2.0253], device='cuda:0')


embeddings
torch.Size([4, 3072])


tensor([[ 1.3750,  0.6758,  1.4531,  ...,  1.0156, -0.7930,  0.9414],
        [-2.0781,  0.7969, -0.2637,  ..., -1.6641,  0.0510,  0.3867],
        [-0.8594,  0.9336, -0.1465,  ..., -1.2500, -1.3047, -1.0391],
        [-0.6719,  0.5938,  0.6641,  ...,  0.8320, -0.2178, -1.0312]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([4])


tensor([False, False, False,  True], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.6719,  0.5938,  0.6641,  ...,  0.8320, -0.2178, -1.0312]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 1])


tensor([[1.6797],
        [1.8516],
        [1.7812],
        [1.0000]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([1.6797, 1.8516, 1.7812, 1.0000], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([1.6797, 1.8516, 1.7812, 0.0000], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([False,  True, False,  True], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[-2.0781,  0.7969, -0.2637,  ..., -1.6641,  0.0510,  0.3867],
        [-0.6719,  0.5938,  0.6641,  ...,  0.8320, -0.2178, -1.0312]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 2])


tensor([[1.8516, 1.6797],
        [1.0000, 1.8516],
        [1.8203, 1.7812],
        [1.8516, 1.0000]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([1.6797, 1.0000, 1.7812, 1.0000], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([1.6797, 0.0000, 1.7812, 0.0000], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([False,  True,  True,  True], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[-2.0781,  0.7969, -0.2637,  ..., -1.6641,  0.0510,  0.3867],
        [-0.8594,  0.9336, -0.1465,  ..., -1.2500, -1.3047, -1.0391],
        [-0.6719,  0.5938,  0.6641,  ...,  0.8320, -0.2178, -1.0312]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 3])


tensor([[1.8516, 1.7969, 1.6797],
        [1.0000, 1.8203, 1.8516],
        [1.8203, 1.0000, 1.7812],
        [1.8516, 1.7812, 1.0000]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([4])


tensor([2, 0, 1, 2], device='cuda:0')


new_candidates
torch.Size([3, 29])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338]],
       device='cuda:0')


new_candidate_parents


[1, 2, 3]


new_candidate_aunts


[[1], [2], [0, 3]]


new_candidate_logprobs
torch.Size([3])


tensor([-3.9703, -2.8348, -5.8707], device='cuda:0')

F NEIGHBORS PRIOR 7: (15101.384811508) 3 candidates, 0.39602081600060046 inference time, 0.3960256739992474 total time
event: message
id: f-7"
data: {"id": "f-7", "level_type": "gather", "duration": 0.39602081600060046, "nodes": [{"content": "\u2581Pro", "parent": 1, "aunts": [1], "prob": -3.9703469276428223}, {"content": "a", "parent": 2, "aunts": [2], "prob": -2.834794521331787}, {"content": "\u2581is", "parent": 3, "aunts": [0, 3], "prob": -5.8706865310668945}]}




F NEIGHBORS AFTER 7: (15101.385359331) 3 candidates, 0.39602081600060046 inference time, 0.3965732710003067 total time

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 28])


tensor([], device='cuda:0', size=(0, 28), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 29])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338]],
       device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-3.9703, -2.8348, -5.8707], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 1.7734,  1.9688,  1.1250,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.2812, -1.7500,  0.4062,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.1875,  0.4727,  1.8750,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[1.0000e+00, 1.4450e-07, 4.5991e-10,  ..., 8.7565e-27, 4.1363e-27,
         7.1878e-28],
        [1.0000e+00, 1.7603e-06, 4.4508e-07,  ..., 4.4469e-26, 4.4469e-26,
         2.6972e-26],
        [4.9547e-01, 4.9547e-01, 7.0675e-03,  ..., 8.5175e-22, 5.8540e-22,
         2.0231e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.4955, 0.9909, 0.9980,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([4])


tensor([0, 1, 2, 2], device='cuda:0')


carryover_candidates
torch.Size([4, 29])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([4])


tensor([-3.9703, -2.8348, -5.8707, -5.8707], device='cuda:0')


new_candidate_toks
torch.Size([4, 1])


tensor([[ 2657],
        [ 2895],
        [  278],
        [21438]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([4])


tensor([-1.1921e-07, -2.3842e-06, -7.0225e-01, -7.0225e-01], device='cuda:0')


new_candidates
torch.Size([4, 30])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338, 21438]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([4])


tensor([-3.9703, -2.8348, -6.5729, -6.5729], device='cuda:0')

TOP P PRIOR 7: (15101.431143243) 4 candidates, 0.4423533819990553 inference time, 0.44235743900026137 total time
event: message
id: 7-p
data: {"id": "7-p", "level_type": "sample", "duration": 0.4423533819990553, "nodes": [{"content": "xim", "parent": 0, "prob": -3.9703471660614014}, {"content": "\u2581Cent", "parent": 1, "prob": -2.834796905517578}, {"content": "\u2581the", "parent": 2, "prob": -6.572933197021484}, {"content": "\u2581closest", "parent": 2, "prob": -6.572933197021484}], "finished": []}




TOP P AFTER 7: (15101.431593555) 4 candidates, 0.4423533819990553 inference time, 0.44280683400029375 total time

num_batches


1

infer start: GPU memory used: 8068 MB.

batch_candidates
torch.Size([4, 30])

batch_candidate_logprobs
torch.Size([4])

batch_logits
torch.Size([4, 30, 32064])

hidden_states[-1]
torch.Size([4, 30, 3072])
infer - after batch run: GPU memory used: 8068 MB.

candidates
torch.Size([4, 30])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338, 21438]],
       device='cuda:0')


candidate_logprobs
torch.Size([4])


tensor([-3.9703, -2.8348, -6.5729, -6.5729], device='cuda:0')


embeddings
torch.Size([4, 3072])


tensor([[-1.7344,  0.0605,  1.5000,  ..., -3.1875, -1.2266, -0.0317],
        [-0.1494, -0.2168,  1.4688,  ..., -1.5000,  0.6289,  0.2715],
        [-1.1875,  1.0938,  0.6406,  ..., -0.6055,  0.1279,  0.6992],
        [-0.0466, -0.2246,  2.1406,  ..., -0.5820,  0.0253,  0.4961]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([4])


tensor([False,  True, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.1494, -0.2168,  1.4688,  ..., -1.5000,  0.6289,  0.2715]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 1])


tensor([[1.7422],
        [1.0000],
        [1.9141],
        [1.8516]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([1.7422, 1.0000, 1.9141, 1.8516], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([1.7422, 0.0000, 1.9141, 1.8516], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([False,  True,  True, False], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[-0.1494, -0.2168,  1.4688,  ..., -1.5000,  0.6289,  0.2715],
        [-1.1875,  1.0938,  0.6406,  ..., -0.6055,  0.1279,  0.6992]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 2])


tensor([[1.7422, 1.8594],
        [1.0000, 1.9141],
        [1.9141, 0.9922],
        [1.8516, 1.5391]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([1.7422, 1.0000, 0.9922, 1.5391], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([1.7422, 0.0000, 0.0000, 1.5391], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([ True,  True,  True, False], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[-1.7344,  0.0605,  1.5000,  ..., -3.1875, -1.2266, -0.0317],
        [-0.1494, -0.2168,  1.4688,  ..., -1.5000,  0.6289,  0.2715],
        [-1.1875,  1.0938,  0.6406,  ..., -0.6055,  0.1279,  0.6992]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 3])


tensor([[1.0000, 1.7422, 1.8594],
        [1.7422, 1.0000, 1.9141],
        [1.8594, 1.9141, 0.9922],
        [1.8047, 1.8516, 1.5391]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([4])


tensor([0, 1, 2, 2], device='cuda:0')


new_candidates
torch.Size([3, 30])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278]],
       device='cuda:0')


new_candidate_parents


[0, 1, 2]


new_candidate_aunts


[[], [], [3]]


new_candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8348, -13.1459], device='cuda:0')

F NEIGHBORS PRIOR 8: (15101.834227358) 3 candidates, 0.40261916700001166 inference time, 0.40262316900043515 total time
event: message
id: f-8"
data: {"id": "f-8", "level_type": "gather", "duration": 0.40261916700001166, "nodes": [{"content": "xim", "parent": 0, "aunts": [], "prob": -3.9703471660614014}, {"content": "\u2581Cent", "parent": 1, "aunts": [], "prob": -2.834796905517578}, {"content": "\u2581the", "parent": 2, "aunts": [3], "prob": -13.145866394042969}]}




F NEIGHBORS AFTER 8: (15101.834675641) 3 candidates, 0.40261916700001166 inference time, 0.4030704990000231 total time

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 29])


tensor([], device='cuda:0', size=(0, 29), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 30])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278]],
       device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8348, -13.1459], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 2.8594,  1.0391, -1.3203,  ...,  0.0000,  0.0000,  0.0000],
        [ 5.0625, -0.7695,  2.0469,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.8438,  3.7500,  1.8359,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[1.0000e+00, 9.2374e-09, 1.4166e-09,  ..., 3.2859e-25, 8.3079e-26,
         3.0563e-26],
        [9.9990e-01, 7.4844e-05, 1.4738e-05,  ..., 2.4277e-24, 7.8815e-25,
         4.7804e-25],
        [9.9478e-01, 5.2201e-03, 2.2485e-06,  ..., 1.1409e-24, 1.0068e-24,
         7.3661e-25]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9948, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([3, 30])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8348, -13.1459], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[29874],
        [29874],
        [21438]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([ 0.0000, -0.0001, -0.0052], device='cuda:0')


new_candidates
torch.Size([3, 31])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438]], device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8349, -13.1511], device='cuda:0')

TOP P PRIOR 8: (15101.880415076) 3 candidates, 0.4488071609994222 inference time, 0.4488100469989149 total time
event: message
id: 8-p
data: {"id": "8-p", "level_type": "sample", "duration": 0.4488071609994222, "nodes": [{"content": "a", "parent": 0, "prob": -3.9703471660614014}, {"content": "a", "parent": 1, "prob": -2.8349013328552246}, {"content": "\u2581closest", "parent": 2, "prob": -13.151103973388672}], "finished": []}




TOP P AFTER 8: (15101.880836678) 3 candidates, 0.4488071609994222 inference time, 0.4492315459992824 total time

num_batches


1

infer start: GPU memory used: 8084 MB.

batch_candidates
torch.Size([3, 31])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 31, 32064])

hidden_states[-1]
torch.Size([3, 31, 3072])
infer - after batch run: GPU memory used: 8084 MB.

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 30])


tensor([], device='cuda:0', size=(0, 30), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 31])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8349, -13.1511], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 7.0938, -1.5312,  0.2598,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.1250, -5.2188,  0.1982,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.6094, -2.3594, -0.8086,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[1.0000e+00, 1.3710e-06, 2.3824e-07,  ..., 3.4633e-26, 3.0563e-26,
         1.4437e-26],
        [9.9997e-01, 2.7536e-05, 3.7265e-06,  ..., 4.9399e-28, 3.8472e-28,
         2.6441e-28],
        [9.1932e-01, 7.5463e-02, 4.8242e-03,  ..., 4.1700e-24, 2.8660e-24,
         2.5293e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9193, 0.9948, 0.9996,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([3, 31])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8349, -13.1511], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[2895],
        [5338],
        [5810]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-1.7881e-06, -3.4214e-05, -8.4119e-02], device='cuda:0')


new_candidates
torch.Size([3, 32])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438,  5810]], device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8349, -13.2352], device='cuda:0')

TOP P PRIOR 9: (15102.258115586) 3 candidates, 0.37726570399900083 inference time, 0.3772702059995936 total time
event: message
id: 9-p
data: {"id": "9-p", "level_type": "sample", "duration": 0.37726570399900083, "nodes": [{"content": "\u2581Cent", "parent": 0, "prob": -3.970349073410034}, {"content": "uri", "parent": 1, "prob": -2.8349356651306152}, {"content": "\u2581star", "parent": 2, "prob": -13.235222816467285}], "finished": []}




TOP P AFTER 9: (15102.258531366) 3 candidates, 0.37726570399900083 inference time, 0.37768496599892387 total time

num_batches


1

infer start: GPU memory used: 8084 MB.

batch_candidates
torch.Size([3, 32])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 32, 32064])

hidden_states[-1]
torch.Size([3, 32, 3072])
infer - after batch run: GPU memory used: 8084 MB.

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 31])


tensor([], device='cuda:0', size=(0, 31), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 32])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438,  5810]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([ -3.9703,  -2.8349, -13.2352], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 5.2812, -0.8867,  1.7891,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.2500,  1.0938,  5.8750,  ...,  0.0000,  0.0000,  0.0000],
        [ 7.5312, -2.4062,  2.2344,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[9.9986e-01, 1.0889e-04, 2.1442e-05,  ..., 3.5321e-24, 1.2994e-24,
         1.1467e-24],
        [6.7824e-01, 3.2038e-01, 5.4580e-04,  ..., 2.0061e-23, 2.0061e-23,
         5.0723e-24],
        [7.9918e-01, 1.7832e-01, 1.8795e-02,  ..., 1.8593e-22, 1.7467e-22,
         1.0594e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.6782, 0.9986, 0.9992,  ..., 1.0000, 1.0000, 1.0000],
        [0.7992, 0.9775, 0.9963,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([5])


tensor([0, 1, 1, 2, 2], device='cuda:0')


carryover_candidates
torch.Size([5, 32])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  


carryover_candidate_logprobs
torch.Size([5])


tensor([ -3.9703,  -2.8349,  -2.8349, -13.2352, -13.2352], device='cuda:0')


new_candidate_toks
torch.Size([5, 1])


tensor([[29874],
        [29889],
        [29892],
        [  304],
        [29889]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([5])


tensor([-1.4402e-04, -3.8825e-01, -1.1383e+00, -2.2417e-01, -1.7242e+00],
       device='cuda:0')


new_candidates
torch.Size([5, 33])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 2


new_candidate_logprobs
torch.Size([5])


tensor([ -3.9705,  -3.2232,  -3.9732, -13.4594, -14.9594], device='cuda:0')

TOP P PRIOR 10: (15102.636325341) 5 candidates, 0.3777798039991467 inference time, 0.37778301799880865 total time
event: message
id: 10-p
data: {"id": "10-p", "level_type": "sample", "duration": 0.3777798039991467, "nodes": [{"content": "a", "parent": 0, "prob": -3.9704930782318115}, {"content": ".", "parent": 1, "prob": -3.223188877105713}, {"content": ",", "parent": 1, "prob": -3.973188877105713}, {"content": "\u2581to", "parent": 2, "prob": -13.459388732910156}, {"content": ".", "parent": 2, "prob": -14.959388732910156}], "finished": []}




TOP P AFTER 10: (15102.63674122) 5 candidates, 0.3777798039991467 inference time, 0.3781983449989639 total time

num_batches


1

infer start: GPU memory used: 8084 MB.

batch_candidates
torch.Size([5, 33])

batch_candidate_logprobs
torch.Size([5])

batch_logits
torch.Size([5, 33, 32064])

hidden_states[-1]
torch.Size([5, 33, 3072])
infer - after batch run: GPU memory used: 8108 MB.

candidates
torch.Size([5, 33])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 2


candidate_logprobs
torch.Size([5])


tensor([ -3.9705,  -3.2232,  -3.9732, -13.4594, -14.9594], device='cuda:0')


embeddings
torch.Size([5, 3072])


tensor([[ 0.6641, -1.5938,  2.0781,  ..., -2.4062,  1.1016, -0.1787],
        [-0.5273, -0.2432,  0.3477,  ..., -0.8438, -0.4629, -2.0000],
        [-0.3105, -0.4727,  0.3379,  ..., -0.9062,  1.7656, -0.8398],
        [ 1.0469,  1.0547,  1.4219,  ..., -1.5078, -0.0215, -0.1582],
        [-0.4785, -0.3008,  0.7031,  ..., -2.0312, -0.3926, -1.9688]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([5])


tensor([False,  True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.5273, -0.2432,  0.3477,  ..., -0.8438, -0.4629, -2.0000]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([5, 1])


tensor([[1.9531],
        [1.0000],
        [1.6406],
        [1.7422],
        [1.0781]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([5])


tensor([1.9531, 1.0000, 1.6406, 1.7422, 1.0781], device='cuda:0',
       dtype=torch.bfloat16)


min_remaining_distances
torch.Size([5])


tensor([1.9531, 0.0000, 1.6406, 1.7422, 1.0781], device='cuda:0',
       dtype=torch.bfloat16)


selected (end of loop)
torch.Size([5])


tensor([ True,  True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[ 0.6641, -1.5938,  2.0781,  ..., -2.4062,  1.1016, -0.1787],
        [-0.5273, -0.2432,  0.3477,  ..., -0.8438, -0.4629, -2.0000]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([5, 2])


tensor([[1.0000, 1.9531],
        [1.9531, 1.0000],
        [1.8828, 1.6406],
        [1.9141, 1.7422],
        [1.9531, 1.0781]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([5])


tensor([1.0000, 1.0000, 1.6406, 1.7422, 1.0781], device='cuda:0',
       dtype=torch.bfloat16)


min_remaining_distances
torch.Size([5])


tensor([0.0000, 0.0000, 1.6406, 1.7422, 1.0781], device='cuda:0',
       dtype=torch.bfloat16)


selected (end of loop)
torch.Size([5])


tensor([ True,  True, False,  True, False], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[ 0.6641, -1.5938,  2.0781,  ..., -2.4062,  1.1016, -0.1787],
        [-0.5273, -0.2432,  0.3477,  ..., -0.8438, -0.4629, -2.0000],
        [ 1.0469,  1.0547,  1.4219,  ..., -1.5078, -0.0215, -0.1582]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([5, 3])


tensor([[1.0000, 1.9531, 1.9141],
        [1.9531, 1.0000, 1.7422],
        [1.8828, 1.6406, 1.5781],
        [1.9141, 1.7422, 1.0000],
        [1.9531, 1.0781, 1.7656]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([5])


tensor([0, 1, 2, 2, 1], device='cuda:0')


new_candidates
torch.Size([3, 33])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438,  5810,   304]], device='cuda:0')


new_candidate_parents


[0, 1, 3]


new_candidate_aunts


[[], [4], [3]]


new_candidate_logprobs
torch.Size([3])


tensor([ -3.9705, -18.1826, -17.4326], device='cuda:0')

F NEIGHBORS PRIOR 11: (15103.204692341) 3 candidates, 0.5679380130004574 inference time, 0.567942020999908 total time
event: message
id: f-11"
data: {"id": "f-11", "level_type": "gather", "duration": 0.5679380130004574, "nodes": [{"content": "a", "parent": 0, "aunts": [], "prob": -3.9704930782318115}, {"content": ".", "parent": 1, "aunts": [4], "prob": -18.18257713317871}, {"content": "\u2581to", "parent": 3, "aunts": [3], "prob": -17.43257713317871}]}




F NEIGHBORS AFTER 11: (15103.205134919) 3 candidates, 0.5679380130004574 inference time, 0.5683833769999183 total time

finished_mask
torch.Size([3])


tensor([False, False, False], device='cuda:0')


finished
torch.Size([0, 32])


tensor([], device='cuda:0', size=(0, 32), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([3, 33])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 29874,  2895, 29874,  5338,   338,   278,
         21438,  5810,   304]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([ -3.9705, -18.1826, -17.4326], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ 6.1250e+00, -5.6562e+00,  6.5002e-03,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 6.9688e+00,  4.2812e+00,  1.3562e+01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 8.0000e+00,  4.5000e+00,  5.2812e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[9.9998e-01, 1.8925e-05, 3.2887e-06,  ..., 2.6442e-28, 2.6442e-28,
         2.0593e-28],
        [8.1523e-01, 1.8190e-01, 1.3888e-03,  ..., 8.2263e-19, 4.6872e-19,
         3.4292e-19],
        [9.7638e-01, 2.2962e-02, 5.4002e-04,  ..., 2.1552e-21, 2.0246e-21,
         5.8006e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8152, 0.9971, 0.9985,  ..., 1.0000, 1.0000, 1.0000],
        [0.9764, 0.9993, 0.9999,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([4])


tensor([0, 1, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([4, 33])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,  1019,  2657, 2


carryover_candidate_logprobs
torch.Size([4])


tensor([ -3.9705, -18.1826, -18.1826, -17.4326], device='cuda:0')


new_candidate_toks
torch.Size([4, 1])


tensor([[ 5338],
        [   13],
        [32007],
        [11563]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([4])


tensor([-2.4557e-05, -2.0428e-01, -1.7043e+00, -2.3903e-02], device='cuda:0')


new_candidates
torch.Size([4, 34])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889, 32007],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 3


new_candidate_logprobs
torch.Size([4])


tensor([ -3.9705, -18.3869, -19.8869, -17.4565], device='cuda:0')

TOP P PRIOR 11: (15103.252032001) 4 candidates, 0.6152775720001955 inference time, 0.6152815690002171 total time
event: message
id: 11-p
data: {"id": "11-p", "level_type": "sample", "duration": 0.6152775720001955, "nodes": [{"content": "uri", "parent": 0, "prob": -3.970517635345459}, {"content": "<0x0A>", "parent": 1, "prob": -18.386856079101562}, {"content": "<|end|>", "parent": 1, "prob": -19.886856079101562}, {"content": "\u2581Earth", "parent": 2, "prob": -17.456480026245117}], "finished": []}




TOP P AFTER 11: (15103.252468012) 4 candidates, 0.6152775720001955 inference time, 0.6157164960004593 total time

num_batches


1

infer start: GPU memory used: 8130 MB.

batch_candidates
torch.Size([4, 34])

batch_candidate_logprobs
torch.Size([4])

batch_logits
torch.Size([4, 34, 32064])

hidden_states[-1]
torch.Size([4, 34, 3072])
infer - after batch run: GPU memory used: 8130 MB.

candidates
torch.Size([4, 34])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889, 32007],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 3


candidate_logprobs
torch.Size([4])


tensor([ -3.9705, -18.3869, -19.8869, -17.4565], device='cuda:0')


embeddings
torch.Size([4, 3072])


tensor([[ 1.0469, -0.2402, -0.2207,  ..., -0.0081,  0.8711,  1.2578],
        [ 1.2812, -0.2334, -3.3906,  ...,  0.3086, -0.3672, -0.1650],
        [-0.5469, -0.1416,  1.0156,  ..., -1.1797,  0.2246,  1.2578],
        [ 1.0234,  0.7031,  0.9375,  ..., -0.0227,  0.1191, -0.3457]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([4])


tensor([ True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 1.0469, -0.2402, -0.2207,  ..., -0.0081,  0.8711,  1.2578]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 1])


tensor([[0.9922],
        [1.9141],
        [2.0312],
        [1.2812]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([0.9922, 1.9141, 2.0312, 1.2812], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([0.0000, 1.9141, 2.0312, 1.2812], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([ True, False,  True, False], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[ 1.0469, -0.2402, -0.2207,  ..., -0.0081,  0.8711,  1.2578],
        [-0.5469, -0.1416,  1.0156,  ..., -1.1797,  0.2246,  1.2578]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 2])


tensor([[0.9922, 2.0312],
        [1.9141, 1.7031],
        [2.0312, 1.0000],
        [1.2812, 2.0469]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([0.9922, 1.7031, 1.0000, 1.2812], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([0.0000, 1.7031, 0.0000, 1.2812], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([ True,  True,  True, False], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[ 1.0469, -0.2402, -0.2207,  ..., -0.0081,  0.8711,  1.2578],
        [ 1.2812, -0.2334, -3.3906,  ...,  0.3086, -0.3672, -0.1650],
        [-0.5469, -0.1416,  1.0156,  ..., -1.1797,  0.2246,  1.2578]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 3])


tensor([[0.9922, 1.9141, 2.0312],
        [1.9141, 1.0000, 1.7031],
        [2.0312, 1.7031, 1.0000],
        [1.2812, 1.9375, 2.0469]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([4])


tensor([0, 1, 2, 0], device='cuda:0')


new_candidates
torch.Size([3, 34])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889, 32007]], device='cuda:0')


new_candidate_parents


[0, 1, 2]


new_candidate_aunts


[[3], [], []]


new_candidate_logprobs
torch.Size([3])


tensor([-21.4270, -18.3869, -19.8869], device='cuda:0')

F NEIGHBORS PRIOR 12: (15103.802176348) 3 candidates, 0.5496949910011608 inference time, 0.5496997280006326 total time
event: message
id: f-12"
data: {"id": "f-12", "level_type": "gather", "duration": 0.5496949910011608, "nodes": [{"content": "uri", "parent": 0, "aunts": [3], "prob": -21.426998138427734}, {"content": "<0x0A>", "parent": 1, "aunts": [], "prob": -18.386856079101562}, {"content": "<|end|>", "parent": 2, "aunts": [], "prob": -19.886856079101562}]}




F NEIGHBORS AFTER 12: (15103.802666382) 3 candidates, 0.5496949910011608 inference time, 0.5501885940011562 total time

finished_mask
torch.Size([3])


tensor([False, False,  True], device='cuda:0')


finished
torch.Size([1, 33])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889]], device='cuda:0')


finished_parents
torch.Size([1])


tensor([2], device='cuda:0')


finished_logprobs
torch.Size([1])


tensor([-19.8869], device='cuda:0')


candidates
torch.Size([2, 34])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-21.4270, -18.3869], device='cuda:0')


last_tok_logits
torch.Size([2, 32064])


tensor([[ 8.1250,  1.5156,  5.6562,  ...,  0.0000,  0.0000,  0.0000],
        [-4.1250,  1.9844,  6.5625,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[9.3942e-01, 6.0055e-02, 1.9114e-04,  ..., 2.4521e-23, 1.4873e-23,
         1.3125e-23],
        [9.9899e-01, 2.6099e-04, 2.3033e-04,  ..., 8.8083e-20, 6.0538e-20,
         4.7147e-20]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[0.9394, 0.9995, 0.9997,  ..., 1.0000, 1.0000, 1.0000],
        [0.9990, 0.9993, 0.9995,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 1], device='cuda:0')


carryover_candidates
torch.Size([2, 34])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-21.4270, -18.3869], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[29889],
        [   13]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.0625, -0.0010], device='cuda:0')


new_candidates
torch.Size([2, 35])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-21.4895, -18.3879], device='cuda:0')

TOP P PRIOR 12: (15103.846010319) 2 candidates, 0.593529280000439 inference time, 0.5935338110011799 total time
event: message
id: 12-p
data: {"id": "12-p", "level_type": "sample", "duration": 0.593529280000439, "nodes": [{"content": ".", "parent": 0, "prob": -21.489492416381836}, {"content": "<0x0A>", "parent": 1, "prob": -18.38786506652832}], "finished": [{"content": "What is the closest star to the Earth? Answer in 5 words or less.  The closest star is Proxima Centauri.", "parent": 2, "prob": -19.886856079101562}]}




TOP P AFTER 12: (15103.846618127) 2 candidates, 0.593529280000439 inference time, 0.5941407800000889 total time

num_batches


1

infer start: GPU memory used: 8130 MB.

batch_candidates
torch.Size([2, 35])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 35, 32064])

hidden_states[-1]
torch.Size([2, 35, 3072])
infer - after batch run: GPU memory used: 8130 MB.

finished_mask
torch.Size([2])


tensor([False, False], device='cuda:0')


finished
torch.Size([0, 34])


tensor([], device='cuda:0', size=(0, 34), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([2, 35])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-21.4895, -18.3879], device='cuda:0')


last_tok_logits
torch.Size([2, 32064])


tensor([[ 7.0312,  4.0312, 13.3750,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.6562,  5.5312,  3.9531,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[5.8736e-01, 4.0369e-01, 6.5250e-03,  ..., 1.3357e-18, 1.1787e-18,
         9.7719e-19],
        [4.6826e-01, 3.2183e-01, 8.1371e-02,  ..., 3.1118e-17, 2.4235e-17,
         2.4235e-17]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[0.5874, 0.9911, 0.9976,  ..., 1.0000, 1.0000, 1.0000],
        [0.4683, 0.7901, 0.8715,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([2, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([6])


tensor([0, 0, 1, 1, 1, 1], device='cuda:0')


carryover_candidates
torch.Size([6, 35])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 298


carryover_candidate_logprobs
torch.Size([6])


tensor([-21.4895, -21.4895, -18.3879, -18.3879, -18.3879, -18.3879],
       device='cuda:0')


new_candidate_toks
torch.Size([6, 1])


tensor([[32007],
        [   13],
        [17245],
        [29898],
        [ 9842],
        [   13]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([6])


tensor([-0.5321, -0.9071, -0.7587, -1.1337, -2.5087, -3.5087], device='cuda:0')


new_candidates
torch.Size([6, 36])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889, 32007],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889,    13],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   4


new_candidate_logprobs
torch.Size([6])


tensor([-22.0216, -22.3966, -19.1466, -19.5216, -20.8966, -21.8966],
       device='cuda:0')

TOP P PRIOR 13: (15104.216948419) 6 candidates, 0.37031494100119744 inference time, 0.3703183630004787 total time
event: message
id: 13-p
data: {"id": "13-p", "level_type": "sample", "duration": 0.37031494100119744, "nodes": [{"content": "<|end|>", "parent": 0, "prob": -22.021602630615234}, {"content": "<0x0A>", "parent": 0, "prob": -22.396602630615234}, {"content": "However", "parent": 1, "prob": -19.146604537963867}, {"content": "(", "parent": 1, "prob": -19.521604537963867}, {"content": "Note", "parent": 1, "prob": -20.896604537963867}, {"content": "<0x0A>", "parent": 1, "prob": -21.896604537963867}], "finished": []}




TOP P AFTER 13: (15104.217379764) 6 candidates, 0.37031494100119744 inference time, 0.3707491029999801 total time

num_batches


1

infer start: GPU memory used: 8130 MB.

batch_candidates
torch.Size([6, 36])

batch_candidate_logprobs
torch.Size([6])

batch_logits
torch.Size([6, 36, 32064])

hidden_states[-1]
torch.Size([6, 36, 3072])
infer - after batch run: GPU memory used: 8218 MB.

candidates
torch.Size([6, 36])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889, 32007],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889,    13],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   4


candidate_logprobs
torch.Size([6])


tensor([-22.0216, -22.3966, -19.1466, -19.5216, -20.8966, -21.8966],
       device='cuda:0')


embeddings
torch.Size([6, 3072])


tensor([[-0.1602, -0.4160,  0.7227,  ..., -2.5156,  0.5859,  1.1953],
        [ 1.3984, -0.3574, -2.6094,  ...,  0.4004, -0.6328,  0.2100],
        [ 0.7734, -0.3867,  0.4473,  ..., -0.0537,  0.4902,  0.4688],
        [ 1.9766, -0.0654,  0.3828,  ...,  0.1328,  0.6016, -0.8438],
        [ 1.3281,  0.4570,  0.6094,  ...,  0.9766,  0.4531, -1.8516],
        [-0.1104,  0.1338,  0.8242,  ...,  0.4766,  0.2578, -0.8750]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([6])


tensor([False, False,  True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.7734, -0.3867,  0.4473,  ..., -0.0537,  0.4902,  0.4688]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([6, 1])


tensor([[2.0469],
        [1.9297],
        [1.0000],
        [1.6484],
        [1.5625],
        [1.6875]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([6])


tensor([2.0469, 1.9297, 1.0000, 1.6484, 1.5625, 1.6875], device='cuda:0',
       dtype=torch.bfloat16)


min_remaining_distances
torch.Size([6])


tensor([2.0469, 1.9297, 0.0000, 1.6484, 1.5625, 1.6875], device='cuda:0',
       dtype=torch.bfloat16)


selected (end of loop)
torch.Size([6])


tensor([ True, False,  True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[-0.1602, -0.4160,  0.7227,  ..., -2.5156,  0.5859,  1.1953],
        [ 0.7734, -0.3867,  0.4473,  ..., -0.0537,  0.4902,  0.4688]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([6, 2])


tensor([[1.0000, 2.0469],
        [1.7500, 1.9297],
        [2.0469, 1.0000],
        [2.0156, 1.6484],
        [2.0000, 1.5625],
        [1.9766, 1.6875]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([6])


tensor([0, 0, 1, 1, 1, 1], device='cuda:0')


new_candidates
torch.Size([2, 36])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889, 32007],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245]], device='cuda:0')


new_candidate_parents


[0, 2]


new_candidate_aunts


[[1], [2, 3, 4, 5]]


new_candidate_logprobs
torch.Size([2])


tensor([-44.4182, -81.4614], device='cuda:0')

F NEIGHBORS PRIOR 14: (15104.929840619) 2 candidates, 0.7124478639998415 inference time, 0.7124513679991651 total time
event: message
id: f-14"
data: {"id": "f-14", "level_type": "gather", "duration": 0.7124478639998415, "nodes": [{"content": "<|end|>", "parent": 0, "aunts": [1], "prob": -44.41820526123047}, {"content": "However", "parent": 2, "aunts": [2, 3, 4, 5], "prob": -81.46141815185547}]}




F NEIGHBORS AFTER 14: (15104.930216833) 2 candidates, 0.7124478639998415 inference time, 0.712826884999231 total time

finished_mask
torch.Size([2])


tensor([ True, False], device='cuda:0')


finished
torch.Size([1, 35])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   304, 11563,   338,  1019,  2657,
         29874,  2895, 29874,  5338, 29889]], device='cuda:0')


finished_parents
torch.Size([1])


tensor([0], device='cuda:0')


finished_logprobs
torch.Size([1])


tensor([-44.4182], device='cuda:0')


candidates
torch.Size([1, 36])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-81.4614], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[3.6094, 5.0625, 0.8008,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9999e-01, 2.5613e-06, 2.2603e-06,  ..., 3.4528e-20, 2.6891e-20,
         2.2293e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 36])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-81.4614], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[29892]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-1.3232e-05], device='cuda:0')


new_candidates
torch.Size([1, 37])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-81.4614], device='cuda:0')

TOP P PRIOR 14: (15104.968513975) 1 candidates, 0.7511211349992664 inference time, 0.7511250099996687 total time
event: message
id: 14-p
data: {"id": "14-p", "level_type": "sample", "duration": 0.7511211349992664, "nodes": [{"content": ",", "parent": 0, "prob": -81.46143341064453}], "finished": [{"content": "What is the closest star to the Earth? Answer in 5 words or less.  The closest star to Earth is Proxima Centauri.", "parent": 0, "prob": -44.41820526123047}]}




TOP P AFTER 14: (15104.969021251) 1 candidates, 0.7511211349992664 inference time, 0.7516315699995175 total time

num_batches


1

infer start: GPU memory used: 8246 MB.

batch_candidates
torch.Size([1, 37])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 37, 32064])

hidden_states[-1]
torch.Size([1, 37, 3072])
infer - after batch run: GPU memory used: 8246 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 36])


tensor([], device='cuda:0', size=(0, 36), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 37])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-81.4614], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[5.6250, 7.6562, 5.9375,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[3.7173e-01, 2.8950e-01, 5.7007e-02,  ..., 1.1198e-18, 6.3807e-19,
         6.1844e-19]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.3717, 0.6612, 0.7182,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([5, 37])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 298


carryover_candidate_logprobs
torch.Size([5])


tensor([-81.4614, -81.4614, -81.4614, -81.4614, -81.4614], device='cuda:0')


new_candidate_toks
torch.Size([5, 1])


tensor([[ 565],
        [ 372],
        [ 304],
        [ 278],
        [1951]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([5])


tensor([-0.9896, -1.2396, -2.8646, -3.1146, -3.1146], device='cuda:0')


new_candidates
torch.Size([5, 38])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   372],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         299


new_candidate_logprobs
torch.Size([5])


tensor([-82.4510, -82.7010, -84.3260, -84.5760, -84.5760], device='cuda:0')

TOP P PRIOR 15: (15105.201600822) 5 candidates, 0.23256568199940375 inference time, 0.2325704949998908 total time
event: message
id: 15-p
data: {"id": "15-p", "level_type": "sample", "duration": 0.23256568199940375, "nodes": [{"content": "\u2581if", "parent": 0, "prob": -82.45101928710938}, {"content": "\u2581it", "parent": 0, "prob": -82.70101928710938}, {"content": "\u2581to", "parent": 0, "prob": -84.32601928710938}, {"content": "\u2581the", "parent": 0, "prob": -84.57601928710938}, {"content": "\u2581since", "parent": 0, "prob": -84.57601928710938}], "finished": []}




TOP P AFTER 15: (15105.202127398) 5 candidates, 0.23256568199940375 inference time, 0.23309660899940354 total time

num_batches


1

infer start: GPU memory used: 8246 MB.

batch_candidates
torch.Size([5, 38])

batch_candidate_logprobs
torch.Size([5])

batch_logits
torch.Size([5, 38, 32064])

hidden_states[-1]
torch.Size([5, 38, 3072])
infer - after batch run: GPU memory used: 8246 MB.

candidates
torch.Size([5, 38])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   372],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   304],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         299


candidate_logprobs
torch.Size([5])


tensor([-82.4510, -82.7010, -84.3260, -84.5760, -84.5760], device='cuda:0')


embeddings
torch.Size([5, 3072])


tensor([[ 0.5547,  0.3086,  0.2773,  ...,  0.6523, -1.6406, -0.6641],
        [ 0.2285,  1.6875,  0.0615,  ...,  0.0505, -1.1641,  0.6055],
        [ 0.2695,  0.9570,  1.3516,  ...,  0.7734, -0.6055,  0.1553],
        [ 0.2637,  0.5703,  0.1875,  ...,  0.0513, -1.0234,  0.1836],
        [-0.5039,  1.5938,  1.5312,  ...,  0.3496, -1.3984, -0.9141]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([5])


tensor([ True, False, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.5547,  0.3086,  0.2773,  ...,  0.6523, -1.6406, -0.6641]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([5, 1])


tensor([[1.0000],
        [1.5781],
        [1.4688],
        [1.4375],
        [1.2812]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 38])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565]],
       device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2, 3, 4]]


new_candidate_logprobs
torch.Size([1])


tensor([-418.6301], device='cuda:0')

F NEIGHBORS PRIOR 16: (15105.751498384) 1 candidates, 0.5493504389996815 inference time, 0.5493542510012048 total time
event: message
id: f-16"
data: {"id": "f-16", "level_type": "gather", "duration": 0.5493504389996815, "nodes": [{"content": "\u2581if", "parent": 0, "aunts": [1, 2, 3, 4], "prob": -418.6300964355469}]}




F NEIGHBORS AFTER 16: (15105.751853237) 1 candidates, 0.5493504389996815 inference time, 0.5497084010003164 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 37])


tensor([], device='cuda:0', size=(0, 37), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 38])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-418.6301], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[6.2188, 4.7500, 7.4688,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[3.5380e-01, 2.7554e-01, 2.7554e-01,  ..., 2.1440e-20, 1.7775e-20,
         1.6698e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.3538, 0.6293, 0.9049,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([3, 38])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([-418.6301, -418.6301, -418.6301], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[366],
        [278],
        [591]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-1.0390, -1.2890, -1.2890], device='cuda:0')


new_candidates
torch.Size([3, 39])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   591]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([-419.6691, -419.9191, -419.9191], device='cuda:0')

TOP P PRIOR 16: (15105.795835831) 3 candidates, 0.5936879380005848 inference time, 0.5936912840006698 total time
event: message
id: 16-p
data: {"id": "16-p", "level_type": "sample", "duration": 0.5936879380005848, "nodes": [{"content": "\u2581you", "parent": 0, "prob": -419.6690979003906}, {"content": "\u2581the", "parent": 0, "prob": -419.9190979003906}, {"content": "\u2581we", "parent": 0, "prob": -419.9190979003906}], "finished": []}




TOP P AFTER 16: (15105.796403212) 3 candidates, 0.5936879380005848 inference time, 0.5942582839998067 total time

num_batches


1

infer start: GPU memory used: 8246 MB.

batch_candidates
torch.Size([3, 39])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 39, 32064])

hidden_states[-1]
torch.Size([3, 39, 3072])
infer - after batch run: GPU memory used: 8246 MB.

candidates
torch.Size([3, 39])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   278],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   591]],
       device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-419.6691, -419.9191, -419.9191], device='cuda:0')


embeddings
torch.Size([3, 3072])


tensor([[-0.8203,  0.5352, -0.2139,  ...,  0.6367, -1.0547,  1.1875],
        [-0.0498,  0.6758,  0.4648,  ...,  0.7617, -2.0469,  0.4004],
        [ 0.2559,  0.1260,  1.2734,  ...,  1.6016, -0.9844,  0.8320]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([3])


tensor([ True, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.8203,  0.5352, -0.2139,  ...,  0.6367, -1.0547,  1.1875]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 1])


tensor([[1.0000],
        [1.4688],
        [1.1641]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([3])


tensor([0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 39])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366]],
       device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2]]


new_candidate_logprobs
torch.Size([1])


tensor([-1259.5073], device='cuda:0')

F NEIGHBORS PRIOR 17: (15106.170340699) 1 candidates, 0.37392315699980827 inference time, 0.3739270469995972 total time
event: message
id: f-17"
data: {"id": "f-17", "level_type": "gather", "duration": 0.37392315699980827, "nodes": [{"content": "\u2581you", "parent": 0, "aunts": [1, 2], "prob": -1259.50732421875}]}




F NEIGHBORS AFTER 17: (15106.170698863) 1 candidates, 0.37392315699980827 inference time, 0.37428453500069736 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 38])


tensor([], device='cuda:0', size=(0, 38), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 39])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-1259.5073], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[5.0312, 3.5781, 3.7344,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[4.5353e-01, 2.7508e-01, 1.0120e-01,  ..., 6.6587e-19, 5.1858e-19,
         5.8183e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.4535, 0.7286, 0.8298,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


carryover_candidates
torch.Size([5, 39])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278


carryover_candidate_logprobs
torch.Size([5])


tensor([-1259.5073, -1259.5073, -1259.5073, -1259.5073, -1259.5073],
       device='cuda:0')


new_candidate_toks
torch.Size([5, 1])


tensor([[29915],
        [  526],
        [ 6839],
        [18719],
        [ 2099]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([5])


tensor([-0.7907, -1.2907, -2.2907, -2.6657, -2.9157], device='cuda:0')


new_candidates
torch.Size([5, 40])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366,   526],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366,  6839],
        [    1, 32010,  1724,   338,   278, 21438


new_candidate_logprobs
torch.Size([5])


tensor([-1260.2980, -1260.7980, -1261.7980, -1262.1730, -1262.4230],
       device='cuda:0')

TOP P PRIOR 17: (15106.217950917) 5 candidates, 0.4215334130003612 inference time, 0.42153753499951563 total time
event: message
id: 17-p
data: {"id": "17-p", "level_type": "sample", "duration": 0.4215334130003612, "nodes": [{"content": "'", "parent": 0, "prob": -1260.2979736328125}, {"content": "\u2581are", "parent": 0, "prob": -1260.7979736328125}, {"content": "\u2581meant", "parent": 0, "prob": -1261.7979736328125}, {"content": "\u2581strictly", "parent": 0, "prob": -1262.1729736328125}, {"content": "\u2581mean", "parent": 0, "prob": -1262.4229736328125}], "finished": []}




TOP P AFTER 17: (15106.218426049) 5 candidates, 0.4215334130003612 inference time, 0.4220116840006085 total time

num_batches


1

infer start: GPU memory used: 8246 MB.

batch_candidates
torch.Size([5, 40])

batch_candidate_logprobs
torch.Size([5])

batch_logits
torch.Size([5, 40, 32064])

hidden_states[-1]
torch.Size([5, 40, 3072])
infer - after batch run: GPU memory used: 8246 MB.

candidates
torch.Size([5, 40])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366,   526],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366,  6839],
        [    1, 32010,  1724,   338,   278, 21438


candidate_logprobs
torch.Size([5])


tensor([-1260.2980, -1260.7980, -1261.7980, -1262.1730, -1262.4230],
       device='cuda:0')


embeddings
torch.Size([5, 3072])


tensor([[ 0.4453,  0.6250, -0.4414,  ..., -0.0718, -0.2676,  0.0317],
        [ 0.3691, -0.0608,  0.4570,  ..., -1.1094, -0.0099,  0.0879],
        [-0.1465,  1.0547,  2.3906,  ...,  0.2734, -0.8711, -2.1406],
        [-1.0000,  0.8438,  0.5508,  ...,  1.3594, -1.0078,  0.0728],
        [ 0.0369,  0.9688,  2.5156,  ..., -0.0175, -0.5859, -1.9219]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([5])


tensor([ True, False, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[ 0.4453,  0.6250, -0.4414,  ..., -0.0718, -0.2676,  0.0317]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([5, 1])


tensor([[1.0000],
        [1.7812],
        [1.7500],
        [1.7422],
        [1.7422]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([5])


tensor([0, 0, 0, 0, 0], device='cuda:0')


new_candidates
torch.Size([1, 40])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915]],
       device='cuda:0')


new_candidate_parents


[0]


new_candidate_aunts


[[1, 2, 3, 4]]


new_candidate_logprobs
torch.Size([1])


tensor([-6307.4897], device='cuda:0')

F NEIGHBORS PRIOR 18: (15106.912269793) 1 candidates, 0.6938295380005002 inference time, 0.6938333750003949 total time
event: message
id: f-18"
data: {"id": "f-18", "level_type": "gather", "duration": 0.6938295380005002, "nodes": [{"content": "'", "parent": 0, "aunts": [1, 2, 3, 4], "prob": -6307.48974609375}]}




F NEIGHBORS AFTER 18: (15106.912626545) 1 candidates, 0.6938295380005002 inference time, 0.6941892489994643 total time

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 39])


tensor([], device='cuda:0', size=(0, 39), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 40])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915]],
       device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-6307.4897], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 1.2344, -1.2266, -0.1245,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9792e-01, 1.9264e-03, 1.2315e-04,  ..., 7.5147e-17, 5.4979e-17,
         5.1648e-17]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9979, 0.9998, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 40])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-6307.4897], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[276]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0021], device='cuda:0')


new_candidates
torch.Size([1, 41])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-6307.4917], device='cuda:0')

TOP P PRIOR 18: (15106.952377095) 1 candidates, 0.7339370589997998 inference time, 0.7339407780000329 total time
event: message
id: 18-p
data: {"id": "18-p", "level_type": "sample", "duration": 0.7339370589997998, "nodes": [{"content": "re", "parent": 0, "prob": -6307.49169921875}], "finished": []}




TOP P AFTER 18: (15106.952746851) 1 candidates, 0.7339370589997998 inference time, 0.7343094489988289 total time

num_batches


1

infer start: GPU memory used: 8246 MB.

batch_candidates
torch.Size([1, 41])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 41, 32064])

hidden_states[-1]
torch.Size([1, 41, 3072])
infer - after batch run: GPU memory used: 8246 MB.

finished_mask
torch.Size([1])


tensor([False], device='cuda:0')


finished
torch.Size([0, 40])


tensor([], device='cuda:0', size=(0, 40), dtype=torch.int64)


finished_parents
torch.Size([0])


tensor([], device='cuda:0', dtype=torch.int64)


finished_logprobs
torch.Size([0])


tensor([], device='cuda:0')


candidates
torch.Size([1, 41])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-6307.4917], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[3.9219, 4.3750, 5.1562,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[5.1326e-01, 3.9973e-01, 2.8956e-02,  ..., 9.0000e-20, 9.0000e-20,
         2.5785e-20]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.5133, 0.9130, 0.9419,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 0], device='cuda:0')


carryover_candidates
torch.Size([2, 41])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-6307.4917, -6307.4917], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[16811],
        [ 6721]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.6670, -0.9170], device='cuda:0')


new_candidates
torch.Size([2, 42])


tensor([[    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276, 16811],
        [    1, 32010,  1724,   338,   278, 21438,  5810,   304,   278, 11563,
         29973,   673,   297, 29871, 29945,  3838,   470,  3109, 29889, 29871,
         32007, 32001,   450, 21438,  5810,   338,  1019,  2657, 29874,  2895,
         29874,  5338, 29889,    13,    13, 17245, 29892,   565,   366, 29915,
           276,  6721]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-6308.1587, -6308.4087], device='cuda:0')

TOP P PRIOR 19: (15107.185344291) 2 candidates, 0.2325843079997867 inference time, 0.2325890320007602 total time
event: message
id: 19-p
data: {"id": "19-p", "level_type": "sample", "duration": 0.2325843079997867, "nodes": [{"content": "\u2581referring", "parent": 0, "prob": -6308.15869140625}, {"content": "\u2581asking", "parent": 0, "prob": -6308.40869140625}], "finished": []}




TOP P AFTER 19: (15107.185758699) 2 candidates, 0.2325843079997867 inference time, 0.23300291600025957 total time
event: message
id: END
data: []




