# Подготовительный этап

In [None]:
!mkdir imgs

In [None]:
import pandas as pd
from keras.datasets import mnist
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.manifold import TSNE
import cv2
from google.colab.patches import cv2_imshow
from tqdm.notebook import tqdm
from ipywidgets import interact
from sklearn import datasets

%matplotlib inline
plt.style.use('ggplot')

Импортируем датасет MNIST (только тестовую его часть из 10 тыс. изображений)

In [None]:
(_, _), (X, y) = mnist.load_data()

In [None]:
X.shape

In [None]:
fig, ax = plt.subplots(4, 6)
for i, axi in enumerate(ax.flat):
    axi.imshow(X[i], cmap='gray')
    axi.set(xticks=[], yticks=[])
    axi.set_xlabel(y[i], color='black')

In [None]:
X = X.reshape(len(X), -1)
X.shape

# t-SNE (Sklearn)

Рассмотрим [библиотечную реализацию t-SNE](https://scikit-learn.org/dev/modules/generated/sklearn.manifold.TSNE.html) на небольшом наборе синтетических данных.

In [None]:
link = 'https://courses.openedu.ru/assets/courseware/v1/6c89dd85d23926d43494d0e4dd968840/asset-v1:ITMOUniversity+INTROMLADVML+fall_2023_ITMO_mag+type@asset+block/94_16.csv'
data = pd.read_csv(link, header = None)
data.head()

In [None]:
t_SNE_lib = TSNE(n_components=2, init='random', perplexity=30)
data_tSNE_lib = t_SNE_lib.fit_transform(data)
plt.scatter(x=data_tSNE_lib[:, 0], y=data_tSNE_lib[:,1])
plt.show()

# t-SNE на MNIST

In [None]:
n_samples = 10000

X_sampled = X[:n_samples]
t_SNE_lib = TSNE(n_components=2, init='random', perplexity=30)
X_tSNE_lib = t_SNE_lib.fit_transform(X_sampled)
X_tSNE_lib[:3]

Визуализируем результат

In [None]:
plt.rcParams["figure.figsize"] = (12, 8)
sns.scatterplot(x=X_tSNE_lib[:,0], y=X_tSNE_lib[:,1], hue=y, palette=sns.color_palette("hls", 10))
plt.show()

In [None]:
n_samples = 1000
X_sampled = X[:n_samples]

images = []
for perp in tqdm(range(10, 110, 10)):
    t_SNE_lib = TSNE(n_components=2, init='random', perplexity=perp, random_state=42)
    X_tSNE_lib = t_SNE_lib.fit_transform(X_sampled)
    images.append(X_tSNE_lib)

In [None]:
for idx, img in enumerate(images):
    plt.rcParams["figure.figsize"] = (12, 8)
    sns.scatterplot(x=img[:,0], y=img[:,1], hue=y[:n_samples], palette=sns.color_palette("hls", 10))
    plt.axis('off')
    plt.savefig('imgs/step_'+str(idx)+'.png')
    plt.close()

In [None]:
digits = []
for idx, img in enumerate(images):
    digits.append(cv2.imread('imgs/step_'+str(idx)+'.png'))

In [None]:
def browse_images(digits):
    n = len(digits)
    def view_image(i):
        plt.imshow(digits[i])
        plt.title('Perplexity: ' + str((i+1)*10))
        plt.axis('off')
        plt.show()
    interact(view_image, i=(0, n-1))

In [None]:
browse_images(digits)