In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import ast
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from pprint import pprint
from collections import Counter
from galleries_mapping import *
import matplotlib.pyplot as plt
from sklearn.utils import resample

In [3]:
def plot_tags_count(data: pd.DataFrame | dict):
    if isinstance(data, pd.DataFrame):
        data = Counter([item for sublist in data['labels'] for item in sublist])
    sorted_values_cntr = {k: v for k, v in sorted(data.items(), key=lambda item: item[1], reverse=True)}
    plt.figure(figsize=(10, 10))
    plt.pie(
        sorted_values_cntr.values(),
        labels=sorted_values_cntr.keys(),
        autopct='%1.1f%%', startangle=0
    )
    plt.axis('equal')
    plt.title('Label Distribution in Test Set')

    plt.show()


def plot_variance_per_key(data: dict):
    plt.figure(figsize=(10, 10))
    plt.bar(data.keys(), data.values())
    plt.title('Label Distribution in Test Set')
    plt.show()

In [7]:
SRC_DIR = Path('/Volumes/external_drive')
FILTERED_PORNHUB_CATEGORIES = [
    'anal',
    'bbw',
    'big ass',
    'big dick',
    'big tits',
    'blonde',
    'blowjob',
    'bondage',
    'brunette',
    'cosplay',
    'creampie',
    'cumshot',
    'double penetration',
    'ebony',
    'feet',
    'fingering',
    'fisting',
    'handjob',
    'hardcore',
    'lesbian',
    'massage',
    'masturbation',
    'milf',
    'old/young',
    'pissing',
    'public',
    'pussy licking',
    'red head',
    'rough sex',
    'small tits',
    'smoking',
    'solo',
    'squirt',
    'strap on',
    'striptease',
    'tattooed women',
    'teen',
    'threesome',
    'toys',
    'transgender'
]

try:
    _df = pd.read_csv('datasets/images_high_res_dataset.csv')
except FileNotFoundError:
    _df = pd.read_csv(SRC_DIR / 'images_high_res_dataset.csv')

print("Dataset loaded.")
_df.drop(['models'], axis=1, inplace=True)
_df = _df[_df['categories'].notnull() & _df['categories'].apply(lambda x: x != [])]
_df = _df[_df['categories_suggestions'].notnull() & _df['categories_suggestions'].apply(lambda x: x != [])]
print("Cleaned categories")
_df['categories'] = _df['categories'].apply(ast.literal_eval)
_df['categories_suggestions'] = _df['categories_suggestions'].apply(ast.literal_eval)
print("Parsed list columns categories")

In [5]:
df = _df.copy(True)
df.shape

(1345434, 7)

# Merge suggestions and categories

In [8]:
def merge_categories(row):
    categories = set(row['categories'])
    categories_suggestions = set(row['categories_suggestions'])
    categories_superset = {category.strip().lower() for category in categories.union(categories_suggestions)}
    return list(categories_superset)


# df['labels'] = df['categories'].apply(lambda x: [category.strip().lower() for category in x])
df['labels'] = df.apply(merge_categories, axis=1)
df.drop(['categories_suggestions', 'categories'], axis=1, inplace=True)
df.shape

(1345434, 6)

# Purge nationalities from tags

In [10]:
nationality_tags_to_purge = {
    'african',
    'american',
    'arab',
    'argentina',
    'australian',
    'brazilian',
    'british',
    'canadian',
    'chinese',
    'colombian',
    'cuban',
    'czech',
    'dutch',
    'european',
    'filipina',
    'french',
    'german',
    'hungarian',
    'indian',
    'italian',
    'japanese',
    'korean',
    'mexican',
    'pinay',
    'polish',
    'russian',
    'spanish',
    'thai',
    'ukrainian',
    'venezuela',
    'white'
}
nationality_tags_to_purge = {n.lower() for n in nationality_tags_to_purge}
df['labels'] = df.labels.apply(lambda x: list(set(x) - nationality_tags_to_purge))
df.shape

(1345434, 6)

# Apply gallery mapping

In [11]:
def gallery_mapping(row):
    labels = row.labels

    out = []
    for L in labels:
        _fetched = GALLERIES_MAP.get(L, None)
        if _fetched is remove_tag:
            continue
        elif _fetched is remove_gallery:
            return None
        elif isinstance(_fetched, list):
            out.extend(_fetched)
        elif _fetched is keep_tag:
            out.append(L.lower())

    return list(set(out))


# gallery_mapping(df.iloc[0])
# df.apply(gallery_mapping, axis=1)
df['labels'] = df.apply(gallery_mapping, axis=1)
df = df[df['labels'].notnull()]
df.shape

(1199096, 6)

# Dataframe clean up

In [12]:
df['file_path'] = df['gallery_category'] + '/' + df['gallery_name'] + '/' + df['filename']
df = df.drop(
    ['gallery_category', 'gallery_name', 'filename'], axis=1
)
df.reset_index(inplace=True, drop=True)

In [None]:
df.drop(
    ['height', 'width'], axis=1, inplace=True
)  #.to_csv('datasets/full_dataset_with_labels.csv')

In [None]:
df.shape

# Balancing the dataset

In [None]:
label_counts = Counter(item for sublist in df['labels'] for item in sublist)
ascending_labels = [k for k, v in sorted(label_counts.items(), key=lambda item: item[1])]
label_proportions = {k: v / len(df) for k, v in sorted(label_counts.items(), key=lambda item: item[1])}

In [13]:
def get_rows_with_label(dataframe, label):
    return dataframe[dataframe['labels'].apply(lambda x: label in x)]


local_min = len(df)
balanced_dfs = {label: pd.DataFrame() for label in label_counts.keys()}

for label in tqdm(ascending_labels, total=len(ascending_labels), desc='Balancing dataset'):
    label_df = get_rows_with_label(df, label)
    label_df = df.apply(lambda x: None if len(x.labels) >= 8 else x, axis=1).dropna()

    label_df_length = len(label_df)

    if local_min > label_df_length:
        local_min = label_df_length

    if label_df_length >= local_min:
        _scale = local_min / label_df_length
    else:
        _scale = 1

    n_samples = int(local_min * _scale)
    label_df = resample(
        label_df,
        n_samples=n_samples,
        random_state=42
    )

    balanced_dfs[label] = label_df

balanced_df = pd.concat(balanced_dfs).drop_duplicates(subset='file_path').reset_index(drop=True)
balanced_df.shape

Balancing dataset:  11%|█         | 12/114 [03:51<32:48, 19.30s/it]

KeyboardInterrupt



In [None]:
plot_tags_count(balanced_df)

# One hot encoding

In [None]:
plot_tags_count(df)

In [None]:
def dataframe_one_hot_encoding(dataframe: pd.DataFrame):
    all_labels = set(label for labels in dataframe['labels'] for label in labels)
    one_hot_encoded = pd.DataFrame()

    for label in tqdm(all_labels, total=len(all_labels), desc='One hot encoding'):
        one_hot_encoded[label] = df['labels'].apply(lambda x: 1 if label in x else 0)

In [None]:
final_df = pd.concat([df, dataframe_one_hot_encoding(df)], axis=1)

In [None]:
final_df = final_df.drop(columns=['labels', 'width', 'height'], axis=1)
final_df.reset_index(drop=True, inplace=True)

In [None]:
cols_sorted = sorted(list(final_df.columns))
cols_sorted.remove("file_path")

final_df = final_df[["file_path", *cols_sorted]]

In [None]:
final_df.to_csv("full_one_hot_dataset.csv")

# Inspect redularization methods

In [None]:
df.shape, balanced_df.shape

In [None]:
_full_df = df.copy(True)
_balanced_df = balanced_df.copy(True)

In [None]:
X_full = balanced_df['file_path'].values
y_full = balanced_df['labels'].values

X = balanced_df['file_path'].values
y = balanced_df['labels'].values

In [None]:
# del df, balanced_df