In [None]:
!pip install torchvision
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install tensorflow-io
!pip install Ninja
!pip install dlib
!pip install cog

In [None]:
!pip install gdown

In [None]:
%cd /kaggle/working/
!git clone https://github.com/wty-ustc/HairCLIP.git

In [None]:
%cd /kaggle/working/HairCLIP

In [None]:
!git clone https://github.com/omertov/encoder4editing.git

# Data

In [None]:
!gdown --id '1gof8kYc_gDLUT4wQlmUdAtPnQIlCO26q' -O "/kaggle/working/HairCLIP/mapper/datasets/train_faces.pt"

In [None]:
!gdown --id '1j7RIfmrCoisxx3t-r-KC02Qc8barBecr' -O '/kaggle/working/HairCLIP/mapper/datasets/test_faces.pt'

In [None]:
import torch
train_latents = torch.load("/kaggle/working/HairCLIP/mapper/datasets/train_faces.pt")

In [None]:
test_latents = torch.load("/kaggle/working/HairCLIP/mapper/datasets/test_faces.pt")

In [None]:
train_latents.shape, test_latents.shape

# Checkpoints

In [None]:
!gdown --id '1cUv_reLE6k3604or78EranS7XzuVMWeO' -O "/kaggle/working/HairCLIP/pretrained_models/e4e_ffhq_encode.pt"

In [None]:
!gdown --id '1pts5tkfAcWrg4TpLDu6ILF5wHID32Nzm' -O "/kaggle/working/HairCLIP/pretrained_models/stylegan2-ffhq-config-f.pt"

In [None]:
!gdown --id '1hqZT6ZMldhX3M_x378Sm4Z2HMYr-UwQ4' -O "/kaggle/working/HairCLIP/pretrained_models/hairclip.pt"

In [None]:
!gdown --id '1FS2V756j-4kWduGxfir55cMni5mZvBTv' -O "/kaggle/working/HairCLIP/pretrained_models/model_ir_se50.pth"

In [None]:
!wget -P /kaggle/working/HairCLIP/pretrained_models http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 # DOWNLOAD LINK

!bunzip2 /kaggle/working/HairCLIP/pretrained_models/shape_predictor_68_face_landmarks.dat.bz2

# LatentsDatasetInference.py

In [None]:
from torch.utils.data import Dataset
import sys
import numpy as np
import clip
import torch
import random
from PIL import Image
import torchvision.transforms as transforms
sys.path.insert(0, "/kaggle/working/HairCLIP/mapper")
from training import train_utils
import os

In [None]:
class LatentsDatasetInference(Dataset):
    def __init__(self, latents, opts):
        self.latents = latents
        self.opts = opts

        if self.opts.editing_type in ['hairstyle', 'both'] and self.opts.input_type.split('_')[0] == 'text':
            with open(self.opts.hairstyle_description, "r") as fd:
                self.hairstyle_description_list = fd.read().splitlines()
            self.hairstyle_list = [single_hairstyle_description[:-9] for single_hairstyle_description in self.hairstyle_description_list]
        if self.opts.editing_type in ['color', 'both'] and self.opts.input_type.split('_')[-1] == 'text':
            self.color_list = [single_color_description.strip()+' ' for single_color_description in self.opts.color_description.split(',')]
        if self.opts.editing_type in ['hairstyle', 'both'] and self.opts.input_type.split('_')[0] == 'image':
            self.out_domain_hairstyle_img_path_list = sorted(train_utils.make_dataset(self.opts.hairstyle_ref_img_test_path))
        if self.opts.editing_type in ['color', 'both'] and self.opts.input_type.split('_')[-1] == 'image':
            self.out_domain_color_img_path_list = sorted(train_utils.make_dataset(self.opts.color_ref_img_test_path))

        self.image_transform = transforms.Compose([transforms.Resize((1024, 1024)), transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])


    def manipulate_hairstyle(self, index):
        if self.opts.input_type.split('_')[0] == 'text':
            color_text_embedding_list = [torch.Tensor([0]) for i in range(len(self.hairstyle_list))]
            color_tensor_list = [torch.Tensor([0]) for i in range(len(self.hairstyle_list))]
            hairstyle_tensor_list = [torch.Tensor([0]) for i in range(len(self.hairstyle_list))]
            selected_hairstyle_description_list = [single_hairstyle_description+'hairstyle' for single_hairstyle_description in self.hairstyle_list]
            hairstyle_text_embedding_list = [torch.cat([clip.tokenize(selected_hairstyle_description)])[0] for selected_hairstyle_description in selected_hairstyle_description_list]
        elif self.opts.input_type.split('_')[0] == 'image':
            color_text_embedding_list = [torch.Tensor([0]) for i in range(self.opts.num_of_ref_img)]
            color_tensor_list = [torch.Tensor([0]) for i in range(self.opts.num_of_ref_img)]
            hairstyle_text_embedding_list = [torch.Tensor([0]) for i in range(self.opts.num_of_ref_img)]
            selected_hairstyle_description_list = ['hairstyle_out_domain_ref' for i in range(self.opts.num_of_ref_img)]
            hairstyle_tensor_list = [self.image_transform(Image.open(random.choice(self.out_domain_hairstyle_img_path_list))) for i in range(self.opts.num_of_ref_img)]
        return self.latents[index], hairstyle_text_embedding_list, color_text_embedding_list, selected_hairstyle_description_list, hairstyle_tensor_list, color_tensor_list


    def manipulater_color(self, index):
        if self.opts.input_type.split('_')[-1] == 'text':
            hairstyle_text_embedding_list = [torch.Tensor([0]) for i in range(len(self.color_list))]
            hairstyle_tensor_list = [torch.Tensor([0]) for i in range(len(self.color_list))]
            color_tensor_list = [torch.Tensor([0]) for i in range(len(self.color_list))]
            selected_color_description_list = [single_color_description+'hair' for single_color_description in self.color_list]
            color_text_embedding_list = [torch.cat([clip.tokenize(selected_color_description)])[0] for selected_color_description in selected_color_description_list]
        elif self.opts.input_type.split('_')[-1] == 'image':
            hairstyle_text_embedding_list = [torch.Tensor([0]) for i in range(self.opts.num_of_ref_img)]
            hairstyle_tensor_list = [torch.Tensor([0]) for i in range(self.opts.num_of_ref_img)]
            color_text_embedding_list = [torch.Tensor([0]) for i in range(self.opts.num_of_ref_img)]
            selected_color_description_list = ['color_out_domain_ref' for i in range(self.opts.num_of_ref_img)]
            color_tensor_list = [self.image_transform(Image.open(random.choice(self.out_domain_color_img_path_list))) for i in range(self.opts.num_of_ref_img)]
        return self.latents[index], hairstyle_text_embedding_list, color_text_embedding_list, selected_color_description_list, hairstyle_tensor_list, color_tensor_list		


    def manipulater_hairstyle_and_color(self, index):
        returned_latent, hairstyle_text_embedding_list, _, selected_hairstyle_description_list, hairstyle_tensor_list, _ = self.manipulate_hairstyle(index)
        _, _, color_text_embedding_list, selected_color_description_list, _, color_tensor_list = self.manipulater_color(index)
        hairstyle_text_embedding_final_list = [hairstyle_text_embedding for hairstyle_text_embedding in hairstyle_text_embedding_list for i in color_text_embedding_list]
        color_text_embedding_final_list = [color_text_embedding for i in hairstyle_text_embedding_list for color_text_embedding in color_text_embedding_list]
        selected_description_list = [f'{selected_hairstyle_description}-{selected_color_description}' for selected_hairstyle_description in selected_hairstyle_description_list for selected_color_description in selected_color_description_list]
        hairstyle_tensor_final_list = [hairstyle_tensor for hairstyle_tensor in hairstyle_tensor_list for i in color_tensor_list]
        color_tensor_final_list = [color_tensor for i in hairstyle_tensor_list for color_tensor in color_tensor_list]
        return returned_latent, hairstyle_text_embedding_final_list, color_text_embedding_final_list, selected_description_list, hairstyle_tensor_final_list, color_tensor_final_list


    def __len__(self):
        return self.latents.shape[0]

    def __getitem__(self, index):
        if self.opts.editing_type == 'hairstyle':
            return self.manipulate_hairstyle(index)
        elif self.opts.editing_type == 'color':
            return self.manipulater_color(index)
        elif self.opts.editing_type == 'both':
            return self.manipulater_hairstyle_and_color(index)

# Predict.py

In [None]:
import sys
import tempfile
from argparse import Namespace

import dlib
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from cog import BasePredictor, Path, Input
from criteria.parse_related_loss import average_lab_color_loss

In [None]:
sys.path.insert(0, "encoder4editing")
from models.psp import pSp
from utils.alignment import align_face

In [None]:
sys.path.insert(0, "mapper")
# from mapper.datasets.latents_dataset_inference import LatentsDatasetInference
from mapper.hairclip_mapper import HairCLIPMapper
from mapper.options.test_options import TestOptions

In [None]:
%matplotlib inline
with open("mapper/hairstyle_list.txt") as infile:
    HAIRSTYLE_LIST = sorted([line.rstrip() for line in infile])


class Predictor_e4e(BasePredictor):
    def setup(self):
        self.device = "cuda:0"
        # use e4e to get latent code for an input image
        e4e_model_path = "pretrained_models/e4e_ffhq_encode.pt"
        e4e_ckpt = torch.load(e4e_model_path, map_location="cpu")
        e4e_opts = e4e_ckpt["opts"]
        e4e_opts["checkpoint_path"] = e4e_model_path
        e4e_opts = Namespace(**e4e_opts)

        self.e4e_net = pSp(e4e_opts)
        self.e4e_net.eval()
        self.e4e_net.cuda()
        print("e4e model successfully loaded!")
        
        self.img_transforms = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )

        # HairClip model
        checkpoint_path = "pretrained_models/hairclip.pt"
        self.ckpt = torch.load(checkpoint_path, map_location="cpu")
        
    
    def predict(
        self,
        image: Path = Input(
            description="Input image. Image will be aligned and resized. Output will be the "
            "concatenation of the inverted input and the image with edited hair."
        ),
        editing_type: str = Input(
            choices=["hairstyle", "color", "both"],
            default="hairstyle",
            description="Edit hairstyle or color or both.",
        ),
        input_type: str = Input(
            choices = ["text", "image", "text_image", "image_text"],
            description="(1)_(2):  (1) for hairstyle and (2) for color"
        ),
        hairstyle_description: str = Input(
            choices=HAIRSTYLE_LIST,
            default=None,
            description="Hairstyle text prompt. "
            "Valid if input_type is text or text_image.",
        ),
        hairstyle_ref_img_test_path: Path = Input(),
        
        color_description: str = Input(
            default=None,
            description="Color text prompt, eg: purple, red, orange. "
            "Valid if editing_type is color or both.",
        ),
        color_ref_img_test_path: Path = Input()
    ) -> Path:
        
        editing_type_ = str(editing_type).split(".")[-1]
        hairstyle_description_ = str(hairstyle_description).split(".")[-1]
        

        if editing_type_ == "both":
            assert (
                hairstyle_description_ is not None and color_description is not None
            ), ("Please provide description " "for both hairstyle and color.")
        elif editing_type_ == "hairstyle":
            assert (
                hairstyle_description_ is not None
            ), "Please provide description for hairstyle."
        else:
            assert (
                color_description is not None
            ), "Please provide description for color."

        opts = self.ckpt["opts"]
        opts = Namespace(**opts)
        
        opts.num_of_ref_img = 5
        opts.editing_type = editing_type_
        opts.input_type = input_type
        opts.color_description = color_description
        if hairstyle_description is not None:
            with open("hairstyle_description.txt", "w") as file:
                file.write(hairstyle_description_)

            opts.hairstyle_description = "hairstyle_description.txt"
        opts.color_ref_img_test_path = color_ref_img_test_path
        opts.hairstyle_ref_img_test_path = hairstyle_ref_img_test_path
        
        opts.checkpoint_path = "pretrained_models/hairclip.pt"
        opts.parsenet_weights = "pretrained_models/parsenet.pth"
        opts.stylegan_weights = "pretrained_models/stylegan2-ffhq-config-f.pt"
        opts.ir_se50_weights = "pretrained_models/model_ir_se50.pth"
        net = HairCLIPMapper(opts)
        net.eval()
        net.cuda()

        # first align, resize image and get latent code
        input_image = run_alignment(str(image))
        resize_dims = (256, 256)
        input_image.resize(resize_dims)
        transformed_image = self.img_transforms(input_image)

        with torch.no_grad():

            images, latents = run_on_batch_e4e(
                transformed_image.unsqueeze(0), self.e4e_net
            )
            print("Latent code calculated!")
            print(f"Latent code shape: {latents.shape}")

        dataset = LatentsDatasetInference(latents=latents.cpu(), opts=opts)
        dataloader = DataLoader(dataset)

        average_color_loss = (
            average_lab_color_loss.AvgLabLoss(opts).to(self.device).eval()
        )

        out_path = Path("/kaggle/working/output.png")

        for input_batch in tqdm(dataloader):

            with torch.no_grad():

                (
                    w,
                    hairstyle_text_inputs_list,
                    color_text_inputs_list,
                    selected_description_tuple_list,
                    hairstyle_tensor_list,
                    color_tensor_list,
                ) = input_batch
                hairstyle_text_inputs = hairstyle_text_inputs_list[0]
                color_text_inputs = color_text_inputs_list[0]
                selected_description = selected_description_tuple_list[0][0]
                hairstyle_tensor = hairstyle_tensor_list[0]
                color_tensor = color_tensor_list[0]
                
                
                w = w.cuda().float()
                hairstyle_text_inputs = hairstyle_text_inputs.cuda()
                color_text_inputs = color_text_inputs.cuda()
                hairstyle_tensor = hairstyle_tensor.cuda()
                color_tensor = color_tensor.cuda()
                
                if hairstyle_tensor.shape[1] != 1:
                    hairstyle_tensor_hairmasked = (
                        hairstyle_tensor * average_color_loss.gen_hair_mask(hairstyle_tensor)
                    )
                else:
                    hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).cuda()
                    
                if color_tensor.shape[1] != 1:
                    color_tensor_hairmasked = (
                        color_tensor * average_color_loss.gen_hair_mask(color_tensor)
                    )
                else:
                    color_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).cuda()
                    
                result_batch = run_on_batch(
                    w,
                    hairstyle_text_inputs,
                    color_text_inputs,
                    hairstyle_tensor_hairmasked,
                    color_tensor_hairmasked,
                    net,
                )

                if (hairstyle_tensor.shape[1] != 1) and (color_tensor.shape[1] != 1):
                    img_tensor = torch.cat([hairstyle_tensor, color_tensor], dim=3)
                elif hairstyle_tensor.shape[1] != 1:
                    img_tensor = hairstyle_tensor
                elif color_tensor.shape[1] != 1:
                    img_tensor = color_tensor
                else:
                    img_tensor = None

                if img_tensor is not None:
                    if img_tensor.shape[3] == 1024:
                        couple_output = torch.cat(
                            [
                                result_batch[2][0].unsqueeze(0),
                                img_tensor,
                                result_batch[0][0].unsqueeze(0),
                            ]
                        )
                    elif img_tensor.shape[3] == 2048:
                        couple_output = torch.cat(
                            [
                                result_batch[2][0].unsqueeze(0),
                                result_batch[0][0].unsqueeze(0),
                                img_tensor[:, :, :, 0:1024],
                                img_tensor[:, :, :, 1024::],
                            ]
                        )
                        couple_output = torch.cat(
                            [
                                result_batch[2][0].unsqueeze(0),
                                result_batch[0][0].unsqueeze(0),
                                img_tensor[:, :, :, 0:1024],
                                img_tensor[:, :, :, 1024::],
                            ]
                        )
                else:
                    couple_output = torch.cat(
                        [
                            result_batch[2][0].unsqueeze(0),
                            result_batch[0][0].unsqueeze(0),
                        ]
                    )
                    
                torchvision.utils.save_image(
                    couple_output, str(out_path), normalize=True
                )

                
        return out_path

In [None]:
datFile = "/kaggle/working/HairCLIP/pretrained_models/shape_predictor_68_face_landmarks.dat"
def run_alignment(image_path):
    predictor = dlib.shape_predictor(datFile)
    aligned_image = align_face(filepath=image_path, predictor=predictor)
    print("Aligned image has shape: {}".format(aligned_image.size))
    return aligned_image

In [None]:
def run_on_batch_e4e(inputs, net):
    images, latents = net(
        inputs.to("cuda").float(), randomize_noise=False, return_latents=True
    )
    return images, latents

In [None]:
def run_on_batch(
    inputs,
    hairstyle_text_inputs,
    color_text_inputs,
    hairstyle_tensor_hairmasked,
    color_tensor_hairmasked,
    net,
):
    w = inputs
    with torch.no_grad():
        w_hat = w + 0.1 * net.mapper(
            w,
            hairstyle_text_inputs,
            color_text_inputs,
            hairstyle_tensor_hairmasked,
            color_tensor_hairmasked,
        )
        x_hat, w_hat = net.decoder(
            [w_hat],
            input_is_latent=True,
            return_latents=True,
            randomize_noise=False,
            truncation=1,
        )
        x, _ = net.decoder(
            [w], input_is_latent=True, randomize_noise=False, truncation=1
        )
        result_batch = (x_hat, w_hat, x)
    return result_batch

# Check image

## Input_image

In [None]:
import shutil
import os

# Define a function to remove a directory and its contents recursively
def remove_folder(folder_path):
    shutil.rmtree(folder_path)

# Example usage:
if os.path.exists('/kaggle/working/Image'):
    remove_folder('/kaggle/working/Image')
    
!git clone https://github.com/HongQuan2003/21522490_final_ACV.git /kaggle/working/Image

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
import torchvision.io as io

input_image_path = "/kaggle/working/Image/Input_image/IMG_0127.jpeg"
input_image = io.read_image(input_image_path)

# Plot the image
plt.imshow(input_image.permute(1, 2, 0))
plt.axis('off')
plt.show()

**Rotate (if needed)**

In [None]:
rotated_input_image_path = "/kaggle/working/rotated_input_image.jpeg"

# Rotate the image tensor by 90 degrees counter-clockwise
rotated_input_image = torch.rot90(input_image, k=-1, dims=(1, 2))

# Convert the rotated tensor back to a PIL Image
rotated_input_image = TF.to_pil_image(rotated_input_image)

# Save the rotated image
rotated_input_image.save(rotated_input_image_path)

# Assuming 'image_tensor' is your image tensor
input_image = io.read_image(rotated_input_image_path)

# Plot the image
plt.imshow(input_image.permute(1, 2, 0))
plt.axis('off')
plt.show()

input_image_path = rotated_input_image_path

## Ref_image

### hairstyle

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
import torchvision.io as io
import os

ref_image_path = "/kaggle/input/flickr-faces-hq-dataset-ffhq/images1024x1024/00000.png"

if os.path.exists('/kaggle/working/reference_image_hairstyle') == False:
    os.mkdir('/kaggle/working/reference_image_hairstyle')
    
ref_image = io.read_image(ref_image_path)
ref_image = TF.to_pil_image(ref_image)
ref_image.save("/kaggle/working/reference_image_hairstyle/ref_img.jpeg")


ref_image = io.read_image("/kaggle/working/reference_image_hairstyle/ref_img.jpeg")
# Plot the image
plt.imshow(ref_image.permute(1, 2, 0))
plt.axis('off')
plt.show()

**Rotate (if needed)**

In [None]:
# Rotate the image tensor by 90 degrees counter-clockwise
rotated_ref_image = torch.rot90(ref_image, k=-1, dims=(1, 2))

# Convert the rotated tensor back to a PIL Image
rotated_ref_image = TF.to_pil_image(rotated_ref_image)

# Save the rotated image
rotated_ref_image.save("/kaggle/working/reference_image_hairstyle/ref_img.jpeg")

# Assuming 'image_tensor' is your image tensor
ref_image = io.read_image("/kaggle/working/reference_image_hairstyle/ref_img.jpeg")

# Plot the image
plt.imshow(ref_image.permute(1, 2, 0))
plt.axis('off')
plt.show()

### Color

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
import torchvision.io as io
import os

ref_image_path = "/kaggle/input/celebahq-resized-256x256/celeba_hq_256/00003.jpg"

if os.path.exists('/kaggle/working/reference_image_color') == False:
    os.mkdir('/kaggle/working/reference_image_color')
    
ref_image = io.read_image(ref_image_path)
ref_image = TF.to_pil_image(ref_image)
ref_image.save("/kaggle/working/reference_image_color/ref_img.jpeg")


ref_image = io.read_image("/kaggle/working/reference_image_color/ref_img.jpeg")
# Plot the image
plt.imshow(ref_image.permute(1, 2, 0))
plt.axis('off')
plt.show()

**Rotate (if needed)**


In [None]:
# Rotate the image tensor by 90 degrees counter-clockwise
rotated_ref_image = torch.rot90(ref_image, k=-1, dims=(1, 2))

# Convert the rotated tensor back to a PIL Image
rotated_ref_image = TF.to_pil_image(rotated_ref_image)

# Save the rotated image
rotated_ref_image.save("/kaggle/working/reference_image_color/ref_img.jpeg")

# Assuming 'image_tensor' is your image tensor
ref_image = io.read_image("/kaggle/working/reference_image_color/ref_img.jpeg")

# Plot the image
plt.imshow(ref_image.permute(1, 2, 0))
plt.axis('off')
plt.show()

# Edit hair with HairClip_e4e

In [None]:
%cd /kaggle/working/HairCLIP

In [None]:
Hair_clip_e4e = Predictor_e4e()
Hair_clip_e4e.setup()

In [None]:
%matplotlib inline

outpath = Hair_clip_e4e.predict(
        image = input_image_path,
        editing_type = "both",
        input_type = "image_text",
        hairstyle_description = "dreadlocks hairstyle",
        color_description = "purple",
        hairstyle_ref_img_test_path = "/kaggle/working/reference_image_hairstyle",
        color_ref_img_test_path = "/kaggle/working/reference_image_color",
)

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
import torchvision.io as io

# Assuming 'image_tensor' is your image tensor
output_image = io.read_image(str(outpath))

# Plot the image
plt.imshow(output_image.permute(1, 2, 0))
plt.axis('off')
plt.show()

# LatentDataset

In [2]:
%cd /kaggle/working/HairCLIP

/kaggle/working/HairCLIP


In [4]:
from torch.utils.data import Dataset
import numpy as np
import clip
import torch
import random
import sys
from PIL import Image
import torchvision.transforms as transforms
sys.path.insert(0, "/kaggle/working/HairCLIP/mapper")
from mapper.training import train_utils
import os

In [47]:
class LatentsDataset(Dataset):

    def __init__(self, latents, opts, status='train'):
        self.latents = latents
        self.opts = opts
        self.status = status
        assert (self.opts.hairstyle_manipulation_prob+self.opts.color_manipulation_prob+self.opts.both_manipulation_prob) <= 1
        with open(self.opts.hairstyle_description, "r") as fd:
            self.hairstyle_description_list = fd.read().splitlines()

        self.hairstyle_list = [single_hairstyle_description[:-9] for single_hairstyle_description in self.hairstyle_description_list]
        self.color_list = [single_color_description.strip()+' ' for single_color_description in self.opts.color_description.split(',')]
        self.image_transform = transforms.Compose([transforms.Resize((1024, 1024)), transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
        if self.status == 'train':
            self.out_domain_hairstyle_img_path_list = sorted(train_utils.make_dataset(self.opts.hairstyle_ref_img_train_path))
            self.out_domain_color_img_path_list = sorted(train_utils.make_dataset(self.opts.color_ref_img_train_path))
        else:
            self.out_domain_hairstyle_img_path_list = sorted(train_utils.make_dataset(self.opts.hairstyle_ref_img_test_path))
            self.out_domain_color_img_path_list = sorted(train_utils.make_dataset(self.opts.color_ref_img_test_path))


    def manipulate_hairstyle(self, index):
        color_text_embedding = torch.Tensor([0])
        color_tensor = torch.Tensor([0])
        if random.random() < self.opts.hairstyle_text_manipulation_prob:
            selected_hairstyle_description = np.random.choice(self.hairstyle_list)+'hairstyle'
            selected_description = selected_hairstyle_description
            hairstyle_text_embedding = torch.cat([clip.tokenize(selected_hairstyle_description)])[0]
            hairstyle_tensor = torch.Tensor([0])
        else:
            hairstyle_text_embedding = torch.Tensor([0])
            img_pil = Image.open(random.choice(self.out_domain_hairstyle_img_path_list))
            hairstyle_tensor = self.image_transform(img_pil)
            selected_description = 'hairstyle_out_domain_ref'
        return self.latents[index], hairstyle_text_embedding, color_text_embedding, selected_description, hairstyle_tensor, color_tensor

    def manipulater_color(self, index):
        hairstyle_text_embedding = torch.Tensor([0])
        hairstyle_tensor = torch.Tensor([0])
        selected_color_description = np.random.choice(self.color_list)+'hair'
        if random.random() < self.opts.color_text_manipulation_prob:
            selected_description = selected_color_description
            color_text_embedding = torch.cat([clip.tokenize(selected_color_description)])[0]
            color_tensor = torch.Tensor([0])
        else:
            color_text_embedding = torch.Tensor([0])
            if random.random() < (self.opts.color_in_domain_ref_manipulation_prob/(1-self.opts.color_text_manipulation_prob)):
                selected_description = 'color_in_domain_ref'
                img_pil = Image.open(self.opts.color_ref_img_in_domain_path+selected_color_description+'/'+str(random.randint(0, (self.opts.num_for_each_augmented_color-1))).zfill(5)+'.jpg')
                color_tensor = self.image_transform(img_pil)
            else:
                selected_description = 'color_out_domain_ref'
                img_pil = Image.open(random.choice(self.out_domain_color_img_path_list))
                color_tensor = self.image_transform(img_pil)
        return self.latents[index], hairstyle_text_embedding, color_text_embedding, selected_description, hairstyle_tensor, color_tensor

    def manipulater_hairstyle_and_color(self, index):
        returned_latent, hairstyle_text_embedding, _, selected_hairstyle_description, hairstyle_tensor, _ = self.manipulate_hairstyle(index)
        _, _, color_text_embedding, selected_color_description, _, color_tensor = self.manipulater_color(index)
        selected_description = f'{selected_hairstyle_description}-{selected_color_description}'
        return returned_latent, hairstyle_text_embedding, color_text_embedding, selected_description, hairstyle_tensor, color_tensor

    def no_editing(self, index):
        return self.latents[index], torch.Tensor([0]), torch.Tensor([0]), 'no_editing', torch.Tensor([0]), torch.Tensor([0])

    def __len__(self):
        return self.latents.shape[0]

    def __getitem__(self, index):
        function_list = ['self.manipulate_hairstyle(index)', 'self.manipulater_color(index)', 'self.manipulater_hairstyle_and_color(index)', 'self.no_editing(index)']
        prob_array = np.array([self.opts.hairstyle_manipulation_prob, self.opts.color_manipulation_prob, self.opts.both_manipulation_prob, (1-self.opts.hairstyle_manipulation_prob-self.opts.color_manipulation_prob-self.opts.both_manipulation_prob)])
        return eval(np.random.choice(function_list, replace=False, p=prob_array.ravel()))

# image_embedding_loss.py

In [None]:
import torch
import clip
import torchvision.transforms as transforms

class ImageEmbddingLoss(torch.nn.Module):

    def __init__(self):
        super(ImageEmbddingLoss, self).__init__()
        self.model, _ = clip.load("ViT-B/32", device="cuda")
        self.transform = transforms.Compose([transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
        self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
        self.cosloss = torch.nn.CosineEmbeddingLoss()

    def forward(self, masked_generated, masked_img_tensor):
#         print(f'masked_generated: {masked_generated.shape}, masked_img_tensor: {masked_img_tensor.shape}')
        
        masked_generated = self.face_pool(masked_generated)
#         print(f'masked_generated: {masked_generated.shape}')
              
        masked_generated_renormed = self.transform(masked_generated * 0.5 + 0.5)
#         print(f'masked_generated_renormed: {masked_generated_renormed.shape}')

        masked_generated_feature = self.model.encode_image(masked_generated_renormed)
#         print(f'masked_generated_feature: {masked_generated_feature.shape}')

        masked_img_tensor = self.face_pool(masked_img_tensor)
#         print(f'masked_img_tensor: {masked_img_tensor.shape}')

        masked_img_tensor_renormed = self.transform(masked_img_tensor * 0.5 + 0.5)
#         print(f'masked_img_tensor_renormed: {masked_img_tensor_renormed.shape}')
        
        masked_img_tensor_feature = self.model.encode_image(masked_img_tensor_renormed)
#         print(f'masked_img_tensor_feature: {masked_img_tensor_feature.shape}')
              
        cos_target = torch.ones((masked_img_tensor.shape[0])).float().cuda()
        similarity = self.cosloss(masked_generated_feature, masked_img_tensor_feature, cos_target).unsqueeze(0).unsqueeze(0)
        return similarity


# coach.py

In [48]:
%cd /kaggle/working/HairCLIP

/kaggle/working/HairCLIP


In [49]:
import os
import sys
import clip
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from criteria.parse_related_loss import bg_loss, average_lab_color_loss
import criteria.clip_loss as clip_loss
# import criteria.image_embedding_loss as image_embedding_loss
from criteria import id_loss

# sys.path.insert(0, "mapper")
# from mapper.datasets.latents_dataset import LatentsDataset
from mapper.hairclip_mapper import HairCLIPMapper
from mapper.training.ranger import Ranger
from mapper.training import train_utils

In [50]:
class Coach:
    def __init__(self, opts):
        self.opts = opts
        self.global_step = 0
        self.device = 'cuda:0'
        self.opts.device = self.device

        # Initialize network
        self.net = HairCLIPMapper(self.opts).to(self.device)

        # Initialize loss
        self.id_loss = id_loss.IDLoss(self.opts).to(self.device).eval()
        self.clip_loss = clip_loss.CLIPLoss(opts)
        self.latent_l2_loss = nn.MSELoss().to(self.device).eval()
        self.background_loss = bg_loss.BackgroundLoss(self.opts).to(self.device).eval()
        self.image_embedding_loss = ImageEmbddingLoss()
        self.average_color_loss = average_lab_color_loss.AvgLabLoss(self.opts).to(self.device).eval()
        self.maintain_color_for_hairstyle_loss = average_lab_color_loss.AvgLabLoss(self.opts).to(self.device).eval()

        # Initialize optimizer
        self.optimizer = self.configure_optimizers()

        # Initialize dataset
        self.train_dataset, self.test_dataset = self.configure_datasets()
        self.train_dataloader = DataLoader(self.train_dataset,
                                           batch_size=self.opts.batch_size,
                                           shuffle=True,
                                           num_workers=int(self.opts.workers),
                                           drop_last=True)
        self.test_dataloader = DataLoader(self.test_dataset,
                                          batch_size=self.opts.test_batch_size,
                                          shuffle=False,
                                          num_workers=int(self.opts.test_workers),
                                          drop_last=True)



        # Initialize logger
        log_dir = os.path.join(opts.exp_dir, 'logs')
        os.makedirs(log_dir, exist_ok=True)
        self.log_dir = log_dir
        self.logger = SummaryWriter(log_dir=log_dir)

        # Initialize checkpoint dir
        self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        self.best_val_loss = None
        if self.opts.save_interval is None:
            self.opts.save_interval = self.opts.max_steps

    def train(self):
        self.net.train()
        while self.global_step < self.opts.max_steps:
            for batch_idx, batch in enumerate(self.train_dataloader):
                self.optimizer.zero_grad()
                w, hairstyle_text_inputs, color_text_inputs, selected_description_tuple, hairstyle_tensor, color_tensor = batch
                selected_description = ''
                for item in selected_description_tuple:
                    selected_description+=item

                w = w.to(self.device)
                hairstyle_text_inputs = hairstyle_text_inputs.to(self.device)
                color_text_inputs = color_text_inputs.to(self.device)
                hairstyle_tensor = hairstyle_tensor.to(self.device)
                color_tensor = color_tensor.to(self.device)
                with torch.no_grad():
                    x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1)
                if hairstyle_tensor.shape[1] != 1:
                    hairstyle_tensor_hairmasked = hairstyle_tensor * self.average_color_loss.gen_hair_mask(hairstyle_tensor)
                else:
                    hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).cuda()
                if color_tensor.shape[1] != 1:
                    color_tensor_hairmasked = color_tensor * self.average_color_loss.gen_hair_mask(color_tensor)
                else:
                    color_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).cuda()
                w_hat = w + 0.1 * self.net.mapper(w, hairstyle_text_inputs, color_text_inputs, hairstyle_tensor_hairmasked, color_tensor_hairmasked)
                x_hat, w_hat = self.net.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1)
                loss, loss_dict = self.calc_loss(w, x, w_hat, x_hat, hairstyle_text_inputs, color_text_inputs, hairstyle_tensor, color_tensor, selected_description)
                loss.backward()
                self.optimizer.step()

                # Logging related
                if self.global_step % self.opts.image_interval == 0 or (
                        self.global_step < 1000 and self.global_step % 1000 == 0):
                    if (hairstyle_tensor.shape[1] != 1) and (color_tensor.shape[1] != 1):
                        img_tensor = torch.cat([hairstyle_tensor, color_tensor], dim = 3)
                    elif hairstyle_tensor.shape[1] != 1:
                        img_tensor = hairstyle_tensor
                    elif color_tensor.shape[1] != 1:
                        img_tensor = color_tensor
                    else:
                        img_tensor = None
                    self.parse_and_log_images(x, x_hat, img_tensor, title='images_train', selected_description=selected_description)
                if self.global_step % self.opts.board_interval == 0:
                    self.print_metrics(loss_dict, prefix='train', selected_description=selected_description)
                    self.log_metrics(loss_dict, prefix='train')

                # Validation related
                val_loss_dict = None
                if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps:
                    val_loss_dict = self.validate()
                    if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss):
                        self.best_val_loss = val_loss_dict['loss']
                        self.checkpoint_me(val_loss_dict, is_best=True)

                if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
                    if val_loss_dict is not None:
                        self.checkpoint_me(val_loss_dict, is_best=False)
                    else:
                        self.checkpoint_me(loss_dict, is_best=False)

                if self.global_step == self.opts.max_steps:
                    print('OMG, finished training!', flush=True)
                    break

                self.global_step += 1

    def validate(self):
        self.net.eval()
        agg_loss_dict = []
        for batch_idx, batch in enumerate(self.test_dataloader):
            if batch_idx > 200:
                break

            w, hairstyle_text_inputs, color_text_inputs, selected_description_tuple, hairstyle_tensor, color_tensor = batch
            selected_description = ''
            for item in selected_description_tuple:
                selected_description+=item

            with torch.no_grad():
                w = w.to(self.device).float()
                hairstyle_text_inputs = hairstyle_text_inputs.to(self.device)
                color_text_inputs = color_text_inputs.to(self.device)
                hairstyle_tensor = hairstyle_tensor.to(self.device)
                color_tensor = color_tensor.to(self.device)
                x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=True, truncation=1)
                if hairstyle_tensor.shape[1] != 1:
                    hairstyle_tensor_hairmasked = hairstyle_tensor * self.average_color_loss.gen_hair_mask(hairstyle_tensor)
                else:
                    hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).cuda()
                if color_tensor.shape[1] != 1:
                    color_tensor_hairmasked = color_tensor * self.average_color_loss.gen_hair_mask(color_tensor)
                else:
                    color_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).cuda()
                w_hat = w + 0.1 * self.net.mapper(w, hairstyle_text_inputs, color_text_inputs, hairstyle_tensor_hairmasked, color_tensor_hairmasked)
                x_hat, _ = self.net.decoder([w_hat], input_is_latent=True, randomize_noise=True, truncation=1)
                loss, cur_loss_dict = self.calc_loss(w, x, w_hat, x_hat, hairstyle_text_inputs, color_text_inputs, hairstyle_tensor, color_tensor, selected_description)
            agg_loss_dict.append(cur_loss_dict)

            # Logging related
            if (hairstyle_tensor.shape[1] != 1) and (color_tensor.shape[1] != 1):
                img_tensor = torch.cat([hairstyle_tensor, color_tensor], dim = 3)
            elif hairstyle_tensor.shape[1] != 1:
                img_tensor = hairstyle_tensor
            elif color_tensor.shape[1] != 1:
                img_tensor = color_tensor
            else:
                img_tensor = None
            self.parse_and_log_images(x, x_hat, img_tensor, title='images_val', selected_description=selected_description, index=batch_idx)

            # For first step just do sanity test on small amount of data
            if self.global_step == 0 and batch_idx >= 4:
                self.net.train()
                return None  # Do not log, inaccurate in first batch

        loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
        self.log_metrics(loss_dict, prefix='test')
        self.print_metrics(loss_dict, prefix='test', selected_description=selected_description)

        self.net.train()
        return loss_dict

    def checkpoint_me(self, loss_dict, is_best):
        save_name = 'best_model.pt' if is_best else 'latest_model.pt'
        save_dict = self.__get_save_dict()
        checkpoint_path = os.path.join(self.checkpoint_dir, save_name)
        torch.save(save_dict, checkpoint_path)
        with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f:
            if is_best:
                f.write('**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict))
            else:
                f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict))

    def configure_optimizers(self):
        params = list(self.net.mapper.parameters())
        if self.opts.optim_name == 'adam':
            optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate)
        else:
            optimizer = Ranger(params, lr=self.opts.learning_rate)
        return optimizer

    def configure_datasets(self):
        if self.opts.latents_train_path:
            train_latents = torch.load(self.opts.latents_train_path)
        else: 
            train_latents_z = torch.randn(self.opts.train_dataset_size, 512).cuda()
            train_latents = []
            for b in range(self.opts.train_dataset_size // self.opts.batch_size):
                with torch.no_grad():
                    _, train_latents_b = self.net.decoder([train_latents_z[b: b + self.opts.batch_size]],
                                                          truncation=0.7, truncation_latent=self.net.latent_avg, return_latents=True)
                    train_latents.append(train_latents_b)
            train_latents = torch.cat(train_latents)

        if self.opts.latents_test_path:
            test_latents = torch.load(self.opts.latents_test_path)
        else:
            test_latents_z = torch.randn(self.opts.train_dataset_size, 512).cuda()
            test_latents = []
            for b in range(self.opts.test_dataset_size // self.opts.test_batch_size):
                with torch.no_grad():
                    _, test_latents_b = self.net.decoder([test_latents_z[b: b + self.opts.test_batch_size]],
                                                      truncation=0.7, truncation_latent=self.net.latent_avg, return_latents=True)
                    test_latents.append(test_latents_b)
            test_latents = torch.cat(test_latents)

        train_dataset_celeba = LatentsDataset(latents=train_latents.cpu(),
                                              opts=self.opts,
                                              status='train')
        test_dataset_celeba = LatentsDataset(latents=test_latents.cpu(),
                                              opts=self.opts,
                                              status='test')
        train_dataset = train_dataset_celeba
        test_dataset = test_dataset_celeba
        print("Number of training samples: {}".format(len(train_dataset)), flush=True)
        print("Number of test samples: {}".format(len(test_dataset)), flush=True)
        return train_dataset, test_dataset

    def calc_loss(self, w, x, w_hat, x_hat, hairstyle_text_inputs, color_text_inputs, hairstyle_tensor, color_tensor, selected_description):
        loss_dict = {}
        loss = 0.0
        if self.opts.id_lambda > 0:
            loss_id, sim_improvement = self.id_loss(x_hat, x)
            loss_dict['loss_id'] = float(loss_id)
            loss_dict['id_improve'] = float(sim_improvement)
            loss = loss_id * self.opts.id_lambda * self.opts.attribute_preservation_lambda

        if self.opts.text_manipulation_lambda > 0:
            if hairstyle_text_inputs.shape[1] != 1:
                loss_text_hairstyle = self.clip_loss(x_hat, hairstyle_text_inputs).mean()
                loss_dict['loss_text_hairstyle'] = float(loss_text_hairstyle)
                loss += loss_text_hairstyle * self.opts.text_manipulation_lambda
            if color_text_inputs.shape[1] != 1:
                loss_text_color = self.clip_loss(x_hat, color_text_inputs).mean()
                loss_dict['loss_text_color'] = float(loss_text_color)
                loss += loss_text_color * self.opts.text_manipulation_lambda

        if self.opts.image_hairstyle_lambda > 0:
            if hairstyle_tensor.shape[1] != 1:
                if 'hairstyle_out_domain_ref' in selected_description:
                    loss_img_hairstyle = self.image_embedding_loss((x_hat * self.average_color_loss.gen_hair_mask(x_hat)), (hairstyle_tensor * self.average_color_loss.gen_hair_mask(hairstyle_tensor))).mean()
                    loss_dict['loss_img_hairstyle'] = float(loss_img_hairstyle)
                    loss += loss_img_hairstyle * self.opts.image_hairstyle_lambda * self.opts.image_manipulation_lambda

        if self.opts.image_color_lambda > 0:
            if color_tensor.shape[1] != 1:
                loss_img_color = self.average_color_loss(color_tensor, x_hat)
                loss_dict['loss_img_color'] = float(loss_img_color)
                loss += loss_img_color * self.opts.image_color_lambda * self.opts.image_manipulation_lambda

        if self.opts.maintain_color_lambda > 0:
            if ((hairstyle_tensor.shape[1] != 1) or (hairstyle_text_inputs.shape[1] != 1)) and (color_tensor.shape[1] == 1) and (color_text_inputs.shape[1] == 1):
                loss_maintain_color_for_hairstyle = self.maintain_color_for_hairstyle_loss(x, x_hat)
                loss_dict['loss_maintain_color_for_hairstyle'] = float(loss_maintain_color_for_hairstyle)
                loss += loss_maintain_color_for_hairstyle * self.opts.maintain_color_lambda * self.opts.attribute_preservation_lambda
        if self.opts.background_lambda > 0:
            loss_background = self.background_loss(x, x_hat)
            loss_dict['loss_background'] = float(loss_background)
            loss += loss_background * self.opts.background_lambda * self.opts.attribute_preservation_lambda
        if self.opts.latent_l2_lambda > 0:
            loss_l2_latent = self.latent_l2_loss(w_hat, w)
            loss_dict['loss_l2_latent'] = float(loss_l2_latent)
            loss += loss_l2_latent * self.opts.latent_l2_lambda * self.opts.attribute_preservation_lambda
        loss_dict['loss'] = float(loss)
        return loss, loss_dict

    def log_metrics(self, metrics_dict, prefix):
        for key, value in metrics_dict.items():
            self.logger.add_scalar('{}/{}'.format(prefix, key), value, self.global_step)

    def print_metrics(self, metrics_dict, prefix, selected_description):
        if prefix == 'train':
            print('Metrics for {}, step {}'.format(prefix, self.global_step), selected_description, flush=True)
        else:
            print('Metrics for {}, step {}'.format(prefix, self.global_step), flush=True)
        for key, value in metrics_dict.items():
            print('\t{} = '.format(key), value, flush=True)

    def parse_and_log_images(self, x, x_hat, img_tensor, title, selected_description, index=None):
        if index is None:
            path = os.path.join(self.log_dir, title, f'{str(self.global_step).zfill(5)}-{selected_description}.jpg')
        else:
            path = os.path.join(self.log_dir, title, f'{str(self.global_step).zfill(5)}-{str(index).zfill(5)}-{selected_description}.jpg')
        os.makedirs(os.path.dirname(path), exist_ok=True)
        if img_tensor is not None:
            if img_tensor.shape[3] == 1024:
                torchvision.utils.save_image(torch.cat([x.detach().cpu(), x_hat.detach().cpu(), img_tensor.detach().cpu()]), path,
                                     normalize=True, scale_each=True, value_range=(-1, 1), nrow=3)
            elif img_tensor.shape[3] == 2048:
                torchvision.utils.save_image(torch.cat([x.detach().cpu(), x_hat.detach().cpu(), img_tensor[:,:,:,0:1024].detach().cpu(), img_tensor[:,:,:,1024::].detach().cpu()]), path,
                                     normalize=True, scale_each=True, value_range=(-1, 1), nrow=4)				
        else:
            torchvision.utils.save_image(torch.cat([x.detach().cpu(), x_hat.detach().cpu()]), path,
                                     normalize=True, scale_each=True, value_range=(-1, 1), nrow=2)				

    def __get_save_dict(self):
        save_dict = {
            'state_dict': self.net.state_dict(),
            'opts': vars(self.opts)
        }
        return save_dict

# Train.py

In [51]:
import os
import json
import sys
import pprint

# sys.path.insert(0, "/kaggle/working/HairCLIP/mapper")
# from mapper.options.train_options import TrainOptions

In [52]:
from argparse import ArgumentParser


class TrainOptions:

    def __init__(self):
        self.parser = ArgumentParser()
        self.initialize()

    def initialize(self):
        self.parser.add_argument('--exp_dir', default="/kaggle/working/experiment", type=str, help='Path to experiment output directory')
        self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true")
        self.parser.add_argument('--no_medium_mapper', default=False, action="store_true")
        self.parser.add_argument('--no_fine_mapper', default=False, action="store_true")
        self.parser.add_argument('--latents_train_path', default="/kaggle/working/HairCLIP/mapper/datasets/train_faces.pt", type=str, help="The latents for the training")
        self.parser.add_argument('--latents_test_path', default="/kaggle/working/HairCLIP/mapper/datasets/test_faces.pt", type=str, help="The latents for the validation")

        self.parser.add_argument('--hairstyle_ref_img_train_path', default="/kaggle/input/celebahq-resized-256x256/celeba_hq_256", type=str, help="The hairstyle reference image for the training")
        self.parser.add_argument('--hairstyle_ref_img_test_path', default="/kaggle/input/celebahq-resized-256x256/celeba_hq_256", type=str, help="The hairstyle reference image for the validation")
        self.parser.add_argument('--color_ref_img_train_path', default="/kaggle/input/celebahq-resized-256x256/celeba_hq_256", type=str, help="The color reference image for the training")
        self.parser.add_argument('--color_ref_img_test_path', default="/kaggle/input/celebahq-resized-256x256/celeba_hq_256", type=str, help="The color reference image for the validation")
        self.parser.add_argument('--color_ref_img_in_domain_path', default="", type=str, help="The color reference image in domain for the augmentation")
        self.parser.add_argument('--num_for_each_augmented_color', default=4000, type=int, help='Number for each augmented color')

        self.parser.add_argument('--hairstyle_manipulation_prob', default=0.5, type=float, help='Probability of only manipulating the hairstyle')
        self.parser.add_argument('--color_manipulation_prob', default=0.2, type=float, help='Probability of only manipulating the color')
        self.parser.add_argument('--both_manipulation_prob', default=0.27, type=float, help='Probability of simultaneously manipulating hairstyle and color')

        self.parser.add_argument('--hairstyle_text_manipulation_prob', default=0.5, type=float, help='Probability of using text to manipulate hairstyle')
        self.parser.add_argument('--color_text_manipulation_prob', default=0.5, type=float, help='Probability of using text to manipulate color')
        self.parser.add_argument('--color_in_domain_ref_manipulation_prob', default=0, type=float, help='Probability of using in-domain reference image to manipulate color')

        self.parser.add_argument('--train_dataset_size', default=5000, type=int, help="Will be used only if no latents are given")
        self.parser.add_argument('--test_dataset_size', default=1000, type=int, help="Will be used only if no latents are given")

        self.parser.add_argument('--batch_size', default=1, type=int, help='Batch size for training')
        self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference')
        self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')
        self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers')

        self.parser.add_argument('--learning_rate', default=0.0005, type=float, help='Optimizer learning rate')
        self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')


        self.parser.add_argument('--text_manipulation_lambda', default=2.0, type=float, help='Text manipulation loss multiplier factor')
        self.parser.add_argument('--image_manipulation_lambda', default=1.0, type=float, help='Image manipulation loss multiplier factor')
        self.parser.add_argument('--attribute_preservation_lambda', default=1.0, type=float, help='Attribute preservation loss multiplier factor')


        self.parser.add_argument('--image_hairstyle_lambda', default=5.0, type=float, help='Image-based hairstyle manipulation loss multiplier factor')
        self.parser.add_argument('--image_color_lambda', default=0.02, type=float, help='Image-based color manipulation loss multiplier factor')

        self.parser.add_argument('--id_lambda', default=0.3, type=float, help='ID loss multiplier factor')
        self.parser.add_argument('--maintain_color_lambda', default=0.02, type=float, help='Color retention loss multiplier factor')
        self.parser.add_argument('--background_lambda', default=1.0, type=float, help='Background loss multiplier factor')
        self.parser.add_argument('--latent_l2_lambda', default=0.8, type=float, help='Latent L2 loss multiplier factor')

        self.parser.add_argument('--parsenet_weights', default='/kaggle/working/HairCLIP/pretrained_models/parsenet.pth', type=str, help='Path to Parsing model weights')
        self.parser.add_argument('--stylegan_weights', default='/kaggle/working/HairCLIP/pretrained_models/stylegan2-ffhq-config-f.pt', type=str, help='Path to StyleGAN model weights')
        self.parser.add_argument('--stylegan_size', default=1024, type=int)
        self.parser.add_argument('--ir_se50_weights', default='/kaggle/working/HairCLIP/pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss")
        self.parser.add_argument('--checkpoint_path', default="/kaggle/working/HairCLIP/pretrained_models/hairclip.pt", type=str, help='Path to HairCLIP model checkpoint')

        self.parser.add_argument('--max_steps', default=1000, type=int, help='Maximum number of training steps')
        self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training')
        self.parser.add_argument('--board_interval', default=50, type=int, help='Interval for logging metrics to tensorboard')
        self.parser.add_argument('--val_interval', default=2000, type=int, help='Validation interval')
        self.parser.add_argument('--save_interval', default=2000, type=int, help='Model checkpoint interval')

        self.parser.add_argument('--hairstyle_description', default="/kaggle/working/HairCLIP/mapper/hairstyle_list.txt", type=str, help='Hairstyle text prompt list')
        self.parser.add_argument('--color_description', default = "purple, red, orange, yellow, green, blue, gray, brown, black, white, blond, pink", type=str, help='Color text prompt, eg: purple, red, orange')


    def parse(self):
        opts = self.parser.parse_args([])
        return opts

In [53]:
opts = TrainOptions().parse()
if opts.batch_size != 1 or opts.test_batch_size != 1:
    raise Exception('This version only supports batch size and test batch size to be 1.')

In [54]:
import shutil
def remove_folder(folder_path):
    shutil.rmtree(folder_path)

if os.path.exists("/kaggle/working/experiment"):
    remove_folder("/kaggle/working/experiment")

In [55]:
if os.path.exists(opts.exp_dir):
    raise Exception('Oops... {} already exists'.format(opts.exp_dir))
    
os.makedirs(opts.exp_dir, exist_ok=True)

opts_dict = vars(opts)
pprint.pprint(opts_dict)
with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
    json.dump(opts_dict, f, indent=4, sort_keys=True)

coach = Coach(opts)

{'attribute_preservation_lambda': 1.0,
 'background_lambda': 1.0,
 'batch_size': 1,
 'board_interval': 50,
 'both_manipulation_prob': 0.27,
 'checkpoint_path': '/kaggle/working/HairCLIP/pretrained_models/hairclip.pt',
 'color_description': 'purple, red, orange, yellow, green, blue, gray, brown, '
                      'black, white, blond, pink',
 'color_in_domain_ref_manipulation_prob': 0,
 'color_manipulation_prob': 0.2,
 'color_ref_img_in_domain_path': '',
 'color_ref_img_test_path': '/kaggle/input/celebahq-resized-256x256/celeba_hq_256',
 'color_ref_img_train_path': '/kaggle/input/celebahq-resized-256x256/celeba_hq_256',
 'color_text_manipulation_prob': 0.5,
 'exp_dir': '/kaggle/working/experiment',
 'hairstyle_description': '/kaggle/working/HairCLIP/mapper/hairstyle_list.txt',
 'hairstyle_manipulation_prob': 0.5,
 'hairstyle_ref_img_test_path': '/kaggle/input/celebahq-resized-256x256/celeba_hq_256',
 'hairstyle_ref_img_train_path': '/kaggle/input/celebahq-resized-256x256/celeba_hq

### Check Dataloader shape

In [56]:
w, hairstyle_text_inputs, color_text_inputs, selected_description_tuple, hairstyle_tensor, color_tensor = next(iter(coach.train_dataloader))
print(w.shape)
print(hairstyle_text_inputs.shape)
print(color_text_inputs.shape)
print(selected_description_tuple) 
print(hairstyle_tensor.shape)
print(color_tensor.shape)

torch.Size([1, 18, 512])
torch.Size([1, 1])
torch.Size([1, 1])
('hairstyle_out_domain_ref-color_out_domain_ref',)
torch.Size([1, 3, 1024, 1024])
torch.Size([1, 3, 1024, 1024])


In [57]:
w, hairstyle_text_inputs, color_text_inputs, selected_description_tuple, hairstyle_tensor, color_tensor = next(iter(coach.test_dataloader))
print(w.shape)
print(hairstyle_text_inputs.shape)
print(color_text_inputs.shape)
print(selected_description_tuple) 
print(hairstyle_tensor.shape)
print(color_tensor.shape)

torch.Size([1, 18, 512])
torch.Size([1, 77])
torch.Size([1, 1])
('pageboy hairstyle',)
torch.Size([1, 1])
torch.Size([1, 1])


### Check propagation

In [64]:
with torch.no_grad():
    for batch_idx, batch in enumerate(coach.test_dataloader):
        w, hairstyle_text_inputs, color_text_inputs, selected_description_tuple, hairstyle_tensor, color_tensor = batch
        selected_description = ''
        for item in selected_description_tuple:
            selected_description+=item

        w = w.to(coach.device)
        hairstyle_text_inputs = hairstyle_text_inputs.to(coach.device)
        color_text_inputs = color_text_inputs.to(coach.device)
        hairstyle_tensor = hairstyle_tensor.to(coach.device)
        color_tensor = color_tensor.to(coach.device)
        
        x, _ = coach.net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1)
        if hairstyle_tensor.shape[1] != 1:
            hairstyle_tensor_hairmasked = hairstyle_tensor * coach.average_color_loss.gen_hair_mask(hairstyle_tensor)
        else:
            hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).cuda()
        if color_tensor.shape[1] != 1:
            color_tensor_hairmasked = color_tensor * coach.average_color_loss.gen_hair_mask(color_tensor)
        else:
            color_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).cuda()
        w_hat = w + 0.1 * coach.net.mapper(w, hairstyle_text_inputs, color_text_inputs, hairstyle_tensor_hairmasked, color_tensor_hairmasked)
        x_hat, w_hat = coach.net.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1)
        
        break

In [63]:
print(w.shape)
print(w_hat.shape)
print(x.shape)
print(x_hat.shape)

torch.Size([1, 18, 512])
torch.Size([1, 18, 512])
torch.Size([1, 3, 1024, 1024])
torch.Size([1, 3, 1024, 1024])


### Train

In [None]:
coach.train()