In [8]:
import os

from tqdm.auto import tqdm
from transformers import CLIPProcessor
from tali.data.data import ModalityTypes
from tali.data.data_plus import TALIBase
from rich import print
from rich.traceback import install

install()


os.environ['TALI_DATASET_DIR']="/data0/TALI-data/"
os.environ['WIT_DATASET_DIR']="/data1/wit/"

tali_dataset = TALIBase(
    set_name="train",
    tali_dataset_dir=os.environ["TALI_DATASET_DIR"],
    modality_list=[
        ModalityTypes.wit_image.value,
        ModalityTypes.wit_caption.value,
        ModalityTypes.wit_title.value,
        ModalityTypes.wit_main_body.value,
    ],
    num_samples_per_episode=32,
    rng_seed=42,
    top_k_tali=10,
    image_size=224,
    num_video_frames=10,
    num_audio_frames=32000,
    clip_duration_in_seconds=3.0,
    deterministic_sampling=True,
    dummy_batch_mode=False,
    image_text_model_name="openai/clip-vit-base-patch16",
    audio_model_name="openai/whisper-base",
    use_model_preprocessing=True,
    total_num_samples=None,
    cache_generated_samples_in_memory=False,
    cache_num_samples=10,
)
image_text_model_name: str = "openai/clip-vit-base-patch16"

clip_preprocessor = CLIPProcessor.from_pretrained(image_text_model_name)


In [20]:
sample = tali_dataset[0]

image = sample['wikipedia_caption_image']
text = sample['wikipedia_caption_text']

print(sample)

In [17]:
print(text)

In [18]:
print(image.shape)

In [19]:
print(f"image: {image.shape}, mean: {image.mean()}, std: {image.std()}, min: {image.min()}, max: {image.max()}")

In [13]:
core_dataset = tali_dataset.dataset

In [14]:
samples = core_dataset[:32]

In [35]:
wit_idx = samples["wit_idx"]
wit_image = samples["wikipedia_caption_image"]
wit_caption = samples["wikipedia_caption_text"]

wit_image_tokens = clip_preprocessor(images=wit_image, return_tensors="pt").pixel_values
wit_text_tokens = clip_preprocessor(text=wit_caption, padding=True, truncation=True, return_tensors="pt").input_ids
print(wit_image_tokens.shape)

In [37]:
import torch

diff_image = torch.sum(torch.abs(image - wit_image_tokens))
diff_text = torch.sum(torch.abs(text - wit_text_tokens))
print(diff_image)
print(diff_text)