In [1]:
import math
import time
import sqlite3
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from tqdm import tqdm

import patch_nr_paper_model
import utils_patch_nr_paper
from dataset.FastPatchExtractor import FastImageLoader
from flow_models.PatchFlowModel import PatchFlowModel
from regularisers import PatchNrRegulariser

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using Device: {DEVICE}')

Using Device: cuda


In [None]:
model_path = 'results/patch_nr/custom_patch_nr/version_3/custom_patch_nr_final.pth'
image_path = 'data/material_pt_nr/train.png'
name = 'material_prior_variance_experiment'
result_path = 'results/prior_variance/'
start = 1000
end = 150000
step_size = 10000
samples_per_step = 200

In [None]:
model = PatchFlowModel(path=model_path)
patch_size = int(math.sqrt(model.hparams['dimension']))

In [None]:
normalization = transforms.Compose([transforms.Normalize([0, ], [255., ])])
image_loader = FastImageLoader(image_path, device=DEVICE, transform=normalization)
ground_truth = image_loader[0].to(DEVICE)

In [None]:
steps = [i for i in range(start, end, step_size)]

In [None]:
res = []
times = []
for current_step in range(start, end, step_size):
    custom_regulariser = PatchNrRegulariser(model, p_size=patch_size, sample_number=current_step, padding=True, padding_size=16, device=DEVICE)
    values = []
    start_time = time.time()
    for i in tqdm(range(250)):
        custom_val = custom_regulariser.evaluate(ground_truth)
        values.append(custom_val.item())
    end_time = time.time()
    values = torch.tensor(values)
    val_min = torch.min(values)
    val_mean = torch.mean(values)
    val_max = torch.max(values)
    res.append((val_min, val_mean, val_max))
    times.append(end_time-start_time)