In [11]:
import sys
sys.path.append("/app")

In [12]:
import numpy as np
import random
import matplotlib.pyplot as plt


from sqlalchemy.orm import Session
from Database.database import SessionLocal
from Database.models import Bandit



class ThompsonSamplingAdSelectorDB:
    def __init__(self, session: Session):
        self.session = session

    def select_ad(self):
        bandits = self.session.query(Bandit).all()
        sampled_theta = [
            np.random.beta(b.number_of_success + 1, b.number_of_failures + 1)
            for b in bandits
        ]
        selected_index = int(np.argmax(sampled_theta))
        return bandits[selected_index], selected_index

    def update(self, bandit: Bandit, reward: int):
        if reward == 1:
            bandit.number_of_success += 1
        else:
            bandit.number_of_failures += 1
        self.session.commit()

if __name__ == "__main__":
    n_rounds = 1000
    true_ctrs = [0.05, 0.13, 0.04, 0.20, 0.10]  # Simulated click-through rates

    session = SessionLocal()
    selector = ThompsonSamplingAdSelectorDB(session)

    selections = []
    total_reward = 0

    for _ in range(n_rounds):
        bandit, index = selector.select_ad()
        reward = 1 if random.random() < true_ctrs[index] else 0
        selector.update(bandit, reward)
        selections.append(index)
        total_reward += reward

    print(f"Total reward (clicks): {total_reward}")

    bandits = session.query(Bandit).all()
    for b in bandits:
        print(f"Ad {b.bandit_id}: Successes = {b.number_of_success}, Failures = {b.number_of_failures}")

    plt.hist(selections, bins=np.arange(len(true_ctrs) + 1) - 0.5, rwidth=0.6)
    plt.title("Ad Selections Over Time")
    plt.xlabel("Ad Index")
    plt.ylabel("Number of Times Selected")
    plt.show()

    session.close()

ModuleNotFoundError: No module named 'Database'