In [None]:
'''
#!pip install transformers
#!pip install torch
%pip install torch==1.13.1 transformers==4.26.1 --quiet



from pyspark.sql import SparkSession
from pyspark.sql.functions import col, monotonically_increasing_id
from transformers import pipeline
import torch


DELTA_PATH = "/mnt/nyt/archive_yearly"  # Your previously ingested data path
MODEL = "assemblyai/distilbert-base-uncased-sst2"  # Lightweight, accurate model
BATCH_SIZE = 64                            # Adjust for memory (Databricks CE = 15GB RAM)

# Load data
spark = SparkSession.builder.getOrCreate()
df = spark.read.format("delta").load(DELTA_PATH)
df = df.filter(col("headline").isNotNull())  # Remove nulls

# Add a unique ID to every row for safe join later
df_with_id = df.withColumn("row_id", monotonically_increasing_id())


# 4. COLLECT HEADLINES LOCALLY FOR SENTIMENT ANALYSIS
rows = df_with_id.select("row_id", "headline").collect()
headlines = [(row["row_id"], row["headline"]) for row in rows]


# 5. LOAD SENTIMENT ANALYSIS PIPELINE
# NOTE: Avoid `device_map="cuda"` in CE (no GPU support)
sentiment_pipeline = pipeline("sentiment-analysis", model=MODEL, tokenizer=MODEL, truncation=True)

# Run batch sentiment analysis
results = []
for row_id, headline in headlines:
    try:
        pred = sentiment_pipeline(headline)[0]
        label = "positive" if pred["label"] == "LABEL_1" else "negative"
        score = float(pred["score"])
        results.append((row_id, headline, label, score))
    except Exception as e:
        results.append((row_id, headline, "error", 0.0))  # fallback on error
        print(f"Error processing row_id {row_id}: {e}")

# --------------------------------------
# 6. CREATE SPARK DATAFRAME WITH SENTIMENT
sentiment_schema = ["row_id", "headline", "sentiment_label", "sentiment_score"]
sentiment_df = spark.createDataFrame(results, sentiment_schema)
sentiment_df = sentiment_df.withColumnRenamed("headline", "headline_sentiment")


# --------------------------------------
# 7. JOIN BACK TO ORIGINAL DATAFRAME
augmented_df = df_with_id.join(sentiment_df, on="row_id", how="left")

# --------------------------------------
# 8. DISPLAY RESULTS
augmented_df.select(
    "headline", "sentiment_label", "sentiment_score", "pub_date"
).orderBy("sentiment_score", ascending=False).show(20, truncate=False)

# --------------------------------------
# 9. (OPTIONAL) SAVE TO DELTA TABLE
OUTPUT_PATH = "/mnt/nyt/sentiment_augmented"
augmented_df.write.format("delta").mode("overwrite").save(OUTPUT_PATH)
'''


# 1. Imports
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, to_timestamp, window,
    collect_list, count, monotonically_increasing_id
)
from transformers import pipeline

#
MODEL_NAME = "assemblyai/distilbert-base-uncased-sst2"
BATCH_SIZE  = 64               # tune if OOM / too slow

spark = SparkSession.builder.getOrCreate()

# 2. Load & prepare the NYT archive
articles = (
    spark.table("nyt_archive")                               # <- your Delta table
         .filter(col("headline").isNotNull())                
         .withColumn("timestamp", to_timestamp(col("pub_date")))
         .select("timestamp", "headline", "topic")
)

# Group articles
grouped = (
    articles
      .groupBy(
          window("timestamp", "24 hours").alias("time_window"),
          col("topic")
      )
      .agg(
          collect_list("headline").alias("headlines"),        # keep them as a list
          count("*").alias("article_count")
      )
      .select(
          col("time_window.start").alias("window_start"),
          col("time_window.end").alias("window_end"),
          "topic",
          "headlines",
          "article_count"
      )
      .withColumn("group_id", monotonically_increasing_id())  # safe join key
)

# Run Sentiment Analysis
groups_local = grouped.select("group_id", "headlines").collect()

sentiment_pipe = pipeline(
    "sentiment-analysis",
    model=MODEL_NAME,
    tokenizer=MODEL_NAME,
    truncation=True,
    batch_size=BATCH_SIZE
)

results = []
for row in groups_local:
    gid        = row["group_id"]
    headlines  = row["headlines"]

    # Model inference (batched internally by HF pipeline)
    preds = sentiment_pipe(headlines)

    # Convert HF labels → human-readable
    labels = ["positive" if p["label"] == "LABEL_1" else "negative" for p in preds]

    pos_cnt = labels.count("positive")
    neg_cnt = labels.count("negative")
    maj_sent = "positive" if pos_cnt >= neg_cnt else "negative"  # tie → positive

    results.append((gid, pos_cnt, neg_cnt, maj_sent))

# Build a tiny DataFrame to join back
schema = ["group_id", "positive_count", "negative_count", "majority_sentiment"]
sentiment_df = spark.createDataFrame(results, schema)

# Final dataframe
final_df = (
    grouped
      .join(sentiment_df, on="group_id", how="left")
      .select(
          "window_start", "window_end", "topic",
          "article_count", "positive_count", "negative_count",
          "majority_sentiment"
      )
      .orderBy("window_start", "topic")
)


final_df.show(truncate=False)

# save as Delta 
OUTPUT_PATH = "/mnt/nyt/sentiment_daily_topic"
final_df.write.format("delta").mode("overwrite").save(OUTPUT_PATH)

