diff --git a/Src/IronPython/Runtime/ByteArray.cs b/Src/IronPython/Runtime/ByteArray.cs index 9e9606004..ebc5516d0 100644 --- a/Src/IronPython/Runtime/ByteArray.cs +++ b/Src/IronPython/Runtime/ByteArray.cs @@ -92,7 +92,7 @@ public void __init__(CodeContext context, object? source) { } IEnumerator ie = PythonOps.GetEnumerator(context, source); while (ie.MoveNext()) { - Add(GetByte(ie.Current)); + Add(ByteOps.GetByte(ie.Current)); } } } @@ -127,7 +127,7 @@ public void append(int item) { public void append(object? item) { lock (this) { - _bytes.Add(GetByte(item)); + _bytes.Add(ByteOps.GetByte(item)); } } @@ -138,13 +138,13 @@ public void extend([NotNull]IEnumerable seq) { } } - public void extend(object? seq) { + public void extend(CodeContext context, object? seq) { // We don't make use of the length hint when extending the byte array. // However, in order to match CPython behavior with invalid length hints we // we need to go through the motions and get the length hint and attempt // to convert it to an int. - extend(GetBytes(seq, useHint: true)); + extend(ByteOps.GetBytes(seq, useHint: true, context)); } public void insert(int index, int value) { @@ -205,7 +205,7 @@ public void remove(int value) { public void remove(object? value) { lock (this) { - RemoveByte(GetByte(value)); + RemoveByte(ByteOps.GetByte(value)); } } @@ -1214,7 +1214,7 @@ public object? this[int index] { } set { lock (this) { - _bytes[PythonOps.FixIndex(index, _bytes.Count)] = GetByte(value); + _bytes[PythonOps.FixIndex(index, _bytes.Count)] = ByteOps.GetByte(value); } } } @@ -1251,7 +1251,7 @@ public object? this[[NotNull]Slice slice] { // integers, longs, etc... - fill in an array of 0 bytes // list of bytes, indexables, etc... - IList list = GetBytes(value, useHint: false); + IList list = ByteOps.GetBytes(value, useHint: false); lock (this) { slice.indices(_bytes.Count, out int start, out int stop, out int step); @@ -1345,41 +1345,6 @@ private void SliceNoStep(int start, int stop, IList other) { } } - private static byte GetByte(object? value) { - if (Converter.TryConvertToIndex(value, out object index)) { - switch (index) { - case int i: return i.ToByteChecked(); - case BigInteger bi: return bi.ToByteChecked(); - default: throw new InvalidOperationException(); // unreachable - } - } - throw PythonOps.TypeError("an integer is required"); - } - - internal static IList GetBytes(object? value, bool useHint) { - switch (value) { - case IList lob when !(lob is ListGenericWrapper): - return lob; - case IBufferProtocol bp: - using (IPythonBuffer buf = bp.GetBuffer()) { - return buf.AsReadOnlySpan().ToArray(); - } - case ReadOnlyMemory rom: - return rom.ToArray(); - case Memory mem: - return mem.ToArray(); - default: - int len = 0; - if (useHint) PythonOps.TryInvokeLengthHint(DefaultContext.Default, value, out len); - List ret = new List(len); - IEnumerator ie = PythonOps.GetEnumerator(value); - while (ie.MoveNext()) { - ret.Add(GetByte(ie.Current)); - } - return ret; - } - } - #endregion #region IList Members diff --git a/Src/IronPython/Runtime/Bytes.cs b/Src/IronPython/Runtime/Bytes.cs index 6cbe153aa..49d28b22d 100644 --- a/Src/IronPython/Runtime/Bytes.cs +++ b/Src/IronPython/Runtime/Bytes.cs @@ -11,7 +11,6 @@ using System.Linq; using System.Linq.Expressions; using System.Numerics; -using System.Runtime.InteropServices; using System.Text; using Microsoft.Scripting.Runtime; @@ -19,7 +18,7 @@ using IronPython.Runtime.Operations; using IronPython.Runtime.Types; -using IronPython.Hosting; +using NotNullWhenAttribute = System.Diagnostics.CodeAnalysis.NotNullWhenAttribute; namespace IronPython.Runtime { [PythonType("bytes"), Serializable] @@ -31,53 +30,110 @@ public Bytes() { _bytes = new byte[0]; } - public Bytes([NotNull]IEnumerable bytes) { + public Bytes([NotNull] Bytes bytes) { + _bytes = bytes._bytes; + } + + public Bytes([NotNull] IEnumerable bytes) { _bytes = bytes.ToArray(); } - public Bytes([BytesLike, NotNull]IBufferProtocol source) { + public Bytes([NotNull] IBufferProtocol source) { using IPythonBuffer buffer = source.GetBuffer(BufferFlags.FullRO); _bytes = buffer.ToArray(); } - public Bytes(CodeContext context, object? source) { - if (PythonTypeOps.TryInvokeUnaryOperator(context, source, "__bytes__", out object? res)) { - if (res is Bytes bytes) { - _bytes = bytes._bytes; + [StaticExtensionMethod] + public static object __new__(CodeContext context, [NotNull] PythonType cls) { + if (cls == TypeCache.Bytes) { + return Empty; + } else { + return cls.CreateInstance(context); + } + } + + [StaticExtensionMethod] + public static object __new__(CodeContext context, [NotNull] PythonType cls, [NotNull] IBufferProtocol source) { + if (cls == TypeCache.Bytes) { + if (source.GetType() == typeof(Bytes)) { + return source; + } else if (TryInvokeBytesOperator(context, source, out Bytes? res)) { + return res; } else { - throw PythonOps.TypeError("__bytes__ returned non-bytes (got '{0}' from type '{1}')", PythonOps.GetPythonTypeName(res), PythonOps.GetPythonTypeName(source)); + return new Bytes(source); } - } else if (Converter.TryConvertToIndex(source, throwOverflowError: true, out int size)) { - if (size < 0) throw PythonOps.ValueError("negative count"); - _bytes = new byte[size]; } else { - _bytes = ByteArray.GetBytes(source, useHint: true).ToArray(); + return cls.CreateInstance(context, __new__(context, TypeCache.Bytes, source)); } } - public Bytes([NotNull]IEnumerable source) { - _bytes = source.Select(b => ((int)PythonOps.Index(b)).ToByteChecked()).ToArray(); + [StaticExtensionMethod] + public static object __new__(CodeContext context, [NotNull] PythonType cls, object? @object) { + if (cls == TypeCache.Bytes) { + return FromObject(context, @object); + } else { + return cls.CreateInstance(context, __new__(context, TypeCache.Bytes, @object)); + } } - public Bytes([NotNull]PythonList bytes) { - _bytes = ByteOps.GetBytes(bytes).ToArray(); + [StaticExtensionMethod] + public static object __new__(CodeContext context, [NotNull] PythonType cls, [NotNull] Extensible size) { + if (cls == TypeCache.Bytes) { + if (TryInvokeBytesOperator(context, size, out Bytes? res)) { + return res; + } else { + if (size < 0) throw PythonOps.ValueError("negative count"); + return new Bytes(new byte[size]); + } + } else { + return cls.CreateInstance(context, __new__(context, TypeCache.Bytes, size)); + } } - public Bytes(int size) { - if (size < 0) throw PythonOps.ValueError("negative count"); - _bytes = new byte[size]; + [StaticExtensionMethod] + public static object __new__(CodeContext context, [NotNull] PythonType cls, int size) { + if (cls == TypeCache.Bytes) { + if (size < 0) throw PythonOps.ValueError("negative count"); + return new Bytes(new byte[size]); + } else { + return cls.CreateInstance(context, __new__(context, TypeCache.Bytes, size)); + } } - public Bytes([NotNull]string @string) { + [StaticExtensionMethod] + public static object __new__(CodeContext context, [NotNull] PythonType cls, [NotNull] ExtensibleString @string) { + if (cls == TypeCache.Bytes) { + if (TryInvokeBytesOperator(context, @string, out Bytes? res)) { + return res; + } else { + throw PythonOps.TypeError("string argument without an encoding"); + } + } else { + return cls.CreateInstance(context, __new__(context, TypeCache.Bytes, @string)); + } + } + + [StaticExtensionMethod] + public static object __new__(CodeContext context, [NotNull] PythonType cls, [NotNull] string @string) { throw PythonOps.TypeError("string argument without an encoding"); } - public Bytes(CodeContext context, [NotNull]string @string, [NotNull]string encoding) { - _bytes = StringOps.encode(context, @string, encoding, "strict").UnsafeByteArray; + [StaticExtensionMethod] + public static object __new__(CodeContext context, [NotNull] PythonType cls, [NotNull] string @string, [NotNull] string encoding) { + if (cls == TypeCache.Bytes) { + return StringOps.encode(context, @string, encoding); + } else { + return cls.CreateInstance(context, __new__(context, TypeCache.Bytes, @string, encoding)); + } } - public Bytes(CodeContext context, [NotNull]string @string, [NotNull]string encoding, [NotNull]string errors) { - _bytes = StringOps.encode(context, @string, encoding, errors).UnsafeByteArray; + [StaticExtensionMethod] + public static object __new__(CodeContext context, [NotNull] PythonType cls, [NotNull] string @string, [NotNull] string encoding, [NotNull] string errors) { + if (cls == TypeCache.Bytes) { + return StringOps.encode(context, @string, encoding, errors); + } else { + return cls.CreateInstance(context, __new__(context, TypeCache.Bytes, @string, encoding, errors)); + } } private Bytes(byte[] bytes) { @@ -89,6 +145,21 @@ private Bytes(byte[] bytes) { internal static Bytes FromByte(byte b) => oneByteBytes[b]; + internal static Bytes FromObject(CodeContext context, object? o) { + if (o == null) { + throw PythonOps.TypeError("cannot convert 'NoneType' object to bytes"); + } else if (o.GetType() == typeof(Bytes)) { + return (Bytes)o; + } else if (TryInvokeBytesOperator(context, o, out Bytes? res)) { + return res; + } else if (Converter.TryConvertToIndex(o, throwOverflowError: true, out int size)) { + if (size < 0) throw PythonOps.ValueError("negative count"); + return new Bytes(new byte[size]); + } else { + return new Bytes(ByteOps.GetBytes(o, useHint: true, context).ToArray()); + } + } + internal static Bytes Make(byte[] bytes) => new Bytes(bytes); @@ -364,7 +435,7 @@ public Bytes join(object? sequence) { public Bytes join([NotNull]PythonList sequence) { if (sequence.__len__() == 0) { - return new Bytes(); + return Empty; } else if (sequence.__len__() == 1) { return JoinOne(sequence[0]); } @@ -892,6 +963,20 @@ internal ReadOnlyMemory AsMemory() { return _bytes.AsMemory(); } + private static bool TryInvokeBytesOperator(CodeContext context, object? obj, [NotNullWhen(true)] out Bytes? bytes) { + if (PythonTypeOps.TryInvokeUnaryOperator(context, obj, "__bytes__", out object? res)) { + if (res is Bytes b) { + bytes = b; + return true; + } else { + throw PythonOps.TypeError("__bytes__ returned non-bytes (got '{0}' from type '{1}')", PythonOps.GetPythonTypeName(res), PythonOps.GetPythonTypeName(obj)); + } + } else { + bytes = null; + return false; + } + } + private static Bytes JoinOne(object? curVal) { if (curVal?.GetType() == typeof(Bytes)) { return (Bytes)curVal; diff --git a/Src/IronPython/Runtime/Operations/ByteOps.cs b/Src/IronPython/Runtime/Operations/ByteOps.cs index 074bd52ac..ab2040b8d 100644 --- a/Src/IronPython/Runtime/Operations/ByteOps.cs +++ b/Src/IronPython/Runtime/Operations/ByteOps.cs @@ -106,11 +106,31 @@ internal static IList CoerceBytes(object? obj) { throw PythonOps.TypeError("a bytes-like object is required, not '{0}'", PythonTypeOps.GetName(obj)); } - internal static List GetBytes(ICollection bytes) { - return bytes.Select(GetByte).ToList(); + internal static IList GetBytes(object? value, bool useHint, CodeContext? context = null) { + switch (value) { + case IList lob when !(lob is ListGenericWrapper): + return lob; + case IBufferProtocol bp: + using (IPythonBuffer buf = bp.GetBuffer()) { + return buf.AsReadOnlySpan().ToArray(); + } + case ReadOnlyMemory rom: + return rom.ToArray(); + case Memory mem: + return mem.ToArray(); + default: + int len = 0; + if (useHint) PythonOps.TryInvokeLengthHint(context ?? DefaultContext.Default, value, out len); + List ret = new List(len); + IEnumerator ie = PythonOps.GetEnumerator(value); + while (ie.MoveNext()) { + ret.Add(GetByte(ie.Current)); + } + return ret; + } } - private static byte GetByte(object? o) { + internal static byte GetByte(object? o) { // TODO: move fast paths to TryConvertToIndex? switch (o) { case int ii: @@ -135,8 +155,13 @@ private static byte GetByte(object? o) { return ((BigInteger)ui).ToByteChecked(); } - if (Converter.TryConvertToIndex(o, out int i)) - return i.ToByteChecked(); + if (Converter.TryConvertToIndex(o, out object index)) { + switch (index) { + case int i: return i.ToByteChecked(); + case BigInteger bi: return bi.ToByteChecked(); + default: throw new InvalidOperationException(); // unreachable + } + } throw PythonOps.TypeError($"'{PythonTypeOps.GetName(o)}' object cannot be interpreted as an integer"); } diff --git a/Src/IronPython/Runtime/Operations/IntOps.cs b/Src/IronPython/Runtime/Operations/IntOps.cs index aff29b1a7..14ae06af8 100644 --- a/Src/IronPython/Runtime/Operations/IntOps.cs +++ b/Src/IronPython/Runtime/Operations/IntOps.cs @@ -545,7 +545,7 @@ public static BigInteger from_bytes(CodeContext context, object bytes, string by bool isLittle = byteorder == "little"; if (!isLittle && byteorder != "big") throw PythonOps.ValueError("byteorder must be either 'little' or 'big'"); - return FromBytes(new Bytes(context, bytes), isLittle, signed); + return FromBytes(Bytes.FromObject(context, bytes), isLittle, signed); } private static BigInteger FromBytes(IList bytes, bool isLittle, bool signed) { diff --git a/Src/IronPython/Runtime/Types/TypeCache.Generated.cs b/Src/IronPython/Runtime/Types/TypeCache.Generated.cs index fa0e9fd5e..b2560c56b 100644 --- a/Src/IronPython/Runtime/Types/TypeCache.Generated.cs +++ b/Src/IronPython/Runtime/Types/TypeCache.Generated.cs @@ -28,6 +28,7 @@ public static class TypeCache { private static PythonType setcollection; private static PythonType pythontype; private static PythonType str; + private static PythonType bytes; private static PythonType pythontuple; private static PythonType weakreference; private static PythonType pythonlist; @@ -123,6 +124,13 @@ public static PythonType String { } } + public static PythonType Bytes { + get { + if (bytes == null) bytes = DynamicHelpers.GetPythonTypeFromType(typeof(Bytes)); + return bytes; + } + } + public static PythonType PythonTuple { get { if (pythontuple == null) pythontuple = DynamicHelpers.GetPythonTypeFromType(typeof(PythonTuple)); diff --git a/Src/Scripts/generate_typecache.py b/Src/Scripts/generate_typecache.py index e82826cde..7a9a3ae47 100644 --- a/Src/Scripts/generate_typecache.py +++ b/Src/Scripts/generate_typecache.py @@ -44,6 +44,7 @@ def __init__(self, type, name=None, typeType='PythonType', entryName=None): TypeData('SetCollection', entryName='Set'), TypeData('PythonType'), TypeData('String', 'str'), + TypeData('Bytes'), TypeData('PythonTuple'), TypeData('WeakReference'), TypeData('PythonList'), diff --git a/Tests/test_bytes.py b/Tests/test_bytes.py index 6b781c854..46c00f427 100644 --- a/Tests/test_bytes.py +++ b/Tests/test_bytes.py @@ -31,6 +31,10 @@ class BytesTest(IronPythonTestCase): def test_init(self): b = bytes(b'abcd') + self.assertIs(bytes(b), b) + + sb = BytesSubclass(b) + self.assertIsNot(BytesSubclass(sb), sb) for testType in types: self.assertEqual(testType(b), b) @@ -180,14 +184,65 @@ def __bytes__(self): def __index__(self): return 42 - self.assertEquals(bytes(A4()), b'abc') - self.assertEquals(bytearray(A4()), bytearray(42)) + self.assertEqual(bytes(A4()), b'abc') + self.assertEqual(bytearray(A4()), bytearray(42)) + self.assertEqual(int.from_bytes(A4(), 'big'), 0x616263) class EmptyClass: pass t = EmptyClass() t.__bytes__ = lambda: b"1" self.assertRaisesRegex(TypeError, "'EmptyClass' object is not iterable", bytes, t) + class OtherBytesSubclass(bytes): pass + + class SomeClass: + def __bytes__(self): + return OtherBytesSubclass(b'SOME CLASS') + + self.assertEqual(bytes(SomeClass()), b'SOME CLASS') + self.assertIs(type(bytes(SomeClass())), OtherBytesSubclass) + self.assertEqual(BytesSubclass(SomeClass()), b'SOME CLASS') + self.assertIs(type(BytesSubclass(SomeClass())), BytesSubclass) + + class BytesBytesSubclass(bytes): + def __bytes__(self): + return BytesBytesSubclass(b"BYTES FROM BYTES") + + self.assertEqual(bytes(BytesBytesSubclass(b"JUST BYTES")), b"BYTES FROM BYTES") + self.assertIs(type(bytes(BytesBytesSubclass(b"JUST BYTES"))), BytesBytesSubclass) + + class ListSubclass(bytes): + def __bytes__(self): + return OtherBytesSubclass(b"BYTES FROM LIST") + + self.assertEqual(bytes(ListSubclass([1, 2, 3])), b"BYTES FROM LIST") + self.assertIs(type(bytes(ListSubclass([1, 2, 3]))), OtherBytesSubclass) + self.assertEqual(BytesSubclass(ListSubclass([1, 2, 3])), b"BYTES FROM LIST") + self.assertIs(type(BytesSubclass(ListSubclass([1, 2, 3]))), BytesSubclass) + + class StrSubclass(str): + def __bytes__(self): + return OtherBytesSubclass(b"BYTES FROM STR") + + if sys.version_info >= (3, 5) or sys.implementation.name == 'ironpython': + self.assertEqual(bytes(StrSubclass("STR")), b"BYTES FROM STR") + self.assertIs(type(bytes(StrSubclass("STR"))), OtherBytesSubclass) + self.assertEqual(BytesSubclass(StrSubclass("STR")), b"BYTES FROM STR") + self.assertIs(type(BytesSubclass(StrSubclass("STR"))), BytesSubclass) + else: + self.assertRaises(TypeError, bytes, StrSubclass("STR")) + + self.assertEqual(bytes(StrSubclass("STR"), 'ascii'), b"STR") + self.assertEqual(bytes(StrSubclass("STR"), 'ascii', 'ignore'), b"STR") + + class IntSubclass(int): + def __bytes__(self): + return OtherBytesSubclass(b"BYTES FROM INT") + + self.assertEqual(bytes(IntSubclass(-1)), b"BYTES FROM INT") + self.assertIs(type(bytes(IntSubclass(-1))), OtherBytesSubclass) + self.assertEqual(BytesSubclass(IntSubclass(-1)), b"BYTES FROM INT") + self.assertIs(type(BytesSubclass(IntSubclass(-1))), BytesSubclass) def test_capitalize(self): tests = [(b'foo', b'Foo'), @@ -575,6 +630,8 @@ def test_join(self): x = bytearray(x) self.assertTrue(id(x.join(b'')) != id(x)) + self.assertEqual(id(b'foo'.join([])), id(b'bar'.join([]))) + x = b'abc' self.assertEqual(id(b'foo'.join([x])), id(x)) diff --git a/Tests/test_bytes_stdlib.py b/Tests/test_bytes_stdlib.py index 3a73d6f67..2ac11af8f 100644 --- a/Tests/test_bytes_stdlib.py +++ b/Tests/test_bytes_stdlib.py @@ -177,7 +177,7 @@ def load_tests(loader, standard_tests, pattern): suite.addTest(test.test_bytes.BytesTest('test_contains')) suite.addTest(test.test_bytes.BytesTest('test_copy')) suite.addTest(test.test_bytes.BytesTest('test_count')) - #suite.addTest(test.test_bytes.BytesTest('test_custom')) + suite.addTest(test.test_bytes.BytesTest('test_custom')) suite.addTest(test.test_bytes.BytesTest('test_decode')) suite.addTest(test.test_bytes.BytesTest('test_empty_sequence')) suite.addTest(test.test_bytes.BytesTest('test_encoding'))