forked from microsoft/TaskWeaver
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_embedding.py
120 lines (97 loc) · 4.07 KB
/
test_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import pytest
from injector import Injector
from taskweaver.config.config_mgt import AppConfigSource
from taskweaver.llm import QWenService, ZhipuAIService
from taskweaver.llm.ollama import OllamaService
from taskweaver.llm.openai import OpenAIService
from taskweaver.llm.sentence_transformer import SentenceTransformerService
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
@pytest.mark.skipif(True, reason="Test doesn't work in Github Actions.")
def test_sentence_transformer_embedding():
app_injector = Injector([])
app_config = AppConfigSource(
config={
"llm.embedding_api_type": "sentence_transformer",
"llm.embedding_model": "all-mpnet-base-v2",
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
sentence_transformer_service = app_injector.create_object(
SentenceTransformerService,
)
text_list = ["This is a test sentence.", "This is another test sentence."]
embedding1 = sentence_transformer_service.get_embeddings(text_list)
assert len(embedding1) == 2
assert len(embedding1[0]) == 768
assert len(embedding1[1]) == 768
@pytest.mark.skipif(True, reason="Test doesn't work in Github Actions.")
def test_openai_embedding():
app_injector = Injector()
app_config = AppConfigSource(
config={
"llm.embedding_api_type": "openai",
"llm.embedding_model": "text-embedding-ada-002",
"llm.api_key": "",
# need to configure llm.api_key in the config to run this test
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
openai_service = app_injector.create_object(OpenAIService)
text_list = ["This is a test sentence.", "This is another test sentence."]
embedding1 = openai_service.get_embeddings(text_list)
assert len(embedding1) == 2
assert len(embedding1[0]) == 1536
assert len(embedding1[1]) == 1536
@pytest.mark.skipif(True, reason="Test doesn't work in Github Actions.")
def test_ollama_embedding():
app_injector = Injector()
app_config = AppConfigSource(
config={
"llm.embedding_api_type": "ollama",
"llm.embedding_model": "llama2",
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
ollama_service = app_injector.create_object(OllamaService)
text_list = ["This is a test sentence.", "This is another test sentence."]
embedding1 = ollama_service.get_embeddings(text_list)
assert len(embedding1) == 2
assert len(embedding1[0]) == 4096
assert len(embedding1[1]) == 4096
@pytest.mark.skipif(True, reason="Test doesn't work in Github Actions.")
def test_qwen_embedding():
app_injector = Injector()
app_config = AppConfigSource(
config={
"llm.embedding_api_type": "qwen",
"llm.embedding_model": "text-embedding-v1",
"llm.api_key": "",
# need to configure llm.api_key in the config to run this test
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
qwen_service = app_injector.create_object(QWenService)
text_list = ["This is a test sentence.", "This is another test sentence."]
embeddings = qwen_service.get_embeddings(text_list)
assert len(embeddings) == 2
assert len(embeddings[0]) == 1536
assert len(embeddings[1]) == 1536
@pytest.mark.skipif(True, reason="Test doesn't work in Github Actions.")
def test_zhipuai_embedding():
app_injector = Injector()
app_config = AppConfigSource(
config={
"llm.embedding_api_type": "zhipuai",
"llm.embedding_model": "embedding-2",
"llm.api_key": "",
# need to configure llm.api_key in the config to run this test
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
zhipuai_service = app_injector.create_object(ZhipuAIService)
text_list = ["This is a test sentence.", "This is another test sentence."]
embedding1 = zhipuai_service.get_embeddings(text_list)
assert len(embedding1) == 2
assert len(embedding1[0]) == 1024
assert len(embedding1[1]) == 1024