In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from sqlalchemy.orm import Session
from datetime import datetime
from database.database import get_db
from database.models import Project, Bandit, Experiment


: 

In [None]:
TAU = 1.0

In [None]:
def create_project_gaussian(description, prices, db):
    proj = Project(description=description, number_bandits=len(prices))
    db.add(proj)
    db.commit()
    db.refresh(proj)
    for p in prices:
        b = Bandit(project_id=proj.project_id, price=p, mean=0.0, variance=1.0, reward=0.0, trial=0)
        db.add(b)
    db.commit()
    return proj


In [None]:
def get_champion_bandit_gaussian(project_id, db):
    bandits = db.query(Bandit).filter(Bandit.project_id == project_id).all()
    samples = []
    for b in bandits:
        mean = float(b.mean)
        lambda_ = float(b.variance)
        std = 1.0 / np.sqrt(lambda_)
        samples.append(np.random.normal(mean, std))
    best = np.argmax(samples)
    return bandits[best]


In [None]:
def submit_reward_gaussian(bandit_id, reward_value, decision, db):
    b = db.query(Bandit).filter(Bandit.bandit_id == bandit_id).first()
    reward = float(reward_value)

    exp = Experiment(project_id=b.project_id, bandit_id=b.bandit_id,
                     decision=decision, reward=reward,
                     start_date=datetime.now(), end_date=datetime.now())
    db.add(exp)

    lambda_old = float(b.variance)
    sum_x_old = float(b.reward)

    lambda_new = lambda_old + TAU
    sum_x_new = sum_x_old + reward
    mean_new = (TAU * sum_x_new) / lambda_new

    b.variance = lambda_new
    b.reward = sum_x_new
    b.mean = mean_new
    b.trial += 1
    b.updated_at = datetime.now()

    db.commit()
    return b


In [None]:
def plot_bandit_distributions(bandits, title="Posterior Distributions"):
    plt.figure(figsize=(8,5))
    x = np.linspace(-5, 10, 400)
    for b in bandits:
        mean = float(b.mean)
        lambda_ = float(b.variance)
        std = 1.0 / np.sqrt(lambda_)
        y = norm.pdf(x, mean, std)
        plt.plot(x, y, label=f"Mean={mean:.2f}, Trials={b.trial}")
    plt.title(title)
    plt.legend()
    plt.grid()
    plt.show()

In [None]:
# Test
db = next(get_db())

In [None]:
proj = create_project_gaussian("Gaussian Test", [10,20,30], db)

In [None]:
champ = get_champion_bandit_gaussian(proj.project_id, db)

In [None]:
submit_reward_gaussian(champ.bandit_id, 15.0, "test", db)

In [None]:
bandits = db.query(Bandit).filter(Bandit.project_id == proj.project_id).all()

In [None]:
plot_bandit_distributions(bandits, "After 1 Trial")