In [1]:
!pip install transformers accelerate optimum nvidia-ml-py



In [2]:
from transformers.utils import is_flash_attn_2_available
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import numpy as np
import torch.nn.functional as F
import torch
from datetime import timedelta
import time
from collections import namedtuple
import json
from sklearn.cluster import KMeans

torch.random.manual_seed(0)

<torch._C.Generator at 0x7f5d6897a1f0>

In [3]:


from pynvml import *

def check_gpu(step):
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"{step}: GPU memory used: {info.used // 1024**2} MB.")
    
def D(obj, label=None, c=True):
    print()
    if label:
        print(label)
        
    if isinstance(obj, tuple):
        print(len(obj))
    elif isinstance(obj, torch.Tensor) or isinstance(obj, np.ndarray):
        print(obj.shape)
        if c: # Contents
            display(obj)
    else:
        if c: # Contents
            display(obj)
            
def DS(obj, label=None):
    D(obj, label, c=False)
    


In [6]:


class InferenceTensor:
    def __init__(self):
        print('Initializing model...')
        self.model = AutoModelForCausalLM.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct",
            torch_dtype=torch.bfloat16,
            device_map='auto',
            trust_remote_code=True,
            use_cache=True,
            # attn_implementation='flash_attention_2',
        )
        print('Initializing tokenizer...')
        self.tokenizer = AutoTokenizer.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.max_new_tokens = 10
        self.batch_size = 8
        
    def candidates_generator(self, top_p: float, max_beams: int, prompt: str):
        print(prompt)
        candidates, candidate_logprobs = self._init_candidates(prompt)
        for level_idx in range(self.max_new_tokens):
            logits, embeddings = self._infer(candidates, candidate_logprobs)
        
            if candidates.shape[0] > max_beams:
                candidates, candidate_parents, candidate_aunts, candidate_logprobs, logits = self._k_means(logits, embeddings, candidates, candidate_logprobs, max_beams)
                yield self._format_k_means(level_idx, candidates, candidate_parents, candidate_aunts, candidate_logprobs)

            candidates, candidate_parents, candidate_logprobs = self._top_p(logits, candidates, candidate_logprobs, top_p)
            yield self._format_top_p(level_idx, candidates, candidate_parents, candidate_logprobs)

        yield f"event: message\nid: END\ndata: []\n\n"

    def _format_k_means(self, level_idx, candidates, candidate_parents, candidate_aunts, candidate_logprobs):
        candidate_texts = self.tokenizer.batch_decode(candidates[:, -1])
        candidate_probs = candidate_logprobs.exp()
        candidate_dicts = []
        idx = f"{level_idx}-k"
        for i in range(len(candidate_texts)):
            candidate_dicts.append({'content': candidate_texts[i], 'parent': candidate_parents[i], 'aunts': candidate_aunts[i], 'prob': candidate_probs[i].item()})
        data = json.dumps({'id': idx, 'level_type': 'k_means', 'nodes': candidate_dicts})
        return f"event: message\nid: {idx}\"\ndata: {data}\n\n"

    def _format_top_p(self, level_idx, candidates, candidate_parents, candidate_logprobs):
        candidate_texts = self.tokenizer.batch_decode(candidates[:, -1])
        candidate_probs = candidate_logprobs.exp()
        candidate_dicts = []
        idx = f"{level_idx}-p"
        for i in range(len(candidate_texts)):
            candidate_dicts.append({'content': candidate_texts[i], 'parent': candidate_parents[i], 'prob': candidate_probs[i].item()})
        data = json.dumps({'id': idx, 'level_type': 'top_p', 'nodes': candidate_dicts})
        return f"event: message\nid: {idx}\ndata: {data}\n\n"


    def _init_candidates(self, text: str):
        prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
        inputs = self.tokenizer(prompt, return_tensors='pt')
        D(inputs.input_ids, 'input_ids')
        print(self.tokenizer.batch_decode(inputs.input_ids))

        candidates = inputs.input_ids.to(self.device)
        candidate_logprobs = torch.zeros((1), dtype=torch.float32, device=self.device)

        return candidates, candidate_logprobs

    def _k_means(self, logits, embeddings, candidates, candidate_logprobs, max_beams):
        D(candidates, 'candidates')
        D(candidate_logprobs, 'candidate_logprobs')
        # === CPU ===
        embeddings_np = embeddings.float().numpy(force=True)
        D(embeddings_np, 'embeddings_np')
        k_means = KMeans(n_clusters=min(2, embeddings_np.shape[0]), random_state=0, n_init="auto")
        k_mean_space = k_means.fit_transform(embeddings_np)
        D(k_mean_space, 'k_mean_space')
        k_mean_clusters = k_means.predict(embeddings_np)
        D(k_mean_clusters, 'k_mean_clusters')
        k_mean_logprob_mass = np.bincount(k_mean_clusters, weights=candidate_logprobs.cpu())
        D(k_mean_logprob_mass, 'k_mean_logprob_mass')
        closest = np.argmin(k_mean_space, axis=0)
        D(closest, 'closest')
        # === END CPU ===
        
        closest_indices = torch.from_numpy(closest).to(self.device)
        new_candidates = candidates.index_select(0, closest_indices)
        D(new_candidates, 'new_candidates')
        new_candidate_parents = closest_indices.tolist()
        D(new_candidate_parents, 'new_candidate_parents')
        new_candidate_aunts = [torch.nonzero(torch.from_numpy(k_mean_clusters).to(self.device) == i).squeeze(-1).tolist() for i in range(new_candidates.shape[0])]
        D(new_candidate_aunts, 'new_candidate_aunts')
        new_candidate_logprobs = torch.from_numpy(k_mean_logprob_mass).to(self.device)
        D(new_candidate_logprobs, 'new_candidate_logprobs')
        new_candidate_logits = logits.index_select(0, closest_indices)
        
        return new_candidates, new_candidate_parents, new_candidate_aunts, new_candidate_logprobs, new_candidate_logits
        
    def _top_p(self, logits, candidates, candidate_logprobs, top_p):
        D(candidates, 'candidates')
        D(candidate_logprobs, 'candidate_logprobs')
        
        last_tok_logits = logits[:, -1, :]
        D(last_tok_logits, 'last_tok_logits')

        sorted_logits, sorted_indices = torch.sort(last_tok_logits, descending=True, dim=-1)
        DS(sorted_logits, 'sorted_logits')
        DS(sorted_indices, 'sorted_indices')
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        D(sorted_probs, 'sorted_probs')
        display(sorted_probs.sum(dim=1))
        cum_probs = torch.cumsum(sorted_probs, dim=-1)
        D(cum_probs, 'cum_probs')

        # Create tensor of bools indicating which indices are cumulatively less than top_p
        keep_indices = cum_probs < top_p

        # Keep the last element that went over top_p
        keep_indices[:, 1:] = keep_indices[:, :-1].clone() # Is this inefficient?
        keep_indices[:, 0] = 1  # Always keep the first element
        D(keep_indices, 'keep_indices')

        new_candidate_parents = keep_indices.nonzero()[:, 0]
        D(new_candidate_parents, 'new_candidate_parents')

        # OPTIM: Potential optimization -- have a fixed tensor of size (max_candidates, max_tokens) and copy this into that (batch-aware).
        # OPTIM: consider which of these operations can be done in-place to prevent new allocations?
        carryover_candidates = candidates.index_select(0, new_candidate_parents)
        D(carryover_candidates, 'carryover_candidates')

        # Similar code could be used to trace entire origin of sequence. For now since server just traces parent of the preceding generation, not needed
        # carryover_candidate_parents = candidate_parents.index_select(0, carryover_candidate_indices)  # Not strictly necessary since 1d
        # D(carryover_candidate_parents, 'carryover_candidate_parents')

        carryover_candidate_logprobs = candidate_logprobs.index_select(0, new_candidate_parents)  # Not strictly necessary since 1d
        D(carryover_candidate_logprobs, 'carryover_candidate_logprobs')

        new_candidate_toks = sorted_indices[keep_indices].unsqueeze(1)
        D(new_candidate_toks, 'new_candidate_toks')
        new_candidate_tok_logprobs = sorted_probs[keep_indices].log()
        D(new_candidate_tok_logprobs, 'new_candidate_tok_logprobs')

        new_candidates = torch.cat([carryover_candidates, new_candidate_toks], dim=1)
        D(new_candidates, 'new_candidates')
        new_candidate_logprobs = carryover_candidate_logprobs.add_(new_candidate_tok_logprobs)
        D(new_candidate_logprobs, 'new_candidate_logprobs')

        return new_candidates, new_candidate_parents.tolist(), new_candidate_logprobs


    def _infer(self, candidates, candidate_logprobs):
        with torch.inference_mode():
            num_batches = (candidates.shape[0] + self.batch_size - 1) // self.batch_size  # Round up to nearest whole number of batches
            D(num_batches, 'num_batches')

            check_gpu('infer start')
            output_logits_list = []
            output_embeddings_list = []
            for i in range(0, num_batches, 1):
                batch_candidates = candidates[i * self.batch_size:(i + 1) * self.batch_size]
                DS(batch_candidates, 'batch_candidates')
                batch_candidate_logprobs = candidate_logprobs[i * self.batch_size:(i + 1) * self.batch_size]
                DS(batch_candidate_logprobs, 'batch_candidate_logprobs')

                batch_outputs = self.model(input_ids=batch_candidates, output_hidden_states=True)
                DS(batch_outputs.logits, 'batch_logits')
                DS(batch_outputs.hidden_states[-1], 'hidden_states[-1]')

                output_logits_list.append(batch_outputs.logits)
                output_embeddings_list.append(batch_outputs.hidden_states[-1][:,-1,:])
                check_gpu('infer - after batch run')

            output_logits = torch.cat(output_logits_list, dim=0)
            output_embeddings = torch.cat(output_embeddings_list, dim=0)
            
            return output_logits, output_embeddings

In [7]:
it = InferenceTensor()

for x in it.candidates_generator(0.9, 2, 'What is the highest mountain?'):
    print(x)
    print()
    print('====================================')
    print()

Initializing model...


`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Initializing tokenizer...
What is the highest mountain?

input_ids
torch.Size([1, 11])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001]])

['<s><|user|> What is the highest mountain? <|end|><|assistant|>']

num_batches


1

infer start: GPU memory used: 15175 MB.

batch_candidates
torch.Size([1, 11])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 11, 32064])

hidden_states[-1]
torch.Size([1, 11, 3072])
infer - after batch run: GPU memory used: 15179 MB.

candidates
torch.Size([1, 11])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([0.], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 5.7812, -1.4688, -3.2969,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.2405e-01, 7.5850e-02, 8.8812e-05,  ..., 1.1622e-21, 7.0490e-22,
         3.7730e-22]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9240, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 11])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([0.], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[450]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0790], device='cuda:0')


new_candidates
torch.Size([1, 12])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-0.0790], device='cuda:0')

event: message
id: 0-p
data: {"id": "0-p", "level_type": "top_p", "nodes": [{"content": "The", "parent": 0, "prob": 0.924048125743866}]}





num_batches


1

infer start: GPU memory used: 15179 MB.

batch_candidates
torch.Size([1, 12])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 12, 32064])

hidden_states[-1]
torch.Size([1, 12, 3072])
infer - after batch run: GPU memory used: 15179 MB.

candidates
torch.Size([1, 12])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-0.0790], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[-1.2969,  1.0312, -5.0938,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9909e-01, 9.1105e-04, 1.2088e-06,  ..., 6.5938e-24, 3.9993e-24,
         1.4713e-24]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9991, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 12])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-0.0790], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[9939]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0009], device='cuda:0')


new_candidates
torch.Size([1, 13])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-0.0799], device='cuda:0')

event: message
id: 1-p
data: {"id": "1-p", "level_type": "top_p", "nodes": [{"content": "highest", "parent": 0, "prob": 0.9232035875320435}]}





num_batches


1

infer start: GPU memory used: 15179 MB.

batch_candidates
torch.Size([1, 13])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 13, 32064])

hidden_states[-1]
torch.Size([1, 13, 3072])
infer - after batch run: GPU memory used: 15181 MB.

candidates
torch.Size([1, 13])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-0.0799], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 2.8906,  3.1875, -7.0938,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9984e-01, 1.2339e-04, 3.5352e-05,  ..., 5.1391e-24, 4.5352e-24,
         3.5320e-24]], device='cuda:0')

tensor([1.], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9998, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 13])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-0.0799], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[14378]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0002], device='cuda:0')


new_candidates
torch.Size([1, 14])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-0.0801], device='cuda:0')

event: message
id: 2-p
data: {"id": "2-p", "level_type": "top_p", "nodes": [{"content": "mountain", "parent": 0, "prob": 0.9230522513389587}]}





num_batches


1

infer start: GPU memory used: 15181 MB.

batch_candidates
torch.Size([1, 14])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 14, 32064])

hidden_states[-1]
torch.Size([1, 14, 3072])
infer - after batch run: GPU memory used: 15181 MB.

candidates
torch.Size([1, 14])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-0.0801], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 4.8125,  1.2734, -5.8750,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[8.1600e-01, 9.7458e-02, 8.6006e-02,  ..., 5.2119e-21, 4.0591e-21,
         1.6921e-21]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.8160, 0.9135, 0.9995,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 0], device='cuda:0')


carryover_candidates
torch.Size([2, 14])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-0.0801, -0.0801], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[373],
        [297]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.2033, -2.3283], device='cuda:0')


new_candidates
torch.Size([2, 15])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-0.2834, -2.4084], device='cuda:0')

event: message
id: 3-p
data: {"id": "3-p", "level_type": "top_p", "nodes": [{"content": "on", "parent": 0, "prob": 0.7532129883766174}, {"content": "in", "parent": 0, "prob": 0.08995844423770905}]}





num_batches


1

infer start: GPU memory used: 15181 MB.

batch_candidates
torch.Size([2, 15])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 15, 32064])

hidden_states[-1]
torch.Size([2, 15, 3072])
infer - after batch run: GPU memory used: 15199 MB.

candidates
torch.Size([2, 15])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-0.2834, -2.4084], device='cuda:0')


last_tok_logits
torch.Size([2, 32064])


tensor([[ 1.6797,  0.8750, -4.7812,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.3438, -5.4688, -4.6875,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[9.9996e-01, 3.5356e-05, 1.7603e-06,  ..., 3.5325e-24, 2.4278e-24,
         3.7232e-25],
        [8.1683e-01, 1.8226e-01, 5.1193e-04,  ..., 3.4021e-20, 3.4021e-20,
         7.5910e-21]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8168, 0.9991, 0.9996,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 1], device='cuda:0')


carryover_candidates
torch.Size([3, 15])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([-0.2834, -2.4084, -2.4084], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[11563],
        [ 4958],
        [  278]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-3.7432e-05, -2.0232e-01, -1.7023e+00], device='cuda:0')


new_candidates
torch.Size([3, 16])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278]], device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([-0.2834, -2.6107, -4.1107], device='cuda:0')

event: message
id: 4-p
data: {"id": "4-p", "level_type": "top_p", "nodes": [{"content": "Earth", "parent": 0, "prob": 0.7531847953796387}, {"content": "terms", "parent": 1, "prob": 0.0734809935092926}, {"content": "the", "parent": 1, "prob": 0.016395829617977142}]}





num_batches


1

infer start: GPU memory used: 15219 MB.

batch_candidates
torch.Size([3, 16])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 16, 32064])

hidden_states[-1]
torch.Size([3, 16, 3072])
infer - after batch run: GPU memory used: 15235 MB.

candidates
torch.Size([3, 16])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-0.2834, -2.6107, -4.1107], device='cuda:0')


embeddings_np
(3, 3072)


array([[-1.4765625 , -0.0112915 ,  1.4921875 , ..., -0.94921875,
        -0.27734375, -1.515625  ],
       [ 0.640625  , -1.109375  ,  2.71875   , ..., -0.9140625 ,
         2.5625    , -2.1875    ],
       [ 0.09326172, -0.6796875 ,  3.390625  , ..., -0.69921875,
        -3.28125   , -1.0859375 ]], dtype=float32)


k_mean_space
(3, 2)


array([[105.234764,  50.63994 ],
       [  0.      ,  93.1218  ],
       [106.760445,  50.63994 ]], dtype=float32)


k_mean_clusters
(3,)


array([1, 0, 1], dtype=int32)


k_mean_logprob_mass
(2,)


array([-2.6107285 , -4.39417294])


closest
(2,)


array([1, 0])


new_candidates
torch.Size([2, 16])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563]], device='cuda:0')


new_candidate_parents


[1, 0]


new_candidate_aunts


[[1], [0, 2]]


new_candidate_logprobs
torch.Size([2])


tensor([-2.6107, -4.3942], device='cuda:0', dtype=torch.float64)

event: message
id: 5-k"
data: {"id": "5-k", "level_type": "k_means", "nodes": [{"content": "terms", "parent": 1, "aunts": [1], "prob": 0.07348099318896636}, {"content": "Earth", "parent": 0, "aunts": [0, 2], "prob": 0.012349089582051961}]}





candidates
torch.Size([2, 16])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-2.6107, -4.3942], device='cuda:0', dtype=torch.float64)


last_tok_logits
torch.Size([2, 32064])


tensor([[  1.5703,  -2.2031,  -5.8750,  ...,   0.0000,   0.0000,   0.0000],
        [ -0.9766,  -6.3438, -10.3125,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[1.0000e+00, 4.0587e-10, 1.3177e-10,  ..., 1.3697e-25, 1.2088e-25,
         1.0668e-25],
        [9.8127e-01, 1.5861e-02, 2.1465e-03,  ..., 8.4822e-22, 5.8297e-22,
         1.6702e-22]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9813, 0.9971, 0.9993,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 1], device='cuda:0')


carryover_candidates
torch.Size([2, 16])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-2.6107, -4.3942], device='cuda:0', dtype=torch.float64)


new_candidate_toks
torch.Size([2, 1])


tensor([[  310],
        [29892]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([ 0.0000, -0.0189], device='cuda:0')


new_candidates
torch.Size([2, 17])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-2.6107, -4.4131], device='cuda:0', dtype=torch.float64)

event: message
id: 5-p
data: {"id": "5-p", "level_type": "top_p", "nodes": [{"content": "of", "parent": 0, "prob": 0.07348099318896636}, {"content": ",", "parent": 1, "prob": 0.012117801617336242}]}





num_batches


1

infer start: GPU memory used: 15235 MB.

batch_candidates
torch.Size([2, 17])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 17, 32064])

hidden_states[-1]
torch.Size([2, 17, 3072])
infer - after batch run: GPU memory used: 15235 MB.

candidates
torch.Size([2, 17])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-2.6107, -4.4131], device='cuda:0', dtype=torch.float64)


last_tok_logits
torch.Size([2, 32064])


tensor([[  0.2129,  -0.4570,  -6.7500,  ...,   0.0000,   0.0000,   0.0000],
        [ -3.5469,  -2.0000, -10.3750,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[3.9407e-01, 1.8615e-01, 1.8615e-01,  ..., 1.0154e-18, 6.1589e-19,
         5.4352e-19],
        [4.0809e-01, 3.1782e-01, 2.1843e-01,  ..., 1.1798e-19, 8.6316e-20,
         7.6173e-20]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[0.3941, 0.5802, 0.7664,  ..., 1.0000, 1.0000, 1.0000],
        [0.4081, 0.7259, 0.9443,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([10])


tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1], device='cuda:0')


carryover_candidates
torch.Size([10, 17])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   33


carryover_candidate_logprobs
torch.Size([10])


tensor([-2.6107, -2.6107, -2.6107, -2.6107, -2.6107, -2.6107, -2.6107, -4.4131,
        -4.4131, -4.4131], device='cuda:0', dtype=torch.float64)


new_candidate_toks
torch.Size([10, 1])


tensor([[11858],
        [ 3171],
        [19224],
        [ 2038],
        [  967],
        [ 2533],
        [11563],
        [  408],
        [ 2729],
        [  297]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([10])


tensor([-0.9312, -1.6812, -1.6812, -3.0562, -3.3062, -3.3062, -3.5562, -0.8963,
        -1.1463, -1.5213], device='cuda:0')


new_candidates
torch.Size([10, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,   967],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  2533],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958, 


new_candidate_logprobs
torch.Size([10])


tensor([-3.5420, -4.2920, -4.2920, -5.6670, -5.9170, -5.9170, -6.1670, -5.3094,
        -5.5594, -5.9344], device='cuda:0', dtype=torch.float64)

event: message
id: 6-p
data: {"id": "6-p", "level_type": "top_p", "nodes": [{"content": "elev", "parent": 0, "prob": 0.02895674790207378}, {"content": "height", "parent": 0, "prob": 0.013678199185093203}, {"content": "peak", "parent": 0, "prob": 0.013678199185093203}, {"content": "above", "parent": 0, "prob": 0.0034583899410235456}, {"content": "its", "parent": 0, "prob": 0.002693396794235407}, {"content": "sum", "parent": 0, "prob": 0.002693396794235407}, {"content": "Earth", "parent": 0, "prob": 0.0020976195324725463}, {"content": "as", "parent": 1, "prob": 0.00494511213828901}, {"content": "based", "parent": 1, "prob": 0.0038512569761225794}, {"content": "in", "parent": 1, "prob": 0.002646927945096369}]}





num_batches


2

infer start: GPU memory used: 15235 MB.

batch_candidates
torch.Size([8, 18])

batch_candidate_logprobs
torch.Size([8])

batch_logits
torch.Size([8, 18, 32064])

hidden_states[-1]
torch.Size([8, 18, 3072])
infer - after batch run: GPU memory used: 15353 MB.

batch_candidates
torch.Size([2, 18])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 18, 32064])

hidden_states[-1]
torch.Size([2, 18, 3072])
infer - after batch run: GPU memory used: 15381 MB.

candidates
torch.Size([10, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,   967],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  2533],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958, 


candidate_logprobs
torch.Size([10])


tensor([-3.5420, -4.2920, -4.2920, -5.6670, -5.9170, -5.9170, -6.1670, -5.3094,
        -5.5594, -5.9344], device='cuda:0', dtype=torch.float64)


embeddings_np
(10, 3072)


array([[ 1.5625    , -2.265625  ,  1.78125   , ...,  1.15625   ,
         1.265625  , -0.3515625 ],
       [-0.71484375,  0.29296875,  3.296875  , ...,  0.05078125,
         1.15625   , -0.35546875],
       [-0.75      , -1.2109375 ,  3.328125  , ...,  0.03637695,
        -1.0546875 , -0.5078125 ],
       ...,
       [-1.1328125 ,  0.54296875,  0.23242188, ..., -1.6953125 ,
         1.109375  , -5.3125    ],
       [ 0.51171875, -1.3359375 ,  0.40625   , ...,  0.41601562,
         0.46289062, -3.9375    ],
       [-2.484375  , -0.44921875,  2.4375    , ...,  0.34375   ,
        -2.703125  , -4.28125   ]], dtype=float32)


k_mean_space
(10, 2)


array([[51.670624, 84.66034 ],
       [87.51908 , 61.940083],
       [85.87011 , 58.84681 ],
       [92.760315, 71.27328 ],
       [89.12618 , 60.97973 ],
       [51.670624, 82.49791 ],
       [90.98064 , 63.01418 ],
       [97.24759 , 67.10196 ],
       [97.51835 , 69.43807 ],
       [94.846634, 62.403366]], dtype=float32)


k_mean_clusters
(10,)


array([0, 1, 1, 1, 1, 0, 1, 1, 1, 1], dtype=int32)


k_mean_logprob_mass
(2,)


array([ -9.45890415, -43.13782734])


closest
(2,)


array([0, 2])


new_candidates
torch.Size([2, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224]],
       device='cuda:0')


new_candidate_parents


[0, 2]


new_candidate_aunts


[[0, 5], [1, 2, 3, 4, 6, 7, 8, 9]]


new_candidate_logprobs
torch.Size([2])


tensor([ -9.4589, -43.1378], device='cuda:0', dtype=torch.float64)

event: message
id: 7-k"
data: {"id": "7-k", "level_type": "k_means", "nodes": [{"content": "elev", "parent": 0, "aunts": [0, 5], "prob": 7.799201197092837e-05}, {"content": "peak", "parent": 2, "aunts": [1, 2, 3, 4, 6, 7, 8, 9], "prob": 1.8428060336253292e-19}]}





candidates
torch.Size([2, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224]],
       device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([ -9.4589, -43.1378], device='cuda:0', dtype=torch.float64)


last_tok_logits
torch.Size([2, 32064])


tensor([[-0.7148, -1.0703, -1.1406,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.1250, -1.9844, -7.7188,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[9.9999e-01, 7.8892e-06, 1.9947e-06,  ..., 2.6103e-23, 1.5832e-23,
         6.5997e-24],
        [6.9831e-01, 2.0007e-01, 9.4507e-02,  ..., 2.5163e-22, 1.7294e-22,
         1.6086e-23]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.6983, 0.8984, 0.9929,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([4])


tensor([0, 1, 1, 1], device='cuda:0')


carryover_candidates
torch.Size([4, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([4])


tensor([ -9.4589, -43.1378, -43.1378, -43.1378], device='cuda:0',
       dtype=torch.float64)


new_candidate_toks
torch.Size([4, 1])


tensor([[  362],
        [11858],
        [ 3171],
        [ 2038]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([4])


tensor([-1.0610e-05, -3.5909e-01, -1.6091e+00, -2.3591e+00], device='cuda:0')


new_candidates
torch.Size([4, 19])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858,   362],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224,  2038]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([4])


tensor([ -9.4589, -43.4969, -44.7469, -45.4969], device='cuda:0',
       dtype=torch.float64)

event: message
id: 7-p
data: {"id": "7-p", "level_type": "top_p", "nodes": [{"content": "ation", "parent": 0, "prob": 7.799118450478204e-05}, {"content": "elev", "parent": 1, "prob": 1.286858130179515e-19}, {"content": "height", "parent": 1, "prob": 3.686910491506643e-20}, {"content": "above", "parent": 1, "prob": 1.741572991525982e-20}]}





num_batches


1

infer start: GPU memory used: 15405 MB.

batch_candidates
torch.Size([4, 19])

batch_candidate_logprobs
torch.Size([4])

batch_logits
torch.Size([4, 19, 32064])

hidden_states[-1]
torch.Size([4, 19, 3072])
infer - after batch run: GPU memory used: 15405 MB.

candidates
torch.Size([4, 19])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858,   362],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224,  2038]],
       device='cuda:0')


candidate_logprobs
torch.Size([4])


tensor([ -9.4589, -43.4969, -44.7469, -45.4969], device='cuda:0',
       dtype=torch.float64)


embeddings_np
(4, 3072)


array([[-1.2265625 , -1.7265625 ,  1.828125  , ..., -0.35742188,
         0.8828125 ,  0.328125  ],
       [ 2.125     , -2.265625  ,  2.328125  , ...,  0.8203125 ,
         0.359375  , -2.0625    ],
       [-1.0546875 , -0.04760742,  1.234375  , ..., -2.640625  ,
        -0.08251953, -0.78125   ],
       [-2.328125  , -1.921875  ,  2.46875   , ...,  1.109375  ,
         0.11474609, -1.1484375 ]], dtype=float32)


k_mean_space
(4, 2)


array([[ 42.043888, 100.05355 ],
       [ 64.13674 , 105.39813 ],
       [ 43.182697, 100.9499  ],
       [ 88.62732 ,   0.      ]], dtype=float32)


k_mean_clusters
(4,)


array([0, 0, 0, 1], dtype=int32)


k_mean_logprob_mass
(2,)


array([-97.70274085, -45.49691314])


closest
(2,)


array([0, 3])


new_candidates
torch.Size([2, 19])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858,   362],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224,  2038]],
       device='cuda:0')


new_candidate_parents


[0, 3]


new_candidate_aunts


[[0, 1, 2], [3]]


new_candidate_logprobs
torch.Size([2])


tensor([-97.7027, -45.4969], device='cuda:0', dtype=torch.float64)

event: message
id: 8-k"
data: {"id": "8-k", "level_type": "k_means", "nodes": [{"content": "ation", "parent": 0, "aunts": [0, 1, 2], "prob": 3.700315724286182e-43}, {"content": "above", "parent": 3, "aunts": [3], "prob": 1.741572991525982e-20}]}





candidates
torch.Size([2, 19])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858,   362],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224,  2038]],
       device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-97.7027, -45.4969], device='cuda:0', dtype=torch.float64)


last_tok_logits
torch.Size([2, 32064])


tensor([[-0.0850, -0.9180, -5.1875,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.6797, -1.0234, -9.4375,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[9.9033e-01, 3.1520e-03, 2.1663e-03,  ..., 1.5174e-20, 1.3391e-20,
         2.3270e-21],
        [1.0000e+00, 8.3153e-07, 5.0435e-07,  ..., 1.6687e-24, 1.1469e-24,
         8.9319e-25]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[0.9903, 0.9935, 0.9956,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 1], device='cuda:0')


carryover_candidates
torch.Size([2, 19])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858,   362],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224,  2038]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-97.7027, -45.4969], device='cuda:0', dtype=torch.float64)


new_candidate_toks
torch.Size([2, 1])


tensor([[2038],
        [7205]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-9.7195e-03, -2.3842e-06], device='cuda:0')


new_candidates
torch.Size([2, 20])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858,   362,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224,  2038,  7205]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-97.7125, -45.4969], device='cuda:0', dtype=torch.float64)

event: message
id: 8-p
data: {"id": "8-p", "level_type": "top_p", "nodes": [{"content": "above", "parent": 0, "prob": 3.66452455770134e-43}, {"content": "sea", "parent": 1, "prob": 1.7415688392922036e-20}]}





num_batches


1

infer start: GPU memory used: 15405 MB.

batch_candidates
torch.Size([2, 20])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 20, 32064])

hidden_states[-1]
torch.Size([2, 20, 3072])
infer - after batch run: GPU memory used: 15405 MB.

candidates
torch.Size([2, 20])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858,   362,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224,  2038,  7205]],
       device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-97.7125, -45.4969], device='cuda:0', dtype=torch.float64)


last_tok_logits
torch.Size([2, 32064])


tensor([[  1.5781,  -3.1719, -11.3125,  ...,   0.0000,   0.0000,   0.0000],
        [  0.8398,  -5.3438, -11.2500,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[9.9999e-01, 1.9947e-06, 1.5535e-06,  ..., 3.3516e-23, 2.9578e-23,
         2.3035e-23],
        [1.0000e+00, 2.6996e-07, 1.8554e-07,  ..., 4.9401e-28, 3.8473e-28,
         5.9001e-29]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 1], device='cuda:0')


carryover_candidates
torch.Size([2, 20])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858,   362,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224,  2038,  7205]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-97.7125, -45.4969], device='cuda:0', dtype=torch.float64)


new_candidate_toks
torch.Size([2, 1])


tensor([[7205],
        [3233]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-5.9605e-06, -5.9605e-07], device='cuda:0')


new_candidates
torch.Size([2, 21])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858,   362,  2038,
          7205],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 19224,  2038,  7205,
          3233]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-97.7125, -45.4969], device='cuda:0', dtype=torch.float64)

event: message
id: 9-p
data: {"id": "9-p", "level_type": "top_p", "nodes": [{"content": "sea", "parent": 0, "prob": 3.664502715432991e-43}, {"content": "level", "parent": 1, "prob": 1.7415678012362957e-20}]}




event: message
id: END
data: []




