In [1]:
import os
import numpy as np
from tqdm import tqdm

BOARD_WIDTH = 7
BOARD_HEIGHT = 6
MAX_BOARDS = 100_000_000

In [2]:
def convert_moves_to_game_board(moves):
    moves_by_column = {}

    curr_player = 1
    for move in moves:
        col_index = int(move) - 1

        if col_index not in moves_by_column:
            moves_by_column[col_index] = []

        moves_by_column[col_index].append(curr_player)

        if curr_player == 1:
            curr_player = -1
        else:
            curr_player = 1

    game_board = np.zeros((BOARD_WIDTH, BOARD_HEIGHT), dtype=np.int8)
    for col_index, moves in moves_by_column.items():
        for row_index, move in enumerate(moves):
            game_board[col_index][row_index] = move

    return game_board

def get_col_scores(entries):
    col_scores = np.zeros(BOARD_WIDTH)

    for col_index, col_score in enumerate(entries):
        col_scores[col_index] = int(col_score)

    return col_scores


In [3]:
from scipy.signal import convolve2d

horizontal_kernel = np.array([[ 1, 1, 1, 1]])
vertical_kernel = np.transpose(horizontal_kernel)
diag1_kernel = np.eye(4, dtype=np.uint8)
diag2_kernel = np.fliplr(diag1_kernel)
detection_kernels = [horizontal_kernel, vertical_kernel, diag1_kernel, diag2_kernel]

def check_for_valid_game_board(game_board, player):
    for kernel in detection_kernels:
        if (convolve2d(game_board == player, kernel, mode="valid") == 4).any():
            return False
    return True

In [4]:
print("Counting lines")
num_lines = sum(1 for _ in open("answers20.txt"))
available_game_boards = np.array([i for i in range(num_lines)])

print("Randomly selecting game boards")
selected_game_boards = set(np.random.choice(available_game_boards, size=MAX_BOARDS, replace=False))
del available_game_boards

game_boards = np.zeros((MAX_BOARDS, BOARD_WIDTH, BOARD_HEIGHT), dtype=np.int8)
game_results = np.zeros((MAX_BOARDS, BOARD_WIDTH), dtype=np.int8)
with open("answers20.txt", "r") as f:
    curr_row = 0
    board_index = 0

    pbar = tqdm(total=num_lines, desc="Converting moves to game boards")
    for row in f:
        if curr_row in selected_game_boards:
            entries = row.split(" ")

            game_board = convert_moves_to_game_board(entries[0])

            # if check_for_valid_game_board(game_board, 1) and check_for_valid_game_board(game_board, 2):
            game_boards[board_index] = game_board

            col_scores = get_col_scores(entries[1:])
            max_value = np.max(col_scores)
            game_results[board_index] = (col_scores == max_value)

            board_index += 1

        curr_row += 1
        pbar.update(1)

print(len(game_boards))
del selected_game_boards


Counting lines
Randomly selecting game boards


Converting moves to game boards: 100%|█████████▉| 206060006/206062531 [15:49<00:00, 217485.09it/s]

100000000


In [5]:
print(game_results[0])
print(game_results[46456])
print(game_results[2000])

[0 1 0 1 0 0 0]
[0 0 0 1 0 0 0]
[0 0 0 1 0 0 0]


In [6]:
import torch

class Connect4Dataset(torch.utils.data.Dataset):
    def __init__(self, game_boards, game_results):
        self.data = torch.from_numpy(game_boards)
        self.targets = torch.from_numpy(game_results)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

In [7]:
config = {
    "num_epochs": 30,
    "learning_rate": 1e-8,
    "weight_decay": 1e-2,
    "clip_quantile": 0.75,
    "max_lr": 1e-3,
    "batch_size": 256,
    "test_prop": 0.1,
    "val_prop": 0.1,
    "device": "cuda:0",
}

In [8]:
import torch
from autoclip.torch import QuantileClip

model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 32, kernel_size=5, padding=2),
    torch.nn.ReLU(),
    torch.nn.BatchNorm2d(32),
    torch.nn.Conv2d(32, 64, kernel_size=5, padding=2),
    torch.nn.ReLU(),
    torch.nn.BatchNorm2d(64),
    torch.nn.Conv2d(64, 128, kernel_size=5, padding=1),
    torch.nn.ReLU(),
    torch.nn.BatchNorm2d(128),
    torch.nn.Conv2d(128, 256, kernel_size=5, padding=1),
    torch.nn.ReLU(),
    torch.nn.BatchNorm2d(256),
    torch.nn.Flatten(),
    torch.nn.Linear(256 * 3 * 2, 512),
    torch.nn.LeakyReLU(),
    torch.nn.BatchNorm1d(512),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(512, 128),    
    torch.nn.LeakyReLU(),
    torch.nn.BatchNorm1d(128),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(128, 7),
).to(config["device"])


In [9]:
# Make Data Loaders
num_training_examples = int(len(game_boards) * (1 - config["test_prop"]))
train_dataset = Connect4Dataset(game_boards[:num_training_examples], game_results[:num_training_examples])
test_dataset = Connect4Dataset(game_boards[num_training_examples:], game_results[num_training_examples:])

# Clear Memory
del game_boards
del game_results

In [10]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=10,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=10,
)

In [11]:
optimizer = torch.optim.AdamW(
    model.parameters(),
)
optimizer = QuantileClip.as_optimizer(optimizer, config["clip_quantile"])
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config["max_lr"],
    steps_per_epoch=len(train_loader),
    epochs=config["num_epochs"],
)

In [12]:
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler

positive_examples = torch.sum(train_dataset.targets)
negative_examples = (len(train_dataset.targets)*BOARD_WIDTH) - positive_examples
loss_weights =  negative_examples / positive_examples

loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(loss_weights).to(config["device"]))
scalar = GradScaler()

  loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(loss_weights).to(config["device"]))


In [13]:
# Setup Training Loop
pbar = tqdm(total=config["num_epochs"], desc="Training")
device = config["device"]
bssf = None
bssf_loss = torch.inf

for epoch in range(config["num_epochs"]):
    print(f'--- Epoch {epoch+1}/{config["num_epochs"]} ---')

    # Train model
    model.train()
    training_losses = []
    training_accuracies = []
    training_confusion_matrix = np.zeros((2, 2), dtype=np.uint32)
    p_bar = tqdm(train_loader, desc="Training", position=0)
    for data, target in p_bar:
        data = data.to(device).float().unsqueeze(1)
        target = target.to(device).float()
        optimizer.zero_grad()

        with autocast():
            logits = model(data)
            loss = loss_fn(logits, target)


        scalar.scale(loss).backward()
        scalar.step(optimizer)
        scalar.update()
        lr_scheduler.step()

        training_losses.append(loss.item())


        predictions = (logits > 0).detach().cpu().numpy()
        detached_targets = target.detach().cpu().numpy()
        is_correct = (predictions == detached_targets)

        accuracy = np.mean(is_correct)
        training_accuracies.append(accuracy)

        # Calculate F1
        training_confusion_matrix[0,0] += np.sum(np.logical_and(predictions == 0, detached_targets == 0))
        training_confusion_matrix[0,1] += np.sum(np.logical_and(predictions == 0, detached_targets == 1))
        training_confusion_matrix[1,0] += np.sum(np.logical_and(predictions == 1, detached_targets == 0))
        training_confusion_matrix[1,1] += np.sum(np.logical_and(predictions == 1, detached_targets == 1))

        f1 = (2*training_confusion_matrix[1,1] / (2*training_confusion_matrix[1,1] + training_confusion_matrix[0,1] + training_confusion_matrix[1,0]))


        p_bar.set_postfix(
            {
                "loss": f"{sum(training_losses)/len(training_losses):.4f}",
                "accuracy": f"{sum(training_accuracies)/len(training_accuracies):.4f}",
                "f1": f"{f1:.4f}",
            }
        )

    model.eval()
    with torch.no_grad():
        test_losses = []
        test_accuracies = []
        test_confusion_matrix = np.zeros((2, 2), dtype=np.uint32)
        p_bar = tqdm(test_loader, desc="Test", position=0)
        for data, target in p_bar:
            data = data.to(device).float().unsqueeze(1)
            target = target.to(device).float()
            optimizer.zero_grad()

            with autocast():
                logits = model(data)
                loss = loss_fn(logits, target)

            test_losses.append(loss.item())

            # Calulate Accuracy and F1
            predictions = (logits > 0).detach().cpu().numpy()
            detached_targets = target.detach().cpu().numpy()
            is_correct = (predictions == detached_targets)

            accuracy = np.mean(is_correct)
            test_accuracies.append(accuracy)

            # Calculate F1
            test_confusion_matrix[0,0] += np.sum(np.logical_and(predictions == 0, detached_targets == 0))
            test_confusion_matrix[0,1] += np.sum(np.logical_and(predictions == 0, detached_targets == 1))
            test_confusion_matrix[1,0] += np.sum(np.logical_and(predictions == 1, detached_targets == 0))
            test_confusion_matrix[1,1] += np.sum(np.logical_and(predictions == 1, detached_targets == 1))

            f1 = (2*test_confusion_matrix[1,1] / (2*test_confusion_matrix[1,1] + test_confusion_matrix[0,1] + test_confusion_matrix[1,0]))


            p_bar.set_postfix(
                {
                    "loss": f"{sum(test_losses)/len(test_losses):.4f}",
                    "accuracy": f"{sum(test_accuracies)/len(test_accuracies):.4f}",
                    "f1": f"{f1:.4f}",
                }
            )

        if sum(test_losses)/len(test_losses) < bssf_loss:
            bssf_loss = sum(test_losses)/len(test_losses)
            bssf = model
        


Converting moves to game boards: 100%|██████████| 206062531/206062531 [15:55<00:00, 215619.59it/s]


--- Epoch 1/30 ---


Training: 100%|██████████| 351563/351563 [1:13:20<00:00, 79.88it/s, loss=0.1956, accuracy=0.9383, f1=0.8816]
Test: 100%|██████████| 39063/39063 [01:29<00:00, 436.15it/s, loss=0.1312, accuracy=0.9595, f1=0.9252]


--- Epoch 2/30 ---


Training: 100%|██████████| 351563/351563 [1:15:03<00:00, 78.06it/s, loss=0.1388, accuracy=0.9570, f1=0.9158]
Test: 100%|██████████| 39063/39063 [01:29<00:00, 435.09it/s, loss=0.1197, accuracy=0.9623, f1=0.9303]


--- Epoch 3/30 ---


Training: 100%|██████████| 351563/351563 [1:14:43<00:00, 78.42it/s, loss=0.1310, accuracy=0.9596, f1=0.9207]
Test: 100%|██████████| 39063/39063 [01:29<00:00, 437.15it/s, loss=0.1167, accuracy=0.9646, f1=0.9343]


--- Epoch 4/30 ---


Training: 100%|██████████| 351563/351563 [1:14:52<00:00, 78.25it/s, loss=0.1282, accuracy=0.9606, f1=0.9225]
Test: 100%|██████████| 39063/39063 [01:29<00:00, 434.20it/s, loss=0.1148, accuracy=0.9645, f1=0.9341]


--- Epoch 5/30 ---


Training: 100%|██████████| 351563/351563 [1:15:50<00:00, 77.25it/s, loss=0.1276, accuracy=0.9609, f1=0.9230]
Test: 100%|██████████| 39063/39063 [01:31<00:00, 425.03it/s, loss=0.1157, accuracy=0.9647, f1=0.9344]


--- Epoch 6/30 ---


Training: 100%|██████████| 351563/351563 [1:17:34<00:00, 75.53it/s, loss=0.1275, accuracy=0.9609, f1=0.9231]
Test: 100%|██████████| 39063/39063 [01:46<00:00, 368.23it/s, loss=0.1156, accuracy=0.9649, f1=0.9347]


--- Epoch 7/30 ---


Training: 100%|██████████| 351563/351563 [1:19:53<00:00, 73.33it/s, loss=0.1274, accuracy=0.9609, f1=0.9231]  
Test: 100%|██████████| 39063/39063 [01:49<00:00, 355.65it/s, loss=0.1152, accuracy=0.9645, f1=0.9341]


--- Epoch 8/30 ---


Training: 100%|██████████| 351563/351563 [1:18:14<00:00, 74.89it/s, loss=0.1272, accuracy=0.9610, f1=0.9233] 
Test: 100%|██████████| 39063/39063 [01:43<00:00, 377.29it/s, loss=0.1141, accuracy=0.9628, f1=0.9315]


--- Epoch 9/30 ---


Training:   6%|▌         | 20096/351563 [02:58<48:59, 112.74it/s, loss=0.1266, accuracy=0.9611, f1=0.9235] 


KeyboardInterrupt: 

In [17]:
bssf.eval()
with torch.no_grad():
    test_losses = []
    test_accuracies = []
    test_confusion_matrix = np.zeros((2, 2), dtype=np.uint32)
    p_bar = tqdm(test_loader, desc="Test", position=0)
    for data, target in p_bar:
        data = data.to(device).float().unsqueeze(1)
        print(data.shape)
        target = target.to(device).float()
        optimizer.zero_grad()

        with autocast():
            logits = bssf(data)
            loss = loss_fn(logits, target)

        test_losses.append(loss.item())

        # Calulate Accuracy and F1
        predictions = (logits > 0).detach().cpu().numpy()
        detached_targets = target.detach().cpu().numpy()
        is_correct = (predictions == detached_targets)

        accuracy = np.mean(is_correct)
        test_accuracies.append(accuracy)

        # Calculate F1
        test_confusion_matrix[0,0] += np.sum(np.logical_and(predictions == 0, detached_targets == 0))
        test_confusion_matrix[0,1] += np.sum(np.logical_and(predictions == 0, detached_targets == 1))
        test_confusion_matrix[1,0] += np.sum(np.logical_and(predictions == 1, detached_targets == 0))
        test_confusion_matrix[1,1] += np.sum(np.logical_and(predictions == 1, detached_targets == 1))

        f1 = (2*test_confusion_matrix[1,1] / (2*test_confusion_matrix[1,1] + test_confusion_matrix[0,1] + test_confusion_matrix[1,0]))


        p_bar.set_postfix(
            {
                "loss": f"{sum(test_losses)/len(test_losses):.4f}",
                "accuracy": f"{sum(test_accuracies)/len(test_accuracies):.4f}",
                "f1": f"{f1:.4f}",
            }
        )

Test:   0%|          | 44/39063 [00:00<09:26, 68.91it/s, loss=0.1218, accuracy=0.9656, f1=0.9359] 

torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
t

Test:   0%|          | 119/39063 [00:01<03:42, 174.97it/s, loss=0.1216, accuracy=0.9655, f1=0.9358]

torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
t

Test:   1%|          | 215/39063 [00:01<02:09, 300.21it/s, loss=0.1229, accuracy=0.9652, f1=0.9349]

torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
t

Test:   1%|          | 311/39063 [00:01<01:41, 380.92it/s, loss=0.1226, accuracy=0.9652, f1=0.9350]

torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
t

Test:   1%|          | 399/39063 [00:01<01:39, 387.14it/s, loss=0.1229, accuracy=0.9652, f1=0.9351]

torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
t

Test:   1%|▏         | 492/39063 [00:01<01:33, 412.83it/s, loss=0.1229, accuracy=0.9652, f1=0.9353]

torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
t

Test:   2%|▏         | 587/39063 [00:02<01:27, 441.59it/s, loss=0.1232, accuracy=0.9650, f1=0.9349]

torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
t

Test:   2%|▏         | 633/39063 [00:02<01:26, 446.61it/s, loss=0.1234, accuracy=0.9649, f1=0.9347]

torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
torch.Size([256, 1, 7, 6])
t

Test:   2%|▏         | 656/39063 [00:02<02:20, 273.05it/s, loss=0.1234, accuracy=0.9649, f1=0.9347]


KeyboardInterrupt: 

In [15]:
import datetime
import os

now = datetime.datetime.now()
torch.save(bssf.state_dict(), f"model_{now.minute}_{now.hour}_{now.day}_{now.month}.pt")


In [16]:
import json

file_name = f"model_{now.minute}_{now.hour}_{now.day}_{now.month}.json"
with open(file_name, "w") as f:
    metrics = {"accuracy": sum(test_accuracies)/len(test_accuracies), "f1": f1}
    json.dump(metrics, f)

In [None]:
import torch

# model.to("cpu")
# saved_model_dict = torch.load("model_1_0_2_12.pt")
# model.load_state_dict(saved_model_dict)

<All keys matched successfully>

In [None]:
# game_board = np.zeros((BOARD_WIDTH, BOARD_HEIGHT))




In [None]:
# game_board[6, 3] = 2
# print(game_board)

[[0. 0. 0. 0. 0. 0.]
 [1. 2. 0. 0. 0. 0.]
 [2. 1. 2. 2. 2. 1.]
 [2. 1. 2. 1. 1. 2.]
 [2. 2. 1. 2. 1. 1.]
 [1. 2. 2. 1. 1. 0.]
 [1. 2. 1. 2. 0. 0.]]


In [None]:
# results = model(torch.tensor(game_board.reshape(1, 1, BOARD_WIDTH, BOARD_HEIGHT)).float())
# print(results)
# print(torch.argmax(results))

tensor([[  3.8322,   8.6084, -17.3531,  -3.2702, -19.3175,   5.7995,  21.1401]],
       grad_fn=<AddmmBackward0>)
tensor(6)
