In [1]:
import os
import sys

sys.path.insert(1, os.path.join(sys.path[0], ".."))

import numpy as np
import matplotlib.pyplot as plt
from argparse import ArgumentParser
import torch

import data_utils.utils as data_utils
import inference.utils as inference_utils
import BigGAN_PyTorch.utils as biggan_utils
from data_utils.datasets_common import pil_loader
import torchvision.transforms as transforms
import time
from glob import glob
import h5py as h5
import torchvision

In [None]:
def get_data(root_path, model, resolution, which_dataset, visualize_instance_images):
    data_path = os.path.join(root_path, "stored_instances")
    if model == "cc_icgan":
        feature_extractor = "classification"
    else:
        feature_extractor = "selfsupervised"
    filename = "%s_res%i_rn50_%s_kmeans_k1000_instance_features.npy" % (
        which_dataset,
        resolution,
        feature_extractor,
    )
    # Load conditioning instances from files
    data = np.load(os.path.join(data_path, filename), allow_pickle=True).item()

    transform_list = None
    if visualize_instance_images:
        # Transformation used for ImageNet images.
        transform_list = transforms.Compose(
            [data_utils.CenterCropLongEdge(), transforms.Resize(resolution)]
        )
    return data, transform_list


def get_model(exp_name, root_path, backbone, device="cuda"):
    parser = biggan_utils.prepare_parser()
    parser = biggan_utils.add_sample_parser(parser)
    parser = inference_utils.add_backbone_parser(parser)

    args = ["--experiment_name", exp_name]
    args += ["--base_root", root_path]
    args += ["--model_backbone", backbone]

    config = vars(parser.parse_args(args=args))

    # Load model and overwrite configuration parameters if stored in the model
    config = biggan_utils.update_config_roots(config, change_weight_folder=False)
    generator, config = inference_utils.load_model_inference(config, device=device)
    biggan_utils.count_parameters(generator)
    generator.eval()

    return generator


def get_conditionings(test_config, generator, data):
    # Obtain noise vectors
    z = torch.empty(
        5 * 5,
        generator.z_dim if "biggan" == "stylegan2" else generator.dim_z,
    ).normal_(mean=0, std=1.0)

    # Subsampling some instances from the 1000 k-means centers file
    if 5 > 1:
        total_idxs = np.random.choice(
            range(1000), 5, replace=False
        )

    # Obtain features, labels and ground truth image paths
    all_feats, all_img_paths, all_labels = [], [], []
    for counter in range(5):
        # Index in 1000 k-means centers file
        if None is not None:
            idx = None
        else:
            idx = total_idxs[counter]
        # Image paths to visualize ground-truth instance
        if False:
            all_img_paths.append(data["image_path"][idx])
        # Instance features
        all_feats.append(
            torch.FloatTensor(data["instance_features"][idx : idx + 1]).repeat(
                5, 1
            )
        )
        # Obtain labels
        if None is not None:
            # Swap label for a manually specified one
            label_int = None
        else:
            # Use the label associated to the instance feature
            label_int = int(data["labels"][idx])
        # Format labels according to the backbone
        labels = None
        if "biggan" == "stylegan2":
            dim_labels = 1000
            labels = torch.eye(dim_labels)[torch.LongTensor([label_int])].repeat(
                5, 1
            )
        else:
            if "icgan" == "cc_icgan":
                labels = torch.LongTensor([label_int]).repeat(
                    5
                )
        all_labels.append(labels)
    # Concatenate all conditionings
    all_feats = torch.cat(all_feats)
    if all_labels[0] is not None:
        all_labels = torch.cat(all_labels)
    else:
        all_labels = None
    return z, all_feats, all_labels, all_img_paths

In [None]:
### -- Data -- ###
data, transform_list = get_data(
    "/Work2/Watch_This/ICGAN/ic_gan/pretrained_models_path",
    "icgan",
    256,
    "imagenet",
    False,
)

In [None]:
device = "cpu"
exp_name = "%s_%s_%s_res%i%s" % (
        "icgan",
        "biggan",
        "imagenet",
        256,
        "",
    )
generator = get_model(
        exp_name, "/Work2/Watch_This/ICGAN/ic_gan/pretrained_models_path", "biggan", device=device
    )

In [None]:
import torch
from PIL import Image

In [None]:
torch.save(torch.tensor(data['instance_features'][0]), "/Work2/Watch_This/ICGAN/ic_gan/data_utils/test_images/0_feat_map.pth")
torch.save(torch.tensor(data['instance_features'][1]), "/Work2/Watch_This/ICGAN/ic_gan/data_utils/test_images/1_feat_map.pth")
torch.save(torch.tensor(data['instance_features'][2]), "/Work2/Watch_This/ICGAN/ic_gan/data_utils/test_images/2_feat_map.pth")
torch.save(torch.tensor(data['instance_features'][3]), "/Work2/Watch_This/ICGAN/ic_gan/data_utils/test_images/3_feat_map.pth")
torch.save(torch.tensor(data['instance_features'][4]), "/Work2/Watch_This/ICGAN/ic_gan/data_utils/test_images/4_feat_map.pth")

In [None]:
Image.open("/Work1/imagenet/" + data['image_path'][0]).save("/Work2/Watch_This/ICGAN/ic_gan/data_utils/test_images/0_img.JPEG") 
Image.open("/Work1/imagenet/" + data['image_path'][1]).save("/Work2/Watch_This/ICGAN/ic_gan/data_utils/test_images/1_img.JPEG") 
Image.open("/Work1/imagenet/" + data['image_path'][2]).save("/Work2/Watch_This/ICGAN/ic_gan/data_utils/test_images/2_img.JPEG") 
Image.open("/Work1/imagenet/" + data['image_path'][3]).save("/Work2/Watch_This/ICGAN/ic_gan/data_utils/test_images/3_img.JPEG") 
Image.open("/Work1/imagenet/" + data['image_path'][4]).save("/Work2/Watch_This/ICGAN/ic_gan/data_utils/test_images/4_img.JPEG") 

In [None]:
z.shape, None, all_feats[start:end]

In [None]:
z, all_feats, all_labels, all_img_paths = get_conditionings(
        None, generator, data
    )

In [6]:
with h5.File("/Work2/Watch_This/ICGAN/ic_gan/data/ILSVRC256_feats_selfsupervised_resnet50.hdf5", "r") as f:
#     augment_data =(f["feats"].clone(), f["paths"].clone())
    for i in range(f["feats"].shape[0]):
        print(f["feats"][i], f["paths"][i])
        break
#     feat_data = torch.tensor(f["feats"][:10])
#     print(feat_data.shape)
    feat_input, feat_path = f["feats"][0], f["paths"][0]

[0.01786717 0.03970662 0.00147569 ... 0.00897138 0.00036234 0.07025905] b'/Work1/ima'


In [None]:
augment_data

In [5]:
input_feat_map = torch.tensor(feat_input).unsqueeze(0).float()
input_feat_map /= np.linalg.norm(input_feat_map, axis=1, keepdims=True)


NameError: name 'feat_input' is not defined

In [None]:
out_img1 = generator(
            z[0:1].float(), None, input_feat_map
        )
out_img1 = torch.clamp((out_img1 * 0.5 + 0.5), 0, 1)
torchvision.transforms.functional.to_pil_image(out_img1[0])

In [None]:
all_labels

In [None]:
all_generated_images = []
with torch.no_grad():
    num_batches = 1 + (z.shape[0]) // 16
    for i in range(num_batches):
        start = 16 * i
        end = min(
            16 * i + 16, z.shape[0]
        )
        if all_labels is not None:
            labels_ = all_labels[start:end].to(device)
        else:
            labels_ = None
        gen_img = generator(
            z[start:end].to(device), labels_, all_feats[start:end].to(device)
        )
        if "biggan" == "biggan":
            gen_img = ((gen_img * 0.5 + 0.5) * 255).int()
        elif "biggan" == "stylegan2":
            gen_img = torch.clamp((gen_img * 127.5 + 128), 0, 255).int()
        all_generated_images.append(gen_img.cpu())

In [None]:
all_generated_images = torch.cat(all_generated_images)
all_generated_images = all_generated_images.permute(0, 2, 3, 1).numpy()

big_plot = []
for i in range(0, 5):
    row = []
    for j in range(0, 5):
        subplot_idx = (i * 5) + j
        row.append(all_generated_images[subplot_idx])
    row = np.concatenate(row, axis=1)
    big_plot.append(row)
big_plot = np.concatenate(big_plot, axis=0)

# (Optional) Show ImageNet ground-truth conditioning instances
# if False:
#     all_gt_imgs = []
#     for i in range(0, len(all_img_paths)):
#         all_gt_imgs.append(
#             np.array(
#                 transform_list(
#                     pil_loader(
#                         os.path.join(test_config["dataset_path"], all_img_paths[i])
#                     )
#                 )
#             ).astype(np.uint8)
#         )
#     all_gt_imgs = np.concatenate(all_gt_imgs, axis=0)
#     white_space = (
#         np.ones((all_gt_imgs.shape[0], 20, all_gt_imgs.shape[2])) * 255
#     ).astype(np.uint8)
#     big_plot = np.concatenate([all_gt_imgs, white_space, big_plot], axis=1)

plt.figure(
    figsize=(
        5 * 5,
        5 * 5,
    )
)
plt.imshow(big_plot)
plt.axis("off")

fig_path = "%s_Generations_with_InstanceDataset_%s%s%s_zvar%0.2f.png" % (
    exp_name,
    "imagenet",
    "_index" + str(None)
    if None is not None
    else "",
    "_class_idx" + str(None)
    if None is not None
    else "",
    1.0,
)
plt.savefig(fig_path, dpi=600, bbox_inches="tight", pad_inches=0)


In [None]:
fig_path