Skip to content

Commit

Permalink
Fixed bug with handling of non-default scoring functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rriva002 committed Jun 30, 2019
1 parent b5141b2 commit a7c9629
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Fork of Shawn Ng's [CNNs for Sentence Classification in PyTorch](https://github.

## To Do
* Add support for cross-validation during training.
* Implement sample weights in eval scoring?

## Parameters
**lr : float, optional (default=0.01)**
Expand Down
27 changes: 18 additions & 9 deletions cnn_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import Counter
from copy import deepcopy
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import accuracy_score
from sklearn.metrics import accuracy_score, make_scorer
from time import time
from torch.autograd import Variable
from torchtext.data import Dataset, Example, Field, Iterator, Pipeline
Expand All @@ -17,7 +17,8 @@ def __init__(self, lr=0.001, epochs=256, batch_size=64, test_interval=100,
embed_dim=128, kernel_num=100, kernel_sizes="3,4,5",
static=False, device=-1, cuda=True, class_weight=None,
split_ratio=0.9, random_state=None, vectors=None,
preprocessor=None, scoring=accuracy_score, verbose=0):
preprocessor=None, scoring=make_scorer(accuracy_score),
verbose=0):
self.lr = lr
self.epochs = epochs
self.batch_size = batch_size
Expand Down Expand Up @@ -75,7 +76,9 @@ def __eval(self, data_iter):
preds += torch.max(logit, 1)[1].view(target.size()).data.tolist()
targets += target.data.tolist()

return self.scoring(targets, preds)
preds = [self.__label_field.vocab.itos[pred + 1] for pred in preds]
targets = [self.__label_field.vocab.itos[targ + 1] for targ in targets]
return self.scoring(_Eval(preds), None, targets)

def fit(self, X, y, sample_weight=None):
if self.random_state is not None:
Expand All @@ -95,10 +98,9 @@ def fit(self, X, y, sample_weight=None):
embed_num = len(self.__text_field.vocab)
class_num = len(self.__label_field.vocab) - 1
kernel_sizes = [int(k) for k in self.kernel_sizes.split(",")]
self.__model = CNNText(embed_num, self.embed_dim, class_num,
self.kernel_num, kernel_sizes, self.dropout,
self.static,
vectors=self.__text_field.vocab.vectors)
self.__model = _CNNText(embed_num, self.embed_dim, class_num,
self.kernel_num, kernel_sizes, self.dropout,
self.static, self.__text_field.vocab.vectors)

if self.cuda and torch.cuda.is_available():
torch.cuda.set_device(self.device)
Expand Down Expand Up @@ -253,10 +255,10 @@ def __print_elapsed_time(self, seconds):
print("Completed training in {}.".format(times))


class CNNText(nn.Module):
class _CNNText(nn.Module):
def __init__(self, embed_num, embed_dim, class_num, kernel_num,
kernel_sizes, dropout, static, vectors=None):
super(CNNText, self).__init__()
super(_CNNText, self).__init__()

self.__embed = nn.Embedding(embed_num, embed_dim)

Expand All @@ -279,3 +281,10 @@ def forward(self, x):
x = [F.relu(conv(x.unsqueeze(1))).squeeze(3) for conv in self.__convs1]
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
return self.__fc1(self.__dropout(torch.cat(x, 1)))

class _Eval():
def __init__(self, preds):
self.__preds = preds

def predict(self, X):
return self.__preds

0 comments on commit a7c9629

Please sign in to comment.