In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
import sys
sys.path.append('gdrive/MyDrive/Colab Notebooks/normalizing_flow_gaia')

In [34]:
!pip install -q git+https://github.com/probabilists/zuko
!pip install -q astroML

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
import zuko

In [35]:
import matplotlib.pyplot as plt
from matplotlib import cm
import sklearn.datasets as datasets
import matplotlib.colors as colors
import torch
from torch import nn
from torch import optim
from torch.distributions.normal import Normal
from torch.optim.lr_scheduler import ExponentialLR
import numpy as np
import copy
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import os

import sklearn.preprocessing
import numpy as np
from matplotlib import pyplot as plt

from astroML.density_estimation import XDGMM

In [6]:
class donut_dist():
    def __init__(self,sigma=0.1):
        self.sigma = sigma
        pass

    def sample_true(self,sam_size):
        """
        Draws n samples from the donut distribution. Returns
        a tensor containing (r,theta)-coordinates, with shape (n,2).
        """
        r = np.random.normal(1.0, self.sigma, size=sam_size)
        theta = np.random.uniform(0., 2.*np.pi, size=sam_size)
        return np.stack([r,theta],axis=1)

    def p_true(self,y):
        """
        Returns the true p(y) at the given points y.
        """
        r = np.sqrt(y[:,0]**2+y[:,1]**2)
        return np.exp(-0.5*((r-1.)/self.sigma)**2) / (self.sigma*r*((2*np.pi)**(3/2)))

In [7]:
def coord_tran(r_theta):
    x = r_theta[:,0]*np.cos(r_theta[:,1])
    y = r_theta[:,0]*np.sin(r_theta[:,1])
    return np.stack([x,y],axis=1)

In [8]:
path = f'gdrive/MyDrive/Colab Notebooks/normalizing_flow_gaia/donut'
def plot_sam(q,title_name='',save = False):
    xlim = (-2, 2)
    fig,ax = plt.subplots(1,1, figsize=(6,6))
    ax.hexbin(q[:,0], q[:,1], extent=xlim+xlim, lw=0.5)
    ax.set_xlim(xlim)
    ax.set_ylim(xlim)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    plt.title(title_name)
    if save:
        plt.savefig(os.path.join(path,title_name), format='png',bbox_inches = 'tight' , dpi = 300)

    plt.show()

In [9]:
def plot_p(potential,title='potential'):
    xline = np.linspace(-2, 2, num=101)
    yline = np.linspace(-2, 2,num=101)
    xgrid, ygrid = np.meshgrid(xline, yline)
    xyinput = np.hstack([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)])
    zgrid = potential(xyinput).reshape(101, 101)
    cs = plt.contourf(xgrid, ygrid, zgrid)
    cb = plt.colorbar(cs)
    plt.title(title + " - p")
    plt.show()
    levels = np.linspace(-10, 0,10+1)
    cs = plt.contourf(xgrid, ygrid, np.log(zgrid),
                    levels,extend='both'
                      )
    cb = plt.colorbar(cs)
    plt.title(title + " - lnp")
    plt.show()

In [10]:
n_samples = 1024*16
dist = donut_dist(sigma = 0.1)
data_true = dist.sample_true(n_samples)
data_true_xy = coord_tran(data_true)
data_test = dist.sample_true(int(0.1*n_samples))
data_test_xy = coord_tran(data_test)

In [11]:
plot_sam(data_true_xy)
plot_p(dist.p_true,'true_potential')

Output hidden; open in https://colab.research.google.com to view.

In [12]:
def add_noise(q_true,amp=1):
    q_noise = np.zeros_like(q_true)
    sigma = amp*q_true[:,1]/(2*np.pi)
    q_noise[:,0] = q_true[:,0] + np.random.normal(0,sigma)
    q_noise[:,1] = q_true[:,1]
    q_noise_xy = coord_tran(q_noise)
    dr = np.array([np.random.normal(r, s, 100) for r, s in zip(q_true[:,0],sigma)]).std(axis=1)
    dx = dr*abs(np.cos(q_true[:,1]))
    dy = dr*abs(np.sin(q_true[:,1]))
    return q_noise,q_noise_xy,np.vstack([dx,dy]).T

In [13]:
data_noise,data_noise_xy,error_xy = add_noise(data_true,amp=0.3)
plot_sam(data_noise_xy,'noise_sample',save = True)
plot_sam(data_true_xy,'true_sample',save = True)

Output hidden; open in https://colab.research.google.com to view.

In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)

cuda:0


In [15]:
def train_flow(flow,optimizer #,scheduler
               ,n_samples,num_iter,
            sam_size,scalar,x,y,dx,dy,X_test
            ,plot = True):
    # num_iter = 500
    epoch_his = []
    loss_his = []
    test_loss_his = []
    best_epoch = -1
    best_loss = np.inf
    best_test_loss = np.inf
    best_model = copy.deepcopy(flow)
    norm_factor = scalar.scale_[0]*scalar.scale_[1]
    x_noise_dis_0 = Normal(torch.tensor(x,dtype=torch.float32),
                                                    torch.tensor(dx,dtype=torch.float32))

    x_noise_dis_1 = Normal(torch.tensor(y,dtype=torch.float32),
                                                    torch.tensor(dy,dtype=torch.float32))
    X_test_tensor = torch.tensor(scalar.transform(X_test), dtype=torch.float32).to(device)


    for i in tqdm(range(num_iter)):
        flow.train()
        optimizer.step()
        # scheduler.step()
        optimizer.zero_grad()
        x_noise_0 = x_noise_dis_0.sample(sample_shape= (sam_size,)).T.reshape((-1,1))
        x_noise_1 = x_noise_dis_1.sample(sample_shape= (sam_size,)).T.reshape((-1,1))
        x_noise = torch.tensor(scalar.transform(np.hstack([x_noise_0,x_noise_1])), dtype=torch.float32).to(device)
        loss = -flow().log_prob(x_noise).reshape(n_samples,sam_size).logsumexp(dim=1).mean() + torch.log(torch.tensor(sam_size)) + torch.log(torch.tensor(norm_factor))
        # print(f'epoch{i} loss = {loss}')
        with torch.no_grad():
            test_loss = -flow().log_prob(X_test_tensor).mean() + torch.log(torch.tensor(norm_factor))
            if loss < best_loss:
                best_loss = loss.item()
                best_epoch = i
                best_test_loss = test_loss.item()
                best_model = copy.deepcopy(flow.state_dict())
            loss_his.append(loss.item())
            test_loss_his.append(test_loss.item())
            epoch_his.append(i+1)



        loss.backward()


    print(f"\\ best epoch = {best_epoch}", flush=True)
    print(f"best loss = {best_loss}", flush=True)
    print(f"best test loss = {best_test_loss}", flush=True)


    flow.load_state_dict(best_model)
    flow.eval()
    with torch.no_grad():
        flow_sample = scalar.inverse_transform(flow().sample((n_samples,)).cpu())
    if plot:
        plot_sam(flow_sample,f'Normalizing Flow Resampling epoch{num_iter}')
        plt.show()

    xline = torch.linspace(-2., 2.,steps=101)
    yline = torch.linspace(-2., 2., steps=101)
    xgrid, ygrid = torch.meshgrid(xline, yline)
    xyinput = torch.cat([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], dim=1)
    xyinput = torch.tensor(scalar.transform(xyinput), dtype=torch.float32).to(device)
    with torch.no_grad():
        zgrid = (flow().log_prob(xyinput).exp().reshape(101, 101).cpu()/norm_factor).numpy()
    if plot:
        cs = plt.contourf(xgrid.numpy(), ygrid.numpy(), zgrid
                        )
        cb = plt.colorbar(cs)
        plt.title('iteration {} - potential'.format(num_iter))
        plt.show()


        levels = np.linspace(-10, 0,10+1)
        cs = plt.contourf(xgrid.numpy(), ygrid.numpy(), np.log(zgrid),
                        levels,extend='both'
                        )
        cb = plt.colorbar(cs)
        plt.title('iteration {} - lnp'.format(num_iter))
        plt.show()


        plt.plot(epoch_his,loss_his, label = 'train_loss')
        plt.plot(epoch_his,test_loss_his, label = 'test_loss')
        plt.legend()
        plt.title(f'loss history')
        plt.xlabel(f'number of epoch')
        plt.ylabel(f'loss')
        plt.show()
    # plt.savefig(os.path.join(self._path,'his'), format='png', bbox_inches = 'tight' , dpi = 500)
    return flow_sample,zgrid,[epoch_his,loss_his,test_loss_his],[best_epoch,best_loss,best_test_loss]

In [16]:
scalar = sklearn.preprocessing.StandardScaler()
scalar.fit(data_noise_xy)

In [20]:
from zuko.flows import Flow, MaskedAutoregressiveTransform, UnconditionalDistribution
from zuko.distributions import DiagNormal
from typing import Sequence



class set_flow(zuko.flows.Flow):
    def __init__(self, features: int, transforms: int,hidden_features: Sequence[int] = (16, 16), context: int = 0, bins: int = 8):
        transforms = [
            zuko.flows.MaskedAutoregressiveTransform(
                features=features,
                context=context,
                order=torch.randperm(features),
                univariate=zuko.transforms.MonotonicRQSTransform,
                shapes=[(bins,), (bins,), (bins-1,)],
                hidden_features = hidden_features,
            )
            for i in range(transforms)
        ]

        base = zuko.flows.UnconditionalDistribution(
            zuko.distributions.DiagNormal,
            torch.zeros(features),
            torch.ones(features),
            buffer=True,
        )

        super().__init__(transforms, base)

In [21]:
path = f'gdrive/MyDrive/Colab Notebooks/normalizing_flow_gaia/donut'
def ensemble_plot(sam_size_group,sample_collection,his_collection, potential_collection,name='ensemble'):
    n_models = len(sam_size_group)+1
    xlim = (-2, 2)
    xline = np.linspace(-2, 2, num=101)
    yline = np.linspace(-2, 2,num=101)
    xgrid, ygrid = np.meshgrid(xline, yline)
    xyinput = np.hstack([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)])
    levels = np.linspace(-10, 0,10+1)

    fig, axs = plt.subplots(nrows=3, ncols=n_models ,figsize=(2*n_models,3*2))


    axs[0, 0].hexbin(data_true_xy[:,0], data_true_xy[:,1], extent=xlim+xlim, lw=0.5)
    axs[0, 0].set_xlim(xlim)
    axs[0, 0].set_ylim(xlim)
    axs[0, 0].set_title("true sample")
    axs[0, 0].axis("off")
    zgrid = dist.p_true(xyinput).reshape(101, 101)
    # cs = plt.contourf(xgrid, ygrid, zgrid)
    axs[1, 0].contourf(xgrid, ygrid, zgrid)
    axs[1, 0].set_title("true potential" + " - p")
    axs[1, 0].axis("off")
    axs[2, 0].contourf(xgrid, ygrid, np.log(zgrid),
                    levels,extend='both'
                      )
    axs[2, 0].set_title("true potential" + " - lnp")
    axs[2, 0].axis("off")
    for i in range(len(sam_size_group)):
        axs[0, i+1].hexbin(sample_collection[i][:,0], sample_collection[i][:,1], extent=xlim+xlim, lw=0.5)
        axs[0, i+1].set_xlim(xlim)
        axs[0, i+1].set_ylim(xlim)
        axs[0, i+1].set_title(f"sam_size{sam_size_group[i]} NF")
        axs[0, i+1].axis("off")
        zgrid = potential_collection[i]
        axs[1, i+1].contourf(xgrid, ygrid, zgrid)
        axs[1, i+1].set_title(f"sam_size{sam_size_group[i]} NF" + " - p")
        axs[1, i+1].axis("off")
        axs[2, i+1].contourf(xgrid, ygrid, np.log(zgrid),
                        levels,extend='both'
                        )
        axs[2, i+1].set_title(f"sam_size{sam_size_group[i]} NF" + " - lnp")
        axs[2, i+1].axis("off")
    fig.tight_layout()
    plt.savefig(os.path.join(path,'donut'), format='png',bbox_inches = 'tight' , dpi = 600)
    plt.show()

    his_collection = np.array(his_collection)
    plt.figure(1)

    for i in range(len(sam_size_group)):
        plt.plot(his_collection[i][0,100:],his_collection[i][1,100:], label = f'sam_size {sam_size_group[i]}')
    plt.legend()
    plt.title(f'train loss')
    plt.xlabel(f'number of epoch')
    plt.ylabel(f'loss')
    plt.savefig(os.path.join(path,'train_loss'), format='png',bbox_inches = 'tight' , dpi = 300)
    plt.show()

    plt.figure(2)

    for i in range(len(sam_size_group)):
        plt.plot(his_collection[i][0,100:],his_collection[i][2,100:], label = f'sam_size {sam_size_group[i]}')
    plt.legend()
    plt.title(f'test loss')
    plt.xlabel(f'number of epoch')
    plt.ylabel(f'loss')
    plt.savefig(os.path.join(path,'test_loss'), format='png',bbox_inches = 'tight' , dpi = 300)
    plt.show()


In [31]:
sam_size_group = [1,5,10,30,50]
def train_ensemble(sam_size_group):
    sample_collection = []
    potential_collection = []
    his_collection = []
    result_collection=[]
    for sam_size in sam_size_group:
        flow = set_flow(features= 2, transforms=3, bins=8).to(device)
        optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)
        print(f'flow with sam_size {sam_size}')
        flow_sample,flow_potential,his,best_result = train_flow(flow=flow, optimizer = optimizer ,#scheduler=scheduler,
                                                    num_iter = 500,n_samples = n_samples
                                                    ,sam_size=sam_size, scalar=scalar,
                                                    x = data_noise_xy[:,0],y = data_noise_xy[:,1],
                                                    dx = error_xy[:,0] ,dy = error_xy[:,1],
                                                    X_test = data_test_xy,
                                                    plot = False)
        sample_collection.append(flow_sample)
        potential_collection.append(flow_potential)
        his_collection.append(his)
        result_collection.append(best_result)
    return sam_size_group,sample_collection,his_collection, potential_collection

In [32]:
sam_size_group,sample_collection,his_collection, potential_collection= train_ensemble(sam_size_group)


flow with sam_size 1


100%|██████████| 500/500 [00:18<00:00, 26.91it/s]

\ best epoch = 454
best loss = 1.660248458457684
best test loss = 1.3429451585538144



  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


flow with sam_size 5


100%|██████████| 500/500 [00:55<00:00,  9.06it/s]

\ best epoch = 472
best loss = 1.5440055728680844
best test loss = 1.1681258083111996





flow with sam_size 10


100%|██████████| 500/500 [01:50<00:00,  4.54it/s]

\ best epoch = 488
best loss = 1.5210728050000424
best test loss = 1.1171254516370053





flow with sam_size 30


100%|██████████| 500/500 [05:29<00:00,  1.52it/s]

\ best epoch = 489
best loss = 1.505116164756512
best test loss = 1.087812721801495





flow with sam_size 50


100%|██████████| 500/500 [09:09<00:00,  1.10s/it]

\ best epoch = 496
best loss = 1.509578406882977
best test loss = 1.0746729970700497





In [37]:
ensemble_plot(sam_size_group,sample_collection,his_collection, potential_collection,name='ensemble')

Output hidden; open in https://colab.research.google.com to view.

In [36]:
X = data_noise_xy
Xerr = np.zeros(X.shape + X.shape[-1:])
diag = np.arange(X.shape[-1])
Xerr[:, diag, diag] = np.vstack([error_xy[:,0] ** 2, error_xy[:,1] ** 2]).T
clf = XDGMM(n_components=10, max_iter=200)

clf.fit(X, Xerr)
XDsample = clf.sample(n_samples)
plot_sam(XDsample,'Extreme Deconvolution Resampling')

Xerr_test = np.zeros(data_test_xy.shape + data_test_xy.shape[-1:])
log_likelihood = np.log(np.sum(np.exp(clf.logprob_a(X=data_test_xy, Xerr=Xerr_test)),axis=-1)).mean()
print(f'log_likelihood for test set is {log_likelihood}')

xline = np.linspace(-2, 2, num=101)
yline = np.linspace(-2, 2,num=101)
xgrid, ygrid = np.meshgrid(xline, yline)
xyinput = np.hstack([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)])
xyinput_err = np.zeros(xyinput.shape + xyinput.shape[-1:])

zgrid = np.sum(np.exp(clf.logprob_a(X=xyinput, Xerr=xyinput_err)),axis=-1).reshape(101, 101)
cs = plt.contourf(xgrid, ygrid, zgrid)
cb = plt.colorbar(cs)
plt.title('XDGMM Distribution')
plt.show()

Output hidden; open in https://colab.research.google.com to view.

In [40]:
path = f'gdrive/MyDrive/Colab Notebooks/normalizing_flow_gaia/donut'
def ensemble_plot_xdgmm(sam_size_group,sample_collection,his_collection, potential_collection,name='ensemble'):
    n_models = len(sam_size_group)+2
    xlim = (-2, 2)
    xline = np.linspace(-2, 2, num=101)
    yline = np.linspace(-2, 2,num=101)
    xgrid, ygrid = np.meshgrid(xline, yline)
    xyinput = np.hstack([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)])
    levels = np.linspace(-10, 0,10+1)

    fig, axs = plt.subplots(nrows=3, ncols=n_models ,figsize=(2*n_models,3*2))


    axs[0, 0].hexbin(data_true_xy[:,0], data_true_xy[:,1], extent=xlim+xlim, lw=0.5)
    axs[0, 0].set_xlim(xlim)
    axs[0, 0].set_ylim(xlim)
    axs[0, 0].set_title("true sample")
    axs[0, 0].axis("off")
    zgrid = dist.p_true(xyinput).reshape(101, 101)
    # cs = plt.contourf(xgrid, ygrid, zgrid)
    axs[1, 0].contourf(xgrid, ygrid, zgrid)
    axs[1, 0].set_title("true potential" + " - p")
    axs[1, 0].axis("off")
    axs[2, 0].contourf(xgrid, ygrid, np.log(zgrid),
                    levels,extend='both'
                      )
    axs[2, 0].set_title("true potential" + " - lnp")
    axs[2, 0].axis("off")
    for i in range(len(sam_size_group)):
        axs[0, i+1].hexbin(sample_collection[i][:,0], sample_collection[i][:,1], extent=xlim+xlim, lw=0.5)
        axs[0, i+1].set_xlim(xlim)
        axs[0, i+1].set_ylim(xlim)
        axs[0, i+1].set_title(f"sam_size{sam_size_group[i]} NF")
        axs[0, i+1].axis("off")
        zgrid = potential_collection[i]
        axs[1, i+1].contourf(xgrid, ygrid, zgrid)
        axs[1, i+1].set_title(f"sam_size{sam_size_group[i]} NF" + " - p")
        axs[1, i+1].axis("off")
        axs[2, i+1].contourf(xgrid, ygrid, np.log(zgrid),
                        levels,extend='both'
                        )
        axs[2, i+1].set_title(f"sam_size{sam_size_group[i]} NF" + " - lnp")
        axs[2, i+1].axis("off")


    axs[0, n_models-1].hexbin(XDsample[:,0], XDsample[:,1], extent=xlim+xlim, lw=0.5)
    axs[0, n_models-1].set_xlim(xlim)
    axs[0, n_models-1].set_ylim(xlim)
    axs[0, n_models-1].set_title("xdgmm sample")
    axs[0, n_models-1].axis("off")
    zgrid = dist.p_true(xyinput).reshape(101, 101)
    # cs = plt.contourf(xgrid, ygrid, zgrid)
    axs[1, n_models-1].contourf(xgrid, ygrid, xdgmm_zgrid)
    axs[1, n_models-1].set_title("xdgmm" + " - p")
    axs[1, n_models-1].axis("off")
    axs[2, n_models-1].contourf(xgrid, ygrid, np.log(xdgmm_zgrid),
                    levels,extend='both'
                      )
    axs[2, n_models-1].set_title("xdgmm" + " - lnp")
    axs[2, n_models-1].axis("off")
    fig.tight_layout()
    plt.savefig(os.path.join(path,'donut_xdgmm'), format='png',bbox_inches = 'tight' , dpi = 600)
    plt.show()

    his_collection = np.array(his_collection)
    plt.figure(1)

    for i in range(len(sam_size_group)):
        plt.plot(his_collection[i][0,100:],his_collection[i][1,100:], label = f'sam_size {sam_size_group[i]}')
    plt.legend()
    plt.title(f'train loss')
    plt.xlabel(f'number of epoch')
    plt.ylabel(f'loss')
    plt.savefig(os.path.join(path,'train_loss'), format='png',bbox_inches = 'tight' , dpi = 300)
    plt.show()

    plt.figure(2)

    for i in range(len(sam_size_group)):
        plt.plot(his_collection[i][0,100:],his_collection[i][2,100:], label = f'sam_size {sam_size_group[i]}')
    plt.legend()
    plt.title(f'test loss')
    plt.xlabel(f'number of epoch')
    plt.ylabel(f'loss')
    plt.savefig(os.path.join(path,'test_loss'), format='png',bbox_inches = 'tight' , dpi = 300)
    plt.show()


In [41]:
ensemble_plot_xdgmm(sam_size_group,sample_collection,his_collection, potential_collection,name='ensemble')

Output hidden; open in https://colab.research.google.com to view.