In [None]:
"""make variations of input image"""

import argparse, os, sys, glob
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import PIL
import torch
import torch.nn as nn
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
from torchvision.utils import make_grid
from torch import autocast
from contextlib import nullcontext
import time
from pytorch_lightning import seed_everything
import torch.nn.functional as F
sys.path.append(os.path.dirname(sys.path[0]))
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from models.encoders.model_irse import Backbone
from transformers import CLIPProcessor, CLIPModel
from torchvision.transforms import Compose, ToTensor, Normalize
from facenet_pytorch import InceptionResnetV1,MTCNN
transfroms = Compose(
        [ToTensor(), Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.to(device)
    model.eval()
    return model


def load_img(path):
    image = Image.open(path).convert("RGB")
    w, h = image.size
    print(f"loaded input image of size ({w}, {h}) from {path}")
    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
    image = image.resize((512, 512), resample=PIL.Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.*image - 1.

In [None]:
config="configs/stable-diffusion/v1-inference.yaml"
ckpt="./models/sd/sd-v1-4.ckpt"
config = OmegaConf.load(f"{config}")
model = load_model_from_config(config, f"{ckpt}")
sampler = DDIMSampler(model)

In [None]:
def sample_reverse(S,
            batch_size,
            shape,
            conditioning=None,
            callback=None,
            normals_sequence=None,
            img_callback=None,
            quantize_x0=False,
            eta=0.,
            mask=None,
            x0=None,
            temperature=1.,
            noise_dropout=0.,
            score_corrector=None,
            corrector_kwargs=None,
            verbose=True,
            x_T=None,
            log_every_t=100,
            unconditional_guidance_scale=1.,
            unconditional_conditioning=None,
            # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
            **kwargs
            ):
    if conditioning is not None:
        if isinstance(conditioning, dict):
            cbs = conditioning[list(conditioning.keys())[0]].shape[0]
            if cbs != batch_size:
                print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
        else:
            if conditioning[0].shape[0] != batch_size:
                print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

    sampler.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
    # sampling
    C, H, W = shape
    size = (batch_size, C, H, W)
    print(f'Data shape for DDIM sampling is {size}, eta {eta}')

    samples, intermediates = ddim_reverse_sampling(conditioning, size,
                                                callback=callback,
                                                img_callback=img_callback,
                                                quantize_denoised=quantize_x0,
                                                mask=mask, x0=x0,
                                                ddim_use_original_steps=False,
                                                noise_dropout=noise_dropout,
                                                temperature=temperature,
                                                score_corrector=score_corrector,
                                                corrector_kwargs=corrector_kwargs,
                                                x_T=x_T,
                                                log_every_t=log_every_t,
                                                unconditional_guidance_scale=unconditional_guidance_scale,
                                                unconditional_conditioning=unconditional_conditioning,
                                                )
    return samples, intermediates

@torch.no_grad()
def ddim_reverse_sampling(cond, shape,
                    x_T=None, ddim_use_original_steps=False,
                    callback=None, timesteps=None, quantize_denoised=False,
                    mask=None, x0=None, img_callback=None, log_every_t=100,
                    temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                    unconditional_guidance_scale=1., unconditional_conditioning=None,):
    device = sampler.model.betas.device
    b = shape[0]
    if x_T is None:
        img = torch.randn(shape, device=device)
    else:
        img = x_T

    if timesteps is None:
        timesteps = sampler.ddpm_num_timesteps if ddim_use_original_steps else sampler.ddim_timesteps
    elif timesteps is not None and not ddim_use_original_steps:
        subset_end = int(min(timesteps / sampler.ddim_timesteps.shape[0], 1) * sampler.ddim_timesteps.shape[0]) - 1
        timesteps = sampler.ddim_timesteps[:subset_end]

    intermediates = {'x_inter': [img], 'pred_x0': [img]}
    time_range = (range(0,timesteps)) if ddim_use_original_steps else (timesteps)
    total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
    print(f"Running DDIM reverse Sampling with {total_steps} timesteps")

    iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
    #input_noise = torch.randn(img.shape, device=device)
    for i, step in enumerate(iterator):
        index = i
        ts = torch.full((b,), step, device=device, dtype=torch.long)
        norm_t = int(ts*10/1000)
        if mask is not None:
            assert x0 is not None
            img_orig = sampler.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
            img = img_orig * mask + (1. - mask) * img
        #seed_everything(42)
        outs = sampler.p_sample_ddim_reverse(img, cond[norm_t], ts, index=index, use_original_steps=ddim_use_original_steps,
                                    quantize_denoised=quantize_denoised, temperature=temperature,
                                    noise_dropout=noise_dropout, score_corrector=score_corrector,
                                    corrector_kwargs=corrector_kwargs,
                                    unconditional_guidance_scale=1.0,
                                    unconditional_conditioning=cond[norm_t])
        img, pred_x0 = outs
        if callback: callback(i)
        if img_callback: img_callback(pred_x0, i)

        if index % log_every_t == 0 or index == total_steps - 1:
            intermediates['x_inter'].append(img)
            intermediates['pred_x0'].append(pred_x0)

    return img, intermediates



In [None]:
def decode( x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
            use_original_steps=False, input_noise = None,initial_img = None,resnet=None,guidance=None,loss_guidance_scale=0):

    timesteps = np.arange(1000) if use_original_steps else sampler.ddim_timesteps
    timesteps = timesteps[:t_start]

    time_range = np.flip(timesteps)
    total_steps = timesteps.shape[0]
    print(f"Running DDIM Sampling with {total_steps} timesteps")

    iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
    x_dec = x_latent
    for i, step in enumerate(iterator):
        index = total_steps - i - 1
        ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
        norm_t = int(ts*10/1000)
        #norm_t = int(ts/87)
        if int(ts)>600:
            cond2 = cond[norm_t]
            loss_guidance_scale = 1.5
        else:
            cond2 = unconditional_conditioning
            loss_guidance_scale = 2.5
        
        print(norm_t)
        x_dec, _ = sampler.p_sample_ddim2(x_dec, cond2, ts, index=index, use_original_steps=use_original_steps,
                                        unconditional_guidance_scale=unconditional_guidance_scale,
                                        unconditional_conditioning=cond2,
                                        input_noise = input_noise,
                                        initial_img = initial_img,
                                        resnet = resnet,
                                        guidance=guidance,
                                        loss_guidance_scale=loss_guidance_scale)
    return x_dec

In [None]:
def main(prompt = '', content_dir = '',ddim_steps = 50,strength = 0.5, model = None, resnet = None,seed=42):
    ddim_eta=0.0
    n_iter=1
    C=4
    f=8
    n_samples=1
    n_rows=0
    scale=10.0
    
    precision="autocast"
    outdir="./out"
    seed_everything(seed)

    mtcnn = MTCNN(image_size=160)
    os.makedirs(outdir, exist_ok=True)
    outpath = outdir

    batch_size = n_samples
    n_rows = n_rows if n_rows > 0 else batch_size
    data = [batch_size * [prompt]]


    sample_path = os.path.join(outpath, "samples")
    os.makedirs(sample_path, exist_ok=True)
    base_count = len(os.listdir(sample_path))
    grid_count = len(os.listdir(outpath)) + 10

    content_name =  content_dir.split('/')[-1].split('.')[0]
    content_image = load_img(content_dir).to(device)
    content_image = repeat(content_image, '1 ... -> b ...', b=batch_size)
    image = Image.open(content_dir)
    align_img = mtcnn(image)
    initial_image = ((align_img)).unsqueeze(0).to(device)
    #initial_image = mtcnn(initial_image).unsqueeze(0).to(device)
    content_latent = model.get_first_stage_encoding(model.encode_first_stage(content_image))  # move to latent space

    init_latent = content_latent

    sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)

    assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
    t_enc = int(strength * ddim_steps)
    print(f"target t_enc is {t_enc} steps")

    precision_scope = autocast if precision == "autocast" else nullcontext
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                tic = time.time()
                all_samples = list()
                for n in trange(n_iter, desc="Sampling"):
                    for prompts in tqdm(data, desc="data"):
                        uc = None
                        if scale != 1.0:
                            uc=[]
                            for i in range(10):
                                uc.append(model.get_learned_conditioning(batch_size * [""], content_image,i))
                        if isinstance(prompts, tuple):
                            prompts = list(prompts)
                        c= [] 
                        for i in range(10):
                            c.append( model.get_learned_conditioning(prompts, content_image,i))
                        seed_everything(seed)
                        z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
                        #z_enc = torch.randn_like(init_latent)
                        t_enc = int(strength * ddim_steps)

                        x_inversion,_ = sample_reverse(ddim_steps,1,(4,512,512),c,verbose=False, eta=0.,x_T = init_latent,
                unconditional_guidance_scale=scale,
                unconditional_conditioning=uc[0],)
                        torch.save(x_inversion,os.path.join(outpath, content_name+'_zt.pt'))
                        torch.save(c,os.path.join(outpath, content_name+'_embedding.pt'))
                        del x_inversion

                        samples = decode(z_enc, c, t_enc, 
                                                unconditional_guidance_scale=scale,
                                                    unconditional_conditioning=uc[0],initial_img=initial_image,resnet=resnet,
                                                    guidance = True,loss_guidance_scale=1)
                        print(z_enc.shape, uc[0].shape, t_enc)

                        x_samples = model.decode_first_stage(samples)

                        x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

                        for x_sample in x_samples:
                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                            base_count += 1
                        all_samples.append(x_samples)

                # additionally, save as grid
                grid = torch.stack(all_samples, 0)
                grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                grid = make_grid(grid, nrow=n_rows)

                # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                output = Image.fromarray(grid.astype(np.uint8))
                output.save(os.path.join(outpath, content_name+f'-{grid_count:04}.png'))
                # Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
                grid_count += 1

                toc = time.time()
    return output

In [None]:
# model.cpu()
model.embedding_manager.load('./logs/Alejandro_Toledo2023-07-23T15-38-22_Alejandro_Toledo/checkpoints/embeddings.pt')
facenet = InceptionResnetV1(pretrained='vggface2').eval()
#resnet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
#resnet.load_state_dict(torch.load("/home/gpu/.cache/torch/hub/checkpoints/model_ir_se50.pth"))
facenet = facenet.to(device).eval()
model = model.to(device)

In [None]:
content_root = './dataset/Alejandro_Toledo_0037.jpg'
main(prompt = '*', \
content_dir = os.path.join(content_root), \
ddim_steps = 100, \
strength = 0.80, \
seed=100162, \
model = model,\
resnet=facenet )