In [1]:
import os
from pathlib import Path

import numpy as np
import torch

from seq2seq_translation.datasets.datasets import LanguagePairsDatasets
from seq2seq_translation.sentence_pairs_dataset import SentencePairsDataset
from seq2seq_translation.tokenization.sentencepiece_tokenizer import SentencePieceTokenizer
from seq2seq_translation.run import _fix_model_state_dict

In [2]:
torch.random.manual_seed(1234)
np.random.seed(1234)

In [3]:
source_tokenizer = SentencePieceTokenizer(model_prefix='/Users/adam.amster/seq2seq_translation/tokenizer/30000/en')
target_tokenizer = SentencePieceTokenizer(model_prefix='/Users/adam.amster/seq2seq_translation/tokenizer/30000/fr')

In [4]:
os.environ['DEVICE'] = 'cpu'

In [5]:
def construct_test_dset():
	test_datasets = LanguagePairsDatasets(
			out_dir=Path('/Users/adam.amster/seq2seq_translation/datasets/wmt14_test'),
			source_lang='en',
			target_lang='fr',
			is_test=True
	)
	
	test_dset = SentencePairsDataset(
		datasets=test_datasets,
		idxs=np.arange(len(test_datasets)),
		source_tokenizer=source_tokenizer,
		target_tokenizer=target_tokenizer,
		max_length=None,
        eos_token_id=source_tokenizer.eot_idx,
        pad_token_id=source_tokenizer.pad_idx,
	)
	return test_dset

In [6]:
test_dset = construct_test_dset()

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

In [7]:
from seq2seq_translation.config.transformer_config import TransformerConfig
import json
from seq2seq_translation.models.transformer.encoder_decoder import EncoderDecoderTransformer


def construct_model():
	with open('/Users/adam.amster/Downloads/train_config-2.json') as f:
		config = json.load(f)
	config = TransformerConfig.model_validate(config)

	model = EncoderDecoderTransformer(
		n_attention_heads=config.n_head,
		n_layers=config.num_layers,
		vocab_size=source_tokenizer.vocab_size,
		d_model=config.d_model,
		block_size=config.max_input_length,
		feedforward_hidden_dim=config.feedforward_hidden_dim,
		sos_token_id=source_tokenizer.processor.bos_id(),
		eos_token_id=source_tokenizer.processor.eos_id(),
		pad_token_id=source_tokenizer.processor.pad_id(),
		norm_first=config.norm_first,
		mlp_activation=config.activation,
		positional_encoding_type=config.positional_encoding_type,
	)

	model.load_state_dict(
		_fix_model_state_dict(torch.load('/Users/adam.amster/Downloads/ckpt.pt', map_location='cpu')["model"])
	)

	model.eval()
	return model

In [8]:
encoder_decoder = construct_model()

In [9]:
def get_test_dset_examples():
	short = [x for x in test_dset if  9 <= len(x[0]) <= 19]
	return short

In [10]:
short_examples = get_test_dset_examples()
rand_idxs = [592, 533, 463, 187]

In [11]:
from enum import Enum


class AttnType(Enum):
    ENCODER_SELF = 'ENCODER_SELF'
    DECODER_SELF = 'DECODER_SELF'
    DECODER_CROSS = 'DECODER_CROSS'

In [12]:
def get_attn_weights(attention_type: AttnType):
	sources = []
	preds = []
	attn_weights = []
	for idx in rand_idxs:
		source = short_examples[idx][0].unsqueeze(0)
		target = short_examples[idx][1].unsqueeze(0)
		with torch.no_grad():
			if attention_type == AttnType.ENCODER_SELF:
				attention_weights = encoder_decoder.encoder(source, src_key_padding_mask=torch.ones_like(source, dtype=torch.bool),
															return_attention_weights=True)
			elif attention_type == AttnType.DECODER_CROSS:
				_, _, attention_weights = encoder_decoder.generate(
					src=source,
					top_k=1,
                    return_cross_attention_weights=True
				)
			pred, _ = encoder_decoder.generate(
				src=source,
				top_k=1,
			)
			print(f'source: {source_tokenizer.decode(source[0])}')
			print(f'target: {target_tokenizer.decode(target[0])}')
			print(f'pred: {target_tokenizer.decode(pred)}')
		sources.append(source)
		preds.append(pred)
		attn_weights.append(attention_weights)
	return sources, preds, attn_weights

In [16]:
import math
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

def plot_attn_weights(attn_type: AttnType):
	sources, preds, attn_weights = get_attn_weights(attention_type=attn_type)

	for idx in range(len(sources)):
		n_layers = len(attn_weights[idx])
		n_heads = attn_weights[idx][0].shape[1]

		n_rows = math.ceil(n_heads/3)
		n_cols = 3

		fig = make_subplots(
			rows=n_rows,
			cols=n_cols,
			horizontal_spacing=0.09,
			vertical_spacing=0.12,
			subplot_titles=[f'Head {h+1}' for h in range(n_heads)]
		)

		if attn_type == AttnType.ENCODER_SELF:
			x_labels = [source_tokenizer.processor.id_to_piece(x.item()) for x in sources[idx][0]]
			y_labels = [source_tokenizer.processor.id_to_piece(x.item()) for x in sources[idx][0]]
		elif attn_type == AttnType.DECODER_SELF:
			x_labels = [target_tokenizer.processor.id_to_piece(x.item()) for x in preds[idx][0]]
			y_labels = [target_tokenizer.processor.id_to_piece(x.item()) for x in preds[idx][0]]
		else:
			x_labels = [source_tokenizer.processor.id_to_piece(x.item()) for x in sources[idx][0]]
			y_labels = [target_tokenizer.processor.id_to_piece(x.item()) for x in preds[idx][0]]

		# need to escape <s> and </s>
		x_labels = [x if x not in ('<s>', '</s>') else ('&lt;s&gt;' if x == '<s>' else ' &lt;/s&gt;') for x in x_labels]
		y_labels = [x if x not in ('<s>', '</s>') else ('&lt;s&gt;' if x == '<s>' else ' &lt;/s&gt;') for x in y_labels]

		for layer in range(n_layers):
			row_idx = 0
			col_idx = 0
			for head in range(n_heads):
				z_data = attn_weights[idx][layer][0, head]
				x_indices = list(range(len(x_labels)))
				y_indices = list(range(len(y_labels)))

				fig.add_trace(
					go.Heatmap(
						z=z_data,
						colorscale="Blues",
						zmin=0,
						zmax=1,
						showscale=False,
						visible=layer == 0,
					),
					row=row_idx+1,
					col=col_idx+1,
				)
				fig.update_yaxes(
					autorange='reversed',
					row=row_idx+1,
					col=col_idx + 1,
					tickmode='array',
					tickvals=y_indices,
					ticktext=y_labels
				)
				fig.update_xaxes(
					tickangle=45,
					row=row_idx+1,
					col=col_idx + 1,
					tickmode='array',
					tickvals=x_indices,
					ticktext=x_labels
				)

				col_idx += 1
				if col_idx == n_cols:
					row_idx += 1
					col_idx = 0

		# --- build slider steps
		steps = []
		for layer in range(n_layers):
			vis = [False]*(n_layers*n_heads)
			for h in range(n_heads):
				vis[layer*n_heads + h] = True
			steps.append(dict(label=f'Layer {layer+1}',
							  method='update',
							  args=[{'visible': vis}]))

		cell_px   = 23                     # tweak until it looks right
		fig.update_layout(
			coloraxis=dict(
				colorscale="Blues"  # Set your desired colorscale here
			),
			autosize=False,
			width  = n_cols * len(x_labels) * cell_px + 200,  # +200 for colour bar & margins
			height = n_rows * len(y_labels) * cell_px,
			plot_bgcolor='rgba(0,0,0,0)',  # Transparent plot area
			paper_bgcolor='rgba(0,0,0,0)',  # Transparent outer background
			font=dict(color='black'),  # Set tick label color for visibility
			margin=dict(t=20, l=40, r=20, b=60),
			sliders=[dict(active=0, steps=steps, pad=dict(t=60, b=10))],
		)

		# Display the plot
		#fig.show()

		if attn_type == AttnType.ENCODER_SELF:
			attention_name = 'encoder_self_attention'
		elif attn_type == AttnType.DECODER_CROSS:
			attention_name = 'cross_attention'
		fig.write_json(f'/Users/adam.amster/aamster.github.io/assets/plotly/2025-04-13-sequence_to_sequence_translation_2/multi_head_{attention_name}_weights_{idx}.json')
		fig.write_html(
			f'/Users/adam.amster/aamster.github.io/assets/plotly/2025-04-13-sequence_to_sequence_translation_2/multi_head_{attention_name}_weights_{idx}.html')
plot_attn_weights(attn_type=AttnType.DECODER_CROSS)

source: But why such optimism for some and pessimism for others?
target: Les raisons d'un tel optimisme, chez les uns, et pessimisme, chez les autres ?
pred: Mais pourquoi un tel optimisme pour certains et le pessimisme pour d'autres?
source: Regulatory authority over phone calls belongs to the Federal Communications Commission, not the FAA.
target: Le pouvoir réglementaire concernant les téléphones portables appartient à la Federal Communications Commission et non à la FAA.
pred: Le pouvoir de réglementation des appels téléphoniques appartient à la Commission fédérale des communications, et non à la FAA.
source: They don't want us to dictate to them what makes them profitable.
target: Elles ne veulent pas qu'on leur dise ce qui leur permettra d'être rentables.
pred: Ils ne veulent pas que nous leur dictions ce qui les rend rentables.
source: The cinema was ventilated and everyone returned in good order.
target: La salle a été aérée et tout est rentré dans l'ordre.
pred: Le cinéma a ét

In [14]:
plot_attn_weights(attn_type=AttnType.ENCODER_SELF)

source: But why such optimism for some and pessimism for others?
target: Les raisons d'un tel optimisme, chez les uns, et pessimisme, chez les autres ?
pred: Mais pourquoi un tel optimisme pour certains et le pessimisme pour d'autres?
source: Regulatory authority over phone calls belongs to the Federal Communications Commission, not the FAA.
target: Le pouvoir réglementaire concernant les téléphones portables appartient à la Federal Communications Commission et non à la FAA.
pred: Le pouvoir de réglementation des appels téléphoniques appartient à la Commission fédérale des communications, et non à la FAA.
source: They don't want us to dictate to them what makes them profitable.
target: Elles ne veulent pas qu'on leur dise ce qui leur permettra d'être rentables.
pred: Ils ne veulent pas que nous leur dictions ce qui les rend rentables.
source: The cinema was ventilated and everyone returned in good order.
target: La salle a été aérée et tout est rentré dans l'ordre.
pred: Le cinéma a ét

In [19]:
plot_attn_weights(attn_type=AttnType.DECODER_CROSS)

KeyboardInterrupt: 