In [13]:
import os
from pathlib import Path

import numpy as np
import torch

from seq2seq_translation.datasets.datasets import LanguagePairsDatasets
from seq2seq_translation.sentence_pairs_dataset import SentencePairsDataset
from seq2seq_translation.tokenization.sentencepiece_tokenizer import SentencePieceTokenizer
from seq2seq_translation.run import _fix_model_state_dict

In [14]:
torch.random.manual_seed(1234)
np.random.seed(1234)

In [15]:
source_tokenizer = SentencePieceTokenizer(model_prefix='/Users/adam.amster/Downloads/30000/en30000')
target_tokenizer = SentencePieceTokenizer(model_prefix='/Users/adam.amster/Downloads/30000/fr30000')

In [16]:
os.environ['DEVICE'] = 'cpu'

In [17]:
def construct_test_dset():
	test_datasets = LanguagePairsDatasets(
			data_path=Path('/Users/adam.amster/seq2seq_translation/datasets/wmt14_test/wmt___wmt14/fr-en/0.0.0/b199e406369ec1b7634206d3ded5ba45de2fe696/wmt14-test.arrow'),
			source_lang='en',
			target_lang='fr',
			is_test=True
	)
	
	test_dset = SentencePairsDataset(
		datasets=test_datasets,
		idxs=np.arange(len(test_datasets)),
		source_tokenizer=source_tokenizer,
		target_tokenizer=target_tokenizer,
		max_length=None,
        eos_token_id=source_tokenizer.eot_idx,
        pad_token_id=source_tokenizer.pad_idx,
        add_bos_token_id=True,
        bos_token_id=source_tokenizer.processor.bos_id()
	)
	return test_dset

In [18]:
test_dset = construct_test_dset()

In [19]:
from seq2seq_translation.models.attention.attention import AttentionType
from seq2seq_translation.models.rnn import EncoderRNN, AttnDecoderRNN, EncoderDecoderRNN


def construct_model_attention():
	encoder = EncoderRNN(
		input_size=source_tokenizer.processor.vocab_size(),
		hidden_size=1000,
		bidirectional=True,
		pad_idx=source_tokenizer.processor.pad_id(),
		embedding_dim=1000,
		num_layers=4,
	)

	decoder = AttnDecoderRNN(
		hidden_size=1000,
		attention_size=1000,
		output_size=target_tokenizer.processor.vocab_size(),
		encoder_bidirectional=True,
		max_len=200,
		attention_type=AttentionType.CosineSimilarityAttention,
		encoder_output_size=encoder.output_size,
		pad_idx=source_tokenizer.processor.pad_id(),
		num_embeddings=target_tokenizer.processor.vocab_size(),
		sos_token_id=source_tokenizer.processor.bos_id(),
		embedding_dim=1000,
		num_layers=4,
		eos_token_id=target_tokenizer.processor.eos_id()
	)
	
	encoder.load_state_dict(
		_fix_model_state_dict(torch.load('/Users/adam.amster/Downloads/attention/encoder.pt', map_location='cpu'))
	)
	decoder.load_state_dict(
		_fix_model_state_dict(torch.load('/Users/adam.amster/Downloads/attention/decoder.pt', map_location='cpu'))
	)
	encoder_decoder = EncoderDecoderRNN(
		encoder=encoder,
		decoder=decoder,
	)
	encoder_decoder.eval()
	return encoder_decoder

In [20]:
encoder_decoder = construct_model_attention()

In [21]:
def get_test_dset_examples():
	short = [x for x in test_dset if  10 <= len(x[0]) <= 20]
	return short

In [22]:
short_examples = get_test_dset_examples()
rand_idxs = torch.randint(low=0, high=len(short_examples), size=(4,))

In [23]:
[source_tokenizer.processor.id_to_piece(x.item()) for x in short_examples[rand_idxs[0]][0]]

['<s>',
 '▁But',
 '▁why',
 '▁such',
 '▁optimism',
 '▁for',
 '▁some',
 '▁and',
 '▁pes',
 'sim',
 'ism',
 '▁for',
 '▁others',
 '?',
 '</s>']

In [24]:
def get_attn_weights():
	sources = []
	preds = []
	attn_weights = []
	for idx in rand_idxs:
		source = short_examples[idx][0].unsqueeze(0)
		with torch.no_grad():
			decoder_outputs, _, decoder_attn = encoder_decoder(source, input_lengths=torch.tensor([source.shape[1]]), return_attention_weights=True)
			pred = decoder_outputs.argmax(dim=-1)[0]
			print(f'source: {source_tokenizer.decode(source[0])}')
			print(f'target: {target_tokenizer.decode(pred)}')
		sources.append(source)
		preds.append(pred)
		attn_weights.append(decoder_attn)
	return sources, preds, attn_weights

In [25]:
sources, preds, attn_weights = get_attn_weights()

source: But why such optimism for some and pessimism for others?
target: Mais pourquoi un tel optimisme pour certains et le pessimisme pour les autres?
source: Regulatory authority over phone calls belongs to the Federal Communications Commission, not the FAA.
target: L'autorité réglementaire sur les appels téléphoniques appartient à la Commission fédérale des communications et non à la FAA.
source: They don't want us to dictate to them what makes them profitable.
target: Ils ne veulent pas que nous leur dictions ce qui les rend rentables.
source: The cinema was ventilated and everyone returned in good order.
target: Le cinéma a été ventilé et tous les gens sont revenus dans l'ordre.


In [34]:
import math, numpy as np, plotly.graph_objects as go
from plotly.subplots import make_subplots

# ───────────  STYLE CONSTANTS  ───────────
ROW_H_PX   = 22          # px between token rows
FONT_SIZE  = 14          # pt
COL_GAP    = 1.8         # x-distance between columns
LINE_ALPHA = 0.45        # line opacity
LW_SCALE   = 4           # max line width multiplier
THRESH     = 0.1        # skip weights < THRESH
# ─────────────────────────────────────────


# ─────────────────────  ONE HEAD  →  ONE SUBPLOT  ──────────────────────
def add_arc_subplot(fig, row, col,
                    weights, left_tok, right_tok,
                    *, show=True):
    """
    Draw vertical/vertical arc diagram for a single head in subplot (row, col).

    * first token in each list is at the TOP row (y = 0)
    * one blue line per weight ≥ THRESH
    * shorter column simply ends – no phantom lines
    """
    T_q, T_k  = weights.shape
    max_len   = max(T_q, T_k)
    y_left    = np.arange(T_q) * ROW_H_PX
    y_right   = np.arange(T_k) * ROW_H_PX
    ymax      = (max_len - 1) * ROW_H_PX + ROW_H_PX / 2

    # 1️⃣ lines under the text
    for i in range(T_q):
        for j in range(T_k):
            w = float(weights[i, j])
            if w < THRESH:
                continue
            # coordinates for a *dense* line -----------------------------
            n_pts   = 20                                  # ← tweak if needed
            xs      = np.linspace(0.15, COL_GAP - 0.15, n_pts)
            ys      = np.linspace(y_left[i], y_right[j], n_pts)

            fig.add_trace(
                go.Scatter(
                    x=xs, y=ys,

                    # draw the blue segment
                    mode='lines+markers',
                    line=dict(width=LW_SCALE * w,
                              color=f'rgba(65,105,225,{LINE_ALPHA})'),

                    # invisible markers sitting *on* the line
                    marker=dict(size=8, color='rgba(0,0,0,0)'),

                    # one tooltip for all points
                    hoverinfo='text',
                    hovertext=[(
                        f"src: {left_tok[i]}"
                        f"<br>tgt: {right_tok[j]}"
                        f"<br>weight: {w:.3f}"
                    )] * n_pts,
                    hovertemplate="%{hovertext}<extra></extra>",

                    showlegend=False,
                    visible=show),
                row=row, col=col)


    # 2️⃣ token labels over the lines
    fig.add_trace(
        go.Scatter(x=np.zeros(T_q), y=y_left,
                   mode='text', text=left_tok,
                   textfont=dict(size=FONT_SIZE, color='black'),
                   textposition='middle right',
                   showlegend=False, visible=show),
        row=row, col=col)

    fig.add_trace(
        go.Scatter(x=np.full(T_k, COL_GAP), y=y_right,
                   mode='text', text=right_tok,
                   textfont=dict(size=FONT_SIZE, color='black'),
                   textposition='middle left',
                   showlegend=False, visible=show),
        row=row, col=col)

    # 3️⃣ axes
    fig.update_xaxes(visible=False, range=[-0.5, COL_GAP + .5], row=row, col=col)
    fig.update_yaxes(visible=False, range=[-ROW_H_PX / 2, ymax],
                     autorange='reversed', row=row, col=col)


# ─────────────────────────────  DRIVER  ────────────────────────────────
def plot_attn_weights(base_output_dir: Path, n_columns: int = 3):
	"""
	Build full multi-head / multi-layer arc-diagram figure and write to HTML+JSON.
	"""
	sources, preds, attn_weights = get_attn_weights()

	for idx in range(1):
		n_rows = 1

		fig = make_subplots(
			rows=1, cols=1,
			horizontal_spacing=.09, vertical_spacing=.03)

		right = [source_tokenizer.processor.id_to_piece(x.item()) for x in sources[idx][0]]
		left = [target_tokenizer.processor.id_to_piece(x.item()) for x in preds[idx]]

		# need to escape <s> and </s>
		right = [x if x not in ('<s>', '</s>') else ('&lt;s&gt;' if x == '<s>' else ' &lt;/s&gt;') for x in right]
		left = [x if x not in ('<s>', '</s>') else ('&lt;s&gt;' if x == '<s>' else ' &lt;/s&gt;') for x in left]

		# build all subplots & store trace ids per layer
		traces = []
		start = len(fig.data)

		#print(ex, lay, hd, start)
		add_arc_subplot(
			fig, row=1, col=1,
			weights=attn_weights[idx][0],
			left_tok=left,
			right_tok=right,
			show=True)

		traces.extend(range(start, len(fig.data)))

		max_tok   = max(len(left), len(right))
		subplot_h = max_tok * ROW_H_PX + 40

		fig.update_layout(
			autosize=False,
			width  = n_columns * 320,
			height = n_rows * subplot_h + 120,
			margin=dict(t=20,l=20,r=20,b=60),
			font=dict(size=FONT_SIZE, color='black'),
			plot_bgcolor='rgba(0,0,0,0)',
			paper_bgcolor='rgba(0,0,0,0)',
		)
		for ann in fig.layout.annotations:
			ann.font.update(size=FONT_SIZE+2, color='black')

		base = (f'{base_output_dir}/'
				f'rnn_attention_weights_{idx}')
		fig.write_json(base+'.json')
		fig.write_html(base+'.html')
plot_attn_weights(base_output_dir=Path('/Users/adam.amster/aamster.github.io/assets/plotly/2025-04-13-sequence_to_sequence_translation_2/'))

source: But why such optimism for some and pessimism for others?
target: Mais pourquoi un tel optimisme pour certains et le pessimisme pour les autres?
source: Regulatory authority over phone calls belongs to the Federal Communications Commission, not the FAA.
target: L'autorité réglementaire sur les appels téléphoniques appartient à la Commission fédérale des communications et non à la FAA.
source: They don't want us to dictate to them what makes them profitable.
target: Ils ne veulent pas que nous leur dictions ce qui les rend rentables.
source: The cinema was ventilated and everyone returned in good order.
target: Le cinéma a été ventilé et tous les gens sont revenus dans l'ordre.
