In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchtext.data as data
import torchtext.datasets as datasets
import pickle

print(torch.__version__)

0.3.1.post2


In [2]:
class CNN_Text(nn.Module):
    
    def __init__(self, embed_num, class_num):
        super(CNN_Text, self).__init__()
        V = embed_num
        C = class_num
        Co = 50 #args.kernel_num
        Ks = [2,3,4]

        self.embed = nn.Embedding(V, 100)
        self.convs1 = nn.ModuleList([nn.Conv2d(1, Co, (K, 100)) for K in Ks])
        self.dropout = nn.Dropout(0.2)
        self.fc1 = nn.Linear(len(Ks)*Co, C)

    def forward(self, x):
        x = self.embed(x)  # (N, W, D)
        x = x.unsqueeze(1)  # (N, Ci, W, D)
        x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1]  # [(N, Co, W), ...]*len(Ks)
        x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]  # [(N, Co), ...]*len(Ks)
        x = torch.cat(x, 1)
        x = self.dropout(x)  # (N, len(Ks)*Co)
        logit = self.fc1(x)  # (N, C)
        return logit

In [3]:
class mydataset(data.Dataset):
    @staticmethod
    def sort_key(ex):
        return len(ex.text)
    def __init__(self, text_field, label_field, path=None, examples=None, **kwargs):
        fields = [('text', text_field), ('label', label_field)]
        if examples is None:
            path = self.dirname if path is None else path
            examples = []
            for i,line in enumerate(open(path,'r',encoding='utf-8')):
                if i==0:
                    continue
                line = line.strip().split('\t')
                txt= [ d.split('/')[0] for d in line[1].split(' ') ]
                
                examples += [ data.Example.fromlist( [ txt, line[2]],fields ) ]
                
        super(mydataset, self).__init__(examples, fields, **kwargs)
        #**print(examples)
        #**print(fields)

In [4]:
text_field = data.Field(batch_first = True, fix_length = 20 )
label_field = data.Field(sequential= False, batch_first = True, unk_token = None )

train_data = mydataset(text_field,label_field,path='../nsm/twit_ratings_train.txt')

test_data = mydataset(text_field,label_field,path='../nsm/twit_ratings_test.txt')

text_field.build_vocab(train_data)
label_field.build_vocab(train_data)

train_iter, test_iter = data.Iterator.splits(
                            (train_data, test_data), 
                            batch_sizes=(100, 1), repeat=False)#, device= 'cuda')
len(text_field.vocab)

61490

In [5]:
cnn = torch.load('model/cnn_model_large.pt', map_location=lambda storage, loc: storage)

In [6]:
%%time
cnn.eval()
correct = 0
incorrect = 0
for batch in test_iter:
    txt = batch.text
    label = batch.label
    
    pred = cnn(txt)
    _,ans = torch.max(pred,dim=1)
    
    if ans.data[0] == label.data[0]:
        correct += 1    
    else:
        incorrect += 1
    
print ('correct : ', correct)
print ('incorrect : ', incorrect)
print ()

correct :  38987
incorrect :  11013

Wall time: 19.7 s
