In [1]:
import numpy as np

import torchvision.transforms as transforms
from torchvision.utils import make_grid

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

from torch.utils.data import Dataset

import torchvision.transforms as transforms

import os
import glob
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import warnings

import base64
import json

In [2]:
from flask import Flask, request, jsonify
import torch
from torchvision import transforms
from PIL import Image
import io
from flask_cors import CORS

In [3]:
def Munchify(model, input_image):
    
    transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # open and transform image to be syle transfered
    image = input_image
    image = transform(image).unsqueeze(0)

    # use generator to perform style transfer
    image = model(image).detach()

    # get correct dimensions, to ndarray, as image, and save
    image = image.squeeze().numpy()
    image = ((image + 1) / 2) * 256
    image = image.astype(np.uint8).transpose(1,2,0)
    img = Image.fromarray(image, 'RGB')
    
    # create byte array, save image in byte array, convert to base64
    byte_arr = io.BytesIO()
    img.save(byte_arr, format='JPEG')
    byte_arr = byte_arr.getvalue()
    img_str = base64.b64encode(byte_arr).decode()
    
    return img_str

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1), # Pads the input tensor using the reflection of the input boundary
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features), 
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)


class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_block):
        super(GeneratorResNet, self).__init__()
        
        channels = input_shape[0]
        
        # Initial Convolution Block
        out_features = 64
        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True)
        ]
        in_features = out_features
        
        # Downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
        
        # Residual blocks
        for _ in range(num_residual_block):
            model += [ResidualBlock(out_features)]
            
        # Upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2), # --> width*2, heigh*2
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            
        # Output Layer
        model += [nn.ReflectionPad2d(channels),
                  nn.Conv2d(out_features, channels, 7),
                  nn.Tanh()
                 ]
        
        # Unpacking
        self.model = nn.Sequential(*model) 
        
    def forward(self, x):
        return self.model(x)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
g_models = []

path = "models/"

for model in os.listdir(path):
    if model == ".DS_Store":
        continue
    g_model = GeneratorResNet((3,256,256), 9)
    model_checkpoint = torch.load(path+model, map_location=device)
    g_model.load_state_dict(model_checkpoint['state_dict'])
    g_model.to(device)
    g_model.eval()
    g_models.append(g_model)
    
app = Flask(__name__)
CORS(app)
@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        try:
            file = Image.open(io.BytesIO(request.files['file'].read()))
                
            results = {}
            
            i = 1
            
            for model in g_models:
                result = Munchify(model, file)
                results[f'result{i}'] = result
                i += 1
            
            return jsonify(results)

        except Exception as e:
            print("Error:", e)
            return jsonify({'error': str(e)}), 400
        
       
        
@app.route('/', methods=['GET'])    
def home():
    return "hello, world!"
    
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on all addresses.
 * Running on http://192.168.105.221:5000/ (Press CTRL+C to quit)
127.0.0.1 - - [15/Jun/2023 09:22:08] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [15/Jun/2023 09:25:53] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [15/Jun/2023 09:29:13] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [15/Jun/2023 09:41:31] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [15/Jun/2023 09:43:49] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [15/Jun/2023 09:45:43] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [15/Jun/2023 09:47:42] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [15/Jun/2023 09:49:17] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [15/Jun/2023 09:51:14] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [15/Jun/2023 09:53:34] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [15/Jun/2023 09:55:18] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [15/Jun/2023 09:56:57] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [15/Jun/2023 09:59:10] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [15/Jun