<a href="https://colab.research.google.com/github/LucianoRodriguez0764/ipadapter_faceid_colab/blob/main/Ipadapter_faceID_Latest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installing dependencies and download the IP-Adapter-FaceID-Plus-v2 model

In [71]:
from IPython.display import clear_output as clear_output_py

!pip install insightface
!pip install onnxruntime
!pip install diffusers
!pip install git+https://github.com/tencent-ailab/IP-Adapter.git
!pip install einops

clear_output_py()

In [None]:
!wget -O ip-adapter-faceid-plusv2_sd15.bin https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15.bin

clear_output_py()

# Load the models to the pipeline

In [None]:
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
from PIL import Image as PILImage

v2 = True
base_model_path = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
vae_model_path = "stabilityai/sd-vae-ft-mse"
image_encoder_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
ip_ckpt = "ip-adapter-faceid-plus_sd15.bin" if not v2 else "ip-adapter-faceid-plusv2_sd15.bin"
device = "cuda"

noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    vae=vae,
    feature_extractor=None,
    safety_checker=None
)

clear_output_py()

# Load Face embeds and FaceID Image

In [None]:
import os

import shutil

folder_path = "/content/images"

try:
    shutil.rmtree(folder_path)
    print(f"Deleted folder: {folder_path}")
except FileNotFoundError:
    print(f"Folder not found: {folder_path}")
except Exception as e:
    print(f"An error occurred: {e}")


image_path = "images/embeds"
image_path_face = "images/face"


os.makedirs(image_path, exist_ok=True)
os.makedirs(image_path_face, exist_ok=True)

clear_output_py()

In [None]:
import cv2
from insightface.app import FaceAnalysis
from insightface.utils import face_align
import torch
import glob
from PIL import Image as PILImage

app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))

app_lower = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app_lower.prepare(ctx_id=0, det_size=(320, 320))

clear_output_py()

imgs = glob.glob(os.path.join(image_path, '*.*'))
#imgs += glob.glob(os.path.join('images/face_2', '*.*'))

# Collect all embeddings
embeddings = []

if len(imgs)==0:
  print("You need to put face images on images/embeds for face-embedding and one image of face on images/face for face-id")

for i in range(len(imgs)):
    print(imgs[i])
    image = cv2.imread(imgs[i])
    faces = app.get(image)

    # Check if a face is detected
    if faces:

      face_image = face_align.norm_crop(image, landmark=faces[0].kps, image_size=224)
      display(face_image)

      faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
      embeddings.append(faceid_embed)
    else:
      print("face not detected with 640x640, trying with 320x320")
      faces = app_lower.get(image)
      if faces:

        face_image = face_align.norm_crop(image, landmark=faces[0].kps, image_size=224)
        display(face_image)

        faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
        embeddings.append(faceid_embed)
      else:
        print("face not detected.")

# Calculate the average embedding for the face ID
use_image_weights = False
weights = []
weights = [elem/sum(weights) for elem in weights] # normalization

if embeddings:
    ### weights
    if use_image_weights:
      total = torch.zeros_like(embeddings[0])
      for i in range(len(embeddings)):
          if len(weights)==len(embeddings):
              total.add_(embeddings[i]*weights[i])
          else:
              total.add_(embeddings[i]*(1/len(embeddings)))
      faceid_embeds_avg = total / len(embeddings[i])
    ###  mean weights
    else:
      faceid_embeds = torch.mean(torch.stack(embeddings), dim=0)

    print("Collected face ID embedding")
else:
    print("No faces detected in the images provided.")

In [None]:
imgs_2 = glob.glob(os.path.join(image_path_face, '*.*'))

image1 = cv2.imread(imgs_2[0])
#image1 = cv2.resize(image1, (640, 640))

faces1 = app.get(image1)


if faces1:
  print("face image detected.")
  face_image = face_align.norm_crop(image1, landmark=faces1[0].kps, image_size=224)
  display(face_image)
else:
  print("face not detected with 640x640, trying with 320x320")
  faces1 = app_lower.get(image1)
  if faces1:
    print("face image detected.")
    face_image = face_align.norm_crop(image1, landmark=faces1[0].kps, image_size=224)
    display(face_image)
  else:
    print("face not detected.")

# Load the functions

In [None]:
from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlus
import ipywidgets as widgets
import math

ip_model = IPAdapterFaceIDPlus(pipe, image_encoder_path, ip_ckpt, device)

clear_output_py()

image_output = widgets.Output()
text_output = widgets.Output()

# UNUSED
def display_images(images):
    image_widgets = [widgets.Image(value=img._repr_png_()) for img in images]
    display(widgets.HBox(image_widgets))

def generate_images():

  global ip_model, num_samples, num_inference_steps, guidance_scale, prompt, negative_prompt
  global face_image, faceid_embeds, seed, width, height, s_scale, shortcut
  global x_plot_generation, y_plot_generation, x_plot_parameter, y_plot_parameter
  global x_plot_values, y_plot_values, xy_different_seeds
  global image_output, text_output

  if num_inference_steps==0:
    num_inference_steps=1

  if seed == -1:
    seed = random.randint(1, 999999999)

  with text_output:
    print("samp:",num_samples, "steps:",num_inference_steps, "gui:",guidance_scale,"seed:", seed,"w:", width, "h:",height,"s_sc:", s_scale,
            "v2:",shortcut,"xplot:",x_plot_generation,"yplot:",y_plot_generation,"x_par:",x_plot_parameter,"y_par:",y_plot_parameter,"x_vals:",x_plot_values,"y_vals:",y_plot_values,"xy_rnd:",xy_different_seeds,"prompt:",prompt,", neg prompt:", negative_prompt)


    rows = []


    if x_plot_generation:
      num_samples=1
      if y_plot_generation:
        images_array = []
        for h in range(len(y_plot_values)):
          images = []
          if y_plot_parameter in globals():
            if y_plot_parameter == "face_image" or y_plot_parameter == "faceid_embeds":
              globals()[y_plot_parameter] = globals()[y_plot_values[h]]
              print(y_plot_parameter,"set to:",y_plot_values[h])
            else:
              globals()[y_plot_parameter] = y_plot_values[h]
              print(y_plot_parameter,"set to:",globals()[y_plot_parameter])
          for i in range(len(x_plot_values)):
            if x_plot_parameter == "face_image" or x_plot_parameter == "faceid_embeds":
              globals()[x_plot_parameter] = globals()[x_plot_values[i]]
              print(x_plot_parameter,"set to:",x_plot_values[i])
            else:
              globals()[x_plot_parameter] = x_plot_values[i]
              print(x_plot_parameter,"set to:",globals()[x_plot_parameter])
            if xy_different_seeds and i!=0:
              seed = random.randint(1, 999999999)
            print("Seed: ", seed)
            generated_image = ip_model.generate(
                prompt=prompt+", "+extra_prompt, negative_prompt=negative_prompt, face_image=face_image, faceid_embeds=faceid_embeds, num_samples=num_samples, width=width, height=height, num_inference_steps=num_inference_steps, seed=seed,
                guidance_scale=guidance_scale, shortcut=shortcut, s_scale=s_scale,
            )[0]
            images.append(generated_image)
          rows.append(widgets.HBox([widgets.Image(value=img._repr_png_()) for img in images]))
        #images_array.append(images)

        '''
        for img in images_array:
          display_images(img)
        return
        '''
      else:
        images = []
        for i in range(len(x_plot_values)):
          if x_plot_parameter == "face_image" or x_plot_parameter == "faceid_embeds":
            globals()[x_plot_parameter] = globals()[x_plot_values[i]]
            print(x_plot_parameter,"set to:",x_plot_values[i])
          else:
            globals()[x_plot_parameter] = x_plot_values[i]
            print(x_plot_parameter,"set to:",globals()[x_plot_parameter])
          if xy_different_seeds and i!=0:
            seed = random.randint(1, 999999999)

          print("Seed: ", seed)
          images.extend(ip_model.generate(
          prompt=prompt+", "+extra_prompt, negative_prompt=negative_prompt, face_image=face_image, faceid_embeds=faceid_embeds, num_samples=num_samples, width=width, height=height, num_inference_steps=num_inference_steps, seed=seed,
          guidance_scale=guidance_scale, shortcut=True, s_scale=s_scale))
        rows.append(widgets.HBox([widgets.Image(value=img._repr_png_()) for img in images]))
    else:
      print("Seed: ", seed)
      images = ip_model.generate(
          prompt=prompt+", "+extra_prompt, negative_prompt=negative_prompt, face_image=face_image, faceid_embeds=faceid_embeds, num_samples=num_samples, width=width, height=height, num_inference_steps=num_inference_steps, seed=seed,
          guidance_scale=guidance_scale, shortcut=shortcut, s_scale=s_scale
      )
      rows.append(widgets.HBox([widgets.Image(value=img._repr_png_()) for img in images]))

  with image_output:
    for row in rows:
      display(row)

In [65]:
import ipywidgets as widgets
from IPython.display import display

slider_style = {'description_width': '200px'}
slider_layout = widgets.Layout(width='500px')

text_input = widgets.Text(placeholder="Add prompts separated by commas",style=slider_style, layout=slider_layout)
add_button = widgets.Button(description="Add tag")


prompt_array = []
neg_prompt_array = []
x_parameter_value_array = []
y_parameter_value_array = []

tags_box = widgets.HBox([])

def add_tag(b):
    text = text_input.value
    tags = text.split(",")

    for i in range(len(tags)):
        tags[i] = tags[i].strip()

    for tag in tags:
      if tag!="":
        tag_wi = widgets.HBox([
            widgets.HTML(value=f"<b>{tag}</b>"),
            widgets.Button(description="x", layout=widgets.Layout(width='30px'))
        ])
        prompt_array.append(tag)

        tags_box.children = list(tags_box.children) + [tag_wi]
        tag_wi.children[1].on_click(lambda b, tag_wi=tag_wi, tag=tag: remove_tag(tag_wi,tag))
    text_input.value = ""

def remove_tag(tag_wi,tag):
    global prompt_array
    tags_box.children = [child for child in tags_box.children if child != tag_wi]
    prompt_array = [x for x in prompt_array if x != tag]

add_button.on_click(add_tag)

######## NEGATIVE PROMPTS ###########

text_input_neg = widgets.Text(placeholder="Add negative prompts separated by commas",style=slider_style, layout=slider_layout)
add_button_neg = widgets.Button(description="Add tag")

tags_box_neg = widgets.HBox([])

def add_tag_neg(b):
    text = text_input_neg.value
    tags = text.split(",")

    for i in range(len(tags)):
        tags[i] = tags[i].strip()

    for tag in tags:
      if tag!="":
        tag_wi = widgets.HBox([
            widgets.HTML(value=f"<b>{tag}</b>"),
            widgets.Button(description="x", layout=widgets.Layout(margin="0px, 0px, 0px, 5px",width='30px'))
        ])
        neg_prompt_array.append(tag)

        tags_box_neg.children = list(tags_box_neg.children) + [tag_wi]
        tag_wi.children[1].on_click(lambda b, tag_wi=tag_wi, tag=tag: remove_tag_neg(tag_wi,tag))
    text_input_neg.value = ""


def remove_tag_neg(tag_wi,tag):
    global neg_prompt_array
    tags_box_neg.children = [child for child in tags_box_neg.children if child != tag_wi]
    neg_prompt_array = [x for x in neg_prompt_array if x != tag]

add_button_neg.on_click(add_tag_neg)


########### x plot generation ###########


text_input_x = widgets.Text(placeholder="X Parameter values", disabled=True)
add_button_x = widgets.Button(description="Add value", disabled=True)

tags_box_x = widgets.HBox([])

def add_tag_x(b):
    text = text_input_x.value
    tags = text.split(",")

    for i in range(len(tags)):
        tags[i] = tags[i].strip()

    for tag in tags:
      if tag!="":
        tag_wi = widgets.HBox([
            widgets.HTML(value=f"<b>{tag}</b>"),
            widgets.Button(description="x", layout=widgets.Layout(width='30px'))
        ])
        x_parameter_value_array.append(tag)
        tags_box_x.children = list(tags_box_x.children) + [tag_wi]
        tag_wi.children[1].on_click(lambda b, tag_wi=tag_wi, tag=tag: remove_tag_x(tag_wi,tag))
    text_input_x.value = ""



def remove_tag_x(tag_wi,tag):
    global x_parameter_value_array
    tags_box_x.children = [child for child in tags_box_x.children if child != tag_wi]

    if x_plot_parameter == "num_inference_steps":
      tag = int(tag)
    elif x_plot_parameter in ["guidance_scale", "s_scale"]:
      tag = float(tag)

    x_parameter_value_array = [x for x in x_parameter_value_array if x != tag]

add_button_x.on_click(add_tag_x)

########### y plot generation ###########


text_input_y = widgets.Text(placeholder="Y Parameter values", disabled=True)
add_button_y = widgets.Button(description="Add value", disabled=True)

tags_box_y = widgets.HBox([])

def add_tag_y(b):
    text = text_input_y.value
    tags = text.split(",")

    for i in range(len(tags)):
        tags[i] = tags[i].strip()

    for tag in tags:
      if tag!="":
        tag_wi = widgets.HBox([
            widgets.HTML(value=f"<b>{tag}</b>"),
            widgets.Button(description="x", layout=widgets.Layout(width='30px'))
        ])
        y_parameter_value_array.append(tag)
        tags_box_y.children = list(tags_box_y.children) + [tag_wi]
        tag_wi.children[1].on_click(lambda b, tag_wi=tag_wi, tag=tag: remove_tag_y(tag_wi,tag))
    text_input_y.value = ""

def remove_tag_y(tag_wi,tag):
    global y_parameter_value_array
    tags_box_y.children = [child for child in tags_box_y.children if child != tag_wi]

    if y_plot_parameter == "num_inference_steps":
      tag = int(tag)
    elif y_plot_parameter in ["guidance_scale", "s_scale"]:
      tag = float(tag)

    y_parameter_value_array = [y for y in y_parameter_value_array if y != tag]

add_button_y.on_click(add_tag_y)

# Generate images with the IP-Adapter

In [None]:
from re import X
import random
import ipywidgets as widgets
from IPython.display import display, clear_output


#### DEFAULT SETTINGS ####

width=512
height=512
num_samples=1
num_inference_steps=30
guidance_scale = 3
s_scale = 1
x_plot_generation = False
y_plot_generation = False
x_plot_values = []
y_plot_values = []
tags_box = widgets.HBox([])
tags_box_neg = widgets.HBox([])
tags_box_x = widgets.HBox([])
tags_box_y = widgets.HBox([])
prompt_array = []
neg_prompt_array = []
x_parameter_value_array = []
y_parameter_value_array = []
x_plot_parameter = "num_inference_steps"
y_plot_parameter = "guidance_scale"
initial_prompt = ""
extra_prompt = ""
initial_negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime), text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, nfsw, extra legs, fused fingers, too many fingers, long neck"
shortcut = True
xy_different_seeds = False
face_id_embeds = faceid_embeds
face_image = face_image
seed = -1

width_selector = widgets.Dropdown(
    options=[256, 288, 384, 512, 640, 768, 1024],
    value=512,
    description='Width:',
    style=slider_style,
    layout=slider_layout
)

height_selector = widgets.Dropdown(
    options=[256, 288, 384, 512, 640, 768, 1024],
    value=512,
    description='Height:',
    style=slider_style,
    layout=slider_layout
)


def on_size_change(change):
    global width, height
    #print(f"Selected Width: {width_selector.value}, Selected Height: {height_selector.value}")
    width = width_selector.value
    height = height_selector.value

width_selector.observe(on_size_change, names='value')
height_selector.observe(on_size_change, names='value')

num_inference_steps_slider = widgets.IntSlider(
    value=30,
    min=0,
    max=100,
    step=5,
    description='Num inference steps:',
    style=slider_style,
    layout=slider_layout
)

def on_num_inference_steps_scale_change(change):
    global num_inference_steps
    num_inference_steps = num_inference_steps_slider.value

num_inference_steps_slider.observe(on_num_inference_steps_scale_change, names='value')

num_samples_slider = widgets.IntSlider(
    value=1,
    min=1,
    max=6,
    step=1,
    description='Num samples:',
    style=slider_style,
    layout=slider_layout
)

def on_num_samples_scale_change(change):
    global num_samples
    num_samples = num_samples_slider.value

num_samples_slider.observe(on_num_samples_scale_change, names='value')

guidance_scale_slider = widgets.FloatSlider(
    value=3,
    min=0.5,
    max=15,
    step=0.5,
    description='Guidance Scale:',
    style=slider_style,
    layout=slider_layout
)

def on_guidance_scale_change(change):
    global guidance_scale
    guidance_scale = guidance_scale_slider.value

guidance_scale_slider.observe(on_guidance_scale_change, names='value')

seed_slider = widgets.IntSlider(
    value=-1,
    min=-1,
    max=999999999,
    step=1,
    description='Seed:',
    style=slider_style,
    layout=slider_layout
)

def on_seed_change(change):
    global seed
    seed = seed_slider.value

seed_slider.observe(on_seed_change, names='value')

s_scale_slider = widgets.FloatSlider(
    value=1,
    min=0.1,
    max=3,
    step=0.1,
    description='s_scale:',
    style=slider_style,
    layout=slider_layout
)

def on_s_scale_change(change):
    global s_scale
    s_scale = s_scale_slider.value
    #print(f"Selected Guidance Scale: {guidance_scale_slider.value}")

s_scale_slider.observe(on_s_scale_change, names='value')

x_plot_selector = widgets.Dropdown(
    options=['num_inference_steps', 'guidance_scale','s_scale','extra_prompt','face_image'],
    value='num_inference_steps',
    description='x plot parameter:',
    disabled=True,
    style=slider_style,
    layout=slider_layout
)

y_plot_selector = widgets.Dropdown(
    options=['num_inference_steps', 'guidance_scale','s_scale','extra_prompt','face_image'],
    value='guidance_scale',
    description='y plot parameter:',
    disabled=True,
    style=slider_style,
    layout=slider_layout
)

xy_randomseed_checkbox = widgets.Checkbox(
    value=False,
    description='random seed when XY plot',
    disabled=False,
    indent=True,
    style=slider_style,
    layout=slider_layout
)

def y_plot_on_change(change):
    global y_plot_parameter
    y_plot_parameter = y_plot_selector.value

def x_plot_on_change(change):
    global x_plot_parameter
    x_plot_parameter = x_plot_selector.value

def xy_randomseed_checkbox_on_change(change):
    global xy_different_seeds
    xy_different_seeds = xy_randomseed_checkbox.value

x_plot_checkbox = widgets.Checkbox(
    value=False,
    description='x plot generation',
    disabled=False,
    indent=True,
    style=slider_style,
    layout=slider_layout
)

y_plot_checkbox = widgets.Checkbox(
    value=False,
    description='y plot generation',
    disabled=False,
    indent=True,
    style=slider_style,
    layout=slider_layout
)

text_output.clear_output()
image_output.clear_output()

x_plot_selector.observe(x_plot_on_change, names='value')
y_plot_selector.observe(y_plot_on_change, names='value')
xy_randomseed_checkbox.observe(xy_randomseed_checkbox_on_change, names='value')

def x_plot_checkbox_change(change):
    global x_plot_generation
    x_plot_generation = x_plot_checkbox.value

    x_plot_selector.disabled= not x_plot_checkbox.value
    text_input_x.disabled= not x_plot_checkbox.value
    add_button_x.disabled= not x_plot_checkbox.value
    tags_box_x.disabled= not x_plot_checkbox.value

def y_plot_checkbox_change(change):
    global y_plot_generation
    y_plot_generation = y_plot_checkbox.value

    y_plot_selector.disabled= not y_plot_checkbox.value
    text_input_y.disabled= not y_plot_checkbox.value
    add_button_y.disabled= not y_plot_checkbox.value
    tags_box_y.disabled= not y_plot_checkbox.value

x_plot_checkbox.observe(x_plot_checkbox_change, names='value')
y_plot_checkbox.observe(y_plot_checkbox_change, names='value')


generate_button = widgets.Button(description="Generate images")
generate_button.layout = widgets.Layout(width='300px', height="40px")

def on_generate_click(b):
    global ip_model, num_samples, num_inference_steps, guidance_scale, prompt, negative_prompt, face_image, faceid_embeds, seed, width, height, s_scale, shortcut,x_plot_generation,y_plot_generation,x_plot_parameter,y_plot_parameter,x_plot_values,y_plot_values,xy_different_seeds,prompt_array,neg_prompt_array,x_parameter_value_array,y_parameter_value_array
    prompt = initial_prompt+", "+", ".join(prompt_array)
    negative_prompt = initial_negative_prompt+", "+", ".join(neg_prompt_array)

    if x_plot_parameter in ["guidance_scale", "s_scale"] and x_plot_generation:
      x_plot_values = [float(value) for value in x_parameter_value_array]
      x_parameter_value_array = [float(value) for value in x_parameter_value_array]
    elif x_plot_parameter=="num_inference_steps":
      x_plot_values = [int(value) for value in x_parameter_value_array]
      x_parameter_value_array = [int(value) for value in x_parameter_value_array]
    else:
      x_plot_values = x_parameter_value_array

    if y_plot_parameter in ["guidance_scale", "s_scale"] and y_plot_generation:
      y_plot_values = [float(value) for value in y_parameter_value_array]
      y_parameter_value_array = [float(value) for value in y_parameter_value_array]
    elif y_plot_parameter=="num_inference_steps":
      y_plot_values = [int(value) for value in y_parameter_value_array]
      y_parameter_value_array = [int(value) for value in y_parameter_value_array]
    else:
      y_plot_values = y_parameter_value_array

    text_output.clear_output()
    image_output.clear_output()
    generate_images()

    seed = -1


generate_button.on_click(on_generate_click)
generate_button_container = widgets.HBox([generate_button])
generate_button_container.layout = widgets.Layout(width='500px',justify_content="center",display="flex")


add_button_container = widgets.HBox([add_button])
add_button_container.layout = widgets.Layout(width='500px',justify_content="center", display="flex")

add_button_neg_container = widgets.HBox([add_button_neg])
add_button_neg_container.layout = widgets.Layout(width='500px',justify_content="center", display="flex")

display(text_input, add_button_container, tags_box)
display(text_input_neg, add_button_neg_container, tags_box_neg)

display(width_selector, height_selector,num_samples_slider, num_inference_steps_slider, guidance_scale_slider,
        seed_slider, s_scale_slider)

display(x_plot_checkbox,y_plot_checkbox,x_plot_selector,y_plot_selector,xy_randomseed_checkbox)

add_button_x_container = widgets.HBox([add_button_x])
add_button_x_container.layout = widgets.Layout(width='500px',justify_content="center", display="flex")
add_button_y_container = widgets.HBox([add_button_y])
add_button_y_container.layout = widgets.Layout(width='500px',justify_content="center", display="flex")
text_input_x_container = widgets.HBox([text_input_x])
text_input_x_container.layout = widgets.Layout(width='500px',justify_content="center", display="flex")
text_input_y_container = widgets.HBox([text_input_y])
text_input_y_container.layout = widgets.Layout(width='500px',justify_content="center", display="flex")

display(text_input_x_container, add_button_x_container, tags_box_x)
display(text_input_y_container, add_button_y_container, tags_box_y)

display(generate_button_container)

display(text_output, image_output)