In [1]:
!pip install transformers accelerate optimum nvidia-ml-py torchmetrics



In [3]:
from transformers.utils import is_flash_attn_2_available
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import numpy as np
import torch.nn.functional as F
import torch
from datetime import timedelta
import time
from collections import namedtuple
import json
from sklearn.cluster import KMeans
from torchmetrics.functional import pairwise_cosine_similarity

torch.random.manual_seed(0)

<torch._C.Generator at 0x7fd86f22a2b0>

In [4]:


from pynvml import *

def check_gpu(step):
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"{step}: GPU memory used: {info.used // 1024**2} MB.")
    
def D(obj, label=None, c=True):
    print()
    if label:
        print(label)
        
    if isinstance(obj, tuple):
        print('tuple size', len(obj), ':')
        if c: # Contents
            display(obj)
    elif isinstance(obj, torch.Tensor) or isinstance(obj, np.ndarray):
        print(obj.shape)
        if c: # Contents
            display(obj)
    else:
        if c: # Contents
            display(obj)
            
def DS(obj, label=None):
    D(obj, label, c=False)
    


In [29]:
class InferenceTensor:
    def __init__(self):
        print('Initializing model...')
        self.model = AutoModelForCausalLM.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct",
            torch_dtype=torch.bfloat16,
            device_map='auto',
            trust_remote_code=True,
            use_cache=True,
            # attn_implementation='flash_attention_2',
        )
        print('Initializing tokenizer...')
        self.tokenizer = AutoTokenizer.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.batch_size = 8
        
    def candidates_generator(self, top_p: float, top_p_decay: float, top_k: float, max_beams: int, max_new_tokens: int, prompt: str):
        candidates, candidate_logprobs = self._init_candidates(prompt)
        for level_idx in range(max_new_tokens):
            logits, embeddings = self._infer(candidates, candidate_logprobs)

            self._farthest_neighbors(logits, embeddings, candidates, candidate_logprobs, max_beams)
            
            if candidates.shape[0] > max_beams:
                start = time.perf_counter()
                candidates, candidate_parents, candidate_aunts, candidate_logprobs, logits = self._k_means(logits, embeddings, candidates, candidate_logprobs, max_beams)
                inference_duration = time.perf_counter() - start
                print('K MEANS PRIOR {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))
                yield self._format_k_means(level_idx, candidates, candidate_parents, candidate_aunts, candidate_logprobs, inference_duration)
                print('K MEANS AFTER {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))

            start = time.perf_counter()
            candidates, candidate_parents, candidate_logprobs = self._top_p(logits, candidates, candidate_logprobs, top_p, top_k)
            inference_duration = time.perf_counter() - start
            print('TOP P PRIOR {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))
            yield self._format_top_p(level_idx, candidates, candidate_parents, candidate_logprobs, inference_duration)
            print('TOP P AFTER {}: ({}) {} candidates, {} inference time, {} total time'.format(level_idx, time.perf_counter(), candidates.shape[0], inference_duration, time.perf_counter() - start))
            top_p *= top_p_decay

        yield f"event: message\nid: END\ndata: []\n\n"

    def _format_k_means(self, level_idx, candidates, candidate_parents, candidate_aunts, candidate_logprobs, duration):
        candidate_texts = self.tokenizer.convert_ids_to_tokens(candidates[:, -1], skip_special_tokens=True)
        candidate_probs = candidate_logprobs.exp()
        candidate_dicts = []
        idx = f"{level_idx}-k"
        for i in range(len(candidate_texts)):
            candidate_dicts.append({'content': candidate_texts[i], 'parent': candidate_parents[i], 'aunts': candidate_aunts[i], 'prob': candidate_probs[i].item()})
        data = json.dumps({'id': idx, 'level_type': 'gather', 'duration': duration, 'nodes': candidate_dicts})
        return f"event: message\nid: {idx}\"\ndata: {data}\n\n"

    def _format_top_p(self, level_idx, candidates, candidate_parents, candidate_logprobs, duration):
        candidate_texts = self.tokenizer.convert_ids_to_tokens(candidates[:, -1], skip_special_tokens=True)
        candidate_probs = candidate_logprobs.exp()
        candidate_dicts = []
        idx = f"{level_idx}-p"
        for i in range(len(candidate_texts)):
            candidate_dicts.append({'content': candidate_texts[i], 'parent': candidate_parents[i], 'prob': candidate_probs[i].item()})
        data = json.dumps({'id': idx, 'level_type': 'sample', 'duration': duration, 'nodes': candidate_dicts})
        return f"event: message\nid: {idx}\ndata: {data}\n\n"


    def _init_candidates(self, text: str):
        prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
        inputs = self.tokenizer(prompt, return_tensors='pt')
        D(inputs.input_ids, 'input_ids')

        candidates = inputs.input_ids.to(self.device)
        candidate_logprobs = torch.zeros((1), dtype=torch.float32, device=self.device)

        return candidates, candidate_logprobs

    def _k_means(self, logits, embeddings, candidates, candidate_logprobs, max_beams):
        D(candidates, 'candidates')
        D(candidate_logprobs, 'candidate_logprobs')
        # === CPU ===
        embeddings_np = embeddings.float().numpy(force=True)
        D(embeddings_np, 'embeddings_np')
        k_means = KMeans(n_clusters=min(max_beams, embeddings_np.shape[0]), random_state=0, n_init="auto")
        k_mean_space = k_means.fit_transform(embeddings_np)
        D(k_mean_space, 'k_mean_space')
        k_mean_clusters = k_means.predict(embeddings_np)
        D(k_mean_clusters, 'k_mean_clusters')
        k_mean_logprob_mass = np.log(np.bincount(k_mean_clusters, weights=candidate_logprobs.cpu().exp()))
        D(k_mean_logprob_mass, 'k_mean_logprob_mass')
        closest = np.argmin(k_mean_space, axis=0)
        D(closest, 'closest')
        # === END CPU ===
        
        closest_indices = torch.from_numpy(closest).to(self.device)
        new_candidates = candidates.index_select(0, closest_indices)
        D(new_candidates, 'new_candidates')
        new_candidate_parents = closest_indices.tolist()
        D(new_candidate_parents, 'new_candidate_parents')
        new_candidate_aunts = [torch.nonzero(torch.from_numpy(k_mean_clusters).to(self.device) == i).squeeze(-1).tolist() for i in range(new_candidates.shape[0])]
        D(new_candidate_aunts, 'new_candidate_aunts')
        new_candidate_logprobs = torch.from_numpy(k_mean_logprob_mass).to(self.device)
        D(new_candidate_logprobs, 'new_candidate_logprobs')
        new_candidate_logits = logits.index_select(0, closest_indices)
        
        return new_candidates, new_candidate_parents, new_candidate_aunts, new_candidate_logprobs, new_candidate_logits
        
    def _farthest_neighbors(self, logits, embeddings, candidates, candidate_logprobs, max_beams):
        print('FARTHEST NEIGHBORS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
        D(candidates, 'candidates')
        D(candidate_logprobs, 'candidate_logprobs')
        D(embeddings, 'embeddings')
        
        selected = torch.zeros((candidates.shape[0],), dtype=torch.bool).to(self.device)
        max_prob_idx = candidate_logprobs.argmax()
        selected[max_prob_idx] = 1
        
        D(selected, 'selected')
        
        for idx in range(min(max_beams - 1, candidates.shape[0])):
            selected_embeddings = embeddings[selected]
            D(selected_embeddings, 'selected_embeddings')
            # Add 2 because bfloat16 on cuda can have imprecision and we need 0 to be lower than every
            # cosine distance
            distances = torch.add(2, pairwise_cosine_similarity(embeddings, selected_embeddings), alpha=-1)
            D(distances, 'distances')
            min_distances = torch.min(distances, dim=1).values
            D(min_distances, 'min_distances')
            min_remaining_distances = min_distances * ~selected
            D(min_remaining_distances, 'min_remaining_distances')
            next_selected = min_remaining_distances.argmax(dim=0)
            selected[next_selected] = 1
            D(selected, 'selected (end of loop)')
            
        # We have all the candidates that are selected to move forward. Figure out which probability mass
        # to assign where.
        selected_embeddings = embeddings[selected]
        D(selected_embeddings, 'selected_embeddings')
        # Add 2 because bfloat16 on cuda can have imprecision and we need 0 to be lower than every
        # cosine distance
        distances = torch.add(2, pairwise_cosine_similarity(embeddings, selected_embeddings), alpha=-1)
        D(distances, 'distances')
        
        closest_per_candidate = distances.argmin(dim=1)
        D(closest_per_candidate, 'closest_per_candidate')
        
        new_candidates = candidates[selected]
        D(new_candidates, 'new_candidates')
        new_candidate_parents = torch.arange(candidates.shape[0]).to(self.device)[selected]
        D(new_candidate_parents, 'new_candidate_parents')
        new_candidate_aunts = [list(torch.nonzero(closest_per_candidate == i).squeeze(-1).tolist().filter(lambda x: x != i)) \
                       for i in range(new_candidates.shape[0])]
        D(new_candidate_aunts, 'new_candidate_aunts')
        new_candidate_logprobs = torch.zeros((new_candidates.shape[0],)).to(self.device)
        new_candidate_logprobs.index_add_(0, closest_per_candidate, candidate_logprobs)
        D(new_candidate_logprobs, 'new_candidate_logprobs')
        new_candidate_logits = logits[selected]
                


    def _top_p(self, logits, candidates, candidate_logprobs, top_p, top_k):
        D(candidates, 'candidates')
        D(candidate_logprobs, 'candidate_logprobs')
        
        last_tok_logits = logits[:, -1, :]
        D(last_tok_logits, 'last_tok_logits')

        sorted_logits, sorted_indices = torch.sort(last_tok_logits, descending=True, dim=-1)
        DS(sorted_logits, 'sorted_logits')
        DS(sorted_indices, 'sorted_indices')
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        D(sorted_probs, 'sorted_probs')
        display(sorted_probs.sum(dim=1))
        cum_probs = torch.cumsum(sorted_probs, dim=-1)
        D(cum_probs, 'cum_probs')

        # Create tensor of bools indicating which indices are cumulatively less than top_p
        keep_indices = cum_probs < top_p

        # Keep the last element that went over top_p
        keep_indices[:, 1:] = keep_indices[:, :-1].clone() # Is this inefficient?
        keep_indices[:, 0] = 1  # Always keep the first element
        D(keep_indices, 'keep_indices')

        # Don't keep any indices that are greater than top_k
        keep_indices[:, top_k:] = 0
        D(keep_indices, 'keep_indices after top_k')

        new_candidate_parents = keep_indices.nonzero()[:, 0]
        D(new_candidate_parents, 'new_candidate_parents')

        # OPTIM: Potential optimization -- have a fixed tensor of size (max_candidates, max_tokens) and copy this into that (batch-aware).
        # OPTIM: consider which of these operations can be done in-place to prevent new allocations?
        carryover_candidates = candidates.index_select(0, new_candidate_parents)
        D(carryover_candidates, 'carryover_candidates')

        # Similar code could be used to trace entire origin of sequence. For now since server just traces parent of the preceding generation, not needed
        # carryover_candidate_parents = candidate_parents.index_select(0, carryover_candidate_indices)  # Not strictly necessary since 1d
        # D(carryover_candidate_parents, 'carryover_candidate_parents')

        carryover_candidate_logprobs = candidate_logprobs.index_select(0, new_candidate_parents)  # Not strictly necessary since 1d
        D(carryover_candidate_logprobs, 'carryover_candidate_logprobs')

        new_candidate_toks = sorted_indices[keep_indices].unsqueeze(1)
        D(new_candidate_toks, 'new_candidate_toks')
        new_candidate_tok_logprobs = sorted_probs[keep_indices].log()
        D(new_candidate_tok_logprobs, 'new_candidate_tok_logprobs')

        new_candidates = torch.cat([carryover_candidates, new_candidate_toks], dim=1)
        D(new_candidates, 'new_candidates')
        new_candidate_logprobs = carryover_candidate_logprobs.add_(new_candidate_tok_logprobs)
        D(new_candidate_logprobs, 'new_candidate_logprobs')

        return new_candidates, new_candidate_parents.tolist(), new_candidate_logprobs


    def _infer(self, candidates, candidate_logprobs):
        with torch.inference_mode():
            num_batches = (candidates.shape[0] + self.batch_size - 1) // self.batch_size  # Round up to nearest whole number of batches
            D(num_batches, 'num_batches')

            check_gpu('infer start')
            output_logits_list = []
            output_embeddings_list = []
            for i in range(0, num_batches, 1):
                batch_candidates = candidates[i * self.batch_size:(i + 1) * self.batch_size]
                DS(batch_candidates, 'batch_candidates')
                batch_candidate_logprobs = candidate_logprobs[i * self.batch_size:(i + 1) * self.batch_size]
                DS(batch_candidate_logprobs, 'batch_candidate_logprobs')

                batch_outputs = self.model(input_ids=batch_candidates, output_hidden_states=True)
                DS(batch_outputs.logits, 'batch_logits')
                DS(batch_outputs.hidden_states[-1], 'hidden_states[-1]')

                output_logits_list.append(batch_outputs.logits)
                output_embeddings_list.append(batch_outputs.hidden_states[-1][:,-1,:])
                check_gpu('infer - after batch run')

            output_logits = torch.cat(output_logits_list, dim=0)
            output_embeddings = torch.cat(output_embeddings_list, dim=0)
            
            return output_logits, output_embeddings
        
it = InferenceTensor()

for x in it.candidates_generator(top_p=0.9, top_p_decay=0.99, top_k=2, max_beams=3, max_new_tokens=10, prompt='What is the highest mountain?'):
    print(x)
    print()
    print('====================================')
    print()

Initializing model...




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Initializing tokenizer...

input_ids
torch.Size([1, 11])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001]])


num_batches


1

infer start: GPU memory used: 14583 MB.

batch_candidates
torch.Size([1, 11])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 11, 32064])

hidden_states[-1]
torch.Size([1, 11, 3072])
infer - after batch run: GPU memory used: 14775 MB.
FARTHEST NEIGHBORS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

candidates
torch.Size([1, 11])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([0.], device='cuda:0')


embeddings
torch.Size([1, 3072])


tensor([[-0.7383,  1.3906,  1.9766,  ...,  2.1094, -0.7656, -0.1934]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([1])


tensor([True], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.7383,  1.3906,  1.9766,  ...,  2.1094, -0.7656, -0.1934]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([1, 1])


tensor([[0.9922]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([1])


tensor([0.9922], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([1])


tensor([0.], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([1])


tensor([True], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.7383,  1.3906,  1.9766,  ...,  2.1094, -0.7656, -0.1934]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([1, 1])


tensor([[0.9922]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([1])


tensor([0], device='cuda:0')


new_candidates
torch.Size([1, 11])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


new_candidate_aunts


[[0]]


new_candidate_logprobs
torch.Size([1])


tensor([0.], device='cuda:0')


candidates
torch.Size([1, 11])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([0.], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 5.7812, -1.4688, -3.2969,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.2405e-01, 7.5850e-02, 8.8812e-05,  ..., 1.1622e-21, 7.0490e-22,
         3.7730e-22]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9240, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 11])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([0.], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[450]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0790], device='cuda:0')


new_candidates
torch.Size([1, 12])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-0.0790], device='cuda:0')

TOP P PRIOR 0: (8559.413267655) 1 candidates, 0.025079651999476482 inference time, 0.025082172998736496 total time
event: message
id: 0-p
data: {"id": "0-p", "level_type": "sample", "duration": 0.025079651999476482, "nodes": [{"content": "\u2581The", "parent": 0, "prob": 0.924048125743866}]}




TOP P AFTER 0: (8559.413601589) 1 candidates, 0.025079651999476482 inference time, 0.02541591699991841 total time

num_batches


1

infer start: GPU memory used: 14775 MB.

batch_candidates
torch.Size([1, 12])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 12, 32064])

hidden_states[-1]
torch.Size([1, 12, 3072])
infer - after batch run: GPU memory used: 14775 MB.
FARTHEST NEIGHBORS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

candidates
torch.Size([1, 12])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-0.0790], device='cuda:0')


embeddings
torch.Size([1, 3072])


tensor([[-0.5547,  0.8164,  1.5469,  ...,  0.9648, -1.7188,  0.1953]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([1])


tensor([True], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.5547,  0.8164,  1.5469,  ...,  0.9648, -1.7188,  0.1953]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([1, 1])


tensor([[0.9922]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([1])


tensor([0.9922], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([1])


tensor([0.], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([1])


tensor([True], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.5547,  0.8164,  1.5469,  ...,  0.9648, -1.7188,  0.1953]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([1, 1])


tensor([[0.9922]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([1])


tensor([0], device='cuda:0')


new_candidates
torch.Size([1, 12])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


new_candidate_aunts


[[0]]


new_candidate_logprobs
torch.Size([1])


tensor([-0.0790], device='cuda:0')


candidates
torch.Size([1, 12])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-0.0790], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[-1.2969,  1.0312, -5.0938,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9909e-01, 9.1105e-04, 1.2088e-06,  ..., 6.5938e-24, 3.9993e-24,
         1.4713e-24]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9991, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 12])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-0.0790], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[9939]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0009], device='cuda:0')


new_candidates
torch.Size([1, 13])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-0.0799], device='cuda:0')

TOP P PRIOR 1: (8559.569830508) 1 candidates, 0.025697528000819148 inference time, 0.02570001899948693 total time
event: message
id: 1-p
data: {"id": "1-p", "level_type": "sample", "duration": 0.025697528000819148, "nodes": [{"content": "\u2581highest", "parent": 0, "prob": 0.9232035875320435}]}




TOP P AFTER 1: (8559.570429112) 1 candidates, 0.025697528000819148 inference time, 0.026299402999939048 total time

num_batches


1

infer start: GPU memory used: 14775 MB.

batch_candidates
torch.Size([1, 13])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 13, 32064])

hidden_states[-1]
torch.Size([1, 13, 3072])
infer - after batch run: GPU memory used: 14777 MB.
FARTHEST NEIGHBORS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

candidates
torch.Size([1, 13])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-0.0799], device='cuda:0')


embeddings
torch.Size([1, 3072])


tensor([[-0.6406, -0.6133,  3.0625,  ..., -0.5703, -1.8750,  0.8086]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([1])


tensor([True], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.6406, -0.6133,  3.0625,  ..., -0.5703, -1.8750,  0.8086]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([1, 1])


tensor([[1.]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([1])


tensor([1.], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([1])


tensor([0.], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([1])


tensor([True], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.6406, -0.6133,  3.0625,  ..., -0.5703, -1.8750,  0.8086]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([1, 1])


tensor([[1.]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([1])


tensor([0], device='cuda:0')


new_candidates
torch.Size([1, 13])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


new_candidate_aunts


[[0]]


new_candidate_logprobs
torch.Size([1])


tensor([-0.0799], device='cuda:0')


candidates
torch.Size([1, 13])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-0.0799], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 2.8906,  3.1875, -7.0938,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[9.9984e-01, 1.2339e-04, 3.5352e-05,  ..., 5.1391e-24, 4.5352e-24,
         3.5320e-24]], device='cuda:0')

tensor([1.], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.9998, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


carryover_candidates
torch.Size([1, 13])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([1])


tensor([-0.0799], device='cuda:0')


new_candidate_toks
torch.Size([1, 1])


tensor([[14378]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([1])


tensor([-0.0002], device='cuda:0')


new_candidates
torch.Size([1, 14])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378]], device='cuda:0')


new_candidate_logprobs
torch.Size([1])


tensor([-0.0801], device='cuda:0')

TOP P PRIOR 2: (8559.725297012) 1 candidates, 0.02661793599872908 inference time, 0.026620985998306423 total time
event: message
id: 2-p
data: {"id": "2-p", "level_type": "sample", "duration": 0.02661793599872908, "nodes": [{"content": "\u2581mountain", "parent": 0, "prob": 0.9230522513389587}]}




TOP P AFTER 2: (8559.725592053) 1 candidates, 0.02661793599872908 inference time, 0.026915407999695162 total time

num_batches


1

infer start: GPU memory used: 14777 MB.

batch_candidates
torch.Size([1, 14])

batch_candidate_logprobs
torch.Size([1])

batch_logits
torch.Size([1, 14, 32064])

hidden_states[-1]
torch.Size([1, 14, 3072])
infer - after batch run: GPU memory used: 14777 MB.
FARTHEST NEIGHBORS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

candidates
torch.Size([1, 14])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-0.0801], device='cuda:0')


embeddings
torch.Size([1, 3072])


tensor([[-2.1094, -1.6094,  2.3125,  ..., -2.2188, -1.5078, -0.1445]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([1])


tensor([True], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-2.1094, -1.6094,  2.3125,  ..., -2.2188, -1.5078, -0.1445]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([1, 1])


tensor([[1.]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([1])


tensor([1.], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([1])


tensor([0.], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([1])


tensor([True], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-2.1094, -1.6094,  2.3125,  ..., -2.2188, -1.5078, -0.1445]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([1, 1])


tensor([[1.]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([1])


tensor([0], device='cuda:0')


new_candidates
torch.Size([1, 14])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378]], device='cuda:0')


new_candidate_parents
torch.Size([1])


tensor([0], device='cuda:0')


new_candidate_aunts


[[0]]


new_candidate_logprobs
torch.Size([1])


tensor([-0.0801], device='cuda:0')


candidates
torch.Size([1, 14])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378]], device='cuda:0')


candidate_logprobs
torch.Size([1])


tensor([-0.0801], device='cuda:0')


last_tok_logits
torch.Size([1, 32064])


tensor([[ 4.8125,  1.2734, -5.8750,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([1, 32064])

sorted_indices
torch.Size([1, 32064])

sorted_probs
torch.Size([1, 32064])


tensor([[8.1600e-01, 9.7458e-02, 8.6006e-02,  ..., 5.2119e-21, 4.0591e-21,
         1.6921e-21]], device='cuda:0')

tensor([1.0000], device='cuda:0')


cum_probs
torch.Size([1, 32064])


tensor([[0.8160, 0.9135, 0.9995,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([1, 32064])


tensor([[ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 0], device='cuda:0')


carryover_candidates
torch.Size([2, 14])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([2])


tensor([-0.0801, -0.0801], device='cuda:0')


new_candidate_toks
torch.Size([2, 1])


tensor([[373],
        [297]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([2])


tensor([-0.2033, -2.3283], device='cuda:0')


new_candidates
torch.Size([2, 15])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297]], device='cuda:0')


new_candidate_logprobs
torch.Size([2])


tensor([-0.2834, -2.4084], device='cuda:0')

TOP P PRIOR 3: (8559.883655938) 2 candidates, 0.0287994959999196 inference time, 0.0288037570007873 total time
event: message
id: 3-p
data: {"id": "3-p", "level_type": "sample", "duration": 0.0287994959999196, "nodes": [{"content": "\u2581on", "parent": 0, "prob": 0.7532129883766174}, {"content": "\u2581in", "parent": 0, "prob": 0.08995844423770905}]}




TOP P AFTER 3: (8559.884080169) 2 candidates, 0.0287994959999196 inference time, 0.029228007999336114 total time

num_batches


1

infer start: GPU memory used: 14777 MB.

batch_candidates
torch.Size([2, 15])

batch_candidate_logprobs
torch.Size([2])

batch_logits
torch.Size([2, 15, 32064])

hidden_states[-1]
torch.Size([2, 15, 3072])
infer - after batch run: GPU memory used: 14793 MB.
FARTHEST NEIGHBORS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

candidates
torch.Size([2, 15])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-0.2834, -2.4084], device='cuda:0')


embeddings
torch.Size([2, 3072])


tensor([[-1.8906, -1.4062,  3.4375,  ..., -0.0107, -1.4531, -2.4688],
        [-1.2656, -1.3281,  3.3750,  ..., -0.0239, -1.3516, -2.4062]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([2])


tensor([ True, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-1.8906, -1.4062,  3.4375,  ..., -0.0107, -1.4531, -2.4688]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([2, 1])


tensor([[1.0000],
        [1.5312]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([2])


tensor([1.0000, 1.5312], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([2])


tensor([0.0000, 1.5312], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([2])


tensor([True, True], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[-1.8906, -1.4062,  3.4375,  ..., -0.0107, -1.4531, -2.4688],
        [-1.2656, -1.3281,  3.3750,  ..., -0.0239, -1.3516, -2.4062]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([2, 2])


tensor([[1.0000, 1.5312],
        [1.5312, 0.9922]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([2])


tensor([1.0000, 0.9922], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([2])


tensor([0., 0.], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([2])


tensor([True, True], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[-1.8906, -1.4062,  3.4375,  ..., -0.0107, -1.4531, -2.4688],
        [-1.2656, -1.3281,  3.3750,  ..., -0.0239, -1.3516, -2.4062]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([2, 2])


tensor([[1.0000, 1.5312],
        [1.5312, 0.9922]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([2])


tensor([0, 1], device='cuda:0')


new_candidates
torch.Size([2, 15])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297]], device='cuda:0')


new_candidate_parents
torch.Size([2])


tensor([0, 1], device='cuda:0')


new_candidate_aunts


[[0], [1]]


new_candidate_logprobs
torch.Size([2])


tensor([-0.2834, -2.4084], device='cuda:0')


candidates
torch.Size([2, 15])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297]], device='cuda:0')


candidate_logprobs
torch.Size([2])


tensor([-0.2834, -2.4084], device='cuda:0')


last_tok_logits
torch.Size([2, 32064])


tensor([[ 1.6797,  0.8750, -4.7812,  ...,  0.0000,  0.0000,  0.0000],
        [ 3.3438, -5.4688, -4.6875,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([2, 32064])

sorted_indices
torch.Size([2, 32064])

sorted_probs
torch.Size([2, 32064])


tensor([[9.9996e-01, 3.5356e-05, 1.7603e-06,  ..., 3.5325e-24, 2.4278e-24,
         3.7232e-25],
        [8.1683e-01, 1.8226e-01, 5.1193e-04,  ..., 3.4021e-20, 3.4021e-20,
         7.5910e-21]], device='cuda:0')

tensor([1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([2, 32064])


tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.8168, 0.9991, 0.9996,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([2, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 1], device='cuda:0')


carryover_candidates
torch.Size([3, 15])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([-0.2834, -2.4084, -2.4084], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[11563],
        [ 4958],
        [  278]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-3.7432e-05, -2.0232e-01, -1.7023e+00], device='cuda:0')


new_candidates
torch.Size([3, 16])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278]], device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([-0.2834, -2.6107, -4.1107], device='cuda:0')

TOP P PRIOR 4: (8560.052081643) 3 candidates, 0.026359686999057885 inference time, 0.026362716998846736 total time
event: message
id: 4-p
data: {"id": "4-p", "level_type": "sample", "duration": 0.026359686999057885, "nodes": [{"content": "\u2581Earth", "parent": 0, "prob": 0.7531847953796387}, {"content": "\u2581terms", "parent": 1, "prob": 0.0734809935092926}, {"content": "\u2581the", "parent": 1, "prob": 0.016395829617977142}]}




TOP P AFTER 4: (8560.052461901) 3 candidates, 0.026359686999057885 inference time, 0.026742464999188087 total time

num_batches


1

infer start: GPU memory used: 14793 MB.

batch_candidates
torch.Size([3, 16])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 16, 32064])

hidden_states[-1]
torch.Size([3, 16, 3072])
infer - after batch run: GPU memory used: 14809 MB.
FARTHEST NEIGHBORS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

candidates
torch.Size([3, 16])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-0.2834, -2.6107, -4.1107], device='cuda:0')


embeddings
torch.Size([3, 3072])


tensor([[-1.4766, -0.0113,  1.4922,  ..., -0.9492, -0.2773, -1.5156],
        [ 0.6406, -1.1094,  2.7188,  ..., -0.9141,  2.5625, -2.1875],
        [ 0.0933, -0.6797,  3.3906,  ..., -0.6992, -3.2812, -1.0859]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([3])


tensor([ True, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-1.4766, -0.0113,  1.4922,  ..., -0.9492, -0.2773, -1.5156]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 1])


tensor([[1.0000],
        [1.8047],
        [1.7422]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([3])


tensor([1.0000, 1.8047, 1.7422], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([3])


tensor([0.0000, 1.8047, 1.7422], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([3])


tensor([ True,  True, False], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[-1.4766, -0.0113,  1.4922,  ..., -0.9492, -0.2773, -1.5156],
        [ 0.6406, -1.1094,  2.7188,  ..., -0.9141,  2.5625, -2.1875]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 2])


tensor([[1.0000, 1.8047],
        [1.8047, 0.9922],
        [1.7422, 1.8203]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([3])


tensor([1.0000, 0.9922, 1.7422], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([3])


tensor([0.0000, 0.0000, 1.7422], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([3])


tensor([True, True, True], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[-1.4766, -0.0113,  1.4922,  ..., -0.9492, -0.2773, -1.5156],
        [ 0.6406, -1.1094,  2.7188,  ..., -0.9141,  2.5625, -2.1875],
        [ 0.0933, -0.6797,  3.3906,  ..., -0.6992, -3.2812, -1.0859]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 3])


tensor([[1.0000, 1.8047, 1.7422],
        [1.8047, 0.9922, 1.8203],
        [1.7422, 1.8203, 1.0000]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([3])


tensor([0, 1, 2], device='cuda:0')


new_candidates
torch.Size([3, 16])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 2], device='cuda:0')


new_candidate_aunts


[[0], [1], [2]]


new_candidate_logprobs
torch.Size([3])


tensor([-0.2834, -2.6107, -4.1107], device='cuda:0')


candidates
torch.Size([3, 16])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-0.2834, -2.6107, -4.1107], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ -0.9766,  -6.3438, -10.3125,  ...,   0.0000,   0.0000,   0.0000],
        [  1.5703,  -2.2031,  -5.8750,  ...,   0.0000,   0.0000,   0.0000],
        [ -0.4199,   2.5781,  -0.6797,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[9.8127e-01, 1.5861e-02, 2.1465e-03,  ..., 8.4822e-22, 5.8297e-22,
         1.6702e-22],
        [1.0000e+00, 4.0587e-10, 1.3177e-10,  ..., 1.3697e-25, 1.2088e-25,
         1.0668e-25],
        [9.9997e-01, 2.4300e-05, 9.4222e-07,  ..., 4.6267e-22, 3.1799e-22,
         3.1799e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[0.9813, 0.9971, 0.9993,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([3, 16])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([3])


tensor([-0.2834, -2.6107, -4.1107], device='cuda:0')


new_candidate_toks
torch.Size([3, 1])


tensor([[29892],
        [  310],
        [ 3186]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([3])


tensor([-1.8907e-02,  0.0000e+00, -2.7895e-05], device='cuda:0')


new_candidates
torch.Size([3, 17])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186]], device='cuda:0')


new_candidate_logprobs
torch.Size([3])


tensor([-0.3024, -2.6107, -4.1108], device='cuda:0')

TOP P PRIOR 5: (8560.22623629) 3 candidates, 0.028676537000137614 inference time, 0.02867960799994762 total time
event: message
id: 5-p
data: {"id": "5-p", "level_type": "sample", "duration": 0.028676537000137614, "nodes": [{"content": ",", "parent": 0, "prob": 0.7390782833099365}, {"content": "\u2581of", "parent": 1, "prob": 0.0734809935092926}, {"content": "\u2581world", "parent": 2, "prob": 0.016395367681980133}]}




TOP P AFTER 5: (8560.226610847) 3 candidates, 0.028676537000137614 inference time, 0.029053455000394024 total time

num_batches


1

infer start: GPU memory used: 14809 MB.

batch_candidates
torch.Size([3, 17])

batch_candidate_logprobs
torch.Size([3])

batch_logits
torch.Size([3, 17, 32064])

hidden_states[-1]
torch.Size([3, 17, 3072])
infer - after batch run: GPU memory used: 14817 MB.
FARTHEST NEIGHBORS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

candidates
torch.Size([3, 17])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-0.3024, -2.6107, -4.1108], device='cuda:0')


embeddings
torch.Size([3, 3072])


tensor([[-1.0859,  2.1406,  2.5000,  ..., -2.1094, -0.7148, -1.5078],
        [-2.9531, -2.5938,  1.4453,  ...,  1.1484, -1.6172, -0.3184],
        [-2.0312, -1.4141,  1.0547,  ..., -1.3906,  0.1191, -0.3828]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([3])


tensor([ True, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-1.0859,  2.1406,  2.5000,  ..., -2.1094, -0.7148, -1.5078]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 1])


tensor([[1.0000],
        [1.6406],
        [1.4375]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([3])


tensor([1.0000, 1.6406, 1.4375], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([3])


tensor([0.0000, 1.6406, 1.4375], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([3])


tensor([ True,  True, False], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[-1.0859,  2.1406,  2.5000,  ..., -2.1094, -0.7148, -1.5078],
        [-2.9531, -2.5938,  1.4453,  ...,  1.1484, -1.6172, -0.3184]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 2])


tensor([[1.0000, 1.6406],
        [1.6406, 0.9922],
        [1.4375, 1.6562]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([3])


tensor([1.0000, 0.9922, 1.4375], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([3])


tensor([0.0000, 0.0000, 1.4375], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([3])


tensor([True, True, True], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[-1.0859,  2.1406,  2.5000,  ..., -2.1094, -0.7148, -1.5078],
        [-2.9531, -2.5938,  1.4453,  ...,  1.1484, -1.6172, -0.3184],
        [-2.0312, -1.4141,  1.0547,  ..., -1.3906,  0.1191, -0.3828]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([3, 3])


tensor([[1.0000, 1.6406, 1.4375],
        [1.6406, 0.9922, 1.6562],
        [1.4375, 1.6562, 1.0000]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([3])


tensor([0, 1, 2], device='cuda:0')


new_candidates
torch.Size([3, 17])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186]], device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 2], device='cuda:0')


new_candidate_aunts


[[0], [1], [2]]


new_candidate_logprobs
torch.Size([3])


tensor([-0.3024, -2.6107, -4.1108], device='cuda:0')


candidates
torch.Size([3, 17])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186]], device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-0.3024, -2.6107, -4.1108], device='cuda:0')


last_tok_logits
torch.Size([3, 32064])


tensor([[ -3.5156,  -2.1094, -10.3125,  ...,   0.0000,   0.0000,   0.0000],
        [  0.3340,  -0.2949,  -6.8750,  ...,   0.0000,   0.0000,   0.0000],
        [ -0.1992,  -3.4062,  -8.2500,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[4.2495e-01, 2.9206e-01, 2.2746e-01,  ..., 1.3078e-19, 1.0185e-19,
         7.9321e-20],
        [4.5078e-01, 1.6583e-01, 1.6583e-01,  ..., 1.0251e-18, 6.2173e-19,
         4.8420e-19],
        [7.6299e-01, 1.1701e-01, 9.1126e-02,  ..., 3.7954e-21, 2.9558e-21,
         6.5953e-22]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[0.4249, 0.7170, 0.9445,  ..., 1.0000, 1.0000, 1.0000],
        [0.4508, 0.6166, 0.7824,  ..., 1.0000, 1.0000, 1.0000],
        [0.7630, 0.8800, 0.9711,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([6])


tensor([0, 0, 1, 1, 2, 2], device='cuda:0')


carryover_candidates
torch.Size([6, 17])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186]], device='cuda:0')


carryover_candidate_logprobs
torch.Size([6])


tensor([-0.3024, -0.3024, -2.6107, -2.6107, -4.1108, -4.1108], device='cuda:0')


new_candidate_toks
torch.Size([6, 1])


tensor([[  408],
        [ 2729],
        [11858],
        [ 3171],
        [29892],
        [  338]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([6])


tensor([-0.8558, -1.2308, -0.7968, -1.7968, -0.2705, -2.1455], device='cuda:0')


new_candidates
torch.Size([6, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,  2729],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186,   338]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([6])


tensor([-1.1581, -1.5331, -3.4075, -4.4075, -4.3813, -6.2563], device='cuda:0')

TOP P PRIOR 6: (8560.403769524) 6 candidates, 0.030444187001194223 inference time, 0.0304478170000948 total time
event: message
id: 6-p
data: {"id": "6-p", "level_type": "sample", "duration": 0.030444187001194223, "nodes": [{"content": "\u2581as", "parent": 0, "prob": 0.3140707314014435}, {"content": "\u2581based", "parent": 0, "prob": 0.21585746109485626}, {"content": "\u2581elev", "parent": 1, "prob": 0.03312361240386963}, {"content": "\u2581height", "parent": 1, "prob": 0.012185496278107166}, {"content": ",", "parent": 2, "prob": 0.012509509921073914}, {"content": "\u2581is", "parent": 2, "prob": 0.0019183953991159797}]}




TOP P AFTER 6: (8560.404458674) 6 candidates, 0.030444187001194223 inference time, 0.031136888001128682 total time

num_batches


1

infer start: GPU memory used: 14817 MB.

batch_candidates
torch.Size([6, 18])

batch_candidate_logprobs
torch.Size([6])

batch_logits
torch.Size([6, 18, 32064])

hidden_states[-1]
torch.Size([6, 18, 3072])
infer - after batch run: GPU memory used: 14997 MB.
FARTHEST NEIGHBORS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

candidates
torch.Size([6, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,  2729],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186,   338]],
       device='cuda:0')


candidate_logprobs
torch.Size([6])


tensor([-1.1581, -1.5331, -3.4075, -4.4075, -4.3813, -6.2563], device='cuda:0')


embeddings
torch.Size([6, 3072])


tensor([[-1.1953,  0.5078,  0.2773,  ..., -1.6484,  1.0625, -5.2812],
        [ 0.5156, -1.3750,  0.4180,  ...,  0.3477,  0.5781, -3.8750],
        [ 1.5938, -2.2969,  1.6953,  ...,  1.1328,  1.2344, -0.3027],
        [-0.7188,  0.2754,  3.2812,  ...,  0.1235,  1.1641, -0.3164],
        [-1.3672,  1.5156,  0.9297,  ..., -0.8477,  0.0498, -0.8242],
        [ 1.6875,  1.0000,  0.3555,  ...,  0.8555, -1.0547, -0.2520]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([6])


tensor([ True, False, False, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-1.1953,  0.5078,  0.2773,  ..., -1.6484,  1.0625, -5.2812]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([6, 1])


tensor([[1.0000],
        [1.6250],
        [1.8750],
        [1.6875],
        [1.5234],
        [1.7422]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([6])


tensor([1.0000, 1.6250, 1.8750, 1.6875, 1.5234, 1.7422], device='cuda:0',
       dtype=torch.bfloat16)


min_remaining_distances
torch.Size([6])


tensor([0.0000, 1.6250, 1.8750, 1.6875, 1.5234, 1.7422], device='cuda:0',
       dtype=torch.bfloat16)


selected (end of loop)
torch.Size([6])


tensor([ True, False,  True, False, False, False], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[-1.1953,  0.5078,  0.2773,  ..., -1.6484,  1.0625, -5.2812],
        [ 1.5938, -2.2969,  1.6953,  ...,  1.1328,  1.2344, -0.3027]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([6, 2])


tensor([[1.0000, 1.8750],
        [1.6250, 1.8906],
        [1.8750, 1.0000],
        [1.6875, 1.7188],
        [1.5234, 1.8828],
        [1.7422, 1.9375]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([6])


tensor([1.0000, 1.6250, 1.0000, 1.6875, 1.5234, 1.7422], device='cuda:0',
       dtype=torch.bfloat16)


min_remaining_distances
torch.Size([6])


tensor([0.0000, 1.6250, 0.0000, 1.6875, 1.5234, 1.7422], device='cuda:0',
       dtype=torch.bfloat16)


selected (end of loop)
torch.Size([6])


tensor([ True, False,  True, False, False,  True], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[-1.1953,  0.5078,  0.2773,  ..., -1.6484,  1.0625, -5.2812],
        [ 1.5938, -2.2969,  1.6953,  ...,  1.1328,  1.2344, -0.3027],
        [ 1.6875,  1.0000,  0.3555,  ...,  0.8555, -1.0547, -0.2520]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([6, 3])


tensor([[1.0000, 1.8750, 1.7422],
        [1.6250, 1.8906, 1.8125],
        [1.8750, 1.0000, 1.9375],
        [1.6875, 1.7188, 1.7969],
        [1.5234, 1.8828, 1.6562],
        [1.7422, 1.9375, 0.9922]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([6])


tensor([0, 0, 1, 0, 0, 2], device='cuda:0')


new_candidates
torch.Size([3, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186,   338]],
       device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 2, 5], device='cuda:0')


new_candidate_aunts


[[0, 1, 3, 4], [2], [5]]


new_candidate_logprobs
torch.Size([3])


tensor([-11.4800,  -3.4075,  -6.2563], device='cuda:0')


candidates
torch.Size([6, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,   408],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   373, 11563, 29892,  2729],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186,   338]],
       device='cuda:0')


candidate_logprobs
torch.Size([6])


tensor([-1.1581, -1.5331, -3.4075, -4.4075, -4.3813, -6.2563], device='cuda:0')


embeddings_np
(6, 3072)


array([[-1.1953125 ,  0.5078125 ,  0.27734375, ..., -1.6484375 ,
         1.0625    , -5.28125   ],
       [ 0.515625  , -1.375     ,  0.41796875, ...,  0.34765625,
         0.578125  , -3.875     ],
       [ 1.59375   , -2.296875  ,  1.6953125 , ...,  1.1328125 ,
         1.234375  , -0.30273438],
       [-0.71875   ,  0.27539062,  3.28125   , ...,  0.12353516,
         1.1640625 , -0.31640625],
       [-1.3671875 ,  1.515625  ,  0.9296875 , ..., -0.84765625,
         0.04980469, -0.82421875],
       [ 1.6875    ,  1.        ,  0.35546875, ...,  0.85546875,
        -1.0546875 , -0.25195312]], dtype=float32)


k_mean_space
(6, 3)


array([[ 98.4211  ,  55.8501  , 110.24304 ],
       [100.15504 ,  60.411293, 110.894684],
       [100.08185 ,  94.79501 ,   0.      ],
       [  0.      ,  79.6112  , 100.08185 ],
       [ 90.86594 ,  52.89082 , 110.886734],
       [105.63401 ,  64.922554, 113.923164]], dtype=float32)


k_mean_clusters
(6,)


array([1, 1, 2, 0, 1, 1], dtype=int32)


k_mean_logprob_mass
(3,)


array([-4.40750886, -0.60815168, -3.40750889])


closest
(3,)


array([3, 4, 2])


new_candidates
torch.Size([3, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858]],
       device='cuda:0')


new_candidate_parents


[3, 4, 2]


new_candidate_aunts


[[3], [0, 1, 4, 5], [2]]


new_candidate_logprobs
torch.Size([3])


tensor([-4.4075, -0.6082, -3.4075], device='cuda:0', dtype=torch.float64)

K MEANS PRIOR 7: (8560.571607967) 3 candidates, 0.020589703000950976 inference time, 0.020593103999999585 total time
event: message
id: 7-k"
data: {"id": "7-k", "level_type": "gather", "duration": 0.020589703000950976, "nodes": [{"content": "\u2581height", "parent": 3, "aunts": [3], "prob": 0.012185496278107165}, {"content": ",", "parent": 4, "aunts": [0, 1, 4, 5], "prob": 0.5443560830317438}, {"content": "\u2581elev", "parent": 2, "aunts": [2], "prob": 0.03312361240386963}]}




K MEANS AFTER 7: (8560.572321589) 3 candidates, 0.020589703000950976 inference time, 0.021306536000338383 total time

candidates
torch.Size([3, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858]],
       device='cuda:0')


candidate_logprobs
torch.Size([3])


tensor([-4.4075, -0.6082, -3.4075], device='cuda:0', dtype=torch.float64)


last_tok_logits
torch.Size([3, 32064])


tensor([[  0.0183,   1.5000,  -7.1250,  ...,   0.0000,   0.0000,   0.0000],
        [ -3.3438,  -2.0469, -10.8750,  ...,   0.0000,   0.0000,   0.0000],
        [ -0.8047,  -1.1484,  -1.5938,  ...,   0.0000,   0.0000,   0.0000]],
       device='cuda:0')


sorted_logits
torch.Size([3, 32064])

sorted_indices
torch.Size([3, 32064])

sorted_probs
torch.Size([3, 32064])


tensor([[9.9433e-01, 2.1751e-03, 1.3193e-03,  ..., 6.3509e-21, 6.3509e-21,
         3.3994e-21],
        [7.1745e-01, 1.2467e-01, 6.6733e-02,  ..., 2.3272e-20, 1.4115e-20,
         7.0274e-22],
        [9.9999e-01, 6.1442e-06, 1.7603e-06,  ..., 1.7940e-23, 1.3972e-23,
         7.4785e-24]], device='cuda:0')

tensor([1.0000, 1.0000, 1.0000], device='cuda:0')


cum_probs
torch.Size([3, 32064])


tensor([[0.9943, 0.9965, 0.9978,  ..., 1.0000, 1.0000, 1.0000],
        [0.7174, 0.8421, 0.9089,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')


keep_indices
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


keep_indices after top_k
torch.Size([3, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]], device='cuda:0')


new_candidate_parents
torch.Size([4])


tensor([0, 1, 1, 2], device='cuda:0')


carryover_candidates
torch.Size([4, 18])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186, 29892],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858]],
       device='cuda:0')


carryover_candidate_logprobs
torch.Size([4])


tensor([-4.4075, -0.6082, -0.6082, -3.4075], device='cuda:0',
       dtype=torch.float64)


new_candidate_toks
torch.Size([4, 1])


tensor([[ 2038],
        [  746],
        [17005],
        [  362]], device='cuda:0')


new_candidate_tok_logprobs
torch.Size([4])


tensor([-5.6907e-03, -3.3205e-01, -2.0821e+00, -8.7023e-06], device='cuda:0')


new_candidates
torch.Size([4, 19])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186, 29892,   746],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186, 29892, 17005],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858,   362]],
       device='cuda:0')


new_candidate_logprobs
torch.Size([4])


tensor([-4.4132, -0.9402, -2.6902, -3.4075], device='cuda:0',
       dtype=torch.float64)

TOP P PRIOR 7: (8560.599923596) 4 candidates, 0.027590836998570012 inference time, 0.02759406699988176 total time
event: message
id: 7-p
data: {"id": "7-p", "level_type": "sample", "duration": 0.027590836998570012, "nodes": [{"content": "\u2581above", "parent": 0, "prob": 0.012116349181550013}, {"content": "\u2581when", "parent": 1, "prob": 0.3905473049217478}, {"content": "\u2581measured", "parent": 1, "prob": 0.0678669371898257}, {"content": "ation", "parent": 2, "prob": 0.03312332415297053}]}




TOP P AFTER 7: (8560.600348328) 4 candidates, 0.027590836998570012 inference time, 0.02801866799927666 total time

num_batches


1

infer start: GPU memory used: 14997 MB.

batch_candidates
torch.Size([4, 19])

batch_candidate_logprobs
torch.Size([4])

batch_logits
torch.Size([4, 19, 32064])

hidden_states[-1]
torch.Size([4, 19, 3072])
infer - after batch run: GPU memory used: 14985 MB.
FARTHEST NEIGHBORS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

candidates
torch.Size([4, 19])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186, 29892,   746],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186, 29892, 17005],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858,   362]],
       device='cuda:0')


candidate_logprobs
torch.Size([4])


tensor([-4.4132, -0.9402, -2.6902, -3.4075], device='cuda:0',
       dtype=torch.float64)


embeddings
torch.Size([4, 3072])


tensor([[-2.4375, -0.9336,  1.3281,  ..., -0.2812, -0.4805, -0.6836],
        [-0.3789,  0.7695, -0.2373,  ..., -1.6641,  0.4238, -1.8281],
        [-0.8438, -0.2969,  0.3086,  ...,  1.1094, -0.2217, -2.2500],
        [-1.2266, -1.7266,  1.8281,  ..., -0.3574,  0.8828,  0.3281]],
       device='cuda:0', dtype=torch.bfloat16)


selected
torch.Size([4])


tensor([False,  True, False, False], device='cuda:0')


selected_embeddings
torch.Size([1, 3072])


tensor([[-0.3789,  0.7695, -0.2373,  ..., -1.6641,  0.4238, -1.8281]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 1])


tensor([[1.7734],
        [1.0000],
        [1.5938],
        [1.6484]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([1.7734, 1.0000, 1.5938, 1.6484], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([1.7734, 0.0000, 1.5938, 1.6484], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([ True,  True, False, False], device='cuda:0')


selected_embeddings
torch.Size([2, 3072])


tensor([[-2.4375, -0.9336,  1.3281,  ..., -0.2812, -0.4805, -0.6836],
        [-0.3789,  0.7695, -0.2373,  ..., -1.6641,  0.4238, -1.8281]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 2])


tensor([[0.9922, 1.7734],
        [1.7734, 1.0000],
        [1.7656, 1.5938],
        [1.7188, 1.6484]], device='cuda:0', dtype=torch.bfloat16)


min_distances
torch.Size([4])


tensor([0.9922, 1.0000, 1.5938, 1.6484], device='cuda:0', dtype=torch.bfloat16)


min_remaining_distances
torch.Size([4])


tensor([0.0000, 0.0000, 1.5938, 1.6484], device='cuda:0', dtype=torch.bfloat16)


selected (end of loop)
torch.Size([4])


tensor([ True,  True, False,  True], device='cuda:0')


selected_embeddings
torch.Size([3, 3072])


tensor([[-2.4375, -0.9336,  1.3281,  ..., -0.2812, -0.4805, -0.6836],
        [-0.3789,  0.7695, -0.2373,  ..., -1.6641,  0.4238, -1.8281],
        [-1.2266, -1.7266,  1.8281,  ..., -0.3574,  0.8828,  0.3281]],
       device='cuda:0', dtype=torch.bfloat16)


distances
torch.Size([4, 3])


tensor([[0.9922, 1.7734, 1.7188],
        [1.7734, 1.0000, 1.6484],
        [1.7656, 1.5938, 1.6719],
        [1.7188, 1.6484, 1.0000]], device='cuda:0', dtype=torch.bfloat16)


closest_per_candidate
torch.Size([4])


tensor([0, 1, 1, 2], device='cuda:0')


new_candidates
torch.Size([3, 19])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310,  3171,  2038],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,   278,  3186, 29892,   746],
        [    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001,   450,  9939, 14378,   297,  4958,   310, 11858,   362]],
       device='cuda:0')


new_candidate_parents
torch.Size([3])


tensor([0, 1, 3], device='cuda:0')


new_candidate_aunts


[[0], [1, 2], [3]]

RuntimeError: index_add_(): self (Float) and source (Double) must have the same scalar type

In [None]:
it = InferenceTensor()

for x in it.candidates_generator(top_p=0.9, top_p_decay=0.99, top_k=2, max_beams=3, max_new_tokens=6, prompt='What is the highest mountain?'):
    print(x)
    print()
    print('====================================')
    print()

In [8]:
a = torch.randn(3, 5)
display(a)
a.min(dim=1)

tensor([[ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845],
        [-1.3986,  0.4033,  0.8380, -0.7193, -0.4033],
        [-0.5966,  0.1820, -0.8567,  1.1006, -1.0712]])

torch.return_types.min(
values=tensor([-2.1788, -1.3986, -1.0712]),
indices=tensor([2, 0, 4]))

In [None]:
a.view((1, a.shape[1], a.shape[0]))

In [None]:
a.view((a.shape[0], a.shape[1], 1))

In [10]:
F.cosine_similarity(a.view((1, a.shape[1], a.shape[0])), a.view((a.shape[0], a.shape[1], 1)), dim=1)

tensor([[ 0.3547, -0.6298, -0.0134],
        [-0.2626,  0.1878,  0.4022],
        [-0.1902, -0.7346,  0.5564]])

In [9]:
F.cosine_similarity(a, a)

tensor([1.0000, 1.0000, 1.0000])

In [11]:
from torchmetrics.functional import pairwise_cosine_similarity

In [17]:
pairwise_cosine_similarity(a, a[[0, 1, 1]])

tensor([[ 1.0000, -0.7373, -0.7373],
        [-0.7373,  1.0000,  1.0000],
        [ 0.4869, -0.0486, -0.0486]])

In [13]:
pairwise_cosine_similarity(a)

tensor([[ 0.0000, -0.7373,  0.4869],
        [-0.7373,  0.0000, -0.0486],
        [ 0.4869, -0.0486,  0.0000]])

In [27]:
b = torch.randn(3, 5, dtype=torch.float16).to('cuda')
b

tensor([[ 0.0571,  0.2240,  0.5518, -0.5786,  0.0177],
        [ 0.1318,  1.0195, -0.4468,  0.4519, -0.9761],
        [ 0.7114, -0.7583, -0.6436, -0.6460, -0.1591]], device='cuda:0',
       dtype=torch.float16)

In [28]:
torch.add(2, pairwise_cosine_similarity(b, b), alpha=-1)

tensor([[1.0000, 2.2227, 2.0977],
        [2.2227, 1.0000, 2.2441],
        [2.0977, 2.2441, 1.0010]], device='cuda:0', dtype=torch.float16)

In [14]:
1 - pairwise_cosine_similarity(a, a)

tensor([[0.0000, 1.7373, 0.5131],
        [1.7373, 0.0000, 1.0486],
        [0.5131, 1.0486, 0.0000]])