In [63]:
from collections import Counter

import pandas as pd


def generate_triplets(
    interactions_path,
    items_path,
    start_date=None,
    end_date=None,
    min_count=1,
    output_path="category_triplets.csv",
):
    """
    Generate triplets of categories with counts, filtered by date range and minimum count.

    :param interactions_path: Path to the interactions CSV file.
    :param items_path: Path to the items CSV file.
    :param start_date: Start date for filtering interactions (inclusive, format: 'YYYY-MM-DD').
    :param end_date: End date for filtering interactions (inclusive, format: 'YYYY-MM-DD').
    :param min_count: Minimum count of triplets to include in the output.
    :param output_path: Path to save the resulting CSV file.
    """
    # Load datasets
    interactions = pd.read_csv(
        interactions_path, usecols=["time", "user_id", "item_id", "interaction"]
    )
    items = pd.read_csv(items_path, usecols=["item_id", "cat2"])

    # Map item_id to cat2 category
    items_mapping = items.set_index("item_id")["cat2"]
    interactions["cat2"] = interactions["item_id"].map(items_mapping)

    # Filter by date range if specified
    if start_date:
        interactions = interactions[interactions["time"] >= start_date]
    if end_date:
        interactions = interactions[interactions["time"] <= end_date]

    # Filter interactions (interaction == 1)
    like_interactions = interactions[interactions["interaction"] == 1].copy()

    # Sort by user_id and time for sequential processing
    like_interactions = like_interactions.sort_values(by=["user_id", "time"])

    # Generate triplets for each user
    triplets = (
        (categories[i], categories[i + 1], categories[i + 2])
        for user_id, group in like_interactions.groupby("user_id")
        for categories in [group["cat2"].dropna().tolist()]
        for i in range(len(categories) - 2)
        if categories[i] != categories[i + 1] != categories[i + 2] != categories[i]
    )

    # Count occurrences of each triplet
    triplet_counter = Counter(triplets)

    # Create a DataFrame with counts, filtered by min_count
    result_df = pd.DataFrame.from_records(
        (
            (t[0], t[1], t[2], count)
            for t, count in triplet_counter.items()
            if count >= min_count
        ),
        columns=["user_category_1", "user_category_2", "predicted_category", "count"],
    ).sort_values(by="count", ascending=False)

    # Save the sorted triplets with counts to a CSV file
    result_df.to_csv(output_path, index=False)

    print(f"Output saved to '{output_path}'.")


# Example usage
generate_triplets(
    interactions_path="../data/interactions.csv",
    items_path="../data/items.csv",
    start_date="2024-12-10",
    end_date="2024-12-11",
    min_count=1,
    output_path="../data/category_triplets.csv",
)


Output saved to '../data/category_triplets.csv'.
