In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import pprint

In [2]:
# Use CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
print(f"Using device: {device}")

Using device: cpu


In [4]:
text = Path('tiny-shakespeare.txt').read_text()

In [6]:
print(text[0:1000])

{\rtf1\ansi\ansicpg1252\cocoartf2822
\cocoatextscaling0\cocoaplatform0{\fonttbl\f0\fmodern\fcharset0 Courier;}
{\colortbl;\red255\green255\blue255;\red0\green0\blue0;}
{\*\expandedcolortbl;;\cssrgb\c0\c0\c0;}
\margl1440\margr1440\vieww20300\viewh13580\viewkind0
\deftab720
\pard\pardeftab720\partightenfactor0

\f0\fs26 \cf0 \expnd0\expndtw0\kerning0
\outl0\strokewidth0 \strokec2 First Citizen:\
Before we proceed any further, hear me speak.\
\
All:\
Speak, speak.\
\
First Citizen:\
You are all resolved rather to die than to famish?\
\
All:\
Resolved. resolved.\
\
First Citizen:\
First, you know Caius Marcius is chief enemy to the people.\
\
All:\
We know't, we know't.\
\
First Citizen:\
Let us kill him, and we'll have corn at our own price.\
Is't a verdict?\
\
All:\
No more talking on't; let it be done: away, away!\
\
Second Citizen:\
One word, good citizens.\
\
First Citizen:\
We are accounted poor citizens, the patricians good.\
What authority surfeits on would relieve us: if they\
wou

In [7]:
class CharTokenizer:
    def __init__(self, vocabulary):
        self.token_id_for_char = {
            char: token_id for token_id, char in enumerate(vocabulary)
        }
        self.char_for_token_id = {
            token_id: char for token_id, char in enumerate(vocabulary)
        }

    @staticmethod
    def train_from_text(text):
        vocabulary = set(text)
        return CharTokenizer(sorted(list(vocabulary)))

    def encode(self, text):
        token_ids = []
        for char in text:
            token_ids.append(self.token_id_for_char[char])
        return torch.tensor(token_ids, dtype=torch.long)

    def decode(self, token_ids):
        chars = []
        for token_id in token_ids.tolist():
            chars.append(self.char_for_token_id[token_id])
        return "".join(chars)

    def vocabulary_size(self):
        return len(self.token_id_for_char)

In [8]:
tokenizer = CharTokenizer.train_from_text(text)

In [9]:
print(tokenizer.encode("Hello world"))

tensor([29, 53, 60, 60, 63,  1, 71, 63, 66, 60, 52])


In [10]:
print(tokenizer.decode(tokenizer.encode("Hello world")))

Hello world


In [11]:
tokenizer.vocabulary_size()

77

In [13]:
pp = pprint.PrettyPrinter(depth=4)

In [14]:
pp.pprint(tokenizer.char_for_token_id)

{0: '\n',
 1: ' ',
 2: '!',
 3: '$',
 4: '&',
 5: "'",
 6: '*',
 7: ',',
 8: '-',
 9: '.',
 10: '0',
 11: '1',
 12: '2',
 13: '3',
 14: '4',
 15: '5',
 16: '6',
 17: '7',
 18: '8',
 19: ':',
 20: ';',
 21: '?',
 22: 'A',
 23: 'B',
 24: 'C',
 25: 'D',
 26: 'E',
 27: 'F',
 28: 'G',
 29: 'H',
 30: 'I',
 31: 'J',
 32: 'K',
 33: 'L',
 34: 'M',
 35: 'N',
 36: 'O',
 37: 'P',
 38: 'Q',
 39: 'R',
 40: 'S',
 41: 'T',
 42: 'U',
 43: 'V',
 44: 'W',
 45: 'X',
 46: 'Y',
 47: 'Z',
 48: '\\',
 49: 'a',
 50: 'b',
 51: 'c',
 52: 'd',
 53: 'e',
 54: 'f',
 55: 'g',
 56: 'h',
 57: 'i',
 58: 'j',
 59: 'k',
 60: 'l',
 61: 'm',
 62: 'n',
 63: 'o',
 64: 'p',
 65: 'q',
 66: 'r',
 67: 's',
 68: 't',
 69: 'u',
 70: 'v',
 71: 'w',
 72: 'x',
 73: 'y',
 74: 'z',
 75: '{',
 76: '}'}


In [15]:
pp.pprint(tokenizer.token_id_for_char)

{'\n': 0,
 ' ': 1,
 '!': 2,
 '$': 3,
 '&': 4,
 "'": 5,
 '*': 6,
 ',': 7,
 '-': 8,
 '.': 9,
 '0': 10,
 '1': 11,
 '2': 12,
 '3': 13,
 '4': 14,
 '5': 15,
 '6': 16,
 '7': 17,
 '8': 18,
 ':': 19,
 ';': 20,
 '?': 21,
 'A': 22,
 'B': 23,
 'C': 24,
 'D': 25,
 'E': 26,
 'F': 27,
 'G': 28,
 'H': 29,
 'I': 30,
 'J': 31,
 'K': 32,
 'L': 33,
 'M': 34,
 'N': 35,
 'O': 36,
 'P': 37,
 'Q': 38,
 'R': 39,
 'S': 40,
 'T': 41,
 'U': 42,
 'V': 43,
 'W': 44,
 'X': 45,
 'Y': 46,
 'Z': 47,
 '\\': 48,
 'a': 49,
 'b': 50,
 'c': 51,
 'd': 52,
 'e': 53,
 'f': 54,
 'g': 55,
 'h': 56,
 'i': 57,
 'j': 58,
 'k': 59,
 'l': 60,
 'm': 61,
 'n': 62,
 'o': 63,
 'p': 64,
 'q': 65,
 'r': 66,
 's': 67,
 't': 68,
 'u': 69,
 'v': 70,
 'w': 71,
 'x': 72,
 'y': 73,
 'z': 74,
 '{': 75,
 '}': 76}
