In [3]:
import random
import json
from copy import deepcopy
from tqdm import tqdm

In [4]:
colors = ["green", "red", "blue", "yellow", "purple"]
color_counts = {"green": 8, "red": 7, "blue": 7, "yellow": 7, "purple": 7}

def init_deck():
    deck = []
    for color in colors:
        deck += [color] * color_counts[color]
    random.shuffle(deck)
    return deck

def deal_hands(deck):
    hand1 = [deck.pop() for _ in range(14)]
    hand2 = [deck.pop() for _ in range(14)]
    return hand1, hand2

def get_row1_range(board):
    row1_cols = [int(pos.split('-')[1]) for pos in board if pos.startswith('1-')]
    if not row1_cols:
        return None, None
    return min(row1_cols), max(row1_cols)

def legal_moves(board, hand):
    moves = []
    min1, max1 = get_row1_range(board)
    count_row1 = sum([1 for pos in board if pos.startswith('1-')])
    if count_row1 == 0:
        for color in set(hand):
            moves.append(('1-7', color))
    else:
        if count_row1 < 7 and min1 > 1:
            for color in set(hand):
                moves.append((f"1-{min1-1}", color))
        if count_row1 < 7 and max1 < 13:
            for color in set(hand):
                moves.append((f"1-{max1+1}", color))
        for row in range(2, 8):  # Rows 2 to 7
            for col in range(1, 14 - row + 1):
                pos = f"{row}-{col}"
                below_left = f"{row-1}-{col}"
                below_right = f"{row-1}-{col+1}"
                if below_left in board and below_right in board and pos not in board:
                    for color in set(hand):
                        if board[below_left] == color or board[below_right] == color:
                            moves.append((pos, color))
    return moves

def simulate_one_game_collect_samples():
    deck = init_deck()
    hand0, hand1 = deal_hands(deck)
    hands = [hand0, hand1]
    board = {}
    turn = 0
    samples = []
    move_count = 0
    # Track if a player can no longer move (for both)
    skip = [False, False]
    while True:
        if hands[turn]:
            legal = legal_moves(board, hands[turn])
        else:
            legal = []

        if not legal or not hands[turn]:
            skip[turn] = True
            # If both players cannot move or are out of cards, end the game
            if all(skip):
                break
            turn = 1 - turn
            continue
        else:
            skip[turn] = False  # This player can still move

        # First move is always at 1-7
        if move_count == 0:
            move = ('1-7', random.choice(hands[turn]))
        # Second move must be at 1-8 or 1-9 if possible
        elif move_count == 1:
            legal2 = [m for m in legal if m[0] in ('1-8', '1-9')]
            move = random.choice(legal2) if legal2 else random.choice(legal)
        else:
            move = random.choice(legal)

        pos_played, color_played = move
        # Collect a sample for every legal move (label 1 if played, else 0)
        for (pos, color) in legal:
            label = int(pos == pos_played and color == color_played)
            sample = {
                "turn": turn,
                "board": deepcopy(board),
                "hand": deepcopy(hands[turn]),
                "pos": pos,
                "color": color,
                "label": label
            }
            samples.append(sample)
        # Update the board and hand
        board[pos_played] = color_played
        hands[turn].remove(color_played)
        turn = 1 - turn
        move_count += 1

    # The winner is the player with fewer remaining cards (random if tied)
    score0 = len(hands[0])
    score1 = len(hands[1])
    winner = 0 if score0 < score1 else 1 if score1 < score0 else random.choice([0, 1])
    for s in samples:
        s['winner'] = winner
    return samples

In [5]:
N_GAMES = 20000

with open("penguin_party_20000_dataset.jsonl", "w", encoding="utf-8") as f:
    for i in tqdm(range(N_GAMES)):
        samples = simulate_one_game_collect_samples()
        for sample in samples:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")

100%|██████████| 20000/20000 [00:54<00:00, 370.29it/s]
