In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import numpy as np

sns.set(font_scale=2., style='whitegrid')

## Illustrative

In [None]:
fig, axes = plt.subplots(figsize=(12,5), ncols=2, sharey=True, sharex=True)

X = 4. * np.random.randn(100, 2)
n = len(X)
y = torch.cat([torch.zeros(n // 2), torch.ones(n // 2)]).long().numpy()

ax = axes[0]
scatter = ax.scatter(X[:, 0], X[:, 1],
                    c=y, cmap=sns.color_palette('Set1', as_cmap=True))

ax.set(xlabel='x', ylabel='y', title=f'A Bad Dataset')

from sklearn.datasets import make_blobs
X, y = make_blobs(100, 2, centers=2, random_state=137)
n = len(X)

ax = axes[1]
scatter = ax.scatter(X[:, 0], X[:, 1],
                    c=y, cmap=sns.color_palette('Set1', as_cmap=True))

ax.set(xlabel='x', title=f'A Good Dataset')

legend = ax.legend(*scatter.legend_elements())
ax.add_artist(legend)

fig.tight_layout()
fig.show()

# fig.savefig('illustration.pdf', bbox_inches='tight')

## Visualize Training

In [None]:
fig, axes = plt.subplots(figsize=(24,5), ncols=4, sharey=True, sharex=True)

for i in range(4):
    X = torch.load(f'.log/gpus-{2**i}/files/model.pt', map_location='cpu')['module.X'].numpy()
    n = len(X)
    y = torch.cat([torch.zeros(n // 2), torch.ones(n // 2)]).long().numpy()

    ax = axes[i]
    scatter = ax.scatter(X[:, 0], X[:, 1],
                        c=y, cmap=sns.color_palette('Set1', as_cmap=True))

    ax.set(xlabel='x', title=f'{2**i} GPU{"s" if i else ""}')
    if i == 0:
        ax.set(ylabel='y')

    if i == 3:
        legend = ax.legend(*scatter.legend_elements())
        ax.add_artist(legend)

fig.tight_layout()
fig.show()

# fig.savefig('toy_learn.pdf', bbox_inches='tight')