## Import Libraries

In [1]:
import os
import json
import torch
import requests
import numpy as np
from PIL import Image
from io import BytesIO
from tqdm import tqdm
from transformers import AutoTokenizer, AutoProcessor
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

## ImageDataLoaderwithInstruction

In [2]:
def check(directory, filename):
    if os.path.isdir(directory):
        files = os.listdir(directory)
        if filename in files:
            return True
    return False

class ImageDataLoaderWithInstruction():
    
    """
    A class for loading images and associated instructions from a given JSON file. The class also 
    prepares a DataLoader for the images and instructions which can be used to feed into models.

    Attributes:
    - directory (str): Directory where the JSON file resides.
    - filename (str): Name of the JSON file to read from.
    - batch_size (int): Size of each batch in the DataLoader.
    - processor (AutoProcessor): Processor for the CLIP model or similar models.
    - tokenizer (AutoTokenizer): Tokenizer for the CLIP model similar models.
    """
    
    def __init__(self, directory, filename, processor, tokenizer, batch_size = 32):
        """
        Initialize the ImageDataLoaderWithInstruction with directory, filename, processor, tokenizer, and batch_size.

        Args:
        - directory (str): Directory where the JSON file resides.
        - filename (str): Name of the JSON file.
        - processor (AutoProcessor): Processor for the CLIP model or similar models.
        - tokenizer (AutoTokenizer): Tokenizer for the CLIP model or similar models.
        - batch_size (int, optional): Size of each batch in the DataLoader. Defaults to 32.
        """
        self.directory = directory
        self.filename = filename
        self.batch_size = batch_size
        self.processor = processor
        self.tokenizer = tokenizer
    
    def load_json_from_directory(self):
        """
        Load the JSON data from the specified directory and filename.

        Returns:
        - dict: Loaded JSON data if file exists, or an empty dictionary if not.
        """
        if check(self.directory, self.filename):
            with open(os.path.join(self.directory, self.filename), 'r') as json_file:
                data = json.load(json_file)
            return data
        else:
            print(f"'{self.filename}' does not exist in the specified directory.")
            return {}
    
    def compute_max_instruction_length(self):
        """
        Compute the maximum instruction length from the JSON data.

        Returns:
        - int: Maximum instruction length.
        """
        max_len = 0
        for item in self.json_data:
            tokens = self.tokenizer.tokenize(item['instruction'])
            length = len(tokens)
            if length > max_len:
                max_len = length
        return max_len
    
    def load_images_from_json(self):
        """
        Load images and their corresponding instructions from the JSON data. Images are processed 
        using the CLIP processor, and paths are constructed based on the JSON data.

        Returns:
        - dict: Dictionary with keys 'input', 'output', and 'instruction', containing processed images and instructions.
        """
        image_data = {'input': {}, 'output': {}, 'instruction': {}}

        for item in tqdm(self.json_data, desc="Processing images"):
            folder_name = item['input'].split('-')[0]
            input_image_path = os.path.join(self.directory, "images", folder_name, item['input'])
            
            input_image = self.processor(images=Image.open(input_image_path), return_tensors="pt")["pixel_values"]
            output_image_path = os.path.join(self.directory, "images", folder_name, item['output'])
            output_image = self.processor(images=Image.open(output_image_path), return_tensors="pt")["pixel_values"]
            
            instruction = item['instruction']
            
            image_data['input'][input_image_path] = input_image
            image_data['output'][input_image_path] = output_image
            image_data['instruction'][input_image_path] = instruction

        return image_data
    
    def prepare_dataloader(self):
        """
        Prepare a DataLoader using the images and instructions loaded from the JSON. Images are stored as tensors 
        and instructions are tokenized.

        Returns:
        - TensorDataset: Dataset containing input images, tokenized instructions, and output images.
        - DataLoader: DataLoader built from the TensorDataset.
        """
        train_input_imgs = []
        train_output_imgs = []
        input_ids = []
        self.json_data = self.load_json_from_directory()
        self.max_len = self.compute_max_instruction_length()
        self.image_data = self.load_images_from_json()

        for key in self.image_data['input'].keys():
            train_input_imgs.append(self.image_data['input'][key])
            train_output_imgs.append(self.image_data['output'][key])
            
            sent = self.image_data['instruction'][key]
            encoded_dict = self.tokenizer.encode_plus(
                sent,
                add_special_tokens=True,
                max_length=self.max_len + 10,
                pad_to_max_length=True,
                padding="max_length",
                return_tensors="pt",
            )
            input_ids.append(encoded_dict["input_ids"].squeeze(dim=0))
        
        train_input_imgs = torch.cat(train_input_imgs, dim=0)
        train_output_imgs = torch.cat(train_output_imgs, dim=0)
        input_ids = torch.stack(input_ids, dim=0)

        train_dataset = TensorDataset(train_input_imgs, input_ids, train_output_imgs)
        
        return train_dataset



In [3]:
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

## Train DataLoader

In [4]:
directory_path = "train"
filename = "edit_turns.json"
train_loader_instance = ImageDataLoaderWithInstruction(directory_path, filename, processor, tokenizer, batch_size=32)
train_dataset= train_loader_instance.prepare_dataloader()

Processing images:   0%|                       | 2/8807 [00:00<07:33, 19.39it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Processing images: 100%|████████████████████| 8807/8807 [15:27<00:00,  9.50it/s]


## Dev DataLoader

In [5]:
directory_path = "dev"
filename = "edit_turns.json"
dev_loader_instance = ImageDataLoaderWithInstruction(directory_path, filename, processor, tokenizer, batch_size=32)
dev_dataset= dev_loader_instance.prepare_dataloader()

Processing images: 100%|██████████████████████| 528/528 [00:57<00:00,  9.23it/s]


## Test DataLoader

In [9]:
directory_path = "test"
filename = "edit_turns.json"
test_loader_instance = ImageDataLoaderWithInstruction(directory_path, filename, processor, tokenizer, batch_size=32)
test_dataset= test_loader_instance.prepare_dataloader()

Processing images: 100%|████████████████████| 1053/1053 [00:45<00:00, 23.34it/s]


## Save Datasets and Params

In [6]:
def save_dataloader_components(dataset, batch_size, dataset_filename):
    """
    Save the TensorDataset and DataLoader parameters to disk.

    Args:
    - dataset (TensorDataset): The dataset you want to save.
    - batch_size (int): Batch size for DataLoader.
    - dataset_filename (str, optional): Name of the file to save the TensorDataset.
    - params_filename (str, optional): Name of the file to save DataLoader parameters.

    Returns:
    None
    """
    # Save the TensorDataset
    torch.save(dataset, dataset_filename)

In [7]:
train_dataset_filename = 'train_dataset_kandisky_bert_magicbrush.pth'
save_dataloader_components(dataset = train_dataset, 
                           batch_size = 32, 
                           dataset_filename = train_dataset_filename)

In [11]:
dev_dataset_filename = 'dev_dataset_kandisky_bert_magicbrush.pth'
save_dataloader_components(dataset = dev_dataset, 
                           batch_size = 32, 
                           dataset_filename = dev_dataset_filename)

In [12]:
test_dataset_filename = 'test_dataset_kandisky_bert_magicbrush.pth'
save_dataloader_components(dataset = test_dataset, 
                           batch_size = 32, 
                           dataset_filename = test_dataset_filename)