In [2]:
import numpy as np 
import pandas as pd

In [3]:
from sklearn.model_selection import train_test_split
from sklearn_crfsuite import CRF
from sklearn_crfsuite.metrics import flat_f1_score
from sklearn_crfsuite.metrics import flat_classification_report

In [5]:
tr = pd.read_csv(r'./data/test.csv')
#te = pd.read_csv(r'./data/test.csv')

In [6]:
tr.head(5)

Unnamed: 0,sentence,word,POS,tag
0,1,Subordinated,NNP,O
1,1,Loan,NNP,O
2,1,Agreement,NNP,O
3,1,-,:,O
4,1,Silicium,NNP,I-ORG


In [7]:
tr.describe()

Unnamed: 0,sentence
count,13249.0
mean,144.385161
std,78.749781
min,1.0
25%,72.0
50%,156.0
75%,201.0
max,306.0


In [8]:
tr['Tag'].unique()

KeyError: 'Tag'

In [7]:
tr.isnull().sum()

Sentence #    1000616
Word                0
POS                 0
Tag                 0
dtype: int64

In [8]:
tr = tr.fillna(method = 'ffill')

In [9]:
tr

Unnamed: 0,Sentence #,Word,POS,Tag
0,Sentence: 1,Thousands,NNS,O
1,Sentence: 1,of,IN,O
2,Sentence: 1,demonstrators,NNS,O
3,Sentence: 1,have,VBP,O
4,Sentence: 1,marched,VBN,O
...,...,...,...,...
1048570,Sentence: 47959,they,PRP,O
1048571,Sentence: 47959,responded,VBD,O
1048572,Sentence: 47959,to,TO,O
1048573,Sentence: 47959,the,DT,O


In [10]:
class sentence(object):
    def __init__(self, df):
        self.n_sent = 1
        self.df = df
        self.empty = False
        agg = lambda s : [(w, p, t) for w, p, t in zip(s['Word'].values.tolist(),
                                                       s['POS'].values.tolist(),
                                                       s['Tag'].values.tolist())]
        self.grouped = self.df.groupby("Sentence #").apply(agg)
        self.sentences = [s for s in self.grouped]
        
    def get_text(self):
        try:
            s = self.grouped['Sentence: {}'.format(self.n_sent)]
            self.n_sent +=1
            return s
        except:
            return None

In [11]:
#Displaying one full sentence
getter = sentence(tr)
sentences = [" ".join([s[0] for s in sent]) for sent in getter.sentences]
sentences[0]

'Thousands of demonstrators have marched through London to protest the war in Iraq and demand the withdrawal of British troops from that country .'

In [12]:
#sentence with its pos and tag.
sent = getter.get_text()
print(sent)

[('Thousands', 'NNS', 'O'), ('of', 'IN', 'O'), ('demonstrators', 'NNS', 'O'), ('have', 'VBP', 'O'), ('marched', 'VBN', 'O'), ('through', 'IN', 'O'), ('London', 'NNP', 'B-geo'), ('to', 'TO', 'O'), ('protest', 'VB', 'O'), ('the', 'DT', 'O'), ('war', 'NN', 'O'), ('in', 'IN', 'O'), ('Iraq', 'NNP', 'B-geo'), ('and', 'CC', 'O'), ('demand', 'VB', 'O'), ('the', 'DT', 'O'), ('withdrawal', 'NN', 'O'), ('of', 'IN', 'O'), ('British', 'JJ', 'B-gpe'), ('troops', 'NNS', 'O'), ('from', 'IN', 'O'), ('that', 'DT', 'O'), ('country', 'NN', 'O'), ('.', '.', 'O')]


In [13]:
sentences = getter.sentences

In [14]:
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]

    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,
        'postag[:2]': postag[:2],
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True

    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True

    return features


def sent2features(sent):
    return [word2features(sent, i) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

In [15]:
x_tr = [sent2features(s) for s in sentences]
y_tr = [sent2labels(s) for s in sentences]

In [16]:
x_train, x_val, y_train, y_val = train_test_split(x_tr, y_tr, test_size = 0.2)

In [17]:
crf = CRF(algorithm = 'lbfgs',
         c1 = 0.1,
         c2 = 0.1,
         max_iterations = 100,
         all_possible_transitions = False)
crf.fit(x_train, y_train)



CRF(algorithm='lbfgs', all_possible_transitions=False, c1=0.1, c2=0.1,
    keep_tempfiles=None, max_iterations=100)

In [18]:
#Predicting on the test set.
y_val_pred = crf.predict(x_val)

In [24]:
f1_score = flat_f1_score(y_val, y_val_pred, average=None)
print(f1_score)

[0.21428571 0.43636364 0.88391355 0.95347766 0.45783133 0.76991943
 0.84580598 0.90672993 0.20588235 0.24324324 0.79876161 0.51515152
 0.21052632 0.80689056 0.87568456 0.79967427 0.99289478]


In [20]:
report = flat_classification_report(y_val, y_val_pred)
print(report)



              precision    recall  f1-score   support

       B-art       0.50      0.14      0.21        88
       B-eve       0.56      0.36      0.44        67
       B-geo       0.86      0.91      0.88      7538
       B-gpe       0.96      0.94      0.95      3294
       B-nat       0.49      0.43      0.46        44
       B-org       0.79      0.75      0.77      4032
       B-per       0.86      0.83      0.85      3497
       B-tim       0.93      0.89      0.91      4095
       I-art       0.44      0.13      0.21        52
       I-eve       0.32      0.20      0.24        46
       I-geo       0.81      0.79      0.80      1467
       I-gpe       0.77      0.39      0.52        44
       I-nat       0.33      0.15      0.21        13
       I-org       0.80      0.81      0.81      3443
       I-per       0.86      0.90      0.88      3566
       I-tim       0.83      0.77      0.80      1273
           O       0.99      0.99      0.99    176802

    accuracy              