In [None]:
import pandas as pd

from bandit.utils.processing import generate_vw_input, generate_vw_actions
from bandit.train import train_bandit

In [2]:
user_feedback = pd.read_csv("../data/user_feedback.csv")
user_feedback = user_feedback[user_feedback["country_code"] == "UK"].sort_values(by="timestamp")   

user_feedback.head(3)

Unnamed: 0,user_id,model_id,country_code,timestamp,feedback
1038,U0007,7951efb0,UK,2025-11-17 08:28:25,1
1112,U0007,6c539336,UK,2025-11-17 08:33:59,1
1131,U0007,7951efb0,UK,2025-11-17 08:35:15,0


In [3]:
model_features = pd.read_csv("../data/model_features.csv")
model_features = model_features[model_features["country_code"] == "UK"]

model_features.head(3)

Unnamed: 0,model_name,country_code,variant_id,model_type,version,model_id,MAE,RMSE,HR,cHR,MRR,Coverage,Precision@K,Recall@K,F1@K
27,LeadFinder,UK,UK0001,content-based,1.0.0,e5f97208,0.86498,0.5023,0.33534,0.23053,0.54197,0.38929,0.22548,0.4362,0.3083
28,LeadFinder,UK,UK0001,content-based,1.0.1,6618f78c,0.57413,0.67331,0.48502,0.60472,0.29218,0.36845,0.13847,0.23396,0.51418
29,LeadFinder,UK,UK0001,content-based,1.0.2,123e86e1,0.8712,0.61091,0.52602,0.68674,0.55842,0.20192,0.25489,0.66431,0.49503


In [4]:
training_data = generate_vw_input(user_feedback, model_features)

# for reference
with open("../data/vw_input.txt", "w") as f:
    f.write("\n\n".join(training_data))

In [5]:
model = train_bandit(training_data)

actions = generate_vw_actions(model_features.to_dict(orient="records"))
predict_input = "shared |user user_id=null\n" + "\n".join(actions)
predictions = model.predict(predict_input)

model_predictions = model_features.copy()[
    ["country_code", "variant_id", "model_type", "version", "model_id"]
]
model_predictions["prediction"] = predictions
model_predictions

Unnamed: 0,country_code,variant_id,model_type,version,model_id,prediction
27,UK,UK0001,content-based,1.0.0,e5f97208,0.116521
28,UK,UK0001,content-based,1.0.1,6618f78c,0.10617
29,UK,UK0001,content-based,1.0.2,123e86e1,0.09663
30,UK,UK0002,collaborative,1.0.0,b1a12453,0.157941
31,UK,UK0002,collaborative,1.0.1,6c539336,0.101143
32,UK,UK0002,collaborative,1.0.2,5d412d02,0.108748
33,UK,UK0003,hybrid,1.0.0,8f6180de,0.117339
34,UK,UK0003,hybrid,1.0.1,7951efb0,0.095794
35,UK,UK0003,hybrid,1.0.2,0b64fdc7,0.099714
