In [1]:
from __future__ import annotations

import argparse
import os
import pathlib
import subprocess
import sys
from typing import Callable, Union

import dlib
import huggingface_hub
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torchvision.transforms as T


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from mapper.datasets.latents_dataset_inference import LatentsDatasetInference
from mapper.hairclip_mapper import HairCLIPMapper

In [4]:
def load_hairclip() -> nn.Module:
        ckpt_path = "pretrained_models/hairclip.pt"
        ckpt = torch.load(ckpt_path, map_location='cpu')
        opts = ckpt['opts']
        opts['device'] = "cpu"
        opts['checkpoint_path'] = ckpt_path
        opts['editing_type'] = 'both'
        opts['input_type'] = 'text'
        opts['hairstyle_description'] = 'HairCLIP/mapper/hairstyle_list.txt'
        opts['color_description'] = 'red'
        opts = argparse.Namespace(**opts)
        model = HairCLIPMapper(opts)
        model.to("cpu")
        model.eval()
        return model

In [5]:

def generate(editing_type: str, hairstyle_index: int,
             color_description: str, latent: torch.Tensor) -> np.ndarray:
    hairclip = load_hairclip() 
    opts = hairclip.opts
    opts.editing_type = editing_type
    opts.color_description = color_description
    if editing_type == 'color':
        hairstyle_index = 0
    device = torch.device(opts.device)
    dataset = LatentsDatasetInference(latents=latent.unsqueeze(0).cpu(),
                                      opts=opts)
    w, hairstyle_text_inputs_list, color_text_inputs_list = dataset[0][:3]
    w = w.unsqueeze(0).to(device)
    hairstyle_text_inputs = hairstyle_text_inputs_list[
        hairstyle_index].unsqueeze(0).to(device)
    color_text_inputs = color_text_inputs_list[0].unsqueeze(0).to(device)
    hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device)
    color_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device)
    w_hat = w + 0.1 * hairclip.mapper(
        w,
        hairstyle_text_inputs,
        color_text_inputs,
        hairstyle_tensor_hairmasked,
        color_tensor_hairmasked,
    )
    x_hat, _ = hairclip.decoder(
        [w_hat],
        input_is_latent=True,
        return_latents=True,
        randomize_noise=False,
        truncation=1,
    )
    res = torch.clamp(x_hat[0].detach(), -1, 1)
    res = postprocess(res)
    return res