In [1]:
import random
import time
import math
import numpy as np
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import KDTree
from scipy.stats import wasserstein_distance

import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd.variable import Variable
from torch.utils.data import DataLoader

In [3]:
x = torch.rand(5,2)

In [4]:
y = torch.rand(5,2)

In [7]:
p = nn.PairwiseDistance(p=2)

In [9]:
torch.norm(x[0]-y[0],p=2)

tensor(0.2827)

In [8]:
p(x,y)

tensor([0.2827, 0.4016, 0.9251, 0.3960, 0.7157])

In [2]:
#device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')
device = torch.device('cpu')

In [3]:
set_dist = []


for i in range(50): 
    m = torch.distributions.normal.Normal(torch.tensor([0.3]), torch.tensor([0.5]))
    x = m.sample([100])
    set_dist.append(x) 
    
for i in range(50): 
    m = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0]))
    x = m.sample([100])
    set_dist.append(x) 
    
for i in range(50): 
    m = torch.distributions.normal.Normal(torch.tensor([-.5]), torch.tensor([0.75]))
    x = m.sample([100])
    set_dist.append(x) 
    
for i in range(50): 
    m = torch.distributions.normal.Normal(torch.tensor([1.0]), torch.tensor([1.0]))
    x = m.sample([100])
    set_dist.append(x) 
    
for i in range(50): 
    m = torch.distributions.normal.Normal(torch.tensor([-.3]), torch.tensor([0.1]))
    x = m.sample([100])
    set_dist.append(x) 
    
for i in range(50): 
    m = torch.distributions.normal.Normal(torch.tensor([0.7]), torch.tensor([0.2]))
    x = m.sample([100])
    set_dist.append(x) 
    


In [4]:
set_dist = torch.stack(set_dist)

In [5]:
set_dist.shape

torch.Size([300, 100, 1])

In [27]:
class Set2Set(nn.Module):
    def __init__(self, input_dim, hidden_dim, act_fn=nn.Tanh, num_layers=1):
        '''
        Args:
            input_dim: input dim of Set2Set. 
            hidden_dim: the dim of set representation, which is also the INPUT dimension of 
                the LSTM in Set2Set. 
                This is a concatenation of weighted sum of embedding (dim input_dim), and the LSTM
                hidden/output (dim: self.lstm_output_dim).
        '''
        super(Set2Set, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        if hidden_dim <= input_dim:
            print('ERROR: Set2Set output_dim should be larger than input_dim')
        # the hidden is a concatenation of weighted sum of embedding and LSTM output
        self.lstm_output_dim = hidden_dim - input_dim
        self.lstm = nn.LSTM(hidden_dim, input_dim, num_layers=num_layers, batch_first=True)

        # convert back to dim of input_dim
       # self.pred = nn.Linear(hidden_dim, input_dim)
        self.pred = nn.Linear(hidden_dim,4)
        self.act = act_fn()

    def forward(self, embedding):
        '''
        Args:
            embedding: [batch_size x n x d] embedding matrix
        Returns:
            aggregated: [batch_size x d] vector representation of all embeddings
        '''
        batch_size = embedding.size()[0]
        n = embedding.size()[1]

        hidden = (torch.zeros(self.num_layers, batch_size, self.lstm_output_dim).cuda(),
                  torch.zeros(self.num_layers, batch_size, self.lstm_output_dim).cuda())

        q_star = torch.zeros(batch_size, 1, self.hidden_dim).cuda()
        for i in range(n):
            # q: batch_size x 1 x input_dim
            q, hidden = self.lstm(q_star, hidden)
            # e: batch_size x n x 1
            e = embedding @ torch.transpose(q, 1, 2)
            a = nn.Softmax(dim=1)(e)
            r = torch.sum(a * embedding, dim=1, keepdim=True)
            q_star = torch.cat((q, r), dim=2)
        q_star = torch.squeeze(q_star, dim=1)
        out = self.act(self.pred(q_star))

        return out

In [3]:
class DeepSet(nn.Module):

    def __init__(self, in_features, set_features):
        super(DeepSet, self).__init__()
        self.in_features = in_features
        self.out_features = set_features
        self.feature_extractor = nn.Sequential(
            nn.Linear(in_features, 50),
            nn.ELU(inplace=True),
            nn.Linear(50, 100),
            nn.ELU(inplace=True),
            nn.Linear(100, set_features)
        )

        self.regressor = nn.Sequential(
            nn.Linear(set_features, 30),
            nn.ELU(inplace=True),
            nn.Linear(30, 30),
            nn.ELU(inplace=True),
            nn.Linear(30, 10),
            nn.ELU(inplace=True),
            nn.Linear(10, 2),
        )
        
        
    def forward(self, input):
        x = input
        x = self.feature_extractor(x)
        x = x.sum(dim=1)
        x = self.regressor(x)
        return x


In [None]:
class Encoder(nn.Module):
    """ Set Encoder 
    """
    def __init__(self, dim_Q, dim_K, dim_V, d_model, num_heads, ln=False, skip=True):
        super(Encoder, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.skip = skip
       # self.s_max = s_max
        #Maximum set size
        self.d_model = d_model
        self.fc_q = nn.Linear(dim_Q, d_model)
        self.fc_k = nn.Linear(dim_K, d_model)
        self.fc_v = nn.Linear(dim_K, d_model)
        if ln:
            self.ln0 = nn.LayerNorm(d_model)
            self.ln1 = nn.LayerNorm(d_model)
        #This is the classic pointwise feedforward in "Attention is All you need"
        self.ff = nn.Sequential(
        nn.Linear(d_model, 4 * d_model),
        nn.ReLU(),
        nn.Linear(4 * d_model, d_model))
        # I have experimented with just a smaller version of this 
       # self.fc_o = nn.Linear(d_model,d_model)
        
     #   self.fc_rep = nn.Linear(s_max, 1)
#number of heads must divide output size = d_model
        

    def forward(self, Q, K):
        Q = self.fc_q(Q)
      
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.d_model // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)
  

        A = torch.softmax(Q_.bmm(K_.transpose(-2,-1))/math.sqrt(self.d_model), dim=-1)
        A_1 = A.bmm(V_)
        
 
        O = torch.cat((A_1).split(Q.size(0), 0), 2)
       
        O = torch.cat((Q_ + A_1).split(Q.size(0), 0), 2) if getattr(self, 'skip', True) else \
             torch.cat((A_1).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
       # O = O + F.relu(self.fc_o(O)) if getattr(self, 'skip', None) is None else F.relu(self.fc_o(O))
        # For the classic transformers paper it is 
        O = O + self.ff(O)
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        O = torch.mean(O,dim=1)
#         O = pad_sequence(O, batch_first=True, padding_value=0)
#         O = O.transpose(-2,-1)
#         O = F.pad(O, (0, self.s_max- O.shape[-1]), 'constant', 0)
      #  O = self.fc_rep(O)
       # O = self.fc_rep(O.transpose(-2,-1))
      #  O = O.squeeze()

        return O

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, dim_in=18, dim_out=8, num_heads=2, ln=True, skip=True):
        super(SelfAttention, self).__init__()
        self.Encoder = Encoder(dim_in, dim_in, dim_in, dim_out, num_heads, ln=ln, skip=skip)

    def forward(self, X):
        return self.Encoder(X, X)


In [None]:
eps = 1e-15
"""Approximating KL divergences between two probability densities using samples. 
    It is buggy. Use at your own peril
"""

def knn_distance(point, sample, k):
    """ Euclidean distance from `point` to it's `k`-Nearest
    Neighbour in `sample` """
    norms = np.linalg.norm(sample-point, axis=1)
    return np.sort(norms)[k]


def verify_sample_shapes(s1, s2, k):
    # Expects [N, D]
    assert(len(s1.shape) == len(s2.shape) == 2)
    # Check dimensionality of sample is identical
    assert(s1.shape[1] == s2.shape[1])


def naive_estimator(s1, s2, k=1):
    """ KL-Divergence estimator using brute-force (numpy) k-NN
        s1: (N_1,D) Sample drawn from distribution P
        s2: (N_2,D) Sample drawn from distribution Q
        k: Number of neighbours considered (default 1)
        return: estimated D(P|Q)
    """
    verify_sample_shapes(s1, s2, k)

    n, m = len(s1), len(s2)
    D = np.log(m / (n - 1))
    d = float(s1.shape[1])

    for p1 in s1:
        nu = knn_distance(p1, s2, k-1)  # -1 because 'p1' is not in 's2'
        rho = knn_distance(p1, s1, k)
        D += (d/n)*np.log((nu/rho)+eps)
    return D


def scipy_estimator(s1, s2, k=1):
    """ KL-Divergence estimator using scipy's KDTree
        s1: (N_1,D) Sample drawn from distribution P
        s2: (N_2,D) Sample drawn from distribution Q
        k: Number of neighbours considered (default 1)
        return: estimated D(P|Q)
    """
    verify_sample_shapes(s1, s2, k)

    n, m = len(s1), len(s2)
    d = float(s1.shape[1])
    D = np.log(m / (n - 1))

    nu_d,  nu_i   = KDTree(s2).query(s1, k)
    rho_d, rhio_i = KDTree(s1).query(s1, k+1)

    # KTree.query returns different shape in k==1 vs k > 1
    if k > 1:
        D += (d/n)*np.sum(np.log(nu_d[::, -1]/rho_d[::, -1]))
    else:
        D += (d/n)*np.sum(np.log(nu_d/rho_d[::, -1]))

    return D


def skl_estimator(s1, s2, k=1):
    """ KL-Divergence estimator using scikit-learn's NearestNeighbours
        s1: (N_1,D) Sample drawn from distribution P
        s2: (N_2,D) Sample drawn from distribution Q
        k: Number of neighbours considered (default 1)
        return: estimated D(P|Q)
    """
    verify_sample_shapes(s1, s2, k)

    n, m = len(s1), len(s2)
    d = float(s1.shape[1])
    D = np.log(m / (n - 1))

    s1_neighbourhood = NearestNeighbors(k+1, 10).fit(s1)
    s2_neighbourhood = NearestNeighbors(k, 10).fit(s2)

    for p1 in s1:
        s1_distances, indices = s1_neighbourhood.kneighbors([p1], k+1)
        s2_distances, indices = s2_neighbourhood.kneighbors([p1], k)
        rho = s1_distances[0][-1]
        nu = s2_distances[0][-1]
        D += (d/n)*np.log(nu/rho)
    return D


# List of all estimators
Estimators = [naive_estimator, scipy_estimator, skl_estimator]

In [7]:
class SinkhornDistance(nn.Module):
    r"""
    Given two empirical measures each with :math:`P_1` locations
    :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
    outputs an approximation of the regularized OT cost for point clouds.
    Args:
        eps (float): regularization coefficient
        max_iter (int): maximum number of Sinkhorn iterations
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'none'
    Shape:
        - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
        - Output: :math:`(N)` or :math:`()`, depending on `reduction`
    """
    def __init__(self, eps, max_iter, reduction='none'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction

    def forward(self, x, y):
        # The Sinkhorn algorithm takes as input three variables :
        C = self._cost_matrix(x, y)  # Wasserstein cost function
        x_points = x.shape[-2]
        y_points = y.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]

        # both marginals are fixed with equal weights
        mu = torch.empty(batch_size, x_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / x_points).to(device).squeeze()
        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / y_points).to(device).squeeze()

        u = torch.zeros_like(mu).to(device)
        v = torch.zeros_like(nu).to(device)
        # To check if algorithm terminates because of threshold
        # or max iterations reached
        actual_nits = 0
        # Stopping criterion
        thresh = 1e-1

        # Sinkhorn iterations
        for i in range(self.max_iter):
            u1 = u  # useful to check the update
            u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()

            actual_nits += 1
            if err.item() < thresh:
                break

        U, V = u, v
        # Transport plan pi = diag(a)*K*diag(b)
        pi = torch.exp(self.M(C, U, V))
        # Sinkhorn distance
        cost = torch.sum(pi * C, dim=(-2, -1))

        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()

      #  return cost, pi, C
        return cost

    def M(self, C, u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

    @staticmethod
    def _cost_matrix(x, y, p=2):
        "Returns the matrix of $|x_i-y_j|^p$."
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
        return C

    @staticmethod
    def ave(u, u1, tau):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1

In [8]:
sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction=None).to(device)

In [9]:
class MyDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data.float()
        
        self.transform = transform

    def __getitem__(self, index):
        x = self.data[index]
        
        if self.transform:
            x = self.transform(x)
           
        return x

    def __len__(self):
        return len(self.data)
    

In [10]:
dataset = MyDataset(set_dist)
loader = DataLoader(dataset, batch_size = 12, shuffle = True)

In [4]:
model = DeepSet(1, 36).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.load_state_dict(torch.load('normal_encode_1D_600.pkl'))
model.eval()

DeepSet(
  (feature_extractor): Sequential(
    (0): Linear(in_features=1, out_features=50, bias=True)
    (1): ELU(alpha=1.0, inplace=True)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): ELU(alpha=1.0, inplace=True)
    (4): Linear(in_features=100, out_features=36, bias=True)
  )
  (regressor): Sequential(
    (0): Linear(in_features=36, out_features=30, bias=True)
    (1): ELU(alpha=1.0, inplace=True)
    (2): Linear(in_features=30, out_features=30, bias=True)
    (3): ELU(alpha=1.0, inplace=True)
    (4): Linear(in_features=30, out_features=10, bias=True)
    (5): ELU(alpha=1.0, inplace=True)
    (6): Linear(in_features=10, out_features=2, bias=True)
  )
)

Wasserstein distance has the following properties: 
1) W(aX,aY) = |a|W(X,Y)
2) W(X+x, Y+x) = W(X,Y)
3) W^2_2(X+x,Y) = ||x+E(X)-E(Y)|| + W^2_2(X,Y)
Next step is to implement these properties 

In [17]:
num_epochs = 500
running_loss = []
for t in range(num_epochs):
    for n_batch, batch in enumerate(loader):
        n_data = Variable(batch.to(device), requires_grad=True)
        a = torch.rand(1).to(device)
        b = torch.rand(1).to(device)
       
    
        optimizer.zero_grad()
        y = model(n_data)
        y_a = model(a*n_data)
        y_translate = model(n_data + b)
        
        loss = 0
       
        for i in range(len(batch)):
            for j in range(i+1,len(batch)):
                
                y_ij = torch.norm(y[i]-y[j], p=2)
                w_ij = sinkhorn(n_data[i],n_data[j]) 
                
                ya_ij = torch.norm(y_a[i]-y_a[j], p=2)
                y_translate_ij = torch.norm(y_translate[i]-y_translate[j], p=2)
                
                diff_translate_ij = torch.norm(y_translate[i]-y[j], p=2)**2
                
    
                loss += torch.norm(y_ij-w_ij, p=2) + (ya_ij-a*y_ij)**2 + (y_translate_ij- y_ij)**2 + (diff_translate_ij - (b + torch.mean(n_data[i]- n_data[j]))**2 - y_ij**2)**2
                                              
                
                del w_ij
        #TODO FIX THE LAST TERMS WITH PAIRWISE DISTANCES (SEE PYTORCH CODE)
        
        
        loss = loss/(len(batch)*(len(batch)-1)/2)
       
        loss.backward()
       
        optimizer.step()
    
        
    running_loss.append(loss)
   
   

In [33]:
#normal_encode_1D pkl is after 100 epochs. Do not do this again
torch.save(model.state_dict(),'normal_encode_1D_600.pkl')

In [32]:
torch.save({
            
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            
            }, 'normal_encode_1D_600epoch.pt')

In [7]:
from torch.distributions import normal

In [8]:
N1 = normal.Normal(torch.tensor([0.0]), torch.tensor([0.01]))
N2 = normal.Normal(torch.tensor([0.0]), torch.tensor([0.001]))
N3 = normal.Normal(torch.tensor([0.0]), torch.tensor([0.002]))
N4 = normal.Normal(torch.tensor([0.0]), torch.tensor([0.003]))
N5 = normal.Normal(torch.tensor([0.0]), torch.tensor([0.004]))
N6 = normal.Normal(torch.tensor([0.0]), torch.tensor([0.005]))
N7 = normal.Normal(torch.tensor([0.0]), torch.tensor([0.006]))
N8 = normal.Normal(torch.tensor([0.0]), torch.tensor([0.007]))
N9 = normal.Normal(torch.tensor([0.0]), torch.tensor([0.008]))
N0 = normal.Normal(torch.tensor([0.0]), torch.tensor([0.009]))

In [9]:
n1 = N1.sample([500]).view(1,-1,1)
n2 = N2.sample([500]).view(1,-1,1)
n3 = N3.sample([500]).view(1,-1,1)
n4 = N4.sample([500]).view(1,-1,1)
n5 = N5.sample([500]).view(1,-1,1)
n6 = N6.sample([500]).view(1,-1,1)
n7 = N7.sample([500]).view(1,-1,1)
n8 = N8.sample([500]).view(1,-1,1)
n9 = N9.sample([500]).view(1,-1,1)
n0 = N0.sample([500]).view(1,-1,1)

In [10]:
model(n1)

tensor([[-16.7333,   2.0709]], grad_fn=<AddmmBackward>)

In [11]:
model(n2)

tensor([[-16.7347,   2.0731]], grad_fn=<AddmmBackward>)

In [12]:
model(n3)

tensor([[-16.7349,   2.0725]], grad_fn=<AddmmBackward>)

In [13]:
model(n4)

tensor([[-16.7346,   2.0730]], grad_fn=<AddmmBackward>)

In [14]:
model(n5)

tensor([[-16.7348,   2.0720]], grad_fn=<AddmmBackward>)

In [15]:
model(n6)

tensor([[-16.7337,   2.0743]], grad_fn=<AddmmBackward>)

In [16]:
model(n7)

tensor([[-16.7346,   2.0716]], grad_fn=<AddmmBackward>)

In [17]:
model(n8)

tensor([[-16.7338,   2.0729]], grad_fn=<AddmmBackward>)

In [18]:
model(n9)

tensor([[-16.7335,   2.0732]], grad_fn=<AddmmBackward>)

In [19]:
model(n0)

tensor([[-16.7328,   2.0738]], grad_fn=<AddmmBackward>)

In [20]:
model(torch.zeros(1,500,1))

tensor([[-16.7348,   2.0729]], grad_fn=<AddmmBackward>)