In [2]:
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from model2 import GVAE
from temp_model import Sharpener
from sampler_trainer import Sampler
from collections import OrderedDict
from dataset1 import SketchDataset
from tqdm import tqdm

In [4]:
gpu_id = 0

def DDP_to_normal(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # Remove 'module.'
        new_state_dict[name] = v

    return new_state_dict

def freeze_model(model):
    model = model.eval()
    for p in model.parameters():
        p.requires_grad = False

In [14]:
vae_state_dict = DDP_to_normal(torch.load('checkpoints/model_checkpoint_gvae_ddp_Adam_mse-25_kld-.001_16layers16heads256hiddenencoder_16layers16heads256hiddendecoder_embedim1024_tempnodedim128_relu_after_node_layernorm.pth'))
sampler_state_dict = DDP_to_normal(torch.load('model_checkpoint_sampler_ddp_Adam_depth_32_1.pth'))
sharpener_state_dict = DDP_to_normal(torch.load('model_checkpoint_sharpener_ddp_Adam_16tflayers.pth'))

vae = GVAE(device = gpu_id)
vae.load_state_dict(vae_state_dict)
vae.eval()

sampler = Sampler(device = gpu_id)
sampler.load_state_dict(sampler_state_dict)
sampler.eval()

sharpener = Sharpener(device = gpu_id)
sharpener.load_state_dict(sharpener_state_dict)
sharpener.eval()

Sharpener(
  (mlp_in_nodes): Sequential(
    (0): Linear(in_features=20, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.05)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.05)
  )
  (mlp_in_edges): Sequential(
    (0): Linear(in_features=17, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.05)
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): LeakyReLU(negative_slope=0.05)
  )
  (block_layers): ModuleList(
    (0-11): 12 x TransformerLayer(
      (attention_heads): MultiHeadAttention(
        (lin_query): Linear(in_features=256, out_features=256, bias=True)
        (lin_key): Linear(in_features=256, out_features=256, bias=True)
        (lin_value): Sequential(
          (0): Linear(in_features=256, out_features=256, bias=True)
        )
        (lin_mul): Sequential(
          (0): Linear(in_features=128, out_features=256, bias=True)
        )
        (lin_add): Sequential(
        

In [15]:
T = sampler.max_timestep
a_bar = torch.cos(0.5 * torch.pi * (torch.arange(0.0, 1.0, 1/(T + 1)) + .008) / 1.008) ** 2
a_bar = a_bar / a_bar[0]
a_bar = a_bar.to(gpu_id)

a = a_bar[1:] / a_bar[:-1]
a = torch.cat([a, torch.tensor([0.0]).to(gpu_id)])

sqrt_a = a.sqrt()
sqrt_a_bar = a_bar.sqrt().to(gpu_id)
sqrt_b_bar = (1 - a_bar).sqrt().to(gpu_id)

sqrt_post_var = torch.cat([torch.tensor([0.0]).to(gpu_id), (a_bar[:-1] / a_bar[1:] * (1 - a[1:]))]).sqrt()

def sample_latent(batch_size):
    sampled_latents = torch.randn((batch_size, 1024)).to(gpu_id)

    for i in reversed(range(1, sampler.max_timestep)):
        pred_latents = sampler(sampled_latents, torch.full(size = (batch_size,), fill_value = i).to(gpu_id))
      
        denoised_mean = (sqrt_a_bar[i - 1] * (1 - a[i]) * pred_latents + sqrt_a[i] * (1 - a_bar[i - 1]) * sampled_latents) / (1 - a_bar[i])
        if i > 1:
            sampled_latents = denoised_mean + sqrt_post_var[i] * torch.randn_like(denoised_mean)
        else:
            sampled_latents = denoised_mean
    
    return sampled_latents

In [16]:
with torch.no_grad():
    print("---- Sampling Latents ----")
    latents = sample_latent(2048)
    print("---- Decoding Latents ----")
    noisy_nodes, noisy_edges = vae.decoder(latents)
    print("---- Sharpening Graphs ----")
    nodes, edges = sharpener(noisy_nodes, noisy_edges)
    print("---- Saving Generated graphs ----")
    for i in tqdm(range(latents.size(0))):
        # SketchDataset.render_graph(noisy_nodes[i].cpu(), noisy_edges[i].cpu())
        fig = SketchDataset.render_graph(nodes[i].cpu(), edges[i].cpu())
        
        fig.savefig(f"test/gen/{i}.png")
        plt.close(fig)

---- Sampling Latents ----
---- Decoding Latents ----
---- Sharpening Graphs ----
---- Saving Generated graphs ----


100%|██████████| 2048/2048 [03:51<00:00,  8.83it/s]


In [19]:
print("--- Loading Dataset into Memory ---")
dataset = SketchDataset(root="data/")

--- Loading Dataset into Memory ---


In [22]:
print("--- Saving Real Graphs ---")
for i in tqdm(range(2048)):
    idx = torch.randint(0, dataset.nodes.size(0), (1,))
    fig = SketchDataset.render_graph(dataset.nodes[idx.item()], dataset.edges[idx.item()])
    fig.savefig(f"test/real/{i}.png")
    plt.close(fig)

--- Saving Real Graphs ---


100%|██████████| 2048/2048 [02:30<00:00, 13.61it/s]


In [6]:
import argparse
import os
import lpips

with torch.no_grad():
	## Initializing the model
	loss_fn = lpips.LPIPS(net='alex')
	loss_fn.to(gpu_id)

	# crawl directories
	files = os.listdir("test/real/")

	dist = 0

	for file in tqdm(files):
		if(os.path.exists(os.path.join("test/gen/",file))):
			# Load images
			img0 = lpips.im2tensor(lpips.load_image(os.path.join("test/real/",file))) # RGB image from [-1,1]
			img1 = lpips.im2tensor(lpips.load_image(os.path.join("test/gen/",file)))

			img0 = img0.to(gpu_id)
			img1 = img1.to(gpu_id)

			# Compute distance
			dist = dist + loss_fn.forward(img0,img1)
			
	print((dist / 2048).item())
		

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/chereds/.conda/envs/thesis/lib/python3.11/site-packages/lpips/weights/v0.1/alex.pth


100%|██████████| 2048/2048 [01:08<00:00, 29.77it/s]

0.331959992647171



