In [None]:
from pathlib import Path

import pandas as pd

import wandb

import torch

from transformers import VisionEncoderDecoderModel
from transformers import AutoTokenizer
from transformers import AutoFeatureExtractor

# Define model
MODEL = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL.to(DEVICE)

# Define image feature extractor and tokenizer
FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

SEED = 1

In [None]:
run = wandb.init(project='pokemon-cards', entity=None, job_type="data_split")

raw_data_at = run.use_artifact('pkthunder/pokemon-cards/pokemon_cards:v2', type='raw_data')
path = Path(raw_data_at.download())

original_table = raw_data_at.get(f"pokemon_table_full_seed_{SEED}")
original_table_df = pd.DataFrame(data=original_table.data, columns=original_table.columns)

In [None]:
from PIL import ImageChops

img_sizes = []
images_to_check = []
for i, row in original_table_df.iterrows():
    img_size = row.image.image.size
    if img_size not in img_sizes:
        img_sizes.append(img_size)
        images_to_check.append(row.image.image)

blank_card = original_table_df.iloc[11776].image.image
if blank_card.mode != "RGB":
    blank_card = blank_card.convert(mode="RGB")

blank_card_features = None
with torch.no_grad():
    blank_card_features = FEATURE_EXTRACTOR(images=blank_card, return_tensors="pt").pixel_values[0]
    blank_card_features = MODEL.encoder(blank_card_features.unsqueeze(0)).pooler_output

# blank_card_test = original_table_df.iloc[11777].image.image
# if blank_card_test.mode != "RGB":
#     blank_card_test = blank_card_test.convert(mode="RGB")

# actual_card_test = original_table_df.iloc[0].image.image
# if actual_card_test.mode != "RGB":
#     actual_card_test = actual_card_test.convert(mode="RGB")

In [None]:
# original_table_df.iloc[11776].image.image

In [None]:
row_ids = []
for i, row in original_table_df.iterrows():

    card = row.image.image
    if card.mode != "RGB":
        card = card.convert(mode="RGB")

    card_features = None
    with torch.no_grad():
        card_features = FEATURE_EXTRACTOR(images=card, return_tensors="pt").pixel_values[0]
        card_features = MODEL.encoder(card_features.unsqueeze(0)).pooler_output

    sim = torch.cosine_similarity(blank_card_features, card_features).item()
    if sim > 0.95:
        row_ids.append((sim, i))

In [None]:
for sim, idx in row_ids:
    original_table_df.iloc[idx].image.image.save(f"blank-card-{idx}.png")

new_table_df = original_table_df.drop(index=[idx for _, idx in row_ids])
new_table_wandb = wandb.Table(dataframe=new_table_df)

## Split by Card Set

In [None]:
from sklearn.model_selection import StratifiedGroupKFold

card_ids = new_table_wandb.get_column('id')
captions = new_table_wandb.get_column('caption')
set_names = new_table_wandb.get_column('set_name')

split_df = pd.DataFrame()
split_df['id'] = card_ids
split_df['fold'] = -1

cv = StratifiedGroupKFold(n_splits=10)
for i, (train_idxs, test_idxs) in enumerate(cv.split(card_ids, set_names, set_names)):
    split_df.loc[test_idxs, ['fold']] = i

In [None]:
split_df['split'] = 'train'
split_df.loc[split_df.fold == 0, ['split']] = 'test'
split_df.loc[split_df.fold == 1, ['split']] = 'valid'
del split_df['fold']
split_df.split.value_counts()

In [None]:
joined_df = new_table_df.merge(split_df, on='id', how='left')
del joined_df['split_x']
joined_df['split'] = joined_df['split_y']
del joined_df['split_y']

In [None]:
processed_data_loc = wandb.Artifact('pokemon_cards_split_full', type="split_data")
join_table = wandb.Table(dataframe=joined_df)
processed_data_loc.add(join_table, f"pokemon_table_full_data_split_seed_{SEED}")
# join_table = wandb.JoinedTable(original_table, data_split_table, "id")

In [None]:
run.log_artifact(processed_data_loc)
run.finish()