In [9]:
import argparse
import torch
from utils import get_model, LoadEncoder
from models.engine import TripleCondDDIMSampler
from torchvision.utils import save_image, make_grid
from collections import OrderedDict

In [21]:
class Args(argparse.Namespace):
    arch = "unetattentiontriple"
    img_size=64
    num_timestep = 1000
    beta = (0.0001, 0.02)
    num_condition = [2, 2, 4]
    emb_size = 128
    channel_mult = [1, 2, 2, 2]
    num_res_blocks = 2
    use_spatial_transformer = True
    num_heads = 4
    num_sample_missing = 50
    num_sample = 5
    w = 1.8
    projection_dim=512
    only_table = False
    concat = False
    only_encoder = False
    num_head_channels = -1
    encoder_path = None

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = Args()

In [22]:
model = get_model(args)

ckpt = torch.load("checkpoints/Zappo50KTriple/HeelSandalCA/model_100.pth")["model"]
new_dict = OrderedDict()
    
for k, v in ckpt.items():
    if k.startswith("module"):
        new_dict[k[7:]] = v
    else:
        new_dict[k] = v
try:
    model.load_state_dict(new_dict)
    print("All keys successfully match")
except:
    print("some keys are missing!")

for p in model.parameters():
    p.requires_grad = False

model.eval()
model.to(device)

sampler = TripleCondDDIMSampler(
    model=model,
    beta =args.beta,
    T=args.num_timestep,
    w=args.w,
).to(device)

# if args.encoder_path != None:
#     encoder = LoadEncoder(args).to(device)
#     sampler = DDIMSamplerEncoder(
#             model = model,
#             encoder = encoder,
#             beta = args.beta,
#             T = args.num_timestep,
#             w = args.w,
#             only_encoder = args.only_encoder
#     ).to(device)

All keys successfully match


In [24]:
from config import Zappo50KTriple, TripleCond

CFG = Zappo50KTriple()

missings = ["Right Heel Sandal", "Left Heel Sandal"]
targets = [
    "Right Flat Boot", "Right Flat Shoe", "Right Flat Slipper", "Right Flat Sandal", "Right Heel Boot", "Right Heel Shoe", "Right Heel Slipper",
    "Left Flat Boot", "Left Flat Shoe", "Left Flat Slipper", "Left Flat Sandal", "Left Heel Boot", "Left Heel Shoe", "Left Heel Slipper"
]
# missing = "Gray_Hair Female"
# targets = ["Brown_Hair Male", "Black_Hair Male", "Gray_Hair Male", "Blond_Hair Male", "Brown_Hair Female", "Black_Hair Female", "Blond_Hair Female"]

# images = []
# for missing in missings:
#     size, atr, obj = CFG.SIZE2IDX[missing.split(" ")[0]], CFG.ATR2IDX[missing.split(" ")[1]], CFG.OBJ2IDX[missing.split(" ")[-1]]
#     size = torch.tensor(size, dtype=torch.long, device=device).repeat(args.num_sample_missing)
#     atr = torch.tensor(atr, dtype=torch.long, device=device).repeat(args.num_sample_missing)
#     obj = torch.tensor(obj, dtype=torch.long, device=device).repeat(args.num_sample_missing)
#     x_i = torch.randn(args.num_sample_missing, 3, 64, 64).to(device)
#     x0 = sampler(x_i, size, atr, obj, steps=100)
#     x0 = x0 * 0.5 + 0.5
#     images.append(x0)

# images = torch.concatenate(images, dim=0)
# save_image(images, "SampledImg/TripleCond/HeelSandalCA.png", nrow=10, noramlized=True)

images = []
for target in targets:
    size, atr, obj = CFG.SIZE2IDX[target.split(" ")[0]], CFG.ATR2IDX[target.split(" ")[1]], CFG.OBJ2IDX[target.split(" ")[-1]]
    size = torch.tensor(size, dtype=torch.long, device=device).repeat(args.num_sample)
    atr = torch.tensor(atr, dtype=torch.long, device=device).repeat(args.num_sample)
    obj = torch.tensor(obj, dtype=torch.long, device=device).repeat(args.num_sample)

    x_i = torch.randn(args.num_sample, 3, 64, 64).to(device)
    x0 = sampler(x_i, size, atr, obj, steps=100)
    x0 = x0 * 0.5 + 0.5
    images.append(x0)
    
images = torch.concatenate(images, dim=0)
# images = make_grid(images, nrow=args.num_sample)
save_image(images, "SampledImg/TripleCond/SeenCA.png", nrow=10, normalized=True)

100%|[38;2;101;101;181m██████████[0m| 100/100 [00:09<00:00, 10.22it/s, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [00:10<00:00,  9.75it/s, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [00:09<00:00, 10.25it/s, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [00:09<00:00, 10.00it/s, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [00:09<00:00, 10.00it/s, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [00:09<00:00, 10.41it/s, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [00:09<00:00, 10.56it/s, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [00:09<00:00, 10.58it/s, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [00:09<00:00, 10.38it/s, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [00:09<00:00, 10.45it/s, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [00:09<00:00, 10.43it/s, step=1, sample=1]

In [19]:
for t in targets:
    print(t)

Right Flat Boot
Right Flat Shoe
Right Flat Slipper
Right Flat Sandal
Right Heel Boot
Right Heel Shoe
Right Heel SlipperLeft Flat Boot
Left Flat Shoe
Left Flat Slipper
Left Flat Sandal
Left Heel Boot
Left Heel Shoe
Left Heel Slipper
