-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature mcboaty llm + new node McPrompty (#79)
* Update McBoaty Upscale Mechanism + add Per Tile Prompting generation + add new node to generate prompt from image * hotfix llm vision name calling * hotfix output size upscaling * hotfix tiles order and execution duration display as int * hotfix tiles order display and default llm model * hotfix display index + default LLM Models list * hotfix McPrompty name * hotfix McPrompty name * hotfix prestartup_script * hotfix readme on tile prompting
- Loading branch information
Showing
11 changed files
with
355 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
import os | ||
import requests | ||
import torch | ||
import folder_paths | ||
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer | ||
from transformers import AutoProcessor, AutoModelForVision2Seq | ||
from transformers import BlipProcessor, BlipForConditionalGeneration | ||
from groq import Groq | ||
from .image import MS_Image_v2 as MS_Image | ||
|
||
from ...utils.log import log | ||
|
||
class MS_Llm_Microsoft(): | ||
|
||
@classmethod | ||
def __init__(self, model_name = "microsoft/kosmos-2-patch14-224"): | ||
self.name = model_name | ||
self.model = AutoModelForVision2Seq.from_pretrained(self.name) | ||
self.processor = AutoProcessor.from_pretrained(self.name) | ||
|
||
@classmethod | ||
def generate_prompt(self, image): | ||
|
||
# prompt_prefix = "<grounding>An image of" | ||
prompt_prefix = "" | ||
|
||
_image = MS_Image.tensor2pil(image) | ||
|
||
inputs = self.processor(text=prompt_prefix, images=_image, return_tensors="pt") | ||
|
||
# Generate the caption | ||
generated_ids = self.model.generate( | ||
pixel_values=inputs["pixel_values"], | ||
input_ids=inputs["input_ids"], | ||
attention_mask=inputs["attention_mask"], | ||
image_embeds=None, | ||
image_embeds_position_mask=inputs["image_embeds_position_mask"], | ||
use_cache=True, | ||
max_new_tokens=128, | ||
) | ||
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | ||
caption, _ = self.processor.post_process_generation(generated_text) | ||
|
||
return caption | ||
|
||
|
||
class MS_Llm_Salesforce(): | ||
|
||
@classmethod | ||
def __init__(self, model_name = "Salesforce/blip-image-captioning-large"): | ||
self.name = model_name | ||
self.model = BlipForConditionalGeneration.from_pretrained(self.name) | ||
self.processor = BlipProcessor.from_pretrained(self.name) | ||
|
||
@classmethod | ||
def generate_prompt(self, image): | ||
|
||
# prompt_prefix = "<grounding>An image of" | ||
prompt_prefix = "" | ||
|
||
_image = MS_Image.tensor2pil(image) | ||
|
||
inputs = self.processor(text=prompt_prefix, images=_image, return_tensors="pt") | ||
|
||
# Generate the caption | ||
generated_ids = self.model.generate(**inputs) | ||
caption = self.processor.decode(generated_ids[0], skip_special_tokens=True) | ||
|
||
return caption | ||
|
||
class MS_Llm_Nlpconnect(): | ||
|
||
@classmethod | ||
def __init__(self, model_name = "nlpconnect/vit-gpt2-image-captioning"): | ||
self.name = model_name | ||
self.model = VisionEncoderDecoderModel.from_pretrained(self.name) | ||
self.processor = ViTImageProcessor.from_pretrained(self.name) | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.name) | ||
|
||
@classmethod | ||
def generate_prompt(self, image): | ||
|
||
_image = MS_Image.tensor2pil(image) | ||
inputs = self.processor(images=_image, return_tensors="pt") | ||
# Generate the caption | ||
generated_ids = self.model.generate( | ||
inputs.pixel_values, | ||
max_length=16, | ||
num_beams=4, | ||
num_return_sequences=1 | ||
) | ||
caption = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) | ||
|
||
return caption | ||
|
||
class MS_Llm(): | ||
|
||
LLM_MODELS = [ | ||
# 'gemma-7b-it', | ||
'llama3-70b-8192', | ||
# 'llama3-8b-8192', | ||
# 'mixtral-8x7b-32768', | ||
] | ||
|
||
# list of model https://huggingface.co/models?pipeline_tag=image-to-text&sort=downloads | ||
VISION_LLM_MODELS = [ | ||
# 'nlpconnect/vit-gpt2-image-captioning', | ||
'microsoft/kosmos-2-patch14-224', | ||
# 'Salesforce/blip-image-captioning-large', | ||
] | ||
|
||
@staticmethod | ||
def prestartup_script(): | ||
folder_paths.add_model_folder_path("nlpconnect", os.path.join(folder_paths.models_dir, "nlpconnect")) | ||
|
||
@classmethod | ||
def __init__(self, vision_llm_name = "nlpconnect/vit-gpt2-image-captioning", llm_name = "llama3-8b-8192"): | ||
|
||
if vision_llm_name == 'microsoft/kosmos-2-patch14-224': | ||
self.vision_llm = MS_Llm_Microsoft() | ||
elif vision_llm_name == 'Salesforce/blip-image-captioning-large': | ||
self.vision_llm = MS_Llm_Salesforce() | ||
else: | ||
self.vision_llm = MS_Llm_Nlpconnect() | ||
|
||
self._groq_key = os.getenv("GROQ_API_KEY", "") | ||
self.llm = llm_name | ||
|
||
@classmethod | ||
def generate_tile_prompt(self, image, prompt_context, seed=None): | ||
prompt_tile = self.vision_llm.generate_prompt(image) | ||
if self.vision_llm.name == 'microsoft/kosmos-2-patch14-224': | ||
_prompt = self.get_grok_prompt(prompt_context, prompt_tile) | ||
else: | ||
_prompt = self.get_grok_prompt(prompt_context, prompt_tile) | ||
if self._groq_key != "": | ||
prompt = self.call_grok_api(_prompt, seed) | ||
else: | ||
prompt = _prompt | ||
log(prompt, None, None, self.vision_llm.name) | ||
return prompt | ||
|
||
|
||
@classmethod | ||
def get_grok_prompt(self, prompt_context, prompt_tile): | ||
prompt = [ | ||
f"tile_prompt: \"{prompt_tile}\".", | ||
f"full_image_prompt: \"{prompt_context}\".", | ||
"tile_prompt is part of full_image_prompt.", | ||
"If tile_prompt is describing something different than the full image, correct tile_prompt to match full_image_prompt.", | ||
"if you don't need to change the tile_prompt return the tile_prompt.", | ||
"your answer will strictly and only return the tile_prompt string without any decoration like markdown syntax." | ||
] | ||
return " ".join(prompt) | ||
|
||
@classmethod | ||
def call_grok_api(self, prompt, seed=None): | ||
|
||
client = Groq(api_key=self._groq_key) # Assuming the Groq client accepts an api_key parameter | ||
completion = client.chat.completions.create( | ||
model=self.llm, | ||
messages=[{ | ||
"role": "user", | ||
"content": prompt | ||
}], | ||
temperature=1, | ||
max_tokens=1024, | ||
top_p=1, | ||
stream=False, | ||
stop=None, | ||
seed=seed, | ||
) | ||
|
||
return completion.choices[0].message.content |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding:utf-8 -*- | ||
# | ||
### | ||
# | ||
# Display Info or any string | ||
# | ||
# Largely inspired by PYSSSSS - ShowText | ||
# | ||
### | ||
|
||
from types import SimpleNamespace | ||
|
||
from ...inc.lib.llm import MS_Llm | ||
|
||
from ...utils.log import * | ||
|
||
class PromptFromImage_v1: | ||
|
||
@classmethod | ||
def INPUT_TYPES(s): | ||
return { | ||
"required": { | ||
"image": ("IMAGE", {"label": "image"}), | ||
"vision_llm_model": (MS_Llm.VISION_LLM_MODELS, { "label": "Vision LLM Model", "default": "microsoft/kosmos-2-patch14-224" }), | ||
"llm_model": (MS_Llm.LLM_MODELS, { "label": "LLM Model", "default": "llama3-70b-8192" }), | ||
}, | ||
"hidden": { | ||
"unique_id": "UNIQUE_ID", | ||
"extra_pnginfo": "EXTRA_PNGINFO", | ||
}, | ||
} | ||
|
||
INPUT_IS_LIST = False | ||
FUNCTION = "fn" | ||
OUTPUT_NODE = True | ||
OUTPUT_IS_LIST = (False,) | ||
CATEGORY = "MaraScott/Prompt" | ||
|
||
RETURN_TYPES = ( | ||
"STRING", | ||
) | ||
RETURN_NAMES = ( | ||
"Prompt", | ||
) | ||
|
||
@classmethod | ||
def fn(self, **kwargs): | ||
|
||
self.INPUTS = SimpleNamespace( | ||
image = kwargs.get('image', None) | ||
) | ||
self.LLM = SimpleNamespace( | ||
vision_model_name = kwargs.get('vision_llm_model', None), | ||
model_name = kwargs.get('llm_model', None), | ||
model = None, | ||
) | ||
self.LLM.model = MS_Llm(self.LLM.vision_model_name, self.LLM.model_name) | ||
|
||
self.OUPUTS = SimpleNamespace( | ||
prompt = self.LLM.model.vision_llm.generate_prompt(self.INPUTS.image) | ||
) | ||
|
||
return {"ui": {"text": self.OUPUTS.prompt}, "result": (self.OUPUTS.prompt,)} |
Oops, something went wrong.