In [None]:
!pip install deepface diffusers
from deepface import DeepFace
from diffusers.utils import load_image
import numpy as np
!gdown 1PAalGvpi9X2jpqHFUJVHTuEgfG9wRYSe
im = load_image("example.png")
objs = DeepFace.analyze(img_path=np.array(im.resize((768, 1024))), actions=["age", "gender"])

!pip install xformers transformers accelerate
!pip install controlnet_aux
!pip install opencv-contrib-python
!git clone https://github.com/boomb0om/Real-ESRGAN-colab
%cd Real-ESRGAN-colab
!gdown 1SGHdZAln4en65_NQeQY9UjchtkEF9f5F
!mv RealESRGAN_x4.pth weights/RealESRGAN_x4.pth

In [3]:
# imports
from PIL import Image
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler, StableDiffusionImg2ImgPipeline
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from torchvision.transforms.functional import to_pil_image
from diffusers.utils import load_image
from tqdm import tqdm
from torch import autocast
from realesrgan import RealESRGAN
from IPython.display import display
import cv2
import numpy as np
import torch
import requests
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision.transforms as T

In [None]:
class ImproveImageQuality:
    def __init__(self):
        # for generating in the same position
        self.controlnet_canny = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
        self.pipe_controlnet_canny = StableDiffusionControlNetPipeline.from_pretrained("emilianJR/epiCRealism", controlnet=self.controlnet_canny, safety_checker=None, torch_dtype=torch.float16).to("cuda")

        # for clothes segmentation
        self.image_processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
        self.clothes_seg_model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")

        # for reducing mismatchings and noise
        self.pipe_regenerate = StableDiffusionImg2ImgPipeline.from_pretrained("emilianJR/epiCRealism", torch_dtype=torch.float16, safety_checker=None).to("cuda")

        # for upscaling
        self.esrgan = RealESRGAN("cuda", scale=4)
        self.esrgan.load_weights('weights/RealESRGAN_x4.pth')

        self.to_tensor = T.ToTensor()
        self.negative_prompt = "blurry, low quality, deformed, mutated, disfigured, ugly, distorted face, asymmetrical face, bad anatomy, bad hands, extra limbs, poorly drawn face, cloned face, artifact, jpeg artifacts, cartoon, anime, sketch, out of frame"

    def preprocessing_image(self, image_input):
        orig_weight, orig_height = np.array(image_input).shape[:2]
        coeff = max(orig_weight, orig_height) / 1024
        size = (int(orig_height / coeff), int(orig_weight / coeff))
        image_input = image_input.resize(size)
        
        image_input_array = np.array(image_input)
        height, weight = image_input_array.shape[:2]
        to_cut_height = (height % 8) // 2
        to_cut_weight = (weight % 8) // 2
        image_input_array = image_input_array[to_cut_height:height - (height % 8 - to_cut_height), to_cut_weight:weight - (weight % 8 - to_cut_weight)]
        return Image.fromarray(image_input_array), tuple(image_input_array.shape[:2][::-1])

    def make_canny_mask(self, image_preprocessed, low_threshold=100, high_threshold=150):
        low_threshold = low_threshold
        high_threshold = high_threshold

        image_canny = cv2.Canny(np.array(image_preprocessed), low_threshold, high_threshold)

        image_canny = image_canny[:, :, None]
        image_canny = np.concatenate([image_canny, image_canny, image_canny], axis=2)
        return Image.fromarray(image_canny)

    def generate_person_with_canny(self, canny_mask, prompt, steps=40, guidance=9):
        steps = steps
        guidance = guidance
        generated_image_canny = self.pipe_controlnet_canny(
            prompt=prompt,
            image=canny_mask,
            num_inference_steps=steps,
            guidance_scale=guidance,
            negative_prompt=self.negative_prompt
        ).images[0]

        return generated_image_canny

    def get_pred_seg_from_image(self, image_preprocessed):
        inputs = self.image_processor(images=image_preprocessed, return_tensors="pt")
        outputs = self.clothes_seg_model(**inputs)
        logits = outputs.logits.cpu()

        upsampled_logits = nn.functional.interpolate(
            logits,
            size=image_preprocessed.size[::-1],
            mode="bilinear",
            align_corners=False,
        )

        pred_seg = upsampled_logits.argmax(dim=1)[0]
        return pred_seg

    #   {
    #     "0": "Background",
    #     "1": "Hat",
    #     "2": "Hair",
    #     "3": "Sunglasses",
    #     "4": "Upper-clothes",
    #     "5": "Skirt",
    #     "6": "Pants",
    #     "7": "Dress",
    #     "8": "Belt",
    #     "9": "Left-shoe",
    #     "10": "Right-shoe",
    #     "11": "Face",
    #     "12": "Left-leg",
    #     "13": "Right-leg",
    #     "14": "Left-arm",
    #     "15": "Right-arm",
    #     "16": "Bag",
    #     "17": "Scarf"
    #   }

    def cut_clothes_from_image(self, pred_seg, image_preprocessed):
        mask = np.isin(pred_seg, [1, 3, 4, 5, 6, 7, 8, 9, 10, 17])
        mask_bkg = (pred_seg == 0)
        cutted_clothes = self.to_tensor(image_preprocessed) * mask
        person_image = self.to_tensor(image_preprocessed) * ~mask_bkg
        return mask, cutted_clothes, to_pil_image(person_image)

    def combine_generated_image_and_clothes(self, generated_image_canny, mask, cutted_clothes):
        generated_image_tensor = self.to_tensor(generated_image_canny)
        generated_image_tensor *= 1 - mask
        generated_image_tensor += cutted_clothes
        return to_pil_image(generated_image_tensor)

    def regenerate_image(self, combined_image, prompt, strength=0.25, guidance_scale=6):
        regenerated_image = self.pipe_regenerate(
            prompt=prompt,
            negative_prompt=self.negative_prompt,
            image=combined_image,
            strength=strength,
            guidance_scale=guidance_scale
        ).images[0]

        return regenerated_image

    def improve_quality_from_url(self, url, background=None, gender=None, age=None):
        image_input = load_image(url)

        preprocessed_image, size = self.preprocessing_image(image_input)

        pred_seg = self.get_pred_seg_from_image(preprocessed_image)
        mask, cutted_clothes, person_image = self.cut_clothes_from_image(pred_seg, preprocessed_image)

        if gender is None:
            if age is None:
                objs = DeepFace.analyze(img_path=np.array(person_image), actions=["age", "gender"])
                age = objs[0]["age"]
                gender = objs[0]["dominant_gender"]
            else:
                objs = DeepFace.analyze(img_path=np.array(person_image), actions=["gender"])
                gender = objs[0]["dominant_gender"]
        else:
            if age is None:
                objs = DeepFace.analyze(img_path=np.array(person_image), actions=["age"])
                age = objs[0]["age"]

        if gender == "Man":
          gender = "male"
        else:
          gender = "female"
        
        if background is None:
            prompt = f"A professional photograph of a {age}-year old {gender} fashion model, ultra-realistic, high detailed skin, realistic proportions, sharp facial features, depth of field, soft natural light, cinematic photo, 85mm lens, dslr, photorealistic"
        else:
            prompt = f"A professional photograph of a {age}-year old {gender} fashion model, {background} background, ultra-realistic, high detailed skin, realistic proportions, sharp facial features, depth of field, soft natural light, cinematic photo, 85mm lens, dslr, photorealistic"

        canny_mask = self.make_canny_mask(person_image)
        generated_image = self.generate_person_with_canny(canny_mask, prompt)

        combined_image = self.combine_generated_image_and_clothes(generated_image, mask, cutted_clothes)

        regenerated_image = self.regenerate_image(combined_image.resize(size), prompt)

        esr_image = self.esrgan.predict(regenerated_image)

        return image_input.resize(size), age, gender, esr_image.resize(size)

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

model = ImproveImageQuality()

### Upload your images

In [None]:
path_to_the_image = "<YOUR PATH or IMAGE URL>"

In [None]:
input_image, age, gender, result = model.improve_quality_from_url(path_to_the_image) # можно дополнительно задать параметры background, gender, age
print(f"age = {age}\ngender = {gender}")
model.image_grid([input_image, result], 1, 2)