<a href="https://colab.research.google.com/github/CaptainPlusPlus/btba_reproduction/blob/main/btba_reproduction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reproduction of BTBA Model for Unsupervised Word Alignment
Article can be found here: https://aclanthology.org/2021.acl-long.24.pdf

### Requirements
* Downloand and preprocess the deen, fren, roen texts from https://github.com/lilt/alignment-scripts.
* Upload the `bpe` lowercased preprocessed `train` & `test` folders as well as the sentencepiece `bpe` models.

## Custom Tokenizer Definition

Since the Sentencepiece tokenizer used in the article and alignment scripts it is compared against outputs a binary format model and vocabulary, the sentencepiece tokenizer must be adjusted to fit the HuggingFace models used to reproduce the article's transformer based approaach.

In [1]:
import sentencepiece as spm

class CustomSentencePieceTokenizer:
    def __init__(self, sentencepiece_model_path):
        self.sp = spm.SentencePieceProcessor()
        if not self.sp.Load(sentencepiece_model_path):
            raise FileNotFoundError("Failed to load SentencePiece model from specified path.")
        self.special_tokens = {'<s>': self.sp.piece_to_id('<s>'), '</s>': self.sp.piece_to_id('</s>'), '<unk>': self.sp.piece_to_id('<unk>')}
        self.additional_special_tokens = {'<pad>': self.sp.GetPieceSize(), '<mask>': self.sp.GetPieceSize() + 1}
        self.special_token_ids = {**self.special_tokens, **self.additional_special_tokens}

    def tokenize(self, text):
        return self.sp.encode_as_pieces(text)

    def convert_tokens_to_ids(self, tokens):
        return [self.special_token_ids.get(token, self.sp.piece_to_id(token)) for token in tokens]

    def convert_ids_to_tokens(self, ids):
        id_to_token_map = {id: token for token, id in self.special_token_ids.items()}
        id_to_token_map.update({id: self.sp.id_to_piece(id) for id in range(self.sp.GetPieceSize())})
        return [id_to_token_map.get(id, '<unk>') for id in ids]

    def get_vocab_size(self):
        return self.sp.GetPieceSize() + len(self.additional_special_tokens)

    def get_special_tokens(self):
        return {**self.special_tokens, **self.additional_special_tokens}

    def get_special_token_ids(self):
        return self.special_token_ids


In [None]:
def test_custom_tokenizer():
    tokenizer = CustomSentencePieceTokenizer('/content/drive/MyDrive/bachelor_thesis/data/alignment-scripts/train/bpe.deen.model')
    test_sentence = "das ist ein test."
    print("Testing tokenization of sentence:", test_sentence)
    tokens = tokenizer.tokenize(test_sentence)
    print("Tokens:", tokens)
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    print("Token IDs:", token_ids)
    tokens_from_ids = tokenizer.convert_ids_to_tokens(token_ids)
    print("Tokens from IDs:", tokens_from_ids)
    print("Special Tokens:", tokenizer.get_special_tokens())
    print("Special Token IDs:", tokenizer.get_special_token_ids())
    special_tokens_test = ['<pad>', '<mask>', '<s>', '</s>', '<unk>']
    special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens_test)
    print("Special tokens to IDs:", list(zip(special_tokens_test, special_tokens_ids)))
    special_tokens_round_trip = tokenizer.convert_ids_to_tokens(special_tokens_ids)
    print("IDs back to special tokens:", special_tokens_round_trip)

test_custom_tokenizer()

In [None]:
tokenizer = CustomSentencePieceTokenizer(model_path)

## Load bpe lowercased data and save tokenized data

In [None]:
import torch

def load_data(src_file, tgt_file):
    with open(src_file, 'r', encoding='utf-8') as src_f, open(tgt_file, 'r', encoding='utf-8') as tgt_f:
        src_lines = [line.strip() for line in src_f.readlines()]
        tgt_lines = [line.strip() for line in tgt_f.readlines()]
    assert len(src_lines) == len(tgt_lines), "Source and target files should have the same number of lines."
    return src_lines, tgt_lines

def tokenize_and_save_data(src_lines, tgt_lines, tokenizer, src_path, tgt_path):
    tokenized_src = [torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))) for line in src_lines]
    tokenized_tgt = [torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))) for line in tgt_lines]
    torch.save(tokenized_src, src_path)
    torch.save(tokenized_tgt, tgt_path)
    print(f"Tokenized data saved to {src_path} and {tgt_path}")


In [None]:
src_file_path = '/content/drive/MyDrive/bachelor_thesis/data/alignment-scripts/train/deen.lc.plustest.src.bpe'
tgt_file_path = '/content/drive/MyDrive/bachelor_thesis/data/alignment-scripts/train/deen.lc.plustest.tgt.bpe'
model_path = '/content/drive/MyDrive/bachelor_thesis/data/alignment-scripts/train/bpe.deen.model'
tokenized_src_path = '/content/drive/MyDrive/bachelor_thesis/data/alignment-scripts/train/deen_tokenized_src.pt'
tokenized_tgt_path = '/content/drive/MyDrive/bachelor_thesis/data/alignment-scripts/train/deen_tokenized_tgt.pt'

src_lines, tgt_lines = load_data(src_file_path, tgt_file_path)
tokenize_and_save_data(src_lines, tgt_lines, tokenizer, tokenized_src_path, tokenized_tgt_path)

In [None]:
from transformers import BartConfig, BartModel
import torch.nn as nn

class BTBADecoderLayer(nn.Module):
    def __init__(self, config, include_ffn=True):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(config.d_model, config.decoder_attention_heads)
        self.multihead_attn = nn.MultiheadAttention(config.d_model, config.decoder_attention_heads)
        self.layer_norm1 = nn.LayerNorm(config.d_model)
        self.layer_norm2 = nn.LayerNorm(config.d_model)
        self.include_ffn = include_ffn
        if include_ffn:
            self.ffn = nn.Sequential(
                nn.Linear(config.d_model, config.decoder_ffn_dim),
                nn.ReLU(),
                nn.Linear(config.decoder_ffn_dim, config.d_model),
            )
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, memory, src_mask=None, tgt_mask=None):
        x = self.layer_norm1(x + self.dropout(self.self_attn(x, x, x, key_padding_mask=tgt_mask)[0]))
        x = self.layer_norm2(x + self.dropout(self.multihead_attn(x, memory, memory, key_padding_mask=src_mask)[0]))
        if self.include_ffn:
            x = self.ffn(x)
        return x

class BTBAModel(BartModel):
    def __init__(self, config):
        super().__init__(config)
        assert hasattr(config, 'decoder_ffn_dim'), "decoder_ffn_dim is not defined in the configuration"
        self.decoder.layers = nn.ModuleList([
            BTBADecoderLayer(config, include_ffn=(i < config.decoder_layers - 1))
            for i in range(config.decoder_layers)
        ])

## Dynamic Masking for tje Dataset

* Every word in a sentence should be masked only once across the entire training - track the masking state and reset it after each epoch.
* Percentage-based: At least 10% of the words in each sentence must be masked, or one word if the sentence has less than ten words.

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset

class DynamicMaskingDataset(Dataset):
    def __init__(self, tokenized_src, tokenized_tgt, tokenizer, mask_probability=0.1):
        self.tokenized_src = tokenized_src
        self.tokenized_tgt = tokenized_tgt
        self.tokenizer = tokenizer
        self.mask_id = tokenizer.get_special_token_ids()['<mask>']
        self.pad_id = tokenizer.get_special_token_ids()['<pad>']
        self.mask_probability = mask_probability

    def __len__(self):
        return len(self.tokenized_tgt)

    def mask_input(self, inputs):
        num_tokens = len(inputs)
        num_to_mask = max(int(num_tokens * self.mask_probability), 1)

        # Don't mask special tokens
        candidate_mask = (inputs != self.pad_id) & (inputs != self.tokenizer.get_special_token_ids()['<s>']) & (inputs != self.tokenizer.get_special_token_ids()['</s>'])
        candidate_indices = np.where(candidate_mask.numpy())[0]

        masked_indices = np.random.choice(candidate_indices, size=num_to_mask, replace=False)
        labels = inputs.clone()

        inputs[masked_indices] = self.mask_id
        labels[candidate_mask == 0] = -100  # No loss for unmaksed tokens

        return inputs, labels

    def __getitem__(self, idx):
        src = self.tokenized_src[idx]
        tgt = self.tokenized_tgt[idx]
        src, _ = self.mask_input(src)
        tgt, labels = self.mask_input(tgt)
        return {"input_ids": src, "labels": labels}


In [None]:
tokenized_src = torch.load(tokenized_src_path)
tokenized_tgt = torch.load(tokenized_tgt_path)

dataset = DynamicMaskingDataset(tokenized_src, tokenized_tgt, tokenizer)

In [None]:
from transformers import BartConfig

config = BartConfig.from_pretrained('facebook/bart-large')
config.decoder_ffn_dim = 3072
model = BTBAModel(config)

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    num_train_epochs=3,
    weight_decay=0.01
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)

trainer.train()
