In [None]:
'''
  code by Tae Hwan Jung(Jeff Jung) @graykode, Derek Miller @dmmiller612, modify by wmathor
  Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch
              https://github.com/JayParks/transformer
https://blog.csdn.net/BXD1314/article/details/126187598https://github.com/wmathor/nlp-tutorial/blob/master/5-1.Transformer/Transformer_Torch.ipynb'''
import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
import re
import wandb

%wandb login
wandb.login()
wandb.init(project="ww")
with open('/root/Multi-Lingual-IME/VtoW/translanguaging-IME/datasets/chinese_news_cangjie_mix.txt', 'r', encoding='utf-8') as file:
        raw_data = file.readlines()

new_dataset = []

for line in raw_data:
    parts = line.strip().split('\t')
    if len(parts) == 2:
        source_text = parts[0]
        target_text = parts[1]
        
        sample = [
            source_text,  # enc_input
            f'S {target_text}',  # dec_input
            f'{target_text} E'  # dec_output
        ]
        
        new_dataset.append(sample)

with open('pretext.txt', 'w', encoding='utf-8') as file:
    for sample in new_dataset:
        file.write('\t'.join(sample) + '\n')
        
sentences = []
with open('pretext.txt', 'r', encoding='utf-8') as file:
    for line in file:
        parts = line.strip().split('\t')
        if len(parts) == 3:
            parts[1] = re.sub(r'([\u4e00-\u9fff。，！？])', r' \1 ', parts[1])
            parts[2] = re.sub(r'([\u4e00-\u9fff。，！？])', r' \1 ', parts[2])
            sentences.append(parts)

# print(sentences[0], sentences[1],sentences[2])
sentences =[['ogf wfyr tgdi yrhjr zxah jrb fusmg hqnl ino nd hgr yrit smha yhml mfj ogd dtrg nmsu nlomo zxab kb omnr ebcd ytap zxai P', 'S  焦  點  對  話 ： 胡  耀  邦  之  子  告  誡  習  近  平  集  權  危  險  ，  有  何  深  意  ？ ', ' 焦  點  對  話 ： 胡  耀  邦  之  子  告  誡  習  近  平  集  權  危  險  ，  有  何  深  意  ？  E'],
            ['her gpd ornin yycb mig ewot abme  l jbtj yrnl a yonk ewot thm k P', 'S  各  地  舒  適  至  溫  暖   中  南  部  日  夜  溫  差  大 ', ' 各  地  舒  適  至  溫  暖   中  南  部  日  夜  溫  差  大  E'],
            ['(afhhh )mtmbc ocsh csh ssr oybc yjksj ypsm ytwg orse afhhh llml  tnfd ybuc aa zxah mf yryvo hgks bq hdi mtln P', 'S ( 影 ) 頭  份  分  局  偵  辦  虐  童  假  影  片   蘇  貞  昌 ： 不  該  動  用  私  刑 ', '( 影 ) 頭  份  分  局  偵  辦  虐  童  假  影  片   蘇  貞  昌 ： 不  該  動  用  私  刑  E']]
# sentences = [
#           # enc_input           dec_input         dec_output
# ["ogf wfyr tgdi yrhjr zxah jrb fusmg hqnl ino nd hgr yrit smha yhml mfj ogd dtrg nmsu nlomo zxab kb omnr ebcd ytap zxai","S 焦點對話：胡耀邦之子告誡習近平集權危險，有何深意？","焦點對話：胡耀邦之子告誡習近平集權危險，有何深意？ E"],
# ["her gpd ornin yycb mig ewot abme  l jbtj yrnl a yonk ewot thm k","S 各地舒適至溫暖 中南部日夜溫差大","各地舒適至溫暖 中南部日夜溫差大 E"]
# ]

src_vocab = {'P': 0}
src_vocab_size = 1

tgt_vocab = {'P': 0, 'S': 1, 'E': 2}
idx2word = {0: 'P', 1: 'S', 2: 'E'}
tgt_vocab_size = 3

src_len = 0  #60
tgt_len = 0  #60


for sentence in sentences:
    enc_input, dec_input, dec_output = sentence

    enc_input_words = enc_input.split()
    for word in enc_input_words:
        if word not in src_vocab:
            src_vocab[word] = src_vocab_size
            src_vocab_size += 1
    src_len = max(src_len, len(enc_input_words))

    dec_input_words = dec_input.split()
    for word in dec_input_words:
        if word not in tgt_vocab:
            tgt_vocab[word] = tgt_vocab_size
            idx2word[tgt_vocab_size] = word
            tgt_vocab_size += 1
    tgt_len = max(tgt_len, len(dec_input_words))


# print(sentences[0])
print("src_vocab:", src_vocab)
print("tgt_vocab:", tgt_vocab)


# Padding Should be Zero
# src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
# src_vocab_size = len(src_vocab)

# tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
# idx2word = {i: w for i, w in enumerate(tgt_vocab)}
# tgt_vocab_size = len(tgt_vocab)

# src_len = 5 # enc_input max sequence length
# tgt_len = 6 # dec_input(=dec_output) max sequence length

# Transformer Parameters
d_model = 512  # Embedding Size
d_ff = 2048 # FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_layers = 6  # number of Encoder of Decoder Layer
n_heads = 8  # number of heads in Multi-Head Attention

In [18]:
def make_data(sentences, src_vocab, tgt_vocab):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    max_enc_len, max_dec_len = 0, 0
    
    for i in range(len(sentences)):
        enc_input = [src_vocab[n] for n in sentences[i][0].split()]
        dec_input = [tgt_vocab[n] for n in sentences[i][1].split()]
        dec_output = [tgt_vocab[n] for n in sentences[i][2].split()]

        max_enc_len = max(max_enc_len, len(enc_input))
        max_dec_len = max(max_dec_len, len(dec_input))

        enc_inputs.append(enc_input)
        dec_inputs.append(dec_input)
        dec_outputs.append(dec_output)

    enc_inputs = [seq + [0] * (max_enc_len - len(seq)) for seq in enc_inputs]
    dec_inputs = [seq + [0] * (max_dec_len - len(seq)) for seq in dec_inputs]
    dec_outputs = [seq + [0] * (max_dec_len - len(seq)) for seq in dec_outputs]

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)

enc_inputs, dec_inputs, dec_outputs = make_data(sentences, src_vocab, tgt_vocab)

class MyDataSet(Data.Dataset):
  def __init__(self, enc_inputs, dec_inputs, dec_outputs):
    super(MyDataSet, self).__init__()
    self.enc_inputs = enc_inputs
    self.dec_inputs = dec_inputs
    self.dec_outputs = dec_outputs
  
  def __len__(self):
    return self.enc_inputs.shape[0]
  
  def __getitem__(self, idx):
    return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]

loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)

In [19]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: [seq_len, batch_size, d_model]
        '''
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], False is masked
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

def get_attn_subsequence_mask(seq):
    '''
    seq: [batch_size, tgt_len]
    '''
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1) # Upper triangular matrix
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask # [batch_size, tgt_len, tgt_len]

In [20]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.
        
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
        return context, attn

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
    def forward(self, input_Q, input_K, input_V, attn_mask):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        residual, batch_size = input_Q, input_Q.size(0)
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # Q: [batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
        context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v]
        output = self.fc(context) # [batch_size, len_q, d_model]
        return nn.LayerNorm(d_model).cuda()(output + residual), attn

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )
    def forward(self, inputs):
        '''
        inputs: [batch_size, seq_len, d_model]
        '''
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(d_model).cuda()(output + residual) # [batch_size, seq_len, d_model]

class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        '''
        enc_inputs: [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]
        '''
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs, attn

class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention()
        self.dec_enc_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        '''
        dec_inputs: [batch_size, tgt_len, d_model]
        enc_outputs: [batch_size, src_len, d_model]
        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        dec_enc_attn_mask: [batch_size, tgt_len, src_len]
        '''
        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
        return dec_outputs, dec_self_attn, dec_enc_attn

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])

    def forward(self, enc_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        '''
        enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]
        enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1) # [batch_size, src_len, d_model]
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]
        enc_self_attns = []
        for layer in self.layers:
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
            enc_self_attns.append(enc_self_attn)
        return enc_outputs, enc_self_attns

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        '''
        dec_inputs: [batch_size, tgt_len]
        enc_intpus: [batch_size, src_len]
        enc_outputs: [batsh_size, src_len, d_model]
        '''
        dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
        dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1).cuda() # [batch_size, tgt_len, d_model]
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0).cuda() # [batch_size, tgt_len, tgt_len]

        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # [batc_size, tgt_len, src_len]

        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        return dec_outputs, dec_self_attns, dec_enc_attns

class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder = Encoder().cuda()
        self.decoder = Decoder().cuda()
        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False).cuda()
    def forward(self, enc_inputs, dec_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        '''
        # tensor to store decoder outputs
        # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
        
        # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        # dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns

model = Transformer().cuda()
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)

In [None]:
%pip install scikit-learn
from sklearn.metrics import precision_score, recall_score, f1_score

for epoch in range(1000):
    for enc_inputs, dec_inputs, dec_outputs in loader:
        enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda()
        outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
        loss = criterion(outputs, dec_outputs.view(-1))
        
        y_true = dec_outputs.view(-1).cpu().numpy()
        y_pred = outputs.argmax(dim=-1).view(-1).cpu().numpy()

        precision = precision_score(y_true, y_pred, average="weighted")
        recall = recall_score(y_true, y_pred, average="weighted")
        f1 = f1_score(y_true, y_pred, average="weighted")

        wandb.log({"precision": precision, "recall": recall, "f1": f1})

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


In [22]:
def greedy_decoder(model, enc_input, start_symbol):
    """
    For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
    target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.
    Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
    :param model: Transformer Model
    :param enc_input: The encoder input
    :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
    :return: The target input
    """
    enc_outputs, enc_self_attns = model.encoder(enc_input)
    dec_input = torch.zeros(1, 0).type_as(enc_input.data)
    terminal = False
    next_symbol = start_symbol
    while not terminal:         
        dec_input = torch.cat([dec_input.detach(),torch.tensor([[next_symbol]],dtype=enc_input.dtype).cuda()],-1)
        dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
        projected = model.projection(dec_outputs)
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
        next_word = prob.data[-1]
        next_symbol = next_word
        if next_symbol == tgt_vocab["E"]:
            terminal = True
        print(next_word)            
    return dec_input

In [23]:
# # Test
# enc_inputs, _, _ = next(iter(loader))
# enc_inputs = enc_inputs.cuda()
# for i in range(len(enc_inputs)):
#     greedy_dec_input = greedy_decoder(model, enc_inputs[i].view(1, -1), start_symbol=tgt_vocab["S"])
#     predict, _, _, _ = model(enc_inputs[i].view(1, -1), greedy_dec_input)
#     predict = predict.data.max(1, keepdim=True)[1]
#     print(enc_inputs[i], '->', [idx2word[n.item()] for n in predict.squeeze()])

test_data = [['jrb fusmg', '', '']]
enc_input = test_data[0][0]
enc_input_tensor = torch.LongTensor([src_vocab[n] for n in enc_input.split()]).unsqueeze(0).cuda()

greedy_dec_input = greedy_decoder(model, enc_input_tensor, start_symbol=tgt_vocab["S"])
predict, _, _, _ = model(enc_input_tensor, greedy_dec_input)
predict = predict.data.max(1, keepdim=True)[1]
predicted_tokens = [idx2word[n.item()] for n in predict.squeeze()]

print(enc_input, '->', predicted_tokens)


tensor(3, device='cuda:0')
tensor(4, device='cuda:0')
tensor(5, device='cuda:0')
tensor(6, device='cuda:0')
tensor(7, device='cuda:0')
tensor(8, device='cuda:0')
tensor(9, device='cuda:0')
tensor(10, device='cuda:0')
tensor(11, device='cuda:0')
tensor(12, device='cuda:0')
tensor(13, device='cuda:0')
tensor(14, device='cuda:0')
tensor(15, device='cuda:0')
tensor(16, device='cuda:0')
tensor(17, device='cuda:0')
tensor(18, device='cuda:0')
tensor(19, device='cuda:0')
tensor(20, device='cuda:0')
tensor(21, device='cuda:0')
tensor(22, device='cuda:0')
tensor(23, device='cuda:0')
tensor(24, device='cuda:0')
tensor(25, device='cuda:0')
tensor(26, device='cuda:0')
tensor(27, device='cuda:0')
tensor(2, device='cuda:0')
jrb fusmg -> ['焦', '點', '對', '話', '：', '胡', '耀', '邦', '之', '子', '告', '誡', '習', '近', '平', '集', '權', '危', '險', '，', '有', '何', '深', '意', '？', 'E']
