In [1]:
import nltk
from nltk.stem.porter import *
from torch.nn import *
from torch.optim import *
import numpy as np
import pandas as pd
import torch,torchvision
import random
from tqdm import *
from torch.utils.data import Dataset,DataLoader
stemmer = PorterStemmer()

In [2]:
def tokenize(sentence):
    return nltk.word_tokenize(sentence)

In [3]:
tokenize('$1000')

['$', '1000']

In [4]:
def stem(word):
    return stemmer.stem(word.lower())

In [5]:
stem('organic')

'organ'

In [6]:
def bag_of_words(tokenized_words,all_words):
    tokenized_words = [stem(w) for w in tokenized_words]
    bag = np.zeros(len(all_words),dtype=np.float32)
    for idx,w in enumerate(all_words):
        if w in tokenized_words:
            bag[idx] = 1.0
    return bag

In [7]:
bag_of_words(['hi'],['oloegsegse','hi'])

array([0., 1.], dtype=float32)

In [8]:
data = pd.read_csv('./data.csv')

In [9]:
X = data['text']
y = data['label']

In [10]:
all_words = []
all_data = []
tags = []

In [11]:
for X_batch,y_batch in tqdm(zip(X,y)):
    X_batch = tokenize(X_batch)
    new_X = []
    for X_batch_in_X_batch in X_batch:
        new_X.append(stem(X_batch_in_X_batch))
    all_words.extend(new_X)
    all_data.append((new_X,y_batch))
    tags.append(y_batch)

3613it [00:01, 2624.53it/s]


In [12]:
all_words = sorted(set(all_words))
tags = sorted(set(tags))

In [13]:
tags

['anger', 'fear', 'joy', 'sadness']

In [14]:
X = []
y = []

In [15]:
for sentence,tag in all_data:
    sentence = bag_of_words(sentence,all_words)
    X.append(sentence)
    y.append(tags.index(tag))

In [16]:
from sklearn.model_selection import *
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.125,shuffle=False)
device = 'cuda'

In [17]:
X_train = torch.from_numpy(np.array(X_train)).to(device).float()
y_train = torch.from_numpy(np.array(y_train)).to(device).to(device).long()
X_test = torch.from_numpy(np.array(X_test)).to(device).to(device).float()
y_test = torch.from_numpy(np.array(y_test)).to(device).to(device).long()

In [18]:
def get_loss(model,X,y,criterion):
    preds = model(X)
    loss = criterion(preds,y)
    return loss.item()

In [19]:
def get_accuracy(model,X,y):
    correct = 0
    total = 0
    preds = model(X)
    for pred,y_batch in zip(preds,y):
        pred = int(torch.argmax(pred))
        if pred == y_batch:
            correct += 1
        total += 1
    print(correct,total)
    acc = round(correct/total,3)*100
    return acc

In [20]:
len(tags)

4

In [21]:
class Model(Module):
    def __init__(self,hidden=2048,iters=12,activation=ReLU):
        super().__init__()
        self.iters = iters
        self.activation = activation()
        self.hidden = hidden
        self.linear1 = Linear(len(all_words),hidden)
        self.linear2 = Linear(hidden,hidden)
        self.output = Linear(hidden,len(tags))
    
    def forward(self,X):
        preds = self.linear1(X)
        for _ in range(self.iters):
            preds = self.activation(self.linear2(preds))
        preds = self.output(preds)
        return preds

In [22]:
model = Model().to(device)

In [23]:
criterion = CrossEntropyLoss()

In [24]:
optimizer = Adam(model.parameters(),lr=0.001)

In [25]:
epochs = 100

In [26]:
batch_size = 8

In [27]:
import wandb

In [None]:
wandb.init(project='Emotion-Classification-NLP',name='baseline')
wandb.watch(model)
for _ in tqdm(range(epochs)):
    for idx in range(0,len(X_train),batch_size):
        X_batch = X_train[idx:idx+batch_size].to(device)
        y_batch = y_train[idx:idx+batch_size].to(device)
        preds = model(X_batch)
        loss = criterion(preds,y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    wandb.log({'Loss':get_loss(model,X_train,y_train,criterion)})
    wandb.log({'Val Loss':get_loss(model,X_test,y_test,criterion)})
    wandb.log({'Acc':get_accuracy(model,X_train,y_train)})
    wandb.log({'Val Acc':get_accuracy(model,X_test,y_test)})
wandb.watch(model)
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mranuga-d[0m (use `wandb login --relogin` to force relogin)


  1%|▍                                          | 1/100 [00:01<02:56,  1.78s/it]

793 3161
0 452


  2%|▊                                          | 2/100 [00:03<02:58,  1.82s/it]

857 3161
0 452


  3%|█▎                                         | 3/100 [00:05<02:58,  1.84s/it]

857 3161
0 452


  4%|█▋                                         | 4/100 [00:07<02:59,  1.87s/it]

857 3161
0 452


  5%|██▏                                        | 5/100 [00:09<02:56,  1.86s/it]

894 3161
0 452


  6%|██▌                                        | 6/100 [00:10<02:50,  1.81s/it]

996 3161
0 452


  7%|███                                        | 7/100 [00:12<02:47,  1.80s/it]

2243 3161
0 452


  8%|███▍                                       | 8/100 [00:14<02:43,  1.78s/it]

2268 3161
0 452


  9%|███▊                                       | 9/100 [00:16<02:42,  1.79s/it]

2505 3161
0 452


 10%|████▏                                     | 10/100 [00:18<02:42,  1.81s/it]

2551 3161
0 452


 11%|████▌                                     | 11/100 [00:19<02:39,  1.80s/it]

2606 3161
0 452


 12%|█████                                     | 12/100 [00:21<02:38,  1.81s/it]

2606 3161
0 452


 13%|█████▍                                    | 13/100 [00:23<02:37,  1.81s/it]

2557 3161
0 452


 14%|█████▉                                    | 14/100 [00:25<02:34,  1.80s/it]

936 3161
0 452


 15%|██████▎                                   | 15/100 [00:27<02:30,  1.77s/it]

2351 3161
0 452


 16%|██████▋                                   | 16/100 [00:28<02:26,  1.74s/it]

2504 3161
3 452


 17%|███████▏                                  | 17/100 [00:30<02:25,  1.75s/it]

2537 3161
40 452


 18%|███████▌                                  | 18/100 [00:32<02:23,  1.75s/it]

1216 3161
9 452


 19%|███████▉                                  | 19/100 [00:33<02:21,  1.75s/it]

392 3161
11 452


 20%|████████▍                                 | 20/100 [00:35<02:20,  1.76s/it]

2661 3161
36 452


 21%|████████▊                                 | 21/100 [00:37<02:19,  1.77s/it]

2932 3161
31 452


 22%|█████████▏                                | 22/100 [00:39<02:17,  1.77s/it]

2990 3161
53 452


 23%|█████████▋                                | 23/100 [00:41<02:17,  1.78s/it]

2896 3161
91 452


 24%|██████████                                | 24/100 [00:42<02:17,  1.80s/it]

2905 3161
107 452


 25%|██████████▌                               | 25/100 [00:44<02:15,  1.81s/it]

3045 3161
133 452


 26%|██████████▉                               | 26/100 [00:46<02:12,  1.79s/it]

3078 3161
130 452


 27%|███████████▎                              | 27/100 [00:48<02:11,  1.80s/it]

3087 3161
113 452


 28%|███████████▊                              | 28/100 [00:50<02:09,  1.80s/it]

3097 3161
137 452


 29%|████████████▏                             | 29/100 [00:51<02:06,  1.78s/it]

3097 3161
163 452


 30%|████████████▌                             | 30/100 [00:53<02:04,  1.78s/it]

3100 3161
147 452


 31%|█████████████                             | 31/100 [00:55<02:04,  1.80s/it]

3100 3161
159 452


 32%|█████████████▍                            | 32/100 [00:57<02:02,  1.80s/it]

3103 3161
148 452


 33%|█████████████▊                            | 33/100 [00:59<02:01,  1.82s/it]

3099 3161
159 452


 34%|██████████████▎                           | 34/100 [01:01<02:01,  1.83s/it]

3075 3161
173 452


 35%|██████████████▋                           | 35/100 [01:02<02:00,  1.86s/it]

3105 3161
166 452


 36%|███████████████                           | 36/100 [01:04<01:57,  1.84s/it]

3106 3161
171 452


 37%|███████████████▌                          | 37/100 [01:06<01:53,  1.81s/it]

3111 3161
184 452


 38%|███████████████▉                          | 38/100 [01:08<01:49,  1.77s/it]

3111 3161
189 452


 39%|████████████████▍                         | 39/100 [01:09<01:46,  1.74s/it]

3110 3161
181 452


 40%|████████████████▊                         | 40/100 [01:11<01:43,  1.72s/it]

3108 3161
175 452


 41%|█████████████████▏                        | 41/100 [01:13<01:40,  1.71s/it]

3114 3161
188 452


 42%|█████████████████▋                        | 42/100 [01:14<01:38,  1.69s/it]

3112 3161
204 452


 43%|██████████████████                        | 43/100 [01:16<01:35,  1.68s/it]

3110 3161
202 452


 44%|██████████████████▍                       | 44/100 [01:18<01:35,  1.71s/it]

3114 3161
200 452


 45%|██████████████████▉                       | 45/100 [01:20<01:34,  1.71s/it]

3115 3161
210 452


 46%|███████████████████▎                      | 46/100 [01:21<01:32,  1.72s/it]

3114 3161
194 452


 47%|███████████████████▋                      | 47/100 [01:23<01:30,  1.70s/it]

3119 3161
199 452


 48%|████████████████████▏                     | 48/100 [01:25<01:28,  1.71s/it]

3121 3161
200 452


 49%|████████████████████▌                     | 49/100 [01:26<01:26,  1.70s/it]

3118 3161
205 452


 50%|█████████████████████                     | 50/100 [01:28<01:24,  1.69s/it]

3121 3161
206 452


 51%|█████████████████████▍                    | 51/100 [01:30<01:23,  1.70s/it]

3117 3161
216 452


 52%|█████████████████████▊                    | 52/100 [01:32<01:23,  1.74s/it]

3111 3161
218 452


 53%|██████████████████████▎                   | 53/100 [01:33<01:20,  1.72s/it]

3121 3161
204 452


 54%|██████████████████████▋                   | 54/100 [01:35<01:18,  1.70s/it]

3117 3161
218 452


 55%|███████████████████████                   | 55/100 [01:37<01:16,  1.69s/it]

3120 3161
202 452


 56%|███████████████████████▌                  | 56/100 [01:38<01:14,  1.70s/it]

2724 3161
209 452


 57%|███████████████████████▉                  | 57/100 [01:40<01:12,  1.69s/it]

1724 3161
273 452


 58%|████████████████████████▎                 | 58/100 [01:42<01:11,  1.69s/it]

3014 3161
184 452


 59%|████████████████████████▊                 | 59/100 [01:43<01:09,  1.70s/it]

3105 3161
186 452


 60%|█████████████████████████▏                | 60/100 [01:45<01:08,  1.72s/it]

2295 3161
218 452


 61%|█████████████████████████▌                | 61/100 [01:47<01:06,  1.71s/it]

3102 3161
203 452


 62%|██████████████████████████                | 62/100 [01:49<01:06,  1.75s/it]

3050 3161
235 452


 63%|██████████████████████████▍               | 63/100 [01:50<01:05,  1.77s/it]

3111 3161
205 452


 64%|██████████████████████████▉               | 64/100 [01:52<01:02,  1.74s/it]

3117 3161
198 452


 65%|███████████████████████████▎              | 65/100 [01:54<01:00,  1.73s/it]

3120 3161
199 452


 66%|███████████████████████████▋              | 66/100 [01:56<00:58,  1.73s/it]

3122 3161
197 452


 67%|████████████████████████████▏             | 67/100 [01:57<00:56,  1.72s/it]

3122 3161
195 452


 68%|████████████████████████████▌             | 68/100 [01:59<00:54,  1.71s/it]

3123 3161
189 452


 69%|████████████████████████████▉             | 69/100 [02:01<00:52,  1.71s/it]

3125 3161
192 452


 70%|█████████████████████████████▍            | 70/100 [02:02<00:51,  1.71s/it]

3127 3161
194 452


 71%|█████████████████████████████▊            | 71/100 [02:04<00:49,  1.71s/it]

3127 3161
186 452


 72%|██████████████████████████████▏           | 72/100 [02:06<00:47,  1.70s/it]

3127 3161
190 452


 73%|██████████████████████████████▋           | 73/100 [02:07<00:46,  1.71s/it]

3128 3161
187 452


 74%|███████████████████████████████           | 74/100 [02:09<00:44,  1.70s/it]

3131 3161
187 452


 75%|███████████████████████████████▌          | 75/100 [02:11<00:42,  1.71s/it]

3129 3161
186 452


 76%|███████████████████████████████▉          | 76/100 [02:13<00:40,  1.70s/it]

3130 3161
187 452


 77%|████████████████████████████████▎         | 77/100 [02:14<00:39,  1.70s/it]

3129 3161
185 452


 78%|████████████████████████████████▊         | 78/100 [02:16<00:37,  1.69s/it]

3129 3161
185 452


 79%|█████████████████████████████████▏        | 79/100 [02:18<00:35,  1.69s/it]

3130 3161
185 452


 80%|█████████████████████████████████▌        | 80/100 [02:19<00:33,  1.69s/it]

3128 3161
185 452


In [None]:
correct = 0
total = 0
preds = model(X)
for pred,y_batch in zip(preds,y):
    pred = int(torch.argmax(pred))
    if pred == y_batch:
        correct += 1
    total += 1
print(correct,total)
acc = round(correct/total,3)*100

In [None]:
for pred in preds:
    print(pred)