In [None]:
# 01_tfhub_eos_analysis.ipynb

# Goal:
# Analyze which tokens from the TFHub Sentence-T5 embedding space are closest to the `</s>` token.
# Discover soft-`</s>` tokens (e.g., lucrarea) and explain their effect on SCS scoring.

import tensorflow_hub as hub
import tensorflow_text as text  # Needed to register SentencePiece ops
import tensorflow as tf
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import sentencepiece as spm


In [None]:
# Load Sentence-T5 encoder from TFHub
model_url = "t5"
encoder = hub.KerasLayer(model_url)

# Load SentencePiece tokenizer
sp = spm.SentencePieceProcessor()
sp.load("t5/spiece.model")  # Adjust path if needed

In [None]:

# Get embedding for a text input
def get_embed(text):
    output = encoder(tf.constant([text]))
    if isinstance(output, list):
        output = output[0]
    elif isinstance(output, dict):
        output = list(output.values())[0]
    return output.numpy().squeeze()


In [None]:
# Compute all token embeddings in batches
def get_all_token_embeddings(encoder, token_texts, batch_size=512):
    all_embeddings = []
    for i in tqdm(range(0, len(token_texts), batch_size)):
        batch = tf.constant(token_texts[i:i+batch_size])
        output = encoder(batch)
        if isinstance(output, list):
            output = output[0]
        elif isinstance(output, dict):
            output = list(output.values())[0]
        all_embeddings.append(output.numpy())
    return np.vstack(all_embeddings)

# Sharpened cosine and raw dot product similarity
from sklearn.metrics.pairwise import cosine_similarity



In [None]:
def compute_similarities(batch_embeds, target_embed, p=3):
    # Cosine similarity
    cos_sim = cosine_similarity(batch_embeds, target_embed.reshape(1, -1)).flatten()
    sharp_cos = cos_sim ** p

    # Raw dot product
    raw_dot = np.dot(batch_embeds, target_embed)
    return sharp_cos, raw_dot

In [None]:
# Get all tokens
token_texts = [sp.id_to_piece(i) for i in range(sp.get_piece_size())]
all_token_embeddings = get_all_token_embeddings(encoder, token_texts, batch_size=512)



In [None]:
# Build fake eos from parts
parts = ["<", "/", "s", ">"]
eos_fake = np.mean([get_embed(p) for p in parts], axis=0)
eos_fake /= np.linalg.norm(eos_fake)


In [None]:
# Compute similarities
cosine_sim, dot_sim = compute_similarities(all_token_embeddings, eos_fake, p=3)

In [None]:
# Sort and show
top_k = 100
top_indices = np.argsort(cosine_sim)[::-1][:top_k]

In [None]:
print("Top tokens closest to </s> (cosine & dot product):")
for idx in top_indices:
    print(f"{token_texts[idx]:<15} | cosine: {cosine_sim[idx]:.4f} | dot: {dot_sim[idx]:.4f}")