In [9]:
class MandM:
    
    def __init__(self,dist_1994,dist_1996):
        self.d94 = dist_1994
        self.d96 = dist_1996
        self.prior = {'h1':0.5,'h2':0.5}
        
    def set_apriori(self):
        self.prior = {'h1':0.5,'h2':0.5}
        
    def likelihood(self, color_bag1, color_bag2):
        # returns the pair (P(D|H_1), P(D|H_2))
        
        likelihood_h1 = self.d94[color_bag1] * self.d96[color_bag2]  # what is the likelihood of the given colors under H_1
        likelihood_h2 = self.d96[color_bag1] * self.d94[color_bag2]  # what is the likelihood of the given colors under H_2
        return (likelihood_h1,likelihood_h2)
    
    def total_prob(self,likelihood):
        # given the pair (P(D|H_1), P(D|H_2)), 
        # return P(D) under the current prior distribution (self.prior)
        
        p_d =  ( likelihood[0]*self.prior['h1'] ) + ( likelihood[1]*self.prior['h2'] )# what is P(D) given the likelihood vectore and the current prior
        return p_d
    
    def posterior(self,likelihood):
        # given the lieklihoods (P(D|H_1), P(D|H_2)),
        # returns the posterior probability distribution (P(H_1),P(H_2)) (post hoc)
        # under the current prior distribution (self.prior)
        p_d = self.total_prob(likelihood)
        post_h1 = ( likelihood[0]*self.prior['h1'] ) / (p_d) # what is the new probability of H_1, given the likelihood
        post_h2 = ( likelihood[1]*self.prior['h2'] ) / (p_d)  # what is the new probability of H_1, given the likelihood
        return (post_h1,post_h2)
    
    def experiment(self,color_bag1, color_bag2):
        likelihood = self.likelihood(color_bag1, color_bag2)
        posterior = self.posterior(likelihood)
        self.prior = {'h1':posterior[0], 'h2':posterior[1]}
        return self.decision()
    
    def decision(self):
        # maximum likelihood decision criteria
        
        if self.prior['h1']>=self.prior['h2']:
            return 'h1'
        return 'h2'
    
    def __repr__(self):
        return f'h1: {self.prior["h1"]:.4f}, h2: {self.prior["h2"]:.4f}'

dist_1994 = {'brown':.3, 'yellow':.2, 'red':.2, 'green':.1, 'orange':.1, 'tan':.1, 'blue':0}
dist_1996 = {'brown':.13, 'yellow':.14, 'red':.13, 'green':.2, 'orange':.16, 'tan':0, 'blue':.24}

# test 1

print(f'the result should be |h1: 0.7407, h2: 0.2593|')
mandm = MandM(dist_1994,dist_1996)
mandm.experiment('yellow','green')
print(mandm)

# test 2

print(f'the result should be |h1: 0.0000, h2: 1.0000|')
mandm.set_apriori()
mandm.experiment('blue','yellow')
print(mandm)

the result should be |h1: 0.7407, h2: 0.2593|
h1: 0.7407, h2: 0.2593
the result should be |h1: 0.0000, h2: 1.0000|
h1: 0.0000, h2: 1.0000
