# install dependencies

In [None]:
# Install dependencies

!pip install pyngrok torchsde einops diffusers accelerate xformers==0.0.27.post2
!apt -y install -qq aria2

# Clone the Repository

In [None]:
%cd /content
!git clone -b totoro3 https://github.com/camenduru/ComfyUI /content/TotoroUI
%cd /content/TotoroUI

# Download Model Files

In [None]:

model_urls = [
    ("https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors", "models/unet", "flux1-schnell.safetensors"),
    ("https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/ae.sft", "models/vae", "ae.sft"),
    ("https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/clip_l.safetensors", "models/clip", "clip_l.safetensors"),
    ("https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/t5xxl_fp8_e4m3fn.safetensors", "models/clip", "t5xxl_fp8_e4m3fn.safetensors")
]

for url, directory, filename in model_urls:
    !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {url} -d /content/TotoroUI/{directory} -o {filename}

# Import Libraries and Load Models

In [None]:
import random
import torch
import numpy as np
from PIL import Image
import base64
from io import BytesIO
import nodes
from nodes import NODE_CLASS_MAPPINGS
from totoro_extras import nodes_custom_sampler
from totoro import model_management
from flask import Flask, request, jsonify
from pyngrok import ngrok

# Load model components
DualCLIPLoader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
UNETLoader = NODE_CLASS_MAPPINGS["UNETLoader"]()
RandomNoise = nodes_custom_sampler.NODE_CLASS_MAPPINGS["RandomNoise"]()
BasicGuider = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicGuider"]()
KSamplerSelect = nodes_custom_sampler.NODE_CLASS_MAPPINGS["KSamplerSelect"]()
BasicScheduler = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicScheduler"]()
SamplerCustomAdvanced = nodes_custom_sampler.NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
VAELoader = NODE_CLASS_MAPPINGS["VAELoader"]()
VAEDecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
EmptyLatentImage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()

app = Flask(__name__)

with torch.inference_mode():
    clip = DualCLIPLoader.load_clip("t5xxl_fp8_e4m3fn.safetensors", "clip_l.safetensors", "flux")[0]
    unet = UNETLoader.load_unet("flux1-schnell.safetensors", "fp8_e4m3fn")[0]
    vae = VAELoader.load_vae("ae.sft")[0]

def closestNumber(n, m):
    q = int(n / m)
    n1 = m * q
    n2 = m * (q + 1) if (n * m) > 0 else m * (q - 1)
    return n1 if abs(n - n1) < abs(n - n2) else n2

# Define Flask App

In [8]:


app = Flask(__name__)

@app.route('/')
def index():
    return '''
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Image Generation Interface</title>
    <link href="https://fonts.googleapis.com/css2?family=Poppins:wght@400;600&display=swap" rel="stylesheet">
    <style>
        body {
            font-family: 'Poppins', sans-serif;
            margin: 0;
            padding: 10px;
            background-color: #1c1c1c;
            color: #ffffff;
            display: flex;
            justify-content: center;
            align-items: center;
            min-height: 100vh;
        }
        .main-container {
            display: grid;
            grid-template-columns: 300px 1fr 300px;
            gap: 20px;
            width: 100%;
            max-width: 1400px;
        }
        .form-container {
            padding: 10px;
            box-shadow: 0 0 15px rgba(0, 0, 0, 0.6);
            background-color: #2a2a2a;
            border-radius: 12px;
            height: 100%;
        }
        .form-container h1 {
            font-size: 20px;
            margin-bottom: 10px;
            color: #ff8c00;
            text-align: center;
        }
        .form-container label {
            font-size: 14px;
            margin-top: 5px;
            display: block;
            color: #ff8c00;
        }
        .form-container textarea, .form-container input, .form-container button, .form-container select {
            width: 100%;
            margin: 5px 0;
            padding: 5px;
            font-size: 14px;
            border: none;
            border-radius: 8px;
            background-color: #444;
            color: #fff;
            box-sizing: border-box;
        }
        .form-container button {
            background-color: #ff8c00;
            cursor: pointer;
            transition: background-color 0.3s;
        }
        .form-container button:hover {
            background-color: #e07b00;
        }
        .image-container {
            display: flex;
            flex-direction: column;
            justify-content: center;
            align-items: center;
            background-color: #2a2a2a;
            border-radius: 12px;
            padding: 10px;
            box-shadow: 0 0 15px rgba(0, 0, 0, 0.6);
        }
        .image-container img {
            max-width: 100%;
            max-height: 600px;
            border-radius: 12px;
            box-shadow: 0 0 10px rgba(0, 0, 0, 0.5);
            margin-bottom: 10px;
        }
        .image-info {
            color: #ff8c00;
            font-size: 14px;
            text-align: center;
            margin-bottom: 10px;
        }
        .image-info p {
            margin: 0;
        }
        .download-button {
            background-color: #ff8c00;
            color: #fff;
            padding: 10px 20px;
            border-radius: 8px;
            text-decoration: none;
            font-weight: bold;
            transition: background-color 0.3s;
            display: none;
        }
        .download-button:hover {
            background-color: #e07b00;
        }
        .generated-images-section {
            padding: 10px;
            box-shadow: 0 0 15px rgba(0, 0, 0, 0.6);
            background-color: #2a2a2a;
            border-radius: 12px;
            height: 100%;
            overflow-y: auto;
        }
        .generated-images-section h2 {
            font-size: 18px;
            color: #ff8c00;
            margin-bottom: 10px;
            text-align: center;
        }
        .generated-images-container {
            display: flex;
            flex-direction: column;
            gap: 10px;
        }
        .generated-image-card {
            width: 100%;
            background-color: #333;
            border-radius: 8px;
            overflow: hidden;
            box-shadow: 0 0 5px rgba(0, 0, 0, 0.5);
            cursor: pointer;
            transition: transform 0.3s;
        }
        .generated-image-card:hover {
            transform: scale(1.05);
        }
        .generated-image-card img {
            width: 100%;
            height: auto;
            border-bottom: 1px solid #444;
        }
        .generated-image-card .prompt {
            padding: 5px;
            font-size: 12px;
            color: #ff8c00;
            text-align: center;
        }
        .loading-overlay {
            display: none;
            position: fixed;
            top: 0;
            left: 0;
            width: 100%;
            height: 100%;
            background-color: rgba(0, 0, 0, 0.75);
            z-index: 1000;
            display: flex;
            justify-content: center;
            align-items: center;
            flex-direction: column;
        }
        .spinner {
            width: 60px;
            height: 60px;
            border: 8px solid #f3f3f3;
            border-top: 8px solid #ff8c00;
            border-radius: 50%;
            animation: spin 1s linear infinite;
        }
        @keyframes spin {
            0% { transform: rotate(0deg); }
            100% { transform: rotate(360deg); }
        }
        #imageModal {
            display: none;
            position: fixed;
            top: 0;
            left: 0;
            width: 100%;
            height: 100%;
            background-color: rgba(0, 0, 0, 0.8);
            z-index: 1000;
            justify-content: center;
            align-items: center;
            flex-direction: column;
        }
        #imageModal img {
            max-width: 90%;
            max-height: 90%;
        }
        #imageModal .prompt {
            color: #ffffff;
            margin-top: 10px;
            text-align: center;
        }
    </style>
</head>
<body>
    <div class="main-container">
        <div class="form-container">
            <h1>Generate Image</h1>
            <form id="imageForm">
                <label for="prompt">Prompt</label>
                <textarea id="prompt" placeholder="Enter your prompt here" rows="4"></textarea>

                <label for="width">Width</label>
                <input type="number" id="width" placeholder="Width" value="800">

                <label for="height">Height</label>
                <input type="number" id="height" placeholder="Height" value="480">

                <label for="seed">Seed</label>
                <input type="number" id="seed" placeholder="Seed" value="0">

                <label for="steps">Steps</label>
                <input type="number" id="steps" placeholder="Steps" value="4">

                <label for="sampler">Sampler</label>
                <select id="sampler">
                    <option value="euler">Euler</option>
                    <option value="ddim">DDIM</option>
                    <option value="plms">PLMS</option>
                    <option value="dpm_solver">DPM Solver</option>
                </select>

                <label for="scheduler">Scheduler</label>
                <input type="text" id="scheduler" placeholder="Scheduler" value="simple">

                <button type="button" onclick="generateImage()">Generate Image</button>
            </form>
        </div>
        <div class="image-container">
            <img id="generatedImage" src="" alt="Generated Image">
            <div class="image-info" id="imageInfo"></div>
            <a id="downloadButton" class="download-button" href="#" download>Download Image</a>
        </div>
        <div class="generated-images-section">
            <h2>Generated Images</h2>
            <div id="generatedImagesContainer" class="generated-images-container"></div>
        </div>
    </div>
    <div id="loadingOverlay" class="loading-overlay">
        <div class="spinner"></div>
    </div>
    <div id="imageModal">
        <img src="" alt="Full Size Image">
        <div class="prompt"></div>
    </div>
    <script>
        async function generateImage() {
            const prompt = document.getElementById('prompt').value;
            const width = document.getElementById('width').value;
            const height = document.getElementById('height').value;
            const seed = document.getElementById('seed').value;
            const steps = document.getElementById('steps').value;
            const sampler = document.getElementById('sampler').value;
            const scheduler = document.getElementById('scheduler').value;

            const loadingOverlay = document.getElementById('loadingOverlay');
            const downloadButton = document.getElementById('downloadButton');

            loadingOverlay.style.display = 'flex';

            downloadButton.style.display = 'none';

            try {
                const response = await fetch('/generate', {
                    method: 'POST',
                    headers: { 'Content-Type': 'application/json' },
                    body: JSON.stringify({ prompt, width, height, seed, steps, sampler, scheduler })
                });

                const data = await response.json();
                const imageUrl = data.imageUrl;
                const seedValue = data.seed;
                document.getElementById('generatedImage').src = imageUrl;

                const imageInfo = `Seed: ${seedValue}, Steps: ${steps}, Sampler: ${sampler}, Scheduler: ${scheduler}`;
                document.getElementById('imageInfo').innerText = imageInfo;

                downloadButton.href = imageUrl;
                downloadButton.download = `image_${Date.now()}.png`;
                downloadButton.style.display = 'inline-block';

                addGeneratedImageCard(imageUrl, prompt);
            } catch (error) {
                console.error('Error generating image:', error);
            } finally {
                loadingOverlay.style.display = 'none';
            }
        }

        function addGeneratedImageCard(imageUrl, prompt) {
            const container = document.getElementById('generatedImagesContainer');

            const card = document.createElement('div');
            card.className = 'generated-image-card';

            const img = document.createElement('img');
            img.src = imageUrl;
            card.appendChild(img);

            const promptText = document.createElement('div');
            promptText.className = 'prompt';
            promptText.innerText = prompt;
            card.appendChild(promptText);

            card.addEventListener('click', () => {
                openImageModal(imageUrl, prompt);
            });

            container.appendChild(card);
        }

        function openImageModal(imageUrl, prompt) {
            const modal = document.getElementById('imageModal');
            const modalImg = modal.querySelector('img');
            const modalPrompt = modal.querySelector('.prompt');

            modalImg.src = imageUrl;
            modalPrompt.innerText = prompt;

            modal.style.display = 'flex';
        }

        function closeImageModal() {
            const modal = document.getElementById('imageModal');
            modal.style.display = 'none';
        }

        document.addEventListener('DOMContentLoaded', function() {
            const loadingOverlay = document.getElementById('loadingOverlay');
            loadingOverlay.style.display = 'none';

            const modal = document.getElementById('imageModal');
            modal.addEventListener('click', closeImageModal);
        });
    </script>
</body>
</html>
    '''


# Define the Generate Endpoint

In [9]:
@app.route('/generate', methods=['POST'])
def generate():
    data = request.json
    prompt = data['prompt']
    width = int(data['width'])
    height = int(data['height'])
    seed = int(data['seed'])
    steps = int(data['steps'])
    sampler_name = data['sampler']
    scheduler = data['scheduler']

    with torch.inference_mode():
        if seed == 0:
            seed = random.randint(0, 18446744073709551615)
        print(f"Seed: {seed}")

        cond, pooled = clip.encode_from_tokens(clip.tokenize(prompt), return_pooled=True)
        cond = [[cond, {"pooled_output": pooled}]]
        print("Prompt encoded with CLIP model")

        noise = RandomNoise.get_noise(seed)[0]
        guider = BasicGuider.get_guider(unet, cond)[0]
        print("Noise and guider created")

        sampler = KSamplerSelect.get_sampler(sampler_name)[0]
        sigmas = BasicScheduler.get_sigmas(unet, scheduler, steps, 1.0)[0]
        latent_image = EmptyLatentImage.generate(closestNumber(width, 16), closestNumber(height, 16))[0]
        print("Model components loaded and initialized")

        sample, sample_denoised = SamplerCustomAdvanced.sample(noise, guider, sampler, sigmas, latent_image)
        model_management.soft_empty_cache()
        print("Image sampling completed")

        decoded = VAEDecode.decode(vae, sample)[0].detach()
        img = Image.fromarray(np.array(decoded * 255, dtype=np.uint8)[0])
        print("Image generated and decoded")

        buffered = BytesIO()
        img.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        print("Image processing complete")

        return jsonify({
            "imageUrl": f"data:image/png;base64,{img_str}",
            "seed": seed
        })


# Start ngrok and Run Flask App

In [None]:

ngrok.set_auth_token("Your_token")
public_url = ngrok.connect(5000)
print(f"Public URL: {public_url}")

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)