Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions Src/IronPython.Modules/_ssl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
Expand Down Expand Up @@ -610,6 +611,83 @@ public int write(CodeContext/*!*/ context, Bytes data) {
}
}

#nullable enable

[PythonType]
public class MemoryBIO {
private bool _write_eof;

public bool eof { get; private set; }
public int pending { get; private set; }

private Bytes? buf;
private Queue<Bytes> queue = new Queue<Bytes>();

public MemoryBIO() { }

public Bytes read(int size = -1) {
if (size == 0 || eof) {
return Bytes.Empty;
}
if (size == -1 || size > pending) {
size = pending;
}

byte[] res = new byte[size];
var resSpan = res.AsSpan();

if (buf is not null) {
var span = buf.AsSpan();
var length = resSpan.Length;
if (length < span.Length) {
buf = Bytes.Make(span.Slice(length).ToArray());
span = span.Slice(0, length);
}
else {
buf = null;
}
span.CopyTo(resSpan);
resSpan = resSpan.Slice(span.Length);
}

while (resSpan.Length > 0) {
Debug.Assert(buf is null && queue.Count > 0);
var span = queue.Dequeue().AsSpan();
var length = resSpan.Length;
if (length < span.Length) {
buf = Bytes.Make(span.Slice(length).ToArray());
span = span.Slice(0, length);
}
span.CopyTo(resSpan);
resSpan = resSpan.Slice(span.Length);
}

pending -= size;
if (_write_eof && pending == 0) eof = true;
return Bytes.Make(res);
}

public int write(CodeContext context, [NotNull] IBufferProtocol b) {
if (_write_eof) throw PythonExceptions.CreateThrowable(SSLError(context), "cannot write() after write_eof()");

if (b is not Bytes bytes) {
using var buffer = b.GetBuffer();
bytes = Bytes.Make(buffer.ToArray());
}
if (bytes.Count == 0) return 0;
queue.Enqueue(bytes);
pending += bytes.Count;
return bytes.Count;
}

public void write_eof() {
_write_eof = true;
if (pending == 0) eof = true;
}
}

#nullable restore

public static object txt2obj(CodeContext context, string txt, bool name = false) {
Asn1Object obj = null;
if (name) {
Expand Down
66 changes: 66 additions & 0 deletions Tests/modules/network_related/test__ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import _ssl
import os
import socket
import sys
import unittest

from iptest import IronPythonTestCase, is_cli, is_netcoreapp, retryOnFailure, run_test, skipUnlessIronPython
Expand Down Expand Up @@ -269,4 +270,69 @@ def test_cert_date_locale(self):
finally:
System.Threading.Thread.CurrentThread.CurrentCulture = culture

import _ssl as ssl

# These come from the 3.5 stdlib and can eventually be removed
@unittest.skipUnless(is_cli or sys.version_info >= (3,5), "not in CPython 3.4")
class MemoryBIOTests(unittest.TestCase):
def test_read_write(self):
bio = ssl.MemoryBIO()
bio.write(b'foo')
self.assertEqual(bio.read(), b'foo')
self.assertEqual(bio.read(), b'')
bio.write(b'foo')
bio.write(b'bar')
self.assertEqual(bio.read(), b'foobar')
self.assertEqual(bio.read(), b'')
bio.write(b'baz')
self.assertEqual(bio.read(2), b'ba')
self.assertEqual(bio.read(1), b'z')
self.assertEqual(bio.read(1), b'')

def test_eof(self):
bio = ssl.MemoryBIO()
self.assertFalse(bio.eof)
self.assertEqual(bio.read(), b'')
self.assertFalse(bio.eof)
bio.write(b'foo')
self.assertFalse(bio.eof)
bio.write_eof()
self.assertFalse(bio.eof)
self.assertEqual(bio.read(2), b'fo')
self.assertFalse(bio.eof)
self.assertEqual(bio.read(1), b'o')
self.assertTrue(bio.eof)
self.assertEqual(bio.read(), b'')
self.assertTrue(bio.eof)

def test_pending(self):
bio = ssl.MemoryBIO()
self.assertEqual(bio.pending, 0)
bio.write(b'foo')
self.assertEqual(bio.pending, 3)
for i in range(3):
bio.read(1)
self.assertEqual(bio.pending, 3-i-1)
for i in range(3):
bio.write(b'x')
self.assertEqual(bio.pending, i+1)
bio.read()
self.assertEqual(bio.pending, 0)

def test_buffer_types(self):
bio = ssl.MemoryBIO()
bio.write(b'foo')
self.assertEqual(bio.read(), b'foo')
bio.write(bytearray(b'bar'))
self.assertEqual(bio.read(), b'bar')
bio.write(memoryview(b'baz'))
self.assertEqual(bio.read(), b'baz')

def test_error_types(self):
bio = ssl.MemoryBIO()
self.assertRaises(TypeError, bio.write, 'foo')
self.assertRaises(TypeError, bio.write, None)
self.assertRaises(TypeError, bio.write, True)
self.assertRaises(TypeError, bio.write, 1)

run_test(__name__)