In [19]:
import argparse
import torch
from utils import get_model, LoadEncoder
from models.engine import DDIMSampler, DDIMSamplerEncoder
from torchvision.utils import save_image
from collections import OrderedDict

In [21]:
class Args(argparse.Namespace):
    arch = "unetattention"
    img_size=64
    num_timestep = 1000
    beta = (0.0001, 0.02)
    num_condition = [2, 4]
    emb_size = 128
    channel_mult = [1, 2, 2, 2]
    num_res_blocks = 2
    use_spatial_transformer = True
    num_heads = 4
    num_sample = 100
    w = 1.8
    projection_dim=512
    only_table = False
    concat = False
    only_encoder = False
    num_head_channels = -1
    encoder_path = None

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = Args()

In [23]:
model = get_model(args)

ckpt = torch.load("checkpoints/Zappo50K/HeelSliAdaCA/model_100.pth")["model"]
new_dict = OrderedDict()
    
for k, v in ckpt.items():
    if k.startswith("module"):
        new_dict[k[7:]] = v
    else:
        new_dict[k] = v
try:
    model.load_state_dict(new_dict)
    print("All keys successfully match")
except:
    print("some keys are missing!")

for p in model.parameters():
    p.requires_grad = False

model.eval()
model.to(device)

sampler = DDIMSampler(
    model=model,
    beta =args.beta,
    T=args.num_timestep,
    w=args.w,
).to(device)

if args.encoder_path != None:
    encoder = LoadEncoder(args).to(device)
    sampler = DDIMSamplerEncoder(
            model = model,
            encoder = encoder,
            beta = args.beta,
            T = args.num_timestep,
            w = args.w,
            only_encoder = args.only_encoder
    ).to(device)

All keys successfully match


In [25]:
from config import Zappo50K, toy_dataset
CFG = Zappo50K()
target = "Heel Slipper"
atr, obj = CFG.ATR2IDX[target.split(" ")[0]], CFG.OBJ2IDX[target.split(" ")[-1]]
atr = torch.tensor(atr, dtype=torch.long, device=device).repeat(args.num_sample)
obj = torch.tensor(obj, dtype=torch.long, device=device).repeat(args.num_sample)

x_i = torch.randn(args.num_sample, 3, 64, 64).to(device)
x0 = sampler(x_i, atr, obj, steps=100)
save_image(x0, "SampledImg/HeelSliCA.png")

100%|[38;2;101;101;181m██████████[0m| 100/100 [01:44<00:00,  1.04s/it, step=1, sample=1]


torch.Size([100])