In [None]:
#@title Mounting your Google Drive
from google.colab import drive
drive.mount('/content/drive')
mount_path = '/content/drive/My Drive/SDworkspace'#@param {type:"string"}
%cd $mount_path

## Initial Setup

In [None]:
#@title Install the required libs
!pip install -qq diffusers["training"]==0.3.0 transformers ftfy
!pip install -qq "ipywidgets>=7,<8"

In [None]:
#@title Login to the Hugging Face Hub
#@markdown Add a token with the "Write Access" role to be able to add your trained concept to the [Library of Concepts](https://huggingface.co/sd-concepts-library)
from huggingface_hub import notebook_login

notebook_login()

In [None]:
#@title Import required libraries and Setting up the model
import argparse
import itertools
import math
import os
import random

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset

import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
# 事前学習モデルのパス
pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"

# トークナイザーとテキストエンコーダーの準備
tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="tokenizer",
    use_auth_token=True,
)
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=True
)
# 学習した特徴ベクトルをCLIPに読み込み
learned_embeds_path = "./sd-concept-output/learned_embeds.bin"
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
  loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
  
  # 個別のトークンと特徴ベクトル
  trained_token = list(loaded_learned_embeds.keys())[0]
  embeds = loaded_learned_embeds[trained_token]

  # text_encoderのdtypeにキャスト
  dtype = text_encoder.get_input_embeddings().weight.dtype
  embeds.to(dtype)

  # トークナイザーにトークンを追加
  token = token if token is not None else trained_token
  num_added_tokens = tokenizer.add_tokens(token)
  if num_added_tokens == 0:
    raise ValueError(f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer.")
  
  # トークンの特徴ベクトルのサイズ変更
  text_encoder.resize_token_embeddings(len(tokenizer))
  
  # トークンのIDを取得し特徴ベクトルを割り当てる
  token_id = tokenizer.convert_tokens_to_ids(token)
  text_encoder.get_input_embeddings().weight.data[token_id] = embeds
load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer)
# Stable Diffusionパイプラインの準備
from torch import autocast

pipe = StableDiffusionPipeline.from_pretrained(
    pretrained_model_name_or_path,
    torch_dtype=torch.float16,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    use_auth_token=True,
).to("cuda")

## Setup Flask Server



In [None]:
#@title Install the required libs
!pip install -U flask-cors
!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
!unzip ngrok-stable-linux-amd64.zip

In [None]:
#@title Check your flask server endpoint 
get_ipython().system_raw('./ngrok http 6006 &')
!curl -s http://localhost:4040/api/tunnels | python3 -c "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

## Run

In [None]:
from flask import Flask, render_template,jsonify, request
from flask_cors import CORS 

import base64
from PIL import Image
import uuid

app = Flask(__name__)
CORS(app)


def image_file_to_base64(file_path):
    with open(file_path, "rb") as image_file:
        data = base64.b64encode(image_file.read())
    return data.decode('utf-8')

# パラメータ
prompt = "a birds in the style of kutani-ware" # プロンプト
num_samples = 4 # 画像数
num_rows = 1 # 行数

def image_file_to_base64(file_path):
    with open(file_path, "rb") as image_file:
        data = base64.b64encode(image_file.read())
    return data.decode('utf-8')

@app.route("/", methods=['GET','POST'])
def index():
  prompt=request.args.get('prompt')
  
  #推論の実行
  all_images = [] 
  for _ in range(num_rows):
      with autocast("cuda"):
          images = pipe([prompt] * num_samples, num_inference_steps=50, guidance_scale=7.5)["sample"]
          all_images.extend(images)
  res = []
  for img in all_images:
    id=str(uuid.uuid4())
    img.save(f"buf/{id}.png")
    res.append(image_file_to_base64(f"buf/{id}.png"))

  return jsonify({"image":res})

if __name__ == '__main__':
    app.run(port=6006)