[Reference](https://medium.com/@theStump/how-to-beat-a-12-year-old-girl-in-battleship-using-transformers-with-reinforcement-learning-b506f7bea470)

In [5]:
from typing import List
import numpy as np
import numpy.typing as npt

def place_ships(board_height:int=10,board_width:int=10,ship_sizes:List[int]=[2,3,3,4,5]) -> np.ndarray:
    """ Return random ship positions."""
    board = np.zeros(shape=(board_width,board_height), dtype=np.float32)
    board_size = board_width * board_height
    def can_place_ship(x, y, length, direction):
        """Check if a ship can be placed at (x, y) in a given direction without overlapping."""
        if direction == "H":  # Horizontal
            if y + length > board_height:
                return False
            return all(board[x, y+i] == 0 for i in range(length))
        else:  # Vertical
            if x + length > board_width:
                return False
            return all(board[x+i, y] == 0 for i in range(length))

    def place_ship(x, y, length, direction):
        """Place a ship at (x, y) in a given direction."""
        for i in range(length):
            if direction == "H":
                board[x, y+i] = 1  # Mark ship presence
            else:
                board[x+i, y] = 1

    for ship_size in ship_sizes:
        placed = False
        while not placed:
            x, y = np.random.randint(0, board_width-1),np.random.randint(0, board_height-1)
            direction = np.random.choice(["H", "V"])  # Horizontal or Vertical

            if can_place_ship(x, y, ship_size, direction):
                place_ship(x, y, ship_size, direction)
                placed = True
    return np.reshape(board, (1, board_size))

def print_board(board: npt.NDArray, predicted_board: npt.NDArray = None):
    """Print the board with proper alignment."""

    cols = board.shape[1]  # Number of columns
    col_width = 1  # Space for each number
    separator = " | "  # Column separator

    # Header row
    column_numbers = "     " + separator.join(f"{chr(65+i):{col_width}}" for i in range(cols))
    top_border = "  " + "-" * (cols * (col_width + 3) - 1)

    if predicted_board is not None:
        column_numbers += "           " + separator.join(f"{chr(65+i):{col_width}}" for i in range(predicted_board.shape[1]))
        top_border += "      " + "-" * (predicted_board.shape[1] * (col_width + 3) - 1)
    print("Current Board \t\t\t\t\t Predicted Board")
    print(column_numbers)
    print(top_border)

    # Print board row by row
    for i, row in enumerate(board):
        row_str = f"{i+1:{col_width+1}} | " + separator.join(f"{int(cell):{col_width}}" for cell in row) + " |"
        if predicted_board is not None:
            row_p = predicted_board[i, :]
            row_str += "    " + f"{i+1:{col_width+1}} | " + separator.join(f"{int(cell):{col_width}}" for cell in row_p) + " |"
        print(row_str)


In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        # Initialize dimensions
        self.d_model = d_model # Model's dimension
        self.num_heads = num_heads # Number of attention heads
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value

        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)

        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        # Reshape the input to have num_heads for multi-head attention
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        # Perform scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)

        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        # nn.init.normal_(self.encoder_embedding.weight, mean=0, std=0.1)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model,padding_idx=0)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc_decoder = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def encoder(self, src: torch.Tensor, src_mask: torch.Tensor = None):
        """Encoder only model

        Args:
            src (torch.Tensor): Input current map as a Tensor [batch, boardheight*boardwidth]
            src_mask (torch.Tensor, optional): Mask for the input. Defaults to None.
        Returns:
            Tensor: Probabilities shaped [batch_size, boardheight*boardwidth]
        """
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)
        return enc_output

    def decoder(self, enc_output,tgt, src_mask=None, tgt_mask=None):
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))
        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)
        output = self.fc_decoder(dec_output)
        return output

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask


    def generate_random_mask(self, src, tgt, p:float=0.15):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 1).unsqueeze(1).unsqueeze(3)

        # Apply random masking
        random_src_mask = (torch.rand_like(src_mask.float()) > p).bool()
        random_tgt_mask = (torch.rand_like(tgt_mask.float()) > p).bool()

        src_mask = src_mask & random_src_mask
        tgt_mask = tgt_mask & random_tgt_mask

        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length,device=tgt.device), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask

        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask = None; tgt_mask = None
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc_decoder(dec_output)
        return output

In [None]:
'''
    Board values: -1 for no guesses, 0 bomb, 1 for hit
'''

from typing import List, Tuple
import torch.nn as nn
import torch.optim as optim
import torch, os
from torch import Tensor
import numpy as np
import pickle
import os.path as osp
from tqdm import trange,tqdm
import numpy.typing as npt
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from torch.amp import autocast, GradScaler

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
torch.cuda.empty_cache()
SHIP_SIZES = [2,3,3,4,5]
board_height = 10
board_width = 10


def generate_game_data(nboards:int,board_height:int,board_width:int,ship_sizes:List[int]) -> Tuple[npt.NDArray,npt.NDArray]:
    """Generates dummy game data for training

    Args:
        nboards (int): number boards to generate
        board_height (int): board height in units
        board_width (int): board width in units
        ship_sizes (List[int]): Array of ship sizes e.g. [2,3,3,4,5]
        src_blank (float): percent of source board to blank out

    Returns:
        Tuple[npt.NDArray,npt.NDArray]: source, target
    """
    percent_of_src_to_generate = 0.10
    number_of_guesses = int(board_height * board_width*(1-percent_of_src_to_generate))
    src_board = np.zeros((nboards*number_of_guesses,board_height*board_width))
    tgt_board = np.zeros((nboards*number_of_guesses,board_height*board_width))
    for indx in trange(nboards):
        ship_positions = place_ships(board_height,board_width,ship_sizes)
        ship_position_indices = np.where(ship_positions == 1)[1]
        bomb_locations = np.arange(board_height*board_width)
        for guess in range(number_of_guesses):
            tgt_board[indx*number_of_guesses+guess,:] = 2*(ship_positions  == 1) + 1*(ship_positions == 0)

        for p in range(board_height*board_width-number_of_guesses): # Lets guess 15 % of the board before we begin training
            bomb_index = np.random.choice(bomb_locations)
            src_board[indx*number_of_guesses,bomb_index] = 2 * (bomb_index in ship_position_indices) + 1 * (bomb_index not in ship_position_indices)
            bomb_locations = np.delete(bomb_locations, np.where(bomb_locations == bomb_index))

        for guess in range(1,number_of_guesses):
            src_board[indx*number_of_guesses+guess,:] = src_board[indx*number_of_guesses+guess-1,:]
            bomb_index = np.random.choice(bomb_locations)
            src_board[indx*number_of_guesses+guess,bomb_index] = 2 * (bomb_index in ship_position_indices) + 1 * (bomb_index not in ship_position_indices)
            bomb_locations = np.delete(bomb_locations, np.where(bomb_locations == bomb_index))

    return src_board,tgt_board

def generate_square_subsequent_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1)  # Upper triangular mask
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

def apply_random_mask(tgt: torch.Tensor, mask_token: int = 0, mask_prob: float = 0.15):
    masked_tgt = tgt.clone()
    labels = tgt.clone()

    # Create random mask
    mask = torch.rand(tgt.shape, device=tgt.device) < mask_prob

    # Replace input with mask token
    masked_tgt[mask] = mask_token

    # Optionally, ignore loss on unmasked tokens using ignore_index
    loss_mask = mask  # use this to mask the loss later

    return masked_tgt, labels, loss_mask

def generate_games(ngames:int=2000,board_height:int=10,board_width:int=10,ship_sizes:List[int]=SHIP_SIZES):
    """Generate Games

    Args:
        ngames (int, optional): number of games to generate. Defaults to 2000.
        board_height (int, optional): board height. Defaults to 10.
        board_width (int, optional): board width. Defaults to 10.
        ship_sizes (List[int], optional): ship sizes to use. Defaults to SHIP_SIZES.
    """
    print("Generating Games to play")
    src,tgt = generate_game_data(ngames,board_height,board_width,SHIP_SIZES)

    os.makedirs('data',exist_ok=True)
    data = {'src':src,'tgt':tgt}
    pickle.dump(data,open('data/training_data.pickle','wb'))


def load_model():
    path = 'data'
    files = [
        os.path.join(path, f)
        for f in os.listdir(path)
        if f.endswith('.pth') and os.path.isfile(os.path.join(path, f))
    ]
    filename = max(files, key=os.path.getmtime)

    data = torch.load(filename,map_location=device)

    model = Transformer(src_vocab_size=data['model']['src_vocab_size'],
                        tgt_vocab_size=data['model']['tgt_vocab_size'],
                        d_model=data['model']['d_model'],
                        num_heads=data['model']['num_heads'],
                        num_layers=data['model']['num_layers'],
                        d_ff=data['model']['d_ff'],
                        max_seq_length=data['model']['max_seq_length'],
                        dropout=data['model']['dropout']).to(device)

    model.load_state_dict(data['model']['state_dict'])
    model.eval()

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
    optimizer.load_state_dict(data['optimizer'])
    try:
        epochs = data['model']['epochs']
    except:
        epochs = 0
    print(f"Loaded model with {epochs} epochs")
    return model,optimizer,epochs,data

def train(resume_training:bool=False,save_every_n_epoch:int=10,epochs:int=100):
    """This function will train the model using the data generated by generate_game_data.

    Args:
        resume_training (bool, optional): Resume training. Defaults to False.
        save_every_n_epoch (int, optional): Save every n epochs. Defaults to 10.
    """
    src_vocab_size = 3
    tgt_vocab_size = 3 # 0, 1, 2
    d_model = 512
    num_heads = 4
    num_layers = 4
    d_ff = 2048
    max_seq_length = board_height*board_width
    dropout = 0.1
    # Instantiate model
    model = Transformer(src_vocab_size=src_vocab_size,
                        tgt_vocab_size=tgt_vocab_size,
                        d_model=d_model, num_heads=num_heads, num_layers=num_layers,
                        d_ff=d_ff,
                        max_seq_length=max_seq_length, dropout=dropout).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)
    # optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

    if (not osp.exists("data/training_data.pickle")):
        generate_games(ngames=20000,board_height=board_height,board_width=board_width,ship_sizes=SHIP_SIZES)

    data = pickle.load(open('data/training_data.pickle','rb'))

    if resume_training:
        model,optimizer,current_epochs,_ = load_model()
    else:
        current_epochs = 0
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    scaler = GradScaler(device=device)
    optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)


    def train_loop(src:npt.NDArray,tgt:npt.NDArray):
        # Train the model
        src_train, src_test, tgt_train, tgt_test = train_test_split(src, tgt, test_size=0.3,shuffle=True)
        src_train_tensor = torch.tensor(src_train, dtype=torch.long)
        tgt_train_tensor = torch.tensor(tgt_train, dtype=torch.long)
        src_test_tensor = torch.tensor(src_test, dtype=torch.long)
        tgt_test_tensor = torch.tensor(tgt_test, dtype=torch.long)

        train_dataset = TensorDataset(src_train_tensor, tgt_train_tensor)       # Create a dataset
        test_dataset = TensorDataset(src_test_tensor, tgt_test_tensor)       # Create a dataset

        batch_size = 128
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

        # Calculate class weights (adjust manually if needed)
        all_targets = tgt_train_tensor.view(-1)
        class_counts = torch.bincount(all_targets, minlength=3).float()
        class_weights = 1.0 / (class_counts + 1e-6)
        class_weights = class_weights / class_weights.sum()
        class_weights = class_weights.to(device)
        criterion_train = nn.CrossEntropyLoss(weight=class_weights,ignore_index=0)

        all_targets = tgt_test_tensor.view(-1)
        class_counts = torch.bincount(all_targets, minlength=3).float()
        class_weights = 1.0 / (class_counts + 1e-6)
        class_weights = class_weights / class_weights.sum()
        class_weights = class_weights.to(device)
        criterion_val = nn.CrossEntropyLoss(weight=class_weights,ignore_index=0)

        for epoch in range(epochs):
            model.train()
            pbar = tqdm(train_loader)
            for batch in pbar:
                optimizer.zero_grad()

                src_batch, tgt_batch = batch
                src_batch = src_batch.to(device)
                tgt_batch = tgt_batch.to(device)

                guessed_mask = (tgt_batch != 0)
                random_mask = (torch.rand_like(tgt_batch.float()) > 0.9).to(device)
                final_mask = guessed_mask & random_mask
                tgt_batch_masked = tgt_batch.clone()
                tgt_batch_masked[~final_mask] = 0
                with autocast(device_type='cuda', dtype=torch.float16):
                    output = model(src_batch, tgt_batch_masked)
                    loss = criterion_train(output.view(-1, tgt_vocab_size), tgt_batch.view(-1))
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                output_tokens = output.argmax(dim=-1)
                hits = torch.sum(output_tokens == 2).detach().cpu()
                matches = torch.sum(output_tokens == tgt_batch).detach().cpu()
                # print(torch.sum(matches))
                pbar.set_description(f"Epoch: {epoch+current_epochs:d} Train Loss: {loss.item():0.2e} Hits match {hits/batch_size:0.2f} Matches {matches/batch_size:0.2f}")


            pbar = tqdm(test_loader)
            total_val_loss = 0; num_batches = 0
            model.eval()
            for batch in pbar:
                src_batch, tgt_batch = batch
                src_batch = src_batch.to(device)
                tgt_batch = tgt_batch.to(device)

                output = model(src_batch, tgt_batch)
                val_loss = criterion_val(output.view(-1, tgt_vocab_size), tgt_batch.view(-1))

                pred_classes = output.argmax(dim=-1)
                # print(src_batch[0,:])
                # print(pred_classes[0,:])

                total_val_loss += val_loss.item()
                num_batches += 1
                pbar.set_description(f"Epoch: {epoch+current_epochs:d} Train Loss: {loss.item():0.2e} Val Loss: {val_loss.item():0.2e}")
            average_val_loss = total_val_loss / num_batches  # Compute average validation loss
            pbar.set_description(f"Epoch: {epoch+current_epochs:d} Train Loss: {loss.item():0.2e} Val Loss: {average_val_loss:0.2e}")
            scheduler.step()
            if (epoch % save_every_n_epoch == 0) or (epoch == epochs-1):
                # Save the model
                data = dict()
                data['model'] = {
                    'state_dict': model.state_dict(),
                    'src_vocab_size': src_vocab_size,
                    'tgt_vocab_size': tgt_vocab_size,
                    'd_model': d_model,
                    'num_heads': num_heads,
                    'num_layers': num_layers,
                    'd_ff': d_ff,
                    'max_seq_length': max_seq_length,
                    'dropout': dropout,
                    'epochs':epoch+current_epochs
                }
                data['optimizer'] = optimizer.state_dict()
                torch.save(data, f"data/trained_model-{epoch+current_epochs}.pth")
                if not resume_training:
                    torch.save(data, "data/trained_model.bak.pth")
                print(f"Saved model at epoch {epoch+current_epochs}")

    model.to(device)
    src = data['src']
    tgt = data['tgt']

    print("Train Loop")
    train_loop(src,tgt)

if __name__ =="__main__":
    # generate_games(10000,board_height,board_width,SHIP_SIZES)
    train(resume_training=True)


Generating Games to play


  7%|▋         | 1467/20000 [00:11<01:56, 158.41it/s]