In [87]:
from torch_utils import misc
import dnnlib
import pickle
import torch
import PIL.Image
import numpy as np
from typing import List, Optional
from pathlib import Path

# setting

In [88]:
outdir = "./outputs"
path = Path(outdir)
path.mkdir(parents=True, exist_ok=True)
# 생성값 고정
truncation_psi = 0.5
noise_mode = 'const'
# ['const', 'random', 'none'] default:'const'

# import model

In [89]:
with open('./pretrained_models/stylegan_human_v2_1024.pkl', 'rb') as f:
    model = pickle.load(f,encoding='latin1')

In [90]:
model.keys()

dict_keys(['training_set_kwargs', 'G', 'D', 'G_ema', 'augment_pipe'])

# setting device

In [91]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


In [119]:
G = model['G_ema'].to(device)
gen = torch.load('./checkpoint/Gmodel_038_1.67.pth').to(device)

In [120]:
for i in range(10):
    seed=i

    target_z = np.array([])
    target_w = np.array([])
    latent_out = outdir.replace('/images/','')
    
    ## pretrained models
    label = torch.zeros([1, G.c_dim], device=device)
    z = torch.from_numpy(np.random.RandomState(
        seed # 랜덤값 고정
    ).randn(1, G.z_dim)).to(device)

    if target_z.size==0:
        target_z= z.cpu()
    else:
        target_z=np.append(target_z, z.cpu(), axis=0) 

    w = G.mapping(z, label,truncation_psi=truncation_psi)
    img = G.synthesis(w, noise_mode=noise_mode, force_fp32 = True)
    if target_w.size==0:
        target_w= w.cpu()
    else:
        target_w=np.append(target_w, w.cpu(), axis=0) 

    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
    
    ## fine_turning model
    label = torch.zeros([1, gen.c_dim], device=device)
    z = torch.from_numpy(np.random.RandomState(seed).randn(1, gen.z_dim)).to(device)

    w = gen.mapping(z, label,
                    truncation_psi=truncation_psi)
    img = gen.synthesis(w, noise_mode=noise_mode,
                        force_fp32 = True)
    # print(type(img))
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    # 채널 순서 변경
    PIL.Image.fromarray(
        img[0].cpu().numpy(), 
        'RGB'
    # ).save(f'{outdir}/tuning_seed{seed:04d}.png')
    ).save(f'{outdir}/tuning_seed{seed:04d}.png')

# save model(torch)

In [8]:
path = Path('./export_model/')
path.mkdir(parents=True, exist_ok=True)
torch.save(G,Path(path,'gen_ema.pth'))

# load model(torch)

In [9]:
gen = torch.load(Path(path,'gen_ema.pth'))

In [10]:
dir(gen)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_buffers',
 '_call_impl',
 '_forward_hooks',
 '_forward_pre_hooks',
 '_get_backward_hooks',
 '_get_name',
 '_init_args',
 '_init_kwargs',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_set',
 '_orig_class_name',
 '_orig_module_src',
 '_parameters',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_s

In [11]:
for child in gen.modules():
    print(child)

Generator(
  (synthesis): SynthesisNetwork(
    (b4): SynthesisBlock(
      (conv1): SynthesisLayer(
        (affine): FullyConnectedLayer()
      )
      (torgb): ToRGBLayer(
        (affine): FullyConnectedLayer()
      )
    )
    (b8): SynthesisBlock(
      (conv0): SynthesisLayer(
        (affine): FullyConnectedLayer()
      )
      (conv1): SynthesisLayer(
        (affine): FullyConnectedLayer()
      )
      (torgb): ToRGBLayer(
        (affine): FullyConnectedLayer()
      )
    )
    (b16): SynthesisBlock(
      (conv0): SynthesisLayer(
        (affine): FullyConnectedLayer()
      )
      (conv1): SynthesisLayer(
        (affine): FullyConnectedLayer()
      )
      (torgb): ToRGBLayer(
        (affine): FullyConnectedLayer()
      )
    )
    (b32): SynthesisBlock(
      (conv0): SynthesisLayer(
        (affine): FullyConnectedLayer()
      )
      (conv1): SynthesisLayer(
        (affine): FullyConnectedLayer()
      )
      (torgb): ToRGBLayer(
        (affine): FullyConn

In [12]:
for param in list(gen.named_parameters()):
    #register_parameter('mapping')
    print(param[0],'\t:\t', (param[1]).shape)

synthesis.b4.const 	:	 torch.Size([512, 4, 2])
synthesis.b4.conv1.weight 	:	 torch.Size([512, 512, 3, 3])
synthesis.b4.conv1.noise_strength 	:	 torch.Size([])
synthesis.b4.conv1.bias 	:	 torch.Size([512])
synthesis.b4.conv1.affine.weight 	:	 torch.Size([512, 512])
synthesis.b4.conv1.affine.bias 	:	 torch.Size([512])
synthesis.b4.torgb.weight 	:	 torch.Size([3, 512, 1, 1])
synthesis.b4.torgb.bias 	:	 torch.Size([3])
synthesis.b4.torgb.affine.weight 	:	 torch.Size([512, 512])
synthesis.b4.torgb.affine.bias 	:	 torch.Size([512])
synthesis.b8.conv0.weight 	:	 torch.Size([512, 512, 3, 3])
synthesis.b8.conv0.noise_strength 	:	 torch.Size([])
synthesis.b8.conv0.bias 	:	 torch.Size([512])
synthesis.b8.conv0.affine.weight 	:	 torch.Size([512, 512])
synthesis.b8.conv0.affine.bias 	:	 torch.Size([512])
synthesis.b8.conv1.weight 	:	 torch.Size([512, 512, 3, 3])
synthesis.b8.conv1.noise_strength 	:	 torch.Size([])
synthesis.b8.conv1.bias 	:	 torch.Size([512])
synthesis.b8.conv1.affine.weight 	:	 to

In [13]:
label = torch.zeros([1, gen.c_dim], device=device)
z = torch.from_numpy(np.random.randn(1, gen.z_dim)).to(device)

w = gen.mapping(z, label,
                truncation_psi=truncation_psi)
img = gen.synthesis(w, noise_mode=noise_mode,
                    force_fp32 = True)
# print(type(img))
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
# 채널 순서 변경
PIL.Image.fromarray(
    img[0].cpu().numpy(), 
    'RGB'
).save(f'{outdir}/asdf.png')

In [14]:
discriminator = model['D']

In [15]:
torch.save(discriminator,Path(path,'discriminator.pth'))

In [16]:
generator = model['G']

In [17]:
torch.save(generator,Path(path,'generator.pth'))

In [18]:
model['training_set_kwargs']

{'class_name': 'training.dataset.ImageFolderDataset',
 'path': 'all_non_repeat.txt',
 'use_labels': False,
 'max_size': 231175,
 'xflip': True,
 'ceph': True,
 'resolution': 1024}

In [19]:
print(discriminator)

Discriminator(
  (b1024): DiscriminatorBlock(
    (fromrgb): Conv2dLayer()
    (conv0): Conv2dLayer()
    (conv1): Conv2dLayer()
    (skip): Conv2dLayer()
  )
  (b512): DiscriminatorBlock(
    (conv0): Conv2dLayer()
    (conv1): Conv2dLayer()
    (skip): Conv2dLayer()
  )
  (b256): DiscriminatorBlock(
    (conv0): Conv2dLayer()
    (conv1): Conv2dLayer()
    (skip): Conv2dLayer()
  )
  (b128): DiscriminatorBlock(
    (conv0): Conv2dLayer()
    (conv1): Conv2dLayer()
    (skip): Conv2dLayer()
  )
  (b64): DiscriminatorBlock(
    (conv0): Conv2dLayer()
    (conv1): Conv2dLayer()
    (skip): Conv2dLayer()
  )
  (b32): DiscriminatorBlock(
    (conv0): Conv2dLayer()
    (conv1): Conv2dLayer()
    (skip): Conv2dLayer()
  )
  (b16): DiscriminatorBlock(
    (conv0): Conv2dLayer()
    (conv1): Conv2dLayer()
    (skip): Conv2dLayer()
  )
  (b8): DiscriminatorBlock(
    (conv0): Conv2dLayer()
    (conv1): Conv2dLayer()
    (skip): Conv2dLayer()
  )
  (b4): DiscriminatorEpilogue(
    (mbstd): Mini

In [20]:
w.shape

torch.Size([1, 18, 512])

In [21]:
z.shape

torch.Size([1, 512])

In [22]:
label.shape

torch.Size([1, 0])

In [23]:
gen(z=z,c=label)

tensor([[[[0.8581, 0.9246, 0.9495,  ..., 0.9624, 0.9901, 0.9272],
          [0.9444, 0.9468, 0.9549,  ..., 1.0303, 0.9980, 1.0108],
          [0.9157, 0.9371, 0.9693,  ..., 0.9997, 1.0271, 0.9929],
          ...,
          [0.9527, 0.9666, 0.9930,  ..., 0.9643, 0.9672, 1.0001],
          [0.9337, 0.9437, 0.9804,  ..., 1.0000, 1.0071, 1.0174],
          [0.7385, 0.9678, 0.9935,  ..., 0.9320, 0.9092, 0.7725]],

         [[0.8422, 0.9247, 0.9781,  ..., 0.9773, 0.9813, 0.8590],
          [0.9539, 0.9645, 0.9919,  ..., 1.0254, 0.9849, 0.9386],
          [0.9256, 0.9550, 0.9760,  ..., 0.9954, 1.0218, 0.9470],
          ...,
          [0.9689, 0.9891, 1.0046,  ..., 0.9888, 0.9842, 0.9972],
          [0.9458, 0.9641, 0.9977,  ..., 1.0024, 1.0141, 1.0221],
          [0.7343, 0.9723, 0.9953,  ..., 0.9343, 0.8989, 0.8040]],

         [[0.8356, 0.9236, 0.9731,  ..., 0.9830, 0.9789, 0.8590],
          [0.9805, 0.9703, 0.9918,  ..., 1.0353, 1.0040, 0.9239],
          [0.9405, 0.9534, 0.9745,  ..., 1