In [38]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.random_projection import GaussianRandomProjection
import os
import sys

sys.path.append('../')
sys.path.append('../src/')

from src.generative import *
from src.utils import *

# Visualization of data-copying in the MNIST case

In [39]:
mnist = MNIST(root='../data.nosync/').fit()
X_train, y_train, X_test, y_test = mnist.get_train_and_test_data()

c = 8 # class 8
X_train_8 = X_train[y_train == c].reshape(-1, 28*28)
X_test_8 = X_test[y_test == c].reshape(-1, 28*28)

# extract random subset of 100 samples
np.random.seed(0)
X_train_8 = X_train_8[np.random.choice(X_train_8.shape[0], 20, replace=False)]
X_test_8 = X_test_8[np.random.choice(X_test_8.shape[0], 20, replace=False)]

In [40]:
q = Memorizer(n_copying=3, radius=0.7)
q.fit(X_train_8)
X_gen = q.sample(20)

In [41]:
rp = GaussianRandomProjection(n_components=2)
X_train_8_rp = rp.fit_transform(X_train_8)
X_test_8_rp = rp.transform(X_test_8)
X_gen_rp = rp.transform(X_gen)
subset_rp = rp.transform(q.subset)

In [45]:
textwidth = set_plotting_params()
colors = sns.color_palette('colorblind')

fig = plt.figure(figsize=(textwidth, 0.4*textwidth))
grid_spec = fig.add_gridspec(1, 2, width_ratios=[1, 1])

# Left panel: Create a 3x5 grid for images, and a row for labels on the left
left_grid = grid_spec[0].subgridspec(3, 6, width_ratios=[0.3, 1, 1, 1, 1, 1])  # First column is for labels
fig.text(0.25, 0.925, '(a) Pixel-Space', va='center', ha='center')
# Row labels
row_labels = ['Train', 'Underfit', 'Copies']

# Add images and row labels
for row in range(3):
    ax_label = fig.add_subplot(left_grid[row, 0])  # Empty subplot for row labels
    ax_label.text(0.5, 0.5, row_labels[row], va='center', ha='center', fontsize=7, rotation=90)
    ax_label.axis('off')  # No axis for labels

    for col in range(5):
        ax_img = fig.add_subplot(left_grid[row, col + 1])  
        if row == 0:
            if col < 3:
                ax_img.imshow(q.subset[col].reshape(28, 28), cmap='gray')
            else:
                ax_img.imshow(X_train_8[col-3].reshape(28, 28), cmap='gray')
        elif row == 1:
            ax_img.imshow(X_test_8[col].reshape(28, 28), cmap='gray')
        else:
            ax_img.imshow(X_gen[col].reshape(28, 28), cmap='gray')
        
        ax_img.axis('off')  # Hide axis for images

# Right panel: Scatter plot
ax2 = fig.add_subplot(grid_spec[0, 1])
ax2.scatter(X_gen_rp[:, 0], X_gen_rp[:, 1], color=colors[3], s=12, marker='x', alpha=0.7, label='Copies')
ax2.scatter(X_train_8_rp[:, 0], X_train_8_rp[:, 1], color=colors[0], s=12, alpha=0.7, label='Train')
ax2.scatter(X_test_8_rp[:, 0], X_test_8_rp[:, 1], color=colors[3], s=12, alpha=0.7, label='Underfit')
ax2.set_title("(b) 2D Random Projection")
# draw a circle around the points of subset_rp
for i in range(subset_rp.shape[0]):
    ax2.add_artist(plt.Circle(subset_rp[i], 1.5, fill=False, color='g', linewidth=0.4))
# remove ticks
ax2.set_xticks([])
ax2.set_yticks([])
ax2.legend(loc='lower right', fontsize=7)

plt.tight_layout()
plt.savefig('../doc/algorithm_vis.png', dpi=300)
plt.close()