## Init

In [None]:
! pip install diffusers==0.26.3
! pip install ipycanvas
! pip install -U "huggingface_hub[cli]"
! pip install peft
! huggingface-cli login
! pip install transforms xformers torchvision sgmllib3k

In [None]:
! git clone https://huggingface.co/briaai/RMBG-1.4 ./RMBG
! pip install -r ./RMBG/requirements.txt

In [None]:
! pip install ipycanvas opencv-python matplotlib
! pip install accelerate

In [None]:
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler, ControlNetModel, StableDiffusionXLControlNetPipeline, EulerAncestralDiscreteScheduler
import torch
from PIL import Image
import random
from PIL import Image
import io
from ipycanvas import Canvas
from ipywidgets import Image as canvas_image
import cv2
import matplotlib.pyplot as plt
import numpy as np
from diffusers import AutoPipelineForInpainting, ControlNetModel
from diffusers.utils import load_image
import torch
from skimage import io
import torch, os
import numpy as np
from PIL import Image
from RMBG.briarmbg import BriaRMBG
from RMBG.utilities import preprocess_image, postprocess_image
from huggingface_hub import hf_hub_download
import gc
import transforms
import torchvision.transforms.functional as F

In [None]:
MAXINT32 = 2**31 - 1

seeds = [random.randint(0, MAXINT32) for _ in range(10)]
negative_prompt = "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
num_inference_steps = 50

def show_img(img):

  if not isinstance(img, Image.Image):
    img = Image.fromarray(img)

  w, h = img.size
  fig = plt.figure(figsize=(w/50, h/50))
  plt.imshow(img)
  plt.axis('off')
  plt.show()


def clear_gpu_cache():
  gc.collect()
  torch.cuda.empty_cache()


def concatenate_images(images, concatenation_type='horizontal', output_path=None):
    if not images:
        raise ValueError("Input list of images is empty.")

    if concatenation_type not in ['horizontal', 'vertical']:
        raise ValueError("Invalid concatenation type. Use 'horizontal' or 'vertical'.")

    if concatenation_type == 'horizontal':
        total_width = sum(image.width for image in images)
        max_height = max(image.height for image in images)
    else:
        total_width = max(image.width for image in images)
        max_height = sum(image.height for image in images)

    combined_image = Image.new("RGB", (total_width, max_height))

    current_width, current_height = 0, 0
    for image in images:
        if concatenation_type == 'horizontal':
            combined_image.paste(image, (current_width, 0))
            current_width += image.width
        else:
            combined_image.paste(image, (0, current_height))
            current_height += image.height

        del image

    if output_path:
        combined_image.save(output_path)

    return combined_image

def generate(model_name: str, prompt: str, height: int = 1024, width: int = 1024, number_of_images = 3, show=True):

    clear_gpu_cache()

    if model_name == "briaai/BRIA-2.2-HD":
        unet = UNet2DConditionModel.from_pretrained(model_name, torch_dtype=torch.float16, use_safetensors=True)
        pipe = DiffusionPipeline.from_pretrained("briaai/BRIA-2.2", unet=unet, torch_dtype=torch.float16, use_safetensors=True)
        del unet
        pipe.to("cuda")
        gen_images = [
            pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                generator=torch.Generator().manual_seed(seeds[ind])
            ).images[0]
        for ind, _ in enumerate(range(number_of_images))]

    elif model_name == "briaai/BRIA-2.2-FAST":
        unet = UNet2DConditionModel.from_pretrained(model_name, torch_dtype=torch.float16)
        pipe = DiffusionPipeline.from_pretrained("briaai/BRIA-2.2", unet=unet, torch_dtype=torch.float16)
        pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
        pipe.to("cuda")
        gen_images = [
            pipe(
                prompt,
                num_inference_steps=8,
                guidance_scale=5.0,
                negative_prompt=negative_prompt,
                height=height,
                width=width,
                generator=torch.Generator().manual_seed(seeds[ind])
            ).images[0]
        for ind, _ in enumerate(range(number_of_images))]

    else:
        pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16, use_safetensors=True)
        pipe.to("cuda")
        gen_images = [
            pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                generator=torch.Generator().manual_seed(seeds[ind])
                ).images[0]

            for ind, _ in enumerate(range(number_of_images))]

    del pipe

    if show:
        concat_imgs = concatenate_images(gen_images)
        print(concat_imgs)
        w, h = concat_imgs.size
        fig = plt.figure(figsize=(width/50, height/50))  # Divide by 100 to convert pixel size to inches
        # plt.rcParams['figure.dpi'] = 200
        plt.imshow(concat_imgs)
        # plt.axis('off')  # Turn off axis
        plt.show()

        # concat_imgs.show()
        # plt.imshow(concat_imgs)
        # plt.show()

    clear_gpu_cache()

    return gen_images

def generate_with_style(model_name: str, prompt: str, height: int = 1024, width: int = 1024, number_of_images = 3):

    clear_gpu_cache()

    pipe = DiffusionPipeline.from_pretrained("briaai/BRIA-2.0", torch_dtype=torch.float16, use_safetensors=True).to("cuda")

    map_of_models = {
        "style_1": "<path_to_your_checkpoint>",
        "style_2": "<path_to_your_checkpoint>",
        "style_3": "<path_to_your_checkpoint>",
    }
    map_of_norm_prompts = {
        "style_1": f"...",
        "style_2": f"...",
        "style_3": f"...",
    }
    pipe.load_lora_weights(f"/home/ubuntu/finetune/output/{map_of_models[model_name]}")

    gen_images = [pipe(
        map_of_norm_prompts[model_name],
        num_inference_steps=num_inference_steps,
        negative_prompt=negative_prompt,
    ).images[0] for _ in range(number_of_images)]

    del pipe

    concat_imgs = concatenate_images(gen_images)

    fig = plt.figure(figsize=(width/50, height/50))  # Divide by 100 to convert pixel size to inches
    plt.imshow(concat_imgs)
    # plt.axis('off')  # Turn off axis
    plt.show()

    del concat_imgs

    clear_gpu_cache()

    return gen_images

## Bria GTC demo

This demo show Bria foundation model implemnetaion with difusers pipeline
<p align="left">
    <br>
    <img src="https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/en/imgs/diffusers_library.jpg" width="200"/>
    <br>
<p>

Models show in this demo
- Bria 2.2
- Bria HD
- Bria Fast
- Bria Inpainting
- Bria Remove-BG

#### Get access to our models:
https://huggingface.co/briaai

## Generate images from text

In [None]:
prompt_1 = "A Mars rover exploring the red planet's surface"
prompt_2 = "A Mars rover exploring the green planet's surface"
prompt_3 = "A Mars rover exploring the blue planet's surface"

### Bria 2.2

In [None]:
generate(prompt=prompt_1, model_name="briaai/BRIA-2.2", height=1024, width=1024)

In [None]:
generate(prompt=prompt_2, model_name="briaai/BRIA-2.2", height=1024, width=1024)

In [None]:
generate(prompt=prompt_3, model_name="briaai/BRIA-2.2", height=1024, width=1024)

### Bria HD

In [None]:
generate(prompt=prompt_1, model_name="briaai/BRIA-2.2-HD", height=1536, width=1536, number_of_images=1)

### Bria Fast (1.6 it/s)

In [None]:
generate(prompt=prompt_1, model_name="briaai/BRIA-2.2-FAST", height=1024, width=1024)

## Bria 2.2 finetuned over Lora

In [None]:
clear_gpu_cache()

fine_tuned_models = [
    "style_1",
    "style_2",
    "style_3"
]

[generate_with_style(prompt=prompt_1, model_name=model_name) for model_name in fine_tuned_models]

# Inpainting

### Draw Mask

In [None]:
fname = "/tmp/inpaint.png"
img = generate(prompt=prompt_1, model_name="briaai/BRIA-2.2", height=1024, width=1024, number_of_images=1, show=False)[0]

img = img.resize((512, 512))
img.save(fname)
org_img_size = img.size

w, h = org_img_size
fig = plt.figure(figsize=(w/100, h/100))
plt.imshow(img)
plt.axis('off')
plt.show()



In [None]:
from enum import Enum
import copy

import numpy as np
from ipywidgets import IntSlider, link, HBox, VBox, Button, ToggleButtons
from ipycanvas import MultiCanvas, hold_canvas

from google.colab import output
output.enable_custom_widget_manager()

class Tools(Enum):
    BRUSH = 0
    SQUARE = 1
    CIRCLE = 2

class DrawingWidget(object):
    drawing = False
    position = None
    shape = []
    output_array = None
    drawing_line_width = 3
    history = []
    future = []
    max_history = 100
    tool_selection = None

    def __init__(self, width, height, background="#FFFFFF", alpha=1.0, default_style="#FFFFFF", default_radius=10):
        """
        Create a MultiCanvas with three layers: a background layer, the drawing layer, and a temporary layer during drawing.

        params:
            width (int): Width of the canvas.
            height (int): Height of the canvas.
            background (string, np.Array): background in the first layer.
                Can be given as a hex-code (str) or a numpy array with values to fill in (default: #FFFFFF).
            alpha (float): Transparency of the drawing layer. Helpful for masking (default: 1.0).
            default_style (str): Hex-code (str) of the default color in the colorpicker (default: #000000)
            default_radius (int): Default brush radius (default: 10).
        """
        # Initialization
        self.background = background
        self.alpha = alpha
        self.default_style = default_style
        self.default_radius = default_radius
        self.init_canvas(width, height)

    def get_image_data(self, background=False):
        if background:
            return self.canvas.get_image_data()
        return self.canvas._canvases[1].get_image_data()

    def init_canvas(self, width, height):
        self.canvas = MultiCanvas(n_canvases=3, width=width, height=height, sync_image_data=True)
        self.canvas._canvases[1].sync_image_data = True
        self.reset_background()

        self.canvas.on_mouse_down(self.on_mouse_down)
        self.canvas.on_mouse_move(self.on_mouse_move)
        self.canvas.on_mouse_up(self.on_mouse_up)

        self.canvas[2].stroke_style = "#4287F5"
        self.canvas[2].fill_style = "#4287F5"
        self.canvas[2].line_cap = 'round'
        self.canvas[2].line_width = self.drawing_line_width

        self.canvas[1].stroke_style = self.default_style
        self.canvas[1].line_cap = 'round'
        self.canvas[1].line_join = 'round'
        self.canvas[1].line_width = self.default_radius
        self.canvas[1].global_alpha = self.alpha

    def reset_background(self, *args):
        with hold_canvas():
            if type(self.background) is np.ndarray:
                # canvas.draw_image(t, 0, 0)
                self.canvas[0].put_image_data(self.background)
            else:
                self.canvas[0].fill_style = self.background
                self.canvas[0].fill_rect(0, 0, self.canvas.width, self.canvas.height)

    def show(self):
        # UI controls
        self.tool_selection = ToggleButtons(options=[('Brush ', Tools.BRUSH), ('Square ', Tools.SQUARE), ('Circle ', Tools.CIRCLE)],
                                            value=Tools.BRUSH,
                                            icons=['brush', 'square', 'circle'])
        radius_slider = IntSlider(description="Brush radius:", value=self.default_radius, min=1, max=100)
        clear_button = Button(description="Clear")
        clear_button.on_click(self.clear_canvas)
        undo_button = Button(description="Undo", icon="rotate-left")
        undo_button.on_click(self.undo)
        redo_button = Button(description="Redo", icon="rotate-right")
        redo_button.on_click(self.redo)

        # Link UI controls to canvas
        link((radius_slider, "value"), (self.canvas[1], "line_width"))

        # Display in grid
        return HBox((self.canvas, VBox((self.tool_selection, radius_slider, clear_button, HBox((undo_button, redo_button))))))

    def save_to_history(self):
        self.history.append(self.canvas._canvases[1].get_image_data())
        if len(self.history) > self.max_history:
            self.history = self.history[1:]
        self.future = []

    def on_mouse_down(self, x, y):
        self.drawing = True
        self.position = (x, y)
        self.shape = [self.position]
        self.save_to_history()

    def on_mouse_move(self, x, y):
        if not self.drawing:
            return

        tool = self.tool_selection.value
        with hold_canvas():
            if (tool == Tools.BRUSH):
                self.canvas[2].line_width = self.canvas[1].line_width
                self.canvas[2].stroke_line(self.position[0], self.position[1], x, y)
                self.canvas[2].line_width = self.drawing_line_width
            elif (tool == Tools.SQUARE):
                self.canvas[2].clear()
                self.canvas[2].stroke_rect(self.shape[0][0], self.shape[0][1], x - self.shape[0][0], y - self.shape[0][1])
            elif (tool == Tools.CIRCLE):
                self.canvas[2].clear()
                circle_radius = max(x - self.shape[0][0], y - self.shape[0][1]) / 2
                self.canvas[2].stroke_circle(self.shape[0][0] + circle_radius, self.shape[0][1] + circle_radius, circle_radius)

            self.position = (x, y)

        self.shape.append(self.position)

    def on_mouse_up(self, x, y):
        self.drawing = False

        tool = self.tool_selection.value
        with hold_canvas():
            self.canvas[2].clear()
            if (tool == Tools.BRUSH):
                self.canvas[1].stroke_lines(self.shape)
            elif (tool == Tools.SQUARE):
                self.canvas[1].fill_rect(self.shape[0][0], self.shape[0][1], x - self.shape[0][0], y - self.shape[0][1])
            elif (tool == Tools.CIRCLE):
                circle_radius = max(x - self.shape[0][0], y - self.shape[0][1]) / 2
                self.canvas[1].fill_circle(self.shape[0][0] + circle_radius, self.shape[0][1] + circle_radius, circle_radius)

        self.shape = []

    def clear_canvas(self, *args):
        self.save_to_history()
        with hold_canvas():
            self.canvas[1].clear()

    def undo(self, *args):
        if self.history:
            with hold_canvas():
                self.future.append(self.canvas._canvases[1].get_image_data())
                self.canvas[1].clear()
                self.canvas[1].put_image_data(self.history[-1])
                self.history = self.history[:-1]

    def redo(self, *args):
        if self.future:
            with hold_canvas():
                self.history.append(self.canvas._canvases[1].get_image_data())
                self.canvas[1].clear()
                self.canvas[1].put_image_data(self.future[-1])
                self.future = self.future[:-1]

drawing_widget = DrawingWidget(width=org_img_size[0], height=org_img_size[1], background= np.array(img))
drawing_widget.show()

### Execute Inpainting

In [None]:
from PIL import Image

clear_gpu_cache()

def inpainting(image: Image, mask: Image, prompt: str, negative_prompt: str, model_name: str = "briaai/BRIA-1.4-Inpainting"):

    pipeline = AutoPipelineForInpainting.from_pretrained(model_name, torch_dtype=torch.float32)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pipeline.to(device)

    print(isinstance(image, Image.Image), isinstance(mask, Image.Image))

    image.save("/tmp/a.jpeg")
    mask.save("/tmp/b.jpeg")

    init_image = load_image("/tmp/a.jpeg")
    mask_image = load_image("/tmp/b.jpeg")

    image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]
    del pipeline
    return image


def get_red_mask(image):
    # Convert image to HSV color space

    # Define lower and upper bounds for red color in HSV
    lower_red = np.array([255, 255, 255])
    upper_red = np.array([255, 255, 255])

    # Threshold the HSV image to get only red colors
    red_mask = cv2.inRange(image, lower_red, upper_red)

    # Morphological operations to remove noise
    kernel = np.ones((5, 5), np.uint8)
    red_mask = cv2.morphologyEx(red_mask, cv2.MORPH_OPEN, kernel)
    red_mask = cv2.morphologyEx(red_mask, cv2.MORPH_CLOSE, kernel)

    return red_mask


# image_path = 'my_canvas_image.png'
# canvas.to_file(image_path)
# image = cv2.imread(image_path)
image = drawing_widget.get_image_data()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
red_mask = get_red_mask(image)

result = inpainting(model_name="briaai/BRIA-1.4-Inpainting", image=img, mask=Image.fromarray(red_mask), prompt="Tank", negative_prompt="")

res = concatenate_images([img, Image.fromarray(image), Image.fromarray(red_mask), result])

w, h = res.size
fig = plt.figure(figsize=(w/50, h/50))
plt.imshow(res)
plt.axis('off')
plt.show()


# Remove Background

In [None]:
def remove_background(image):
    net = BriaRMBG()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = BriaRMBG.from_pretrained("briaai/RMBG-1.4", revision="refs/pr/19")
    net.to(device)
    net.eval()

    # Prepare input
    model_input_size = [1024, 1024]
    orig_im = np.array(image)
    orig_im_size = orig_im.shape[0:2]
    image_tensor = preprocess_image(orig_im, model_input_size).to(device)

    # Inference
    result = net(image_tensor)
    del net
    # Post-process
    result_image = postprocess_image(result[0][0], orig_im_size)

    # Create a transparent image
    pil_im = Image.fromarray(result_image)
    no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
    orig_image = Image.fromarray(orig_im)

    # Paste the original image onto the transparent image
    no_bg_image.paste(orig_image, mask=pil_im)

    return no_bg_image

In [None]:
clear_gpu_cache()

images = generate(prompt=prompt_1, model_name="briaai/BRIA-2.2", height=1024, width=1024, number_of_images=1, show=False)
result_image_np = remove_background(images[0])
concat_imgs = concatenate_images([images[0], result_image_np])

w, h = concat_imgs.size
fig = plt.figure(figsize=(w/50, h/50))
plt.imshow(concat_imgs)
plt.show()

# Control Net Canny

In [None]:
def resize_image(image):
    image = image.convert('RGB')
    current_size = image.size
    if current_size[0] > current_size[1]:
        center_cropped_image = F.center_crop(image, (current_size[1], current_size[1]))
    else:
        center_cropped_image = F.center_crop(image, (current_size[0], current_size[0]))
    resized_image = F.resize(center_cropped_image, (1024, 1024))
    return resized_image

def get_canny_filter(image):
    
    if not isinstance(image, np.ndarray):
        image = np.array(image) 
        
    low_threshold = 100
    high_threshold = 200
    image = cv2.Canny(image, low_threshold, high_threshold)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    canny_image = Image.fromarray(image)
    return canny_image

In [None]:
controlnet = ControlNetModel.from_pretrained(
    "briaai/BRIA-2.2-ControlNet-Canny",
    torch_dtype=torch.float16
)

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "briaai/BRIA-2.2",
    controlnet=controlnet,
    torch_dtype=torch.float16,
)

pipe.scheduler = EulerAncestralDiscreteScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    num_train_timesteps=1000,
    steps_offset=1
)
pipe.force_zeros_for_empty_prompt = False
pipe.to("cuda")

prompt = "Spacecraft"

original_img = Image.open('/content/web.png')
input_image = resize_image(original_img)
canny_image = get_canny_filter(input_image)

gen_img = pipe(prompt=prompt, negative_prompt="", image=canny_image, controlnet_conditioning_scale=1.0, height=1024, width=1024).images[0]

concat_imgs = concatenate_images([input_image, canny_image, gen_img])
show_img(concat_imgs)

del pipe
clear_gpu_cache()