Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Merge pull request #34 from PetrochukM/dataset
Browse files Browse the repository at this point in the history
Add set operation to dataset
  • Loading branch information
PetrochukM committed May 6, 2018
2 parents 9925127 + 6f71f1d commit 1766cc3
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 36 deletions.
34 changes: 34 additions & 0 deletions tests/datasets/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import random

from torchnlp.datasets import Dataset

Expand All @@ -24,6 +25,25 @@ def test_dataset_get_column():
dataset['c']


def test_dataset_set_column():
dataset = Dataset([{'a': 'a', 'b': 'b'}, {'a': 'aa', 'b': 'bb'}])

# Regular column update
dataset['a'] = ['aa', 'aaa']
assert dataset['a'] == ['aa', 'aaa']

# To Little
dataset['b'] = ['b']
assert dataset['b'] == ['b', None]

# Too many
dataset['c'] = ['c', 'cc', 'ccc']
assert dataset['c'] == ['c', 'cc', 'ccc']

# Smoke (regression test)
random.shuffle(dataset)


def test_dataset_get_row():
dataset = Dataset([{'a': 'a', 'b': 'b'}, {'a': 'aa', 'b': 'bb'}])
assert dataset[0] == {'a': 'a', 'b': 'b'}
Expand All @@ -32,6 +52,20 @@ def test_dataset_get_row():
dataset[2]


def test_dataset_set_row():
dataset = Dataset([{'a': 'a', 'b': 'b'}, {'a': 'aa', 'b': 'bb'}])
dataset[0] = {'c': 'c'}
assert dataset['c'] == ['c', None]
assert dataset['a'] == [None, 'aa']

dataset[0:2] = [{'d': 'd'}, {'d': 'dd'}]
assert dataset[0] == {'d': 'd'}
assert dataset[1] == {'d': 'dd'}

with pytest.raises(IndexError):
dataset[2] = {'c': 'c'}


def test_dataset_equality():
dataset = Dataset([{'a': 'a', 'b': 'b'}, {'a': 'aa', 'b': 'bb'}])
other_dataset = Dataset([{'a': 'a', 'b': 'b'}, {'a': 'aa', 'b': 'bb'}])
Expand Down
43 changes: 41 additions & 2 deletions torchnlp/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,52 @@ def __getitem__(self, key):
# Given an column string return list of column values.
if isinstance(key, str):
if key not in self.columns:
raise AttributeError
raise AttributeError('Key not in columns.')
return [row[key] if key in row else None for row in self.rows]
# Given an row integer return a object of row values.
elif isinstance(key, (int, slice)):
return self.rows[key]
else:
raise TypeError("Invalid argument type.")
raise TypeError('Invalid argument type.')

def __setitem__(self, key, item):
"""
Set a column or row for a dataset.
Args:
key (str or int): String referencing a column or integer referencing a row
item (list or dict): Column or rows to set in the dataset.
"""
if isinstance(key, str):
column = item
self.columns.add(key)
if len(column) > len(self.rows):
for i, value in enumerate(column):
if i < len(self.rows):
self.rows[i][key] = value
else:
self.rows.append({key: value})
else:
for i, row in enumerate(self.rows):
if i < len(column):
self.rows[i][key] = column[i]
else:
self.rows[i][key] = None
elif isinstance(key, slice):
rows = item
for row in rows:
if not isinstance(row, dict):
raise ValueError('Row must be a dict.')
self.columns.update(row.keys())
self.rows[key] = rows
elif isinstance(key, int):
row = item
if not isinstance(row, dict):
raise ValueError('Row must be a dict.')
self.columns.update(row.keys())
self.rows[key] = row
else:
raise TypeError('Invalid argument type.')

def __len__(self):
return len(self.rows)
Expand Down
66 changes: 32 additions & 34 deletions torchnlp/word_to_vector/bpemb.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,26 @@
from torchnlp.word_to_vector.pretrained_word_vectors import _PretrainedWordVectors


# List of all 275 supported languages from http://cosyne.h-its.org/bpemb/data/
SUPPORTED_LANGUAGES = [
'ab', 'ace', 'ady', 'af', 'ak', 'als', 'am', 'an', 'ang', 'ar', 'arc',
'arz', 'as', 'ast', 'atj', 'av', 'ay', 'az', 'azb', 'ba', 'bar', 'bcl',
'be', 'bg', 'bi', 'bjn', 'bm', 'bn', 'bo', 'bpy', 'br', 'bs', 'bug', 'bxr',
'ca', 'cdo', 'ce', 'ceb', 'ch', 'chr', 'chy', 'ckb', 'co', 'cr', 'crh',
'cs', 'csb', 'cu', 'cv', 'cy', 'da', 'de', 'din', 'diq', 'dsb', 'dty', 'dv',
'dz', 'ee', 'el', 'en', 'eo', 'es', 'et', 'eu', 'ext', 'fa', 'ff', 'fi',
'fj', 'fo', 'fr', 'frp', 'frr', 'fur', 'fy', 'ga', 'gag', 'gan', 'gd', 'gl',
'glk', 'gn', 'gom', 'got', 'gu', 'gv', 'ha', 'hak', 'haw', 'he', 'hi',
'hif', 'hr', 'hsb', 'ht', 'hu', 'hy', 'ia', 'id', 'ie', 'ig', 'ik', 'ilo',
'io', 'is', 'it', 'iu', 'ja', 'jam', 'jbo', 'jv', 'ka', 'kaa', 'kab', 'kbd',
'kbp', 'kg', 'ki', 'kk', 'kl', 'km', 'kn', 'ko', 'koi', 'krc', 'ks', 'ksh',
'ku', 'kv', 'kw', 'ky', 'la', 'lad', 'lb', 'lbe', 'lez', 'lg', 'li', 'lij',
'lmo', 'ln', 'lo', 'lrc', 'lt', 'ltg', 'lv', 'mai', 'mdf', 'mg', 'mh',
'mhr', 'mi', 'min', 'mk', 'ml', 'mn', 'mr', 'mrj', 'ms', 'mt', 'mwl', 'my',
'myv', 'mzn', 'na', 'nap', 'nds', 'ne', 'new', 'ng', 'nl', 'nn', 'no',
'nov', 'nrm', 'nso', 'nv', 'ny', 'oc', 'olo', 'om', 'or', 'os', 'pa', 'pag',
'pam', 'pap', 'pcd', 'pdc', 'pfl', 'pi', 'pih', 'pl', 'pms', 'pnb', 'pnt',
'ps', 'pt', 'qu', 'rm', 'rmy', 'rn', 'ro', 'ru', 'rue', 'rw', 'sa', 'sah',
'sc', 'scn', 'sco', 'sd', 'se', 'sg', 'sh', 'si', 'sk', 'sl', 'sm', 'sn',
'so', 'sq', 'sr', 'srn', 'ss', 'st', 'stq', 'su', 'sv', 'sw', 'szl', 'ta',
'tcy', 'te', 'tet', 'tg', 'th', 'ti', 'tk', 'tl', 'tn', 'to', 'tpi', 'tr',
'ts', 'tt', 'tum', 'tw', 'ty', 'tyv', 'udm', 'ug', 'uk', 'ur', 'uz', 've',
'vec', 'vep', 'vi', 'vls', 'vo', 'wa', 'war', 'wo', 'wuu', 'xal', 'xh',
'xmf', 'yi', 'yo', 'za', 'zea', 'zh', 'zu'
'ab', 'ace', 'ady', 'af', 'ak', 'als', 'am', 'an', 'ang', 'ar', 'arc', 'arz', 'as', 'ast',
'atj', 'av', 'ay', 'az', 'azb', 'ba', 'bar', 'bcl', 'be', 'bg', 'bi', 'bjn', 'bm', 'bn', 'bo',
'bpy', 'br', 'bs', 'bug', 'bxr', 'ca', 'cdo', 'ce', 'ceb', 'ch', 'chr', 'chy', 'ckb', 'co',
'cr', 'crh', 'cs', 'csb', 'cu', 'cv', 'cy', 'da', 'de', 'din', 'diq', 'dsb', 'dty', 'dv', 'dz',
'ee', 'el', 'en', 'eo', 'es', 'et', 'eu', 'ext', 'fa', 'ff', 'fi', 'fj', 'fo', 'fr', 'frp',
'frr', 'fur', 'fy', 'ga', 'gag', 'gan', 'gd', 'gl', 'glk', 'gn', 'gom', 'got', 'gu', 'gv', 'ha',
'hak', 'haw', 'he', 'hi', 'hif', 'hr', 'hsb', 'ht', 'hu', 'hy', 'ia', 'id', 'ie', 'ig', 'ik',
'ilo', 'io', 'is', 'it', 'iu', 'ja', 'jam', 'jbo', 'jv', 'ka', 'kaa', 'kab', 'kbd', 'kbp', 'kg',
'ki', 'kk', 'kl', 'km', 'kn', 'ko', 'koi', 'krc', 'ks', 'ksh', 'ku', 'kv', 'kw', 'ky', 'la',
'lad', 'lb', 'lbe', 'lez', 'lg', 'li', 'lij', 'lmo', 'ln', 'lo', 'lrc', 'lt', 'ltg', 'lv',
'mai', 'mdf', 'mg', 'mh', 'mhr', 'mi', 'min', 'mk', 'ml', 'mn', 'mr', 'mrj', 'ms', 'mt', 'mwl',
'my', 'myv', 'mzn', 'na', 'nap', 'nds', 'ne', 'new', 'ng', 'nl', 'nn', 'no', 'nov', 'nrm',
'nso', 'nv', 'ny', 'oc', 'olo', 'om', 'or', 'os', 'pa', 'pag', 'pam', 'pap', 'pcd', 'pdc',
'pfl', 'pi', 'pih', 'pl', 'pms', 'pnb', 'pnt', 'ps', 'pt', 'qu', 'rm', 'rmy', 'rn', 'ro', 'ru',
'rue', 'rw', 'sa', 'sah', 'sc', 'scn', 'sco', 'sd', 'se', 'sg', 'sh', 'si', 'sk', 'sl', 'sm',
'sn', 'so', 'sq', 'sr', 'srn', 'ss', 'st', 'stq', 'su', 'sv', 'sw', 'szl', 'ta', 'tcy', 'te',
'tet', 'tg', 'th', 'ti', 'tk', 'tl', 'tn', 'to', 'tpi', 'tr', 'ts', 'tt', 'tum', 'tw', 'ty',
'tyv', 'udm', 'ug', 'uk', 'ur', 'uz', 've', 'vec', 'vep', 'vi', 'vls', 'vo', 'wa', 'war', 'wo',
'wuu', 'xal', 'xh', 'xmf', 'yi', 'yo', 'za', 'zea', 'zh', 'zu'
]

# All supported vector dimensionalities for which embeddings were trained
Expand All @@ -40,6 +34,11 @@ class BPEmb(_PretrainedWordVectors):
"""
Byte-Pair Encoding (BPE) embeddings trained on Wikipedia for 275 languages
A collection of pre-trained subword unit embeddings in 275 languages, based
on Byte-Pair Encoding (BPE). In an evaluation using fine-grained entity typing as testbed,
BPEmb performs competitively, and for some languages better than alternative subword
approaches, while requiring vastly fewer resources and no tokenization.
References:
* https://arxiv.org/abs/1710.02187
* https://github.com/bheinzerling/bpemb
Expand Down Expand Up @@ -78,24 +77,23 @@ def __init__(self, language='en', dim=300, merge_ops=50000, **kwargs):
# Check if all parameters are valid
if language not in SUPPORTED_LANGUAGES:
raise ValueError(("Language '%s' not supported. Use one of the "
"following options instead:\n%s"
) % (language, SUPPORTED_LANGUAGES))
"following options instead:\n%s") % (language, SUPPORTED_LANGUAGES))
if dim not in SUPPORTED_DIMS:
raise ValueError(("Embedding dimensionality of '%d' not supported. "
"Use one of the following options instead:\n%s"
) % (dim, SUPPORTED_DIMS))
"Use one of the following options instead:\n%s") % (dim,
SUPPORTED_DIMS))
if merge_ops not in SUPPORTED_MERGE_OPS:
raise ValueError(("Number of '%d' merge operations not supported. "
"Use one of the following options instead:\n%s"
) % (merge_ops, SUPPORTED_MERGE_OPS))
"Use one of the following options instead:\n%s") %
(merge_ops, SUPPORTED_MERGE_OPS))

format_map = {'language': language, 'merge_ops': merge_ops, 'dim': dim}

# Assemble file name to locally store embeddings under
name = self.file_name.format_map(format_map)
# Assemble URL to download the embeddings form
url = (self.url_base.format_map(format_map) +
self.file_name.format_map(format_map) +
self.zip_extension)
url = (
self.url_base.format_map(format_map) + self.file_name.format_map(format_map) +
self.zip_extension)

super(BPEmb, self).__init__(name, url=url, **kwargs)

0 comments on commit 1766cc3

Please sign in to comment.