In [2]:
from flask import Flask, render_template, request
from flask.helpers import send_file
from PIL import Image
from io import BytesIO
import loader
import modelpipeline
from transformers import CLIPTokenizer
import torch

app = Flask(__name__)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

ALLOW_CUDA = True
ALLOW_MPS = False

if torch.cuda.is_available() and ALLOW_CUDA:
    DEVICE = "cuda"
elif (torch.backends.mps.is_built() or torch.backends.mps.is_available()) and ALLOW_MPS:
    DEVICE = "mps"
print(f"Using device: {DEVICE}")

tokenizer = CLIPTokenizer("../data/vocab.json", merges_file="../data/merges.txt")
model_file = "../data/v1-5-pruned-emaonly.ckpt"
models = loader.preload_models_from_standard_weights(model_file, DEVICE)

@app.route("/")
def index():
    return render_template("index.html")

@app.route("/generate", methods=["POST"])
def generate():
    prompt = request.form.get("prompt")
    uncond_prompt = request.form.get("uncond_prompt")
    strength = float(request.form.get("strength"))

    image = request.files["image"]
    input_image = Image.open(image)

    do_cfg = True  # Example value, modify as needed
    cfg_scale = 8  # Example value, modify as needed
    sampler = "ddpm"  # Example value, modify as needed
    num_inference_steps = 50  # Example value, modify as needed
    seed = 42  # Example value, modify as needed

    output_image = modelpipeline.generate(
        prompt=prompt,
        uncond_prompt=uncond_prompt,
        input_image=input_image,
        strength=strength,
        do_cfg=do_cfg,
        cfg_scale=cfg_scale,
        sampler_name=sampler,
        n_inference_steps=num_inference_steps,
        seed=seed,
        models=models,
        device=DEVICE,
        idle_device="cpu",
        tokenizer=tokenizer,
    )

    output_buffer = BytesIO()
    output_image.save(output_buffer, format="PNG")
    output_buffer.seek(0)

    return send_file(output_buffer, mimetype="image/png")

if __name__ == "__main__":
    app.run(debug=True)


Using device: cpu


100%|██████████| 10/10 [01:01<00:00,  6.12s/it]
