In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Use CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
from pathlib import Path

text = Path('../../tiny-shakespeare.txt').read_text(encoding='utf-8')

In [7]:
print(text[0:1000])

    THE SONNETS
    ALL’S WELL THAT ENDS WELL
    THE TRAGEDY OF ANTONY AND CLEOPATRA
    AS YOU LIKE IT
    THE COMEDY OF ERRORS
    THE TRAGEDY OF CORIOLANUS
    CYMBELINE
    THE TRAGEDY OF HAMLET, PRINCE OF DENMARK
    THE FIRST PART OF KING HENRY THE FOURTH
    THE SECOND PART OF KING HENRY THE FOURTH
    THE LIFE OF KING HENRY THE FIFTH
    THE FIRST PART OF HENRY THE SIXTH
    THE SECOND PART OF KING HENRY THE SIXTH
    THE THIRD PART OF KING HENRY THE SIXTH
    KING HENRY THE EIGHTH
    THE LIFE AND DEATH OF KING JOHN
    THE TRAGEDY OF JULIUS CAESAR
    THE TRAGEDY OF KING LEAR
    LOVE’S LABOUR’S LOST
    THE TRAGEDY OF MACBETH
    MEASURE FOR MEASURE
    THE MERCHANT OF VENICE
    THE MERRY WIVES OF WINDSOR
    A MIDSUMMER NIGHT’S DREAM
    MUCH ADO ABOUT NOTHING
    THE TRAGEDY OF OTHELLO, THE MOOR OF VENICE
    PERICLES, PRINCE OF TYRE
    KING RICHARD THE SECOND
    KING RICHARD THE THIRD
    THE TRAGEDY OF ROMEO AND JULIET
    THE TAMING OF THE SHREW
    THE TEMPEST
    

In [8]:

class CharTokenizer:
  def __init__(self, vocabulary):
    self.token_id_for_char = {char: token_id for token_id, char in enumerate(vocabulary)}
    self.char_for_token_id = {token_id: char for token_id, char in enumerate(vocabulary)}

  @staticmethod
  def train_from_text(text):
    vocabulary = set(text)
    return CharTokenizer(sorted(list(vocabulary)))

  def encode(self, text):
    token_ids = []
    for char in text:
      token_ids.append(self.token_id_for_char[char])
    return torch.tensor(token_ids, dtype=torch.long)

  def decode(self, token_ids):
    chars = []
    for token_id in token_ids.tolist():
      chars.append(self.char_for_token_id[token_id])
    return ''.join(chars)


  def vocabulary_size(self):
    return len(self.token_id_for_char)

In [9]:
tokenizer = CharTokenizer.train_from_text(text)

In [10]:
print(tokenizer.encode("Hello world"))
print(tokenizer.decode(tokenizer.encode("Hello world")))

tensor([31, 57, 64, 64, 67,  2, 75, 67, 70, 64, 56])
Hello world


In [11]:
print(f"Vocabulary size: {tokenizer.vocabulary_size()}")

Vocabulary size: 98


In [12]:
from torch.utils.data import Dataset

class TokenIdsDataset(Dataset):
  def __init__(self, data, block_size):
    self.data = data
    self.block_size = block_size

  def __len__(self):
    return len(self.data) - self.block_size

  def __getitem__(self, pos):
    assert pos < len(self.data) - self.block_size

    x = self.data[pos:pos + self.block_size]
    y = self.data[pos + 1:pos + 1 + self.block_size]
    return x, y

In [13]:
tokenized_text = tokenizer.encode(text)
dataset = TokenIdsDataset(tokenized_text, block_size=64)

In [14]:
x, y = dataset[0]

In [15]:
x

tensor([ 2,  2,  2,  2, 43, 31, 28,  2, 42, 38, 37, 37, 28, 43, 42,  1,  2,  2,
         2,  2, 24, 35, 35, 94, 42,  2, 46, 28, 35, 35,  2, 43, 31, 24, 43,  2,
        28, 37, 27, 42,  2, 46, 28, 35, 35,  1,  2,  2,  2,  2, 43, 31, 28,  2,
        43, 41, 24, 30, 28, 27, 48,  2, 38, 29])

In [16]:
tokenizer.decode(x)

'    THE SONNETS\n    ALL’S WELL THAT ENDS WELL\n    THE TRAGEDY OF'

In [17]:
from torch.utils.data import DataLoader, RandomSampler

sampler = RandomSampler(dataset, replacement=True)
dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)

In [18]:
x, y = next(iter(dataloader))

In [19]:
x.shape

torch.Size([2, 64])

In [20]:
x

tensor([[ 2, 53,  2, 65, 67, 66, 71, 72, 57, 70, 10,  1,  1, 38, 35, 27,  2, 35,
         24, 27, 48, 10,  1, 31, 57, 53, 70, 72, 71,  2, 67, 58,  2, 65, 67, 71,
         72,  2, 60, 53, 70, 56,  2, 72, 57, 65, 68, 57, 70,  1, 36, 57, 64, 72,
          2, 53, 66, 56,  2, 64, 53, 65, 57, 66],
        [70, 23,  2, 48, 67, 73,  2, 56, 67,  2, 71, 73, 70, 57, 64, 77,  2, 54,
         53, 70,  2, 72, 60, 57,  1, 56, 67, 67, 70,  2, 73, 68, 67, 66,  2, 77,
         67, 73, 70,  2, 67, 75, 66,  2, 64, 61, 54, 57, 70, 72, 77,  2, 61, 58,
          2, 77, 67, 73,  2, 56, 57, 66, 77,  2]])

In [21]:
tokenizer.decode(x[0])

' a monster.\n\nOLD LADY.\nHearts of most hard temper\nMelt and lamen'

In [22]:
tokenizer.decode(y[0])

'a monster.\n\nOLD LADY.\nHearts of most hard temper\nMelt and lament'