-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_loader.py
70 lines (60 loc) · 2.36 KB
/
model_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from langchain_community.llms.ollama import Ollama
from typing import Optional, Dict
import logging
class ModelLoader:
"""
A class to load and interact with AI models.
Attributes:
model_name (str): The name of the model to use.
base_url (str): The base URL of the model API.
context_window (int): The maximum context window size for the model.
"""
def __init__(
self,
model_name: str,
client: Optional[Ollama] = None,
base_url: str = "http://localhost:11434",
):
"""
Initializes the ModelLoader with the specified model name and base URL.
Args:
model_name (str): The name of the model to use.
client (Optional[Ollama]): An instance of the Ollama client. Defaults to None.
base_url (str): The base URL of the model API.
"""
self.base_url = base_url
self.model_name = model_name
self.client = client or Ollama(base_url=base_url)
self.context_window = 32768 # Fixed context window size
self.logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) # Ensure logging is configured
def _get_options(self) -> Dict[str, int]:
"""
Constructs the options for the model prompt.
Returns:
dict: The options for the model prompt.
"""
return {"num_ctx": self.context_window}
def generate(self, prompt: str) -> str:
"""
Generates a response from the AI model based on the provided prompt and context window.
Args:
prompt (str): The input prompt for the model.
Returns:
str: The generated response from the model.
"""
try:
if isinstance(prompt, dict):
prompt = prompt["prompt"] # Ensure prompt is a string
response = self.client.generate(
prompts=[prompt], # Pass the prompt as a list of strings
max_tokens=150,
temperature=0.7,
)
# Access the text from the response object
return response.generations[0][
0
].text # Adjusting to access the first text item
except Exception as e:
self.logger.error(f"Error generating response: {e}")
raise RuntimeError(f"Error generating response: {e}")