In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torchdata.datapipes.map import SequenceWrapper, Batcher
from torch.utils.data import DataLoader
from lightning import LightningModule, Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
import wandb
import einops
from matplotlib import pyplot as plt
import seaborn as sns
import random
from utils import PositionalEncoding, bincount_along_dim, get_returns
import warnings

warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings("ignore", ".*train_dataloader yielded None.*")
sns.set(style="white")

In [None]:
class MastermindEnvironment:
    def __init__(
        self,
        n_pins,
        n_colors,
        n_turns,
        batch_size=1,
        device=torch.device("cpu"),
    ):
        self.n_pins = n_pins
        self.n_colors = n_colors
        self.n_turns = n_turns
        self.batch_size = batch_size
        self.device = device

        self.turn = None
        self.code = None
        self.guesses = None
        self.scores_full = None
        self.scores_color = None
        _ = self.reset()

    @property
    def scores(self):
        return self.scores_full, self.scores_color

    def reset(self):
        self.turn = 1
        self.code = torch.randint(
            low=0,
            high=self.n_colors,
            size=(self.batch_size, self.n_pins),
            device=self.device,
        )
        self.guesses = torch.zeros(
            (0, self.batch_size, self.n_pins),
            dtype=torch.long,
            device=self.device,
        )
        self.scores_full = torch.zeros(
            (1, self.batch_size),
            dtype=torch.long,
            device=self.device,
        )
        self.scores_color = torch.zeros(
            (1, self.batch_size),
            dtype=torch.long,
            device=self.device,
        )
        return self.guesses, self.scores

    def step(self, guess):
        if isinstance(guess, list):
            guess = torch.tensor(guess, dtype=torch.long, device=self.device)
        assert self.turn <= self.n_turns
        assert guess.size() == (self.batch_size, self.n_pins)

        self.guesses = torch.cat((self.guesses, guess.unsqueeze(0)), dim=0)

        current_score_full = (self.code == guess).sum(dim=1)
        self.scores_full = torch.cat(
            (self.scores_full, current_score_full.unsqueeze(0)), dim=0
        )

        code_color_counts = bincount_along_dim(self.code, max_val=self.n_colors, dim=1)
        guess_color_counts = bincount_along_dim(guess, max_val=self.n_colors, dim=1)
        num_correct_colors = torch.min(code_color_counts, guess_color_counts).sum(dim=1)
        current_score_color = num_correct_colors - current_score_full
        self.scores_color = torch.cat(
            (self.scores_color, current_score_color.unsqueeze(0)), dim=0
        )

        done = (self.scores_full == self.n_pins).any(dim=0)
        reward = self.reward_func(done)

        self.turn += 1

        return self.guesses, self.scores, reward, done

    def reward_func(self, done):
        # -1 for each turn that doesn't solve the code, -3.0 more for last turn if not solved
        reward = torch.where(done, 0.0, -1.0)
        if self.turn == self.n_turns:
            reward[~done] -= 3.0
        return reward

    def visualize(self, show_answer=False, num_games=None, scale=0.5):
        if num_games is None:
            num_games = self.batch_size
        width, height = (1.25 * scale * (self.n_pins + 1), scale * (self.n_turns + 1))
        fig, axes = plt.subplots(1, num_games, figsize=(width * num_games, height))
        if num_games == 1:
            axes = [axes]
        for i in range(num_games):
            ax, guesses, scores_full, scores_color, code = (
                axes[i],
                self.guesses[:, i].cpu().numpy(),
                self.scores_full[1:, i].cpu().numpy(),
                self.scores_color[1:, i].cpu().numpy(),
                self.code[i].cpu().numpy(),
            )
            # Guesses
            sns.scatterplot(
                x=list(range(1, self.n_pins + 1)) * (self.turn - 1),
                y=sum([[t] * self.n_pins for t in range(1, self.turn)], []),
                hue=[str(g) for g in guesses.flatten()],
                hue_order=[str(c) for c in range(self.n_colors)],
                s=scale * 200,
                ax=ax,
            )
            # Answer
            if show_answer:
                sns.scatterplot(
                    x=range(1, self.n_pins + 1),
                    y=[self.n_turns + 1.5] * self.n_pins,
                    hue=[str(c) for c in code],
                    hue_order=[str(c) for c in range(self.n_colors)],
                    s=scale * 200,
                    ax=ax,
                )
            # Bounds
            ax.set(
                xlim=(0, self.n_pins + 1),
                ylim=(0, self.n_turns + 1 + show_answer * 1.5),
            )
            if show_answer:
                ax.set_yticks(
                    list(range(1, self.n_turns + 1)) + [self.n_turns + 1.5],
                    labels=list(range(1, self.n_turns + 1)) + ["Answer"],
                )
            else:
                ax.set_yticks(range(1, self.n_turns + 1))
            # Scores
            y_score = ax.twinx()
            y_score.set_ylim(ax.get_ylim())
            y_score.set_yticks(
                range(1, self.turn),
                labels=[f"{f} / {c}" for f, c in zip(scores_full, scores_color)],
            )
            # Other aesthetics
            ax.set(xticks=[])
            ax.legend().remove()
        fig.tight_layout()
        return fig, axes

In [None]:
class MastermindPolicy(nn.Module):
    def __init__(
        self,
        n_pins,
        n_colors,
        n_turns,
        hidden_dim=128,
        n_heads=2,
        n_layers=2,
        dropout=0.0,
    ):
        super().__init__()
        assert hidden_dim % n_pins == 0

        self.n_pins = n_pins
        self.n_colors = n_colors
        self.n_turns = n_turns
        self.hidden_dim = hidden_dim

        self.guess_embedding = nn.Linear(n_colors, int(hidden_dim / n_pins))
        self.score_embedding = nn.Linear(2 * (n_pins + 1), hidden_dim)

        self.pe = PositionalEncoding(hidden_dim, max_len=n_turns * 2 - 1, dropout=0.0)
        self.transformer = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=hidden_dim, nhead=n_heads, dropout=dropout
            ),
            num_layers=n_layers,
        )

        self.policy_head = nn.Linear(int(hidden_dim / n_pins), n_colors)

    def forward(self, guesses, scores, exploration_eps=0.1):
        """Action policy

        Args:
            guesses (torch.LongTensor): (n_guesses, bs, n_pins) with values in (0, n_colors - 1)
            scores (tuple[torch.LongTensor, torch.LongTensor]): Each is (n_guesses + 1, bs) with values in (0, n_pins)
            exploration_eps (float, optional): Probability of random exploratory action. Defaults to 0.1.

        Returns:
            (tuple[torch.LongTensor, torch.FloatTensor]): Tuple of the action (bs, n_pins) and it's probability under the policy (bs)
        """
        assert scores[0].size(0) == guesses.size(0) + 1
        n_guesses, bs, _ = guesses.size()
        scores_full, scores_color = scores

        # Embed guesses and scores
        guesses = F.one_hot(guesses, num_classes=self.n_colors).float()
        guesses = self.guess_embedding(guesses)
        guesses = guesses.view(n_guesses, bs, self.hidden_dim)
        scores_full = F.one_hot(scores_full, num_classes=self.n_pins + 1).float()
        scores_color = F.one_hot(scores_color, num_classes=self.n_pins + 1).float()
        scores = torch.cat([scores_full, scores_color], dim=-1)
        scores = self.score_embedding(scores)

        # Build state as a sequence of interleaved guesses and scores
        state = einops.rearrange(
            [scores[:-1], guesses], "k n_guesses bs d -> (n_guesses k) bs d"
        )
        state = torch.cat([state, scores[-1:]], dim=0)
        state = self.pe(state)

        # Run transformer and get the representation of the action from the last token
        representation = self.transformer(state)[-1]
        representation = representation.view(bs, self.n_pins, -1)

        # Sample an action from the policy and compute its probability
        p = F.softmax(
            self.policy_head(representation), dim=-1
        )  # (bs, n_pins, n_colors)
        if random.random() < exploration_eps:
            p_action = torch.ones_like(p) / self.n_colors
        else:
            p_action = p
        next_guess = torch.multinomial(
            p_action.view(-1, self.n_colors),
            num_samples=1,
        ).view(bs, self.n_pins)
        p_next_guess = (
            p.gather(dim=-1, index=next_guess.unsqueeze(-1)).squeeze(-1).prod(dim=-1)
        )

        return next_guess, p_next_guess


In [None]:
class MastermindRL(LightningModule):
    def __init__(
        self,
        n_pins=4,
        n_colors=8,
        n_turns=12,
        hidden_dim=128,
        n_heads=2,
        n_layers=2,
        dropout=0.2,
        batch_size=32,
        val_episodes=100,
        train_episodes=1000,
        lr=1e-3,
        exploration_eps=0.1,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.env = MastermindEnvironment(
            n_pins=n_pins,
            n_colors=n_colors,
            n_turns=n_turns,
            batch_size=batch_size,
        )
        self.policy = MastermindPolicy(
            n_pins=n_pins,
            n_colors=n_colors,
            n_turns=n_turns,
            hidden_dim=hidden_dim,
            n_heads=n_heads,
            n_layers=n_layers,
            dropout=dropout,
        )

    def training_step(self, batch, batch_idx):
        p, reward, done = self.unroll_episode()
        loss = self.loss_func(p, reward, done)
        num_turns = (~done).sum(dim=0).float().mean()
        self.log("train/policy_loss", loss, on_epoch=True, on_step=False)
        self.log(
            "train/num_turns", num_turns, on_epoch=True, on_step=False, prog_bar=True
        )
        return loss

    def validation_step(self, batch, batch_idx):
        p, reward, done = self.unroll_episode()
        loss = self.loss_func(p, reward, done)
        num_turns = (~done).sum(dim=0).float().mean()
        self.log("val/policy_loss", loss)
        self.log("val/num_turns", num_turns, prog_bar=True)

    def on_validation_epoch_end(self):
        pass

    def unroll_episode(self):
        exploration_eps = self.hparams.exploration_eps if self.training else 0.0
        guesses, scores = self.env.reset()
        p_guesses, rewards, dones = [], [], []
        dones.append(
            torch.zeros(self.hparams.batch_size, dtype=torch.bool, device=self.device)
        )

        while not dones[-1].all() and self.env.turn <= self.hparams.n_turns:
            next_guess, p_next_guess = self.policy(
                guesses, scores, exploration_eps=exploration_eps
            )
            guesses, scores, reward, done = self.env.step(next_guess)
            p_guesses.append(p_next_guess)
            rewards.append(reward)
            dones.append(done)

        p_guesses, rewards, dones = (
            torch.stack(p_guesses),
            torch.stack(rewards),
            torch.stack(dones),
        )

        return p_guesses, rewards, dones

    def loss_func(self, p, reward, done):
        # REINFORCE loss
        returns = get_returns(reward)
        loss = -p.log() * returns
        loss = torch.where(
            ~done[:-1], loss, 0
        )  # Mask out transitions from completed trajectories
        return loss.sum(dim=0).mean()  # Sum across trajectories, mean across batch

    def on_fit_start(self):
        self.env.device = self.device

    def train_dataloader(self):
        return DataLoader(
            torch.zeros(
                self.hparams.train_episodes // self.hparams.batch_size,
                self.hparams.batch_size,
            ).bool(),
            batch_size=None,
            num_workers=0,
        )

    def val_dataloader(self):
        return DataLoader(
            torch.zeros(
                self.hparams.val_episodes // self.hparams.batch_size,
                self.hparams.batch_size,
            ).bool(),
            batch_size=None,
            num_workers=0,
        )

    def configure_optimizers(self):
        return torch.optim.Adam(self.policy.parameters(), lr=self.hparams.lr)


In [None]:
task = MastermindRL(n_pins=2, n_colors=3)

wandb.finish()
trainer = Trainer(
    max_epochs=1000,
    logger=WandbLogger(
        project="Mastermind",
        entity="ericelmoznino",
        save_dir="trained_models",
    ),
    # callbacks=[
    #     ModelCheckpoint(monitor="val/num_turns"),
    #     EarlyStopping(monitor="val/num_turns", patience=30),
    # ],
)
trainer.fit(model=task)
