Skip to content

Commit

Permalink
fixed error in test_expand
Browse files Browse the repository at this point in the history
  • Loading branch information
annh3 committed Aug 16, 2024
1 parent d958912 commit e981b8b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
8 changes: 6 additions & 2 deletions ai_planning_searching/mcts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,10 @@ def expand(root:Node, tokenizer, model, k, max_beam_len):
scores = list(torch.chunk(beam_output.sequences_scores,chunks=k,dim=0))


beam_list = [(a,b,c) for a,b,c in zip(scores, next_tokens, str_repr)]
beam_list = [(a,[b],[c]) for a,b,c in zip(scores, next_tokens, str_repr)]

for _ in range(max_beam_len-1):
print("Length of beam_list: ", len(beam_list))
new_list = []
for current_path in beam_list:
"""
Expand All @@ -192,7 +193,10 @@ def expand(root:Node, tokenizer, model, k, max_beam_len):
https://huggingface.co/docs/transformers/v4.43.3/en/internal/generation_utils#transformers.generation.GenerateBeamDecoderOnlyOutput
"""
# pdb.set_trace()
current_tokens = torch.cat(current_path[1])
current_tokens = current_tokens.unsqueeze(0)
# pdb.set_trace()
beam_output = model.generate(
current_tokens,
max_new_tokens=1,
Expand All @@ -208,7 +212,7 @@ def expand(root:Node, tokenizer, model, k, max_beam_len):
scores = list(torch.chunk(beam_output.sequences_scores,chunks=k,dim=0))
for score,string,next_token in zip(scores,str_repr,next_tokens):
# add to beams
cur = (score,current_path[2]+[string],current_path[1]+[next_token])
cur = (score,current_path[1]+[next_token],current_path[2]+[string])
new_list.append(cur)

beam_list = new_list
Expand Down
4 changes: 3 additions & 1 deletion ai_planning_searching/mcts_utils_tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
python -m unittest mcts_utils_tests.py
python -m unittest mcts_utils_tests.testMCTSUtils.test_expand
"""
import pdb
import numpy as np
Expand Down Expand Up @@ -165,7 +167,7 @@ def test_expand(self):
node_0.P_s_a = node_0.P_s_a / beam_width

beam_list = expand(node_0, self.tokenizer, self.model, beam_width, max_beam_len)
self.assertEqual(len(beam_list),max_beam_len**beam_width)
self.assertEqual(len(beam_list),beam_width**max_beam_len)

def test_evaluate_full_paths(self):
seq_len = 2
Expand Down

0 comments on commit e981b8b

Please sign in to comment.