In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [2]:
if torch.cuda.is_available():
    device = 'cuda'

In [3]:
for i in range(torch.cuda.device_count()):
    print(torch.cuda.get_device_name(i))

NVIDIA GeForce RTX 4070


In [4]:
with open('wizard_of_oz.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [5]:
chars = sorted(set(text))
print(chars)
vocab_size = len(chars)
print(vocab_size)

['\n', ' ', '!', '"', '$', '%', '&', "'", '(', ')', '*', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '¹', '‒', '—', '―', '‘', '’', '“', '”', '•', '™', '♠', '♦', '\ufeff']
96


In [6]:
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for idx, char in enumerate(chars)}

In [7]:
print(idx_to_char)

{0: '\n', 1: ' ', 2: '!', 3: '"', 4: '$', 5: '%', 6: '&', 7: "'", 8: '(', 9: ')', 10: '*', 11: ',', 12: '-', 13: '.', 14: '/', 15: '0', 16: '1', 17: '2', 18: '3', 19: '4', 20: '5', 21: '6', 22: '7', 23: '8', 24: '9', 25: ':', 26: ';', 27: '?', 28: 'A', 29: 'B', 30: 'C', 31: 'D', 32: 'E', 33: 'F', 34: 'G', 35: 'H', 36: 'I', 37: 'J', 38: 'K', 39: 'L', 40: 'M', 41: 'N', 42: 'O', 43: 'P', 44: 'Q', 45: 'R', 46: 'S', 47: 'T', 48: 'U', 49: 'V', 50: 'W', 51: 'X', 52: 'Y', 53: 'Z', 54: '[', 55: ']', 56: '_', 57: 'a', 58: 'b', 59: 'c', 60: 'd', 61: 'e', 62: 'f', 63: 'g', 64: 'h', 65: 'i', 66: 'j', 67: 'k', 68: 'l', 69: 'm', 70: 'n', 71: 'o', 72: 'p', 73: 'q', 74: 'r', 75: 's', 76: 't', 77: 'u', 78: 'v', 79: 'w', 80: 'x', 81: 'y', 82: 'z', 83: '¹', 84: '‒', 85: '—', 86: '―', 87: '‘', 88: '’', 89: '“', 90: '”', 91: '•', 92: '™', 93: '♠', 94: '♦', 95: '\ufeff'}


In [8]:
# data = [char_to_idx[char] for char in text ]

In [9]:
# if len(data) == len(text):
#     print(f"True! Length is {len(data)}")

In [10]:
# inputs = []
# targets = []
# seq_length = 100

# for i in range(0, len(data) -seq_length):
#     inputs.append(data[i:i+seq_length])
#     targets.append(data[i+1:i+seq_length+1])

# inputs = torch.tensor(inputs, dtype=torch.long)
# targets = torch.tensor(targets, dtype=torch.long)

In [11]:
# print(inputs.size())
# print(targets.size())
# inputs.to(device)
# targets.to(device)

In [12]:
class TextDataset(Dataset):
    def __init__(self, text, seq_length):
        chars = sorted(list(set(text)))
        self.char_to_idx = {char: idx for idx, char in enumerate(chars)}
        self.idx_to_char = {idx: char for idx, char in enumerate(chars)}
        self.vocab_size = len(chars)
        self.seq_length = seq_length
        self.data = [self.char_to_idx[char] for char in text]

        self.inputs = []
        self.targets = []
        for i in range(0, len(self.data) - seq_length):
            self.inputs.append(self.data[i:i + seq_length])
            self.targets.append(self.data[i + 1:i + seq_length + 1])

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return torch.tensor(self.inputs[idx], dtype=torch.long), torch.tensor(self.targets[idx], dtype=torch.long)

# Load the text data
with open('wizard_of_oz.txt', 'r', encoding='utf-8') as f:
    text = f.read()

seq_length = 50
dataset = TextDataset(text, seq_length)
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [13]:
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        x = self.embedding(x)
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out)
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size)

# Model parameters
input_size = dataset.vocab_size
hidden_size = 128
output_size = dataset.vocab_size

# Initialize model, loss function, and optimizer
model = SimpleRNN(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [14]:
model.to(device)

SimpleRNN(
  (embedding): Embedding(96, 128)
  (rnn): RNN(128, 128, batch_first=True)
  (fc): Linear(in_features=128, out_features=96, bias=True)
)

In [15]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for x_batch, y_batch in dataloader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        
        # Initialize hidden state for the current batch
        hidden = model.init_hidden(x_batch.size(0)).to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        output, hidden = model(x_batch, hidden)
        
        # Calculate loss
        loss = criterion(output.view(-1, output_size), y_batch.view(-1))
        
        # Backward pass
        loss.backward()
        
        # Update parameters
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader):.4f}')

Epoch [1/10], Loss: 1.5196
Epoch [2/10], Loss: 1.4275
Epoch [3/10], Loss: 1.4144
Epoch [4/10], Loss: 1.4083
Epoch [5/10], Loss: 1.4043
Epoch [6/10], Loss: 1.4016
Epoch [7/10], Loss: 1.3999
Epoch [8/10], Loss: 1.3980
Epoch [9/10], Loss: 1.3972
Epoch [10/10], Loss: 1.3962


In [21]:
output[1,1].size()
output[1,1].tolist()

[-6.723585605621338,
 -3.8264291286468506,
 -3.2240824699401855,
 -16.451496124267578,
 -23.411029815673828,
 -29.336402893066406,
 -22.524686813354492,
 -6.478598594665527,
 -13.270509719848633,
 -11.285575866699219,
 -12.38095474243164,
 -1.6375879049301147,
 -1.388867974281311,
 0.30989038944244385,
 -12.410364151000977,
 -14.931523323059082,
 -12.222330093383789,
 -13.27371597290039,
 -15.596689224243164,
 -11.92841911315918,
 -14.370338439941406,
 -14.55559253692627,
 -19.44388198852539,
 -15.277057647705078,
 -12.98617172241211,
 -6.768560409545898,
 -4.385221481323242,
 -3.9468653202056885,
 -10.544624328613281,
 -4.4307379722595215,
 -7.464421272277832,
 -5.638528347015381,
 -9.103068351745605,
 -14.175461769104004,
 -6.72966194152832,
 -10.394250869750977,
 -5.823312282562256,
 -9.328401565551758,
 -11.48056411743164,
 -9.701556205749512,
 -11.409364700317383,
 -3.88547682762146,
 -3.6877083778381348,
 -13.180563926696777,
 -19.027984619140625,
 -11.714963912963867,
 -7.951646

In [29]:
output.shape

torch.Size([12, 50, 96])

In [32]:
predicted_indices = output.argmax(dim=-1)

predicted_indices = predicted_indices.view(-1).tolist()

print(len(predicted_indices))

predicted_chars = [dataset.idx_to_char[idx] for idx in predicted_indices]

predicted_text = ''.join(predicted_chars)

print("\nPredicted Text from the last batch of the last epoch:")
print(predicted_text)

600

Predicted Text from the last batch of the last epoch:
eleur opder tnt stncg , &n thet tn the storp  tltotn i she  th traase tnd tfher  Bord, sor iventy  b tn tim weaeete and then the sresty sroncess ouynee t tn travate pety  and ta sind ias tn to e teasutvery thiught  ahudstnd tard  ahldeth thes droryou tarths  an was tte  etea  ahet Ihe sar   of ton t tn ws tne  to te tefore the r lontersions ahth tehmg tolnvns e  and the sitten tenned th tyny piret consinuedth teclntett  ahll tnr sves ine soll  tf  
"Iow ds tncle Hewry   ahe sxduired  anter t srr esh-andtvee oonst snr ng tnd the  tycmieeng tn  O tes  ahin te sholl b peareon t l the salless of t
