In [26]:
import torch
import torch.nn as nn
import numpy as np
from torch.optim import Adam
from data_rnn import load_ndfa, load_brackets
from utils import device
from random import choices
from tqdm import tqdm

In [27]:
x_train, (i2w, w2i) = load_ndfa(n=150_000)

In [28]:
class Dataset():
    def __init__(self, x, maxsize=10000, bs=32):
        self.x = [[w2i['.start']] + i + [w2i['.end']] for i in x]
        self.c = 0
        self.maxsize = max(maxsize, len(max(self.x, key=lambda x: len(x))) + 1)
        self.bs = bs
    
    def shuffle(self):
        self.c = 0
        self.x = [self.x[i] for i in np.random.permutation(range(len(self.x)))]
    
    def get(self):
        temp = []
        while len(temp) < self.maxsize:
            if self.c >= len(self.x) or len(temp) + len(self.x[self.c]) + 1 > self.maxsize:
                temp.extend([w2i['.pad']] * (self.maxsize - len(temp)))
            else:
                temp.extend(self.x[self.c] + [w2i['.pad']])
                self.c += 1
        return torch.tensor(temp, dtype=torch.long)
    
    def dataloader(self):
        while self.c < len(self.x):
            x = torch.concat([self.get().view(1, -1) for _ in range(self.bs)]).to(device)
            x = x[x.count_nonzero(dim=1) > 1]
            y = torch.zeros_like(x)
            y[:, :-1] = x[:, 1:]
            yield x, y.view(-1)

In [29]:
class Network(nn.Module):
    def __init__(self, emb=32, h=16):
        super().__init__()
        self.emb = nn.Embedding(len(i2w), emb)
        self.lstm = nn.LSTM(emb, h, batch_first=True)
        self.linear = nn.Linear(h, len(w2i))
    
    def forward(self, x):
        x = self.emb(x)
        x, _ = self.lstm(x)
        x = self.linear(x)
        return x.view(-1, len(w2i))

In [30]:
model = Network().to(device)
optimizer = Adam([p for p in model.parameters() if p.requires_grad], lr=3e-4, weight_decay=1e-4)
dl = Dataset(x_train, bs=8, maxsize=200)
criterion = nn.CrossEntropyLoss()

In [31]:
for epoch in range(10):
    model.train()
    dl.shuffle()
    total_loss = 0
    c = 0
    for x, y in dl.dataloader():
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        c += 1
    print(f'Epoch {epoch}, Train Loss: {total_loss / c:.2f}')

Epoch 0, Train Loss: 0.94
Epoch 1, Train Loss: 0.24
Epoch 2, Train Loss: 0.21
Epoch 3, Train Loss: 0.20
Epoch 4, Train Loss: 0.19
Epoch 5, Train Loss: 0.19
Epoch 6, Train Loss: 0.19
Epoch 7, Train Loss: 0.19
Epoch 8, Train Loss: 0.19
Epoch 9, Train Loss: 0.19


In [32]:
sm = nn.Softmax(dim=1)
for _ in range(10):
    seq = [w2i['.start'], w2i['s']]
    while w2i['.end'] not in seq:
        seq.append(choices(range(len(w2i)), weights=sm(model(torch.tensor([seq], dtype=torch.long, device=device)))[-1, :].detach(), k=1)[0])
    print(''.join([i2w[i] for i in seq]))

.startsklm!klm!s.end
.startsuvw!uvw!uvw!s.end
.startsklm!klmm!klm!klm!klm!s.end
.startsuvw!uvw!uvw!s.end
.startsabc!abc!abc!s.end
.startsklm!s.end
.startsuvw!uvw!uvw!uvw!s.end
.startsuvw!uvw!s.end
.startsabc!abc!s.end
.startss.end


In [33]:
x_train, (i2w, w2i) = load_brackets(n=150_000)

In [34]:
model = Network().to(device)
optimizer = Adam([p for p in model.parameters() if p.requires_grad], lr=3e-4, weight_decay=1e-4)
dl = Dataset(x_train, bs=8, maxsize=200)
criterion = nn.CrossEntropyLoss()

In [35]:
for epoch in range(10):
    model.train()
    dl.shuffle()
    total_loss = 0
    c = 0
    for x, y in dl.dataloader():
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        c += 1
    print(f'Epoch {epoch}, Train Loss: {total_loss / c:.2f}')

Epoch 0, Train Loss: 1.33
Epoch 1, Train Loss: 0.91
Epoch 2, Train Loss: 0.71
Epoch 3, Train Loss: 0.61
Epoch 4, Train Loss: 0.55
Epoch 5, Train Loss: 0.52
Epoch 6, Train Loss: 0.50
Epoch 7, Train Loss: 0.49
Epoch 8, Train Loss: 0.48
Epoch 9, Train Loss: 0.48


In [36]:
sm = nn.Softmax(dim=1)
for _ in range(10):
    seq = [w2i['.start'], w2i['('], w2i['('], w2i[')']]
    while w2i['.end'] not in seq:
        seq.append(choices(range(len(w2i)), weights=sm(model(torch.tensor([seq], dtype=torch.long, device=device)))[-1, :].detach(), k=1)[0])
    print(sum([1 if w2i['('] == i else -1 for i in seq[1:-1]]) == 0, ':', ''.join([i2w[i] for i in seq]))

True : .start(()()(()(()()(()()((()))((((())(())))())))(()))()(())).end
True : .start(()()).end
True : .start(()(()(((((()(())))((()((())()))()))())()((()())(((((()((()()))))()))(())(()()(()))(())((()(())))))))).end
True : .start(()).end
True : .start(()(())).end
True : .start(()()()((()()))(()))(()()).end
True : .start(()()).end
True : .start(()).end
True : .start(()(()((())))).end
False : .start(()(())())).end
