In [59]:
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.feature_extraction import DictVectorizer
from sklearn.model_selection import train_test_split

import numpy as np
import pandas as pd

np.set_printoptions(linewidth=np.inf)

In [21]:
def parse_data(tweets: list[str]) -> dict[str, int]:
    corpus: list[dict] = []
    for tweet in tweets:
        current_dict: dict[str, int] = {}
        for word in tweet.split():
            if word in current_dict:
                current_dict[word] += 1
            else:
                current_dict[word] = 1
        corpus.append(current_dict)
        
    return corpus

In [27]:
df = pd.read_csv('../../datasets/tweet_emotions.csv')
df.dtypes

tweet_id      int64
sentiment    object
content      object
dtype: object

In [63]:
unique_emotions = df['sentiment'].unique()
print(unique_emotions)

['empty'
 'sadness'
 'enthusiasm'
 'neutral'
 'worry'
 'surprise'
 'love'
 'fun'
 'hate'
 'happiness'
 'boredom'
 'relief'
 'anger']


In [55]:
X = df.loc[:, 'content']
y = df.loc[:, 'sentiment']

In [35]:
dv = DictVectorizer()

In [56]:
X = dv.fit_transform(parse_data(X))

In [58]:
X

<40000x83297 sparse matrix of type '<class 'numpy.float64'>'
	with 505606 stored elements in Compressed Sparse Row format>

In [45]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)


AttributeError: split not found

In [38]:
model = MultinomialNB()

In [46]:
model.fit(X_train, y_train)

In [47]:
y_train_pred = model.predict(X_train)

In [48]:
accuracy_score(y_train, y_train_pred)

0.5792857142857143

In [49]:
y_test_pred = model.predict(X_test)
accuracy_score(y_test, y_test_pred)

0.2901666666666667

In [64]:

confusion_matrix(y_test, y_test_pred)

array([[   0,    0,    0,    0,    0,    0,    0,    0,    7,    0,    2,    0,   21],
       [   0,    0,    0,    0,    0,    0,    0,    0,    6,    0,    2,    0,   40],
       [   0,    0,    0,    0,    1,   12,    1,    1,   74,    1,    8,    2,  156],
       [   0,    0,    0,    0,    0,   27,    0,    6,   60,    0,    5,    0,  128],
       [   0,    0,    0,    0,    0,   67,    0,   16,  127,    0,    7,    1,  326],
       [   0,    0,    0,    0,    0,  309,    1,   89,  353,    1,   32,    2,  766],
       [   0,    0,    0,    0,    0,    4,    1,    2,   31,    0,   27,    0,  298],
       [   0,    0,    0,    0,    1,  147,    0,  230,  210,    0,   21,    1,  515],
       [   0,    0,    4,    0,    2,  173,    0,   62,  813,    0,   92,    7, 1483],
       [   0,    0,    0,    0,    0,   37,    0,   20,   93,    0,   10,    0,  283],
       [   0,    0,    1,    0,    0,   25,    1,   13,  168,    0,  105,    0, 1243],
       [   0,    0,    1,    0,    0,   52,