In [1]:
import torch
import torch.nn as nn
import torch.functional as F
import matplotlib.pyplot as plt


# Preprocessing

In [2]:
batch_size = 16
num_steps = 1000
lr = 1e-3
context_length = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
with open('input.txt', 'r', encoding='utf-8' ) as f:
    data = f.read().lower()
len(data)

1115394

In [4]:
chars = sorted(list(set(data)))
chars.remove('3')
chars.append('<UNK>')
print(chars)
len(chars)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', ':', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '<UNK>']


39

In [5]:
vocab = {}
for t in range(len(chars)):
    vocab[t] = chars[t]

rev_vocab = {v:k for k,v in vocab.items()}
rev_vocab

{'\n': 0,
 ' ': 1,
 '!': 2,
 '$': 3,
 '&': 4,
 "'": 5,
 ',': 6,
 '-': 7,
 '.': 8,
 ':': 9,
 ';': 10,
 '?': 11,
 'a': 12,
 'b': 13,
 'c': 14,
 'd': 15,
 'e': 16,
 'f': 17,
 'g': 18,
 'h': 19,
 'i': 20,
 'j': 21,
 'k': 22,
 'l': 23,
 'm': 24,
 'n': 25,
 'o': 26,
 'p': 27,
 'q': 28,
 'r': 29,
 's': 30,
 't': 31,
 'u': 32,
 'v': 33,
 'w': 34,
 'x': 35,
 'y': 36,
 'z': 37,
 '<UNK>': 38}

In [6]:
def encode(text):
    out = []
    for t in text:
        out.append(rev_vocab.get(t, 38)) # 38 is the <UNK>
    return out


encode('hello1')

[19, 16, 23, 23, 26, 38]

In [7]:
def decode(nums):
    out = ''
    for i in nums:
        out += vocab.get(i, '<UNK>')
    return out


decode([34,  7, 31, 29, 2])

'w-tr!'

In [8]:
encoded_data = torch.tensor(encode(data), dtype=torch.long)
train_data = encoded_data[:int(0.9*len(encoded_data))]
test_data = encoded_data[int(0.9*len(encoded_data)):]

len(train_data), len(test_data)

(1003854, 111540)

In [9]:
def get_batch(split='train'):
    data = train_data if split == 'train' else test_data
    start = torch.randint(0,len(data)-context_length-1,(batch_size,))
    outx = torch.stack([data[i:i+context_length] for i in start]).to(device)
    if split == 'test': return outx
    outy = torch.stack([data[i+1:i+context_length+1] for i in start]).to(device)
    if split == 'train': return outx, outy

get_batch('train')

(tensor([[17, 20, 23, 23,  5, 15,  6,  1, 26, 32, 29,  1, 13, 23, 26, 26, 15,  1,
          20, 30,  1, 14, 26, 23, 15,  6,  1, 12, 25, 15,  1, 31],
         [19, 16,  1, 34, 16, 12, 29, 36,  1, 30, 32, 25,  1, 19, 12, 31, 19,  1,
          24, 12, 15, 16,  1, 12,  1, 18, 26, 23, 15, 16, 25,  1],
         [ 1, 20, 25,  1, 18, 29, 16, 16, 14, 16,  6,  7,  7,  0,  0, 24, 16, 25,
          16, 25, 20, 32, 30,  9,  0, 34, 16, 23, 23,  6,  1, 34],
         [26, 17,  1, 36, 26, 29, 22,  9,  0,  0, 19, 16, 25, 29, 36,  1, 13, 26,
          23, 20, 25, 18, 13, 29, 26, 22, 16,  9,  0, 33, 20, 23],
         [ 1, 17, 26, 29,  1, 31, 19, 16,  1, 30, 16, 23, 17, 30, 12, 24, 16,  1,
          19, 16, 12, 33, 16, 25,  0, 31, 19, 12, 31,  1, 17, 29],
         [ 6,  1, 12, 25, 15,  1, 15, 16, 12, 31, 19,  1, 31, 19, 36,  1, 23, 20,
          17, 16,  2,  0,  0, 18, 23, 26, 32, 14, 16, 30, 31, 16],
         [12, 25, 14, 16,  8,  0,  0, 22, 20, 25, 18,  1, 29, 20, 14, 19, 12, 29,
          15,  1, 20, 20

# Model

# Training

# Inference