In [1]:
import numpy as np

import torchvision.transforms as transforms
from torchvision.utils import make_grid

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

from torch.utils.data import Dataset

import torchvision.transforms as transforms

import os
import glob
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import warnings

import base64

In [2]:
from flask import Flask, request, jsonify
import torch
from torchvision import transforms
from PIL import Image
import io
from flask_cors import CORS

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1), # Pads the input tensor using the reflection of the input boundary
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features), 
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)


class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_block):
        super(GeneratorResNet, self).__init__()
        
        channels = input_shape[0]
        
        # Initial Convolution Block
        out_features = 64
        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True)
        ]
        in_features = out_features
        
        # Downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
        
        # Residual blocks
        for _ in range(num_residual_block):
            model += [ResidualBlock(out_features)]
            
        # Upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2), # --> width*2, heigh*2
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            
        # Output Layer
        model += [nn.ReflectionPad2d(channels),
                  nn.Conv2d(out_features, channels, 7),
                  nn.Tanh()
                 ]
        
        # Unpacking
        self.model = nn.Sequential(*model) 
        
    def forward(self, x):
        return self.model(x)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GeneratorResNet(input_shape=(3, 256, 256), num_residual_block=9)
model.load_state_dict(torch.load('BA_epoch_90.pth', map_location=device))
model.to(device)
model.eval()
app = Flask(__name__)
CORS(app)
@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        try:
            file = Image.open(io.BytesIO(request.files['file'].read()))

            transform = transforms.Compose([
                #transforms.Resize((256, 256)),
                transforms.ToTensor(),
                # Normalize to [-1, 1]
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

            # Transform the image
            image = transform(file).unsqueeze(0).to(device)

            # Disable gradient calculation
            with torch.no_grad():
                print("Image tensor shape:", image.shape)
                output = model(image)
                print("Output tensor shape:", output.shape)
                print("Output min value:", torch.min(output))
                print("Output max value:", torch.max(output))

            # Convert tensor to PIL --> Byte Array
            output_image = transforms.ToPILImage()(output.squeeze().cpu().detach())
            byte_arr = io.BytesIO()
            output_image.save(byte_arr, format='JPEG')
            byte_arr = byte_arr.getvalue()
            img_str = base64.b64encode(byte_arr).decode()
            print("Base64 image string:", img_str)

            return jsonify({'result': img_str})

        except Exception as e:
            print("Error:", e)
            return jsonify({'error': str(e)}), 400
        
       
        
@app.route('/', methods=['GET'])    
def home():
    return "hello, world!"
    
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://172.26.91.183:5000
[33mPress CTRL+C to quit[0m


Image tensor shape: torch.Size([1, 3, 800, 800])


127.0.0.1 - - [09/Jun/2023 13:05:47] "POST /predict HTTP/1.1" 200 -


Output tensor shape: torch.Size([1, 3, 800, 800])
Output min value: tensor(-0.9713)
Output max value: tensor(0.9697)
Base64 image string: /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAMgAyADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwBl7fywu2ANwfAYdCPpVKbV5gFyAPwNTa