In [1]:
!pip install -U FlagEmbedding
!pip install -U angle-emb
!pip install openpyxl



In [2]:
!pip install datasets



In [3]:
import warnings
import pandas as pd
import numpy as np
from numpy.linalg import norm
from FlagEmbedding import FlagModel
# from angle_emb import AnglE
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from sentence_transformers import  util
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel

## configuration

In [4]:
file_path = r"Z:\BayWatch\yehoshua\dev\Netivot\notebook\twitter_embedding\all_data_us__updated_to_2024-01-11__16_54.xlsx"  # path to all data twitter sentiment file
model_name = "e5"  # e.g bge,angle,e5
# bge_model = FlagModel('BAAI/bge-large-zh-v1.5',
#                   use_fp16=True, normalize_embeddings=False) # Setting use_fp16 to True speeds up computation with a slight performance degradation
# angle_model = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1', pooling_strategy='cls').cuda()
model = AutoModel.from_pretrained('intfloat/e5-large-v2')
num_samples = 500

In [5]:
def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    # source: https://huggingface.co/intfloat/e5-large-v2
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

In [6]:
def compute_similarity(query_emb, embs):
  if not np.isclose(np.linalg.norm(query_emb), 1, atol=1e-12):
                print ("query embedding is not normalized: ", str(np.linalg.norm(query_emb, axis=0)), "normalizing...")
                query_emb = query_emb / np.linalg.norm(query_emb, axis=0)
  if not np.isclose(np.linalg.norm(embs), 1, atol=1e-12):
              print ("sentences embedding is not normalized: ", str(np.linalg.norm(embs, axis=0)), "normalizing...")
              embs = embs / np.linalg.norm(embs, axis=0)
  return query_emb @ embs.T

In [7]:
def compute_cosine_similarity(a, b):
  if a.ndim == 1:
    a = np.expand_dims(a, axis=0)
  if b.ndim == 1:
    b = np.expand_dims(b, axis=0)
  cos_sim = (a @ b.T) / (norm(a, axis=1)*norm(b, axis=1))
  return cos_sim.squeeze()

In [8]:
def encode_bge(model, sentences):
  return model.encode(sentences)


def encode_angle(model, sentences):
  return model.encode(sentences, to_numpy=True)


def encode_e5(model, sentences):
  tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-large-v2')
  batch_dict = tokenizer(sentences, max_length=512, padding=True, truncation=True, return_tensors='pt')
  outputs = model(**batch_dict)
  return average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

In [9]:
def embed_text(model, model_name, query, sentences):
    query_emb = eval(f"encode_{model_name}(model, query)")
    sentence_emb = eval(f"encode_{model_name}(model, sentences)")
    return query_emb, sentence_emb

In [10]:
tweets_df = pd.read_excel(file_path)

In [11]:
tweets_df_processed = tweets_df.copy(deep=True)
# drop NaN
tweets_df_processed.dropna(subset=['israel_related_similarity', 'conflict_sentiment', 'conflict_sentiment_v'], inplace=True)
# Convert conflict_sentiment to numeric
tweets_df_processed['conflict_sentiment'] = pd.to_numeric(tweets_df_processed['conflict_sentiment'], errors='coerce')
tweets_df_processed['conflict_sentiment_v'] = tweets_df_processed['conflict_sentiment_v'].astype(int)
# Convert conflict_sentiment values which are different than 0 to 1
tweets_df_processed['conflict_sentiment'] = tweets_df_processed['conflict_sentiment'].apply(lambda x: 0 if x == 0 else 1)
tweets_df_processed['conflict_sentiment_v'] = tweets_df_processed['conflict_sentiment_v'].apply(lambda x: 0 if x == 0 else 1)

In [12]:
query = "The ongoing conflict between Israel and Gaza involves military actions, political tensions, and humanitarian concerns"
sentences = tweets_df_processed["post_text"].tolist()

In [16]:
query_emb, sentence_emb = embed_text(model, model_name, query, sentences[:num_samples])

## compute_cosine_similarity

In [17]:
similarity = compute_cosine_similarity(query_emb.detach().numpy(), sentence_emb.detach().numpy())

In [18]:
# ada_similarity = tweets_df_processed['israel_related_similarity'].values[:num_samples]

In [19]:
def find_threshold(df, label_column, similarity, error_rate=10, print=False):
  threshold_list, precision_list, recall_list = [], [], []
  thresholds = similarity
  thresholds = list(set([round(float(value), 3) for value in thresholds]))
  thresholds.sort(reverse=True)
  max_f1 = 0
  findings = {}
  for thr in thresholds:
    labels = df[label_column].astype(int)
    predictions = similarity >= thr
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    if print: print(f'Threshold: {thr} \t Precision: {precision:.4f} \t Recall: {recall:.4f} \t F1-score: {f1:.4f}')
    if f1 > max_f1:
      max_f1 = f1
      findings['best_f1'] = {
        "thr": round(thr, 2),
        "precision": round(precision, 2),
        "recall": round(recall, 2),
        "f1": round(f1, 2)
      }
    if recall >= 1 - error_rate/100 and 'error_rate_condition' not in findings:
      findings['error_rate_condition'] = {
        "thr": round(thr, 2),
        "precision": round(precision, 2),
        "recall": round(recall, 2),
        "f1": round(f1, 2),
        "num_rows": np.count_nonzero(predictions),
        "confusion_matrix": confusion_matrix(labels, predictions)
      }
    threshold_list.append(thr)
    precision_list.append(precision)
    recall_list.append(recall)
  return(threshold_list, recall_list, precision_list, findings)

In [22]:
warnings.filterwarnings('ignore')
threshold_list, recall_list, precision_list, findings = find_threshold(tweets_df_processed[:num_samples], label_column='conflict_sentiment', similarity=similarity, error_rate=10)
findings

{'best_f1': {'thr': 0.77, 'precision': 0.5, 'recall': 1.0, 'f1': 0.67},
 'error_rate_condition': {'thr': 0.77,
  'precision': 0.5,
  'recall': 1.0,
  'f1': 0.67,
  'num_rows': 2,
  'confusion_matrix': array([[48,  1],
         [ 0,  1]], dtype=int64)}}