In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
from nltk.tokenize import word_tokenize

In [2]:
vocab = torch.load('vocab.pth')
word_vectors = torch.load('word_vectors.pth')

  vocab = torch.load('vocab.pth')
  word_vectors = torch.load('word_vectors.pth')


In [3]:
class KimCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_filters, filter_sizes, num_classes, dropout=0.5):
        super(KimCNN, self).__init__()
        
        self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(word_vectors.vectors), freeze=False)
        self.convs = nn.ModuleList([nn.Conv2d(1, num_filters, (fs, embedding_dim)) for fs in filter_sizes])
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(num_filters * len(filter_sizes), num_classes)
        
    def forward(self, x):
        # for input of size (batch_size, sentence_length, embedding_dim)
        x = self.embedding(x)
        x = x.unsqueeze(1)
        # -> (batch_size, 1, sentence_length, embedding_dim)
        x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]
        # -> [(batch_size, num_filters, sentence_length - fs + 1) for fs in filter_sizes]
        x = [F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in x]
        # -> [(batch_size, num_filters) for fs in filter_sizes]
        x = torch.cat(x, 1)
        # -> (batch_size, num_filters * len(filter_sizes))
        x = self.dropout(x)
        x = self.fc(x)
        
        return x

In [6]:
model = KimCNN(len(vocab), 150, 100, [1, 2, 3, 4, 5, 6], 2)
model.load_state_dict(torch.load('kim_cnn.pth'))

  model.load_state_dict(torch.load('kim_cnn.pth'))


<All keys matched successfully>

In [7]:
nltk.download('stopwords')
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\84359\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\84359\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\84359\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [9]:
stopword = stopwords.words('english')
def preprocess(txt):
    # Remove URL
    txt = re.sub(r'((www\.[^\s]+)|(https?://[^\s]+))','', txt)
    # Also remove hashtags since in reddit they are a kind of link
    txt = re.sub(r'#([^\s]+)', '', txt)
    # Remove all the special characters
    txt = re.sub(r'\W', ' ', txt)
    # Remove numeric
    txt = re.sub(r"\d+", "", txt)
    # There are words that are glueTogetherSinceRedditCommentFormattingIsWeird
    match = re.search(r'[a-z][A-Z]', txt)
    if match:
        idx = match.start()
        txt = txt[:idx+1]+' '+txt[idx+1:]
    # remove all single characters
    txt = re.sub(r'\s+[a-zA-Z]\s+', ' ', txt)
    # Remove single characters from the start
    txt = re.sub(r'\^[a-zA-Z]\s+', ' ', txt) 
    # Substituting multiple spaces with single space
    txt = re.sub(r'\s+', ' ', txt, flags=re.I)
    # Converting to Lowercase
    txt = txt.lower()
    # Remove whitespace at begin and end string
    txt = txt.strip()

    txt = txt.split()
    txt = [word for word in txt if word not in stopword]
    txt = ' '.join(txt)

    stemmer = PorterStemmer()
    txt = txt.split()
    txt = [stemmer.stem(word) for word in txt]
    txt = ' '.join(txt)

    print(txt)

    return txt

In [10]:
PADDING = '<PAD>'
UNKNOWN = '<BRUH>'
MAX_LEN = 200
def prepare_data(sentence):
    sentence = preprocess(sentence)
    sentence = word_tokenize(sentence)
    res = []
    for word in sentence:
        if word not in vocab:
            res.append(vocab[UNKNOWN])
        else:
            res.append(vocab[word])
    if len(res) > MAX_LEN:
        res = res[:MAX_LEN]
    else:
        res += [vocab[PADDING]]*(MAX_LEN-len(res))
    return res

In [64]:
my_sentence = """
i hit myself in the balls accidentally
Gaming
ouch


"""
another_sentence = """
How to deal with the utter meaningless of it all?
Brain damage at a young age has left me with the inability to socialize. I hate my situation, i hate myself for my inability to change. I hate myself for my utter stupidity when it comes to social awareness.

The lack of meaningful relationships extending my entire life has left me dead inside. Been this way for a long time. What point is there in continuing to suffer? It never gets better. I'll continue to rot in my room for the remainder of my miserable life if I don't make a change, but change seems impossible for someone like me. And even if by some miracle I do change, would it even make me happy? I know, at the end of the day, nothing matters.

Suicide is a permanent solution to a temporary problem. Life is a temporary problem, sure. But my problems are permanent. Why suffer any longer? My life has zero value"""

In [83]:
# List of sentences
sentences = [
    """I’m tired of people pointing out how much I eat
Title says it all.

I hate it so much. Yesterday I ate one serving of something I had made (pasta casserole). A little while later, I ate two more servings because I was hungry again. My dad made a comment about it, asking me how many servings I had eaten.

I’ve been called fat since I was a child (even though I was underweight) by other people, and now I’m getting comments by my family.

It just sucks, and I wish it would stop.
""",
"""You are not just cooked.
You good sir are burned, fried, then deep-fried, grilled, boiled, roasted, fried again and then cooked"""
]

# Corresponding labels: 1 for suicidal, 0 for non-suicidal
sentences = [prepare_data(x) for x in sentences]

sentences = torch.tensor(sentences)

tire peopl point much eat titl say hate much yesterday ate one serv someth made pasta casserol littl later ate two serv hungri dad made comment ask mani serv eaten call fat sinc child even though underweight peopl get comment famili suck wish would stop
cook good sir burn fri deep fri grill boil roast fri cook


In [84]:
outputs = model(sentences)
_, predicted = torch.max(outputs, 1)
print(predicted)

tensor([1, 0])
