In [1]:
# Install required dependencies
!pip install flask pillow requests torch diffusers transformers accelerate




[notice] A new release of pip is available: 23.0.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
# Packages to be used
from flask import Flask, request, jsonify
import requests
from io import BytesIO
from PIL import Image
import torch
from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
from diffusers.utils import load_image
import logging

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Create app for api using Flask
app = Flask(__name__)
SERVER_PORT = 5000
SERVER_HOST = "0.0.0.0"

In [4]:
# Set up logging
logging.basicConfig(level=logging.DEBUG)

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# model_id = "CompVis/stable-diffusion-v1-4"
model_id = "runwayml/stable-diffusion-v1-5"
# pipeline = StableDiffusionPipeline.from_pretrained(model_id).to(device)
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(model_id).to(device)


DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /api/models/runwayml/stable-diffusion-v1-5 HTTP/11" 200 6423
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /runwayml/stable-diffusion-v1-5/resolve/main/model_index.json HTTP/11" 200 0
Loading pipeline components...: 100%|██████████| 7/7 [00:02<00:00,  3.19it/s]


In [6]:
# Define function to treat the image and apply the aging function
def age_image(input_image: Image.Image) -> Image.Image:

    image_to_convert = input_image.convert("RGB")
    image_to_convert = image_to_convert.resize((768,512))
    # Aging transfer model code here
    # prompt = "Aged face, old person, wrinkles, gray hair"
    prompt = "A detailed close-up portrait of the same person but aged 30 years, with wrinkles, gray hair, age spots, and a mature, wise expression, realistic, high resolution"

    result = pipeline(prompt, image=image_to_convert, strength=0.75, guidance_scale=7.5)
    aged_image = result.images[0]
    return aged_image

In [7]:
# The main Flask app implementation
# Get image from URL, apply algorithm and save it to
@app.route('/age_picture', methods=['POST'])
def age_picture():
    image_url = request.form.get('url')

    if not image_url:
        app.logger.debug("No Image URL found")
        return jsonify({"error": "No Image URL found"}), 400

    try:
        response = requests.get(image_url)
        response.raise_for_status()
    except requests.RequestException as e:
        app.logger.debug(f"Error fetching image: {e}")
        return jsonify({"error": str(e)}), 400

    try:
        input_image = Image.open(BytesIO(response.content))
    except Exception as e:
        app.logger.debug(f"Error opening image, Invalid image format: {e}")
        return jsonify({"error": "Error opening image, Invalid image format"}), 400

    try:
        aged_image = age_image(input_image)
    except Exception as e:
        app.logger.debug(f"Error processing image: {e}")
        return jsonify({"error": "Error processing image", "pb":e}), 500

    output_path = "aged_image.jpg"
    
    try:
        aged_image.save(output_path)
    except Exception as e:
        app.logger.debug(f"Error saving image: {e}")
        return jsonify({"error": "Error saving image"}), 500

    return jsonify({"path": output_path})

In [8]:
# Launch the flask app server
def run_flask():
    app.run(host=SERVER_HOST, port=SERVER_PORT)

In [9]:
import threading

# Start flask server in a thread
thread = threading.Thread(target=run_flask)
thread.start()

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://192.168.1.113:5000
INFO:werkzeug:[33mPress CTRL+C to quit[0m


# TEST

In [10]:
# Test model using our api
endpoint_url = "http://127.0.0.1:5000/age_picture"
image_url = "https://img.freepik.com/free-photo/portrait-white-man-isolated_53876-40306.jpg"

form_data = {
    'url': image_url
}

response = requests.post(endpoint_url, data=form_data)

print(response.json())

DEBUG:urllib3.connectionpool:Starting new HTTP connection (1): 127.0.0.1:5000
DEBUG:urllib3.connectionpool:http://127.0.0.1:5000 "POST /age_picture HTTP/11" 200 26


DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): img.freepik.com:443
DEBUG:urllib3.connectionpool:https://img.freepik.com:443 "GET /free-photo/portrait-white-man-isolated_53876-40306.jpg HTTP/11" 200 43027
100%|██████████| 37/37 [29:35<00:00, 48.00s/it]
INFO:werkzeug:127.0.0.1 - - [04/Aug/2024 20:18:20] "POST /age_picture HTTP/1.1" 200 -


{'path': 'aged_image.jpg'}
