In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext.data as data
import spacy

In [2]:
spacy_en = spacy.load("en_core_web_sm")
def tokenizer(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]
question = data.Field(lower=True, 
                      init_token="<sos>",
                      eos_token="<eos>",
                      tokenize=tokenizer)

answer = data.Field(lower=True, 
                    init_token="<sos>", 
                    eos_token="<eos>", 
                    tokenize=tokenizer)

In [3]:
datafields = {"Question":("q", question), "Answer":("a", answer)}
dataset = data.TabularDataset(path="questions_answers.csv", format="csv", fields=datafields)

In [4]:
question.build_vocab(dataset, min_freq=2)
answer.build_vocab(dataset, min_freq=2)
train_data, valid_data = dataset.split(split_ratio=0.7)
train_iter = data.BucketIterator(
        train_data,
        batch_size=32,
        device='cuda',
        sort_within_batch=True,
        sort_key=lambda x: len(x.q)
)

In [5]:
class Transformer(nn.Module):
    def __init__(self, embedding_size, 
                 src_vocab_size,
                 trg_vocab_size,
                 src_pad_index,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 dense_dim,
                 dropout,
                 max_len,
                ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)
        self.transformer = nn.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            dense_dim,
            dropout
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_index = src_pad_index
    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_index
        return src_mask
    
    def forward(self, src, trg):
        src_seq_len, N = src.shape
        trg_seq_len, N = trg.shape
        src_positions = (
            torch.arange(0, src_seq_len).unsqueeze(1).expand(src_seq_len, N).cuda()
        )
        trg_positions = (
            torch.arange(0, trg_seq_len).unsqueeze(1).expand(trg_seq_len, N).cuda()
        )
        embed_src = self.dropout((self.src_word_embedding(src) + self.src_position_embedding(src_positions)))
        embed_trg = self.dropout((self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions)))
        src_padding_mask = self.make_src_mask(src)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_len).cuda()
        out = self.transformer(
            embed_src,
            embed_trg,
            src_key_padding_mask = src_padding_mask,
            tgt_mask=trg_mask
        )
        out = self.fc_out(out)
        return out  

In [6]:
num_epochs = 125
learning_rate = 3e-4
src_vocab_size = len(question.vocab)
trg_vocab_size= len(answer.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.1
max_len = max(max([batch.a.shape[0] for batch in train_iter]), max([batch.q.shape[0] for batch in train_iter]))
dense_dim = 64
src_pad_index = question.vocab.stoi["<pad>"]

net = Transformer(embedding_size,  src_vocab_size, 
                  trg_vocab_size, src_pad_index, 
                  num_heads, num_encoder_layers,
                  num_decoder_layers, dense_dim, dropout, max_len).cuda()

trg_pad_index = answer.vocab.stoi["<pad>"]
loss_fn = nn.CrossEntropyLoss(ignore_index=trg_pad_index)
opt = optim.Adam(net.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt, factor=0.1, patience=10, verbose=True
)

In [7]:
test_sentence = "Who suggested Lincoln grow a beard?" 
def predict(test_sentence, model):
    max_len = 100
    sen = [w.text.lower() for w in spacy_en.tokenizer(test_sentence)]
    sen.insert(0, question.init_token)
    sen.append(question.eos_token)
    inp_sen = [question.vocab.stoi[i] for i in sen]
    inp_sen = torch.tensor(inp_sen, dtype=torch.long).unsqueeze(1).cuda()
    outputs = [answer.vocab.stoi["<sos>"]]
    for i in range(max_len):
        trg = torch.tensor(outputs, dtype=torch.long).unsqueeze(1).cuda()
        with torch.no_grad():
            output = model(inp_sen, trg)
        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)
        if best_guess == answer.vocab.stoi["<eos>"]:
            break
    pred_sentence = [answer.vocab.itos[i] for i in outputs]
    return pred_sentence

In [8]:
for epoch in range(num_epochs):
    net.eval()
    pred_sen = predict(test_sentence, net)
    print(" ".join(pred_sen))
    losses = []
    for data in train_iter:
        net.train()
        outs = net(data.q, data.a[:-1, :])
        outs = outs.reshape(-1, outs.shape[2])
        target = data.a[1:].reshape(-1)
        opt.zero_grad()
        loss = loss_fn(outs, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1)
        opt.step()
        losses.append(loss.item())
    mean_loss = sum(losses) / len(losses)
    scheduler.step(mean_loss)
    print(f"[Epoch {epoch}] Loss {mean_loss}")

<sos> least too 10 blaine too do months months months do least too 10 being being being months do ursidae : months months buffalo too least blanco formally simple stronghold being monroe kivi blanco months across do france being monroe recognized monroe if mining buffalo months blaine do ursidae blaine do ursidae too 10 recognized monroe months cello recognized monroe buffalo blaine do ursidae : least too 10 months america pascal monroe monroe members least kilometers monroe pull entries blaine do least ursidae simple least kilometers months gouverneur least too cello recognized do months high if mining : formally r. recognized
[Epoch 0] Loss 4.906983041763306
<sos> <unk> <eos>
[Epoch 1] Loss 4.339334112803141
<sos> the <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> 

In [11]:
new_net = torch.load("chatbot_model.pth").cuda()
predict(test_sentence, new_net)

['<sos>', 'grace', 'bedell', '.', '<eos>']

In [15]:
def save_vocab(vocab, path):
    with open(path, 'w+', encoding='utf-8') as f:     
        for token, index in vocab.stoi.items():
            f.write(f'{index}\t{token}')

# save_vocab(question.vocab, "src_vocab.txt")
# save_vocab(answer.vocab, "trg_vocab.txt")