# Solving the Line Extending Game with Neural Networks

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import random
import warnings
import time
from collections import Counter

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

from nsai_experiments import line_extending_game_tools as lgt

We'll first try a simple NN solution on the 5x5 subgrids without RL. We start by solving the problem the human way.

In [3]:
def human_answer(subgrid):
    if subgrid[0, 2] and subgrid[1, 2]: return True  # from the north
    if subgrid[0, 4] and subgrid[1, 3]: return True  # from the northeast
    if subgrid[2, 4] and subgrid[2, 3]: return True  # from the east
    if subgrid[4, 4] and subgrid[3, 3]: return True  # from the southeast
    if subgrid[4, 2] and subgrid[3, 2]: return True  # from the south
    if subgrid[4, 0] and subgrid[3, 1]: return True  # from the southwest
    if subgrid[2, 0] and subgrid[2, 1]: return True  # from the west
    if subgrid[0, 0] and subgrid[1, 1]: return True  # from the northwest
    return False

In [4]:
grids = [
    lgt.create_grid("""
    - - x - -
    - - x - -
    - - - - -
    - - - - -
    - - - - -
    """),
    lgt.create_grid("""
    - - - - x
    - - - x -
    - - - - -
    - - - - -
    - - - - -
    """),
    lgt.create_grid("""
    - - - - -
    - - - - -
    - - - x x
    - - - - -
    - - - - -
    """),
    lgt.create_grid("""
    - - - - -
    - - - - -
    - - - - -
    - - - x -
    - - - - x
    """),
    lgt.create_grid("""
    - - - - -
    - - - - -
    - - - - -
    - - x - -
    - - x - -
    """),
    lgt.create_grid("""
    - - - - -
    - - - - -
    - - - - -
    - x - - -
    x - - - -
    """),
    lgt.create_grid("""
    - - - - -
    - - - - -
    x x - - -
    - - - - -
    - - - - -
    """),
    lgt.create_grid("""
    x - - - -
    - x - - -
    - - - - -
    - - - - -
    - - - - -
    """),
    lgt.create_grid("""
    x x x x x
    x - - - x
    x - - - x
    x - - - x
    x x x x x
    """),
    lgt.create_grid("""
    - - - - -
    - x x x -
    - x - x -
    - x x x -
    - - - - -
    """),
    ]
[human_answer(grid) for grid in grids]

[True, True, True, True, True, True, True, True, False, False]

In [5]:
def flatten_subgrid(subgrid):
    rows, cols = subgrid.shape
    assert rows % 2 == cols % 2 == 1
    return np.delete(subgrid.flatten(), cols*(rows//2)+cols//2)
assert sum(flatten_subgrid(lgt.create_grid("""
    - - - - -
    - - - - -
    - - x - -
    - - - - -
    - - - - -
    """))) == 0

In [6]:
def create_subproblem():
    subgrid = np.random.choice([True, False], size = (5, 5), p = [0.25, 0.75])
    return flatten_subgrid(subgrid), human_answer(subgrid)
create_subproblem()

(array([False, False, False,  True, False, False,  True, False, False,
        False, False,  True,  True, False, False, False,  True,  True,
        False, False, False, False, False,  True]),
 True)

Now we can generate training, validation, and test data:

In [7]:
np.random.seed(47)
X_list, y_list = zip(*[create_subproblem() for i in range(10_000)])

X = torch.tensor(np.array(X_list), dtype=torch.float32)
y = torch.tensor(np.array(y_list), dtype=torch.float32).reshape(-1, 1)
print(sum(y)/y.shape[0])

X_train, y_train = X[:8_000], y[:8_000]
X_valid, y_valid = X[8_000:9_000], y[8_000:9_000]
X_test, y_test = X[9_000:], y[9_000:]

tensor([0.4005])


Our neural network will have a single fully connected hidden layer:

In [8]:
class MyNet(nn.Module):
    def __init__(self, n_rows, n_cols):
        super().__init__()
        n_cells = n_rows*n_cols-1
        n_hidden = n_cells*2
        self.hidden = nn.Linear(n_cells, n_hidden)
        self.relu = nn.ReLU()
        self.output = nn.Linear(n_hidden, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.hidden(x)
        x = self.relu(x)
        x = self.output(x)
        x = self.sigmoid(x)
        return x

The only special thing we're doing here is manually implementing L1 regularization to reward sparsity:

In [9]:
def train(model, X_train, y_train, X_valid, y_valid, n_epochs = 50, batch_size = 10, lr = 0.1, l1_lambda = 0.002):
    loss_fn = nn.BCELoss()
    optimizer = optim.SGD(model.parameters(), lr = lr)

    for epoch in range(n_epochs):
        model.train()
        for start_i in range(0, X_train.shape[0], batch_size):
            X_batch = X_train[start_i:start_i+batch_size]
            y_batch = y_train[start_i:start_i+batch_size]
            
            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)
            l1_norm = sum(val.abs().sum() for (name, val) in model.named_parameters() if name.endswith(".weight"))
            loss += l1_lambda*l1_norm
            model.zero_grad()
            loss.backward()
            optimizer.step()
        
        if epoch % (n_epochs//10) == 0:
            model.eval()
            y_pred = model(X_valid)
            print(f"Epoch {epoch}:")
            print((y_pred.round() == y_valid).float().mean())

torch.manual_seed(0)
mynet = MyNet(5, 5)
train(mynet, X_train, y_train, X_valid, y_valid)

Epoch 0:
tensor(0.7620)
Epoch 5:
tensor(0.9640)
Epoch 10:
tensor(1.)
Epoch 15:
tensor(1.)
Epoch 20:
tensor(1.)
Epoch 25:
tensor(1.)
Epoch 30:
tensor(1.)
Epoch 35:
tensor(1.)
Epoch 40:
tensor(1.)
Epoch 45:
tensor(1.)


In [10]:
mynet.eval()
y_pred = mynet(X_test)
print(f"Testing set:")
print((y_pred.round() == y_test).float().mean())

Testing set:
tensor(1.)


We get perfect accuracy! If you squint, there are just a few nonzero output weights:

In [11]:
output_weights = np.array(list(mynet.output.parameters())[0].detach()).reshape(-1)
with np.printoptions(precision = 2, suppress = True):
    print(output_weights)
print((abs(output_weights) > 0.1).sum())
print(len(output_weights))

[ 0.   -0.    3.89  0.   -0.   -3.68  0.    3.9   0.    0.    0.    0.
 -0.   -0.    0.   -3.66  0.09 -0.    0.   -3.6  -0.   -0.   -0.    3.91
  0.   -0.    0.    0.    3.89 -0.   -0.    4.46  0.   -0.   -0.    0.
  0.    3.88 -0.    0.    0.    0.   -0.    0.    0.    0.    0.   -0.  ]
9
48


and they correspond to just a few nonzero hidden weights, the indices of which encode the pairs that the human would test for!

| G | R | I | D | : |
| - | - | - | - | - |
| 0 | 1 | 2 | 3 | 4 |
| 5 | 6 | 7 | 8 | 9 |
| 10 | 11 |    | 12 | 13 |
| 14 | 15 | 16 | 17 | 18 |
| 19 | 20 | 21 | 22 | 23 |

In [12]:
hidden_weights = np.array(list(mynet.hidden.parameters())[0].detach())
print(hidden_weights.shape)
print()
for i, w in enumerate(output_weights):
    if abs(w) > 0.1:
        print(f"{w:.2f}")
        print(np.argwhere(abs(hidden_weights[i]) > 0.1).reshape(-1).tolist())
        print()

(48, 24)

3.89
[0, 6]

-3.68
[2, 7]

3.90
[4, 8]

-3.66
[10, 11]

-3.60
[15, 19]

3.91
[17, 23]

3.89
[16, 21]

4.46
[2, 11, 15]

3.88
[12, 13]



TODO implement RL version.