In [1]:
import torch
from torchvision.transforms import Compose, Resize, CenterCrop
import numpy as np
import os
import yaml
from yaml.loader import SafeLoader
from model import _G
from util.visualization import visualize_occupancy
import matplotlib.pyplot as plt
from PIL import Image
import clip
import cv2


with open('config.yaml') as fp:
    config = yaml.load(fp, Loader=SafeLoader)

if torch.cuda.is_available():
    print("using cuda")

using cuda


In [19]:
def plot_img(img):
    resized_img = img.permute(0, 2, 3, 1).detach().cpu().numpy()

    plt.imshow(resized_img[0])
    plt.show()
    
def generate_model(Z):
    with torch.no_grad():
        generation = G(Z)
        generation_cpu = (generation.detach().cpu().numpy()>0.5).astype(int)
        visualize_occupancy(generation_cpu[0].squeeze())        

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

mean = torch.as_tensor((0.48145466, 0.4578275, 0.40821073), dtype=torch.float, device=device)
std = torch.as_tensor((0.26862954, 0.26130258, 0.27577711), dtype=torch.float, device=device)
if mean.ndim == 1:
    mean = mean.view(-1, 1, 1)
if std.ndim == 1:
    std = std.view(-1, 1, 1)

transf = Compose([Resize(224, interpolation=Image.BICUBIC), CenterCrop(224)])

model, preprocess = clip.load("ViT-B/32", device=device)

G_path = os.path.join("checkpoints", "G_300.pth")

G = _G(config["dim"], config["latent_len"])

G.load_state_dict(torch.load(G_path))
G.to(device)
G.eval()



_G(
  (layer1): Sequential(
    (0): ConvTranspose3d(256, 256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer2): Sequential(
    (0): ConvTranspose3d(256, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer3): Sequential(
    (0): ConvTranspose3d(128, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer4): Sequential(
    (0): ConvTranspose3d(64, 32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer5): Sequential(
    (0): ConvTranspose3d(32, 1, kernel_size=(4, 4, 4), stride

In [4]:
nof_samples = 1
Z = torch.Tensor(nof_samples, config["latent_len"]).normal_(0, 0.33).to(device)
generation = G(Z)
generation_cpu = (generation.detach().cpu().numpy()>0.5).astype(int)

In [5]:
visualize_occupancy(generation_cpu[0].squeeze())

Output()

In [129]:
nof_samples = 1
Z1 = torch.Tensor(nof_samples, config["latent_len"]).normal_(0, 0.33).to(device)
generate_model(Z1)

Output()

In [131]:
nof_samples = 1
Z2 = torch.Tensor(nof_samples, config["latent_len"]).normal_(0, 0.33).to(device)
generate_model(Z2)

Output()

In [132]:
nof_samples = 1
Z3 = torch.Tensor(nof_samples, config["latent_len"]).normal_(0, 0.33).to(device)
generate_model(Z3)

Output()

In [136]:
generate_model(Z3+Z2-Z1)

Output()

In [154]:
nof_samples = 1
Z = torch.Tensor(nof_samples, config["latent_len"]).normal_(0, 0.33).to(device)
generate_model(Z)

Output()

In [157]:
nof_iterations = 50

text = clip.tokenize(["airliner"]).to(device)
text_features = model.encode_text(text)

os.makedirs("images", exist_ok=True)

nof_samples = 1
#Z = torch.Tensor(nof_samples, config["latent_len"]).normal_(0, 0.33).to(device)

Z.requires_grad_()

optimizer = torch.optim.Adam([Z], lr=0.001)

for i in range(nof_iterations):
    
    optimizer.zero_grad()
    
    generation = G(Z)
    
    if i==0 or i==(nof_iterations-1):
        generation_cpu = (generation.detach().cpu().numpy()>0.5).astype(int)
        visualize_occupancy(generation_cpu[0].squeeze())

    image_side_one = generation.amax(dim=2)
    image_side_two = generation.amax(dim=3)
    image_side_three = generation.amax(dim=4)
    
    image_sides = torch.concat([image_side_one, image_side_two, image_side_three])

    broadcasted_img = torch.broadcast_to(image_sides, [3, 3, 32, 32])

    img = (transf(broadcasted_img)).sub_(mean).div_(std)

    image_features = model.encode_image(img)

    print("Iteration:", i, "Similarity:", torch.nn.functional.cosine_similarity(image_features, (text_features[0]).unsqueeze(0)).sum().item())
    
    cos_sim = -1*torch.nn.functional.cosine_similarity(image_features, (text_features[0]).unsqueeze(0))
    
    cos_sim.sum().backward(retain_graph=True)
    
    optimizer.step()
    
    
    img_print = torch.concat([img[0], img[1], img[2]], axis=2).unsqueeze(dim=0).permute(0, 2, 3, 1).detach().cpu().numpy()[0]
    img_print = (img_print-img_print.min())/(img_print.max()-img_print.min())
    
    plt.imsave(os.path.join("images", f"{i+1:02d}.jpg"), img_print)

Output()

Iteration: 0 Similarity: 0.68408203125
Iteration: 1 Similarity: 0.6845703125
Iteration: 2 Similarity: 0.6845703125
Iteration: 3 Similarity: 0.6845703125
Iteration: 4 Similarity: 0.6845703125
Iteration: 5 Similarity: 0.68408203125
Iteration: 6 Similarity: 0.6845703125
Iteration: 7 Similarity: 0.6845703125
Iteration: 8 Similarity: 0.68505859375
Iteration: 9 Similarity: 0.6845703125
Iteration: 10 Similarity: 0.6845703125
Iteration: 11 Similarity: 0.6845703125
Iteration: 12 Similarity: 0.6845703125
Iteration: 13 Similarity: 0.68505859375
Iteration: 14 Similarity: 0.68505859375
Iteration: 15 Similarity: 0.6845703125
Iteration: 16 Similarity: 0.68505859375
Iteration: 17 Similarity: 0.68505859375
Iteration: 18 Similarity: 0.68505859375
Iteration: 19 Similarity: 0.68505859375
Iteration: 20 Similarity: 0.68505859375
Iteration: 21 Similarity: 0.6845703125
Iteration: 22 Similarity: 0.68505859375
Iteration: 23 Similarity: 0.68505859375
Iteration: 24 Similarity: 0.685546875
Iteration: 25 Similarity

Output()

Iteration: 49 Similarity: 0.68505859375


In [143]:
image_folder = 'images'
video_name = 'video.mp4'

images = [img for img in os.listdir(image_folder) if img.endswith(".jpg")]
frame = cv2.imread(os.path.join(image_folder, images[0]))
height, width, layers = frame.shape

video = cv2.VideoWriter(video_name, 0, 4, (width,height))

for image in images:
    video.write(cv2.imread(os.path.join(image_folder, image)))

cv2.destroyAllWindows()
video.release()