### 1. Prepare Data

In [15]:
corpus = [
    "ăn quả nhớ kẻ trồng cây",
    "có chí thì nên"    
]
data_size = len(corpus)

# Define the max vocabulary size and sequence length
vocab_size = 15
sequence_length = 7

In [16]:
import torch
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

# Define tokenizer function
tokenizer = get_tokenizer("basic_english")

# Create a function to yield list of tokens
def yield_token(examples):
    for text in examples:
        yield tokenizer(text)

# Create vocabulary
vocab = build_vocab_from_iterator(iterator=yield_token(corpus), 
                                  specials=["<unk>", "<pad>", "<sos_topic1>", "<sos_topic2>", "<eos>"], 
                                  max_tokens=vocab_size)

vocab.set_default_index(vocab["<unk>"])
vocab.get_stoi()


{'trồng': 13,
 '<unk>': 0,
 'có': 7,
 '<pad>': 1,
 '<sos_topic2>': 3,
 'thì': 12,
 'ăn': 14,
 '<sos_topic1>': 2,
 '<eos>': 4,
 'chí': 5,
 'nên': 10,
 'cây': 6,
 'quả': 11,
 'kẻ': 8,
 'nhớ': 9}

In [17]:
data_X, data_y = [], []


corpus[0] =  "<sos_topic1> " + corpus[0] + " <eos>"
corpus[1] =  "<sos_topic2> " + corpus[1] + " <eos>" 

for vector in corpus:
    vector = vector.split()
    data_X.append(vector[:-1])
    data_y.append(vector[1:])

print(data_X)
print(data_y)

[['<sos_topic1>', 'ăn', 'quả', 'nhớ', 'kẻ', 'trồng', 'cây'], ['<sos_topic2>', 'có', 'chí', 'thì', 'nên']]
[['ăn', 'quả', 'nhớ', 'kẻ', 'trồng', 'cây', '<eos>'], ['có', 'chí', 'thì', 'nên', '<eos>']]


In [18]:
# Tokenize and numericalize your samples
def vectorize(X, y, vocab, sequence_length):
    X_ids = [vocab[token] for token in X][:sequence_length]
    y_ids = [vocab[token] for token in y][:sequence_length]

    X_ids_pad = X_ids + [vocab["<pad>"]] * (sequence_length - len(X))
    y_ids_pad = y_ids + [vocab["<pad>"]] * (sequence_length - len(y))

    return X_ids_pad, y_ids_pad

data_X_ids, data_y_ids =[], []

for X, y in zip(data_X, data_y):
    X_ids, y_ids = vectorize(X, y, vocab, sequence_length)
    data_X_ids.append(X_ids)
    data_y_ids.append(y_ids)

data_X_ids = torch.tensor(data_X_ids, dtype=torch.long)
data_y_ids = torch.tensor(data_y_ids, dtype=torch.long)



In [19]:
for x, y in zip(data_X_ids, data_y_ids):
    print(x)
    print(y)
    print()

tensor([ 2, 14, 11,  9,  8, 13,  6])
tensor([14, 11,  9,  8, 13,  6,  4])

tensor([ 3,  7,  5, 12, 10,  1,  1])
tensor([ 7,  5, 12, 10,  4,  1,  1])



### 2. Train with Decoder

In [30]:
class TG_Model(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, num_heads: int):
        super().__init__()

        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)

        self.mask = torch.triu(input=torch.ones(sequence_length, sequence_length), diagonal=1).bool()
        self.decoder_transformer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=num_heads, 
            dim_feedforward=4, 
            dropout=0.0, 
            activation="relu", 
            batch_first=True, 
            bias=True
        )

        self.linear = nn.Linear(embed_dim, vocab_size)

    def forward(self, x): # shape x: [N, sequence_length]
        embedding = self.embedding(x) # shape: [N, sequence_length, embed_dim]
        output = self.decoder_transformer(embedding, src_mask=self.mask) # shape: [N, sequence_length, embed_dim]
        output = self.linear(output) # shape: [N, sequence_length, vocab_size]

        return output.permute(0, 2, 1) # shape: [N, vocab_size, sequence_length]

model = TG_Model(vocab_size=vocab_size, embed_dim=8, num_heads=1)
print(model)

TG_Model(
  (embedding): Embedding(15, 8)
  (decoder_transformer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)
    )
    (linear1): Linear(in_features=8, out_features=4, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (linear2): Linear(in_features=4, out_features=8, bias=True)
    (norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.0, inplace=False)
    (dropout2): Dropout(p=0.0, inplace=False)
  )
  (linear): Linear(in_features=8, out_features=15, bias=True)
)


In [31]:
mock_data = data_X_ids
output = model(mock_data)
print(output.shape)

torch.Size([2, 15, 7])


In [33]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(params=model.parameters(), lr=0.05)

for _ in range(40):
    optimizer.zero_grad()
    output = model(data_X_ids)
    loss = criterion(output, data_y_ids)
    print(loss.item())

    loss.backward()
    optimizer.step()

0.9093026518821716
0.7072864174842834
0.5677958130836487
0.35140368342399597
0.2591227889060974
0.1747504621744156
0.11587797105312347
0.07337851077318192
0.047472644597291946
0.032829366624355316
0.023836275562644005
0.017551401630043983
0.012809100560843945
0.009483846835792065
0.007389824837446213
0.0060529462061822414
0.0048980629071593285
0.003810920985415578
0.0029855084139853716
0.002408376196399331
0.0019975665491074324
0.0016854798886924982
0.001431369804777205
0.0012159096077084541
0.0010354847181588411
0.0008896319195628166
0.0007741297013126314
0.0006835010135546327
0.0006120220059528947
0.0005548485787585378
0.0005084694712422788
0.00047018862096592784
0.0004379051097203046
0.000410138803999871
0.00038569868775084615
0.00036380221717990935
0.0003439132997300476
0.00032570032635703683
0.0003090443497058004
0.0002938518300652504


In [34]:
output = model(data_X_ids)
print(torch.argmax(output, axis=1))

tensor([[14, 11,  9,  8, 13,  6,  4],
        [ 7,  5, 12, 10,  4,  1,  1]])


In [35]:
data_y_ids

tensor([[14, 11,  9,  8, 13,  6,  4],
        [ 7,  5, 12, 10,  4,  1,  1]])

### 3. Inference

In [36]:
promt = '<sos_topic2> có'
promt = promt.split()
promt_ids = [vocab[token] for token in promt][:sequence_length]
promt_ids = promt_ids + [vocab["<pad>"]] * (sequence_length - len(promt))

print(promt_ids)

[3, 7, 1, 1, 1, 1, 1]


In [39]:

id2label = {id: label for label, id in vocab.get_stoi().items()}

for i in range(sequence_length - len(promt)):
    promt_tensor = torch.tensor(promt_ids, dtype=torch.long).reshape(1, -1)
    outputs = model(promt_tensor)
    outputs = torch.argmax(outputs, axis=1)   
    next_id = outputs[0][len(promt)+i-1]

    promt_ids[len(promt)+i] = next_id.item()
    prompt_token = [id2label[id] for id in promt_ids]
    print(promt_ids)
    print(prompt_token)

[3, 7, 5, 12, 10, 4, 6]
['<sos_topic2>', 'có', 'chí', 'thì', 'nên', '<eos>', 'cây']
[3, 7, 5, 12, 10, 4, 6]
['<sos_topic2>', 'có', 'chí', 'thì', 'nên', '<eos>', 'cây']
[3, 7, 5, 12, 10, 4, 6]
['<sos_topic2>', 'có', 'chí', 'thì', 'nên', '<eos>', 'cây']
[3, 7, 5, 12, 10, 4, 6]
['<sos_topic2>', 'có', 'chí', 'thì', 'nên', '<eos>', 'cây']
[3, 7, 5, 12, 10, 4, 6]
['<sos_topic2>', 'có', 'chí', 'thì', 'nên', '<eos>', 'cây']
