In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import string
import urllib.request

# Download Shakespeare dataset
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
urllib.request.urlretrieve(url, 'shakespeare.txt')

# Define the Shakespeare dataset
class ShakespeareDataset(Dataset):
    def __init__(self, file_path, seq_length=50):
        self.seq_length = seq_length
        self.chars = list(string.printable)
        self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
        self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}
        self.text = self.load_text(file_path)

    def load_text(self, file_path):
        with open(file_path, 'r') as f:
            text = f.read()
        return text

    def __len__(self):
        return len(self.text) - self.seq_length

    def __getitem__(self, idx):
        input_seq = self.text[idx:idx+self.seq_length]
        target = self.text[idx+self.seq_length]
        return torch.tensor([self.char_to_idx[ch] for ch in input_seq]), self.char_to_idx[target]

# Define the Transformer model
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=6):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, nhead) for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(1, 0, 2)  # Change the sequence length dimension
        for layer in self.transformer_layers:
            x = layer(x)
        x = self.fc(x[-1, :, :])  # Only using the output from the last position
        return x

# Training parameters
batch_size = 64
seq_length = 50
lr = 0.001
epochs = 10

print(batch_size, seq_length, lr, epochs)

# Load and preprocess the Shakespeare dataset
dataset = ShakespeareDataset('shakespeare.txt', seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize the model, criterion, and optimizer
model = Transformer(vocab_size=len(dataset.chars))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
for epoch in range(epochs):
    for batch_inputs, batch_targets in dataloader:
        optimizer.zero_grad()
        output = model(batch_inputs)
        loss = criterion(output, batch_targets)
        loss.backward()
        optimizer.step()

        print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')

# Generate text using the trained model
def generate_text(start_text, length=200, temperature=0.8):
    model.eval()
    with torch.no_grad():
        input_seq = torch.tensor([dataset.char_to_idx[ch] for ch in start_text])
        for _ in range(length):
            output = model(input_seq.unsqueeze(0))
            probabilities = F.softmax(output.squeeze() / temperature, dim=0)
            predicted_idx = torch.multinomial(probabilities, 1).item()
            input_seq = torch.cat((input_seq[1:], torch.tensor([predicted_idx])))
            print(dataset.idx_to_char[predicted_idx], end='')

# Generate text with a starting prompt
generate_text("The sun", length=300)


64 50 0.001 10
Epoch 1/10, Loss: 4.814452648162842
Epoch 1/10, Loss: 4.008248329162598
Epoch 1/10, Loss: 4.349300861358643
Epoch 1/10, Loss: 3.7609927654266357
Epoch 1/10, Loss: 3.525066375732422
Epoch 1/10, Loss: 3.4460384845733643
Epoch 1/10, Loss: 3.47165846824646
Epoch 1/10, Loss: 3.6625864505767822
Epoch 1/10, Loss: 3.567234516143799
Epoch 1/10, Loss: 3.409498691558838
Epoch 1/10, Loss: 3.438255786895752
Epoch 1/10, Loss: 3.3701398372650146
Epoch 1/10, Loss: 3.5919103622436523
Epoch 1/10, Loss: 3.348745822906494
Epoch 1/10, Loss: 3.546959400177002
Epoch 1/10, Loss: 3.43572998046875
Epoch 1/10, Loss: 3.2080280780792236
Epoch 1/10, Loss: 3.5196919441223145
Epoch 1/10, Loss: 3.4165432453155518
Epoch 1/10, Loss: 3.489027976989746
Epoch 1/10, Loss: 3.357367753982544
Epoch 1/10, Loss: 3.4218525886535645
Epoch 1/10, Loss: 3.256929397583008
Epoch 1/10, Loss: 3.4109857082366943
Epoch 1/10, Loss: 3.687011480331421
Epoch 1/10, Loss: 3.275266170501709
Epoch 1/10, Loss: 3.7433922290802
Epoch 1

KeyboardInterrupt: 