In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import requests
import re
import math
from collections import Counter

In [3]:
# 1. Data preparation
def get_auguste_maquet_corpus():
    url = "https://www.gutenberg.org/files/7849/7849-0.txt"
    response = requests.get(url)
    text = response.text
    # Remove header and footer
    start = text.find("*** START OF THE PROJECT GUTENBERG EBOOK")
    end = text.find("*** END OF THE PROJECT GUTENBERG EBOOK")
    text = text[start:end]
    # Clean text
    text = re.sub(r'[^\w\s]', '', text)
    text = text.lower()
    return text

# 2. Tokenization and Vocabulary
def tokenize(text):
    return text.split()

In [4]:
class Vocabulary:
    def __init__(self, tokens):
        self.itos = ["<unk>", "<pad>", "<sos>", "<eos>"] + list(set(tokens))
        self.stoi = {token: i for i, token in enumerate(self.itos)}
    
    def __len__(self):
        return len(self.itos)
    
    def encode(self, tokens):
        return [self.stoi.get(token, self.stoi["<unk>"]) for token in tokens]
    
    def decode(self, ids):
        return [self.itos[id] for id in ids]

In [5]:

corpus = get_auguste_maquet_corpus()
tokens = tokenize(corpus)
vocab = Vocabulary(tokens)

In [6]:
class TextDataset(Dataset):
    def __init__(self, text, vocab, seq_length):
        self.text = text
        self.vocab = vocab
        self.seq_length = seq_length
        self.tokens = tokenize(self.text)
        self.data = self.vocab.encode(self.tokens)

    def __len__(self):
        return len(self.data) - self.seq_length

    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx:idx+self.seq_length])
        y = torch.tensor(self.data[idx+1:idx+self.seq_length+1])
        return x, y

In [21]:
# 3. Model architecture
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, tgt, tgt_mask=None):
        tgt = self.embedding(tgt) * math.sqrt(self.embedding.embedding_dim)
        tgt = self.pos_encoder(tgt)
        output = self.transformer_decoder(tgt, tgt, tgt_mask)
        return self.fc_out(output)
    
    def calculate_loss(self, output, target):
        return self.criterion(output.view(-1, output.size(-1)), target.view(-1))
    
    
    

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0)]

In [28]:

# 4. Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seq_length = 64
batch_size = 32
d_model = 256
nhead = 8
num_layers = 4
dim_feedforward = 1024
lr = 0.001
epochs = 10

dataset = TextDataset(corpus, vocab, seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = TransformerDecoder(len(vocab), d_model, nhead, num_layers, dim_feedforward).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
print(model)

TransformerDecoder(
  (embedding): Embedding(5108, 256)
  (pos_encoder): PositionalEncoding()
  (transformer_decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(

In [23]:
from tqdm import tqdm
# 5. Training loop
def train():
    model.train()
    total_loss = 0
    for batch, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        # print(x.shape)
        output = model(x)
        loss = criterion(output.view(-1, len(vocab)), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

for epoch in tqdm(range(epochs)):
    loss = train()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}")

# 6. Text generation
def generate_text(model, start_sequence, max_length=100):
    model.eval()
    with torch.no_grad():
        current_sequence = start_sequence
        for _ in range(max_length):
            input_tensor = torch.tensor(vocab.encode(current_sequence)).unsqueeze(0).to(device)
            output = model(input_tensor)
            next_token_idx = output[0, -1, :].argmax().item()
            next_token = vocab.itos[next_token_idx]
            current_sequence.append(next_token)
            if next_token == "<eos>":
                break
    return " ".join(current_sequence)

 10%|█         | 1/10 [01:08<10:12, 68.06s/it]

Epoch 1/10, Loss: 6.1964


 20%|██        | 2/10 [02:16<09:04, 68.03s/it]

Epoch 2/10, Loss: 6.1701


 30%|███       | 3/10 [03:23<07:55, 67.96s/it]

Epoch 3/10, Loss: 6.1682


 40%|████      | 4/10 [04:31<06:47, 67.91s/it]

Epoch 4/10, Loss: 6.1596


 50%|█████     | 5/10 [05:40<05:40, 68.09s/it]

Epoch 5/10, Loss: 6.1568


 60%|██████    | 6/10 [06:49<04:33, 68.44s/it]

Epoch 6/10, Loss: 6.1549


 70%|███████   | 7/10 [07:58<03:25, 68.65s/it]

Epoch 7/10, Loss: 6.1555


 80%|████████  | 8/10 [09:07<02:17, 68.89s/it]

Epoch 8/10, Loss: 6.1521


 90%|█████████ | 9/10 [10:15<01:08, 68.41s/it]

Epoch 9/10, Loss: 6.1494


100%|██████████| 10/10 [11:21<00:00, 68.15s/it]

Epoch 10/10, Loss: 6.1515





In [24]:
# Example usage
start_sequence = ["the", "count", "of"]
generated_text = generate_text(model, start_sequence)
print(generated_text)

the count of the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the


In [25]:
# Update the text generation function to calculate perplexity
def generate_text_with_perplexity(model, start_sequence, max_length=10):
    model.eval()
    generated_sequence = start_sequence.copy()
    total_log_likelihood = 0
    total_tokens = 0
    
    with torch.no_grad():
        for _ in range(max_length):
            input_tensor = torch.tensor(vocab.encode(generated_sequence)).unsqueeze(0).to(device)
            output = model(input_tensor)
            next_token_logits = output[0, -1, :]
            next_token_idx = next_token_logits.argmax().item()
            next_token = vocab.itos[next_token_idx]
            
            # Calculate log likelihood of the chosen token
            log_likelihood = -model.criterion(next_token_logits.unsqueeze(0), torch.tensor([next_token_idx]).to(device)).item()
            total_log_likelihood += log_likelihood
            total_tokens += 1
            
            generated_sequence.append(next_token)
            if next_token == "<eos>" or len(generated_sequence) >= max_length:
                break
    
    generated_text = " ".join(generated_sequence)
    perplexity = math.exp(-total_log_likelihood / total_tokens)
    return generated_text, perplexity

In [26]:
# Example usage
start_sequence = ["the", "count", "of"]
generated_text, text_perplexity = generate_text_with_perplexity(model, start_sequence)
print("Generated Text:")
print(generated_text)
print(f"Text Perplexity: {text_perplexity:.4f}")

Generated Text:
the count of the the the the the the the
Text Perplexity: 16.5952


## New

In [29]:
# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seq_length = 128  # Increased sequence length
batch_size = 64   # Increased batch size
d_model = 512     # Increased model dimension
nhead = 8
num_layers = 6    # Increased number of layers
dim_feedforward = 2048
lr = 0.0001       # Decreased learning rate
epochs = 50       # Increased number of epochs

dataset = TextDataset(corpus, vocab, seq_length)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

model = TransformerDecoder(len(vocab), d_model, nhead, num_layers, dim_feedforward).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [31]:
# Gradient clipping
clip_value = 1.0

def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    total_tokens = 0
    for _, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = model.calculate_loss(output, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        optimizer.step()
        total_loss += loss.item() * y.numel()
        total_tokens += y.numel()
    
    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return avg_loss, perplexity


def validate(model, dataloader, device):
    model.eval()
    total_loss = 0
    total_tokens = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = model.calculate_loss(output, y)
            total_loss += loss.item() * y.numel()
            total_tokens += y.numel()
    
    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return avg_loss, perplexity


# Training loop
for epoch in tqdm(range(epochs), desc='Training'):
    train_loss, train_perplexity = train_epoch(model, train_dataloader, optimizer, device)
    val_loss, val_perplexity = validate(model, val_dataloader, device)
    scheduler.step()
    
    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Perplexity: {train_perplexity:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Perplexity: {val_perplexity:.4f}")
    print("--------------------")

Training:   2%|▏         | 1/50 [03:36<2:56:50, 216.53s/it]

Epoch 1/50
Train Loss: 6.1937, Train Perplexity: 489.6367
Val Loss: 6.1833, Val Perplexity: 484.5922
--------------------


Training:   4%|▍         | 2/50 [07:14<2:54:03, 217.57s/it]

Epoch 2/50
Train Loss: 6.1852, Train Perplexity: 485.5127
Val Loss: 6.1799, Val Perplexity: 482.9481
--------------------


Training:   6%|▌         | 3/50 [10:52<2:50:36, 217.80s/it]

Epoch 3/50
Train Loss: 6.1814, Train Perplexity: 483.6522
Val Loss: 6.1774, Val Perplexity: 481.7247
--------------------


Training:   8%|▊         | 4/50 [14:31<2:47:07, 217.99s/it]

Epoch 4/50
Train Loss: 6.1783, Train Perplexity: 482.1849
Val Loss: 6.1749, Val Perplexity: 480.5353
--------------------


Training:  10%|█         | 5/50 [18:09<2:43:29, 217.99s/it]

Epoch 5/50
Train Loss: 6.1762, Train Perplexity: 481.1477
Val Loss: 6.1736, Val Perplexity: 479.9243
--------------------


Training:  12%|█▏        | 6/50 [21:47<2:39:51, 217.98s/it]

Epoch 6/50
Train Loss: 6.1743, Train Perplexity: 480.2290
Val Loss: 6.1721, Val Perplexity: 479.1731
--------------------


Training:  14%|█▍        | 7/50 [25:25<2:36:11, 217.95s/it]

Epoch 7/50
Train Loss: 6.1723, Train Perplexity: 479.2691
Val Loss: 6.1708, Val Perplexity: 478.5728
--------------------


Training:  16%|█▌        | 8/50 [29:03<2:32:35, 217.99s/it]

Epoch 8/50
Train Loss: 6.1709, Train Perplexity: 478.6131
Val Loss: 6.1690, Val Perplexity: 477.6878
--------------------


Training:  18%|█▊        | 9/50 [32:41<2:28:57, 217.99s/it]

Epoch 9/50
Train Loss: 6.1679, Train Perplexity: 477.1949
Val Loss: 6.1667, Val Perplexity: 476.5965
--------------------


Training:  20%|██        | 10/50 [36:18<2:25:17, 217.95s/it]

Epoch 10/50
Train Loss: 6.1667, Train Perplexity: 476.6041
Val Loss: 6.1659, Val Perplexity: 476.2101
--------------------


Training:  22%|██▏       | 11/50 [39:56<2:21:36, 217.85s/it]

Epoch 11/50
Train Loss: 6.1644, Train Perplexity: 475.5203
Val Loss: 6.1630, Val Perplexity: 474.8523
--------------------


Training:  24%|██▍       | 12/50 [43:34<2:17:59, 217.87s/it]

Epoch 12/50
Train Loss: 6.1635, Train Perplexity: 475.1067
Val Loss: 6.1627, Val Perplexity: 474.7148
--------------------


Training:  26%|██▌       | 13/50 [47:12<2:14:22, 217.90s/it]

Epoch 13/50
Train Loss: 6.1631, Train Perplexity: 474.9070
Val Loss: 6.1624, Val Perplexity: 474.5459
--------------------


Training:  28%|██▊       | 14/50 [50:50<2:10:49, 218.05s/it]

Epoch 14/50
Train Loss: 6.1629, Train Perplexity: 474.8148
Val Loss: 6.1622, Val Perplexity: 474.4508
--------------------


Training:  30%|███       | 15/50 [54:29<2:07:15, 218.17s/it]

Epoch 15/50
Train Loss: 6.1628, Train Perplexity: 474.7598
Val Loss: 6.1619, Val Perplexity: 474.3500
--------------------


Training:  32%|███▏      | 16/50 [58:07<2:03:38, 218.18s/it]

Epoch 16/50
Train Loss: 6.1626, Train Perplexity: 474.6569
Val Loss: 6.1619, Val Perplexity: 474.3189
--------------------


Training:  34%|███▍      | 17/50 [1:01:45<2:00:00, 218.19s/it]

Epoch 17/50
Train Loss: 6.1624, Train Perplexity: 474.5568
Val Loss: 6.1617, Val Perplexity: 474.2389
--------------------


Training:  36%|███▌      | 18/50 [1:05:23<1:56:20, 218.13s/it]

Epoch 18/50
Train Loss: 6.1620, Train Perplexity: 474.3818
Val Loss: 6.1611, Val Perplexity: 473.9674
--------------------


Training:  38%|███▊      | 19/50 [1:09:01<1:52:40, 218.07s/it]

Epoch 19/50
Train Loss: 6.1617, Train Perplexity: 474.2433
Val Loss: 6.1610, Val Perplexity: 473.8885
--------------------


Training:  40%|████      | 20/50 [1:12:39<1:49:00, 218.03s/it]

Epoch 20/50
Train Loss: 6.1616, Train Perplexity: 474.2050
Val Loss: 6.1606, Val Perplexity: 473.7190
--------------------


Training:  42%|████▏     | 21/50 [1:16:17<1:45:20, 217.96s/it]

Epoch 21/50
Train Loss: 6.1612, Train Perplexity: 474.0055
Val Loss: 6.1606, Val Perplexity: 473.7122
--------------------


Training:  44%|████▍     | 22/50 [1:19:55<1:41:42, 217.93s/it]

Epoch 22/50
Train Loss: 6.1611, Train Perplexity: 473.9655
Val Loss: 6.1606, Val Perplexity: 473.6917
--------------------


Training:  46%|████▌     | 23/50 [1:23:33<1:38:04, 217.94s/it]

Epoch 23/50
Train Loss: 6.1612, Train Perplexity: 473.9776
Val Loss: 6.1605, Val Perplexity: 473.6803
--------------------


Training:  48%|████▊     | 24/50 [1:27:11<1:34:25, 217.91s/it]

Epoch 24/50
Train Loss: 6.1612, Train Perplexity: 473.9799
Val Loss: 6.1605, Val Perplexity: 473.6793
--------------------


Training:  50%|█████     | 25/50 [1:30:48<1:30:47, 217.89s/it]

Epoch 25/50
Train Loss: 6.1612, Train Perplexity: 473.9826
Val Loss: 6.1605, Val Perplexity: 473.6687
--------------------


Training:  52%|█████▏    | 26/50 [1:34:26<1:27:08, 217.86s/it]

Epoch 26/50
Train Loss: 6.1613, Train Perplexity: 474.0324
Val Loss: 6.1605, Val Perplexity: 473.6613
--------------------


Training:  54%|█████▍    | 27/50 [1:38:04<1:23:30, 217.87s/it]

Epoch 27/50
Train Loss: 6.1611, Train Perplexity: 473.9391
Val Loss: 6.1605, Val Perplexity: 473.6512
--------------------


Training:  56%|█████▌    | 28/50 [1:41:42<1:19:52, 217.84s/it]

Epoch 28/50
Train Loss: 6.1611, Train Perplexity: 473.9724
Val Loss: 6.1605, Val Perplexity: 473.6478
--------------------


Training:  58%|█████▊    | 29/50 [1:45:20<1:16:15, 217.90s/it]

Epoch 29/50
Train Loss: 6.1610, Train Perplexity: 473.8867
Val Loss: 6.1604, Val Perplexity: 473.6382
--------------------


Training:  60%|██████    | 30/50 [1:48:58<1:12:38, 217.92s/it]

Epoch 30/50
Train Loss: 6.1611, Train Perplexity: 473.9620
Val Loss: 6.1604, Val Perplexity: 473.6372
--------------------


Training:  62%|██████▏   | 31/50 [1:52:36<1:09:00, 217.92s/it]

Epoch 31/50
Train Loss: 6.1609, Train Perplexity: 473.8575
Val Loss: 6.1604, Val Perplexity: 473.6335
--------------------


Training:  64%|██████▍   | 32/50 [1:56:14<1:05:22, 217.92s/it]

Epoch 32/50
Train Loss: 6.1609, Train Perplexity: 473.8449
Val Loss: 6.1604, Val Perplexity: 473.6331
--------------------


Training:  66%|██████▌   | 33/50 [1:59:52<1:01:44, 217.93s/it]

Epoch 33/50
Train Loss: 6.1610, Train Perplexity: 473.8961
Val Loss: 6.1604, Val Perplexity: 473.6328
--------------------


Training:  68%|██████▊   | 34/50 [2:03:30<58:06, 217.92s/it]  

Epoch 34/50
Train Loss: 6.1610, Train Perplexity: 473.9014
Val Loss: 6.1604, Val Perplexity: 473.6312
--------------------


Training:  70%|███████   | 35/50 [2:07:07<54:28, 217.93s/it]

Epoch 35/50
Train Loss: 6.1610, Train Perplexity: 473.8817
Val Loss: 6.1604, Val Perplexity: 473.6308
--------------------


Training:  72%|███████▏  | 36/50 [2:10:45<50:51, 217.94s/it]

Epoch 36/50
Train Loss: 6.1610, Train Perplexity: 473.8861
Val Loss: 6.1604, Val Perplexity: 473.6299
--------------------


Training:  74%|███████▍  | 37/50 [2:14:23<47:12, 217.91s/it]

Epoch 37/50
Train Loss: 6.1608, Train Perplexity: 473.8289
Val Loss: 6.1604, Val Perplexity: 473.6292
--------------------


Training:  76%|███████▌  | 38/50 [2:18:01<43:35, 217.98s/it]

Epoch 38/50
Train Loss: 6.1609, Train Perplexity: 473.8492
Val Loss: 6.1604, Val Perplexity: 473.6288
--------------------


Training:  78%|███████▊  | 39/50 [2:21:40<39:58, 218.01s/it]

Epoch 39/50
Train Loss: 6.1610, Train Perplexity: 473.8931
Val Loss: 6.1604, Val Perplexity: 473.6287
--------------------


Training:  80%|████████  | 40/50 [2:25:17<36:18, 217.89s/it]

Epoch 40/50
Train Loss: 6.1609, Train Perplexity: 473.8773
Val Loss: 6.1604, Val Perplexity: 473.6264
--------------------


Training:  82%|████████▏ | 41/50 [2:28:54<32:39, 217.71s/it]

Epoch 41/50
Train Loss: 6.1609, Train Perplexity: 473.8694
Val Loss: 6.1604, Val Perplexity: 473.6265
--------------------


Training:  84%|████████▍ | 42/50 [2:32:32<29:00, 217.56s/it]

Epoch 42/50
Train Loss: 6.1609, Train Perplexity: 473.8561
Val Loss: 6.1604, Val Perplexity: 473.6267
--------------------


Training:  86%|████████▌ | 43/50 [2:36:09<25:22, 217.48s/it]

Epoch 43/50
Train Loss: 6.1609, Train Perplexity: 473.8317
Val Loss: 6.1604, Val Perplexity: 473.6265
--------------------


Training:  88%|████████▊ | 44/50 [2:39:47<21:45, 217.55s/it]

Epoch 44/50
Train Loss: 6.1609, Train Perplexity: 473.8528
Val Loss: 6.1604, Val Perplexity: 473.6265
--------------------


Training:  90%|█████████ | 45/50 [2:43:24<18:07, 217.49s/it]

Epoch 45/50
Train Loss: 6.1610, Train Perplexity: 473.9023
Val Loss: 6.1604, Val Perplexity: 473.6267
--------------------


Training:  92%|█████████▏| 46/50 [2:47:01<14:29, 217.44s/it]

Epoch 46/50
Train Loss: 6.1611, Train Perplexity: 473.9275
Val Loss: 6.1604, Val Perplexity: 473.6268
--------------------


Training:  94%|█████████▍| 47/50 [2:50:39<10:52, 217.46s/it]

Epoch 47/50
Train Loss: 6.1609, Train Perplexity: 473.8597
Val Loss: 6.1604, Val Perplexity: 473.6267
--------------------


Training:  96%|█████████▌| 48/50 [2:54:16<07:14, 217.38s/it]

Epoch 48/50
Train Loss: 6.1610, Train Perplexity: 473.8829
Val Loss: 6.1604, Val Perplexity: 473.6268
--------------------


Training:  98%|█████████▊| 49/50 [2:57:54<03:37, 217.41s/it]

Epoch 49/50
Train Loss: 6.1610, Train Perplexity: 473.9024
Val Loss: 6.1604, Val Perplexity: 473.6267
--------------------


Training: 100%|██████████| 50/50 [3:01:31<00:00, 217.83s/it]

Epoch 50/50
Train Loss: 6.1609, Train Perplexity: 473.8779
Val Loss: 6.1604, Val Perplexity: 473.6267
--------------------





In [32]:
# Improved text generation function
def generate_text_with_sampling(model, start_sequence, max_length=100, temperature=0.7, top_k=50, top_p=0.9):
    model.eval()
    generated_sequence = start_sequence.copy()
    
    with torch.no_grad():
        for _ in range(max_length):
            input_tensor = torch.tensor(vocab.encode(generated_sequence)).unsqueeze(0).to(device)
            output = model(input_tensor)
            next_token_logits = output[0, -1, :] / temperature
            
            # Top-k sampling
            top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
            
            # Top-p (nucleus) sampling
            sorted_logits, sorted_indices = torch.sort(top_k_logits, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            top_k_logits[indices_to_remove] = float('-inf')
            
            # Sample from the filtered distribution
            probs = torch.softmax(top_k_logits, dim=-1)
            next_token_idx = top_k_indices[torch.multinomial(probs, 1).item()].item()
            next_token = vocab.itos[next_token_idx]
            
            generated_sequence.append(next_token)
            if next_token == "<eos>" or len(generated_sequence) >= max_length:
                break
    
    return " ".join(generated_sequence)

# Example usage
start_sequence = ["the", "count", "of"]
generated_text = generate_text_with_sampling(model, start_sequence, temperature=0.7, top_k=50, top_p=0.9)
print("Generated Text:")
print(generated_text)

Generated Text:
the count of had the he its the that for i had the it but as but a to and you i to the the a and and he had to he to so to and and the and the a the that the the the he to was but and k to and the was his the this for you of but the at and was at the had and i the to and him and and not the of to to and this it the a the the that it him in the and the the and to


In [33]:
import torch
import torch.nn.functional as F
import math

# Improved text generation function with perplexity
def generate_text_with_sampling_and_perplexity(model, start_sequence, vocab, max_length=100, temperature=0.7, top_k=50, top_p=0.9):
    model.eval()
    generated_sequence = start_sequence.copy()
    perplexities = []
    
    with torch.no_grad():
        for _ in range(max_length):
            input_tensor = torch.tensor(vocab.encode(generated_sequence)).unsqueeze(0).to(device)
            output = model(input_tensor)
            next_token_logits = output[0, -1, :] / temperature
            
            # Top-k sampling
            top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
            
            # Top-p (nucleus) sampling
            sorted_logits, sorted_indices = torch.sort(top_k_logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            top_k_logits[indices_to_remove] = float('-inf')
            
            # Sample from the filtered distribution
            probs = F.softmax(top_k_logits, dim=-1)
            next_token_idx = top_k_indices[torch.multinomial(probs, 1).item()].item()
            next_token = vocab.itos[next_token_idx]
            
            # Calculate perplexity for the generated token
            token_prob = probs[top_k_indices == next_token_idx].item()  # Probability of the sampled token
            perplexity = math.exp(-math.log(token_prob)) if token_prob > 0 else float('inf')
            perplexities.append(perplexity)
            
            generated_sequence.append(next_token)
            print(f"Generated Token: {next_token}, Perplexity: {perplexity:.4f}")
            
            # Break if end-of-sequence token is generated or max length is reached
            if next_token == "<eos>" or len(generated_sequence) >= max_length:
                break
    
    return " ".join(generated_sequence), perplexities

# Example usage
start_sequence = ["the", "count", "of"]
generated_text, perplexities = generate_text_with_sampling_and_perplexity(model, start_sequence, vocab, temperature=0.7, top_k=50, top_p=0.9)

print("Generated Text:")
print(generated_text)
print("Perplexities:", perplexities)

Generated Token: and, Perplexity: 17.1910
Generated Token: to, Perplexity: 8.6642
Generated Token: said, Perplexity: 64.9131
Generated Token: you, Perplexity: 35.3226
Generated Token: the, Perplexity: 5.7980
Generated Token: i, Perplexity: 50.8094
Generated Token: the, Perplexity: 4.3551
Generated Token: it, Perplexity: 27.9179
Generated Token: and, Perplexity: 14.5621
Generated Token: the, Perplexity: 5.3899
Generated Token: with, Perplexity: 75.6919
Generated Token: was, Perplexity: 43.3846
Generated Token: he, Perplexity: 17.4259
Generated Token: at, Perplexity: 122.5639
Generated Token: it, Perplexity: 24.7624
Generated Token: to, Perplexity: 7.1765
Generated Token: on, Perplexity: 98.8945
Generated Token: and, Perplexity: 17.3603
Generated Token: to, Perplexity: 8.6642
Generated Token: of, Perplexity: 18.0934
Generated Token: of, Perplexity: 22.5940
Generated Token: the, Perplexity: 3.9198
Generated Token: at, Perplexity: 81.3594
Generated Token: k, Perplexity: 32.5852
Generated T