In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import plotly.express as px

from seq2seq_translation.tokenization.sentencepiece_tokenizer import SentencePieceTokenizer
from seq2seq_translation.datasets.datasets import LanguagePairsDatasets

In [2]:
attention = pd.read_csv('/Users/adam.amster/seq2seq_translation/results/eval_metrics_attention_wmt14_test_wmt14_bleu.csv')
greedy = pd.read_csv('/Users/adam.amster/seq2seq_translation/results/eval_metrics_greedy_wmt14_test.csv')

source_tokenizer = SentencePieceTokenizer(model_prefix='/Users/adam.amster/seq2seq_translation/tokenizer/30000/en30000')
target_tokenizer = SentencePieceTokenizer(model_prefix='/Users/adam.amster/seq2seq_translation/tokenizer/30000/fr30000')

In [3]:


from seq2seq_translation.sentence_pairs_dataset import SentencePairsDataset


def construct_test_dset():
	test_datasets = LanguagePairsDatasets(
			out_dir=Path('/Users/adam.amster/seq2seq_translation/datasets/wmt14_test'),
			source_lang='en',
			target_lang='fr',
			is_test=True
	)
	
	test_dset = SentencePairsDataset(
		datasets=test_datasets,
		idxs=np.arange(len(test_datasets)),
		source_tokenizer=source_tokenizer,
		target_tokenizer=target_tokenizer,
		max_length=None,
	)
	return test_dset

In [4]:
test_dset = construct_test_dset()

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

In [5]:
def get_beam_search_better_examples():
	diff = (attention['bleu'] - greedy['bleu']).tolist()
	diff_idx = torch.sort(-torch.tensor(diff)).indices
	for idx in diff_idx[:10]:
		idx = idx.item()
		print('beam bleu', attention.iloc[idx]['bleu'])
		print('greedy bleu', greedy.iloc[idx]['bleu'])
		print('input', source_tokenizer.decode(test_dset[idx][0]))
		print('attention pred', attention.iloc[idx]['pred'])
		print('greedy pred', greedy.iloc[idx]['pred'])
		print('='*11)

In [6]:
def get_greedy_better_examples():
	diff = (greedy['bleu'] - attention['bleu']).tolist()
	diff_idx = torch.sort(-torch.tensor(diff)).indices
	for idx in diff_idx[:10]:
		idx = idx.item()
		print('beam bleu', attention.iloc[idx]['bleu'])
		print('greedy bleu', greedy.iloc[idx]['bleu'])
		print('input', source_tokenizer.decode(test_dset[idx][0]))
		print('attention pred', attention.iloc[idx]['pred'])
		print('greedy pred', greedy.iloc[idx]['pred'])
		print('='*11)

In [11]:
# Calculate the difference between 'bleu' scores
difference = attention['bleu'] - greedy['bleu']

# Create a histogram using Plotly
fig = px.histogram(difference, title='Beam search - greedy search BLEU score', nbins=20)

# Update the x-axis label
fig.update_xaxes(title_text='difference')

fig.update_layout(
	autosize=True,
    plot_bgcolor='rgba(0,0,0,0)',  # Transparent plot area
    paper_bgcolor='rgba(0,0,0,0)',  # Transparent outer background
    font=dict(color='black'),  # Set tick label color for visibility
)

# Show the plot
fig.write_json('/Users/adam.amster/aamster.github.io/assets/plotly/2024-10-03-sequence_to_sequence_translation/beam_search.json')