In [None]:
%%capture
!pip install -U wandb
!pip install -U pandas # upgrade pandas
# !pip install -U dask["complete"]
!pip install swifter # first time installation
!pip install swifter[groupby]
# !pip install wrapt_timeout_decorator

In [None]:
import wandb
wandb.login()

In [None]:
project_name = "krea-open-prompts"
run_name = "process-open-prompts-sd"
tags = ["download", "stable_diffusion", "process"]
_config = {
    # option of "prompts" or "sample_prompts"
    "dataset": "sample_prompts"
}

In [None]:
run = wandb.init(project=project_name, name=run_name, tags=tags, config=_config)

In [None]:
config = run.config
dataset_name = config['dataset']

In [None]:
art = run.use_artifact('open-prompts-sd:latest', type='raw_data')
dataset_path = art.get_path(f"{dataset_name}.csv").download()

In [None]:
# import ray
# ray.init(ignore_reinit_error=True, dashboard_host="0.0.0.0", include_dashboard=True)

In [None]:
import pandas as pd
import swifter
# import modin.pandas as mpd
# from modin.config import ProgressBar
# ProgressBar.enable()

In [None]:
raw_df = pd.read_csv(dataset_path)

In [None]:
import json

In [None]:
from pandas import json_normalize

In [None]:
def load_and_flatten_json(record):
    json_record = json.loads(record)
    flattened_json_record = json_normalize(json_record, sep="_")
    return flattened_json_record.to_dict(orient="records")[0]

In [None]:
df = pd.DataFrame(raw_df["raw_data"].swifter.apply(load_and_flatten_json).to_list())

In [None]:
df["prompt"] = raw_df["prompt"]

In [None]:
import gc

In [None]:
del raw_df
gc.collect()

In [None]:
df = df.drop(["modifiers"], axis=1)

In [None]:
from PIL import Image
import requests

In [None]:
from pathlib import Path
import os

In [None]:
# dataset_name = "test_modin"

In [None]:
image_folder = Path(".", dataset_name)
image_folder.mkdir(parents=True, exist_ok=True)

In [None]:
default_timeout = 5

In [None]:
def download_image_files(image_url):
    file_name = image_url.rsplit('/', 1)[-1]
    file_path = Path(image_folder, file_name)
    try:
        #Allows for retries without redownloads
        if file_path.exists():
            #Test the image actually opens and then close it
            img = Image.open(file_path)
            img.close()
            return (str(file_path), True)
        #Downloads image and writes it to file
        img_data = requests.get(image_url, timeout=default_timeout).content
        with open(file_path, 'wb') as handler:
            handler.write(img_data)
        #Test the image actually opens and then close it
        img = Image.open(file_path)
        img.close()
        return (str(file_path), True)
    except Exception as e:
        print(e)
        #Remove traces of erred files to prevent broken files from still existing
        if file_path.exists():
            os.remove(file_path)
        return (None, False)

In [None]:
# from tqdm import tqdm
# tqdm.pandas()

In [None]:
#BUG: check if modin actually helps here
# df = mpd.DataFrame(df)

In [None]:
from tqdm import tqdm

In [None]:
responses = df["raw_discord_data_image_uri"].swifter.apply(download_image_files)

In [None]:
df["local_image_location"] = [response[0] for response in responses]
df["image_download_success"] = [response[1] for response in responses]

In [None]:
df.shape

In [None]:
processed_df = df[df["image_download_success"]]

In [None]:
processed_df.shape

In [None]:
processed_df.columns

In [None]:
processed_df.iloc[0]

In [None]:
processed_df["local_image_location"].iloc[0]

In [None]:
selected_columns = [
    "prompt",
    "local_image_location",
    "raw_discord_data_seed",
    "raw_discord_data_width",
    "raw_discord_data_height",
    "raw_discord_data_is_grid",
    "raw_discord_data_num_step",
    "raw_discord_data_cfg_scale",
    "raw_discord_data_timestamp",
    "raw_discord_data_num_generations"
    
]

In [None]:
processed_df = processed_df[selected_columns]

In [None]:
processed_df.iloc[0]

In [None]:
processed_df_path = f"{dataset_name}.csv"
processed_df.to_csv(processed_df_path, index=False)

In [None]:
processed_data_art = wandb.Artifact(name=dataset_name, type="processed_data")

In [None]:
processed_data_art.add_file(processed_df_path)
processed_data_art.add_dir(image_folder, name=dataset_name)

In [None]:
run.log_artifact(processed_data_art)

In [None]:
run.finish()