In [1]:
import torch
import torch.nn as nn
import numpy as np
import sys; sys.path.append('../src/')
import os
from model.cmpnts import MLP
from scipy.stats import ortho_group
import matplotlib.pyplot as plt
from math import sqrt

In [2]:
data_dir = '../data/toy'

In [3]:
def nonlinearity(dataset, epsilon=1.): # epsilon can be set to the groundtruth LD manifold's average std.
    X = dataset[:]
    _, s, _ = torch.svd(X - X.mean(0), compute_uv=False)
    return s.max() / s.min(), s.min() / epsilon, epsilon, s

In [4]:
def save_dataset(X, Y, i):
    in_dim = X.shape[1]
    out_dim = Y.shape[1]
    path = os.path.join(data_dir, '{}-manifold'.format(in_dim), '{}-ambient'.format(out_dim))
    os.makedirs(path, exist_ok=True)
    
    data_path = os.path.join(path, '{}-{}_{}.npy'.format(in_dim, out_dim, i))
    report_path = os.path.join(path, '{}-{}_{}.txt'.format(in_dim, out_dim, i))
    np.save(data_path, Y.detach().numpy())
    
    _, s_x, _ = torch.svd(X - X.mean(0), compute_uv=False)
    dis, ratio, eps, s = nonlinearity(Y, epsilon=s_x.mean())
    with open(report_path, 'w') as report:
        report.write('s.max / s.min = {}\n'.format(dis))
        report.write('s.min / x.s.mean = {}\n'.format(ratio))
        report.write('x.s.mean = {}\n'.format(eps / sqrt(len(X))))
        report.write('singular values = {}'.format(s.sort()[0].flip(0) / sqrt(len(X))))

# Nonlinear Dataset

In [5]:
from dataset.toy import IsometricEmbedding
from model.flow import _RealNVP

In [12]:
def build_flow(out_dim, n=10, w=256):
    nets = lambda: nn.Sequential(
        nn.Linear(out_dim, w), nn.ReLU(), 
        nn.Linear(w, w), nn.ReLU(),
        nn.Linear(w, w), nn.ReLU(),
        nn.Linear(w, out_dim), nn.Tanh())
    nett = lambda: nn.Sequential(
        nn.Linear(out_dim, w), nn.ReLU(), 
        nn.Linear(w, w), nn.ReLU(),
        nn.Linear(w, w), nn.ReLU(),
        nn.Linear(w, out_dim))
    masks = torch.as_tensor([
        [0]*(out_dim//2) + [1]*(out_dim//2), 
        [1]*(out_dim//2) + [0]*(out_dim//2)] * n, dtype=torch.float)
    flow = _RealNVP(nets, nett, masks)
    return flow

In [19]:
from tqdm import tqdm

in_dims = [2, 4, 8, 16, 32, 64] # [1, 2, 4, 8, 16, 32, 64]
out_dims = np.asarray([2, 4, 16, 64, 256], dtype=int)
sizes = [300, 1000, 3000, 10000, 30000, 100000]

for in_dim, size in zip(in_dims, sizes):
    for out_dim in out_dims * in_dim:
        for i in range(5):
            print('{} -> {} ({})'.format(in_dim, out_dim, i))

            X = torch.randn(size, in_dim)
            Y = IsometricEmbedding(X, out_dim=out_dim)

            flow = build_flow(out_dim, 10, 256)
            Y.embed(flow = flow,
                    optimizer = torch.optim.Adam(flow.parameters(), lr=1e-4),
                    epochs = 100 * int(sqrt(out_dim / in_dim)),
                    batch_size = 100,
                    lam = 100.)

            save_dataset(X, Y[:], i)

4 -> 8 (0)


Embedding: 100%|██████████████████████████████| 100/100 [00:16<00:00,  6.12it/s, isometricity=0.104, linearity=0.265]


4 -> 8 (1)


Embedding: 100%|██████████████████████████████| 100/100 [00:15<00:00,  6.35it/s, isometricity=0.0913, linearity=0.242]


4 -> 8 (2)


Embedding: 100%|██████████████████████████████| 100/100 [00:14<00:00,  6.80it/s, isometricity=0.0851, linearity=0.297]


4 -> 8 (3)


Embedding: 100%|██████████████████████████████| 100/100 [00:14<00:00,  6.78it/s, isometricity=0.0813, linearity=0.318]


4 -> 8 (4)


Embedding: 100%|██████████████████████████████| 100/100 [00:15<00:00,  6.49it/s, isometricity=0.0853, linearity=0.322]


4 -> 16 (0)


Embedding: 100%|██████████████████████████████| 200/200 [00:30<00:00,  6.47it/s, isometricity=0.0684, linearity=0.442]


4 -> 16 (1)


Embedding: 100%|██████████████████████████████| 200/200 [00:30<00:00,  6.58it/s, isometricity=0.0566, linearity=0.254]


4 -> 16 (2)


Embedding: 100%|██████████████████████████████| 200/200 [00:30<00:00,  6.60it/s, isometricity=0.0537, linearity=0.252]


4 -> 16 (3)


Embedding: 100%|██████████████████████████████| 200/200 [00:31<00:00,  6.41it/s, isometricity=0.0574, linearity=0.272]


4 -> 16 (4)


Embedding: 100%|██████████████████████████████| 200/200 [1:07:34<00:00, 20.27s/it, isometricity=0.0541, linearity=0.245]    


4 -> 64 (0)


Embedding: 100%|██████████████████████████████| 400/400 [01:05<00:00,  6.08it/s, isometricity=0.0945, linearity=0.702]


4 -> 64 (1)


Embedding: 100%|██████████████████████████████| 400/400 [01:03<00:00,  6.31it/s, isometricity=0.0885, linearity=0.635]


4 -> 64 (2)


Embedding: 100%|██████████████████████████████| 400/400 [01:03<00:00,  6.32it/s, isometricity=0.101, linearity=0.677] 


4 -> 64 (3)


Embedding: 100%|██████████████████████████████| 400/400 [01:02<00:00,  6.43it/s, isometricity=0.0967, linearity=0.628]


4 -> 64 (4)


Embedding: 100%|██████████████████████████████| 400/400 [01:01<00:00,  6.46it/s, isometricity=0.11, linearity=0.61]   


4 -> 256 (0)


Embedding: 100%|██████████████████████████████| 800/800 [02:51<00:00,  4.65it/s, isometricity=0.0667, linearity=0.381]


4 -> 256 (1)


Embedding: 100%|██████████████████████████████| 800/800 [02:54<00:00,  4.58it/s, isometricity=0.0584, linearity=0.412]


4 -> 256 (2)


Embedding: 100%|██████████████████████████████| 800/800 [02:52<00:00,  4.65it/s, isometricity=0.0459, linearity=0.364]


4 -> 256 (3)


Embedding: 100%|██████████████████████████████| 800/800 [02:53<00:00,  4.62it/s, isometricity=0.0489, linearity=0.405]


4 -> 256 (4)


Embedding: 100%|██████████████████████████████| 800/800 [02:53<00:00,  4.62it/s, isometricity=0.0575, linearity=0.396]


4 -> 1024 (0)


Embedding: 100%|██████████████████████████████| 1600/1600 [11:21<00:00,  2.35it/s, isometricity=0.0297, linearity=0.232]


4 -> 1024 (1)


Embedding: 100%|██████████████████████████████| 1600/1600 [11:26<00:00,  2.33it/s, isometricity=0.0238, linearity=0.225]


4 -> 1024 (2)


Embedding: 100%|██████████████████████████████| 1600/1600 [11:25<00:00,  2.33it/s, isometricity=0.0255, linearity=0.237]


4 -> 1024 (3)


Embedding: 100%|██████████████████████████████| 1600/1600 [11:17<00:00,  2.36it/s, isometricity=0.0295, linearity=0.229]


4 -> 1024 (4)


Embedding: 100%|██████████████████████████████| 1600/1600 [11:20<00:00,  2.35it/s, isometricity=0.0325, linearity=0.23] 


8 -> 16 (0)


Embedding: 100%|██████████████████████████████| 100/100 [00:14<00:00,  6.94it/s, isometricity=0.0818, linearity=0.363]


8 -> 16 (1)


Embedding: 100%|██████████████████████████████| 100/100 [00:14<00:00,  6.96it/s, isometricity=0.0837, linearity=0.324]


8 -> 16 (2)


Embedding: 100%|██████████████████████████████| 100/100 [00:14<00:00,  7.01it/s, isometricity=0.0892, linearity=0.382]


8 -> 16 (3)


Embedding: 100%|██████████████████████████████| 100/100 [00:13<00:00,  7.15it/s, isometricity=0.0915, linearity=0.394]


8 -> 16 (4)


Embedding: 100%|██████████████████████████████| 100/100 [00:13<00:00,  7.14it/s, isometricity=0.0885, linearity=0.303]


8 -> 32 (0)


Embedding: 100%|██████████████████████████████| 200/200 [00:29<00:00,  6.74it/s, isometricity=0.0886, linearity=0.468]


8 -> 32 (1)


Embedding: 100%|██████████████████████████████| 200/200 [00:28<00:00,  6.94it/s, isometricity=0.0975, linearity=0.394]


8 -> 32 (2)


Embedding: 100%|██████████████████████████████| 200/200 [00:28<00:00,  7.02it/s, isometricity=0.0783, linearity=0.372]


8 -> 32 (3)


Embedding: 100%|██████████████████████████████| 200/200 [00:28<00:00,  6.96it/s, isometricity=0.108, linearity=0.768] 


8 -> 32 (4)


Embedding: 100%|██████████████████████████████| 200/200 [00:28<00:00,  6.96it/s, isometricity=0.0845, linearity=0.384]


8 -> 128 (0)


Embedding: 100%|██████████████████████████████| 400/400 [01:12<00:00,  5.53it/s, isometricity=0.0692, linearity=0.466]


8 -> 128 (1)


Embedding: 100%|██████████████████████████████| 400/400 [01:12<00:00,  5.49it/s, isometricity=0.0687, linearity=0.48] 


8 -> 128 (2)


Embedding: 100%|██████████████████████████████| 400/400 [01:12<00:00,  5.55it/s, isometricity=0.0768, linearity=0.459]


8 -> 128 (3)


Embedding: 100%|██████████████████████████████| 400/400 [01:14<00:00,  5.40it/s, isometricity=0.0693, linearity=0.492]


8 -> 128 (4)


Embedding: 100%|██████████████████████████████| 400/400 [01:12<00:00,  5.48it/s, isometricity=0.0687, linearity=0.431]


8 -> 512 (0)


Embedding: 100%|██████████████████████████████| 800/800 [03:57<00:00,  3.37it/s, isometricity=0.035, linearity=0.186] 


8 -> 512 (1)


Embedding: 100%|██████████████████████████████| 800/800 [03:54<00:00,  3.41it/s, isometricity=0.0352, linearity=0.194]


8 -> 512 (2)


Embedding: 100%|██████████████████████████████| 800/800 [03:52<00:00,  3.43it/s, isometricity=0.0358, linearity=0.183]


8 -> 512 (3)


Embedding: 100%|██████████████████████████████| 800/800 [03:53<00:00,  3.43it/s, isometricity=0.0327, linearity=0.19] 


8 -> 512 (4)


Embedding: 100%|██████████████████████████████| 800/800 [03:56<00:00,  3.39it/s, isometricity=0.0321, linearity=0.19] 


8 -> 2048 (0)


Embedding: 100%|██████████████████████████████| 1600/1600 [16:26<00:00,  1.62it/s, isometricity=0.0224, linearity=0.135]


8 -> 2048 (1)


Embedding: 100%|██████████████████████████████| 1600/1600 [16:10<00:00,  1.65it/s, isometricity=0.0235, linearity=0.134]


8 -> 2048 (2)


Embedding: 100%|██████████████████████████████| 1600/1600 [16:14<00:00,  1.64it/s, isometricity=0.0237, linearity=0.133]


8 -> 2048 (3)


Embedding: 100%|██████████████████████████████| 1600/1600 [16:15<00:00,  1.64it/s, isometricity=0.0226, linearity=0.128]


8 -> 2048 (4)


Embedding: 100%|██████████████████████████████| 1600/1600 [16:17<00:00,  1.64it/s, isometricity=0.0255, linearity=0.144]


16 -> 32 (0)


Embedding: 100%|██████████████████████████████| 100/100 [00:14<00:00,  6.75it/s, isometricity=0.0826, linearity=0.593]


16 -> 32 (1)


Embedding: 100%|██████████████████████████████| 100/100 [00:14<00:00,  6.78it/s, isometricity=0.076, linearity=0.624]


16 -> 32 (2)


Embedding: 100%|██████████████████████████████| 100/100 [00:14<00:00,  6.84it/s, isometricity=0.0691, linearity=0.61]


16 -> 32 (3)


Embedding: 100%|██████████████████████████████| 100/100 [00:15<00:00,  6.60it/s, isometricity=0.0692, linearity=0.588]


16 -> 32 (4)


Embedding: 100%|██████████████████████████████| 100/100 [00:14<00:00,  6.70it/s, isometricity=0.072, linearity=0.625]


16 -> 64 (0)


Embedding: 100%|██████████████████████████████| 200/200 [00:31<00:00,  6.41it/s, isometricity=0.112, linearity=0.661]


16 -> 64 (1)


Embedding: 100%|██████████████████████████████| 200/200 [00:30<00:00,  6.47it/s, isometricity=0.117, linearity=0.717]


16 -> 64 (2)


Embedding: 100%|██████████████████████████████| 200/200 [00:31<00:00,  6.45it/s, isometricity=0.115, linearity=0.667] 


16 -> 64 (3)


Embedding: 100%|██████████████████████████████| 200/200 [00:30<00:00,  6.47it/s, isometricity=0.124, linearity=0.701]


16 -> 64 (4)


Embedding: 100%|██████████████████████████████| 200/200 [00:30<00:00,  6.58it/s, isometricity=0.112, linearity=0.714] 


16 -> 256 (0)


Embedding: 100%|██████████████████████████████| 400/400 [01:22<00:00,  4.84it/s, isometricity=0.0449, linearity=0.26] 


16 -> 256 (1)


Embedding: 100%|██████████████████████████████| 400/400 [01:21<00:00,  4.90it/s, isometricity=0.048, linearity=0.268] 


16 -> 256 (2)


Embedding: 100%|██████████████████████████████| 400/400 [01:21<00:00,  4.88it/s, isometricity=0.0511, linearity=0.263]


16 -> 256 (3)


Embedding: 100%|██████████████████████████████| 400/400 [01:22<00:00,  4.87it/s, isometricity=0.0472, linearity=0.259]


16 -> 256 (4)


Embedding: 100%|██████████████████████████████| 400/400 [01:22<00:00,  4.86it/s, isometricity=0.0488, linearity=0.26] 


16 -> 1024 (0)


Embedding: 100%|██████████████████████████████| 800/800 [05:01<00:00,  2.66it/s, isometricity=0.0339, linearity=0.198]


16 -> 1024 (1)


Embedding: 100%|██████████████████████████████| 800/800 [05:03<00:00,  2.64it/s, isometricity=0.0343, linearity=0.198]


16 -> 1024 (2)


Embedding: 100%|██████████████████████████████| 800/800 [05:01<00:00,  2.65it/s, isometricity=0.0318, linearity=0.2]  


16 -> 1024 (3)


Embedding: 100%|██████████████████████████████| 800/800 [04:57<00:00,  2.69it/s, isometricity=0.031, linearity=0.201] 


16 -> 1024 (4)


Embedding: 100%|██████████████████████████████| 800/800 [04:58<00:00,  2.68it/s, isometricity=0.0319, linearity=0.201]


16 -> 4096 (0)


Embedding: 100%|██████████████████████████████| 1600/1600 [26:52<00:00,  1.01s/it, isometricity=0.0249, linearity=0.173]


16 -> 4096 (1)


Embedding: 100%|██████████████████████████████| 1600/1600 [26:31<00:00,  1.01it/s, isometricity=0.0212, linearity=0.174]


16 -> 4096 (2)


Embedding: 100%|██████████████████████████████| 1600/1600 [26:39<00:00,  1.00it/s, isometricity=0.0288, linearity=0.178]


16 -> 4096 (3)


Embedding: 100%|██████████████████████████████| 1600/1600 [26:53<00:00,  1.01s/it, isometricity=0.03, linearity=0.18]   


16 -> 4096 (4)


Embedding: 100%|██████████████████████████████| 1600/1600 [26:46<00:00,  1.00s/it, isometricity=0.0265, linearity=0.173]


32 -> 64 (0)


Embedding: 100%|██████████████████████████████| 100/100 [00:15<00:00,  6.54it/s, isometricity=0.0705, linearity=1.11]


32 -> 64 (1)


Embedding: 100%|██████████████████████████████| 100/100 [00:15<00:00,  6.59it/s, isometricity=0.0725, linearity=1.08]


32 -> 64 (2)


Embedding: 100%|██████████████████████████████| 100/100 [00:15<00:00,  6.31it/s, isometricity=0.0745, linearity=1.13]


32 -> 64 (3)


Embedding: 100%|██████████████████████████████| 100/100 [00:15<00:00,  6.57it/s, isometricity=0.0587, linearity=1.13]


32 -> 64 (4)


Embedding: 100%|██████████████████████████████| 100/100 [00:15<00:00,  6.55it/s, isometricity=0.068, linearity=1.11]


32 -> 128 (0)


Embedding: 100%|██████████████████████████████| 200/200 [00:36<00:00,  5.53it/s, isometricity=0.0861, linearity=0.704]


32 -> 128 (1)


Embedding: 100%|██████████████████████████████| 200/200 [00:36<00:00,  5.53it/s, isometricity=0.085, linearity=0.731] 


32 -> 128 (2)


Embedding: 100%|██████████████████████████████| 200/200 [00:35<00:00,  5.63it/s, isometricity=0.0849, linearity=0.728]


32 -> 128 (3)


Embedding: 100%|██████████████████████████████| 200/200 [00:35<00:00,  5.69it/s, isometricity=0.0966, linearity=0.745]


32 -> 128 (4)


Embedding: 100%|██████████████████████████████| 200/200 [00:35<00:00,  5.63it/s, isometricity=0.0852, linearity=0.714]


32 -> 512 (0)


Embedding: 100%|██████████████████████████████| 400/400 [01:53<00:00,  3.52it/s, isometricity=0.0536, linearity=0.365]


32 -> 512 (1)


Embedding: 100%|██████████████████████████████| 400/400 [01:54<00:00,  3.50it/s, isometricity=0.0516, linearity=0.361]


32 -> 512 (2)


Embedding: 100%|██████████████████████████████| 400/400 [01:51<00:00,  3.60it/s, isometricity=0.0543, linearity=0.359]


32 -> 512 (3)


Embedding: 100%|██████████████████████████████| 400/400 [01:50<00:00,  3.62it/s, isometricity=0.0499, linearity=0.36] 


32 -> 512 (4)


Embedding: 100%|██████████████████████████████| 400/400 [01:50<00:00,  3.61it/s, isometricity=0.0572, linearity=0.368]


32 -> 2048 (0)


Embedding: 100%|██████████████████████████████| 800/800 [07:01<00:00,  1.90it/s, isometricity=0.0432, linearity=0.252]


32 -> 2048 (1)


Embedding: 100%|██████████████████████████████| 800/800 [07:03<00:00,  1.89it/s, isometricity=0.0437, linearity=0.259]


32 -> 2048 (2)


Embedding: 100%|██████████████████████████████| 800/800 [07:13<00:00,  1.85it/s, isometricity=0.0441, linearity=0.261]


32 -> 2048 (3)


Embedding: 100%|██████████████████████████████| 800/800 [06:59<00:00,  1.91it/s, isometricity=0.0398, linearity=0.254]


32 -> 2048 (4)


Embedding: 100%|██████████████████████████████| 800/800 [06:59<00:00,  1.91it/s, isometricity=0.0431, linearity=0.257]


32 -> 8192 (0)


Embedding: 100%|██████████████████████████████| 1600/1600 [49:54<00:00,  1.87s/it, isometricity=0.0361, linearity=0.224]


32 -> 8192 (1)


Embedding: 100%|██████████████████████████████| 1600/1600 [51:11<00:00,  1.92s/it, isometricity=0.0347, linearity=0.229]


32 -> 8192 (2)


Embedding: 100%|██████████████████████████████| 1600/1600 [51:02<00:00,  1.91s/it, isometricity=0.037, linearity=0.228] 


32 -> 8192 (3)


Embedding: 100%|██████████████████████████████| 1600/1600 [50:19<00:00,  1.89s/it, isometricity=0.0378, linearity=0.231]


32 -> 8192 (4)


Embedding:  54%|████████████████▎             | 868/1600 [24:37<20:46,  1.70s/it, isometricity=0.0394, linearity=0.264]


KeyboardInterrupt: 