In [190]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [191]:
if torch.cuda.is_available():
  device = torch.device("cuda")
  print("GPU is available")
else:
  device = torch.device("cpu")

In [192]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

print("Text Len:", len(text))
print("Vocab Size:", vocab_size)

char_to_ix = {ch: i for i, ch in enumerate(chars)}
ix_to_char = {i: ch for i, ch in enumerate(chars)}

encode = lambda s: [char_to_ix[c] for c in s]
decode = lambda x: ''.join([ix_to_char[i] for i in x])
decode_torch = lambda x: ''.join([ix_to_char[i.item()] for i in x])

Text Len: 1833819
Vocab Size: 280


In [193]:
class LSTM(nn.Module):
    def __init__(self, hidden_dim, input_dim, output_dim, layers):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.layers = layers

        self.forgetGate = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.inputGate = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.candidateGate = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.outputGate = nn.Linear(input_dim + hidden_dim, hidden_dim)

        # Initialize weights
        # Xavier initialization is used to prevent exploding gradients in deeper networks
        nn.init.xavier_uniform_(self.forgetGate.weight)
        nn.init.xavier_uniform_(self.inputGate.weight)
        nn.init.xavier_uniform_(self.candidateGate.weight)
        nn.init.xavier_uniform_(self.outputGate.weight)

        # Initialize biases
        # Forget gate bias is initialized to 1 to remember everything at the beginning 
        # Input, candidate, and output gate biases are initialized to 0 to forget nothing at the beginning
        nn.init.constant_(self.forgetGate.bias, 1.0)
        nn.init.constant_(self.inputGate.bias, 0.0)
        nn.init.constant_(self.candidateGate.bias, 0.0)
        nn.init.constant_(self.outputGate.bias, 0.0)

        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, x, hidden):
        # Hidden state intuition: this is like the short-term memory of the LSTM (just to carry information to the next time step roughly)
        # Cell state intuition: this is like the long-term memory of the LSTM (to carry information across many time steps)
        h_prev, c_prev = hidden
        cat = torch.cat((x, h_prev), 1)

        # Forget gate
        # Intuition: this is later multiplied element-wise with the cell state (forgets where 0 and keeps where 1).
        # Sigmoid squashes the values between 0 and 1 after the hidden state and input are concatenated and linearly transformed to create a forget vec
        f = torch.sigmoid(self.forgetGate(cat))

        # Input gate
        # Intuition: This is later multiplied element-wise with the candidate cell state (decides what to add to the cell state).
        # Sigmoid once again used to squash between 0 and 1
        i = torch.sigmoid(self.inputGate(cat))

        # Candidate cell state
        # Intuition: This is the new candidate cell state that will be added to the cell state.
        # Tanh squashes the values between -1 and 1 to prevent exploding gradients (RNN moment)
        c_hat = torch.tanh(self.candidateGate(cat))

        # Output gate
        # Intuition: This is later multiplied element-wise with the cell state (decides what to output to the hidden state).
        o = torch.sigmoid(self.outputGate(cat))

        # Cell state after forgetting and adding new candidate cell state
        f_t = f * c_prev + i * c_hat

        # Hidden state for next time step based on the new cell state
        h_t = o * torch.tanh(f_t)

        # Layer normalization
        # Intuition: This is to prevent the vanishing gradient problem in deeper NNs by normalizing the hidden state
        h_t = self.layer_norm(h_t)

        # Return like this so that we can pass it back in to the next time step autoregressively
        return h_t, (h_t, f_t)

In [194]:
class LayeredLSTM(nn.Module):
    def __init__(self, hidden_dim, input_dim, output_dim, layers):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.layers = layers

        # Create a list of LSTM layers
        # Intuition: This is to create a stacked LSTM model by feeding the hidden state of the previous LSTM to the next LSTM
        self.lstms = nn.ModuleList([LSTM(hidden_dim, input_dim, output_dim, layers) if i == 0 else LSTM(hidden_dim, hidden_dim, output_dim, layers) for i in range(layers)])

    def forward(self, x, hidden):

        # Loop through the LSTM layers
        # Intuition: This is to pass the hidden state of the previous LSTM to the next LSTM in the stack
        for lstm in self.lstms:
            x, hidden = lstm(x, hidden)
        return x, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_dim), torch.zeros(batch_size, self.hidden_dim)

In [195]:
class CharLSTM(nn.Module):
    def __init__(self, hidden_dim, input_dim, output_dim, layers):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.layers = layers

        self.emb = nn.Embedding(vocab_size, input_dim)
        self.dropout = nn.Dropout(0.2)
        self.lstm = LSTM(hidden_dim, input_dim, output_dim, layers)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, hidden):
        # Embedding layer
        # Intuition: this is a lookup table that maps the token to a vector in the embedding space
        x = self.emb(x)

        # Dropout layer
        # Intuition: this randomly zeros out some elements of the input tensor with probability p. This is to prevent overfitting.
        x = self.dropout(x)

        # LSTM cell
        # Intuition: this returns the hidden state and cell state for the next time step
        h_t, hidden = self.lstm(x, hidden)

        # Fully connected layer
        # Intuition: this maps the hidden state to the output space (vocab size)
        out = self.fc(h_t)

        return out, hidden

    def init_hidden(self, batch_size):
        # Intuition: this initializes the hidden state and cell state to zeros (there is no memory at the start)
        return (torch.zeros(batch_size, self.hidden_dim).to(device),
                torch.zeros(batch_size, self.hidden_dim).to(device))

In [196]:
hidden_dim = 256
input_dim = 128
output_dim = vocab_size
layers = 1
seq_length = 100
num_epochs = 10
learning_rate = 0.002
batch_size = 64

In [197]:
model = CharLSTM(hidden_dim, input_dim, output_dim, layers).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [198]:
encoded_text = torch.tensor(encode(text))

val_split = encoded_text[:len(encoded_text) // 10]
train_split = encoded_text[len(encoded_text) // 10:]
encoded_text = train_split

In [None]:
# Re-written training loop
num_epochs = 100

for epoch in range(num_epochs):
    total_loss = 0
    n = 0
    hidden = model.init_hidden(batch_size)

    print("Hidden shape", hidden[0].shape, hidden[1].shape)

    rand_start = torch.randint(0, 1000, (1,)).item()

    # The goal of this is to create batches where each batch is one after the other in the text
    # 1----->2----->3----->...32----->
    chars_in_batch = (len(text) - rand_start) // batch_size

    for i in range(rand_start, chars_in_batch + rand_start, seq_length):
        
        # Get the input and target sequences
        x = encoded_text[i:i+seq_length*batch_size].view(batch_size, seq_length).to(device)
        y = encoded_text[i+1:i+1+seq_length*batch_size].view(batch_size, seq_length).to(device)
        
        y_pred = torch.zeros(batch_size, seq_length, vocab_size).to(device)

        for j in range(x.shape[1]):
            optimizer.zero_grad()
            out, hidden = model(x[:, j], hidden)
            hidden = (hidden[0].detach(), hidden[1].detach())
            y_pred[:, j] = out.squeeze(1)

        loss = loss_fn(y_pred.view(-1, vocab_size), y.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        n += 1

        if n % 100 == 0:
            print(f'Epoch {epoch+1} Iter {n} Loss {loss.item()}')

    torch.save(model.state_dict(), 'CharRNN_shakespeare'+str(epoch)+'.pth')

    print(f'* Epoch {epoch+1} Total Aggregate Loss {total_loss/n}')

In [None]:
# print total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'Total Parameters: {total_params}')

In [None]:
model = CharLSTM(hidden_dim, input_dim, output_dim, layers)
model.load_state_dict(torch.load('CharRNN_shakespeare99.pth', map_location=torch.device('cpu')))

# Sampling
with torch.no_grad():
    hidden = model.init_hidden(1)  
    out = ""
    x = encoded_text[0].unsqueeze(0)
    for i in range(1000):
        y_pred, hidden = model(x, hidden)

        # Softmax to convert logits to probabilities
        y_pred = F.softmax(y_pred, dim=1)

        # Sample from the probability distribution
        ix = torch.multinomial(y_pred[0], 1)

        out += ix_to_char[ix.item()]
        x = ix

    print(out)
