In [1]:
import nltk
from nltk.stem.porter import *
from torch.nn import *
from torch.optim import *
import numpy as np
import pandas as pd
import torch,torchvision
import random
from tqdm import *
from torch.utils.data import Dataset,DataLoader
stemmer = PorterStemmer()

In [2]:
def tokenize(sentence):
    return nltk.word_tokenize(sentence)

In [3]:
tokenize('$100')

['$', '100']

In [4]:
def stem(word):
    return stemmer.stem(word.lower())

In [5]:
stem('organic')

'organ'

In [6]:
def bag_of_words(tokenized_words,all_words):
    tokenized_words = [stem(w) for w in tokenized_words]
    bag = np.zeros(len(all_words))
    for idx,w in enumerate(all_words):
        if w in tokenized_words:
            bag[idx] = 1.0
    return bag

In [7]:
bag_of_words(['hi'],['how','hi'])

array([0., 1.])

In [8]:
data = pd.read_csv('./data.csv')

In [9]:
data = data[:1000]

In [10]:
X = data['Text']
y = data['Summary']

In [11]:
X_words = []
data = []
y_words = []

In [12]:
for X_batch,y_batch in tqdm(zip(X,y)):
    X_batch = tokenize(X_batch)
    y_batch = tokenize(y_batch)
    new_X = []
    new_y = []
    for Xb in X_batch:
        new_X.append(stem(Xb))
    for yb in y_batch:
        new_y.append(stem(yb))
    X_words.extend(new_X)
    y_words.extend(new_y)
    data.append([new_X,new_y])

1000it [00:01, 727.07it/s]


In [13]:
X_words = sorted(set(X_words))
y_words = sorted(set(y_words))

In [14]:
np.random.shuffle(data)

In [15]:
X = []
y = []

In [16]:
for X_batch,y_batch in tqdm(data):
    X.append(bag_of_words(X_batch,X_words))
    y.append(bag_of_words(y_batch,y_words))

100%|██████████████████████████████████████| 1000/1000 [00:03<00:00, 262.65it/s]


In [17]:
from sklearn.model_selection import * 

In [18]:
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.25,shuffle=False)

In [19]:
device = 'cuda'

In [20]:
X_train = torch.from_numpy(np.array(X_train)).to(device).float()
y_train = torch.from_numpy(np.array(y_train)).to(device).float()
X_test = torch.from_numpy(np.array(X_test)).to(device).float()
y_test = torch.from_numpy(np.array(y_test)).to(device).float()

In [21]:
# torch.save(X_train,'X_train.pt')
# torch.save(X_test,'X_test.pth')
# torch.save(y_train,'y_train.pt')
# torch.save(y_test,'y_test.pth')
# torch.save(X,'X.pt')
# torch.save(X,'X.pth')
# torch.save(y,'y.pt')
# torch.save(y,'y.pth')

In [22]:
# torch.save(X_words,'X_words.pt')
# torch.save(X_words,'X_words.pth')
# torch.save(data,'data.pt')
# torch.save(data,'data.pth')
# torch.save(y_words,'y_words.pt')
# torch.save(y_words,'y_words.pth')

In [23]:
def get_accuracy(model,X,y):
    preds = model(X)
    correct = 0
    total = 0
    for pred,yb in zip(preds,y):
        for pred_in_pred,yb_in_yb in zip(pred,yb):
            pred_in_pred = int(torch.argmax(pred_in_pred))
            yb_in_yb = int(yb_in_yb)
            if pred_in_pred == yb_in_yb:
                correct += 1
            total += 1
    acc = round(correct/total,3)*100
    return acc

In [24]:
def get_loss(model,X,y,criterion):
    preds = model(X)
    loss = criterion(preds,y)
    return loss.item()

In [25]:
class Model(Module):
    def __init__(self):
        super().__init__()
        self.activation = ReLU()
        self.iters = 10
        self.linear1 = Linear(len(X_words),256)
        self.linear2 = Linear(256,256)
        self.linear2bn = BatchNorm1d(256)
        self.output = Linear(256,len(y_words))
    
    def forward(self,X):
        preds = self.linear1(X)
        for _ in range(self.iters):
            preds = self.activation(self.linear2bn(self.linear2(preds)))
        preds = self.output(preds)
        return preds

In [26]:
model = Model().to(device)
criterion = MSELoss()
optimizer = Adam(model.parameters(),lr=0.001)
batch_size = 32
epochs = 100

In [27]:
import wandb
PROJECT_NAME = 'Summarize-Text-Review'

In [28]:
def matrix_to_words(words,matrix):
    word = []
    for idx,m in enumerate(matrix):
        m = int(torch.argmax(m))
#         print(m)
        if m == 1:
            word.append(words[idx])
    return word

In [None]:
wandb.init(project=PROJECT_NAME,name='baseline')
for _ in tqdm(range(epochs)):
    for i in range(0,len(X_train),batch_size):
        X_batch = X_train[i:i+batch_size].to(device)
        y_batch = y_train[i:i+batch_size].to(device)
        preds = model(X_batch)
        loss = criterion(preds,y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    model.eval()
    torch.cuda.empty_cache()
    wandb.log({'Loss':get_loss(model,X_train,y_train,criterion)})
    torch.cuda.empty_cache()
    wandb.log({'Val Loss':get_loss(model,X_test,y_test,criterion)})
    torch.cuda.empty_cache()
    wandb.log({'Acc':get_accuracy(model,X_train,y_train)})
    torch.cuda.empty_cache()
    wandb.log({'Val Acc':get_accuracy(model,X_test,y_test)})
    torch.cuda.empty_cache()
    model.train()
    print(matrix_to_words(y_words,preds[0]))
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mranuga-d[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


  1%|▍                                          | 1/100 [00:14<24:29, 14.85s/it]

[]


  2%|▊                                          | 2/100 [00:29<24:07, 14.77s/it]

[]


  3%|█▎                                         | 3/100 [00:44<24:07, 14.93s/it]

[]


  4%|█▋                                         | 4/100 [00:59<23:55, 14.95s/it]

[]


  5%|██▏                                        | 5/100 [01:14<23:40, 14.96s/it]

[]


  6%|██▌                                        | 6/100 [01:29<23:23, 14.93s/it]

[]


  7%|███                                        | 7/100 [01:44<23:09, 14.94s/it]

[]


  8%|███▍                                       | 8/100 [01:59<22:55, 14.95s/it]

[]


  9%|███▊                                       | 9/100 [02:14<22:37, 14.92s/it]

[]


 10%|████▏                                     | 10/100 [02:29<22:21, 14.91s/it]

[]


 11%|████▌                                     | 11/100 [02:44<22:07, 14.92s/it]

[]


 12%|█████                                     | 12/100 [03:00<22:21, 15.25s/it]

[]


 13%|█████▍                                    | 13/100 [03:17<23:00, 15.86s/it]

[]


 14%|█████▉                                    | 14/100 [03:32<22:20, 15.59s/it]

[]


 15%|██████▎                                   | 15/100 [03:47<21:46, 15.38s/it]

[]


 16%|██████▋                                   | 16/100 [04:02<21:21, 15.26s/it]

[]


 17%|███████▏                                  | 17/100 [04:17<20:56, 15.14s/it]

[]


 18%|███████▌                                  | 18/100 [04:32<20:36, 15.08s/it]

[]


 19%|███████▉                                  | 19/100 [04:46<20:17, 15.03s/it]

[]


 20%|████████▍                                 | 20/100 [05:01<20:01, 15.02s/it]

[]


 21%|████████▊                                 | 21/100 [05:16<19:43, 14.98s/it]

[]


 22%|█████████▏                                | 22/100 [05:31<19:26, 14.95s/it]

[]


 23%|█████████▋                                | 23/100 [05:48<19:46, 15.41s/it]

[]


 24%|██████████                                | 24/100 [06:03<19:26, 15.35s/it]

[]


 25%|██████████▌                               | 25/100 [06:18<19:08, 15.32s/it]

[]


 26%|██████████▉                               | 26/100 [06:34<19:00, 15.41s/it]

[]


 27%|███████████▎                              | 27/100 [06:50<18:56, 15.58s/it]

[]


 28%|███████████▊                              | 28/100 [07:05<18:39, 15.55s/it]

[]


 29%|████████████▏                             | 29/100 [07:20<18:09, 15.35s/it]

[]


 30%|████████████▌                             | 30/100 [07:35<17:44, 15.20s/it]

[]


 31%|█████████████                             | 31/100 [07:50<17:22, 15.11s/it]

[]


 32%|█████████████▍                            | 32/100 [08:05<17:04, 15.06s/it]

[]


 33%|█████████████▊                            | 33/100 [08:20<16:45, 15.01s/it]

[]


 34%|██████████████▎                           | 34/100 [08:35<16:31, 15.03s/it]

[]


 35%|██████████████▋                           | 35/100 [08:50<16:13, 14.98s/it]

[]


 36%|███████████████                           | 36/100 [09:05<15:58, 14.97s/it]

[]


 37%|███████████████▌                          | 37/100 [09:19<15:41, 14.95s/it]

[]


 38%|███████████████▉                          | 38/100 [09:35<15:28, 14.97s/it]

[]


 39%|████████████████▍                         | 39/100 [09:49<15:11, 14.94s/it]

[]


 40%|████████████████▊                         | 40/100 [10:04<14:54, 14.91s/it]

[]


 41%|█████████████████▏                        | 41/100 [10:19<14:40, 14.92s/it]

[]


 42%|█████████████████▋                        | 42/100 [10:35<14:39, 15.16s/it]

[]


 43%|██████████████████                        | 43/100 [10:51<14:39, 15.42s/it]

[]


 44%|██████████████████▍                       | 44/100 [11:07<14:30, 15.55s/it]

[]


 45%|██████████████████▉                       | 45/100 [11:22<14:08, 15.42s/it]

[]


 46%|███████████████████▎                      | 46/100 [11:37<13:43, 15.26s/it]

[]


 47%|███████████████████▋                      | 47/100 [11:53<13:37, 15.43s/it]

[]


 48%|████████████████████▏                     | 48/100 [12:08<13:25, 15.50s/it]

[]


 49%|████████████████████▌                     | 49/100 [12:24<13:15, 15.59s/it]

[]


 50%|█████████████████████                     | 50/100 [12:40<13:04, 15.70s/it]

[]


 51%|█████████████████████▍                    | 51/100 [12:56<12:52, 15.76s/it]

[]


 52%|█████████████████████▊                    | 52/100 [13:11<12:29, 15.62s/it]

[]


 53%|██████████████████████▎                   | 53/100 [13:27<12:18, 15.71s/it]

[]


 54%|██████████████████████▋                   | 54/100 [13:44<12:15, 15.98s/it]

[]


 55%|███████████████████████                   | 55/100 [14:00<12:02, 16.05s/it]

[]


 56%|███████████████████████▌                  | 56/100 [14:16<11:49, 16.13s/it]

[]


 57%|███████████████████████▉                  | 57/100 [14:33<11:40, 16.29s/it]

[]


 58%|████████████████████████▎                 | 58/100 [14:49<11:26, 16.35s/it]

[]


 59%|████████████████████████▊                 | 59/100 [15:06<11:08, 16.32s/it]

[]


 60%|█████████████████████████▏                | 60/100 [15:22<10:55, 16.39s/it]

[]


 61%|█████████████████████████▌                | 61/100 [15:38<10:33, 16.25s/it]

[]


 62%|██████████████████████████                | 62/100 [15:54<10:15, 16.20s/it]

[]


 63%|██████████████████████████▍               | 63/100 [16:09<09:44, 15.81s/it]

[]


 64%|██████████████████████████▉               | 64/100 [16:24<09:19, 15.54s/it]

[]


 65%|███████████████████████████▎              | 65/100 [16:39<08:57, 15.36s/it]

[]


 66%|███████████████████████████▋              | 66/100 [16:54<08:37, 15.22s/it]

[]


 67%|████████████████████████████▏             | 67/100 [17:09<08:18, 15.12s/it]

[]


 68%|████████████████████████████▌             | 68/100 [17:24<08:02, 15.06s/it]

[]


 69%|████████████████████████████▉             | 69/100 [17:39<07:45, 15.01s/it]

[]


 70%|█████████████████████████████▍            | 70/100 [17:53<07:29, 14.97s/it]

[]


 71%|█████████████████████████████▊            | 71/100 [18:08<07:13, 14.95s/it]

[]


 72%|██████████████████████████████▏           | 72/100 [18:23<06:57, 14.92s/it]

[]


 73%|██████████████████████████████▋           | 73/100 [18:38<06:42, 14.93s/it]

[]


 74%|███████████████████████████████           | 74/100 [18:54<06:37, 15.29s/it]

[]


 75%|███████████████████████████████▌          | 75/100 [19:10<06:22, 15.29s/it]

[]


 76%|███████████████████████████████▉          | 76/100 [19:25<06:04, 15.19s/it]

[]


 77%|████████████████████████████████▎         | 77/100 [19:42<06:03, 15.80s/it]

[]


 78%|████████████████████████████████▊         | 78/100 [19:58<05:47, 15.80s/it]

[]


 79%|█████████████████████████████████▏        | 79/100 [20:13<05:28, 15.65s/it]

[]


 80%|█████████████████████████████████▌        | 80/100 [20:28<05:12, 15.62s/it]

[]


 81%|██████████████████████████████████        | 81/100 [20:44<04:56, 15.62s/it]

[]


 82%|██████████████████████████████████▍       | 82/100 [20:59<04:40, 15.57s/it]

[]


 83%|██████████████████████████████████▊       | 83/100 [21:15<04:23, 15.47s/it]

[]


 84%|███████████████████████████████████▎      | 84/100 [21:30<04:08, 15.54s/it]

[]


 85%|███████████████████████████████████▋      | 85/100 [21:46<03:54, 15.64s/it]

[]


 86%|████████████████████████████████████      | 86/100 [22:02<03:40, 15.72s/it]

[]


 87%|████████████████████████████████████▌     | 87/100 [22:17<03:22, 15.55s/it]

[]


 88%|████████████████████████████████████▉     | 88/100 [22:32<03:04, 15.41s/it]

[]


 89%|█████████████████████████████████████▍    | 89/100 [22:48<02:51, 15.56s/it]

[]


 90%|█████████████████████████████████████▊    | 90/100 [23:05<02:39, 15.96s/it]

[]


 91%|██████████████████████████████████████▏   | 91/100 [23:22<02:26, 16.24s/it]

[]


 92%|██████████████████████████████████████▋   | 92/100 [23:39<02:10, 16.36s/it]

[]


 93%|███████████████████████████████████████   | 93/100 [23:56<01:56, 16.59s/it]

[]


 94%|███████████████████████████████████████▍  | 94/100 [24:12<01:39, 16.59s/it]

[]


 95%|███████████████████████████████████████▉  | 95/100 [24:27<01:20, 16.05s/it]

[]


 96%|████████████████████████████████████████▎ | 96/100 [24:42<01:02, 15.70s/it]

[]


 97%|████████████████████████████████████████▋ | 97/100 [24:57<00:46, 15.51s/it]

[]


 98%|█████████████████████████████████████████▏| 98/100 [25:12<00:30, 15.31s/it]

[]


 99%|█████████████████████████████████████████▌| 99/100 [25:27<00:15, 15.33s/it]

[]


In [None]:
torch.save(model,'model.pt')
torch.save(model,'model.pth')
torch.save(model.state_dict(),'model-sd.pt')
torch.save(model.state_dict(),'model-sd.pth')

In [None]:
matrix_to_words(y_words,preds[0])

In [None]:
word = []
for idx,m in enumerate(preds[5]):
    print(m)
    m = int(torch.argmax(m))
#   print(m)
    if m == 1:
        word.append(y_words[idx])