Skip to content

Commit

Permalink
ran on training set
Browse files Browse the repository at this point in the history
  • Loading branch information
czaloom committed Jun 25, 2024
1 parent ae80f4b commit 93c450a
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 172 deletions.
119 changes: 119 additions & 0 deletions api/valor_api/backend/metrics/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from sqlalchemy import alias, select
from sqlalchemy.orm import Session

from valor_api import enums, schemas
from valor_api.backend import core, models
from valor_api.backend.metrics.metric_utils import (
create_metric_mappings,
get_or_create_row,
log_evaluation_duration,
log_evaluation_item_counts,
prepare_filter_for_evaluation,
validate_computation,
)


def _compute_embedding_metrics(
db: Session,
):
queries = select(models.Embedding).where().subquery()
ref = select(models.Embedding).where().subquery()
ref_alias = alias(ref)

ref_dist = db.scalars(
select(ref.c.value.cosine_distance(ref_alias.c.value))
.select_from(ref)
.join(ref_alias, ref_alias.c.id != ref.c.id)
)
query_dist = db.scalars(
select(ref.c.value.cosine_distance(queries.c.value))
.select_from(ref)
.join(
queries,
isouter=True,
)
)


@validate_computation
def compute_embedding_metrics(
*,
db: Session,
evaluation_id: int,
) -> int:
"""
Create classification metrics. This function is intended to be run using FastAPI's `BackgroundTasks`.
Parameters
----------
db : Session
The database Session to query against.
evaluation_id : int
The job ID to create metrics for.
Returns
----------
int
The evaluation job id.
"""

# fetch evaluation
evaluation = core.fetch_evaluation_from_id(db, evaluation_id)

# unpack filters and params
parameters = schemas.EvaluationParameters(**evaluation.parameters)
groundtruth_filter, prediction_filter = prepare_filter_for_evaluation(
db=db,
filters=schemas.Filter(**evaluation.filters),
dataset_names=evaluation.dataset_names,
model_name=evaluation.model_name,
task_type=parameters.task_type,
label_map=parameters.label_map,
)

log_evaluation_item_counts(
db=db,
evaluation=evaluation,
prediction_filter=prediction_filter,
groundtruth_filter=groundtruth_filter,
)

if parameters.metrics_to_return is None:
raise RuntimeError("Metrics to return should always be defined here.")

metrics = _compute_embedding_metrics(
db=db,
prediction_filter=prediction_filter,
groundtruth_filter=groundtruth_filter,
label_map=parameters.label_map,
pr_curve_max_examples=(
parameters.pr_curve_max_examples
if parameters.pr_curve_max_examples
else 0
),
metrics_to_return=parameters.metrics_to_return,
)

metric_mappings = create_metric_mappings(
db=db,
metrics=metrics,
evaluation_id=evaluation.id,
)

for mapping in metric_mappings:
# ignore value since the other columns are unique identifiers
# and have empirically noticed value can slightly change due to floating
# point errors
get_or_create_row(
db,
models.Metric,
mapping,
columns_to_ignore=["value"],
)

log_evaluation_duration(
evaluation=evaluation,
db=db,
)

return evaluation_id
2 changes: 2 additions & 0 deletions api/valor_api/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,5 @@ class MetricType(str, Enum):
mIOU = "mIOU"
PrecisionRecallCurve = "PrecisionRecallCurve"
DetailedPrecisionRecallCurve = "DetailedPrecisionRecallCurve"
CramerVonMises = "CramerVonMises"
KolmgorovSmirnov = "KolmgorovSmirnov"
42 changes: 41 additions & 1 deletion api/valor_api/schemas/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from pydantic import BaseModel, ConfigDict, field_validator

from valor_api.enums import MetricType
from valor_api.schemas.types import Label


Expand All @@ -20,7 +21,7 @@ class Metric(BaseModel):
The `Label` for the metric.
"""

type: str
type: MetricType
parameters: dict | None = None
value: float | dict | None = None
label: Label | None = None
Expand Down Expand Up @@ -745,3 +746,42 @@ def db_mapping(self, evaluation_id: int) -> dict:
"evaluation_id": evaluation_id,
"parameters": {"label_key": self.label_key},
}


class _EmbeddingMetric(BaseModel):
statistics: dict[str, dict[str, float]]
pvalues: dict[str, dict[str, float]]

def db_mapping(self, evaluation_id: int, type_name: str) -> dict:
"""
Creates a mapping for use when uploading the metric to the database.
Parameters
----------
evaluation_id : int
The evaluation id.
Returns
----------
A mapping dictionary.
"""
return {
"value": {
"statistics": self.statistics,
"pvalues": self.pvalues,
},
"type": type_name,
"evaluation_id": evaluation_id,
}


class CramerVonMisesMetric(_EmbeddingMetric):
def db_mapping(self, evaluation_id: int, type_name: str) -> dict:
type_name = MetricType.CramerVonMises
return super().db_mapping(evaluation_id, type_name)


class KolmgorovSmirnovMetric(_EmbeddingMetric):
def db_mapping(self, evaluation_id: int, type_name: str) -> dict:
type_name = MetricType.KolmgorovSmirnov
return super().db_mapping(evaluation_id, type_name)
Loading

0 comments on commit 93c450a

Please sign in to comment.