In [29]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
File Name: main.py
Author: Alexandre Donciu-Julin
Date: 2024-10-14
Description: Main file to run inference on the model.
"""

# load custom modules ------------------------------------------------------------------------------
import os
import pandas as pd
import importlib
import helpers
import prompting
importlib.reload(helpers)
importlib.reload(prompting)

# constants and paths ------------------------------------------------------------------------------
DATA_PROCESSED_PKL = 'pickle/data_processed.pkl'
DATA_CLUSTERED_PKL = 'pickle/data_clustered.pkl'
DATA_SA_PKL = 'pickle/data_sentiment_analysis.pkl'
DATA_SCORED_PKL = 'pickle/data_scored.pkl'
SEP = 100 * '-'

# global variables ---------------------------------------------------------------------------------
# to avoid loading the model and tokenizer multiple times
model = None
tokenizer = None


def review_best_product_by_category(
        model: object,
        tokenizer: object,
        category: str,
        n: int = 10
) -> dict:
    """Review the best product in a category based on n reviews.

    Args:
        model (object): The model instance to use for inference.
        tokenizer (object): The tokenizer instance to use for inference.
        category (str): The category to extract the best product from.
        n (int, optional): Number of reviews to use for the final review. Defaults to 10.

    Returns:
        str: The full review packed inside a dict.
    """
    
    print(f"{SEP}\nCATEGORY: {category}")
    
    # get the name of the top product in the category
    top_product_name = helpers.get_top_products_per_category(category, 1)[0]
    print(f"{SEP}\nTOP PRODUCT NAME: {top_product_name}\n{SEP}")
    
    # get n reviews for the top product
    sample_reviews = helpers.sample_product_reviews(top_product_name, category, n)

    print(f"{SEP}\nINFERENCE 1: SAMPLE OF PRODUCT REVIEWS:\n")

    # load review text from pickle if available
    pickle_path = os.path.join('pickle', f"reviews_{category.lower().replace(' ', '_')}.pkl")

    if os.path.exists(pickle_path):
        review_text = helpers.load_pickled_reviews(pickle_path)
        print(f"{review_text}")

    else:
        # summarize all reviews and build review_text
        review_text = ""
        for i, review in enumerate(sample_reviews):
            # infer model and summarize review
            review_summary = f"[Review {i + 1}]: {prompting.generate_review_summary(model, tokenizer, review)}"
            print(review_summary)
            review_text += review_summary + '\n'

        # pickle the review text
        helpers.pickle_list_reviews(review_text, pickle_path)

    # generate recurring ideas
    recurring_ideas = prompting.generate_reviews_recurring_ideas(model, tokenizer, review_text, max_tokens=75)
    print(f"{SEP}\nINFERENCE 2: RECURRING IDEAS:\n\n{recurring_ideas}")

    # generate final review
    review_title, product_review = prompting.generate_final_review(model, tokenizer, top_product_name, recurring_ideas, max_tokens=150)
    print(f"{SEP}\nINFERENCE 3: FINAL REVIEW:\n")

    review_dict = {}
    review_dict["category"] = category
    review_dict["product"] = top_product_name
    review_dict["title"] = review_title
    review_dict["review"] = product_review

    return review_dict


if __name__ == "__main__":

    # load HugginFace token to environment
    if not os.environ.get('HF_TOKEN'):
        os.environ['HF_TOKEN'] = input('Enter API token for Hugging Face: ')
    else:
        print('Hugging Face token already loaded to environment')

    # load the model and tokenizer
    if model is None and tokenizer is None:
        model, tokenizer = prompting.load_model_and_tokenizer()
    else:
        print('Model and tokenizer already loaded')

    # get the unique categories
    categories = helpers.get_categories_from_dataset()

    # review the best product for a given category
    category = categories[3]

    review_dict = review_best_product_by_category(model, tokenizer, category)

    for k, v in review_dict.items():
        print(f"{k}: {v}")
        print()

Hugging Face token already loaded to environment


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

----------------------------------------------------------------------------------------------------
CATEGORY: Pet Supplies
----------------------------------------------------------------------------------------------------
TOP PRODUCT NAME: Cat Litter Box Covered Tray Kitten Extra Large Enclosed Hooded Hidden Toilet
----------------------------------------------------------------------------------------------------


KeyboardInterrupt: 