In [2]:
import codecs
import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from sklearn.utils import shuffle
from gensim.models.keyedvectors import KeyedVectors
from gensim.test.utils import datapath
import numpy as np
import argparse
import copy
import operator
import re

In [4]:
def load_data(fpath,label):
    data = []
    with codecs.open(fpath,'r','utf-8',errors='ignore') as f:
        lines = f.readlines()
        for string in lines:
            string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
            string = re.sub(r"\'s", " \'s", string)
            string = re.sub(r"\'ve", " \'ve", string)
            string = re.sub(r"n\'t", " n\'t", string)
            string = re.sub(r"\'re", " \'re", string)
            string = re.sub(r"\'d", " \'d", string)
            string = re.sub(r"\'ll", " \'ll", string)
            string = re.sub(r",", " , ", string)
            string = re.sub(r"!", " ! ", string)
            string = re.sub(r"\(", " \( ", string)
            string = re.sub(r"\)", " \) ", string)
            string = re.sub(r"\?", " \? ", string)
            string = re.sub(r"\s{2,}", " ", string)
            string = string.rstrip()
            data.append((string.split(' '),label))
    return data
pos = load_data('../sentence/dataset/rt-polarity.pos',1)
neg = load_data('../sentence/dataset/rt-polarity.neg',0)
print(len(pos))
print(len(neg))
data = pos+neg
print(len(data))

5331
5331
10662


In [19]:
max_sentence_len = max([len(sentence) for sentence, _ in data])

vocab = []

for d, _ in data:
    for w in d:
        if w not in vocab:
            vocab.append(w)
vocab = sorted(vocab)
vocab_size = len(vocab)

w2i = {w:i for i,w in enumerate(vocab)}
i2w = {i:w for i,w in enumerate(vocab)}
word_vectors = KeyedVectors.load_word2vec_format('../sentence/dataset/GoogleNews-vectors-negative300.bin', binary=True)
wv_matrix = []
for i in range(len(vocab)):
            word = i2w[i]
            if word in word_vectors.vocab:
                wv_matrix.append(word_vectors.word_vec(word))
            else:
                wv_matrix.append(np.random.uniform(-0.01, 0.01, 300).astype("float32"))

wv_matrix.append(np.random.uniform(-0.01, 0.01, 300).astype("float32"))
wv_matrix.append(np.zeros(300).astype("float32"))
wv_matrix = np.array(wv_matrix)

In [111]:
div_idx = (int)(len(data)*0.8)
random.shuffle(data)
train_data = data[:div_idx]
test_data = data[div_idx:]

In [112]:
class Net(nn.Module):
    def __init__(self,vocab_size,embd_size,out_chs,filter_heights,pretrained_vec):
        super(Net,self).__init__()
        self.embedding = nn.Embedding(vocab_size+2,embd_size)
        self.embedding.weight.data.copy_(torch.from_numpy(pretrained_vec))
        self.conv = nn.ModuleList([nn.Conv2d(1,out_chs,(fh,embd_size),padding = (fh-1,0)) for fh in filter_heights])
        self.dropout = nn.Dropout(.5)
        self.fc1 = nn.Linear(out_chs*len(filter_heights),1)
        
    def forward(self,x):
        embeding = self.embedding(x)
        x = embeding.unsqueeze(1)
        conved = [F.relu(conv(x)).squeeze(3) for conv in self.conv]
        x = [F.avg_pool1d(i,i.size(2)).squeeze(2) for i in conved]
        features = torch.cat(x,1)
        x = self.dropout(features)
        x = self.fc1(x)
        probs = F.sigmoid(x)
        return probs, features, conved

In [113]:
def train(model,data,batch_size,n_epoch):
    model.train()
    if use_cuda:
        model.cuda()
    losses = []
    optimizer = torch.optim.Adadelta(model.parameters(),lr = 0.5)
    for epoch in range(n_epoch):
        epoch_loss = 0.0
        random.shuffle(data)
        for i in range(0,len(data)-batch_size,batch_size):
            in_data,labels = [],[]
            for sentence, label in data[i:i+batch_size]:
                index_vec = [w2i[w] for w in sentence]
                pad_len = max(0,max_sentence_len - len(index_vec))
                index_vec +=[0]*pad_len
                index_vec = index_vec[:max_sentence_len]
                in_data.append(index_vec)
                labels.append(label)
            sent_var = Variable(torch.LongTensor(in_data))
            if use_cuda:
                sent_var = sent_var.cuda()
            target_var = Variable(torch.Tensor(labels).unsqueeze(1))
            if use_cuda:
                target_var = target_var.cuda()
            optimizer.zero_grad()
            probs,features,conved = model(sent_var)
            loss = F.binary_cross_entropy(probs,target_var)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.data[0]
        
        print('epoch : {:d},loss : {:.3f}'.format(epoch,epoch_loss))
        losses.append(epoch_loss)
    print('Training avg loss : {:.3f}'.format(sum(losses)/len(losses)))
    
    return model, losses

def test(model,data, n_test,min_sentence_len):
    model.eval()
    loss = 0
    correct = 0
    params = list(model.parameters())
    print(len(params[-2][-0]))
    #for sentence,label in data[:n_test]:
    for sentence,label in data[:200]:
        if len(sentence) < min_sentence_len:
            continue
        index_vec = [w2i[w] for w in sentence]
        sent_var = Variable(torch.LongTensor([index_vec]))
        if use_cuda:
            sent_var = sent_var.cuda()
        out,features,conved = model(sent_var)
        score = out.data[0][0]
        pred = 1 if score > .5 else 0
        if pred == label:
            correct += 1
        loss += math.pow((label-score),2)
        print('pred : {:d}'.format(pred))
        print('label : {:d}'.format(label))
        word_weight=[]
        for i in range(0,300):
            if i<100:
                for start in range(2,len(conved[0][0][0])):
                    conved[0][0][i][start] = conved[0][0][i][start]*params[-2][0][i]
            if i>=100 and i<200:
                for start in range(3,len(conved[1][0][0])):
                    conved[1][0][i-100][start] = conved[1][0][i-100][start]*params[-2][0][i]
            if i>=200 and i<300:
                for start in range(4,len(conved[2][0][0])):
                    conved[2][0][i-200][start] = conved[2][0][i-200][start]*params[-2][0][i]
        for length in range(0,len(sentence)):
            sum = 0
            for i in range(0,300):
                if i<100:
                    sum = sum + conved[0][0][i][length+2]
                if i>=100 and i<200:
                    sum = sum + conved[1][0][i-100][length+3]
                if i>=200 and i<300:
                    sum = sum + conved[2][0][i-200][length+4]
            word_weight.append(sum)
    
        for i in range(0,len(sentence)):
            word_weight[i] = word_weight[i].cpu().data.numpy()
            word_weight[i] = word_weight[i].tolist()
            if pred == 0:
                 word_weight[i] = word_weight[i]*-1
            word_weight[i] = round(word_weight[i],6)
        print(sentence)
        print(word_weight)
        print('')
    #print('Test acc : {:.3f} ({:d}/{:d})'.format(correct/n_test,correct,n_test))
    #print('Test loss : {:.3f}'.format(loss/n_test))
    
    

out_ch = 100
embd_size = 300
batch_size = 50
n_epoch = 50
filter_size = [3,4,5]
print('filter : ',filter_size)
use_cuda = torch.cuda.is_available()
model = Net(vocab_size,embd_size,out_ch,filter_size,wv_matrix)
model,losses = train(model,train_data,batch_size,n_epoch)
test(model,test_data,len(test_data),max(filter_size))
print('')

filter :  [3, 4, 5]




epoch : 0,loss : 117.552
epoch : 1,loss : 116.254
epoch : 2,loss : 110.947
epoch : 3,loss : 100.278
epoch : 4,loss : 91.000
epoch : 5,loss : 83.339
epoch : 6,loss : 77.567
epoch : 7,loss : 73.313
epoch : 8,loss : 68.729
epoch : 9,loss : 64.493
epoch : 10,loss : 61.576
epoch : 11,loss : 57.561
epoch : 12,loss : 54.081
epoch : 13,loss : 50.563
epoch : 14,loss : 47.758
epoch : 15,loss : 44.000
epoch : 16,loss : 40.747
epoch : 17,loss : 38.484
epoch : 18,loss : 35.272
epoch : 19,loss : 33.027
epoch : 20,loss : 30.299
epoch : 21,loss : 28.240
epoch : 22,loss : 25.838
epoch : 23,loss : 23.330
epoch : 24,loss : 21.695
epoch : 25,loss : 19.208
epoch : 26,loss : 17.991
epoch : 27,loss : 16.957
epoch : 28,loss : 15.326
epoch : 29,loss : 14.323
epoch : 30,loss : 13.015
epoch : 31,loss : 12.234
epoch : 32,loss : 11.237
epoch : 33,loss : 10.159
epoch : 34,loss : 9.630
epoch : 35,loss : 8.780
epoch : 36,loss : 8.399
epoch : 37,loss : 7.721
epoch : 38,loss : 7.270
epoch : 39,loss : 6.994
epoch : 40,l

['the', 'concert', 'footage', 'is', 'stirring', ',', 'the', 'recording', 'sessions', 'are', 'intriguing', ',', 'and', 'on', 'the', 'way', 'to', 'striking', 'a', 'blow', 'for', 'artistic', 'integrity', 'this', 'quality', 'band', 'may', 'pick', 'up', 'new', 'admirers']
[-19.775461, -4.761189, 28.675604, 15.541658, -15.870186, -44.171337, -40.922565, -24.431423, 23.952269, 66.865341, 77.631668, 29.997263, 31.127552, 27.335131, 32.343876, 20.333517, -14.710588, -27.804474, -44.824253, -43.949875, 46.230282, 83.587845, 117.135536, 98.333328, 107.537262, 87.108147, 80.93438, 87.961456, 84.568161, 68.965927, 57.683598]

pred : 1
label : 1
['ludicrous', ',', 'but', 'director', 'carl', 'franklin', 'adds', 'enough', 'flourishes', 'and', 'freak', 'outs', 'to', 'make', 'it', 'entertaining']
[-65.095505, -19.011528, -42.948685, -51.745415, -40.352905, -0.22601, 55.244606, 52.696487, 63.497169, 32.235767, 21.251501, 22.486012, 41.847118, 63.606621, 58.391548, 60.332123]

pred : 0
label : 1
['like', 

['a', 'full', 'experience', ',', 'a', 'love', 'story', 'and', 'a', 'murder', 'mystery', 'that', 'expands', 'into', 'a', 'meditation', 'on', 'the', 'deep', 'deceptions', 'of', 'innocence']
[12.734287, 20.651222, 28.557899, 37.54327, 39.710297, 44.887192, 10.266661, -19.806511, -32.023792, -6.864705, -3.430657, 58.693054, 61.815979, 11.9614, 17.833763, 40.670776, 57.857929, 71.905624, 82.446899, 34.249683, 22.744511, 27.213842]

pred : 0
label : 1
['lends', 'itself', 'to', 'the', 'narcotizing', 'bland', '\\(', 'sinister', ',', 'though', 'not', 'nearly', 'so', 'sinister', 'as', 'the', 'biennial', 'disney', 'girl', 'movie', '\\)', 'machinations', 'of', 'the', 'biennial', 'disney', 'boy', 'movie']
[7.810628, 55.451157, 67.506668, 92.40815, 92.784996, 94.082092, -8.62385, -14.196405, 2.613597, 13.04607, 24.681719, 26.16189, 17.144817, 10.874607, 19.37583, 26.761429, 39.150127, 48.240398, 37.263596, 41.268272, 37.323162, 21.014545, 12.280672, 22.614151, 26.781193, 24.129211, -3.93845, 11.3645

['nair', 'does', "n't", 'use', 'monsoon', 'wedding', 'to', 'lament', 'the', 'loss', 'of', 'culture', 'instead', ',', 'she', 'sees', 'it', 'as', 'a', 'chance', 'to', 'revitalize', 'what', 'is', 'and', 'always', 'has', 'been', 'remarkable', 'about', 'clung', 'to', 'traditions']
[-30.93335, -63.870018, -62.010502, -71.516602, -79.448845, -64.849449, -40.387165, -8.381053, -0.546739, -13.411242, -24.285257, -19.160864, -51.408447, 21.332594, 23.304125, 53.047966, 14.944735, 2.966037, 6.079129, -3.911969, -24.184793, 22.960913, 75.160461, 76.296089, 89.735168, 104.834068, 57.784378, 24.424395, 45.709045, -30.459627, -14.467124, 27.735495, 38.570415]

pred : 1
label : 1
['offers', 'a', 'clear', 'eyed', 'chronicle', 'of', 'a', 'female', 'friendship', 'that', 'is', 'more', 'complex', 'and', 'honest', 'than', 'anything', 'represented', 'in', 'a', 'hollywood', 'film']
[84.009377, 56.620201, 59.720066, 52.795532, 57.841591, 26.285856, 37.794395, 39.500225, 46.132359, 37.567501, 51.811409, 45.0287

['the', 'pool', 'drowned', 'me', 'in', 'boredom']
[98.714615, 123.597679, 109.017609, 58.862667, 48.487576, 45.72884]

pred : 0
label : 0
['one', 'thing', 'is', 'for', 'sure', 'this', 'movie', 'does', 'not', 'tell', 'you', 'a', 'whole', 'lot', 'about', 'lily', 'chou', 'chou']
[41.716755, 40.565186, 0.757238, 17.406513, 31.480373, 42.017353, 26.502924, 12.739807, 1.682489, -1.632338, -8.678026, 42.126663, 79.741829, 102.331406, 142.136917, 160.881409, 111.818596, 59.515358]

pred : 0
label : 0
['a', 'depraved', ',', 'incoherent', ',', 'instantly', 'disposable', 'piece', 'of', 'hackery']
[108.199379, 148.808197, 129.54837, 156.924255, 48.198383, 65.931808, 50.792976, 2.456877, 40.459179, 42.679466]

pred : 1
label : 0
['its', 'solemn', 'pretension', 'prevents', 'us', 'from', 'sharing', 'the', 'awe', 'in', 'which', 'it', 'holds', 'itself']
[-77.297302, -87.418282, -18.233358, 54.678391, 73.43338, 67.432671, 62.606937, 49.475235, 57.822525, 13.415379, 16.087408, 9.719805, 7.405597, -11.220

['priggish', ',', 'lethargically', 'paced', 'parable', 'of', 'renewal']
[97.129181, 42.675289, 52.24213, 20.060371, 14.036299, -18.003242, -24.290138]

pred : 1
label : 1
['a', 'clever', 'blend', 'of', 'fact', 'and', 'fiction']
[10.06134, 10.538591, 30.677319, 12.576126, 13.663584, 24.599276, 13.289375]

pred : 1
label : 0
['the', 'man', 'from', 'elysian', 'fields', 'is', 'a', 'cold', ',', 'bliss', 'less', 'work', 'that', 'groans', 'along', 'thinking', 'itself', 'some', 'important', 'comment', 'on', 'how', 'life', 'throws', 'us', 'some', 'beguiling', 'curves']
[10.298983, 4.597484, -6.040841, -19.605362, -12.907642, 42.118866, 29.073524, 13.689567, 2.043786, -29.209835, -58.562771, -59.465939, -75.098526, -104.826355, -55.221619, -81.663414, -42.505405, -45.830818, -48.55122, -51.069412, 9.224797, 11.053307, 71.614693, 106.590492, 123.107925, 127.922508, 151.279877, 81.033669]

pred : 0
label : 0
['will', 'no', 'doubt', 'delight', 'plympton', "'s", 'legion', 'of', 'fans', 'others', 'ma

['one', 'of', 'the', 'most', 'incoherent', 'features', 'in', 'recent', 'memory']
[48.40131, 66.542595, 99.851425, 99.452072, 97.071571, -71.684395, -55.252136, -48.22831, -36.100067]

pred : 1
label : 1
['a', 'moving', 'and', 'weighty', 'depiction', 'of', 'one', 'family', "'s", 'attempts', 'to', 'heal', 'after', 'the', 'death', 'of', 'a', 'child']
[87.236977, 104.213509, 81.182571, 75.926727, 55.40004, -7.533272, -23.797939, -23.196154, -50.523262, -64.323418, -33.860207, -33.10535, -55.804977, 0.565491, 19.586098, 33.916344, 36.062855, 32.561749]

pred : 0
label : 0
['directed', 'in', 'a', 'flashy', ',', 'empty', 'sub', 'music', 'video', 'style', 'by', 'a', 'director', 'so', 'self', 'possessed', 'he', 'actually', 'adds', 'a', 'period', 'to', 'his', 'first', 'name']
[15.888957, 51.014866, 82.814583, 117.973396, 109.476143, 122.829147, 60.423786, 22.218874, 33.154377, 1.536284, 32.353687, 39.242283, 44.080616, 44.375607, 29.576357, 17.867031, 15.541459, 13.433466, 9.808824, 11.515253, 1

['the', 'piece', 'plays', 'as', 'well', 'as', 'it', 'does', 'thanks', 'in', 'large', 'measure', 'to', 'anspaugh', "'s", 'three', 'lead', 'actresses']
[9.388214, 15.470032, -18.175476, -2.644384, 18.93408, 34.875416, 62.009754, 70.725113, 73.05542, 11.575459, 9.075266, -5.495347, -24.464399, -34.003395, -39.830276, -46.913788, -30.112589, 5.321875]

pred : 0
label : 0
['completely', 'awful', 'iranian', 'drama', 'as', 'much', 'fun', 'as', 'a', 'grouchy', 'ayatollah', 'in', 'a', 'cold', 'mosque']
[-33.414246, -33.61615, -100.108955, -65.276611, -34.646763, -11.413033, 6.202788, 54.747162, 53.088245, 59.247379, 18.247732, 35.299099, 49.033085, 47.380692, 45.393181]

pred : 0
label : 0
['hardman', 'is', 'a', 'grating', ',', 'mannered', 'onscreen', 'presence', ',', 'which', 'is', 'especially', 'unfortunate', 'in', 'light', 'of', 'the', 'fine', 'work', 'done', 'by', 'most', 'of', 'the', 'rest', 'of', 'her', 'cast']
[37.139267, 71.029335, 84.386826, 107.323471, 62.081017, 76.880325, 51.503254,

['can', 'be', 'viewed', 'as', 'pure', 'composition', 'and', 'form', 'film', 'as', 'music']
[-8.026393, 0.294643, 36.176365, 68.944138, 96.796112, 91.797829, 64.684456, 58.603737, 15.357924, 8.061474, 15.773766]

pred : 1
label : 0
['a', 'minor', 'league', 'soccer', 'remake', 'of', 'the', 'longest', 'yard']
[43.250168, 42.104633, -36.046959, -21.530901, -6.016651, -19.57913, -22.922171, -17.476646, -21.313208]

pred : 1
label : 1
['if', 'the', 'man', 'from', 'elysian', 'fields', 'is', 'doomed', 'by', 'its', 'smallness', ',', 'it', 'is', 'also', 'elevated', 'by', 'it', 'the', 'kind', 'of', 'movie', 'that', 'you', 'enjoy', 'more', 'because', 'you', "'re", 'one', 'of', 'the', 'lucky', 'few', 'who', 'sought', 'it', 'out']
[23.630943, 10.298983, 4.597484, -37.523602, -65.159805, -79.158859, -47.751202, -38.725937, 23.391382, 30.391415, 41.341568, 27.619581, 26.398479, 20.425364, 13.077416, 4.572387, 1.166552, 0.752268, -2.441089, -4.698341, 19.206852, 30.448565, 53.686047, 67.135475, 72.4805

['the', 'movie', 'itself', 'appears', 'to', 'be', 'running', 'on', 'hypertime', 'in', 'reverse', 'as', 'the', 'truly', 'funny', 'bits', 'get', 'further', 'and', 'further', 'apart']
[40.270363, 40.122402, 35.96986, 24.112505, 4.600741, 4.395751, -8.213526, 1.307106, 6.622415, 20.469351, 10.062926, 9.029055, -1.355655, 6.768156, -37.825542, 15.237713, 12.565208, 33.346146, 29.068525, 40.45607, 31.831673]

pred : 0
label : 0
['every', 'nanosecond', 'of', 'the', 'the', 'new', 'guy', 'reminds', 'you', 'that', 'you', 'could', 'be', 'doing', 'something', 'else', 'far', 'more', 'pleasurable', 'something', 'like', 'scrubbing', 'the', 'toilet', 'or', 'emptying', 'rat', 'traps', 'or', 'doing', 'last', 'year', "'s", 'taxes', 'with', 'your', 'ex', 'wife']
[-7.618098, -20.739069, 5.750508, -14.280689, -39.384979, -63.17403, -68.805222, -88.94558, -18.116625, -2.742, 2.646477, 37.785397, 56.571072, 76.550728, 28.12365, 23.503851, -24.74507, -38.046623, -50.912174, 41.862831, 60.728817, 73.06778, 59.4

['saved', 'from', 'being', 'merely', 'way', 'cool', 'by', 'a', 'basic', ',', 'credible', 'compassion']
[-37.966232, -4.982893, -5.375708, 24.385378, 71.274452, 68.37944, 28.461254, 42.034107, 46.107384, 36.476795, 32.664951, -6.68827]

pred : 0
label : 0
['too', 'bad', 'writer', 'director', 'adam', 'rifkin', 'situates', 'it', 'all', 'in', 'a', 'plot', 'as', 'musty', 'as', 'one', 'of', 'the', 'golden', 'eagle', "'s", 'carpets']
[106.914352, 67.699287, 10.029917, -13.793454, -15.930635, -23.533785, -28.461119, 7.413605, 9.551571, 28.736135, 42.111092, 62.442066, 66.668564, 68.683578, 4.776466, -11.438304, -36.34956, -33.950726, -33.574532, -1.137913, 18.829786, 20.598867]

pred : 0
label : 0
['coughs', 'and', 'sputters', 'on', 'its', 'own', 'postmodern', 'conceit']
[76.093666, 78.863495, 71.201706, -16.404709, -14.167777, 4.387042, -27.229115, 34.656372]

pred : 1
label : 1
['with', 'a', 'large', 'cast', 'representing', 'a', 'broad', 'cross', 'section', ',', 'tavernier', "'s", 'film', 'b