In [1]:
from connect_four.connect_four import ConnectFour, Agent
from connect_four.connect_four_dataset import (
    ConnectFourDataset, CharConnectFourDataset, DatasetPreprocessingConfig
)
import pickle
import torch
from mingpt.utils import sample
from mingpt.model import GPT, GPTConfig

device = torch.cuda.current_device()

In [2]:
with open("connect_four/dataset/dataset_6x7_110000.pkl", "rb") as f:
    game_transcriptions = pickle.load(f)

cf_data = ConnectFourDataset(data_size=0, train_size=100000, games_to_use=game_transcriptions)
char_cf_dataset = CharConnectFourDataset(cf_data)

Dataset created has 100000 sequences, 8 unique words.


In [3]:
class GPTAgent(Agent):
    def __init__(
            self,
            model: GPT,
            game: ConnectFour,
            preprocessing_config: DatasetPreprocessingConfig
        ):
        self.model = model
        self.game = game
        self.config = preprocessing_config

    def choose_move(self) -> int:
        x = torch.tensor(
            [self.config.to_model_repr[s] for s in self.game.history], dtype=torch.long
        )[None, ...].to(device)
        y = sample(self.model, x, 1, temperature=1.0)[0]
        completion = [self.config.from_model_repr[int(i)] for i in y if i != -1]
        return completion[-1]

In [4]:
game = ConnectFour()
mconf = GPTConfig(
    char_cf_dataset.config.vocab_size, char_cf_dataset.config.block_size, n_layer=2, n_head=8, n_embd=80
)
model = GPT(mconf).to(device)
model.load_state_dict(torch.load("./ckpts/gpt_at_20230618_093325.ckpt"))
gpt_agent = GPTAgent(model, game, char_cf_dataset.config)

In [5]:
game.make_move(piece=1, column=4) # First player is placing its piece in the fourth column
gpt_agent.choose_move() # GPT makes a move based on the game of length = 1

5