# Setup

In [None]:
from datetime import datetime
# import glob
import matplotlib.pyplot as plt
import numpy as np
# import os
import pandas as pd
import seaborn as sns
from typing import Iterable

In [None]:
KEEP_CATEGORY_DFS = False
KEEP_SETS = False
DROPS = {'movie_name', 'description', 'director', 'star'}

EXPLORE = False # Whether to execute the data exploration cells

# NOTE: filters
F_RUNTIME = (30, 300)
F_YEAR = (1920, 2020)

In [None]:
# NOTE: Combining the CSVs as below is not used anymore, as the data is dealt with in DataFrames before concatenation.
"""
# Get all CSV file paths in the folder
csv_files = glob.glob(os.path.join("Data", '*.csv'))

# Combine them
combined_df = pd.concat((pd.read_csv(f) for f in csv_files), ignore_index=True)

# Save the result
combined_df.to_csv('movies.csv', index=False)

print(f"Combined {len(csv_files)} files into 'movies.csv'.")
"""

categories = {
    'action',
    'adventure',
    'animation',
    'biography',
    'crime',
    'family',
    'fantasy',
    'film_noir',
    'history',
    'horror',
    'mystery',
    'romance',
    'scifi',
    'sports',
    'thriller',
    'war'
}

In [None]:
# NOTE: Rather than defining each category DataFrame manually, the loop below defines them dynamically.
"""
action = pd.read_csv('Data/action.csv')
adventure = pd.read_csv('Data/adventure.csv')
animation = pd.read_csv('Data/animation.csv')
biography = pd.read_csv('Data/biography.csv')
crime = pd.read_csv('Data/crime.csv')
family = pd.read_csv('Data/family.csv')
fantasy = pd.read_csv('Data/fantasy.csv')
film_noir = pd.read_csv('Data/film_noir.csv')
history = pd.read_csv('Data/history.csv')
horror = pd.read_csv('Data/horror.csv')
mystery = pd.read_csv('Data/mystery.csv')
romance = pd.read_csv('Data/romance.csv')
scifi = pd.read_csv('Data/scifi.csv')
sports = pd.read_csv('Data/sports.csv')
thriller = pd.read_csv('Data/thriller.csv')
war = pd.read_csv('Data/war.csv')
"""

for category in categories:
    # Dynamically read each CSV file into a pd.DataFrame based on the category name
        # and assign it to a global variable with the same name as the category
    globals()[category] = pd.read_csv(f"Data/{category}.csv")

del category

# globals()['action'].head(10)

## Data Preparation

In [None]:
def split_listlikes(df: pd.DataFrame, splits: Iterable, make_sets: bool) -> None:
    """
    Cleans and splits specified columns in a DataFrame.
    
    Parameters:
        df: The DataFrame to process.
        splits: List of columns to split and clean.
        make_sets: If `True`, creates a set of unique values for each column.
    
    Returns:
        None: The function modifies the DataFrame in place.
    """

    def split_multiples(df: pd.DataFrame, col: str) -> None:
        """
        Splits the values in a column by `,` and stores them in a new column, as a list.

        Parameters:
            df: The DataFrame to process.
            col: The column to split.
        
        Returns:
            None: The function modifies the DataFrame in place.
        """

        df[col + 's'] = \
            df[col].apply(
                lambda x: {val.strip().lower() for val in x.split(',')}
                    if pd.notnull(x) 
                    else None)
        
    def trim_name_urls(df: pd.DataFrame, col: str) -> None:
        """
        Trims name URLs to just the important part.
        This is used for director and star names.
        
        i.e. `'/name/nm0000001/'` becomes `'nm0000001'`
        
        Parameters:
            df: The DataFrame to process.
            col: The column to trim.
        
        Returns:
            None: The function modifies the DataFrame in place.
        """

        df[col] = df[col].apply(
            lambda s: {x.split('/')[-2].strip() for x in s}
                if isinstance(s, set) and len(s) > 0
                else s)

    if make_sets:
        def track_uniques(df: pd.DataFrame, col: str, unique_set: set) -> None:
            
            """
            Tracks unique values in a column and stores them in a set.

            Parameters:
                df: The DataFrame to process.
                col: The column to track unique values from.
                unique_set: A set to store unique values.
            
            Returns:
                None: The function modifies the set in place.
            """

            for items in df[col + 's'].dropna():
                # Add each item to the set
                for item in items:
                    unique_set.add(item)

    ##########

    for col in splits:
        col_s = col + 's'
        try:
            # Split the column values by ',' and store in a new column
            split_multiples(df, col)

            # Trim name URLs for director and star names if applicable
            if col_s in ['director_ids', 'star_ids']:
                trim_name_urls(df, col_s)
            elif col_s == 'genres':
                # Convert genres to lowercase
                g = {df['category'][0]}
                df[col_s] = df[col_s].apply(lambda s: s | g)
                df.drop(columns=['category'], inplace=True)

            if make_sets:
                # Track unique values for each column
                globals()[col_s] = set()
                track_uniques(df, col, globals()[col_s])

            df.drop(columns=[col], inplace=True) # Drop the original column to save space
        except KeyError:
            # Handle the case where the column might not exist in the DataFrame (i.e. this function has already been performed)
            print(f"Column '{col}' not found in DataFrame.")

            if make_sets:
                # Initialize an empty set to avoid errors later on
                if col_s not in globals():
                    globals()[col_s] = set()

def clean_runtime(df: pd.DataFrame) -> None:
    """
    Cleans the runtime column in the DataFrame.
    
    Parameters:
        df: The DataFrame to process.
    
    Returns:
        None: The function modifies the DataFrame in place.
    """

    # Convert runtime to numeric values (in minutes)
    df['runtime'] = df['runtime'].apply(
        lambda x: int(x.split()[0].replace(',', '')) if pd.notnull(x) else None)

def convert_year(df: pd.DataFrame) -> None:
    """
    Converts the year column to numeric values.
    
    Parameters:
        df: The DataFrame to process.
    
    Returns:
        None: The function modifies the DataFrame in place.
    """
    def _convert(x):
        """Helper function to convert year to int or None."""
        try:
            return int(x)
        except (ValueError, TypeError):
            return None
        
    df['year'] = df['year'].apply(
        lambda x: _convert(x) if pd.notnull(x) else None)

In [None]:
def prepare(category: str, drops: Iterable = None, make_sets: bool = True) -> None:
    """
    Performs all clean-up / preprocessing steps on the DataFrame.

    Parameters:
        df: The DataFrame to process.
        drops: List of columns to drop from the DataFrame.
        make_sets: If `True`, creates a set of unique values for each column.
    
    Returns:
        None: The function modifies the DataFrame in place.
    """
    df = globals()[category]

    df['category'] = category
    
    df.dropna(inplace=True, how='all') # Drop rows where all elements are None
    try:
        df.drop(columns=drops, inplace=True) # Drop specified columns
    except:
        # Handle the case where the columns might not exist in the DataFrame
        print(f"Columns {drops} not found in DataFrame.\nHas preprocessing already been performed?")
        return

    # NOTE: splits is a list of columns to split by ','.
        # 'director' and 'star' columns are likely not useful,
        # as they are just text (meaningless to a model, likely),
        # so if they are included in drops, they will be ignored here.
    splits = {
        'genre',
        'director',
        'director_id',
        'star',
        'star_id',
    } - drops

    split_listlikes(df, splits, make_sets)
    clean_runtime(df)
    convert_year(df)

In [None]:
for category in categories:
    # Prepare each category DataFrame
    prepare(category, drops=DROPS, make_sets=KEEP_SETS)

del category

# NOTE: now, each category's DataFrame has:
# - Empty rows removed,
# - useless columns dropped,
# - list-like columns split into lists and the original columns removed.

In [None]:
# NOTE: testing

# type(globals()['action']['genres'][0])
# print(globals()['genres'])
# globals()['action']['certificate'].value_counts()

### Combination

In [None]:
# Combine all DataFrames in categories into one DataFrame
combined_df = pd.concat([globals()[category] for category in categories], ignore_index=True)
# combined_df.head()

In [None]:
# NOTE: Exploring duplicate movie IDs

# # Find movie_ids that are duplicated across all DataFrames
# duplicated_ids = combined_df['movie_id'][combined_df['movie_id'].duplicated(keep=False)]

# # Filter rows where movie_id is in the list of duplicated IDs
# duplicates = combined_df[combined_df['movie_id'].isin(duplicated_ids)]
# # duplicates.head()
# duplicates[duplicates['movie_name'] == 'Black Panther: Wakanda Forever'].head()

In [None]:
# NOTE: Minimizing bloat for testing aggregation
# comb_example = combined_df[['movie_id', 'movie_name', 'genres']]
# # comb_example.head()
# dupl_example = duplicates[['movie_id', 'movie_name', 'genres']]
# # dupl_example[dupl_example['movie_name'] == 'Black Panther: Wakanda Forever'].head()

In [None]:
# Combine all rows with the same movie_id into a single row:
# union all genre sets, and keep the first-appearing value for other columns.

# Identify the columns other than movie_id and genre
other_cols = [col for col in combined_df.columns if col not in ['movie_id', 'genres']]

# Create an aggregation dictionary:
#   - for all "other" columns, take 'first'
#   - for 'genre', union the sets
agg_dict = {col: 'first' for col in other_cols}
agg_dict['genres'] = lambda genre_series: set().union(*genre_series)

# Perform the groupby aggregation
grouped_df = combined_df.groupby('movie_id', as_index=False).agg(agg_dict)

# grouped_df.info()
# grouped_df[grouped_df['movie_name'] == 'Black Panther: Wakanda Forever'].head()

In [None]:
# combined_df.info()
# grouped_df.info()

In [None]:
# NOTE: This is how many duplicate movie_ids were removed:
# print(combined_df.shape[0] - grouped_df.shape[0])

In [None]:
movies = grouped_df.copy()

to_del = {'combined_df', 'duplicated_ids', 'duplicates', 'comb_example', 'dupl_example', 'other_cols', 'agg_dict', 'grouped_df'}
if not KEEP_CATEGORY_DFS:
    to_del |= categories

for item in to_del:
    try:
        del globals()[item]
    except KeyError:
        pass
del to_del
del item

# Data Exploration

In [None]:
def get_link(id: str, type: str) -> str:
    """
    Returns a link to the movie or person based on the ID and type.

    Parameters:
        id: The ID of the movie or person.
        type: The type of the link (`'movie'` or `'person'`).

    Returns:
        str: The formatted link.
    """
    base_url = "https://www.imdb.com"
    if   type == 'movie':
        return f"{base_url}/title/{id}/"
    elif type == 'person':
        return f"{base_url}/name/{id}/"
    else:
        raise ValueError("Type must be either 'movie' or 'person'.")
    
def show_sorted(df, cols: list, order_by: str, max: bool = True, n: int = 10) -> None:
    """
    Sorts the DataFrame by a specified column and returns the top N rows.

    Parameters:
        df: The DataFrame to sort.
        order_by: The column name to sort by.
        ascending: If True, sorts in ascending order; otherwise, descending.
        top_n: The number of top rows to return.

    Returns:
        pd.DataFrame: The sorted DataFrame with the top n rows.
    """
    if order_by not in cols:
        cols.append(order_by)
    print(df.sort_values(by=order_by, ascending=not max).head(n)[cols])

In [None]:
# print(get_link(movies['movie_id'][0], 'movie'))

In [None]:
movies_linked = movies.copy()
movies_linked['movie_id'] = \
    movies_linked['movie_id'].apply(
        lambda x: get_link(x, 'movie')
            if pd.notnull(x)
            else None)

movies_linked['director_ids'] = \
    movies_linked['director_ids'].apply(
        lambda ids: [get_link(id, 'person') for id in ids]
            if isinstance(ids, set)
            else None)

movies_linked['star_ids'] = \
    movies_linked['star_ids'].apply(
        lambda ids: [get_link(id, 'person') for id in ids]
            if isinstance(ids, set)
            else None)

In [None]:
# print(movies_linked.head(1))

### Runtime

In [None]:
def plot_runtime_histogram(df):
    """
    Plots a vertical histogram of the 'runtime' column with a logarithmic x-axis.
    Adds count labels on top of each bar.

    Parameters:
        df (pandas.DataFrame): DataFrame containing the 'runtime' column.
        bins (int): Number of bins for the histogram.
    """
    # Filter for valid runtime values
    runtimes = df['runtime']
    runtimes = runtimes[runtimes > 0]

    bins = 30

    # Set up logarithmic bins
    log_bins = np.logspace(np.log10(runtimes.min()), np.log10(runtimes.max()), bins)

    plt.figure(figsize=(10, 5))
    counts, bins, patches = plt.hist(runtimes, bins=log_bins, color='skyblue', edgecolor='black')

    # Add count labels on top of bars
    for count, patch in zip(counts, patches):
        if count > 0:
            bar_center = patch.get_x() + patch.get_width() / 2
            plt.text(bar_center, count + 0.5, f'{int(count)}', ha='center', va='bottom', fontsize=8, rotation=0)

    plt.xscale('log')
    plt.xlabel('Runtime (log scale)')
    plt.ylabel('Frequency')
    plt.title('Runtimes')
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.tight_layout()
    plt.show()

In [None]:
if EXPLORE:
    plot_runtime_histogram(movies)

In [None]:
if EXPLORE:
    show_sorted(movies_linked, ['movie_id', 'runtime', 'year'], 'runtime')
# NOTE: the first one here, "Ekalavya", has a runtime of 138 minutes, not 5,538 (random +90h?); however, the others are accurate... lonk.

In [None]:
if EXPLORE:
    ml_f_runtime = movies_linked.copy()
    ml_f_runtime = ml_f_runtime[(ml_f_runtime['runtime'] >= F_RUNTIME[0]) & (ml_f_runtime['runtime'] <= F_RUNTIME[1])]

In [None]:
if EXPLORE:
    plot_runtime_histogram(ml_f_runtime)

### Years

In [None]:
def plot_year_histogram(df):
    """
    Plots a histogram of the 'year' column.

    Parameters:
        df (pandas.DataFrame): DataFrame containing the 'year' column.
        bins (int): Number of bins to use in the histogram.
    """
    years = df['year'].dropna()

    year_range = years.max() - years.min()
    bins = int(year_range / 4)

    plt.figure(figsize=(10, 5))
    counts, bins, patches = plt.hist(
        years, bins=bins, align='left',
        color='lightcoral', edgecolor='black'
    )

    # Add labels above bars
    for count, patch in zip(counts, patches):
        if count > 0:
            plt.text(
                patch.get_x() + patch.get_width() / 2,
                count + 0.5,
                str(int(count)),
                ha='center', va='bottom', fontsize=8, rotation=0
            )

    plt.xlabel('Release Year')
    plt.ylabel('Number of Movies')
    plt.title('Histogram of Movie Release Years')
    plt.grid(True, linestyle='--', linewidth=0.5)
    plt.tight_layout()
    plt.show()

In [None]:
if EXPLORE:
    plot_year_histogram(movies_linked)

In [None]:
if EXPLORE:
    show_sorted(movies_linked, ['movie_id', 'year', 'runtime'], 'year')

In [None]:
if EXPLORE:
    ml_f_year = movies_linked.copy()
    # ml_f_year = ml_f_year[(ml_f_year['year'] >= datetime.now().year)]
    ml_f_year = ml_f_year[(ml_f_year['year'] >= F_YEAR[0]) & (ml_f_year['year'] <= F_YEAR[1])]
    ml_f_year_new = ml_f_year[(ml_f_year['year'] >= F_YEAR[1])]

In [None]:
if EXPLORE:
    # ml_f_year.info()
    ml_f_year_new.info()

In [None]:
if EXPLORE:
    plot_year_histogram(ml_f_year)

### Show Me the Numbahs

In [None]:
if EXPLORE:
    # has_certificate = movies_linked.copy()
    # has_certificate = has_certificate[has_certificate['certificate'].notnull()]

    # has_votes = movies_linked.copy()
    # has_votes = has_votes[has_votes['votes'].notnull()]

    has_gross = movies_linked.copy()
    has_gross = has_gross[has_gross['gross(in $)'].notnull()]

In [None]:
if EXPLORE:
    # has_certificate.info()
    # has_votes.info()
    has_gross.info()

In [None]:
# gross_pre2015 = has_gross.copy()
# gross_pre2015 = gross_pre2015[gross_pre2015['year'] < 2015]

# gross_post2015 = has_gross.copy()
# gross_post2015 = gross_post2015[gross_post2015['year'] >= 2015]

In [None]:
# gross_pre2015.info()
# gross_post2015.info()

In [None]:
if EXPLORE:
    plot_runtime_histogram(has_gross)
    plot_year_histogram(has_gross)

# Training