In [None]:
import json
import random
import os
import re

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

def display_images(images, img_size):
    num_images = len(images)
    if num_images == 1:
        fig, ax = plt.subplots(figsize=(img_size, img_size))
        ax.axis('off')
        ax.imshow(np.array(images[0]))
        plt.show()
    else:
        fig, axes = plt.subplots(nrows=1, ncols=num_images, figsize=(img_size*num_images, img_size))
        for ax, image in zip(axes.flatten(), images):
            ax.axis('off')
            ax.imshow(np.array(image))
        plt.show()

def extract_path_and_convert_token(input_data, img_dir):
    img_path_pattern = re.compile(r'<img_path>(.*?)<img_path>')
    img_paths = [os.path.join(img_dir, path) for path in img_path_pattern.findall(input_data)]
    # input_data_converted = img_path_pattern.sub('<image>', input_data)
    input_data_converted = img_path_pattern.sub('<image>', input_data)
    return input_data_converted, img_paths

class InstructionDataset():
    def __init__(self, json_path, image_dir_path, shuffle=True):
        with open(json_path, encoding='utf-8') as f:
            self.data = json.load(f)
        self.total_samples = len(self.data)
        if shuffle:
            random.shuffle(self.data)            
        self.image_dir_path = image_dir_path
        self.json_path = os.path.basename(json_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        input_data = sample['input']
        output_data = sample['output']
        input_data, img_paths = extract_path_and_convert_token(sample['input'], self.image_dir_path)

        return input_data, output_data, img_paths

In [None]:

dataset  = InstructionDataset(
    # json_path='/cpfs/user/chendelong/open_flamingo_v2/instruction_dataset/converted_datasets/mimic_it/SN_instructions.json',
    # image_dir_path='/cpfs/shared/research-llm/instruc_data_en/multimodal_instruct_tuning/MIMIC-IT/images/SN',
    # json_path = '/cpfs/user/chendelong/open_flamingo_v2/instruction_dataset/converted_datasets/mimic_it/VST_instructions.json',
    # image_dir_path='/cpfs/shared/research-llm/instruc_data_en/multimodal_instruct_tuning/MIMIC-IT/images/VST',

    # json_path = '/cpfs/user/chendelong/open_flamingo_v2/instruction_dataset/converted_datasets/mimic_it/LACONV_instructions.json',
    # json_path = '/cpfs/user/chendelong/open_flamingo_v2/instruction_dataset/converted_datasets/mimic_it/LACR_I2I_instructions.json',
    # json_path = '/cpfs/user/chendelong/open_flamingo_v2/instruction_dataset/converted_datasets/mimic_it/LACR_T2T_instructions.json',
    # json_path = '/cpfs/user/chendelong/open_flamingo_v2/instruction_dataset/converted_datasets/mimic_it/LADD_instructions.json',
    # image_dir_path='/cpfs/shared/research-llm/instruc_data_en/multimodal_instruct_tuning/MIMIC-IT/images/LA',

    json_path = '/cpfs/shared/research-llm/instruc_data_en/multimodal_instruct_tuning/artemis/artemis-v2/dataset/combined/train/artemis_train_instruction.json',
    image_dir_path='',
    shuffle=True
    )
print(f'total samples: {len(dataset)}')
for i in range(10):
    input_data, output_data, img_paths = dataset[i]
    if len(img_paths)>0:
        print(img_paths)
        display_images([Image.open(img_path) for img_path in img_paths], img_size=3)
    print(f'[input]\n{input_data}\n[output]\n{output_data}\n{"-"*64}')

In [None]:
dataset_config = '/cpfs/user/chendelong/open_flamingo_v2/instruction_dataset/configs/datasets.json'

datasets_info = json.load(open(dataset_config, 'r'))
for dataset_info in datasets_info:
    
    if 'cpfs' not in dataset_info['json_path']:
        dataset_info['json_path'] = '/cpfs/user/chendelong/open_flamingo_v2/' + dataset_info['json_path']
    
    dataset  = InstructionDataset(
        json_path=dataset_info['json_path'],
        image_dir_path=dataset_info['img_dir'],
        shuffle=True
    )
    print('='*64)
    for k,v in dataset_info.items():
        print(f'{k}: {v}')
    print(f'total samples: {len(dataset)}')
    for i in range(3):
        input_data, output_data, img_paths = dataset[i]
        # if len(img_paths)>0:
        #     display_images([Image.open(img_path) for img_path in img_paths], img_size=3)
        print(f'[input]\n{input_data}\n[output]\n{output_data}\n{"-"*64}')