<a href="https://colab.research.google.com/github/HighCWu/anime_biggan_toy/blob/main/colab/pytorch_anime_biggan_for_discriminator_converter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Save weights

In [None]:
!cp drive/My\ Drive/anime-biggan-256px-run39-607250 ./ -r

In [None]:
import os
os.makedirs('images', exist_ok=True)
filelist = [
    '512px/0999/999999.jpg',
    '512px/0999/999.jpg',
    '512px/0999/998999.jpg',
    '512px/0999/997999.jpg'
]
for i in range(4):
    path = filelist[-i-1]
    print(f'Rsync image from rsync://78.46.86.149:873/danbooru2019/{path} to directory "images"')
    !rsync rsync://78.46.86.149:873/danbooru2019/$path ./images


In [None]:
import glob
import numpy as np
from PIL import Image
imgs_path = glob.glob('./images/*.jpg')
imgs = []
for path in imgs_path:
    img = (np.asarray(Image.open(path).crop([127,127,127+256,127+256]))[None,...]/255.0).astype('float32')
    imgs.append(img)
imgs = np.concatenate(imgs, 0)

In [None]:
import os
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

import tensorflow_hub as hub

module_path = os.path.join('anime-biggan-256px-run39-607250', "tfhub")
tf.reset_default_graph()
module = hub.Module(module_path, 
            name="disc_module", tags={"disc", "bsNone"})
print('Loaded BigGAN module from:', module_path)

initializer = tf.global_variables_initializer()
sess = tf.Session()
sess.run(initializer)

batch_size = 4
images = tf.placeholder(shape=[batch_size, 256, 256, 3], dtype=tf.float32)  # noise sample
labels = tf.random.uniform([batch_size], maxval=1000, dtype=tf.int32, seed=0)
inputs = dict(images=images, labels=labels)

prediction = module(inputs, as_dict=True)["prediction"]

In [None]:
for tensor in [tensor for op in sess.graph.get_operations() for tensor in op.values()]:
    if 'disc_module_apply_default/discriminator' in tensor.name:
        print(tensor.name, tensor.shape)

In [None]:
var_list = []
for var in tf.global_variables():
    val = sess.run(var)
    var_list.append([var.name, val])

for weights in var_list:
  print(weights[0], weights[1].shape)

import pickle
f = open('tf_discriminator.pkl', 'wb')
pickle.dump(var_list, f)
f.close()

In [None]:
tensors_name_con = [
    'conv1/add:0', 'conv2/add:0', 'conv_shortcut/add:0', 
    'conv2d_theta/Conv2D:0', 'conv2d_phi/Conv2D:0', 'conv2d_g/Conv2D:0', 'conv2d_attn_g/Conv2D:0',
    'final_fc/add:0', 'embedding_fc/MatMul_4:0'
]
import collections
tensor_dict = collections.OrderedDict()
tensor_dict['images'] = images
tensor_dict['labels'] = labels
for tensor in [tensor for op in sess.graph.get_operations() for tensor in op.values()]:
    if 'disc_module_apply_default/discriminator' in tensor.name:
        for name_con in tensors_name_con:
            if name_con in tensor.name:
                tensor_dict[tensor.name] = tensor
                break
tensor_dict['prediction'] = prediction

# for name, tensor in tensor_dict.items():
#     print(name, tensor)

ret = sess.run(tensor_dict, feed_dict={images:imgs})

for name, value in ret.items():
    print(name, value.shape)

import pickle
f = open('tf_tensor_samples.pkl', 'wb')
pickle.dump(ret, f)
f.close()

## Convert weights
You may want to restart the kernel to release gpu memory after generating some samples use TF.

In [None]:
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
from torch.nn import functional as F

import collections
tensors_name_con = [ 'images','labels',
    'conv1/add:0', 'conv2/add:0', 'conv_shortcut/add:0', 
    'conv2d_theta/Conv2D:0', 'conv2d_phi/Conv2D:0', 'conv2d_g/Conv2D:0', 'conv2d_attn_g/Conv2D:0',
    'final_fc/add:0', 'embedding_fc/MatMul_4:0',
    'prediction'
]
gt = collections.OrderedDict()
for n in tensors_name_con:
    gt[n] = []

def l2_normalize(v, dim=None, eps=1e-12):
    return v / (v.norm(dim=dim, keepdim=True) + eps)
    
 
def unpool(value):
    """Unpooling operation.
    N-dimensional version of the unpooling operation from
    https://www.robots.ox.ac.uk/~vgg/rg/papers/Dosovitskiy_Learning_to_Generate_2015_CVPR_paper.pdf
    Taken from: https://github.com/tensorflow/tensorflow/issues/2169
    Args:
        value: a Tensor of shape [b, d0, d1, ..., dn, ch]
        name: name of the op
    Returns:
        A Tensor of shape [b, 2*d0, 2*d1, ..., 2*dn, ch]
    """
    value = torch.Tensor.permute(value, [0,2,3,1])
    sh = list(value.shape)
    dim = len(sh[1:-1])
    out = (torch.reshape(value, [-1] + sh[-dim:]))
    for i in range(dim, 0, -1):
        out = torch.cat([out, torch.zeros_like(out)], i)
    out_size = [-1] + [s * 2 for s in sh[1:-1]] + [sh[-1]]
    out = torch.reshape(out, out_size)
    out = torch.Tensor.permute(out, [0,3,1,2])
    return out
 
 
class BatchNorm2d(nn.BatchNorm2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.initialized = False
        self.accumulating = False
        self.accumulated_mean = Parameter(torch.zeros(args[0]), requires_grad=False)
        self.accumulated_var = Parameter(torch.zeros(args[0]), requires_grad=False)
        self.accumulated_counter = Parameter(torch.zeros(1)+1e-12, requires_grad=False)
 
    def forward(self, inputs, *args, **kwargs):
        if not self.initialized:
            self.check_accumulation()
            self.set_initialized(True)
        if self.accumulating:
            self.eval()
            with torch.no_grad():
                axes = [0] + ([] if len(inputs.shape) == 2 else list(range(2,len(inputs.shape))))
                _mean = torch.mean(inputs, axes, keepdim=True)
                mean = torch.mean(inputs, axes, keepdim=False)
                var = torch.mean((inputs-_mean)**2, axes)
                self.accumulated_mean.copy_(self.accumulated_mean + mean)
                self.accumulated_var.copy_(self.accumulated_var + var)
                self.accumulated_counter.copy_(self.accumulated_counter + 1)
                _mean = self.running_mean*1.0
                _variance = self.running_var*1.0
                self._mean.copy_(self.accumulated_mean / self.accumulated_counter)
                self._variance.copy_(self.accumulated_var / self.accumulated_counter)
                out = super().forward(inputs, *args, **kwargs)
                self.running_mean.copy_(_mean)
                self.running_var.copy_(_variance)
                return out
        out = super().forward(inputs, *args, **kwargs)
        return out
 
    def check_accumulation(self):
        if self.accumulated_counter.detach().cpu().numpy().mean() > 1-1e-12:
            self.running_mean.copy_(self.accumulated_mean / self.accumulated_counter)
            self.running_var.copy_(self.accumulated_var / self.accumulated_counter)
            return True
        return False
 
    def clear_accumulated(self):
        self.accumulated_mean.copy_(self.accumulated_mean*0.0)
        self.accumulated_var.copy_(self.accumulated_var*0.0)
        self.accumulated_counter.copy_(self.accumulated_counter*0.0+1e-2)
 
    def set_accumulating(self, status=True):
        if status:
            self.accumulating = True
        else:
            self.accumulating = False
 
    def set_initialized(self, status=False):
        if not status:
            self.initialized = False
        else:
            self.initialized = True
 

class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iterations=2):
        super().__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()
 
    def _update_u(self):
        w = self.weight
        u = self.weight_u
 
        if len(w.shape) == 4:
            _w = torch.Tensor.permute(w, [2,3,1,0])
            _w = torch.reshape(_w, [-1, _w.shape[-1]])
        elif isinstance(self.module, nn.Linear) or isinstance(self.module, nn.Embedding):
            _w = torch.Tensor.permute(w, [1,0])
            _w = torch.reshape(_w, [-1, _w.shape[-1]])
        else:
            _w = torch.reshape(w, [-1, w.shape[-1]])
            _w = torch.reshape(_w, [-1, _w.shape[-1]])
        singular_value = "left" if _w.shape[0] <= _w.shape[1] else "right"
        norm_dim = 0 if _w.shape[0] <= _w.shape[1] else 1
        for _ in range(self.power_iterations):
            if singular_value == "left":
                v = l2_normalize(torch.matmul(_w.t(), u), dim=norm_dim)
                u = l2_normalize(torch.matmul(_w, v), dim=norm_dim)
            else:
                v = l2_normalize(torch.matmul(u, _w.t()), dim=norm_dim)
                u = l2_normalize(torch.matmul(v, _w), dim=norm_dim)
 
        if singular_value == "left":
            sigma = torch.matmul(torch.matmul(u.t(), _w), v)
        else:
            sigma = torch.matmul(torch.matmul(v, _w), u.t())
        _w = w / sigma
        setattr(self.module, self.name, _w)
        self.weight_u.copy_(u.detach())
 
    def _made_params(self):
        try:
            self.weight
            self.weight_u
            return True
        except AttributeError:
            return False
 
    def _make_params(self):
        w = getattr(self.module, self.name)
 
        if len(w.shape) == 4:
            _w = torch.Tensor.permute(w, [2,3,1,0])
            _w = torch.reshape(_w, [-1, _w.shape[-1]])
        elif isinstance(self.module, nn.Linear) or isinstance(self.module, nn.Embedding):
            _w = torch.Tensor.permute(w, [1,0])
            _w = torch.reshape(_w, [-1, _w.shape[-1]])
        else:
            _w = torch.reshape(w, [-1, w.shape[-1]])
        singular_value = "left" if _w.shape[0] <= _w.shape[1] else "right"
        norm_dim = 0 if _w.shape[0] <= _w.shape[1] else 1
        u_shape = (_w.shape[0], 1) if singular_value == "left" else (1, _w.shape[-1])
        
        u = Parameter(w.data.new(*u_shape).normal_(0, 1), requires_grad=False)
        u.copy_(l2_normalize(u, dim=norm_dim).detach())
 
        del self.module._parameters[self.name]
        self.weight = w
        self.weight_u = u
 
    def forward(self, *args, **kwargs):
        self._update_u()
        return self.module.forward(*args, **kwargs)
    
    
class SelfAttention(nn.Module):
    def __init__(self, in_dim, activation=torch.relu):
        super().__init__()
        self.chanel_in = in_dim
        self.activation = activation
    
        self.theta = SpectralNorm(nn.Conv2d(in_dim, in_dim // 8, 1, bias=False))
        self.phi = SpectralNorm(nn.Conv2d(in_dim, in_dim // 8, 1, bias=False))
        self.pool = nn.MaxPool2d(2, 2)
        self.g = SpectralNorm(nn.Conv2d(in_dim, in_dim // 2, 1, bias=False))
        self.o_conv = SpectralNorm(nn.Conv2d(in_dim // 2, in_dim, 1, bias=False))
        self.gamma = Parameter(torch.zeros(1))
    
    def forward(self, x):
        m_batchsize, C, width, height = x.shape
        N = height * width
    
        theta = self.theta(x)
        gt['conv2d_theta/Conv2D:0'].append(torch.Tensor.permute(theta,[0,2,3,1]))
        phi = self.phi(x)
        gt['conv2d_phi/Conv2D:0'].append(torch.Tensor.permute(phi,[0,2,3,1]))
        phi = self.pool(phi)
        phi = torch.reshape(phi,(m_batchsize, -1, N // 4))
        theta = torch.reshape(theta,(m_batchsize, -1, N))
        theta = torch.Tensor.permute(theta,(0, 2, 1))
        attention = torch.softmax(torch.bmm(theta, phi), -1)
        g = self.g(x)
        gt['conv2d_g/Conv2D:0'].append(torch.Tensor.permute(g,[0,2,3,1]))
        g = torch.reshape(self.pool(g),(m_batchsize, -1, N // 4))
        attn_g = torch.reshape(torch.bmm(g, torch.Tensor.permute(attention,(0, 2, 1))),(m_batchsize, -1, width, height))
        out = self.o_conv(attn_g)
        gt['conv2d_attn_g/Conv2D:0'].append(torch.Tensor.permute(out,[0,2,3,1]))
        return self.gamma * out + x
 
 
class ConditionalBatchNorm2d(nn.Module):
    def __init__(self, num_features, num_classes, eps=1e-5, momentum=0.1):
        super().__init__()
        self.bn_in_cond = BatchNorm2d(num_features, affine=False, eps=eps, momentum=momentum)
        self.gamma_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False))
        self.beta_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False))
    
    def forward(self, x, y):
        out = self.bn_in_cond(x)
        gamma = self.gamma_embed(y)
        # gamma = gamma + 1
        beta = self.beta_embed(y)
        out = torch.reshape(gamma, (gamma.shape[0], -1, 1, 1)) * out + torch.reshape(beta, (beta.shape[0], -1, 1, 1))
        return out
 

class ResBlock(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size=[3, 3],
        padding=1,
        stride=1,
        n_class=None,
        conditional=True,
        activation=torch.relu,
        upsample=True,
        downsample=False,
        z_dim=128,
        use_attention=False,
        skip_proj=None
    ):
        super().__init__()
    
        if conditional:
            self.cond_norm1 = ConditionalBatchNorm2d(in_channel, z_dim)
    
        self.conv0 = SpectralNorm(
            nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding)
        )
    
        if conditional:
            self.cond_norm2 = ConditionalBatchNorm2d(out_channel, z_dim)
    
        self.conv1 = SpectralNorm(
            nn.Conv2d(out_channel, out_channel, kernel_size, stride, padding)
        )
    
        self.skip_proj = False
        if skip_proj is not True and (upsample or downsample):
            self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0))
            self.skip_proj = True
    
        if use_attention:
            self.attention = SelfAttention(out_channel)
    
        self.upsample = upsample
        self.downsample = downsample
        self.activation = activation
        self.conditional = conditional
        self.use_attention = use_attention
    
    def forward(self, input, condition=None):
        out = input
    
        if self.conditional:
            out = self.cond_norm1(out, condition)
        out = self.activation(out)
        if self.upsample:
            out = unpool(out) # out = F.interpolate(out, scale_factor=2)
            gt['unpool:0'].append(torch.Tensor.permute(out,[0,2,3,1]))
        out = self.conv0(out)
        gt['conv1/add:0'].append(torch.Tensor.permute(out,[0,2,3,1]))
        if self.conditional:
            out = self.cond_norm2(out, condition)
        out = self.activation(out)
        out = self.conv1(out)
        gt['conv2/add:0'].append(torch.Tensor.permute(out,[0,2,3,1]))
    
        if self.downsample:
            out = F.avg_pool2d(out, 2, 2)
    
        if self.skip_proj:
            skip = input
            if self.upsample:
                skip = unpool(skip) # skip = F.interpolate(skip, scale_factor=2)
            skip = self.conv_sc(skip)
            gt['conv_shortcut/add:0'].append(torch.Tensor.permute(skip,[0,2,3,1]))
            if self.downsample:
                skip = F.avg_pool2d(skip, 2, 2)
            out = out + skip
        else:
            skip = input
    
        if self.use_attention:
            out = self.attention(out)
    
        return out
 
 
class Discriminator(nn.Module):
    def __init__(self, n_class=1000, chn=96, blocks_with_attention="B2", resolution=256): 
        super().__init__()
    
        def DBlock(in_channel, out_channel, downsample=True, use_attention=False, skip_proj=None):
            return ResBlock(in_channel, out_channel, conditional=False, upsample=False, 
                        downsample=downsample, use_attention=use_attention, skip_proj=skip_proj)
    
        self.chn = chn
        self.colors = 3
        self.resolution = resolution  
        self.blocks_with_attention = set(blocks_with_attention.split(",")) 
        self.blocks_with_attention.discard('')
    
        dblock = []
        in_channels, out_channels = self.get_in_out_channels()
    
        self.sa_ids = [int(s.split('B')[-1]) for s in self.blocks_with_attention]
    
        for i, (nc_in, nc_out) in enumerate(zip(in_channels[:-1], out_channels[:-1])):
            dblock.append(DBlock(nc_in, nc_out, downsample=True, 
                        use_attention=(i+1) in self.sa_ids, skip_proj=nc_in==nc_out))
        dblock.append(DBlock(in_channels[-1], out_channels[-1], downsample=False, 
                        use_attention=len(out_channels) in self.sa_ids, skip_proj=in_channels[-1]==out_channels[-1]))
        self.blocks = nn.ModuleList(dblock)
    
        self.final_fc = SpectralNorm(nn.Linear(16 * chn, 1))
    
        self.embed_y = nn.Embedding(n_class, 16 * chn)
        self.embed_y.weight.data.uniform_(-0.1, 0.1)
        self.embed_y = SpectralNorm(self.embed_y)
 
    def get_in_out_channels(self):
        colors = self.colors
        resolution = self.resolution
        if resolution == 1024:
            channel_multipliers = [1, 1, 1, 2, 4, 8, 8, 16, 16]
        elif resolution == 512:
            channel_multipliers = [1, 1, 2, 4, 8, 8, 16, 16]
        elif resolution == 256:
            channel_multipliers = [1, 2, 4, 8, 8, 16, 16]
        elif resolution == 128:
            channel_multipliers = [1, 2, 4, 8, 16, 16]
        elif resolution == 64:
            channel_multipliers = [2, 4, 8, 16, 16]
        elif resolution == 32:
            channel_multipliers = [2, 2, 2, 2]
        else:
            raise ValueError("Unsupported resolution: {}".format(resolution))
        out_channels = [self.chn * c for c in channel_multipliers]
        in_channels = [colors] + out_channels[:-1]
        return in_channels, out_channels
 
    def forward(self, input, class_id):
        for key, item in gt.items():
            item.clear()
        gt['images'].append(torch.Tensor.permute(input, [0,2,3,1]))
        out = input
        for i, dblock in enumerate(self.blocks):
            out = dblock(out)
        out = torch.relu(out)
        out = torch.sum(out, [2,3])
        out_linear = self.final_fc(out)
        gt['final_fc/add:0'].append(out_linear)
        gt['labels'].append(class_id)
        class_emb = self.embed_y(class_id) 
        gt['embedding_fc/MatMul_4:0'].append(class_emb)
    
        prod = torch.sum((class_emb * out), 1, keepdim=True)
        
        gt['prediction'].append(torch.sigmoid(out_linear + prod))
    
        return torch.sigmoid(out_linear + prod)


In [None]:
img = torch.rand(4,3,256,256).cuda()
y = torch.randint(0,1000,[4]).cuda()
print(img.shape)
d_256 = Discriminator(n_class=1000, chn=96, blocks_with_attention="B2", resolution=256).cuda()
pred = d_256(img, y)
print(pred.shape)

In [None]:
import pickle
f = open('tf_discriminator.pkl', 'rb')
tf_weights = pickle.load(f)
f.close()

# def tf_filter(x):
#   if 'accu/update_accus:0' in x[0]:
#     return False
#   return True

# tf_weights = filter(tf_filter, tf_weights)

def pd_filter(x):
#   if 'weight_v' in x[0] or '._mean' in x[0] or '._variance' in x[0] \
#      or'.bn_in_cond.weight' in x[0] or '.bn_in_cond.bias' in x[0]:
#     return False
  return True

_pd_params = list(filter(pd_filter, d_256.named_parameters()))
pd_params = []
for i, params in enumerate(_pd_params):
  b_continue = False
  for j in range(6):
    if 'attention.gamma' in _pd_params[i-j][0]:
      pd_params.append(_pd_params[i+1])
      b_continue = True
  if b_continue:
    continue
  if 'attention.gamma' in _pd_params[i-6][0]:
    pd_params.append(_pd_params[i-6])
    continue
#   if 'output_layer.0.weight' in params[0]:
#     pd_params.append(_pd_params[i+1])
#     continue
#   if 'output_layer.0.bias' in params[0]:
#     pd_params.append(_pd_params[i-1])
#     continue
  pd_params.append(params)

# _pd_params = pd_params
# pd_params = [param for param in _pd_params]
# pd_params[-8] = _pd_params[-6]
# pd_params[-7] = _pd_params[-5]
# pd_params[-6] = _pd_params[-4]
# pd_params[-5] = _pd_params[-8]
# pd_params[-4] = _pd_params[-7]

for i, (tf_weight, pd_param) in enumerate(zip(tf_weights, pd_params)):
  if len(pd_param[1].shape) == 4:
    weight = tf_weight[1].transpose([3, 2, 0, 1])
  elif len(pd_param[1].shape) == 2 and \
        pd_param[1].shape[0] == tf_weight[1].shape[1] and \
        pd_param[1].shape[1] == tf_weight[1].shape[0]:
    weight = tf_weight[1].transpose([1,0])
  else:
    weight = tf_weight[1].reshape(pd_param[1].shape)
  grad_status = pd_param[1].requires_grad
  if grad_status:
    pd_param[1].requires_grad = False
  pd_param[1].copy_(torch.from_numpy(weight).cuda())
  if grad_status:
    pd_param[1].requires_grad = True
  print(tf_weight[0], tf_weight[1].shape, pd_param[0], pd_param[1].shape)

In [None]:
import pickle
f = open('tf_tensor_samples.pkl', 'rb')
tf_tensors = pickle.load(f)
f.close()

gtt = collections.OrderedDict()
for n in tensors_name_con:
    gtt[n] = []
for key, tensor in tf_tensors.items():
    for n in tensors_name_con:
        if n in key:
            gtt[n].append([key, tensor])
            break

# for _layers in d_256.named_sublayers():
#   class_name = _layers[1].__class__.__name__
#   if 'BatchNorm' == class_name:
#     _layers[1].set_initialized(False)

d_256.eval()
x = torch.from_numpy(tf_tensors['images'].astype('float32').transpose([0,3,1,2])).cuda() # layers.random_uniform(shape=[4,3,256,256],min=0,max=1)
y = torch.from_numpy(tf_tensors['labels'].astype('int64')).cuda() # layers.randint(0,1000,shape=[2])
img = d_256(x, y)

In [None]:
for (_, item1), (key2, item2) in zip(gtt.items(), gt.items()):
   for (key1, t1), t2 in zip(item1, item2):
       print(key1, key2, t1.shape, t2.shape, (np.abs(t1 - t2.detach().cpu().numpy())).mean())

In [None]:
save_path = './anime-biggan-256px-run39-607250.discriminator.pth'
torch.save(d_256.state_dict(), save_path)
!cp ./anime-biggan-256px-run39-607250.discriminator.pth ./drive/My\ Drive/