-
Notifications
You must be signed in to change notification settings - Fork 136
/
utils.py
117 lines (99 loc) · 4.39 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# utils.py
import torch
from torchtext import data
from torchtext.vocab import Vectors
import spacy
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
class Dataset(object):
def __init__(self, config):
self.config = config
self.train_iterator = None
self.test_iterator = None
self.val_iterator = None
self.vocab = []
self.word_embeddings = {}
def parse_label(self, label):
'''
Get the actual labels from label string
Input:
label (string) : labels of the form '__label__2'
Returns:
label (int) : integer value corresponding to label string
'''
return int(label.strip()[-1])
def get_pandas_df(self, filename):
'''
Load the data into Pandas.DataFrame object
This will be used to convert data to torchtext object
'''
with open(filename, 'r') as datafile:
data = [line.strip().split(',', maxsplit=1) for line in datafile]
data_text = list(map(lambda x: x[1], data))
data_label = list(map(lambda x: self.parse_label(x[0]), data))
full_df = pd.DataFrame({"text":data_text, "label":data_label})
return full_df
def load_data(self, w2v_file, train_file, test_file, val_file=None):
'''
Loads the data from files
Sets up iterators for training, validation and test data
Also create vocabulary and word embeddings based on the data
Inputs:
w2v_file (String): absolute path to file containing word embeddings (GloVe/Word2Vec)
train_file (String): absolute path to training file
test_file (String): absolute path to test file
val_file (String): absolute path to validation file
'''
NLP = spacy.load('en')
tokenizer = lambda sent: [x.text for x in NLP.tokenizer(sent) if x.text != " "]
# Creating Field for data
TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True, fix_length=self.config.max_sen_len)
LABEL = data.Field(sequential=False, use_vocab=False)
datafields = [("text",TEXT),("label",LABEL)]
# Load data from pd.DataFrame into torchtext.data.Dataset
train_df = self.get_pandas_df(train_file)
train_examples = [data.Example.fromlist(i, datafields) for i in train_df.values.tolist()]
train_data = data.Dataset(train_examples, datafields)
test_df = self.get_pandas_df(test_file)
test_examples = [data.Example.fromlist(i, datafields) for i in test_df.values.tolist()]
test_data = data.Dataset(test_examples, datafields)
# If validation file exists, load it. Otherwise get validation data from training data
if val_file:
val_df = self.get_pandas_df(val_file)
val_examples = [data.Example.fromlist(i, datafields) for i in val_df.values.tolist()]
val_data = data.Dataset(val_examples, datafields)
else:
train_data, val_data = train_data.split(split_ratio=0.8)
TEXT.build_vocab(train_data, vectors=Vectors(w2v_file))
self.word_embeddings = TEXT.vocab.vectors
self.vocab = TEXT.vocab
self.train_iterator = data.BucketIterator(
(train_data),
batch_size=self.config.batch_size,
sort_key=lambda x: len(x.text),
repeat=False,
shuffle=True)
self.val_iterator, self.test_iterator = data.BucketIterator.splits(
(val_data, test_data),
batch_size=self.config.batch_size,
sort_key=lambda x: len(x.text),
repeat=False,
shuffle=False)
print ("Loaded {} training examples".format(len(train_data)))
print ("Loaded {} test examples".format(len(test_data)))
print ("Loaded {} validation examples".format(len(val_data)))
def evaluate_model(model, iterator):
all_preds = []
all_y = []
for idx,batch in enumerate(iterator):
if torch.cuda.is_available():
x = batch.text.cuda()
else:
x = batch.text
y_pred = model(x)
predicted = torch.max(y_pred.cpu().data, 1)[1] + 1
all_preds.extend(predicted.numpy())
all_y.extend(batch.label.numpy())
score = accuracy_score(all_y, np.array(all_preds).flatten())
return score