<a href="https://colab.research.google.com/github/Byronliang8/HubnessGANSampling/blob/main/Colab/Exploring_and_Exploiting_Hubness_Priors_for_High_Quality_GAN_Latent_Sampling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hubness Priors
This code is to show how to generate the high-quality images by using hubness priors. You also can test the code into the other different model by compute the GAN latent codes. Our code is based on the StyleGAN2 Paper: https://arxiv.org/abs/1812.04948 and Code: https://nvlabs.github.io/stylegan2/versions.html.



In [None]:
!pip install scikit-hubness

# Download the code
!git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git
%cd stylegan2-ada-pytorch


In [None]:
# Download the model of choice
import PIL.Image
from scipy.spatial import distance

import os
import re
from typing import List, Optional

import click
import dnnlib

import PIL.Image
import torch

import legacy

from io import BytesIO, StringIO
import IPython.display
import numpy as np
from math import ceil


from skhubness.neighbors import NearestNeighbors
import torch
from tqdm import tqdm
import dill as pickle

# Choose between these pretrained models - I think 'f' is the best choice:


network_pkl = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl"

# If downloads fails, due to 'Google Drive download quota exceeded' you can try downloading manually from your own Google Drive account
# network_pkl = "/content/drive/My Drive/GAN/stylegan2-ffhq-config-f.pkl"

print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
  G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

#noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]

Loading networks from "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl"...


# The functions

In [None]:
# Funcation
def generate_list_latent(recorder_list, dls):
    out = []
    for i in recorder_list:
        out.append(dls[i])
    out=torch.stack(out)
    return out

def generate_images(zs):
    imgs=[]
    for i in zs:
        latent=i.reshape(1,len(i))
        img=G(latent,c)
        img =(img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        imgs.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
    return imgs

def generate_images_inW(w):
    imgs=[]
    for i in w:
        latent=i.reshape(1,len(i),len(i[0]))
        img = G.synthesis(latent, noise_mode='const', force_fp32=True)
        img =(img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        imgs.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
    return imgs


def imshow(imgs_list,format='png', jpeg_fallback=True):
    imgs_list = np.asarray(imgs_list, dtype=np.uint8)
    str_file = BytesIO()
    PIL.Image.fromarray(imgs_list).save(str_file, format)
    im_data = str_file.getvalue()
    try:
        disp = IPython.display.display(IPython.display.Image(im_data))
    except IOError:
        if jpeg_fallback and format != 'jpeg':
            print ('Warning: image was too large to display in format "{}"; '
             'trying jpeg instead.').format(format)
            return imshow(imgs_list, format='jpeg')
        else:
            raise
    return disp

def createImageGrid(images, scale=0.25, rows=1):
   w,h = images[0].size
   w = int(w*scale)
   h = int(h*scale)
   height = rows*h
   cols = ceil(len(images) / rows)
   width = cols*w
   canvas = PIL.Image.new('RGBA', (width,height), 'white')
   for i,img in enumerate(images):
     img = img.resize((w,h), PIL.Image.ANTIALIAS)
     canvas.paste(img, (w*(i % cols), h*(i // cols)))
   return canvas

def getRecorder(value,s,threshold):
    recorder = []
    for i in range(s):
        if value[0][i] >=threshold:
            #if value[0][i]>=30:
                recorder.append(i)
                #print(value[0][i])
    return recorder

def getK_Occ(k,high_list):
    neigh = NearestNeighbors(k)
    # dls=np.asarray(dls)
    neigh.fit(high_list)
    theArray = neigh.kneighbors_graph(high_list)
    #z-space
    #dls=np.asarray(dls)
    neigh.fit(high_list)
    theArray = neigh.kneighbors_graph(high_list)

    value = np.sum(theArray,axis=0)  # to compute the point was connected (if is 1, that is to compute the compute to connect-> always k)
    value = np.asarray(value)
    return value

def to_torch(list):
    return torch.tensor(list).cuda()

def to_numpy(list):
    out=[]
    for i in list:
        out.append(i.cpu().numpy())
    return out

def change_circle(hubs_code_w,index=0,step=0.01):
    target_circle=hubs_code_w[index].clone()
    latent_circle=[]
    lenth=target_circle.shape
    for i in range(lenth[1]-1):
        target_circle=hubs_code_w[index].clone()
        for j in range(lenth[0]-1):
            target_circle[j][i]=target_circle[j][i]+step
    #print(target_circle.shape)
        latent_circle.append( target_circle)
    return latent_circle

def get_trunca(latent_list,threshold=0.7):
    out=[]
    latent_list=to_numpy(latent_list)
    latent_mean=get_mean(latent_list)
    for i in latent_list:
        trun_i=latent_mean+(i-latent_mean)*threshold
        out.append(trun_i)
    out=to_torch(out)
    return out

def get_mean(latent_list):
    #latent_list=latent_list.cpu().numpy()
    out=np.mean(latent_list, axis=0)
    return out

def get_distance(latent_list,avr):
    out=[]
    for i in latent_list:
        if len(i.shape)==2:
            dst = distance.euclidean(i[0],avr[0])
            out.append(dst)
        elif len(i.shape)<2:
            dst = distance.euclidean(i,avr)
            out.append(dst)
    return out

def edit_step(input,edit_vectors,size=10,step=-0.4):
    edit=[]
    for i in range(size):
        edit_latent=input+edit_vectors*(i)*step
        edit.append(edit_latent)

    edit=torch.tensor(edit).cuda()
    edit=edit.reshape(size,18,512)
    return edit

def affine_different_from_editing(ws,ws_smile):
    bloc_list=[G.synthesis.b4,G.synthesis.b8,G.synthesis.b16,G.synthesis.b32,G.synthesis.b64,G.synthesis.b128,G.synthesis.b256,G.synthesis.b512,G.synthesis.b1024]
    x_ws=img_ws=x_smile=img_smile=None
    diffs=[]
    #w_iter = iter(ws.unbind(dim=1))
    for i in bloc_list:
        input_size=i.num_torgb+i.num_conv
    # initial images
        x_ws,img_ws=i(x_ws,img_ws,ws[0][0:input_size].reshape(1,input_size,512))
    # get the affine (the input for synthesis layer)
        affine_ws = i.conv1.affine(ws[0][0].reshape(1,512))
        #if i.in_channels == 0:
        #   affine_ws = i.conv1.affine(ws[0][0].reshape(1,512))
        #elif i.architecture == 'resnet':
        #   affine_ws = i.conv1.affine(ws[0][0].reshape(1,512))

    #smile images
        x_smile,img_smile=i(x_smile,img_smile,ws_smile[0][0:input_size].reshape(1,input_size,512))
    # get the affine (the input for synthesis layer)
        affine_smile = i.conv1.affine(ws_smile[0][0].reshape(1,512))
        #if i.in_channels == 0:
        #    affine_smile = i.conv1.affine(ws_smile[0][0].reshape(1,512))
        #elif i.architecture == 'resnet':
        #    affine_smile = i.conv1.affine(ws_smile[0][0].reshape(1,512))

    # compute the affine different for the style layer
        diff_=affine_ws-affine_smile
        diff_=to_numpy(diff_)
        diffs.append(diff_)
    return diffs

def get_styles(affine):
    styles=affine
    return styles / styles.norm(float('inf'), dim=1, keepdim=True)

def style_layer_w_out(affines):
    styles=[]
    for j,i in enumerate(affines):
        styles.append(get_styles(i))
    return styles

def affine_w_out(ws):
    bloc_list=[G.synthesis.b4,G.synthesis.b8,G.synthesis.b16,G.synthesis.b32,G.synthesis.b64,G.synthesis.b128,G.synthesis.b256,G.synthesis.b512,G.synthesis.b1024]
    affines=[]
    w_iter = iter(ws.unbind(dim=1))
    for j,i in enumerate(bloc_list):
        affine_ws1 = i.conv1.affine(next(w_iter))
        affines.append(affine_ws1)
        #affine_rgb = i.torgb.affine(next(w_iter))
        #affines.append(affine_rgb)
        if j >0:
            affine_ws0 = i.conv0.affine(next(w_iter))
            affines.append(affine_ws0)
    return affines

def saveImgs(imgs, location):
  for idx, img in log_progress(enumerate(imgs), size = len(imgs), name="Saving images"):
    file = location+ str(idx) + ".png"
    img.save(file)

def feature_map_analyze(ws):
    bloc_list=[G.synthesis.b4,G.synthesis.b8,G.synthesis.b16,G.synthesis.b32,G.synthesis.b64,G.synthesis.b128,G.synthesis.b256,G.synthesis.b512,G.synthesis.b1024]
    xs=[]
    imgs=[]
    x=img=None
    w_idx=0
    for i in (bloc_list):
        input_size=i.num_torgb+i.num_conv
        input_w=ws.narrow(1, w_idx,input_size )#size=i.num_conv + i.num_torgb
        x,img=i(x,img,input_w)
        xs.append(x)
        imgs.append(img)
        w_idx += i.num_conv
    return xs,imgs

# Taken from https://github.com/alexanderkuk/log-progress
def log_progress(sequence, every=1, size=None, name='Items'):
    from ipywidgets import IntProgress, HTML, VBox
    from IPython.display import display

    is_iterator = False
    if size is None:
        try:
            size = len(sequence)
        except TypeError:
            is_iterator = True
    if size is not None:
        if every is None:
            if size <= 200:
                every = 1
            else:
                every = int(size / 200)     # every 0.5%
    else:
        assert every is not None, 'sequence is iterator, set every'

    if is_iterator:
        progress = IntProgress(min=0, max=1, value=1)
        progress.bar_style = 'info'
    else:
        progress = IntProgress(min=0, max=size, value=0)
    label = HTML()
    box = VBox(children=[label, progress])
    display(box)

    index = 0
    try:
        for index, record in enumerate(sequence, 1):
            if index == 1 or index % every == 0:
                if is_iterator:
                    label.value = '{name}: {index} / ?'.format(
                        name=name,
                        index=index
                    )
                else:
                    progress.value = index
                    label.value = u'{name}: {index} / {size}'.format(
                        name=name,
                        index=index,
                        size=size
                    )
            yield record
    except:
        progress.bar_style = 'danger'
        raise
    else:
        progress.bar_style = 'success'
        progress.value = index
        label.value = "{name}: {index}".format(
            name=name,
            index=str(index or '?')
        )

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# The random generation

In [None]:
print(G)

In [None]:
s=25
zs = torch.randn([s, G.z_dim]).cuda()    # latent codes
c = None

w = G.mapping(zs, c, truncation_psi=1, truncation_cutoff=8)
imgs=generate_images_inW(w)
#saveImgs(imgs,'Random/')
imshow(createImageGrid(imgs[0:25],0.4,5))

# The hubness priors 

In [None]:
s=10000
k=5
zs = torch.randn([s, G.z_dim]).cuda()    # latent codes
c = None
w = G.mapping(zs, c, truncation_psi=1, truncation_cutoff=8)

w_np=to_numpy(w)
w_input=[]
for j in w_np:
    w_input.append(j[0])

value_w=getK_Occ(k,w_input)# get hub value
    #print('done')

# hubness priors
threshold=50
recorder_w=getRecorder(value_w,s,threshold)
hubs_code_w = generate_list_latent(recorder_w, w)# t hub latents
print(hubs_code_w.shape)

hubs_imgs=generate_images_inW(hubs_code_w)
#dir='hubsimage/hubs_'
#saveImgs(hubs_imgs,dir)
imshow(createImageGrid(hubs_imgs[0:25],0.4,5))