In [1]:
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
st.image("https://storage.googleapis.com/mle-courses-prod/users/61b6fa1ba83a7e37c8309756/private-files/678dadd0-603b-11ef-b0a7-998b84b38d43-ProtonX_logo_horizontally__1_.png", width=100)  # Replace "logo.png" with your logo file path


# Function to build the 2D rotation matrix
# def build_rotation_matrix(theta):
#     """Returns the 2D rotation matrix for angle theta."""
#     return np.array([[np.cos(theta), -np.sin(theta)],
#                      [np.sin(theta), np.cos(theta)]])

def build_rotation_matrix(m, d):
    assert d % 2 == 0, "Embedding dimension d must be even."
    theta = [10000**(-2 * k / d) for k in range(d // 2)]
    R = np.zeros((d, d))
    for k in range(d // 2):
        theta_k = m * theta[k]
        rot_matrix = np.array([[np.cos(theta_k), -np.sin(theta_k)],
                               [np.sin(theta_k), np.cos(theta_k)]])
        R[2*k:2*k+2, 2*k:2*k+2] = rot_matrix
    return R


# Function to generate LaTeX for rotation matrix
def generate_rotation_matrix_latex():
    """Generates LaTeX representation of the 2D rotation matrix with theta value."""
    latex_matrix = rf"""
    R(\theta) =
    \begin{{bmatrix}}
    \cos(m\theta) & -\sin(m\theta) \\
    \sin(m\theta) & \cos(m\theta)
    \end{{bmatrix}}
    """
    return latex_matrix


# Function to apply Rotary Position Encoding (ROPE)
# def apply_rope(embedding, position, base_theta=0.2):
#     """Applies ROPE to the embedding based on position."""
#     rotated_embedding = embedding.copy()
#     theta = position * base_theta
#     for i in range(0, len(embedding) - 1, 2):  # Process in pairs
#         R = build_rotation_matrix(theta)
#         rotated_embedding[i:i+2] = R @ embedding[i:i+2]
#     return rotated_embedding

def apply_rope(embedding, position, d):
    rotated_embedding = embedding.copy()
    R = build_rotation_matrix(position, d)
    rotated_embedding = R @ embedding
    return rotated_embedding

# Function to calculate the angle between vectors
def calculate_angle_between_vectors(v1, v2):
    """Calculates the angle between two vectors."""
    dot_product = np.dot(v1, v2)
    norm_v1 = np.linalg.norm(v1)
    norm_v2 = np.linalg.norm(v2)
    return np.arccos(dot_product / (norm_v1 * norm_v2))

# Function to visualize embeddings with selectable m
def visualize_embeddings(word, embedding, rotated_embedding, m):
    """Visualizes original and rotated embeddings with arrows and arcs that properly connect vectors."""
    fig, ax = plt.subplots(figsize=(6,6))

    first_orgi = np.array(embedding[:2])
    second_orgi = np.array(embedding[2:])
    k_0_1_rotated = np.array(rotated_embedding[:2])
    k_2_3_rotated = np.array(rotated_embedding[2:])

    # Function to compute angle and draw arc correctly connecting vectors
    def draw_arc(vector1, vector2, arc_radius=0.3, color='green'):
        """Draws an arc between two vectors correctly positioned at the origin."""
        # Compute angle
        dot_product = np.dot(vector1, vector2)
        norm_v1 = np.linalg.norm(vector1)
        norm_v2 = np.linalg.norm(vector2)
        angle_rad = np.arccos(dot_product / (norm_v1 * norm_v2))
        angle_deg = np.degrees(angle_rad)

        # Determine start and end angles for the arc
        start_angle = np.arctan2(vector1[1], vector1[0])
        end_angle = np.arctan2(vector2[1], vector2[0])

        if end_angle < start_angle:
            end_angle += 2 * np.pi

        arc_theta = np.linspace(start_angle, end_angle, 30)
        arc_x = arc_radius * np.cos(arc_theta)
        arc_y = arc_radius * np.sin(arc_theta)
        ax.plot(arc_x, arc_y, color=color, linestyle='-', linewidth=1.5)

        # Annotate angle
        mid_angle = (start_angle + end_angle) / 2
        mid_x = arc_radius * np.cos(mid_angle)
        mid_y = arc_radius * np.sin(mid_angle)
        ax.text(mid_x, mid_y, f'{angle_deg:.2f}°', fontsize=10, color=color, ha='center')

    # Plot original vectors
    ax.quiver(0, 0, first_orgi[0], first_orgi[1], angles='xy', scale_units='xy', scale=1, color='blue', label='Original Vector 0-1')
    ax.text(first_orgi[0], first_orgi[1], f'[{first_orgi[0]:.2f}, {first_orgi[1]:.2f}]', fontsize=7, ha='left')

    ax.quiver(0, 0, second_orgi[0], second_orgi[1], angles='xy', scale_units='xy', scale=1, color='blue', label='Original Vector 2-3')
    ax.text(second_orgi[0], second_orgi[1], f'[{second_orgi[0]:.2f}, {second_orgi[1]:.2f}]', fontsize=7, ha='left')

    # Plot rotated vectors
    ax.quiver(0, 0, k_0_1_rotated[0], k_0_1_rotated[1], angles='xy', scale_units='xy', scale=1, color='red', label='Rotated Vector 0-1')
    ax.text(k_0_1_rotated[0], k_0_1_rotated[1], f'[{k_0_1_rotated[0]:.2f}, {k_0_1_rotated[1]:.2f}]', fontsize=7, ha='left')

    ax.quiver(0, 0, k_2_3_rotated[0], k_2_3_rotated[1], angles='xy', scale_units='xy', scale=1, color='red', label='Rotated Vector 2-3')
    ax.text(k_2_3_rotated[0], k_2_3_rotated[1], f'[{k_2_3_rotated[0]:.2f}, {k_2_3_rotated[1]:.2f}]', fontsize=7, ha='left')

    # Draw arcs for both rotations, ensuring they connect vectors
    draw_arc(first_orgi, k_0_1_rotated, arc_radius=0.3, color='green')
    draw_arc(second_orgi, k_2_3_rotated, arc_radius=0.4, color='purple')

    # Set plot limits
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)
    ax.axhline(0, color='black', linewidth=0.5)
    ax.axvline(0, color='black', linewidth=0.5)
    ax.set_title(f"Vector Rotation Visualization for '{word}' (m={m})")
    ax.legend()
    ax.grid(True)
    ax.set_aspect('equal')
    
    ax.legend(fontsize=3)
    st.pyplot(fig)

# Streamlit UI
st.markdown("### Rotary position embedding visualization")


# Sample sentence and embeddings
sentence = ["the", "cat", "sat", "on", "the", "mat"]
embeddings = np.array([
    [0.1, 0.2, -0.3, -0.4],
    [0.5, 0.6, 0.7, 0.8],  # Reference row
    [-0.25,1.08,-0.52,0.91],
    [0.4, 0.3, -0.2, 0.8],
    [0.2, 0.5, -0.1, 0.3],
    [0.5, 0.6, 0.7, 0.8]
])

# Use session state to store selected word position
if "selected_position" not in st.session_state:
    st.session_state.selected_position = 0  # Default to first word

# Render buttons for each word
cols = st.columns(len(sentence))
for i, word in enumerate(sentence):
    if cols[i].button(word, key=f"word_button_{i}"):  # Unique key for each button
        st.session_state.selected_position = i  # Store selected position in session state

# Get the selected word and its embedding
selected_position = st.session_state.selected_position
selected_word = sentence[selected_position]
original_embedding = embeddings[selected_position]
d = original_embedding.shape[0]
rotated_embedding = apply_rope(original_embedding, selected_position, d)


# Compute theta for the selected position
# theta_value = selected_position * 1.2  # Base theta from the ROPE function

# Display LaTeX matrix in Streamlit
st.markdown("Rotation Matrix")

st.latex(generate_rotation_matrix_latex())

# Visualize embeddings
visualize_embeddings(selected_word, original_embedding, rotated_embedding, selected_position)


# Show embeddings before and after ROPE
col1, col2 = st.columns(2)
with col1:
    st.write("**Original Embedding**")
    st.write(original_embedding)
with col2:
    st.write("**After ROPE Rotation**")
    st.write(rotated_embedding)
st.markdown("[Paper](https://arxiv.org/pdf/2104.09864v5)")

2025-02-27 23:58:57.664 
  command:

    streamlit run C:\Users\Sese\AppData\Roaming\Python\Python312\site-packages\ipykernel_launcher.py [ARGUMENTS]
2025-02-27 23:58:57.667 Session state does not function when running a script without `streamlit run`


AttributeError: st.session_state has no attribute "selected_position". Did you forget to initialize it? More info: https://docs.streamlit.io/library/advanced-features/session-state#initialization

In [2]:
! streamlit run app.py

Usage: streamlit run [OPTIONS] TARGET [ARGS]...
Try 'streamlit run --help' for help.

Error: Invalid value: File does not exist: app.py
