In [10]:
import json

import numpy as np
import matplotlib.pyplot as plt
import torch

from src.game import Game
from src.models.patch_mlp import PatchMLPModel
from src.models.unet import UnetModel
from src.models.conv import ConvModel
from src.player import *

In [11]:

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(device)

cpu


In [12]:
result_file = 'weights/scores.json'
r, c, m, n = (16, 30, 99, 100)

In [13]:
models: list[tuple[MinesweeperModel, str]] = [
    (ConvModel, 'weights/conv_3x3_64.pth'),
    (ConvModel, 'weights/conv_3x3_128.pth'),
    (ConvModel, 'weights/conv_5x5_64.pth'),
    (ConvModel, 'weights/conv_5x5_128.pth'),
    (ConvModel, 'weights/conv_7x7_64.pth'),
    (PatchMLPModel, 'weights/patch_mlp_7x7_512.pth'),
    (PatchMLPModel, 'weights/patch_mlp_7x7_1024.pth'),
    (UnetModel, 'weights/unet_16x30_64.pth'),
]

In [14]:
games = Game(r, c, m, n)
zeros = games.open_zero()
win_rates = {}
max_win_rate = 0

In [15]:
for model_class, path in models:
    model = model_class.load(path, device)
    player = ThresholdPlayer(model)
    games.reset()
    games.move(zeros)
    player.play(games)
    win_rate = games.win_rate()
    win_rates[path] = win_rate
    print(f'{path} win rate: {win_rate}')
    if win_rate > max_win_rate:
        max_win_rate = win_rate
        best_model = model

weights/conv_3x3_64.pth win rate: 0.52
weights/conv_3x3_128.pth win rate: 0.43
weights/conv_5x5_64.pth win rate: 0.44
weights/conv_5x5_128.pth win rate: 0.4
weights/conv_7x7_64.pth win rate: 0.43
weights/patch_mlp_7x7_512.pth win rate: 0.41
weights/patch_mlp_7x7_1024.pth win rate: 0.35
weights/unet_16x30_64.pth win rate: 0.39


In [16]:
with open(result_file, 'w+') as fp:
    json.dump({'n': n, 'winRates': win_rates}, fp, indent=4)

In [17]:
for thresholds in [(0.05, 0.95), (0.02, 0.98), (0.015, 0.985), (0.01, 0.99)]:
    player = ThresholdPlayer(best_model, *thresholds)
    games.reset()
    games.move(zeros)
    player.play(games)
    win_rate = games.win_rate()
    print(f'{thresholds} win rate: {win_rate}')

(0.05, 0.95) win rate: 0.41
(0.02, 0.98) win rate: 0.49
(0.015, 0.985) win rate: 0.52
(0.01, 0.99) win rate: 0.52
