<a href="https://colab.research.google.com/github/AndreSlavescu/Token-Sampling/blob/main/token_decoding_techniques.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Token Decoding Methods

1. Gumbel Max-Trick
2. Top-K Decoding
2. Greedy Decoding


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import urllib
from torch.utils.data import Dataset, DataLoader

url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = urllib.request.urlopen(url)
shakespeare_text = response.read().decode('utf-8')[:100000]

chars = sorted(list(set(shakespeare_text)))
vocab_size = len(chars)
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}

data = [char_to_idx[ch] for ch in shakespeare_text]

seq_length = 64

class DatasetLoader(Dataset):
    def __init__(self, data, seq_length):
        self.data = data
        self.seq_length = seq_length

    def __len__(self):
        return len(self.data) - self.seq_length

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.seq_length]
        y = self.data[idx + 1:idx + self.seq_length + 1]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

batch_size = seq_length
dataset = DatasetLoader(data, seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [3]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(LSTMLanguageModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden):
        x = self.embedding(x)
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out)
        return out, hidden

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        return (weight.new(self.lstm.num_layers, batch_size, self.lstm.hidden_size).zero_(),
                weight.new(self.lstm.num_layers, batch_size, self.lstm.hidden_size).zero_())

embedding_dim = 128
hidden_dim = 256
num_layers = 2
model = LSTMLanguageModel(vocab_size, embedding_dim, hidden_dim, num_layers)

In [5]:
import time

num_epochs = 5
learning_rate = 0.001

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    model.train()
    start_time = time.time()

    total_loss = 0
    batch_count = len(dataloader)

    for batch_idx, (x, y) in enumerate(dataloader):
        hidden = model.init_hidden(x.size(0))
        hidden = tuple([h.data for h in hidden])

        model.zero_grad()
        output, hidden = model(x, hidden)
        loss = criterion(output.view(-1, vocab_size), y.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if (batch_idx + 1) % 10 == 0:
            avg_loss = total_loss / (batch_idx + 1)
            elapsed_time = time.time() - start_time
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{batch_count}], '
                  f'Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}, Time: {elapsed_time:.2f}s')

    avg_epoch_loss = total_loss / batch_count
    print(f'End of Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_epoch_loss:.4f}, '
          f'Total Time: {time.time() - start_time:.2f}s')

torch.save(model.state_dict(), 'gumbel_sampling_shakespeare_lstm.pth')

Epoch [1/5], Batch [10/1562], Loss: 1.5700, Avg Loss: 1.5771, Time: 1.65s
Epoch [1/5], Batch [20/1562], Loss: 1.4954, Avg Loss: 1.5583, Time: 3.25s
Epoch [1/5], Batch [30/1562], Loss: 1.5441, Avg Loss: 1.5519, Time: 4.85s
Epoch [1/5], Batch [40/1562], Loss: 1.4940, Avg Loss: 1.5399, Time: 6.53s
Epoch [1/5], Batch [50/1562], Loss: 1.5513, Avg Loss: 1.5344, Time: 8.23s
Epoch [1/5], Batch [60/1562], Loss: 1.4678, Avg Loss: 1.5286, Time: 10.07s
Epoch [1/5], Batch [70/1562], Loss: 1.4679, Avg Loss: 1.5234, Time: 12.03s
Epoch [1/5], Batch [80/1562], Loss: 1.4875, Avg Loss: 1.5205, Time: 13.67s
Epoch [1/5], Batch [90/1562], Loss: 1.4634, Avg Loss: 1.5162, Time: 15.27s
Epoch [1/5], Batch [100/1562], Loss: 1.4520, Avg Loss: 1.5114, Time: 16.86s
Epoch [1/5], Batch [110/1562], Loss: 1.4274, Avg Loss: 1.5073, Time: 18.48s
Epoch [1/5], Batch [120/1562], Loss: 1.4642, Avg Loss: 1.5031, Time: 20.12s
Epoch [1/5], Batch [130/1562], Loss: 1.4498, Avg Loss: 1.4987, Time: 21.97s
Epoch [1/5], Batch [140/15

In [13]:
import torch
import torch.nn.functional as F

def gumbel_distribution_sample(model, start_seq, max_len, temperature=1.0):
    model.eval()
    chars = [char_to_idx[ch] for ch in start_seq]
    input_seq = torch.tensor(chars, dtype=torch.long).unsqueeze(0)

    hidden = model.init_hidden(1)
    avg_log_prob = 0

    for _ in range(max_len):
        output, hidden = model(input_seq, hidden)
        logits = output[:, -1, :] / temperature

        # Add Gumbel noise
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits)))
        noisy_logits = logits + gumbel_noise
        next_char_idx = torch.argmax(noisy_logits).item()
        chars.append(next_char_idx)
        input_seq = torch.tensor([[next_char_idx]], dtype=torch.long)

        # Calculate log probability
        log_probs = F.log_softmax(noisy_logits, dim=-1)
        avg_log_prob += log_probs[0, next_char_idx].item()

    avg_perplexity = torch.exp(torch.tensor(-avg_log_prob / max_len))

    return ''.join(idx_to_char[idx] for idx in chars), avg_perplexity.item()

def top_k_sample(model, start_seq, max_len, k=5):
    model.eval()
    chars = [char_to_idx[ch] for ch in start_seq]
    input_seq = torch.tensor(chars, dtype=torch.long).unsqueeze(0)

    hidden = model.init_hidden(1)
    avg_log_prob = 0

    for _ in range(max_len):
        output, hidden = model(input_seq, hidden)
        logits = output[:, -1, :]
        probabilities = F.softmax(logits, dim=-1)
        top_k_probs, top_k_indices = torch.topk(probabilities, k)
        next_char_idx = top_k_indices[0, torch.multinomial(top_k_probs, 1)].item()
        chars.append(next_char_idx)
        input_seq = torch.tensor([[next_char_idx]], dtype=torch.long)

        # Calculate log probability
        log_probs = F.log_softmax(logits, dim=-1)
        avg_log_prob += log_probs[0, next_char_idx].item()

    avg_perplexity = torch.exp(torch.tensor(-avg_log_prob / max_len))

    return ''.join(idx_to_char[idx] for idx in chars), avg_perplexity.item()

def greedy_sample(model, start_seq, max_len):
    model.eval()
    chars = [char_to_idx[ch] for ch in start_seq]
    input_seq = torch.tensor(chars, dtype=torch.long).unsqueeze(0)

    hidden = model.init_hidden(1)
    avg_log_prob = 0

    for _ in range(max_len):
        output, hidden = model(input_seq, hidden)
        logits = output[:, -1, :]
        next_char_idx = torch.argmax(logits).item()
        chars.append(next_char_idx)
        input_seq = torch.tensor([[next_char_idx]], dtype=torch.long)

        # Calculate log probability
        log_probs = F.log_softmax(logits, dim=-1)
        avg_log_prob += log_probs[0, next_char_idx].item()

    avg_perplexity = torch.exp(torch.tensor(-avg_log_prob / max_len))

    return ''.join(idx_to_char[idx] for idx in chars), avg_perplexity.item()

start_seq = "To be or not to be"

gumbel_distribution_text, gumbel_distribution_perplexity = gumbel_distribution_sample(model, start_seq, 100)
top_k_text, top_k_perplexity = top_k_sample(model, start_seq, 100)
greedy_text, greedy_perplexity = greedy_sample(model, start_seq, 100)

gumbel_distribution_generated_text = f"gumbel max-trick sampling:\n\n{gumbel_distribution_text} \n\nperplexity: {gumbel_distribution_perplexity}\n"
top_k_generated_text = f"top-k sampling:\n{top_k_text} \n\nperplexity: {top_k_perplexity}\n"
greedy_generated_text = f"greedy sampling:\n{greedy_text} \n\nperplexity: {greedy_perplexity}\n"

print(gumbel_distribution_generated_text)
print()
print(greedy_generated_text)
print()
print(top_k_generated_text)


gumbel max-trick sampling:

To be or not to be their loves and safeguard
Of what that want might ruin.

MENENIUS:
Noble lady!
Come, go with us; sp 

perplexity: 1.042262077331543


greedy sampling:
To be or not to be their bedfellow.
Worthy Cominius, speak.
Nay, keep your place.

First Senator:
Sit, Coriolanus; nev 

perplexity: 1.030051350593567


top-k sampling:
To be or not to be say?

First Citizen:
It was an answer: how apply you to the people's voices,
Allow their officers a 

perplexity: 1.127898931503296

