# Huggingface Dataset Experiments

Experimented with the huggingface datasets instead. But not very efficient due to the exploding of the rows (as it will have many duplicated values that are stored seperatly, like the image and text encodings). It would be possible to do the exploding in a custom pytorch dataset in the end, but this sort of defeats the purpose of using huggingface datasets in the first place.

In [5]:
import meme_entity_detection
from pathlib import Path

project_dir = Path(meme_entity_detection.__file__).parent.parent.parent
data_dir = project_dir / "data" / "HVVMemesWithFaces"

In [6]:
import datasets

dataset = datasets.load_dataset(path=str(data_dir), split='train[0:100]')
dataset = dataset.map(lambda row: {"image": str(data_dir / "images" / row["image"])}, batched=False)

Generating train split: 5552 examples [00:00, 610560.46 examples/s]
Generating validation split: 1368 examples [00:00, 196446.45 examples/s]
Generating test split: 718 examples [00:00, 152281.06 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 6114.59 examples/s]


In [7]:
import pandas as pd

def explode_roles(df):
    roles = ['hero', 'villain', 'victim', 'other']
    exploded_df = pd.concat([df.explode(role).assign(role=role).rename(columns={role: 'entity'}) for role in roles])
    
    return exploded_df.dropna(subset=['entity']).drop(columns=roles).reset_index(drop=True)

dataset = dataset.with_format("pandas").map(explode_roles, batched=True)

Map: 100%|██████████| 100/100 [00:00<00:00, 5837.75 examples/s]


In [9]:
import transformers

processor: transformers.ViltProcessor = transformers.ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")

def tokenize(examples):
    encoding = processor(text=examples['OCR'], images=examples['image'], return_tensors="pt", padding="max_length", truncation=True)
    return encoding

dataset = dataset.with_format().cast_column('image', datasets.Image(mode="RGB"))
encoded_dataset = dataset.map(tokenize, batched=True)

Map:   0%|          | 0/330 [00:04<?, ? examples/s]


KeyboardInterrupt: 

In [None]:
encoded_dataset[0]

{'OCR': 'PROUD TO BE\n***\nREDUICAN\nEPUBLICA\n',
 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=225x225>,
 'entity': 'republican',
 'faces': None,
 'role': 'hero'}

In [None]:
import transformers
import PIL 

processor: transformers.ViltProcessor = transformers.ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")

def preprocess_function(examples):
    encoding = processor(text=examples['OCR'], images=examples['image'], return_tensors="pt", padding="max_length", truncation=True)
    
    # labels = {
    #     'hero': examples['hero'],
    #     'villain': examples['villain'],
    #     'victim': examples['victim'],
    #     'other': examples['other']
    # }
    
    # encoding['labels'] = labels
    # return encoding

encoded_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset.column_names)

Map: 100%|██████████| 2433/2433 [00:24<00:00, 100.58 examples/s]
