In [1]:
import os

import pymongo
from dotenv import load_dotenv
from pymongo import MongoClient
from tqdm import tqdm

load_dotenv("../env")
client = MongoClient(os.environ.get("MONGO_DB_CONNECTION"))
db = client.get_database("prismai")

In [2]:
collection_transition_scores = db.get_collection("test_swt")

In [None]:
from simple_dataset import Dataset
from transition_scores.data import FeaturesDict, TransitionScores

dataset = Dataset(
    FeaturesDict.new(**document) for document in collection_transition_scores.find()
)
len(dataset)

In [None]:
dataset.apply(TransitionScores.merge, "transition_scores")
type(dataset[0]["transition_scores"])

In [5]:
grouped_dataset = dataset.map(
    lambda doc: {"doc_id": doc["document"]["_id"]} | doc,
    in_place=False,
).group_documents_by(
    "doc_id",
    remainder_into="documents",
    in_place=True,
)

In [None]:
[len(group["documents"]) for group in grouped_dataset]

In [3]:
from itertools import islice

import numpy as np
from matplotlib import pyplot as plt

In [6]:
def get_oracle_ts(
    seq_len=1024,
    vocab_size=128,
    top_k=10,
) -> TransitionScores:
    target_ids = np.random.choice(vocab_size, seq_len, replace=True).tolist()
    target_probs = (np.arange(seq_len) / seq_len).tolist()
    target_ranks = np.random.choice(vocab_size, seq_len, replace=True).tolist()
    top_k_ids = [
        np.random.choice(vocab_size, top_k, replace=False) for _ in range(seq_len)
    ]
    top_k_probs = [
        list(sorted((np.random.rand(top_k) / 10).tolist(), reverse=True))
        for _ in range(seq_len)
    ]
    return TransitionScores(
        target_ids,
        target_probs,
        target_ranks,
        top_k_ids,
        top_k_probs,
    )

In [259]:
def viz2d(features: np.array, vmin=None, vmax=None):
    fig = plt.imshow(features, cmap="gray", vmin=vmin, vmax=vmax)
    fig.axes.set_axis_off()
    return fig

In [8]:
def features_2d_likelihood(
    transition_scores: TransitionScores,
    w: int,
    h: int,
    skip: int = 1,
    no_overlap: bool = False,
    take_first: bool = False,
    sort_slices: bool = True,
):
    if no_overlap:
        raise NotImplementedError("no_overlap not implemented")

    size = len(transition_scores) - skip - w
    if size < w:
        raise ValueError(
            f"Sequence is too short: ({len(transition_scores)} - {skip} - {w}) < {w}"
        )

    if take_first:
        offsets = np.arange(skip, size, w)[:h]
        if not sort_slices:
            offsets = np.random.permutation(offsets)
    else:
        offsets = np.random.choice(size, h, replace=False) + skip
        if sort_slices:
            offsets = np.sort(offsets)

    features = np.zeros((h, w))
    for i, offset in enumerate(offsets):
        features[i] = transition_scores.target_probs[offset : offset + w]

    return features

In [None]:
features = features_2d_likelihood(dataset[3]["transition_scores"], 32, 16, take_first=True)
viz2d(features)
plt.show()

features = features_2d_likelihood(dataset[-1]["transition_scores"], 32, 16, take_first=True)
viz2d(features)
plt.show()

In [10]:
def viz3d(features: np.array, vmin=None, vmax=None):
    vmin = vmin or features.min()
    vmax = vmax or features.max()

    h, w, d = features.shape
    d = int(np.ceil(np.sqrt(d)))
    fig, axes = plt.subplots(d, d, figsize=(5, h / w * 5))
    for i, ax in enumerate(axes.flat):
        if i < features.shape[-1]:
            ax.imshow(features[..., i], cmap="gray", vmin=vmin, vmax=vmax)
        ax.set_axis_off()
    return fig

In [11]:
from warnings import warn


def features_3d_liklihood_topk_stack(
    transition_scores: TransitionScores,
    w: int,
    h: int,
    d: int,
    skip: int = 1,
    take_first: bool = False,
    no_overlap: bool = False,
    remove_target: bool = True,
    sort_slices: bool = True,
    log: bool = False,
):
    if no_overlap:
        raise NotImplementedError("no_overlap not implemented")

    size = len(transition_scores) - skip - w
    if size < w:
        raise ValueError(
            f"Sequence is too short: ({len(transition_scores)} - {skip} - {w}) < {w}"
        )
    if size < w * h:
        warn(
            f"Sequence very short, will produce overlapping offsets: ({len(transition_scores)} - {skip} - {w}) = {size} < {w * h}"
        )
    ts = transition_scores[skip:]

    if take_first:
        offsets = np.arange(0, size, w)[:h]
        if not sort_slices:
            offsets = np.random.permutation(offsets)
    else:
        offsets = np.random.choice(size, h, replace=False) + skip
        if sort_slices:
            offsets = np.sort(offsets)

    top_k_probs = [
        list(
            islice(
                (
                    prob
                    for top_id, prob in zip(top_ids, top_probs)
                    if not remove_target or top_id != tgt_id
                ),
                d - 1,
            )
        )
        for tgt_id, top_ids, top_probs in zip(
            ts.target_ids,
            ts.top_k_indices,
            ts.top_k_probs,
        )
    ]
    top_k_probs = np.array(top_k_probs)

    features = np.empty((h, w, d))
    for i, offset in enumerate(offsets):
        features[i, :, 0] = ts.target_probs[offset : offset + w]
        features[i, :, 1:] = top_k_probs[offset : offset + w]  # [:, ::-1]
    if log:
        features = np.log1p(features)
    return features


In [None]:
ts = get_oracle_ts()
features = features_3d_liklihood_topk_stack(ts, 16, 16, 4, take_first=True)
viz3d(features)
plt.show()

In [None]:
features = features_3d_liklihood_topk_stack(
    dataset[3]["transition_scores"], 32, 16, 9, take_first=True,
    log="log_ratio"
)
viz3d(features)
plt.show()

mn, mx = features.min(), features.max()

features = features_3d_liklihood_topk_stack(
    dataset[-1]["transition_scores"], 32, 16, 9, take_first=True,
    log="log_ratio"
)
viz3d(features, vmin=mn, vmax=mx)
plt.show()

In [17]:
from typing import Literal


def features_3d_likelihood_top_k_ratio(
    transition_scores: TransitionScores,
    w: int,
    h: int,
    d: int,
    skip: int = 1,
    take_first: bool = False,
    no_overlap: bool = False,
    sort_slices: bool = True,
    log: None | Literal["log_ratio", "ratio_of_logs"] = None,
):
    if no_overlap:
        raise NotImplementedError("no_overlap not implemented")

    size = len(transition_scores) - skip - w
    if size < w:
        raise ValueError(
            f"Sequence is too short: ({len(transition_scores)} - {skip} - {w}) < {w}"
        )
    if size < w * h:
        warn(
            f"Sequence very short, will produce overlapping offsets: ({len(transition_scores)} - {skip} - {w}) = {size} < {w * h}"
        )
    ts = transition_scores[skip:]

    if take_first:
        offsets = np.arange(0, size, w)[:h]
        if not sort_slices:
            offsets = np.random.permutation(offsets)
    else:
        offsets = np.random.choice(size, h, replace=False) + skip
        if sort_slices:
            offsets = np.sort(offsets)

    top_k_probs = [
        list(
            islice(
                (prob for top_id, prob in zip(top_ids, top_probs) if top_id != tgt_id),
                d,
            )
        )
        for tgt_id, top_ids, top_probs in zip(
            ts.target_ids,
            ts.top_k_indices,
            ts.top_k_probs,
        )
    ]
    top_k_probs = np.array(top_k_probs)
    target_probs = np.array(ts.target_probs).reshape(-1, 1)

    if log == "ratio_of_logs":
        top_k_probs = np.log(top_k_probs)
        target_probs = np.log(target_probs)

    features = np.empty((h, w, d))
    for i, offset in enumerate(offsets):
        features[i, :, :] = np.true_divide(
            target_probs[offset : offset + w],
            top_k_probs[offset : offset + w],
        )
        if log == "log_ratio":
            features[i, :, :] = np.log(features[i, :, :])

    return features


In [None]:
human = features_3d_likelihood_top_k_ratio(
    dataset[3]["transition_scores"],
    32,
    16,
    9,
    take_first=True,
    # log="log_ratio",
)
viz3d(human)
plt.show()

ai = features_3d_likelihood_top_k_ratio(
    dataset[-1]["transition_scores"],
    32,
    16,
    9,
    take_first=True,
    # log="log_ratio",
)
viz3d(ai, vmin=human.min(), vmax=human.max())
plt.show()

In [None]:
viz2d(human.reshape(-1, 9).T)
plt.show()

viz2d(ai.reshape(-1, 9).T, vmin=human.min(), vmax=human.max())
plt.show()

In [None]:
features = features_3d_likelihood_top_k_ratio(ts, 32, 16, 9, take_first=True)
viz3d(features)
plt.show()

In [None]:
features_3d_likelihood_top_k_ratio(
    dataset[3]["transition_scores"], 32, 16, 9, take_first=True
).sum()

In [None]:
features_3d_likelihood_top_k_ratio(
    dataset[-1]["transition_scores"], 32, 16, 9, take_first=True
).sum()

In [251]:
def features_2d_log_likelihood_log_rank_ratio(
    transition_scores: TransitionScores,
    w: int,
    h: int,
    skip: int = 1,
    take_first: bool = False,
    no_overlap: bool = False,
    sort_slices: bool = True,
    epsilon=1e-8,
):
    if no_overlap:
        raise NotImplementedError("no_overlap not implemented")

    size = len(transition_scores) - skip - w
    if size < w:
        raise ValueError(
            f"Sequence is too short: ({len(transition_scores)} - {skip} - {w}) < {w}"
        )

    if take_first:
        offsets = np.arange(skip, size, w)[:h]
        if not sort_slices:
            offsets = np.random.permutation(offsets)
    else:
        offsets = np.random.choice(size, h, replace=False) + skip
        if sort_slices:
            offsets = np.sort(offsets)

    features = np.zeros((h, w))
    target_probs = np.array(transition_scores.target_probs)
    target_ranks = np.array(transition_scores.target_ranks)
    for i, offset in enumerate(offsets):
        features[i] = -np.true_divide(
            # probs are generally small but never zero, so log(x) is safe
            np.log(target_probs[offset : offset + w]),
            # ranks, however, are 0-indexed, so we use log1p to avoid log(0)
            # and add epsilon to avoid division by zero
            np.log1p(target_ranks[offset : offset + w]) + epsilon,
        )

    return features


In [None]:
human = features_2d_log_likelihood_log_rank_ratio(
    dataset[2]["transition_scores"], 32, 16, take_first=False
)
viz2d(human)
plt.show()

ai = features_2d_log_likelihood_log_rank_ratio(
    dataset[-1]["transition_scores"], 32, 16, take_first=False
)
viz2d(ai, vmin=human.min(), vmax=human.max())
plt.show()

In [291]:
def viz_line(features: np.array, vmin=None, vmax=None):
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.set_axis_off()
    ax.imshow(
        features.flatten().reshape(-1, 1).repeat(32, 1).T,
        cmap="gray",
        vmin=vmin,
        vmax=vmax,
    )
    return fig

In [None]:
human = features_2d_log_likelihood_log_rank_ratio(
    dataset[2]["transition_scores"],
    32,
    16,
    take_first=False,
    sort_slices=False,
)
viz_line(human)
plt.show()

ai = features_2d_log_likelihood_log_rank_ratio(
    dataset[-1]["transition_scores"],
    32,
    16,
    take_first=False,
    sort_slices=False,
)
viz_line(ai, vmin=human.min(), vmax=human.max())
plt.show()

In [293]:
import numpy as np


def log_likelihood_log_rank_ratio(
    target_probs: np.array,
    target_ranks: np.array,
    epsilon: float = 1e-8,
) -> np.array:
    """Compute the log-likelihood log-rank ratio.

    Args:
        target_probs (np.array): Target token probabilities.
        target_ranks (np.array): Zero-indexed target token ranks.

    Returns:
        np.array: array of the same shape as target_probs and target_ranks.
    """
    return -np.true_divide(
        # probs are generally small but never zero, so log(x) is safe
        np.log(target_probs),
        # ranks, however, are 0-indexed, so we use log1p to avoid log(0)
        # and add epsilon to avoid division by zero
        np.log1p(target_ranks) + epsilon,
    )

In [None]:
ts = dataset[2]["transition_scores"]
l = len(ts.target_probs) - 256
o = np.random.randint(1, l)
o = 1
human = log_likelihood_log_rank_ratio(
    ts.target_probs[o : o + 256],
    ts.target_ranks[o : o + 256],
)
viz_line(human)
plt.show()

ts = dataset[-1]["transition_scores"]
l = len(ts.target_probs) - 256
o = np.random.randint(1, l)
o = 1
ai = log_likelihood_log_rank_ratio(
    ts.target_probs[o : o + 256],
    ts.target_ranks[o : o + 256],
)
viz_line(ai, vmin=human.min(), vmax=human.max())
plt.show()

In [1]:
import os

from dotenv import load_dotenv

from luminar.document.data import DocumentClassificationDataModule
from luminar.mongo import MongoDBAdapter

load_dotenv("../env")

False

In [10]:
db = MongoDBAdapter(
    os.environ.get("MONGO_DB_CONNECTION"),
    "prismai",
    "collected_items",
    "synthesized_texts",
    "transition_scores",
    source_collection_limit=100,
)


In [11]:
dm = DocumentClassificationDataModule(db)

In [12]:
dm.setup()

  np.log(target_probs),


In [13]:
dm.train_data

[(array([1.27855146e+00, 1.34798792e+00, 2.03493063e+00, 1.48797243e+08,
         1.72505205e+00, 1.40413764e+00, 1.56457079e+00, 2.16193620e+00,
         1.39077830e+00, 1.83790971e+00, 1.40629756e+00, 1.76936231e+00,
         1.67441166e+00, 1.55129705e+00, 1.37006306e+00, 2.13476492e+00,
         2.15672902e+00, 1.40797145e+00, 2.28319284e+00, 1.96670193e+00,
         1.45935617e+00, 1.42534585e+00, 1.40346861e+00, 2.19910007e+00,
         1.37280477e+00, 1.91138666e+00, 1.45291292e+00, 1.46542989e+00,
         1.44720809e+00, 2.30223292e+00, 2.78026063e+00, 1.65930851e+00,
         1.44645375e+00, 2.25769053e+00, 7.91587253e+07, 1.49720410e+08,
         1.39080238e+00, 1.48482985e+00, 1.51920978e+00, 2.17185240e+06,
         4.16759115e+07, 1.36122098e+00, 1.53777643e+00, 5.00140724e+07,
         1.54440428e+00, 3.12922876e+00, 1.57365769e+00, 1.47298389e+00,
         1.56912665e+00, 9.99889968e+07, 6.21553527e+07, 1.59169050e+00,
         1.44610444e+00, 1.48566796e+00, 1.54868019

In [14]:
next(iter(dm.train_dataloader()))

RuntimeError: stack expects each tensor to be equal size, but got [256] at entry 0 and [216] at entry 8