In [1]:
import os
import re
import pandas as pd
from verstack import NaNImputer
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import FunctionTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Directory config to project root to insure consistency across environments for project specific imports
from pyprojroot import here
os.chdir(here())

# Project specific imports
from src.utils import save_dataframe


## Data Quering and Loading

In [3]:
# DB related imports
from database.db_utils import init_db
from config.config_loader import load_config
from database.queries import prepped_data_query
from sqlalchemy import text

# Initialize local PostgreSQL session
Session = init_db(load_config("DB_URL"))
session = Session()

# Execute and fetch results
data = session.execute(text(prepped_data_query))

# Close the session
session.close()

# Convert to DataFrame
data = pd.DataFrame(data.fetchall(), columns=data.keys())

# Saving cleaned dataset for future use
save_dataframe(data, "00_base_data.csv")

✅ Data successfully saved to ./data\00_base_data.csv with separator ','


In [4]:
# Load dataset
data = pd.read_csv("./data/00_base_data.csv")

In [5]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5997 entries, 0 to 5996
Data columns (total 22 columns):
 #   Column                   Non-Null Count  Dtype  
---  ------                   --------------  -----  
 0   title                    5997 non-null   object 
 1   release_date             5997 non-null   object 
 2   tmdb_vote_count          5997 non-null   int64  
 3   tmdb_vote_average        5997 non-null   float64
 4   genre_names              5997 non-null   object 
 5   budget                   3734 non-null   float64
 6   revenue                  4160 non-null   float64
 7   runtime_in_min           5995 non-null   float64
 8   tmdb_popularity          5997 non-null   float64
 9   production_company_name  5971 non-null   object 
 10  production_country_name  5992 non-null   object 
 11  spoken_languages         5994 non-null   object 
 12  director                 5994 non-null   object 
 13  writer                   5941 non-null   object 
 14  actors                  

In [6]:
# Count missing values
data.isnull().sum()

title                         0
release_date                  0
tmdb_vote_count               0
tmdb_vote_average             0
genre_names                   0
budget                     2263
revenue                    1837
runtime_in_min                2
tmdb_popularity               0
production_company_name      26
production_country_name       5
spoken_languages              3
director                      3
writer                       56
actors                        8
imdb_rating                   9
imdb_votes                    7
metascore                  1489
age_rating                  435
awards                     1003
rotten_tomatoes_rating      928
meta_critic_rating         1489
dtype: int64

## Multilable Categorical Features

There are a nuymber of multilable categorical features that will need to be looked into.

In [7]:
def count_unique_values_for_feature(df: pd.DataFrame, feature: str, delimiter: str = ",") -> int:
    """
    Splits the specified feature column by the delimiter and returns the number of unique values.

    Args:
        df (pd.DataFrame): The DataFrame containing the data.
        feature (str): The name of the column to process.
        delimiter (str): The delimiter used to separate multiple values in the column.

    Returns:
        int: The number of unique values.
    """
    return len(df[feature].dropna().str.split(rf"{delimiter}\s*").explode().unique())

# List of features you want to analyze:
features = [
    "genre_names", 
    "production_company_name", "production_country_name", 
            "spoken_languages", "director", "writer", "actors"]

# Create a dictionary with the counts for each feature:
unique_counts = {feature: count_unique_values_for_feature(data, feature) for feature in features}

# Display the results:
for feature, count in unique_counts.items():
    print(f"{feature}: {count} unique values")

genre_names: 19 unique values
production_company_name: 7238 unique values
production_country_name: 96 unique values
spoken_languages: 100 unique values
director: 3445 unique values
writer: 7741 unique values
actors: 7822 unique values


In [8]:
def print_top_categories(df: pd.DataFrame, column: str, top_n: int, delimiter: str = ",", others_label: str = "Others") -> None:
    """
    Prints the top_n unique values from a multi-label column and the total count of values 
    that fall outside the top_n (which would be grouped as 'Others').

    Args:
        df (pd.DataFrame): The DataFrame containing your data.
        column (str): The name of the multi-label column.
        top_n (int): The number of top categories to display.
        delimiter (str): The delimiter separating multiple values (default is a comma).
        others_label (str): The label used for less frequent values.
    """
    # Split the column into individual values and count frequencies
    exploded = df[column].dropna().str.split(rf"{delimiter}\s*").explode().str.strip()
    counts = exploded.value_counts()
    
    # Get the top N categories and the sum for the rest
    top_categories = counts.head(top_n)
    others_count = counts[counts.index.difference(top_categories.index)].sum()
    
    print("--------------------------------------------------||")
    print(f"Top {top_n} unique values for '{column}':")
    print(top_categories)
    print(f"Total count of all other values (will be grouped as '{others_label}'): {others_count}")
    print("--------------------------------------------------||\n")


top_values = {
    "genre_names": 20,
    "production_company_name": 20,
    "production_country_name": 10,
    "spoken_languages": 10,
    "director": 20,
    "writer": 20,
    "actors": 20
}

for feature, top_n in top_values.items():
    print_top_categories(data, feature, top_n)

--------------------------------------------------||
Top 20 unique values for 'genre_names':
genre_names
Drama              2618
Comedy             2053
Thriller           1690
Action             1431
Adventure           960
Horror              938
Romance             907
Crime               797
Science Fiction     743
Fantasy             675
Family              642
Animation           591
Mystery             561
History             311
Music               173
War                 158
Documentary         117
TV Movie             87
Western              57
Name: count, dtype: int64
Total count of all other values (will be grouped as 'Others'): 0
--------------------------------------------------||

--------------------------------------------------||
Top 20 unique values for 'production_company_name':
production_company_name
Universal Pictures       250
Warner Bros. Pictures    215
Columbia Pictures        187
Lionsgate                183
Paramount Pictures       151
20th Century Fox    

## Handling Nulls

### Simple Imputer

#### Categorical Imputer

In [9]:
# Define a function to add missing indicators for certain columns.
def impute_data(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    imputer = NaNImputer()
    df = imputer.impute(df)
    return df

impute_data = FunctionTransformer(impute_data, validate=False)

In [10]:
# imputer = NaNImputer()
# data_imputed = imputer.impute(data)

In [11]:
# data_imputed.head(5)

In [12]:
# data_imputed.info()

In [13]:
def convert_to_numeric(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    for col in df.columns:
        # Convert to string, remove commas, then convert to numeric
        df[col] = pd.to_numeric(df[col].astype(str).str.replace(',', ''), errors='coerce')
    return df

to_numeric = FunctionTransformer(convert_to_numeric, validate=False)

In [14]:
# Define a function to add missing indicators for certain columns.
def add_missing_indicators(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    for col in df.columns:
        df[col + "_missing"] = df[col].isnull().astype(int)
    return df

missing_indicator_transformer = FunctionTransformer(add_missing_indicators, validate=False)

iter_cols = ['metascore', 'rotten_tomatoes_rating', 'meta_critic_rating', 'budget', 'revenue']

In [15]:
def extract_awards_info(awards_str):
    """
    Extracts numerical awards information from a text string.

    Parameters
    ----------
    awards_str : str
        The awards description string.

    Returns
    -------
    pd.Series
        A Series with the following index:
        ["total_wins", "total_noms", "oscar_wins", "oscar_noms", "bafta_wins", "bafta_noms"]
    """
    # Handle missing or "N/A" values.
    if pd.isna(awards_str) or awards_str.strip() in ["N/A", ""]:
        return pd.Series([0, 0, 0, 0, 0, 0],
                         index=["total_wins", "total_noms", "oscar_wins", "oscar_noms", "bafta_wins", "bafta_noms"])
    
    # Extract overall totals.
    # Look for a pattern like "56 wins" (we use negative lookahead to avoid picking up Oscar wins)
    total_wins_match = re.search(r'(\d+)\s+wins?(?!.*Oscars)', awards_str, flags=re.IGNORECASE)
    total_noms_match = re.search(r'(\d+)\s+nominations', awards_str, flags=re.IGNORECASE)
    total_wins = int(total_wins_match.group(1)) if total_wins_match else 0
    total_noms = int(total_noms_match.group(1)) if total_noms_match else 0

    # Oscar-specific extraction:
    oscar_noms_match = re.search(r'Nominated for\s+(\d+)\s+Oscars?', awards_str, flags=re.IGNORECASE)
    oscar_noms = int(oscar_noms_match.group(1)) if oscar_noms_match else 0
    # Look for something like "Oscars. 56 wins" or "Oscars 56 wins" (using non-digit separator)
    oscar_wins_match = re.search(r'Oscars?[\W_]+(\d+)\s+wins?', awards_str, flags=re.IGNORECASE)
    oscar_wins = int(oscar_wins_match.group(1)) if oscar_wins_match else 0

    # BAFTA-specific extraction:
    # For nominations, sometimes the text might run together (e.g. "BAFTA Award28 nominations total")
    bafta_noms_match = re.search(r'Nominated for\s+(\d+)\s*BAFTA', awards_str, flags=re.IGNORECASE)
    bafta_noms = int(bafta_noms_match.group(1)) if bafta_noms_match else 0
    # For wins, allow an optional "Award" word after BAFTA.
    bafta_wins_match = re.search(r'BAFTA(?:\s+Award)?[\D_]+(\d+)\s+wins?', awards_str, flags=re.IGNORECASE)
    bafta_wins = int(bafta_wins_match.group(1)) if bafta_wins_match else 0

    return pd.Series([total_wins, total_noms, oscar_wins, oscar_noms, bafta_wins, bafta_noms],
                     index=["total_wins", "total_noms", "oscar_wins", "oscar_noms", "bafta_wins", "bafta_noms"])


def transform_awards(X):
    """
    Expects X to be a DataFrame with a single column (e.g., 'awards').
    Applies extract_awards_info row-wise and returns a DataFrame.
    """
    # Apply the function to the first (and only) column
    return X.iloc[:, 0].apply(extract_awards_info)

# Wrap the function in a FunctionTransformer
awards_transformer = FunctionTransformer(transform_awards, validate=False)

In [16]:
from functools import partial

def transform_top_categories(X, column, top_n, delimiter=",", others_label="Others"):
    """
    Transforms a multi-label column by keeping only the top_n categories (based on frequency)
    and replacing all other categories with a generic label.
    
    Parameters:
        X (pd.DataFrame): Input DataFrame.
        column (str): The name of the multi-label column to process.
        top_n (int): Number of top categories to keep.
        delimiter (str): Delimiter separating the values.
        others_label (str): Label to assign to categories not among the top_n.
    
    Returns:
        pd.DataFrame: A DataFrame with one column (the processed column).
    """
    X = X.copy()
    # Split the column values, explode, and count frequencies.
    exploded = X[column].dropna().str.split(rf"{delimiter}\s*").explode().str.strip()
    counts = exploded.value_counts()
    top_categories = counts.head(top_n).index.tolist()
    
    def map_categories(cell):
        if pd.isna(cell):
            return cell
        # Split and strip each value.
        cats = [cat.strip() for cat in cell.split(delimiter)]
        # Replace values not in top_categories with others_label.
        new_cats = [cat if cat in top_categories else others_label for cat in cats]
        # Remove duplicates while preserving order.
        seen = set()
        new_cats = [x for x in new_cats if x not in seen and not seen.add(x)]
        return delimiter.join(new_cats)
    
    X[column] = X[column].apply(map_categories)
    # Return a DataFrame with just the transformed column.
    return X[[column]]

# Now, to create a FunctionTransformer for, say, the 'production_country_name' column with top_n=5:
transformer_prod_country = FunctionTransformer(
    func=partial(transform_top_categories, column="production_country_name", top_n=5, delimiter=",", others_label="Others"),
    validate=False
)

# Similarly, for 'spoken_languages' column with top_n=5:
transformer_spoken_lang = FunctionTransformer(
    func=partial(transform_top_categories, column="spoken_languages", top_n=5, delimiter=",", others_label="Others"),
    validate=False
)

In [17]:
def add_date_features(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df['release_date'] = pd.to_datetime(df['release_date'])
    df['release_year'] = df['release_date'].dt.year
    df['release_month'] = df['release_date'].dt.month
    df['release_day'] = df['release_date'].dt.day
    df['is_weekend'] = (df['release_date'].dt.weekday >= 4).astype(int)
    df['is_holiday_season'] = df['release_month'].isin([6, 7, 11, 12]).astype(int)
    df['movie_age'] = 2025 - df['release_year']
    return df

# Wrap the function as a transformer
date_features_transformer = FunctionTransformer(add_date_features, validate=False)

In [18]:
def calculate_roi(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df['roi'] = (df['revenue'] - df['budget']) / df['budget']
    return df

# Wrap the function as a transformer
roi_transformer = FunctionTransformer(calculate_roi, validate=False)

In [19]:
main_transformer = ColumnTransformer(
    transformers=[
        ('missing_indicator', missing_indicator_transformer, ['metascore', 'rotten_tomatoes_rating', 'meta_critic_rating', 'budget', 'revenue']),
        ('awards', awards_transformer, ['awards']),
        ('date_feature_engineering', date_features_transformer, ['release_date']),
        ('top_n_prod_country', transformer_prod_country, ['production_country_name']),
        ('top_n_spoken_lang', transformer_spoken_lang, ['spoken_languages']),
        ('to_numeric', to_numeric, ['imdb_rating', 'imdb_votes'])
    ],
    remainder='passthrough', 
    verbose_feature_names_out=False
)

# Set output to pandas dataframe
main_transformer.set_output(transform='pandas')

# Apply the preprocessor to the data
# clean_data = main_transformer.fit_transform(data)
# clean_data.head()

In [20]:
full_pipeline = Pipeline(steps=[
    ('main_transformer', main_transformer),
    ('impute_data', impute_data),
    ('roi_feature_engineering', roi_transformer),
    # ('dropper', dropper, columns_to_drop)
])

# Set output to pandas dataframe
full_pipeline.set_output(transform='pandas')

In [21]:
clean_data = full_pipeline.fit_transform(data)


 * Initiating NaNImputer.impute
     . Dataset dimensions:
     .. rows:         5997
     .. columns:      38
     .. mb in memory: 1.65
     .. NaN cols num: 15

   - Drop hopeless NaN cols
     . Missing values in production_company_name replaced by "Missing_data" string
     . Missing values in director replaced by "Missing_data" string
     . Missing values in writer replaced by "Missing_data" string
     . Missing values in actors replaced by "Missing_data" string

   - Processing whole data for imputation

   - Imputing single core 11 cols
     . Imputed (regression) - 1489     NaN in metascore
     . Imputed (regression) - 928      NaN in rotten_tomatoes_rating
     . Imputed (regression) - 1489     NaN in meta_critic_rating
     . Imputed (regression) - 2263     NaN in budget
     . Imputed (regression) - 1837     NaN in revenue
     . Imputed (multiclass) - 5        NaN in production_country_name
     . Imputed (multiclass) - 3        NaN in spoken_languages
     . Imputed (

In [22]:
clean_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5997 entries, 0 to 5996
Data columns (total 39 columns):
 #   Column                          Non-Null Count  Dtype         
---  ------                          --------------  -----         
 0   metascore                       5997 non-null   float64       
 1   rotten_tomatoes_rating          5997 non-null   float64       
 2   meta_critic_rating              5997 non-null   float64       
 3   budget                          5997 non-null   float64       
 4   revenue                         5997 non-null   float64       
 5   metascore_missing               5997 non-null   int64         
 6   rotten_tomatoes_rating_missing  5997 non-null   int64         
 7   meta_critic_rating_missing      5997 non-null   int64         
 8   budget_missing                  5997 non-null   int64         
 9   revenue_missing                 5997 non-null   int64         
 10  total_wins                      5997 non-null   int64         
 11  tota

In [23]:
clean_data.head()

Unnamed: 0,metascore,rotten_tomatoes_rating,meta_critic_rating,budget,revenue,metascore_missing,rotten_tomatoes_rating_missing,meta_critic_rating_missing,budget_missing,revenue_missing,...,tmdb_vote_average,genre_names,runtime_in_min,tmdb_popularity,production_company_name,director,writer,actors,age_rating,roi
0,12.904434,32.0,12.876945,15922140.0,17954290.0,1,0,1,1,1,...,6.0,"Action, Thriller, Crime",132.0,27.043,Apeitda,Jung Byung-gil,"Jung Byung-gil, Byeong-sik Jung","Joo Won, Kim Bo-min, Lee Sung-jae",TV-MA,0.127631
1,14.39881,95.994282,14.208377,9256521.0,5699404.0,1,1,1,1,1,...,7.829,"War, Drama, History",107.0,16.008,Miso Film,Ole Bornedal,Ole Bornedal,"Bertram Bisgaard Enevoldsen, Ester Birch, Ella...",TV-MA,-0.384282
2,48.0,56.0,48.0,13963940.0,12323160.0,0,0,0,1,1,...,6.595,"Family, Comedy, Fantasy",100.0,10.04,Walt Disney Pictures,Marc Lawrence,Marc Lawrence,"Anna Kendrick, Shirley MacLaine, Bill Hader",G,-0.117501
3,13.778372,82.137547,13.737696,10405390.0,7448100.0,1,1,1,1,1,...,7.278,Comedy,108.0,5.214,"Chapter 2, Moonshaker",Anthony Marciano,"Max Boublil, Anthony Marciano","Max Boublil, Alice Isaaz, Malik Zidi",TV-MA,-0.284207
4,13.221361,42.0,13.400988,6523460.0,3222323.0,1,0,1,1,1,...,7.6,Comedy,65.0,7.683,Missing_data,Stan Lathan,Dave Chappelle,Dave Chappelle,TV-MA,-0.506041


In [24]:
# Saving cleaned dataset for future use
save_dataframe(clean_data, "01_clean_data.csv")

✅ Data successfully saved to ./data\01_clean_data.csv with separator ','
