# SetUp FS_encoder

In [None]:
import shutil
import os

# Define a function to remove a directory and its contents recursively
def remove_folder(folder_path):
    shutil.rmtree(folder_path)

# Example usage:
if os.path.exists('/kaggle/working/FeatureStyleEncoder'):
    remove_folder('/kaggle/working/FeatureStyleEncoder')
    
%cd /kaggle/working/
!git clone https://github.com/InterDigitalInc/FeatureStyleEncoder.git

In [None]:
!pip install face_alignment

In [None]:
%cd FeatureStyleEncoder

!pip install gdown
%mkdir pretrained_models
%cd pretrained_models

# download pretrained encoder
!gdown --fuzzy https://drive.google.com/file/d/1RnnBL77j_Can0dY1KOiXHvG224MxjvzC/view?usp=sharing

# download arcface pretrained model
!gdown --fuzzy https://drive.google.com/file/d/1coFTz-Kkgvoc_gRT8JFzqCgeC3lAFWQp/view?usp=sharing

# download face parsing model from https://github.com/zllrunning/face-parsing.PyTorch
!gdown --fuzzy https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
    
# download pSp pretrained model from https://github.com/eladrich/pixel2style2pixel.git
%cd ../pixel2style2pixel
!mkdir pretrained_models

%cd pretrained_models
!gdown --fuzzy https://drive.google.com/file/d/1bMTNWkh5LArlaWSc_wa8VKyq2V42T2z0/view?usp=sharing

In [None]:
!pip install PyDrive

In [None]:
%cd /kaggle/working/FeatureStyleEncoder

In [None]:
%cd /kaggle/working/FeatureStyleEncoder

import matplotlib.pyplot as plt
%matplotlib inline

import torch
import argparse
import glob
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import yaml

from PIL import Image
from tqdm import tqdm
from torchvision import transforms, utils
from utils.functions import *

from trainer import *

In [None]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(True)
Image.MAX_IMAGE_PIXELS = None

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='001', help='Path to the config file.')
parser.add_argument('--pretrained_model_path', type=str, default='./pretrained_models/143_enc.pth', help='pretrained stylegan2 model')
parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='pretrained stylegan2 model')
parser.add_argument('--arcface_model_path', type=str, default='./pretrained_models/backbone.pth', help='pretrained arcface model')
parser.add_argument('--parsing_model_path', type=str, default='./pretrained_models/79999_iter.pth', help='pretrained parsing model')
parser.add_argument('--log_path', type=str, default='./logs/', help='log file path')
parser.add_argument('--resume', action='store_true', help='resume from checkpoint')
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint file path')
parser.add_argument('--checkpoint_noiser', type=str, default='', help='checkpoint file path')
parser.add_argument('--multigpu', type=bool, default=True, help='use multiple gpus')
parser.add_argument('--input_path', type=str, default='./test/', help='evaluation data file path')
parser.add_argument('--save_path', type=str, default='./', help='output data save path')
fs_opts = parser.parse_args([])

In [None]:
config = yaml.load(open('./configs/' + fs_opts.config + '.yaml', 'r'), Loader=yaml.FullLoader)

# Initialize trainer
trainer = Trainer(config, fs_opts)
trainer.initialize(fs_opts.stylegan_model_path, fs_opts.arcface_model_path, fs_opts.parsing_model_path)  
trainer.to(device)

state_dict = torch.load(fs_opts.pretrained_model_path)#os.path.join(fs_opts.log_path, fs_opts.config + '/checkpoint.pth'))
trainer.enc.load_state_dict(torch.load(fs_opts.pretrained_model_path))
trainer.enc.eval()

print("Feature_style_encoder successfully loaded!")

# Make dataset: 10000 bức ảnh đầu trong ffhq

In [None]:
import os
import pathlib
import torch

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from typing import Tuple, Dict, List

In [None]:
# Write a custom dataset class (inherits from torch.utils.data.Dataset)
from torch.utils.data import Dataset

# 1. Subclass torch.utils.data.Dataset
class ImageFolderCustom(Dataset):
    
    # 2. Initialize with a target_dir and transform (optional) parameter
    def __init__(self, target_dir: str, 
                 range_of_samples: tuple, 
                 transform=None) -> None:
        
    # 3. Create class attributes
        self.image_dir_path = target_dir
        # Get all image names
        head, tail = range_of_samples
        self.image_names = os.listdir(target_dir)[head:tail]
        # Setup transforms
        self.transform = transform
        
    # 4. Make function to load images
    def load_image(self, index: int) -> Image.Image:
        "Opens an image via a path and returns it."
        image_path = os.path.join(self.image_dir_path, self.image_names[index])
        return Image.open(image_path) 
    
    # 5. Overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)
    def __len__(self) -> int:
        "Returns the total number of samples."
        return len(self.image_names)
    
    # 6. Overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        img = self.load_image(index)

        # Transform if necessary
        if self.transform:
            return self.transform(img)
        else:
            return img

In [None]:
img_to_tensor = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
image_dir_path = "/kaggle/input/flickr-faces-hq-dataset-ffhq/images1024x1024"

data_custom = ImageFolderCustom(target_dir=image_dir_path,
                                range_of_samples = (0, 10_000),
                                transform=img_to_tensor)

In [None]:
len(data_custom)

In [None]:
data_custom[0].shape

### Prepare Data loader

In [None]:
from torch.utils.data import DataLoader
import os

# Setup the batch size hyperparameter
BATCH_SIZE = 4
NUM_CORES = os.cpu_count()

# Turn datasets into iterables (batches)
dataloader = DataLoader(
    data_custom, # dataset to turn into iterable
    batch_size=BATCH_SIZE, # how many samples per batch? 
    shuffle=False, # shuffle data every epoch?
    
    num_workers = NUM_CORES,
    pin_memory = True
)

In [None]:
len(dataloader)

In [None]:
image_batch = next(iter(dataloader))
image_batch.shape

In [None]:
%mkdir /kaggle/working/features
%mkdir /kaggle/working/latents

In [None]:
import os
from tqdm import tqdm


with torch.no_grad():
    latents = []; features = []
    for batch, image in tqdm(enumerate(dataloader), total = len(dataloader)):
#        print(batch)
#        print(image.shape)
        output = trainer.test(img=image.to(device), return_latent=True)
        feature = output.pop()
        latent = output.pop()
#         print(feature.shape)
#         print(latent.shape)
        latents.append(latent)
        features.append(feature)
        
    latents = torch.cat(latents)
    features = torch.cat(features)
    torch.save(latents, f'/kaggle/working/latents/latents{10_000}.pt')
    torch.save(features, f'/kaggle/working/features/features{10_000}.pt')
    latents = []; features = []

# Make dataset: 20000 

In [None]:
data_custom = ImageFolderCustom(target_dir=image_dir_path,
                                range_of_samples = (10_000, 20_000),
                                transform=img_to_tensor)

In [None]:
len(data_custom)

In [None]:
data_custom[0].shape

### Prepare DataLoader

In [None]:
# Setup the batch size hyperparameter
BATCH_SIZE = 4
NUM_CORES = os.cpu_count()

# Turn datasets into iterables (batches)
dataloader = DataLoader(
    data_custom, # dataset to turn into iterable
    batch_size=BATCH_SIZE, # how many samples per batch? 
    shuffle=False, # shuffle data every epoch?
    
    num_workers = NUM_CORES,
    pin_memory = True
)

In [None]:
len(dataloader)

In [None]:
image_batch = next(iter(dataloader))
image_batch.shape

In [None]:
import os
from tqdm import tqdm

torch.cuda.empty_cache()
with torch.no_grad():
    latents = []; features = []
    for batch, image in tqdm(enumerate(dataloader), total = len(dataloader)):
#        print(batch)
#        print(image.shape)
        output = trainer.test(img=image.to(device), return_latent=True)
        feature = output.pop()
        latent = output.pop()
#         print(feature.shape)
#         print(latent.shape)
        latents.append(latent)
        features.append(feature)
        
    latents = torch.cat(latents)
    features = torch.cat(features)
    torch.save(latents, f'/kaggle/working/latents/latents{20_000}.pt')
    torch.save(features, f'/kaggle/working/features/features{20_000}.pt')
    latents = []; features = []

# Make dataset: 30000 

In [None]:
data_custom = ImageFolderCustom(target_dir=image_dir_path,
                                range_of_samples = (20_000, 30_000),
                                transform=img_to_tensor)

In [None]:
len(data_custom)

In [None]:
data_custom[0].shape

### Prepare DataLoader

In [None]:
# Setup the batch size hyperparameter
BATCH_SIZE = 4
NUM_CORES = os.cpu_count()

# Turn datasets into iterables (batches)
dataloader = DataLoader(
    data_custom, # dataset to turn into iterable
    batch_size=BATCH_SIZE, # how many samples per batch? 
    shuffle=False, # shuffle data every epoch?
    
    num_workers = NUM_CORES,
    pin_memory = True
)

In [None]:
len(dataloader)

In [None]:
image_batch = next(iter(dataloader))
image_batch.shape

In [None]:
import os
from tqdm import tqdm

torch.cuda.empty_cache()
with torch.no_grad():
    latents = []; features = []
    for batch, image in tqdm(enumerate(dataloader), total = len(dataloader)):
#        print(batch)
#        print(image.shape)
        output = trainer.test(img=image.to(device), return_latent=True)
        feature = output.pop()
        latent = output.pop()
#         print(feature.shape)
#         print(latent.shape)
        latents.append(latent)
        features.append(feature)
        
    latents = torch.cat(latents)
    features = torch.cat(features)
    torch.save(latents, f'/kaggle/working/latents/latents{30_000}.pt')
    torch.save(features, f'/kaggle/working/features/features{30_000}.pt')
    latents = []; features = []

In [None]:
remove_folder("/kaggle/working/FeatureStyleEncoder")