## Jupyter notebook to plot the Gaussian distribution of the sampled fourier features 
### To study the effect of feature size of Fourier Encoding, 12 networks are trained with different combinations of feature size and standard deviation
### This notebook runs the saved models for all the runs and plots the distribution of sampled fourier features

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# global params

width = 512
height = 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# volume params
center = torch.Tensor([0.0, 0.0, 0.0])
radius = 1.0
vol_params = ("sphere", center, radius)

np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f1f580334d0>

In [3]:
# model paths
model_paths = {}
model_paths['25_64'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_25.2_64/models/200.pth'
model_paths['25_128'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_25.2_128/models/200.pth'
model_paths['25_256'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_25.2_256/models/200.pth'
model_paths['25_512'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_25.2_512/models/200.pth'

model_paths['38_64'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_37.8_64/models/200.pth'
model_paths['38_128'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_37.8_128/models/200.pth'
model_paths['38_256'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_37.8_256/models/200.pth'
model_paths['38_512'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_37.8_512/models/200.pth'

model_paths['50_64'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_50.4_64/models/200.pth'
model_paths['50_128'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_50.4_128/models/200.pth'
model_paths['50_256'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_50.4_256/models/200.pth'
model_paths['50_512'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_50.4_512/models/200.pth'

model_paths['63_64'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_62.84_64/models/200.pth'
model_paths['63_128'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_62.84_128/models/200.pth'
model_paths['63_256'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_62.84_256/models/200.pth'
model_paths['63_512'] = '/home/parika/WorkingDir/Thesis/Code/dvr/outputs/ejecta/dvr/effectOfFourierFeatures/gauss_62.84_512/models/200.pth'

In [4]:
standard_deviation = ['25', '38', '50', '63']
# standard_deviation = ['63']
feature_size = ['64', '128', '256', '512']
labels_fs = ['features_%s'%fs for fs in feature_size]
labels_sd = ['sd_%s'%sd for sd in standard_deviation]

colors_features = ['#F6A131','#0D7336','#4400FF','#FF0000']

keys = ['%s_%s'%(sd,fs) for sd in standard_deviation for fs in feature_size]

input_maps = {}

for key in keys:
    checkpoint = torch.load(model_paths[key], map_location=device)
    input_maps[key] = checkpoint['input_map'] * 2 * np.pi

In [6]:
out_dir = '/home/parika/WorkingDir/Thesis/Documentation/ejecta/EffectOfFeatureSize/freq_dist_hist/'

for fs in feature_size:
    plot_file = out_dir + 'fs_' + fs + '.png'
    plt.figure(figsize=(10,4))

    data = []
    for count, sd in enumerate(standard_deviation):
        key = '%s_%s'%(sd,fs)
        data.append(torch.flatten(input_maps[key]).cpu().numpy())
        
    n, bins, _ = plt.hist(data,histtype='bar',bins=10,color=colors_features, label=labels_sd)
    plt.grid(True, which='major', alpha=.3)
    plt.xlabel('Values sampled from the distribution',fontsize=15)
    plt.ylabel('Number of occurrences',fontsize=15)
    plt.xticks(bins.astype(np.int))
    # plt.yticks(range(0,n.max().astype(np.int),100))
    plt.legend(prop={'size':10})
    plt.savefig(plot_file)
    plt.close()
    # plt.show()