Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

BRILL based POS tagger with around 75-80% accuracy based on AUBT trainer

AUBT - Affix Unigram Bigram Tigram Tagger
  • Loading branch information...
commit 2c7015d9501d37e8713c6e7bcf76acdd111fb40a 1 parent e392117
Abhishek Gupta authored July 01, 2012
238  POSTag.py
... ...
@@ -1,86 +1,79 @@
1 1
 import nltk
2 2
 import os,errno
3 3
 import random
4  
-
5  
-'''
6  
-Important Constants
7  
-'''
8  
-fileNLTRdata="data/nltr"
9  
-analyzedDataDir="analyzed_data/"
10  
-taggers=[nltk.tag.UnigramTagger, nltk.tag.BigramTagger, nltk.tag.TrigramTagger]
11  
-
  4
+import params
  5
+from nltk.tag import brill
12 6
 
13 7
 '''
14 8
 Supporting Functions
15 9
 '''
16 10
 def backoff_tagger(tagged_sents, tagger_classes, backoff=None):
17  
-    if not backoff:
18  
-        backoff = tagger_classes[0](tagged_sents)
19  
-        del tagger_classes[0]
  11
+	if not backoff:
  12
+		backoff = tagger_classes[0](tagged_sents)
20 13
 
21  
-    for cls in tagger_classes:
22  
-        tagger = cls(tagged_sents, backoff=backoff)
23  
-        backoff = tagger
  14
+	for cls in tagger_classes:
  15
+		tagger = cls(tagged_sents, backoff=backoff)
  16
+		backoff = tagger
24 17
 
25  
-    return backoff
  18
+	return backoff
26 19
 
27 20
 def split_arr(arr):
28  
-    try:
29  
-        return map(lambda token: (token.split('\\')[0],token.split('\\')[1].split(".")[0]), arr)
30  
-    except:
31  
-        pass
  21
+	try:
  22
+		return map(lambda token: (token.split('\\')[0],token.split('\\')[1].split(".")[0]), arr)
  23
+	except:
  24
+		pass
32 25
 
33 26
 def stats(arr,dataset):
34  
-    numLines=len(arr)
35  
-    numWords=len(reduce(lambda x,y:x+y,arr))
36  
-    lines=["Stats for dataset "+dataset+"\n","Number of lines: "+str(numLines)+"\n","Number of words: "+str(numWords)+"\n"]
37  
-    fwrite=open(analyzedDataDir+dataset+".stat",'w')
38  
-    fwrite.writelines(lines)
39  
-    fwrite.close()
40  
-
41  
-    
  27
+	numLines=len(arr)
  28
+	numWords=len(reduce(lambda x,y:x+y,arr))
  29
+	lines=["Stats for dataset "+dataset+"\n","Number of lines: "+str(numLines)+"\n","Number of words: "+str(numWords)+"\n"]
  30
+	fwrite=open(params.analyzedDataDir+dataset+".stat",'w')
  31
+	fwrite.writelines(lines)
  32
+	fwrite.close()
  33
+
  34
+	
42 35
 def pos_frequency(arr,dataset):
43  
-    arr=reduce(lambda x,y:x+y,arr)
44  
-    new_arr=map(lambda a:a[1],arr)
45  
-    lines=[]
46  
-    for e in set(new_arr):
47  
-        lines.append((e,new_arr.count(e)))
48  
-    lines=sorted(lines,key=lambda a:a[1],reverse=True)
49  
-    lines=map(lambda (a,b): str(a)+";"+str(b)+"\n", lines)
50  
-    fwrite=open(analyzedDataDir+dataset+".freq",'w')
51  
-    fwrite.writelines(lines)
52  
-    fwrite.close()
  36
+	arr=reduce(lambda x,y:x+y,arr)
  37
+	new_arr=map(lambda a:a[1],arr)
  38
+	lines=[]
  39
+	for e in set(new_arr):
  40
+		lines.append((e,new_arr.count(e)))
  41
+	lines=sorted(lines,key=lambda a:a[1],reverse=True)
  42
+	lines=map(lambda (a,b): str(a)+";"+str(b)+"\n", lines)
  43
+	fwrite=open(params.analyzedDataDir+dataset+".freq",'w')
  44
+	fwrite.writelines(lines)
  45
+	fwrite.close()
53 46
 
54 47
 def accuracy(tagger,test_set):
55  
-    matched=0
56  
-    total=0
57  
-    
58  
-    for data in test_set:
59  
-        text=map(lambda a:a[0],data)
60  
-        ctags=map(lambda a:a[1],data)
61  
-        ntags=map(lambda a:a[1],tagger.tag(text))
  48
+	matched=0
  49
+	total=0
  50
+	
  51
+	for data in test_set:
  52
+		text=map(lambda a:a[0],data)
  53
+		ctags=map(lambda a:a[1],data)
  54
+		ntags=map(lambda a:a[1],tagger.tag(text))
62 55
 
63  
-        #Compare ctags and ntags for evaluation
64  
-        for i in xrange(len(ctags)):
65  
-            if ctags[i]==ntags[i]:
66  
-                matched=matched+1
  56
+		#Compare ctags and ntags for evaluation
  57
+		for i in xrange(len(ctags)):
  58
+			if ctags[i]==ntags[i]:
  59
+				matched=matched+1
67 60
 
68  
-        total=total+len(ctags)
  61
+		total=total+len(ctags)
69 62
 
70  
-    return float(matched)/total
  63
+	return float(matched)/total
71 64
 
72 65
 #Filter the sentences from Bangla.pos as provided by nltk (prepared by IIT Kharagpur)
73 66
 tagged_sents=nltk.corpus.indian.tagged_sents(fileids="bangla.pos")
74 67
 filtered_sents1=[]
75 68
 
76 69
 for i,sent in enumerate(tagged_sents):
77  
-    try:
78  
-        tagger=nltk.tag.UnigramTagger([sent])
79  
-        filtered_sents1.append(sent)
80  
-    except ValueError:
81  
-        pass
  70
+	try:
  71
+		tagger=nltk.tag.UnigramTagger([sent])
  72
+		filtered_sents1.append(sent)
  73
+	except ValueError:
  74
+		pass
82 75
 
83  
-fread=open(fileNLTRdata)
  76
+fread=open(params.fileNLTRdata)
84 77
 lines=fread.readlines()
85 78
 fread.close()
86 79
 
@@ -89,22 +82,22 @@ def accuracy(tagger,test_set):
89 82
 lines=map(lambda line:split_arr(line), lines)
90 83
 
91 84
 for i,line in enumerate(lines):
92  
-    try:
93  
-        tagger=nltk.tag.UnigramTagger([line])
94  
-        filtered_sents2.append(line)
95  
-    except (TypeError, ValueError):
96  
-        pass
  85
+	try:
  86
+		tagger=nltk.tag.UnigramTagger([line])
  87
+		filtered_sents2.append(line)
  88
+	except (TypeError, ValueError):
  89
+		pass
97 90
 
98 91
 '''
99 92
 Generate some statistical data
100 93
 '''
101 94
 try:
102  
-    os.makedirs(analyzedDataDir)
  95
+	os.makedirs(params.analyzedDataDir)
103 96
 except OSError as exc:
104  
-    if exc.errno == errno.EEXIST:
105  
-        pass
106  
-    else:
107  
-        raise
  97
+	if exc.errno == errno.EEXIST:
  98
+		pass
  99
+	else:
  100
+		raise
108 101
 
109 102
 stats(filtered_sents1,"iitk")
110 103
 stats(filtered_sents2,"nltr")
@@ -116,19 +109,100 @@ def accuracy(tagger,test_set):
116 109
 total_set=filtered_sents1
117 110
 #+filtered_sents2 - Don't include this data because of the non standard POS tags that used while tagging
118 111
 
119  
-random.shuffle(total_set)
120  
-
121  
-len_set=len(total_set)
122  
-train_length=int(0.8*len_set)
123  
-
124  
-#Prepare a training and a test set
125  
-training_set=total_set[:train_length]
126  
-test_set=total_set[train_length:]
127  
-
128  
-#RUN DIFFERENT CLASSIFIERS AND REPORT THEIR PERFORMANCES ON THE TEST DATA SET
129  
-#CHOOSE A BEST CLASSIFFIER USING BACKOFF TAGGER (Can we do this automatic?)
130  
-#HOW TO USE CONDITIONAL RANDOM FIELDS BASED TAGGER TO ACCOMPLISH THIS AND COMPARE RESULTS
131  
-
132  
-tagger = backoff_tagger(training_set, [nltk.tag.UnigramTagger, nltk.tag.BigramTagger, nltk.tag.TrigramTagger])
133  
-
134  
-print accuracy(tagger,test_set)
  112
+scores=[]
  113
+avg_scores=[]
  114
+atagger=[nltk.tag.AffixTagger]
  115
+utagger=[nltk.tag.UnigramTagger]
  116
+btagger=[nltk.tag.BigramTagger]
  117
+ttagger=[nltk.tag.TrigramTagger]
  118
+ub_tagger=utagger+btagger
  119
+ut_tagger=utagger+ttagger
  120
+ubt_tagger=ub_tagger+ttagger
  121
+aubt_tagger=atagger+ubt_tagger
  122
+
  123
+taggers=[utagger, ub_tagger, ut_tagger, ubt_tagger,atagger,aubt_tagger]
  124
+tagger_names=["Unigram Tagger", "Unigram-Bigram Tagger","Unigram Tigram Tagger","Unigram Bigram Trigram Tagger","Affix based tagger","Affix Unigram Bigram Tigram Tagger"]
  125
+
  126
+brill_templates = [
  127
+	brill.SymmetricProximateTokensTemplate(brill.ProximateTagsRule, (1,1)),
  128
+	brill.SymmetricProximateTokensTemplate(brill.ProximateTagsRule, (2,2)),
  129
+	brill.SymmetricProximateTokensTemplate(brill.ProximateTagsRule, (1,2)),
  130
+	brill.SymmetricProximateTokensTemplate(brill.ProximateTagsRule, (1,3)),
  131
+	brill.SymmetricProximateTokensTemplate(brill.ProximateWordsRule, (1,1)),
  132
+	brill.SymmetricProximateTokensTemplate(brill.ProximateWordsRule, (2,2)),
  133
+	brill.SymmetricProximateTokensTemplate(brill.ProximateWordsRule, (1,2)),
  134
+	brill.SymmetricProximateTokensTemplate(brill.ProximateWordsRule, (1,3)),
  135
+	brill.ProximateTokensTemplate(brill.ProximateTagsRule, (-1, -1), (1,1)),
  136
+	brill.ProximateTokensTemplate(brill.ProximateWordsRule, (-1, -1), (1,1))
  137
+]
  138
+
  139
+fwrite=open(params.analyzedDataDir+"accuracy.txt",'w')
  140
+
  141
+for j in xrange(len(taggers)):
  142
+	avg_scores.append(0)
  143
+	scores.append([])
  144
+	#tag_classes=copy.copy(taggers[j])
  145
+	tag_classes=taggers[j]
  146
+
  147
+	for i in xrange(params.numTrials):
  148
+		random.shuffle(total_set)
  149
+		len_set=len(total_set)
  150
+		train_length=int(0.8*len_set)
  151
+
  152
+		#Prepare a training and a test set
  153
+		training_set=total_set[:train_length]
  154
+		test_set=total_set[train_length:]
  155
+			
  156
+		tagger=backoff_tagger(training_set,tag_classes)
  157
+				
  158
+		scores[j].append(accuracy(tagger,test_set))
  159
+		avg_scores[j]=avg_scores[j]+scores[j][i]
  160
+		
  161
+	avg_scores[j]=float(avg_scores[j])/params.numTrials
  162
+
  163
+	lines=['Tagger:\t'+tagger_names[j]+"\n"]
  164
+	
  165
+	line=""
  166
+	for i in xrange(params.numTrials):
  167
+		line=line+str(scores[j][i])+"\t"
  168
+	line=line+"\n"
  169
+	
  170
+	lines.append(line)
  171
+	
  172
+	lines.append("Accuracy Score:\t"+str(avg_scores[j])+"\n")
  173
+	lines.append("\n")
  174
+	
  175
+	fwrite.writelines(lines)
  176
+
  177
+scores=[]
  178
+avg_score=0
  179
+for i in xrange(params.numTrials):
  180
+	random.shuffle(total_set)
  181
+	len_set=len(total_set)
  182
+	train_length=int(0.8*len_set)
  183
+
  184
+	#Prepare a training and a test set
  185
+	training_set=total_set[:train_length]
  186
+	test_set=total_set[train_length:]
  187
+
  188
+	btrainer = nltk.tag.brill.FastBrillTaggerTrainer(backoff_tagger(training_set,aubt_tagger), brill_templates)
  189
+	tagger = btrainer.train(training_set, max_rules=300, min_score=3)
  190
+	
  191
+	scores.append(accuracy(tagger,test_set))
  192
+	avg_score=avg_score+scores[i]
  193
+
  194
+avg_score=float(avg_score)/params.numTrials
  195
+lines=['Tagger: Brill Based Tagger with AUBT as the trainer Tagger\n']
  196
+	
  197
+line=""
  198
+for i in xrange(params.numTrials):
  199
+	line=line+str(scores[i])+"\t"
  200
+line=line+"\n"
  201
+	
  202
+lines.append(line)
  203
+	
  204
+lines.append("Accuracy Score:\t"+str(avg_score)+"\n")
  205
+lines.append("\n")
  206
+fwrite.writelines(lines)
  207
+
  208
+fwrite.close()
23  README.md
Source Rendered
... ...
@@ -1,4 +1,19 @@
1  
-bangla-pos-tagger
2  
-=================
3  
-
4  
-POS Tagger for Bangla language based on Conditional Random Fields
  1
+bangla-pos-tagger
  2
+=================
  3
+
  4
+POS Tagger for Bangla language based on Conditional Random Fields
  5
+
  6
+Observations
  7
+============
  8
+* Unigram Based Tagger gives approximately 60-65% accuracy.
  9
+* Adding Bigram, and Trigram based taggers following the same increases the accuracy to some extent.
  10
+* Adding an affix based tagger, improves the accuracy a bit.
  11
+
  12
+Note: In the "accuracy.txt" file in the analyzed_data directory. Only the relevant results have been added which were giving really good accuracies. The analysis is similar to that of the blog.
  13
+
  14
+Relevant Blog Posts
  15
+-------------------
  16
+1. http://streamhacker.com/2008/12/29/how-to-train-a-nltk-chunker/
  17
+2. http://streamhacker.com/2008/11/10/part-of-speech-tagging-with-nltk-part-2/
  18
+3. http://streamhacker.com/2008/12/03/part-of-speech-tagging-with-nltk-part-3/
  19
+4. http://streamhacker.com/2010/04/12/pos-tag-nltk-brill-classifier/
28  analyzed_data/accuracy.txt
... ...
@@ -0,0 +1,28 @@
  1
+Tagger:	Unigram Tagger
  2
+0.637562528422	0.635555555556	0.660314435445	0.63741223671	0.669927909372	0.669254658385	0.649638554217	0.638384754991	0.650075414781	0.656467315716	
  3
+Accuracy Score:	0.65045933636
  4
+
  5
+Tagger:	Unigram-Bigram Tagger
  6
+0.62810707457	0.646478873239	0.641527446301	0.652214891612	0.646290247317	0.659907834101	0.654579265458	0.674937965261	0.658772353794	0.65340626554	
  7
+Accuracy Score:	0.651622221719
  8
+
  9
+Tagger:	Unigram Tigram Tagger
  10
+0.640189125296	0.667027027027	0.658502119642	0.669454008854	0.640883977901	0.630352045671	0.649098474341	0.645992366412	0.658221559861	0.640272679601	
  11
+Accuracy Score:	0.649999338461
  12
+
  13
+Tagger:	Unigram Bigram Trigram Tagger
  14
+0.635526921307	0.672032193159	0.65507518797	0.658308751229	0.66065747614	0.649265078561	0.657560056524	0.649883449883	0.651266766021	0.653241650295	
  15
+Accuracy Score:	0.654281753109
  16
+
  17
+Tagger:	Affix based tagger
  18
+0.338250790306	0.343848580442	0.340155945419	0.331627430911	0.329274479959	0.332566697332	0.350168350168	0.347126436782	0.341346153846	0.334858188472	
  19
+Accuracy Score:	0.338922305364
  20
+
  21
+Tagger:	Affix Unigram Bigram Tigram Tagger
  22
+0.777345649003	0.768137254902	0.777117384844	0.761423761424	0.775687409551	0.76867816092	0.758603238866	0.764930721452	0.764677804296	0.769818529131	
  23
+Accuracy Score:	0.768641991439
  24
+
  25
+Tagger: Brill Based Tagger with AUBT as the trainer Tagger
  26
+0.782021425245	0.776097560976	0.76311844078	0.762596899225	0.767847971238	0.761102603369	0.784935579782	0.746578574799	0.757050964869	0.7818275154	
  27
+Accuracy Score:	0.768317753568
  28
+
47  bangla_pos_tagger.py
... ...
@@ -0,0 +1,47 @@
  1
+import nltk
  2
+from nltk.tag import brill
  3
+
  4
+'''
  5
+Supporting Functions
  6
+'''
  7
+def backoff_tagger(tagged_sents, tagger_classes, backoff=None):
  8
+	if not backoff:
  9
+		backoff = tagger_classes[0](tagged_sents)
  10
+
  11
+	for cls in tagger_classes:
  12
+		tagger = cls(tagged_sents, backoff=backoff)
  13
+		backoff = tagger
  14
+
  15
+	return backoff
  16
+
  17
+
  18
+tagged_sents=nltk.corpus.indian.tagged_sents(fileids="bangla.pos")
  19
+filtered_sents1=[]
  20
+for i,sent in enumerate(tagged_sents):
  21
+	try:
  22
+		tagger=nltk.tag.UnigramTagger([sent])
  23
+		filtered_sents1.append(sent)
  24
+	except ValueError:
  25
+		pass
  26
+
  27
+total_set=filtered_sents1
  28
+
  29
+aubt_tagger=[nltk.tag.AffixTagger,nltk.tag.UnigramTagger,nltk.tag.BigramTagger,nltk.tag.TrigramTagger]
  30
+brill_templates = [
  31
+	brill.SymmetricProximateTokensTemplate(brill.ProximateTagsRule, (1,1)),
  32
+	brill.SymmetricProximateTokensTemplate(brill.ProximateTagsRule, (2,2)),
  33
+	brill.SymmetricProximateTokensTemplate(brill.ProximateTagsRule, (1,2)),
  34
+	brill.SymmetricProximateTokensTemplate(brill.ProximateTagsRule, (1,3)),
  35
+	brill.SymmetricProximateTokensTemplate(brill.ProximateWordsRule, (1,1)),
  36
+	brill.SymmetricProximateTokensTemplate(brill.ProximateWordsRule, (2,2)),
  37
+	brill.SymmetricProximateTokensTemplate(brill.ProximateWordsRule, (1,2)),
  38
+	brill.SymmetricProximateTokensTemplate(brill.ProximateWordsRule, (1,3)),
  39
+	brill.ProximateTokensTemplate(brill.ProximateTagsRule, (-1, -1), (1,1)),
  40
+	brill.ProximateTokensTemplate(brill.ProximateWordsRule, (-1, -1), (1,1))
  41
+]
  42
+
  43
+btrainer = nltk.tag.brill.FastBrillTaggerTrainer(backoff_tagger(total_set,aubt_tagger), brill_templates)
  44
+tagger = btrainer.train(total_set, max_rules=300, min_score=3)
  45
+
  46
+def pos_tag(query):
  47
+	return tagger.tag(query)
10  params.py
... ...
@@ -1,2 +1,10 @@
  1
+'''
  2
+Config
  3
+'''
  4
+
  5
+fileNLTRdata="data/nltr"
  6
+analyzedDataDir="analyzed_data/"
  7
+numTrials=10
  8
+
1 9
 iitk_dataset_original="data_sets/iitk/bangla.pos"
2  
-iitk_dataset="data_sets/iitk/bangla.cleaned.pos"
  10
+iitk_dataset="data_sets/iitk/bangla.cleaned.pos"
15  setup.py
... ...
@@ -0,0 +1,15 @@
  1
+#!/usr/bin/env python
  2
+
  3
+from distutils.core import setup
  4
+
  5
+setup(name='Bangla POS Tagger',
  6
+      version='1.0',
  7
+      description='Bangla Based POS Tagger',
  8
+      author='Abhishek Gupta',
  9
+      author_email='abhishekgupta.iitd@gmail.com',
  10
+      url='http://github.com/abhishekgupta92/bangla_pos_tagger',
  11
+      py_modules=['bangla_pos_tagger'],
  12
+      package_dir={'data':'data'},
  13
+      package_data={'data': ['*']},
  14
+      data_files=[('config',['params.py'])],
  15
+     )

0 notes on commit 2c7015d

Please sign in to comment.
Something went wrong with that request. Please try again.