In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision as tv
from PIL import Image
import imageio
import numpy as np
import io
from matplotlib import pyplot as plt

In [2]:
import uvicorn
import numpy as np
import nest_asyncio
from enum import Enum
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import StreamingResponse

In [3]:
to_tensor = tv.transforms.Compose([
                tv.transforms.Resize((512,512)),
                tv.transforms.ToTensor(),
                tv.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[1, 1, 1]),
            ])

unload = tv.transforms.Compose([
                tv.transforms.Normalize(mean=[-0.485,-0.456,-0.406],
                                    std=[1,1,1]),                
                tv.transforms.Lambda(lambda x: x.clamp(0,1))
            ])
to_image = tv.transforms.ToPILImage()

In [4]:
def image_loader(image_name):
    
    # load the image from the notebook filesystem
    image = Image.open(image_name)

    return to_tensor(image).cuda()

In [5]:
import os

dir_name = "images_uploaded"
if not os.path.exists(dir_name):
    os.mkdir(dir_name)

In [6]:
app = FastAPI(title="Style Transfer Pro")

@app.get("/")
def home():
    return "Working like a charm! Go to http://localhost:8000/docs to test!"

@app.post("/predict")
def prediction(input_img: UploadFile = File(...), style: UploadFile = File(...)):
    filename = input_img.filename
    fileExtension = filename.split(".")[-1] in ("jpg", "jpeg", "png")
    if not fileExtension:
        raise HTTPException(status_code=415, detail="Unsupported file provided")

    input_image = image_loader(io.BytesIO(input_img.file.read()))
    style_image = image_loader(io.BytesIO(style.file.read()))

    def get_features(module, x, y):
    #     print('here')
        features.append(y)
        
    def gram_matrix(x):
        
        b, c, h, w = x.size()
        F = x.view(b,c,h*w)
        G = torch.bmm(F, F.transpose(1,2))/(h*w)
        return G

    VGG = tv.models.vgg19(pretrained=True).features
    VGG.cuda()

    for i, layer in enumerate(VGG):
        
        if i in [0,5,10,19,21,28]:
            VGG[i].register_forward_hook(get_features)
        
        elif isinstance(layer, nn.MaxPool2d):
            VGG[i] = nn.AvgPool2d(kernel_size=2)

    VGG.eval()

    for p in VGG.parameters():
        p.requires_grad = False

    features = []
    VGG(input_image.unsqueeze(0))
    c_target = features[4].detach()

    features = []
    VGG(style_image.unsqueeze(0))
    f_targets = features[:4]+features[5:]
    gram_targets = [gram_matrix(i).detach() for i in f_targets]

    alpha = 1
    beta = 1e3
    iterations = 100
    image = input_image.clone().unsqueeze(0)
    # image = torch.randn(1,3,512,512).cuda()
    images = []
    optimizer = optim.LBFGS([
    image.requires_grad_()], lr=1)    
    mse_loss = nn.MSELoss(reduction='mean')
    l_c = []
    l_s = []
    counter = 0

    for itr in range(iterations):

        features = []
        def closure():
            optimizer.zero_grad()
            VGG(image)
            t_features = features[-6:]
            content = t_features[4]
            style_features = t_features[:4]+t_features[5:]
            t_features = []
            gram_styles = [gram_matrix(i) for i in style_features]
            c_loss = alpha * mse_loss(content, c_target)
            s_loss = 0

            for i in range(5):
                n_c = gram_styles[i].shape[0]
                s_loss += beta * mse_loss(gram_styles[i],gram_targets[i])/(n_c**2)

            total_loss = c_loss+s_loss

            l_c.append(c_loss)
            l_s.append(s_loss)
            
            total_loss.backward()
            return total_loss

        optimizer.step(closure)
        

        if itr%1 == 0:
            temp = unload(image[0].cpu().detach())
            temp = to_image(temp)
            temp = np.array(temp)
            images.append(temp)
            imageio.mimsave(f'styled_images/{filename.split(".")[0]}progress.gif', images)

    output_image = images[-1]
    plt.imsave(f"styled_images/{filename}", output_image)

    file_image = open(f'styled_images/{filename}', mode='rb')

    return StreamingResponse(file_image, media_type="image/jpeg")

In [7]:
nest_asyncio.apply()

host = "0.0.0.0" if os.getenv("DOCKER-SETUP") else "127.0.0.1"

uvicorn.run(app, host=host, port=8000)

INFO:     Started server process [12020]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO:     127.0.0.1:53350 - "POST /predict HTTP/1.1" 500 Internal Server Error
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "C:\Users\chris\AppData\Local\Programs\Python\Python38\lib\site-packages\uvicorn\protocols\http\h11_impl.py", line 396, in run_asgi
    result = await app(self.scope, self.receive, self.send)
  File "C:\Users\chris\AppData\Local\Programs\Python\Python38\lib\site-packages\uvicorn\middleware\proxy_headers.py", line 45, in __call__
    return await self.app(scope, receive, send)
  File "C:\Users\chris\AppData\Local\Programs\Python\Python38\lib\site-packages\fastapi\applications.py", line 199, in __call__
    await super().__call__(scope, receive, send)
  File "C:\Users\chris\AppData\Local\Programs\Python\Python38\lib\site-packages\starl