<a href="https://colab.research.google.com/github/CaptainPlusPlus/btba_reproduction/blob/main/btba_reproduction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reproduction of BTBA Model for Unsupervised Word Alignment
Article can be found here: https://aclanthology.org/2021.acl-long.24.pdf

This reproduction offers a modification of the original transformer model as well as BART using the architecture enhancment for unsupervised learning of the alignment task as described in the article.

### Requirements
* Downloand and preprocess the deen, fren, roen texts from https://github.com/lilt/alignment-scripts.
* Upload the `bpe` lowercased preprocessed `train` & `test` folders as well as the sentencepiece `bpe` models.

## Custom Tokenizer Definition

Since the Sentencepiece tokenizer used in the article and alignment scripts it is compared against outputs a binary format model and vocabulary, the sentencepiece tokenizer must be adjusted to fit the HuggingFace models used to reproduce the article's transformer based approaach.

In [None]:
import sentencepiece as spm

class CustomSentencePieceTokenizer:
    def __init__(self, sentencepiece_model_path):
        self.sp = spm.SentencePieceProcessor()
        if not self.sp.Load(sentencepiece_model_path):
            raise FileNotFoundError("Failed to load SentencePiece model from specified path.")
        self.special_tokens = {'<s>': self.sp.piece_to_id('<s>'), '</s>': self.sp.piece_to_id('</s>'), '<unk>': self.sp.piece_to_id('<unk>')}
        self.additional_special_tokens = {'<pad>': self.sp.GetPieceSize(), '<mask>': self.sp.GetPieceSize() + 1}
        self.special_token_ids = {**self.special_tokens, **self.additional_special_tokens}

    def tokenize(self, text):
        return self.sp.encode_as_pieces(text)

    def convert_tokens_to_ids(self, tokens):
        return [self.special_token_ids.get(token, self.sp.piece_to_id(token)) for token in tokens]

    def convert_ids_to_tokens(self, ids):
        id_to_token_map = {id: token for token, id in self.special_token_ids.items()}
        id_to_token_map.update({id: self.sp.id_to_piece(id) for id in range(self.sp.GetPieceSize())})
        return [id_to_token_map.get(id, '<unk>') for id in ids]

    def get_vocab_size(self):
        return self.sp.GetPieceSize() + len(self.additional_special_tokens)

    def get_special_tokens(self):
        return {**self.special_tokens, **self.additional_special_tokens}

    def get_special_token_ids(self):
        return self.special_token_ids


## Tokenizer for BART without modifications

In [None]:
from transformers import BartForConditionalGeneration, Trainer, TrainingArguments
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch
import sentencepiece as spm

class CustomSentencePieceTokenizer:
    def __init__(self, sentencepiece_model_path):
        self.sp = spm.SentencePieceProcessor()
        if not self.sp.Load(sentencepiece_model_path):
            raise FileNotFoundError("Failed to load SentencePiece model from specified path.")
        self.special_tokens = {'<pad>': self.sp.piece_to_id('<pad>'), '<mask>': self.sp.piece_to_id('<mask>'), '<eos>': self.sp.piece_to_id('<eos>')}
        self.additional_special_tokens = {'<bos>': self.sp.GetPieceSize(), '<sep>': self.sp.GetPieceSize() + 1}
        self.special_token_ids = {**self.special_tokens, **self.additional_special_tokens}

    def tokenize(self, text):
        return self.sp.encode_as_pieces(text)

    def convert_tokens_to_ids(self, tokens):
        return [self.special_token_ids.get(token, self.sp.piece_to_id(token)) for token in tokens]

    def convert_ids_to_tokens(self, ids):
        id_to_token_map = {id: token for token, id in self.special_token_ids.items()}
        id_to_token_map.update({id: self.sp.id_to_piece(id) for id in range(self.sp.GetPieceSize())})
        return [id_to_token_map.get(id, '<unk>') for id in ids]

    def get_vocab_size(self):
        return self.sp.GetPieceSize() + len(self.additional_special_tokens)

    def get_special_tokens(self):
        return {**self.special_tokens, **self.additional_special_tokens}

    def get_special_token_ids(self):
        return self.special_token_ids

    def pad_token_id(self):
        return self.special_token_ids['<pad>']

In [None]:
def test_custom_tokenizer():
    tokenizer = CustomSentencePieceTokenizer('<PATH TO SENTENCEPIECE TRAIN MODEL FROM SOURCE OF PAIR>')
    test_sentence = "das ist ein test."
    print("Testing tokenization of sentence:", test_sentence)
    tokens = tokenizer.tokenize(test_sentence)
    print("Tokens:", tokens)
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    print("Token IDs:", token_ids)
    tokens_from_ids = tokenizer.convert_ids_to_tokens(token_ids)
    print("Tokens from IDs:", tokens_from_ids)
    print("Special Tokens:", tokenizer.get_special_tokens())
    print("Special Token IDs:", tokenizer.get_special_token_ids())
    special_tokens_test = ['<pad>', '<mask>', '<s>', '</s>', '<unk>']
    special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens_test)
    print("Special tokens to IDs:", list(zip(special_tokens_test, special_tokens_ids)))
    special_tokens_round_trip = tokenizer.convert_ids_to_tokens(special_tokens_ids)
    print("IDs back to special tokens:", special_tokens_round_trip)

test_custom_tokenizer()

Testing tokenization of sentence: das ist ein test.
Tokens: ['▁das', '▁ist', '▁ein', '▁test', '.']
Token IDs: [94, 158, 69, 4218, 39789]
Tokens from IDs: ['▁das', '▁ist', '▁ein', '▁test', '.']
Special Tokens: {'<pad>': 0, '<mask>': 0, '<eos>': 0, '<bos>': 40000, '<sep>': 40001}
Special Token IDs: {'<pad>': 0, '<mask>': 0, '<eos>': 0, '<bos>': 40000, '<sep>': 40001}
Special tokens to IDs: [('<pad>', 0), ('<mask>', 0), ('<s>', 1), ('</s>', 2), ('<unk>', 0)]
IDs back to special tokens: ['<unk>', '<unk>', '<s>', '</s>', '<unk>']


In [None]:
tokenizer = CustomSentencePieceTokenizer('PATH TO SENTENCEPIECE TRAIN MODEL FROM SOURCE OF PAIR')

## Load bpe lowercased data and save tokenized data
Load data, tokenize data and save it to file (so that subsequent runs don't have to reload data)

In [None]:
import torch

def load_data(src_file, tgt_file):
    with open(src_file, 'r', encoding='utf-8') as src_f, open(tgt_file, 'r', encoding='utf-8') as tgt_f:
        src_lines = [line.strip() for line in src_f.readlines()]
        tgt_lines = [line.strip() for line in tgt_f.readlines()]
    assert len(src_lines) == len(tgt_lines), "Source and target files should have the same number of lines."
    return src_lines, tgt_lines

def tokenize_and_save_data(src_lines, tgt_lines, tokenizer, src_path, tgt_path):
    tokenized_src = [torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line)), dtype=torch.long) for line in src_lines]
    tokenized_tgt = [torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line)), dtype=torch.long) for line in tgt_lines]
    torch.save(tokenized_src, src_path)
    torch.save(tokenized_tgt, tgt_path)
    print(f"Tokenized data saved to {src_path} and {tgt_path}")


### Define paths to data here after preprocessing

In [None]:
src_file_path = '<PATH TO SRC TRAIN, BPE LOWERCASE FORMAT>'
tgt_file_path = '<PATH TO TARGET TRAIN, BPE LOWERCASE FORMAT>'
model_path = '<PATH TO BPE SENTENCEPIECE TRAIN MODEL, BINARY>'
tokenized_src_path = '<PATH TO SAVE/LOAD TOKENIZED SRC TRAIN DATA>'
tokenized_tgt_path = '<PATH TO SAVE/LOAD TOKENIZED TARGERT TRAIN DATA>'

eval_src_path = '<PATH TO SRC EVAL, BPE LOWERCASE FORMAT>'
eval_tgt_path = '<PATH TO TARGET EVAL, BPE LOWERCASE FORMAT>'
tokenized_eval_src_path = '<PATH TO SAVE/LOAD TOKENIZED SRC EVAL DATA>'
tokenized_eval_tgt_path = '<PATH TO SAVE/LOAD TOKENIZED TARGET EVAL DATA>'

### Load and tokenize training and evaluation data

In [None]:
src_lines, tgt_lines = load_data(src_file_path, tgt_file_path)
tokenize_and_save_data(src_lines, tgt_lines, tokenizer, tokenized_src_path, tokenized_tgt_path)

In [None]:
src_lines, tgt_lines = load_data(eval_src_path, eval_tgt_path)
tokenize_and_save_data(src_lines, tgt_lines, tokenizer, tokenized_eval_src_path, tokenized_eval_tgt_path)

## Dynamic Masking for tje Dataset

* Every word in a sentence should be masked only once across the entire training - track the masking state and reset it after each epoch.
* Percentage-based: At least 10% of the words in each sentence must be masked, or one word if the sentence has less than ten words.

In [None]:
from torch.utils.data import Dataset

import numpy as np

class DynamicMaskingDataset(Dataset):
    def __init__(self, tokenized_src, tokenized_tgt, tokenizer, mask_probability=0.1):
        self.tokenized_src = tokenized_src
        self.tokenized_tgt = tokenized_tgt
        self.tokenizer = tokenizer
        self.mask_id = tokenizer.get_special_token_ids()['<mask>']
        self.pad_id = tokenizer.get_special_token_ids()['<pad>']
        self.mask_probability = mask_probability
        self.mask_tracker = {i: set() for i in range(len(tokenized_tgt))}

    def __len__(self):
        return len(self.tokenized_tgt)

    def reset_mask_tracker(self):
        self.mask_tracker = {i: set() for i in range(len(self.tokenized_tgt))}

    def mask_input(self, inputs, idx):
        num_tokens = len(inputs)
        num_to_mask = max(int(num_tokens * self.mask_probability), 1)
        labels = inputs.clone()
        candidate_mask = (inputs != self.pad_id) & (inputs != self.tokenizer.get_special_token_ids()['<s>']) & (inputs != self.tokenizer.get_special_token_ids()['</s>'])
        candidate_indices = np.setdiff1d(np.where(candidate_mask.numpy())[0], list(self.mask_tracker[idx]))

        if len(candidate_indices) == 0:
            self.mask_tracker[idx] = set()
            candidate_indices = np.where(candidate_mask.numpy())[0]

        if len(candidate_indices) < num_to_mask:
            num_to_mask = len(candidate_indices)

        if num_to_mask > 0:
            masked_indices = np.random.choice(candidate_indices, size=num_to_mask, replace=False)
            self.mask_tracker[idx].update(masked_indices)
            inputs[masked_indices] = self.mask_id
        else:
            labels.fill_(-100)

        labels[~candidate_mask] = -100
        return inputs, labels

    def __getitem__(self, idx):
        src = self.tokenized_src[idx]
        tgt = self.tokenized_tgt[idx]
        src, src_labels = self.mask_input(src, idx)
        tgt, tgt_labels = self.mask_input(tgt, idx)

        return {"input_ids": src, "labels": tgt_labels}


## Load tokenized data and and predefine to mask dynamically

In [None]:
tokenized_src = torch.load(tokenized_src_path)
tokenized_tgt = torch.load(tokenized_tgt_path)

In [None]:
tokenized_eval_src = torch.load(tokenized_eval_src_path)
tokenized_eval_tgt = torch.load(tokenized_eval_tgt_path)

In [None]:
dataset = DynamicMaskingDataset(tokenized_src, tokenized_tgt, tokenizer)

## For UNMODIFIED BART - load non masked dataset without collate

In [None]:
class SimpleDataset(torch.utils.data.Dataset):
    def __init__(self, tokenized_src, tokenized_tgt):
        self.tokenized_src = tokenized_src
        self.tokenized_tgt = tokenized_tgt

    def __len__(self):
        return len(self.tokenized_src)

    def __getitem__(self, idx):
        src = self.tokenized_src[idx]
        tgt = self.tokenized_tgt[idx]
        return {"input_ids": src, "labels": tgt}

In [None]:
dataset = SimpleDataset(tokenized_src, tokenized_tgt)


In [None]:
eval_dataset = SimpleDataset(tokenized_eval_src, tokenized_eval_tgt)

## Modigy BART according to BTBA architecture
* Remove final feed forward sublayer in the last decoder layer.
* Adjust model to initialize properly with these changes, padding to accomodate for masking is done in "collate"

In [None]:
from transformers import BartConfig, BartModel
import torch.nn as nn

class BTBADecoderLayer(nn.Module):
    def __init__(self, config, is_last_layer=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(config.d_model, config.decoder_attention_heads)
        self.multihead_attn = nn.MultiheadAttention(config.d_model, config.decoder_attention_heads)
        self.layer_norm1 = nn.LayerNorm(config.d_model)
        self.layer_norm2 = nn.LayerNorm(config.d_model)
        self.is_last_layer = is_last_layer
        if not is_last_layer:
            self.ffn = nn.Sequential(
                nn.Linear(config.d_model, config.decoder_ffn_dim),
                nn.ReLU(),
                nn.Linear(config.decoder_ffn_dim, config.d_model),
            )
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, memory, src_mask=None, tgt_mask=None):
        x = self.layer_norm1(x + self.dropout(self.self_attn(x, x, x, key_padding_mask=tgt_mask)[0]))
        x = self.layer_norm2(x + self.dropout(self.multihead_attn(x, memory, memory, key_padding_mask=src_mask)[0]))
        if not self.is_last_layer:
            x = self.ffn(x)
        return x

class BTBAModel(BartModel):
    def __init__(self, config):
        super().__init__(config)
        self.decoder.layers = nn.ModuleList([
            BTBADecoderLayer(config, is_last_layer=(i == config.decoder_layers - 1))
            for i in range(config.decoder_layers)
        ])


In [None]:
from transformers import BartForConditionalGeneration, BartConfig

config = BartConfig.from_pretrained('facebook/bart-large')
config.decoder_ffn_dim = 3072
model = BartForConditionalGeneration(config)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

### Load unmodified BART

In [None]:
from transformers import BartTokenizer, BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

# BTBA for original transformer model modification
To perform training with the original transformer model
*Requires a lot of memory

In [None]:
import torch
import torch.nn as nn
from torch.nn import Transformer

class CustomDecoderLayer(nn.TransformerDecoderLayer):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", is_last_layer=False):
        super().__init__(d_model, nhead, dim_feedforward, dropout, activation)
        self.is_last_layer = is_last_layer
        if self.is_last_layer:
            self.linear2 = None

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        x = tgt
        x = self.self_attn(x, x, x, attn_mask=tgt_mask,
                           key_padding_mask=tgt_key_padding_mask)[0]
        x = self.dropout1(x)
        x = self.norm1(tgt + x)

        x2 = self.multihead_attn(x, memory, memory, attn_mask=memory_mask,
                                 key_padding_mask=memory_key_padding_mask)[0]
        x = x + self.dropout2(x2)
        x = self.norm2(x)

        if not self.is_last_layer:
            x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))
            x = x + self.dropout3(x2)
            x = self.norm3(x)

        return x

class CustomTransformerModel(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()
        self.transformer = Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)
        self.transformer.decoder.layers = nn.ModuleList([
            CustomDecoderLayer(d_model, nhead, is_last_layer=(i == num_decoder_layers - 1))
            for i in range(num_decoder_layers)
        ])

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None):
        return self.transformer(src, tgt, src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask)

d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6

model = CustomTransformerModel(d_model, nhead, num_encoder_layers, num_decoder_layers)

## Collate function
To accomodate for masking while training modified BART (Padding and masking properly)

In [None]:
def collate_fn(batch):
    input_ids = pad_sequence([item['input_ids'] for item in batch], batch_first=True, padding_value=tokenizer.pad_token_id)
    labels = pad_sequence([item['labels'] for item in batch], batch_first=True, padding_value=-100)
    attention_mask = input_ids.ne(tokenizer.pad_token_id).int()
    return {'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask}

In [None]:
# FOR UNMODIFIED BART, TO TEST PERFORMANCE WITH JUST MASKING
def collate_fn(batch):
    pad_token_id = tokenizer.pad_token_id()
    input_ids = pad_sequence([item['input_ids'] for item in batch], batch_first=True, padding_value=pad_token_id).to(device)
    labels = pad_sequence([item['labels'] for item in batch], batch_first=True, padding_value=-100).to(device)
    attention_mask = input_ids.ne(pad_token_id).int().to(device)
    return {'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask}

In [None]:
### Load data

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

data_loader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)

for batch in data_loader:
    print("Batch shapes and device:")
    print("Input IDs:", batch['input_ids'].shape, batch['input_ids'].device)
    print("Labels:", batch['labels'].shape, batch['labels'].device)
    print("Attention Masks:", batch['attention_mask'].shape, batch['attention_mask'].device)

    model = model.to(batch['input_ids'].device)
    outputs = model(**batch)
    print("Output Logits Shape:", outputs.logits.shape)
    break


Batch shapes and device:
Input IDs: torch.Size([2, 82]) cuda:0
Labels: torch.Size([2, 82]) cuda:0
Attention Masks: torch.Size([2, 82]) cuda:0
Output Logits Shape: torch.Size([2, 82, 50265])


### For testing - subsample the data

In [None]:
import torch
from torch.utils.data import Subset
import multiprocessing
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

multiprocessing.set_start_method('spawn', force=True)

def subsample_dataset(dataset, factor=10):
    subset_indices = list(range(0, len(dataset), factor))
    return Subset(dataset, subset_indices)

subsampled_dataset = subsample_dataset(dataset, factor=400)

data_loader = DataLoader(subsampled_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)


### Customize Trainer for BART models

In [None]:
#BTBA/UNMODIFIED BART
from torch.nn.functional import cross_entropy
from transformers import Trainer, TrainingArguments

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels", None)
        outputs = model(**inputs)
        logits = outputs.logits

        # Match the logits and labels lengths
        if labels is not None and logits.size(1) != labels.size(1):
            min_seq_len = min(logits.size(1), labels.size(1))
            logits = logits[:, :min_seq_len, :].contiguous()
            labels = labels[:, :min_seq_len].contiguous()

        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)

        loss = cross_entropy(logits, labels, ignore_index=-100)
        return (loss, outputs) if return_outputs else loss




In [None]:
# ORIGINAL TRANSFORMER MODEL
from torch.nn.functional import cross_entropy
from transformers import Trainer, TrainingArguments

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels", None)
        outputs = model(**inputs)
        logits = outputs.logits

        if labels is not None and logits.size(1) != labels.size(1):
            min_seq_len = min(logits.size(1), labels.size(1))
            logits = logits[:, :min_seq_len, :].contiguous()
            labels = labels[:, :min_seq_len].contiguous()

        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)
        loss = cross_entropy(logits, labels, ignore_index=-100)
        return (loss, outputs) if return_outputs else loss


# Training and train parameters

In [None]:
# UNMODIFIED BART BART
from transformers import Trainer, TrainingArguments
from torch.utils.data import DataLoader
import numpy as np
import torch.multiprocessing as mp

training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    save_strategy="no",
    learning_rate=3e-5,
    per_device_train_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=False,
    fp16=True,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collate_fn
)

trainer.train()

In [None]:
# ORIGINAL TRANSFORMER MODEL BTBA

from transformers import Trainer, TrainingArguments
from torch.utils.data import DataLoader
import numpy as np  # Ensure numpy is imported
import torch.multiprocessing as mp

training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    fp16=True,
)


data_loader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_fn
)



trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collate_fn,
)

trainer.train()


In [None]:
print("Total number of samples in dataset:", len(dataset))
print("Configured batch size:", training_args.per_device_train_batch_size)
print("Total number of batches per epoch:", len(dataset) // training_args.per_device_train_batch_size)


In [None]:
# BTBA BART TRAINING

from transformers import Trainer, TrainingArguments
from torch.utils.data import DataLoader
import torch

training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    save_strategy="no",
    learning_rate=3e-5,
    per_device_train_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=False,
    fp16=True,
    report_to="none"

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=subsampled_dataset,
    data_collator=collate_fn
)

trainer.train()


## Evaluate and test evaluation data
* Use this data with your talp (gold alignments) data
* Converts talp to expected AER and evaluation format.

In [None]:
def parse_alignment_line(line):
    sure_alignments = set()
    possible_alignments = set()
    for alignment in line.split():
        # Check if alignment is marked as 'possible'
        if 'p' in alignment:
            src, tgt = alignment.replace('p', '-').split('-')
            possible_alignments.add((int(src), int(tgt)))
        else:
            src, tgt = alignment.split('-')
            sure_alignments.add((int(src), int(tgt)))
            possible_alignments.add((int(src), int(tgt)))  # Include sure in possible

    return sure_alignments, possible_alignments

def process_talp_file(file_path):
    gold_alignments = {'sure': set(), 'possible': set()}

    with open(file_path, 'r') as file:
        for line in file:
            sure, possible = parse_alignment_line(line)
            gold_alignments['sure'].update(sure)
            gold_alignments['possible'].update(possible)

    return gold_alignments

alignment_file_path = '<PATH TO PAIR TEST TALP FILE>'
gold_alignments = process_talp_file(alignment_file_path)
print(gold_alignments['sure'])
print(gold_alignments['possible'])


{(15, 21), (26, 21), (18, 17), (7, 17), (26, 30), (15, 30), (18, 26), (26, 39), (29, 32), (8, 9), (19, 9), (11, 5), (8, 18), (19, 18), (11, 14), (11, 23), (33, 20), (33, 29), (10, 27), (25, 25), (4, 2), (33, 38), (25, 34), (3, 6), (22, 19), (14, 15), (3, 15), (22, 28), (34, 30), (14, 24), (15, 7), (7, 3), (15, 16), (26, 16), (7, 12), (18, 12), (15, 25), (26, 25), (18, 21), (7, 21), (18, 30), (29, 27), (8, 4), (29, 36), (30, 13), (21, 32), (11, 9), (10, 22), (33, 24), (25, 20), (33, 33), (25, 29), (22, 5), (25, 38), (3, 1), (14, 1), (22, 14), (34, 16), (14, 10), (3, 10), (34, 25), (22, 23), (14, 19), (22, 32), (37, 30), (14, 28), (15, 2), (36, 34), (15, 11), (7, 7), (18, 7), (26, 20), (7, 16), (18, 16), (29, 22), (21, 18), (29, 31), (21, 27), (11, 4), (40, 40), (10, 8), (10, 17), (25, 15), (2, 13), (10, 26), (33, 28), (25, 24), (33, 37), (25, 33), (3, 5), (14, 5), (22, 18), (14, 14), (3, 14), (22, 27), (14, 23), (36, 29), (15, 6), (17, 25), (28, 25), (7, 2), (18, 2), (28, 34), (7, 11), 

In [None]:
def calculate_aer(predicted_alignments, gold_alignments):
    sure = set(gold_alignments['sure'])
    possible = set(gold_alignments['possible']).union(sure)
    predicted = set(predicted_alignments)

    num_predicted = len(predicted)
    num_sure = len(sure)
    num_possible = len(possible)
    num_correct_predicted = len(predicted.intersection(possible))
    num_correct_sure = len(predicted.intersection(sure))

    precision = num_correct_predicted / num_predicted if num_predicted > 0 else 0
    recall = num_correct_sure / num_sure if num_sure > 0 else 0
    aer = 1 - (num_correct_predicted + num_correct_sure) / (num_predicted + num_sure)

    return aer, precision, recall

class CustomTrainer(Trainer):
    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"):
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        if eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        total_aer = 0
        total_precision = 0
        total_recall = 0
        num_batches = 0

        for batch in eval_dataloader:
            batch = {k: v.to(self.args.device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = self.model(**batch)

            predicted_alignments = self.get_predicted_alignments(outputs)

            if num_batches < len(gold_alignments):
                gold_alignments_batch = gold_alignments[num_batches]
            else:
                continue  # Or handle the case where there's no corresponding gold data

            aer, precision, recall = calculate_aer(predicted_alignments, gold_alignments_batch)
            total_aer += aer
            total_precision += precision
            total_recall += recall
            num_batches += 1

        avg_aer = total_aer / num_batches
        avg_precision = total_precision / num_batches
        avg_recall = total_recall / num_batches

        metrics = {
            f"{metric_key_prefix}_aer": avg_aer,
            f"{metric_key_prefix}_precision": avg_precision,
            f"{metric_key_prefix}_recall": avg_recall
        }

        self.log(metrics)
        return metrics

    def get_predicted_alignments(self, outputs):
        predicted_alignments = set()
        if 'alignment' in outputs:
            alignments = outputs['alignment']
            for pair in alignments:
                predicted_alignments.add((pair[0].item(), pair[1].item()))  # Example conversion from tensor
        return predicted_alignments

training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy="epoch",
    save_strategy="no",
    learning_rate=3e-5,
    per_device_train_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=False,
    fp16=True,
    report_to="none"
)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=subsampled_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn
)
trainer.train()

Epoch,Training Loss,Validation Loss


KeyError: 0

In [None]:
def evaluate_with_talp(model, tokenizer, data_loader, device):
    model.eval()
    model.to(device)
    alignments = []
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            outputs = model(input_ids)
            generated_tokens = torch.argmax(outputs.logits, dim=-1)
            alignments.extend(process_alignment_data(generated_tokens, tokenizer))
    return alignments

def process_alignment_data(generated_tokens, tokenizer):
    return [tokenizer.convert_ids_to_tokens(g) for g in generated_tokens]


# Optimizations
* FCBO parameter freezing
* Symmetrize and train on labels

In [None]:
def full_context_based_optimization(model, train_dataloader, optimizer, scheduler, device, num_iterations=50):
    model.train()
    for iteration in range(num_iterations):
        for batch in train_dataloader:
            input_ids, labels = batch['input_ids'].to(device), batch['labels'].to(device)
            outputs = model(input_ids, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()


In [None]:
def grow_diagonal_final_and(alignments_lr, alignments_rl):
    """
    Symmetrize alignments using the grow-diagonal-final-and heuristic.
    :param alignments_lr: List of alignments from left-to-right model
    :param alignments_rl: List of alignments from right-to-left model
    :return: Symmetrized alignments
    """
    alignments = set(alignments_lr).intersection(set(alignments_rl))
    alignments.update(set(alignments_lr).difference(set(alignments_rl)))
    alignments.update(set(alignments_rl).difference(set(alignments_lr)))

    def grow_diagonal(alignments):
        grown_alignments = set(alignments)
        for (i, j) in alignments:
            for (di, dj) in [(0, 1), (1, 0), (1, 1), (-1, -1)]:
                if (i + di, j + dj) in alignments_lr or (i + di, j + dj) in alignments_rl:
                    grown_alignments.add((i + di, j + dj))
        return grown_alignments

    alignments = grow_diagonal(alignments)
    return alignments

In [None]:
# Example usage:
alignments_lr = [(0, 0), (1, 1), (2, 2)]
alignments_rl = [(0, 0), (1, 1), (2, 3)]

symmetrized_alignments = grow_diagonal_final_and(alignments_lr, alignments_rl)
print(symmetrized_alignments)

### Cleanup memory for subsequent runs

In [None]:
import gc

gc.collect()
torch.cuda.empty_cache()