In [13]:
import os
from pathlib import Path

import numpy as np
import torch

from seq2seq_translation.datasets.datasets import LanguagePairsDatasets
from seq2seq_translation.sentence_pairs_dataset import SentencePairsDataset
from seq2seq_translation.tokenization.sentencepiece_tokenizer import SentencePieceTokenizer
from seq2seq_translation.run import _fix_model_state_dict

In [14]:
torch.random.manual_seed(1234)
np.random.seed(1234)

In [15]:
source_tokenizer = SentencePieceTokenizer(model_prefix='/Users/adam.amster/Downloads/30000/en30000')
target_tokenizer = SentencePieceTokenizer(model_prefix='/Users/adam.amster/Downloads/30000/fr30000')

In [16]:
os.environ['DEVICE'] = 'cpu'

In [17]:
def construct_test_dset():
	test_datasets = LanguagePairsDatasets(
			out_dir=Path('/Users/adam.amster/seq2seq_translation/datasets/wmt14_test'),
			source_lang='en',
			target_lang='fr',
			is_test=True
	)
	
	test_dset = SentencePairsDataset(
		datasets=test_datasets,
		idxs=np.arange(len(test_datasets)),
		source_tokenizer=source_tokenizer,
		target_tokenizer=target_tokenizer,
		max_length=None,
        eos_token_id=source_tokenizer.eot_idx,
        pad_token_id=source_tokenizer.pad_idx,
        add_bos_token_id=True,
        bos_token_id=source_tokenizer.processor.bos_id()
	)
	return test_dset

In [18]:
test_dset = construct_test_dset()

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

In [19]:
from seq2seq_translation.models.attention.attention import AttentionType
from seq2seq_translation.models.rnn import EncoderRNN, AttnDecoderRNN, EncoderDecoderRNN


def construct_model_attention():
	encoder = EncoderRNN(
		input_size=source_tokenizer.processor.vocab_size(),
		hidden_size=1000,
		bidirectional=True,
		pad_idx=source_tokenizer.processor.pad_id(),
		embedding_dim=1000,
		num_layers=4,
	)

	decoder = AttnDecoderRNN(
		hidden_size=1000,
		attention_size=1000,
		output_size=target_tokenizer.processor.vocab_size(),
		encoder_bidirectional=True,
		max_len=200,
		attention_type=AttentionType.CosineSimilarityAttention,
		encoder_output_size=encoder.output_size,
		pad_idx=source_tokenizer.processor.pad_id(),
		num_embeddings=target_tokenizer.processor.vocab_size(),
		sos_token_id=source_tokenizer.processor.bos_id(),
		embedding_dim=1000,
		num_layers=4,
		eos_token_id=target_tokenizer.processor.eos_id()
	)
	
	encoder.load_state_dict(
		_fix_model_state_dict(torch.load('/Users/adam.amster/Downloads/attention/encoder.pt', map_location='cpu'))
	)
	decoder.load_state_dict(
		_fix_model_state_dict(torch.load('/Users/adam.amster/Downloads/attention/decoder.pt', map_location='cpu'))
	)
	encoder_decoder = EncoderDecoderRNN(
		encoder=encoder,
		decoder=decoder,
	)
	encoder_decoder.eval()
	return encoder_decoder

In [20]:
encoder_decoder = construct_model_attention()

In [21]:
def get_test_dset_examples():
	short = [x for x in test_dset if  10 <= len(x[0]) <= 20]
	return short

In [22]:
short_examples = get_test_dset_examples()
rand_idxs = torch.randint(low=0, high=len(short_examples), size=(4,))

In [23]:
[source_tokenizer.processor.id_to_piece(x.item()) for x in short_examples[rand_idxs[0]][0]]

['<s>',
 '▁But',
 '▁why',
 '▁such',
 '▁optimism',
 '▁for',
 '▁some',
 '▁and',
 '▁pes',
 'sim',
 'ism',
 '▁for',
 '▁others',
 '?',
 '</s>']

In [24]:
def get_attn_weights():
	sources = []
	preds = []
	attn_weights = []
	for idx in rand_idxs:
		source = short_examples[idx][0].unsqueeze(0)
		with torch.no_grad():
			decoder_outputs, _, decoder_attn = encoder_decoder(source, input_lengths=torch.tensor([source.shape[1]]), return_attention_weights=True)
			pred = decoder_outputs.argmax(dim=-1)[0]
			print(f'source: {source_tokenizer.decode(source[0])}')
			print(f'target: {target_tokenizer.decode(pred)}')
		sources.append(source)
		preds.append(pred)
		attn_weights.append(decoder_attn)
	return sources, preds, attn_weights

In [25]:
sources, preds, attn_weights = get_attn_weights()

source: But why such optimism for some and pessimism for others?
target: Mais pourquoi un tel optimisme pour certains et le pessimisme pour les autres?
source: Regulatory authority over phone calls belongs to the Federal Communications Commission, not the FAA.
target: L'autorité réglementaire sur les appels téléphoniques appartient à la Commission fédérale des communications et non à la FAA.
source: They don't want us to dictate to them what makes them profitable.
target: Ils ne veulent pas que nous leur dictions ce qui les rend rentables.
source: The cinema was ventilated and everyone returned in good order.
target: Le cinéma a été ventilé et tous les gens sont revenus dans l'ordre.


In [27]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

# Create a 1x1 grid of subplots
fig = make_subplots(
	rows=1,
	cols=1,
    # horizontal_spacing=0.09,
    # vertical_spacing=0.09
)

# Add heatmaps to the grid
row_col_idx_map = {
    (0, 0): 0,
    (1, 0): 1,
    (0, 1): 2,
    (1, 1): 3
}
for i in range(1):
	for j in range(1):
		idx = row_col_idx_map[(i, j)]
		x_labels = [source_tokenizer.processor.id_to_piece(x.item()) for x in sources[idx][0]]
		y_labels = [target_tokenizer.processor.id_to_piece(x.item()) for x in preds[idx]]


		# need to escape <s> and </s>
		x_labels = [x if x not in ('<s>', '</s>') else ('&lt;s&gt;' if x == '<s>' else ' &lt;/s&gt;') for x in x_labels]
		y_labels = [x if x not in ('<s>', '</s>') else ('&lt;s&gt;' if x == '<s>' else ' &lt;/s&gt;') for x in y_labels]

		z_data = attn_weights[idx][0]
		x_indices = list(range(len(x_labels)))
		y_indices = list(range(len(y_labels)))

		fig.add_trace(
			go.Heatmap(
				z=z_data,
				colorscale="Blues",
				zmin=0,
				zmax=1,
				showscale=False
			),
			row=i + 1, col=j + 1
		)
		fig.update_yaxes(
			autorange='reversed',
			row=i + 1,
			col=j + 1,
			tickmode='array',
			tickvals=y_indices,
			ticktext=y_labels,
		)
		fig.update_xaxes(
			tickangle=45,
			row=i + 1,
			col=j + 1,
			tickmode='array',
			tickvals=x_indices,
			ticktext=x_labels,
		)

	fig.update_layout(
	height=400,
	width=None,
	coloraxis=dict(
		colorscale="Blues"  # Set your desired colorscale here
	),
	autosize=True,
	plot_bgcolor='rgba(0,0,0,0)',  # Transparent plot area
	paper_bgcolor='rgba(0,0,0,0)',  # Transparent outer background
	font=dict(color='black'),  # Set tick label color for visibility
	margin=dict(t=0, r=0)
	)

	# Display the plot
	#fig.show()
	fig.write_json('/Users/adam.amster/aamster.github.io/assets/plotly/2025-04-13-sequence_to_sequence_translation_2/attention_weights.json')