In [5]:
import torch
import torch.nn as nn
import torch.distributions as dist
import torch.nn.functional as F

from torch.optim import Adam
from utils import ARDataset, AutoRegressiveNetwork, device
from data_rnn import load_ndfa, load_brackets

def sample(lnprobs, temperature=1.0):
    if temperature == 0.0:
        return lnprobs.argmax()
    p = F.softmax(lnprobs / temperature, dim=0)
    cd = dist.Categorical(p)
    return cd.sample()

In [6]:
x_train, (i2w, w2i) = load_ndfa(n=150_000)

In [7]:
model = AutoRegressiveNetwork(w2i).to(device)
optimizer = Adam([p for p in model.parameters() if p.requires_grad], lr=3e-4, weight_decay=1e-4)
dl = ARDataset(x_train, w2i, bs=16, maxsize=300)
criterion = nn.CrossEntropyLoss()

In [8]:
def generate_seq(model, n=10):
    for _ in range(n):
        seq = [w2i['.start'], w2i['s']]
        while w2i['.end'] not in seq and len(seq) < 100:
            seq.append(sample(model(torch.tensor([seq], dtype=torch.long, device=device))[-1, :]))
        print('\t', ''.join([i2w[i] for i in seq]))

for epoch in range(10):
    model.train()
    dl.shuffle()
    total_loss = 0
    c = 0
    for x, y in dl.dataloader():
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        c += 1
    print(f'Epoch {epoch}, Train Loss: {total_loss / c:.2f}')
    generate_seq(model)

Epoch 0, Train Loss: 1.85
	 .starts.pad!bclmm!!akkblcm!aak.end
	 .starts.end
	 .startsu!au.padsl.start!.pad.pad.padw.end
	 .startsa.startm!wb.unks!s.end
	 .startskmklmlm!k!kblm!abc!bcmk!klabasu!s.end
	 .starts.unk.start.padssa.end
	 .starts.end
	 .starts.end
	 .startsbl.padabk!b!kl.unks.unk.end
	 .starts.pad.end
Epoch 1, Train Loss: 0.68
	 .starts.end
	 .startsa.unkss.end
	 .starts.end
	 .startsuvwvb!uvw!uvws.end
	 .starts.unk!umw!usklm!kl!s.end
	 .startskl!lm!klm!s.end
	 .starts.startsabc!ab.unkc!.unk.padss.end
	 .startsabc!a!klm!.end
	 .startss.end
	 .startskm!klm!kabc!a.unkubc!am!ss.end
Epoch 2, Train Loss: 0.38
	 .startss.end
	 .startsabc!s.end
	 .startsklm!s.end
	 .startsuvvw!uvw!uvw!uvw!klm!uv!uvwvwm!abc!s.end
	 .startss.end
	 .starts.end
	 .startsklm!klm!s.end
	 .startsss.end
	 .startss.end
	 .starts.start.padw!uvw!uvw!uvw!uvw!uvwm!s.end
Epoch 3, Train Loss: 0.28
	 .startsabcs!abc.end
	 .startss.end
	 .startsklm!klm!klm!klm!klm!klm!klm!kklm!klm!klm!uvw!uvw!uvw!s.end
	 .startsuvw

In [9]:
x_train, (i2w, w2i) = load_brackets(n=150_000)

In [10]:
model = AutoRegressiveNetwork(w2i).to(device)
optimizer = Adam([p for p in model.parameters() if p.requires_grad], lr=3e-4, weight_decay=1e-4)
dl = ARDataset(x_train, w2i, bs=16, maxsize=300)
criterion = nn.CrossEntropyLoss()

In [11]:
def generate_seq(model, n=10):
    for _ in range(n):
        seq = [w2i['.start'], w2i['('], w2i['('], w2i[')']]
        while w2i['.end'] not in seq and len(seq) < 100:
            seq.append(sample(model(torch.tensor([seq], dtype=torch.long, device=device))[-1, :], temperature=0.3))
        print('\t', sum([1 if w2i['('] == i else (-1 if w2i[')'] == i else 0) for i in seq[1:-1]]) == 0, ':', ''.join([i2w[i] for i in seq]))

for epoch in range(10):
    model.train()
    dl.shuffle()
    total_loss = 0
    c = 0
    for x, y in dl.dataloader():
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        c += 1
    print(f'Epoch {epoch}, Train Loss: {total_loss / c:.2f}')
    generate_seq(model)

Epoch 0, Train Loss: 1.65
	 False : .start(()((()()))))))()).end
	 True : .start(()).end
	 False : .start(()(((()()(((((.start(()())(())(().unk()()(.end
	 False : .start(()()(()))()))(()()))))(())()())(()())()(().end
	 False : .start(()(.start()())().end
	 False : .start(())(()(().pad)(())))((())().pad))(()().unk.unk((((())()))))))()(.end
	 False : .start(()))(((()()()()()))(((.pad)).end
	 False : .start(().end
	 False : .start(()((((()(((())((()(.pad.end
	 False : .start(()(((())))((()().pad((()())).end
Epoch 1, Train Loss: 1.29
	 False : .start(()))(()))((()()(()()()(()(())())())()((()())()))((()()()()))()((((()())))(()())()()))()()((()))()(
	 False : .start(()())(()((())()(()))((())(((())((()(())())(()())()(()())((()())(()())(()(((()()()()((())))())))(()
	 False : .start(()(()()())())(())()))))()(())))(()()())()()))())(())()))())()()))(()))(()(())(()()())))())()(()()(
	 False : .start(()(()()())()(()()((()()()))())))(((((())))())()())(()((()(((())))(()(()()((()())()()((()(()())())(
