In [None]:
from flask import Flask, send_file, request
import torch
from PIL import Image
from io import BytesIO
import torchvision.transforms as transforms
from model.custom.net import ConvAutoencoder

torch.cuda.empty_cache()

app = Flask(__name__)
model = ConvAutoencoder()
model.load_state_dict(torch.load('save/custom/model_1.pkl'))
model = model.cuda()
model.eval()
transform = transforms.Compose([
    transforms.ToTensor(),
])


def serve_pil_image(pil_img):
    img_io = BytesIO()
    pil_img.save(img_io, 'JPEG', quality=100)
    img_io.seek(0)
    return send_file(img_io, mimetype='image/jpeg', as_attachment=True, download_name='denoised.jpg')


@app.route("/")
def main():
    return app.send_static_file('index.html')


@app.route("/denoise", methods=['POST'])
def denoise():
    image = request.files['image']
    img_io = BytesIO()
    image.save(img_io)
    img_io.seek(0)
    image = Image.open(img_io)
    print(image.size)
    image = image.resize((8 * (image.size[0] // 8), 8 * (image.size[1] // 8)))
    print(image.size)
    image = transform(image).cuda()
    image = image.unsqueeze(0)
    with torch.no_grad():
        image = model(image)
    image = image.squeeze(0)
    image = image.cpu().detach().numpy()
    image = image.transpose(1, 2, 0)
    image = (image * 255).astype('uint8')
    image = Image.fromarray(image)
    return serve_pil_image(image)


if __name__ == '__main__':
    app.run(port=8080)