Greedy decoding and beam search are two common strategies for generating sequences from autoregressive models.
Greedy decoding selects the most probable token at each step, while beam search keeps multiple candidate sequences to explore better global solutions.

In [1]:
import math

In [2]:
# Toy vocabulary
vocab = ["<SOS>", "I", "you", "love", "hate", "cats", "dogs", "<EOS>"]

In [3]:
# A fake model
def toy_model(prev_token):
    probs = {
        "<SOS>": {"I": 0.6, "you": 0.4},
        "I": {"love": 0.51, "hate": 0.49},
        "you": {"love": 0.9, "hate": 0.1},
        "love": {"cats": 0.5, "dogs": 0.5},
        "hate": {"cats": 0.9, "dogs": 0.1},
        "cats": {"<EOS>": 1.0},
        "dogs": {"<EOS>": 1.0},
    }
    return probs.get(prev_token, {"<EOS>": 1.0})

In [4]:
# Greedy decoding
def greedy_decode():
    sequence = ["<SOS>"]
    score = 0.0

    while True:
        prev = sequence[-1]
        next_probs = toy_model(prev)

        # Pick highest probability token
        next_token = max(next_probs, key=next_probs.get)
        score += math.log(next_probs[next_token])

        sequence.append(next_token)
        if next_token == "<EOS>":
            break

    return sequence, score

In [5]:
greedy_seq, greedy_score = greedy_decode()
print("Greedy output:", greedy_seq)
print("Greedy log-prob:", greedy_score)

Greedy output: ['<SOS>', 'I', 'love', 'cats', '<EOS>']
Greedy log-prob: -1.8773173575897015


In [7]:
# Beam search
def beam_search_decode(beam_size=2):
    beams = [(["<SOS>"], 0.0)]

    while True:
        new_beams = []

        for seq, score in beams:
            prev = seq[-1]
            if prev == "<EOS>":
                new_beams.append((seq, score))
                continue

            for token, prob in toy_model(prev).items():
                new_seq = seq + [token]
                new_score = score + math.log(prob)
                new_beams.append((new_seq, new_score))

        # Keep top-k beams
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]

        # Stop if all beams ended
        if all(seq[-1] == "<EOS>" for seq, _ in beams):
            break

    return beams

In [8]:
beams = beam_search_decode(beam_size=2)
for seq, score in beams:
    print("Beam output:", seq, "log-prob:", score)

Beam output: ['<SOS>', 'you', 'love', 'cats', '<EOS>'] log-prob: -1.7147984280919264
Beam output: ['<SOS>', 'you', 'love', 'dogs', '<EOS>'] log-prob: -1.7147984280919264


In [9]:
# CCompare results
print("Greedy:", greedy_seq)
print("Best beam:", beams[0][0])

Greedy: ['<SOS>', 'I', 'love', 'cats', '<EOS>']
Best beam: ['<SOS>', 'you', 'love', 'cats', '<EOS>']


Greedy decoding is fast but short-sighted, making locally optimal decisions that can lead to suboptimal sequences.

Beam search trades speed for quality by keeping multiple hypotheses, often producing more fluent and higher-probability outputs.