In [11]:
import torch
import torch.nn as nn
import torch.functional as F
import matplotlib.pyplot as plt


# Preprocessing

In [105]:
batch_size = 16
num_steps = 1000
lr = 1e-3
context_length = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [40]:
with open('input.txt', 'r', encoding='utf-8' ) as f:
    data = f.read().lower()
len(data)

1115394

In [50]:
chars = sorted(list(set(data)))
chars.remove('3')
chars.append('<UNK>')
print(chars)
len(chars)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', ':', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '<UNK>']


39

In [54]:
vocab = {}
for t in range(len(chars)):
    vocab[t] = chars[t]

rev_vocab = {v:k for k,v in vocab.items()}
rev_vocab

{'\n': 0,
 ' ': 1,
 '!': 2,
 '$': 3,
 '&': 4,
 "'": 5,
 ',': 6,
 '-': 7,
 '.': 8,
 ':': 9,
 ';': 10,
 '?': 11,
 'a': 12,
 'b': 13,
 'c': 14,
 'd': 15,
 'e': 16,
 'f': 17,
 'g': 18,
 'h': 19,
 'i': 20,
 'j': 21,
 'k': 22,
 'l': 23,
 'm': 24,
 'n': 25,
 'o': 26,
 'p': 27,
 'q': 28,
 'r': 29,
 's': 30,
 't': 31,
 'u': 32,
 'v': 33,
 'w': 34,
 'x': 35,
 'y': 36,
 'z': 37,
 '<UNK>': 38}

In [70]:
def encode(text):
    out = []
    for t in text:
        out.append(rev_vocab.get(t, 38)) # 38 is the <UNK>
    return out


encode('hello1')

[19, 16, 23, 23, 26, 38]

In [72]:
def decode(nums):
    out = ''
    for i in nums:
        out += vocab.get(i, '<UNK>')
    return out


decode([19, 16, 23, 23, 26, 38])

'hello<UNK>'

In [159]:
encoded_data = encode(data)
train_data = encoded_data[:int(0.9*len(encoded_data))]
test_data = encoded_data[int(0.9*len(encoded_data)):]

len(train_data), len(test_data)

(1003854, 111540)

In [156]:
def get_batch(split='train'):
    start = torch.randint(0,len(encoded_data)-context_length-1,(batch_size,))
    print(start)
    outx, outy = [], []
    for i in start:
        outx.append(encoded_data[i:i+context_length])
        if split == 'train':
            outy.append(encoded_data[i+context_length])
    if split == 'train': return torch.tensor(outx), torch.tensor(outy)
    return torch.tensor(outx)

get_batch('test')

tensor([748845, 457334, 726204, 424344, 734603, 103634, 545085,  62470, 328412,
         48875, 942474, 108661,  10514, 394508, 247688, 140337])


tensor([[ 1, 24, 12, 25, 22, 20, 25, 15,  1, 34, 20, 31, 14, 19,  2,  1, 19, 16,
         25, 14, 16,  1, 34, 20, 31, 19,  1, 19, 16, 29,  6,  1],
        [ 1, 15, 20, 16,  6,  0, 31, 29, 12, 25, 30, 27, 12, 29, 16, 25, 31,  1,
         19, 16, 29, 16, 31, 20, 14, 30,  6,  1, 13, 16,  1, 13],
        [25, 18,  1, 31, 26,  1, 31, 19, 16, 24,  1, 13, 26, 31, 19,  9,  1, 34,
         16, 29, 16,  1, 24, 36,  1, 34, 20, 17, 16,  5, 30,  1],
        [22,  9,  0, 34, 16, 23, 14, 26, 24, 16,  6,  1, 24, 36,  1, 30, 26, 25,
          9,  1, 34, 19, 26,  1, 12, 29, 16,  1, 31, 19, 16,  1],
        [ 1, 30, 31, 20, 29, 30,  1, 12, 24, 26, 25, 18, 30, 31,  1, 36, 26, 32,
         11,  1, 14, 26, 24, 16,  6,  1, 30, 20, 29,  6,  1, 25],
        [31, 26,  1, 19, 16, 12, 29,  1, 26, 17,  1, 31, 19, 16, 20, 29,  1, 29,
         16, 12, 15, 20, 25, 16, 30, 30,  6,  1, 12, 25, 15,  1],
        [21, 32, 23, 20, 16, 31,  9,  0, 18, 20, 33, 16,  1, 24, 16,  6,  1, 18,
         20, 33, 16,  1, 24, 16,  2, 

# Model

# Training

# Inference