In [None]:
# @title Cloning The Repository {"display-mode":"form"}
!wget https://github.com/SanshruthR/VQGAN-CLIP/raw/refs/heads/master/content.zip

In [None]:
# @title Unzipping contents {"display-mode":"form"}
!unzip -o ./content.zip

In [None]:
# @title Resolving Dependencies {"display-mode":"form"}
%%capture
!pip install --no-deps ftfy regex tqdm
!pip install omegaconf pytorch-lightning
!pip uninstall torchtext --yes
!pip install einops

In [None]:
# @title Library Imports {"display-mode":"form"}
#import libraries
import numpy as np
import torch,os,imageio,pdb,math
import torchvision
import torchvision.transforms
import torchvision.transforms.functional
import PIL
import matplotlib.pyplot as plt
import yaml
from omegaconf import OmegaConf
from CLIP import clip
import os
os.chdir('./taming-transformers')
from taming.models.vqgan import VQModel
os.chdir('..')
from PIL import Image
import cv2
import os
import imageio

In [None]:
# @title Start VQGAN-CLIP with Text Prompts Configuration {"display-mode":"form"}
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# @markdown ## Text Prompts Configuration
# @markdown Enter your prompts below:

# @markdown ### Include Prompts (comma-separated)
include_text = "desert, heavy rain, cactus" # @param {type:"string"}

# @markdown ### Exclude Prompts (comma-separated)
exclude_text = "confusing, blurry" # @param {type:"string"}

# @markdown ### Extra Style Prompts (comma-separated)
extras_text = "desert, clear, detailed, beautiful, good shape, detailed" # @param {type:"string"}


w1 = 1.0
w2 = 0.9

def create_video():
    image_folder = './generated'
    video_name = 'morphing_video.mp4'
    images = sorted([img for img in os.listdir(image_folder) if img.endswith(".png") or img.endswith(".jpg")])

    if len(images) == 0:
        print("No images found in the folder.")
        exit()

    frame = cv2.imread(os.path.join(image_folder, images[0]))
    height, width, layers = frame.shape
    video_writer = imageio.get_writer(video_name, fps=10)

    for image in images:
        img_path = os.path.join(image_folder, image)
        img = imageio.imread(img_path)
        video_writer.append_data(img)

    video_writer.close()
    print(f"Video saved as {video_name}")

def save_from_tensors(tensor, output_dir, filename):
    img = tensor.clone()
    img = img.mul(255).byte()
    img = img.cpu().numpy().transpose((1, 2, 0))
    os.makedirs(output_dir, exist_ok=True)
    Image.fromarray(img).save(os.path.join(output_dir, filename))

def norm_data(data):
    return (data.clip(-1, 1) + 1) / 2

def setup_clip_model():
    model, _ = clip.load('ViT-B/32', jit=False)
    model.eval().to(device)
    return model

def setup_vqgan_model(config_path, checkpoint_path):
    config = OmegaConf.load(config_path)
    model = VQModel(**config.model.params)
    state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
    model.load_state_dict(state_dict, strict=False)
    return model.eval().to(device)

def generator(x, model):
    x = model.post_quant_conv(x)
    x = model.decoder(x)
    return x

def encode_text(text, clip_model):
    t = clip.tokenize(text).to(device)
    return clip_model.encode_text(t).detach().clone()

def create_encoding(include, exclude, extras, clip_model):
    include_enc = [encode_text(text, clip_model) for text in include]
    exclude_enc = [encode_text(text, clip_model) for text in exclude]
    extras_enc = [encode_text(text, clip_model) for text in extras]
    return include_enc, exclude_enc, extras_enc

def create_crops(img, num_crops=32, size1=225, noise_factor=0.05):
    aug_transform = torch.nn.Sequential(
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomAffine(30, translate=(0.1, 0.1), fill=0)
    ).to(device)

    p = size1 // 2
    img = torch.nn.functional.pad(img, (p, p, p, p), mode='constant', value=0)
    img = aug_transform(img)

    crop_set = []
    for _ in range(num_crops):
        gap1 = int(torch.normal(1.2, .3, ()).clip(.43, 1.9) * size1)
        offsetx = torch.randint(0, int(size1 * 2 - gap1), ())
        offsety = torch.randint(0, int(size1 * 2 - gap1), ())
        crop = img[:, :, offsetx:offsetx + gap1, offsety:offsety + gap1]
        crop = torch.nn.functional.interpolate(crop, (224, 224), mode='bilinear', align_corners=True)
        crop_set.append(crop)

    img_crops = torch.cat(crop_set, 0)
    randnormal = torch.randn_like(img_crops, requires_grad=False)
    randstotal = torch.rand((img_crops.shape[0], 1, 1, 1)).to(device)
    img_crops = img_crops + noise_factor * randstotal * randnormal

    return img_crops

def optimize_result(params, prompt, vqgan_model, clip_model, w1, w2, extras_enc, exclude_enc):
    alpha = 1
    beta = 0.5
    out = generator(params, vqgan_model)
    out = norm_data(out)
    out = create_crops(out)
    out = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                         (0.26862954, 0.26130258, 0.27577711))(out)

    img_enc = clip_model.encode_image(out)
    final_enc = w1 * prompt + w2 * extras_enc[0]
    final_text_include_enc = final_enc / final_enc.norm(dim=-1, keepdim=True)
    final_text_exclude_enc = exclude_enc[0]

    main_loss = torch.cosine_similarity(final_text_include_enc, img_enc, dim=-1)
    penalize_loss = torch.cosine_similarity(final_text_exclude_enc, img_enc, dim=-1)

    return -alpha * main_loss.mean() + beta * penalize_loss.mean()

def optimize(params, optimizer, prompt, vqgan_model, clip_model, w1, w2, extras_enc, exclude_enc):
    loss = optimize_result(params, prompt, vqgan_model, clip_model, w1, w2, extras_enc, exclude_enc)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

def training_loop(params, optimizer, include_enc, exclude_enc, extras_enc, vqgan_model, clip_model, w1, w2,
                 total_iter=200, show_step=1):
    res_img = []
    res_z = []

    for prompt in include_enc:
        for it in range(total_iter):
            loss = optimize(params, optimizer, prompt, vqgan_model, clip_model, w1, w2, extras_enc, exclude_enc)

            if it >= 0 and it % show_step == 0:
                with torch.no_grad():
                    generated = generator(params, vqgan_model)
                    new_img = norm_data(generated[0].to(device))
                res_img.append(new_img)
                res_z.append(params.clone().detach())
                print(f"loss: {loss.item():.4f}\nno. of iteration: {it}")

        torch.cuda.empty_cache()
    return res_img, res_z

def main():
    # Process the input prompts
    include = [x.strip() for x in include_text.split(',')]
    exclude = [x.strip() for x in exclude_text.split(',')]
    extras = [x.strip() for x in extras_text.split(',')]

    # Setup models
    clip_model = setup_clip_model()
    vqgan_model = setup_vqgan_model("./models/vqgan_imagenet_f16_16384/configs/model.yaml",
                                   "./models/vqgan_imagenet_f16_16384/checkpoints/last.ckpt")

    # Parameters
    learning_rate = 0.1
    batch_size = 1
    wd = 0.1
    size1, size2 = 225, 400

    # Initialize parameters
    initial_image = PIL.Image.open('./gradient1.png')
    initial_image = initial_image.resize((size2, size1))
    initial_image = torchvision.transforms.ToTensor()(initial_image).unsqueeze(0).to(device)

    with torch.no_grad():
        z, _, _ = vqgan_model.encode(initial_image)

    params = torch.nn.Parameter(z).to(device)
    optimizer = torch.optim.AdamW([params], lr=learning_rate, weight_decay=wd)
    params.data = params.data * 0.6 + torch.randn_like(params.data) * 0.4

    # Encode prompts
    include_enc, exclude_enc, extras_enc = create_encoding(include, exclude, extras, clip_model)

    # Run training loop
    res_img, res_z = training_loop(params, optimizer, include_enc, exclude_enc, extras_enc,
                                 vqgan_model, clip_model, w1, w2)

    # Save results
    output_dir = "generated"
    for i, img in enumerate(res_img):
        save_from_tensors(img, output_dir, f"generated_image_{i:03d}.png")

    print(f"Generated images saved in the '{output_dir}' directory.")
    create_video()

if __name__ == "__main__":
    main()

# Display the video
from IPython.display import display, Video
display(Video("morphing_video.mp4", embed=True))

100%|████████████████████████████████████████| 338M/338M [00:02<00:00, 128MiB/s]


Working with z of shape (1, 256, 16, 16) = 65536 dimensions.


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:04<00:00, 122MB/s] 


Downloading vgg_lpips model from https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1 to taming/modules/autoencoder/lpips/vgg.pth


8.19kB [00:00, 2.76MB/s]                   
  self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)


loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.


  state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]


loss: -0.1498
no. of iteration: 0
loss: -0.1558
no. of iteration: 1
loss: -0.1638
no. of iteration: 2
loss: -0.1582
no. of iteration: 3
loss: -0.1624
no. of iteration: 4
loss: -0.1570
no. of iteration: 5
loss: -0.1663
no. of iteration: 6
loss: -0.1665
no. of iteration: 7
loss: -0.1670
no. of iteration: 8
loss: -0.1658
no. of iteration: 9
loss: -0.1719
no. of iteration: 10
loss: -0.1749
no. of iteration: 11
loss: -0.1737
no. of iteration: 12
loss: -0.1731
no. of iteration: 13
loss: -0.1801
no. of iteration: 14
loss: -0.1838
no. of iteration: 15
loss: -0.1755
no. of iteration: 16
loss: -0.1885
no. of iteration: 17
loss: -0.1902
no. of iteration: 18
loss: -0.1852
no. of iteration: 19
loss: -0.1812
no. of iteration: 20
loss: -0.1821
no. of iteration: 21
loss: -0.1851
no. of iteration: 22
loss: -0.1860
no. of iteration: 23
loss: -0.1750
no. of iteration: 24
loss: -0.1875
no. of iteration: 25
loss: -0.1943
no. of iteration: 26
loss: -0.1921
no. of iteration: 27
loss: -0.1909
no. of iteration

  img = imageio.imread(img_path)


Video saved as morphing_video.mp4
