Permalink
| // Accord Machine Learning Library | |
| // The Accord.NET Framework | |
| // http://accord-framework.net | |
| // | |
| // Copyright © César Souza, 2009-2017 | |
| // cesarsouza at gmail.com | |
| // | |
| // This library is free software; you can redistribute it and/or | |
| // modify it under the terms of the GNU Lesser General Public | |
| // License as published by the Free Software Foundation; either | |
| // version 2.1 of the License, or (at your option) any later version. | |
| // | |
| // This library is distributed in the hope that it will be useful, | |
| // but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | |
| // Lesser General Public License for more details. | |
| // | |
| // You should have received a copy of the GNU Lesser General Public | |
| // License along with this library; if not, write to the Free Software | |
| // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA | |
| // | |
| namespace Accord.MachineLearning.DecisionTrees | |
| { | |
| using System; | |
| using System.Collections.Generic; | |
| using System.IO; | |
| using System.Linq.Expressions; | |
| using System.Reflection; | |
| using System.Reflection.Emit; | |
| using System.Runtime.Serialization; | |
| using System.Runtime.Serialization.Formatters.Binary; | |
| using Accord.MachineLearning.DecisionTrees.Rules; | |
| using Accord.Math; | |
| using Learning; | |
| using Statistics.Filters; | |
| /// <summary> | |
| /// Decision tree (for both discrete and continuous classification problems). | |
| /// </summary> | |
| /// | |
| /// <remarks> | |
| /// <para> | |
| /// Represents a decision tree which can be compiled to code at run-time. For sample usage | |
| /// and example of learning, please see the documentation pages for the <see cref="ID3Learning"> | |
| /// ID3</see> and <see cref="C45Learning">C4.5 learning algorithms</see>.</para> | |
| /// | |
| /// <para> | |
| /// It is also possible to create <see cref="RandomForest">random forests</see> using | |
| /// the <see cref="RandomForestLearning">random forest learning algorithm</see>.</para> | |
| /// </remarks> | |
| /// | |
| /// <example> | |
| /// <para> | |
| /// This example shows the simplest way to induce a decision tree with discrete variables.</para> | |
| /// <code source="Unit Tests\Accord.Tests.MachineLearning\DecisionTrees\ID3LearningTest.cs" region="doc_learn_simplest" /> | |
| /// | |
| ///<para> | |
| /// This example shows a common textbook example, and how to induce a decision tree using a | |
| /// <see cref="Codification">codebook</see> to convert string (text) variables into discrete symbols.</para> | |
| /// <code source="Unit Tests\Accord.Tests.MachineLearning\DecisionTrees\ID3LearningTest.cs" region="doc_learn_mitchell" /> | |
| /// <para>For more examples with discrete variables, please see <see cref="ID3Learning"/></para> | |
| /// | |
| /// <para> | |
| /// This example shows the simplest way to induce a decision tree with continuous variables.</para> | |
| /// <code source="Unit Tests\Accord.Tests.MachineLearning\DecisionTrees\C45LearningTest.cs" region="doc_simplest" /> | |
| /// <para>For more examples with continuous variables, please see <see cref="C45Learning"/></para> | |
| /// </example> | |
| /// | |
| /// <seealso cref="ID3Learning"/> | |
| /// <seealso cref="C45Learning"/> | |
| /// <seealso cref="RandomForestLearning"/> | |
| /// | |
| [Serializable] | |
| [SerializationBinder(typeof(DecisionTree.DecisionTreeBinder))] | |
| public class DecisionTree : MulticlassClassifierBase, IEnumerable<DecisionNode> | |
| { | |
| private DecisionNode root; | |
| private DecisionVariableCollection attributes; | |
| /// <summary> | |
| /// Gets or sets the root node for this tree. | |
| /// </summary> | |
| /// | |
| public DecisionNode Root | |
| { | |
| get { return root; } | |
| set { root = value; } | |
| } | |
| /// <summary> | |
| /// Gets the collection of attributes processed by this tree. | |
| /// </summary> | |
| /// | |
| public DecisionVariableCollection Attributes | |
| { | |
| get { return attributes; } | |
| } | |
| /// <summary> | |
| /// Creates a new <see cref="DecisionTree"/> to process | |
| /// the given <paramref name="inputs"/> and the given | |
| /// number of possible <paramref name="classes"/>. | |
| /// </summary> | |
| /// | |
| /// <param name="inputs">An array specifying the attributes to be processed by this tree.</param> | |
| /// <param name="classes">The number of possible output classes for the given attributes.</param> | |
| /// | |
| public DecisionTree(IList<DecisionVariable> inputs, int classes) | |
| { | |
| if (classes <= 0) | |
| throw new ArgumentOutOfRangeException("classes"); | |
| if (inputs == null) | |
| throw new ArgumentNullException("inputs"); | |
| for (int i = 0; i < inputs.Count; i++) | |
| if (inputs[i].Range.Length == 0) | |
| throw new ArgumentException("Attribute " + i + " is a constant."); | |
| this.attributes = new DecisionVariableCollection(inputs); | |
| this.NumberOfInputs = inputs.Count; | |
| this.NumberOfOutputs = classes; | |
| this.NumberOfClasses = classes; | |
| } | |
| /// <summary> | |
| /// Computes the tree decision for a given input. | |
| /// </summary> | |
| /// | |
| /// <param name="input">The input data.</param> | |
| /// | |
| /// <returns>A predicted class for the given input.</returns> | |
| /// | |
| public override int Decide(double[] input) | |
| { | |
| return decide(input, Root); | |
| } | |
| /// <summary> | |
| /// Computes the tree decision for a given input. | |
| /// </summary> | |
| /// | |
| /// <param name="input">The input data.</param> | |
| /// | |
| /// <returns>A predicted class for the given input.</returns> | |
| /// | |
| public override int Decide(int[] input) | |
| { | |
| return decide(input, Root); | |
| } | |
| /// <summary> | |
| /// Computes the tree decision for a given input. | |
| /// </summary> | |
| /// | |
| /// <param name="input">The input data.</param> | |
| /// <param name="subtree">The node where the decision starts.</param> | |
| /// | |
| /// <returns>A predicted class for the given input.</returns> | |
| /// | |
| public int Decide(double[] input, DecisionNode subtree) | |
| { | |
| if (subtree == null) | |
| throw new ArgumentNullException("subtree"); | |
| if (subtree.Owner != this) | |
| throw new ArgumentException("The node does not belong to this tree.", "subtree"); | |
| return decide(input, subtree); | |
| } | |
| private static int decide(double[] input, DecisionNode subtree) | |
| { | |
| DecisionNode current = subtree; | |
| // Start reasoning | |
| while (current != null) | |
| { | |
| // Check if this is a leaf | |
| if (current.IsLeaf) | |
| { | |
| // This is a leaf node. The decision | |
| // process thus should stop here. | |
| return (current.Output.HasValue) ? current.Output.Value : -1; | |
| } | |
| // This node is not a leaf. Continue the | |
| // decision process following the children | |
| // Get the next attribute to guide reasoning | |
| int attribute = current.Branches.AttributeIndex; | |
| // Check which child is responsible for dealing | |
| // which the particular value of the attribute | |
| DecisionNode nextNode = null; | |
| foreach (DecisionNode branch in current.Branches) | |
| { | |
| if (branch.Compute(input[attribute])) | |
| { | |
| // This is the child node responsible for dealing | |
| // which this particular attribute value. Choose it | |
| // to continue reasoning. | |
| nextNode = branch; break; | |
| } | |
| } | |
| current = nextNode; | |
| } | |
| // Normal execution should not reach here. | |
| throw new InvalidOperationException("The tree is degenerated. This is often a sign that " | |
| + "the tree is expecting discrete inputs, but it was given only real values."); | |
| } | |
| private static int decide(int[] input, DecisionNode subtree) | |
| { | |
| DecisionNode current = subtree; | |
| // Start reasoning | |
| while (current != null) | |
| { | |
| // Check if this is a leaf | |
| if (current.IsLeaf) | |
| { | |
| // This is a leaf node. The decision | |
| // process thus should stop here. | |
| return (current.Output.HasValue) ? current.Output.Value : -1; | |
| } | |
| // This node is not a leaf. Continue the | |
| // decision process following the children | |
| // Get the next attribute to guide reasoning | |
| int attribute = current.Branches.AttributeIndex; | |
| // Check which child is responsible for dealing | |
| // which the particular value of the attribute | |
| DecisionNode nextNode = null; | |
| foreach (DecisionNode branch in current.Branches) | |
| { | |
| if (branch.Compute(input[attribute])) | |
| { | |
| // This is the child node responsible for dealing | |
| // which this particular attribute value. Choose it | |
| // to continue reasoning. | |
| nextNode = branch; break; | |
| } | |
| } | |
| current = nextNode; | |
| } | |
| // Normal execution should not reach here. | |
| throw new InvalidOperationException("The tree is degenerated. This is often a sign that " | |
| + "the tree is expecting discrete inputs, but it was given only real values."); | |
| } | |
| /// <summary> | |
| /// Returns an enumerator that iterates through the tree. | |
| /// </summary> | |
| /// | |
| /// <returns> | |
| /// An <see cref="T:System.Collections.IEnumerator"/> object that can be | |
| /// used to iterate through the collection. | |
| /// </returns> | |
| /// | |
| public IEnumerator<DecisionNode> GetEnumerator() | |
| { | |
| if (Root == null) | |
| yield break; | |
| var stack = new Stack<DecisionNode>(new[] { Root }); | |
| while (stack.Count != 0) | |
| { | |
| DecisionNode current = stack.Pop(); | |
| yield return current; | |
| if (current.Branches != null) | |
| for (int i = current.Branches.Count - 1; i >= 0; i--) | |
| stack.Push(current.Branches[i]); | |
| } | |
| } | |
| System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() | |
| { | |
| return GetEnumerator(); | |
| } | |
| /// <summary> | |
| /// Traverse the tree using a <see cref="DecisionTreeTraversal">tree | |
| /// traversal method</see>. Can be iterated with a foreach loop. | |
| /// </summary> | |
| /// | |
| /// <param name="method">The tree traversal method. Common methods are | |
| /// available in the <see cref="TreeTraversal"/>static class.</param> | |
| /// | |
| /// <returns>An <see cref="IEnumerable{T}"/> object which can be used to | |
| /// traverse the tree using the chosen traversal method.</returns> | |
| /// | |
| public IEnumerable<DecisionNode> Traverse(DecisionTreeTraversalMethod method) | |
| { | |
| return new TreeTraversal(Root, method); | |
| } | |
| /// <summary> | |
| /// Traverse a subtree using a <see cref="DecisionTreeTraversal">tree | |
| /// traversal method</see>. Can be iterated with a foreach loop. | |
| /// </summary> | |
| /// | |
| /// <param name="method">The tree traversal method. Common methods are | |
| /// available in the <see cref="TreeTraversal"/>static class.</param> | |
| /// <param name="subtree">The root of the subtree to be traversed.</param> | |
| /// | |
| /// <returns>An <see cref="IEnumerable{T}"/> object which can be used to | |
| /// traverse the tree using the chosen traversal method.</returns> | |
| /// | |
| public IEnumerable<DecisionNode> Traverse(DecisionTreeTraversalMethod method, DecisionNode subtree) | |
| { | |
| if (subtree.Owner != this) | |
| throw new ArgumentException("The node does not belong to this tree.", "subtree"); | |
| return new TreeTraversal(subtree, method); | |
| } | |
| /// <summary> | |
| /// Transforms the tree into a set of <see cref="DecisionSet">decision rules</see>. | |
| /// </summary> | |
| /// | |
| /// <returns>A <see cref="DecisionSet"/> created from this tree.</returns> | |
| /// | |
| public DecisionSet ToRules() | |
| { | |
| return DecisionSet.FromDecisionTree(this); | |
| } | |
| #if !NET35 | |
| /// <summary> | |
| /// Creates an <see cref="Expression">Expression Tree</see> representation | |
| /// of this decision tree, which can in turn be compiled into code. | |
| /// </summary> | |
| /// | |
| /// <returns>A tree in the form of an expression tree.</returns> | |
| /// | |
| public Expression<Func<double[], int>> ToExpression() | |
| { | |
| return new DecisionTreeExpressionCreator(this).Create(); | |
| } | |
| #if !NETSTANDARD2_0 | |
| /// <summary> | |
| /// Creates a .NET assembly (.dll) containing a static class of | |
| /// the given name implementing the decision tree. The class will | |
| /// contain a single static Compute method implementing the tree. | |
| /// </summary> | |
| /// | |
| /// <param name="assemblyName">The name of the assembly to generate.</param> | |
| /// <param name="className">The name of the generated static class.</param> | |
| /// | |
| public void ToAssembly(string assemblyName, string className) | |
| { | |
| ToAssembly(assemblyName, "Accord.MachineLearning.DecisionTrees.Custom", className); | |
| } | |
| /// <summary> | |
| /// Creates a .NET assembly (.dll) containing a static class of | |
| /// the given name implementing the decision tree. The class will | |
| /// contain a single static Compute method implementing the tree. | |
| /// </summary> | |
| /// | |
| /// <param name="assemblyName">The name of the assembly to generate.</param> | |
| /// <param name="moduleName">The namespace which should contain the class.</param> | |
| /// <param name="className">The name of the generated static class.</param> | |
| /// | |
| public void ToAssembly(string assemblyName, string moduleName, string className) | |
| { | |
| AssemblyBuilder da = AppDomain.CurrentDomain.DefineDynamicAssembly( | |
| new AssemblyName(assemblyName), AssemblyBuilderAccess.Save); | |
| ModuleBuilder dm = da.DefineDynamicModule(moduleName, assemblyName); | |
| TypeBuilder dt = dm.DefineType(className); | |
| MethodBuilder method = dt.DefineMethod("Compute", | |
| MethodAttributes.Public | MethodAttributes.Static); | |
| // Compile the tree into a method | |
| ToExpression().CompileToMethod(method); | |
| dt.CreateType(); | |
| da.Save(assemblyName); | |
| } | |
| #endif | |
| #endif | |
| /// <summary> | |
| /// Generates a C# class implementing the decision tree. | |
| /// </summary> | |
| /// | |
| /// <param name="className">The name for the generated class.</param> | |
| /// | |
| /// <returns>A string containing the generated class.</returns> | |
| /// | |
| public string ToCode(string className) | |
| { | |
| using (MemoryStream stream = new MemoryStream()) | |
| { | |
| TextWriter writer = new StreamWriter(stream); | |
| var treeWriter = new DecisionTreeWriter(writer); | |
| treeWriter.Write(this, className); | |
| writer.Flush(); | |
| stream.Seek(0, SeekOrigin.Begin); | |
| TextReader reader = new StreamReader(stream); | |
| return reader.ReadToEnd(); | |
| } | |
| } | |
| /// <summary> | |
| /// Generates a C# class implementing the decision tree. | |
| /// </summary> | |
| /// | |
| /// <param name="className">The name for the generated class.</param> | |
| /// <param name="writer">The <see cref="TextWriter"/> where the class should be written.</param> | |
| /// | |
| public void ToCode(TextWriter writer, string className) | |
| { | |
| var treeWriter = new DecisionTreeWriter(writer); | |
| treeWriter.Write(this, className); | |
| } | |
| /// <summary> | |
| /// Computes the height of the tree, defined as the | |
| /// greatest distance (in links) between the tree's | |
| /// root node and its leaves. | |
| /// </summary> | |
| /// | |
| /// <returns>The tree's height.</returns> | |
| /// | |
| public int GetHeight() | |
| { | |
| int maxHeight = 0; | |
| foreach (var node in this) | |
| { | |
| if (!node.IsLeaf) | |
| continue; | |
| int h = node.GetHeight(); | |
| if (h > maxHeight) | |
| maxHeight = h; | |
| } | |
| return maxHeight; | |
| } | |
| #region Obsolete | |
| /// <summary> | |
| /// Obsolete. Please use <see cref="Accord.IO.Serializer.Save{T}(T, string)"/> (or use it as an extension method). | |
| /// </summary> | |
| /// | |
| [Obsolete("Please use Accord.IO.Serializer.Save(path) instead (or use it as an extension method).")] | |
| public void Save(string path) | |
| { | |
| Accord.IO.Serializer.Save(this, path); | |
| } | |
| /// <summary> | |
| /// Obsolete. Please use <see cref="Accord.IO.Serializer.Save{T}(T, Stream, IO.SerializerCompression)"/> (or use it as an extension method). | |
| /// </summary> | |
| /// | |
| [Obsolete("Please use Accord.IO.Serializer.Save(stream) instead (or use it as an extension method).")] | |
| public void Save(Stream stream) | |
| { | |
| Accord.IO.Serializer.Save(this, stream); | |
| } | |
| /// <summary> | |
| /// Obsolete. Please use <see cref="Accord.IO.Serializer.Load{T}(Stream, IO.SerializerCompression)"/>. | |
| /// </summary> | |
| /// | |
| [Obsolete("Please use Accord.IO.Serializer.Load<DecisionTree>(stream) instead.")] | |
| public static DecisionTree Load(Stream stream) | |
| { | |
| return Accord.IO.Serializer.Load<DecisionTree>(stream); | |
| } | |
| /// <summary> | |
| /// Obsolete. Please use <see cref="Accord.IO.Serializer.Load{T}(string)"/>. | |
| /// </summary> | |
| /// | |
| [Obsolete("Please use Accord.IO.Serializer.Load<DecisionTree>(path) instead.")] | |
| public static DecisionTree Load(string path) | |
| { | |
| return Accord.IO.Serializer.Load<DecisionTree>(path); | |
| } | |
| /// <summary> | |
| /// Deprecated. Please use the NumberOfOutputs property instead. | |
| /// </summary> | |
| /// | |
| [Obsolete("Please use NumberOfOutputs instead.")] | |
| public int OutputClasses { get { return NumberOfOutputs; } } | |
| /// <summary> | |
| /// Deprecated. Please use the NumberOfInputs property. | |
| /// </summary> | |
| /// | |
| [Obsolete("Please use NumberOfInputs instead.")] | |
| public int InputCount { get { return NumberOfInputs; } } | |
| /// <summary> | |
| /// Deprecated. Please use the Decide() method instead. | |
| /// </summary> | |
| /// | |
| [Obsolete("Please use the Decide() method instead.")] | |
| public int Compute(int[] input) | |
| { | |
| return Decide(input); | |
| } | |
| /// <summary> | |
| /// Deprecated. Please use the Decide() method instead. | |
| /// </summary> | |
| /// | |
| [Obsolete("Please use the Decide() method instead.")] | |
| public int Compute(double[] input) | |
| { | |
| return Decide(input); | |
| } | |
| /// <summary> | |
| /// Deprecated. Please use the Decide() method instead. | |
| /// </summary> | |
| /// | |
| [Obsolete("Please use the Decide() method instead.")] | |
| public int[] Compute(int[][] input) | |
| { | |
| return Decide(input); | |
| } | |
| /// <summary> | |
| /// Deprecated. Please use the Decide() method instead. | |
| /// </summary> | |
| /// | |
| [Obsolete("Please use the Decide() method instead.")] | |
| public int Compute(double[] input, DecisionNode subtree) | |
| { | |
| return Decide(input, subtree); | |
| } | |
| #endregion | |
| private class TreeTraversal : IEnumerable<DecisionNode> | |
| { | |
| private DecisionNode tree; | |
| private DecisionTreeTraversalMethod method; | |
| public TreeTraversal(DecisionNode tree, DecisionTreeTraversalMethod method) | |
| { | |
| this.tree = tree; | |
| this.method = method; | |
| } | |
| public IEnumerator<DecisionNode> GetEnumerator() | |
| { | |
| return method(tree); | |
| } | |
| System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() | |
| { | |
| return method(tree); | |
| } | |
| } | |
| #region Serialization backwards compatibility | |
| internal class DecisionTreeBinder : SerializationBinder | |
| { | |
| public override Type BindToType(string assemblyName, string typeName) | |
| { | |
| AssemblyName name = new AssemblyName(assemblyName); | |
| if (name.Version < new Version(3, 1, 0)) | |
| { | |
| if (typeName == "Accord.MachineLearning.DecisionTrees.DecisionTree") | |
| return typeof(DecisionTree_2_13); | |
| if (typeName == "AForge.DoubleRange") | |
| return typeof(Accord.DoubleRange); | |
| } | |
| return null; | |
| } | |
| } | |
| #pragma warning disable 0169 | |
| #pragma warning disable 0649 | |
| [Serializable] | |
| class DecisionTree_2_13 | |
| { | |
| public DecisionNode Root { get; set; } | |
| public DecisionVariableCollection Attributes { get; private set; } | |
| public int OutputClasses { get; private set; } | |
| public int InputCount { get; private set; } | |
| public static implicit operator DecisionTree(DecisionTree_2_13 obj) | |
| { | |
| var tree = new DecisionTree(obj.Attributes, obj.OutputClasses); | |
| tree.Root = obj.Root; | |
| foreach (DecisionNode node in tree) | |
| { | |
| if (node.Owner == null) | |
| node.Owner = tree; | |
| } | |
| return tree; | |
| } | |
| } | |
| #pragma warning restore 0169 | |
| #pragma warning restore 0649 | |
| #endregion | |
| } | |
| } | |