In [9]:
import jax
import jax.numpy as jnp 
import jax.scipy.stats as stats
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import ipywidgets as widgets
import jax.scipy as scipy

key = jax.random.PRNGKey(0)




In [10]:
def Ellipse(mean , cov):
  """
  This functions graphs a confidence ellipse for mu based on the eigenvalues 
  and eigenvectors for the covariance matrix S.

  returns the ellipse in terms of X and Y coordinates
  """
  eigvect, eigval, vh = jnp.linalg.svd(cov, full_matrices=True)

  # Calculate the angle of rotation.
  # since the covariance matrix is positive definite, 
  # the singular value decomposition, results in 
  # u - matrix with all eigenvectors
  # s - matrix with eigenvalues.
  # Because we are taking this as output of svd(),
  # the eigenvector in column 1 will correspond with the largest eigenvalue,
  # which will thus be the eigenvector associated with the major axis,
  # which is what is used for determining the angle of rotation.
  theta = jnp.arctan(eigvect[0,1]/eigvect[0,0])

  # clevel = jnp.sqrt(chi2.ppf(0.95, 2))
  # the scipy library of jax does not have PPF
  clevel = jnp.sqrt(5.991)
  
  # From Johnson & Wichern (2008) result (4-8)
  # The half-lengths are c*sqrt(lambda_i);
  a = jnp.sqrt(eigval[0])*clevel
  b = jnp.sqrt(eigval[1])*clevel

  # generate the ellipse with given major and minor axis length
  t = jnp.linspace(0,2*jnp.pi,num=100)
  ellipseX = a*jnp.sin(t)
  ellipseY = b*jnp.cos(t)
  
  # rotate the ellipse by angle theta
  ellipseXnew = ellipseX*jnp.cos(theta) - ellipseY*jnp.sin(theta)
  ellipseYnew = ellipseX*jnp.sin(theta) + ellipseY*jnp.cos(theta)

  # shift it by mean
  ellipseX = mean[0] + ellipseXnew
  ellipseY = mean[1] + ellipseYnew
  
  return ellipseX,ellipseY


In [11]:
def plot_bivariate_graph(mean0,mean1,cov00,cov11,cov01):
  """
  This funtion plots the desired graph for the Bivariate Normal distribution.
  """
  # generate samples with given mean and varience.
  no_of_samples = 5000
  mean = jnp.array([mean0,mean1])
  cov = jnp.array([[cov00,cov01],[cov01,cov11]])
  X = jax.random.multivariate_normal(key = key,mean=mean, cov = cov,shape= (no_of_samples,))

  # plot the points inside the figure
  fig = plt.figure(figsize=[12,6] )
  ax = fig.add_subplot(projection='3d')
  ax.scatter(X[:,0],X[:,1],s=0.5)

  # plot the 95% confidene contour.
  ellipseX,ellipseY= Ellipse(mean,cov)
  ax.plot3D(ellipseX,ellipseY,jnp.zeros(ellipseX.shape[0]),color='red')

  # draw histograms at the walls.
  lim =6
  histx, binx = jnp.histogram(X[:,0],bins=20)
  histx = histx/(jnp.sum(histx)*(binx[1]-binx[0]))
  Z_zeros = jnp.zeros(histx.shape[0])
  Y_ones = jnp.ones(histx.shape[0])*lim
  ax.bar3d( (binx+((-binx[0]+binx[1])/2))[:-1],Y_ones,Z_zeros,
          ((-binx[0]+binx[1]))*jnp.ones(histx.shape[0]),Z_zeros,histx,color= 'gray')
  # join centers of bars to get the approximate pdf curve
  ax.plot3D((binx+((-binx[0]+binx[1])))[:-1],Y_ones ,histx)

  # at the other wall.
  histy, biny = jnp.histogram(X[:,1],bins=20)
  histy = histy/(jnp.sum(histy)*(biny[1]-biny[0]))
  Z_zeros = jnp.zeros(histy.shape[0])
  X_ones = jnp.ones(histy.shape[0])*(-lim)
  ax.bar3d( X_ones,(biny+((-biny[0]+biny[1])/2))[:-1] ,Z_zeros,Z_zeros,
          ((-biny[0]+biny[1]))*jnp.ones(histy.shape[0]),histy,color= 'gray')
  ax.plot3D(X_ones,(biny+((-biny[0]+biny[1])))[:-1] ,histy)

  # get a good look
  ax.set_zlim(0, 0.5)
  ax.set_xlim(-lim,lim)
  ax.set_ylim(-lim,lim)

  ax.text(-6,-6,0.25,"P(Y)",zdir='y')
  ax.set_xlabel("X axis")
  ax.set_ylabel("Y axis")
  ax.set_zlabel("P(X)")
  plt.show()

In [14]:
# widgets
cov00 = widgets.FloatSlider(value=2.5,min=0.6,max=2.5,step=0.1,continuous_update=False,description='covXX')
cov01 = widgets.FloatSlider(value=0,min=0,max=1,step=0.05,continuous_update=False,description='covXY')
cov10 = widgets.FloatSlider(value=0,min=0,max=1,step=0.05,continuous_update=False,description='covYX')
cov11 = widgets.FloatSlider(value=2.5,min=0.6,max=2.5,step=0.1,continuous_update=False,description='covYY')
mean0 = widgets.FloatSlider(value=0,min=-2.5,max=2.5,step=0.5,continuous_update=False,description='meanX')
mean1 = widgets.FloatSlider(value=0,min=-2.5,max=2.5,step=0.5,continuous_update=False,description='meanY')
# link Cov[0,1] and Cov[1,0] as it is symmetric
mylink = widgets.jslink((cov01, 'value'), (cov10, 'value'))
ui1 = widgets.HBox([cov00, cov01])
ui2 = widgets.HBox([cov10,cov11])
uicov = widgets.VBox([ui1,ui2])
uimean = widgets.VBox([mean0,mean1])
ui = widgets.VBox([uimean, uicov])


out = widgets.interactive_output(plot_bivariate_graph, {'mean0': mean0, 'mean1': mean1, 'cov00': cov00,'cov01':cov01,'cov11':cov11})
display(ui,out)

VBox(children=(VBox(children=(FloatSlider(value=0.0, continuous_update=False, description='meanX', max=2.5, mi…

Output()