diff --git a/annif/backend/svc.py b/annif/backend/svc.py index e8e1f8617..626f42de3 100644 --- a/annif/backend/svc.py +++ b/annif/backend/svc.py @@ -53,7 +53,11 @@ def _corpus_to_texts_and_classes(corpus): classes = [] for doc in corpus.documents: texts.append(doc.text) - classes.append(doc.uris[0]) + if len(doc.uris) > 1: + raise NotSupportedException( + 'SVC backend does not support training on documents ' + + 'with multiple subjects.') + classes.append(list(doc.uris)[0]) return texts, classes def _train_classifier(self, veccorpus, classes): diff --git a/tests/conftest.py b/tests/conftest.py index f92e40e92..5e8f1ab45 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,6 +96,17 @@ def document_corpus(subject_index): return doc_corpus +@pytest.fixture(scope='module') +def document_corpus_single_subject(document_corpus): + docs_single_subj = [] + for doc in document_corpus.documents: + uri = list(doc.uris)[0] if len(doc.uris) > 0 else None + label = list(doc.labels)[0] if len(doc.labels) > 0 else None + docs_single_subj.append( + annif.corpus.Document(doc.text, {uri}, {label})) + return annif.corpus.DocumentList(docs_single_subj) + + @pytest.fixture(scope='module') def pretrained_vectors(): return py.path.local(os.path.join( diff --git a/tests/test_backend_svc.py b/tests/test_backend_svc.py index 9cc5245d0..df6928ba1 100644 --- a/tests/test_backend_svc.py +++ b/tests/test_backend_svc.py @@ -34,26 +34,26 @@ def test_svc_suggest_no_vectorizer(project): svc.suggest("example text") -def test_svc_train(datadir, document_corpus, project): +def test_svc_train(datadir, document_corpus_single_subject, project): svc_type = annif.backend.get_backend('svc') svc = svc_type( backend_id='svc', config_params={}, project=project) - svc.train(document_corpus) + svc.train(document_corpus_single_subject) assert svc._model is not None assert datadir.join('svc-model.gz').exists() -def test_svc_train_ngram(datadir, document_corpus, project): +def test_svc_train_ngram(datadir, document_corpus_single_subject, project): svc_type = annif.backend.get_backend('svc') svc = svc_type( backend_id='svc', config_params={'ngram': 2}, project=project) - svc.train(document_corpus) + svc.train(document_corpus_single_subject) assert svc._model is not None assert datadir.join('svc-model.gz').exists() @@ -69,6 +69,17 @@ def test_svc_train_cached(datadir, project): svc.train("cached") +def test_svc_train_multiple_subjects(datadir, document_corpus, project): + svc_type = annif.backend.get_backend('svc') + svc = svc_type( + backend_id='svc', + config_params={}, + project=project) + + with pytest.raises(NotSupportedException): + svc.train(document_corpus) + + def test_svc_train_nodocuments(datadir, project, empty_corpus): svc_type = annif.backend.get_backend('svc') svc = svc_type(