In [2]:
import torch
import torch.nn.functional as F
from modeling_llama import LlamaForCausalLM
from transformers import LlamaTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_name = "meta-llama/Llama-2-7b-hf"  # Adjust based on your access
tokenizer = LlamaTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  5.57it/s]


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

In [97]:
from typing import List
class SearchNode:
    def __init__(self, root, idx, token_id, token_score):
        self.root: 'SearchTree' = root
        self.idx: int = idx
        self.token_id: Tensor = token_id
        self.token_score: float = token_score
        self.parent: Optional['SearchNode'] = None
        self.children: List['SearchNode'] = []
        self.acc_score: float = token_score


    def add_children(self, child):
        self.children.append(child)
        child.parent = self
        child.acc_score = self.acc_score + child.token_score
        self.root.node_count += 1

    def delete_child(self, child):
        self.children.remove(child)
        self.root.node_count -= 1


class SearchTree:
    def __init__(self, beam_width=3):
        self.node_count: int = 0
        self.model = model
        self.device = model.device
        self.root: List[SearchNode] = []
        self.beam_width: int = beam_width

def generate_causal_mask(searchTree: SearchTree,input_len: int,nodes: List[SearchNode]) -> torch.Tensor:
    branch_count = len(nodes)
    mask = torch.full((1, 1, branch_count, searchTree.node_count + input_len), -65504, device=device, dtype=torch.float16)
    mask[0, 0,:,:input_len] = 0
    tmp = nodes.copy()
    #print("========")
    while True:
        end = False
        for i in range(branch_count):
            #print(i, tmp[i].idx)
            mask[0, 0, i, tmp[i].idx + input_len] = 0
            if tmp[i].parent is not None:
                tmp[i] = tmp[i].parent
            else:
                end = True
        if end:
            return mask


In [98]:
import torch
import torch.nn.functional as F
from collections import deque

def prune_completed(searchTree: SearchTree, leaf: SearchNode):
    prune_node = leaf
    prune_nodes = []
    while True:
        prune_nodes.append(prune_node)
        if prune_node.parent == None:
            break
        prune_node.parent.delete_child(prune_node)
        if len(prune_node.parent.children) == 0:
            break
        prune_node = prune_node.parent
    
    
def prune_kv_cache(past_key_values, prune_nodes: List[int]):
    pass


@torch.no_grad()
def generate_next_tokens(model, input_ids, beam_width = 3, max_tokens=300):
    device = model.device
    past_key_values = None
    input_len = input_ids.shape[1]
    print("input length: ", input_len)

    #generate the first 3 tokens
    outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
    past_key_values = outputs.past_key_values
    token_scores = F.log_softmax(outputs.logits, dim=-1)

    token_scores, tokens = torch.topk(token_scores, beam_width, dim=-1, largest=True, sorted=True)
    searchTree = SearchTree(beam_width = beam_width)
    newest_branch = []
    idx = 0

    #define eos token
    eos_token_id = model.config.eos_token_id

    
    for i in range(beam_width):
        searchNode = SearchNode(searchTree, idx, tokens[0][-1][i],token_scores[0][-1][i])
        idx += 1
        newest_branch.append(searchNode)
        searchTree.root.append(searchNode)
        searchTree.node_count += 1
    
    completed_branches = []
    alive_beams = beam_width

    for i in range(input_len, max_tokens):
        #construct position_ids
        print("alive_beams: ", alive_beams)

        position_ids = torch.tensor([[i for _ in range(alive_beams)]], device=device)
        
        #construct attention_mask
        attention_mask = generate_causal_mask(searchTree,input_len , newest_branch)
        #print(attention_mask)
        #print(attention_mask[0][0])

        #construct input_ids
        input_ids = torch.tensor([[node.token_id for node in newest_branch]], device=device)
        
        #generate candidate tokens
        #print("kv: ", past_key_values[0][0].shape)
        #print("atm: ", attention_mask.shape)
        #print("node_count: ", searchTree.node_count)
        outputs = model(input_ids, past_key_values=past_key_values, position_ids=position_ids, attention_mask=attention_mask, use_cache=True)
        past_key_values = outputs.past_key_values
        #calculate token scores
        token_scores = F.log_softmax(outputs.logits, dim=-1)
        token_scores, tokens = torch.topk(token_scores, alive_beams, dim=-1, largest=True, sorted=True)
        #print(token_scores.shape)
        #list n candidates from n branches
        candidates = torch.empty(0, device=model.device)
        candidate_scores = []
        for i in range(alive_beams):
            branch_score = newest_branch[i].acc_score
            for j in range(alive_beams):
                candidates = torch.cat((candidates, branch_score+token_scores[0][i][j].unsqueeze(0)))
                candidate_scores.append(token_scores[0][i][j])
        #print(candidates)
        wa, token_idxs = torch.topk(candidates, alive_beams, dim=-1, largest=True, sorted=True)
        #print("picks ", token_idxs, wa)

        #update newest_branch and searchTree

        tmp_newest_branch = []
        
        completed_nodes = []
        for i in range(alive_beams):
            token_idx = token_idxs[i]
            token_id = tokens[0][int(token_idx/alive_beams)][int(token_idx%alive_beams)]
            searchNode = SearchNode(searchTree, idx, token_id=token_id, token_score = candidate_scores[i])
            
            #print(int(token_idx/beam_width)," add child")
            newest_branch[int(token_idx/alive_beams)].add_children(searchNode)
            if token_id == eos_token_id:
                print(searchNode.idx, "ended")
                completed_nodes.append(searchNode)
                completed_branches.append(searchNode)
                searchTree.node_count -= 1
                #tmp_newest_branch.append(searchNode)
            else:
                idx += 1
                tmp_newest_branch.append(searchNode)
                
        alive_beams -= len(completed_nodes)
        newest_branch = tmp_newest_branch
    
    #find the best branch
    max_score=0
    max_idx = 0
    for i in range(alive_beams):
        if newest_branch[i].acc_score > max_score:
            max_score = newest_branch[i].acc_score
            max_idx = i

    #construct the output
    outputs = []
    newest_branch = newest_branch + completed_branches
    for i in range(beam_width):
        output = torch.empty(0, device=model.device)
        branch_parent = newest_branch[i]
        while branch_parent is not None:
            output = torch.cat((output, branch_parent.token_id.unsqueeze(0)))
            branch_parent = branch_parent.parent
        output=output.flip(dims=[0])
        outputs.append(output)
        #outputs = torch.cat((outputs, output.unsqueeze(0)))
    return outputs




In [100]:
input_ids = tokenizer.encode("Such a nice day.", return_tensors="pt").to(model.device)
beam_width = 4
output = generate_next_tokens(model, input_ids, beam_width=beam_width, max_tokens=400)
for i in range(beam_width):
    print(":", tokenizer.decode(output[i]))

input length:  6
alive_beams:  4
alive_beams:  4
alive_beams:  4
alive_beams:  4
alive_beams:  4
alive_beams:  4
alive_beams:  4
alive_beams:  4
alive_beams:  4
alive_beams:  4
alive_beams:  4
alive_beams:  4
alive_beams:  4
alive_beams:  4
57 ended
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
alive_beams:  3
139 ended
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
alive_beams:  2
aliv

In [20]:
import torch
import torch.nn.functional as F
from collections import deque




@torch.no_grad()
def generate_next_tokens(model, input_ids, beam_width = 5, max_tokens=100):
    device = model.device
    past_key_values = None
    input_len = input_ids.shape[1]
    print("input length: ", input_len)


    outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
    past_key_values = outputs.past_key_values
    token_scores = F.log_softmax(outputs.logits, dim=-1)

    token_scores, tokens = torch.topk(token_scores, beam_width, dim=-1, largest=True, sorted=True)
    searchTree = SearchTree(beam_width = beam_width)
    newest_branch = []
    round = 0
    eos_token_id = model.config.eos_token_id
    for i in range(beam_width):
        searchNode = SearchNode(round, i,tokens[0][-1][i],token_scores[0][-1][i])
        newest_branch.append(searchNode)
        searchTree.root.append(searchNode)

    completed_branches = []

    for i in range(input_len, max_tokens):
        #construct position_ids
        position_ids = torch.tensor([[i for _ in range(beam_width)]], device=device)
        
        #construct attention_mask
        attention_mask_length = input_len + (i-input_len+1) * beam_width
        attention_mask = torch.full((1, 1, beam_width, attention_mask_length), -65504, device=device, dtype=torch.float16)
        attention_mask[0,0,:,:input_len] = 0
        for idx, node in enumerate(newest_branch):
            node_parent = node
            while node_parent is not None:
                attention_mask[0, 0, idx, input_len + node_parent.round * beam_width + node_parent.sib_idx] = 0
                node_parent = node_parent.parent

        #print("mask", attention_mask)
        #print("pos_id", position_ids)
        #construct input_ids
        input_ids = torch.tensor([[node.token_id for node in newest_branch]], device=device)
        
        #generate candidate tokens
        outputs = model(input_ids, past_key_values=past_key_values, position_ids=position_ids, attention_mask=attention_mask, use_cache=True)
        past_key_values = outputs.past_key_values
        token_scores = F.log_softmax(outputs.logits, dim=-1)
        token_scores, tokens = torch.topk(token_scores, beam_width, dim=-1, largest=True, sorted=True)
        candidates = torch.empty(0, device=model.device)
        for i in range(beam_width):
            branch_score = newest_branch[i].acc_score
            for j in range(beam_width):
                candidates = torch.cat((candidates, branch_score+token_scores[0][i][j].unsqueeze(0)))
        #print(candidates)
        token_scores, token_idxs =torch.topk(candidates, beam_width, dim=-1, largest=True, sorted=True)
        #print("selected", token_idxs)

        #update newest_branch and searchTree
        round += 1
        tmp_newest_branch = []
        #print("new tokens", tokens)
        for i in range(beam_width):
            token_idx = token_idxs[i]
            token_id = tokens[0][int(token_idx/beam_width)][token_idx%beam_width]
            searchNode = SearchNode(round=round,sib_idx=i, token_id=token_id, token_score = token_scores[i])
            print(int(token_idx/beam_width)," add child")
            newest_branch[int(token_idx/beam_width)].add_children(searchNode)
            tmp_newest_branch.append(searchNode)
        newest_branch = tmp_newest_branch
    
    #find the best branch
    max_score=0
    max_idx = 0
    for i in range(beam_width):
        if newest_branch[i].acc_score > max_score:
            max_score = newest_branch[i].acc_score
            max_idx = i

    #construct the output
    outputs = torch.empty(0, device=model.device)
    for i in range(beam_width):
        output = torch.empty(0, device=model.device)
        branch_parent = newest_branch[i]
        while branch_parent is not None:
            output = torch.cat((output, branch_parent.token_id.unsqueeze(0)))
            branch_parent = branch_parent.parent
        output=output.flip(dims=[0])
        outputs = torch.cat((outputs, output.unsqueeze(0)))
    return outputs
    
    #construct the output
    output = torch.empty(0, device=model.device)
    best_branch_parent = newest_branch[max_idx]
    while best_branch_parent is not None:
        output = torch.cat((output, best_branch_parent.token_id.unsqueeze(0)))
        best_branch_parent = best_branch_parent.parent
    output=output.flip(dims=[0])
    return output

