In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from src.model import MineSweeperDataset, PatchMLPModel, OnHotEncodingTransform, Game
from src.player import ThresholdPlayer

In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

In [3]:
train = Game(16, 30, 99, n = 5000)
train.random_open(0.5)
train.random_flags(0.3)
test = Game(n = 100)
test.random_open(0.5)
test.random_flags(0.3)
transform = OnHotEncodingTransform(2)
training_data = MineSweeperDataset(train, transform)
test_data = MineSweeperDataset(test, transform)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
model = PatchMLPModel(2, device)
player = ThresholdPlayer(model)
optimizer = torch.optim.Adam(model.model.parameters(), lr=0.0001)

In [None]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    model.train(train_dataloader, optimizer)
    model.test(test_dataloader)

In [None]:
reinforcing_iterations = 10
epochs = 1
for i in range(reinforcing_iterations):
    model.save(f'weights/patch_mlp_4x4_200x4_{i}.pth')
    games = Game(16, 30, 99, n = 500)
    games.random_open(0.5)
    games.random_flags(0.3)
    training_data.mix(games)
    train_dataloader = DataLoader(training_data, batch_size=64)
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        model.train(train_dataloader, optimizer)
        model.test(test_dataloader)

In [6]:
model.save('weights/patch_mlp_4x4_200x4.pth')

In [None]:
import plotly_express as px
import pandas as pd
df = pd.DataFrame({'train': model.train_loss_log, 'test': model.test_loss_log})
px.line(df)