In [None]:
from demo_utils import *
display_images([Image.open("./visuals/qr_code.png")], resize=300)

** This notebook was tested on **Amazon EC2 G6e** Instance (NVIDIA L40S Tensor Core GPU).

In [1]:
# This code requires a Hugging-Face token with Bria permissions, as well as an AWS account with Bedrock permissions.
import os

os.environ["AWS_ACCESS_KEY_ID"] = "your-access-key-id"
os.environ["AWS_SECRET_ACCESS_KEY"] = "your-secret-access-key"
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"

os.environ["HF_TOKEN"] = "your-huggingface-token"

# Tailored-Generation: 

## LoRA Fine-Tuning Bria's Text-to-Image Model

#### What is LoRA fine-tuning?
LoRA (Low-Rank Adaptation) is a fine-tuning method that enables efficient training of text-to-image models by adjusting only a small set of low-rank matrices rather than modifying the entire model. 

Instead of retraining all the parameters of a large model, LoRA inserts lightweight trainable layers into the architecture and updates only these layers during training. 

This significantly reduces the computational cost and memory requirements while allowing the model to quickly learn new concepts. Because the core model remains unchanged, LoRA makes it easier to deploy and switch between different fine-tuned adaptations without having to maintain multiple large model copies.

#### Fine-Tuning Bria's Foundation Model
For fine-tuning, whether through LoRA or full-fine-tuning, we'll use Bria's latest **4B-Adapt** model, which is designed to provide exceptional fine-tuning capabilities for commercial use.

This model excels in aligning to the tuned style while preserving an remarkably high prompt alignment. 

Bria-4B-Adapt weights as well as training and inference instructions can be found here: https://huggingface.co/briaai/BRIA-4B-Adapat

Bria also offers an API suite to train, manage and use fine-tuned models: https://docs.bria.ai/tailored-generation

In [None]:
# Download necessary script files from huggingface hub

from huggingface_hub import hf_hub_download
import os

try:
    local_dir = os.path.dirname(__file__)
except:
    local_dir = '.'
    
hf_hub_download(repo_id="briaai/BRIA-4B-Adapt", filename='pipeline_bria.py', local_dir=local_dir)
hf_hub_download(repo_id="briaai/BRIA-4B-Adapt", filename='transformer_bria.py', local_dir=local_dir)
hf_hub_download(repo_id="briaai/BRIA-4B-Adapt", filename='bria_utils.py', local_dir=local_dir)
hf_hub_download(repo_id="briaai/BRIA-4B-Adapt", filename='train_lora.py', local_dir=local_dir)

We want to Fine-Tune Bria-4B-Adapt on the following Bria Elephant Images:

In [None]:
import os
from PIL import Image
from diffusers.utils import make_image_grid

# show all elephant images from /home/ubuntu/demo_gtc/briaphant
images_dir = "./briaphant"
all_images = {f: Image.open(os.path.join(images_dir, f)) for f in os.listdir(images_dir) if f.endswith('.png')}
make_image_grid(all_images.values(), rows=2,cols=4, resize=256)

### Training Data

As training data, we're going to need a folder with all images and a csv with the image file names and captions.

For caption creation we can use a VLM (Vision-Language Model) by instructing it to return a description of each image in the dataset.
We'll choose a prefix to repeat in all captions that contain a "trigger word", in this case we'll name the character BriaPhant and include it in the caption: "A 3d-render of BriaPhant, a purple elephant".

We can ask the VLM to write the caption by completing the given prefix.

We'll use **Amazon Bedrock** API to call Claude Sonnet:

In [4]:
import boto3
import json
import base64
import io

from botocore.exceptions import ClientError

# Create a Bedrock Runtime client in the AWS Region of your choice.
bedrock_client = boto3.client("bedrock-runtime")

def center_crop_image(image):
    width, height = image.size
    if width == height:
        return image
    elif width > height:
        left = (width - height) // 2
        right = left + height
        top = 0
        bottom = height
    else:
        top = (height - width) // 2
        bottom = top + width
        left = 0
        right = width
    return image.crop((left, top, right, bottom))

# Define the prompt for the model.
def load_images_for_bedrock(images, resize_resolution=512, center_crop=True):
    base64_images = []
    for img in images:
        if center_crop:
            img = center_crop_image(img)
        img = img.resize((resize_resolution, resize_resolution)).convert("RGB")

        with io.BytesIO() as img_buffer:
            img.save(img_buffer, format="JPEG")  # Save image directly to buffer
            img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")

        base64_images.append(img_base64)
        
    return base64_images

# Format the request payload using the model's native structure.
CLAUDE_MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0"

def infer_with_bedrock(prompt, images, model_id = CLAUDE_MODEL_ID, temperature=0., top_p=0.3, top_k=5):

    content = [{"type": "text", "text": prompt}]
    for image in images:
        content.append({"type": "image", "source": {"type":"base64", "media_type":"image/jpeg","data": image}})
        
    message = {
            "role": "user",
                "content": content}  


    native_request = {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 512,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "messages": [message],
    }

    # Convert the native request to JSON.
    request = json.dumps(native_request)

    try:
        # Invoke the model with the request.
        response = bedrock_client.invoke_model(modelId=model_id, body=request)

    except (ClientError, Exception) as e:
        print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
        return None


    # Decode the response body.
    model_response = json.loads(response["body"].read())

    # Extract and print the response text.
    response_text = model_response["content"][0]["text"]
    return response_text

def generate_caption_with_claude(image, prefix):
    images = load_images_for_bedrock([image], center_crop=True) # training code does center-crop by default, so we can feed the captioner the same

    prompt = f"""
Describe the image. Use up to 30 words, starting with: "{prefix}"

Response format:
{prefix} [Concise description of the image in up to 30 words]
"""

    caption = infer_with_bedrock(prompt, images, temperature=0.1)

    return caption


In [None]:
import pandas as pd

captions = []
for file_name, img in all_images.items():
    caption = generate_caption_with_claude(img, "A 3d-render of BriaPhant, a purple elephant")
    captions.append({"file_name": file_name, "caption": caption})
    display_images([img], title=caption.replace('BriaPhant,','BriaPhant,\n'), font_size=12, resize=350)



In [None]:
df = pd.DataFrame(captions)
df.to_csv("./briaphant/metadata.csv", index=False)
df.head()


### Training

Bria's Tailored-Gen training API uses **Amazon Sagemaker** to manage and run training sessions.

In this case, we can run the training with a batch size of 1 on the current EC2 instance. We'll use gradient_accumulation_steps=4 to get an effective batch-size of 4.

For larger datasets and for full-fine-tuning of Bria's adaptable model (4B-Adapt), we may want to increase the batch size and use multi-GPUs, so more VRAM resources would probably be required.


In [7]:
# We'll set up the training configuration:

hf_model_id = "briaai/BRIA-4B-Adapt"
output_dir = "./briaphant_training_output" 
data_dir = images_dir
max_training_steps = 1500
rank = 128
batch_size = 1
gradient_accumulation_steps = 4

os.makedirs(output_dir, exist_ok=True)

In [None]:
# Run the training script
!python train_lora.py \
    --pretrained_model_name_or_path {hf_model_id} \
    --dataset_name {data_dir} \
    --output_dir {output_dir} \
    --max_train_steps {max_training_steps} \
    --rank {rank} \
    --train_batch_size {batch_size} \
    --gradient_accumulation_steps {gradient_accumulation_steps}

Once training is finished, we can run inference on the trained model:

In [None]:
import torch
from pipeline_bria import BriaPipeline


tailored_pipe = BriaPipeline.from_pretrained("briaai/BRIA-4B-Adapt", torch_dtype=torch.bfloat16,trust_remote_code=True)
tailored_pipe.to(device="cuda")


tailored_pipe.load_lora_weights(output_dir, weight_name = "pytorch_lora_weights.safetensors")
# tailored_pipe.load_lora_weights("briaai/BRIA-4B-Adapt", subfolder="example_finetuned_model", weight_name = "bria_elephant.safetensors")

prompt = "A 3d-render of BriaPhant, a purple elephant, playing the trumpet"

seed = 42
lora_scale = 1.0
generator = torch.Generator("cuda").manual_seed(seed)
image = tailored_pipe(prompt=prompt, height=1024, width=1024, generator=generator, joint_attention_kwargs={"scale": lora_scale}, num_inference_steps=30).images[0]
display_images([image])

In [None]:
display_images([Image.open("./visuals/qr_code.png")], resize=300)