In [None]:
from typing import Annotated, Callable
import json
import os
import typing
from typing import Awaitable
import asyncio

from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain import chat_models
from pydantic import BaseModel, Field, RootModel
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from random import Random

### Define the shape of the profile an analyzer should return

In [None]:
class Profile(BaseModel):
    identity: float = Field(ge=0, le=1)
    horoscope: str = Field()

    def cmp(self, other: "Profile") -> float:
        return abs(self.identity - other.identity)


class RelationshipProfile(BaseModel):
    horoscope: str = Field()

### Setup

In [None]:
class ProdProfile(BaseModel):
    identity: float = Field(ge=0, le=1)
    horoscope: str = Field()

    def cmp(self, other: "ProdProfile") -> float:
        return abs(self.identity - other.identity)


class QuestionResponse(BaseModel):
    question: str = Field()
    response: str = Field()


class Response(BaseModel):
    first_name: str = Field()
    last_name: str = Field()
    responses: dict[str, QuestionResponse] = Field()


class User(BaseModel):
    response: Response = Field()
    prod_profile: ProdProfile | None = Field()


UserSet = Annotated[dict[str, Response], Field()]

RelationshipLink = Annotated[list[tuple[str, str]], Field()]


Analyzer = Callable[[Response], Awaitable[Profile]]

if False:

    def update_data():
        from supabase.lib.client_options import SyncClientOptions
        import supabase

        from typing import Any

        SUPABASE_URL_BASE = "https://znsozdvrmfdwxyymtgdz.supabase.co/"

        with open("secrets.json", "r") as f:
            secrets = json.load(f)

        sb_client = supabase.create_client(
            SUPABASE_URL_BASE,
            secrets["EEVA_SUPABASE_SERVICE_KEY"],
            options=SyncClientOptions(auto_refresh_token=False, persist_session=False),
        )

        questions = sb_client.table("questions").select("*").execute().data
        raw_answers = (
            sb_client.table("user_answers")
            .select("user_id, question_id, answer_text")
            .execute()
            .data
        )
        user_answer_lists: dict[str, dict[str, str]] = {}
        for ans in raw_answers:
            user_answer_lists.setdefault(ans["user_id"], {})[ans["question_id"]] = ans[
                "answer_text"
            ]

        raw_user_data = (
            sb_client.table("profiles")
            .select("user_id,first_name,last_name,hidden,profile")
            .execute()
            .data
        )
        users: dict[str, dict[str, Any]] = {}
        for user in raw_user_data:
            user_id = user["user_id"]
            if user["hidden"] or user_id not in user_answer_lists:
                continue
            users[user_id] = {
                "first_name": user["first_name"],
                "last_name": user["last_name"],
                "profile": ProdProfile.model_validate(user["profile"])
                if user["profile"]
                else None,
                "answers": user_answer_lists[user_id],
            }

        user_data: dict[str, User] = {}
        # Build a lookup for question text
        question_text_lookup = {q["id"]: q["text"] for q in questions}
        for user_id, user in users.items():
            answers = user["answers"]
            responses = {}
            for question_id, answer_text in answers.items():
                question_text = question_text_lookup.get(question_id, "")
                responses[question_id] = {
                    "question": question_text,
                    "response": answer_text,
                }
            response = Response(
                first_name=user["first_name"],
                last_name=user["last_name"] if user["last_name"] else "",
                responses=responses,
            )
            user_data[user_id] = User(
                response=response,
                prod_profile=user["profile"],
            )

        with open("data/user_data.json", "w", encoding="utf-8") as f:
            json.dump(
                {k: v.model_dump() for k, v in user_data.items()},
                f,
                indent=2,
            )

    update_data()

with open("data/user_data.json", "r", encoding="utf-8") as f:

    class UserSetDeserializer(RootModel[dict[str, User]]):
        pass

    user_data = UserSetDeserializer.model_validate_json(f.read()).root

In [None]:
couple_pairs = {
    "TaisMagdalena": [
        "c6151208-4b40-478e-a7e4-152dca25c4d6",
        "94a72be8-ef68-4fd2-bc89-6a71995c1425",
    ],
    "PaulineKasper": [
        "8a758e48-8fcc-4fdb-9aa6-0f306da872da",
        "425cb72c-abc6-4a85-86c6-3d4fbbee3f06",
    ],
    "EllenJakob": [
        "69071fe7-1bc7-4451-bd27-53fb2628a6b8",
        "6ea9924d-7803-4c9d-ab7d-f8f3b477f360",
    ],
    "LukasSarah": [
        "53243eb9-441b-4028-99ca-936dadc5fcaa",
        "12e74d54-7160-4cfc-9b68-281cfdb68b9f",
    ],
    "KarlDiogo": [
        "ffef1878-0c55-460c-96ae-5692b08ffafe",
        "9c7cdcee-e785-4b01-b87a-a798bfa891a8",
    ],
    "AskerBente(Saphir)": [
        "a87bb894-9a4f-4592-8d05-7e69a34b1240",
        "2d5ad42c-f204-4852-a4c3-a2fca910705d",
    ],
    "MikkelHelena": [
        "aa781030-f670-49d1-b0fd-3b1dc1e51389",
        "ab139189-c904-495c-a3ae-8698fbb9ab6a",
    ],
}

### Test analyzer

In [None]:
values = np.array(
    [
        [id, user.prod_profile.identity]
        for id, user in user_data.items()
        if user.prod_profile
    ]
)
couple_indices = np.array(
    [
        [np.where(values[:, 0] == id1)[0][0], np.where(values[:, 0] == id2)[0][0]]
        for [id1, id2] in couple_pairs.values()
    ]
)
values = values[:, 1].astype(float)
couple_values = values[couple_indices]
couple_scores = np.abs(couple_values[:, 0] - couple_values[:, 1])
avg_couple_score = np.mean(couple_scores, axis=0)

In [None]:
def gen_random_couples(num_users: int, num_random_couples: int, rng):
    return rng.sample(range(num_users), 2 * num_random_couples)


def gen_random_couple_sets(num_users: int, num_random_couples: int, num_sets: int, rng):
    return np.array(
        [
            gen_random_couples(num_users, num_random_couples, rng)
            for _ in range(num_sets)
        ]
    ).reshape(num_sets, num_random_couples, 2)


rng = Random(45)
random_couple_sets = gen_random_couple_sets(
    len(values), len(couple_pairs), 100_000, rng
)
random_couple_scores = np.abs(
    values[random_couple_sets[:, :, 0]] - values[random_couple_sets[:, :, 1]]
)
random_avgs = random_couple_scores.mean(axis=1)

In [None]:
scores_square = np.abs(values[:, None] - values[None, :])
scores_square

In [None]:
scores_square[couple_indices].shape

In [None]:
# For each member of a couple, the number of users with closer identity than their partner.
# Subtract 2 to exclude themselves and their partner.
steps_to_partner = (
    np.sum(scores_square[couple_indices] <= couple_scores[:, None, None], axis=2) - 2
)
avg_steps_to_partner = np.mean(steps_to_partner)
steps_to_partner

In [None]:
random_steps_to_partner = (
    np.sum(
        scores_square[random_couple_sets] <= random_couple_scores[:, :, None, None],
        axis=3,
    )
    - 2
)
random_avg_steps_to_partner = np.mean(random_steps_to_partner, axis=(1, 2))
random_avg_steps_to_partner

In [None]:
cutoff = 0.1
np.sum(scores_square < cutoff, axis=1)

In [None]:
plt.figure(figsize=(9, 5))
plt.hist(random_avgs, bins=50, color="skyblue", edgecolor="black")
plt.axvline(
    avg_couple_score,
    color="red",
    linestyle="--",
    linewidth=2,
    label=f"Observed avg = {avg_couple_score:.3f}",
)
plt.xlabel("Average |Δ identity|")
plt.ylabel("Frequency")
plt.title("Distribution of Avg |Δ identity| (random samples)")
plt.legend()
plt.tight_layout()
plt.show()

# Print out the percentile of the actual couples' average score
percentile = (random_avgs < avg_couple_score).mean() * 100
print(
    f"The actual couples' average score is at the {percentile:.2f} percentile of the random distribution."
)
plt.figure(figsize=(9, 5))
plt.hist(random_avg_steps_to_partner, bins=50, color="skyblue", edgecolor="black")
plt.axvline(
    avg_steps_to_partner,
    color="red",
    linestyle="--",
    linewidth=2,
    label=f"Observed avg steps = {avg_steps_to_partner:.2f}",
)
plt.xlabel("Average steps to partner")
plt.ylabel("Frequency")
plt.title("Distribution of average steps to partner (random samples)")
plt.legend()
plt.tight_layout()
plt.show()

# Print percentile
percentile_steps = (random_avg_steps_to_partner < avg_steps_to_partner).mean() * 100
print(
    f"The actual couples' average steps to partner is at the {percentile_steps:.2f} percentile of the random distribution."
)