In [1]:
#Add repo path to the system path
from pathlib import Path
import os, sys
repo_path= Path.cwd().resolve()
while '.gitignore' not in os.listdir(repo_path): # while not in the root of the repo
    repo_path = repo_path.parent #go up one level
sys.path.insert(0,str(repo_path)) if str(repo_path) not in sys.path else None

import csv
import pandas as pd
from torch.utils.data import Dataset

# Vanilla prompt

We want to extract specific information from the general csv metadata file. This can be done by creating a function taht takes the folder of interest as input, creates a csv fuile with the contents of the fgolder (filenames) and then creates a new metadata csv using just the names of the filenames csv and lookiing for the ids in the general metadat csv.

In [3]:
def create_folder_csv(folder_dir:Path, image_extension: str):
    """Creates a csv file with the name of the files in the folder.

    Args:
        folder_dir (Path): images folder
        image_extension (str): png, jpg, etc.
    """
    # get folder name from directory
    folder_name = folder_dir.name
    # check if the csv file with the filenames already exists
    csv_path = folder_dir.parent.parent / 'filenames' / f'{folder_name}.csv'
    if not csv_path.exists(): # if not, create it
        with open(csv_path, mode='w', newline='') as file:
            writer = csv.writer(file)
            for filename in os.listdir(folder_dir):
                if filename.endswith(f".{image_extension}"):
                    writer.writerow([filename])
    return csv_path

def subset_csv(files_folder:Path, reference_folder:Path):
    """creates subset coming from the ids of the files in the files folder with reference to the referencer forlder

    Args:
        files_folder (Path): files folder
        reference_folder (Path): reference folder
    """
    csv_path = create_folder_csv(files_folder, 'png') # create name csv if it does not exist
    # open csv file
    name_csv = pd.read_csv(csv_path, header=None)
    # set column name as filename
    name_csv.columns = ['filename']
    # remove extension in all filenames in name_csv
    name_csv['filename'] = name_csv['filename'].str.replace('.png', '', regex=True)
    # open general metadata csv file
    general_csv = pd.read_csv(reference_folder, header=0)
    # create new csv only with the filenames in the folder
    new_csv = general_csv[general_csv['image_id'].isin(name_csv['filename'])]
    # save new csv
    save_path = files_folder.parent.parent / 'metadata' / f'{files_folder.name}.csv'
    new_csv.to_csv(save_path, index=False)
    
    return save_path

In [3]:
# create subset csv
folder_name = 'breast10p'
files_folder = repo_path / 'data/images' / f'{folder_name}'
reference_folder = repo_path / 'dataset_analysis/metadata/metadata_Hologic.csv'
metadata_path = subset_csv(files_folder, reference_folder)

In [4]:
# read csv file
metadata = pd.read_csv(metadata_path, header=0)
# get df with only two columns
metadata = metadata[['image_id', 'view_position']]
metadata

Unnamed: 0,image_id,view_position
0,1.2.826.0.1.3680043.9.3218.1.1.162126803.1843....,MLO
1,1.2.826.0.1.3680043.9.3218.1.1.162126803.1843....,CC
2,1.2.826.0.1.3680043.9.3218.1.1.162126803.1843....,CC
3,1.2.826.0.1.3680043.9.3218.1.1.162126803.1843....,MLO
4,1.2.826.0.1.3680043.9.3218.1.1.2984797.8432.15...,MLO
...,...,...
4054,1.2.826.0.1.3680043.9.3218.1.1.40099880.7359.1...,CC
4055,1.2.826.0.1.3680043.9.3218.1.1.400998809.7359....,MLO
4056,1.2.826.0.1.3680043.9.3218.1.1.400998809.7359....,MLO
4057,1.2.826.0.1.3680043.9.3218.1.1.400998809.7359....,MLO


# Analyze dreambooth dataset

In [4]:
class DreamBoothDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        instance_data_root,
        instance_prompt,
        tokenizer,
        class_data_root=None,
        class_prompt=None,
        class_num=None,
        size=512,
        center_crop=False,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer

        self.instance_data_root = Path(instance_data_root)
        if not self.instance_data_root.exists():
            raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")

        self.instance_images_path = list(Path(instance_data_root).iterdir()) # get all the images paths in the folder
        self.num_instance_images = len(self.instance_images_path) # number of images in the folder
        self.instance_prompt = instance_prompt # prompt for the instance images
        self._length = self.num_instance_images # length of the dataset

        if class_data_root is not None: # if there are prior images
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True) # create the folder if it doesn't exist
            self.class_images_path = list(self.class_data_root.iterdir()) # get paths of all the class images
            if class_num is not None: # class number. This can vary if there are more images in the folder and we only want to use a subset of them
                self.num_class_images = min(len(self.class_images_path), class_num)
            else:
                self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images) # length of the dataset will be the max of the number of instance images and the number of class images
            self.class_prompt = class_prompt
        else:
            self.class_data_root = None

        self.image_transforms = transforms.Compose( # classic image transforms for diffusion models
            [ # resize images to squares with ration preservation, then normalize -1 to 1
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length # length of the dataset is the max of the number of instance images and the number of class images

    def __getitem__(self, index):
        """returns example dictionary

        Args:
            index (int): index of the example

        Returns:
            dict: example dictionary with keys "instance_images", "instance_prompt_ids", "class_images", "class_prompt_ids"
        """
        example = {}
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) # instance images as module in case there are more class images than instance images
        if not instance_image.mode == "RGB": # convert to RGB if not already
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image) # instance images
        example["instance_prompt_ids"] = self.tokenizer( # tokenize the prompt
            self.instance_prompt,
            truncation=True, # max 77 tokens
            padding="max_length", # pad to max length
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids

        if self.class_data_root: # if there are class images
            class_image = Image.open(self.class_images_path[index % self.num_class_images]) # same idea as above
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)
            example["class_prompt_ids"] = self.tokenizer(
                self.class_prompt,
                truncation=True,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                return_tensors="pt",
            ).input_ids

        return example