In [1]:
from flask import Flask, request, send_file
from PIL import Image
import torch
import torchvision.transforms as transforms
import io


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# Load the generator
from autoencoder_gen import Generator # Assuming you have a separate file with the generator model class named `my_generator.py`
generator = Generator(1, 3, 32)
generator.load_state_dict(torch.load('generator_autoencoder_unet_128_state.pt', map_location='cpu'))

# generator = torch.load('generator_autoencoder_unet_128.pth')
# generator.to("cpu")
generator = generator.eval()

In [3]:
app = Flask(__name__)

CHANNELS_IMG = 3
IMAGE_SIZE = 128

@app.route('/generate', methods=['POST'])
def generate():
    image = Image.open(request.files['image'])
    if image.mode != 'RGB':  # If the image is grayscale, convert it to RGB
        image = image.convert('RGB')

    input_image = preprocess_image(image)
    gray_image = input_image.mean(dim=1, keepdim=True)  # Compute the mean of the RGB channels to create a single channel image

    with torch.no_grad():
        output_image = generator(gray_image)

    output_image = postprocess_output(output_image)

    # Save the output image as a byte stream to send it back as a response
    output_stream = io.BytesIO()
    Image.fromarray(output_image).save(output_stream, format='PNG')
    output_stream.seek(0)

    return send_file(output_stream, mimetype='image/png')

def preprocess_image(image, image_size=(IMAGE_SIZE, IMAGE_SIZE)):
    transform = transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
            ),
        ]
    )
    return transform(image).unsqueeze(0)

def postprocess_output(output_tensor):
    output_image = output_tensor.squeeze(0).detach().cpu().numpy().transpose(1, 2, 0)
    return ((output_image + 1) / 2 * 255).clip(0, 255).astype('uint8')



In [None]:
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8000)

 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on all addresses.
 * Running on http://192.168.178.13:8000/ (Press CTRL+C to quit)
127.0.0.1 - - [07/May/2023 17:10:27] "POST /generate HTTP/1.1" 200 -
127.0.0.1 - - [07/May/2023 17:10:38] "POST /generate HTTP/1.1" 200 -
127.0.0.1 - - [07/May/2023 17:10:57] "POST /generate HTTP/1.1" 200 -
127.0.0.1 - - [07/May/2023 17:11:45] "POST /generate HTTP/1.1" 200 -
127.0.0.1 - - [07/May/2023 17:13:11] "POST /generate HTTP/1.1" 200 -
127.0.0.1 - - [07/May/2023 22:43:51] "POST /generate HTTP/1.1" 200 -
127.0.0.1 - - [07/May/2023 22:49:55] "POST /generate HTTP/1.1" 200 -
127.0.0.1 - - [07/May/2023 22:52:40] "POST /generate HTTP/1.1" 200 -
127.0.0.1 - - [07/May/2023 22:52:44] "POST /generate HTTP/1.1" 200 -
127.0.0.1 - - [07/May/2023 22:52:50] "POST /generate HTTP/1.1" 200 -
