In [13]:
import torch
import dnnlib
import legacy
import numpy as np
import pickle
import copy

cfg_specs = {
    'auto':      dict(ref_gpus=-1, kimg=25000,  mb=-1, mbstd=-1, fmaps=-1,  lrate=-1,     gamma=-1,   ema=-1,  ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count.
    'stylegan2': dict(ref_gpus=8,  kimg=25000,  mb=32, mbstd=4,  fmaps=1,   lrate=0.002,  gamma=10,   ema=10,  ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2.
    'paper256':  dict(ref_gpus=8,  kimg=25000,  mb=64, mbstd=8,  fmaps=0.5, lrate=0.0025, gamma=1,    ema=20,  ramp=None, map=8),
    'paper512':  dict(ref_gpus=8,  kimg=25000,  mb=64, mbstd=8,  fmaps=1,   lrate=0.0025, gamma=0.5,  ema=20,  ramp=None, map=8),
    'paper1024': dict(ref_gpus=8,  kimg=25000,  mb=32, mbstd=4,  fmaps=1,   lrate=0.002,  gamma=2,    ema=10,  ramp=None, map=8),
    'cifar':     dict(ref_gpus=2,  kimg=100000, mb=64, mbstd=32, fmaps=1,   lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2),
}

cfg = 'stylegan2' # ['stylegan2', 'paper256]
img_resolution = 256 # [256, 1024]
ckpt_path = '/path/to/unofficial/weight/path'

In [14]:
img_channels = 3

args = dnnlib.EasyDict()
assert cfg in cfg_specs
spec = dnnlib.EasyDict(cfg_specs[cfg])

args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict())
args.D_kwargs = dnnlib.EasyDict(class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict())
args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(spec.fmaps * 32768)
args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
args.G_kwargs.mapping_kwargs.num_layers = spec.map
args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 # enable mixed-precision training
args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp activations to avoid float16 overflow
args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd
args.G_kwargs["class_name"] = "training.networks.Generator"

# Pruned model

# args.G_kwargs["class_name"] = "training.networks_stylekd.Generator"
# args.G_kwargs["prune_ratio"] = 0.7


In [15]:
device = 'cpu'

common_kwargs = dict(c_dim=0, img_resolution=img_resolution, img_channels=img_channels)
G_ema = dnnlib.util.construct_class_by_name(**args.G_kwargs, **common_kwargs).eval().requires_grad_(False).to(device)
D = dnnlib.util.construct_class_by_name(**args.D_kwargs, **common_kwargs).eval().requires_grad_(False).to(device) # subclass of torch.nn.Module
ckpt = torch.load(ckpt_path, map_location='cpu')['g_ema']

In [18]:
def convert_mapping(ckpt_from, G_ema):
    mapping_dict = G_ema.mapping.state_dict()
    len_map = (len(mapping_dict.keys())-1)//2
    
    for idx in range(len_map):
        from_weight_key = f"style.{idx+1}.weight"
        from_bias_key = f"style.{idx+1}.bias"
        
        to_weight_key = f"mapping.fc{idx}.weight"
        to_bias_key = f"mapping.fc{idx}.bias"
        
        ckpt_from[to_weight_key] = ckpt_from[from_weight_key]
        ckpt_from[to_bias_key] = ckpt_from[from_bias_key]
        
        del ckpt_from[from_weight_key]    
        del ckpt_from[from_bias_key]

In [19]:
def convert_torgb(ckpt_from, G_ema):
    channels_res = [int(res) for res in 4*2**np.arange(np.log2(G_ema.img_resolution)-1)]
    
    res4 = channels_res.pop(0)
    from_weight_key = f"to_rgb1.conv.weight"
    from_bias_key = f"to_rgb1.bias"
    from_affine_weight_key = f"to_rgb1.conv.modulation.weight"
    from_affine_bias_key = f"to_rgb1.conv.modulation.bias"
    
    to_weight_key = f"synthesis.b4.torgb.weight"
    to_bias_key = f"synthesis.b4.torgb.bias"
    to_affine_weight_key = f"synthesis.b4.torgb.affine.weight"
    to_affine_bias_key = f"synthesis.b4.torgb.affine.bias"
    
    ckpt_from[to_weight_key] = ckpt_from[from_weight_key][0]
    ckpt_from[to_bias_key] = ckpt_from[from_bias_key].squeeze()
    ckpt_from[to_affine_weight_key] = ckpt_from[from_affine_weight_key]
    ckpt_from[to_affine_bias_key] = ckpt_from[from_affine_bias_key]
    del ckpt_from[from_weight_key]    
    del ckpt_from[from_bias_key]
    del ckpt_from[from_affine_weight_key]    
    del ckpt_from[from_affine_bias_key]
    
    
    for idx, res in enumerate(channels_res):
        from_weight_key = f"to_rgbs.{idx}.conv.weight"
        from_bias_key = f"to_rgbs.{idx}.bias"
        from_affine_weight_key = f"to_rgbs.{idx}.conv.modulation.weight"
        from_affine_bias_key = f"to_rgbs.{idx}.conv.modulation.bias"
        
        to_weight_key = f"synthesis.b{res}.torgb.weight"
        to_bias_key = f"synthesis.b{res}.torgb.bias"
        to_affine_weight_key = f"synthesis.b{res}.torgb.affine.weight"
        to_affine_bias_key = f"synthesis.b{res}.torgb.affine.bias"
        
        ckpt_from[to_weight_key] = ckpt_from[from_weight_key][0]
        ckpt_from[to_bias_key] = ckpt_from[from_bias_key].squeeze()
        ckpt_from[to_affine_weight_key] = ckpt_from[from_affine_weight_key]
        ckpt_from[to_affine_bias_key] = ckpt_from[from_affine_bias_key]
        
        del ckpt_from[from_weight_key]    
        del ckpt_from[from_bias_key]
        del ckpt_from[from_affine_weight_key]    
        del ckpt_from[from_affine_bias_key]

In [20]:
def convert_const(ckpt_from, G_ema):
    from_const_key = f"input.input"
    to_const_key = f"synthesis.b4.const"
    ckpt_from[to_const_key] = ckpt_from[from_const_key][0]
    
    del ckpt_from[from_const_key]

In [21]:
def convert_convs(ckpt_from, G_ema):
    channels_res = [int(res) for res in 4*2**np.arange(np.log2(G_ema.img_resolution)-1)]
    
    res4 = channels_res.pop(0)
    from_weight_key = f"conv1.conv.weight"
    from_bias_key = f"conv1.activate.bias"
    from_affine_weight_key = f"conv1.conv.modulation.weight"
    from_affine_bias_key = f"conv1.conv.modulation.bias"
    from_noise_weight_key = f"conv1.noise.weight"
    
    to_weight_key = f"synthesis.b4.conv1.weight"
    to_bias_key = f"synthesis.b4.conv1.bias"
    to_affine_weight_key = f"synthesis.b4.conv1.affine.weight"
    to_affine_bias_key = f"synthesis.b4.conv1.affine.bias"
    to_noise_weight_key = f"synthesis.b4.conv1.noise_strength"
    
    ckpt_from[to_weight_key] = ckpt_from[from_weight_key][0]
    ckpt_from[to_bias_key] = ckpt_from[from_bias_key]
    ckpt_from[to_affine_weight_key] = ckpt_from[from_affine_weight_key]
    ckpt_from[to_affine_bias_key] = ckpt_from[from_affine_bias_key]
    ckpt_from[to_noise_weight_key] = ckpt_from[from_noise_weight_key].squeeze()
    
    del ckpt_from[from_weight_key]    
    del ckpt_from[from_bias_key]
    del ckpt_from[from_affine_weight_key]
    del ckpt_from[from_affine_bias_key]
    del ckpt_from[from_noise_weight_key]
    
    
    for idx, res in enumerate(channels_res):
        for i, _idx in enumerate([2*idx, 2*idx+1]):
            from_weight_key = f"convs.{_idx}.conv.weight"
            from_bias_key = f"convs.{_idx}.activate.bias"
            from_affine_weight_key = f"convs.{_idx}.conv.modulation.weight"
            from_affine_bias_key = f"convs.{_idx}.conv.modulation.bias"
            from_noise_weight_key = f"convs.{_idx}.noise.weight"
            
            to_weight_key = f"synthesis.b{res}.conv{i}.weight"
            to_bias_key = f"synthesis.b{res}.conv{i}.bias"
            to_affine_weight_key = f"synthesis.b{res}.conv{i}.affine.weight"
            to_affine_bias_key = f"synthesis.b{res}.conv{i}.affine.bias"
            to_noise_weight_key = f"synthesis.b{res}.conv{i}.noise_strength"
            
            ckpt_from[to_weight_key] = ckpt_from[from_weight_key][0]
            ckpt_from[to_bias_key] = ckpt_from[from_bias_key]
            ckpt_from[to_affine_weight_key] = ckpt_from[from_affine_weight_key]
            ckpt_from[to_affine_bias_key] = ckpt_from[from_affine_bias_key]
            ckpt_from[to_noise_weight_key] = ckpt_from[from_noise_weight_key].squeeze()
            
            del ckpt_from[from_weight_key]    
            del ckpt_from[from_bias_key]
            del ckpt_from[from_affine_weight_key]
            del ckpt_from[from_affine_bias_key]
            del ckpt_from[from_noise_weight_key]

In [22]:
convert_mapping(ckpt, G_ema)
convert_torgb(ckpt, G_ema)
convert_const(ckpt, G_ema)
convert_convs(ckpt, G_ema)

In [24]:
G_ema.load_state_dict(ckpt, strict=False)

_IncompatibleKeys(missing_keys=['synthesis.b4.resample_filter', 'synthesis.b4.conv1.resample_filter', 'synthesis.b4.conv1.noise_const', 'synthesis.b8.resample_filter', 'synthesis.b8.conv0.resample_filter', 'synthesis.b8.conv0.noise_const', 'synthesis.b8.conv1.resample_filter', 'synthesis.b8.conv1.noise_const', 'synthesis.b16.resample_filter', 'synthesis.b16.conv0.resample_filter', 'synthesis.b16.conv0.noise_const', 'synthesis.b16.conv1.resample_filter', 'synthesis.b16.conv1.noise_const', 'synthesis.b32.resample_filter', 'synthesis.b32.conv0.resample_filter', 'synthesis.b32.conv0.noise_const', 'synthesis.b32.conv1.resample_filter', 'synthesis.b32.conv1.noise_const', 'synthesis.b64.resample_filter', 'synthesis.b64.conv0.resample_filter', 'synthesis.b64.conv0.noise_const', 'synthesis.b64.conv1.resample_filter', 'synthesis.b64.conv1.noise_const', 'synthesis.b128.resample_filter', 'synthesis.b128.conv0.resample_filter', 'synthesis.b128.conv0.noise_const', 'synthesis.b128.conv1.resample_filt

In [25]:
converted_pickle = {}
for name, module in [('G', G_ema), ('D', D), ('G_ema', G_ema)]:
    module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
    converted_pickle[name] = module
with open('converted_model.pkl', 'wb') as f:
    pickle.dump(converted_pickle, f)

In [26]:
with dnnlib.util.open_url('converted_model.pkl') as f:
    loaded_G = legacy.load_network_pkl(f)['G_ema']