In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
import os
from jiwer import wer
import yaml
import json 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
####### Config #######
config_path = "conf_1"
config_file = os.path.join(config_path, "config.yml")
with open(config_file,'r') as conf:
    config = yaml.load(conf, Loader=yaml.SafeLoader)


In [3]:
class AddNorm(nn.Module):
    def __init__(self, norm_shape: int, dropout=0.2):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(norm_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

In [4]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, input_dim: int, hidden_ff_dim: int, dropout=0.2):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_ff_dim)
        self.dropout = nn.Dropout(dropout)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(hidden_ff_dim, input_dim)

    def forward(self, x):
        return self.linear2(self.dropout(self.relu1(self.linear1(x))))

In [5]:
class ShrinkNorm(nn.Module):
    def __init__(self, input_dim: int, shrink_norm_hidden: int, output_dim: int, dropout=0.2):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.linear1 = nn.Linear(input_dim, shrink_norm_hidden)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(shrink_norm_hidden, output_dim)
        self.ln = nn.LayerNorm(output_dim)

    def forward(self, x):
        return self.ln(self.linear2(self.dropout(self.relu1(self.linear1(x)))))

In [6]:
class SinoVietnameseTranslator(nn.Module):
    def __init__(self, tokenizer, base_model, vocab, hidden_ff_dim=512, model_hidden_dim=512, 
                 large_hidden_classification_head_dim=256, small_hidden_classification_head_dim=128,
                 shrink_norm_hidden=512, max_num_spellings=7, num_spelling_threshold=3, train_bert_param=True, dropout=0.2):
        super(SinoVietnameseTranslator, self).__init__()
        self.tokenizer = tokenizer
        self.bert = base_model
        self.vocab = vocab
        self.max_num_spellings = max_num_spellings
        
        for param in self.bert.parameters():
            param.requires_grad = train_bert_param
        
        self.shrink_norm = ShrinkNorm(self.bert.config.hidden_size,shrink_norm_hidden, model_hidden_dim, dropout)
        self.feed_forward = FeedForwardNetwork(model_hidden_dim, hidden_ff_dim, dropout)
        self.add_norm = AddNorm(model_hidden_dim, dropout)
        
        self.classification_heads = nn.ModuleDict()
        for sino_word, viet_spellings in self.vocab.items():
            if len(viet_spellings) > 1 and len(viet_spellings) <= num_spelling_threshold:
                num_spellings = len(viet_spellings)
                self.classification_heads[sino_word] = nn.Sequential(
                    nn.Linear(model_hidden_dim, small_hidden_classification_head_dim),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(small_hidden_classification_head_dim, num_spellings),
                    nn.Softmax(dim=-1)
                )
            elif len(viet_spellings) > num_spelling_threshold:
                num_spellings = len(viet_spellings)
                self.classification_heads[sino_word] = nn.Sequential(
                    nn.Linear(model_hidden_dim, large_hidden_classification_head_dim),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(large_hidden_classification_head_dim, num_spellings),
                    nn.Softmax(dim=-1)
                )

    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        shrink_output = self.shrink_norm(sequence_output)
        projected_output = self.add_norm(shrink_output, self.feed_forward(shrink_output))
        
        batch_size, max_len = input_ids.size()
        predictions = torch.full((batch_size, max_len, self.max_num_spellings), -1.0, device=input_ids.device)
        
        for i in range(batch_size):
            for j in range(max_len):
                token_id = input_ids[i, j].item()
                if token_id == self.tokenizer.pad_token_id:
                    continue
                    
                sino_word = self.tokenizer.convert_ids_to_tokens(token_id)
                
                if sino_word in self.classification_heads:
                    logits = self.classification_heads[sino_word](projected_output[i, j])
                    predictions[i, j, :len(logits)] = logits
                else:
                    predictions[i, j, 0] = 1.0

        return predictions

In [7]:
class SinoVietnameseDataset(Dataset):
    def __init__(self, tokenizer, data, vocab, max_len=512):
        self.data = data
        self.tokenizer = tokenizer
        self.vocab = vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sino_sent, viet_sent = self.data[idx]
        sino_tokens = self.tokenizer.encode(sino_sent, add_special_tokens=False, max_length=self.max_len, truncation=True)
        viet_spellings = viet_sent.split()

        input_ids = sino_tokens + [self.tokenizer.pad_token_id] * (self.max_len - len(sino_tokens))

        labels = []
        for i, sino_word_id in enumerate(sino_tokens):
            sino_word = self.tokenizer.convert_ids_to_tokens(sino_word_id)
            if sino_word in self.vocab:
                viet_spellings_for_word = self.vocab[sino_word]
                if len(viet_spellings_for_word) > 1:
                    label = viet_spellings_for_word.index(viet_spellings[i])
                else:
                    label = -1
            else:
                label = -1
            labels.append(label)

        labels += [-1] * (self.max_len - len(labels))  # Padding
        attention_mask = [1] * len(sino_tokens) + [0] * (self.max_len - len(sino_tokens))

        return {
            "input_ids": torch.tensor(input_ids),
            "labels": torch.tensor(labels),
            "attention_mask": torch.tensor(attention_mask),
        }


In [8]:

def load_data(data_file):
    data = []
    with open(data_file, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            if ',' not in line:
                continue
            sino_sent, viet_sent = line.strip().split(',')
            data.append((sino_sent, viet_sent))
    return data

train_data_path = "data/train.txt"
test_data_path = "data/test.txt"
train_data = load_data(train_data_path)
test_data = load_data(test_data_path)

In [9]:
with open('vocab/vocab.json', 'r') as vocab_file, open('vocab/sino_viet_words.json', 'r') as words_file:
    base_vocab = json.load(vocab_file)
    sino_viet_words = json.load(words_file)

print(type(base_vocab))
print(len(base_vocab))
print(type(sino_viet_words))
print(len(sino_viet_words))

<class 'dict'>
7688
<class 'list'>
7688


In [10]:
# Model Config
bert_model = config['model_config']['bert_model'] 

base_tokenizer = BertTokenizer.from_pretrained(bert_model)
base_tokenizer.add_tokens(sino_viet_words)

base_model = BertModel.from_pretrained(bert_model)
base_model.resize_token_embeddings(len(base_tokenizer))

Embedding(23683, 768)

In [11]:
# Data Config
batch_size = config['data_config']['batch_size']
max_len = config['data_config']['max_len']

train_dataset = SinoVietnameseDataset(base_tokenizer, train_data, base_vocab, max_len)
test_dataset = SinoVietnameseDataset(base_tokenizer, test_data, base_vocab, max_len)

print(f"Train set: {len(train_dataset)}")
print(f"Test set: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Train batch num: {len(train_loader)}")
print(f"Test batch num: {len(test_loader)}")

Train set: 153372
Test set: 17042
Train batch num: 4793
Test batch num: 533


In [12]:
def decode_predictions(predictions, input_ids, tokenizer, vocab):
    decoded_sentences = []
    for i, predicted_indices in enumerate(predictions):
        decoded_sentence = []
        for j, spelling_index in enumerate(predicted_indices):
            token = input_ids[i, j].item()
            if token == tokenizer.pad_token_id:
                continue
                
            sino_word = tokenizer.convert_ids_to_tokens(token)
            if spelling_index == -1:
                viet_spelling = vocab[sino_word][0]
            else:
                viet_spelling = vocab[sino_word][spelling_index]
            decoded_sentence.append(viet_spelling)

        decoded_sentences.append(" ".join(decoded_sentence))
    return decoded_sentences

In [13]:
def train(model, train_dataloader, test_dataloader, epochs=3, lr=1e-5, 
          max_grad_norm=1.0, model_load_path=None, config_folder_dir="config/"):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    tokenizer = model.tokenizer

    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)
    
    log_dir = os.path.join(config_folder_dir, f"running/")
    os.makedirs(log_dir, exist_ok=True)
    train_losses_dir = os.path.join(log_dir, f"train_losses.txt")
    test_losses_dir = os.path.join(log_dir, f"test_losses.txt")
    test_accuracies_dir = os.path.join(log_dir, f"test_accuracies.txt")
    test_wers_dir = os.path.join(log_dir, f"test_wers.txt")
    
    if model_load_path:
        model.load_state_dict(torch.load(model_load_path))
    # Determine the starting epoch
    start_epoch = 0
    if model_load_path:
        start_epoch = int(model_load_path.split("_")[-1].split(".")[0]) 
    
    for epoch in range(start_epoch, start_epoch + epochs):
        model.train()
        total_loss = 0

        # Training loop with progress bar
        train_iterator = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{start_epoch + epochs}", unit="batch")
        for batch in train_iterator:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()

            outputs = model(input_ids, attention_mask=attention_mask)

            # Flatten 
            preds = outputs.view(-1, outputs.size(-1))
            targets = labels.view(-1)

            loss = criterion(preds, targets) # batch loss
            total_loss += loss.item()

            loss.backward()
            clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            
            train_iterator.set_postfix(loss=loss.item())

        avg_train_loss = total_loss / len(train_dataloader)

        with open(train_losses_dir, 'a') as tl:
            tl.write(f"{avg_train_loss};")

        print(f"Epoch {epoch+1}/{start_epoch + epochs}, Training Loss: {avg_train_loss}")

        ################################## Run test ##################################
        model.eval()
        total_test_loss = 0
        correct_predictions = 0 # calculate accuracies over sino words that have multiple viet spellings only
        total_predictions = 0
        all_ground_truths = []
        all_predictions = []
        
        with torch.no_grad():
            test_iterator = tqdm(test_dataloader, desc="Validating", unit="batch")
            for batch in test_iterator:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                outputs = model(input_ids, attention_mask=attention_mask)

                preds = outputs.view(-1, outputs.size(-1)) # Flatten
                targets = labels.view(-1)

                test_loss = criterion(preds, targets)
                total_test_loss += test_loss.item()

                predictions = torch.argmax(outputs, dim=-1)
                mask = labels != -1
                correct_predictions += (predictions[mask] == labels[mask]).sum().item()
                total_predictions += mask.sum().item()
                
                batch_predictions = decode_predictions(predictions, input_ids, tokenizer, model.vocab)
                batch_ground_truths = decode_predictions(labels, input_ids, tokenizer, model.vocab)
                all_predictions.extend(batch_predictions)
                all_ground_truths.extend(batch_ground_truths)
            
            avg_test_loss = total_test_loss / len(test_dataloader)
            with open(test_losses_dir, 'a') as tl2:
                tl2.write(f"{avg_test_loss};")
            
            test_accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
            with open(test_accuracies_dir, 'a') as ta:
                ta.write(f"{test_accuracy * 100};")
            
            test_wer = wer(all_ground_truths, all_predictions)
            with open(test_wers_dir, 'a') as tw:
                tw.write(f"{test_wer * 100};")
            
            print(f"Epoch {epoch+1}/{start_epoch + epochs}, Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy * 100:.4f}, Test WER: {test_wer * 100:.4f}")

        scheduler.step(avg_test_loss)
        print(f"Learning rate: {scheduler.get_last_lr()}")
        
        # Save the model after each epoch
        save_dir = os.path.join(log_dir, f"saved_model/")
        os.makedirs(save_dir, exist_ok=True)
        model_save_path = os.path.join(save_dir, f"sivi_model_epoch_{epoch+1}.pt")
        torch.save(model.state_dict(), model_save_path)
        print(f"Model saved to {model_save_path}")

In [14]:
# Model config
hidden_ff_dim = config['model_config']['hidden_ff_dim']
model_hidden_dim = config['model_config']['model_hidden_dim']
shrink_norm_hidden = config['model_config']['shrink_norm_hidden']
large_hidden_classification_head_dim = config['model_config']['large_hidden_classification_head_dim']
small_hidden_classification_head_dim = config['model_config']['small_hidden_classification_head_dim']
max_num_spellings = config['model_config']['max_num_spellings']
num_spelling_threshold = config['model_config']['num_spelling_threshold']
train_bert_param = config['model_config']['train_bert_param']
dropout = config['model_config']['dropout']

model = SinoVietnameseTranslator(base_tokenizer, base_model, base_vocab, hidden_ff_dim=hidden_ff_dim, 
                                model_hidden_dim=model_hidden_dim, shrink_norm_hidden=shrink_norm_hidden,
                                large_hidden_classification_head_dim=large_hidden_classification_head_dim,
                                small_hidden_classification_head_dim=small_hidden_classification_head_dim,
                                max_num_spellings=max_num_spellings, train_bert_param=train_bert_param,
                                num_spelling_threshold=num_spelling_threshold, dropout=dropout)

num_param = sum([param.nelement() for param in model.parameters()]) / 1000000
print(f"{num_param:.1f}M params.")

151.9M params.


In [15]:
# Trainning config
num_epochs = config['training_config']['num_epochs']
learning_rate = config['training_config']['learning_rate']
model_load_path = None if config['training_config']['model_load_path'] == 'None' else config['training_config']['model_load_path']
config_folder_dir = config_path

train(model, train_loader, test_loader, epochs=num_epochs, lr=learning_rate,
    model_load_path=model_load_path, config_folder_dir=config_folder_dir)

  attn_output = torch.nn.functional.scaled_dot_product_attention(
Epoch 1/60: 100%|██████████| 4793/4793 [1:31:25<00:00,  1.14s/batch, loss=0.856]


Epoch 1/60, Training Loss: 0.9126091286580448


Validating: 100%|██████████| 533/533 [10:19<00:00,  1.16s/batch]


Epoch 1/60, Test Loss: 0.8474, Test Accuracy: 91.3767, Test WER: 2.5048
Learning rate: [1e-05]
Model saved to conf\saved_model/sivi_model_epoch_1.pt


Epoch 2/60: 100%|██████████| 4793/4793 [1:34:14<00:00,  1.18s/batch, loss=0.855]


Epoch 2/60, Training Loss: 0.836351030683209


Validating: 100%|██████████| 533/533 [09:34<00:00,  1.08s/batch]


Epoch 2/60, Test Loss: 0.8258, Test Accuracy: 92.3318, Test WER: 2.2274
Learning rate: [1e-05]
Model saved to conf\saved_model/sivi_model_epoch_2.pt


Epoch 3/60:  21%|██        | 1005/4793 [19:12<1:17:34,  1.23s/batch, loss=0.843]

In [15]:
torch.cuda.empty_cache()