Emotion Classifier

# Setup

## Imports 

In [25]:
import pandas as pd
pd.set_option('display.max_colwidth', 200)

import json
from typing import List, Dict, Union


## Reading in the Data

In [26]:
# Load emotion names from file
with open('../datasets/GoEmotions/emotions.txt', 'r') as f:
    emotion_names = [line.strip() for line in f]

# Function to read and expand one file
def load_go_emotions_split(path):
    df = pd.read_csv(path, sep='\t', header=None, names=['text', 'emotion_labels', 'id'])
    # Create 28 one-hot columns, all default to 0
    for i, emo in enumerate(emotion_names):
        df[emo] = 0
    # Fill columns by parsing emotion_labels
    for idx, row in df.iterrows():
        label_idxs = list(map(int, row['emotion_labels'].split(',')))
        for label in label_idxs:
            df.at[idx, emotion_names[label]] = 1
    return df

# Load all splits
go_emotions_train = load_go_emotions_split('../datasets/GoEmotions/train.tsv')
go_emotions_val   = load_go_emotions_split('../datasets/GoEmotions/dev.tsv')
go_emotions_test  = load_go_emotions_split('../datasets/GoEmotions/test.tsv')

# Read in the Ekman mapping
with open('../datasets/GoEmotions/ekman_mapping.json', 'r') as f:
    ekman_map = json.load(f)

## Helper Functions

### Ekman Mapping

In [27]:
def ekman_category_breakdown(
    df: pd.DataFrame,
    emotion_columns: list,
    ekman_mapping: dict
) -> None:
    """
    Prints a clean percentage breakdown of each Ekman umbrella category in the dataset.

    Args:
        df (pd.DataFrame): DataFrame with one-hot columns for emotions.
        emotion_columns (list): List of the 28 emotion column names.
        ekman_mapping (dict): Dict mapping Ekman categories to emotion names.
    """
    total = len(df)
    print("Ekman category percentage breakdown:")
    for ekman_cat, fine_emotions in ekman_mapping.items():
        present = df[fine_emotions].any(axis=1)
        pct = present.sum() / total * 100
        print(f"{ekman_cat:<9} : {pct:.2f}%")


### Text Stats

In [28]:
def text_length_stats(
    df: pd.DataFrame,
    text_col: str = "text",
    by: str = "char",
    return_df: bool = False
):
    """
    Print a neatly aligned stats table for lengths of the text column (char or word count).

    Args:
        df (pd.DataFrame): DataFrame with the text data.
        text_col (str): Column name containing text. Default 'text'.
        by (str): 'char' for character count, 'word' for word count.
        return_df (bool): If True, also return the stats as a DataFrame.

    Returns:
        Optional[pd.DataFrame]: Stats DataFrame if requested.
    """
    if by == "char":
        lengths = df[text_col].astype(str).apply(len)
    elif by == "word":
        lengths = df[text_col].astype(str).apply(lambda x: len(x.split()))
    else:
        raise ValueError("`by` must be 'char' or 'word'")

    stats = [
        ("count",        lengths.count()),
        ("min",          lengths.min()),
        ("Q1",           lengths.quantile(0.25)),
        ("median",       lengths.median()),
        ("mean",         lengths.mean()),
        ("Q3",           lengths.quantile(0.75)),
        ("max",          lengths.max()),
        ("mode",         lengths.mode().values[0] if not lengths.mode().empty else None),
        ("std dev",      lengths.std()),
        ("variance",     lengths.var()),
        ("IQR",          lengths.quantile(0.75) - lengths.quantile(0.25)),
    ]
    print(f"\n{'Text length statistics (' + ('characters' if by=='char' else 'words') + ')':^36}")
    print("=" * 36)
    for label, val in stats:
        print(f"{label:<10}: {val:>10.2f}" if isinstance(val, float) else f"{label:<10}: {val:>10}")
    if return_df:
        return pd.DataFrame(stats, columns=["statistic", "value"]).set_index("statistic")


# Exploring Training Set

In [29]:
print("5 random rows of the training set:\n")
display(go_emotions_train.sample(5).style.set_properties(**{'white-space': 'pre-wrap'}))


5 random rows of the training set:



Unnamed: 0,text,emotion_labels,id,admiration,amusement,anger,annoyance,approval,caring,confusion,curiosity,desire,disappointment,disapproval,disgust,embarrassment,excitement,fear,gratitude,grief,joy,love,nervousness,optimism,pride,realization,relief,remorse,sadness,surprise,neutral
9819,Taking a float plane is such a amazing experience - highly recommend!,13,edzypwe,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0
40075,Sounds hella aggressive. You can say the same thing while being nicer about it.,2,efayfae,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
18239,"I have a feeling this is gonna get really crowded, and end up being a shit show with such short notice for everything.",23,ef78ks9,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
34945,Thank you. If im not authorized to pickup from daycare tomorrow ill get my daughter friday with this court order and the police if necessary.,15,eetarpi,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0
36145,Of course the one Ospreys player was my boi [NAME] <3,18,edmlyqy,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0


In [30]:
print("go_emotion_train info:\n")
print(go_emotions_train.info())

go_emotion_train info:

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 43410 entries, 0 to 43409
Data columns (total 31 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   text            43410 non-null  object
 1   emotion_labels  43410 non-null  object
 2   id              43410 non-null  object
 3   admiration      43410 non-null  int64 
 4   amusement       43410 non-null  int64 
 5   anger           43410 non-null  int64 
 6   annoyance       43410 non-null  int64 
 7   approval        43410 non-null  int64 
 8   caring          43410 non-null  int64 
 9   confusion       43410 non-null  int64 
 10  curiosity       43410 non-null  int64 
 11  desire          43410 non-null  int64 
 12  disappointment  43410 non-null  int64 
 13  disapproval     43410 non-null  int64 
 14  disgust         43410 non-null  int64 
 15  embarrassment   43410 non-null  int64 
 16  excitement      43410 non-null  int64 
 17  fear            43410 non-

In [31]:
ekman_category_breakdown(go_emotions_train, emotion_names, ekman_map)

Ekman category percentage breakdown:
anger     : 12.85%
disgust   : 1.83%
fear      : 1.67%
joy       : 40.11%
sadness   : 7.52%
surprise  : 12.36%


In [32]:
text_length_stats(go_emotions_train, text_col="text", by="char")


Text length statistics (characters) 
count     :      43410
min       :          2
Q1        :      38.00
median    :      65.00
mean      :      68.40
Q3        :      96.00
max       :        703
mode      :         56
std dev   :      36.72
variance  :    1348.50
IQR       :      58.00


# Exploring Validation Set

In [33]:
print("5 random rows of the validation set:\n")
display(go_emotions_val.sample(5).style.set_properties(**{'white-space': 'pre-wrap'}))


5 random rows of the validation set:



Unnamed: 0,text,emotion_labels,id,admiration,amusement,anger,annoyance,approval,caring,confusion,curiosity,desire,disappointment,disapproval,disgust,embarrassment,excitement,fear,gratitude,grief,joy,love,nervousness,optimism,pride,realization,relief,remorse,sadness,surprise,neutral
289,I miss those days.,25,ed4wj8r,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0
7,There it is!,27,ede4v0m,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
1887,Worst advice ever. This would piss me off.,211,eemwgao,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
295,"It appeared that she smacked her grape pretty damn good. Oh, and I chuckled.",117,effahnx,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0
4859,Walmart is also a good place to look at the local fauna.,0,eeslhys,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [34]:
print("go_emotions_val info:\n")
print(go_emotions_val.info())

go_emotions_val info:

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5426 entries, 0 to 5425
Data columns (total 31 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   text            5426 non-null   object
 1   emotion_labels  5426 non-null   object
 2   id              5426 non-null   object
 3   admiration      5426 non-null   int64 
 4   amusement       5426 non-null   int64 
 5   anger           5426 non-null   int64 
 6   annoyance       5426 non-null   int64 
 7   approval        5426 non-null   int64 
 8   caring          5426 non-null   int64 
 9   confusion       5426 non-null   int64 
 10  curiosity       5426 non-null   int64 
 11  desire          5426 non-null   int64 
 12  disappointment  5426 non-null   int64 
 13  disapproval     5426 non-null   int64 
 14  disgust         5426 non-null   int64 
 15  embarrassment   5426 non-null   int64 
 16  excitement      5426 non-null   int64 
 17  fear            5426 non-null

In [35]:
ekman_category_breakdown(go_emotions_val, emotion_names, ekman_map)

Ekman category percentage breakdown:
anger     : 13.21%
disgust   : 1.79%
fear      : 1.94%
joy       : 40.90%
sadness   : 7.19%
surprise  : 11.50%


In [36]:
text_length_stats(go_emotions_val, text_col="text", by="char")


Text length statistics (characters) 
count     :       5426
min       :          5
Q1        :      37.00
median    :      64.00
mean      :      68.24
Q3        :      96.00
max       :        187
mode      :         37
std dev   :      36.91
variance  :    1362.24
IQR       :      59.00


# Exploring Test Set

In [37]:
print("5 random rows of the test set:\n")
display(go_emotions_test.sample(5).style.set_properties(**{'white-space': 'pre-wrap'}))


5 random rows of the test set:



Unnamed: 0,text,emotion_labels,id,admiration,amusement,anger,annoyance,approval,caring,confusion,curiosity,desire,disappointment,disapproval,disgust,embarrassment,excitement,fear,gratitude,grief,joy,love,nervousness,optimism,pride,realization,relief,remorse,sadness,surprise,neutral
321,I think teams will still cut the cross-seam that made [NAME] and [NAME] so deadly last year. This is awesome nonetheless.,20,ednc0w3,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0
3349,"The lead designer, not the producer :)",27,efdfvft,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
1303,"No problem at all! If you ever want to talk to some internet random again, feel free to talk to me.",5,ed50vzo,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1606,"“I couldn’t, I wouldn’t” you said it ;)",27,eeu00m5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
3105,"Play hard, but we don't want to reveal our secret raptor killing strategy until the playoffs...",20,efgm57i,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0


In [38]:
print("go_emotions_test info:\n")
print(go_emotions_test.info())

go_emotions_test info:

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5427 entries, 0 to 5426
Data columns (total 31 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   text            5427 non-null   object
 1   emotion_labels  5427 non-null   object
 2   id              5427 non-null   object
 3   admiration      5427 non-null   int64 
 4   amusement       5427 non-null   int64 
 5   anger           5427 non-null   int64 
 6   annoyance       5427 non-null   int64 
 7   approval        5427 non-null   int64 
 8   caring          5427 non-null   int64 
 9   confusion       5427 non-null   int64 
 10  curiosity       5427 non-null   int64 
 11  desire          5427 non-null   int64 
 12  disappointment  5427 non-null   int64 
 13  disapproval     5427 non-null   int64 
 14  disgust         5427 non-null   int64 
 15  embarrassment   5427 non-null   int64 
 16  excitement      5427 non-null   int64 
 17  fear            5427 non-nul

In [39]:
ekman_category_breakdown(go_emotions_test, emotion_names, ekman_map)

Ekman category percentage breakdown:
anger     : 13.38%
disgust   : 2.27%
fear      : 1.81%
joy       : 38.77%
sadness   : 6.98%
surprise  : 12.47%


In [40]:
text_length_stats(go_emotions_test, text_col="text", by="char")


Text length statistics (characters) 
count     :       5427
min       :          5
Q1        :      37.00
median    :      65.00
mean      :      67.82
Q3        :      95.00
max       :        184
mode      :         24
std dev   :      36.32
variance  :    1319.03
IQR       :      58.00


# Prototype CNN

In [41]:
sample = "UGH srsly my head is POUNDING! got no sleep AGAIN? fUCK this BS :( cant even think. just wanna cry rn 😭"

## Tokenization

In [42]:
import string

# Create a simple character vocabulary (expand as needed)
all_chars = list(string.ascii_lowercase + string.ascii_uppercase + string.digits +
                 string.punctuation + string.whitespace + "😴😭")
char2idx = {c: i for i, c in enumerate(all_chars)}
vocab_size = len(char2idx)

# Tokenize the sample
tokens = [char2idx.get(c, 0) for c in sample]
print(tokens)


[46, 32, 33, 94, 18, 17, 18, 11, 24, 94, 12, 24, 94, 7, 4, 0, 3, 94, 8, 18, 94, 41, 40, 46, 39, 29, 34, 39, 32, 62, 94, 6, 14, 19, 94, 13, 14, 94, 18, 11, 4, 4, 15, 94, 26, 32, 26, 34, 39, 82, 94, 5, 46, 28, 36, 94, 19, 7, 8, 18, 94, 27, 44, 94, 77, 69, 94, 2, 0, 13, 19, 94, 4, 21, 4, 13, 94, 19, 7, 8, 13, 10, 75, 94, 9, 20, 18, 19, 94, 22, 0, 13, 13, 0, 94, 2, 17, 24, 94, 17, 13, 94, 101]


## 16d character vectors

In [43]:
import torch
import torch.nn as nn

# 16d character embeddings
embedding_dim = 16
embeddings = nn.Embedding(vocab_size, embedding_dim)

char_tensor = torch.tensor(tokens).unsqueeze(0)  # shape: (1, seq_len)
char_embedded = embeddings(char_tensor)          # shape: (1, seq_len, 16)
print(char_embedded.shape)
print(char_embedded[0])  # Print the embeddings for each character


torch.Size([1, 103, 16])
tensor([[ 0.4292, -1.3076,  0.6829,  ...,  0.4087,  0.7021, -1.3959],
        [ 1.8084,  0.5969, -0.0820,  ...,  0.9963, -0.5001, -1.5666],
        [-1.4377, -2.5468,  0.0385,  ..., -1.3265, -0.4006, -0.3865],
        ...,
        [-0.4096,  0.4846,  0.6227,  ..., -0.1391, -2.4851, -0.6276],
        [-0.4211, -0.7751,  0.1817,  ..., -0.6272, -0.7914, -0.5547],
        [ 0.8762,  0.1688, -0.2232,  ...,  0.3199, -0.4995, -1.0901]],
       grad_fn=<SelectBackward0>)


## trigrams

In [44]:
conv3 = nn.Conv1d(in_channels=embedding_dim, out_channels=21, kernel_size=3, padding=1)
x3 = char_embedded.permute(0, 2, 1)  # (batch, channels, seq_len)
trigrams = conv3(x3)                 # shape: (1, 21, seq_len)
print(trigrams.shape)
print(trigrams[0].permute(1, 0))     # (seq_len, 21)


torch.Size([1, 21, 103])
tensor([[-0.8361, -1.0257, -0.1592,  ..., -0.5087, -0.9399, -0.4476],
        [-0.0210, -0.0065,  0.1501,  ...,  0.2900,  0.4302,  0.4450],
        [ 0.4219,  0.6931, -0.0842,  ..., -0.3045,  0.2046, -0.4646],
        ...,
        [ 0.1030,  0.1343,  0.2223,  ...,  0.8430,  0.1525, -0.2251],
        [-1.0321,  0.0077, -0.0296,  ..., -0.3194, -0.5092, -0.5664],
        [-0.7594, -0.9428,  0.4877,  ...,  0.3981,  0.4297,  0.0063]],
       grad_fn=<PermuteBackward0>)


## pentagrams

In [45]:
conv5 = nn.Conv1d(in_channels=embedding_dim, out_channels=21, kernel_size=5, padding=2)
pentagrams = conv5(x3)               # shape: (1, 21, seq_len)
print(pentagrams.shape)
print(pentagrams[0].permute(1, 0))   # (seq_len, 21)


torch.Size([1, 21, 103])
tensor([[-1.1919, -0.9274, -0.0529,  ...,  0.6986, -0.3840, -0.1152],
        [-0.0344, -0.0528,  0.4243,  ..., -0.9816, -0.0536,  1.2123],
        [ 0.8655, -0.3201,  0.0120,  ..., -0.5864,  0.3383,  0.2093],
        ...,
        [-0.1185,  0.6189,  0.6760,  ...,  0.1322,  1.2785, -0.7947],
        [ 0.1628,  0.2216, -0.7359,  ...,  0.0291,  0.1646, -0.8811],
        [-0.4001, -0.4673, -0.0919,  ..., -0.1808, -0.2585, -0.0324]],
       grad_fn=<PermuteBackward0>)


## Heptagrams

In [46]:
conv7 = nn.Conv1d(in_channels=embedding_dim, out_channels=22, kernel_size=7, padding=3)
heptagrams = conv7(x3)               # shape: (1, 22, seq_len)
print(heptagrams.shape)
print(heptagrams[0].permute(1, 0))   # (seq_len, 22)


torch.Size([1, 22, 103])
tensor([[-1.0663,  0.1089, -0.2064,  ..., -0.6108,  0.7009,  0.8391],
        [-0.5627,  0.1133,  0.5301,  ..., -0.4353, -0.5453,  0.2704],
        [ 0.2386,  0.6018, -1.0504,  ...,  1.0313, -0.2714,  0.4327],
        ...,
        [ 0.6374,  0.5192, -0.3694,  ...,  0.0784,  0.0736,  0.4718],
        [-0.1441, -0.5059, -0.5742,  ..., -0.0343,  0.6339,  0.6863],
        [-0.3361, -0.2474,  0.2531,  ..., -0.4836,  0.1478, -0.3208]],
       grad_fn=<PermuteBackward0>)


## max pooling

In [47]:
# For each set of filters, take the max across the time dimension (seq_len)
trigram_pooled = torch.max(trigrams, dim=2).values  # (1, 21)
pentagram_pooled = torch.max(pentagrams, dim=2).values  # (1, 21)
heptagram_pooled = torch.max(heptagrams, dim=2).values  # (1, 22)

print(trigram_pooled)
print(pentagram_pooled)
print(heptagram_pooled)


tensor([[0.9506, 1.2908, 1.3667, 1.4502, 1.2214, 1.0101, 0.9618, 1.5814, 1.6560,
         1.4454, 1.8745, 1.1384, 1.2658, 1.7635, 1.0730, 0.9360, 1.2501, 1.2008,
         1.0500, 1.2579, 1.3560]], grad_fn=<MaxBackward0>)
tensor([[0.9899, 1.4206, 1.1570, 1.3253, 1.0319, 1.3364, 1.1160, 1.3235, 1.5939,
         1.5837, 1.3754, 0.9213, 0.8833, 1.2693, 1.5981, 0.7960, 1.1314, 0.9276,
         1.6547, 1.7418, 1.2123]], grad_fn=<MaxBackward0>)
tensor([[1.1296, 1.2718, 1.2047, 1.2286, 1.5506, 1.7004, 1.7020, 1.9522, 1.0898,
         0.9598, 1.3179, 1.3270, 1.0898, 1.6458, 1.4549, 1.2193, 1.1689, 0.9310,
         1.6835, 1.6070, 1.5305, 1.3931]], grad_fn=<MaxBackward0>)


## 64d output

In [48]:
final_cnn_output = torch.cat([trigram_pooled, pentagram_pooled, heptagram_pooled], dim=1)  # (1, 64)
print(final_cnn_output.shape)
print(final_cnn_output)


torch.Size([1, 64])
tensor([[0.9506, 1.2908, 1.3667, 1.4502, 1.2214, 1.0101, 0.9618, 1.5814, 1.6560,
         1.4454, 1.8745, 1.1384, 1.2658, 1.7635, 1.0730, 0.9360, 1.2501, 1.2008,
         1.0500, 1.2579, 1.3560, 0.9899, 1.4206, 1.1570, 1.3253, 1.0319, 1.3364,
         1.1160, 1.3235, 1.5939, 1.5837, 1.3754, 0.9213, 0.8833, 1.2693, 1.5981,
         0.7960, 1.1314, 0.9276, 1.6547, 1.7418, 1.2123, 1.1296, 1.2718, 1.2047,
         1.2286, 1.5506, 1.7004, 1.7020, 1.9522, 1.0898, 0.9598, 1.3179, 1.3270,
         1.0898, 1.6458, 1.4549, 1.2193, 1.1689, 0.9310, 1.6835, 1.6070, 1.5305,
         1.3931]], grad_fn=<CatBackward0>)
