From 78d60e37c09503c35d04b23d205b50a714cf9dd1 Mon Sep 17 00:00:00 2001 From: Laimonas Simutis Date: Tue, 23 Dec 2014 16:10:25 -0500 Subject: [PATCH 1/3] classification tests and fixes --- .../KNearestNeighborClassifier.cs | 3 +- .../Lucene.Net.Classification.sln | 24 ++ .../SimpleNaiveBayesClassifier.cs | 2 +- .../ClassificationTestBase.cs | 326 ++++++++++++++++++ .../KNearestNeighborClassifierTest.cs | 52 +++ .../Lucene.Net.Tests.Classification.csproj | 75 ++++ .../Properties/AssemblyInfo.cs | 36 ++ .../SimpleNaiveBayesClassifierTest.cs | 69 ++++ .../app.config | 11 + .../packages.config | 4 + 10 files changed, 599 insertions(+), 3 deletions(-) create mode 100644 src/Lucene.Net.Tests.Classification/ClassificationTestBase.cs create mode 100644 src/Lucene.Net.Tests.Classification/KNearestNeighborClassifierTest.cs create mode 100644 src/Lucene.Net.Tests.Classification/Lucene.Net.Tests.Classification.csproj create mode 100644 src/Lucene.Net.Tests.Classification/Properties/AssemblyInfo.cs create mode 100644 src/Lucene.Net.Tests.Classification/SimpleNaiveBayesClassifierTest.cs create mode 100644 src/Lucene.Net.Tests.Classification/app.config create mode 100644 src/Lucene.Net.Tests.Classification/packages.config diff --git a/src/Lucene.Net.Classification/KNearestNeighborClassifier.cs b/src/Lucene.Net.Classification/KNearestNeighborClassifier.cs index e40025465d..480799aaa0 100644 --- a/src/Lucene.Net.Classification/KNearestNeighborClassifier.cs +++ b/src/Lucene.Net.Classification/KNearestNeighborClassifier.cs @@ -93,10 +93,9 @@ private ClassificationResult SelectClassFromNeighbors(TopDocs topDocs) foreach (ScoreDoc scoreDoc in topDocs.ScoreDocs) { BytesRef cl = new BytesRef(_indexSearcher.Doc(scoreDoc.Doc).GetField(_classFieldName).StringValue); - int count = classCounts[cl]; if (classCounts.ContainsKey(cl)) { - classCounts[cl] = count + 1; + classCounts[cl] = classCounts[cl] + 1; } else { diff --git a/src/Lucene.Net.Classification/Lucene.Net.Classification.sln b/src/Lucene.Net.Classification/Lucene.Net.Classification.sln index 99650495c6..8f6ef9dc1d 100644 --- a/src/Lucene.Net.Classification/Lucene.Net.Classification.sln +++ b/src/Lucene.Net.Classification/Lucene.Net.Classification.sln @@ -7,6 +7,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Lucene.Net", "..\Lucene.Net EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Lucene.Net.Queries", "..\Lucene.Net.Queries\Lucene.Net.Queries.csproj", "{69D7956C-C2CC-4708-B399-A188FEC384C4}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Lucene.Net.Tests.Classification", "..\Lucene.Net.Tests.Classification\Lucene.Net.Tests.Classification.csproj", "{4D77E491-F50F-4A0C-9BD9-F9AB655720AD}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Lucene.Net.Tests", "..\Lucene.Net.Tests\Lucene.Net.Tests.csproj", "{DE63DB10-975F-460D-AF85-572C17A91284}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -47,6 +51,26 @@ Global {69D7956C-C2CC-4708-B399-A188FEC384C4}.Release|Mixed Platforms.ActiveCfg = Release|Any CPU {69D7956C-C2CC-4708-B399-A188FEC384C4}.Release|Mixed Platforms.Build.0 = Release|Any CPU {69D7956C-C2CC-4708-B399-A188FEC384C4}.Release|x86.ActiveCfg = Release|Any CPU + {4D77E491-F50F-4A0C-9BD9-F9AB655720AD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4D77E491-F50F-4A0C-9BD9-F9AB655720AD}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4D77E491-F50F-4A0C-9BD9-F9AB655720AD}.Debug|Mixed Platforms.ActiveCfg = Debug|Any CPU + {4D77E491-F50F-4A0C-9BD9-F9AB655720AD}.Debug|Mixed Platforms.Build.0 = Debug|Any CPU + {4D77E491-F50F-4A0C-9BD9-F9AB655720AD}.Debug|x86.ActiveCfg = Debug|Any CPU + {4D77E491-F50F-4A0C-9BD9-F9AB655720AD}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4D77E491-F50F-4A0C-9BD9-F9AB655720AD}.Release|Any CPU.Build.0 = Release|Any CPU + {4D77E491-F50F-4A0C-9BD9-F9AB655720AD}.Release|Mixed Platforms.ActiveCfg = Release|Any CPU + {4D77E491-F50F-4A0C-9BD9-F9AB655720AD}.Release|Mixed Platforms.Build.0 = Release|Any CPU + {4D77E491-F50F-4A0C-9BD9-F9AB655720AD}.Release|x86.ActiveCfg = Release|Any CPU + {DE63DB10-975F-460D-AF85-572C17A91284}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {DE63DB10-975F-460D-AF85-572C17A91284}.Debug|Any CPU.Build.0 = Debug|Any CPU + {DE63DB10-975F-460D-AF85-572C17A91284}.Debug|Mixed Platforms.ActiveCfg = Debug|Any CPU + {DE63DB10-975F-460D-AF85-572C17A91284}.Debug|Mixed Platforms.Build.0 = Debug|Any CPU + {DE63DB10-975F-460D-AF85-572C17A91284}.Debug|x86.ActiveCfg = Debug|Any CPU + {DE63DB10-975F-460D-AF85-572C17A91284}.Release|Any CPU.ActiveCfg = Release|Any CPU + {DE63DB10-975F-460D-AF85-572C17A91284}.Release|Any CPU.Build.0 = Release|Any CPU + {DE63DB10-975F-460D-AF85-572C17A91284}.Release|Mixed Platforms.ActiveCfg = Release|Any CPU + {DE63DB10-975F-460D-AF85-572C17A91284}.Release|Mixed Platforms.Build.0 = Release|Any CPU + {DE63DB10-975F-460D-AF85-572C17A91284}.Release|x86.ActiveCfg = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/Lucene.Net.Classification/SimpleNaiveBayesClassifier.cs b/src/Lucene.Net.Classification/SimpleNaiveBayesClassifier.cs index a045c80f8c..799a8c4b59 100644 --- a/src/Lucene.Net.Classification/SimpleNaiveBayesClassifier.cs +++ b/src/Lucene.Net.Classification/SimpleNaiveBayesClassifier.cs @@ -91,7 +91,7 @@ private String[] TokenizeDoc(String doc) TokenStream tokenStream = _analyzer.TokenStream(textFieldName, new StringReader(doc)); try { - CharTermAttribute charTermAttribute = tokenStream.AddAttribute(); + ICharTermAttribute charTermAttribute = tokenStream.AddAttribute(); tokenStream.Reset(); while (tokenStream.IncrementToken()) { diff --git a/src/Lucene.Net.Tests.Classification/ClassificationTestBase.cs b/src/Lucene.Net.Tests.Classification/ClassificationTestBase.cs new file mode 100644 index 0000000000..269d8c4c66 --- /dev/null +++ b/src/Lucene.Net.Tests.Classification/ClassificationTestBase.cs @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Diagnostics; +using System.Text; +using Lucene.Net.Analysis; +using Lucene.Net.Documents; +using Lucene.Net.Index; +using Lucene.Net.Randomized.Generators; +using Lucene.Net.Search; +using Lucene.Net.Store; +using Lucene.Net.Util; +using NUnit.Framework; + +namespace Lucene.Net.Classification +{ + + /** + * Base class for testing {@link Classifier}s + */ + public abstract class ClassificationTestBase : Util.LuceneTestCase + { + public readonly static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " + + "If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section."; + public static readonly BytesRef POLITICS_RESULT = new BytesRef("politics"); + + public static readonly String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." + + " Truth is, Amazon may know more."; + public static readonly BytesRef TECHNOLOGY_RESULT = new BytesRef("technology"); + + private RandomIndexWriter indexWriter; + private Directory dir; + private FieldType ft; + + protected String textFieldName; + protected String categoryFieldName; + + String booleanFieldName; + + [SetUp] + public override void SetUp() + { + base.SetUp(); + dir = NewDirectory(); + indexWriter = new RandomIndexWriter(Random(), dir); + textFieldName = "text"; + categoryFieldName = "cat"; + booleanFieldName = "bool"; + ft = new FieldType(TextField.TYPE_STORED); + ft.StoreTermVectors = true; + ft.StoreTermVectorOffsets = true; + ft.StoreTermVectorPositions = true; + } + + [TearDown] + public void tearDown() + { + base.TearDown(); + indexWriter.Dispose(); + dir.Dispose(); + } + + protected void CheckCorrectClassification(IClassifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) + { + CheckCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null); + } + + protected void CheckCorrectClassification(IClassifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) + { + AtomicReader atomicReader = null; + try + { + PopulateSampleIndex(analyzer); + atomicReader = SlowCompositeReaderWrapper.Wrap(indexWriter.Reader); + classifier.Train(atomicReader, textFieldName, classFieldName, analyzer, query); + ClassificationResult classificationResult = classifier.AssignClass(inputDoc); + NotNull(classificationResult.AssignedClass); + AreEqual(expectedResult, classificationResult.AssignedClass, "got an assigned class of " + classificationResult.AssignedClass); + IsTrue(classificationResult.Score > 0, "got a not positive score " + classificationResult.Score); + } + finally + { + if (atomicReader != null) + atomicReader.Dispose(); + } + } + protected void CheckOnlineClassification(IClassifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) + { + CheckOnlineClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null); + } + + protected void CheckOnlineClassification(IClassifier classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) + { + AtomicReader atomicReader = null; + try + { + PopulateSampleIndex(analyzer); + atomicReader = SlowCompositeReaderWrapper.Wrap(indexWriter.Reader); + classifier.Train(atomicReader, textFieldName, classFieldName, analyzer, query); + ClassificationResult classificationResult = classifier.AssignClass(inputDoc); + NotNull(classificationResult.AssignedClass); + AreEqual(expectedResult, classificationResult.AssignedClass, "got an assigned class of " + classificationResult.AssignedClass); + IsTrue(classificationResult.Score > 0, "got a not positive score " + classificationResult.Score); + UpdateSampleIndex(analyzer); + ClassificationResult secondClassificationResult = classifier.AssignClass(inputDoc); + Equals(classificationResult.AssignedClass, secondClassificationResult.AssignedClass); + Equals(classificationResult.Score, secondClassificationResult.Score); + + } + finally + { + if (atomicReader != null) + atomicReader.Dispose(); + } + } + + private void PopulateSampleIndex(Analyzer analyzer) + { + indexWriter.DeleteAll(); + indexWriter.Commit(); + + String text; + + Document doc = new Document(); + text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " + + "who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " + + "the Unknown Soldier in Warsaw Tuesday."; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "politics", ft)); + doc.Add(new Field(booleanFieldName, "true", ft)); + + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" + + " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama."; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "politics", ft)); + doc.Add(new Field(booleanFieldName, "true", ft)); + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "And there's a threshold question that he has to answer for the American people and " + + "that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " + + "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\""; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "politics", ft)); + doc.Add(new Field(booleanFieldName, "true", ft)); + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " + + "keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " + + "Albany's School of Criminal Justice."; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "politics", ft)); + doc.Add(new Field(booleanFieldName, "true", ft)); + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " + + "technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " + + "world through the Internet."; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "technology", ft)); + doc.Add(new Field(booleanFieldName, "false", ft)); + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "So, about all those experts and analysts who've spent the past year or so saying " + + "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen."; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "technology", ft)); + doc.Add(new Field(booleanFieldName, "false", ft)); + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "More than 400 million people trust Google with their e-mail, and 50 million store files" + + " in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " + + "generally transfer or store huge volumes of personal data online."; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "technology", ft)); + doc.Add(new Field(booleanFieldName, "false", ft)); + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "unlabeled doc"; + doc.Add(new Field(textFieldName, text, ft)); + indexWriter.AddDocument(doc, analyzer); + + indexWriter.Commit(); + } + + protected void CheckPerformance(IClassifier classifier, Analyzer analyzer, String classFieldName) + { + AtomicReader atomicReader = null; + var stopwatch = new Stopwatch(); + stopwatch.Start(); + try + { + PopulatePerformanceIndex(analyzer); + atomicReader = SlowCompositeReaderWrapper.Wrap(indexWriter.Reader); + classifier.Train(atomicReader, textFieldName, classFieldName, analyzer); + stopwatch.Stop(); + long trainTime = stopwatch.ElapsedMilliseconds; + IsTrue(trainTime < 120000, "training took more than 2 mins : " + trainTime / 1000 + "s"); + } + finally + { + if (atomicReader != null) + atomicReader.Dispose(); + } + } + + private void PopulatePerformanceIndex(Analyzer analyzer) + { + indexWriter.DeleteAll(); + indexWriter.Commit(); + + FieldType ft = new FieldType(TextField.TYPE_STORED); + ft.StoreTermVectors = true; + ft.StoreTermVectorOffsets = true; + ft.StoreTermVectorPositions = true; + int docs = 1000; + Random random = new Random(); + for (int i = 0; i < docs; i++) + { + Boolean b = random.NextBoolean(); + Document doc = new Document(); + doc.Add(new Field(textFieldName, createRandomString(random), ft)); + doc.Add(new Field(categoryFieldName, b ? "technology" : "politics", ft)); + doc.Add(new Field(booleanFieldName, b.ToString(), ft)); + indexWriter.AddDocument(doc, analyzer); + } + indexWriter.Commit(); + } + + private String createRandomString(Random random) + { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < 20; i++) + { + builder.Append(TestUtil.RandomSimpleString(random, 5)); + builder.Append(" "); + } + return builder.ToString(); + } + + private void UpdateSampleIndex(Analyzer analyzer) + { + String text; + + Document doc = new Document(); + text = "Warren Bennis says John F. Kennedy grasped a key lesson about the presidency that few have followed."; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "politics", ft)); + doc.Add(new Field(booleanFieldName, "true", ft)); + + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "Julian Zelizer says Bill Clinton is still trying to shape his party, years after the White House, while George W. Bush opts for a much more passive role."; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "politics", ft)); + doc.Add(new Field(booleanFieldName, "true", ft)); + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "Crossfire: Sen. Tim Scott passes on Sen. Lindsey Graham endorsement"; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "politics", ft)); + doc.Add(new Field(booleanFieldName, "true", ft)); + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "Illinois becomes 16th state to allow same-sex marriage."; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "politics", ft)); + doc.Add(new Field(booleanFieldName, "true", ft)); + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "Apple is developing iPhones with curved-glass screens and enhanced sensors that detect different levels of pressure, according to a new report."; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "technology", ft)); + doc.Add(new Field(booleanFieldName, "false", ft)); + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "The Xbox One is Microsoft's first new gaming console in eight years. It's a quality piece of hardware but it's also noteworthy because Microsoft is using it to make a statement."; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "technology", ft)); + doc.Add(new Field(booleanFieldName, "false", ft)); + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "Google says it will replace a Google Maps image after a California father complained it shows the body of his teen-age son, who was shot to death in 2009."; + doc.Add(new Field(textFieldName, text, ft)); + doc.Add(new Field(categoryFieldName, "technology", ft)); + doc.Add(new Field(booleanFieldName, "false", ft)); + indexWriter.AddDocument(doc, analyzer); + + doc = new Document(); + text = "second unlabeled doc"; + doc.Add(new Field(textFieldName, text, ft)); + indexWriter.AddDocument(doc, analyzer); + + indexWriter.Commit(); + } + } +} \ No newline at end of file diff --git a/src/Lucene.Net.Tests.Classification/KNearestNeighborClassifierTest.cs b/src/Lucene.Net.Tests.Classification/KNearestNeighborClassifierTest.cs new file mode 100644 index 0000000000..0a69365587 --- /dev/null +++ b/src/Lucene.Net.Tests.Classification/KNearestNeighborClassifierTest.cs @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using Lucene.Net.Analysis; +using Lucene.Net.Index; +using Lucene.Net.Search; +using Lucene.Net.Util; +using NUnit.Framework; + +namespace Lucene.Net.Classification +{ + /** + * Testcase for {@link KNearestNeighborClassifier} + */ + public class KNearestNeighborClassifierTest : ClassificationTestBase + { + [Test] + public void TestBasicUsage() + { + // usage with default MLT min docs / term freq + CheckCorrectClassification(new KNearestNeighborClassifier(3), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(Random()), textFieldName, categoryFieldName); + // usage without custom min docs / term freq for MLT + CheckCorrectClassification(new KNearestNeighborClassifier(3, 2, 1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(Random()), textFieldName, categoryFieldName); + } + + [Test] + public void TestBasicUsageWithQuery() + { + CheckCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(Random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it"))); + } + + [Test] + public void TestPerformance() + { + CheckPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(Random()), categoryFieldName); + } + } +} \ No newline at end of file diff --git a/src/Lucene.Net.Tests.Classification/Lucene.Net.Tests.Classification.csproj b/src/Lucene.Net.Tests.Classification/Lucene.Net.Tests.Classification.csproj new file mode 100644 index 0000000000..164c8f50cc --- /dev/null +++ b/src/Lucene.Net.Tests.Classification/Lucene.Net.Tests.Classification.csproj @@ -0,0 +1,75 @@ + + + + + Debug + AnyCPU + {4D77E491-F50F-4A0C-9BD9-F9AB655720AD} + Library + Properties + Lucene.Net.Tests.Classification + Lucene.Net.Tests.Classification + v4.5.1 + 512 + + + true + full + false + bin\Debug\ + DEBUG;TRACE + prompt + 4 + + + pdbonly + true + bin\Release\ + TRACE + prompt + 4 + + + + ..\Lucene.Net.Tests\bin\Debug\Lucene.Net.TestFramework.dll + + + False + ..\Lucene.Net.Classification\packages\NUnit.2.6.4\lib\nunit.framework.dll + + + + + + + + + + + + + + + + + + {E067B8BB-D8E7-4040-BEB8-EFF8BB4149BD} + Lucene.Net.Classification + + + {5D4AD9BE-1FFB-41AB-9943-25737971BF57} + Lucene.Net + + + + + + + + \ No newline at end of file diff --git a/src/Lucene.Net.Tests.Classification/Properties/AssemblyInfo.cs b/src/Lucene.Net.Tests.Classification/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..e1a47e214f --- /dev/null +++ b/src/Lucene.Net.Tests.Classification/Properties/AssemblyInfo.cs @@ -0,0 +1,36 @@ +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("Lucene.Net.Tests.Classification")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Lucene.Net.Tests.Classification")] +[assembly: AssemblyCopyright("Copyright © 2014")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("253246a8-7b09-4251-ab4c-7971d3b2be4a")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("1.0.*")] +[assembly: AssemblyVersion("1.0.0.0")] +[assembly: AssemblyFileVersion("1.0.0.0")] diff --git a/src/Lucene.Net.Tests.Classification/SimpleNaiveBayesClassifierTest.cs b/src/Lucene.Net.Tests.Classification/SimpleNaiveBayesClassifierTest.cs new file mode 100644 index 0000000000..7f6d3a0acc --- /dev/null +++ b/src/Lucene.Net.Tests.Classification/SimpleNaiveBayesClassifierTest.cs @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using Lucene.Net.Analysis; +using Lucene.Net.Index; +using Lucene.Net.Search; +using Lucene.Net.Util; +using NUnit.Framework; + +namespace Lucene.Net.Classification +{ + /** + * Testcase for {@link SimpleNaiveBayesClassifier} + */ + // TODO : eventually remove this if / when fallback methods exist for all un-supportable codec methods (see LUCENE-4872) + // [Util.LuceneTestCase.SuppressCodecs("Lucene3x")] TODO : seems like we lost ability to pass in params in SupressCodecs constructor + [SuppressCodecs] + public class SimpleNaiveBayesClassifierTest : ClassificationTestBase + { + [Test] + public void TestBasicUsage() + { + CheckCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(Random()), textFieldName, categoryFieldName); + CheckCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(Random()), textFieldName, categoryFieldName); + } + + [Test] + public void TestBasicUsageWithQuery() + { + CheckCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(Random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it"))); + } + + [Test] + [Ignore("Need to figure out what to do with NGramAnalyzer, issues with things in Analysis.Common project")] + public void TestNGramUsage() + { + //CheckCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName); + } + + //private class NGramAnalyzer : Analyzer + //{ + // public override TokenStreamComponents CreateComponents(String fieldName, TextReader reader) + // { + // Tokenizer tokenizer = new KeywordTokenizer(reader); + // return new TokenStreamComponents(tokenizer, new ReverseStringFilter(TEST_VERSION_CURRENT, new EdgeNGramTokenFilter(TEST_VERSION_CURRENT, new ReverseStringFilter(TEST_VERSION_CURRENT, tokenizer), 10, 20))); + // } + //} + + [Test] + public void TestPerformance() + { + CheckPerformance(new SimpleNaiveBayesClassifier(), new MockAnalyzer(Random()), categoryFieldName); + } + } +} \ No newline at end of file diff --git a/src/Lucene.Net.Tests.Classification/app.config b/src/Lucene.Net.Tests.Classification/app.config new file mode 100644 index 0000000000..6feedb05ff --- /dev/null +++ b/src/Lucene.Net.Tests.Classification/app.config @@ -0,0 +1,11 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/src/Lucene.Net.Tests.Classification/packages.config b/src/Lucene.Net.Tests.Classification/packages.config new file mode 100644 index 0000000000..b25e5bda35 --- /dev/null +++ b/src/Lucene.Net.Tests.Classification/packages.config @@ -0,0 +1,4 @@ + + + + \ No newline at end of file From 3134b63c48366fd6f56b87dc984f0f97a33eb7fb Mon Sep 17 00:00:00 2001 From: Laimonas Simutis Date: Tue, 23 Dec 2014 21:16:07 -0500 Subject: [PATCH 2/3] add Util\DataSplitter and corresponding tests --- .../Lucene.Net.Classification.csproj | 1 + .../Utils/DatasetSplitter.cs | 150 ++++++++++++++++++ .../Lucene.Net.Tests.Classification.csproj | 2 + .../Utils/DataSplitterTest.cs | 145 +++++++++++++++++ 4 files changed, 298 insertions(+) create mode 100644 src/Lucene.Net.Classification/Utils/DatasetSplitter.cs create mode 100644 src/Lucene.Net.Tests.Classification/Utils/DataSplitterTest.cs diff --git a/src/Lucene.Net.Classification/Lucene.Net.Classification.csproj b/src/Lucene.Net.Classification/Lucene.Net.Classification.csproj index 8d31ed5c41..cbefa2cd04 100644 --- a/src/Lucene.Net.Classification/Lucene.Net.Classification.csproj +++ b/src/Lucene.Net.Classification/Lucene.Net.Classification.csproj @@ -44,6 +44,7 @@ + diff --git a/src/Lucene.Net.Classification/Utils/DatasetSplitter.cs b/src/Lucene.Net.Classification/Utils/DatasetSplitter.cs new file mode 100644 index 0000000000..e5c64e900e --- /dev/null +++ b/src/Lucene.Net.Classification/Utils/DatasetSplitter.cs @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.IO; +using Lucene.Net.Analysis; +using Lucene.Net.Documents; +using Lucene.Net.Index; +using Lucene.Net.Search; +using Directory = Lucene.Net.Store.Directory; + +namespace Lucene.Net.Classification.Utils +{ + /** + * Utility class for creating training / test / cross validation indexes from the original index. + */ + public class DatasetSplitter + { + + private readonly double _crossValidationRatio; + private readonly double _testRatio; + + /** + * Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes + * + * @param testRatio the ratio of the original index to be used for the test IDX as a double between 0.0 and 1.0 + * @param crossValidationRatio the ratio of the original index to be used for the c.v. IDX as a double between 0.0 and 1.0 + */ + public DatasetSplitter(double testRatio, double crossValidationRatio) + { + this._crossValidationRatio = crossValidationRatio; + this._testRatio = testRatio; + } + + /** + * Split a given index into 3 indexes for training, test and cross validation tasks respectively + * + * @param originalIndex an {@link AtomicReader} on the source index + * @param trainingIndex a {@link Directory} used to write the training index + * @param testIndex a {@link Directory} used to write the test index + * @param crossValidationIndex a {@link Directory} used to write the cross validation index + * @param analyzer {@link Analyzer} used to create the new docs + * @param fieldNames names of fields that need to be put in the new indexes or null if all should be used + * @throws IOException if any writing operation fails on any of the indexes + */ + public void Split(AtomicReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, Analyzer analyzer, params string[] fieldNames) + { + // create IWs for train / test / cv IDXs + IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(Util.Version.LUCENE_CURRENT, analyzer)); + IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(Util.Version.LUCENE_CURRENT, analyzer)); + IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(Util.Version.LUCENE_CURRENT, analyzer)); + + try + { + int size = originalIndex.MaxDoc; + + IndexSearcher indexSearcher = new IndexSearcher(originalIndex); + TopDocs topDocs = indexSearcher.Search(new MatchAllDocsQuery(), Int32.MaxValue); + + // set the type to be indexed, stored, with term vectors + FieldType ft = new FieldType(TextField.TYPE_STORED); + ft.StoreTermVectors = true; + ft.StoreTermVectorOffsets = true; + ft.StoreTermVectorPositions = true; + + int b = 0; + + // iterate over existing documents + foreach (ScoreDoc scoreDoc in topDocs.ScoreDocs) + { + // create a new document for indexing + Document doc = new Document(); + if (fieldNames != null && fieldNames.Length > 0) + { + foreach (String fieldName in fieldNames) + { + doc.Add(new Field(fieldName, originalIndex.Document(scoreDoc.Doc).GetField(fieldName).ToString(), ft)); + } + } + else + { + foreach (IndexableField storableField in originalIndex.Document(scoreDoc.Doc).Fields) + { + if (storableField.ReaderValue != null) + { + doc.Add(new Field(storableField.Name(), storableField.ReaderValue, ft)); + } + else if (storableField.BinaryValue() != null) + { + doc.Add(new Field(storableField.Name(), storableField.BinaryValue(), ft)); + } + else if (storableField.StringValue != null) + { + doc.Add(new Field(storableField.Name(), storableField.StringValue, ft)); + } + else if (storableField.NumericValue != null) + { + doc.Add(new Field(storableField.Name(), storableField.NumericValue.ToString(), ft)); + } + } + } + + // add it to one of the IDXs + if (b % 2 == 0 && testWriter.MaxDoc < size * _testRatio) + { + testWriter.AddDocument(doc); + } + else if (cvWriter.MaxDoc < size * _crossValidationRatio) + { + cvWriter.AddDocument(doc); + } + else + { + trainingWriter.AddDocument(doc); + } + b++; + } + } + catch (Exception e) + { + throw new IOException("Exceptio in DatasetSplitter", e); + } + finally + { + testWriter.Commit(); + cvWriter.Commit(); + trainingWriter.Commit(); + // close IWs + testWriter.Dispose(); + cvWriter.Dispose(); + trainingWriter.Dispose(); + } + } + + } +} \ No newline at end of file diff --git a/src/Lucene.Net.Tests.Classification/Lucene.Net.Tests.Classification.csproj b/src/Lucene.Net.Tests.Classification/Lucene.Net.Tests.Classification.csproj index 164c8f50cc..693aacd7cd 100644 --- a/src/Lucene.Net.Tests.Classification/Lucene.Net.Tests.Classification.csproj +++ b/src/Lucene.Net.Tests.Classification/Lucene.Net.Tests.Classification.csproj @@ -50,6 +50,7 @@ + @@ -64,6 +65,7 @@ +