<a href="https://colab.research.google.com/github/anirbansen3027/NER_from_scratch/blob/main/NER_from_scratch_CRF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Conditional Random Fields
CRFs are a type of discriminative undirected probabilistic graphical model.whose nodes can be divided into exactly two disjoint sets X and Y, the observed and output variables, respectively; the conditional distribution p(Y|X) is then modeled.
Learning the parameters theta  is usually done by maximum likelihood learning for p(Y_i|X_i; \theta). If all nodes have exponential family distributions and all nodes are observed during training, this optimization is convex. It can be solved for example using gradient descent algorithms, or Quasi-Newton methods such as the L-BFGS algorithm.

CoNLL dataset 

BIO notation: 
* B indicates the beginning of an entity;
* I inside an entity, indicates when entities comprise more than one word;
* O other, indicates non-
entities.

Models trained on Wikipedia corpus (Nothman et al., 2013) use a less fine-grained NER annotation scheme and recognise the following entities:

PER	Named person or family.
LOC	Name of politically or geographically defined location (cities, provinces, countries, international regions, bodies of water, mountains).
ORG	Named corporate, governmental, or other organizational entity.
MISC	Miscellaneous entities, e.g. events, nationalities, products or works of art.

In [73]:
! head /content/train.txt

EU	B-ORG
rejects	O
German	B-MISC
call	O
to	O
boycott	O
British	B-MISC
lamb	O
.	O



In [74]:
! pip install sklearn_crfsuite -q

In [75]:
import nltk
nltk.download('averaged_perceptron_tagger') 
from nltk.tag import pos_tag
from sklearn_crfsuite import CRF, metrics

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


In [76]:
"""
Load the training/testing data. 
input: conll format data, but with only 2 tab separated colums - words and NEtags.
output: A list where each item is 2 lists.  sentence as a list of tokens, NER tags as a list for each token.
"""
def load_data_conll(file_path):
  myoutput, words, tags = [], [], []
  fh = open(file_path)
  for line in fh:
    line = line.strip()
    if '\t' not in line:
      #Sentence Ended
      myoutput.append([words, tags])
      words, tags = [], []
    else:
      word, tag = line.split('\t')
      words.append(word)
      tags.append(tag)
  fh.close()
  return myoutput

In [77]:
"""
Get features for all words in the sentence
Features:
- word context: a window of 2 words on either side of the current word, and current word.
- POS context: a window of 2 POS tags on either side of the current word, and current tag. 
input: sentence as a list of tokens.
output: list of dictionaries. each dict represents features for that word.
"""
def sent2feats(sentence):
  feats = []
  sent_tags = pos_tag(sentence)
  # [('John', 'NNP'),("'s", 'POS'),...]
  for i in range(len(sentence)):
    word = sentence[i]
    #word features: word, prev 2 words, next 2 words in the sentence.
    #POS tag features: current tag, previous and next 2 tags.
    word_feats = {}
    word_feats['word'] = word
    word_feats['tag'] = sent_tags[i][1]
    if i==0:
      word_feats['prevWord'] = '<S>'
      word_feats['prevSecondWord'] = '<S>'
      word_feats['prevTag'] = '<S>'
      word_feats['prevSecondTag'] = '<S>'
    elif i == 1:
      word_feats['prevWord'] = sentence[i-1]
      word_feats['prevSecondWord'] = '<S>'
      word_feats['prevTag'] = sent_tags[i-1][1]
      word_feats['prevSecondTag'] = '<S>'
    else:
      word_feats['prevWord'] = sentence[i-1]
      word_feats['prevSecondWord'] = sentence[i-2]
      word_feats['prevTag'] = sent_tags[i-1][1]
      word_feats['prevSecondTag'] = sent_tags[i-2][1]
    if i == len(sentence) - 1:
      word_feats['nextWord'] = '</S>'
      word_feats['nextNextWord'] = '</S>'
      word_feats['nextTag'] = '</S>'
      word_feats['nextNextTag'] = '</S>'
    elif i == len(sentence) - 2:
      word_feats['nextWord'] = sentence[i+1]
      word_feats['nextNextWord'] = '</S>'
      word_feats['nextTag'] = sent_tags[i+1][1]
      word_feats['nextNextTag'] = '</S>'
    else:
      word_feats['nextWord'] = sentence[i+1]
      word_feats['nextNextWord'] = sentence[i+2]
      word_feats['nextTag'] = sent_tags[i+1][1]
      word_feats['nextNextTag'] = sent_tags[i+2][1]
   
    feats.append(word_feats)
  return feats 

In [78]:
#Extract features from the conll data, after loading it.
def get_features_conll(conll_data):
  feats, labels = [], []
  for sentence in conll_data:
    feats.append(sent2feats(sentence[0]))
    labels.append(sentence[1])
  return feats, labels

In [79]:
def train_seq(feats, labels, feats_val, labels_val):
  crf = CRF(algorithm = 'lbfgs', c1 = 0.1, c2 = 10, max_iterations = 50)
  crf.fit(feats, labels)
  labels = list(crf.classes_)
  labels_val_pred = crf.predict(feats_val)
  print(metrics.flat_f1_score(labels_val, labels_val_pred, average = 'weighted', labels = labels))
  print(metrics.flat_classification_report(labels_val, labels_val_pred, labels = labels, digits = 3))

In [80]:
def main():
  train_path = 'train.txt'
  val_path = 'val.txt'
  conll_train = load_data_conll(train_path)
  conll_val = load_data_conll(val_path)
  print(conll_train[:2])
  feats, labels = get_features_conll(conll_train)
  feats_val, labels_val = get_features_conll(conll_val)
  print(feats[:2])
  train_seq(feats, labels, feats_val, labels_val)
if __name__ == '__main__':
  main()

[[['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']], [['Peter', 'Blackburn'], ['B-PER', 'I-PER']]]
[[{'word': 'EU', 'tag': 'NNP', 'prevWord': '<S>', 'prevSecondWord': '<S>', 'prevTag': '<S>', 'prevSecondTag': '<S>', 'nextWord': 'rejects', 'nextNextWord': 'German', 'nextTag': 'VBZ', 'nextNextTag': 'JJ'}, {'word': 'rejects', 'tag': 'VBZ', 'prevWord': 'EU', 'prevSecondWord': '<S>', 'prevTag': 'NNP', 'prevSecondTag': '<S>', 'nextWord': 'German', 'nextNextWord': 'call', 'nextTag': 'JJ', 'nextNextTag': 'NN'}, {'word': 'German', 'tag': 'JJ', 'prevWord': 'rejects', 'prevSecondWord': 'EU', 'prevTag': 'VBZ', 'prevSecondTag': 'NNP', 'nextWord': 'call', 'nextNextWord': 'to', 'nextTag': 'NN', 'nextNextTag': 'TO'}, {'word': 'call', 'tag': 'NN', 'prevWord': 'German', 'prevSecondWord': 'rejects', 'prevTag': 'JJ', 'prevSecondTag': 'VBZ', 'nextWord': 'to', 'nextNextWord': 'boycott', 'nextTag': 'TO', 'nextNextTag': '