In [24]:
import os
from pathlib import Path

import numpy as np
import torch

from seq2seq_translation.datasets.datasets import LanguagePairsDatasets
from seq2seq_translation.sentence_pairs_dataset import SentencePairsDataset
from seq2seq_translation.tokenization.sentencepiece_tokenizer import SentencePieceTokenizer
from seq2seq_translation.run import _fix_model_state_dict

In [25]:
torch.random.manual_seed(1234)
np.random.seed(1234)

In [26]:
source_tokenizer = SentencePieceTokenizer(model_prefix='/Users/adam.amster/seq2seq_translation/tokenizer/30000/en')
target_tokenizer = SentencePieceTokenizer(model_prefix='/Users/adam.amster/seq2seq_translation/tokenizer/30000/fr')

In [27]:
os.environ['DEVICE'] = 'cpu'

In [28]:
def construct_test_dset():
	test_datasets = LanguagePairsDatasets(
			out_dir=Path('/Users/adam.amster/seq2seq_translation/datasets/wmt14_test'),
			source_lang='en',
			target_lang='fr',
			is_test=True
	)
	
	test_dset = SentencePairsDataset(
		datasets=test_datasets,
		idxs=np.arange(len(test_datasets)),
		source_tokenizer=source_tokenizer,
		target_tokenizer=target_tokenizer,
		max_length=None,
        eos_token_id=source_tokenizer.eot_idx,
        pad_token_id=source_tokenizer.pad_idx,
	)
	return test_dset

In [29]:
test_dset = construct_test_dset()

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

In [30]:
from seq2seq_translation.config.transformer_config import TransformerConfig
import json
from seq2seq_translation.models.transformer.encoder_decoder import EncoderDecoderTransformer


def construct_model():
	with open('/Users/adam.amster/Downloads/train_config-2.json') as f:
		config = json.load(f)
	config = TransformerConfig.model_validate(config)

	model = EncoderDecoderTransformer(
		n_attention_heads=config.n_head,
		n_layers=config.num_layers,
		vocab_size=source_tokenizer.vocab_size,
		d_model=config.d_model,
		block_size=config.max_input_length,
		feedforward_hidden_dim=config.feedforward_hidden_dim,
		sos_token_id=source_tokenizer.processor.bos_id(),
		eos_token_id=source_tokenizer.processor.eos_id(),
		pad_token_id=source_tokenizer.processor.pad_id(),
		norm_first=config.norm_first,
		mlp_activation=config.activation,
		positional_encoding_type=config.positional_encoding_type,
	)

	model.load_state_dict(
		_fix_model_state_dict(torch.load('/Users/adam.amster/Downloads/ckpt.pt', map_location='cpu')["model"])
	)

	model.eval()
	return model

In [31]:
encoder_decoder = construct_model()

In [32]:
def get_test_dset_examples():
	short = [x for x in test_dset if  9 <= len(x[0]) <= 19]
	return short

In [33]:
test_set = get_test_dset_examples()
long_examples = [x for x in test_dset if len(x[0]) > 19]
rand_idxs = [592, 533, 463, 187]

In [43]:
target_tokenizer.decode(test_dset[302][1])

'Chaque constructeur planifie de façon différente.'

In [11]:
from enum import Enum


class AttnType(Enum):
    ENCODER_SELF = 'ENCODER_SELF'
    DECODER_SELF = 'DECODER_SELF'
    DECODER_CROSS = 'DECODER_CROSS'

In [17]:
def get_attn_weights(attention_type: AttnType, examples):
	sources = []
	preds = []
	attn_weights = []
	for ex in examples:
		source = ex[0].unsqueeze(0)
		target = ex[1].unsqueeze(0)
		with torch.no_grad():
			if attention_type == AttnType.ENCODER_SELF:
				attention_weights = encoder_decoder.encoder(source, src_key_padding_mask=torch.ones_like(source, dtype=torch.bool),
															return_attention_weights=True)
			elif attention_type == AttnType.DECODER_CROSS:
				_, _, attention_weights = encoder_decoder.generate(
					src=source,
					top_k=1,
                    return_cross_attention_weights=True
				)
			pred, _ = encoder_decoder.generate(
				src=source,
				top_k=1,
			)
			print(f'source: {source_tokenizer.decode(source[0])}')
			print(f'target: {target_tokenizer.decode(target[0])}')
			print(f'pred: {target_tokenizer.decode(pred)}')
		sources.append(source)
		preds.append(pred)
		attn_weights.append(attention_weights)
	return sources, preds, attn_weights

In [15]:
import math, numpy as np, plotly.graph_objects as go
from plotly.subplots import make_subplots
from enum import Enum

# ───────────  STYLE CONSTANTS  ───────────
ROW_H_PX   = 22          # px between token rows
FONT_SIZE  = 14          # pt
COL_GAP    = 1.8         # x-distance between columns
LINE_ALPHA = 0.45        # line opacity
LW_SCALE   = 4           # max line width multiplier
THRESH     = 0.02        # skip weights < THRESH
# ─────────────────────────────────────────

class AttnType(Enum):
    ENCODER_SELF  = 1
    DECODER_SELF  = 2
    DECODER_CROSS = 3


# ─────────────────────  ONE HEAD  →  ONE SUBPLOT  ──────────────────────
def add_arc_subplot(fig, row, col,
                    weights, left_tok, right_tok,
                    *, show=True):
    """
    Draw vertical/vertical arc diagram for a single head in subplot (row, col).

    * first token in each list is at the TOP row (y = 0)
    * one blue line per weight ≥ THRESH
    * shorter column simply ends – no phantom lines
    """
    T_q, T_k  = weights.shape
    max_len   = max(T_q, T_k)
    y_left    = np.arange(T_q) * ROW_H_PX
    y_right   = np.arange(T_k) * ROW_H_PX
    ymax      = (max_len - 1) * ROW_H_PX + ROW_H_PX / 2

    # 1️⃣ lines under the text
    for i in range(T_q):
        for j in range(T_k):
            w = float(weights[i, j])
            if w < THRESH:
                continue
            fig.add_trace(
                go.Scatter(
                    x=[0.15, COL_GAP - 0.15],
                    y=[y_left[i], y_right[j]],
                    mode='lines',
                    line=dict(width=LW_SCALE * w,
                              color=f'rgba(65,105,225,{LINE_ALPHA})'),
                    hoverinfo='none',
                    showlegend=False,
                    visible=show),
                row=row, col=col)

    # 2️⃣ token labels over the lines
    fig.add_trace(
        go.Scatter(x=np.zeros(T_q), y=y_left,
                   mode='text', text=left_tok,
                   textfont=dict(size=FONT_SIZE, color='black'),
                   textposition='middle right',
                   showlegend=False, visible=show),
        row=row, col=col)

    fig.add_trace(
        go.Scatter(x=np.full(T_k, COL_GAP), y=y_right,
                   mode='text', text=right_tok,
                   textfont=dict(size=FONT_SIZE, color='black'),
                   textposition='middle left',
                   showlegend=False, visible=show),
        row=row, col=col)

    # 3️⃣ axes
    fig.update_xaxes(visible=False, range=[-0.5, COL_GAP + .5], row=row, col=col)
    fig.update_yaxes(visible=False, range=[-ROW_H_PX / 2, ymax],
                     autorange='reversed', row=row, col=col)


# ─────────────────────────────  DRIVER  ────────────────────────────────
def plot_attn_weights(examples, base_output_dir: Path, attn_type: AttnType, n_columns: int = 3):
    """
    Build full multi-head / multi-layer arc-diagram figure and write to HTML+JSON.
    """
    sources, preds, attn = get_attn_weights(attention_type=attn_type, examples=examples)

    for ex in range(1):
        L = len(attn[ex])              # layers
        H = attn[ex][0].shape[1]       # heads
        n_rows = math.ceil(H / n_columns)

        fig = make_subplots(
            rows=n_rows, cols=n_columns,
            horizontal_spacing=.09, vertical_spacing=.03,
            subplot_titles=[f'Head {h+1}' for h in range(H)])

        # token lists for this sentence
        if attn_type == AttnType.ENCODER_SELF:
            left  = right = [source_tokenizer.processor.id_to_piece(x.item())
                             for x in sources[ex][0]]
        elif attn_type == AttnType.DECODER_SELF:
            left  = right = [target_tokenizer.processor.id_to_piece(x.item())
                             for x in preds[ex][0]]
        else:  # DECODER_CROSS
            left  = [target_tokenizer.processor.id_to_piece(x.item())
                     for x in preds[ex][0]]
            right = [source_tokenizer.processor.id_to_piece(x.item())
                     for x in sources[ex][0]]

        # HTML-escape <s>, </s>
        esc = lambda t: '&lt;s&gt;' if t == '<s>' else ('&lt;/s&gt;' if t == '</s>' else t)
        left  = list(map(esc, left))
        right = list(map(esc, right))

        # build all subplots & store trace ids per layer
        traces_by_layer = [[] for _ in range(L)]

        for lay in range(L):
            for hd in range(H):
                r = hd // n_columns + 1
                c = hd %  n_columns + 1
                visible = (lay == 0)
                start = len(fig.data)

                add_arc_subplot(
                    fig, r, c,
                    weights=attn[ex][lay][0, hd],
                    left_tok=left,
                    right_tok=right,
                    show=visible)

                traces_by_layer[lay].extend(range(start, len(fig.data)))

        # slider
        steps = []
        for lay, ids in enumerate(traces_by_layer):
            vis = [False]*len(fig.data)
            for i in ids: vis[i] = True
            steps.append(dict(label=f'Layer {lay+1}',
                              method='update',
                              args=[{'visible': vis}]))

        max_tok   = max(len(left), len(right))
        subplot_h = max_tok * ROW_H_PX + 40

        fig.update_layout(
            sliders=[dict(active=0, steps=steps,
                          pad={'t':60,'b':10},
                          currentvalue=dict(font=dict(size=FONT_SIZE+1,
                                                      color='black')))],
            autosize=False,
            width  = n_columns * 320,
            height = n_rows * subplot_h + 120,
            margin=dict(t=20,l=20,r=20,b=60),
            font=dict(size=FONT_SIZE, color='black'),
            plot_bgcolor='rgba(0,0,0,0)',
            paper_bgcolor='rgba(0,0,0,0)',
        )
        for ann in fig.layout.annotations:
            ann.font.update(size=FONT_SIZE+2, color='black')

        tag = {AttnType.ENCODER_SELF:  'encoder_self',
               AttnType.DECODER_SELF:  'decoder_self',
               AttnType.DECODER_CROSS: 'decoder_cross'}[attn_type]
        base = (f'{base_output_dir}/'
                f'multi_head_{tag}_attention_weights_{ex}')
        fig.write_json(base+'.json')
        fig.write_html(base+'.html')

In [99]:
plot_attn_weights(attn_type=AttnType.ENCODER_SELF, examples=[test_set[idx] for idx in rand_idxs], base_output_dir=Path('/Users/adam.amster/aamster.github.io/assets/plotly/2025-04-13-sequence_to_sequence_translation_2/'))

source: But why such optimism for some and pessimism for others?
target: Les raisons d'un tel optimisme, chez les uns, et pessimisme, chez les autres ?
pred: Mais pourquoi un tel optimisme pour certains et le pessimisme pour d'autres?
source: Regulatory authority over phone calls belongs to the Federal Communications Commission, not the FAA.
target: Le pouvoir réglementaire concernant les téléphones portables appartient à la Federal Communications Commission et non à la FAA.
pred: Le pouvoir de réglementation des appels téléphoniques appartient à la Commission fédérale des communications, et non à la FAA.
source: They don't want us to dictate to them what makes them profitable.
target: Elles ne veulent pas qu'on leur dise ce qui leur permettra d'être rentables.
pred: Ils ne veulent pas que nous leur dictions ce qui les rend rentables.
source: The cinema was ventilated and everyone returned in good order.
target: La salle a été aérée et tout est rentré dans l'ordre.
pred: Le cinéma a ét

In [93]:
plot_attn_weights(attn_type=AttnType.DECODER_CROSS, examples=[test_set[idx] for idx in rand_idxs], base_output_dir=Path('/Users/adam.amster/aamster.github.io/assets/plotly/2025-04-13-sequence_to_sequence_translation_2/'))

source: But why such optimism for some and pessimism for others?
target: Les raisons d'un tel optimisme, chez les uns, et pessimisme, chez les autres ?
pred: Mais pourquoi un tel optimisme pour certains et le pessimisme pour d'autres?
source: Regulatory authority over phone calls belongs to the Federal Communications Commission, not the FAA.
target: Le pouvoir réglementaire concernant les téléphones portables appartient à la Federal Communications Commission et non à la FAA.
pred: Le pouvoir de réglementation des appels téléphoniques appartient à la Commission fédérale des communications, et non à la FAA.
source: They don't want us to dictate to them what makes them profitable.
target: Elles ne veulent pas qu'on leur dise ce qui leur permettra d'être rentables.
pred: Ils ne veulent pas que nous leur dictions ce qui les rend rentables.
source: The cinema was ventilated and everyone returned in good order.
target: La salle a été aérée et tout est rentré dans l'ordre.
pred: Le cinéma a ét

In [22]:
plot_attn_weights(attn_type=AttnType.ENCODER_SELF, examples=[long_examples[np.random.choice(len(long_examples))]], base_output_dir=Path('/tmp'))

source: While Murphy said the purpose of the trip is to help improve relationships, he said some "tough love" will also be dispensed.
target: Murphy a déclaré que le but de ce voyage était de permettre d'améliorer les relations, mais également de faire preuve de « fermeté affectueuse ».
pred: Bien que le but du voyage soit d'améliorer les relations, il a dit que l'on dispensera aussi certains « amours très serrés ».
