In [51]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

batch_size = 1
seq_len = 8
eval_iters = 200
learning_rate = 1e-2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')

In [52]:
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('data/tinyshakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
print(chars)

d_input = len(chars)
print(d_input)

def stoi1hot(c):
    vec = np.zeros(d_input, dtype=np.int8)
    vec[chars.index(c)] = 1
    return vec

def itos1hot(vec):
    return chars[(vec == 1).nonzero(as_tuple=True)[0].item()]

stoi = lambda c: chars.index(c)
itos = lambda n: chars[n]
encode = lambda s: torch.tensor([stoi(c) for c in s], dtype=torch.long)
decode = lambda m: ''.join([itos(i) for i in m])

code = encode('abc')
print(code)
print(decode(code))


['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
65
tensor([39, 40, 41])
abc


In [53]:
data = encode(text)
n_split = int(0.9 * len(data))
train_data = data[:n_split]
test_data = data[n_split:]

def get_batch(mode):
    source = train_data if mode == 'train' else test_data
    starts = torch.randint(len(source) - seq_len - 1, (batch_size, ))
    x = torch.stack([source[s:s+seq_len] for s in starts])
    y = torch.stack([source[s+1:s+1+seq_len] for s in starts])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for mode in ['train', 'eval']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(mode)
            pred, loss = model(X, Y)
            losses[k] = loss.item()
        out[mode] = losses.mean()
    model.train()
    return out

In [54]:
class Bigram(nn.Module):
    def __init__(self, d_model, seq_len, n_batch):
        super(Bigram, self).__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.n_batch = n_batch
        self.embedding = nn.Embedding(d_model, d_model) # 1-hot vector

    def forward(self, x, y=None):
        # x: batch_size x seq_len x char_vec (1 x 8 x 1)
        # y: batch_size x seq_len x char_vec 
        
        z = self.embedding(x) # 1 x 8 x 65

        if y is None:
            loss = None
        else:
            z = z.reshape(self.n_batch * self.seq_len, self.d_model)
            y = y.reshape(self.n_batch * self.seq_len)
            loss = F.cross_entropy(z, y)
        
        return z, loss
    
    def generate(self, x, new_seq_len):
        # x: batch_size x seq_len x char_vec (1 x 8 x 1)
        
        
        for i in range(new_seq_len):
            z = x
            if z.size()[-2] > self.seq_len:
                z = z[:,-self.seq_len:,:]
            z, loss = self(z) # 1x8x1
            z = z[:, -1, :]   # 1x65    get the last char_vec in the sequence
            probs = F.softmax(z, dim=-1) # 1x65
            char_vecs = torch.multinomial(probs, num_samples=1) # 1x1 ?
            x = torch.cat([x, char_vecs], dim=-1) # 1x9
        return x

In [55]:
model = Bigram(d_model=d_input, seq_len=seq_len, n_batch=batch_size)
model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

max_iters = 30000
for iter in range(1, max_iters+1):

    x, y = get_batch('train')

    z, loss = model(x, y)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if iter % (max_iters // 10) == 0 or iter == 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['eval']:.4f}")

step 1: train loss 4.6347, val loss 4.6328
step 3000: train loss 2.6008, val loss 2.5533
step 6000: train loss 2.4793, val loss 2.5410
step 9000: train loss 2.4745, val loss 2.4569
step 12000: train loss 2.4725, val loss 2.5058
step 15000: train loss 2.5049, val loss 2.4921
step 18000: train loss 2.4347, val loss 2.5447
step 21000: train loss 2.4932, val loss 2.4809
step 24000: train loss 2.4957, val loss 2.5258
step 27000: train loss 2.5243, val loss 2.5462
step 30000: train loss 2.4750, val loss 2.5156


In [59]:
context = encode("AB").unsqueeze(0).to(device) # add batch to 1
result = model.generate(context, new_seq_len=300)[0]
print(decode(result))

ABire,
IARCloutand,
Thed f tin
NClan derrl ad cot tast omy gghe y

Th d tol w yeaisue,
OMI:
Wher Ifrt pr cte ud:
Clas:
Anocer thet the t, foks; rgl:
Goowe
Thas ad, p s.

Whathealou pe g panyoow:
Thidotithit lot brun
O:
AREETowishoovagusein. LERI'stoours, s my ous theloscas n stonute lllo hanth f s m R
