In [None]:
%%capture
!pip install --U wandb
!pip install pandarallel

In [None]:
import wandb
wandb.login()

In [None]:
project_name = "krea-open-prompts"
run_name = "process-and-log-available-open-prompts-sd-with-images"
tags = ["log", "stable_diffusion", "available", "process"]
_config = {
    # option of "prompts" or "sample_prompts"
    "dataset": "prompts"
}

In [None]:
run = wandb.init(project=project_name, name=run_name, tags=tags, config=_config)

In [None]:
config = run.config
dataset_name = config['dataset']

In [None]:
art = run.use_artifact('open-prompts-sd:latest', type='raw_data')
dataset_path = art.get_path(f"{dataset_name}.csv").download()

In [None]:
import psutil
# By default, Pandarallel use all available CPUs
NB_PHYSICAL_CORES = psutil.cpu_count(logical=False)
NB_CORES = psutil.cpu_count()

In [None]:
from pandarallel import pandarallel
#LEts see if using logical cores too will be good
pandarallel.initialize(progress_bar=True, nb_workers=NB_CORES)

In [None]:
import json
from pandas import json_normalize

In [None]:
def load_and_flatten_json(record):
    json_record = json.loads(record)
    flattened_json_record = json_normalize(json_record, sep="_")
    return flattened_json_record.to_dict(orient="records")[0]

In [None]:
import os
#TODO: store images in a bucket and read/write to there and make the artifact reference that. Will allow for multiple pcs and processes to write to it
#also not restricted to local filestore
downloaded_images = set(os.listdir(image_folder))

In [None]:
raw_df = pd.read_csv(dataset_path)
df = pd.DataFrame(raw_df["raw_data"].parallel_apply(load_and_flatten_json).to_list())
# df = pd.DataFrame(raw_df["raw_data"].swifter.force_parallel().apply(load_and_flatten_json).to_list())

df["local_image_location"] = df["raw_discord_data_image_uri"].str.rsplit("/", 1).str[-1]
df["prompt"] = raw_df["prompt"]
del raw_df
df = df[df["local_image_location"].isin(downloaded_images)][["raw_discord_data_image_uri", "local_image_location"]]

In [None]:
selected_columns = [
    "prompt",
    "local_image_location",
    "raw_discord_data_seed",
    "raw_discord_data_width",
    "raw_discord_data_height",
    "raw_discord_data_is_grid",
    "raw_discord_data_num_step",
    "raw_discord_data_cfg_scale",
    "raw_discord_data_timestamp",
    "raw_discord_data_num_generations"
    
]

In [None]:
df = df[selected_columns]

In [None]:
df_path = f"{dataset_name}.csv"
df.to_csv(df_path, index=False)

In [None]:
data_art = wandb.Artifact(name=dataset_name, type="processed_data")

In [None]:
image_folder = Path(".", dataset_name)
image_folder.mkdir(parents=True, exist_ok=True)

In [None]:
data_art.add_file(df_path)
data_art.add_dir(image_folder, name=dataset_name)

In [None]:
run.log_artifact(data_art)

In [None]:
run.finish()