<a href="https://colab.research.google.com/github/Jiang15/fashion-gan/blob/master/app.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
cd /content/drive/MyDrive/fashion-gan/stylegan/

/content/drive/MyDrive/fashion-gan/stylegan


In [None]:
# Upgrade CUDA to 11.2
!pip install flask-ngrok
!sudo apt-get upgrade cuda
!apt --fix-broken install
!pip install ninja
# Upgrade pytorch to CUDA 11.0
!pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio===0.11.0 -f https://download.pytorch.org/whl/torch_stable.html
!pip install dalle-pytorch --upgrade
!pip install sentencepiece
!pip install transformers
!pip install html2text
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install torchtext==0.12.0

In [None]:
import cv2
import copy
import clip
import click
from einops import repeat
from google.colab.patches import cv2_imshow
import html2text
import imageio
import json
import math
from matplotlib import pyplot as plt
import numpy as np
import os
from PIL import Image, ImageChops, ImageOps
from pathlib import Path
import pandas as pd
import random
import re
import sentencepiece
import torch
from time import perf_counter
from tqdm import tqdm


import dnnlib
import legacy
from style_mixing import *
from helpers import *
from projector import img_to_w

from flask import Flask, render_template, request, jsonify, send_file
from flask_ngrok import run_with_ngrok


import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import make_grid, save_image

from transformers import FSMTForConditionalGeneration, FSMTTokenizer, MarianTokenizer, MarianMTModel
from dalle_pytorch import DALLE, VQGanVAE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer, SimpleTokenizer

#### Load network and model


In [None]:
# load network
network = "training_runs/fashion-gan-final.pkl"
device = torch.device('cuda')
with dnnlib.util.open_url(network) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device)

# load principle components and configuration of the semantically meaningful directions
vectors = np.load('out/vectors.npy')
with open('static/output/dim_dic.txt') as fobj:
    dim_dic = json.load(fobj)
data_percents = {'0':[33,33,33],'1':[33,33,33]}
id_dic = {
    0:[0,7],
    1:[7,14]
}

In [None]:
def resize_img(img):
  """resize images to 256*256
  """
  result = Image.new("RGB", (256, 256), (255, 255, 255))
  # Resize image
  img.thumbnail((256, 256))
  # compute center
  position_x = math.floor((256 - img.width) / 2)
  position_y = math.floor((256 - img.height) / 2)
  # Paste image into center of background
  result.paste(img, (position_x, position_y))
  return result

##### Text to image generation

In [None]:
# import fined dalle model
dalle_path = "static/textimg/dalle-final.pt"
load_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')
dalle_params.pop('vae', None) # cleanup later

vae = VQGanVAE()
dalle = DALLE(vae = vae, **dalle_params).cuda()
dalle.load_state_dict(weights)

In [None]:
# import translator
de_tokenizer_hel = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-de-en')
de_model_hel = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-de-en').to('cuda')

Downloading:   0%|          | 0.00/778k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/750k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.21M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/284M [00:00<?, ?B/s]

In [None]:

CHUNK_SIZE = 32
def en2ge(description):
  """Translate German to English"""
  # description = "'Knee-length fabric dress in grey with lace top, left side shoulder-free, right side with straps.'"
  description = html2text.html2text(description) # Clean text
  split_text = description.split()
  chunks = [" ".join(split_text[i:i+CHUNK_SIZE]) for i in range(0, len(split_text), CHUNK_SIZE)] # Break text into chunks of CHUNK_SIZE words

  batch = de_tokenizer_hel(chunks, return_tensors="pt", padding=True).to('cuda')
  gen = de_model_hel.generate(**batch)
  return de_tokenizer_hel.batch_decode(gen, skip_special_tokens=True)

##### Text-image similarity measurement via CLIP model

In [None]:
model = torch.jit.load("clip.pt").cuda().eval()
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, ToPILImage
from PIL import Image
input_resolution = model.input_resolution.item()
context_length = model.context_length.item()
vocab_size = model.vocab_size.item()
preprocess = Compose([
    ToPILImage(),
    Resize(input_resolution, interpolation=Image.BICUBIC),
    CenterCrop(input_resolution),
    ToTensor()
])
# image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
# image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()

def similarity(output, text):
  """measure the similarity between text and image"""
  images = []
  images.append(preprocess(output))
  image_input = torch.tensor(np.stack(images)).cuda()
  text_tokens = clip.tokenize(["This is " + text]).cuda()
  with torch.no_grad():
      image_features = model.encode_image(image_input).float()
      text_features = model.encode_text(text_tokens).float()
  image_features /= image_features.norm(dim=-1, keepdim=True)
  text_features /= text_features.norm(dim=-1, keepdim=True)
  similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
  return similarity

def clip_score(image_path, text):
  """calculate the cosine similarity with clip model"""
  image = [preprocess(Image.open(image_path)).unsqueeze(0), Image.open('/content/drive/MyDrive/Colab Notebooks/Wei/stylegan2-ada-pytorch/static/upload_image/9.jpg')].to(device)
  text = clip.tokenize(text).to(device)

  with torch.no_grad():
      image_features = model.encode_image(image)
      text_features = model.encode_text(text)

      logits_per_image, logits_per_text = model(image, text)
      probs = logits_per_image.softmax(dim=-1).cpu().numpy()

  print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]
  return probs

In [None]:
def stylegan_text_image(text, num_images=2000, num_show =8, outputs_dir = "./output"):
  """
  Generate images with stylegan model
  """
  outputs = {}
  ind = 0
  for seed1 in range(num_images):
      w_1 = generate_w(generateZ_from_seed(seed1, G), G, truncation_psi=1)
      img1 = generate_image_from_w(w_1, G, noise_mode = 'const')
      s = similarity(img1, text)[0][0]
      if s > 0.31:
          outputs[ind] = img1
          ind+=1
          print(s)
          if ind >= num_show:
              break
      # save all images
  for i, image in outputs.items():
      cv2.imwrite(outputs_dir+str(i)+'.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR))


def text2image(input_text, top_k=0.99, num_images=100, batch_size=10, outputs_dir = "./output"):
  """
  Generate images with text via dalle model
  """
  texts = input_text.split('|')
  outputs = {}
  for text_en in tqdm(texts):
      text = en2ge(text_en)
      text = tokenizer.tokenize(text, dalle.text_seq_len).cuda()
      text = repeat(text, '() n -> b n', b = num_images)
      for text_chunk in text.split(batch_size):
          output = dalle.generate_images(text_chunk,  filter_thres = top_k)
          for o in output:
              score = similarity(o.squeeze(0), text_en)[0][0]

              if score>0.2:
                  outputs[score] = o


  # save all images
  outputs_dir = Path(outputs_dir)
  outputs_dir.mkdir(parents = True, exist_ok = True)

  for i, image in tqdm(enumerate(sorted(outputs.keys(),reverse =True)), desc = 'saving images'):
      if i<4:
          print(image)
          #cv2.imwrite(outputs_dir+str(i+4)+'.jpg', cv2.cvtColor(outputs[image].cpu(), cv2.COLOR_RGB2BGR))
          save_image(outputs[image], outputs_dir / f'{i+4}.jpg', normalize=True)


### Run Webapp

In [None]:
# run app
app = Flask(__name__)
run_with_ngrok(app)   #starts ngrok when the app is run

@app.route("/")
def index():
    return render_template('index.html',
                           img1 = "static/upload_image/query1.jpg",
                           img2 = "static/upload_image/query2.jpg",
                           img3 = "static/upload_image/query3.jpg")

@app.route("/getImg1")
def getImg1():
    upload_path1 = os.path.join('static/upload_image','query1.jpg')
    seed1 = random.randint(0, 100)
    w_1 = generate_w(generateZ_from_seed(seed1, G), G, truncation_psi=1)
    img1 = generate_image_from_w(w_1, G, noise_mode = 'const')
    torch.save(w_1, 'static/output/w_1.pt')
    cv2.imwrite(upload_path1, cv2.cvtColor(img1, cv2.COLOR_RGB2BGR))

    upload_path2 = "static/image_interaction/query400_400.jpg"

    w_1 = torch.load('static/output/w_1.pt')
    w_2 = torch.load('static/output/w_2.pt')
    w_3 = torch.load('static/output/w_3.pt')

    w = torch.zeros_like(w_1)
    for ls, pers in zip([id_dic[0], id_dic[1]], [data_percents['0'], data_percents['1']]):
        w[0, ls[0]:ls[1]] = w_1[0, ls[0]:ls[1]]*pers[0]/100 + \
                            w_2[0, ls[0]:ls[1]]*pers[1]/100 + \
                            w_3[0, ls[0]:ls[1]]*pers[2]/100
    # time.sleep(1)
    print(seed1)
    torch.save(w, 'static/image_interaction/w_interaction.pt')
    img2 = generate_image_from_w(w, G, noise_mode = 'const')
    cv2.imwrite(upload_path2, cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))

    return json.dumps({"l": upload_path1+"?"+ str(random.randint(0,10000000)),
                       "interaction":upload_path2+"?"+ str(random.randint(0,10000000))})

@app.route("/getImg2")
def getImg2():

    seed2 = random.randint(0, 100)
    upload_path1 = os.path.join('static/upload_image','query2.jpg')
    w_2 = generate_w(generateZ_from_seed(seed2, G), G, truncation_psi=1)
    img2 = generate_image_from_w(w_2, G, noise_mode = 'const')
    torch.save(w_2, 'static/output/w_2.pt')
    cv2.imwrite(upload_path1, cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))
    upload_path2 = "static/image_interaction/query400_400.jpg"

    w_1 = torch.load('static/output/w_1.pt')
    w_2 = torch.load('static/output/w_2.pt')
    w_3 = torch.load('static/output/w_3.pt')

    print(seed2)
    w = torch.zeros_like(w_1)
    for ls, pers in zip([id_dic[0], id_dic[1]], [data_percents['0'], data_percents['1']]):
        w[0, ls[0]:ls[1]] = w_1[0, ls[0]:ls[1]]*pers[0]/100 + \
                            w_2[0, ls[0]:ls[1]]*pers[1]/100 + \
                            w_3[0, ls[0]:ls[1]]*pers[2]/100
    # time.sleep(1)

    torch.save(w, 'static/image_interaction/w_interaction.pt')
    img2 = generate_image_from_w(w, G, noise_mode = 'const')
    cv2.imwrite(upload_path2, cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))

    return json.dumps({"l": upload_path1+"?"+ str(random.randint(0,10000000)),
                       "interaction":upload_path2+"?"+ str(random.randint(0,10000000))})

@app.route("/getImg3")
def getImg3():

    upload_path1 = os.path.join('static/upload_image','query3.jpg')

    seed1 = random.randint(0, 100)
    w_3 = generate_w(generateZ_from_seed(seed1, G), G, truncation_psi=1)
    img1 = generate_image_from_w(w_3, G, noise_mode = 'const')
    torch.save(w_3, 'static/output/w_3.pt')
    cv2.imwrite(upload_path1, cv2.cvtColor(img1, cv2.COLOR_RGB2BGR))

    upload_path2 = "static/image_interaction/query400_400.jpg"

    w_1 = torch.load('static/output/w_1.pt')
    w_2 = torch.load('static/output/w_2.pt')
    w_3 = torch.load('static/output/w_3.pt')


    w = torch.zeros_like(w_1)
    for ls, pers in zip([id_dic[0], id_dic[1]], [data_percents['0'], data_percents['1']]):
        w[0, ls[0]:ls[1]] = w_1[0, ls[0]:ls[1]]*pers[0]/100 + \
                            w_2[0, ls[0]:ls[1]]*pers[1]/100 + \
                            w_3[0, ls[0]:ls[1]]*pers[2]/100
    # time.sleep(1)
    print(seed1)
    torch.save(w, 'static/image_interaction/w_interaction.pt')
    img2 = generate_image_from_w(w, G, noise_mode = 'const')
    cv2.imwrite(upload_path2, cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))

    return json.dumps({"l": upload_path1+"?"+ str(random.randint(0,10000000)),
                       "interaction":upload_path2+"?"+ str(random.randint(0,10000000))})


@app.route("/generate", methods=["POST"])
def generate():
    m_s = float(request.form["sleeve"])
    m_p = float(request.form["pattern"])

    upload_path1 = os.path.join('static/upload_image','query1.jpg')
    upload_path2 = os.path.join('static/upload_image','query2.jpg')

    # load w_1
    f1 = request.files["imgShape"]
    if f1:
        f1.save(upload_path1)
        img1 = Image.open(upload_path1).convert("RGB")
        width, height = img1.size
        if width != 256 or height != 256:
            img1 = resize_img(img1)
        img1.save(upload_path1)
        w_1 = img_to_w(img1, G, num_steps = 300, mode = "avg")
        # caching
        torch.save(w_1, 'static/output/w_1.pt')
        img1 = generate_image_from_w(w_1, G)
    else:
        w_1 = torch.load('static/output/w_1.pt')

    # load w_2
    f2 = request.files["imgPattern"]
    if f2:
        f2.save(upload_path2)
        img2 = Image.open(upload_path2).convert("RGB")
        img2.save(upload_path2)
        width, height = img2.size
        if width != 256 or height != 256:
            img2 = resize_img(img2)
        w_2 = img_to_w(img2, G, num_steps = 300, mode = "avg")
        # caching
        torch.save(w_2, 'static/output/w_2.pt')
        img2 = generate_image_from_w(w_2, G)
    else:
        w_2 = torch.load('static/output/w_2.pt')

    # mix style of two images
    w_1[0, 7:] = w_2[0, 7:]
    w_tmp = E(vectors, [0], [m_s], [np.arange(0,4,1)], w_1.cpu().numpy())
    w_modify = E(vectors, [4], [m_p], [np.arange(3, 9,1)], w_tmp.cpu().numpy())

    img_hybrid = generate_image_from_w(w_modify, G)
    p = "static/output/result_img.png"
    cv2.imwrite("static/output/result_img.png", cv2.cvtColor(img_hybrid, cv2.COLOR_RGB2BGR))

    # add random number in image address in order to reload
    return render_template('index.html', img = p+"?"+ str(random.randint(0,10000000)),
                           img1 = upload_path1+"?"+ str(random.randint(0,10000000)),
                           img2 = upload_path2+"?"+ str(random.randint(0,10000000)),
                           val1 = int(m_s), val2 = int(m_p))

@app.route("/getImage")
def getImage():
    seed2 = random.randint(0, 100)
    upload_path2 = os.path.join('static/image_interaction','query.jpg')
    w_2 = generate_w(generateZ_from_seed(seed2, G), G, truncation_psi=1)
    img2 = generate_image_from_w(w_2, G, noise_mode = 'const')
    torch.save(w_2, 'static/image_interaction/w_interaction.pt')
    cv2.imwrite(upload_path2, cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))
    return json.dumps({"l": upload_path2+"?"+ str(random.randint(0,10000000))})

@app.route("/static/upload_image", methods = ["POST"])
def upload_images():
    f1 = request.files["image"]
    print(f1.filename, f1.name)
    upload_path1 = 'static/upload_image/' + str(random.randint(0, 100))+".jpg"
    f1.save(upload_path1)
    img1 = Image.open(upload_path1).convert("RGB")
    width, height = img1.size
    if width != 256 or height != 256:
        img1 = resize_img(img1)
    img1.save(upload_path1)
    return json.dumps({"l": upload_path1+"?"+ str(random.randint(0,10000000))})

@app.route("/changeImage", methods=["POST"])
def changeImage():

    data = request.get_json(force=True)
    x = float(data['x'])
    y = float(data['y'])
    xaxis = data['yaxis']
    yaxis = data['xaxis']
    current_x = int(data['currentX'])
    current_y = int(data['currentY'])

    w = torch.load('static/image_interaction/w_interaction.pt')
    upload_path2 = 'static/image_interaction/query'+str(current_x)+'_'+str(current_y)+'.jpg'

    dic_x = dim_dic[xaxis]
    dic_y = dim_dic[yaxis]

    w_tmp = E(vectors, [dic_x["id"]], [(x+10)/20*(dic_x['mag'][1] - dic_x['mag'][0]) + dic_x['mag'][0]],
              [np.arange(dic_x['layers'][0], dic_x['layers'][1],1)], w.cpu().numpy())
    w_modify = E(vectors, [dic_y["id"]], [(y+10)/20*(dic_y['mag'][1] - dic_y['mag'][0]) + dic_y['mag'][0]],
              [np.arange(dic_y['layers'][0], dic_y['layers'][1],1)], w_tmp.cpu().numpy())

    img2 = generate_image_from_w(w_modify, G, noise_mode = 'const')
    cv2.imwrite(upload_path2, cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))
    torch.save(w_modify, 'static/image_interaction/w_interaction' +str(current_x)+'_'+str(current_y)+'.pt')


    return json.dumps({"l": upload_path2+"?"+ str(random.randint(0,10000000))})

@app.route("/stylemix", methods = ["POST"])
def stylemix():

    upload_path2 = "static/image_interaction/query400_400.jpg"
    data_percents = request.get_json(force=True)
    w_1 = torch.load('static/output/w_1.pt')
    w_2 = torch.load('static/output/w_2.pt')
    w_3 = torch.load('static/output/w_3.pt')
    print(w_1.shape)

    w = copy.copy(w_1)
    for ls, pers in zip([id_dic[0], id_dic[1]], [data_percents['0'], data_percents['1']]):
        w[:, ls[0]:ls[1]] = w_1[0, ls[0]:ls[1]]*pers[0]/100 + \
                            w_2[0, ls[0]:ls[1]]*pers[1]/100 + \
                            w_3[0, ls[0]:ls[1]]*pers[2]/100
    # time.sleep(1)

    torch.save(w, 'static/image_interaction/w_interaction.pt')
    img2 = generate_image_from_w(w, G, noise_mode = 'const')
    cv2.imwrite(upload_path2, cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))
    return json.dumps({"l": upload_path2+"?"+ str(random.randint(0,10000000))})

@app.route("/upload1", methods = ["POST"])
def upload1():
    f1 = request.files["image"]
    upload_path1 = 'static/upload_image/query1.jpg'
    f1.save(upload_path1)
    img1 = Image.open(upload_path1).convert("RGB")
    width, height = img1.size
    if width != 256 or height != 256:
        img1 = resize_img(img1)
    img1.save(upload_path1)
    w_1 = img_to_w(img1, G, num_steps = 1000, mode = "avg")
    # caching
    torch.save(w_1, 'static/output/w_1.pt')
    upload_path2 = "static/image_interaction/query400_400.jpg"

    w_1 = torch.load('static/output/w_1.pt')
    w_2 = torch.load('static/output/w_2.pt')
    w_3 = torch.load('static/output/w_3.pt')


    w = copy.copy(w_1)
    for ls, pers in zip([id_dic[0], id_dic[1]], [data_percents['0'], data_percents['1']]):
        w[0, ls[0]:ls[1]] = w_1[0, ls[0]:ls[1]]*pers[0]/100 + \
                            w_2[0, ls[0]:ls[1]]*pers[1]/100 + \
                            w_3[0, ls[0]:ls[1]]*pers[2]/100
    # time.sleep(1)

    torch.save(w, 'static/image_interaction/w_interaction.pt')
    img2 = generate_image_from_w(w, G, noise_mode = 'const')
    cv2.imwrite(upload_path2, cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))

    return json.dumps({"l": upload_path1+"?"+ str(random.randint(0,10000000)),
                       "interaction":upload_path2+"?"+ str(random.randint(0,10000000))})

@app.route("/upload2", methods = ["POST"])
def upload2():
    f1 = request.files["image"]
    upload_path1 = 'static/upload_image/query2.jpg'
    f1.save(upload_path1)
    img1 = Image.open(upload_path1).convert("RGB")
    width, height = img1.size
    if width != 256 or height != 256:
        img1 = resize_img(img1)
    img1.save(upload_path1)
    w_1 = img_to_w(img1, G, num_steps = 1000, mode = "avg")
    # caching
    torch.save(w_1, 'static/output/w_2.pt')
    upload_path2 = "static/image_interaction/query400_400.jpg"

    w_1 = torch.load('static/output/w_1.pt')
    w_2 = torch.load('static/output/w_2.pt')
    w_3 = torch.load('static/output/w_3.pt')


    w = copy.copy(w_1)
    for ls, pers in zip([id_dic[0], id_dic[1]], [data_percents['0'], data_percents['1']]):
        w[0, ls[0]:ls[1]] = w_1[0, ls[0]:ls[1]]*pers[0]/100 + \
                            w_2[0, ls[0]:ls[1]]*pers[1]/100 + \
                            w_3[0, ls[0]:ls[1]]*pers[2]/100
    # time.sleep(1)

    torch.save(w, 'static/image_interaction/w_interaction.pt')
    img2 = generate_image_from_w(w, G, noise_mode = 'const')
    cv2.imwrite(upload_path2, cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))

    return json.dumps({"l": upload_path1+"?"+ str(random.randint(0,10000000)),
                       "interaction":upload_path2+"?"+ str(random.randint(0,10000000))})

@app.route("/upload3", methods = ["POST"])
def upload3():
    f1 = request.files["image"]
    upload_path1 = 'static/upload_image/query3.jpg'
    f1.save(upload_path1)
    img1 = Image.open(upload_path1).convert("RGB")
    width, height = img1.size
    if width != 256 or height != 256:
        img1 = resize_img(img1)
    img1.save(upload_path1)
    w_1 = img_to_w(img1, G, num_steps = 1000, mode = "avg")
    # caching
    torch.save(w_1, 'static/output/w_3.pt')

    upload_path2 = "static/image_interaction/query400_400.jpg"

    w_1 = torch.load('static/output/w_1.pt')
    w_2 = torch.load('static/output/w_2.pt')
    w_3 = torch.load('static/output/w_3.pt')


    w = copy.copy(w_1)
    for ls, pers in zip([id_dic[0], id_dic[1]], [data_percents['0'], data_percents['1']]):
        w[0, ls[0]:ls[1]] = w_1[0, ls[0]:ls[1]]*pers[0]/100 + \
                            w_2[0, ls[0]:ls[1]]*pers[1]/100 + \
                            w_3[0, ls[0]:ls[1]]*pers[2]/100
    # time.sleep(1)

    torch.save(w, 'static/image_interaction/w_interaction.pt')
    img2 = generate_image_from_w(w, G, noise_mode = 'const')
    cv2.imwrite(upload_path2, cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))

    return json.dumps({"l": upload_path1+"?"+ str(random.randint(0,10000000)),
                       "interaction":upload_path2+"?"+ str(random.randint(0,10000000))})



@app.route('/text2img_re', methods=['POST'])
def text2img_re():
    text = request.get_json(force=True)
    num_images = 8
    # text2image(text, top_k=0.6, num_images=num_images, outputs_dir = "static/textimg")
    stylegan_text_image(text, num_show=8, outputs_dir = "static/textimg/")
    dic = {}
    for i in range(num_images):
        dic[i]="static/textimg/"+str(i)+".jpg?"+ str(random.randint(0,10000000))
    return json.dumps(dic)

@app.route('/clearimg', methods=['POST'])
def clearimg():
    data = request.get_json(force=True)
    current_x = int(data['currentX'])
    current_y = int(data['currentY'])
    path = "static/image_interaction"
    files = os.listdir(path)

    for file in files:
         if file!='query'+str(current_x)+'_'+str(current_y)+'.jpg' and file!='w_interaction' +str(current_x)+'_'+str(current_y)+'.pt' and file[0]!='.':
             os.remove(os.path.join(path, file))
    os.rename(os.path.join(path, 'query'+str(current_x)+'_'+str(current_y)+'.jpg'), os.path.join(path, 'query400_400.jpg'))
    os.rename(os.path.join(path, 'w_interaction' +str(current_x)+'_'+str(current_y)+'.pt'), os.path.join(path, 'w_interaction.pt'))
    return json.dumps({"l": os.path.join(path, 'query400_400.jpg')+"?"+ str(random.randint(0,10000000))})

if __name__ == "__main__":
    app.run()

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)


 * Running on http://7a6a-34-83-213-242.ngrok.io
 * Traffic stats available on http://127.0.0.1:4040


127.0.0.1 - - [29/Oct/2021 13:33:00] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:33:01] "[37mGET /static/script.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:33:01] "[37mGET /static/css/style.css HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:33:02] "[37mGET /static/textimg/5.jpg HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:33:02] "[37mGET /static/textimg/4.jpg HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:33:02] "[37mGET /static/textimg/6.jpg HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:33:03] "[37mGET /static/textimg/7.jpg HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:33:03] "[37mGET /static/textimg/0.jpg HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:33:03] "[37mGET /static/textimg/3.jpg HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:33:03] "[37mGET /static/upload_image/query1.jpg HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:33:03] "[37mGET /static/upload_image/query2.jpg HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/202

Setting up PyTorch plugin "bias_act_plugin"... Done.
Setting up PyTorch plugin "upfirdn2d_plugin"... Done.
54


127.0.0.1 - - [29/Oct/2021 13:33:48] "[37mGET /getImg1 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:33:49] "[37mGET /static/upload_image/query1.jpg?9019908 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:33:49] "[37mGET /static/image_interaction/query400_400.jpg?7263371 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:34:36] "[37mPOST /stylemix HTTP/1.1[0m" 200 -


torch.Size([1, 14, 512])


127.0.0.1 - - [29/Oct/2021 13:34:36] "[37mGET /static/image_interaction/query400_400.jpg?8233625 HTTP/1.1[0m" 200 -


model prepared
start sgd


127.0.0.1 - - [29/Oct/2021 13:36:09] "[37mPOST /upload1 HTTP/1.1[0m" 200 -


tensor(0.0637, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.0543, device='cuda:0', grad_fn=<MeanBackward0>)


127.0.0.1 - - [29/Oct/2021 13:36:09] "[37mGET /static/upload_image/query1.jpg?1637740 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:36:09] "[37mGET /static/image_interaction/query400_400.jpg?3439898 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:37:10] "[37mPOST /stylemix HTTP/1.1[0m" 200 -


torch.Size([1, 14, 512])


127.0.0.1 - - [29/Oct/2021 13:37:10] "[37mGET /static/image_interaction/query400_400.jpg?5584708 HTTP/1.1[0m" 200 -


0.31082553
0.3185576
0.31962842
0.3130533
0.31897944
0.32141474
0.31559896


127.0.0.1 - - [29/Oct/2021 13:45:46] "[37mPOST /text2img_re HTTP/1.1[0m" 200 -


0.31829464


127.0.0.1 - - [29/Oct/2021 13:45:46] "[37mGET /static/textimg/0.jpg?3038943 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:45:46] "[37mGET /static/textimg/2.jpg?4922884 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:45:46] "[37mGET /static/textimg/1.jpg?3507412 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:45:46] "[37mGET /static/textimg/3.jpg?8846143 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:45:47] "[37mGET /static/textimg/4.jpg?3776924 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:45:47] "[37mGET /static/textimg/5.jpg?2239124 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:45:47] "[37mGET /static/textimg/6.jpg?3114283 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:45:47] "[37mGET /static/textimg/7.jpg?8374662 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:46:29] "[37mPOST /changeImage HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/2021 13:46:29] "[37mGET /static/image_interaction/query589_418.jpg?9615363 HTTP/1.1[0m" 200 -
127.0.0.1 - - [29/Oct/202