From 98b215751c4c76f03bdd5805dda237d9ef30b8b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Lozier?= Date: Thu, 24 Mar 2022 14:08:43 -0400 Subject: [PATCH 1/2] Add _ssl.MemoryBIO --- Src/IronPython.Modules/_ssl.cs | 78 ++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/Src/IronPython.Modules/_ssl.cs b/Src/IronPython.Modules/_ssl.cs index 7a2b1deeb..4df88b4df 100644 --- a/Src/IronPython.Modules/_ssl.cs +++ b/Src/IronPython.Modules/_ssl.cs @@ -6,6 +6,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Globalization; using System.IO; using System.Linq; @@ -610,6 +611,83 @@ public int write(CodeContext/*!*/ context, Bytes data) { } } +#nullable enable + + [PythonType] + public class MemoryBIO { + private bool _write_eof; + + public bool eof { get; private set; } + public int pending { get; private set; } + + private Bytes? buf; + private Queue queue = new Queue(); + + public MemoryBIO() { } + + public Bytes read(int size = -1) { + if (size == 0 || eof) { + return Bytes.Empty; + } + if (size == -1 || size > pending) { + size = pending; + } + + byte[] res = new byte[size]; + var resSpan = res.AsSpan(); + + if (buf is not null) { + var span = buf.AsSpan(); + var length = resSpan.Length; + if (length < span.Length) { + buf = Bytes.Make(span.Slice(length).ToArray()); + span = span.Slice(0, length); + } + else { + buf = null; + } + span.CopyTo(resSpan); + resSpan = resSpan.Slice(span.Length); + } + + while (resSpan.Length > 0) { + Debug.Assert(buf is null && queue.Count > 0); + var span = queue.Dequeue().AsSpan(); + var length = resSpan.Length; + if (length < span.Length) { + buf = Bytes.Make(span.Slice(length).ToArray()); + span = span.Slice(0, length); + } + span.CopyTo(resSpan); + resSpan = resSpan.Slice(span.Length); + } + + pending -= size; + if (_write_eof && pending == 0) eof = true; + return Bytes.Make(res); + } + + public int write(CodeContext context, [NotNull] IBufferProtocol b) { + if (_write_eof) throw PythonExceptions.CreateThrowable(SSLError(context), "cannot write() after write_eof()"); + + if (b is not Bytes bytes) { + using var buffer = b.GetBuffer(); + bytes = Bytes.Make(buffer.ToArray()); + } + if (bytes.Count == 0) return 0; + queue.Enqueue(bytes); + pending += bytes.Count; + return bytes.Count; + } + + public void write_eof() { + _write_eof = true; + if (pending == 0) eof = true; + } + } + +#nullable restore + public static object txt2obj(CodeContext context, string txt, bool name = false) { Asn1Object obj = null; if (name) { From 62bfafa88bd0f28c91dbe0e95400879f6de6be19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Lozier?= Date: Sat, 26 Mar 2022 13:18:15 -0400 Subject: [PATCH 2/2] Add some tests --- Tests/modules/network_related/test__ssl.py | 66 ++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/Tests/modules/network_related/test__ssl.py b/Tests/modules/network_related/test__ssl.py index 0b6bb3185..32e1c1c9f 100644 --- a/Tests/modules/network_related/test__ssl.py +++ b/Tests/modules/network_related/test__ssl.py @@ -9,6 +9,7 @@ import _ssl import os import socket +import sys import unittest from iptest import IronPythonTestCase, is_cli, is_netcoreapp, retryOnFailure, run_test, skipUnlessIronPython @@ -269,4 +270,69 @@ def test_cert_date_locale(self): finally: System.Threading.Thread.CurrentThread.CurrentCulture = culture +import _ssl as ssl + +# These come from the 3.5 stdlib and can eventually be removed +@unittest.skipUnless(is_cli or sys.version_info >= (3,5), "not in CPython 3.4") +class MemoryBIOTests(unittest.TestCase): + def test_read_write(self): + bio = ssl.MemoryBIO() + bio.write(b'foo') + self.assertEqual(bio.read(), b'foo') + self.assertEqual(bio.read(), b'') + bio.write(b'foo') + bio.write(b'bar') + self.assertEqual(bio.read(), b'foobar') + self.assertEqual(bio.read(), b'') + bio.write(b'baz') + self.assertEqual(bio.read(2), b'ba') + self.assertEqual(bio.read(1), b'z') + self.assertEqual(bio.read(1), b'') + + def test_eof(self): + bio = ssl.MemoryBIO() + self.assertFalse(bio.eof) + self.assertEqual(bio.read(), b'') + self.assertFalse(bio.eof) + bio.write(b'foo') + self.assertFalse(bio.eof) + bio.write_eof() + self.assertFalse(bio.eof) + self.assertEqual(bio.read(2), b'fo') + self.assertFalse(bio.eof) + self.assertEqual(bio.read(1), b'o') + self.assertTrue(bio.eof) + self.assertEqual(bio.read(), b'') + self.assertTrue(bio.eof) + + def test_pending(self): + bio = ssl.MemoryBIO() + self.assertEqual(bio.pending, 0) + bio.write(b'foo') + self.assertEqual(bio.pending, 3) + for i in range(3): + bio.read(1) + self.assertEqual(bio.pending, 3-i-1) + for i in range(3): + bio.write(b'x') + self.assertEqual(bio.pending, i+1) + bio.read() + self.assertEqual(bio.pending, 0) + + def test_buffer_types(self): + bio = ssl.MemoryBIO() + bio.write(b'foo') + self.assertEqual(bio.read(), b'foo') + bio.write(bytearray(b'bar')) + self.assertEqual(bio.read(), b'bar') + bio.write(memoryview(b'baz')) + self.assertEqual(bio.read(), b'baz') + + def test_error_types(self): + bio = ssl.MemoryBIO() + self.assertRaises(TypeError, bio.write, 'foo') + self.assertRaises(TypeError, bio.write, None) + self.assertRaises(TypeError, bio.write, True) + self.assertRaises(TypeError, bio.write, 1) + run_test(__name__)