In [93]:
import string
import pandas as pd
import numpy as np
import torch
from torchtext.data.utils import get_tokenizer
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge
from sklearn.multioutput import MultiOutputRegressor
from sklearn.pipeline import make_pipeline
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer

# nltk.download('wordnet')

In [85]:
data = 'data/WASSA23_conv_level_with_labels_train.tsv'
df = pd.read_table(data, header=0)
new_col = []
for names in df.columns:
    new_col.append(names.strip())
df.columns = new_col
df.drop(["conversation_id", "turn_id", "speaker_number", "article_id", "speaker_id", "essay_id"], axis=1, inplace=True)

X_data, y_data = df.loc[:, 'text'], df.drop('text', axis=1) #df.loc[:,'Emotion']
X_train, X_test, y_train , y_test = train_test_split(X_data, y_data, train_size=0.8)
#reset index of training examples
X_train, X_test = X_train.reset_index(drop=True), X_test.reset_index(drop=True)
y_train, y_test = y_train.reset_index(drop=True), y_test.reset_index(drop=True)

In [38]:
X_train, X_test, y_train, y_test

(0       those are some crazy statistics               ...
 1       Incredibly sad, and such a realization of how ...
 2       I think we could all benefit from self suffici...
 3       I truly hope something can be done to combat t...
 4       Yes, there needs to be a way to do something s...
                               ...                        
 7015    Well me too with all that came forward.  There...
 7016    Yeah, I know people are blaming the movie, but...
 7017    These past years there has been too many.     ...
 7018    I think it was terrible that there were so man...
 7019    I remember hearing he jumped into the water he...
 Name: text, Length: 7020, dtype: object,
 0       bye                                           ...
 1       What did you think of the article?            ...
 2       It is hard to believe he had all those injurie...
 3       Yeah these sites that clickbait these cures ne...
 4       Yeah, it said the other two that had injuries ...
              

- tokenization
- remove stop word and punctuatuons, numbers
- lematization
- vectorization

In [86]:
def word_preprocessor(sentence):
    tok = get_tokenizer("basic_english")
    stop_words = set(stopwords.words('english'))
    punctuations = set(string.punctuation)
    lem = WordNetLemmatizer().lemmatize

    sentence = tok(sentence)
    sentence = [word for word in sentence if word not in stop_words]
    sentence = [word for word in sentence if word not in punctuations]
    sentence_str = ' '.join(sentence)
    sentence = lem(sentence_str)
    return sentence #sentence

In [87]:
X_train = X_train.apply(word_preprocessor)
X_test = X_test.apply(word_preprocessor)

#convert labels to array
X_train, X_test = np.array(X_train), np.array(X_test)
y_train, y_test = np.array(y_train[['EmotionalPolarity', 'Emotion', 'Empathy']]), np.array(y_test[['EmotionalPolarity', 'Emotion', 'Empathy']])


In [88]:
X_train.shape, y_train.shape, X_test.shape, y_test.shape

((7020,), (7020, 3), (1756,), (1756, 3))

In [95]:
regressor = make_pipeline(  
                        TfidfVectorizer(max_features=2048),
                        MultiOutputRegressor(Ridge())
                    )
regressor.fit(X_train, y_train)
y_pred = regressor.predict(X_test)

mse = mean_squared_error(y_test, y_pred)
mae = mean_absolute_error(y_test[0], y_pred[0])
print(f'MeanSquaredError: \t {mse} \nMeanAbsoluteError: \t {mae}')

MeanSquaredError: 	 0.38767747301557226 
MeanAbsoluteError: 	 0.23008384988087646


In [98]:
y_pred[0:8], y_test[0:8]

(array([[1.33457107, 2.72398481, 1.63169567],
        [1.25178642, 1.694748  , 1.29955085],
        [1.95352253, 3.07710957, 3.06809042],
        [0.94203286, 1.77813621, 1.91456287],
        [1.46044139, 2.29102417, 1.914624  ],
        [0.30728377, 1.54681841, 0.97736431],
        [1.56066379, 1.97697262, 2.26354768],
        [1.19975846, 2.18944368, 2.88791831]]),
 array([[1.    , 2.6667, 1.3333],
        [1.    , 1.3333, 2.    ],
        [2.    , 3.3333, 3.    ],
        [1.    , 1.3333, 1.3333],
        [2.    , 3.3333, 2.3333],
        [0.3333, 1.6667, 1.3333],
        [1.6667, 2.6667, 2.6667],
        [1.6667, 2.3333, 3.6667]]))