In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import torch
os.chdir('/content/drive/MyDrive/Softmax_sampling/code')


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

Mounted at /content/drive


device(type='cpu')

# Utils

In [None]:
import os
import numpy as np
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm.notebook import trange, tqdm
from torch.optim.lr_scheduler import LambdaLR
from copy import deepcopy

import argparse
import time
import math
import torch
import torch.nn as nn
from data_utils import *
#import model
import easydict
import torch.nn.functional as F
import torch.optim as optim
#import Adam

import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
def gram_schmidt_columns(X):
    '''
    Using QR decomoposition to obtain orthogonal matrix.
    
    Parameters
    ----------
    X : matrix, dimension = m * d, where m <= d
        Random feature matrix with l2 normalized row.
    Returns
    -------
    Q : matrix, dimension = m * d, where m <= d
        Orthogonal random feature matrix with l2 normalized row.
    '''
    Q, R = np.linalg.qr(X)
    return Q

def orthgonalize(V):
    '''
    Generate matrix with multiple orthogonal blocks
    Parameters
    ----------
    V : matrix, dimension = m * d, where m > d
        Random feature matrix with l2 normalized row.
    Returns
    -------
    V_ : TYPE
        Random feature matrix with l2 normalized row and multiple
        blocks.
    '''
    N = V.shape[0]
    d = V.shape[1]
    turns = int(N/d)
    remainder = N%d

    if turns:
        V_ = np.zeros_like(V)

        for i in range(turns):
            v = gram_schmidt_columns(V[i*d:(i+1)*d, :].T).T
            V_[i*d:(i+1)*d, :] = v
        if remainder != 0:
            V_[(i+1)*d:,:] = gram_schmidt_columns(V[(i+1)*d:,:].T).T
    else:
        V_ = gram_schmidt_columns(V.T).T

    return V_

def orthogonal_gau(dim_0, dim_1):

    V = np.random.normal(0, 1, (dim_0, dim_1))
    norms = np.linalg.norm(V, axis = 1)[:, np.newaxis]
    V_orth = orthgonalize(V)
    
    return V_orth*norms

def trig_att(x, y, random_feats_sfm):
    
    x_feat = np.sqrt(1/(random_feats_sfm.shape[0])) *\
                 np.exp(np.linalg.norm(x, axis = 1)**2/2)[:, np.newaxis] *\
                 np.vstack((np.sin(random_feats_sfm.dot(x.T)), \
                            np.cos(random_feats_sfm.dot(x.T)))).T
    #print('x_feat shape ', x_feat.shape)  
    y_feat = np.sqrt(1/(random_feats_sfm.shape[0])) *\
                 np.exp(np.linalg.norm(y, axis = 1)**2/2)[:, np.newaxis] *\
                 np.vstack((np.sin(random_feats_sfm.dot(y.T)), \
                            np.cos(random_feats_sfm.dot(y.T)))).T
    #print('y_feat shape ', y_feat.shape)    
  
    return np.dot(x_feat, y_feat.T)
    
def pos_att(x, y, random_feats_sfm):
    
    x_feat = np.sqrt(1/(2*random_feats_sfm.shape[0])) * \
                    np.exp(-np.linalg.norm(x, axis = 1)**2/2)[:, np.newaxis] *\
                    np.vstack((np.exp(random_feats_sfm.dot(x.T)), \
                                np.exp(-random_feats_sfm.dot(x.T)))).T
    #print('x_feat shape ', x_feat.shape)  
    y_feat = np.sqrt(1/(2*random_feats_sfm.shape[0])) * \
                    np.exp(-np.linalg.norm(y, axis = 1)**2/2)[:, np.newaxis] *\
                    np.vstack((np.exp(random_feats_sfm.dot(y.T)), \
                                np.exp(-random_feats_sfm.dot(y.T)))).T
    #print('y_feat shape ', y_feat.shape)    
  
    return np.dot(x_feat, y_feat.T)

def ang_hyb_lambda(x, y, random_feats_lambda):
    
    x_feat = np.hstack((np.repeat(np.sqrt(1/2), x.shape[0])[:, np.newaxis],\
                                      (1j*np.sqrt(1/(2*random_feats_lambda.shape[0])) *\
                                      np.sign(random_feats_lambda.dot(x.T))).T))
    #print('x_feat shape ', x_feat.shape)  
    y_feat = np.hstack((np.repeat(np.sqrt(1/2), y.shape[0])[:, np.newaxis],\
                                      (1j*np.sqrt(1/(2*random_feats_lambda.shape[0])) *\
                                      np.sign(random_feats_lambda.dot(y.T))).T))
    #print('y_feat shape ', y_feat.shape)    
  
    return np.dot(x_feat, y_feat.T).real

def gau_hyb_lambda(x, y, random_feats_lambda, lambda_=1):

    x_feat = (1*np.sqrt(1/(random_feats_lambda.shape[0])) *\
                      np.vstack((np.sin(lambda_*random_feats_lambda.dot(x.T)), \
                                np.cos(lambda_*random_feats_lambda.dot(x.T))))).T
    #print('x_feat shape ', x_feat.shape)  
    y_feat = (1*np.sqrt(1/(random_feats_lambda.shape[0])) *\
                      np.vstack((np.sin(lambda_*random_feats_lambda.dot(y.T)), \
                                np.cos(lambda_*random_feats_lambda.dot(y.T))))).T
    #print('y_feat shape ', y_feat.shape)    
  
  
    return np.dot(x_feat, y_feat.T)

def ang_hyb_att(x, y, random_feats_sfm, random_feats_lambda):

    approx_softmax_trig_hyb = trig_att(x, y, random_feats_sfm)
            
    approx_softmax_pos_hyb = pos_att(x, y, random_feats_sfm)
    
    approx_softmax_ang = ang_hyb_lambda(x, y, random_feats_lambda)

    approx_softmax_hyb_ang = np.multiply((approx_softmax_ang), approx_softmax_pos_hyb) + \
                            np.multiply((1 - approx_softmax_ang), approx_softmax_trig_hyb)

    del approx_softmax_trig_hyb, approx_softmax_pos_hyb, approx_softmax_ang
    return approx_softmax_hyb_ang

def gau_hyb_att(x, y, random_feats_sfm, random_feats_lambda):

    approx_softmax_trig_hyb = trig_att(x, y, random_feats_sfm)
            
    approx_softmax_pos_hyb =pos_att(x, y, random_feats_sfm)

    approx_softmax_gau = gau_hyb_lambda(x, y, random_feats_lambda)
            
    approx_softmax_hyb_gau = np.multiply((1-approx_softmax_gau), approx_softmax_pos_hyb) + \
                            np.multiply(approx_softmax_gau, approx_softmax_trig_hyb)

    del approx_softmax_trig_hyb, approx_softmax_pos_hyb, approx_softmax_gau
    return approx_softmax_hyb_gau

# Generate Embeddings

In [None]:
class RNNModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.2):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp) # Token2Embeddings
       
        self.rnn = nn.LSTM(ninp, ninp, nlayers, dropout=dropout)
        
        self.decoder = nn.Linear(nhid, ntoken, bias=False)

        # Optionally tie weights as in:
        # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
        # https://arxiv.org/abs/1608.05859
        # and
        # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
        # https://arxiv.org/abs/1611.01462
    
        self.decoder.weight = self.encoder.weight

        self.init_weights()

        
        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.encoder.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.weight)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)
       

    def forward(self, input, hidden, softmax_temp = 1):

        #self.encoder.weight = torch.nn.Parameter(self.encoder.weight / torch.norm(self.encoder.weight, dim=1, keepdim=True))
        emb = self.drop(self.encoder(input))
        #emb = F.normalize(emb, p=2, dim=1)
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        #output = F.normalize(output.view(output.size(0)*output.size(1), output.size(2)), p=2, dim=1)
        output = output.view(output.size(0)*output.size(1), output.size(2))
        #self.decoder.weight = torch.nn.Parameter(self.decoder.weight / torch.norm(self.decoder.weight, dim=1, keepdim=True))

        decoded = np.sqrt(softmax_temp) * self.decoder(np.sqrt(softmax_temp) * output)
        return decoded, hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
       
        return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
                    Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))

In [None]:
train_batch_size = 32
eval_batch_size = 10

data_path = '/content/drive/MyDrive/Softmax_sampling/data/ptb'
corpus_raw = Corpus(data_path)

train_data = batchify(corpus_raw.train, train_batch_size) # size(total_len//bsz, bsz)
val_data = batchify(corpus_raw.valid, eval_batch_size)
test_data = batchify(corpus_raw.test, eval_batch_size)

In [None]:
interval = 200 # interval to report
ntokens = len(corpus_raw.dictionary)


#model hyperparameters

hidden_size = 200

n_layers = 2
net = RNNModel(ntokens, hidden_size, hidden_size, n_layers, dropout=.2)
bptt = 64

In [None]:
#Download from drive and put the right path
net.load_state_dict(torch.load('/content/drive/MyDrive/Softmax_sampling/models/ptb/lstm_ptb_tied_warmstart_try4.pkl', map_location=torch.device('cpu')))
net.to(device)

RNNModel(
  (drop): Dropout(p=0.2, inplace=False)
  (encoder): Embedding(10000, 200)
  (rnn): LSTM(200, 200, num_layers=2, dropout=0.2)
  (decoder): Linear(in_features=200, out_features=10000, bias=False)
)

In [None]:
def get_model_embeddings(data_source, net, bptt):
    # Turn on evaluation mode which disables dropout.
    """Computes the model embeddings. 
    Args: data_source = test dataloder
          net = trained LSTM with weight tied
          bptt = Batch size as defined in the args in main.py
    
    Output: Tensor of shape (data_source.reshape(-1), output of net) 
    """
    with torch.no_grad():
        net.eval()
       
        ntokens = len(corpus_raw.dictionary)
        hidden = net.init_hidden(eval_batch_size) #hidden size(nlayers, bsz, hdsize)
        model_out = []
        for i in range(0, data_source.size(0) - 1, bptt):# iterate over every timestep

            data, targets = get_batch(data_source, i)
            data, targets = data.to(device), targets.to(device)
            
            emb = net.encoder(data)
            emb = F.normalize(emb, p=2, dim=1)
            output, hidden = net.rnn(emb, hidden)
            #model_out.append(F.normalize(output.reshape(-1, output.shape[-1]), p=2, dim=1))
            model_out.append(output.reshape(-1, output.shape[-1]))

            # model input and output
            # inputdata size(bptt, bsz), and size(bptt, bsz, embsize) after embedding
            # output size(bptt*bsz, ntoken)
            
    return torch.cat((model_out), dim=0)

def get_class_embeddings(net):

    """Computes class embeddings. 
    Args: net = trained LSTM with weight tied. 
    Outputs: class embeddings. Tensor of shape (ntokens, output size of net)
    
    """
    classes = torch.tensor([list(corpus_raw.dictionary.word2idx.values())]).squeeze()
    embeddings = net.encoder(classes.to(device))
    
    return embeddings

def cross_entropy(X,y):
  
    """
    X is the output from fully connected layer (num_examples x num_classes)
    y is labels (num_examples x 1)
    Note that y is not a one-hot encoded vector. 
   
    """
    m = y.shape[0]
    log_likelihood = -np.log(X[range(m),y])

    loss = np.sum(log_likelihood) / m
    return loss

def evaluate(data_source):
    # Turn on evaluation mode which disables dropout.
    with torch.no_grad():
        net.eval()
        total_loss = 0
        ntokens = len(corpus_raw.dictionary)
        hidden = net.init_hidden(eval_batch_size) #hidden size(nlayers, bsz, hdsize)
        for i in range(0, data_source.size(0) - 1, 64):# iterate over every timestep
            data, targets = get_batch(data_source, i)
            data, targets = data.to(device), targets.to(device)
            output, hidden = net(data, hidden)
            # model input and output
            # inputdata size(bptt, bsz), and size(bptt, bsz, embsize) after embedding
            # output size(bptt*bsz, ntoken)
            total_loss += len(data) * criterion(output, targets).data
            hidden = repackage_hidden(hidden)
        return total_loss / len(data_source)

def cross_entropy(X,y):
    """
    X is the output from fully connected layer (num_examples x num_classes)
    y is labels (num_examples x 1)
    Note that y is not a one-hot encoded vector. 
   
    """
    m = y.shape[0]
    log_likelihood = -np.log(X[range(m),y])

    loss = np.sum(log_likelihood) / m
    return loss

## Evaluate model

In [None]:
criterion = nn.CrossEntropyLoss().to(device)
evaluate(test_data)

tensor(4.8644, device='cuda:0')

## Get embeddings

In [None]:
model_embeddings = get_model_embeddings(test_data, net, bptt)
class_embeddings = get_class_embeddings(net)

In [None]:
softmax_temp = 1
true_sfm = torch.exp(torch.matmul(softmax_temp*model_embeddings,
                        class_embeddings.T)).cpu().detach().numpy()
true_sfm = true_sfm/np.sum(true_sfm, axis =1, keepdims = True)           

In [None]:
true_sfm.shape

(82420, 10000)

# Wasserstein

In [None]:
from scipy.stats import wasserstein_distance

In [None]:
all_rf_dim = [64, 128, 256, 512]
num_samples = 100
#softmax_temp = 4

### Favor+

In [None]:
softmax_temp = 1
runtimes = 10
inter = 4121

for runtime in range(runtimes):

    all_wass_dist = []
    np.random.seed(runtime)
    for rf_dim in all_rf_dim:    
        #rf_dim = 128
      
        sfm_app = []
        orth = orthogonal_gau(rf_dim, model_embeddings.shape[1])

        for i in range(int(model_embeddings.shape[0]/inter)):
            sfm_app.append(pos_att(softmax_temp*model_embeddings.cpu().numpy()[i*inter : (i+1)*inter],
                          class_embeddings.cpu().detach().numpy(), orth))
        sfm_app = np.array(sfm_app)
        sfm_app = sfm_app.reshape(-1, sfm_app.shape[2])

        sfm_app[sfm_app<0] = 0
        sfm_app = sfm_app/np.sum(sfm_app, axis = 1, keepdims = True)
        wass_dist = []
        for i in trange(sfm_app.shape[0]):
            wass_dist.append(wasserstein_distance(u_values=np.arange(10000),
                                                  v_values=np.arange(10000),
                                                  u_weights=true_sfm[i],
                                                  v_weights=sfm_app[i]))
        all_wass_dist.append(sum(wass_dist)/len(wass_dist))
    np.save(f'pos_att_result_{runtime}.npy', np.array(all_wass_dist))

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

In [None]:
all_wass_dist

[3239.3008271261256,
 3130.7470326393714,
 2897.1225213358157,
 2929.7967400559037]

### RFF

In [None]:
softmax_temp = 1
runtimes = 10
inter = 4121

for runtime in range(runtimes):

    all_wass_dist = []  
    np.random.seed(runtime) 
    for rf_dim in all_rf_dim: 
      
      sfm_app = []
      orth = orthogonal_gau(rf_dim, model_embeddings.shape[1])
      
      for i in range(int(model_embeddings.shape[0]/inter)):
          sfm_app.append(trig_att(softmax_temp*model_embeddings.cpu().numpy()[i*inter : (i+1)*inter],
                        class_embeddings.cpu().detach().numpy(), orth))
      sfm_app = np.array(sfm_app)
      sfm_app = sfm_app.reshape(-1, sfm_app.shape[2])       
      sfm_app[sfm_app<0] = 0
      sfm_app = sfm_app/np.sum(sfm_app, axis = 1, keepdims = True)


      wass_dist = []
      for i in trange(sfm_app.shape[0]):
          wass_dist.append(wasserstein_distance(u_values=np.arange(10000),
                                                v_values=np.arange(10000),
                                                u_weights=true_sfm[i],
                                                v_weights=sfm_app[i]))
      all_wass_dist.append(sum(wass_dist)/len(wass_dist))

    np.save(f'trig_att_result_{runtime}.npy', np.array(all_wass_dist))

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

In [None]:
all_wass_dist

[1587.554145427954, 1551.1965454080964, 1584.3983024572085, 1560.385117444572]

### Angular

In [None]:
all_rf_dim = np.array([64, 128, 256, 512])
rf_dim_lam = np.array([8]*4)
rf_dim_base = all_rf_dim / rf_dim_lam

In [None]:
rf_dim_lam, rf_dim_base

(array([8, 8, 8, 8]), array([ 8., 16., 32., 64.]))

In [None]:
softmax_temp = 1
inter = 4121
runtimes = 10

for runtime in range(1, runtimes):

    all_wass_dist = [] 
    np.random.seed(runtime) 

    for k in range(4):    
        #rf_dim = 128
        
        orth_base = orthogonal_gau(int(rf_dim_base[k]), model_embeddings.shape[1])
        orth_lam = orthogonal_gau(int(rf_dim_lam[k]), model_embeddings.shape[1])
        sfm_app = []
        for i in range(int(model_embeddings.shape[0]/inter)):
            sfm_app.append(ang_hyb_att(softmax_temp*model_embeddings.cpu().numpy()[i*inter : (i+1)*inter],
                          class_embeddings.cpu().detach().numpy(), orth_base, orth_lam))
        sfm_app = np.array(sfm_app)
        sfm_app = sfm_app.reshape(-1, sfm_app.shape[2])
        sfm_app[sfm_app<0] = 0
        sfm_app = sfm_app/np.sum(sfm_app, axis = 1, keepdims = True)

        wass_dist = []
        for i in trange(sfm_app.shape[0]):
            wass_dist.append(wasserstein_distance(u_values=np.arange(10000),
                                                  v_values=np.arange(10000),
                                                  u_weights=true_sfm[i],
                                                  v_weights=sfm_app[i]))
        all_wass_dist.append(sum(wass_dist)/len(wass_dist))
    print(all_wass_dist)
    np.save(f'ang_hyb_att_result_{runtime}.npy', np.array(all_wass_dist)) 

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

[1218.1581643053407, 1371.8770340134681, 1407.420551954666, 1410.440249308614]


  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

[1469.0805682070782, 1342.688001581585, 1509.6228144945928, 1278.3311094172539]


  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

[1374.4500593604077, 1392.440752651851, 1370.29292618758, 1488.9770286983528]


  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

[1377.3243138049124, 1514.0228173671858, 1445.4151215549402, 1439.2200189183648]


  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

[1303.9887225256011, 1451.735559591694, 1401.6307843602794, 1505.8463310311026]


  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

[1448.9441023506868, 1487.352938250347, 1294.856484127074, 1463.4246680425647]


  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

[1387.0045933468543, 1485.208704571898, 1305.4415590976728, 1495.5910231451521]


  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

[1408.8640771762787, 1525.8213022597663, 1350.8427569388764, 1369.5612988374767]


  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

[1477.8632939273682, 1423.5682801152873, 1424.757234959597, 1416.6376660433368]


In [None]:
[1484.6591093271031, 1457.6874629301096, 1439.419395421821, 1390.7935842748614]

[1484.6591093271031, 1457.6874629301096, 1439.419395421821, 1390.7935842748614]

### Gaussian

In [None]:
softmax_temp = 1
inter = 4121
all_wass_dist = []
runtimes = 10

for runtime in range(runtimes):

    all_wass_dist = [] 
    np.random.seed(runtime) 
    for k in range(4):    
        #rf_dim = 128
        
        orth_base = orthogonal_gau(int(rf_dim_base[k]), model_embeddings.shape[1])
        orth_lam = orthogonal_gau(int(rf_dim_lam[k]), model_embeddings.shape[1])
        sfm_app = []
        for i in range(int(model_embeddings.shape[0]/inter)):
            sfm_app.append(gau_hyb_att(softmax_temp*model_embeddings.cpu().numpy()[i*inter : (i+1)*inter],
                          class_embeddings.cpu().detach().numpy(), orth_base, orth_lam))
        sfm_app = np.array(sfm_app)
        sfm_app = sfm_app.reshape(-1, sfm_app.shape[2])
        sfm_app[sfm_app<0] = 0
        sfm_app = sfm_app/np.sum(sfm_app, axis = 1, keepdims = True)

        wass_dist = []
        for i in trange(sfm_app.shape[0]):
            wass_dist.append(wasserstein_distance(u_values=np.arange(10000),
                                                  v_values=np.arange(10000),
                                                  u_weights=true_sfm[i],
                                                  v_weights=sfm_app[i]))
        all_wass_dist.append(sum(wass_dist)/len(wass_dist))

    np.save(f'gau_hyb_att_Wassresult_{runtime}.npy', np.array(all_wass_dist))

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

  0%|          | 0/82420 [00:00<?, ?it/s]

In [None]:
all_wass_dist

[1526.935236872064, 1509.4146823428707, 1515.5406360965894, 1515.3829860007952]