In [2]:
import numpy as np
import random

import tqdm
import torch
import torch.utils
from torch import optim
from torch import nn
from torch.nn import functional as F

import gobanana as gb

In [3]:
class NaiveGenerator(gb.nn.Generator):
    def __init__(self, board_shape, num_metrics, hidden_dim):
        super().__init__(board_shape, num_metrics)
        board_size = board_shape[0] * board_shape[1]
        input_size = board_size + num_metrics
        self.fc1 = nn.Linear(input_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, board_size * 3)

    def forward(self, noise: torch.Tensor, metrics: torch.Tensor):
        x = torch.cat([noise.flatten(start_dim=1), metrics.flatten(start_dim=1)], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = x.reshape(-1, *self.board_shape, 3)
        x = F.softmax(x, dim=-1)
        return x

class NaiveEvaluator(gb.nn.Evaluator):
    def __init__(self, board_shape, num_metrics, hidden_dim):
        super().__init__(board_shape, num_metrics)
        board_size = board_shape[0] * board_shape[1]
        self.embed = nn.Embedding(3, 1)
        self.fc1 = nn.Linear(board_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_metrics)
        
    def forward(self, one_hot_boards):
        x = one_hot_boards
        x = self.embed(x).flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [15]:
def batch_generate_boards(generator, board_shape, metrics, batch_size=32):
    generator.eval()
    with torch.no_grad():
        noise = torch.rand(batch_size, *board_shape)
        metrics = torch.tensor(metrics).float()
        assert batch_size == metrics.shape[0]
        generator_out = generator(noise, metrics)
        board_matrices = generator_out.argmax(-1).numpy()
    return [gb.game.Board(mat) for mat in board_matrices]
    
def count_bananas(board: gb.game.Board):
    return np.sum(board.mat == board.BANANA).item()

def build_evaluator_training_samples(boards):
    training_samples = []
    for board in boards:
        actual_num_bananas = count_bananas(board)
        board_one_hot_tensor = torch.tensor(board.mat).long()
        metrics_tensor = torch.tensor([actual_num_bananas]).float()
        sample = (board_one_hot_tensor, metrics_tensor)
        training_samples.append(sample)
    return training_samples

def train_evaluator(evaluator, train_loader, epochs=1):
    print("Training Evaluator ...")
    optimizer = optim.Adam(evaluator.parameters())
    loss_function = nn.MSELoss()
    evaluator.train()
    losses = []
    for _ in range(epochs):
        for x, y in train_loader:
            optimizer.zero_grad()
            predicted_metrics = evaluator(x)
            loss = loss_function(y, predicted_metrics)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
    print('loss:', np.mean(losses))
        
            
def train_generator(generator, evaluator, train_loader, epochs=1):
    print("Training Generator ...")
    # should I include both parameters here and freeze the other one?
    optimizer = optim.Adam(generator.parameters())
    loss_function = nn.MSELoss()
    generator.train()
    losses = []
    for _ in range(epochs):
        for x, y in train_loader:
            optimizer.zero_grad()
            boards = generator(x, y)
            # argmax maybe the culprit of the generator not training.
            predicted_metrics = evaluator(boards.argmax(-1))
            loss = loss_function(y, predicted_metrics)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
    print('loss:', np.mean(losses))
            
def get_random_metrics():
    desired_num_bananas = random.randint(0, 9)
    return [desired_num_bananas]

def get_random_metrics_batch(batch_size):
    return [get_random_metrics() for _ in range(batch_size)]

In [16]:
board_shape = (3, 3)
generator = NaiveGenerator(board_shape, 1, 32)
evaluator = NaiveEvaluator(board_shape, 1, 32)
batch_size = 32
iterations = 256

for _ in tqdm.notebook.tqdm(list(range(iterations))):
    batch_metrics = get_random_metrics_batch(batch_size)
    generated_boards = batch_generate_boards(generator, board_shape, batch_metrics, batch_size)
        
    evaluator_training_samples = build_evaluator_training_samples(generated_boards)
    evaluator_train_loader = torch.utils.data.DataLoader(
            evaluator_training_samples,
            batch_size=batch_size
    )

    train_evaluator(evaluator, evaluator_train_loader, epochs=10)

    generator_training_samples = []
    
    for _ in range(batch_size):
        noise = torch.rand(*board_shape)
        metrics = torch.tensor(get_random_metrics()).float()
        generator_training_samples.append((noise, metrics))
        
    generator_train_loader = torch.utils.data.DataLoader(
        generator_training_samples,
        batch_size=batch_size
    )
    train_generator(generator, evaluator, generator_train_loader, epochs=50)

HBox(children=(FloatProgress(value=0.0, max=256.0), HTML(value='')))

Training Evaluator ...
loss: 10.314599323272706
Training Generator ...
loss: 31.059799194335938
Training Evaluator ...
loss: 7.0968194007873535
Training Generator ...
loss: 21.19251823425293
Training Evaluator ...
loss: 5.373669099807739
Training Generator ...
loss: 20.91891098022461
Training Evaluator ...
loss: 3.4882835388183593
Training Generator ...
loss: 13.732887268066406
Training Evaluator ...
loss: 2.440849769115448
Training Generator ...
loss: 7.155974864959717
Training Evaluator ...
loss: 2.2714247703552246
Training Generator ...
loss: 12.292977333068848
Training Evaluator ...
loss: 1.089270567893982
Training Generator ...
loss: 13.53931713104248
Training Evaluator ...
loss: 0.5066334903240204
Training Generator ...
loss: 12.536434173583984
Training Evaluator ...
loss: 0.3618632242083549
Training Generator ...
loss: 15.90229606628418
Training Evaluator ...
loss: 0.1527445949614048
Training Generator ...
loss: 11.450594902038574
Training Evaluator ...
loss: 0.13096734061837195

loss: 20.817113876342773
Training Evaluator ...
loss: 0.003651901485864073
Training Generator ...
loss: 9.428675651550293
Training Evaluator ...
loss: 0.0021866044960916044
Training Generator ...
loss: 18.948225021362305
Training Evaluator ...
loss: 0.002349258927279152
Training Generator ...
loss: 17.39179801940918
Training Evaluator ...
loss: 0.003239898144965991
Training Generator ...
loss: 12.7905912399292
Training Evaluator ...
loss: 0.0029234834073577076
Training Generator ...
loss: 22.336849212646484
Training Evaluator ...
loss: 0.0026343995792558416
Training Generator ...
loss: 19.645057678222656
Training Evaluator ...
loss: 0.0019261546665802598
Training Generator ...
loss: 14.60218334197998
Training Evaluator ...
loss: 0.0025092610769206656
Training Generator ...
loss: 13.838748931884766
Training Evaluator ...
loss: 0.0017454047512728721
Training Generator ...
loss: 18.158031463623047
Training Evaluator ...
loss: 0.0027454954280983655
Training Generator ...
loss: 19.629297256

loss: 15.329361915588379
Training Evaluator ...
loss: 0.002337996521964669
Training Generator ...
loss: 17.225425720214844
Training Evaluator ...
loss: 0.0017221759771928191
Training Generator ...
loss: 14.31745719909668
Training Evaluator ...
loss: 0.0024513077456504106
Training Generator ...
loss: 18.15622329711914
Training Evaluator ...
loss: 0.0015132169864955358
Training Generator ...
loss: 11.941939353942871
Training Evaluator ...
loss: 0.0024546986183850094
Training Generator ...
loss: 13.19661808013916
Training Evaluator ...
loss: 0.0016992169475997798
Training Generator ...
loss: 14.386531829833984
Training Evaluator ...
loss: 0.0021198746428126468
Training Generator ...
loss: 16.795015335083008
Training Evaluator ...
loss: 0.0016650445802952162
Training Generator ...
loss: 11.429699897766113
Training Evaluator ...
loss: 0.0019288884301204233
Training Generator ...
loss: 15.864179611206055
Training Evaluator ...
loss: 0.0017372203670674934
Training Generator ...
loss: 17.66786

In [19]:
# test evaluator
for _ in range(16):
    metrics = get_random_metrics()
    metrics = torch.tensor(metrics).float().reshape(1, -1)
    noise = torch.rand(1, *board_shape)
    gen_board = generator(noise, metrics)
    board_mat = gen_board.argmax(-1).long()
    pred = evaluator(board_mat)
    board = gb.game.Board(board_mat.squeeze().numpy())
    num_bananas = count_bananas(board)
    print('num bananas: ', num_bananas, '\tprediction:', pred.item())

num bananas:  4 	prediction: 4.039809703826904
num bananas:  2 	prediction: 2.0237693786621094
num bananas:  4 	prediction: 4.03702974319458
num bananas:  2 	prediction: 2.0237693786621094
num bananas:  5 	prediction: 5.042510509490967
num bananas:  2 	prediction: 2.0237693786621094
num bananas:  4 	prediction: 4.03702974319458
num bananas:  2 	prediction: 2.0237693786621094
num bananas:  2 	prediction: 2.0237693786621094
num bananas:  2 	prediction: 2.0237693786621094
num bananas:  4 	prediction: 4.039809703826904
num bananas:  2 	prediction: 2.0237693786621094
num bananas:  2 	prediction: 2.0237693786621094
num bananas:  2 	prediction: 2.0237693786621094
num bananas:  3 	prediction: 3.0684990882873535
num bananas:  2 	prediction: 2.0237693786621094


In [20]:
# test generator 
for _ in range(16):
    metrics = get_random_metrics()
    print("Desired metrics: ", metrics)
    metrics = torch.tensor(metrics).float().reshape(1, -1)
    noise = torch.rand(1, *board_shape)
    board = generator(noise, metrics)
    print(board.argmax(-1))
    print()

Desired metrics:  [2]
tensor([[[1, 1, 2],
         [1, 2, 2],
         [0, 2, 0]]])

Desired metrics:  [1]
tensor([[[1, 0, 2],
         [1, 2, 2],
         [0, 2, 0]]])

Desired metrics:  [7]
tensor([[[1, 1, 0],
         [1, 2, 0],
         [0, 2, 0]]])

Desired metrics:  [0]
tensor([[[1, 0, 2],
         [1, 2, 2],
         [0, 2, 0]]])

Desired metrics:  [2]
tensor([[[2, 1, 2],
         [1, 2, 2],
         [0, 2, 0]]])

Desired metrics:  [8]
tensor([[[1, 1, 0],
         [1, 2, 0],
         [0, 2, 0]]])

Desired metrics:  [3]
tensor([[[1, 1, 2],
         [1, 2, 2],
         [0, 2, 0]]])

Desired metrics:  [7]
tensor([[[1, 1, 0],
         [1, 2, 0],
         [0, 2, 0]]])

Desired metrics:  [4]
tensor([[[1, 1, 2],
         [1, 2, 2],
         [0, 2, 0]]])

Desired metrics:  [8]
tensor([[[1, 1, 0],
         [1, 2, 0],
         [0, 2, 0]]])

Desired metrics:  [3]
tensor([[[1, 1, 2],
         [1, 2, 2],
         [0, 2, 0]]])

Desired metrics:  [6]
tensor([[[1, 1, 0],
         [1, 2, 0],
   