# Running the model locally

Before running the following cells, make sure that you've already downloaded the pretrained models inside ```../pretrained_models/```

In [None]:
%config Completer.use_jedi=False

%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../")
from io import BytesIO

In [None]:
import base64
import requests
import torch
import os
import numpy as np
import argparse
from PIL import Image
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.utils as vutils
from network.Transformer import Transformer
from matplotlib.pyplot import imshow

from tqdm import tqdm_notebook
from test_from_code import transform

In [None]:
styles = ["Hosoda", "Hayao", "Shinkai", "Paprika"]

models = {}

for style in tqdm_notebook(styles):
    model = Transformer()
    model.load_state_dict(torch.load(os.path.join("./../pretrained_models/", style + '_net_G_float.pth')))
    model.eval()
    models[style] = model

In [None]:
### change this path to test one of the four models locally
path = "../../test-images/paris.jpg"

In [None]:
img = Image.open(path)
imshow(img)

In [None]:
### pick a style in : ["Hosoda", "Hayao", "Shinkai", "Paprika"]
style = "Hosoda"

In [None]:
### choose a load_size, the higher the better are the results, but the longer is the transformation
load_size = 300

In [None]:
%%time
output300 = transform(models, style, path, load_size)

In [None]:
imshow(output300)

In [None]:
%%time
### load_size to 450
output450 = transform(models, style, path, load_size=450)

In [None]:
imshow(output450)

In [None]:
%%time
### load_size to 650
output650 = transform(models, style, path, load_size=650)

In [None]:
imshow(output650)

## Use the deployed API

In this section, you'll have to put in your API URL

In [None]:
path = '../../test-images/lawrence.jpg'

img = Image.open(path)
with open(path, "rb") as image_file:
    encoded_string = base64.b64encode(image_file.read()).decode('utf-8')

# url = "https://t06twtw4n1.execute-api.us-west-1.amazonaws.com/dev/transform"
url = "https://tdi2o4m9vi.execute-api.us-west-1.amazonaws.com/dev/transform"

data = {
    "image": encoded_string,
    "model_id": 1,
    "load_size": 500
}

In [None]:
imshow(img)

In [None]:
%%time
response = requests.post(url, json=data)
print(response)
print(response.json())

In [None]:
image = response.json()["output"]

image = image[image.find(",")+1:]
dec = base64.b64decode(image + "===")

image = Image.open(BytesIO(dec))

In [None]:
imshow(image)