In [None]:
from typing import Annotated, Callable
import json
import os
import typing
from typing import Awaitable
import asyncio
from pathlib import Path
import aiofiles

from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain.chat_models.base import BaseChatModel
from langchain import chat_models
from pydantic import BaseModel, Field, RootModel
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from random import Random

### Define the shape of the profile an analyzer should return

In [None]:
class Profile(BaseModel):
    identity: float = Field(ge=0, le=1)

    def cmp(self, other: "Profile") -> float:
        return abs(self.identity - other.identity)

### Setup

In [None]:
class ProdProfile(BaseModel):
    identity: float = Field(ge=0, le=1)
    horoscope: str = Field()

    def cmp(self, other: "ProdProfile") -> float:
        return abs(self.identity - other.identity)


class QuestionResponse(BaseModel):
    question: str = Field()
    response: str = Field()


class Response(BaseModel):
    first_name: str = Field()
    last_name: str = Field()
    responses: dict[str, QuestionResponse] = Field()


class User(BaseModel):
    response: Response = Field()
    prod_profile: ProdProfile | None = Field()


UserSet = Annotated[dict[str, Response], Field()]

Couple = Annotated[tuple[str, str], Field()]
CouplePairs = Annotated[dict[str, Couple], Field()]


Analyzer = Callable[[Response], Awaitable[Profile]]

DATA_DIR = Path("data")

if False:

    def update_data():
        from supabase.lib.client_options import SyncClientOptions
        import supabase

        from typing import Any

        SUPABASE_URL_BASE = "https://znsozdvrmfdwxyymtgdz.supabase.co/"

        with open("secrets.json", "r") as f:
            secrets = json.load(f)

        sb_client = supabase.create_client(
            SUPABASE_URL_BASE,
            secrets["EEVA_SUPABASE_SERVICE_KEY"],
            options=SyncClientOptions(auto_refresh_token=False, persist_session=False),
        )

        questions = sb_client.table("questions").select("*").execute().data
        raw_answers = (
            sb_client.table("user_answers")
            .select("user_id, question_id, answer_text")
            .execute()
            .data
        )
        user_answer_lists: dict[str, dict[str, str]] = {}
        for ans in raw_answers:
            user_answer_lists.setdefault(ans["user_id"], {})[ans["question_id"]] = ans[
                "answer_text"
            ]

        raw_user_data = (
            sb_client.table("profiles")
            .select("user_id,first_name,last_name,hidden,profile")
            .execute()
            .data
        )
        users: dict[str, dict[str, Any]] = {}
        for user in raw_user_data:
            user_id = user["user_id"]
            if user["hidden"] or user_id not in user_answer_lists:
                continue
            users[user_id] = {
                "first_name": user["first_name"],
                "last_name": user["last_name"],
                "profile": ProdProfile.model_validate(user["profile"])
                if user["profile"]
                else None,
                "answers": user_answer_lists[user_id],
            }

        user_data: dict[str, User] = {}
        # Build a lookup for question text
        question_text_lookup = {q["id"]: q["text"] for q in questions}
        for user_id, user in users.items():
            answers = user["answers"]
            responses = {}
            for question_id, answer_text in answers.items():
                question_text = question_text_lookup.get(question_id, "")
                responses[question_id] = {
                    "question": question_text,
                    "response": answer_text,
                }
            response = Response(
                first_name=user["first_name"],
                last_name=user["last_name"] if user["last_name"] else "",
                responses=responses,
            )
            user_data[user_id] = User(
                response=response,
                prod_profile=user["profile"],
            )

        with open(DATA_DIR / "user_data.json", "w", encoding="utf-8") as f:
            json.dump(
                {k: v.model_dump() for k, v in user_data.items()},
                f,
                indent=2,
            )

    update_data()

with open("secrets.json", "r") as f:
    secrets = json.load(f)
    os.environ["OPENAI_API_KEY"] = secrets["OPENAI_API_KEY"]
    os.environ["ANTHROPIC_API_KEY"] = secrets["ANTHROPIC_API_KEY"]

# llm = chat_models.init_chat_model("gpt-4.1-nano", model_provider="openai")
# llm = chat_models.init_chat_model("gpt-4o-mini", model_provider="openai")
# llm = chat_models.init_chat_model("gpt-4o", model_provider="openai")
# llm = chat_models.init_chat_model("gpt-5", model_provider="openai")
# llm = chat_models.init_chat_model("gpt-5-mini", model_provider="openai")
llm = chat_models.init_chat_model("gpt-5-nano", model_provider="openai")
# llm = chat_models.init_chat_model("claude-3-5-haiku-latest", model_provider="anthropic")
# llm = chat_models.init_chat_model("claude-sonnet-4-20250514", model_provider="anthropic")

with open(DATA_DIR / "user_data.json", "r", encoding="utf-8") as f:

    class UserSetDeserializer(RootModel[dict[str, User]]):
        pass

    user_data = UserSetDeserializer.model_validate_json(f.read()).root

with open(DATA_DIR / "couples.json", "r", encoding="utf-8") as f:
    couple_pairs_raw: dict[str, list[str]] = json.load(f)
    couple_pairs: dict[str, Couple] = {
        k: (v[0], v[1]) for k, v in couple_pairs_raw.items()
    }

## Analyzer

### Define analyzer

In [None]:
async def analyze(response: Response, llm: BaseChatModel, data_path: Path) -> Profile:
    async with aiofiles.open(
        data_path / "identity.txt", mode="r", encoding="utf-8"
    ) as f:
        identity_prompt = await f.read()

    class AnalyzerOutput(BaseModel):
        """ """

        identity: float = Field(ge=0, le=1, description=identity_prompt)

    structured_llm = llm.with_structured_output(AnalyzerOutput)

    content = "\n".join(
        f"{question}: {question_response.response}"
        for question, question_response in response.responses.items()
    )

    raw_output = await structured_llm.ainvoke(
        [
            SystemMessage(
                content="Please analyze the identity of this set of answers."
            ),
            HumanMessage(content=content),
        ]
    )
    if isinstance(raw_output, dict):
        output = AnalyzerOutput(**raw_output)
    elif isinstance(raw_output, AnalyzerOutput):
        output = typing.cast(AnalyzerOutput, raw_output)
    else:
        raise ValueError(
            f"Unexpected output type: {type(raw_output)}. Expected dict or AnalyzerOutput."
        )
    avg_identity = output.identity
    profile = Profile(identity=avg_identity)

    return profile

### Calculate profiles

In [None]:
async def generate_user_profiles(
    user_id: str, user: User, num_profiles: int
) -> tuple[str, list[Profile]]:
    tasks = []
    for _ in range(num_profiles):
        tasks.append(asyncio.create_task(analyze(user.response, llm, DATA_DIR)))
    return user_id, await asyncio.gather(*tasks)


# Create a dict user_id -> Profile for all users in user_data using their responses to run `analyze`
# Use asyncio to run analyze concurrently for all users
async def generate_profiles(
    user_data: dict[str, User], num_profiles: int, user_subset: set[str] | None
) -> dict[str, list[Profile]]:
    if user_subset is not None:
        user_data = {k: v for k, v in user_data.items() if k in user_subset}

    profiles: dict[str, Profile] = {}

    async def analyze_all_users():
        tasks = []
        for user_id, user in user_data.items():
            tasks.append(
                asyncio.create_task(generate_user_profiles(user_id, user, num_profiles))
            )
        results = await asyncio.gather(*tasks)
        for user_id, user_profiles in results:
            profiles[user_id] = user_profiles

    await analyze_all_users()
    return profiles

In [None]:
NUM_TESTS = 5

profiles = await generate_profiles(user_data, NUM_TESTS, None)
profile_list = [(id, profile) for id, profile in profiles.items()]

In [None]:
for key, pair in couple_pairs.items():
    user1_scores = [profile.identity for profile in profiles[pair[0]]]
    user2_scores = [profile.identity for profile in profiles[pair[1]]]
    print(f"{key}: {user1_scores} - {user2_scores}")

### Test analyzer

In [None]:
values = np.array(
    [
        [profile.identity for profile in user_profiles]
        for _id, user_profiles in profile_list
    ]
).T
couple_indices_list = []
for id1, id2 in couple_pairs.values():
    idx1 = None
    idx2 = None
    for idx, (id, _profile) in enumerate(profile_list):
        if id == id1:
            idx1 = idx
        if id == id2:
            idx2 = idx
    if idx1 is not None and idx2 is not None:
        couple_indices_list.append([idx1, idx2])

couple_indices = np.array(couple_indices_list)
print(f"Values shape: {values.shape}")
print(f"Couple indices shape: {couple_indices.shape}")

In [None]:
couple_values = values[:, couple_indices]
print(f"Couple values shape: {couple_values.shape}")
couple_scores = np.abs(couple_values[:, :, 0] - couple_values[:, :, 1])
print(f"Couple scores shape: {couple_scores.shape}")
avg_couple_score = np.mean(couple_scores, axis=1)
print(f"Avg couple score shape: {avg_couple_score.shape}")

In [None]:
def gen_random_couples(num_users: int, num_random_couples: int, rng):
    return rng.sample(range(num_users), 2 * num_random_couples)


def gen_random_couple_sets(num_users: int, num_random_couples: int, num_sets: int, rng):
    return np.array(
        [
            gen_random_couples(num_users, num_random_couples, rng)
            for _ in range(num_sets)
        ]
    ).reshape(num_sets, num_random_couples, 2)


rng = Random(45)
random_couple_sets = gen_random_couple_sets(
    values.shape[1], len(couple_pairs), 100_000, rng
)
print(random_couple_sets.shape)
random_couple_scores = np.abs(
    values[:, random_couple_sets[:, :, 0]] - values[:, random_couple_sets[:, :, 1]]
)
print(random_couple_scores.shape)
random_avgs = random_couple_scores.mean(axis=2)
print(random_avgs.shape)

In [None]:
scores_square = np.abs(values[:, :, None] - values[:, None, :])
scores_square

In [None]:
# For each member of a couple, the number of users with closer identity than their partner.
# Subtract 2 to exclude themselves and their partner.
print(scores_square[:, couple_indices].shape)
steps_to_partner = (
    np.sum(scores_square[:, couple_indices] <= couple_scores[:, :, None, None], axis=3)
    - 2
)
print(steps_to_partner.shape)
avg_steps_to_partner = np.mean(steps_to_partner, axis=(1, 2))
print(avg_steps_to_partner.shape)

In [None]:
random_steps_to_partner = (
    np.sum(
        scores_square[:, random_couple_sets]
        <= random_couple_scores[:, :, :, None, None],
        axis=4,
    )
    - 2
)
random_avg_steps_to_partner = np.mean(random_steps_to_partner, axis=(2, 3))
random_avg_steps_to_partner
print(random_avg_steps_to_partner.shape)

In [None]:
fig, axs = plt.subplots(2, NUM_TESTS, figsize=(20, 9))
for i in range(NUM_TESTS):
    ax = axs[0, i]
    ax.hist(random_avgs[i], bins=50, color="skyblue", edgecolor="black")
    ax.axvline(
        avg_couple_score[i],
        color="red",
        linestyle="--",
        linewidth=2,
        label=f"Observed avg = {avg_couple_score[i]:.3f}",
    )
    ax.set_xlabel("Average |Δ identity|")
    ax.set_ylabel("Frequency")
    ax.set_title("Distribution of Avg |Δ identity| (random samples)")
    ax.legend()

    ax = axs[1, i]
    ax.hist(
        random_avg_steps_to_partner[i], bins=50, color="lightgreen", edgecolor="black"
    )
    ax.axvline(
        avg_steps_to_partner[i],
        color="red",
        linestyle="--",
        linewidth=2,
        label=f"Observed avg = {avg_steps_to_partner[i]:.3f}",
    )
    ax.axvline(
        np.mean(random_avg_steps_to_partner[i]),
        color="black",
        linestyle="--",
        linewidth=2,
        label=f"Observed avg = {np.mean(random_avg_steps_to_partner[i]):.3f}",
    )
    ax.set_xlabel("Average steps to partner")
    ax.set_ylabel("Frequency")
    ax.set_title("Distribution of Avg steps to partner (random samples)")
    ax.legend()
fig.tight_layout()
plt.show()
