diff --git a/Algorithms.Tests/MachineLearning/LogisticRegressionTests.cs b/Algorithms.Tests/MachineLearning/LogisticRegressionTests.cs new file mode 100644 index 00000000..6383a4a6 --- /dev/null +++ b/Algorithms.Tests/MachineLearning/LogisticRegressionTests.cs @@ -0,0 +1,65 @@ +using NUnit.Framework; +using Algorithms.MachineLearning; +using System; + +namespace Algorithms.Tests.MachineLearning; + +[TestFixture] +public class LogisticRegressionTests +{ + [Test] + public void Fit_ThrowsOnEmptyInput() + { + var model = new LogisticRegression(); + Assert.Throws(() => model.Fit(Array.Empty(), Array.Empty())); + } + + [Test] + public void Fit_ThrowsOnMismatchedLabels() + { + var model = new LogisticRegression(); + double[][] X = { new double[] { 1, 2 } }; + int[] y = { 1, 0 }; + Assert.Throws(() => model.Fit(X, y)); + } + + [Test] + public void FitAndPredict_WorksOnSimpleData() + { + // Simple AND logic + double[][] X = + { + new[] { 0.0, 0.0 }, + new[] { 0.0, 1.0 }, + new[] { 1.0, 0.0 }, + new[] { 1.0, 1.0 } + }; + int[] y = { 0, 0, 0, 1 }; + var model = new LogisticRegression(); + model.Fit(X, y, epochs: 2000, learningRate: 0.1); + Assert.That(model.Predict(new double[] { 0, 0 }), Is.EqualTo(0)); + Assert.That(model.Predict(new double[] { 0, 1 }), Is.EqualTo(0)); + Assert.That(model.Predict(new double[] { 1, 0 }), Is.EqualTo(0)); + Assert.That(model.Predict(new double[] { 1, 1 }), Is.EqualTo(1)); + } + + [Test] + public void PredictProbability_ThrowsOnFeatureMismatch() + { + var model = new LogisticRegression(); + double[][] X = { new double[] { 1, 2 } }; + int[] y = { 1 }; + model.Fit(X, y); + Assert.Throws(() => model.PredictProbability(new double[] { 1 })); + } + + [Test] + public void FeatureCount_ReturnsCorrectValue() + { + var model = new LogisticRegression(); + double[][] X = { new double[] { 1, 2, 3 } }; + int[] y = { 1 }; + model.Fit(X, y); + Assert.That(model.FeatureCount, Is.EqualTo(3)); + } +} diff --git a/Algorithms/MachineLearning/LogisticRegression.cs b/Algorithms/MachineLearning/LogisticRegression.cs new file mode 100644 index 00000000..3bf6bdb6 --- /dev/null +++ b/Algorithms/MachineLearning/LogisticRegression.cs @@ -0,0 +1,87 @@ +using System; +using System.Linq; + +namespace Algorithms.MachineLearning; + +/// +/// Logistic Regression for binary classification. +/// +public class LogisticRegression +{ + private double[] weights = []; + private double bias; + + public int FeatureCount => weights.Length; + + /// + /// Fit the model using gradient descent. + /// + /// 2D array of features (samples x features). + /// Array of labels (0 or 1). + /// Number of iterations. + /// Step size. + public void Fit(double[][] x, int[] y, int epochs = 1000, double learningRate = 0.01) + { + if (x.Length == 0 || x[0].Length == 0) + { + throw new ArgumentException("Input features cannot be empty."); + } + + if (x.Length != y.Length) + { + throw new ArgumentException("Number of samples and labels must match."); + } + + int nSamples = x.Length; + int nFeatures = x[0].Length; + weights = new double[nFeatures]; + bias = 0; + + for (int epoch = 0; epoch < epochs; epoch++) + { + double[] dw = new double[nFeatures]; + double db = 0; + for (int i = 0; i < nSamples; i++) + { + double linear = Dot(x[i], weights) + bias; + double pred = Sigmoid(linear); + double error = pred - y[i]; + for (int j = 0; j < nFeatures; j++) + { + dw[j] += error * x[i][j]; + } + + db += error; + } + + for (int j = 0; j < nFeatures; j++) + { + weights[j] -= learningRate * dw[j] / nSamples; + } + + bias -= learningRate * db / nSamples; + } + } + + /// + /// Predict probability for a single sample. + /// + public double PredictProbability(double[] x) + { + if (x.Length != weights.Length) + { + throw new ArgumentException("Feature count mismatch."); + } + + return Sigmoid(Dot(x, weights) + bias); + } + + /// + /// Predict class label (0 or 1) for a single sample. + /// + public int Predict(double[] x) => PredictProbability(x) >= 0.5 ? 1 : 0; + + private static double Sigmoid(double z) => 1.0 / (1.0 + Math.Exp(-z)); + + private static double Dot(double[] a, double[] b) => a.Zip(b).Sum(pair => pair.First * pair.Second); +} diff --git a/README.md b/README.md index 4d92562d..718a694e 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,7 @@ find more than one implementation for the same objective but using different alg * [CollaborativeFiltering](./Algorithms/RecommenderSystem/CollaborativeFiltering) * [Machine Learning](./Algorithms/MachineLearning) * [Linear Regression](./Algorithms/MachineLearning/LinearRegression.cs) + * [Logistic Regression](./Algorithms/MachineLearning/LogisticRegression.cs) * [Searches](./Algorithms/Search) * [A-Star](./Algorithms/Search/AStar/) * [Binary Search](./Algorithms/Search/BinarySearcher.cs)