In [None]:
!wget https://web.eecs.umich.edu/~justincj/models/vgg16-00b39a1b.pth

--2022-05-06 17:26:05--  https://web.eecs.umich.edu/~justincj/models/vgg16-00b39a1b.pth
Resolving web.eecs.umich.edu (web.eecs.umich.edu)... 141.212.113.214
Connecting to web.eecs.umich.edu (web.eecs.umich.edu)|141.212.113.214|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 553451520 (528M) [application/x-tar]
Saving to: ‘vgg16-00b39a1b.pth’


2022-05-06 17:26:24 (28.1 MB/s) - ‘vgg16-00b39a1b.pth’ saved [553451520/553451520]



In [None]:
!unzip -qq /content/models.zip

In [None]:
!pip install --upgrade pip

Collecting pip
  Downloading pip-22.0.4-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 14.1 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-22.0.4


In [None]:
!pip install -r requirements.txt

Collecting Pillow==9.0.1
  Using cached Pillow-9.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.3 MB)
Collecting opencv-python==4.5.5.64
  Using cached opencv_python-4.5.5.64-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (60.5 MB)
Collecting requests==2.24.0
  Using cached requests-2.24.0-py2.py3-none-any.whl (61 kB)
Collecting fast-dash==0.1.5
  Using cached fast_dash-0.1.5-py3-none-any.whl (32 kB)
Collecting dash-bootstrap-components<2.0.0,>=1.0.2
  Using cached dash_bootstrap_components-1.1.0-py3-none-any.whl (210 kB)
Collecting Flask<3.0.0,>=2.0.2
  Using cached Flask-2.1.2-py3-none-any.whl (95 kB)
Collecting mkdocs-material-extensions<2.0.0,>=1.0.1
  Using cached mkdocs_material_extensions-1.0.3-py3-none-any.whl (8.1 kB)
Collecting dash[testing]<3.0.0,>=2.2.0
  Using cached dash-2.3.1-py3-none-any.whl (9.6 MB)
Collecting dash-html-components==2.0.0
  Using cached dash_html_components-2.0.0-py3-none-any.whl (4.1 kB)
Collecting dash-core-components==2.0.

In [9]:
!pip install flask-ngrok

['Collecting flask-ngrok',
 '  Downloading flask_ngrok-0.0.25-py3-none-any.whl (3.1 kB)',
 'Installing collected packages: flask-ngrok',
 'Successfully installed flask-ngrok-0.0.25',
 '\x1b[0m']

In [2]:
from io import BytesIO
import base64

import torch
from torchvision import transforms
import cv2
import PIL

from custom_model import CustomModel
from vgg16 import VGG16
import utils

from fast_dash import FastDash, Fastify
from fast_dash.Components import UploadImage, Image, html
from fast_dash.utils import pil_to_b64
from dash import dcc


#### Define inference function
## VGG16 mapper
vgg16_model_mapper = {'coco_rain_princess': 'models/VGG16/COCO/rain_princess.pth',
                         'coco_the_scream':  'models/VGG16/COCO/the_scream.pth',
                         'coco_the_shipwreck':  'models/VGG16/COCO/the_shipwreck.pth',
                         'coco_udnie':  'models/VGG16/COCO/udnie.pth',
                         'coco_wave':  'models/VGG16/COCO/wave.pth',
                         'tinyIN_rain_princess': 'models/VGG16/TinyImagenet/rain_princess.pth',
                         'tinyIN_the_scream':  'models/VGG16/TinyImagenet/the_scream.pth',
                         'tinyIN_the_shipwreck':  'models/VGG16/TinyImagenet/the_shipwreck.pth',
                         'tinyIN_udnie':  'models/VGG16/TinyImagenet/udnie.pth',
                         'tinyIN_wave':  'models/VGG16/TinyImagenet/wave.pth'}


custom_model_mapper = {'coco_rain_princess': 'models/Custom/COCO/rain_princess.pth',
                         'coco_the_scream':  'models/Custom/COCO/the_scream.pth',
                         'coco_the_shipwreck':  'models/Custom/COCO/the_shipwreck.pth',
                         'coco_udnie':  'models/Custom/COCO/udnie.pth',
                         'coco_wave':  'models/Custom/COCO/wave.pth',
                         'tiny_imagenet_rain_princess': 'models/Custom/TinyImagenet/rain_princess.pth',
                         'tiny_imagenet_the_scream':  'models/Custom/TinyImagenet/the_scream.pth',
                         'tiny_imagenet_the_shipwreck':  'models/Custom/TinyImagenet/the_shipwreck.pth',
                         'tiny_imagenet_udnie':  'models/Custom/TinyImagenet/udnie.pth',
                         'tiny_imagenet_wave':  'models/Custom/TinyImagenet/wave.pth'}
                         
def make_snake_case(x): 
    return x.lower().replace(' ', '_')


def load_image_base64(base64_str):
    img = PIL.Image.open(BytesIO(base64.b64decode(base64_str))).convert('RGB')
    return img


def saveimg(image):
    # clip the image to [0, 255]
    image = image.clip(0, 255).astype("uint8")
    image = PIL.Image.fromarray(image)  
    return image


def stylize(image, architecture, trained_on, style):
    
    _, image_content = image.split(',')
    content_image = load_image_base64(image_content)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if architecture == 'VGG16':
        style_model = VGG16()
        model_path = vgg16_model_mapper[f"{make_snake_case(trained_on)}_{make_snake_case(style)}"]
        
    else:
        style_model = CustomModel()
        model_path = custom_model_mapper[f"{make_snake_case(trained_on)}_{make_snake_case(style)}"]
    
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    
    content_image = content_transform(content_image).unsqueeze(0).to(device)

    with torch.no_grad():
        state_dict = torch.load(model_path)
                
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        
        output = style_model(content_image).cpu()
        output = utils.ttoi(output.clone())
        
    image = saveimg(output)
    image_b64 = pil_to_b64(image)
    
    return image_b64


### Web app using Fast Dash!
# Fastify Dash's dropdown component
architecture_dropdown = Fastify(component=dcc.Dropdown(options={x:x for x in ['VGG16', 'Custom']}), assign_prop='value')
architecture_trained_on = Fastify(component=dcc.Dropdown(options={x:x for x in ['COCO', 'Tiny Imagenet']}), assign_prop='value')
architecture_style = Fastify(component=dcc.Dropdown(options={x:x for x in ['Rain Princess', 'The Scream', 'The Shipwreck',
                                                                                   'Udnie', 'Wave']}), assign_prop='value')

app = FastDash(callback_fn=stylize, 
                inputs=[UploadImage, architecture_dropdown, architecture_trained_on, architecture_style], 
                outputs=Image, 
                title='Neural Style Transfer',
                title_image_path='https://raw.githubusercontent.com/dkedar7/fast_dash/main/examples/Neural%20style%20transfer/assets/icon.png',
                subheader="Apply styles from well-known pieces of art to your own photos",
                github_url='https://github.com/dkedar7/fast_dash/',
                linkedin_url='https://linkedin.com/in/dkedar7/',
                twitter_url='https://twitter.com/dkedar7/',
                theme='JOURNAL')



if __name__=='__main__':
    
    app.run()

 * Serving Flask app 'fast_dash.fast_dash' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:5000 (Press CTRL+C to quit)
