In [6]:
import argparse
import torch
from utils import get_model, LoadEncoder
from models.engine import DDIMSampler, DDIMSamplerEncoder
from torchvision.utils import save_image, make_grid
from collections import OrderedDict
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
import os

In [8]:
class Args(argparse.Namespace):
    arch = "unetadagn"
    img_size=64
    num_timestep = 1000
    beta = (0.0001, 0.02)
    num_condition = [2, 4]
    emb_size = 128
    channel_mult = [1, 2, 2, 2]
    num_res_blocks = 2
    use_spatial_transformer = True
    num_heads = 4
    num_sample_missing = 1000
    num_sample = 10
    w = 1.8
    projection_dim=512
    only_table = False
    concat = False
    only_encoder = False
    num_head_channels = -1
    encoder_path = None
    compose = False
    dataset = "Zappo50K"
    exp = "HSliAdaGN"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = Args()

In [10]:
model = get_model(args)

ckpt = torch.load(os.path.join("checkpoints", args.dataset, args.exp, "model_100.pth"))["model"]
new_dict = OrderedDict()
    
for k, v in ckpt.items():
    if k.startswith("module"):
        new_dict[k[7:]] = v
    else:
        new_dict[k] = v
try:
    model.load_state_dict(new_dict)
    print("All keys successfully match")
except:
    print("some keys are missing!")

for p in model.parameters():
    p.requires_grad = False

model.eval()
model.to(device)

sampler = DDIMSampler(
    model=model,
    beta =args.beta,
    T=args.num_timestep,
    w=args.w,
).to(device)

if args.encoder_path != None:
    encoder = LoadEncoder(args).to(device)
    sampler = DDIMSamplerEncoder(
            model = model,
            encoder = encoder,
            beta = args.beta,
            T = args.num_timestep,
            w = args.w,
            only_encoder = args.only_encoder
    ).to(device)

All keys successfully match


In [15]:
from config import Zappo50K, toy_dataset, CelebA

if args.dataset == "Zappo50K":
    CFG = Zappo50K()
elif args.dataset == "CelebA":
    CFG = CelebA()
else:
    CFG = toy_dataset()
# CFG = Zappo50K()
# missing = "Heel Slipper"
# targets = ["Flat Boot", "Flat Shoe", "Flat Slipper", "Flat Sandal", "Heel Boot", "Heel Shoe", "Heel Sandal"]
missing = "Heel Slipper"
targets = ["Brown_Hair Male", "Black_Hair Male", "Gray_Hair Male", "Blond_Hair Male", "Brown_Hair Female", "Black_Hair Female", "Blond_Hair Female"]

atr, obj = CFG.ATR2IDX[missing.split(" ")[0]], CFG.OBJ2IDX[missing.split(" ")[-1]]
atr = torch.tensor(atr, dtype=torch.long, device=device).repeat(args.num_sample_missing)
obj = torch.tensor(obj, dtype=torch.long, device=device).repeat(args.num_sample_missing)
x_i = torch.randn(args.num_sample_missing, 3, 64, 64).to(device)
x0 = sampler(x_i, atr, obj, steps=100)
x0 = x0 * 0.5 + 0.5
x0 = x0.cpu()
i = 0
os.makedirs(os.path.join("SampledImg", args.dataset, args.exp), exist_ok=True)
for x in x0:
    img = transforms.ToPILImage(x)
    img.save(os.path.join("SampledImg", args.dataset, args.exp, "{:05d}.jpg".format(i)))
    i += 1
# save_image(x0, "SampledImg/DualCond/GrayFemaleAdaGN.png", nrow=10, noramlized=True)

# images = []
# for target in targets:
#     atr, obj = CFG.ATR2IDX[target.split(" ")[0]], CFG.OBJ2IDX[target.split(" ")[-1]]
#     atr = torch.tensor(atr, dtype=torch.long, device=device).repeat(args.num_sample)
#     obj = torch.tensor(obj, dtype=torch.long, device=device).repeat(args.num_sample)

#     x_i = torch.randn(args.num_sample, 3, 64, 64).to(device)
#     x0 = sampler(x_i, atr, obj, steps=100)
#     x0 = x0 * 0.5 + 0.5
#     images.append(x0)
# images = torch.concatenate(images, dim=0)
# save_image(images, "SampledImg/DualCond/SeenAdaGN.png", nrow=args.num_sample, normalized=True)

  0%|[38;2;101;101;181m          [0m| 0/100 [00:05<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 500.00 MiB (GPU 0; 31.75 GiB total capacity; 10.38 GiB already allocated; 423.69 MiB free; 11.52 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [12]:
def _image_grid(imgs, rows, cols):
        """
        input -
        1. imgs - np array dims num_images x **(image_dims)
        2. rows, cols - number of rows and columns in the image grid
        output - 
        1. saved image grid on the disk at the end when all the images are pasted after multiple iterations
        """
        n,c,w,h = imgs.shape

        grid = Image.new('RGB', size=(cols*w, rows*h))
        
        if isinstance(imgs, torch.Tensor):
            imgs = imgs.cpu().numpy()

        grid_w, grid_h = grid.size
        for idx in range(n):
            img = imgs[idx,:].squeeze()
            img = Image.fromarray(img)
            if img.mode != 'RGB':
                img = img.convert('RGB')
            grid.paste(img, box=(idx*w, 0))
        return grid

In [24]:
imgs = images.cpu().numpy()
img = imgs[0, :].squeeze()
# img = Image.fromarray((img * 255).astype(np.uint8))

3

In [14]:
print("{:04d}".format(1))

0001


In [59]:
from glob import glob
import os
import random
import shutil

data_root = "data/CelebA"
train_root = "data/CelebA_train"
val_root = "data/CelebA_val"
os.makedirs(train_root, exist_ok=True)
os.makedirs(val_root, exist_ok=True)

for target in os.listdir(data_root):
    os.makedirs(os.path.join(train_root, target), exist_ok=True)
    images = glob(os.path.join(data_root, target, "*.jpg"))
    train_imgs = random.sample(images, int(0.8 * len(images)))
    for img in train_imgs:
        name = img.split("/")[-1]
        shutil.move(img, os.path.join(train_root, target, name))
    print(target, len(train_imgs))


Gray_Hair Female 1011
Black_Hair Female 18652
Brown_Hair Male 10230
Blond_Hair Male 1399
Blond_Hair Female 22587
Gray_Hair Male 5788
Black_Hair Male 20124
Brown_Hair Female 23027


In [60]:
for target in os.listdir(data_root):
    os.makedirs(os.path.join(val_root, target), exist_ok=True)
    images = glob(os.path.join(data_root, target, "*.jpg"))
    for img in images:
        name = img.split("/")[-1]
        shutil.move(img, os.path.join(val_root, target, name))
    print(target, len(images))

Gray_Hair Female 253
Black_Hair Female 4664
Brown_Hair Male 2558
Blond_Hair Male 350
Blond_Hair Female 5647
Gray_Hair Male 1447
Black_Hair Male 5032
Brown_Hair Female 5757
