In [1]:
import numpy as np

import torchvision.transforms as transforms
from torchvision.utils import make_grid

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

import itertools

from torch.utils.data import Dataset

import torchvision.transforms as transforms

import os
import glob
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import warnings

import base64

In [2]:
from flask import Flask, request, jsonify
import torch
from torchvision import transforms
from PIL import Image
import io
from flask_cors import CORS

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1), # Pads the input tensor using the reflection of the input boundary
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features), 
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)


class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_block):
        super(GeneratorResNet, self).__init__()
        
        channels = input_shape[0]
        
        # Initial Convolution Block
        out_features = 64
        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True)
        ]
        in_features = out_features
        
        # Downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
        
        # Residual blocks
        for _ in range(num_residual_block):
            model += [ResidualBlock(out_features)]
            
        # Upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2), # --> width*2, heigh*2
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            
        # Output Layer
        model += [nn.ReflectionPad2d(channels),
                  nn.Conv2d(out_features, channels, 7),
                  nn.Tanh()
                 ]
        
        # Unpacking
        self.model = nn.Sequential(*model) 
        
    def forward(self, x):
        return self.model(x)

In [4]:
# data (path)
root = 'data/'
# data (img)
img_height = 256
img_width = 256
channels = 3
# number of cpu threads to use during batch generation
n_cpu = 2
# training
epoch = 0 # epoch to start training from
n_epochs = 200 # number of epochs of training
batch_size = 1 # size of the batches
lr = 0.0002 # adam : learning rate
b1 = 0.5 # adam : decay of first order momentum of gradient
b2 = 0.999 # adam : decay of first order momentum of gradient
decay_epoch = 100 # suggested default : 100 (suggested 'n_epochs' is 200)
                 # epoch from which to start lr decay2

In [5]:
transforms_ = [
    transforms.Resize(int(img_height*1.12), Image.BICUBIC),
    transforms.RandomCrop((img_height, img_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

  transforms.Resize(int(img_height*1.12), Image.BICUBIC),


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_shape = (channels, img_height, img_width) # (3,256,256)
n_residual_blocks = 9 # suggested default, number of residual blocks in generator
BA = GeneratorResNet(input_shape, n_residual_blocks)
AB = GeneratorResNet(input_shape, n_residual_blocks)
optimizer = torch.optim.Adam(itertools.chain(AB.parameters(), BA.parameters()), lr=lr, betas=(b1,b2))
AB_checkpoint = torch.load('BA_epoch_115.pt',map_location=device)
BA_checkpoint = torch.load('BA_epoch_115.pt',map_location=device)

AB.load_state_dict(AB_checkpoint['state_dict'])
BA.load_state_dict(BA_checkpoint['state_dict'])
optimizer.load_state_dict(BA_checkpoint['optimizer_state_dict'])

AB.eval()
BA.eval()

GeneratorResNet(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (9): ReLU(inplace=True)
    (10): ResidualBlock(
      (block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (3): ReLU(inplace=True)
        (4): ReflectionPad2d((1, 1, 1, 1))
        

In [None]:
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#model = GeneratorResNet(input_shape=(3, 256, 256), num_residual_block=9)
#_ = GeneratorResNet(input_shape=(3, 256, 256), num_residual_block=9)

#optimizer = torch.optim.Adam(itertools.chain(_.parameters(), model.parameters()), lr=lr, betas=(b1,b2))

#checkpoint=torch.load('BA_epoch_115.pt',map_location=device)
#model.load_state_dict(checkpoint['state_dict'],)
#optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#model.to(device)
#model.eval()
app = Flask(__name__)
CORS(app)
@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        try:
            file = Image.open(io.BytesIO(request.files['file'].read()))

            transform = transforms.Compose([
                #transforms.Resize((256, 256)),
                transforms.ToTensor(),
                # Normalize to [-1, 1]
                #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

            # Transform the image
            image = transform(file).unsqueeze(0).to(device)

            # Disable gradient calculation
            with torch.no_grad():
                print("Image tensor shape:", image.shape)
                output = BA(image)
                print("Output tensor shape:", output.shape)
                print("Output min value:", torch.min(output))
                print("Output max value:", torch.max(output))

            # Convert tensor to PIL --> Byte Array
            output_image = transforms.ToPILImage()(output.squeeze().cpu().detach())
            byte_arr = io.BytesIO()
            output_image.save(byte_arr, format='JPEG')
            byte_arr = byte_arr.getvalue()
            img_str = base64.b64encode(byte_arr).decode()
            print("Base64 image string:", img_str)

            return jsonify({'result': img_str})

        except Exception as e:
            print("Error:", e)
            return jsonify({'error': str(e)}), 400
        
       
        
@app.route('/', methods=['GET'])    
def home():
    return "hello, world!"
    
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on all addresses.
 * Running on http://172.26.91.56:5000/ (Press CTRL+C to quit)


Image tensor shape: torch.Size([1, 3, 800, 800])


127.0.0.1 - - [09/Jun/2023 14:20:50] "POST /predict HTTP/1.1" 200 -


Output tensor shape: torch.Size([1, 3, 800, 800])
Output min value: tensor(-0.9880)
Output max value: tensor(0.9841)
Base64 image string: /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAMgAyADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwCr4m1PUZ9bjhfSbR7JI0eGRlkV2Vh83I