In [None]:
import numpy as np
from scipy.linalg import orthogonal_procrustes

import numpy as np
from scipy.linalg import orthogonal_procrustes

def compute_orthogonal_transform(base_model, other_model):
    """Compute orthogonal transform from other (Rep) to base (Dem) using all common vocabulary"""
    # base_model: Democratic model (target)
    # other_model: Republican model (source to align)
    
    # Find common vocabulary
    vocab1 = set(base_model.wv.index_to_key)
    vocab2 = set(other_model.wv.index_to_key)
    common_vocab = list(vocab1.intersection(vocab2))
    
    print(f"Common vocabulary: {len(common_vocab)} words")
    
    # Extract embeddings for common words
    A = np.array([base_model.wv[word] for word in common_vocab])  # Dem vectors
    B = np.array([other_model.wv[word] for word in common_vocab])  # Rep vectors
    
    # Mean center
    A -= A.mean(axis=0, keepdims=True)
    B -= B.mean(axis=0, keepdims=True)
    
    # Normalize each vector to unit length
    A /= np.linalg.norm(A, axis=1, keepdims=True)
    B /= np.linalg.norm(B, axis=1, keepdims=True)
    
    # Compute orthogonal Procrustes transformation: A R â‰ˆ B
    R, scale = orthogonal_procrustes(A, B)
    
    # Create aligned KeyedVectors by applying transform to all vectors in other_model
    aligned_model = other_model.wv.copy()
    aligned_model.vectors = aligned_model.vectors.dot(R)  # Apply R to all vectors
    
    return aligned_model


def get_neighbor_union(base_model, aligned_other_model, target_word, k):
    base_neighbors = base_model.wv.most_similar(target_word, topn=k)
    other_neighbors = aligned_other_model.most_similar(target_word, topn=k)

    neighbor_union = {target_word.lower()}
    neighbor_union.update(word for word, _ in base_neighbors)
    neighbor_union.update(word for word, _ in other_neighbors)

    return {
        "target": target_word,
        "base_neighbors": base_neighbors,
        "other_neighbors": other_neighbors,
        "union": sorted(neighbor_union)
    }

TARGET_WORD = "trump"
TOP_K = 100
neighbor_info = get_neighbor_union(dem_model, aligned_rep_kv, TARGET_WORD, TOP_K)
neighbor_info["union"][:10], len(neighbor_info["union"])
# ...existing code...
