In [None]:
##########################
#We recreate the MDN implementation described by Bishop C. [1994]
#########################

In [None]:
#import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
plt.rcParams["figure.figsize"]=15,15

In [None]:
#define the toy model inverse problem
def inv_model(t):
    #create a an array of uniform random variables
    eps = np.random.uniform(-0.1,0.1,t.shape)
    return t + 0.3*np.sin(2*np.pi*t) + eps
    
#we create a dataset
X_train = np.random.uniform(0,1,(4000))
Y_train = torch.Tensor(inv_model(X_train)).unsqueeze(1)
X_train = torch.Tensor(X_train).unsqueeze(1)

X_val = np.random.uniform(0,1,(4000))
Y_val = torch.Tensor(inv_model(X_val)).unsqueeze(1)
X_val = torch.Tensor(X_val).unsqueeze(1)

X_test = np.random.uniform(0,1,(4000))
Y_test = torch.Tensor(inv_model(X_test)).unsqueeze(1)
X_test = torch.Tensor(X_test).unsqueeze(1)

In [None]:
#plot the toy model
#predict on test set
X_test_sorted,_=torch.sort(X_test,0)
Y_test_predict = net(X_test_sorted).detach().numpy()
x = np.linspace(0,1,5000)
y = inv_model(x)
plt.scatter(x,y , color='grey', marker='o',s=5,facecolors='none', label="true function")
plt.plot(X_test_sorted,Y_test_predict,  "black", label="NN")
plt.legend(loc="lower right")
plt.xlabel("x")
plt.ylabel("y")
plt.title("True vs. Approximated function")
plt.show()

In [None]:
#######################
#### First we use a regular neural network to fit the inverse function
#######################

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.regressor = nn.Sequential(
            nn.Linear(1, 100),
            nn.ReLU(),
            nn.Linear(100, 1)
        )
        
    def forward(self, x):
        x = self.regressor(x)
        return x

net = Net()
params = list(net.parameters())

criterion = nn.MSELoss(reduction="mean")
optimizer = optim.Adam(net.parameters(), lr=0.001)

loss_v=[]
loss_t=[]

traindata = TensorDataset( X_train,Y_train )
dataloader = DataLoader(traindata, batch_size= 500, shuffle=True)
    
for epoch in range(5000):  # loop over the dataset multiple times
    net.train()
    ######NOTE THAT WE SWAP THE X AND Y HERE #########
    for batch_idx, (Y, X) in enumerate(dataloader):
        # forward + backward + optimize
        optimizer.zero_grad()
        outputs = net(X)
        loss = criterion(input=outputs, target=Y)
        loss.backward()
        optimizer.step()
        loss_t.append(loss.item()) 
        loss_v.append(criterion(net(Y_val), X_val).item())
    if epoch%100 ==0:
        print(f"epoch = {epoch} train loss = {loss_t[-1]} validation loss = {loss_v[-1]}")
        
print('Finished Training')

plt.plot(loss_t, "-b", label="train")
plt.plot(loss_v , "-r", label="validation")
plt.legend(loc="upper right")
plt.xlabel("epochs")
plt.ylabel("loss")
plt.title("Train vs validation loss")

plt.show()

In [None]:
#plot the toy model
#predict on test set
Y_test_sorted,_=torch.sort(Y_test,0)
X_test_predict = net(Y_test_sorted).detach().numpy()

x = np.linspace(0,1,1000)
y = inv_model(x)
plt.scatter(y,x , color='grey', marker='o',s=5,facecolors='none', label="true function")
plt.plot(Y_test_sorted,X_test_predict,  "black", label="NN")
plt.legend(loc="lower right")
plt.xlabel("x")
plt.ylabel("y")
plt.title("True vs. Approximated function - On Input and Target Variables Interchanged")
plt.show()

In [None]:
#######################
#### Mixture Density Network to fit the inverse function
#######################

def MDN_Output_Layer(x):
    #we convert the final layer of the network into the MDN outputs
    with torch.no_grad():
        n=int(x.shape[1]/3)
    #first third - softmax layer for the alpha weightings
    alphas=torch.nn.functional.softmax(x[:,0:n],1)
    #second third - exponential layer for the variance
    variance =torch.exp(x[:,n:(2*n)])
    #final third - unchanged, for the means
    means = x[:,-n:]
    x_adjusted = torch.cat((alphas,variance,means),1)
    return x_adjusted

def MDN_loss(x,y):
    #we compute the MDN loss using a custom loss function
    with torch.no_grad():
        n=int(x.shape[1]/3)
    #first third - alpha weightings
    alphas = x[:,0:n]
    #second third - variance
    variances = x[:,n:(2*n)]
    #final third - means
    means = x[:,-n:]
    #compute the loss
    gaussians = torch.exp(-((y.repeat(1,n) - means)**2)/(2*variances**2))/(variances * (2*3.1415927410125732)**(0.5))
    #return the loss
    loss = -torch.sum(torch.log(torch.sum(gaussians*alphas,1)))/x.shape[0]
    return loss

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.regressor = nn.Sequential(
            nn.Linear(1, 50),
            nn.Tanh(),
            nn.Linear(50, 34*3)
        )
        
    def forward(self, x):
        x = self.regressor(x)
        #apply the custom output layer
        return MDN_Output_Layer(x)

net = Net()
params = list(net.parameters())

#We apply the MDN loss function
criterion = MDN_loss
optimizer = optim.Adam(net.parameters(), lr=0.001)

loss_v=[]
loss_t=[]
    
traindata = TensorDataset( X_train,Y_train )
dataloader = DataLoader(traindata, batch_size= 4000, shuffle=True)
    
for epoch in range(6000):  # loop over the dataset multiple times
    net.train()
    ######NOTE THAT WE SWAP THE X AND Y HERE #########
    for batch_idx, (Y, X) in enumerate(dataloader):
        # forward + backward + optimize
        optimizer.zero_grad()
        outputs = net(X)
        loss = criterion(outputs, Y)
        loss.backward()
        optimizer.step()
        loss_t.append(loss.item()) 
    
        #evaluate the loss on validation set    
        loss_v.append(criterion(net(Y_val), X_val).item())
    if epoch%100 ==0:
        print(f"epoch = {epoch} train loss = {loss_t[-1]} validation loss = {loss_v[-1]}")
        
print('Finished Training')

plt.plot(loss_t, "-b", label="train")
plt.plot(loss_v , "-r", label="validation")
plt.legend(loc="upper right")
plt.xlabel("epochs")
plt.ylabel("loss")
plt.title("Train vs validation loss")

plt.show()

In [None]:
def random_choice_prob_index(a, axis=1):
    r = np.expand_dims(np.random.rand(a.shape[1-axis]), axis=axis)
    return (a.cumsum(axis=axis) > r).argmax(axis=axis)

def MDN_predict(x):
    #we make a random sample from the distibution conditional on x
    n=int(x.shape[1]/3)
    alphas = x[:,0:n]
    #second third - variance
    variances = x[:,n:(2*n)]
    #final third - means
    means = x[:,-n:]
    idx = random_choice_prob_index(alphas)[:,None]
    selections = np.random.normal(np.take_along_axis(means,idx,axis=1),np.take_along_axis(variances,idx,axis=1))
    ###################
    #Optional code to plot the mode fit
    #idx = np.argmax(alphas,1)[:,None]
    #selections = np.take_along_axis(means,np.argmax(alphas,1)[:,None],axis=1)
    ###################
    return selections

#plot the toy model
X_test_predict_MDN = net(Y_test).detach().numpy()
X_test_predict = MDN_predict(X_test_predict_MDN)
plt.scatter(Y_test,X_test , color='red', marker='o',s=5,facecolors='none', label="True Function")
plt.scatter(Y_test,X_test_predict, color='blue', marker='o',s=5,facecolors='none', label="MDN")
plt.legend(loc="lower right")
plt.xlabel("x")
plt.ylabel("y")
plt.title("MDN True vs. Approximated function - On Input and Target Variables Interchanged - " + version)
plt.show()    

In [None]:
#need to compute the cdf by randomly picking alphas and then corresponding gaussians.
def mdn_pdf(t,x):
    n=int(x.shape[1]/3)
    alphas = x[:,0:n]
    #second third - variance
    variances = x[:,n:(2*n)]
    #final third - means
    means = x[:,-n:]
    norm_pdfs = scipy.stats.norm.pdf(t,means, variances)
    return np.sum(alphas*norm_pdfs,1)

def mdn_sim_cdf(x):
    #simulate a cdf for the the MDN
    res=1000
    cdf = None
    for i in range(res):
        this_pdf = (mdn_pdf((i/(res-1))*np.ones((p.shape[0],1)),x)).reshape(-1,1)        
        if (cdf is None):
            cdf=this_pdf
        else:
            cdf=np.concatenate((cdf,this_pdf),axis=1)
    #now we have a bunch of pdfs, create a cdf
    cdf = np.cumsum(cdf,1)
    cdf = cdf/cdf[:,-1].reshape(-1,1)
    return cdf
    
#compute the scores
# the score functions
#average negative log predictive density (NLPD) (Good, 1952)
def NLPD(x_obs,y_obs,p):
    #compute the pdfs for each y|x
    prob = mdn_pdf(x_obs.view(-1,1).numpy(),p).reshape(-1,1)
    #set the minimum to 1/n, taking n=4000 to avoid infinite predictions
    prob = (prob<1/p.shape[0])*1/p.shape[0]+prob*(prob>=1/p.shape[0])
    return -np.sum(np.log(prob))/prob.shape[0]

#MAE to the median of the distribution 
def MAE(x_obs,y_obs,cdf):
    res=cdf.shape[1]
    medians = cdf*(cdf<=0.5)
    medians = (np.argmax(medians,1)/(res-1)).reshape(-1,1)
    return np.sum(np.abs(medians-x_obs.view(-1,1).numpy()))/medians.shape[0]

#Continuous Ranked Probability Score or the CRPS. The CRPS (Gneiting & Raftery, 2004)
def CRPS(x_obs,y_obs,cdf_in):
    #resolution 
    res = cdf_in.shape[1]
    #we get the cdf values at each point
    cdf = None
    for i in range(res):
        this_cdf = cdf_in[:,i].reshape(-1,1)
        if (cdf is None):
            cdf=this_cdf
        else:
            cdf=np.concatenate((cdf,this_cdf),axis=1)
    cdf_lhs = cdf*cdf*(cdf<=x_obs.view(-1,1).numpy())
    cdf_rhs = (cdf-1)*(cdf-1)*(cdf>=x_obs.view(-1,1).numpy())
    
    #now compute the approx area using trapezium rule
    cdf_lhs = np.sum(cdf_lhs[:,:-1] +cdf_lhs[:,1:],1)/(res*2)
    cdf_rhs = np.sum(cdf_rhs[:,:-1] +cdf_rhs[:,1:],1)/(res*2)
    crps = np.sum(cdf_lhs+cdf_rhs)/p.shape[0]
    return crps


In [None]:
p = net(Y_test).detach().numpy()
mdn_cdf = mdn_sim_cdf(p)
print(f"NLPD = {NLPD(X_test,Y_test,p)}")
print(f"MAE = {MAE(X_test,Y_test,mdn_cdf)}")
print(f"CRPS = {CRPS(X_test,Y_test,mdn_cdf)}")

Y_density = torch.linspace(0.0, 1.0,1000).view(-1,1)
X_density = torch.linspace(0.0, 1.0,1000).view(-1,1)
X_density_MDN = net(Y_density).detach()
X, Y = torch.meshgrid(X_density.view(-1), Y_density.view(-1))
Z = torch.zeros(X.shape)
for i in range(1000):
    Z[:,i] = torch.from_numpy(mdn_pdf(X[i,:].view(-1,1).numpy(),X_density_MDN.numpy()))
    Z[:,i] = Z[:,i]*(Z[:,i]>0)
    Z[:,i] = Z[:,i]/torch.sum(Z[:,i])
plt.contourf(X, Y, Z, 5, cmap='RdGy')
cbar = plt.colorbar()
cbar.ax.set_ylabel('Probability Density', rotation=90)
plt.xlabel("x")
plt.ylabel("y")
plt.title("MDN - " + version)
plt.show()
