In [None]:
#%pip install ..
%pip install stability-sdk

In [None]:
import getpass, os

# NB: host url is not prepended with \"https\" nor does it have a trailing slash.
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'

# To get your API key, visit https://beta.dreamstudio.ai/membership
os.environ['STABILITY_KEY'] = getpass.getpass('Enter your API Key')

In [None]:
from argparse import Namespace

import io
import torch
import numpy as np
import sys
import os
import dlib
from pathlib import Path

from PIL import Image


from models.Embedding import Embedding
from models.Alignment import Alignment
from models.Blending import Blending

from utils.drive import open_url
from utils.shape_predictor import align_face

import torchvision

In [None]:
align_args = Namespace(unprocessed_dir = 'unprocessed', output_dir = 'input/face', output_size = 1024,
 cache_dir = 'cache', inter_method = 'bicubic')

In [None]:
#Loading Shape Predictor for Alignment
cache_dir = Path(align_args.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)

output_dir = Path(align_args.output_dir)
output_dir.mkdir(parents=True,exist_ok=True)

print("Downloading Shape Predictor")
f=open_url("https://drive.google.com/uc?id=1huhv8PYpNNKbGCLOaYUjOgR1pY5pmbJx", cache_dir=cache_dir, return_path=True)
predictor = dlib.shape_predictor(f)

In [None]:
#Automatic Image Cropper
def cropper():
    img = Image.open(f'unprocessed/image.jpg')
    img = img.convert('RGB')

    img = img.crop((448, 28, 1472, 1052))
    img.save(f'unprocessed/image.jpg')

In [None]:
import io
import warnings


In [None]:
import json
import os
import shutil
from flask import request

from flask import Flask, render_template

In [None]:
def align_the_picture():
    for im in Path(align_args.unprocessed_dir).glob("*.*"):
           faces = align_face(str(im),predictor)

           for i,face in enumerate(faces):
               if(align_args.output_size):
                   factor = 1024//align_args.output_size
                   assert align_args.output_size*factor == 1024
                   face_tensor = torchvision.transforms.ToTensor()(face).unsqueeze(0).cuda()
                   face_tensor_lr = face_tensor[0].cpu().detach().clamp(0, 1)
                   face = torchvision.transforms.ToPILImage()(face_tensor_lr)
                   if factor != 1:
                       face = face.resize((align_args.output_size, align_args.output_size), Image.LANCZOS)
               if len(faces) > 1:
                   face.save(Path(align_args.output_dir) / (im.stem+f"_{i}.png"))
               else:
                   face.save(Path(align_args.output_dir) / (im.stem + f".png"))

In [None]:
def Embedding_step(im_path1, args):
    ii2s = Embedding(args)

    im_set = {im_path1}
    ii2s.invert_images_in_W([*im_set])
    ii2s.invert_images_in_FS([*im_set])

In [None]:
app = Flask(__name__, template_folder="/mnt/d/Programming/Spectrum/Barbershop/templates")

userPrompt = ""
al_src_path = "/mnt/d/Programming/Spectrum/Barbershop/static/output/Align_realistic"
otp_src_path = "/mnt/d/Programming/Spectrum/Barbershop/static/output"
og_src_path = "/mnt/d/Programming/Spectrum/Barbershop/static/unprocessed"
fin_path = "/mnt/d/Programming/Spectrum/Barbershop/static/static"



@app.route('/test', methods=['POST'])
def test():
    output = request.get_json()
    result = json.loads(output)
    userPrompt = result["userPrompt"]
    print(userPrompt, type(userPrompt))
    answers = stability_api.generate(
    prompt= "a realistic photo of " + userPrompt,
    start_schedule=0.95,
    # seed=34567, # if provided, specifying a random seed makes results deterministic
    steps=30, # defaults to 50 if not specified
    )

# iterating over the generator produces the api response
    for resp in answers:
        for artifact in resp.artifacts:
            if artifact.finish_reason == generation.FILTER:
                warnings.warn(
                "Your request activated the API's safety filters and could not be processed."
                "Please modify the prompt and try again.")
            if artifact.type == generation.ARTIFACT_IMAGE:
                img = Image.open(io.BytesIO(artifact.binary))
                display(img)
                img.save("static/image.jpg")
    return result

@app.route('/hair', methods=['POST'])
def hair():
    print("Running Hair")
    cropper()
    
    output = request.get_json()
    result = json.loads(output)
    userHair = result["id"]
    blendCheck = result["blendCheck"]
    print(blendCheck, type(blendCheck))
    old_file = ""
    new_file = ""

    print(userHair)
    
    DOWNLOADED_IMAGE = os.path.join(og_src_path , 'image.jpg')
    print(DOWNLOADED_IMAGE)

    #PUT IMAGE GENERATION CODE!!!!!
    #ALIGN_FACE
    print("Aligning Face")
    align_the_picture()
    
    #Embedding Phase
    print("Embedding")
    

    args = Namespace(input_dir='input/face', output_dir='output', im_path1 = 'image.png', im_path2=f'{userHair}.png', 
    im_path3= f'{userHair}.png', sign='realistic', smooth=5, size=1024, ckpt='pretrained_models/ffhq.pt', channel_multiplier=2,
      latent=512, n_mlp=8, device='cuda', seed=None, tile_latent=False, opt_name='adam',
       learning_rate=0.01, lr_schedule='fixed', save_intermediate=False, save_interval=300,
        verbose=False, seg_ckpt='pretrained_models/seg.pth', percept_lambda=1.0, l2_lambda=1.0,
         p_norm_lambda=0.001, l_F_lambda=0.1, W_steps=250, FS_steps=250, ce_lambda=1.0, style_lambda=40000.0,
          align_steps1=100, align_steps2=100, face_lambda=1.0, hair_lambda=1.0, blend_steps=200)

    im_path1 = os.path.join(args.input_dir, args.im_path1)
    im_path2 = os.path.join(args.input_dir, args.im_path2)

    Embedding_step(im_path1, args)
    
    #Mask Alignment
    align = Alignment(args)
    align.align_images(im_path1, im_path2, sign=args.sign, align_more_region=False, smooth=args.smooth)
    
    new_file = os.path.join(fin_path, "image.png")
    if (blendCheck):
        #Blending
        blend = Blending(args)
        blend.blend_images(im_path1, im_path2, im_path2, sign=args.sign)

        old_file = os.path.join(otp_src_path, f"image_{userHair}_{userHair}_{args.sign}.png")
        os.rename(old_file , new_file)
        print("Blending -> Finished Rename")
    else:
        old_file = os.path.join(al_src_path, f"image_{userHair}.png")
        os.rename(old_file, new_file)
        shutil.move(old_file, new_file)
        print("Finished os manipulation move")
    return result


    #LOADING FINISHED IMAGE CODE... DO NOT TOUCH!
    '''new_file = os.path.join(otp_src_path, "image.png")
    if(blendCheck):
        old_file = os.path.join(otp_src_path, f"image_{userHair}_{userHair}.png")
        os.rename(old_file , new_file)
        print("Blending -> Finished Rename")
    else:
        old_file = os.path.join(al_src_path, f"image_{userHair}.png")
        os.rename(old_file, new_file)
        shutil.move(old_file, new_file)
        print("Finished os manipulation move")
    print(userHair)
    return result'''


@app.route('/del', methods=['POST'])
def delete():
    print("deleting")
    output = request.get_json()
    result = json.loads(output)
    if os.path.exists('/mnt/d/Programming/Spectrum/Barbershop/static/unprocessed/image.jpg'):
        file = '/mnt/d/Programming/Spectrum/Barbershop/static/unprocessed/image.jpg'
        os.remove(file)
    else:
        print("The file does not exist")
    return result


@app.route('/')
def index():
    return render_template('index.html')


@app.route('/demo')
def demo():
    return render_template('demo.html')


@app.route('/interpolation')
def interpolation():
    return render_template('interpolation.html')


@app.route('/hairstyle')
def hairstyle():
    return render_template('hairstyle.html');


if __name__ == "__main__":
    app.run(debug=False)


In [None]:
# import os






# shutil.move(old_file, new_file)