In [37]:
import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
from transformers import BertTokenizer
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_

In [38]:
pad = 1

max_length = 60
batch_size = 64
d_model = 512
learning_rate = 0.001
num_epochs = 3

In [39]:
def fullfill(tensor):
    tensor.extend([0] * (max_length - len(tensor)))
    return tensor

In [40]:

tokenizer = BertTokenizer(vocab_file = "vocab.txt")
tokenizer.add_special_tokens({
    "additional_special_tokens": ["vindex", "command", "input"]
})
src_data = []
tgt_data = []
f = open("./data/train.txt", "r")
for line in f.readlines():
    src, tgt = line.split(" ")
    src_data.append(src)
    tgt_data.append(tgt)

src_tensors = [fullfill(tokenizer.encode(src)) for src in src_data]
tgt_tensors = [fullfill(tokenizer.encode(tgt)) for tgt in tgt_data]

train_data = list(zip(src_tensors, tgt_tensors))
dict_size = len(tokenizer.vocab) + 3

In [41]:
def collate_fn(batch):
    return torch.tensor(batch, dtype = torch.long)
train_loader = DataLoader(train_data, batch_size = batch_size, shuffle = True, collate_fn = collate_fn)

In [42]:
class CommandParser(nn.Module):
    def __init__(self):
        super(CommandParser, self).__init__()
        self.embedding = nn.Embedding(dict_size, d_model)
        self.transformer = nn.Transformer(d_model = d_model, nhead = 8, num_encoder_layers = 4, num_decoder_layers = 4, dim_feedforward = 2048, dropout = 0.2, batch_first = True)
        self.fc = nn.Linear(d_model, dict_size)
    def forward(self, src, tgt):
        embeded_src = self.embedding(src)
        embeded_tgt = self.embedding(tgt)
        out = self.transformer(embeded_src, embeded_tgt)
        out = self.fc(out)
        return out

In [43]:
model = CommandParser()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [45]:
for epoch in range(num_epochs):
    total_loss = 0.0
    num_batches = 0
    for batch in train_loader:
        src, tgt = batch[:, 0], batch[:, 1]
        optimizer.zero_grad()
        out = model(src, tgt)
        loss = loss_fn(out.reshape(-1, dict_size), tgt.reshape(-1))
        loss.backward()

        clip_grad_norm_(model.parameters(), max_norm = 1)
        optimizer.step()
        total_loss += loss.item()
        num_batches += 1
    average_loss = total_loss / num_batches
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}")
    

KeyboardInterrupt: 