<a href="https://colab.research.google.com/github/ardapekis/colab/blob/master/language_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# Uncomment this if you need to install dependencies.
# GPU Version
!pip3 install torch

In [0]:
# Run this if you need to download the training data

# Download and clean Karl Marx
!wget https://www.gutenberg.org/files/61/61.txt
with open('61.txt', 'r') as f:
  lines = f.readlines()
start_index = lines.index("MANIFESTO OF THE COMMUNIST PARTY\n")
end_index = lines.index("           WORKING MEN OF ALL COUNTRIES, UNITE!\n")
with open('marx.txt', 'w') as f:
  f.writelines(lines[start_index:end_index+1])

# Download and clean Confucius
!wget http://www.gutenberg.org/cache/epub/3330/pg3330.txt
with open('pg3330.txt', 'r') as f:
  lines = f.readlines()
start_index = lines.index("CONFUCIAN ANALECTS.\n")
end_index = lines.index("know men.'\n")
with open('confucius.txt', 'w') as f:
  f.writelines(lines[start_index:end_index+1])
  
# Download and clean Robert Frost
!wget http://www.gutenberg.org/cache/epub/3021/pg3021.txt
with open('pg3021.txt', 'r') as f:
  lines = f.readlines()
start_index = lines.index("A BOY'S WILL\n")
end_index = lines.index("    Of a love or a season?\n")
with open("frost.txt", 'w') as f:
  f.writelines(lines[start_index:end_index+1])
  
# Download and clean Sun Tzu
!wget http://www.gutenberg.org/cache/epub/44024/pg44024.txt
with open('pg44024.txt', 'r') as f:
  lines = f.readlines()
start_index = lines.index("THE ARTICLES OF SUNTZU\n")
end_index = lines.index("Wei shook the heavens.\n")
with open('suntzu.txt', 'w') as f:
  f.writelines(lines[start_index:end_index+1])

In [0]:
# Load the dataset
with open('marx.txt', 'r') as f:
  marx = f.read().encode('ascii', 'ignore')
with open('confucius.txt', 'r') as f:
  confucius = f.read().encode('ascii', 'ignore')
with open('frost.txt', 'r') as f:
  frost = f.read().encode('ascii', 'ignore')
with open('suntzu.txt', 'r') as f:
  suntzu = f.read().encode('ascii', 'ignore')

texts = [marx, confucius, frost, suntzu]
authors = ["marx", "confucius", "frost", "suntzu"]

In [0]:
#Helper code for word-level language modeling
class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)
      
class Corpus(object):
    def __init__(self, texts):
        self.dictionary = Dictionary()
        self.texts = [self.tokenize(text) for text in texts]

    def tokenize(self, text):
        """Tokenizes a text file."""
        # Add words to the dictionary
        tokens = 0
        words = text.split()
        tokens += len(words)
        for word in words:
            self.dictionary.add_word(word)

        # Tokenize file content
        ids = torch.LongTensor(tokens)
        token = 0
        words = text.split()
        for word in words:
            ids[token] = self.dictionary.word2idx[word]
            token += 1

        return ids

In [0]:
from random import randint
from timeit import default_timer as timer

import multiprocessing

from tqdm import tqdm

import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn

In [0]:
def get_batch(params):
  text, author, seq_len, batch_size = params
  batch = []
  for i in range(batch_size):
    start = randint(0, len(text) - seq_len)
    batch.append((author, text[start:(start+seq_len)]))
  return batch

In [0]:
class CharLSTM(nn.Module):
  def __init__(self, num_authors, author_emb_dim, num_characters, character_emb_dim, hidden_size, output_size):
    super().__init__()
    assert num_characters == output_size
    self.num_authors = num_authors
    self.author_emb_dim = author_emb_dim
    self.num_characters = num_characters
    self.character_emb_dim = character_emb_dim
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.author_embedding = nn.Linear(num_authors, author_emb_dim)
    self.character_embedding = nn.Linear(num_characters, character_emb_dim)
    self.hidden_init = nn.Parameter(torch.randn([self.hidden_size]))
    self.cell_init = nn.Parameter(torch.randn([self.hidden_size]))
    self.forget_gate = nn.Linear(author_emb_dim + character_emb_dim + hidden_size, hidden_size)
    self.input_gate = nn.Linear(author_emb_dim + character_emb_dim + hidden_size, hidden_size)
    self.update_gate = nn.Linear(author_emb_dim + character_emb_dim + hidden_size, hidden_size)
    self.output_gate = nn.Linear(author_emb_dim + character_emb_dim + hidden_size, hidden_size)
    self.output_mlp = nn.Linear(hidden_size, output_size)

  def forward(self, author, text, forced=True):
    # Author is a one-hot / many-hot encoding of the text's author (batch x seq_length x num_authors)
    # Text is a batch x seq_length x num_characters
    author_emb = self.author_embedding(author)
    char_emb = self.character_embedding(text)
        
    # This is the input sequence for the LSTM
    input_sequence = torch.cat([author_emb, char_emb], 2)
    
    # The initial state of the LSTM
    hidden_init = self.hidden_init.expand(author.shape[0], -1)
    cell_init = self.cell_init.expand(author.shape[0], -1)
    
    # Keep track of the hidden output, cell memory and output over time
    hidden_sequence = [hidden_init]
    cell_sequence = [cell_init]
    output_sequence = [text[:, 0] + 1e-1]
    
    # Do the LSTM
    # For each sequence step...
    for i, item in enumerate(input_sequence.permute(1, 0, 2)):
      # Get the state information from the previous step
      prev_hidden = hidden_sequence[-1]
      cell = cell_sequence[-1]

      if not forced:
        item = torch.cat([author_emb[:, i], self.character_embedding(nn.functional.softmax(output_sequence[-1], 1))], 1)
      # Concatenate the input with the previous output
      lstm_input = torch.cat([item, prev_hidden], 1)
      
      # Calculate the LSTM Gates
      forget_vector = torch.sigmoid(self.forget_gate(lstm_input))
      input_vector = torch.sigmoid(self.input_gate(lstm_input))
      update_vector = torch.tanh(self.update_gate(lstm_input))
      output_vector = torch.sigmoid(self.output_gate(lstm_input))
      
      # Update the cell's memory based on the gates
      cell = forget_vector                  * cell
      cell = (input_vector * update_vector) + cell
      
      # We are done updating the cell's memory
      cell_sequence.append(cell)
      
      # Hidden output is detemined by cell memory and output gate
      hidden = output_vector * torch.tanh(cell)
      
      # We are done calculating this step of the LSTM
      hidden_sequence.append(hidden)
      
      # Calculate the final output
      # Translates LSTM output to character prediction
      output = self.output_mlp(hidden)
      output_sequence.append(output)
      
    prediction_sequence = torch.stack(output_sequence, 1)
    return prediction_sequence
      

In [0]:
device = "cuda"

criterion = nn.CrossEntropyLoss()

pool_size = len(authors) * 1
pool = multiprocessing.Pool(pool_size)


def train_epoch(model, texts, authors, seq_len, batch_size, batches_per_epoch, force=True):
  model = model.to(device)
  optimizer = optim.RMSprop(model.parameters(), lr=2e-3, momentum=0.5)
  running_loss = 0.0
  
  for i in range(batches_per_epoch):
    # Grab the data. Use multiple workers
    batch = [item for sublist in list(pool.map(get_batch, [(texts[i%model.num_authors], i%model.num_authors, seq_len, batch_size//pool_size) for i in range(pool_size)])) for item in sublist]
    
    # Format the data into one-hot vectors
    author_sample, text_samples = zip(*batch)
    author_one_hot = [torch.zeros([seq_len, model.num_authors]).to(device).scatter_(1, torch.LongTensor(seq_len * [[author]]).to(device), 1) for author in author_sample]
    author_in = torch.stack(author_one_hot, 0)
    text_one_hot = [torch.zeros([seq_len, model.num_characters]).to(device).scatter_(1, torch.LongTensor(np.frombuffer(text, dtype=np.uint8)).to(device).unsqueeze(1).clamp(max=model.num_characters-1), 1) for text in text_samples]
    text_in = torch.stack(text_one_hot, 0)

    # Predict
    optimizer.zero_grad()
    output = model(author_in, text_in, forced=force)
    
    # Calculate error
    output = output[:, :-1, :]
    target = torch.stack([torch.LongTensor(np.frombuffer(text, dtype=np.uint8)[1:]).to(device) for text in text_samples], 0)
    loss = torch.tensor(0.0).to(device)
    for i in range(target.shape[1]):
      loss = loss + criterion(output[:, i], target[:, i])
      
    # Perform gradient descent
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 8)
    optimizer.step()
    
    running_loss += loss.item()
    
  for i in range(len(authors)):
    print("{}: ".format(authors[i]) + "".join(list(map(chr, torch.argmax(output[i * (batch_size//4)], 1).to(torch.uint8).tolist()))))
    
  return running_loss / (batches_per_epoch * seq_len)

In [0]:

model = CharLSTM(4, 32, 128, 128, 128, 128)

for epoch in range(16):
  print("Epoch #{}".format(epoch))
  start = timer()
  loss = train_epoch(model, texts, authors, 128, 1024, 16)
  print("Loss: {:.4f}".format(loss))
  print("Time: {:.1f} seconds".format(timer()-start))

In [0]:
for epoch in range(16):
  print("Epoch #{}".format(epoch))
  start = timer()
  loss = train_epoch(model, texts, authors, 128, 1024, 16, force=False)
  print("Loss: {:.4f}".format(loss))
  print("Time: {:.1f} seconds".format(timer()-start))