Note, on mac m1, use conda to install transformers:
```
conda install -c huggingface transformers
```

In [1]:
import pandas as pd
import re
import string
from transformers import pipeline
from tqdm import tqdm
from sklearn.metrics import accuracy_score

In [2]:
# config
APPLY_SAMPLE = False
SAMPLE_SIZE = 500

TWEET_DATA = '../data/01_raw/product_sentiment.csv'
LABEL_MAPPING = { 
    "No emotion toward brand or product": {
        "alt_label": "NEU",
        "class": 1,
    },
    "Positive emotion": {
        "alt_label": "POS",
        "class": 2,
    },
    "Negative emotion": {
        "alt_label": "NEG",
        "class": 0,
    },
}
USE_CLEANED_TWEET = True

BASE_MODEL = "finiteautomata/bertweet-base-sentiment-analysis"
# BASE_MODEL = "cardiffnlp/twitter-roberta-base-sentiment"

# save filepath
processed_indicator = 'raw'
if USE_CLEANED_TWEET:
    processed_indicator = 'cleaned'
PREDICTION_FILEPATH = f"../data/07_model_output/{BASE_MODEL}/{processed_indicator}_tweet_predicitons.csv"

if APPLY_SAMPLE:
    PREDICTION_FILEPATH = PREDICTION_FILEPATH.replace(".csv", f"_sample_{SAMPLE_SIZE}.csv")


### Load tweet data

In [3]:
df = pd.read_csv(TWEET_DATA)
df.rename(columns={'is_there_an_emotion_directed_at_a_brand_or_product': 'label'}, inplace=True)

### Preprocess data
1. Remove "I can't tell" labels - 156 rows (1.7% of data) and any missing tweets (1 data point)
2. Clean tweet text - remove links and `@` / `#` prefixes
3. Rename the labels - see `LABEL_MAPPING`

In [4]:
# helper functions
def _strip_links(text):
    link_regex = re.compile("((https?):((//)|(\\\\))+([\w\d:#@%/;$()~_?\+-=\\\.&](#!)?)*)", re.DOTALL)
    links = re.findall(link_regex, text)
    for link in links:
        text = text.replace(link[0], ", ")
    return text


def _strip_all_entities(text):
    entity_prefixes = ["@", ".@", "#", ".#"]
    # replace all other punctuation with a space
    # for separator in string.punctuation:
    #     if separator not in entity_prefixes:
    #         text = text.replace(separator, " ")
    words = []
    for word in text.split():
        word = word.strip()
        if word:
            if word[0] not in entity_prefixes:
                words.append(word)
    return " ".join(words)

def preprocess_tweet_text(text):
    return _strip_all_entities(_strip_links(text))

def rename_labels(row):
    row['alt_label'] = LABEL_MAPPING[row['label']]['alt_label']
    row['class'] = LABEL_MAPPING[row['label']]['class']
    return row

In [5]:
# remove unknown labels and drop na
df_cln = df.copy()[df['label'] != "I can't tell"].dropna(subset=['tweet_text'])
len_before, len_after = len(df), len(df_cln)
print(f'len before cln: {len_before}, len after cln: {len_after}')

# clean text
df_cln['tweet_text_cln'] = df_cln['tweet_text'].apply(preprocess_tweet_text)

# rename labels
df_cln = df_cln.apply(rename_labels, axis=1)
df_cln.head(3)

len before cln: 9093, len after cln: 8936


Unnamed: 0,index,tweet_text,emotion_in_tweet_is_directed_at,label,tweet_text_cln,alt_label,class
0,0,.@wesley83 I have a 3G iPhone. After 3 hrs twe...,iPhone,Negative emotion,.@wesley83 I have a 3G iPhone. After 3 hrs twe...,NEG,0
1,1,@jessedee Know about @fludapp ? Awesome iPad/i...,iPad or iPhone App,Positive emotion,Know about ? Awesome iPad/iPhone app that you'...,POS,2
2,2,@swonderlin Can not wait for #iPad 2 also. The...,iPad,Positive emotion,Can not wait for 2 also. They should sale them...,POS,2


### Load model

Using [pre-trained sentiment model](https://huggingface.co/finiteautomata/bertweet-base-sentiment-analysis?text=net-purpose+is+cool), fine-tuned on ~40k tweets

In [6]:
# use sentiment model fine-tuned on tweet data
sentiment_pipeline = pipeline(model=BASE_MODEL)

# test
data = ["lots of love for this", "windows os is bad", "london has a high population", "london is cool"]
predictions = sentiment_pipeline(data)
for i, pred in enumerate(predictions):
    print(f"test sentence: \"{data[i]}\", predicted sentiment: {pred['label']}({pred['score']:.2f})")


Downloading:   0%|          | 0.00/1.12k [00:00<?, ?B/s]

test sentence: "lots of love for this", predicted sentiment: POS(0.99)
test sentence: "windows os is bad", predicted sentiment: NEG(0.98)
test sentence: "london has a high population", predicted sentiment: NEU(0.84)
test sentence: "london is cool", predicted sentiment: POS(0.99)


### Make predictions


In [7]:
text_column = 'tweet_text'
if USE_CLEANED_TWEET:
    text_column = 'tweet_text_cln'

# sample data if required
df_pred = df_cln.copy()
if APPLY_SAMPLE:
    df_pred = df_pred.sample(SAMPLE_SIZE)

# tweet list
tweets = df_pred[text_column].tolist()

In [8]:
# full data takes ~35mins on cpu
# for loop is just as fast as `sentiment_pipeline(tweets)` on cpu

# TODO detect device (cuda) and change prediction function accordingly - for loop will likely be slower (?)
pred = []
for tweet in tqdm(tweets):
    pred.extend(sentiment_pipeline(tweet))

  1%|          | 71/8936 [00:20<43:05,  3.43it/s] 


KeyboardInterrupt: 

In [9]:
# add to df
df_pred['pred_label'] = [prediction['label'] for prediction in pred]
df_pred['pred_scores'] = [prediction['score'] for prediction in pred]

def rename_pred_label(x):
    labels = {'LABEL_0':"NEG", 'LABEL_1': "NEU", 'LABEL_2': "POS"}
    if x in labels.keys():
        x = labels[x]
    return x

df_pred['pred_label'] = df_pred['pred_label'].apply(rename_pred_label)

ValueError: Length of values (71) does not match length of index (8936)

In [None]:
# save
df_pred.to_csv(PREDICTION_FILEPATH)

### Results analysis

In [10]:
PREDICTION_FILEPATH

'../data/07_model_output/finiteautomata/bertweet-base-sentiment-analysis/cleaned_tweet_predicitons.csv'

In [11]:

df_pred = pd.read_csv(PREDICTION_FILEPATH)

accuracy_score(df_pred['alt_label'], df_pred['pred_label'])

0.6324977618621307

In [12]:
from sklearn.metrics import classification_report

report = classification_report(df_pred['alt_label'], df_pred['pred_label'], output_dict=True)


In [13]:
report

{'NEG': {'precision': 0.33712512926577043,
  'recall': 0.5719298245614035,
  'f1-score': 0.42420299284320107,
  'support': 570},
 'NEU': {'precision': 0.7623318385650224,
  'recall': 0.6310319227913883,
  'f1-score': 0.6904955320877335,
  'support': 5388},
 'POS': {'precision': 0.548874323168994,
  'recall': 0.6467427803895232,
  'f1-score': 0.5938029905965777,
  'support': 2978},
 'accuracy': 0.6324977618621307,
 'macro avg': {'precision': 0.5494437636665955,
  'recall': 0.616568175914105,
  'f1-score': 0.5695005051758374,
  'support': 8936},
 'weighted avg': {'precision': 0.6640726280513758,
  'recall': 0.6324977618621307,
  'f1-score': 0.6412859152647652,
  'support': 8936}}