In [0]:
# Define the file paths
feedback_file_path = "/mnt/AI-Boost-Project/feedback.csv"
product_usage_file_path = "/mnt/AI-Boost-Project/product_usage.csv"

In [0]:
# Load the files as DataFrames
feedback_df = spark.read.format("csv").option("header", "true").option("inferSchema", "true").load(feedback_file_path)
product_usage_df = spark.read.format("csv").option("header", "true").option("inferSchema", "true").load(product_usage_file_path)

In [0]:
#Merge the dataframs on customer_id
customer_df = feedback_df.join(product_usage_df, on="customer_id", how="left")

# Display the merged DataFrame
display(customer_df)

In [0]:
import os
from huggingface_hub import InferenceClient
from pyspark.sql import Row


#Initialize Hugging Face client
client = InferenceClient(
    provider="cerebras",
    api_key=os.environ["HF_TOKEN"],
)

#Collect messages from Spark DataFrame

messages_list = [row['message'] for row in customer_df.select("message").collect()]


#Create prompt for theme extraction

prompt = f"""
You are a data analyst. Below is a list of 260 customer support messages.

Your task:
- Identify exactly 3 to 5 main recurring themes (do not echo messages).
- Each theme must be no more than 5 words.
- **Return ONLY an array of strings.**
- Do NOT include any extra words or explanation


Messages:
{chr(10).join(messages_list)}
"""


#Call the model
completion = client.chat.completions.create(
    model="meta-llama/Llama-3.1-8B-Instruct",
    messages=[{"role": "user", "content": prompt}],
    max_tokens=300,
    temperature=0
)


# Extract themes from response
if hasattr(completion.choices[0].message, 'content'):
    themes = completion.choices[0].message.content
else:
    themes = str(completion)

print(themes)



In [0]:
#Collect messages from Spark DataFrame
messages_list = [row['message'] for row in customer_df.select("message").collect()]
customer_ids = [row['customer_id'] for row in customer_df.select("customer_id").collect()]


#Function to map messages to a theme
def classify_message(message, themes):
    prompt = f"""
You are an assistant helping classify customer messages.
Choose exactly ONE theme from this list of options:

{themes}

Message:
{message}

Answer with only the theme text.
"""
    completion = client.chat.completions.create(
    model="meta-llama/Llama-3.1-8B-Instruct",
    messages=[{"role": "user", "content": prompt}],
    max_tokens=300,
    temperature=0
)
    
    # Extract the response
    try:
        return completion.choices[0].message.content.strip()
    except Exception:
        return "Unknown"


# Classification
results = []
for cid, msg in zip(customer_ids, messages_list):
    theme = classify_message(msg, themes)
    results.append(Row(customer_id=cid, message=msg, theme=theme))


# Convert back to Spark DataFrame
result_df = spark.createDataFrame(results)

# Display results
display(result_df)

In [0]:
# Retrieving customer information features
customer_info_df = customer_df.select(
    "customer_id", "created_at", "total_spend", "subscription_tier"
)

# Merge result_df with customer_info_df on customer_id
final_df = result_df.join(
    customer_info_df,
    on="customer_id",
    how="left"
)

# Write the final DataFrame to a table
final_df.write.mode("overwrite").saveAsTable("workspace.default.final_customer_themes")

#Theme Table

In [0]:
# Display the final DataFrame
display(final_df)

#Joined View

In [0]:
from pyspark.sql import functions as F
from pyspark.sql import Window

#Count number of messages per theme
theme_counts = final_df.groupBy("theme").agg(
    F.count("*").alias("customer_count"),
    F.avg("total_spend").alias("avg_total_spend")
)

#retrieve a sample message per theme
window = Window.partitionBy("theme").orderBy(F.rand())  # random sample
sample_messages = result_df.withColumn("rn", F.row_number().over(window)) \
                           .filter(F.col("rn") == 1) \
                           .select("theme", F.col("message").alias("sample_message"))

#Join counts and sample messages
joined_view = theme_counts.join(sample_messages, on="theme", how="left")

#Show the joined view
display(joined_view)

#Theme vs Customer Spend Chart

In [0]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Convert Spark DataFrame to Pandas for plotting
pdf = final_df.toPandas()

# Count number of messages per theme × subscription tier
theme_tier_counts = pdf.groupby(['theme', 'subscription_tier']).size().reset_index(name='count')

plt.figure(figsize=(12,6))
sns.barplot(
    x='theme', 
    y='count', 
    hue='subscription_tier', 
    data=theme_tier_counts, 
)
plt.title('Number of Issues by Theme and Subscription Tier')
plt.xlabel('Theme')
plt.ylabel('Number of Messages')
plt.xticks(rotation=45, ha='right')
plt.ylim(0, 18)
plt.legend(title='Subscription Tier')
plt.tight_layout()
plt.show()
