In [None]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.nn.functional import softmax

df = pd.read_csv("../data/filtered_sentences.csv")  # Load your CSV file

# Load model and tokenizer
model_name = "yangheng/deberta-v3-base-absa-v1.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define the entities you want to check sentiment for
entities = ["Tony", "Blair", "Tony Blair"]

def get_sentiment(sentence, entity):
    """
    Run entity-level sentiment classification on a sentence.
    """
    # Format the input as required by ABSA models (e.g., [CLS] aspect [SEP] sentence [SEP])
    encoded_input = tokenizer(f"[CLS] {entity} [SEP] {sentence} [SEP]",
                              return_tensors="pt", truncation=True).to(device)
    
    with torch.no_grad():
        output = model(**encoded_input)
        probs = softmax(output.logits, dim=1).cpu().numpy()[0]
        sentiment = ["negative", "neutral", "positive"]
        return sentiment[probs.argmax()], probs.max()

# Dummy DataFrame (replace this with your real one)
# df = pd.DataFrame({'sentences': ["Tony Blair was praised.", "Tony made a controversial decision."]})

# Create a new column with a list of sentiment results per entity
def analyze_entities(df, entities):
    results = []
    
    for sentence in df["sentences"]:
        entity_sentiments = {}
        for entity in entities:
            if entity in sentence:
                sentiment_label, confidence = get_sentiment(sentence, entity)
                entity_sentiments[entity] = {
                    "sentiment": sentiment_label,
                    "confidence": round(float(confidence), 3)
                }
        results.append(entity_sentiments)
    
    df["entity_sentiments"] = results
    return df

# Assuming your DataFrame is named `df`
df = analyze_entities(df, entities)

# Preview
print(df[['sentences', 'entity_sentiments']].head())
