In [None]:
from jass_bot.heuristics.graf import get_graf_scores, push_threshold
from itertools import combinations
import math
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import pyarrow as pa

In [None]:
ALL_CARDS = list(range(36))
n_hand = 9

In [None]:
n_combinations = math.comb(len(ALL_CARDS), n_hand)
n_combinations

In [None]:
total_rows = n_combinations * 2
total_cols = n_hand + 1 + 1 + 6  # hand, forehand, selected, trump_scores
expected_final_shape = (total_rows, total_cols)
expected_final_shape

In [None]:
def get_sample(comb, forehand, selected_trump, trump_scores):
    return [*comb, forehand, selected_trump, *trump_scores]

In [None]:
batch_size = 1024 * 128  # one row group will contain this times 2
parquet_file_path = "./data/graf-data.parquet"

schema = pa.schema(
    [
        ("c1", pa.int8()),
        ("c2", pa.int8()),
        ("c3", pa.int8()),
        ("c4", pa.int8()),
        ("c5", pa.int8()),
        ("c6", pa.int8()),
        ("c7", pa.int8()),
        ("c8", pa.int8()),
        ("c9", pa.int8()),
        ("fh", pa.int8()),
        ("trump", pa.int8()),
        ("ts0", pa.int16()),
        ("ts1", pa.int16()),
        ("ts2", pa.int16()),
        ("ts3", pa.int16()),
        ("ts4", pa.int16()),
        ("ts5", pa.int16()),
    ]
)


In [None]:
assert False, "you sure you wanna override it? takes 3 hours"
writer = pq.ParquetWriter(parquet_file_path, schema)

samples = []

for i, comb in tqdm(enumerate(combinations(ALL_CARDS, n_hand)), total=n_combinations):
    trump_scores = []
    for trump in range(6):  # 4 suits, obeabe, uneufe
        scores = get_graf_scores(trump)
        score = np.sum(scores[np.array(comb)])
        trump_scores.append(score)
    selected_trump = np.argmax(trump_scores)
    would_push = trump_scores[selected_trump] < push_threshold
    samples.append(get_sample(comb, 0, selected_trump, trump_scores))
    # push is not a trump itself but has the value 6 (7th value)
    # if it's forehand, you can push, otherwise just pick the best you can
    samples.append(get_sample(comb, 1, 6 if would_push else selected_trump, trump_scores))
    
    if (i+1) % batch_size == 0:
        writer.write_batch(pa.record_batch(pd.DataFrame(samples, columns=schema.names), schema=schema))
        del samples
        samples = []

# write the remaining samples to the file
if samples:
    writer.write_batch(pa.record_batch(pd.DataFrame(samples, columns=schema.names), schema=schema))
    del samples
    samples = []

writer.close()

In [None]:
import dask.dataframe as dd

In [None]:
ddf = dd.read_parquet(parquet_file_path)
ddf

In [None]:
trump_counts = ddf.trump.value_counts().compute()
trump_counts

In [None]:
assert trump_counts.sum() == total_rows, f"Has {trump_counts.sum()} rows, should be {total_rows} (diff = {total_rows - trump_counts.sum()})"

Forgot to write the remaining samples down after the loop, it would be right otherwise I'm sure.