From 82f060bbde5c8bfabda6e26811c5c2198123f8ec Mon Sep 17 00:00:00 2001 From: cn1010 <1062108372@qq.com> Date: Tue, 14 Jan 2020 16:07:16 +0800 Subject: [PATCH] Enhance security (1) (#1403) * inputs validation * more validation * recover constructor in ECPoint * format * minor changes * format * changes based on review * recover validation in AesEncrypt --- src/neo/Cryptography/Base58.cs | 1 + src/neo/Cryptography/BloomFilter.cs | 2 + src/neo/Cryptography/ECC/ECDsa.cs | 107 ------------------ src/neo/Cryptography/ECC/ECFieldElement.cs | 5 +- src/neo/Cryptography/ECC/ECPoint.cs | 3 +- src/neo/Cryptography/MerkleTree.cs | 2 +- .../Cryptography/ECC/UT_ECDsa.cs | 55 --------- .../Cryptography/ECC/UT_ECFieldElement.cs | 24 ++++ .../Cryptography/ECC/UT_ECPoint.cs | 18 ++- 9 files changed, 41 insertions(+), 176 deletions(-) delete mode 100644 src/neo/Cryptography/ECC/ECDsa.cs delete mode 100644 tests/neo.UnitTests/Cryptography/ECC/UT_ECDsa.cs diff --git a/src/neo/Cryptography/Base58.cs b/src/neo/Cryptography/Base58.cs index 316f1de72b..f571f9540d 100644 --- a/src/neo/Cryptography/Base58.cs +++ b/src/neo/Cryptography/Base58.cs @@ -12,6 +12,7 @@ public static class Base58 public static byte[] Base58CheckDecode(this string input) { + if (input is null) throw new ArgumentNullException(nameof(input)); byte[] buffer = Decode(input); if (buffer.Length < 4) throw new FormatException(); byte[] checksum = buffer.Sha256(0, buffer.Length - 4).Sha256(); diff --git a/src/neo/Cryptography/BloomFilter.cs b/src/neo/Cryptography/BloomFilter.cs index c4b79d01f2..32cfa50dd4 100644 --- a/src/neo/Cryptography/BloomFilter.cs +++ b/src/neo/Cryptography/BloomFilter.cs @@ -1,3 +1,4 @@ +using System; using System.Collections; using System.Linq; @@ -16,6 +17,7 @@ public class BloomFilter public BloomFilter(int m, int k, uint nTweak, byte[] elements = null) { + if (k < 0 || m < 0) throw new ArgumentOutOfRangeException(); this.seeds = Enumerable.Range(0, k).Select(p => (uint)p * 0xFBA4C795 + nTweak).ToArray(); this.bits = elements == null ? new BitArray(m) : new BitArray(elements); this.bits.Length = m; diff --git a/src/neo/Cryptography/ECC/ECDsa.cs b/src/neo/Cryptography/ECC/ECDsa.cs deleted file mode 100644 index 08f48b1dc3..0000000000 --- a/src/neo/Cryptography/ECC/ECDsa.cs +++ /dev/null @@ -1,107 +0,0 @@ -using System; -using System.Numerics; -using System.Security.Cryptography; - -namespace Neo.Cryptography.ECC -{ - public class ECDsa - { - private readonly byte[] privateKey; - private readonly ECPoint publicKey; - private readonly ECCurve curve; - - public ECDsa(byte[] privateKey, ECCurve curve) - : this(curve.G * privateKey) - { - this.privateKey = privateKey; - } - - public ECDsa(ECPoint publicKey) - { - this.publicKey = publicKey; - this.curve = publicKey.Curve; - } - - private BigInteger CalculateE(BigInteger n, ReadOnlySpan message) - { - int messageBitLength = message.Length * 8; - BigInteger trunc = new BigInteger(message, isUnsigned: true, isBigEndian: true); - if (n.GetBitLength() < messageBitLength) - { - trunc >>= messageBitLength - n.GetBitLength(); - } - return trunc; - } - - public BigInteger[] GenerateSignature(ReadOnlySpan message) - { - if (privateKey == null) throw new InvalidOperationException(); - BigInteger e = CalculateE(curve.N, message); - BigInteger d = new BigInteger(privateKey, isUnsigned: true, isBigEndian: true); - BigInteger r, s; - using (RandomNumberGenerator rng = RandomNumberGenerator.Create()) - { - do - { - BigInteger k; - do - { - do - { - k = rng.NextBigInteger(curve.N.GetBitLength()); - } - while (k.Sign == 0 || k.CompareTo(curve.N) >= 0); - ECPoint p = ECPoint.Multiply(curve.G, k); - BigInteger x = p.X.Value; - r = x.Mod(curve.N); - } - while (r.Sign == 0); - s = (k.ModInverse(curve.N) * (e + d * r)).Mod(curve.N); - if (s > curve.N / 2) - { - s = curve.N - s; - } - } - while (s.Sign == 0); - } - return new BigInteger[] { r, s }; - } - - private static ECPoint SumOfTwoMultiplies(ECPoint P, BigInteger k, ECPoint Q, BigInteger l) - { - int m = Math.Max(k.GetBitLength(), l.GetBitLength()); - ECPoint Z = P + Q; - ECPoint R = P.Curve.Infinity; - for (int i = m - 1; i >= 0; --i) - { - R = R.Twice(); - if (k.TestBit(i)) - { - if (l.TestBit(i)) - R = R + Z; - else - R = R + P; - } - else - { - if (l.TestBit(i)) - R = R + Q; - } - } - return R; - } - - public bool VerifySignature(ReadOnlySpan message, BigInteger r, BigInteger s) - { - if (r.Sign < 1 || s.Sign < 1 || r.CompareTo(curve.N) >= 0 || s.CompareTo(curve.N) >= 0) - return false; - BigInteger e = CalculateE(curve.N, message); - BigInteger c = s.ModInverse(curve.N); - BigInteger u1 = (e * c).Mod(curve.N); - BigInteger u2 = (r * c).Mod(curve.N); - ECPoint point = SumOfTwoMultiplies(curve.G, u1, publicKey, u2); - BigInteger v = point.X.Value.Mod(curve.N); - return v.Equals(r); - } - } -} diff --git a/src/neo/Cryptography/ECC/ECFieldElement.cs b/src/neo/Cryptography/ECC/ECFieldElement.cs index 948afc263d..9ebd48082d 100644 --- a/src/neo/Cryptography/ECC/ECFieldElement.cs +++ b/src/neo/Cryptography/ECC/ECFieldElement.cs @@ -10,6 +10,8 @@ internal class ECFieldElement : IComparable, IEquatable= curve.Q) throw new ArgumentException("x value too large in field element"); this.Value = value; @@ -19,6 +21,7 @@ public ECFieldElement(BigInteger value, ECCurve curve) public int CompareTo(ECFieldElement other) { if (ReferenceEquals(this, other)) return 0; + if (!curve.Equals(other.curve)) throw new InvalidOperationException("Invalid comparision for points with different curves"); return Value.CompareTo(other.Value); } @@ -35,7 +38,7 @@ public override bool Equals(object obj) public bool Equals(ECFieldElement other) { - return Value.Equals(other.Value); + return Value.Equals(other.Value) && curve.Equals(other.curve); } private static BigInteger[] FastLucasSequence(BigInteger p, BigInteger P, BigInteger Q, BigInteger k) diff --git a/src/neo/Cryptography/ECC/ECPoint.cs b/src/neo/Cryptography/ECC/ECPoint.cs index 4189dcbb9b..f0b0b19a36 100644 --- a/src/neo/Cryptography/ECC/ECPoint.cs +++ b/src/neo/Cryptography/ECC/ECPoint.cs @@ -22,7 +22,7 @@ public bool IsInfinity internal ECPoint(ECFieldElement x, ECFieldElement y, ECCurve curve) { - if ((x != null && y == null) || (x == null && y != null)) + if ((x is null ^ y is null) || (curve is null)) throw new ArgumentException("Exactly one of the field elements is null"); this.X = x; this.Y = y; @@ -31,6 +31,7 @@ internal ECPoint(ECFieldElement x, ECFieldElement y, ECCurve curve) public int CompareTo(ECPoint other) { + if (!Curve.Equals(other.Curve)) throw new InvalidOperationException("Invalid comparision for points with different curves"); if (ReferenceEquals(this, other)) return 0; int result = X.CompareTo(other.X); if (result != 0) return result; diff --git a/src/neo/Cryptography/MerkleTree.cs b/src/neo/Cryptography/MerkleTree.cs index 89d0a275b5..870eaabbb9 100644 --- a/src/neo/Cryptography/MerkleTree.cs +++ b/src/neo/Cryptography/MerkleTree.cs @@ -15,7 +15,7 @@ public class MerkleTree internal MerkleTree(UInt256[] hashes) { - if (hashes.Length == 0) throw new ArgumentException(); + if (hashes is null || hashes.Length == 0) throw new ArgumentException(); this.root = Build(hashes.Select(p => new MerkleTreeNode { Hash = p }).ToArray()); int depth = 1; for (MerkleTreeNode i = root; i.LeftChild != null; i = i.LeftChild) diff --git a/tests/neo.UnitTests/Cryptography/ECC/UT_ECDsa.cs b/tests/neo.UnitTests/Cryptography/ECC/UT_ECDsa.cs deleted file mode 100644 index 62206741cb..0000000000 --- a/tests/neo.UnitTests/Cryptography/ECC/UT_ECDsa.cs +++ /dev/null @@ -1,55 +0,0 @@ -using FluentAssertions; -using Microsoft.VisualStudio.TestTools.UnitTesting; -using Neo.Wallets; -using System; -using System.Numerics; -using ECDsa = Neo.Cryptography.ECC.ECDsa; - -namespace Neo.UnitTests.Cryptography -{ - [TestClass] - public class UT_ECDsa - { - private KeyPair key = null; - - [TestInitialize] - public void TestSetup() - { - key = UT_Crypto.generateCertainKey(32); - } - - [TestMethod] - public void TestECDsaConstructor() - { - Action action = () => new ECDsa(key.PublicKey); - action.Should().NotThrow(); - action = () => new ECDsa(key.PrivateKey, key.PublicKey.Curve); - action.Should().NotThrow(); - } - - [TestMethod] - public void TestGenerateSignature() - { - ECDsa sa = new ECDsa(key.PrivateKey, key.PublicKey.Curve); - byte[] message = System.Text.Encoding.Default.GetBytes("HelloWorld"); - for (int i = 0; i < 10; i++) - { - BigInteger[] result = sa.GenerateSignature(message); - result.Length.Should().Be(2); - } - sa = new ECDsa(key.PublicKey); - Action action = () => sa.GenerateSignature(message); - action.Should().Throw(); - } - - [TestMethod] - public void TestVerifySignature() - { - ECDsa sa = new ECDsa(key.PrivateKey, key.PublicKey.Curve); - byte[] message = System.Text.Encoding.Default.GetBytes("HelloWorld"); - BigInteger[] result = sa.GenerateSignature(message); - sa.VerifySignature(message, result[0], result[1]).Should().BeTrue(); - sa.VerifySignature(message, new BigInteger(-100), result[1]).Should().BeFalse(); - } - } -} diff --git a/tests/neo.UnitTests/Cryptography/ECC/UT_ECFieldElement.cs b/tests/neo.UnitTests/Cryptography/ECC/UT_ECFieldElement.cs index 6b2965b41f..725fc46962 100644 --- a/tests/neo.UnitTests/Cryptography/ECC/UT_ECFieldElement.cs +++ b/tests/neo.UnitTests/Cryptography/ECC/UT_ECFieldElement.cs @@ -23,6 +23,29 @@ public void TestECFieldElementConstructor() action.Should().Throw(); } + [TestMethod] + public void TestCompareTo() + { + ECFieldElement X1 = new ECFieldElement(new BigInteger(100), ECCurve.Secp256k1); + ECFieldElement Y1 = new ECFieldElement(new BigInteger(200), ECCurve.Secp256k1); + ECFieldElement X2 = new ECFieldElement(new BigInteger(300), ECCurve.Secp256k1); + ECFieldElement Y2 = new ECFieldElement(new BigInteger(400), ECCurve.Secp256k1); + ECFieldElement X3 = new ECFieldElement(new BigInteger(100), ECCurve.Secp256r1); + ECFieldElement Y3 = new ECFieldElement(new BigInteger(400), ECCurve.Secp256r1); + ECPoint point1 = new ECPoint(X1, Y1, ECCurve.Secp256k1); + ECPoint point2 = new ECPoint(X2, Y1, ECCurve.Secp256k1); + ECPoint point3 = new ECPoint(X1, Y2, ECCurve.Secp256k1); + ECPoint point4 = new ECPoint(X3, Y3, ECCurve.Secp256r1); + + point1.CompareTo(point1).Should().Be(0); + point1.CompareTo(point2).Should().Be(-1); + point2.CompareTo(point1).Should().Be(1); + point1.CompareTo(point3).Should().Be(-1); + point3.CompareTo(point1).Should().Be(1); + Action action = () => point3.CompareTo(point4); + action.Should().Throw(); + } + [TestMethod] public void TestEquals() { @@ -30,6 +53,7 @@ public void TestEquals() object element = new ECFieldElement(input, ECCurve.Secp256k1); element.Equals(element).Should().BeTrue(); element.Equals(1).Should().BeFalse(); + element.Equals(new ECFieldElement(input, ECCurve.Secp256r1)).Should().BeFalse(); input = new BigInteger(200); element.Equals(new ECFieldElement(input, ECCurve.Secp256k1)).Should().BeFalse(); diff --git a/tests/neo.UnitTests/Cryptography/ECC/UT_ECPoint.cs b/tests/neo.UnitTests/Cryptography/ECC/UT_ECPoint.cs index b38df0fa44..9a8232d9fb 100644 --- a/tests/neo.UnitTests/Cryptography/ECC/UT_ECPoint.cs +++ b/tests/neo.UnitTests/Cryptography/ECC/UT_ECPoint.cs @@ -28,18 +28,12 @@ public static byte[] generatePrivateKey(int privateKeyLength) public void TestCompareTo() { ECFieldElement X1 = new ECFieldElement(new BigInteger(100), ECCurve.Secp256k1); - ECFieldElement Y1 = new ECFieldElement(new BigInteger(200), ECCurve.Secp256k1); - ECFieldElement X2 = new ECFieldElement(new BigInteger(300), ECCurve.Secp256k1); - ECFieldElement Y2 = new ECFieldElement(new BigInteger(400), ECCurve.Secp256k1); - ECPoint point1 = new ECPoint(X1, Y1, ECCurve.Secp256k1); - ECPoint point2 = new ECPoint(X2, Y1, ECCurve.Secp256k1); - ECPoint point3 = new ECPoint(X1, Y2, ECCurve.Secp256k1); + ECFieldElement X2 = new ECFieldElement(new BigInteger(200), ECCurve.Secp256k1); + ECFieldElement X3 = new ECFieldElement(new BigInteger(100), ECCurve.Secp256r1); - point1.CompareTo(point1).Should().Be(0); - point1.CompareTo(point2).Should().Be(-1); - point2.CompareTo(point1).Should().Be(1); - point1.CompareTo(point3).Should().Be(-1); - point3.CompareTo(point1).Should().Be(1); + X1.CompareTo(X2).Should().Be(-1); + Action action = () => X1.CompareTo(X3); + action.Should().Throw(); } [TestMethod] @@ -60,6 +54,8 @@ public void TestECPointConstructor() action.Should().Throw(); action = () => new ECPoint(null, Y, ECCurve.Secp256k1); action.Should().Throw(); + action = () => new ECPoint(null, Y, null); + action.Should().Throw(); } [TestMethod]