In [4]:
import torch as t
from plotly import express as px, graph_objects as go, subplots

from src.envs import make_envs
from src.utils import seed
from src.thermostat import Thermostat
from src.training_and_testing import train, test

SEED = 42

In [2]:
seed(SEED)

model = Thermostat()
lr = 1e-4
optimizer = t.optim.Adam(model.parameters(), lr, maximize=True)

train_hist = train(model, optimizer)

fig = px.line(
    x=list(range(len(train_hist.gains))), y=train_hist.gains
).update_layout(xaxis_title="round", yaxis_title="gain")
fig.show()
# fig.write_image("training_history_gains_plot.png")

100%|██████████| 1000/1000 [00:02<00:00, 470.07it/s]


In [10]:
seed(SEED)

test_envs = make_envs(10, temp_mu=22)
test_hist = test(model, test_envs, n_rounds=2000)
# fig = px.line(x=list(range(len(test_hist.prefs))), y=test_hist.prefs).update_layout(
#     xaxis_title="round", yaxis_title="pref"
# )
fig = subplots.make_subplots(specs=[[{"secondary_y": True}]])
x = list(range(len(test_hist.prefs)))
mean_temps = test_hist.env_history[:, :, 0].mean(1)
fig.add_traces(
    data=[
        go.Scatter(x=x, y=test_hist.prefs, name="preference"),
        go.Scatter(x=x, y=mean_temps, name="temperature")
    ],
    secondary_ys=[0, 1],
)
# fig.write_image("testing_history_prefs_plot.png")
fig.show()

100%|██████████| 2000/2000 [00:00<00:00, 9945.05it/s]
