<a href="https://colab.research.google.com/github/AndreSlavescu/Token-Sampling/blob/main/token_sampling_techniques.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Gumbel Max-Trick Sampling

In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import urllib
from torch.utils.data import Dataset, DataLoader

url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = urllib.request.urlopen(url)
shakespeare_text = response.read().decode('utf-8')[:100000]

chars = sorted(list(set(shakespeare_text)))
vocab_size = len(chars)
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}

data = [char_to_idx[ch] for ch in shakespeare_text]

seq_length = 64

class DatasetLoader(Dataset):
    def __init__(self, data, seq_length):
        self.data = data
        self.seq_length = seq_length

    def __len__(self):
        return len(self.data) - self.seq_length

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.seq_length]
        y = self.data[idx + 1:idx + self.seq_length + 1]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

batch_size = seq_length
dataset = DatasetLoader(data, seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [23]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(LSTMLanguageModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden):
        x = self.embedding(x)
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out)
        return out, hidden

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        return (weight.new(self.lstm.num_layers, batch_size, self.lstm.hidden_size).zero_(),
                weight.new(self.lstm.num_layers, batch_size, self.lstm.hidden_size).zero_())

embedding_dim = 128
hidden_dim = 256
num_layers = 2
model = LSTMLanguageModel(vocab_size, embedding_dim, hidden_dim, num_layers)

In [24]:
import time

num_epochs = 5
learning_rate = 0.001

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    model.train()
    start_time = time.time()

    total_loss = 0
    batch_count = len(dataloader)

    for batch_idx, (x, y) in enumerate(dataloader):
        hidden = model.init_hidden(x.size(0))
        hidden = tuple([h.data for h in hidden])

        model.zero_grad()
        output, hidden = model(x, hidden)
        loss = criterion(output.view(-1, vocab_size), y.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if (batch_idx + 1) % 10 == 0:
            avg_loss = total_loss / (batch_idx + 1)
            elapsed_time = time.time() - start_time
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{batch_count}], '
                  f'Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}, Time: {elapsed_time:.2f}s')

    avg_epoch_loss = total_loss / batch_count
    print(f'End of Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_epoch_loss:.4f}, '
          f'Total Time: {time.time() - start_time:.2f}s')

torch.save(model.state_dict(), 'gumbel_sampling_shakespeare_lstm.pth')

Epoch [1/5], Batch [10/1562], Loss: 3.3346, Avg Loss: 3.7402, Time: 1.29s
Epoch [1/5], Batch [20/1562], Loss: 3.2499, Avg Loss: 3.5262, Time: 2.57s
Epoch [1/5], Batch [30/1562], Loss: 3.0540, Avg Loss: 3.4009, Time: 3.87s
Epoch [1/5], Batch [40/1562], Loss: 2.8428, Avg Loss: 3.2850, Time: 5.13s
Epoch [1/5], Batch [50/1562], Loss: 2.5955, Avg Loss: 3.1703, Time: 6.68s
Epoch [1/5], Batch [60/1562], Loss: 2.4825, Avg Loss: 3.0662, Time: 8.18s
Epoch [1/5], Batch [70/1562], Loss: 2.3920, Avg Loss: 2.9747, Time: 9.46s
Epoch [1/5], Batch [80/1562], Loss: 2.2790, Avg Loss: 2.8958, Time: 10.71s
Epoch [1/5], Batch [90/1562], Loss: 2.2375, Avg Loss: 2.8270, Time: 11.97s
Epoch [1/5], Batch [100/1562], Loss: 2.1851, Avg Loss: 2.7663, Time: 13.25s
Epoch [1/5], Batch [110/1562], Loss: 2.1339, Avg Loss: 2.7122, Time: 14.52s
Epoch [1/5], Batch [120/1562], Loss: 2.1024, Avg Loss: 2.6639, Time: 15.78s
Epoch [1/5], Batch [130/1562], Loss: 2.0444, Avg Loss: 2.6188, Time: 17.05s
Epoch [1/5], Batch [140/1562

In [28]:
def sample(model, start_seq, max_len, temperature=1.0):
    model.eval()
    chars = [char_to_idx[ch] for ch in start_seq]
    input_seq = torch.tensor(chars, dtype=torch.long).unsqueeze(0)

    hidden = model.init_hidden(1)

    for _ in range(max_len):
        output, hidden = model(input_seq, hidden)

        logits = output[:, -1, :] / temperature

        # Add Gumbel noise
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits)))
        noisy_logits = logits + gumbel_noise

        next_char_idx = torch.argmax(noisy_logits).item()

        chars.append(next_char_idx)
        input_seq = torch.tensor([[next_char_idx]], dtype=torch.long)

    return ''.join(idx_to_char[idx] for idx in chars)

start_seq = "To be or not to be"
generated_text = sample(model, start_seq, 100)
print(generated_text)

To be or not to be his worthy deeds disbench'd you nourive,
or peace to all's his:
When, by your price o' the consulsh
