In [1]:
from pytorch_lightning import callbacks
import yaml
import argparse
import numpy as np
import cv2
import matplotlib.pyplot as plt 
from models import *
import torch.nn  as nn
from experiments.vae_experiment import VAEXperiment
from experiments.vae_pix2pix_exp import Pix2pixExperiment
import torch.backends.cudnn as cudnn
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TestTubeLogger
from torch.utils.data import DataLoader 
from terrain_loader import TerrainDataset
from pytorch_lightning.callbacks import ModelCheckpoint

In [2]:

with open("configs/vae_pix2pix.yml", 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(".\n\n",exc)

In [3]:
dataset = TerrainDataset(root = config['exp_params']['data_path'],
                        train=False,
                        hide_green=config['exp_params']['hide_green'],
                        norm=config['exp_params']['norm'])

sample_dataloader = DataLoader(dataset,
                        batch_size= 1,
                        num_workers=config['exp_params']['n_workers'],
                        shuffle = True,
                        drop_last=False)

In [10]:
 
! rsync -aP ada:/share3/shanthika_naik/pytorch_terrain_authoring/TerrainAuthoring_Pytorch/logs/VanillaVAE/version_0  /scratch/shan/
! rsync -aP ada:/share3/shanthika_naik/pytorch_terrain_authoring/TerrainAuthoring_Pytorch/logs/VAE_PIX2PIX/log0/version_3 /scratch/shan

receiving incremental file list
receiving incremental file list


In [11]:
#Vae Model
vae_model = vae_models[config['vae_model_params']['name']](**config['vae_model_params'])

# pix2pix model
gen_model = pix2pix_model[config['pix2pix_model_params']['gen_name']](config['exp_params']['in_channels'],config['exp_params']['out_channels'])
disc_model = pix2pix_model[config['pix2pix_model_params']['disc_name']](config['exp_params']['in_channels'])


In [12]:
if config['vae_model_params']['load_model'] :
    experiment_p2p = Pix2pixExperiment.load_from_checkpoint(config['pix2pix_model_params']['pretrained_model'], gen_model=gen_model,disc_model=disc_model,vae_model=vae_model,params=config['exp_params'])
    experiment_vae = VAEXperiment.load_from_checkpoint(config['vae_model_params']['pretrained_model'], vae_model=vae_model,params=config['exp_params'])
    print("[INFO] Loaded pretrained model")

[INFO] Loaded pretrained model


In [13]:
vae_model.eval()
gen_model.eval()

Generator(
  (down1): DownSample(
    (model): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.3, inplace=True)
    )
  )
  (down2): DownSample(
    (model): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.3, inplace=True)
    )
  )
  (down3): DownSample(
    (model): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.3, inplace=True)
    )
  )
  (down4): DownSample(
    (model): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affin

In [14]:
def denormalize(result):
        # minv, maxv = torch.min(result), torch.max(result)
        new = (result+1)*127.5
        return torch.squeeze(new).detach().numpy().transpose((1,2,0)).astype(np.uint8)


In [15]:
def get_mse(ip,op ):

    res = vae_model(ip)[0] 
     
    res = gen_model(res) 
    res = denormalize(res)
    op = denormalize(op)
    # ip = denormalize(ip)
    ip = (ip*255)
    ip = torch.squeeze(ip).detach().numpy().transpose((1,2,0)).astype(np.uint8)
    res = cv2.GaussianBlur(res, (5, 5), 0)

    res_3 = cv2.GaussianBlur(res, (3,3), 0)
    res_5 = cv2.GaussianBlur(res, (5,5), 0)
    res_7 = cv2.GaussianBlur(res, (7,7), 0)
    res_11 = cv2.GaussianBlur(res, (11,11), 0)
    res_21 = cv2.GaussianBlur(res, (21,21), 0)


    mse = np.mean((res-op)**2)
    mse_3 = np.mean((res_3-op)**2)
    mse_5 = np.mean((res_5-op)**2)
    mse_7 = np.mean((res_7-op)**2)
    mse_11 = np.mean((res_11-op)**2)
    mse_21 = np.mean((res_21-op)**2)


    return mse,mse_3,mse_5,mse_7,mse_11,mse_21

In [16]:
def calculate_mse( ):
    count=0
    mse_loss = 0 
    mse_loss_3 = 0 
    mse_loss_5 = 0 
    mse_loss_7 = 0 
    mse_loss_11 = 0 
    mse_loss_21 = 0 

    for ip, op in sample_dataloader:
        count+=1
        ml,ml_3,ml_5,ml_7,ml_11,ml_21 = get_mse(ip,op )  
        mse_loss += ml
        mse_loss_3 +=ml_3
        mse_loss_5 +=ml_5
        mse_loss_7 +=ml_7
        mse_loss_11 +=ml_11
        mse_loss_21 +=ml_21

    
    return mse_loss,mse_loss_3,mse_loss_5,mse_loss_7,mse_loss_11,mse_loss_21

In [17]:
mse_loss,mse_loss_3,mse_loss_5,mse_loss_7,mse_loss_11,mse_loss_21 = calculate_mse()
print(mse_loss,mse_loss_3,mse_loss_5,mse_loss_7,mse_loss_11,mse_loss_21)


In [None]:
mse_loss_1

28.895053582564444