Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes single file model loading #789

Merged
merged 57 commits into from
Aug 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
6731b86
Fixes single file model loading
w4ffl35 May 12, 2024
ebe97d8
bump version
w4ffl35 May 12, 2024
fa0805f
Partially fixes sd loading
w4ffl35 May 13, 2024
2f6f257
Remove unused imports
w4ffl35 May 13, 2024
14a8893
Stable Diffusion loading improvements
w4ffl35 May 13, 2024
1276f8a
Modify path structure
w4ffl35 May 14, 2024
313de29
Modify path structure
w4ffl35 May 14, 2024
3a1029f
Loading fixes for controlnet
w4ffl35 May 14, 2024
ef80806
Installation fixes
w4ffl35 May 14, 2024
371cd12
Fixes model scanning when models are deleted
w4ffl35 May 14, 2024
79ab0d0
make clear_controlnet public
w4ffl35 May 14, 2024
f19e6f4
Fix scheduler path
w4ffl35 May 14, 2024
2322166
prevent fatal error when model is None
w4ffl35 May 14, 2024
3fd3024
Fixes model unloading
w4ffl35 May 14, 2024
1389db7
fix vae and conrolnet paths in import widget
w4ffl35 May 14, 2024
a3889ee
Adds civitai token to settings and downloader
w4ffl35 May 14, 2024
5ab0016
Fixes causallm loading
w4ffl35 May 14, 2024
2aaa8ec
Closes #802 download all base models
w4ffl35 May 14, 2024
0478d43
Adds new templates for civitai and llama lincense
w4ffl35 May 14, 2024
b25698f
fixes #725 download civitai
w4ffl35 May 14, 2024
d975c29
move history class into separate file
w4ffl35 May 14, 2024
508fdbd
fix nltk path
w4ffl35 May 14, 2024
4431b30
Improves loading message during splash screen
w4ffl35 May 14, 2024
4e670b3
Improve logging messages
w4ffl35 May 14, 2024
86fe5a3
reduce delay when model manager is loading
w4ffl35 May 14, 2024
94943d7
Fixes usage of llama 3
w4ffl35 May 21, 2024
329fa04
Fixes usage of llama 3
w4ffl35 May 21, 2024
16cb9b7
Merge remote-tracking branch 'origin/devastator' into devastator
w4ffl35 May 21, 2024
d29a4fc
fix llamaindex
w4ffl35 May 22, 2024
318350a
Removes chat_history property from stream_chat call
w4ffl35 Jul 18, 2024
5ff8b17
Removes redundant code
w4ffl35 Jul 18, 2024
b3b40c6
Fixes embeddings path
w4ffl35 Jul 18, 2024
f12df97
more enums added
w4ffl35 Jul 18, 2024
9307d94
code formatting change
w4ffl35 Jul 18, 2024
8ea22cb
Remove unused imports
w4ffl35 Jul 18, 2024
198fc66
Fixes import models in import widget
w4ffl35 Jul 18, 2024
684e0e3
Fix LLM bootstrap key
w4ffl35 Jul 18, 2024
0f8c02f
Fix lora toggle
w4ffl35 Jul 18, 2024
045aad3
Fix lora path
w4ffl35 Jul 18, 2024
10c1496
update vae status for stats widget
w4ffl35 Jul 18, 2024
3cad7b3
Remove unused imports
w4ffl35 Jul 18, 2024
62bf9ef
Fix for vae and unet loading
w4ffl35 Jul 18, 2024
71aab24
fix model path
w4ffl35 Jul 18, 2024
c3a896e
catch errors when moving pipe
w4ffl35 Jul 18, 2024
0084771
Fixes for embeddings
w4ffl35 Jul 18, 2024
9d182f3
Remove redundant code
w4ffl35 Jul 18, 2024
6c9b1c7
Fix default paths
w4ffl35 Jul 18, 2024
d503d38
add vae_models to settings
w4ffl35 Jul 18, 2024
fb6f07b
adds vae settings
w4ffl35 Jul 18, 2024
47ff751
ads vae scanner worker
w4ffl35 Jul 18, 2024
e02b729
Caching: Implemented caching for directory scans to reduce I/O operat…
w4ffl35 Aug 8, 2024
fd03d1e
improvements to embedding and lora widget loading speed
w4ffl35 Aug 18, 2024
b4dca8b
adds image preset for prompts
w4ffl35 Aug 18, 2024
dd26dc8
remove old code
w4ffl35 Aug 18, 2024
ae1d95d
improvements to loading time
w4ffl35 Aug 18, 2024
4ffa965
improvements to loading time
w4ffl35 Aug 18, 2024
16c5259
version bump
w4ffl35 Aug 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='airunner',
version="3.0.0.dev10",
version="3.0.0.dev12",
author="Capsize LLC",
description="A Stable Diffusion GUI",
long_description=open("README.md", "r", encoding="utf-8").read(),
Expand Down
180 changes: 122 additions & 58 deletions src/airunner/aihandler/llm/agent/agent_llamaindex_mixin.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import os.path
import string
from typing import Optional, List

from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, ServiceContext, StorageContext
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, ServiceContext, StorageContext, PromptHelper, \
SimpleKeywordTableIndex
from llama_index.core.chat_engine import ContextChatEngine
from llama_index.core.data_structs import IndexDict
from llama_index.core.indices.keyword_table import KeywordTableSimpleRetriever
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.core.schema import TransformComponent
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.readers.file import EpubReader, PDFReader, MarkdownReader
from llama_index.core import Settings
from airunner.aihandler.llm.agent.html_file_reader import HtmlFileReader
from airunner.enums import SignalCode, LLMChatRole, AgentState
from airunner.enums import SignalCode, LLMChatRole, AgentState, LLMActionType


class AgentLlamaIndexMixin:
def __init__(self):
self.__documents = None
self.__index = None
self.__chat_engine = None
self.__retriever = None
self.__service_context: Optional[ServiceContext] = None
self.__storage_context: StorageContext = None
self.__transformations: Optional[List[TransformComponent]] = None
Expand Down Expand Up @@ -89,49 +94,66 @@ def load_rag(self, model, tokenizer):
self.__load_readers()
self.__load_file_extractor()
self.__load_documents()
self.__load_text_splitter()
self.__load_prompt_helper()
self.__load_service_context()
self.__load_document_index()
self.__load_retriever()
self.__load_context_chat_engine()

# self.__load_storage_context()
# self.__load_transformations()
# self.__load_index_struct()
self.__load_document_index()

def __load_llm(self, model, tokenizer):
self.__llm = HuggingFaceLLM(
model=model,
tokenizer=tokenizer,
max_new_tokens=1000,
generate_kwargs=dict(
top_k=50,
top_p=0.95,
temperature=0.9,
num_return_sequences=1,
num_beams=1,
no_repeat_ngram_size=3,
early_stopping=True,
do_sample=True,
# pad_token_id=tokenizer.eos_token_id,
# eos_token_id=tokenizer.eos_token_id,
# bos_token_id=tokenizer.bos_token_id,
try:
self.__llm = HuggingFaceLLM(
model=model,
tokenizer=tokenizer,
max_new_tokens=4096,
generate_kwargs=dict(
top_k=40,
top_p=0.90,
temperature=0.5,
num_return_sequences=1,
num_beams=1,
no_repeat_ngram_size=4,
early_stopping=True,
do_sample=True,
)
)
)
except Exception as e:
self.logger.error(f"Error loading LLM: {str(e)}")

@property
def is_llama_instruct(self):
return True

def perform_rag_search(
self,
prompt,
streaming: bool = False,
response_mode: ResponseMode = ResponseMode.COMPACT
self,
prompt,
streaming: bool = False,
response_mode: ResponseMode = ResponseMode.COMPACT
):
if self.__chat_engine is None:
raise RuntimeError(
"Chat engine is not initialized. "
"Please ensure __load_service_context "
"is called before perform_rag_search."
)

self.add_message_to_history(
prompt,
LLMChatRole.HUMAN
)

if response_mode in (
ResponseMode.ACCUMULATE
):
streaming = False

try:
query_engine = self.__index.as_query_engine(
streaming=streaming,
response_mode=response_mode,
)
print(f"Querying with prompt: {prompt}") # Debug: Show the prompt
engine = self.__chat_engine
except AttributeError as e:
self.logger.error(f"Error performing RAG search: {str(e)}")
if streaming:
Expand All @@ -145,14 +167,21 @@ def perform_rag_search(
)
)
return
response = query_engine.query(prompt)

inputs:str = self.get_rendered_template(LLMActionType.PERFORM_RAG_SEARCH, [])

response = engine.stream_chat(
message=inputs
)
response_text = ""
if streaming:
self.emit_signal(SignalCode.UNBLOCK_TTS_GENERATOR_SIGNAL)
is_first_message = True
is_end_of_message = False
for res in response.response_gen:
response_text += res
if response_text: # Only add a space if response_text is not empty
response_text += " "
response_text += res.strip()
self.emit_signal(
SignalCode.LLM_TEXT_STREAMED_SIGNAL,
dict(
Expand All @@ -163,10 +192,18 @@ def perform_rag_search(
)
)
is_first_message = False
self.add_message_to_history(
response_text,
LLMChatRole.ASSISTANT
)
response_text = ""
else:
response_text = response.response
is_first_message = True
self.add_message_to_history(
response_text,
LLMChatRole.ASSISTANT
)

self.emit_signal(
SignalCode.LLM_TEXT_STREAMED_SIGNAL,
Expand All @@ -178,11 +215,6 @@ def perform_rag_search(
)
)

self.add_message_to_history(
response_text,
LLMChatRole.ASSISTANT
)

return response

def __load_rag_model(self):
Expand Down Expand Up @@ -221,16 +253,50 @@ def __load_documents(self):
self.logger.error(f"Error loading documents: {str(e)}")
self.__documents = None

def __load_service_context(self):
self.logger.debug("Loading service context...")
self.__service_context = ServiceContext.from_defaults(
llm=self.__llm,
embed_model=Settings.embed_model,
chunk_size=self.__chunk_size,
chunk_overlap=self.__chunk_overlap,
system_prompt="Search the full text and find all relevant information related to the query.",
def __load_text_splitter(self):
self.__text_splitter = SentenceSplitter(
chunk_size=256,
chunk_overlap=20
)

def __load_prompt_helper(self):
self.__prompt_helper = PromptHelper(
context_window=4096,
num_output=1024,
chunk_overlap_ratio=0.1,
chunk_size_limit=None,
)

def __load_context_chat_engine(self):
context_retriever = self.__retriever # Your method to retrieve context

try:
self.__chat_engine = ContextChatEngine.from_defaults(
retriever=context_retriever,
service_context=self.__service_context,
chat_history=self.history,
memory=None, # Define or use an existing memory buffer if needed
system_prompt="Search the full text and find all relevant information related to the query.",
node_postprocessors=[], # Add postprocessors if utilized in your setup
llm=self.__llm, # Use the existing LLM setup
)
except Exception as e:
self.logger.error(f"Error loading chat engine: {str(e)}")

def __load_service_context(self):
self.logger.debug("Loading service context with ContextChatEngine...")
try:
# Update service context to use the newly created chat engine
self.__service_context = ServiceContext.from_defaults(
llm=self.__llm,
embed_model=Settings.embed_model,
#chat_engine=self.__chat_engine, # Include the chat engine in the service context
text_splitter=self.__text_splitter,
prompt_helper=self.__prompt_helper,
)
except Exception as e:
self.logger.error(f"Error loading service context with chat engine: {str(e)}")

# def __load_storage_context(self):
# self.logger.debug("Loading storage context...")
# path = os.path.expanduser(self.settings["path_settings"]["storage_path"])
Expand Down Expand Up @@ -284,25 +350,23 @@ def print_chunks(self):
for chunk in chunks:
print(chunk)

def print_indexed_chunks(self):
# Assuming get_indexed_nodes is a method that returns the indexed nodes
if self.__index is not None:
node_doc_ids = list(self.__index.index_struct.nodes_dict.values())
indexed_nodes = self.__index.docstore.get_nodes(node_doc_ids)
for i, node in enumerate(indexed_nodes):
print(f"Chunk {i + 1}: {node.text}") # Print first 200 characters of each chunk

def __load_document_index(self):
self.logger.debug("Loading index...")
try:
self.__index = VectorStoreIndex.from_documents(
self.__index = SimpleKeywordTableIndex.from_documents(
self.__documents,
service_context=self.__service_context,
# storage_context=self.__storage_context,
# transformations=self.__transformations,
# index_struct=self.__index_struct
)
self.logger.debug("Index loaded successfully.")
except TypeError as e:
self.logger.error(f"Error loading index: {str(e)}")

self.print_indexed_chunks()
def __load_retriever(self):
try:
self.__retriever = KeywordTableSimpleRetriever(
index=self.__index,
)
self.logger.debug("Retriever loaded successfully with index.")
except Exception as e:
self.logger.error(f"Error setting up the retriever: {str(e)}")

38 changes: 11 additions & 27 deletions src/airunner/aihandler/llm/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def build_system_prompt(
self.history_prompt(),
]
system_prompt = self.add_vision_prompt(vision_history, system_prompt)
system_prompt = self.append_date_time_timezone(system_prompt)

if self.chatbot["use_datetime"]:
system_prompt = self.append_date_time_timezone(system_prompt)

elif action == LLMActionType.ANALYZE_VISION_HISTORY:
vision_history = vision_history[-10:] if len(vision_history) > 10 else vision_history
Expand Down Expand Up @@ -323,27 +325,11 @@ def prepare_messages(

return messages

@property
def _chat_template(self):
return (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"{{ '[INST] <<SYS>>' + message['content'] + ' <</SYS>>[/INST]' }}"
"{% elif message['role'] == 'user' %}"
"{{ '[INST]Consider the full chat history and then respond to this message from {{ username }}: ' + message['content'] + ' [/INST]' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ message['content'] + eos_token + ' ' }}"
"{% endif %}"
"{% endfor %}"
) if self.is_mistral else None

def get_rendered_template(
self,
action,
vision_history
):
conversation = []

action: LLMActionType,
vision_history: list
) -> str:
conversation = self.prepare_messages(
action,
vision_history=vision_history
Expand All @@ -358,7 +344,7 @@ def get_rendered_template(
)

rendered_template = self.tokenizer.apply_chat_template(
chat_template=self._chat_template,
chat_template=self.chat_template,
conversation=conversation,
tokenize=False
)
Expand Down Expand Up @@ -398,12 +384,10 @@ def generator_settings(self):

def get_model_inputs(
self,
action,
vision_history,
action: LLMActionType,
vision_history: list,
**kwargs
):
self.chat_template = kwargs.get("chat_template", self.chat_template)

self.rendered_template = self.get_rendered_template(
action,
vision_history
Expand All @@ -424,7 +408,7 @@ def get_model_inputs(
def run(
self,
prompt: str,
action,
action: str,
vision_history: list = [],
**kwargs
):
Expand All @@ -445,7 +429,7 @@ def run(

def do_run(
self,
action,
action: LLMActionType,
vision_history: list = [],
streamer=None,
do_emit_response: bool = True,
Expand Down
Loading