## Replicate Training Workflow for SDXL

First, install replicate dependancies with pip:

In [8]:
!pip install replicate



DEPRECATION: torchsde 0.2.5 has a non-standard dependency specifier numpy>=1.19.*; python_version >= "3.7". pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of torchsde or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063


## Prepare & Upload Image Dataset

The training API expects a zip file containing your training images. A handful of images (5-6) is enough to fine-tune SDXL on a single person, but you might need more if your training subject is more complex or the images are very different. Keep the following guidelines in mind when preparing your images:

- Images can be of yourself, your pet, your favorite stuffed animal, or any unique object. For best results, your images should contain only the subject itself, with a minimum of background noise or other objects.
- Images can be in JPEG or PNG format.
- Dimensions and size don't matter.
- Filenames don't matter.
- Do not use images of other people without their consent.

Put your images in a folder and zip it up. The directory structure of the zip file doesn't matter:

```console
zip -r data.zip data
```

Once the dataset is zipped, upload the file somewhere on the internet that is publicly accessible, like an S3 bucket or a GitHub Pages site. The worklflow for serving it via replicate is below.

In [9]:
import requests
import json

def upload_to_replicate(api_token, file_path):
    """
    Uploads a file to replicate using the given API token.

    Args:
    - api_token (str): The API token for authorization.
    - file_path (str): The path to the file to upload.

    Returns:
    - str: The serving URL.
    """

    # Define headers for initial request
    headers_init = {
        "Authorization": f"Token {api_token}"
    }
    
    # POST request to get the upload URL
    response_init = requests.post(
        "https://dreambooth-api-experimental.replicate.com/v1/upload/data.zip",
        headers=headers_init
    )

    # Handle possible errors in the initial request
    if response_init.status_code != 200:
        raise Exception(f"Initial request failed with status {response_init.status_code}: {response_init.text}")

    # Extract upload URL from the response
    upload_url = json.loads(response_init.text)['upload_url']

    # PUT request to upload the file
    with open(file_path, 'rb') as f:
        response_upload = requests.put(
            upload_url,
            headers={"Content-Type": "application/zip"},
            data=f
        )

    # Handle possible errors in the upload
    if response_upload.status_code != 200:
        raise Exception(f"Upload failed with status {response_upload.status_code}: {response_upload.text}")

    # Extract and return the serving URL
    serving_url = json.loads(response_init.text)['serving_url']

    return serving_url

In [10]:
api_token = "r8_0pb9sQTRpuE5l8pyJEfyv9CBN72qa2P3fV5iA"
file_path = "data/embersteel-sdxl-training-dataset-cerephelo.zip"
serving_url = upload_to_replicate(api_token, file_path)
print(serving_url)

https://replicate.delivery/pbxt/JRpVDqtjQ2gPzBCnfkytT0uZyCBDKSj0UebvOwuIkhJFiJVX/data.zip


Great! Now that we have our dataset hosted, we can initialize training. Change the "input_images" value with the URL you uploaded in the last cell, adjusting your destination to your SDXL instance before training. 

In [11]:
import os
import replicate

def start_replicate_training(api_key, version, input_url, destination):
    """
    Start training using the Replicate API.

    Args:
    - api_key (str): Replicate API token.
    - version (str): Version string for training.
    - input_url (str): URL for input data.
    - destination (str): Destination string.

    Returns:
    - Training object.
    """
    
    # Set the API token as an environment variable
    os.environ["REPLICATE_API_TOKEN"] = api_key

    # Start the training
    training = replicate.trainings.create(
        version=version,
        input={
            "input_images": input_url,
            "lora_lr": 2e-4,
            "caption_prefix": 'embersteel',
        },
        destination=destination
    )

    return training

# Usage:
api_key = "r8_0pb9sQTRpuE5l8pyJEfyv9CBN72qa2P3fV5iA"
version = "stability-ai/sdxl:7ca7f0d3a51cd993449541539270971d38a24d9a0d42f073caf25190d41346d7"
input_url = "https://replicate.delivery/pbxt/JRpVDqtjQ2gPzBCnfkytT0uZyCBDKSj0UebvOwuIkhJFiJVX/data.zip"
destination = "adynblaed/embersteel-sdxl-cerephelo"

training_result = start_replicate_training(api_key, version, input_url, destination)

In [12]:
import time

def monitor_training_progress(training, check_interval=5, max_iterations=100):
    """
    Continuously monitor the progress of the Replicate training.

    Args:
    - training (object): The training object returned from start_replicate_training.
    - check_interval (int): The number of seconds to wait between each check. Default is 30 seconds.
    - max_iterations (int): Maximum number of times to check the training status. To run indefinitely, set to None.

    Prints:
    - Training status.
    - Last 10 logs.
    """
    
    iterations = 0
    while max_iterations is None or iterations < max_iterations:
        training.reload()
        print("Iteration:", iterations + 1)
        print("Status:", training.status)
        print("Logs:")
        print("\n".join(training.logs.split("\n")[-10:]))
        print("="*50)  # Just for separating different iterations visually

        if training.status == "completed":  # Assuming "completed" is the final status
            print("Training has completed.")
            break

        time.sleep(check_interval)  # Wait before the next check
        iterations += 1

# Usage
# Assuming `training_result` is the object returned by your start_replicate_training() function.
monitor_training_progress(training_result, check_interval=5, max_iterations=100)

Iteration: 1
Status: processing
Logs:
67it [00:03, 16.38it/s]
69it [00:03, 16.24it/s]
69it [00:03, 18.16it/s]
Upscaling 69 images...
Downloading (…)lve/main/config.json:   0%|          | 0.00/772 [00:00<?, ?B/s]
Downloading (…)lve/main/config.json: 100%|██████████| 772/772 [00:00<00:00, 689kB/s]
Downloading pytorch_model.bin:   0%|          | 0.00/48.6M [00:00<?, ?B/s]
Downloading pytorch_model.bin:  65%|██████▍   | 31.5M/48.6M [00:00<00:00, 301MB/s]
Downloading pytorch_model.bin: 100%|██████████| 48.6M/48.6M [00:00<00:00, 250MB/s]
  0%|          | 0/69 [00:00<?, ?it/s]
Iteration: 2
Status: processing
Logs:
  3%|▎         | 2/69 [00:00<00:16,  4.16it/s]
  4%|▍         | 3/69 [00:00<00:22,  2.94it/s]
  6%|▌         | 4/69 [00:01<00:25,  2.57it/s]
  7%|▋         | 5/69 [00:01<00:26,  2.38it/s]
 10%|█         | 7/69 [00:02<00:20,  3.04it/s]
 13%|█▎        | 9/69 [00:02<00:17,  3.44it/s]
 14%|█▍        | 10/69 [00:03<00:19,  3.00it/s]
 16%|█▌        | 11/69 [00:03<00:21,  2.72it/s]
 17%|█▋