In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os

os.chdir("/content/drive/MyDrive/final_3/StyleGAN/")

In [None]:
!pip install lpips ninja

In [None]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from torchvision import models
from torchvision.utils import save_image
from torchvision import transforms
from collections import OrderedDict
import numpy as np
import pickle
import torch_utils
from PIL import Image
from lpips import LPIPS
from math import log10
from tqdm import tqdm
import re

import warnings
warnings.filterwarnings("ignore")

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
with open('./metrics/stylegan3-r-ffhqu-1024x1024.pkl', 'rb') as f:
    G = pickle.load(f)['G_ema'].to(device)

g_all = nn.Sequential(OrderedDict([('g_mapping', G.mapping),
        ('g_synthesis', G.synthesis)
    ]))

g_all.eval()
g_all.to(device)
g_mapping, g_synthesis = g_all[0], g_all[1]
print(device)

# Upsample using bilinear interpolation
upsample = torch.nn.Upsample(scale_factor=256/1024, mode='bilinear').to(device)

# MSE loss object
MSE_loss = nn.MSELoss(reduction="mean").to(device)

class VGG16_perceptual(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super(VGG16_perceptual, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=True).features
        self.slice1 = nn.Sequential(*list(vgg_pretrained_features)[:4])
        self.slice2 = nn.Sequential(*list(vgg_pretrained_features)[4:9])
        self.slice3 = nn.Sequential(*list(vgg_pretrained_features)[9:16])
        self.slice4 = nn.Sequential(*list(vgg_pretrained_features)[16:23])
        self.slice5 = nn.Sequential(*list(vgg_pretrained_features)[23:29])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu_1 = h
        h = self.slice2(h)
        h_relu_2 = h
        h = self.slice3(h)
        h_relu_3 = h
        h = self.slice4(h)
        h_relu_4 = h
        h = self.slice5(h)
        h_relu_5 = h
        return h_relu_1, h_relu_2, h_relu_3, h_relu_4, h_relu_5

def loss_function(syn_img, img, img_p, MSE_loss, upsample, perceptual, lpips):
    syn_img_p = upsample(syn_img)
    syn0, syn1, syn2, syn3, syn4 = perceptual(syn_img_p)
    r0, r1, r2, r3, r4 = perceptual(img_p)
    mse = MSE_loss(syn_img, img)
    lpips_distance = lpips(syn_img, img)

    per_loss = 0
    per_loss += MSE_loss(syn0, r0)
    per_loss += MSE_loss(syn1, r1)
    per_loss += MSE_loss(syn2, r2)
    per_loss += MSE_loss(syn3, r3)
    per_loss += MSE_loss(syn4, r4)

    return mse, per_loss, lpips_distance.detach().cpu().view(-1).item()

def PSNR(mse, flag=0):
    if flag == 0:
        psnr = 10 * log10(1 / mse.item())
    return psnr

def embedding_Hierarchical(image, img_num):
    img_p = upsample(image.clone())
    perceptual = VGG16_perceptual().to(device)
    latent_w = torch.zeros((1, 512), requires_grad=True, device=device)
    optimizer = optim.Adam({latent_w}, lr=0.01, betas=(0.9, 0.999), eps=1e-8)
    lpips = LPIPS(pretrained=True, net='vgg', version='0.1').to(device)

    loss_ = []
    loss_psnr = []
    for e in tqdm(range(4000), desc=f'Optimizing Phase 1 - Image {img_num}'):
        optimizer.zero_grad()
        latent_w1 = latent_w.unsqueeze(1).expand(-1, 16, -1)
        syn_img = g_synthesis(latent_w1)
        syn_img = (syn_img + 1.0) / 2.0
        mse, per_loss, lpips_distance = loss_function(syn_img, image, img_p, MSE_loss, upsample, perceptual, lpips)
        psnr = PSNR(mse, flag=0)
        loss = mse + per_loss + lpips_distance
        loss.backward()
        optimizer.step()
        loss_np = loss.detach().cpu().numpy()
        loss_p = per_loss.detach().cpu().numpy()
        loss_m = mse.detach().cpu().numpy()
        loss_lpips = lpips_distance
        loss_psnr.append(psnr)
        loss_.append(loss_np)
        if (e + 1) % 500 == 0:
            print("iter{}: loss -- {}, lpips_loss -- {}, mse_loss -- {},  percep_loss -- {}, psnr -- {}".format(e + 1, loss_np, loss_lpips, loss_m, loss_p, psnr))
            save_image(syn_img.clamp(0, 1), f"./save_images/Hier_pass_morphP1-{img_num}-{e + 1}.png")

    latent_w1 = latent_w.unsqueeze(1).expand(-1, 16, -1)
    latent_w1 = torch.tensor(latent_w1, requires_grad=True)
    optimizer = optim.Adam({latent_w1}, lr=0.01, betas=(0.9, 0.999), eps=1e-8)
    for e in tqdm(range(6000), desc=f'Optimizing Phase 2 - Image {img_num}'):
        optimizer.zero_grad()
        syn_img = g_synthesis(latent_w1)
        syn_img = (syn_img + 1.0) / 2.0
        mse, per_loss, lpips_distance = loss_function(syn_img, image, img_p, MSE_loss, upsample, perceptual, lpips)
        psnr = PSNR(mse, flag=0)
        loss = mse + per_loss + lpips_distance
        loss.backward()
        optimizer.step()
        loss_np = loss.detach().cpu().numpy()
        loss_p = per_loss.detach().cpu().numpy()
        loss_m = mse.detach().cpu().numpy()
        loss_psnr.append(psnr)
        loss_.append(loss_np)
        if (e + 1) % 500 == 0:
            print("iter{}: loss -- {}, lpips_loss -- {}, mse_loss -- {},  percep_loss -- {}, psnr -- {}".format(e + 1, loss_np, loss_lpips, loss_m, loss_p, psnr))
            save_image(syn_img.clamp(0, 1), f"./save_images/Hier_pass_morphP2-{img_num}-{e + 1}.png")

    return latent_w1

# Setup images
input_folder = "./images"
output_folder = "./latent"
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

files = os.listdir(input_folder)
image_files = [file for file in files if file.lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp'))]
sorted_files = sorted(image_files, key=lambda x: int(re.search(r'\d+', x).group()))
for idx, image_file in enumerate(sorted_files):
    # 정규표현식을 사용하여 파일 이름에서 숫자 추출
    match = re.search(r'\d+', image_file)
    if match:
        number = int(match.group())
        image_path = os.path.join(input_folder, image_file)
        latent_filename = f"latent{number}.npy"
        latent_path = os.path.join(output_folder, latent_filename)

        with open(image_path, "rb") as f:
            image = Image.open(f).convert("RGB")

        transform = transforms.Compose([transforms.ToTensor()])
        image = transform(image)
        image = image.unsqueeze(0)
        image = image.to(device)

        latent_vector = embedding_Hierarchical(image, number)
        np.save(latent_path, latent_vector.detach().cpu().numpy())
    else:
        print(f"파일: {image_file}, 파일 이름에서 숫자를 찾을 수 없습니다.")