In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from scipy.stats import norm
from sklearn.random_projection import GaussianRandomProjection
import os
import sys

sys.path.append('../')
sys.path.append('../src/')

from src.generative import *
from src.utils import *

In [71]:
# learning in generative models
textwidth = set_plotting_params()
colors = sns.color_palette('colorblind')
markers = ['o', 'x', 's', '^', 'v', '<', '>', 'd', 'p', 'P']


fig, axs = plt.subplots(1, 2, figsize=(textwidth, 0.5*textwidth))

# Left subplot: Uniformly draw points under the PDF curve
x_train = np.linspace(-2, 6, 300)
mean_train, std_train = 2, 1
pdf_train = norm.pdf(x_train, mean_train, std_train)

# Filter points that are truly under the curve (using rejection sampling)
x_uniform = np.random.uniform(-2, 6, 1000)  # Randomly sample x points
y_uniform = np.random.uniform(0, max(pdf_train), 1000)  # Randomly sample y points
under_curve = y_uniform < norm.pdf(x_uniform, mean_train, std_train)  # Keep points under the curve

axs[0].scatter(x_uniform[under_curve], y_uniform[under_curve], s=5,  alpha=0.8, color=colors[0], marker=markers[0])
axs[0].set_title('(a) Training Data')
axs[0].set_xlabel('x')
axs[0].set_yticks([])

# Right subplot: PDF of learned Gaussian distribution
x = np.linspace(-2, 6, 500)
mean_learned, std_learned = 2, 1
pdf = norm.pdf(x, mean_learned, std_learned)

axs[1].plot(x, pdf, color=colors[1])
axs[1].fill_between(x, pdf, color=colors[1], alpha=0.3)
axs[1].set_title('(b) Learned Generative Model q')
axs[1].set_xlabel('x')
axs[1].set_yticks([])

# set x-limits in both subplots to the same
axs[0].set_xlim(-2, 6)
axs[1].set_xlim(-2, 6)

plt.tight_layout()
plt.savefig('../doc/Presentation/LearningGoal.png', dpi=300)
plt.close()



# Visualization of data-copying

## Mixed Generative Model

In [55]:
# Mixed Generative Model
textwidth = set_plotting_params()
colors = sns.color_palette('colorblind')
markers = ['o', 'x', 's', '^', 'v', '<', '>', 'd', 'p', 'P']

fig, axs = plt.subplots(1, 3, figsize=(textwidth, 2))

# (a) Underfitting Component
X = Halfmoons(noise=0.1).sample(1000)
underfit = Memorizer(radius=0.25, n_copying=len(X)).fit(X)
X_underfit = underfit.sample(2000)

axs[0].scatter(X_underfit[:, 0], X_underfit[:, 1], alpha=0.7, s=5, color=colors[1], marker=markers[1], label="Generated")
axs[0].scatter(X[:, 0], X[:, 1], alpha=0.7, s=1, color=colors[0], marker=markers[0], label="Train")
axs[0].set_title("(a) $q_{underfit}$")
axs[0].set_xticks([])
axs[0].set_yticks([])
axs[0].legend()

# (c) Mixture
X = Halfmoons(noise=0.1).sample(50)
underfit = Memorizer(radius=0.25, n_copying=len(X))
copying = Memorizer(radius=0.05, n_copying=5)
q = Mixture(rho=0.4, q1=copying, q2=underfit).fit(X)
X_mixture = q.sample(200)

axs[2].scatter(X_mixture[:, 0], X_mixture[:, 1], alpha=0.5, s=5, color=colors[1], marker=markers[1], label="Generated")
axs[2].scatter(X[:, 0], X[:, 1], alpha=0.5, s=1, color=colors[0], marker=markers[0], label="Train")

for i in range(q.q1.subset.shape[0]):
    axs[2].add_patch(plt.Circle(q.q1.subset[i], 0.07, color="black", fill=False, linewidth=0.3))

axs[2].set_title("(c) $q = \\rho q_{copy} + (1 - \\rho) q_{underfit}$")
axs[2].set_xticks([])
axs[2].set_yticks([])

# (b) Copying Component
# X = Halfmoons(noise=0.1).sample(50)
# copying = Memorizer(radius=0.05, n_copying=5).fit(X)
X_copying = copying.sample(100)

axs[1].scatter(X_copying[:, 0], X_copying[:, 1], alpha=0.5, s=5, color=colors[1], marker=markers[1], label="Generated")
axs[1].scatter(X[:, 0], X[:, 1], alpha=0.5, s=1, color=colors[0], marker=markers[0], label="Train")

# for i in range(copying.subset.shape[0]):
#     axs[1].add_patch(plt.Circle(copying.subset[i], 0.05, color="black", fill=False, linewidth=0.7))
axs[1].set_title("(b) $q_{copy}$")
axs[1].set_xticks([])
axs[1].set_yticks([])



plt.tight_layout()
plt.savefig("../doc/MixedModel.png", dpi=300)
plt.close()

## Intuition using MNIST

In [2]:
mnist = MNIST(root='../data.nosync/').fit()
X_train, y_train, X_test, y_test = mnist.get_train_and_test_data()

c = 8 # class 8
X_train_8 = X_train[y_train == c].reshape(-1, 28*28)
X_test_8 = X_test[y_test == c].reshape(-1, 28*28)

# extract random subset of 100 samples
np.random.seed(0)
X_train_8 = X_train_8[np.random.choice(X_train_8.shape[0], 20, replace=False)]
X_test_8 = X_test_8[np.random.choice(X_test_8.shape[0], 20, replace=False)]

In [3]:
q = Memorizer(n_copying=3, radius=0.8)
q.fit(X_train_8)
X_gen = q.sample(20)

In [4]:
rp = GaussianRandomProjection(n_components=2)
X_train_8_rp = rp.fit_transform(X_train_8)
X_test_8_rp = rp.transform(X_test_8)
X_gen_rp = rp.transform(X_gen)
subset_rp = rp.transform(q.subset)

In [7]:
textwidth = set_plotting_params()
colors = sns.color_palette('colorblind')
markers = ['o', 'x', 's', '^', 'v', '<', '>', 'd', 'p', 'P']

fig = plt.figure(figsize=(textwidth, 0.3*textwidth))
grid_spec = fig.add_gridspec(1, 2, width_ratios=[1, 1])

# Left panel: Create a 3x5 grid for images, and a row for labels on the left
left_grid = grid_spec[0].subgridspec(2, 6, width_ratios=[0.3, 1, 1, 1, 1, 1])  # First column is for labels
fig.text(0.25, 0.925, '(a) Pixel-Space', va='center', ha='center')
# Row labels
row_labels = ['Train', 'Generated']

# Add images and row labels
for row in range(2):
    ax_label = fig.add_subplot(left_grid[row, 0])  # Empty subplot for row labels
    ax_label.text(0.5, 0.5, row_labels[row], va='center', ha='center', fontsize=7, rotation=90)
    ax_label.axis('off')  # No axis for labels

    for col in range(5):
        ax_img = fig.add_subplot(left_grid[row, col + 1])  
        if row == 0:
            if col < 3:
                ax_img.imshow(q.subset[col].reshape(28, 28), cmap='gray')
            else:
                ax_img.imshow(X_train_8[col-3].reshape(28, 28), cmap='gray')
        elif row == 1:
            ax_img.imshow(X_gen[col].reshape(28, 28), cmap='gray')
            
        
        ax_img.axis('off')  # Hide axis for images

# Right panel: Scatter plot
ax2 = fig.add_subplot(grid_spec[0, 1])
ax2.scatter(X_gen_rp[:, 0], X_gen_rp[:, 1], color=colors[1], s=12, marker=markers[1], alpha=0.7, label='Generated')
ax2.scatter(X_train_8_rp[:, 0], X_train_8_rp[:, 1], color=colors[0], s=12, marker=markers[0], alpha=0.7, label='Train')
# ax2.scatter(X_test_8_rp[:, 0], X_test_8_rp[:, 1], color=colors[2], s=12, marker=markers[2], alpha=0.7, label='Validation')
ax2.set_title("(b) 2D Random Projection")
# draw a circle around the points of subset_rp
for i in range(subset_rp.shape[0]):
    ax2.add_artist(plt.Circle(subset_rp[i], 2, fill=False, color='black', linewidth=0.7))
# remove ticks
ax2.set_xticks([])
ax2.set_yticks([])
ax2.legend(loc='lower right', fontsize=7)

plt.tight_layout()
plt.savefig('../doc/algorithm_vis.png', dpi=300)
plt.close()

## Vis Defintion

In [7]:
# first concept: Blow-up
X_bu = np.array([[0, -.25], [-0.5, 0], [0.5, 0.5]])

# copies of X[0] (sample normal noise to make it look like a cloud)
copies_x0 = np.random.normal(0, 0.1, (10, 2)) + X_bu[0]
valid_x0 = np.random.normal(0, 0.2, (2, 2)) + X_bu[0]
copies_x1 = np.random.normal(0, 0.2, (3, 2)) + X_bu[1]
valid_x1 = np.random.normal(0, 0.2, (2, 2)) + X_bu[1]
copies_x2 = np.random.normal(0, 0.2, (3, 2)) + X_bu[2]
valid_x2 = np.random.normal(0, 0.2, (2, 2)) + X_bu[2]

# stack copies to generated set
X_gen_bu = np.vstack([copies_x0, copies_x1, copies_x2])
X_valid_bu = np.vstack([valid_x0, valid_x1, valid_x2])

In [8]:
# second concept: Locality
X = np.random.rand(10, 2)
X_val = np.random.rand(10, 2)

q = Memorizer(n_copying=3, radius=0.075)
X_gen = q.fit(X).sample(15)
X3 = np.random.rand(10, 2)

# stack X_gen and X3
X_gen = np.vstack([X_gen, X3])

In [9]:
textwidth = set_plotting_params()
colors = sns.color_palette('colorblind')
markers = ['o', 'x', 's', '^', 'v', '<', '>', 'd', 'p', 'P']

fig, axs = plt.subplots(1, 2, figsize=(textwidth, 0.5*textwidth))

# blow-up
axs[0].scatter(X_bu[:, 0], X_bu[:, 1], color=colors[0], s=12, marker=markers[0], alpha=0.7, label='Train')
axs[0].scatter(X_gen_bu[:, 0], X_gen_bu[:, 1], color=colors[1], s=12, marker=markers[1], alpha=0.7, label='Generated')
axs[0].scatter(X_valid_bu[:, 0], X_valid_bu[:, 1], color=colors[2], s=12, marker=markers[2], alpha=0.7, label='Validation')
axs[0].add_artist(plt.Circle(X_bu[0], 0.2, fill=False, color='black', linewidth=0.7))

# make plot quadratic by adjusting limits
axs[0].set_xlim(-1, 1)
axs[0].set_ylim(-1, 1)

# locality
axs[1].scatter(X[:, 0], X[:, 1], color=colors[0], s=12, marker=markers[0], alpha=0.7, label='Train')
axs[1].scatter(X_gen[:, 0], X_gen[:, 1], color=colors[1], s=12, marker=markers[1], alpha=0.7, label='Generated')
axs[1].scatter(X_val[:, 0], X_val[:, 1], color=colors[2], s=12, marker=markers[2], alpha=0.7, label='Validation')
# draw circles around q.subset
for i in range(q.subset.shape[0]):
    axs[1].add_artist(plt.Circle(q.subset[i], 0.080, fill=False, color='black', linewidth=0.7))

# titles
axs[0].set_title("(a) Blow-up", fontsize=8)
axs[1].set_title("(b) Locality", fontsize=8)
axs[0].legend(loc='upper left', fontsize=8)

# remove ticks
for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.savefig('../doc/data_copying.png', dpi=300)
plt.close()

In [11]:
# produce single plot of figure above for the presentation

# blow-up
fig, ax = plt.subplots(1, 1, figsize=(0.5*textwidth, 0.5*textwidth))

ax.scatter(X_bu[:, 0], X_bu[:, 1], color=colors[0], s=12, marker=markers[0], alpha=0.7, label='Train')
ax.scatter(X_gen_bu[:, 0], X_gen_bu[:, 1], color=colors[1], s=12, marker=markers[1], alpha=0.7, label='Generated')
ax.scatter(X_valid_bu[:, 0], X_valid_bu[:, 1], color=colors[2], s=12, marker=markers[2], alpha=0.7, label='Validation')
ax.add_artist(plt.Circle(X_bu[0], 0.2, fill=False, color='black', linewidth=0.7))
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_xticks([])
ax.set_yticks([])
ax.legend(loc='upper left', fontsize=8)
plt.tight_layout()
plt.savefig('../doc/Presentation/data_copying_blowup.png', dpi=300)
plt.close()


# locality
fig, ax = plt.subplots(1, 1, figsize=(0.5*textwidth, 0.5*textwidth))

ax.scatter(X[:, 0], X[:, 1], color=colors[0], s=12, marker=markers[0], alpha=0.7, label='Train')
ax.scatter(X_gen[:, 0], X_gen[:, 1], color=colors[1], s=12, marker=markers[1], alpha=0.7, label='Generated')
ax.scatter(X_val[:, 0], X_val[:, 1], color=colors[2], s=12, marker=markers[2], alpha=0.7, label='Validation')
# draw circles around q.subset
for i in range(q.subset.shape[0]):
    ax.add_artist(plt.Circle(q.subset[i], 0.080, fill=False, color='black', linewidth=0.7))

ax.set_xticks([])
ax.set_yticks([])
ax.legend(loc='upper left', fontsize=8)
plt.tight_layout()
plt.savefig('../doc/Presentation/data_copying_locality.png', dpi=300)
plt.close()