In [1]:
import numpy as np
import jax
import scipy
import matplotlib.pyplot as plt

In [8]:
class MetropolisHastings:
    def __init__(self, target_distr, initial_state=None, seed=None):
        """
        Initialize the Metropolis-Hastings sampler.
        
        Parameters:
        - target_distr: function
            The target distribution from which we want to sample.
        - initial_state: array-like, optional
            The initial state of the Markov chain. Default is [0, 0].
        - seed: int, optional
            Random seed for reproducibility. Default is None.
        """
        self.target_distr = target_distr
        self.initial_state = np.array(initial_state) if initial_state is not None else np.array([0, 0])
        self.seed = seed
        self.samples = []
        self.acceptance_rate = 0.0
        self.mean = None
        self.covariance = None
        
        if seed is not None:
            np.random.seed(seed)
    
    def sample(self, n_samples):
        """
        Generate samples using the Metropolis-Hastings algorithm.
        
        Parameters:
        - n_samples: int
            The number of samples to generate.
        
        Returns:
        - samples: np.array
            Array of generated samples.
        """
        proposal_distr = scipy.stats.multivariate_normal
        sigmax = 1
        sigmay = 1
        rho = 0.7
        cov = np.array([[sigmax**2, sigmax * sigmay * rho], 
                        [sigmax * sigmay * rho, sigmay**2]])
        
        x_t = self.initial_state
        samples = []
        accept_count = 0
        
        while len(samples) < n_samples:
            x_prime = proposal_distr.rvs(mean=x_t, cov=cov)
            acceptance_ratio = (self.target_distr(x_prime) * proposal_distr.pdf(x_t, mean=x_prime, cov=cov)) / \
                               (self.target_distr(x_t) * proposal_distr.pdf(x_prime, mean=x_t, cov=cov))
            
            acceptance_ratio = min(1, acceptance_ratio)
            if np.random.uniform(0, 1) < acceptance_ratio:
                x_t = x_prime
                samples.append(x_t)
                accept_count += 1
            else:
                samples.append(x_t)
        
        self.samples = np.array(samples)
        self.acceptance_rate = accept_count / n_samples
        self.mean = np.mean(self.samples, axis=0)
        self.covariance = np.cov(self.samples.T)
        
        return self.samples
    
    def plot_samples(self):
        """
        Plot the generated samples.
        """
        if self.samples is None or len(self.samples) == 0:
            raise ValueError("No samples to plot. Please run the sample method first.")
        
        plt.figure(figsize=(10, 6))
        plt.scatter(self.samples[:, 0], self.samples[:, 1], alpha=0.5, s=1)
        plt.title('Samples from Target Distribution using Metropolis-Hastings')
        plt.xlabel('X')
        plt.ylabel('Y')
        plt.axis('equal')
        plt.show()