In [90]:
import numpy as np
import random

import tqdm
import torch
import torch.utils
from torch import optim
from torch import nn
from torch.nn import functional as F

import gobanana as gb

In [91]:
class NaiveGenerator(gb.nn.Generator):
    def __init__(self, board_shape, num_metrics, hidden_dim):
        super().__init__(board_shape, num_metrics)
        board_size = board_shape[0] * board_shape[1]
        input_size = board_size + num_metrics
        self.fc1 = nn.Linear(input_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, board_size * 3)

    def forward(self, noise: torch.Tensor, metrics: torch.Tensor):
        x = torch.cat([noise.flatten(start_dim=1), metrics.flatten(start_dim=1)], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = x.reshape(-1, *self.board_shape, 3)
        x = F.softmax(x, dim=-1)
        return x

In [92]:
class NaiveEvaluator(gb.nn.Evaluator):
    def __init__(self, board_shape, num_metrics, hidden_dim):
        super().__init__(board_shape, num_metrics)
        board_size = board_shape[0] * board_shape[1]
        self.embed = nn.Embedding(3, 1)
        self.fc1 = nn.Linear(board_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_metrics)
        
    def forward(self, one_hot_boards):
        x = one_hot_boards
        x = self.embed(x).flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [93]:
def batch_generate_boards(generator, board_shape, metrics, batch_size=32):
    generator.eval()
    with torch.no_grad():
        noise = torch.rand(batch_size, *board_shape)
        metrics = torch.tensor(metrics).float()
        assert batch_size == metrics.shape[0]
        generator_out = generator(noise, metrics)
        board_matrices = generator_out.argmax(-1).numpy()
    return [gb.game.Board(mat) for mat in board_matrices]
    
def count_bananas(board: gb.game.Board):
    return np.sum(board.mat == board.BANANA).item()

def build_evaluator_training_samples(boards):
    training_samples = []
    for board in boards:
        actual_num_bananas = count_bananas(board)
        board_one_hot_tensor = torch.tensor(board.mat).long()
        metrics_tensor = torch.tensor(actual_num_bananas).float()
        sample = (board_one_hot_tensor, metrics_tensor)
        training_samples.append(sample)
    return training_samples

In [94]:
def train_evaluator(evaluator, train_loader, epochs=1):
    print("Training Evaluator ...")
    optimizer = optim.Adam(evaluator.parameters())
    loss_function = nn.MSELoss()
    evaluator.train()
    losses = []
    for _ in range(epochs):
        for x, y in train_loader:
            optimizer.zero_grad()
            predicted_metrics = evaluator(x)
            loss = loss_function(y, predicted_metrics)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
    print('loss:', np.mean(losses))
        
            
def train_generator(generator, evaluator, train_loader, epochs=1):
    print("Training Generator ...")
    # should I include both parameters here and freeze the other one?
    optimizer = optim.Adam(generator.parameters())
    loss_function = nn.MSELoss()
    generator.train()
    losses = []
    for _ in range(epochs):
        for x, y in train_loader:
            optimizer.zero_grad()
            boards = generator(x, y)
            predicted_metrics = evaluator(boards.argmax(-1))
            loss = loss_function(y, predicted_metrics)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
    print('loss:', np.mean(losses))
            
def get_random_metrics():
    desired_num_bananas = random.randint(0, 9)
    return [desired_num_bananas]

def get_random_metrics_batch(batch_size):
    return [get_random_metrics() for _ in range(batch_size)]

In [98]:
board_shape = (3, 3)
generator = NaiveGenerator(board_shape, 1, 32)
evaluator = NaiveEvaluator(board_shape, 1, 32)
batch_size = 256
iterations = 128

for _ in tqdm.tqdm(list(range(iterations))):
    batch_metrics = get_random_metrics_batch(batch_size)
    generated_boards = batch_generate_boards(generator, board_shape, batch_metrics, batch_size)
        
    evaluator_training_samples = build_evaluator_training_samples(generated_boards)
    evaluator_train_loader = torch.utils.data.DataLoader(
            evaluator_training_samples,
            batch_size=batch_size
    )

    train_evaluator(evaluator, evaluator_train_loader, epochs=10)

    generator_training_samples = []
    
    for _ in range(batch_size):
        noise = torch.rand(*board_shape)
        metrics = torch.tensor(get_random_metrics()).float()
        generator_training_samples.append((noise, metrics))
        
    generator_train_loader = torch.utils.data.DataLoader(
        generator_training_samples,
        batch_size=batch_size
    )
    train_generator(generator, evaluator, generator_train_loader, epochs=50)

100%|██████████| 128/128 [01:40<00:00,  1.28it/s]


Training Evaluator ...
loss: 24.443652534484862
Training Generator ...
loss: 23.22728729248047
Training Evaluator ...
loss: 22.045560836791992
Training Generator ...
loss: 23.636838912963867
Training Evaluator ...
loss: 19.0175573348999
Training Generator ...
loss: 19.70987319946289
Training Evaluator ...
loss: 15.33828763961792
Training Generator ...
loss: 14.698568344116211
Training Evaluator ...
loss: 10.660590648651123
Training Generator ...
loss: 15.611869812011719
Training Evaluator ...
loss: 7.495778942108155
Training Generator ...
loss: 11.276052474975586
Training Evaluator ...
loss: 4.045218467712402
Training Generator ...
loss: 10.237462043762207
Training Evaluator ...
loss: 1.5218568563461303
Training Generator ...
loss: 10.755376815795898
Training Evaluator ...
loss: 1.273093056678772
Training Generator ...
loss: 9.89112377166748
Training Evaluator ...
loss: 1.2643842577934266
Training Generator ...
loss: 8.506629943847656
Training Evaluator ...
loss: 0.9184126257896423
Tra

In [99]:
for _ in range(16):
    metrics = get_random_metrics()
    print("Desired metrics: ", metrics)
    metrics = torch.tensor(metrics).float().reshape(1, -1)
    noise = torch.rand(1, *board_shape)
    board = generator(noise, metrics)
    print(board.argmax(-1))
    print()



Desired metrics:  [8]
tensor([[[2, 0, 1],
         [2, 2, 2],
         [1, 2, 2]]])

Desired metrics:  [4]
tensor([[[0, 0, 1],
         [2, 2, 2],
         [1, 2, 2]]])

Desired metrics:  [8]
tensor([[[2, 0, 1],
         [0, 2, 0],
         [1, 2, 2]]])

Desired metrics:  [0]
tensor([[[1, 0, 2],
         [2, 1, 2],
         [1, 2, 2]]])

Desired metrics:  [2]
tensor([[[0, 0, 2],
         [2, 1, 2],
         [1, 2, 2]]])

Desired metrics:  [5]
tensor([[[0, 0, 2],
         [2, 2, 2],
         [1, 2, 2]]])

Desired metrics:  [0]
tensor([[[0, 0, 1],
         [2, 1, 2],
         [1, 2, 1]]])

Desired metrics:  [3]
tensor([[[0, 0, 1],
         [2, 2, 2],
         [1, 2, 2]]])

Desired metrics:  [8]
tensor([[[2, 0, 2],
         [0, 2, 0],
         [1, 2, 2]]])

Desired metrics:  [1]
tensor([[[0, 0, 2],
         [2, 1, 2],
         [1, 2, 2]]])

Desired metrics:  [5]
tensor([[[0, 0, 1],
         [2, 2, 2],
         [1, 2, 2]]])

Desired metrics:  [9]
tensor([[[2, 0, 2],
         [0, 2, 0],
   