In [None]:
from matplotlib import pyplot as plt
import numpy as np
import sys
sys.path.append('../')
from dataset import *
import seaborn as sns

In [None]:
log_path = '/home/ubuntu/data/logr'
model = 'gan/model=gaussian_wgan-zdim=100-lr=-4.00-rep=2'
iteration = 1

In [None]:
size_path = 'crand_150_0_%d/%s-run=10/samples%d.npz',
location_path = 'crand_150_1_%d/%s-run=10/samples%d.npz'
color_path = 'crand_150_3_%d/%s-run=10/samples%d.npz'

In [None]:
# Plot size
size_configs = [1, 3, 5, 7, 9]
size_hist = []
size_bins = []
for index, config in enumerate(size_configs):
    samples = np.load(size_path % (config, model, iteration))['g']
    vals = DotsDataset.eval_size(samples)
    hist, bins = np.histogram(vals, bins=60, range=(0.3, 0.8))
    hist = hist.astype(np.float)
    hist /= np.sum(hist)
    size_bins.append(0.5*(bins[1:]+bins[:-1]))
    size_hist.append(hist)

def size_to_val(config):
    return config / 20.0 + 0.35

In [None]:
loc_configs = [1, 3, 5, 7, 9]
loc_hist = []
loc_bins = []
for index, config in enumerate(loc_configs):
    samples = np.load(loc_path % (config, model, iteration))['g']
    vals = DotsDataset.eval_location(samples)[:, 0]
    hist, bins = np.histogram(vals, bins=50, range=(-0.4, 0.4))
    hist = hist.astype(np.float)
    hist /= np.sum(hist)
    loc_bins.append(0.5*(bins[1:]+bins[:-1]))
    loc_hist.append(hist)

def loc_to_val(config):
    return (config - 5) / 20.0

In [None]:
# Plot colors
color_configs = [1, 3, 5, 7, 9]
color_hist = []
color_bins = []
for index, config in enumerate(color_configs):
    samples = np.load(color_path % (config, model, iteration))['g']
    vals = DotsDataset.eval_color_proportion(samples)
    hist, bins = np.histogram(vals, bins=50, range=(0, 1))
    hist = hist.astype(np.float)
    hist /= np.sum(hist)
    color_bins.append(0.5*(bins[1:]+bins[:-1]))
    color_hist.append(hist)

def color_to_val(config):
    return color / 10.0

In [None]:
plt.figure(figsize=(15, 4))
plt.subplot(1, 3, 3)
for index in range(len(size_configs)):
    plt.plot(size_bins[index], size_hist[index], 
             label='%.2f' % (size_configs[index] / 10.0), 
             c=sns.color_palette("husl", len(size_configs))[index])
    plt.axvline(x=size_to_val(size_configs[index]), c=sns.color_palette("hls", len(fractions))[index], ls=':', lw=1)
plt.legend()
plt.xlabel('size of generated circle')
plt.ylabel('frequency in samples')

plt.subplot(1, 3, 2)
for index in range(len(size_configs)):
    plt.plot(size_bins[index], size_hist[index], 
             label='%.2f' % (loc_to_val(loc_configs[index])), 
             c=sns.color_palette("husl", len(size_configs))[index])
    plt.axvline(x=loc_to_val(loc_configs[index]), c=sns.color_palette("hls", len(fractions))[index], ls=':', lw=1)
plt.legend()
plt.xlabel('location of generated circle')
plt.ylabel('frequency in samples')

plt.subplot(1, 3, 3)
for index in range(len(color_configs)):
    plt.plot(color_bins[index], color_hist[index], 
             label='%.2f' % (color_to_val(color_configs[index])), 
             c=sns.color_palette("husl", len(size_configs))[index])
    plt.axvline(x=color_to_val(color_configs[index]), c=sns.color_palette("hls", len(fractions))[index], ls=':', lw=1)
plt.legend()
plt.xlabel('proportion of red color')
plt.ylabel('frequency in samples')

plt.tight_layout()
plt.savefig('../results/circle_size_wgan.pdf')
plt.show()