In [1]:
import re
import json
import os
from typing import Optional, List, Any, Dict, Set, Tuple
import numpy as np
from copy import deepcopy
import random

In [2]:
def read_json(fname: str) -> Any:
    """
    Given a filename, reads a json file and returns the data stored inside.

    Input:
        fname (str):
            Name of the file to be read.

    Output:
        data (Any):
            The data loaded from the json file.
    """

    assert os.path.isfile(fname)
    assert fname.endswith(".json")

    with open(fname, "r") as file:
        data = json.load(file)

    return data


def write_json(
    data: Any,
    fname: str,
) -> None:
    """
    Given a data and the filename, writes the data to the specified
    fname.
    If the directory that the specified filename should be in
    does not exist, then it creates the directory first.

    Input:
        data (Any):
            the data that needs to stored in a json format.

        fname (str):
            path to the file where the data needs to be saved.

    Output:
        None
    """

    assert isinstance(fname, str) and fname.endswith(".json")
    splits = fname.split("/")[:-1]
    root_dir = "/".join(splits)
    if not os.path.isdir(root_dir):
        os.makedirs(root_dir, exist_ok=True)

    with open(fname, "w") as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

In [3]:
def can_place_ship(
    row: int, 
    col: int, 
    size: int, 
    orientation: str,
    grid_size: int,
    board: List[List[str]],
) -> bool:
    if orientation == "horizontal":
        if col + size > grid_size:
            return False
        
        for c in range(col, col + size):
            if board[row][c] != ".":
                return False
        
    elif orientation == "vertical":  
        if row + size > grid_size:
            return False
        
        for r in range(row, row + size):
            if board[r][col] != ".":
                return False
            
    else:
        raise ValueError(
            f"Given orientation {orientation} is not allowed."
        )
    
    return True


def do_place_ship(
    ship_name: str, 
    row: int, 
    col: int, 
    size: int, 
    orientation: str,
    board: List[List[str]],
    ship_positions: Dict[str, Set[Tuple[int, int]]],
) -> None:
    if orientation == "horizontal":
        for c in range(col, col + size):
            board[row][c] = "S"
            ship_positions[ship_name].add((row, c))
    
    elif orientation == "vertical":
        for r in range(row, row + size):
            board[r][col] = "S"
            ship_positions[ship_name].add((r, col))
            
    else:
        raise ValueError(
            f"Given orientation {orientation} is not allowed."
        )
        
        
def place_ships_randomly(
    board: List[List[str]],
    ship_sizes: Dict[str, int],
    grid_size: int,
    ship_positions: Dict[str, Set[Tuple[int, int]]],
):
    for ship_name, size in ship_sizes.items():
        placed = False
        
        while not placed:
            orientation = random.choice(["horizontal", "vertical"])
            row = random.randint(0, grid_size - 1)
            col = random.randint(0, grid_size - 1)

            if can_place_ship(
                row=row, 
                col=col, 
                size=size, 
                orientation=orientation, 
                grid_size=grid_size, 
                board=board,
            ):
                do_place_ship(
                    ship_name=ship_name, 
                    row=row, 
                    col=col, 
                    size=size, 
                    orientation=orientation, 
                    board=board, 
                    ship_positions=ship_positions,
                )
                
                placed = True
                
                
def create_initial_board_representation(
    grid_size: int,
) -> str:
    header_nums = " ".join(str(i + 1).rjust(2) for i in range(grid_size))
    header = f"    {header_nums}"

    rows = []
    for r in range(grid_size):
        row_label = chr(ord('A') + r)
        row_cells = []

        for c in range(grid_size):
            row_cells.append(".")

        row_str = " ".join(cell.rjust(2) for cell in row_cells)
        rows.append(f"{row_label}  {row_str}")
    
    return header + "\n" + "\n".join(rows)

In [4]:
SHIP_SIZES = {
    "Carrier": 5,
    "Battleship": 4,
    "Destroyer": 2,
}


def generate_battleship_boards(
    num_examples: int,
    data_type: str,
    ship_sizes: Dict[str, int],
) -> Dict[str, Any]:
    all_data = []
    
    for _ in range(num_examples):
        if data_type == "train":
            grid_size = random.randint(5, 6)
            
        elif data_type == "test":
            grid_size = 7
            
        else:
            raise ValueError(
                f"Given data type {data_type} is not supported."
            )
            
        board = [
            ["." for _ in range(grid_size)] 
            for _ in range(grid_size)
        ]
        
        ship_positions = {
            ship_name: set() 
            for ship_name in ship_sizes
        }
        
        place_ships_randomly(
            board=board,
            ship_sizes=ship_sizes,
            grid_size=grid_size,
            ship_positions=ship_positions,
        )
        
        board_representation_for_agent = (
            create_initial_board_representation(
                grid_size=grid_size,
            )
        )
        
        for key in ship_positions:
            ship_positions[key] = list(ship_positions[key])
        
        d = {
            "agent": board_representation_for_agent,
            "env": {
                "hidden_board_representation": board,
                "ship_positions": ship_positions,
            },
        }
        
        all_data.append(d)
        
    return all_data
        
        
def generate_battleship_config_file(
    path_to_old_config: str,
    path_to_new_config: str,
    num_train_examples: int,
    num_test_examples: int,
) -> None:
    data = read_json(
        fname=path_to_old_config,
    )
    
    data["train"] = generate_battleship_boards(
        num_examples=num_train_examples,
        data_type="train",
        ship_sizes=SHIP_SIZES,
    )
    
    data["eval"] = generate_battleship_boards(
        num_examples=num_test_examples,
        data_type="test",
        ship_sizes=SHIP_SIZES,
    )
    
    write_json(
        data=data,
        fname=path_to_new_config,
    )

In [5]:
generate_battleship_config_file(
    path_to_old_config="/Users/fahimtajwar/academics/exploration/verl/verl/paprika/environments/env_configs/battleship.json",
    path_to_new_config="/Users/fahimtajwar/academics/exploration/verl/verl/paprika/environments/env_configs/battleship_new.json",
    num_train_examples=1500,
    num_test_examples=100,
)