In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn
from tokenizer import CharTokenizer
import torch.nn.functional as F


In [3]:
import numpy as np

# Read the file and split into sentences
base_path = 'dataset'

with open(f'{base_path}/vi_sents', encoding='utf-8') as f:
    sentences = f.read().split('\n')

# Calculate lengths of each sentence
sentence_lengths = [len(sentence) for sentence in sentences]

# Calculate 95th percentile length
percentile_95 = np.percentile(sentence_lengths, 95)

print(percentile_95)


with open(f'{base_path}/en_sents', encoding='utf-8') as f:
    sentences = f.read().split('\n')

# Calculate lengths of each sentence
sentence_lengths = [len(sentence) for sentence in sentences]

# Calculate 95th percentile length
percentile_95 = np.percentile(sentence_lengths, 95)

print(percentile_95)

# # Filter sentences - keep only those <= 95th percentile length
# filtered_sentences = [sentence for sentence in sentences if len(sentence) <= percentile_95]

# # If you want to save the filtered sentences back to a file
# with open('filtered_file.txt', 'w', encoding='utf-8') as f:
#     f.write('\n'.join(filtered_sentences))

# print(f"95th percentile length: {percentile_95}")
# print(f"Original number of sentences: {len(sentences)}")
# print(f"Filtered number of sentences: {len(filtered_sentences)}")

62.0
59.0


In [4]:
#torch.manual_seed(1337)

base_path = 'dataset'

with open(f'{base_path}/vi_sents', encoding='utf-8') as f:
    vi_sentences = f.read().split('\n')

with open(f'{base_path}/en_sents', encoding='utf-8') as f:
    en_sentences = f.read().split('\n')


assert len(vi_sentences) == len(en_sentences), "Files have different number of sentences"

vi_lengths = [len(s) for s in vi_sentences]
en_lengths = [len(s) for s in en_sentences]
vi_percentile_95 = np.percentile(vi_lengths, 95)
en_percentile_95 = np.percentile(en_lengths, 95)


filtered_pairs = [(vi, en) for vi, en, vi_len, en_len in zip(
    vi_sentences, en_sentences, vi_lengths, en_lengths)
    if vi_len <= vi_percentile_95 and en_len <= en_percentile_95
]

# Separate back into Vietnamese and English
filtered_vi, filtered_en = zip(*filtered_pairs)

# Join with newlines
vi_text = '\n'.join(filtered_vi)
en_text = '\n'.join(filtered_en)


vi_chars = sorted(list(set(vi_text)))
en_chars = sorted(list(set(en_text)))

In [5]:
token_padding = "\uE000"
token_start =  "\uE001"
token_end = "\uE002"
vi_chars = [token_padding, token_start, token_end] + vi_chars
en_chars = [token_padding] + en_chars

In [6]:
vocab_vi_size = len(vi_chars)

stoi_vi = { ch:i for i,ch in enumerate(vi_chars) }
itos_vi = { i:ch for i,ch in enumerate(vi_chars) }

encode_vi = lambda s: [stoi_vi[c] for c in s]
decode_vi = lambda l: ''.join([itos_vi[i] for i in l])

vocab_eng_size = len(en_chars)

stoi_en = { ch:i for i,ch in enumerate(en_chars) }
itos_en = { i:ch for i,ch in enumerate(en_chars) }

encode_en = lambda s: [stoi_en[c] for c in s]
decode_en = lambda l: ''.join([itos_en[i] for i in l])


In [7]:
addToken = lambda text: "".join(f"{token_start} {line} {token_end}" for line in text.split("\n"))

vi = encode_vi(addToken(vi_text))
en = encode_en(en_text)

In [8]:
from torch.nn.utils.rnn import pad_sequence

def create_data(tokenized, token_end, add_end, token_padding):    
    sublists = []
    current = []
    for token in tokenized:
        if token == token_end:
            if current:
                if add_end:
                    current.append(token)
                sublists.append(torch.tensor(current))
                current = []
        else:
            current.append(token)
    if current:
        sublists.append(torch.tensor(current))
    padded = pad_sequence(sublists, batch_first=True, padding_value=token_padding)
    return padded

data_viet = create_data(vi, stoi_vi[token_end], True, stoi_vi[token_padding])
data_eng = create_data(en, stoi_en['\n'], False, stoi_en[token_padding])

In [9]:
data_viet.shape, data_eng.shape

(torch.Size([237469, 66]), torch.Size([237469, 59]))

In [10]:
iterations = 20000
eval_iters = 100
batch_size = 32
lr = 1e-3
emb_enc_size = 256
emb_dec_size = 256
num_heads = 3
num_layers = 5
eval_interval = 250
qk_dim = 128

device = 'mps'

In [11]:
n = int(0.9 * data_viet.shape[0])
train_data = (data_eng[:n], data_viet[:n])
val_data = (data_eng[n:], data_viet[n:])

In [12]:
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            (X_eng, x_viet), Y = get_batch(split)
            logits = model(X_eng, x_viet)
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            Y = Y.view(B*T)
            loss = F.cross_entropy(logits, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

def get_batch(split):
    data = train_data if split == 'train' else val_data
    N, L2 = data[1].shape
    ix = torch.randint(0, N, (batch_size,))
    l1, l2 = data
    x = (l1[ix].to(device),l2[ix, : L2 - 1].to(device))
    y = l2[ix, 1 : L2].to(device)

    return x, y

In [13]:
(x_eng, x_viet), y = get_batch('train')
print(f'Encoder input: {decode_en(x_eng[0].tolist())}')
for i in range(len(decode_vi(x_viet[0].tolist()))):
    print(f'When: "{decode_vi(x_viet[0].tolist())[:i + 1]}" then:"{decode_vi(y[0].tolist())[i]}"')

Encoder input: I've never been better
When: "" then:" "
When: " " then:"t"
When: " t" then:"ô"
When: " tô" then:"i"
When: " tôi" then:" "
When: " tôi " then:"c"
When: " tôi c" then:"h"
When: " tôi ch" then:"ư"
When: " tôi chư" then:"a"
When: " tôi chưa" then:" "
When: " tôi chưa " then:"b"
When: " tôi chưa b" then:"a"
When: " tôi chưa ba" then:"o"
When: " tôi chưa bao" then:" "
When: " tôi chưa bao " then:"g"
When: " tôi chưa bao g" then:"i"
When: " tôi chưa bao gi" then:"ờ"
When: " tôi chưa bao giờ" then:" "
When: " tôi chưa bao giờ " then:"t"
When: " tôi chưa bao giờ t" then:"ố"
When: " tôi chưa bao giờ tố" then:"t"
When: " tôi chưa bao giờ tốt" then:" "
When: " tôi chưa bao giờ tốt " then:"h"
When: " tôi chưa bao giờ tốt h" then:"ơ"
When: " tôi chưa bao giờ tốt hơ" then:"n"
When: " tôi chưa bao giờ tốt hơn" then:" "
When: " tôi chưa bao giờ tốt hơn " then:""
When: " tôi chưa bao giờ tốt hơn " then:""
When: " tôi ch

In [14]:
class SelfAttention(nn.Module):
    def __init__(self, in_dim, qk_dim, num_heads, context_size = None, dropout = 0.1): # context_size = None means not causal attention
        super().__init__()

        self.query = nn.Linear(in_dim, qk_dim * num_heads, bias=True)
        self.key = nn.Linear(in_dim, qk_dim * num_heads, bias=True)
        self.value = nn.Linear(in_dim, qk_dim * num_heads, bias=True)
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)

        self.context_size = context_size
        
        if self.context_size is not None:
            self.register_buffer('att_mask', torch.triu(torch.ones((context_size, context_size)), diagonal= 1).bool())

        self.proj = nn.Linear(qk_dim * num_heads, in_dim)

    def forward(self, x):
        B,T, _ = x.shape

        queries = self.query(x) # (B, T, qk_dim * num_heads)
        keys = self.key(x) # (B, T, qk_dim * num_heads)
        values = self.value(x) # (B, T, qk_dim * num_heads)

        queries = queries.reshape(B, T, self.num_heads, -1).transpose(1, 2) # (B, num_heads, T, qk_dim)
        keys = keys.reshape(B, T, self.num_heads, -1).transpose(1, 2) # (B, num_heads, T, qk_dim)
        values = values.reshape(B, T, self.num_heads, -1).transpose(1, 2) # (B, num_heads, T, qk_dim)

        att = queries @ keys.transpose(2, 3) * queries.shape[3]**(-0.5) # (B, num_heads, T, T) = (B, num_heads, T, qk_dim) x (B, num_heads, qk_dim, T)
        if self.context_size is not None:
            att = att.masked_fill(self.att_mask[:T,:T], float("-inf"))
        att_norm = F.softmax(att, dim = 3) # (B, num_heads, T, T)
        att_norm = self.dropout(att_norm)
        v = att_norm @ values #  (B, num_heads, T, qk_dim) = (B, num_heads, T, T) x (B, num_heads, T, qk_dim)
        v = v.transpose(1,2).reshape(B, T, -1)  # (B, T, qk_dim * num_heads)
        out = self.dropout(self.proj(v)) # (B, T, in_dim)
        return out


class CrossAttention(nn.Module):
    def __init__(self, in_enc_dim, in_dec_dim, qk_dim, num_heads, dropout = 0.1):
        super().__init__()
        self.query = nn.Linear(in_dec_dim, qk_dim * num_heads, bias=True) # Check if bias is necessary
        self.key = nn.Linear(in_enc_dim, qk_dim * num_heads, bias=True) # Check if bias is necessary
        self.value = nn.Linear(in_enc_dim, qk_dim * num_heads, bias=True) # Check if bias is necessary
        self.proj = nn.Linear(qk_dim * num_heads, in_dec_dim)
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_enc, x_dec):
        B,T_dec, _ = x_dec.shape
        B,T_enc, _ = x_enc.shape

        queries = self.query(x_dec) # (B, T, qk_dim * num_heads)
        keys = self.key(x_enc) # (B, T, qk_dim * num_heads)
        values = self.value(x_enc) # (B, T, qk_dim * num_heads)

        queries = queries.reshape(B, T_dec, self.num_heads, -1).transpose(1,2) # (B, num_heads, T_dec, qk_dim)
        keys = keys.reshape(B, T_enc, self.num_heads, -1).transpose(1,2) # (B, num_heads, T_enc, qk_dim)
        values = values.reshape(B, T_enc, self.num_heads, -1).transpose(1,2) # (B, num_heads, T_enc, qk_dim)

        att = queries @ keys.transpose(2, 3) * queries.shape[3]**(-0.5) # (B, num_heads, T_dec, T_enc) = (B, num_heads, T_dec, qk_dim) x (B, num_heads, qk_dim, T_enc)
        att_norm = F.softmax(att, dim = 3) # (B, num_heads, T_dec, T_enc)
        att_norm = self.dropout(att_norm)
        v = att_norm @ values # (B, num_heads, T_dec, qk_dim) = (B, num_heads, T_dec, T_enc) x (B, num_heads, T_enc, qk_dim)
        
        v = v.transpose(1,2).reshape(B, T_dec, -1)  # (B, T_dec, num_heads * qk_dim)
        out = self.dropout(self.proj(v))
        return out


class FeedForward(nn.Module):
    def __init__(self, in_dim, dropout=0.1):
        super().__init__()
        inner_dim = in_dim * 4
        self.first = nn.Linear(in_dim, inner_dim, bias = True)
        self.second = nn.Linear(inner_dim, in_dim, bias = True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.relu(self.first(x))
        x = self.second(x)
        x = self.dropout(x)
        return x

class Encoder(nn.Module):
    def __init__(self, in_dim, qk_dim, num_heads, dropout=0.1):
        super().__init__()
        self.sa = SelfAttention(in_dim, qk_dim, num_heads, dropout = dropout)
        self.ln1 = nn.LayerNorm(in_dim)
        self.ffn = FeedForward(in_dim, dropout = dropout)
        self.ln2 = nn.LayerNorm(in_dim)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

class Decoder(nn.Module):
    def __init__(self, in_enc_dim, in_dec_dim, qk_dim, num_heads, context_size, dropout = 0.1):
        super().__init__()
        self.sa = SelfAttention(in_dec_dim, qk_dim, num_heads, context_size = context_size, dropout = dropout)
        self.ln1 = nn.LayerNorm(in_dec_dim)
        self.ca = CrossAttention(in_enc_dim, in_dec_dim, qk_dim, num_heads, dropout = dropout)
        self.ln2 = nn.LayerNorm(in_dec_dim)
        self.ffn = FeedForward(in_dec_dim, dropout = dropout)
        self.ln3 = nn.LayerNorm(in_dec_dim)


    def forward(self, x_enc, x_dec):
        x = x_dec + self.sa(self.ln1(x_dec))
        x = x + self.ca(x_enc, self.ln2(x))
        x = x + self.ffn(self.ln3(x))
        return x

class Transformer(nn.Module):
    def __init__(self, vocab_enc_size, emb_enc_dim, context_enc_size, vocab_dec_size, emb_dec_dim, context_dec_size, qk_dim, num_heads, num_layers, dropout = 0.1):
        super().__init__()
        assert num_layers > 0
        self.emb_enc = nn.Embedding(vocab_enc_size, emb_enc_dim)
        self.emb_dec = nn.Embedding(vocab_dec_size, emb_dec_dim)

        self.encoders = nn.ModuleList([Encoder(emb_enc_dim, qk_dim, num_heads,  dropout = dropout) for _ in range(num_layers)])
        self.decoders = nn.ModuleList([Decoder(emb_enc_dim, emb_dec_dim, qk_dim, num_heads, context_dec_size, dropout = dropout) for _ in range(num_layers)])
        self.linear = nn.Linear(emb_dec_dim, vocab_dec_size)

        self.register_buffer('pos_enc_emb', self.positional_encoding(emb_enc_dim, context_enc_size)) 
        self.register_buffer('pos_dec_emb', self.positional_encoding(emb_dec_dim, context_dec_size))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_enc, x_dec):
        x_enc_emb = self.dropout(self.emb_enc(x_enc) + self.pos_enc_emb[:x_enc.shape[1]]) # Broadcasting
        x_dec_emb = self.dropout(self.emb_dec(x_dec) + self.pos_dec_emb[:x_dec.shape[1]]) # Broadcasting

        for encoder in self.encoders:
            x_enc_emb = encoder(x_enc_emb)
            
        for decoder in self.decoders:
            x_dec_emb = decoder(x_enc_emb, x_dec_emb)

        out = self.linear(x_dec_emb)
        return out

    def positional_encoding(self, in_dim, length):
        pos = torch.arange(length)[:, None] # (length, 1)
        i = torch.arange(in_dim)[None, :] # (1, in_dim)
        #i = 10000 ** (i / in_dim)

        angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / in_dim)
        pe = pos * angle_rates  # (length, in_dim)

        #pe = pos / i # (length, in_dim)
        
        pe[:, ::2] = torch.sin(pe[:, ::2])
        pe[:, 1::2] = torch.cos(pe[:, 1::2])
        return pe
    
    def generate(self, x_enc, start_token_id, end_token_id, max_length=64):

        self.eval()
        with torch.no_grad():
            current = torch.tensor([[start_token_id]], device=x_enc.device)
            
            for _ in range(max_length):
                logits = self(x_enc, current)  # (1, t, vocab_size)
                logits = logits[:, -1, :]  # (1, vocab_size)
                probs = F.softmax(logits, dim=1)  # (1, vocab_size)
                next_token = torch.multinomial(probs, 1)  # (1, 1)
                current = torch.cat((current, next_token), dim=1)
                
                if next_token.item() == end_token_id:
                    break
                    
        return current
    

In [15]:
context_dec_size, context_enc_size = data_viet.shape[1] - 1, data_eng.shape[1]
vocab_vi_size, vocab_eng_size, context_dec_size, context_enc_size

(177, 104, 65, 59)

In [16]:
tf = Transformer(vocab_eng_size, emb_enc_size, context_enc_size, vocab_vi_size, emb_dec_size, context_dec_size, qk_dim, num_heads, num_layers)
tf.to(device)
optimizer = torch.optim.AdamW(tf.parameters(), lr)

In [17]:
losses = []
for iter in range(iterations):
    if iter % eval_interval == 0 or iter == iterations - 1:
        losses_i = estimate_loss(tf)
        losses.append(losses_i)
        print(f"step {iter}: train loss {losses_i['train']:.4f}, val loss {losses_i['val']:.4f}")
        
    (x_eng, x_viet), y = get_batch('train')
    logits = tf(x_eng, x_viet)
    B, T, C = logits.shape
    logits = logits.view(B*T, C)
    y = y.view(B*T)
    loss = F.cross_entropy(logits, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 5.5342, val loss 5.5332
step 250: train loss 0.8113, val loss 0.8030
step 500: train loss 0.6959, val loss 0.6807
step 750: train loss 0.6358, val loss 0.6324
step 1000: train loss 0.5868, val loss 0.5985
step 1250: train loss 0.5702, val loss 0.5799
step 1500: train loss 0.5557, val loss 0.5626
step 1750: train loss 0.5400, val loss 0.5330
step 2000: train loss 0.5141, val loss 0.5320
step 2250: train loss 0.5107, val loss 0.5140
step 2500: train loss 0.4938, val loss 0.4983
step 2750: train loss 0.4725, val loss 0.4867
step 3000: train loss 0.4673, val loss 0.4685
step 3250: train loss 0.4540, val loss 0.4651
step 3500: train loss 0.4519, val loss 0.4472
step 3750: train loss 0.4407, val loss 0.4477
step 4000: train loss 0.4335, val loss 0.4313
step 4250: train loss 0.4127, val loss 0.4126
step 4500: train loss 0.4072, val loss 0.4184
step 4750: train loss 0.4035, val loss 0.4136
step 5000: train loss 0.3926, val loss 0.3899
step 5250: train loss 0.3823, val loss 0

KeyboardInterrupt: 

In [18]:
state = {
    'epoch': iterations,
    'state_dict': tf.state_dict(),
    'optimizer': optimizer.state_dict(),
    'losses': losses
}
torch.save(state, './model.pt')

In [28]:
input = 'The crowd went wild.\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000\ue000'
out = tf.generate(torch.tensor(encode_en(input))[None,:].to(device), stoi_vi[token_start], stoi_vi[token_end], 1000)
#print(decode_vi(*out.tolist()))
out = decode_vi(*out.tolist())
print(out)

 Chiếc xe thời gian. 
