Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 44 additions & 53 deletions Src/IronPython.Modules/_ssl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,18 @@ public static void RAND_add(object buf, double entropy) {

[PythonType]
public class _SSLContext {
private readonly X509Certificate2Collection _cert_store = new X509Certificate2Collection();
private string _cafile;
internal readonly X509Certificate2Collection _cert_store = new X509Certificate2Collection();
internal string _cafile;
private int _verify_mode = SSL_VERIFY_NONE;

public _SSLContext(CodeContext context, int protocol = PROTOCOL_SSLv23) {
public _SSLContext(CodeContext context, int protocol) {
if (protocol != PROTOCOL_SSLv2 && protocol != PROTOCOL_SSLv23 && protocol != PROTOCOL_SSLv3 &&
protocol != PROTOCOL_TLSv1 && protocol != PROTOCOL_TLSv1_1 && protocol != PROTOCOL_TLSv1_2) {
throw PythonOps.ValueError("invalid protocol version");
}

this.protocol = protocol;

if (protocol != PROTOCOL_SSLv2)
options |= OP_NO_SSLv2;
if (protocol != PROTOCOL_SSLv3)
Expand Down Expand Up @@ -175,14 +176,20 @@ public void load_cert_chain(string certfile, string keyfile = null, object passw

}

public void load_verify_locations(CodeContext context, string cafile = null, string capath = null, object cadata = null) {
public void load_verify_locations(CodeContext context, object cafile = null, string capath = null, object cadata = null) {
if (cafile == null && capath == null && cadata == null) {
throw PythonOps.TypeError("cafile, capath and cadata cannot be all omitted");
}

if (cafile != null) {
_cert_store.Add(ReadCertificate(context, cafile));
_cafile = cafile;
if (cafile is not null) {
if (cafile is string s) {
_cafile = s;
} else if (cafile is Bytes b) {
_cafile = b.MakeString();
} else {
throw PythonOps.TypeError("cafile should be a valid filesystem path");
}
_cert_store.Add(ReadCertificate(context, _cafile));
}

if (capath != null) {
Expand All @@ -207,8 +214,8 @@ public void load_verify_locations(CodeContext context, string cafile = null, str
}
}

public object _wrap_socket(CodeContext context, PythonSocket.socket sock, bool server_side, string server_hostname = null, object ssl_sock = null) {
return new _SSLSocket(context, sock, server_side, null, _cafile, verify_mode, protocol | options, null, _cert_store) { _serverHostName = server_hostname };
public object _wrap_socket(CodeContext context, PythonSocket.socket sock, bool server_side, string server_hostname = null) {
return new _SSLSocket(context, this, sock, server_side, server_hostname);
}
}

Expand All @@ -225,34 +232,22 @@ public class _SSLSocket {
private Exception _validationFailure;
internal string _serverHostName;

public _SSLSocket(CodeContext context, PythonSocket.socket sock, string keyfile = null, string certfile = null, X509Certificate2Collection certs = null) {
_context = context;
_sslStream = new SslStream(new NetworkStream(sock._socket, false), true, CertValidationCallback);
_socket = sock;
_protocol = PythonSsl.PROTOCOL_SSLv23 | PythonSsl.OP_NO_SSLv2 | PythonSsl.OP_NO_SSLv3;
_validate = false;
_certCollection = certs ?? new X509Certificate2Collection();
}
public _SSLContext context { get; }

internal _SSLSocket(CodeContext context,
PythonSocket.socket sock,
bool server_side,
string keyfile = null,
string certfile = null,
int certs_mode = PythonSsl.CERT_NONE,
int protocol = (PythonSsl.PROTOCOL_SSLv23 | PythonSsl.OP_NO_SSLv2 | PythonSsl.OP_NO_SSLv3),
string cacertsfile = null,
X509Certificate2Collection certs = null) {
internal _SSLSocket(CodeContext context, _SSLContext sslcontext, PythonSocket.socket sock, bool server_side, string server_hostname) {
if (sock == null) {
throw PythonOps.TypeError("expected socket object, got None");
}

this.context = sslcontext;
_serverSide = server_side;
bool validate;
_certsMode = certs_mode;
_serverHostName = server_hostname;

_certsMode = sslcontext.verify_mode;

bool validate;
RemoteCertificateValidationCallback callback;
switch (certs_mode) {
switch (_certsMode) {
case PythonSsl.CERT_NONE:
validate = false;
callback = CertValidationCallback;
Expand All @@ -266,28 +261,24 @@ internal _SSLSocket(CodeContext context,
callback = CertValidationCallbackRequired;
break;
default:
throw new InvalidOperationException(String.Format("bad certs_mode: {0}", certs_mode));
throw new InvalidOperationException(String.Format("bad certs_mode: {0}", _certsMode));
}

_callback = callback;

if (certs != null) {
_certCollection = certs;
if (sslcontext._cert_store != null) {
_certCollection = sslcontext._cert_store;
}

if (certfile != null) {
_cert = PythonSsl.ReadCertificate(context, certfile);
}

if (cacertsfile != null) {
_certCollection = new X509Certificate2Collection(new[] { PythonSsl.ReadCertificate(context, cacertsfile) });
if (sslcontext._cafile != null) {
_cert = PythonSsl.ReadCertificate(context, sslcontext._cafile);
}

_socket = sock;

EnsureSslStream(false);

_protocol = protocol;
_protocol = sslcontext.protocol | sslcontext.options;
_validate = validate;
_context = context;
}
Expand Down Expand Up @@ -521,7 +512,7 @@ public object peer_certificate(bool binary_form) {

if (peerCert != null) {
if (binary_form) {
return peerCert.GetRawCertData().MakeString();
return Bytes.Make(peerCert.GetRawCertData());
} else if (_validate) {
return CertificateToPython(_context, peerCert);
}
Expand All @@ -548,24 +539,23 @@ public string issuer() {
return String.Empty;
}

[Documentation(@"read([len]) -> string

Read up to len bytes from the SSL socket.")]
public object read(CodeContext/*!*/ context, int len, ByteArray buffer = null) {
[Documentation(@"read(size, [buffer])
Read up to size bytes from the SSL socket.")]
public object read(CodeContext/*!*/ context, int size, ByteArray buffer = null) {
EnsureSslStream(true);

try {
byte[] buf = new byte[2048];
MemoryStream result = new MemoryStream(len);
MemoryStream result = new MemoryStream(size);
while (true) {
int readLength = (len < buf.Length) ? len : buf.Length;
int readLength = (size < buf.Length) ? size : buf.Length;
int bytes = _sslStream.Read(buf, 0, readLength);
if (bytes > 0) {
result.Write(buf, 0, bytes);
len -= bytes;
size -= bytes;
}

if (bytes == 0 || len == 0 || bytes < readLength) {
if (bytes == 0 || size == 0 || bytes < readLength) {
var res = result.ToArray();
if (buffer == null)
return Bytes.Make(res);
Expand Down Expand Up @@ -593,10 +583,9 @@ public string server() {
return String.Empty;
}

[Documentation(@"write(s) -> len
[Documentation(@"Writes the bytes-like object b into the SSL object.

Writes the string s into the SSL object. Returns the number
of bytes written.")]
Returns the number of bytes written.")]
public int write(CodeContext/*!*/ context, Bytes data) {
EnsureSslStream(true);

Expand Down Expand Up @@ -1046,8 +1035,10 @@ private static Exception ErrorDecoding(CodeContext context, params object[] args
public const int PROTOCOL_TLSv1_1 = 4;
public const int PROTOCOL_TLSv1_2 = 5;

public const uint OP_ALL = 0x80000BFF;
public const uint OP_DONT_INSERT_EMPTY_FRAGMENTS = 0x00000800;
public const int OP_ALL = unchecked((int)0x800003FF);
public const int OP_CIPHER_SERVER_PREFERENCE = 0x400000;
public const int OP_SINGLE_DH_USE = 0x100000;
public const int OP_SINGLE_ECDH_USE = 0x80000;
public const int OP_NO_SSLv2 = 0x01000000;
public const int OP_NO_SSLv3 = 0x02000000;
public const int OP_NO_TLSv1 = 0x04000000;
Expand Down
2 changes: 1 addition & 1 deletion Src/IronPythonTest/Cases/CPythonCasesManifest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ RunCondition=$(IS_POSIX)
[CPython.test_sqlite]
Ignore=true

[CPython.test_ssl]
[CPython.test_ssl] # IronPython.test_ssl_stdlib
Ignore=true
Reason=Blocking

Expand Down
6 changes: 5 additions & 1 deletion Src/StdLib/Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def handle_error(prefix):
sys.stdout.write(prefix + exc_format)

def can_clear_options():
if sys.implementation.name == 'ironpython': return True
# 0.9.8m or higher
return ssl._OPENSSL_API_VERSION >= (0, 9, 8, 13, 15)

Expand Down Expand Up @@ -569,7 +570,10 @@ def test_enum_certificates(self):
self.assertTrue(ssl.enum_certificates("ROOT"))

self.assertRaises(TypeError, ssl.enum_certificates)
self.assertRaises(WindowsError, ssl.enum_certificates, "")
if sys.implementation.name == "ironpython":
self.assertEqual(ssl.enum_certificates(""), [])
else:
self.assertRaises(WindowsError, ssl.enum_certificates, "")

trust_oids = set()
for storename in ("CA", "ROOT"):
Expand Down
32 changes: 25 additions & 7 deletions Tests/modules/network_related/test__ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import _ssl
import os
import socket
import sys
import unittest

from iptest import IronPythonTestCase, is_cli, is_netcoreapp, retryOnFailure, run_test, skipUnlessIronPython
Expand All @@ -27,13 +28,22 @@ def test_constants(self):
self.assertEqual(_ssl.CERT_NONE, 0)
self.assertEqual(_ssl.CERT_OPTIONAL, 1)
self.assertEqual(_ssl.CERT_REQUIRED, 2)
self.assertEqual(_ssl.PROTOCOL_SSLv2, 0)
if sys.version_info >= (3,5):
self.assertRaises(AttributeError, lambda: _ssl.PROTOCOL_SSLv2)
else:
self.assertEqual(_ssl.PROTOCOL_SSLv2, 0)
self.assertEqual(_ssl.PROTOCOL_SSLv23, 2)
self.assertEqual(_ssl.PROTOCOL_SSLv3, 1)
if sys.version_info >= (3,7):
self.assertRaises(AttributeError, lambda: _ssl.PROTOCOL_SSLv3)
else:
self.assertEqual(_ssl.PROTOCOL_SSLv3, 1)
self.assertEqual(_ssl.PROTOCOL_TLSv1, 3)
self.assertEqual(_ssl.PROTOCOL_TLSv1_1, 4)
self.assertEqual(_ssl.PROTOCOL_TLSv1_2, 5)
self.assertEqual(_ssl.OP_NO_SSLv2, 0x1000000)
if sys.version_info >= (3,7):
self.assertEqual(_ssl.OP_NO_SSLv2, 0)
else:
self.assertEqual(_ssl.OP_NO_SSLv2, 0x1000000)
self.assertEqual(_ssl.OP_NO_SSLv3, 0x2000000)
self.assertEqual(_ssl.OP_NO_TLSv1, 0x4000000)
self.assertEqual(_ssl.OP_NO_TLSv1_1, 0x10000000)
Expand Down Expand Up @@ -106,7 +116,8 @@ def test_SSLType_ssl(self):
context = _ssl._SSLContext(_ssl.PROTOCOL_SSLv23)
ssl_s = context._wrap_socket(s, False)

ssl_s.shutdown()
if is_cli:
ssl_s.shutdown()
s.close()

#sock, keyfile, certfile
Expand All @@ -133,6 +144,7 @@ def test_SSLType_ssl_neg(self):
#Cleanup
s.close()

@skipUnlessIronPython()
def test_SSLType_issuer(self):
#--Positive
s = socket.socket(socket.AF_INET)
Expand Down Expand Up @@ -165,6 +177,7 @@ def test_SSLType_issuer(self):
ssl_s.shutdown()
s.close()

@skipUnlessIronPython()
def test_SSLType_server(self):
#--Positive
s = socket.socket(socket.AF_INET)
Expand Down Expand Up @@ -206,8 +219,12 @@ def test_SSLType_read_and_write(self):
ssl_s = context._wrap_socket(s, False)
ssl_s.do_handshake()

self.assertIn("Writes the string s into the SSL object.", ssl_s.write.__doc__)
self.assertIn("Read up to len bytes from the SSL socket.", ssl_s.read.__doc__)
if is_cli or sys.version_info >= (3,5):
self.assertIn("Writes the bytes-like object b into the SSL object.", ssl_s.write.__doc__)
self.assertIn("Read up to size bytes from the SSL socket.", ssl_s.read.__doc__)
else:
self.assertIn("Writes the string s into the SSL object.", ssl_s.write.__doc__)
self.assertIn("Read up to len bytes from the SSL socket.", ssl_s.read.__doc__)

#Write
self.assertEqual(ssl_s.write(SSL_REQUEST),
Expand All @@ -225,7 +242,8 @@ def test_SSLType_read_and_write(self):
self.assertIn(SSL_RESPONSE, response)

#Cleanup
ssl_s.shutdown()
if is_cli:
ssl_s.shutdown()
s.close()

def test_parse_cert(self):
Expand Down
2 changes: 1 addition & 1 deletion Tests/test_socket_stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def load_tests(loader, standard_tests, pattern):
suite.addTest(unittest.expectedFailure(test.test_socket.UnbufferedFileObjectClassTestCase('testSmallReadNonBlocking'))) # TODO: figure out
suite.addTest(test.test_socket.UnbufferedFileObjectClassTestCase('testUnbufferedRead'))
suite.addTest(test.test_socket.UnbufferedFileObjectClassTestCase('testUnbufferedReadline'))
suite.addTest(test.test_socket.UnbufferedFileObjectClassTestCase('testWriteNonBlocking'))
#suite.addTest(test.test_socket.UnbufferedFileObjectClassTestCase('testWriteNonBlocking')) # fails intermittently during CI
suite.addTest(test.test_socket.UnicodeReadFileObjectClassTestCase('testAttributes'))
suite.addTest(test.test_socket.UnicodeReadFileObjectClassTestCase('testCloseAfterMakefile'))
suite.addTest(test.test_socket.UnicodeReadFileObjectClassTestCase('testClosedAttr'))
Expand Down
Loading