In [3]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.model_selection import train_test_split
import math
from scipy.stats import wasserstein_distance

In [4]:
print(f'CUDA available: {torch.cuda.is_available()}')
print(f'CUDA current: {torch.cuda.current_device()}')

CUDA available: False


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [None]:
class dataGenerator():
    def __init__(self, batchsize, bins):
        self.batchsize=batchsize
        self.bins=bins
    
    def datagen(self):
        histograms = np.zeros((self.batchsize, 30), dtype=np.int64)
        means = np.zeros(self.batchsize)
        sigmas = np.zeros(self.batchsize)
        total_counts = np.zeros(self.batchsize)
        for i in range(self.batchsize):
            nsamples = np.random.randint(0,10000)
            mu = np.random.uniform(0.2,0.4)
            sigma = np.random.uniform(0.1, 0.3)

            samples = np.random.normal(mu, sigma, nsamples)

            hist,_ = np.histogram(samples, bins=self.bins, range=(0,1))
            histograms[i,:] = hist
            means[i] = mu
            sigmas[i] = sigma
            total_counts[i] = hist.sum()

        return histograms, means, sigmas, total_counts

In [None]:
data = dataGenerator(2500000,30)

In [None]:
hist, means, sigmas, counts = data.datagen()

In [None]:
means = np.array(means)
sigmas = np.array(sigmas)
counts = np.array(counts)
conds = np.column_stack([means, sigmas, counts])

In [None]:
hist_train, hist_test, conds_train, conds_test = train_test_split(hist,conds, test_size=0.2, random_state=42)

hist_test = torch.tensor(hist_test, dtype=torch.float32)
hist_test = hist_test/hist_test.max().item()
hist_train = torch.tensor(hist_train, dtype=torch.float32)
hist_train = hist_train/hist_train.max().item()
conds_test = torch.tensor(conds_test, dtype=torch.float32)
conds_train = torch.tensor(conds_train, dtype=torch.float32)

In [None]:
class Encoder(torch.nn.Module):
    def __init__(self, cond_dim: int, hidden_dim: int, input_dim=30):
        super().__init__()
        self.input_dim=input_dim
        self.cond_dim=cond_dim
        self.total_dim=self.input_dim + self.cond_dim
        self.Mu = torch.nn.Sequential(
            torch.nn.Linear(self.total_dim,hidden_dim),
            torch.nn.LeakyReLU(0.2),
            torch.nn.BatchNorm1d(hidden_dim),

            torch.nn.Linear(hidden_dim,hidden_dim * 2),
            torch.nn.LeakyReLU(0.2),
            torch.nn.BatchNorm1d(hidden_dim * 2),

            torch.nn.Linear(hidden_dim * 2, hidden_dim * 4),
            torch.nn.LeakyReLU(0.2),
            torch.nn.BatchNorm1d(hidden_dim * 4),

            torch.nn.Linear(hidden_dim * 4, hidden_dim * 2),
            torch.nn.LeakyReLU(0.2),
            torch.nn.BatchNorm1d(hidden_dim * 2),

            torch.nn.Linear(hidden_dim * 2, hidden_dim),
            torch.nn.LeakyReLU(0.2),
            torch.nn.BatchNorm1d(hidden_dim),

            torch.nn.Linear(hidden_dim, 32),
            torch.nn.Softplus()
        )

        self.Sigma = torch.nn.Sequential(
            torch.nn.Linear(self.total_dim,hidden_dim),
            torch.nn.LeakyReLU(0.2),
            torch.nn.BatchNorm1d(hidden_dim),

            torch.nn.Linear(hidden_dim,hidden_dim * 2),
            torch.nn.LeakyReLU(0.2),
            torch.nn.BatchNorm1d(hidden_dim * 2),

            torch.nn.Linear(hidden_dim * 2, hidden_dim * 4),
            torch.nn.LeakyReLU(0.2),
            torch.nn.BatchNorm1d(hidden_dim * 4),

            torch.nn.Linear(hidden_dim * 4, hidden_dim * 2),
            torch.nn.LeakyReLU(0.2),
            torch.nn.BatchNorm1d(hidden_dim * 2),

            torch.nn.Linear(hidden_dim * 2, hidden_dim),
            torch.nn.LeakyReLU(0.2),
            torch.nn.BatchNorm1d(hidden_dim),

            torch.nn.Linear(hidden_dim, 32),
            torch.nn.Softplus()
        )


    def forward(self, input_vec, cond_vec):

        x = torch.cat([input_vec, cond_vec], dim=1)

        mu = self.Mu(x)
        sigma = self.Sigma(x)

        return mu, sigma  