-
Notifications
You must be signed in to change notification settings - Fork 433
/
model.py
83 lines (63 loc) 路 2.49 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import dataclasses as dc
import os
import typing as t
import requests
from llama_cpp import Llama
from superduperdb.ext.llm.model import BaseLLM
from superduperdb.misc.annotations import merge_docstrings
# TODO use core downloader already implemented
def download_uri(uri, save_path):
"""Download file.
:param uri: URI to download
:param save_path: place to save
"""
response = requests.get(uri)
if response.status_code == 200:
with open(save_path, 'wb') as file:
file.write(response.content)
else:
raise Exception(f"Error while downloading uri {uri}")
@merge_docstrings
@dc.dataclass(kw_only=True)
class LlamaCpp(BaseLLM):
"""Llama.cpp connector.
:param model_name_or_path: path or name of model
:param model_kwargs: dictionary of init-kwargs
:param download_dir: local caching directory
"""
model_name_or_path: str = "facebook/opt-125m"
model_kwargs: t.Dict = dc.field(default_factory=dict)
download_dir: str = '.llama_cpp'
signature: str = 'singleton'
def init(self):
"""Initialize the model.
If the model_name_or_path is a uri, download it to the download_dir.
"""
if self.model_name_or_path.startswith('http'):
# Download the uri
os.makedirs(self.download_dir, exist_ok=True)
saved_path = os.path.join(self.download_dir, f'{self.identifier}.gguf')
download_uri(self.model_name_or_path, saved_path)
self.model_name_or_path = saved_path
if self.predict_kwargs is None:
self.predict_kwargs = {}
self._model = Llama(self.model_name_or_path, **self.model_kwargs)
def _generate(self, prompt: str, **kwargs) -> str:
"""Generate text from a prompt.
:param prompt: The prompt to generate text from.
:param kwargs: The keyword arguments to pass to the llm model.
"""
out = self._model.create_completion(prompt, **self.predict_kwargs, **kwargs)
return out['choices'][0]['text']
@merge_docstrings
@dc.dataclass
class LlamaCppEmbedding(LlamaCpp):
"""Llama.cpp connector for embeddings."""
def _generate(self, prompt: str, **kwargs) -> str:
"""Generate embedding from a prompt.
:param prompt: The prompt to generate the embedding from.
:param kwargs: The keyword arguments to pass to the llm model.
"""
return self._model.create_embedding(
prompt, embedding=True, **self.predict_kwargs, **kwargs
)