In [56]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
from sklearn.metrics import recall_score, f1_score, precision_score, accuracy_score, precision_recall_fscore_support
from tqdm import tqdm
import os
from jiwer import wer
import yaml
import json 
import math

In [57]:
####### Config #######
config_path = "conf_11_gelu"
config_file = os.path.join(config_path, "config.yml")
with open(config_file,'r') as conf:
    config = yaml.load(conf, Loader=yaml.SafeLoader)


In [58]:
class AddNorm(nn.Module):
    def __init__(self, norm_shape: int, dropout=0.2):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(norm_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

class FeedForwardNetwork(nn.Module):
    def __init__(self, input_dim: int, hidden_ff_dim: int, dropout=0.2, use_gelu=True):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_ff_dim)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU() if use_gelu else nn.ReLU()
        self.linear2 = nn.Linear(hidden_ff_dim, input_dim)

    def forward(self, x):
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

class ClassificationHead(nn.Module):
    def __init__(self, input_dim: int, head_hidden_dim: int, num_readings: int, dropout=0.2, use_gelu=True):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.linear1 = nn.Linear(input_dim, head_hidden_dim)
        self.activation = nn.ReLU()
        self.linear2 = nn.Linear(head_hidden_dim, num_readings)
        self.output_prob = nn.Softmax(dim=-1)

    def forward(self, x):
        return self.output_prob(self.linear2(self.dropout(self.activation(self.linear1(x)))))
    
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) 
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 
        pe[:, 0::2] = torch.sin(position * div_term) 
        pe[:, 1::2] = torch.cos(position * div_term) 
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
        return self.dropout(x)
    
class ShrinkNorm(nn.Module):
    def __init__(self, input_dim: int, shrink_norm_hidden: int, output_dim: int, dropout=0.2, use_gelu=True):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.linear1 = nn.Linear(input_dim, shrink_norm_hidden)
        self.activation = nn.GELU() if use_gelu else nn.ReLU()
        self.linear2 = nn.Linear(shrink_norm_hidden, output_dim)
        self.ln = nn.LayerNorm(output_dim)

    def forward(self, x):
        return self.ln(self.linear2(self.dropout(self.activation(self.linear1(x)))))

class TransformerEncoderBlock(nn.Module): 
    """The Transformer encoder block."""
    def __init__(self, model_hidden_dim, ffn_num_hiddens, num_heads, dropout, use_gelu=True):
        super().__init__()
        self.multihead_attention = nn.MultiheadAttention(embed_dim=model_hidden_dim, num_heads=num_heads, 
                                                         batch_first=True, dropout=dropout)

        self.addnorm1 = AddNorm(model_hidden_dim, dropout)
        self.ffn = FeedForwardNetwork(model_hidden_dim, ffn_num_hiddens, use_gelu)
        self.addnorm2 = AddNorm(model_hidden_dim, dropout)

    def forward(self, X, key_padding_mask):
        Y = self.addnorm1(X, self.multihead_attention(X, X, X, key_padding_mask=key_padding_mask)[0])
        return self.addnorm2(Y, self.ffn(Y))
    
class TransformerEncoder(nn.Module):  
    """The Transformer encoder."""
    def __init__(self, model_hidden_dim, hidden_ff_dim,
                 num_heads, num_blks, max_len, dropout, use_gelu=True):
        super().__init__()
        self.model_hidden_dim = model_hidden_dim
        self.pos_encoding = PositionalEncoding(d_model=model_hidden_dim,seq_len=max_len, dropout=dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module("block"+str(i), TransformerEncoderBlock(
                model_hidden_dim, hidden_ff_dim, num_heads, dropout, use_gelu))

    def forward(self, X, key_padding_mask):
        X = self.pos_encoding(X)
        for blk in self.blks:
            X = blk(X, key_padding_mask)
        return X

class SinoVietnameseTranslator(nn.Module):
    def __init__(self, tokenizer, base_model, vocab, hidden_ff_dim=1024, model_hidden_dim=256, 
                 head_hidden_dim=128, shrink_norm_hidden=512, max_num_spellings=7, train_bert_param=True, 
                 max_len=512, num_heads=8, num_blks=6, dropout=0.2, use_gelu=True):
        super(SinoVietnameseTranslator, self).__init__()
        self.tokenizer = tokenizer
        self.bert = base_model
        self.vocab = vocab
        self.max_num_spellings = max_num_spellings
        self.max_len = max_len
        
        for param in self.bert.parameters():
            param.requires_grad = train_bert_param
        
        self.shrink_norm = ShrinkNorm(self.bert.config.hidden_size,shrink_norm_hidden, model_hidden_dim, dropout, use_gelu)
        self.encoder = TransformerEncoder(model_hidden_dim, hidden_ff_dim, num_heads, num_blks, max_len, dropout, use_gelu)
        
        self.classification_heads = nn.ModuleDict()
        for sino_word, viet_spellings in self.vocab.items():
            num_readings = len(viet_spellings)
            if num_readings > 1:
                self.classification_heads[sino_word] = ClassificationHead(model_hidden_dim, head_hidden_dim, num_readings, dropout, use_gelu)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state

        sequence_output = self.shrink_norm(sequence_output)

        if attention_mask is not None:
            key_padding_mask = ~attention_mask.bool()  # Convert to bool

        projected_output = self.encoder(sequence_output, key_padding_mask)
        
        batch_size, max_len = input_ids.size()
        predictions = torch.full((batch_size, max_len, self.max_num_spellings), -1.0, device=input_ids.device)
        
        for i in range(batch_size):
            for j in range(max_len):
                token_id = input_ids[i, j].item()
                if token_id == self.tokenizer.pad_token_id:
                    continue
                    
                sino_word = self.tokenizer.convert_ids_to_tokens(token_id)
                
                if sino_word in self.classification_heads:
                    logits = self.classification_heads[sino_word](projected_output[i, j])
                    predictions[i, j, :len(logits)] = logits
                else:
                    predictions[i, j, 0] = 1.0

        return predictions

In [59]:
class SinoVietnameseDataset(Dataset):
    def __init__(self, tokenizer, data, vocab, max_len=512):
        self.data = data
        self.tokenizer = tokenizer
        self.vocab = vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sino_sent, viet_sent = self.data[idx]
        sino_tokens = self.tokenizer.encode(sino_sent, add_special_tokens=False, max_length=self.max_len, truncation=True)
        viet_spellings = viet_sent.split()

        input_ids = sino_tokens + [self.tokenizer.pad_token_id] * (self.max_len - len(sino_tokens))

        labels = []
        for i, sino_word_id in enumerate(sino_tokens):
            sino_word = self.tokenizer.convert_ids_to_tokens(sino_word_id)
            if sino_word in self.vocab:
                viet_spellings_for_word = self.vocab[sino_word]
                if len(viet_spellings_for_word) > 1:
                    label = viet_spellings_for_word.index(viet_spellings[i])
                else:
                    label = -1
            else:
                label = -1
            labels.append(label)

        labels += [-1] * (self.max_len - len(labels))  # Padding
        attention_mask = [1] * len(sino_tokens) + [0] * (self.max_len - len(sino_tokens))

        return {
            "input_ids": torch.tensor(input_ids),
            "labels": torch.tensor(labels),
            "attention_mask": torch.tensor(attention_mask),
        }


In [60]:
def load_data(data_file):
    data = []
    with open(data_file, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            if ',' not in line:
                continue
            sino_sent, viet_sent = line.strip().split(',')
            data.append((sino_sent, viet_sent))
    return data

test_data_path = "data/test.txt"
test_data = load_data(test_data_path)

In [61]:
with open('vocab/vocab.json', 'r') as vocab_file, open('vocab/sino_viet_words.json', 'r') as words_file:
    base_vocab = json.load(vocab_file)
    sino_viet_words = json.load(words_file)

In [62]:
# Model Config
bert_model = config['model_config']['bert_model'] 

base_tokenizer = BertTokenizer.from_pretrained(bert_model)
base_tokenizer.add_tokens(sino_viet_words)

base_model = BertModel.from_pretrained(bert_model)
base_model.resize_token_embeddings(len(base_tokenizer))

Embedding(23683, 768)

In [63]:
# Data Config
batch_size = config['data_config']['batch_size']
max_len = config['data_config']['max_len']

test_dataset = SinoVietnameseDataset(base_tokenizer, test_data, base_vocab, max_len)

print(f"Test set: {len(test_dataset)}")

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Test batch num: {len(test_loader)}")

Test set: 17042
Test batch num: 1066


In [64]:
# Model config
hidden_ff_dim = config['model_config']['hidden_ff_dim']
model_hidden_dim = config['model_config']['model_hidden_dim']
shrink_norm_hidden = config['model_config']['shrink_norm_hidden']
head_hidden_dim = config['model_config']['head_hidden_dim']
max_num_spellings = config['model_config']['max_num_spellings']
train_bert_param = config['model_config']['train_bert_param']
num_heads = config['model_config']['num_heads']
num_blks = config['model_config']['num_blks']
dropout = config['model_config']['dropout']
use_gelu = config['model_config']['use_gelu']

model = SinoVietnameseTranslator(base_tokenizer, base_model, base_vocab, hidden_ff_dim=hidden_ff_dim, 
                                model_hidden_dim=model_hidden_dim, shrink_norm_hidden=shrink_norm_hidden,
                                head_hidden_dim=head_hidden_dim, max_num_spellings=max_num_spellings,
                                train_bert_param=train_bert_param, num_heads=num_heads, use_gelu=use_gelu,
                                max_len=max_len, num_blks=num_blks, dropout=dropout)

num_param = sum([param.nelement() for param in model.parameters()]) / 1000000
print(f"{num_param:.1f}M params.")

152.5M params.


In [65]:
pred_records = {}
for sino_char, viet_readings in base_vocab.items():
    if len(viet_readings) > 1:
        pred_records[sino_char] = {'labels': [], 'preds': [], 'num_appear': 0}

def decode_predictions(predictions, labels, input_ids, tokenizer, vocab, total_char):
    predicted_sentences = []
    ground_truth_sentences = []
    sino_viet_sentences = []
    for i in range(predictions.size(0)):
        decoded_sentence = []
        ground_truth = []
        sino_viet_sentence = []
        for j in range(predictions.size(1)):
            token = input_ids[i, j].item()
            if token == tokenizer.pad_token_id:
                continue
                
            sino_word = tokenizer.convert_ids_to_tokens(token)
            sino_viet_sentence.append(sino_word)
            total_char = total_char + 1

            pred_index = predictions[i, j].item()
            label_index = labels[i, j].item()

            if len(vocab[sino_word]) == 1:
                assert pred_index == 0, "pred_index is not 0 in case of 1 spelling."
                assert label_index == -1, "label_index is not -1 in case of 1 spelling."

                pred_viet_spelling = label_viet_spelling = vocab[sino_word][0]
            else:
                pred_records[sino_word]['labels'].append(label_index)
                pred_records[sino_word]['preds'].append(pred_index)
                pred_records[sino_word]['num_appear'] += 1

                if pred_index != label_index:
                    pred_viet_spelling = f"#{vocab[sino_word][pred_index]}#"
                    label_viet_spelling = f"#{vocab[sino_word][label_index]}#"
                elif pred_index == label_index:
                    pred_viet_spelling = label_viet_spelling = vocab[sino_word][label_index]
                
            decoded_sentence.append(pred_viet_spelling)
            ground_truth.append(label_viet_spelling)

        sino_viet_sentences.append("".join(sino_viet_sentence))
        predicted_sentences.append(" ".join(decoded_sentence))
        ground_truth_sentences.append(" ".join(ground_truth))
        
    return sino_viet_sentences, ground_truth_sentences, predicted_sentences, total_char

In [66]:
def save_test_results(test_accuracy, test_wer, all_source_sentences, 
                      all_ground_truths, all_predictions, test_results_dir, total_char):
    test_results_file_dir = os.path.join(test_results_dir, 'test.txt')
    with open(test_results_file_dir, 'w') as f:
        f.write(f'Test accuracy: {test_accuracy}\nTest WER: {test_wer}\nTotal number of characters: {total_char}\n')
        # f.write("SOURCE\t||\tGROUNDTRUTH\t||\tPREDICTION\n")
        for i in range(len(all_source_sentences)):
            f.write(f'{all_source_sentences[i]},{all_ground_truths[i]},{all_predictions[i]}\n')

    print(f"Test results saved successfully into {test_results_file_dir}")

In [67]:
def test(model, test_dataloader, model_load_path, config_folder_dir):
    assert model_load_path != None, "No model to load"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    checkpoint = torch.load(model_load_path)
    model.load_state_dict(checkpoint['model_state_dict'])

    tokenizer = model.tokenizer
    vocab = model.vocab

    test_results_dir = os.path.join(config_folder_dir, f"test_results/")
    os.makedirs(test_results_dir, exist_ok=True)

    model.eval()
    correct_predictions = 0 # calculate accuracies over sino words that have multiple viet spellings only
    total_predictions = 0
    all_source_sentences = []
    all_ground_truths = []
    all_predictions = []
    total_char = 0
    with torch.no_grad():
        test_iterator = tqdm(test_dataloader, desc="Testing", unit="batch")
        for batch in test_iterator:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)

            predictions = torch.argmax(outputs, dim=-1)
            mask = labels != -1
            correct_predictions += (predictions[mask] == labels[mask]).sum().item()
            total_predictions += mask.sum().item()
            
            batch_source_sentences, batch_ground_truths, batch_predictions, total_char = decode_predictions(predictions, 
                                                                            labels, input_ids, tokenizer, vocab, total_char)
            all_predictions.extend(batch_predictions)
            all_ground_truths.extend(batch_ground_truths)
            all_source_sentences.extend(batch_source_sentences)

        test_accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
        test_wer = wer(all_ground_truths, all_predictions)
        print(f"Correct preds: {correct_predictions}")
        print(f"Total preds: {total_predictions}")
        save_test_results(test_accuracy, test_wer, all_source_sentences, 
                      all_ground_truths, all_predictions, test_results_dir, total_char)



In [68]:
model_load_path = None if config['training_config']['model_load_path'] == 'None' else config['training_config']['model_load_path']
print((model_load_path))
test(model, test_loader, model_load_path, config_path)

conf_11_gelu\running\saved_model\sivi_model_epoch_76.pt


  attn_output = torch.nn.functional.scaled_dot_product_attention(
Testing: 100%|██████████| 1066/1066 [05:23<00:00,  3.29batch/s]


Correct preds: 52044
Total preds: 54550
Test results saved successfully into conf_11_gelu\test_results/test.txt


In [69]:
def cal_f1_recall(labels, predictions, num_readings):
    avg = ""
    if num_readings == 2 :
        avg = "weighted"
    else:
        avg = "weighted"
    precision, recall, f1, _  = precision_recall_fscore_support(labels, predictions, average=avg)
    accu = accuracy_score(labels, predictions)
    # print(precision)
    # print(recall)
    # print(f1)
    return precision, recall, f1, accu

In [70]:
# Filter pred_records to keep only entries where 'num_appear' > 0
pred_records = {sino_char: data for sino_char, data in pred_records.items() if data['num_appear'] > 0}

# Initialize pred_results dictionary
pred_results = {}

# Process each filtered entry in pred_records
for sino_viet_char, data in pred_records.items():
    num_readings = len(base_vocab[sino_viet_char])
    labels = data['labels']
    preds = data['preds']
    precision, recall, f1, accuracy = cal_f1_recall(labels, preds, num_readings)
    num_appear = data['num_appear']
    freq = num_appear / 54550

    pred_results[sino_viet_char] = {
        'num_readings': num_readings,
        'num_appear': num_appear,
        'freq': freq,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize

In [71]:
pred_records_dir = os.path.join(config_path, f"test_results/")
os.makedirs(pred_records_dir, exist_ok=True)
with open(f"{pred_records_dir}/pred_records.json", 'w', encoding="utf-8") as pred_records_file:
    json.dump(pred_records, pred_records_file, ensure_ascii=False, indent=4)

In [72]:
pred_results_dir = os.path.join(config_path, f"test_results/")
os.makedirs(pred_results_dir, exist_ok=True)
with open(f"{pred_results_dir}/pred_results.json", 'w', encoding="utf-8") as pred_results_file:
    json.dump(pred_results, pred_results_file, ensure_ascii=False, indent=4)

In [57]:
# calculate average scores
# calculate average scores
all_precision = [result['precision'] for result in pred_results.values() if result['precision'] == result['precision']]
all_recall = [result['recall'] for result in pred_results.values() if result['recall'] == result['recall']]
all_f1 = [result['f1'] for result in pred_results.values() if result['f1'] == result['f1']]
average_precision = sum(all_precision) / len(all_precision)
average_recall = sum(all_recall) / len(all_recall)
average_f1 = sum(all_f1) / len(all_f1)

print(f"Average Precision: {average_precision}")
print(f"Average Recall: {average_recall}")
print(f"Average F1: {average_f1}")

Average Precision: 0.7742575062782276
Average Recall: 0.7819213756390647
Average F1: 0.7686534070284696
