In [1]:
import os
from pathlib import Path

import numpy as np
import torch

from seq2seq_translation.datasets.datasets import LanguagePairsDatasets
from seq2seq_translation.models.transformer.decoder import DecoderTransformer
from seq2seq_translation.sentence_pairs_dataset import SentencePairsDataset
from seq2seq_translation.tokenization.sentencepiece_tokenizer import SentencePieceTokenizer
from seq2seq_translation.run import _fix_model_state_dict

In [2]:
torch.random.manual_seed(1234)
np.random.seed(1234)

In [3]:
tokenizer = SentencePieceTokenizer(
    model_prefix='/Users/adam.amster/seq2seq_translation/tokenizer/32000/32000',
    include_language_tag=True
)

In [4]:
os.environ['DEVICE'] = 'cpu'

In [5]:
import json
from seq2seq_translation.config.transformer_config import TransformerConfig

with open('/Users/adam.amster/Downloads/train_config_decoder.json') as f:
    config = json.load(f)
config = TransformerConfig.model_validate(config)

In [6]:
def construct_test_dset():
	test_datasets = LanguagePairsDatasets(
			data_path=Path('/Users/adam.amster/seq2seq_translation/datasets/wmt14_test/wmt___wmt14/fr-en/0.0.0/b199e406369ec1b7634206d3ded5ba45de2fe696/wmt14-test.arrow'),
			source_lang='en',
			target_lang='fr',
			is_test=True
	)
	
	test_dset = SentencePairsDataset(
		datasets=test_datasets,
		idxs=np.arange(len(test_datasets)),
		combined_tokenizer=tokenizer,
		max_length=None,
        eos_token_id=tokenizer.eot_idx,
        pad_token_id=tokenizer.pad_idx,
        combine_source_and_target=True,
        source_language_tag_token_id=tokenizer.language_tag_map[config.source_lang],
        target_language_tag_token_id=tokenizer.language_tag_map[config.target_lang],
	)
	return test_dset

In [7]:
test_dset = construct_test_dset()

In [8]:
def construct_model():
	model = DecoderTransformer(
        n_attention_heads=config.n_head,
        n_layers=config.num_layers,
        vocab_size=tokenizer.vocab_size,
        d_model=config.d_model,
        block_size=config.fixed_length,
        feedforward_hidden_dim=config.feedforward_hidden_dim,
        norm_first=config.norm_first,
        mlp_activation=config.activation,
        use_cross_attention=False,
        positional_encoding_type=config.positional_encoding_type,
        pad_token_idx=tokenizer.pad_idx,
	)

	model.load_state_dict(
		_fix_model_state_dict(torch.load('/Users/adam.amster/Downloads/ckpt_decoder.pt', map_location='cpu')["model"])
	)

	model.eval()
	return model

In [9]:
model = construct_model()

In [10]:
[i for i,x in enumerate(test_dset) if test_dset._datasets[i][0] == "The cinema was ventilated and everyone returned in good order."]

[634]

In [11]:
long_examples = [x for x in test_dset if len(x[0]) > 19]
rand_idxs = [2184, 1956, 1679, 634]

In [72]:
def get_attn_weights(examples):
    sources = []
    preds = []
    attn_weights = []
    for ex in examples:
        source = ex[0].unsqueeze(0)
        target = ex[1].unsqueeze(0)
        with torch.no_grad():
            pred, attention_weights = model.generate(
                x=source,
                top_k=1,
                return_attention_weights=True,
                eot_token_id=tokenizer.eot_idx,
                pad_token_id=tokenizer.pad_idx,
                include_input=True,
            )
            print(f'source: {tokenizer.decode(source[0])}')
            print(f'target: {tokenizer.decode(target[0])}')
            print(f'pred: {tokenizer.decode(pred)}')
        sources.append(source)
        preds.append(pred)
        attn_weights.append(attention_weights)
    return preds, attn_weights

In [82]:
import math, numpy as np, plotly.graph_objects as go
from plotly.subplots import make_subplots

# ───────────  STYLE CONSTANTS  ───────────
ROW_H_PX   = 22          # px between token rows
FONT_SIZE  = 14          # pt
COL_GAP    = 1.8         # x-distance between columns
LINE_ALPHA = 0.45        # line opacity
LW_SCALE   = 4           # max line width multiplier
THRESH     = 0.1        # skip weights < THRESH
# ─────────────────────────────────────────


# ─────────────────────  ONE HEAD  →  ONE SUBPLOT  ──────────────────────
def add_arc_subplot(fig, row, col,
                    weights, left_tok, right_tok,
                    *, show=True):
    """
    Draw vertical/vertical arc diagram for a single head in subplot (row, col).

    * first token in each list is at the TOP row (y = 0)
    * one blue line per weight ≥ THRESH
    * shorter column simply ends – no phantom lines
    """
    T_q, T_k  = weights.shape
    max_len   = max(T_q, T_k)
    y_left    = np.arange(T_q) * ROW_H_PX
    y_right   = np.arange(T_k) * ROW_H_PX
    ymax      = (max_len - 1) * ROW_H_PX + ROW_H_PX / 2

    # 1️⃣ lines under the text
    for i in range(T_q):
        for j in range(T_k):
            w = float(weights[i, j])
            if w < THRESH:
                continue
            # coordinates for a *dense* line -----------------------------
            n_pts   = 20                                  # ← tweak if needed
            xs      = np.linspace(0.15, COL_GAP - 0.15, n_pts)
            ys      = np.linspace(y_left[i], y_right[j], n_pts)

            fig.add_trace(
                go.Scatter(
                    x=xs, y=ys,

                    # draw the blue segment
                    mode='lines+markers',
                    line=dict(width=LW_SCALE * w,
                              color=f'rgba(65,105,225,{LINE_ALPHA})'),

                    # invisible markers sitting *on* the line
                    marker=dict(size=8, color='rgba(0,0,0,0)'),

                    # one tooltip for all points
                    hoverinfo='text',
                    hovertext=[(
                        f"src: {left_tok[i]}"
                        f"<br>tgt: {right_tok[j]}"
                        f"<br>weight: {w:.3f}"
                    )] * n_pts,
                    hovertemplate="%{hovertext}<extra></extra>",

                    showlegend=False,
                    visible=show),
                row=row, col=col)


    # 2️⃣ token labels over the lines
    fig.add_trace(
        go.Scatter(x=np.zeros(T_q), y=y_left,
                   mode='text', text=left_tok,
                   textfont=dict(size=FONT_SIZE, color='black'),
                   textposition='middle right',
                   showlegend=False, visible=show),
        row=row, col=col)

    fig.add_trace(
        go.Scatter(x=np.full(T_k, COL_GAP), y=y_right,
                   mode='text', text=right_tok,
                   textfont=dict(size=FONT_SIZE, color='black'),
                   textposition='middle left',
                   showlegend=False, visible=show),
        row=row, col=col)

    # 3️⃣ axes
    fig.update_xaxes(visible=False, range=[-0.5, COL_GAP + .5], row=row, col=col)
    fig.update_yaxes(visible=False, range=[-ROW_H_PX / 2, ymax],
                     autorange='reversed', row=row, col=col)


# ─────────────────────────────  DRIVER  ────────────────────────────────
def plot_attn_weights(examples, base_output_dir: Path, n_columns: int = 3):
    """
    Build full multi-head / multi-layer arc-diagram figure and write to HTML+JSON.
    """
    preds, attn = get_attn_weights(examples=examples)

    for ex in range(1):
        L = len(attn[ex][0])              # layers
        H = attn[ex][0][0].shape[1]       # heads
        n_rows = math.ceil(H / n_columns)

        fig = make_subplots(
            rows=n_rows, cols=n_columns,
            horizontal_spacing=.09, vertical_spacing=.03,
            subplot_titles=[f'Head {h+1}' for h in range(H)])

        left = []
        for x in preds[ex][0]:
            if x == tokenizer.language_tag_map[config.source_lang]:
                left.append(f'&lt;{config.source_lang}&gt;')
            elif x == tokenizer.language_tag_map[config.target_lang]:
                left.append(f'&lt;{config.target_lang}&gt;')
            else:
                tok = tokenizer.processor.id_to_piece(x.item())
                # HTML-escape <s>, </s>
                if tok == '<s>':
                    tok = '&lt;s&gt;'
                elif tok == '</s>':
                    tok = '&lt;/s&gt;'
                left.append(tok)
        right = left

        # build all subplots & store trace ids per layer
        traces_by_layer = [[] for _ in range(L)]

        for lay in range(L):
            for hd in range(H):
                r = hd // n_columns + 1
                c = hd %  n_columns + 1
                visible = (lay == 0)
                start = len(fig.data)

                #print(ex, lay, hd, start)
                add_arc_subplot(
                    fig, r, c,
                    weights=attn[ex][0][lay][0, hd],
                    left_tok=left,
                    right_tok=right,
                    show=visible)

                traces_by_layer[lay].extend(range(start, len(fig.data)))

        # slider
        steps = []
        for lay, ids in enumerate(traces_by_layer):
            vis = [False]*len(fig.data)
            for i in ids: vis[i] = True
            steps.append(dict(label=f'Layer {lay+1}',
                              method='update',
                              args=[{'visible': vis}]))

        max_tok   = max(len(left), len(right))
        subplot_h = max_tok * ROW_H_PX + 40

        fig.update_layout(
            sliders=[dict(active=0, steps=steps,
                          pad={'t':60,'b':10},
                          currentvalue=dict(font=dict(size=FONT_SIZE+1,
                                                      color='black')))],
            autosize=False,
            width  = n_columns * 320,
            height = n_rows * subplot_h + 120,
            margin=dict(t=20,l=20,r=20,b=60),
            font=dict(size=FONT_SIZE, color='black'),
            plot_bgcolor='rgba(0,0,0,0)',
            paper_bgcolor='rgba(0,0,0,0)',
        )
        for ann in fig.layout.annotations:
            ann.font.update(size=FONT_SIZE+2, color='black')

        base = (f'{base_output_dir}/'
                f'multi_head_decoder_attention_weights_{ex}')
        fig.write_json(base+'.json')
        fig.write_html(base+'.html')
plot_attn_weights(examples=[test_dset[idx] for idx in rand_idxs], base_output_dir=Path('/Users/adam.amster/aamster.github.io/assets/plotly/2025-04-13-sequence_to_sequence_translation_2/'))

source: But why such optimism for some and pessimism for others?
target: Les raisons d'un tel optimisme, chez les uns, et pessimisme, chez les autres ?
pred: But why such optimism for some and pessimism for others? Mais pourquoi un tel optimisme pour certains et un tel pessimisme pour les autres?
source: Regulatory authority over phone calls belongs to the Federal Communications Commission, not the FAA.
target: Le pouvoir réglementaire concernant les téléphones portables appartient à la Federal Communications Commission et non à la FAA.
pred: Regulatory authority over phone calls belongs to the Federal Communications Commission, not the FAA. L'autorité réglementaire sur les appels téléphoniques appartient à la Commission fédérale des communications et non à la LGFP.
source: They don't want us to dictate to them what makes them profitable.
target: Elles ne veulent pas qu'on leur dise ce qui leur permettra d'être rentables.
pred: They don't want us to dictate to them what makes them prof