Skip to content

Commit

Permalink
Re-add test, Fix #49
Browse files Browse the repository at this point in the history
  • Loading branch information
AmaliePauli committed Jul 10, 2020
1 parent e4c6a6f commit e317c61
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 9 deletions.
10 changes: 1 addition & 9 deletions danlp/download.py
Expand Up @@ -232,7 +232,6 @@ def update_to(self, b=1, bsize=1, tsize=None):
def download_dataset(dataset: str, cache_dir: str = DEFAULT_CACHE_DIR,
process_func: Callable = None, verbose: bool = False, force = False):
"""
:param verbose:
:param dataset:
:param cache_dir:
Expand Down Expand Up @@ -308,7 +307,6 @@ def _check_file(fname):
def _check_process_func(process_func: Callable):
"""
Checks that a process function takes the correct arguments
:param process_func:
"""
function_args = inspect.getfullargspec(process_func).args
Expand All @@ -319,7 +317,6 @@ def _check_process_func(process_func: Callable):

def _download_and_process(meta_info: dict, process_func: Callable, single_file_path, verbose):
"""
:param meta_info:
:param process_func:
:param single_file_path:
Expand All @@ -342,7 +339,6 @@ def _download_and_process(meta_info: dict, process_func: Callable, single_file_p

def _download_file(meta_info: dict, destination: str, verbose: bool = False):
"""
:param meta_info:
:param destination:
:param verbose:
Expand Down Expand Up @@ -374,7 +370,6 @@ def _unzip_process_func(tmp_file_path: str, meta_info: dict, cache_dir: str = DE
"""
Simple process function for processing models
that only needs to be unzipped after download.
:param tmp_file_path: The path to the downloaded raw file
:param clean_up_raw_data:
:param verbose:
Expand Down Expand Up @@ -404,7 +399,4 @@ def _unzip_process_func(tmp_file_path: str, meta_info: dict, cache_dir: str = DE

else: # Extract all the files to the name of the model/dataset
destination = os.path.join(cache_dir, meta_info['name'])
zip_file.extractall(path=destination)



zip_file.extractall(path=destination)
91 changes: 91 additions & 0 deletions tests/test_embeddings.py
@@ -0,0 +1,91 @@
import unittest

from gensim.models.keyedvectors import FastTextKeyedVectors

from danlp.download import MODELS, download_model, _unzip_process_func
from danlp.models.embeddings import load_wv_with_spacy, load_wv_with_gensim, load_context_embeddings_with_flair, \
AVAILABLE_EMBEDDINGS, AVAILABLE_SUBWORD_EMBEDDINGS


class TestEmbeddings(unittest.TestCase):

def setUp(self):
# First we will add smaller test embeddings to the
MODELS['wiki.da.small.wv'] = {
'url': 'https://danlp.alexandra.dk/304bd159d5de/tests/wiki.da.small.zip',
'vocab_size': 5000,
'dimensions': 300,
'md5_checksum': 'fcaa981a613b325ae4dc61aba235aa82',
'size': 5594508,
'file_extension': '.bin'
}

AVAILABLE_EMBEDDINGS.append('wiki.da.small.wv')

self.embeddings_for_testing = [
'wiki.da.small.wv',
'dslreddit.da.wv'
]
# Lets download the models and unzip it
for emb in self.embeddings_for_testing:
download_model(emb, process_func=_unzip_process_func)

def test_embeddings_with_spacy(self):
with self.assertRaises(ValueError):
load_wv_with_spacy("wiki.da.small.swv")

embeddings = load_wv_with_spacy("wiki.da.wv")

sentence = embeddings('jeg gik ned af en gade')
for token in sentence:
self.assertTrue(token.has_vector)

def test_embeddings_with_gensim(self):
for emb in self.embeddings_for_testing:
embeddings = load_wv_with_gensim(emb)
self.assertEqual(MODELS[emb]['vocab_size'], len(embeddings.vocab))


def test_embeddings_with_flair(self):
from flair.data import Sentence

embs = load_context_embeddings_with_flair()

sentence1 = Sentence('Han fik bank')
sentence2 = Sentence('Han fik en ny bank')

embs.embed(sentence1)
embs.embed(sentence2)

# Check length of context embeddings
self.assertEqual(len(sentence1[2].embedding), 2364)
self.assertEqual(len(sentence2[4].embedding), 2364)

def test_fasttext_embeddings(self):
# First we will add smaller test embeddings to the
MODELS['ddt.swv'] = {
'url': 'https://danlp.alexandra.dk/304bd159d5de/tests/ddt.swv.zip',
'vocab_size': 5000,
'dimensions': 100,
'md5_checksum': 'c50c61e1b434908e2732c80660abf8bf',
'size': 741125088,
'file_extension': '.bin'
}

AVAILABLE_SUBWORD_EMBEDDINGS.append('ddt.swv')

download_model('ddt.swv', process_func=_unzip_process_func)

fasttext_embeddings = load_wv_with_gensim('ddt.swv')

self.assertEqual(type(fasttext_embeddings), FastTextKeyedVectors)

# The word is not in the vocab
self.assertNotIn('institutmedarbejdskontrakt', fasttext_embeddings.vocab)

# However we can get an embedding because of subword units
self.assertEqual(fasttext_embeddings['institutmedarbejdskontrakt'].size, 100)


if __name__ == '__main__':
unittest.main()

0 comments on commit e317c61

Please sign in to comment.