In [None]:
from datasets import load_dataset
from tqdm.notebook import tqdm
from pathlib import Path
import polars as pl
import time
import random
import spacy
import cupy
import math

# Load the Dataset

In [None]:
subset = [
    'similarity',
    'hash',
    'punsafe',
    'pwatermark',
    'LANGUAGE',
    'caption',
    'url',
    'key',
    'status',
    'error_message',
    'width',
    'height',
    'original_width',
    'original_height',
]

# splits = {
#     'train': 'data/train-*-of-*.parquet', 
#     'test': 'data/test-00000-of-00001-f5aa494af1d25f74.parquet'
# }

# df_train = pl.read_parquet(
#     'hf://datasets/yuvalkirstain/laion-hd-subset/' + splits['train']
# )[subset]

# df_test = pl.read_parquet(
#     'hf://datasets/yuvalkirstain/laion-hd-subset/' + splits['test']
# )[subset]

df_train = pl.read_parquet('data/laion-hd-subset-train')[subset]
df_test = pl.read_parquet('data/laion-hd-subset-test')[subset]

In [None]:
df_train

# Text Cleaning

In [None]:
# Text Cleaning Function

import re
import nltk
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

def clean_text(text):
    # 1) Lowercase
    text = text.lower()
    
    # 2) Remove punctuation
    text = re.sub(r'[^\w\s]', '', text)

    # 2.5) Remove non-ascii characters
    text = text.encode("ascii", "ignore").decode()
    
    # 3) Remove numbers
    text = re.sub(r'\d+', '', text)
    
    # # 4) Tokenize
    # tokens = nltk.word_tokenize(text)
    
    # # 5) Remove stopwords
    # stop_words = set(stopwords.words('english'))
    # tokens = [w for w in tokens if w not in stop_words]
    
    # # 6) Lemmatize
    # lemmatizer = WordNetLemmatizer()
    # tokens = [lemmatizer.lemmatize(w) for w in tokens]
    
    tokens = text.split()
    # # 7) Remove extra whitespace / rejoin if desired
    cleaned_text = " ".join(tokens)

    return cleaned_text

# # Example usage
# raw_text = "Hello there!!! This is a sample text, containing numbers like 123 and punctuation."
# cleaned = clean_text(raw_text)
# print(cleaned)
# # Output might be: "hello sample text containing number like punctuation"

In [None]:
row_index = random.randint(0, df_train.shape[0])

caption = df_train.with_row_index().filter(
                    pl.col('index') == row_index
                )['caption'].to_list()[0]

cleaned_caption = clean_text(caption)

print(
    f"Original Caption:\n{caption}\n\n"
    f"Cleaned Caption:\n{cleaned_caption}"
)

# Dataset Cleansing

In [None]:
# Are uids unique?

df_train['hash'].unique().count() == df_train.shape[0]

In [None]:
# Are Captions unique?

df_train['caption'].unique().count() == df_train.shape[0]

In [None]:
condition = df_train['caption'].is_duplicated()

df_train.filter(
    condition
).sort('caption')

In [7]:
df_train = df_train.unique(
        subset='caption',
        keep='first',
        maintain_order=True
    )

In [None]:
# Are Captions unique?

df_train['caption'].unique().count() == df_train.shape[0]

In [None]:
# Are urls unique?

df_train['url'].unique().count() == df_train.shape[0]

In [None]:
df_train['LANGUAGE'].value_counts().sort('count', descending=True)

# Is English Caption?

In [None]:
# Caption Filtering by Language

import fasttext
import os
import re

def fasttext_load_model(model_name: str, dir_name: str, cwd: str):
    try:
        model_path = os.path.join(cwd, dir_name, model_name)
        model = fasttext.load_model(model_path)
    except:
        parent_dir = os.path.dirname(cwd)
        return fasttext_load_model(model_name, dir_name, parent_dir)
    else:
        return model
    
def extract_str_from_pattern(s: str, regex: str) -> str:
    pattern = re.compile(regex)
    match = pattern.match(s)
    tot_groups = len(match.groups())
    matched_groups = match.groups(tot_groups)
    return matched_groups

def is_english_sentence(sentence: str):
    predictions, score = model.predict(sentence.strip().replace('\n', ' '))
    return {
        "is_english":predictions[0] == '__label__en', 
        "lang_detected":predictions[0], 
        "score":score[0]
        }

model = fasttext_load_model('lid.176.bin', 'models', os.getcwd())

# if __name__ == "__main__":
#     model = fasttext_load_model('lid.176.bin', 'models', os.getcwd()) 
#     input_file = "/home/fbernardi/Documents/fair_spoke_8/train_cap.txt"
#     with open(input_file, 'r', encoding='utf-8') as f:
#         for line in f:
#             result = is_english_sentence(line)
#             detected_lang = extract_str_from_pattern(result['lang_detected'], r'^__label__(.+)$')
#             print(detected_lang[0], result)
        

In [None]:
# Batch Prediction
start_time = time.time()
list_of_captions = df_train['caption'].to_list()
list_of_captions = [x.strip().replace('\n', ' ') for x in list_of_captions]

lang, scores = model.predict(list_of_captions)

end_time = time.time()

elapsed_time = end_time - start_time
print(f"Execution Time: {elapsed_time:.2f} seconds")

tot_elapsed_time = (elapsed_time/df_train.shape[0]) * 4e6
print(f"Esitmated Elapsing Time for 4M rows: {tot_elapsed_time:.2f} seconds")

In [13]:
map_data_types = pl.Struct({
                        pl.Field("is_english", pl.Boolean),
                        pl.Field("lang_detected", pl.String),
                        pl.Field("score", pl.Float64)
                        })

In [None]:
try:
    df_train = df_train.drop(['is_english', 'lang_detected', 'score'])
except:
    pass



start_time = time.time()

df_train = df_train.with_columns(
    df_train['caption'].map_elements(
        lambda x: is_english_sentence(x), 
        return_dtype=map_data_types, 
        skip_nulls=True)
        .alias('lang_detection')
).unnest("lang_detection")

end_time = time.time()

elapsed_time = end_time - start_time
print(f"Execution Time: {elapsed_time:.2f} seconds")

tot_elapsed_time = (elapsed_time/df_train.shape[0]) * 4e6
print(f"Esitmated Elapsing Time for 4M rows: {tot_elapsed_time:.2f} seconds")

In [None]:
df_train

In [16]:
df_train = df_train.with_columns(
        df_train['lang_detected'].map_elements(
        lambda x: extract_str_from_pattern(x, r'^__label__(.+)$')[0], 
        return_dtype=pl.String, 
        skip_nulls=False)
)

In [None]:
df_train

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(
                    rows=2, cols=1,
                    shared_xaxes=True,
                    vertical_spacing=0.03,
                    subplot_titles=["Dataset Language", "Language Detected"]
                    )

fig.add_trace(
    go.Bar(
        x=df_train['LANGUAGE']
            .value_counts()
            .sort('count', descending=False)['count']
            .to_list(),

        y=df_train['LANGUAGE']
            .value_counts()
            .sort('count', descending=False)['LANGUAGE']
            .to_list(),

        orientation='h'
        ),
    row=1, col=1
)

fig.add_trace(
    go.Bar(
        x=df_train['lang_detected']
            .value_counts()
            .sort('count', descending=False)['count']
            .to_list(),

        y=df_train['lang_detected']
            .value_counts()
            .sort('count', descending=False)['lang_detected']
            .to_list(),
            
        orientation='h'
        ),
    row=2, col=1
)

fig.update_layout(height=800, width=1000, title_text="Distributions")
fig.show()

In [None]:
pl.Config(fmt_str_lengths=1000, tbl_width_chars=1000)

# Same Detected Language

df_train.filter(
    (pl.col('LANGUAGE') == pl.col('lang_detected'))
).select(['LANGUAGE', 'lang_detected', 'caption'])

In [None]:
# Only Detected Language (No Ground Truth)

df_train.filter(
    (pl.col('LANGUAGE') == "nolang")
).select(['LANGUAGE', 'lang_detected', 'caption'])

In [None]:
# Different Detected Language

df_train.filter(
    (pl.col('LANGUAGE') != "nolang")
    & (pl.col('LANGUAGE') != pl.col('lang_detected'))
).select(['LANGUAGE', 'lang_detected', 'caption'])

In [None]:
df_train['is_english'].value_counts()   

In [None]:
# Only English Captions (ground truth == detected)

df_train.filter(
    (pl.col('LANGUAGE') == 'en')
    & (pl.col('lang_detected') == 'en')
)['LANGUAGE', 'lang_detected', 'caption']

In [None]:
# Detected English Captions (ground truth different from english)

df_train.filter(
    (pl.col('LANGUAGE') != 'en')
    & (pl.col('lang_detected') == 'en')
)['LANGUAGE', 'lang_detected', 'caption']

In [None]:
# Detected Non-English Captions 

df_train.filter(
 (pl.col('lang_detected') != 'en')
)['LANGUAGE', 'lang_detected', 'caption']#.sample(10)

In [None]:
# Ground Truth English Captions Detected as Non-English

df_train.filter(
    (pl.col('LANGUAGE') == 'en')
    & (pl.col('lang_detected') != 'en')
)['LANGUAGE', 'lang_detected', 'caption']#.sample(10)

# Caption Filtering by Length and Word Count

In [None]:
# Caption Filtering by Length and Word Count

def filter_captions(caption: str, min_words: int=2, min_chars: int=5) -> bool:
    # Strip leading/trailing whitespace
    cap = caption.strip()
    
    # Split the caption into words
    words = cap.split()
    
    # Count characters (excluding leading/trailing whitespace)
    char_count = len(cap)
    
    # Check conditions
    condition =  len(words) > min_words and char_count > min_chars
    
    return condition

In [None]:
df_train = df_train.with_columns(
        df_train['caption'].map_elements(
            filter_captions, 
            return_dtype=pl.Boolean, 
            skip_nulls=False
        ).alias('caption_long')
)

In [None]:
df_train['caption_long'].value_counts(normalize=True)

# Image Dimension and Ratio

In [None]:
import requests
from PIL import Image
from io import BytesIO

def check_image_conditions_from_url(image_url, min_dimension=200, max_aspect_ratio=3):
    """
    Checks if an image from a URL satisfies the given conditions:
    - The smaller dimension is above a specified number of pixels.
    - The aspect ratio is below a specified threshold.

    Parameters:
        image_url (str): URL of the image file.
        min_dimension (int): Minimum size for the smaller dimension. Default is 200 pixels.
        max_aspect_ratio (float): Maximum allowable aspect ratio. Default is 3.

    Returns:
        bool: True if the image satisfies the conditions, False otherwise.
    """
    try:
        # Fetch the image from the URL
        response = requests.get(image_url)
        response.raise_for_status()  # Raise an error for bad HTTP responses

        # Open the image from the response content
        img = Image.open(BytesIO(response.content))
        width, height = img.size

        # Check the smaller dimension
        smaller_dimension = min(width, height)
        if smaller_dimension < min_dimension:
            return False

        # Check the aspect ratio
        aspect_ratio = max(width / height, height / width)
        if aspect_ratio >= max_aspect_ratio:
            return False

        # Both conditions are satisfied
        return True

    except Exception as e:
        print(f"Error processing image from URL {image_url}: {e}")
        return False

def check_image_conditions_from_path(image_path, min_dimension=200, max_aspect_ratio=3):
    """
    Checks if an image satisfies the given conditions:
    - The smaller dimension is above a specified number of pixels.
    - The aspect ratio is below a specified threshold.

    Parameters:
        image_path (str): Path to the image file.
        min_dimension (int): Minimum size for the smaller dimension. Default is 200 pixels.
        max_aspect_ratio (float): Maximum allowable aspect ratio. Default is 3.

    Returns:
        bool: True if the image satisfies the conditions, False otherwise.
    """
    try:
        # Open the image
        with Image.open(image_path) as img:
            width, height = img.size

        # Check the smaller dimension
        smaller_dimension = min(width, height)
        if smaller_dimension < min_dimension:
            return False

        # Check the aspect ratio
        aspect_ratio = max(width / height, height / width)
        if aspect_ratio >= max_aspect_ratio:
            return False

        # Both conditions are satisfied
        return True

    except Exception as e:
        print(f"Error processing image {image_path}: {e}")
        return False

def check_image_conditions_from_dimensions(width, height, min_dimension=200, max_aspect_ratio=3):
    """
    Checks if an image satisfies the given conditions:
    - The smaller dimension is above a specified number of pixels (min_dimension).
    - The aspect ratio is below a specified threshold (max_aspect_ratio).

    Parameters:
        width (int): The width of the image in pixels.
        height (int): The height of the image in pixels.
        min_dimension (int): Minimum size for the smaller dimension. Default is 200 pixels.
        max_aspect_ratio (float): Maximum allowable aspect ratio. Default is 3.

    Returns:
        bool: True if the image satisfies the conditions, False otherwise.
    """
    # Check the smaller dimension
    smaller_dimension = min(width, height)
    if smaller_dimension < min_dimension:
        return False

    # Check the aspect ratio
    aspect_ratio = max(width / height, height / width)
    if aspect_ratio >= max_aspect_ratio:
        return False

    # Both conditions are satisfied
    return True

def download_images(url_list, save_folder):
    """
    Downloads images from a list of URLs and saves them to a specified folder.

    Parameters:
        url_list (list): A list of image URLs.
        save_folder (str): The folder where the images will be saved.

    Returns:
        None
    """
    
    # Create the folder if it doesn't exist
    os.makedirs(save_folder, exist_ok=True)

    for i, url in enumerate(url_list):
        try:
            # Get the image data
            response = requests.get(url, stream=True)
            response.raise_for_status()  # Raise an error for bad HTTP status codes

            # Determine the file name and path
            file_extension = 'jpg'
            file_name = f"image_{i + 1}.{file_extension}"
            file_path = os.path.join(save_folder, file_name)

            # Save the image to the folder
            with open(file_path, 'wb') as file:
                for chunk in response.iter_content(1024):
                    file.write(chunk)

            print(f"Downloaded: {file_name}")

        except Exception as e:
            print(f"Failed to download image from {url}: {e}")

## From Url

In [None]:
start_time = time.time()


url = df_train['url'].sample(1).to_list()[0]
check_image_conditions_from_url(url)

end_time = time.time()

elapsed_time = end_time - start_time
print(f"Execution Time: {elapsed_time:.2f} seconds")

tot_elapsed_time = ((elapsed_time) * 4e6)/(60 * 60 * 24)
print(f"Esitmated Elapsing Time for 4M rows: {tot_elapsed_time:.2f} Days")


## From Path

In [None]:
parent = os.getcwd()
image_dir = 'data'
images_path = os.path.join(parent, image_dir)

url_list = df_train['url'].head(10).to_list()

download_images(url_list, images_path)

In [None]:
start_time = time.time()

for img in os.listdir(images_path):
    path = os.path.join(images_path, img)
    print(check_image_conditions_from_path(path))

end_time = time.time()

elapsed_time = end_time - start_time
print(f"Execution Time: {elapsed_time:.4f} seconds")

denominator = len(os.listdir(images_path))
tot_elapsed_time = ((elapsed_time/denominator) * 4e6)/60
print(f"Esitmated Elapsing Time for 4M rows: {tot_elapsed_time:.2f} min")

## From Dimensions

In [None]:
start_time = time.time()

df_train = df_train.with_columns(pl.struct(['original_width','original_height']).
                        map_elements(
                            lambda x: check_image_conditions_from_dimensions(
                                            x['original_width'], 
                                            x['original_height']), 
                            return_dtype=pl.Boolean, 
                            skip_nulls=True
                        ).alias('image_valid')
                    )


end_time = time.time()

elapsed_time = end_time - start_time
print(f"Execution Time: {elapsed_time:.4f} seconds")

denominator = df_train.shape[0]
tot_elapsed_time = ((elapsed_time/denominator) * 4e6)
print(f"Esitmated Elapsing Time for 4M rows: {tot_elapsed_time:.2f} seconds")

In [None]:
df_train['image_valid'].value_counts(normalize=True)

In [None]:
df_train.filter(pl.col('image_valid') == False).select(
    ['original_width', 'original_height', 'image_valid', 'url']
)

In [None]:
df_train['is_english', 'caption_long', 'image_valid']\
        .to_pandas()\
        .value_counts()\
        .sort_index()

# PoS

Pattern of Speech

- SpaCy Tags Meaning: [SpaCy Docs](https://spacy.io/api/token#attributes) -> [Universal POS tags](https://universaldependencies.org/u/pos/)
  - `ADJ`: adjective
  - `ADP`: adposition
  - `ADV`: adverb
  - `AUX`: auxiliary
  - `CCONJ`: coordinating conjunction
  - `DET`: determiner
  - `INTJ`: interjection
  - `NOUN`: noun
  - `NUM`: numeral
  - `PART`: particle
  - `PRON`: pronoun
  - `PROPN`: proper noun
  - `PUNCT`: punctuation
  - `SCONJ`: subordinating conjunction
  - `SYM`: symbol
  - `VERB`: verb
  - `X`: other
- [Universal Dependancy](https://universaldependencies.org/)

In [None]:
efficency_model = "en_core_web_sm"
accuracy_model = "en_core_web_trf"

# nlp = spacy.load(efficency_model, disable=['lemmatizer', 'ner'])
# nlp = spacy.load(efficency_model)
nlp = spacy.load(accuracy_model, disable=['lemmatizer', 'ner'])

In [None]:
df_eng_caption = df_train.filter(
                    (pl.col('is_english') == True)
                    & (pl.col('caption_long') == True)
                ).with_row_index()

In [None]:
row_index = random.randint(0, df_eng_caption.shape[0])

caption = df_eng_caption.filter(
                    pl.col('index') == row_index
                )['caption'].to_list()[0]

url = df_eng_caption.filter(
                    pl.col('index') == row_index
                )['url'].to_list()[0]

print('\n',caption,
      '\n',url)

In [None]:
doc = nlp(caption)

for token in doc:
    print(token.text, token.pos_, token.dep_)

In [None]:
spacy.displacy.render(doc, style='dep')

In [None]:
# spacy.require_gpu()
# nlp = spacy.load(accuracy_model)

nlp = spacy.load(efficency_model)

def tag_captions(caption: str, model: str='en_core_web_sm'):
    doc = nlp(caption)
    tags = [(token.text, token.pos_, token.dep_) for token in doc]
    return tags

In [None]:
df_eng_caption = df_eng_caption.with_columns(
    df_eng_caption['caption'].map_elements(
    lambda x: tag_captions(x), 
    return_dtype=pl.List, 
    skip_nulls=False).alias('pos_tags')
)

In [None]:
def extract_and_concatenate(list_of_lists: list[list[str]]) -> str:
    """
    Extract the second element from each sublist and concatenate them with underscores.

    Args:
        list_of_lists (list[list[str]]): A list of lists, where each sublist 
            contains at least two string elements.

    Returns:
        str: A concatenated string formed by joining the second element of 
            each sublist with an underscore.

    Raises:
        IndexError: If any sublist contains fewer than two elements.

    Example:
        >>> extract_and_concatenate([['a', 'b'], ['c', 'd'], ['e', 'f']])
        'b_d_f_'

        >>> extract_and_concatenate([['x', 'y'], ['z', 'w']])
        'y_w_'
    """
    concatenated_string = ""
    for element in list_of_lists:
        concatenated_string += element[1] + "_"
    return concatenated_string

In [None]:
df_eng_caption = df_eng_caption.with_columns(
        df_eng_caption['pos_tags'].map_elements(
            lambda x: extract_and_concatenate(x),
            return_dtype=pl.String,
            skip_nulls=False
        ).alias('conc_tags')
    )

In [None]:
df_eng_caption.filter(
    pl.col('index') == 8631
)['caption', 'conc_tags']

In [None]:
df_eng_caption.filter(
    pl.col('index') == 8631
)['pos_tags']#.to_list()[0]

In [None]:
def count_token_pos(tokens: list[tuple[str, str, str]]) -> dict[str, int]:
    """
    Count the occurrences of each POS category in a list of token tuples.

    Args:
        tokens (list[tuple[str, str, str]]): Each tuple should have the format
            (token.text, token.pos_, token.dep_).

    Returns:
        dict[str, int]: A dictionary with POS tags as keys and their counts as values.
    """
    counts = {}
    for token in tokens:
        pos = token[1]
        counts[pos] = counts.get(pos, 0) + 1
    return counts

In [None]:
df_eng_caption = df_eng_caption.with_columns(
    df_eng_caption['pos_tags'].map_elements(
        lambda x: count_token_pos(x), 
        return_dtype=pl.Struct, 
        skip_nulls=True)
        .alias('count_tags')
).unnest("count_tags")

In [None]:
columns_to_replace = [
    'ADJ', 'NOUN', 'PUNCT', 'PROPN', 'VERB', 'SYM', 'NUM', 
    'PART', 'SCONJ', 'ADP', 'DET', 'PRON', 'SPACE', 'CCONJ', 
    'INTJ', 'X', 'AUX', 'ADV'
]

# Replace null values in the selected columns only
df_eng_caption = df_eng_caption.with_columns([
    pl.col(col).fill_null(0) for col in columns_to_replace
])

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

def create_bar_plots(
        df : pl.DataFrame, 
        columns_to_plot: list[str], 
        num_cols: int,
        graph_heihgt: int=600,
    ):

    num_graphs = len(columns_to_plot)

    num_rows = math.ceil(num_graphs/num_cols)

    fig = make_subplots(
                        rows=num_rows, 
                        cols=num_cols,
                        shared_xaxes=False,
                        vertical_spacing=0.03,
                        # subplot_titles=columns_to_plot
                        )

    i = 0
    for k in range(1, num_cols + 1):
        for j in range(1, num_rows+1):
            if i < num_graphs:
                field = columns_to_plot[i]

                fig.add_trace(
                    go.Bar(
                        x=df[field]\
                            .value_counts()\
                            .sort('count', descending=False)[field]
                            .to_list(),

                        y=df[field]\
                            .value_counts()\
                            .sort('count', descending=False)['count']
                            .to_list(),

                        orientation='v',
                        showlegend=True,
                        # legendgroup=field,
                        name=field
                        ),
                    row=j, col=k
                )

                fig.add_annotation(
                    text=field,
                    xref="x2 domain", yref="y2 domain",
                    x=0.5, y=1.1,
                    showarrow=False,
                    row=j, col=k
                )
                i += 1

    fig.update_layout(height=600*num_rows, width=600*num_cols)
    fig.show()

In [None]:
create_bar_plots(df_eng_caption, columns_to_replace, 3)

- senza aggettivi
- con numeri
- con punteggiatura

In [None]:
df_eng_caption.filter(
    pl.col('ADJ') == 0
)['caption', 'conc_tags', 'ADJ']

In [None]:
df_eng_caption.filter(
    pl.col('NOUN') == 0
)['caption', 'conc_tags', 'NOUN']

In [None]:
df_eng_caption.filter(
    (pl.col('ADJ') != 0)
    & (pl.col('NOUN') != 0)
    & (pl.col('NUM') == 0)
)['caption', 'conc_tags', 'ADJ', 'NOUN']

In [None]:
def extract_unique_tags(tags: str) -> str:
    """
    Extract the unique tags from a concatenated string of tags.

    Args:
        tags (str): A string with tags concatenated by underscores.

    Returns:
        str: A string with unique tags concatenated by underscores.
    """
    unique_tags = set(tags.split('_')[:-1])
    unique_tags = sorted(list(unique_tags))
    return "_".join(unique_tags)

def count_unique_tags(tags: str) -> int:
    """
    Count the number of unique tags in a concatenated string of tags.

    Args:
        tags (str): A string with tags concatenated by underscores.

    Returns:
        int: The number of unique tags in the string.
    """
    unique_tags = extract_unique_tags(tags)
    return len(unique_tags.split('_'))

In [None]:
indice = random.randint(0, df_eng_caption.shape[0])

original = df_eng_caption.filter(
    pl.col('index') == indice
)['conc_tags'].to_list()[0]

unique_sorted_caption = extract_unique_tags(original)
tag_num = count_unique_tags(original)

print(
    f"Original Tags:\n{original}\n\n"
    f"Unique Tags:\n{unique_sorted_caption}"
    f"\n\nNumber of Unique Tags: {tag_num}"
)

In [None]:
df_eng_caption = df_eng_caption.with_columns(
    df_eng_caption['conc_tags'].map_elements(
        lambda x: extract_unique_tags(x), 
        return_dtype=pl.String, 
        skip_nulls=True)
        .alias('unique_tags')
)

df_eng_caption = df_eng_caption.with_columns(
    df_eng_caption['conc_tags'].map_elements(
        lambda x: count_unique_tags(x), 
        return_dtype=pl.Int8, 
        skip_nulls=True)
        .alias('count_unique_tags')
)

In [None]:
df_eng_caption['unique_tags'].value_counts().sort('count', descending=True)

In [None]:
create_bar_plots(df_eng_caption, ['count_unique_tags'], 1)

## GCC PoS distribution

In [None]:
gcc_train = pl.read_csv(
        'data/Train_GCC-training.tsv',
        separator='\t',
        new_columns = ['caption', 'url'],
        schema = {
            'captions': pl.String,
            'url': pl.String
        }
    )

gcc_train.shape

In [None]:
# gcc_val = pl.read_csv(
#         'data/Validation_GCC-1.1.0-Validation.tsv',
#         separator='\t',
#         new_columns = ['caption', 'url'],
#         schema = {
#             'captions': pl.String,
#             'url': pl.String
#         }
#     )

# gcc_val.shape

# gcc_train_image_label = pl.read_csv(
#         'data/Image_Labels_Subset_Train_GCC-Labels-training.tsv',
#         separator='\t',
#         has_header=False
# )[:, :-2]

# gcc_train_image_label.shape

### Cleaning

In [None]:
condition = gcc_train['caption'].is_duplicated()

gcc_train_duplicated = gcc_train.filter(
    condition
).sort('caption').with_row_index()

gcc_train_duplicated

In [None]:
duplicated_captions = gcc_train.filter(
    pl.col('caption').is_duplicated()
)['caption'].unique().to_list()


gcc_train_clean = gcc_train.filter(
    ~pl.col("caption").is_in(duplicated_captions)
    )

In [None]:
gcc_train_clean

In [None]:
gcc_train_clean = gcc_train_clean.with_columns(
    gcc_train_clean['caption'].map_elements(
        lambda x: is_english_sentence(x), 
        return_dtype=map_data_types, 
        skip_nulls=True)
        .alias('lang_detection')
).unnest("lang_detection")

In [None]:
gcc_train_clean['is_english'].value_counts()

In [None]:
gcc_train_clean.filter(
    (pl.col('is_english') == True)
    & (pl.col('score') > 0.9)
)

In [None]:
gcc_train_clean = gcc_train_clean.with_columns(
        gcc_train_clean['caption'].map_elements(
            filter_captions, 
            return_dtype=pl.Boolean, 
            skip_nulls=False
        ).alias('caption_long')
)

In [None]:
gcc_train_clean['caption_long'].value_counts(normalize=True)

In [None]:
gcc_eng_caption = gcc_train_clean.filter(
                    (pl.col('is_english') == True)
                    & (pl.col('caption_long') == True)
                     & (pl.col('score') > 0.9)
                ).with_row_index()

In [None]:
gcc_eng_caption = gcc_eng_caption.with_columns(
    gcc_eng_caption['caption'].str.strip_chars_end(' .')
)

### Tagging

In [None]:
def add_missing_columns(df: pl.DataFrame, column_list: list[str]) -> pl.DataFrame:
    # Find columns that are not present in the dataframe
    missing_columns = [col for col in column_list if col not in df.columns]
    if missing_columns:
        # Create expressions that add each missing column as null values
        df = df.with_columns(
            [pl.lit(None).alias(col) for col in missing_columns]
        )
    return df

In [None]:
spacy.require_gpu()
nlp = spacy.load(accuracy_model)

# nlp = spacy.load(efficency_model)


gcc_eng_caption = gcc_eng_caption.with_columns(
    gcc_eng_caption['caption'].map_elements(
    lambda x: tag_captions(x), 
    return_dtype=pl.List, 
    skip_nulls=False).alias('pos_tags')
)

gcc_eng_caption = gcc_eng_caption.with_columns(
        gcc_eng_caption['pos_tags'].map_elements(
            lambda x: extract_and_concatenate(x),
            return_dtype=pl.String,
            skip_nulls=False
        ).alias('conc_tags')
    )

gcc_eng_caption = gcc_eng_caption.with_columns(
    gcc_eng_caption['pos_tags'].map_elements(
        lambda x: count_token_pos(x), 
        return_dtype=pl.Struct, 
        skip_nulls=True)
        .alias('count_tags')
).unnest("count_tags")

columns_to_replace = [
    'ADJ', 'NOUN', 'PUNCT', 'PROPN', 'VERB', 'SYM', 'NUM', 
    'PART', 'SCONJ', 'ADP', 'DET', 'PRON', 'SPACE', 'CCONJ', 
    'INTJ', 'X', 'AUX', 'ADV'
]

gcc_eng_caption = add_missing_columns(gcc_eng_caption, columns_to_replace)

# Replace null values in the selected columns only
gcc_eng_caption = gcc_eng_caption.with_columns([
    pl.col(col).fill_null(0) for col in columns_to_replace
])

gcc_eng_caption = gcc_eng_caption.with_columns(
    gcc_eng_caption['conc_tags'].map_elements(
        lambda x: extract_unique_tags(x), 
        return_dtype=pl.String, 
        skip_nulls=True)
        .alias('unique_tags')
)

gcc_eng_caption = gcc_eng_caption.with_columns(
    gcc_eng_caption['conc_tags'].map_elements(
        lambda x: count_unique_tags(x), 
        return_dtype=pl.Int8, 
        skip_nulls=True)
        .alias('count_unique_tags')
)

In [None]:
# SAVE
# gcc_eng_caption.write_parquet('data/gcc_eng_caption')

In [None]:
# LOAD
gcc_eng_caption = pl.read_parquet('data/gcc_eng_caption')

## Shutterstock PoS distribution

In [None]:
shutterstock = pl.read_csv('data/shutterstock.csv', 
            separator='\t',
            new_columns = ['url', 'caption'],
            schema = {
                'url': pl.String,
                'captions': pl.String
            }
    )

### Cleaning

In [None]:
condition = shutterstock['caption'].is_duplicated()

shutterstock_duplicated = shutterstock.filter(
    condition
).sort('caption').with_row_index()

shutterstock_duplicated

In [None]:
duplicated_captions = shutterstock.filter(
    pl.col('caption').is_duplicated()
)['caption'].unique().to_list()


shutterstock_clean = shutterstock.filter(
    ~pl.col("caption").is_in(duplicated_captions)
    )

shutterstock_clean

In [None]:
shutterstock_clean = shutterstock_clean.with_columns(
    shutterstock_clean['caption'].map_elements(
        lambda x: is_english_sentence(x), 
        return_dtype=map_data_types, 
        skip_nulls=True)
        .alias('lang_detection')
).unnest("lang_detection")

In [None]:
shutterstock_clean['is_english'].value_counts(normalize=True)*100

In [None]:
shutterstock_clean.filter(
    (pl.col('is_english') == True)
    & (pl.col('score') > 0.9)
)

In [None]:
shutterstock_clean = shutterstock_clean.with_columns(
        shutterstock_clean['caption'].map_elements(
            filter_captions, 
            return_dtype=pl.Boolean, 
            skip_nulls=False
        ).alias('caption_long')
)

In [None]:
shutterstock_clean['caption_long'].value_counts(normalize=True)

In [None]:
shutterstock_clean = shutterstock_clean.filter(
                    (pl.col('is_english') == True)
                    & (pl.col('caption_long') == True)
                     & (pl.col('score') > 0.9)
                ).with_row_index()

In [None]:
shutterstock_clean

### Tagging

In [None]:
spacy.require_gpu()
nlp = spacy.load(accuracy_model)

# nlp = spacy.load(efficency_model)


shutterstock_clean = shutterstock_clean.with_columns(
    shutterstock_clean['caption'].map_elements(
    lambda x: tag_captions(x), 
    return_dtype=pl.List, 
    skip_nulls=False).alias('pos_tags')
)

In [None]:
shutterstock_clean = shutterstock_clean.with_columns(
        shutterstock_clean['pos_tags'].map_elements(
            lambda x: extract_and_concatenate(x),
            return_dtype=pl.String,
            skip_nulls=False
        ).alias('conc_tags')
    )

shutterstock_clean = shutterstock_clean.with_columns(
    shutterstock_clean['pos_tags'].map_elements(
        lambda x: count_token_pos(x), 
        return_dtype=pl.Struct, 
        skip_nulls=True)
        .alias('count_tags')
).unnest("count_tags")

columns_to_replace = [
    'ADJ', 'NOUN', 'PUNCT', 'PROPN', 'VERB', 'SYM', 'NUM', 
    'PART', 'SCONJ', 'ADP', 'DET', 'PRON', 'SPACE', 'CCONJ', 
    'INTJ', 'X', 'AUX', 'ADV'
]

shutterstock_clean = add_missing_columns(
        shutterstock_clean, 
        columns_to_replace
    )

# Replace null values in the selected columns only
shutterstock_clean = shutterstock_clean.with_columns([
    pl.col(col).fill_null(0) for col in columns_to_replace
])

shutterstock_clean = shutterstock_clean.with_columns(
    shutterstock_clean['conc_tags'].map_elements(
        lambda x: extract_unique_tags(x), 
        return_dtype=pl.String, 
        skip_nulls=True)
        .alias('unique_tags')
)

shutterstock_clean = shutterstock_clean.with_columns(
    shutterstock_clean['conc_tags'].map_elements(
        lambda x: count_unique_tags(x), 
        return_dtype=pl.Int8, 
        skip_nulls=True)
        .alias('count_unique_tags')
)

In [None]:
shutterstock_clean

In [None]:
# SAVE
# shutterstock_clean.write_parquet('data/shutterstock_eng_caption')

In [None]:
# LOAD
shutterstock_clean = pl.read_parquet('data/shutterstock_eng_caption')

## Distributions

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

fig.add_trace(go.Histogram(
    x=gcc_eng_caption['count_unique_tags'],
    name='GCC Dataset',
    opacity=0.75,
    nbinsx=20
))

fig.add_trace(go.Histogram(
    x=shutterstock_clean['count_unique_tags'],
    name='Shutterstock Dataset',
    opacity=0.75,
    nbinsx=20
))

fig.update_layout(
    title='Distribution of Unique POS Tags Count',
    xaxis_title='Number of Unique POS Tags',
    yaxis_title='Count',
    barmode='overlay',
    width=800,
    height=500
)

fig.show()

In [None]:
# Calculate the total number of samples for each dataset
gcc_total = len(gcc_eng_caption)
shutterstock_total = len(shutterstock_clean)

fig = go.Figure()

# Add normalized histograms
fig.add_trace(go.Histogram(
    x=gcc_eng_caption['count_unique_tags'],
    name='GCC Dataset',
    opacity=0.75,
    nbinsx=20,
    histnorm='percent'  # Normalize to percentage
))

fig.add_trace(go.Histogram(
    x=shutterstock_clean['count_unique_tags'],
    name='Shutterstock Dataset',
    opacity=0.75,
    nbinsx=20,
    histnorm='percent'  # Normalize to percentage
))

fig.update_layout(
    title='Distribution of Unique POS Tags Count (Normalized)',
    xaxis_title='Number of Unique POS Tags',
    yaxis_title='Percentage',
    barmode='overlay',
    width=800,
    height=500
)

fig.show()

In [None]:
# Get unique patterns from each dataset
gcc_patterns = set(gcc_eng_caption['unique_tags'].to_list())
shutterstock_patterns = set(shutterstock_clean['unique_tags'].to_list())

# Find patterns that appear in both datasets
common_patterns = gcc_patterns.intersection(shutterstock_patterns)

# Find patterns unique to each dataset
gcc_only = gcc_patterns - shutterstock_patterns
shutterstock_only = shutterstock_patterns - gcc_patterns

print(f"Total patterns in GCC: {len(gcc_patterns)}")
print(f"Total patterns in Shutterstock: {len(shutterstock_patterns)}")
print(f"Common patterns: {len(common_patterns)}")
print(f"Patterns only in GCC: {len(gcc_only)}")
print(f"Patterns only in Shutterstock: {len(shutterstock_only)}")

In [None]:
# Calculate the total number of samples for each dataset
gcc_total = len(gcc_eng_caption)
shutterstock_total = len(shutterstock_clean)

fig = go.Figure()

# Add normalized histograms
fig.add_trace(go.Histogram(
    x=gcc_eng_caption.filter(
        pl.col('unique_tags').is_in(common_patterns)
    )['count_unique_tags'],
    name='GCC Dataset',
    opacity=0.75,
    nbinsx=20,
    histnorm='percent'  # Normalize to percentage
))

fig.add_trace(go.Histogram(
    x=shutterstock_clean.filter(
        pl.col('unique_tags').is_in(common_patterns)
    )['count_unique_tags'],
    name='Shutterstock Dataset',
    opacity=0.75,
    nbinsx=20,
    histnorm='percent'  # Normalize to percentage
))

fig.update_layout(
    title='Distribution of Unique POS Tags Count (Normalized) - Common Patterns',
    xaxis_title='Number of Unique POS Tags',
    yaxis_title='Percentage',
    barmode='overlay',
    width=800,
    height=500
)

fig.show()

In [None]:
# Calculate the total number of samples for each dataset
gcc_total = len(gcc_eng_caption)
shutterstock_total = len(shutterstock_clean)

fig = go.Figure()

# Add normalized histograms
fig.add_trace(go.Histogram(
    x=gcc_eng_caption.filter(
        pl.col('unique_tags').is_in(gcc_only)
    )['count_unique_tags'],
    name='GCC Dataset',
    opacity=0.75,
    nbinsx=20,
    histnorm='percent'  # Normalize to percentage
))

fig.add_trace(go.Histogram(
    x=shutterstock_clean.filter(
        pl.col('unique_tags').is_in(shutterstock_only)
    )['count_unique_tags'],
    name='Shutterstock Dataset',
    opacity=0.75,
    nbinsx=20,
    histnorm='percent'  # Normalize to percentage
))

fig.update_layout(
    title='Distribution of Unique POS Tags Count (Normalized) - Unique Patterns',
    xaxis_title='Number of Unique POS Tags',
    yaxis_title='Percentage',
    barmode='overlay',
    width=800,
    height=500
)

fig.show()

### No Punct, Zero Noun, Zero Adj

In [None]:
gcc_eng_caption_cleaned = gcc_eng_caption.filter(
    (pl.col('ADJ') != 0)
    & (pl.col('NOUN') != 0)
    & (pl.col('NUM') == 0)
)

In [None]:
shutterstock_cleaned = shutterstock_clean.filter(
    (pl.col('ADJ') != 0)
    & (pl.col('NOUN') != 0)
    & (pl.col('NUM') == 0)
)

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

fig.add_trace(go.Histogram(
    x=gcc_eng_caption_cleaned['count_unique_tags'],
    name='GCC Dataset',
    opacity=0.75,
    nbinsx=20
))

fig.add_trace(go.Histogram(
    x=shutterstock_cleaned['count_unique_tags'],
    name='Shutterstock Dataset',
    opacity=0.75,
    nbinsx=20
))

fig.update_layout(
    title='Distribution of Unique POS Tags Count',
    xaxis_title='Number of Unique POS Tags',
    yaxis_title='Count',
    barmode='overlay',
    width=800,
    height=500
)

fig.show()

In [None]:
# Get unique patterns from each dataset
gcc_patterns = set(gcc_eng_caption_cleaned['unique_tags'].to_list())
shutterstock_patterns = set(shutterstock_cleaned['unique_tags'].to_list())

# Find patterns that appear in both datasets
common_patterns = gcc_patterns.intersection(shutterstock_patterns)

# Find patterns unique to each dataset
gcc_only = gcc_patterns - shutterstock_patterns
shutterstock_only = shutterstock_patterns - gcc_patterns

print(f"Total patterns in GCC: {len(gcc_patterns)}")
print(f"Total patterns in Shutterstock: {len(shutterstock_patterns)}")
print(f"Common patterns: {len(common_patterns)}")
print(f"Patterns only in GCC: {len(gcc_only)}")
print(f"Patterns only in Shutterstock: {len(shutterstock_only)}")

In [None]:
# Calculate the total number of samples for each dataset
gcc_total = len(gcc_eng_caption)
shutterstock_total = len(shutterstock_clean)

fig = go.Figure()

# Add normalized histograms
fig.add_trace(go.Histogram(
    x=gcc_eng_caption_cleaned.filter(
        pl.col('unique_tags').is_in(common_patterns)
    )['count_unique_tags'],
    name='GCC Dataset',
    opacity=0.75,
    nbinsx=20,
    histnorm='percent'  # Normalize to percentage
))

fig.add_trace(go.Histogram(
    x=shutterstock_cleaned.filter(
        pl.col('unique_tags').is_in(common_patterns)
    )['count_unique_tags'],
    name='Shutterstock Dataset',
    opacity=0.75,
    nbinsx=20,
    histnorm='percent'  # Normalize to percentage
))

fig.update_layout(
    title='Distribution of Unique POS Tags Count (Normalized) - Common Patterns',
    xaxis_title='Number of Unique POS Tags',
    yaxis_title='Percentage',
    barmode='overlay',
    width=800,
    height=500
)

fig.show()

In [None]:
# Save common patterns to a text file
with open('common_pos_patterns.txt', 'w') as f:
    for pattern in sorted(common_patterns):
        f.write(pattern + '\n')

print(f"Saved {len(common_patterns)} patterns to common_pos_patterns.txt")

In [None]:
import pickle

# Save the common_patterns set as a list to a pickle file
with open('common_patterns.txt', 'w') as f:
    f.write(list(common_patterns), f)