In [None]:
from pathlib import Path

import pandas as pd

import wandb

run = wandb.init(project='pokemon-cards', entity=None, job_type="data_split")

raw_data_at = run.use_artifact('pokemon_cards:latest')
path = Path(raw_data_at.download())

In [None]:
SEED = 1
original_table = raw_data_at.get(f"pokemon_table_1k_seed_{SEED}")
original_table_df = pd.DataFrame(data=original_table.data, columns=original_table.columns)

## Split by Card Set

In [None]:
from sklearn.model_selection import StratifiedGroupKFold

card_ids = original_table.get_column('id')
captions = original_table.get_column('caption')
set_names = original_table.get_column('set_name')

split_df = pd.DataFrame()
split_df['id'] = card_ids
split_df['fold'] = -1

cv = StratifiedGroupKFold(n_splits=10)
for i, (train_idxs, test_idxs) in enumerate(cv.split(card_ids, set_names, set_names)):
    split_df.loc[test_idxs, ['fold']] = i

In [None]:
split_df['split'] = 'train'
split_df.loc[split_df.fold == 0, ['split']] = 'test'
split_df.loc[split_df.fold == 1, ['split']] = 'valid'
del split_df['fold']
split_df.split.value_counts()

In [None]:
joined_df = original_table_df.merge(split_df, on='id', how='left')
del joined_df['split_x']
joined_df['split'] = joined_df['split_y']
del joined_df['split_y']

In [None]:
processed_data_loc = wandb.Artifact('pokemon_cards_split', type="split_data")
join_table = wandb.Table(dataframe=joined_df)
processed_data_loc.add(join_table, f"pokemon_table_1k_data_split_seed_{SEED}")
# join_table = wandb.JoinedTable(original_table, data_split_table, "id")

In [None]:
run.log_artifact(processed_data_loc)
run.finish()