Skip to content

Commit

Permalink
Merge pull request #84 from Toloka/test_fixes
Browse files Browse the repository at this point in the history
Fix tests
  • Loading branch information
dustalov committed Sep 29, 2023
2 parents de3e56a + ebc1374 commit d3fe6ff
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 19 deletions.
5 changes: 4 additions & 1 deletion crowdkit/aggregation/classification/dawid_skene.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ def _e_step(data: pd.DataFrame, priors: pd.Series, errors: pd.DataFrame) -> pd.D
# row is equal to 0. This trick ensures proper scaling after exponentiating and
# does not affect the result of E-step
scaled_likelihoods = np.exp2(log_likelihoods.sub(log_likelihoods.max(axis=1), axis=0))
return scaled_likelihoods.div(scaled_likelihoods.sum(axis=1), axis=0)
scaled_likelihoods = scaled_likelihoods.div(scaled_likelihoods.sum(axis=1), axis=0)
# Convert columns types to label type
scaled_likelihoods.columns = pd.Index(scaled_likelihoods.columns, name='label', dtype=data.label.dtype)
return scaled_likelihoods

def _evidence_lower_bound(self, data: pd.DataFrame, probas: pd.DataFrame, priors: pd.Series, errors: pd.DataFrame) -> float:
# calculate joint probability log-likelihood expectation over probas
Expand Down
40 changes: 22 additions & 18 deletions crowdkit/datasets/_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, Tuple, Dict, Callable, Union

import pandas as pd
import numpy as np

from ._base import get_data_dir, fetch_remote

Expand All @@ -17,6 +18,21 @@ def _load_dataset(data_name: str, data_dir: Optional[str], data_url: str, checks
return full_data_path


def _load_ms_coco_dataframes(data_path: str) -> Tuple[pd.DataFrame, pd.Series]:
labels = np.load(join(data_path, 'crowd_labels.npz'))
rows = []
for key in labels.files:
task, worker = key.split('\t')
rows.append([int(task), worker, labels[key]])
labels = pd.DataFrame(rows, columns=['task', 'worker', 'segmentation'])

true_labels = np.load(join(data_path, 'gt.npz'))
true_labels = {int(key): true_labels[key] for key in true_labels.files}
true_labels = pd.Series(true_labels, name='true_segmentation', index=pd.Index(true_labels.keys(), name='task'))

return labels, true_labels


def load_relevance2(data_dir: Optional[str] = None) -> Tuple[pd.DataFrame, pd.Series]:
data_name = 'relevance-2'
data_url = 'https://tlk.s3.yandex.net/dataset/crowd-kit/relevance-2.zip'
Expand Down Expand Up @@ -51,34 +67,22 @@ def load_dataframes(data_path: str) -> Tuple[pd.DataFrame, pd.Series]:

def load_mscoco(data_dir: Optional[str] = None) -> Tuple[pd.DataFrame, pd.Series]:
data_name = 'mscoco'
data_url = 'https://tlk.s3.yandex.net/dataset/crowd-kit/mscoco.zip'
checksum_url = 'https://tlk.s3.yandex.net/dataset/crowd-kit/mscoco.md5'

def load_dataframes(data_path: str) -> Tuple[pd.DataFrame, pd.Series]:
labels = pd.read_pickle(join(data_path, 'crowd_labels.zip')).rename(columns={'performer': 'worker'})
true_labels = pd.read_pickle(join(data_path, 'gt.zip')).set_index('task')['true_segmentation']

return labels, true_labels
data_url = 'https://huggingface.co/datasets/toloka/crowdkit-datasets/resolve/af2c00549cc026eaea80c18c54a686d98a58fd6e/mscoco.zip'
checksum_url = 'https://huggingface.co/datasets/toloka/crowdkit-datasets/resolve/79d5468d12d233153c0fdcee0dd61b98980ff7a4/mscoco.md5'

full_data_path = _load_dataset(data_name, data_dir, data_url, checksum_url)

return load_dataframes(full_data_path)
return _load_ms_coco_dataframes(full_data_path)


def load_mscoco_small(data_dir: Optional[str] = None) -> Tuple[pd.DataFrame, pd.Series]:
data_name = 'mscoco_small'
data_url = 'https://tlk.s3.yandex.net/dataset/crowd-kit/mscoco_small.zip'
checksum_url = 'https://tlk.s3.yandex.net/dataset/crowd-kit/mscoco_small.md5'

def load_dataframes(data_path: str) -> Tuple[pd.DataFrame, pd.Series]:
labels = pd.read_pickle(join(data_path, 'crowd_labels.zip')).rename(columns={'performer': 'worker'})
true_labels = pd.read_pickle(join(data_path, 'gt.zip'))

return labels, true_labels
data_url = 'https://huggingface.co/datasets/toloka/crowdkit-datasets/resolve/0e0cac7f51869d4b20d83842c578ca3d013af7b7/mscoco_small.zip'
checksum_url = 'https://huggingface.co/datasets/toloka/crowdkit-datasets/resolve/bb48658b78db95845ff2a8d3db3e533a493ab819/mscoco_small.md5'

full_data_path = _load_dataset(data_name, data_dir, data_url, checksum_url)

return load_dataframes(full_data_path)
return _load_ms_coco_dataframes(full_data_path)


def load_crowdspeech_dataframes(data_path: str) -> Tuple[pd.DataFrame, pd.Series]:
Expand Down

0 comments on commit d3fe6ff

Please sign in to comment.