# How to prepare the dataset

In this notebook, I will demonstrate how to prepare the dataset to finetune Instruct Pix2Pix Stable Diffusion model. For doing so, you will be needed a free [huggingface](https://hf.co) account to store the dataset. Also note that you don't need an Amazon AWS account or any other cloud services for this notebook to run, everything here is executed on my local machine.

## Install Dependencies

You will be needing the following python packages:

- `huggingface-hub`
- `datasets`

In [None]:
! pip install --no-cache-dir huggingface-hub datasets pillow

## Login to Huggingface

Obtain a Read/Write or Full Access Access Token by following the steps:

- Login to your huggingface account.
- Go to Settings -> Access Tokens -> New Token
- Click on New Token. Provde a Name for your key and select the Type as Write for Full Access.
- Copy the Key


In [None]:
! huggingface-cli login --token=<paste-your-token>

## Organize the Images

- Store all your **original** images in a directory named `original`.
- Store all your **edited** images in a directory named `edited`.
- All prompts should be stored in a single text file where each line corresponds to their respective image.
- `original/image_1.jpg` should correspond to `edited/image_1.jpg` and so on.
- Accordingly the prompt in line 1 of `prompts.txt` should be the prompt used for training the model


In [None]:
# Paths to directories
from pathlib import Path
ORIGINAL_IMAGES = Path("original")
EDITED_IMAGES = Path("edited")
PROMPTS = Path("prompts.txt")

## Validation

In [None]:
# check if directories exists
if not ORIGINAL_IMAGES.exists():
    raise FileNotFoundError(f"Directory: {ORIGINAL_IMAGES.absolute()} not found")
if not EDITED_IMAGES.exists():
    raise FileNotFoundError(f"Directory: {EDITED_IMAGES.absolute()} not found")
if not PROMPTS.exists():
    raise FileNotFoundError(f"File: {PROMPTS.absolute()} not found")

# check if directory contains images
ORIGINAL_IMAGES_COUNT = len(list(ORIGINAL_IMAGES.iterdir()))
EDITED_IMAGES_COUNT = len(list(EDITED_IMAGES.iterdir()))

if ORIGINAL_IMAGES_COUNT == 0:
    raise FileNotFoundError(f"Directory: {ORIGINAL_IMAGES.absolute()} does not contain any images")
else:
    print(f"original images: {ORIGINAL_IMAGES_COUNT}")
    
if EDITED_IMAGES_COUNT == 0:
    raise FileNotFoundError(f"Directory: {ORIGINAL_IMAGES.absolute()} does not contain any images")
else:
    print(f"edited images: {EDITED_IMAGES_COUNT}")
    
if not (ORIGINAL_IMAGES_COUNT == EDITED_IMAGES_COUNT):
    raise ValueError("Mismatch in the number of images in original and edited images")
    
# check if prompts.txt is empty
with open(PROMPTS, "r") as fp:
    prompts = fp.readlines()

if len(prompts) == 0:
    raise ValueError(f"File: {PROMPTS.absolute()} does not contain any prompts")
elif not (len(prompts) == EDITED_IMAGES_COUNT):
    raise ValueError("The number of Images don't match with the number of prompts")
else:
    print(f"Prompts: {len(prompts)}")


## Load the Data

In [None]:
from PIL import Image
import datasets

In [None]:
def load_samples(original_images_path: list[Path], edited_images_path: list[Path], prompts_list: list[str]):
    original_images: list[Image.Image] = []
    edited_images: list[Image.Image] = []
    
    for orig_img, edit_img in zip(original_images_path, edited_images_path):     
        # load images
        original_images.append(Image.open(orig_img.absolute()))
        edited_images.append(Image.open(edit_img.absolute()))
        
        # format the dataset
        dataset_json = {
            "before": original_images,
            "after": edited_images,
            "prompt": prompts_list
        }
        
        # build the dataset
        features = datasets.Features({
            "before": datasets.Image(),
            "after": datasets.Image(),
            "prompt": datasets.Value('string')
        })
        
    return datasets.Dataset.from_dict(dataset_json, features)


ip2p_dataset = load_samples(ORIGINAL_IMAGES.iterdir(), EDITED_IMAGES.iterdir(), prompts)

## Upload Dataset to Huggingface Hub

In [None]:
REPO_ID = "arnabdhar/instruct-pix2pix-dataset"

ip2p_dataset.push_to_hub(
    repo_id = REPO_ID,
    split = 'train',
    private = True
)