In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=3

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=3


In [2]:
import torch
from pytorch_pretrained_biggan import (
    BigGAN,
    truncated_noise_sample,
    one_hot_from_int
)
import PIL.Image
import numpy as np
import os
import argparse
from tqdm import tqdm
import json
import pickle
import matplotlib.pyplot as plt
import utils_bigbigan as ubigbi

model = BigGAN.from_pretrained('biggan-deep-256').cuda()

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


Instructions for updating:
non-resource variables are not supported in the long term


In [3]:
def convert_to_images(obj):
    """ Convert an output tensor from BigGAN in a list of images.
    """
    # need to fix import, see: https://github.com/huggingface/pytorch-pretrained-BigGAN/pull/14/commits/68a7446951f0b9400ebc7baf466ccc48cdf1b14c
    if not isinstance(obj, np.ndarray):
        obj = obj.detach().numpy()
    obj = obj.transpose((0, 2, 3, 1))
    obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255)
    img = []
    for i, out in enumerate(obj):
        out_array = np.asarray(np.uint8(out), dtype=np.uint8)
        img.append(PIL.Image.fromarray(out_array))
    return img

In [8]:
dataset_name = '/data/vision/phillipi/ganclr/datasets/biggan256tr1-png_paper_figure_10_samples'
class_name = 'n02231487'
with open(os.path.join(dataset_name, 'train', class_name, 'z_dataset.pkl') ,'rb') as fid:
    z_d = pickle.load(fid)

In [None]:
A_z_list = []
A_str_list =[]
bs = 4
idx_rnd = np.random.randint(1300, size=bs)
idx_list = []
for i in range(bs):
    name = list(z_d.keys())[idx_rnd[i]]
    A_str_list.append(name)
    A_z_list.append(z_d[name][0])
    idx_list.append(z_d[list(z_d.keys())[idx_rnd[i]]][1])


class_vector = one_hot_from_int(idx_list, batch_size=bs)
class_vector = torch.from_numpy(class_vector).cuda()
noise_vector = torch.from_numpy(np.stack(A_z_list)).cuda()

# from model
with torch.no_grad():
    output = model(noise_vector, class_vector, truncation=1.0)
output = output.cpu()
ims = convert_to_images(output)
ubigbi.imshow(ubigbi.imgrid(np.stack(ims), cols=4))

In [None]:
# now from disk
ims_disk = []
for i in range(bs):
    im = PIL.Image.open(os.path.join(dataset_name, 'train', class_name, A_str_list[i]))
    ims_disk.append(im)
ubigbi.imshow(ubigbi.imgrid(np.stack(ims_disk), cols=4))

In [None]:
B_str_list = []
B_str_list.append('seed0994_sample01299_anchor.png')
B_str_list.append('seed0994_sample01214_1.0_1.png')
B_str_list.append('seed0994_sample00962_1.0_1.png')
B_str_list.append('seed0994_sample01075_anchor.png')
B_z_list = []
idx_list = []
for i in range(len(B_str_list)):
    B_z_list.append(z_d[B_str_list[i]][0])
    idx_list.append(z_d[B_str_list[i]][1])

class_vector = one_hot_from_int(idx_list, batch_size=bs)
class_vector = torch.from_numpy(class_vector).cuda()
noise_vector = torch.from_numpy(np.stack(B_z_list)).cuda()

# from model
with torch.no_grad():
    output = model(noise_vector, class_vector, truncation=1.0)
output = output.cpu()
ims = convert_to_images(output)
ubigbi.imshow(ubigbi.imgrid(np.stack(ims), cols=4))

# now from disk
ims_disk = []
for i in range(bs):
    im = PIL.Image.open(os.path.join(dataset_name, 'train', class_name, B_str_list[i]))
    ims_disk.append(im)
ubigbi.imshow(ubigbi.imgrid(np.stack(ims_disk), cols=4))

In [9]:
len(z_d.keys())

210

In [14]:
z_d['seed0313_sample00000_ang_1.png']

[array([ 0.9696869 , -1.5022956 ,  1.9868112 ,  2.047441  ,  0.66491055,
         0.13993809,  0.48872668,  0.22876877,  0.26554343,  0.9624948 ,
        -1.5528227 ,  0.9562741 ,  1.7246914 ,  1.3230776 ,  0.64363205,
        -0.51891255,  0.7672481 , -1.2031724 ,  2.1111617 ,  1.647719  ,
        -1.8709643 , -2.2367384 , -0.98279727, -0.07452065,  0.48659796,
         0.5637121 , -2.9666898 , -0.47821516,  1.544259  , -0.12092339,
         1.3283062 , -0.25168955,  2.2286413 ,  2.651914  ,  0.75995266,
        -0.3517795 , -0.3110422 , -1.0643257 , -0.8794848 , -0.00319077,
         0.06570914, -0.6019136 ,  0.86606973,  1.1753898 ,  2.420297  ,
         0.69352674, -0.5354007 , -0.6238876 ,  0.19461262,  0.06809068,
        -0.06923448, -0.66976285, -1.3350623 ,  0.72886413,  1.3438145 ,
         0.8670308 , -0.65357924,  1.8726108 , -1.4996046 ,  0.16940922,
         0.22020459,  2.730825  ,  1.205796  ,  0.9836327 , -1.660442  ,
        -1.4162284 ,  2.6179326 ,  0.5024657 , -1.3