In [2]:
from collections import Counter
import random
from typing import NamedTuple

from fancy_einsum import einsum
import numpy as np
import torch as t
from torch import nn
from torch.distributions import Categorical, Normal, Uniform
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from tqdm import auto, tqdm
import pandas as pd
from plotly import express as px, graph_objects as go, subplots

In [3]:
SEED = 42

def seed(seed_val: float) -> None:
    random.seed(seed_val)
    np.random.seed(seed_val)
    t.manual_seed(seed_val)
    
def center_and_standardize(x: t.Tensor) -> t.Tensor:
    """Center and standardize `x` so that its mean is 0 and standard deviation is 1."""
    return (x - x.mean()) / x.std()

def normalize(x: t.Tensor) -> t.Tensor:
    """Normalize `x` so that it adds up to 1."""
    return x / x.sum()

### Constants / Defaults

In [4]:
ENV_TEMP_MU = 20
ENV_TEMP_SIGMA = 6
DEFAULT_N_ENVS = 1_000
DEFAULT_N_ROUNDS = 10_000

### Envs

In [5]:
def make_envs_normal(
    n_envs: int,
    *,
    temp_mu: float = ENV_TEMP_MU,  # temperature mean
    temp_sigma: float = ENV_TEMP_SIGMA,  # temperature standard deviation
) -> t.Tensor: # [n_envs 1]
    return (t.randn(n_envs) * temp_sigma + temp_mu).reshape(n_envs, 1)

def make_envs_uniform(
    n_envs: int,
    *,
    temp_mean: float = ENV_TEMP_MU,
    temp_diff: float = ENV_TEMP_SIGMA,
) -> t.Tensor: # [n_envs 1]
    temp_dist = Uniform(temp_mean - temp_diff, temp_mean + temp_diff)
    return temp_dist.sample(t.Size((n_envs, 1)))

# How much one binary action nudges temperature
ACTION_SIZE = 1e-2  # TODO better name?

# def act_in_envs(
#     envs: t.Tensor,  # [n_envs 1]
#     action_scores: t.Tensor,  # [n_envs n_actions(2 for now)]
# ) -> t.Tensor: # (-.5, +.5)
#     actions = (action_scores[:, 1] - action_scores[:, 0]).reshape(envs.shape)
#     new_envs = envs.clone().detach() + ACTION_SIZE * actions
#     return new_envs


### Thermostat

In [6]:
IN_DIM = 1
HIDDEN_DIM = 32
OUT_DIM = 3
N_MID_LAYERS = 4

class Thermostat(nn.Module):
    def __init__(
        self,
        *,
        temp_mu: float = 20,
        temp_sigma: float = 1,
        in_dim: int = IN_DIM,
        hidden_dim: int = HIDDEN_DIM,
        out_dim: int = OUT_DIM,
        n_mid_layers: int = N_MID_LAYERS
    ) -> None:
        super().__init__()
        self.temp_mu = temp_mu
        self.temp_sigma = temp_sigma
        self.temp = Normal(temp_mu, temp_sigma)
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        # Layers
        self.fc_in = nn.Linear(in_dim, hidden_dim)
        self.mid_layers = nn.ModuleList(
            [nn.Linear(hidden_dim, hidden_dim) for _ in range(n_mid_layers)]
        )
        self.fc_out = nn.Linear(hidden_dim, out_dim)
    
    # (just for type checking, so that pylint knows model(x) is a Tensor)
    def __call__(self, *args, **kwargs) -> t.Tensor:
        return self.forward(*args, **kwargs)

    def forward(
        self,
        temp: t.Tensor,  # [n_envs 1]
    ) -> t.Tensor:  # [n_envs 2] logits of action (-1 or 1)
        x = F.relu(self.fc_in(temp))
        for layer in self.mid_layers:
            x = F.relu(layer(x))
        return self.fc_out(x)

    def temp_densities(
        self,
        temp: t.Tensor,  # [n_envs 1]
    ) -> t.Tensor:
        return self.temp.log_prob(temp)

    # def sample_action(
    #     self, action_probs: t.Tensor  # [n_envs n_actions]
    # ) -> t.Tensor:  # [n_envs 1] (row-wise: unit tensor [index of chosen action])
    #     return Categorical(action_probs).sample()
    
    # def act_in_envs(self, envs: t.Tensor, action_scores: t.Tensor) -> t.Tensor:
    #     action_probs = action_scores.log_softmax(1)
    #     actions = Categorical(action_probs).sample().reshape(envs.shape) - 1
    #     new_envs = envs + actions * ACTION_SIZE
    #     print(f"{action_probs.shape=}; {actions.shape=}; {new_envs.shape=}")
    #     return new_envs

In [7]:
@t.no_grad
def test_model(model: Thermostat) -> None:
    test_envs = t.arange(0, 40.1, 0.1).reshape(-1, 1)
    test_x = test_envs.ravel().tolist()
    action_scores = model(test_envs)
    action_logprobs = action_scores.log_softmax(1)
    action_inds = Categorical(logits=action_logprobs).sample()
    actions = action_inds.reshape(test_envs.shape) - 1
    new_test_envs = test_envs + actions * ACTION_SIZE
    # assert new_test_envs.shape == test_envs.shape
    
    temp_diffs = (new_test_envs - test_envs).ravel().tolist()
    pref_diffs = (model.temp_densities(new_test_envs).exp() - model.temp_densities(test_envs).exp()).ravel().tolist()
    # choice_counts = action_inds.bincount()
    
    fig = subplots.make_subplots(rows=2, cols=1, specs=[[{"secondary_y": True}], [{}]])
    fig.add_traces(
        [
            go.Scatter(x=test_x, y=temp_diffs, name="temp_diffs"),
            go.Scatter(x=test_x, y=pref_diffs, name="pref_diffs"),
            go.Scatter(x=test_x, y=action_inds.tolist(), name="choice")
        ],
        secondary_ys=[0, 1, None], rows=[1, 1, 2], cols=[1, 1, 1],
    )
    #TODO: there's probably a smarter way to do it but how and what exactly? box plot choice counts over time or something?
    # fig.add_trace(, row=2, col=1)
    #TODO: update layout(s)
    fig.show()

In [20]:
class TrainingHistory(NamedTuple):
    gains: list[float]
    n_rounds: int
    n_envs: int
    action_choice_history: list[tuple[int, int, int]]
    temp_diff_history: list[list[float]]
    pref_diff_history: list[list[float]]

def train(
    model: Thermostat,
    optimizer: t.optim.Optimizer,
    n_rounds: int = DEFAULT_N_ROUNDS,
    n_envs: int = DEFAULT_N_ENVS,
    *,
    progressbar: bool = True,
) -> TrainingHistory:
    gains: list[float] = []
    action_choice_history: list[tuple[int, int, int]] = []
    temp_diff_history: list[list[float]] = []
    pref_diff_history: list[list[float]] = []

    for round_i in auto.tqdm(range(n_rounds), disable=not progressbar):
        # Initialize envs
        envs = make_envs_uniform(n_envs)

        # Compute action scores on observation
        action_scores = model(envs)
        action_probs = action_scores.log_softmax(1).exp()
        action_inds = Categorical(probs=action_probs).sample()
        actions = action_inds.reshape(envs.shape) - 1
        action_counts = Counter(actions.ravel().tolist())
        action_choice_history.append((action_counts[-1], action_counts[0], action_counts[1]))
        new_envs = envs + actions * ACTION_SIZE
        
        post_prefs = model.temp_densities(new_envs).exp()
        gain = (
            post_prefs * action_probs.gather(
                dim=0, index=action_inds.unsqueeze(1)
            ).squeeze()
        ).pow(2).mean()
        
        gain.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Append to history
        gain = gain.item()
        gains.append(gain)
        
        if round_i % 200 == 0:
            print(f"[{round_i}] {gain=}; mean action scores: {action_scores.mean().item()}; mean action probs: {action_probs.mean().item()}")
            # test_model(model)
        temp_diffs = (new_envs - envs).ravel().tolist()
        pref_diffs = (post_prefs - model.temp_densities(envs).exp()).ravel().tolist()
        temp_diff_history.append(temp_diffs)
        pref_diff_history.append(pref_diffs)

    return TrainingHistory(gains, n_rounds, n_envs, action_choice_history, temp_diff_history, pref_diff_history)

In [21]:
seed(SEED)

model = Thermostat()
LR = 1e-3
optimizer = t.optim.Adam(model.parameters(), lr=LR, maximize=True)

th = train(model, optimizer, n_rounds=1000)

  0%|          | 0/1000 [00:00<?, ?it/s]

[0] gain=0.0029891079757362604; mean action scores: -0.1035313606262207; mean action probs: 0.3333333432674408
[200] gain=0.02121865376830101; mean action scores: -0.10997683554887772; mean action probs: 0.3333333432674408
[400] gain=0.022636136040091515; mean action scores: -0.09370272606611252; mean action probs: 0.3333333432674408
[600] gain=0.021327249705791473; mean action scores: -0.18016935884952545; mean action probs: 0.3333333432674408
[800] gain=0.023543115705251694; mean action scores: -0.25982779264450073; mean action probs: 0.3333333432674408


In [24]:
test_envs = t.arange(0, 40.1, 0.1).reshape(-1, 1)
action_scores = model(test_envs)
action_probs = action_scores.log_softmax(1).exp()
action_inds = Categorical(probs=action_probs).sample()
actions = action_inds.reshape(test_envs.shape) - 1
action_counts = Counter(actions.ravel().tolist())
new_test_envs = test_envs + actions * ACTION_SIZE


In [25]:
test_model(model)