In [1]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join('..', '..')))

import argparse
from pprint import pp
import torch
from torch import nn
from tqdm import tqdm
import numpy as np
import json
import os
from omegaconf import OmegaConf
from torch.utils.tensorboard import SummaryWriter

from lpn.utils import load_dataset, load_config
from lpn.utils import get_model
from lpn.utils import get_loss_hparams_and_lr, get_loss
from lpn.utils import trainer
from lpn.utils import utils
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset_config_path = "./configs/test_dataset.json"

In [None]:
model_config_path = "./models/ne_lpn_sub_sq/s=0.1/model_config.json"
model_weight_path = "./models/ne_lpn_sub_sq/s=0.1/model.pt"

#import model as cuda
model_config = load_config(model_config_path)
model = get_model(model_config).to(device)
model.load_state_dict(torch.load(model_weight_path)["model_state_dict"])

init weights


<All keys matched successfully>

In [11]:
Lpn_config_path = "./models/lpn/s=0.1/model_config.json"
Lpn_weight_path = "./models/lpn/s=0.1/model.pt"

lpn_config = load_config(Lpn_config_path)
lpn = get_model(lpn_config).to(device)
lpn.load_state_dict(torch.load(Lpn_weight_path)["model_state_dict"])

init weights


<All keys matched successfully>

In [6]:
dataset_config = load_config(dataset_config_path)
test_dataset = load_dataset(dataset_config, "test")
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

dataset:  celeba


In [12]:
shift_factors = [0.1, 0.2, 0.3, 0.4, 0.5]

for shift_factor in shift_factors:

    ne_running_loss = 0 
    lpn_running_loss = 0
    num_steps = 500
    for step, batch in enumerate(test_data_loader):
        if step >= num_steps:
            break

        clean_image = batch["image"].to(device)
        shift_tensor = torch.full_like(clean_image, shift_factor).to(device)
        shifted_image = clean_image + shift_tensor

        ne_clean_out = model(clean_image)
        lpn_clean_out = lpn(clean_image)

        ne_shifted_out = model(shifted_image)
        lpn_shifted_out = lpn(shifted_image)

        #shifted versions of clean outputs
        ne_clean_out_shifted = ne_clean_out + shift_tensor
        lpn_clean_out_shifted = lpn_clean_out + shift_tensor

        #loss between output from shifted input vs shifted output from clean input
        mse_loss = nn.MSELoss()
        ne_loss = mse_loss(ne_shifted_out, ne_clean_out_shifted)
        lpn_loss = mse_loss(lpn_shifted_out, lpn_clean_out_shifted)

        ne_running_loss += ne_loss.item()
        lpn_running_loss += lpn_loss.item()
    
    ne_avg_loss = ne_running_loss / num_steps
    lpn_avg_loss = lpn_running_loss / num_steps

    print(f"Shift factor: {shift_factor}")
    print(f"NE avg loss: {ne_avg_loss}")
    print(f"LPN avg loss: {lpn_avg_loss}")
    print(f"LPN avg loss - NE avg loss: {lpn_avg_loss - ne_avg_loss}\n")

Shift factor: 0.1
NE avg loss: 1.4885827883922078e-13
LPN avg loss: 9.136728108569515e-05
LPN avg loss - NE avg loss: 9.136728093683687e-05

Shift factor: 0.2
NE avg loss: 2.313394794637238e-13
LPN avg loss: 0.00042952534833239043
LPN avg loss - NE avg loss: 0.00042952534810105096

Shift factor: 0.3
NE avg loss: 1.976973893700598e-13
LPN avg loss: 0.001341682005473558
LPN avg loss - NE avg loss: 0.0013416820052758608

Shift factor: 0.4
NE avg loss: 1.9832702612755963e-13
LPN avg loss: 0.0035485970449153684
LPN avg loss - NE avg loss: 0.0035485970447170414

Shift factor: 0.5
NE avg loss: 2.331373898662305e-13
LPN avg loss: 0.008025354051169416
LPN avg loss - NE avg loss: 0.008025354050936278

