# Create a Concept Generator Service
This Notebook will load the text2image model into the Model Registry and create a service that can be called from outside this Notebook.

# Step 1 - Load Packages and Model

In [None]:
# Install external python packages from huggingface - this gives us easy access to HF models
!pip install diffusers
!pip install huggingface_hub
!pip install sentencepiece

In [None]:
import torch
import pandas as pd
import numpy as np
import streamlit as st
from snowflake.ml.model import custom_model, model_signature
from snowflake.ml.registry import Registry
from diffusers import FluxPipeline
from snowflake.snowpark.context import get_active_session

session = get_active_session()

In [None]:
# Download current pipeline and model to a temporary directory
from huggingface_hub import snapshot_download

model_checkpoint_path = snapshot_download(
    repo_id='black-forest-labs/FLUX.1-schnell',
    ignore_patterns=['flux1-schnell.safetensors']
)

In [None]:
# Create a custom model class for the instantiation and inference of this model
class ImageGenerationModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)
        self.pipeline = FluxPipeline.from_pretrained(
        	context.path("model_path"), 
            local_files_only=True,
            torch_dtype=torch.float16, 
        ).to('cuda')
    
    @custom_model.inference_api
    def predict(self, prompt_df: pd.DataFrame) -> pd.DataFrame:
        prompts = prompt_df.iloc[:, 0].tolist()
        result = self.pipeline(prompts).images
        return pd.DataFrame({"images": [np.array(img).tolist() for img in result]})

In [None]:
# Instantiate the model class with the downloaded model package
path_list = {"model_path": model_checkpoint_path}
img_model = ImageGenerationModel(context=custom_model.ModelContext(artifacts=path_list))

In [None]:
# Generate an image using the inference function, this is used to create the model signature
TEST_PROMPT = 'Picture of a seaside village at night. Anime style'
img = img_model.predict(pd.DataFrame([[TEST_PROMPT]]))

In [None]:
# Infer the model signature and all the required parameters
signature = model_signature.infer_signature(pd.DataFrame([[TEST_PROMPT]]), img)

In [None]:
# Log the model to the Snowflake Model Registry
reg = Registry(session)
mv = reg.log_model(
    img_model,
    model_name='FLUX_1_schnell',
    conda_dependencies=["transformers", "conda-forge::diffusers", "pytorch", "sentencepiece"],
    signatures={"predict":signature},
    options={"cuda_version": "11.8"}
)

In [None]:
# Create a compute pool for GPU access to run this service

# Compute Pool definition
IMAGE_REPO_NAME = "CONCEPT_GEN_SERVICE_REPO"
COMPUTE_POOL_NAME = "CONCEPT_GEN_SERVICE_POOL_L"
COMPUTE_POOL_NODES = 1
COMPUTE_POOL_INSTANCE_TYPE = 'GPU_NV_L'

session.sql(f"create image repository if not exists {IMAGE_REPO_NAME}").collect()
session.sql(f"alter compute pool if exists {COMPUTE_POOL_NAME} stop all").collect()
session.sql(f"drop compute pool if exists {COMPUTE_POOL_NAME}").collect()
session.sql(f"create compute pool if not exists {COMPUTE_POOL_NAME} min_nodes={COMPUTE_POOL_NODES} " +
            f"max_nodes={COMPUTE_POOL_NODES} instance_family={COMPUTE_POOL_INSTANCE_TYPE} " +
            f"initially_suspended=True auto_resume=True auto_suspend_secs=300").collect()

In [None]:
# Create a Service object that can be called easily
# Name of the Service for powering inference
SERVICE_NAME = 'CONCEPT_GEN_SERVICE'

mv.create_service(
    service_name=SERVICE_NAME,
    service_compute_pool=COMPUTE_POOL_NAME,
    image_repo=IMAGE_REPO_NAME,
    gpu_requests="1",
    ingress_enabled=True,
    max_instances=int(COMPUTE_POOL_NODES),
    build_external_access_integration="ALLOW_ALL_INTEGRATION"
)

# Step 2 - Call the new Service and Generate an image

In [None]:
# Call service and output a raw image
model_output = mv.run(pd.DataFrame([[TEST_PROMPT]]), service_name=SERVICE_NAME)
img = model_output["images"][0]

In [None]:
# Process the image
from PIL import Image
import numpy as np

array = np.array(img, dtype=np.uint8)
img_final = Image.fromarray(array)
st.image(img_final)

# Step 3 - Create the end-to-end concept test generator
We will need to cold-call the model from Model Registry

In [None]:
# Define our prompt and brand context (for later cortex call)
CONCEPT_PROMPT = 'a paper towel roll with christmas-style designs on the paper'
BRAND_DESCRIPTION = 'Charisma paper towels'

# Db and Schema details - some of these will need to be created once you've created the Service
DATABASE_NAME = 'CONCEPT_GEN_DB'
SCHEMA_NAME = 'CONCEPT_GEN_SCHEMA'
SERVICE_NAME = 'CONCEPT_GEN_SERVICE'
SELECTED_MODEL = 'FLUX_1_SCHNELL'
MODEL_VERSION = 'HAPPY_RAY_4'  # Get this from Model Registry

In [None]:
import json
import requests
import pandas as pd
import numpy as np
import streamlit as st
import snowflake.snowpark as snowpark
from PIL import Image, ImageDraw, ImageFont
from snowflake.cortex import Complete
from snowflake.ml.registry import registry
from snowflake.snowpark.context import get_active_session


session = get_active_session()
reg = registry.Registry(session=session, database_name=DATABASE_NAME, schema_name=SCHEMA_NAME)
mv = reg.get_model(SELECTED_MODEL).version(MODEL_VERSION)

In [None]:
# Generate the background image from our model
model_output = mv.run(pd.DataFrame([[CONCEPT_PROMPT]]), service_name=SERVICE_NAME)
img = model_output["images"][0]

In [None]:
# Import the brand logo (could be part of a Snowflake stage as well)
img_url = 'https://raw.githubusercontent.com/sfc-gh-pnanisetty/concept-generator-service/refs/heads/main/charisma_paper_towels.png' 
img_logo = Image.open(requests.get(img_url, stream=True).raw)

In [None]:
# Process the image
unprocessed_img = model_output["images"][0]
array = np.array(unprocessed_img, dtype=np.uint8)
img_background = Image.fromarray(array)
st.image(img_background)

In [None]:
# Overlay the two images on top of each other
img_background_temp = img_background.copy()
img_background_temp.paste(img_logo, (0,0), mask=img_logo)
st.image(img_background_temp)

In [None]:
# Create a witty marketing tagline
instruction = 'Please provide a witty advertising tagline that will be displayed at the bottom of the image described below. ' \
              'Please do not provide any additional text, json, or descriptions, just the tagline only. If the tagline is more than ' \
              '10 words long please insert a new line character'
prompt = BRAND_DESCRIPTION + ' brand with a background of ' + CONCEPT_PROMPT
cortex_prompt = [
    {'role':'system', 'content': instruction},
    {'role':'user', 'content': prompt}
]
response = json.loads(Complete('llama3.1-70b', cortex_prompt))
tagline = response['choices'][0]['messages'].replace('"','')
tagline

In [None]:
# Add the tagline to our image
width, height = img_background_temp.size

# Add a larger canvas to add a tagline at the bottom
full_concept_img = Image.new(mode='RGB', size=(width, height+100), color=(255,255,255))
full_concept_img.paste(img_background_temp, (0,0))

# Add the tagline text to the image
font = ImageFont.load_default(size=30)
draw = ImageDraw.Draw(full_concept_img)
_, _, w, h = draw.textbbox((0,0), tagline, font=font)
draw.text(((width-w)/2, height + ((100-h)/2)), tagline, font=font, fill='black')

In [None]:
# See our final concept image!
st.image(full_concept_img)