In [7]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import yaml
import pprint
import os
import time
# weights and biases for tracking of metrics
import wandb 
# make the plots inline again
%matplotlib inline
# sometimes have to activate this to plot plots in notebook
# matplotlib.use('Qt5Agg')
from code import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
mu_list = np.array([[-4.7503373e-01, -8.7996745e-01, -5.0922018e-04],
                    [-1.6167518e-01,  6.5595394e-01, -7.3728257e-01],
                    [ 2.6248896e-01,  6.9851363e-01,  6.6571641e-01],
                    [ 1.0,  0,  0.0]], dtype='float32')

k_list = 1.5*np.array([13, 14,  12, 15],dtype='float32')

mu_list = torch.tensor(mu_list)
k_list = torch.tensor(k_list)

In [9]:
def get_power_spherical_samples(mu_list, k_list, nr_samples):
    """
    Args:
        mu_list np.array
        k_list np.array
        nr_samples scalar
    Returns:
        out torch.tensor 
    """
    nr_mixtures = len(mu_list)
    mixt_components = np.random.randint(low=0, high=nr_mixtures,size = nr_samples)

    # then count how often every mixture components occurs in sample
    # mix_comp_counter is dict with mix comps as keys and nr of samplings as values
    mix_comp_counter = Counter(mixt_components)
    
    print(mix_comp_counter)

    data = torch.tensor([])

    # the sample for each mixture component, as many samples as they occured in the sampling of the components

    for mix_comp in mix_comp_counter:

        dist = PowerSpherical(loc=mu_list[mix_comp].clone().detach().float(), 
                              scale=k_list[mix_comp].clone().detach().float())

        sample_per_comp = dist.sample((mix_comp_counter[mix_comp],))

        data = torch.cat([data, sample_per_comp], dim=0)

    # shuffle tensor
    return data[torch.randperm(nr_samples),:]

In [10]:
import time
start_time = time.time()
for _ in range(int(1e1)):
    data = get_power_spherical_samples(mu_list, k_list, 1)
print(f'time {time.time() - start_time}')

Counter({1: 1})
Counter({2: 1})
Counter({1: 1})
Counter({0: 1})
Counter({3: 1})
Counter({0: 1})
Counter({2: 1})
Counter({2: 1})
Counter({3: 1})
Counter({2: 1})
time 0.017911672592163086


In [11]:
start_time = time.time()
data = get_power_spherical_samples(mu_list, k_list, int(1e3))
print(f'time {time.time() - start_time}')

Counter({0: 281, 3: 264, 1: 244, 2: 211})
time 0.004645347595214844


In [12]:
power_spherical_data = PowerSphericalData(mu_list=mu_list, k_list=k_list, nr_samples=int(1e3))

train_loader = DataLoader(power_spherical_data, batch_size=256, shuffle=True)

for train_set in train_loader:
    print(train_set.shape)
    break

Secs for entropy calc 0.7115859985351562
torch.Size([256, 3])
