In [1]:
import torch
import os
import sys
import re
from shutil import copyfile
import numpy as np


sys.path.append('../loss')
from watson import WatsonDistance
from watson_fft import WatsonDistanceFft
from watson_vgg import WatsonDistanceVgg
from color_wrapper import ColorWrapper

# Save params for use outside of evaluation framework

In [2]:
def params_remove_prefix(params):
    old_keys = list(params.keys())
    for k in old_keys:
        params[k[4:]] = params[k]
        params.pop(k)
    return params

In [10]:
weight_dirs = os.listdir('./checkpoints/')

for weight_dir in weight_dirs:
    path=os.path.join('./checkpoints/', weight_dir, 'latest_net_.pth')
    state_dict = torch.load(path, map_location='cpu')
    if not('pnet_lin' in weight_dir):
        state_dict = params_remove_prefix(state_dict)
    torch.save(state_dict, os.path.join('../loss/weights', weight_dir + '.pth'))

# Inspect weights

In [8]:
weight_dirs = os.listdir('./checkpoints/')
weight_dirs

['rgb_watson_vgg_trial0',
 'gray_watson_dct_trial0',
 'gray_pnet_lin_vgg_trial0',
 'gray_watson_vgg_trial0',
 'rgb_watson_dct_trial0',
 'rgb_pnet_lin_vgg_trial0',
 'rgb_watson_fft_trial0',
 'gray_watson_fft_trial0']

In [21]:
def get_parameter_count(state_dict):
    size = 0
    for key in state_dict:
        size += torch.prod(torch.tensor(state_dict[key].shape)).item()
    return int(size)

weight_dirs = os.listdir('./checkpoints/')

for weight_dir in weight_dirs:
    path=os.path.join('./checkpoints/', weight_dir, 'latest_net_.pth')
    state_dict = torch.load(path, map_location='cpu')
    size = get_parameter_count(state_dict)
    print("model {} has {} parameters".format(weight_dir, size))

model rgb_watson_vgg_trial0 has 14716172 parameters
model gray_watson_dct_trial0 has 132 parameters
model gray_pnet_lin_vgg_trial0 has 1472 parameters
model gray_watson_vgg_trial0 has 14716172 parameters
model rgb_watson_dct_trial0 has 411 parameters
model rgb_watson_fft_trial0 has 267 parameters
model gray_watson_fft_trial0 has 84 parameters


In [23]:
path=os.path.join('./checkpoints/', 'rgb_watson_vgg_trial0', 'latest_net_.pth')
torch.load(path, map_location='cpu')

OrderedDict([('net.shift', tensor([[[[-0.0300]],
              
                       [[-0.0880]],
              
                       [[-0.1880]]]])), ('net.scale', tensor([[[[0.4580]],
              
                       [[0.4480]],
              
                       [[0.4500]]]])), ('net.t0_tild',
              tensor([ 1.6665,  1.3944,  1.8071,  0.4880,  1.5426,  1.2504,  1.1417,  1.4551,
                       1.5342, -1.6020,  1.2098,  1.6442,  1.3703,  1.6884,  1.4577, -0.3602,
                       1.6789,  0.6749,  1.4583, -1.8348, -1.7571, -0.2289,  1.8768, -1.1867,
                       1.1493,  1.8374, -1.1729,  1.4361, -1.5052,  1.1556,  0.9521, -0.1256,
                       1.1809,  1.6935,  1.7781,  1.4170,  1.2724, -1.7293,  1.0174,  1.5716,
                      -1.1117, -1.8404,  1.7971,  1.5930,  1.3669,  1.4646, -1.5323,  1.7043,
                       1.8134,  1.8625,  1.8061,  1.5721,  1.7469,  1.2741,  1.2116, -1.1692,
                      -1.3987,  

# Grayscale

In [3]:
path=os.path.join('./checkpoints/', 'gray_watson_dct_trial0', 'latest_net_.pth')
l = WatsonDistance()
l.load_state_dict(params_remove_prefix(torch.load(path, map_location='cpu')))
print('luminance alpha: ', l.alpha.item())
print('contrast w: ', l.w.item())
print('pooling beta: ', l.beta.item())
print('QM: ', l.t)

luminance alpha:  -0.0067008789628744125
contrast w:  0.1926499456167221
pooling beta:  0.9743978977203369
QM:  tensor([[ 1.0988,  0.7369,  3.3444, 10.9491, 12.5315,  1.2571,  2.6511,  0.2575],
        [ 0.7858,  5.2716,  1.5247,  2.4578,  2.6859,  1.2124,  0.8375,  0.3511],
        [ 5.0928,  2.8245, 14.9405,  9.7730,  3.6165,  1.2752,  0.9091,  0.4840],
        [10.3639,  5.9086, 10.8770,  8.0454,  2.4509,  1.2691,  0.9340,  0.3341],
        [10.8749,  5.4385,  5.0274,  2.6190,  0.9379,  1.1439,  0.9016,  0.3753],
        [ 1.8720,  1.7446,  1.6660,  1.2375,  0.9895,  0.6829,  0.6728,  0.2311],
        [ 4.3739,  1.0282,  1.0608,  1.0639,  0.9034,  0.4963,  0.6014,  0.6030],
        [ 0.1862,  0.6187,  0.7538,  0.4654,  0.4612,  0.2287,  0.5607,  0.2649]])


In [4]:
path=os.path.join('./checkpoints/', 'gray_watson_fft_trial0', 'latest_net_.pth')
l = WatsonDistanceFft()
l.load_state_dict(params_remove_prefix(torch.load(path, map_location='cpu')))
print('luminance alpha: ', l.alpha.item())
print('contrast w: ', l.w.item())
print('pooling beta: ', l.beta.item())
print('QM: ', l.t)
print('Phase Weights: ', l.w_phase)

luminance alpha:  -0.020072540268301964
contrast w:  0.1594894528388977
pooling beta:  0.5561908483505249
QM:  tensor([[0.4125, 1.5324, 2.2538, 3.3158, 0.6956],
        [1.7432, 0.5361, 0.5019, 0.7160, 0.6297],
        [2.6141, 0.4985, 0.3500, 0.4961, 0.4433],
        [3.6895, 0.7436, 0.5119, 0.5728, 0.4197],
        [0.7143, 0.7332, 0.4649, 0.6083, 0.2855],
        [3.7143, 0.7388, 0.5131, 0.5742, 0.4166],
        [2.6253, 0.4966, 0.3395, 0.5095, 0.4429],
        [1.7346, 0.5422, 0.4921, 0.7015, 0.6270]])
Phase Weights:  tensor([[0.0000, 0.0478, 0.0415, 0.0458, 0.0000],
        [0.0478, 0.0381, 0.0353, 0.0362, 0.0400],
        [0.0433, 0.0352, 0.0365, 0.0381, 0.0450],
        [0.0471, 0.0358, 0.0377, 0.0478, 0.0692],
        [0.0000, 0.0381, 0.0442, 0.0598, 0.0000],
        [0.0471, 0.0359, 0.0373, 0.0471, 0.0690],
        [0.0436, 0.0351, 0.0358, 0.0381, 0.0453],
        [0.0478, 0.0381, 0.0352, 0.0363, 0.0399]])


In [5]:
path=os.path.join('./checkpoints/', 'gray_watson_vgg_trial0', 'latest_net_.pth')
l = WatsonDistanceVgg()
l.load_state_dict(params_remove_prefix(torch.load(path, map_location='cpu')), strict=False)
print('luminance alpha: ', l.alpha.item())
print('contrast w: ', l.w.item())
print('pooling beta: ', l.beta.item())
print('QM: ', l.t)

luminance alpha:  -0.020072540268301964
contrast w:  0.1594894528388977
pooling beta:  0.5561908483505249
QM:  tensor([[0.4125, 1.5324, 2.2538, 3.3158, 0.6956],
        [1.7432, 0.5361, 0.5019, 0.7160, 0.6297],
        [2.6141, 0.4985, 0.3500, 0.4961, 0.4433],
        [3.6895, 0.7436, 0.5119, 0.5728, 0.4197],
        [0.7143, 0.7332, 0.4649, 0.6083, 0.2855],
        [3.7143, 0.7388, 0.5131, 0.5742, 0.4166],
        [2.6253, 0.4966, 0.3395, 0.5095, 0.4429],
        [1.7346, 0.5422, 0.4921, 0.7015, 0.6270]])
Phase Weights:  tensor([[0.0000, 0.0478, 0.0415, 0.0458, 0.0000],
        [0.0478, 0.0381, 0.0353, 0.0362, 0.0400],
        [0.0433, 0.0352, 0.0365, 0.0381, 0.0450],
        [0.0471, 0.0358, 0.0377, 0.0478, 0.0692],
        [0.0000, 0.0381, 0.0442, 0.0598, 0.0000],
        [0.0471, 0.0359, 0.0373, 0.0471, 0.0690],
        [0.0436, 0.0351, 0.0358, 0.0381, 0.0453],
        [0.0478, 0.0381, 0.0352, 0.0363, 0.0399]])


# Color

In [18]:
path=os.path.join('./checkpoints/', 'rgb_watson_dct_trial0', 'latest_net_.pth')
l = ColorWrapper(WatsonDistance, (), {'trainable': False, 'blocksize': 8}, trainable=False)
l.load_state_dict(params_remove_prefix(torch.load(path, map_location='cpu')))
print('weights Y Cb Cr:', l.w)
for name, single_channel_model in [('L', l.ly), ('Cb', l.lcb), ('Cr', l.lcr)]:
    print(name, ' luminance alpha: ', single_channel_model.alpha.item())
    print(name, ' contrast w: ', single_channel_model.w.item())
    print(name, ' pooling beta: ', single_channel_model.beta.item())
    print(name, ' QM: ', single_channel_model.t)

weights Y Cb Cr: tensor([0.1719, 0.3905, 0.4376])
L  luminance alpha:  0.038675736635923386
L  contrast w:  0.7726239562034607
L  pooling beta:  3.100041389465332
L  QM:  tensor([[ 1.7759,  1.6741,  2.2121,  3.8511,  3.4907,  1.2192,  2.7582,  0.6499],
        [ 1.8539,  4.9332,  1.4256,  0.0925,  1.1353,  0.3349,  0.4021,  1.1729],
        [ 2.7925,  1.6860,  5.4023,  2.5244,  1.3205,  0.3462,  0.4044,  0.5511],
        [ 3.0943,  1.5132,  3.6586,  3.0976,  1.2802,  1.0769,  0.7096,  0.9958],
        [ 2.9054,  1.4571,  1.9852,  2.2368,  1.3320,  1.8620,  2.9352,  7.6524],
        [ 1.9409,  1.1379,  0.8128,  1.2723,  1.4484,  5.2171,  9.8331, 12.7165],
        [ 3.8321,  1.0217,  0.8770,  1.2383,  2.8755,  9.8815, 13.9435, 17.0857],
        [ 0.6075,  1.2019,  1.3003,  1.7630,  7.6920, 12.7506, 17.0824, 20.7693]])
Cb  luminance alpha:  -0.1839788407087326
Cb  contrast w:  0.5128593444824219
Cb  pooling beta:  2.1932907104492188
Cb  QM:  tensor([[1.1051, 1.7078, 4.7066, 7.5938, 6.7558

In [19]:
path=os.path.join('./checkpoints/', 'rgb_watson_fft_trial0', 'latest_net_.pth')
l = ColorWrapper(WatsonDistanceFft, (), {'trainable': False, 'blocksize': 8}, trainable=False)
l.load_state_dict(params_remove_prefix(torch.load(path, map_location='cpu')))
print('weights Y Cb Cr:', l.w)
for name, single_channel_model in [('L', l.ly), ('Cb', l.lcb), ('Cr', l.lcr)]:
    print(name, ' luminance alpha: ', single_channel_model.alpha.item())
    print(name, ' contrast w: ', single_channel_model.w.item())
    print(name, ' pooling beta: ', single_channel_model.beta.item())
    print(name, ' QM: ', single_channel_model.t)
    print(name, 'Phase Weights: ', single_channel_model.w_phase)

weights Y Cb Cr: tensor([0.3504, 0.3059, 0.3437])
L  luminance alpha:  -0.0115112429484725
L  contrast w:  0.2857496738433838
L  pooling beta:  0.6494485139846802
L  QM:  tensor([[ 0.9073,  3.4689,  4.7331, 16.1478,  0.3566],
        [ 4.2219,  0.1456,  0.1165,  0.4201,  0.4759],
        [ 6.1246,  0.0998,  0.0697,  0.2826,  0.3909],
        [18.7375,  0.3907,  0.2816,  1.2255,  1.9442],
        [ 0.4733,  0.5170,  0.4409,  2.2793,  0.2361],
        [18.9578,  0.3965,  0.2944,  1.2229,  1.9370],
        [ 6.2007,  0.0977,  0.0669,  0.2751,  0.3903],
        [ 4.1944,  0.1486,  0.1066,  0.4141,  0.4695]])
L Phase Weights:  tensor([[0.0000, 0.0061, 0.0062, 0.0075, 0.0000],
        [0.0064, 0.0055, 0.0055, 0.0062, 0.0081],
        [0.0073, 0.0056, 0.0063, 0.0074, 0.0095],
        [0.0079, 0.0062, 0.0072, 0.0105, 0.0158],
        [0.0000, 0.0078, 0.0091, 0.0154, 0.0000],
        [0.0079, 0.0062, 0.0070, 0.0098, 0.0159],
        [0.0073, 0.0056, 0.0060, 0.0073, 0.0096],
        [0.0064, 0.0

In [None]:
path=os.path.join('./checkpoints/', 'rgb_watson_vgg_trial0', 'latest_net_.pth')
l = WatsonDistanceVgg()
l.load_state_dict(params_remove_prefix(torch.load(path, map_location='cpu')), strict=False)
print('luminance alpha: ', l.alpha.item())
print('contrast w: ', l.w.item())
print('pooling beta: ', l.beta.item())
print('QM: ', l.t)