In [7]:
!python -m pip install pysimplegui
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
import PySimpleGUI as sg
import sys

sys.path.append("../")
from src.alpha_connect.connect4_game import Connect4Game
from src.alpha_connect.connect4_supervised_agent import Connect4SupervisedAgent
from src.alpha_connect.connect4_random_agent import Connect4RandomAgent
from src.alpha_connect.connect4_MCTS_agent import Connect4MCTSAgent
from src.alpha_connect.connect4_human_agent import Connect4HumanAgent


def create_connect4_board(board_rows, board_columns):
    """Create the Connect 4 game board layout using circles."""
    cell_size = 60  # Size of each cell in the grid
    circle_radius = 25  # Radius of each circle

    # Calculate the size of the graph canvas
    graph_canvas_size = (board_columns * cell_size, board_rows * cell_size)

    players = {
        "Human": Connect4HumanAgent(),
        "Supervised": Connect4SupervisedAgent("epoch=72-step=18688.ckpt"),
        "Random": Connect4RandomAgent(),
        "MCTS": Connect4MCTSAgent(),
    }

    # Create a Graph element to draw the circles
    graph = sg.Graph(
        canvas_size=graph_canvas_size,
        graph_bottom_left=(0, 0),
        graph_top_right=graph_canvas_size,
        background_color="blue",
        key="-GRAPH-",
        enable_events=True,
    )

    # Define the layout for the menu on the right
    menu_layout = [
        [sg.Button("Reset", key="-RESET-")],
        [sg.Button("Play", key="-PLAY-")],
        [sg.Text("Player 1:")],
        [sg.DropDown(["Human", "Supervised", "Random", "MCTS"], key="-PLAYER1-")],
        [sg.Text("Player 2:")],
        [sg.DropDown(["Human", "Supervised", "Random", "MCTS"], key="-PLAYER2-")],
    ]

    # Combine the graph and menu layouts
    layout = [[graph, sg.Column(menu_layout)]]

    # Create the window
    window = sg.Window("Connect 4", layout, finalize=True)

    # Draw the initial board with empty circles (white)
    for row in range(board_rows):
        for col in range(board_columns):
            center_x = col * cell_size + cell_size // 2
            center_y = row * cell_size + cell_size // 2
            graph.DrawCircle(
                (center_x, center_y),
                circle_radius,
                fill_color="white",
                line_color="white",
            )

    game = None
    # Event loop
    while True:
        event, values = window.read()
        if event == sg.WIN_CLOSED:
            break
        elif event == "-GRAPH-":  # If a graph area is clicked
            if game is None:
                continue
            x, y = values["-GRAPH-"]
            col = x // cell_size
            Connect4HumanAgent.add_waiting_move(col)

            for row in range(board_rows):
                for col in range(board_columns):
                    center_x = col * cell_size + cell_size // 2
                    center_y = row * cell_size + cell_size // 2
                    graph.DrawCircle(
                        (center_x, center_y),
                        circle_radius,
                        fill_color="white",
                        line_color="white",
                    )

        elif event == "-RESET-":
            game = None

            for row in range(board_rows):
                for col in range(board_columns):
                    center_x = col * cell_size + cell_size // 2
                    center_y = row * cell_size + cell_size // 2
                    graph.DrawCircle(
                        (center_x, center_y),
                        circle_radius,
                        fill_color="white",
                        line_color="white",
                    )
            print("Reset the game logic here")

        elif event == "-PLAY-":
            agent1 = players.get(values["-PLAYER1-"], Connect4HumanAgent())
            agent2 = players.get(values["-PLAYER2-"], Connect4HumanAgent())
            print("Play button logic here")
            if not play_game(
                agent1,
                agent2,
                graph,
                cell_size,
                circle_radius,
                board_rows,
                board_columns,
                window,
            ):
                break  # Exit the game if the window is closed

    window.close()


def update_board(graph, cell_size, circle_radius, board_rows, board_columns, game):
    colors = {-1: "white", 0: "yellow", 1: "red"}
    for row in range(board_rows):
        for col in range(board_columns):
            center_x = col * cell_size + cell_size // 2
            center_y = row * cell_size + cell_size // 2
            color = colors[game.get_value(col, row)]
            graph.DrawCircle(
                (center_x, center_y), circle_radius, fill_color=color, line_color=color
            )


def play_game(
    agent1, agent2, graph, cell_size, circle_radius, board_rows, board_columns, window
):
    game = Connect4Game(agent1, agent2)
    agents = [agent1, agent2]
    current_player = 0
    while not game.has_ended():
        window.refresh()
        if not isinstance(agents[current_player], Connect4HumanAgent):
            print("Playing move")
            game = game.play()
            current_player = 1 - current_player
            update_board(
                graph, cell_size, circle_radius, board_rows, board_columns, game
            )
            continue

        event, values = window.read()
        if event == sg.WIN_CLOSED:
            return False
        elif event == "-RESET-":
            game = game.reset()
            update_board(
                graph, cell_size, circle_radius, board_rows, board_columns, game
            )
            return True
        elif event == "-GRAPH-":  # If a graph area is clicked
            if game is None:
                continue
            x, y = values["-GRAPH-"]
            col = x // cell_size
            Connect4HumanAgent.add_waiting_move(col)
            game = game.play()
            current_player = 1 - current_player

            update_board(
                graph, cell_size, circle_radius, board_rows, board_columns, game
            )

    update_board(graph, cell_size, circle_radius, board_rows, board_columns, game)

    return True


if __name__ == "__main__":
    # Constants for the board size
    BOARD_ROWS = 6
    BOARD_COLUMNS = 7
    create_connect4_board(BOARD_ROWS, BOARD_COLUMNS)

Play button logic here
Playing move
tensor([[-3.3900,  1.4151, -0.7727,  1.0921,  0.3966,  0.1480, -0.8059]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.6776]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-1.8944, -0.4326, -1.1811,  1.1263,  1.3570, -1.5859, -1.9001]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.5452]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.7454, -0.5442, -1.5982,  6.6076, -0.9766, -1.5362, -2.2517]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.2788]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-4.1409, -0.3146, -1.7426,  6.6933, -2.1274, -1.0571, -2.9472]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.2488]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-4.1393, -0.6401, -3.6431,  0.2323,  1.5146,  2.4806, -2.0246]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.2114]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.8555, -1.4122, -2.7683,  0.0673, -1.3129,  5.0176, -3.1633]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.3330]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.0811,  0.0656, -2.3206,  1.0158,  2.0044,  0.8863, -1.9635]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0515]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.0787,  0.4446, -1.9887,  2.4813, -1.0939, -0.1351,  0.1884]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.2472]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-1.9433,  1.4478, -0.2239, -2.7830, -1.0188,  1.2845,  0.5645]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.1492]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-2.8008,  3.8318,  0.0695, -2.4419, -1.9976,  0.5888, -0.3834]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0360]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-1.5750, -0.7747,  1.0313, -3.7737, -0.2400,  0.6451,  0.6129]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0105]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.0194,  0.8039, -1.7845, -1.3432,  3.6980, -0.8704, -0.8910]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0265]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-2.2443, -1.1455,  4.6406, -0.8291, -1.4125, -0.7027, -1.8874]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0181]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Reset the game logic here
Play button logic here
Playing move
tensor([[-3.3900,  1.4151, -0.7727,  1.0921,  0.3966,  0.1480, -0.8059]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.6776]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.1860,  3.1168, -1.5757,  3.0996, -1.5580, -0.9006, -2.3867]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.2060]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.5837, -0.9854, -1.2451,  6.3238, -0.7209, -1.4961, -1.9684]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0695]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.0099, -2.4110, -0.3512, -3.0398,  3.6543, -2.6820,  0.9838]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0919]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-1.8302,  0.6305, -2.1090,  0.4095,  1.6344,  0.3507, -1.6099]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0647]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-2.7976, -1.6782, -0.5966, -1.3508,  1.0789,  2.1603, -2.6481]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0177]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.2850,  6.0653, -0.9876, -0.8824, -0.6206, -1.0325, -1.7305]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0071]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-4.0054, -2.7719, -2.0073,  2.8900, -1.3653,  2.9763, -2.4697]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[-0.0031]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.5546, -1.3779, -1.5108,  6.1487, -0.4611, -1.3085, -1.4654]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0058]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.0705, -2.1716,  1.3392, -2.2771,  3.4107, -1.8925, -0.7168]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0079]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.1062, -2.9978,  5.1961, -2.0361, -1.0742, -0.8125, -1.0915]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0085]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-2.9654, -2.1298,  6.1674, -0.9059, -0.5436, -2.3358, -2.2308]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[-0.0004]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Play button logic here
Playing move
tensor([[-3.3900,  1.4151, -0.7727,  1.0921,  0.3966,  0.1480, -0.8059]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.6776]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-2.9814,  0.3948,  3.7997, -2.1420, -0.6837, -0.9028, -1.1323]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.3573]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-4.7663,  0.1751,  2.9421, -0.0601,  1.6050, -0.5511, -2.2661]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0005]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-2.6738, -0.4671, -1.0214,  4.8412, -1.0748, -0.6370, -2.0110]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.1359]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-2.9269,  0.1380, -2.4846,  2.4668,  2.6933, -1.6685, -1.2156]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[-0.0083]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-4.1607,  0.3903, -1.8808,  4.5856,  2.4571, -3.7486, -1.2315]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[-0.0132]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-4.2053, -0.7731, -1.5941,  3.1004,  3.3811, -3.8426, -1.1926]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[-0.0179]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.6607, -0.1523, -1.2436,  4.9361, -0.6045, -1.2129, -1.0851]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[-0.0218]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-4.1843, -0.2059, -0.0847, -2.1639,  3.7809, -2.6196, -1.1023]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[-0.0145]], device='mps:0', grad_fn=<TanhBackward0>)
Playing move


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-4.6954,  6.4308, -1.1031, -1.2585,  0.0663, -0.3118, -2.0501]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0064]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Play button logic here
Playing move
tensor([[-3.3900,  1.4151, -0.7727,  1.0921,  0.3966,  0.1480, -0.8059]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.6776]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.1860,  3.1168, -1.5757,  3.0996, -1.5580, -0.9006, -2.3867]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.2060]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.5837, -0.9854, -1.2451,  6.3238, -0.7209, -1.4961, -1.9684]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0695]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-4.3774,  1.7991, -1.7517,  3.9376, -1.8180, -0.7140, -1.6585]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0139]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-4.5720, -1.4995, -1.1233, -1.7001, -1.1468,  4.7646, -1.6084]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0279]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.5042, -1.2604,  2.1320, -0.5579,  2.8089, -2.2168, -2.0625]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0122]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.3992, -0.9107,  5.3708,  0.0522, -1.2673, -0.8088, -2.2122]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0059]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-4.4443, -0.6008,  5.2738, -2.7613,  1.3282, -1.4579, -1.1443]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[-0.0118]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Play button logic here
Playing move
tensor([[-3.3900,  1.4151, -0.7727,  1.0921,  0.3966,  0.1480, -0.8059]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.6776]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.1860,  3.1168, -1.5757,  3.0996, -1.5580, -0.9006, -2.3867]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.2060]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.9605,  2.6615, -1.8449,  3.4286, -1.4444, -1.3391, -1.7576]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[-0.0197]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Play button logic here
Playing move
tensor([[-3.3900,  1.4151, -0.7727,  1.0921,  0.3966,  0.1480, -0.8059]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.6776]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.1860,  3.1168, -1.5757,  3.0996, -1.5580, -0.9006, -2.3867]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.2060]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-2.2898, -1.3915,  1.3813, -3.1526, -2.6743,  2.7964, -0.0269]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0150]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-3.6539, -1.0617, -1.3744,  5.6046, -1.2006, -0.1660, -1.6556]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.1311]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-4.1939, -0.4299,  2.8156, -1.1875,  1.9207, -1.8327, -1.4473]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0011]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-2.8371, -0.8527,  5.3941, -1.6142, -0.0806, -1.2201, -1.7808]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0493]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[-2.8906, -0.7439,  2.0297,  2.3487, -0.4608, -0.4669, -2.4620]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.1316]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Playing move
tensor([[ 2.1174, -2.1386, -0.1550, -2.0652, -0.8072, -0.9052, -3.0463]],
       device='mps:0', grad_fn=<LinearBackward0>) tensor([[0.0941]], device='mps:0', grad_fn=<TanhBackward0>)


  y_hat,value = self.model(torch.tensor(input, device='mps').type(torch.float32).view(1,3,6,7))


Reset the game logic here
