In [9]:
def check_win(board, player):
    """
    Check if the player can win in the next move and return all such moves.
    """
    winning_moves = []
    winning_combinations = [(0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
                            (0, 3, 6), (1, 4, 7), (2, 5, 8),  # columns
                            (0, 4, 8), (2, 4, 6)]  # diagonals
    for a, b, c in winning_combinations:
        if board[a] == board[b] == player and board[c] == ' ':
            winning_moves.append(c)
        if board[a] == board[c] == player and board[b] == ' ':
            winning_moves.append(b)
        if board[b] == board[c] == player and board[a] == ' ':
            winning_moves.append(a)
    return winning_moves

def check_block(board, player):
    """
    Check if the opponent can win in the next move and return all blocking moves.
    """
    opponent = 'O' if player == 'X' else 'X'
    return check_win(board, opponent)

def is_unblocked_line(board, player, a, b, c):
    """
    Check if a line (defined by indices a, b, c) is unblocked for the given player.
    """
    return ((board[a] == board[b] == player and board[c] == ' ') or
            (board[a] == board[c] == player and board[b] == ' ') or
            (board[b] == board[c] == player and board[a] == ' '))

def check_fork(board, player):
    """
    Check if the player can create a fork and return all such moves.
    """
    fork_moves = []
    winning_combinations = [(0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
                            (0, 3, 6), (1, 4, 7), (2, 5, 8),  # columns
                            (0, 4, 8), (2, 4, 6)]  # diagonals
    for i in range(9):
        if board[i] == ' ':
            board[i] = player  # Temporarily place the player's marker
            unblocked_lines = 0  # Count of unblocked lines
            for a, b, c in winning_combinations:
                if is_unblocked_line(board, player, a, b, c):
                    unblocked_lines += 1
            if unblocked_lines >= 2:
                fork_moves.append(i)
            board[i] = ' '  # Reset the board
    return fork_moves

def is_two_in_a_row(board, player, a, b, c):
    return ((board[a] == player and board[b] == player and board[c] == ' ') or
            (board[a] == player and board[c] == player and board[b] == ' ') or
            (board[b] == player and board[c] == player and board[a] == ' '))

# Update the check_block_fork function to include the additional logic
def check_block_fork(board, player):
    opponent = 'O' if player == 'X' else 'X'
    forks = check_fork(board, opponent)

    # Check if opponent has two opposite corners
    if board[0] == board[8] == opponent:
        if board[2] == ' ' or board[6] == ' ':
            to_rtn = []
            for i in [1, 3, 5, 7]:
                if board[i] == ' ': to_rtn.append(i)
            return to_rtn
    if board[2] == board[6] == opponent:
        if board[0] == ' ' or board[8] == ' ':
            to_rtn = []
            for i in [1, 3, 5, 7]:
                if board[i] == ' ': to_rtn.append(i)
            return to_rtn

    if len(forks) > 1:
        two_in_a_row_moves = []
        winning_combinations = [(0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
                                (0, 3, 6), (1, 4, 7), (2, 5, 8),  # columns
                                (0, 4, 8), (2, 4, 6)]  # diagonals
        for i in range(9):
            if board[i] == ' ':
                board[i] = player  # Temporarily place the player's marker
                for a, b, c in winning_combinations:
                    if is_two_in_a_row(board, player, a, b, c):
                        two_in_a_row_moves.append(i)
                        break  # No need to check further combinations for this move
                board[i] = ' '  # Reset the board
        # Return intersection of two_in_a_row_moves and forks
        return list(set(two_in_a_row_moves) & set(forks))
    if len(forks) == 1:
        return forks
    return []

def check_center(board):
    """
    Check if the center is free and return it as a move if it is.
    """
    return [4] if board[4] == ' ' else []

def check_opposite_corner(board, player):
    """
    Check if the opponent is in the corner, and the opposite corner is free, and return all such moves.
    """
    opponent = 'O' if player == 'X' else 'X'
    opposite_corners = [(0, 8), (2, 6), (6, 2), (8, 0)]
    moves = []
    for a, b in opposite_corners:
        if board[a] == opponent and board[b] == ' ':
            moves.append(b)
        if board[b] == opponent and board[a] == ' ':
            moves.append(a)
    return moves

def check_empty_corner(board):
    """
    Check for any empty corners and return all such moves.
    """
    corners = [0, 2, 6, 8]
    return [corner for corner in corners if board[corner] == ' ']

def check_empty_side(board):
    """
    Check for any empty sides and return all such moves.
    """
    sides = [1, 3, 5, 7]
    return [side for side in sides if board[side] == ' ']

# Update Main Function

def get_optimal_moves(board, player):
    """
    Get all optimal moves for the given board and player.
    """
    for check in [check_win, check_block, check_fork, check_block_fork]:
        moves = check(board, player)
        if moves:
            return moves

    for check in [check_center, check_empty_corner, check_empty_side]:
        moves = check(board)
        if moves:
            return moves

    for check in [check_opposite_corner]:
        moves = check(board, player)
        if moves:
            return moves

    return []  # Should never reach this point in a valid game of Tic-Tac-Toe

In [10]:
# Initialize list to store sequences of all finished games
finished_games = []

def is_winner(board, player):
    """
    Check if the player has won on the current board.
    """
    winning_combinations = [(0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
                            (0, 3, 6), (1, 4, 7), (2, 5, 8),  # columns
                            (0, 4, 8), (2, 4, 6)]  # diagonals
    for a, b, c in winning_combinations:
        if board[a] == board[b] == board[c] == player:
            return True
    return False

def simulate_game(board, move_sequence, next_player):
    """
    Simulate a game of Tic-Tac-Toe recursively.
    
    Parameters:
        board (list): The current game board.
        move_sequence (list): The sequence of moves made so far.
        next_player (str): The player to move next ('X' or 'O').
    """
    
    # Check for game over conditions (win or draw)
    if is_winner(board, 'X') or is_winner(board, 'O'):
        finished_games.append(move_sequence[:])
        return
    if ' ' not in board:
        finished_games.append(move_sequence[:])
        return
    
    # Optimal player's move
    if next_player == 'X':
        optimal_moves = get_optimal_moves(board, next_player)
        for move in optimal_moves:
            board[move] = next_player  # Make the move
            move_sequence.append(move)  # Record the move
            simulate_game(board, move_sequence, 'O')  # Recursive call
            board[move] = ' '  # Undo the move
            move_sequence.pop()  # Remove the last move from the sequence
    
    # All moves for the non-optimal player
    else:
        for move in range(9):
            if board[move] == ' ':
                board[move] = next_player  # Make the move
                move_sequence.append(move)  # Record the move
                simulate_game(board, move_sequence, 'X')  # Recursive call
                board[move] = ' '  # Undo the move
                move_sequence.pop()  # Remove the last move from the sequence

# Initialize board and move_sequence
initial_board = [' ' for _ in range(9)]
initial_move_sequence = []

# Start the simulation with 'X' going first
simulate_game(initial_board, initial_move_sequence, 'X')

# Show some of the finished games to verify correctness
finished_games[:10], len(finished_games)

([[4, 0, 2, 1, 6],
  [4, 0, 2, 3, 6],
  [4, 0, 2, 5, 6],
  [4, 0, 2, 6, 3, 1, 5],
  [4, 0, 2, 6, 3, 5, 8, 1, 7],
  [4, 0, 2, 6, 3, 5, 8, 7, 1],
  [4, 0, 2, 6, 3, 7, 5],
  [4, 0, 2, 6, 3, 8, 5],
  [4, 0, 2, 7, 6],
  [4, 0, 2, 8, 6]],
 488)

In [11]:
import pickle

# Save the list of lists to appropriate file
with open('finished_games.pkl', 'wb') as f:
    pickle.dump(finished_games, f)

# Load the list of lists from appropriate file
with open('finished_games.pkl', 'rb') as f:
    finished_games = pickle.load(f)

In [12]:
finished_games

[[4, 0, 2, 1, 6],
 [4, 0, 2, 3, 6],
 [4, 0, 2, 5, 6],
 [4, 0, 2, 6, 3, 1, 5],
 [4, 0, 2, 6, 3, 5, 8, 1, 7],
 [4, 0, 2, 6, 3, 5, 8, 7, 1],
 [4, 0, 2, 6, 3, 7, 5],
 [4, 0, 2, 6, 3, 8, 5],
 [4, 0, 2, 7, 6],
 [4, 0, 2, 8, 6],
 [4, 0, 6, 1, 2],
 [4, 0, 6, 2, 1, 3, 7],
 [4, 0, 6, 2, 1, 5, 7],
 [4, 0, 6, 2, 1, 7, 8, 3, 5],
 [4, 0, 6, 2, 1, 7, 8, 5, 3],
 [4, 0, 6, 2, 1, 8, 7],
 [4, 0, 6, 3, 2],
 [4, 0, 6, 5, 2],
 [4, 0, 6, 7, 2],
 [4, 0, 6, 8, 2],
 [4, 0, 8, 1, 2, 3, 5],
 [4, 0, 8, 1, 2, 3, 6],
 [4, 0, 8, 1, 2, 5, 6],
 [4, 0, 8, 1, 2, 6, 5],
 [4, 0, 8, 1, 2, 7, 5],
 [4, 0, 8, 1, 2, 7, 6],
 [4, 0, 8, 2, 1, 3, 7],
 [4, 0, 8, 2, 1, 5, 7],
 [4, 0, 8, 2, 1, 6, 7],
 [4, 0, 8, 2, 1, 7, 6, 3, 5],
 [4, 0, 8, 2, 1, 7, 6, 5, 3],
 [4, 0, 8, 3, 6, 1, 7],
 [4, 0, 8, 3, 6, 1, 2],
 [4, 0, 8, 3, 6, 2, 7],
 [4, 0, 8, 3, 6, 5, 7],
 [4, 0, 8, 3, 6, 5, 2],
 [4, 0, 8, 3, 6, 7, 2],
 [4, 0, 8, 5, 6, 1, 7],
 [4, 0, 8, 5, 6, 1, 2],
 [4, 0, 8, 5, 6, 2, 7],
 [4, 0, 8, 5, 6, 3, 7],
 [4, 0, 8, 5, 6, 3, 2],
 [4, 0, 8, 5, 6,

In [13]:
# Initialize list to store sequences of all finished games
finished_games_O_first = []
o_wins = 0

def simulate_game_O_first(board, move_sequence):
    """
    Simulate a game of Tic-Tac-Toe recursively where 'O' goes first.
    
    Parameters:
        board (list): The current game board.
        move_sequence (list): The sequence of moves made so far.
    """
    
    # Check for game over conditions (win or draw)
    if is_winner(board, 'X') or is_winner(board, 'O'):
        # if is_winner(board, 'X'): print("X won!")
        # if is_winner(board, 'O'): 
        #     print("O won!")
        finished_games_O_first.append(move_sequence[:])
        return
    if ' ' not in board:
        finished_games_O_first.append(move_sequence[:])
        return
    
    # All moves for the non-optimal player
    for move in range(9):
        if board[move] == ' ':
            board[move] = 'O'  # Make the move
            move_sequence.append(move)  # Record the move

            # Optimal player's move
            optimal_moves = get_optimal_moves(board, 'X')
            for x_move in optimal_moves:
                board[x_move] = 'X'  # Make the move
                move_sequence.append(x_move)  # Record the move 
                simulate_game_O_first(board, move_sequence)  # Recursive call
                board[x_move] = ' '  # Undo the move
                move_sequence.pop()  # Remove the last move from the sequence
                
            board[move] = ' '  # Undo the move
            move_sequence.pop()  # Remove the last move from the sequence

# Initialize board and move_sequence
initial_board = [' ' for _ in range(9)]
initial_move_sequence = []

# Start the simulation with 'O' going first
simulate_game_O_first(initial_board, initial_move_sequence)

# Show some of the finished games to verify correctness
finished_games_O_first[:10], len(finished_games_O_first)

([[0, 4, 1, 2, 3, 6],
  [0, 4, 1, 2, 5, 6],
  [0, 4, 1, 2, 6, 3, 7, 5],
  [0, 4, 1, 2, 6, 3, 8, 5],
  [0, 4, 1, 2, 7, 6],
  [0, 4, 1, 2, 8, 6],
  [0, 4, 2, 1, 3, 7],
  [0, 4, 2, 1, 5, 7],
  [0, 4, 2, 1, 6, 7],
  [0, 4, 2, 1, 8, 7]],
 656)

In [14]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import torch

# Revised TicTacToeDataset class
class TicTacToeDataset(Dataset):
    def __init__(self, finished_games_X_first, finished_games_O_first):
        self.data = []
        
        # For games where X went first
        for game in finished_games_X_first:
            for i in range(1, len(game), 2):  # Only X's moves
                sub_seq = [9] + game[:i]
                self.data.append((sub_seq, game[i]))
        
        # For games where X went second
        for game in finished_games_O_first:
            for i in range(0, len(game), 2):  # Only X's moves
                sub_seq = [9] + game[:i+1]
                self.data.append((sub_seq, game[i+1]))
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sub_seq, target = self.data[idx]
        # Pad the sub_seq to have length 9
        padded_sub_seq = sub_seq + [10] * (10 - len(sub_seq))
        return torch.tensor(padded_sub_seq, dtype=torch.long).view(-1), torch.tensor(target, dtype=torch.long).view(-1)

# Create Dataset
dataset = TicTacToeDataset(finished_games, finished_games_O_first)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

# Fetch a batch to check
j = 800
for i, (X, y) in enumerate(dataloader):
    if i == j:
        input_tensor = X
        output_tensor = y
        print("X:", X)
        print("y:", y)
        break

X: tensor([[ 9,  7,  4,  1, 10, 10, 10, 10, 10, 10],
        [ 9,  7,  4,  1,  0,  3, 10, 10, 10, 10],
        [ 9,  7, 10, 10, 10, 10, 10, 10, 10, 10],
        [ 9,  7,  4,  1, 10, 10, 10, 10, 10, 10]])
y: tensor([[0],
        [8],
        [4],
        [0]])


In [15]:
from copy import deepcopy
from itertools import product
import numpy as np
import torch

itos = {0: '-', 1: 'X', 2: '0'}
decode_to_sym = lambda l: ''.join([itos[i] for i in l])

def decode(game):
    decoded_game = np.zeros([9])
    for i in range(9):
        index = np.where(np.array(random_game) == i)
        
        if np.size(index) == 0:
            pass
        
        elif i%2 == 0:
            decoded_game[i] = 1
            
        else:
            decoded_game[i] = 2
    return decoded_game
        
def pretty_print_board(board: str) :
    """Pretty-print the board"""
    for i in range(0, 9, 3) :
        print(f"{board[i]} | {board[i+1]} | {board[i+2]}")
        if i < 6 :
            print("- "*5)

rand_idx = torch.randint(len(input_tensor), (1,))[0]
random_game = input_tensor[rand_idx].tolist()
print("Current game state:")
decoded_game = decode_to_sym(decode(random_game))
pretty_print_board(decoded_game)
print()

move = output_tensor[rand_idx].item()
decoded_game = decoded_game[:move] + 'X' + decoded_game[move+1:]
print("New game state:")
pretty_print_board(decoded_game)


Current game state:
- | 0 | -
- - - - - 
- | X | -
- - - - - 
- | 0 | -

New game state:
X | 0 | -
- - - - - 
- | X | -
- - - - - 
- | 0 | -


In [16]:
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Transformer

model = Transformer()

# Hyperparameters
batch_size = 128  # How many independent sequences will we process in parallel?
block_size = 9  # The size of the tic-tac-toe board
max_iters = 10000
eval_interval = 500
learning_rate = 1e-3
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
eval_iters = 100
n_embd = 32  # Reduced the embedding size
n_head = 2  # Reduced the number of heads
n_layer = 2  # Reduced the number of layers
dropout = 0.1

print(f'Training on {device}')

# Initialize random seed
torch.manual_seed(1337)

# Split into training and validation sets
n = int(0.90 * len(input_tensor))  # 90% for training
train_input = input_tensor[:n]
train_output = output_tensor[:n]
val_input = input_tensor[n:]
val_output = output_tensor[n:]

def get_batch(split):
    input_data = train_input if split == 'train' else val_input
    output_data = train_output if split == 'train' else val_output
    # Choose index locs for batch_size sequences
    ix = torch.randint(len(input_data) - block_size + 1, (batch_size,))
    # Get the input and output sequences
    x = input_data[ix]
    y = output_data[ix]
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out
    
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class Transformer(nn.Module):

    def __init__(self):
        super().__init__()
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)  # final layer norm
        self.lm_head = nn.Linear(n_embd, 9)

        # better init, not covered in the original GPT video, but important, will improve training stability
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx)  # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))  # (B,T,C)
        x = tok_emb + pos_emb  # (B,T,C)
        x = self.blocks(x)  # (B,T,C)
        x = self.ln_f(x)  # (B,T,C)
        logits = self.lm_head(x)  # (B,T,vocab_size)

        # Take the logits from corresponding to the last time step T
        logits = logits[:, -1, :]  # Now logits is (B, 9)

        if targets is None:
            loss = None
        else:
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

model = model.to(device)

print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay = 0.01)
train_loss_history=[]
val_loss_history=[]

for iter in tqdm(range(max_iters)) :
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses}")
        val_loss_history.append(losses['val'])

xb, yb = get_batch('train')

logits, loss = model(xb, yb)
train_loss_history.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()

Training on mps
44.140544 M parameters


  0%|                                                 | 0/10000 [00:00<?, ?it/s]


RuntimeError: random_ expects 'from' to be less than 'to', but got from=0 >= to=-5

In [49]:
import torch
from torch import nn
from torch.nn import Transformer
import math

batch_size = 128
block_size = 9  # The size of the tic-tac-toe board

device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Training on {device}')

# Initialize random seed
torch.manual_seed(1337)

n = int(0.90 * len(input_tensor))  # 90% for training
train_input = input_tensor[:n]
train_output = output_tensor[:n]
val_input = input_tensor[n:]
val_output = output_tensor[n:]

def get_batch(split):
    input_data = train_input if split == 'train' else val_input
    output_data = train_output if split == 'train' else val_output
    # Choose index locs for batch_size sequences
    ix = torch.randint(len(input_data) - block_size + 1, (batch_size,))
    # Get the input and output sequences
    x = input_data[ix]
    y = output_data[ix]
    x, y = x.to(device), y.to(device)
    return x, y

# Define the model
class TicTacToeTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length):
        super(TicTacToeTransformer, self).__init__()
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_length)
        self.transformer = Transformer(d_model=d_model,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask):
        src = self.embed(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        tgt = self.embed(tgt) * math.sqrt(self.d_model)
        tgt = self.pos_encoder(tgt)
        output = self.transformer(src, tgt, src_mask, tgt_mask, src_key_padding_mask=src_padding_mask, tgt_key_padding_mask=tgt_padding_mask)
        output = self.out(output)
        return output

# Define the PositionalEncoding module
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        super(PositionalEncoding, self).__init__() 
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

# Model Hyperparameters
vocab_size = 9  # assuming 'chars' is the list of unique characters
d_model = 512  # the number of expected features in the input (required)
nhead = 2  # the number of heads in the multiheadattention model (required)
num_encoder_layers = 3  # the number of sub-encoder-layers in the encoder (required)
num_decoder_layers = 3  # the number of sub-decoder-layers in the decoder (required)
dim_feedforward = 2048  # the dimension of the feedforward network model (default: 2048)
max_seq_length = 10  # the maximum size of the input sequence (required)

# Initialize model
model = TicTacToeTransformer(vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length)

# Print the model
print(model)


Training on mps
TicTacToeTransformer(
  (embed): Embedding(9, 512)
  (pos_encoder): PositionalEncoding()
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-2): 3 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleL

In [36]:
import torch
from torch.nn import CrossEntropyLoss

# Assume get_batch is a function that gives you a batch from your dataset
# And assumes 'Transformer' refers to your defined TicTacToeTransformer model or similar
# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Retrieve a single batch
xb, yb = get_batch('train')  # this function is not defined in the provided snippets, replace with your own
xb, yb = xb.to(device), yb.to(device)

print(xb.shape, yb.shape)

# Initialize your Transformer model with required hyperparameters
# Replace 'vocab_size' and other hyperparameters with appropriate values for your task
vocab_size = 3  # For Tic-Tac-Toe 'X', 'O', and '-'
m = TicTacToeTransformer(vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length).to(device)

# Create the source and target masks, and the source padding mask
# These functions are placeholders and should be replaced with actual implementations
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_masks(xb, yb)

# Forward pass through the model
logits = m(xb, yb, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask)

# Since logits are expected to be of shape (N, C, L) where C is the number of classes,
# and CrossEntropyLoss expects (N, C), we might need to adjust dimensions
# Flatten the output for use with CrossEntropyLoss
logits = logits.view(-1, vocab_size)

# Flatten yb to match logits shape
yb = yb.view(-1)

# Define loss function
loss_fn = CrossEntropyLoss()

# Calculate loss
loss = loss_fn(logits, yb)

print(logits.shape)
print(f"Loss: {loss.item():.3f}")


RuntimeError: random_ expects 'from' to be less than 'to', but got from=0 >= to=-5