In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os

os.chdir("/content/drive/MyDrive/final_3/StyleGAN/")

In [None]:
!pip install lpips ninja

In [None]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from torchvision import models
from torchvision.utils import save_image
from torchvision import transforms
from collections import OrderedDict
import numpy as np
import pickle
import torch_utils
from PIL import Image
from lpips import LPIPS
from math import log10
from tqdm import tqdm
import re
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec


import warnings
warnings.filterwarnings("ignore")

In [None]:
# Define the device
if torch.backends.mps.is_available():
    device = torch.device('mps:0')
elif torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

with open('./metrics/stylegan3-r-ffhqu-1024x1024.pkl', 'rb') as f:
    G = pickle.load(f)['G_ema'].to(device)

g_all = nn.Sequential(OrderedDict([('g_mapping', G.mapping),
        ('g_synthesis', G.synthesis)
    ]))

g_all.eval()
g_all.to(device)
g_mapping, g_synthesis = g_all[0],g_all[1]
print(device)

In [None]:
weights_test_path = './weights_test'
if not os.path.exists(weights_test_path):
    os.makedirs(weights_test_path)

In [None]:
latent_folder = './latent'
files = os.listdir(latent_folder)
latent_files = [file for file in files if file.lower().endswith('.npy')]
sorted_files = sorted(latent_files, key=lambda x: int(re.search(r'\d+', x).group()))

latents = []
for idx, latent_file in enumerate(sorted_files):
    # 정규표현식을 사용하여 파일 이름에서 숫자 추출
    match = re.search(r'\d+', latent_file)
    if match:
        number = int(match.group())
        latent_path = os.path.join(latent_folder, latent_file)
        latent = np.load(latent_path)
        latent = torch.tensor(latent, dtype=torch.float32, device=device)
        latents.append(latent)
    else:
        print(f"파일: {latent_file}, 파일 이름에서 숫자를 찾을 수 없습니다.")

print(len(latents))

In [None]:
test_latent = latents[0].clone()
print(test_latent.shape)
print(test_latent[0].shape)
print(test_latent[0][idx].shape)

In [None]:
test_latent = latents[0].clone()
print(test_latent.shape)
print(test_latent[0].shape)
print(test_latent[0][:3].shape)
print(test_latent[0][3:6].shape)
print(test_latent[0][6:].shape)

In [None]:
## latent의 각 열이 의미하는 바를 확인
weight_list = np.ones(16, dtype=int)

img_list = []
for i in range(len(weight_list)):
  test_latent = latents[0].clone()

  if i==0:
    syn_img = g_synthesis(test_latent)
    syn_img = (syn_img+1.0)/2.0
    syn_img = syn_img.clamp(0,1).detach().cpu()
    img_list.append(syn_img)

  test_weight = weight_list.copy()
  test_weight[i] = 5

  for idx, weight in enumerate(test_weight):
    test_latent[0][idx] = test_latent[0][idx] * weight

  syn_img = g_synthesis(test_latent)
  syn_img = (syn_img+1.0)/2.0
  syn_img = syn_img.clamp(0,1).detach().cpu()
  img_list.append(syn_img)

plt.imshow(img_list[0].squeeze().permute(1, 2, 0).numpy())
plt.axis("off")
plt.title(f"F1 latent vector")
plt.show()

# 이미지를 2행 8열로 출력
rows = 2
cols = 8

fig = plt.figure(figsize=(20, 6))
gs = GridSpec(rows, cols, figure=fig)

for i in range(len(img_list[1:])):
    row = i // cols
    col = i % cols

    ax = fig.add_subplot(gs[row, col])
    ax.imshow(img_list[i+1].squeeze().permute(1, 2, 0).numpy())
    ax.set_title(f"index {i}")
    ax.axis("off")

plt.suptitle("F1 latent vector with modified weight")
plt.show()

In [None]:
## latent의 부분별 의미하는 바를 확인
weight_list = np.ones(5, dtype=int)

img_list = []
for i in range(len(weight_list)):
  test_latent = latents[0].clone()

  if i==0:
    syn_img = g_synthesis(test_latent)
    syn_img = (syn_img+1.0)/2.0
    syn_img = syn_img.clamp(0,1).detach().cpu()
    img_list.append(syn_img)

  test_weight = weight_list.copy()
  test_weight[i] = 5

  for idx, weight in enumerate(test_weight): # 0, 1, 2, 3, 4
    if idx==0:
      test_latent[0][idx*3:(idx+1)*3+1] = test_latent[0][idx*3:(idx+1)*3+1] * weight # 0~3
    elif idx==4:
      test_latent[0][idx*3:] = test_latent[0][idx*3:] * weight # 12~16
    else:
      test_latent[0][idx*3:(idx+1)*3] = test_latent[0][idx*3:(idx+1)*3] * weight # 3~5, 6~8, 9~11

  syn_img = g_synthesis(test_latent)
  syn_img = (syn_img+1.0)/2.0
  syn_img = syn_img.clamp(0,1).detach().cpu()
  img_list.append(syn_img)

plt.imshow(img_list[0].squeeze().permute(1, 2, 0).numpy())
plt.axis("off")
plt.title(f"F1 latent vector")
plt.show()

# 이미지를 1행 5열로 출력
rows = 1
cols = 5

fig = plt.figure(figsize=(20, 4))
gs = GridSpec(rows, cols, figure=fig)

for i in range(len(img_list[1:])):
    row = i // cols
    col = i % cols

    ax = fig.add_subplot(gs[row, col])
    ax.imshow(img_list[i+1].squeeze().permute(1, 2, 0).numpy())
    ax.set_title(f"section {i}")
    ax.axis("off")

plt.suptitle("F1 latent vector with modified weight")
plt.show()

In [None]:
## latent의 각 열이 의미하는 바를 확인
weight_list = np.ones(16, dtype=int)

img_list = []
for i in range(len(weight_list)):
  test_latent = latents[1].clone()

  if i==0:
    syn_img = g_synthesis(test_latent)
    syn_img = (syn_img+1.0)/2.0
    syn_img = syn_img.clamp(0,1).detach().cpu()
    img_list.append(syn_img)

  test_weight = weight_list.copy()
  test_weight[i] = 5

  for idx, weight in enumerate(test_weight):
    test_latent[0][idx] = test_latent[0][idx] * weight

  syn_img = g_synthesis(test_latent)
  syn_img = (syn_img+1.0)/2.0
  syn_img = syn_img.clamp(0,1).detach().cpu()
  img_list.append(syn_img)

plt.imshow(img_list[0].squeeze().permute(1, 2, 0).numpy())
plt.axis("off")
plt.title(f"F2 latent vector")
plt.show()

# 이미지를 2행 8열로 출력
rows = 2
cols = 8

fig = plt.figure(figsize=(20, 6))
gs = GridSpec(rows, cols, figure=fig)

for i in range(len(img_list[1:])):
    row = i // cols
    col = i % cols

    ax = fig.add_subplot(gs[row, col])
    ax.imshow(img_list[i+1].squeeze().permute(1, 2, 0).numpy())
    ax.set_title(f"index {i}")
    ax.axis("off")

plt.suptitle("F2 latent vector with modified weight")
plt.show()

In [None]:
## latent의 부분별 의미하는 바를 확인
weight_list = np.ones(5, dtype=int)

img_list = []
for i in range(len(weight_list)):
  test_latent = latents[1].clone()

  if i==0:
    syn_img = g_synthesis(test_latent)
    syn_img = (syn_img+1.0)/2.0
    syn_img = syn_img.clamp(0,1).detach().cpu()
    img_list.append(syn_img)

  test_weight = weight_list.copy()
  test_weight[i] = 5

  for idx, weight in enumerate(test_weight): # 0, 1, 2, 3, 4
    if idx==0:
      test_latent[0][idx*3:(idx+1)*3+1] = test_latent[0][idx*3:(idx+1)*3+1] * weight # 0~3
    elif idx==4:
      test_latent[0][idx*3:] = test_latent[0][idx*3:] * weight # 12~16
    else:
      test_latent[0][idx*3:(idx+1)*3] = test_latent[0][idx*3:(idx+1)*3] * weight # 3~5, 6~8, 9~11

  syn_img = g_synthesis(test_latent)
  syn_img = (syn_img+1.0)/2.0
  syn_img = syn_img.clamp(0,1).detach().cpu()
  img_list.append(syn_img)

plt.imshow(img_list[0].squeeze().permute(1, 2, 0).numpy())
plt.axis("off")
plt.title(f"F2 latent vector")
plt.show()

# 이미지를 1행 5열로 출력
rows = 1
cols = 5

fig = plt.figure(figsize=(20, 4))
gs = GridSpec(rows, cols, figure=fig)

for i in range(len(img_list[1:])):
    row = i // cols
    col = i % cols

    ax = fig.add_subplot(gs[row, col])
    ax.imshow(img_list[i+1].squeeze().permute(1, 2, 0).numpy())
    ax.set_title(f"section {i}")
    ax.axis("off")

plt.suptitle("F2 latent vector with modified weight")
plt.show()

In [None]:
# ## 이미지 전체의 가중치
# weight = [0.5, 0.3, 0.1, 0.05, 0.02, 0.01, 0.01, 0.01]
# for idx, latent in enumerate(latents):
#     if idx == 0:
#         latents[0] = latent * weight[idx]
#     else:
#         latents[0] += latent * weight[idx]

# syn_img = g_synthesis(latents[0])
# syn_img = (syn_img+1.0)/2.0
# save_image(syn_img.clamp(0,1),f"out.png")


##기준 latent1.png

## 각각이 의미하는 바 확인
# weight_sum_each => 16x512중 하나의 열에 5, 나머지 열에 1의 가중치로 각각의 열의 의미 확인
# weight_dict = {
#     1: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5],
#     }


# for j in range(16):
#     latents[0][:,j] = latents[0][:,j] * weight_dict[1][j]


# syn_img = g_synthesis(latents[0])
# syn_img = (syn_img+1.0)/2.0
# save_image(syn_img.clamp(0,1),f"weight_sum_each_15.png")

## 묶음이 의미하는 바 확인
# weight_sum => 0:3, 3:6, 6:9, 9:12, 12: 각각 가중치 2,1,1,1,1로 번갈아가면서 실험


# weight_dict = {
#     1: [0.5, 0.4, 0.4, 0.4, 0.4],
#     2: [0.2, 0.2, 0.2, 0.2, 0.2],
#     3: [0.15, 0.15, 0.15, 0.15, 0.15],
#     4: [0.1, 0.12, 0.1, 0.1, 0.1],
#     5: [0.05, 0.08, 0.08, 0.05, 0.05],
#     6: [0.04, 0.06, 0.06, 0.04, 0.04],
#     7: [0.03, 0.05, 0.05, 0.03, 0.03],
#     8: [0.02, 0.04, 0.04, 0.02, 0.02],
#     }


# for i in range(len(recommendation_list)):
#     for j in range(5):
#         if j == 0:
#             latents[i][:, 0:3] = latents[i][:, 0:3] * weight_dict[i+1][j]
#         elif j == 1:
#             latents[i][:, 3:6] = latents[i][:, 3:6] * weight_dict[i+1][j]
#         elif j == 2:
#             latents[i][:, 6:9] = latents[i][:, 6:9] * weight_dict[i+1][j]
#         elif j == 3:
#             latents[i][:, 9:12] = latents[i][:, 9:12] * weight_dict[i+1][j]
#         elif j == 4:
#             latents[i][:, 12:] = latents[i][:, 12:] * weight_dict[i+1][j]

# # 리스트에 있는 텐서들을 더하기
# sum_tensor = torch.sum(torch.stack(latents), dim=0)
# print(sum_tensor.shape)

# syn_img = g_synthesis(sum_tensor)
# syn_img = (syn_img+1.0)/2.0
# save_image(syn_img.clamp(0,1),f"weight_sum_0.png")


In [None]:
# # Define the device
# if torch.backends.mps.is_available():
#     device = torch.device('mps:0')
# elif torch.cuda.is_available():
#     device = torch.device('cuda:0')
# else:
#     device = torch.device('cpu')

# device = torch.device('cpu')
# with open('./metrics/stylegan3-r-ffhqu-1024x1024.pkl', 'rb') as f:
#     G = pickle.load(f)['G_ema'].to(device)

# g_all = nn.Sequential(OrderedDict([('g_mapping', G.mapping),
#         ('g_synthesis', G.synthesis)
#     ]))

# g_all.eval()
# g_all.to(device)
# g_mapping, g_synthesis = g_all[0],g_all[1]
# print(device)

# recommendation_list = [31,1,31,8,10,13,14,2]
# latents  = []

# for i in recommendation_list:
#   latent = np.load(f'./latent/latent{i}.npy')
#   latent = torch.tensor(latent, dtype=torch.float32, device=device)
#   latents.append(latent)


# ## 이미지 전체의 가중치
# weight = [0.5, 0.3, 0.1, 0.05, 0.02, 0.01, 0.01, 0.01]
# for idx, latent in enumerate(latents):
#     if idx == 0:
#         latents[0] = latent * weight[idx]
#     else:
#         latents[0] += latent * weight[idx]


# syn_img = g_synthesis(latents[0])
# syn_img = (syn_img+1.0)/2.0
# save_image(syn_img.clamp(0,1),f"out.png")


##기준 latent1.png

## 각각이 의미하는 바 확인
# weight_sum_each => 16x512중 하나의 열에 5, 나머지 열에 1의 가중치로 각각의 열의 의미 확인
# weight_dict = {
#     1: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5],
#     }


# for j in range(16):
#     latents[0][:,j] = latents[0][:,j] * weight_dict[1][j]


# syn_img = g_synthesis(latents[0])
# syn_img = (syn_img+1.0)/2.0
# save_image(syn_img.clamp(0,1),f"weight_sum_each_15.png")

## 묶음이 의미하는 바 확인
# weight_sum => 0:3, 3:6, 6:9, 9:12, 12: 각각 가중치 2,1,1,1,1로 번갈아가면서 실험


# weight_dict = {
#     1: [0.5, 0.4, 0.4, 0.4, 0.4],
#     2: [0.2, 0.2, 0.2, 0.2, 0.2],
#     3: [0.15, 0.15, 0.15, 0.15, 0.15],
#     4: [0.1, 0.12, 0.1, 0.1, 0.1],
#     5: [0.05, 0.08, 0.08, 0.05, 0.05],
#     6: [0.04, 0.06, 0.06, 0.04, 0.04],
#     7: [0.03, 0.05, 0.05, 0.03, 0.03],
#     8: [0.02, 0.04, 0.04, 0.02, 0.02],
#     }


# for i in range(len(recommendation_list)):
#     for j in range(5):
#         if j == 0:
#             latents[i][:, 0:3] = latents[i][:, 0:3] * weight_dict[i+1][j]
#         elif j == 1:
#             latents[i][:, 3:6] = latents[i][:, 3:6] * weight_dict[i+1][j]
#         elif j == 2:
#             latents[i][:, 6:9] = latents[i][:, 6:9] * weight_dict[i+1][j]
#         elif j == 3:
#             latents[i][:, 9:12] = latents[i][:, 9:12] * weight_dict[i+1][j]
#         elif j == 4:
#             latents[i][:, 12:] = latents[i][:, 12:] * weight_dict[i+1][j]

# # 리스트에 있는 텐서들을 더하기
# sum_tensor = torch.sum(torch.stack(latents), dim=0)
# print(sum_tensor.shape)

# syn_img = g_synthesis(sum_tensor)
# syn_img = (syn_img+1.0)/2.0
# save_image(syn_img.clamp(0,1),f"weight_sum_0.png")
