In [1]:
import time
import copy
import numpy as np
import pandas as pd
import seaborn as sn
from tqdm import tqdm
import matplotlib.pyplot as plt

import sklearn
from sklearn.manifold import TSNE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

from models import *
from utils import *
from datasets import *

In [2]:
import torchvision as tv
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.utils import save_image

# Data Preprocessing

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
trainTransform  = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Normalize((0.1307,), (0.3081,))])
trainset = tv.datasets.MNIST(root='./data',  train=True, download=False, transform=transform)
testset = tv.datasets.MNIST(root='./data',  train=False, download=False, transform=transform)

# x = trainset[0][0].reshape(28,28)
# x.shape

# sample
idx = np.random.choice(np.arange(len(trainset)), size=1000)
trainset = [trainset[i] for i in idx]
idx = np.random.choice(np.arange(len(testset)), size=1000)
testset = [testset[i] for i in idx]

traindata = [i[0].unsqueeze(0) for i in trainset]
trainlabel = [i[1] for i in trainset]
testdata = [i[0].unsqueeze(0) for i in testset]
testlabel = [i[1] for i in testset]

X_train = torch.vstack(traindata)
y_train = torch.tensor(trainlabel)
X_test = torch.vstack(testdata)
y_test = torch.tensor(testlabel)

X_train.shape, X_test.shape, y_train.shape, y_test.shape

(torch.Size([1000, 1, 28, 28]),
 torch.Size([1000, 1, 28, 28]),
 torch.Size([1000]),
 torch.Size([1000]))

# umap + emb

In [8]:
import umap

time_cost = []

for n in [5, 10, 20, 50, 100, 200]:
    
    for min_dist in [0.03, 0.1, 0.25, 0.5, 0.8, 0.99]:

        t1 = time.time()
        reducer = umap.UMAP(n_neighbors=n, min_dist=min_dist, n_components=2, metric='euclidean')
        emb = reducer.fit_transform(X_train.view(-1, 28*28))
        t2 = time.time()
        print("n_neighbors: {}, min_dist:{}, time: {}".format(n, min_dist, t2-t1))
        time_cost.append(t2-t1)

        draw_z(
            z=normalise(emb), 
            cls=y_train, 
            s=1, 
            title="n: {}, min_dist: {}, time: {}".format(n, min_dist, t2-t1), 
            display=False, 
            save_path="./data/facts_check/vis_umapemb_n{}_mindist{}.png".format(n, min_dist)
        )

n_neighbors: 5, min_dist:0.03, time: 2.309999704360962
n_neighbors: 5, min_dist:0.1, time: 2.2879889011383057
n_neighbors: 5, min_dist:0.25, time: 2.6669931411743164
n_neighbors: 5, min_dist:0.5, time: 2.2809908390045166
n_neighbors: 5, min_dist:0.8, time: 2.2739901542663574
n_neighbors: 5, min_dist:0.99, time: 2.287997007369995
n_neighbors: 10, min_dist:0.03, time: 2.901000738143921
n_neighbors: 10, min_dist:0.1, time: 2.440992593765259
n_neighbors: 10, min_dist:0.25, time: 2.402992010116577
n_neighbors: 10, min_dist:0.5, time: 2.825998306274414
n_neighbors: 10, min_dist:0.8, time: 2.397998571395874
n_neighbors: 10, min_dist:0.99, time: 2.4529964923858643
n_neighbors: 20, min_dist:0.03, time: 2.557995557785034
n_neighbors: 20, min_dist:0.1, time: 2.957990884780884
n_neighbors: 20, min_dist:0.25, time: 2.5889925956726074
n_neighbors: 20, min_dist:0.5, time: 2.5638492107391357
n_neighbors: 20, min_dist:0.8, time: 2.8739993572235107
n_neighbors: 20, min_dist:0.99, time: 2.539990901947021

  fig = plt.figure(figsize=figsize)


n_neighbors: 50, min_dist:0.25, time: 3.2139925956726074


  plt.figure(dpi=1500)


n_neighbors: 50, min_dist:0.5, time: 2.801990270614624
n_neighbors: 50, min_dist:0.8, time: 2.797992467880249
n_neighbors: 50, min_dist:0.99, time: 2.834996461868286
n_neighbors: 100, min_dist:0.03, time: 3.3759913444519043
n_neighbors: 100, min_dist:0.1, time: 2.984992265701294
n_neighbors: 100, min_dist:0.25, time: 3.0069992542266846
n_neighbors: 100, min_dist:0.5, time: 2.9509918689727783
n_neighbors: 100, min_dist:0.8, time: 3.3719921112060547
n_neighbors: 100, min_dist:0.99, time: 3.0249969959259033
n_neighbors: 200, min_dist:0.03, time: 3.2029902935028076
n_neighbors: 200, min_dist:0.1, time: 3.590991973876953
n_neighbors: 200, min_dist:0.25, time: 3.152992010116577
n_neighbors: 200, min_dist:0.5, time: 3.43599009513855
n_neighbors: 200, min_dist:0.8, time: 3.184990406036377
n_neighbors: 200, min_dist:0.99, time: 3.5589921474456787


# umap + nn

In [13]:
# for n in [5, 10, 20, 50, 100, 200]:
for n in [50, 100, 200]:
    
    for min_dist in [0.03, 0.1, 0.25, 0.5, 0.8, 0.99]:

        print("n_neighbours: {}, min_dist: {}".format(n, min_dist))

        # 1. dataset
        batch_size = 1000 # 1024

        # construct graph of nearest neighbors
        graph_constructor_train =  ConstructUMAPGraph(metric='euclidean', n_neighbors=n, batch_size=1000, random_state=42)
        epochs_per_sample_train, head_train, tail_train, weight_train = graph_constructor_train(X_train)

        graph_constructor_test =  ConstructUMAPGraph(metric='euclidean', n_neighbors=n, batch_size=1000, random_state=42)
        epochs_per_sample_test, head_test, tail_test, weight_test = graph_constructor_test(X_test)

        train_dataset = UMAPDataset(
            data=X_train, labels=y_train, 
            epochs_per_sample=epochs_per_sample_train, head=head_train, tail=tail_train, weight=weight_train, 
            device='cuda', batch_size=batch_size, feedback=None
        )
        test_dataset = UMAPDataset(
            data=X_test, labels=y_test, 
            epochs_per_sample=epochs_per_sample_test, head=head_test, tail=tail_test, weight=weight_test, 
            device='cuda', batch_size=batch_size, feedback=None
        )

        criterion = UMAPLoss(device='cuda', min_dist=min_dist, batch_size=batch_size, negative_sample_rate=5, edge_weight=None, repulsion_strength=1.0)

        # 2. model
        class Encoder(nn.Module):

            def __init__ (self, output_dim=2):
                super().__init__()
                self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=2)
                self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2)
                self.flatten = nn.Flatten()
                self.linear1 = nn.Linear(128*6*6, 512)
                self.linear2 = nn.Linear(512, 512)
                self.linear3 = nn.Linear(512, output_dim)

            def forward(self, x):
                x = F.relu(self.conv1(x))
                x = F.relu(self.conv2(x))
                x = self.flatten(x)
                x = F.relu(self.linear1(x))
                x = F.relu(self.linear2(x))
                x = F.relu(self.linear3(x))

                return x

        model = Encoder(output_dim=2).cuda()
        # print(model)
        # print("num params: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

        # 3. training
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

        train_losses = []
        for epoch in tqdm(range(100)):
            train_loss = 0.
            # for batch_to, batch_from in tqdm(train_dataset.get_batches()):
            for batch_to, batch_from, batch_index_to, batch_index_from, labels, feedback in train_dataset.get_batches():
                optimizer.zero_grad()
                embedding_to = model(batch_to)
                embedding_from = model(batch_from)
                loss = criterion(embedding_to, embedding_from)
                train_loss += loss.item()
                loss.backward()
                optimizer.step()

            train_losses.append(train_loss)
            # print('epoch: {}, loss: {}'.format(epoch, train_loss))

        torch.save(model.state_dict(), './data/facts_check/umap_n{}_dist{}.pt'.format(n, min_dist))
        torch.save(torch.tensor(train_losses), './data/facts_check/loss_umap_n{}_dist{}.pt'.format(n, min_dist))

n_neighbours: 50, min_dist: 0.03
Tue Nov 28 17:38:44 2023 Building RP forest with 7 trees
Tue Nov 28 17:38:44 2023 NN descent for 10 iterations
	 1  /  10
	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations
Tue Nov 28 17:38:44 2023 Building RP forest with 7 trees
Tue Nov 28 17:38:44 2023 NN descent for 10 iterations
	 1  /  10
	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations


100%|██████████| 100/100 [10:15<00:00,  6.16s/it]

n_neighbours: 50, min_dist: 0.1
Tue Nov 28 17:49:01 2023 Building RP forest with 7 trees
Tue Nov 28 17:49:01 2023 NN descent for 10 iterations
	 1  /  10





	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations
Tue Nov 28 17:49:01 2023 Building RP forest with 7 trees
Tue Nov 28 17:49:01 2023 NN descent for 10 iterations
	 1  /  10
	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations


100%|██████████| 100/100 [10:20<00:00,  6.21s/it]

n_neighbours: 50, min_dist: 0.25
Tue Nov 28 17:59:23 2023 Building RP forest with 7 trees
Tue Nov 28 17:59:23 2023 NN descent for 10 iterations
	 1  /  10





	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations
Tue Nov 28 17:59:23 2023 Building RP forest with 7 trees
Tue Nov 28 17:59:24 2023 NN descent for 10 iterations
	 1  /  10
	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations


100%|██████████| 100/100 [10:15<00:00,  6.15s/it]

n_neighbours: 50, min_dist: 0.5
Tue Nov 28 18:09:40 2023 Building RP forest with 7 trees
Tue Nov 28 18:09:40 2023 NN descent for 10 iterations
	 1  /  10





	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations
Tue Nov 28 18:09:40 2023 Building RP forest with 7 trees
Tue Nov 28 18:09:41 2023 NN descent for 10 iterations
	 1  /  10
	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations


100%|██████████| 100/100 [10:14<00:00,  6.14s/it]

n_neighbours: 50, min_dist: 0.8
Tue Nov 28 18:19:56 2023 Building RP forest with 7 trees
Tue Nov 28 18:19:56 2023 NN descent for 10 iterations
	 1  /  10





	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations
Tue Nov 28 18:19:56 2023 Building RP forest with 7 trees
Tue Nov 28 18:19:56 2023 NN descent for 10 iterations
	 1  /  10
	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations


100%|██████████| 100/100 [10:17<00:00,  6.17s/it]

n_neighbours: 50, min_dist: 0.99
Tue Nov 28 18:30:15 2023 Building RP forest with 7 trees
Tue Nov 28 18:30:15 2023 NN descent for 10 iterations
	 1  /  10





	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations
Tue Nov 28 18:30:15 2023 Building RP forest with 7 trees
Tue Nov 28 18:30:15 2023 NN descent for 10 iterations
	 1  /  10
	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations


100%|██████████| 100/100 [10:19<00:00,  6.19s/it]

n_neighbours: 100, min_dist: 0.03
Tue Nov 28 18:40:35 2023 Building RP forest with 7 trees
Tue Nov 28 18:40:35 2023 NN descent for 10 iterations
	 1  /  10





	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations
Tue Nov 28 18:40:36 2023 Building RP forest with 7 trees
Tue Nov 28 18:40:36 2023 NN descent for 10 iterations
	 1  /  10
	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations


100%|██████████| 100/100 [40:51<00:00, 24.52s/it]

n_neighbours: 100, min_dist: 0.1
Tue Nov 28 19:21:31 2023 Building RP forest with 7 trees
Tue Nov 28 19:21:31 2023 NN descent for 10 iterations
	 1  /  10





	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations
Tue Nov 28 19:21:31 2023 Building RP forest with 7 trees
Tue Nov 28 19:21:32 2023 NN descent for 10 iterations
	 1  /  10
	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations


100%|██████████| 100/100 [40:54<00:00, 24.55s/it]

n_neighbours: 100, min_dist: 0.25
Tue Nov 28 20:02:29 2023 Building RP forest with 7 trees
Tue Nov 28 20:02:29 2023 NN descent for 10 iterations





	 1  /  10
	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations
Tue Nov 28 20:02:30 2023 Building RP forest with 7 trees
Tue Nov 28 20:02:30 2023 NN descent for 10 iterations
	 1  /  10
	 2  /  10
	 3  /  10
	Stopping threshold met -- exiting after 3 iterations


  1%|          | 1/100 [00:31<52:34, 31.86s/it]


KeyboardInterrupt: 