diff --git a/requirements.txt b/requirements.txt index 63cc203..fd150c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ flask -pykcs11 \ No newline at end of file +pykcs11 +retrying \ No newline at end of file diff --git a/setup.py b/setup.py index 494a67c..20c2a9c 100755 --- a/setup.py +++ b/setup.py @@ -9,8 +9,9 @@ version = '0.0.1' install_requires = [ - 'flask', - 'pykcs11' + 'flask', + 'pykcs11', + 'retrying' ] setup(name='pyeleven', diff --git a/src/pyeleven/__init__.py b/src/pyeleven/__init__.py index ec996c1..22e0bbd 100644 --- a/src/pyeleven/__init__.py +++ b/src/pyeleven/__init__.py @@ -1,48 +1,76 @@ +from base64 import b64decode from flask import Flask, request, jsonify -from .pk11 import pkcs11, intarray2bytes, mechanism, find_key, library +from .pk11 import pkcs11, load_library, slots_for_label +from .utils import mechanism, intarray2bytes import os -import sys import logging +from .pool import allocation +from retrying import retry +from PyKCS11 import PyKCS11Error __author__ = 'leifj' app = Flask(__name__) app.debug = True app.config.from_pyfile(os.path.join(os.getcwd(), 'config.py')) -app.secret_key = app.config.get("SECRET_KEY") -print app.config +max_retry = app.config.get('MAX_RETRY', 7) + + +def pin(): + return app.config.get('PKCS11PIN', None) + + +def secret_key(): + return app.config.get("SECRET_KEY") + + +def library_name(): + return str(app.config['PKCS11MODULE']) + + +#print app.config logging.basicConfig(level=logging.DEBUG) @app.route("/info") def _info(): - libn = app.config['PKCS11MODULE'] - return jsonify(dict(library=libn)) + return jsonify(dict(library=library_name())) + +@retry(stop_max_attempt_number=max_retry) +def _do_sign(label, keyname, mech, data, include_cert=True, require_cert=False): + if require_cert: + include_cert = True -def _find_slot(label): - slots = [] - lib = library(app.config['PKCS11MODULE']) - for slot in lib.getSlotList(): - slot_info = lib.getSlotInfo(slot) - if slot_info.get('label') == label: - slots.append(slot) - return slots + with pkcs11(library_name(), label, pin()) as si: + key, cert = si.find_key(keyname, find_cert=include_cert) + assert key is not None + result = dict(slot=label,signed=intarray2bytes(si.session.sign(key, data, mech)).encode('base64')) + if require_cert: + assert cert is not None + if cert and include_cert: + result['cert'] = cert + return result @app.route("///sign", methods=['POST']) def _sign(slot_or_label, keyname): - slots = [] - try: - slot = int(slot_or_label) - slots = [slot] - except ValueError: - slots = _find_slot(slot_or_label) + msg = request.get_json() + if not type(msg) is dict: + raise ValueError("request must be a dict") + + msg.setdefault('mech', 'RSAPKCS1') + if 'data' not in msg: + raise ValueError("missing 'data' in request") + data = b64decode(msg['data']) + mech = mechanism(msg['mech']) + return jsonify(_do_sign(slot_or_label, keyname, mech, data, require_cert=True)) + - if not slots: - raise ValueError("No slot found matching %s" % slot_or_label) +@app.route("///rawsign", methods=['POST']) +def _rawsign(slot_or_label, keyname): msg = request.get_json() if not type(msg) is dict: @@ -51,57 +79,47 @@ def _sign(slot_or_label, keyname): msg.setdefault('mech', 'RSAPKCS1') if 'data' not in msg: raise ValueError("missing 'data' in request") - data = msg['data'].decode('base64') - libn = app.config['PKCS11MODULE'] + data = b64decode(msg['data']) mech = mechanism(msg['mech']) - pin = app.config.get('PKCS11PIN', None) - for slot in slots: - try: - with pkcs11(libn, slot, pin=pin) as si: - key, cert = find_key(si, keyname) - assert key is not None - assert cert is not None - return jsonify(dict(slot=slot, - mech=msg['mech'], - signed=intarray2bytes(si.session.sign(key, data, mech)).encode('base64'))) - except Exception, ex: - logging.error(ex) - with pkcs11(libn, slot, pin=pin) as si: - si.exception = ex # invalidate it - - raise ValueError("Unable to sign using any of the matching slots") + return jsonify(_do_sign(slot_or_label, keyname, mech, data, include_cert=False)) @app.route("/", methods=['GET']) def _slot(slot_or_label): - slot = -1 - try: - slot = int(slot_or_label) - except ValueError: - slot = _find_slot(slot_or_label) + lib = load_library(library_name()) + slots = slots_for_label(slot_or_label, lib) + result = [] + for slot in slots: + r = dict() + try: + r['mechanisms'] = lib.getMechanismList(slot) + except PyKCS11Error, ex: + r['mechanisms'] = {'error': str(ex)} + try: + r['slot'] = lib.getSlotInfo(slot).to_dict() + except PyKCS11Error, ex: + r['slot'] = {'error': str(ex)} + try: + r['token'] = lib.getTokenInfo(slot).to_dict() + except PyKCS11Error, ex: + r['token'] = {'error': str(ex)} - lib = library(app.config['PKCS11MODULE']) - r = dict() - try: - r['mechanisms'] = lib.getMechanismList(slot) - except: - pass - try: - r['slot'] = lib.getSlotInfo(slot).to_dict() - except: - pass - try: - r['token'] = lib.getTokenInfo(slot).to_dict() - except: - pass - return jsonify(r) + result.append(r) + + return jsonify(dict(slots=result)) @app.route("/", methods=['GET']) def _token(): - lib = library(app.config['PKCS11MODULE']) + lib = load_library(library_name()) r = dict() + token_labels = dict() r['slots'] = lib.getSlotList() + for slot in r['slots']: + ti = lib.getTokenInfo(slot) + lst = token_labels.setdefault(ti.label.strip(), []) + lst.append(slot) + r['labels'] = token_labels return jsonify(r) diff --git a/src/pyeleven/pk11.py b/src/pyeleven/pk11.py index 541aaaf..11edb0c 100644 --- a/src/pyeleven/pk11.py +++ b/src/pyeleven/pk11.py @@ -1,14 +1,22 @@ -import base64 -from collections import namedtuple import threading +from .pool import ObjectPool, allocation +from .utils import intarray2bytes, cert_der2pem +from random import Random +import time +import logging +import PyKCS11 +from PyKCS11.LowLevel import CKA_ID, \ + CKA_LABEL, \ + CKA_CLASS, \ + CKO_PRIVATE_KEY, \ + CKO_CERTIFICATE, \ + CKK_RSA, \ + CKA_KEY_TYPE, \ + CKA_VALUE __author__ = 'leifj' -import logging -import PyKCS11 -from PyKCS11.LowLevel import CKA_ID, CKA_LABEL, CKA_CLASS, CKO_PRIVATE_KEY, CKO_CERTIFICATE, CKK_RSA, \ - CKA_KEY_TYPE, CKA_VALUE all_attributes = PyKCS11.CKA.keys() @@ -22,7 +30,6 @@ all_attributes = [e for e in all_attributes if isinstance(e, int)] thread_data = threading.local() -_session_lock = threading.RLock() def _modules(): @@ -37,15 +44,25 @@ def _sessions(): return thread_data.sessions -def mechanism(mech): - mn = "Mechanism%s" % mech - return getattr(PyKCS11, mn) +def _pools(): + if not hasattr(thread_data, 'pools'): + thread_data.pools = dict() + return thread_data.pools + + +def reset(): + _pools() + _sessions() + _modules() + thread_data.pools = dict() + thread_data.sessions = dict() + thread_data.modules = dict() -def library(lib_name): +def load_library(lib_name): modules = _modules() - if not lib_name in modules: - logging.debug("loading library %s" % lib_name) + if lib_name not in modules: + logging.debug("loading load_library %s" % lib_name) lib = PyKCS11.PyKCS11Lib() assert type(lib_name) == str # lib.load does not like unicode lib.load(lib_name) @@ -54,80 +71,133 @@ def library(lib_name): return modules[lib_name] -SessionInfo = namedtuple('SessionInfo', ['session', 'keys']) -class pkcs11(): +class SessionInfo(object): - def __init__(self, library, slot, pin=None): - self.library = library + def __init__(self, session, slot): + self.session = session self.slot = slot - self.pin = pin - self.exception = None - - def __enter__(self): - s = _sessions() - if self.library not in s: - s.setdefault(self.library, dict()) - - if self.slot not in s[self.library] or self.exception is not None: - s[self.library].setdefault(self.slot, dict()) - - lib = library(self.library) - session = lib.openSession(self.slot) - if self.pin is not None: - session.login(self.pin) - s[self.library][self.slot] = SessionInfo(session=session, keys=dict()) - - if self.slot not in s[self.library]: - raise EnvironmentError("Unable to open session") - - return s[self.library][self.slot] - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - -def intarray2bytes(x): - return ''.join(chr(i) for i in x) - - -def find_object(session, template): - for o in session.session.findObjects(template): - logging.debug("Found pkcs11 object: %s" % o) - return o - return None - - -def get_object_attributes(session, o): - attributes = session.session.getAttributeValue(o, all_attributes) - return dict(zip(all_attributes, attributes)) - - -def cert_der2pem(der): - x = base64.standard_b64encode(der) - r = "-----BEGIN CERTIFICATE-----\n" - while len(x) > 64: - r += x[0:64] - r += "\n" - x = x[64:] - r += x - r += "\n" - r += "-----END CERTIFICATE-----" - return r - - -def find_key(session, keyname): - if keyname not in session.keys: - key = find_object(session, [(CKA_LABEL, keyname), (CKA_CLASS, CKO_PRIVATE_KEY), (CKA_KEY_TYPE, CKK_RSA)]) - if key is None: - return None, None - key_a = get_object_attributes(session, key) - cert = find_object(session, [(CKA_ID, key_a[CKA_ID]), (CKA_CLASS, CKO_CERTIFICATE)]) - cert_pem = None - if cert is not None: - cert_a = get_object_attributes(session, cert) - cert_pem = cert_der2pem(intarray2bytes(cert_a[CKA_VALUE])) - logging.debug(cert) - session.keys[keyname] = (key, cert_pem) + self.keys = {} + self.use_count = 0 + + @property + def priority(self): + return self.use_count + + def __str__(self): + return "SessionInfo[session=%s,slot=%d,use_count=%d,keys=%d]" % (self.session, self.slot, self.use_count, len(self.keys)) + + def __cmp__(self, other): + return cmp(self.use_count, other.use_count) + + def find_object(self, template): + for o in self.session.findObjects(template): + logging.debug("Found pkcs11 object: %s" % o) + return o + return None + + def get_object_attributes(self, o): + attributes = self.session.getAttributeValue(o, all_attributes) + return dict(zip(all_attributes, attributes)) + + def find_key(self, keyname, find_cert=True): + if keyname not in self.keys: + key = self.find_object([(CKA_LABEL, keyname), (CKA_CLASS, CKO_PRIVATE_KEY), (CKA_KEY_TYPE, CKK_RSA)]) + if key is None: + return None, None + key_a = self.get_object_attributes(key) + cert_pem = None + if find_cert: + cert = self.find_object([(CKA_ID, key_a[CKA_ID]), (CKA_CLASS, CKO_CERTIFICATE)]) + if cert is not None: + cert_a = self.get_object_attributes(cert) + cert_pem = cert_der2pem(intarray2bytes(cert_a[CKA_VALUE])) + logging.debug(cert) + self.keys[keyname] = (key, cert_pem) + + return self.keys[keyname] + + @staticmethod + def open(lib, slot, pin=None): + sessions = _sessions() + if slot not in sessions: + session = lib.openSession(slot) + if pin is not None: + session.login(pin) + si = SessionInfo(session=session, slot=slot) + sessions[slot] = si + #print "opened session for %s:%d" % (lib, slot) + return sessions[slot] + + @staticmethod + def close_slot(slot): + sessions = _sessions() + if slot in sessions: + del sessions[slot] + + def close(self): + SessionInfo.close_slot(self.slot) + + +def _find_slot(label, lib): + slots = [] + for slot in lib.getSlotList(): + token_info = lib.getTokenInfo(slot) + if label == token_info.label.strip(): + slots.append(int(slot)) + return slots + + +def slots_for_label(label, lib): + try: + slot = int(label) + return [slot] + except ValueError: + return _find_slot(label, lib) + +seed = Random(time.time()) + + +def pkcs11(library_name, label, pin=None, low_mark=1): + pools = _pools() + sessions = _sessions() + + max_slots = len(slots_for_label(label, load_library(library_name))) + + def _del(*args, **kwargs): + si = args[0] + sd = kwargs['slots'] + if si.slot in sd: + del sd[si.slot] + si.close() + + def _bump(si): + si.use_count += 1 + + def _get(*args, **kwargs): + lib = load_library(library_name) + sd = kwargs['slots'] + + def _refill(): # if sd is getting a bit light - fill it back up + if len(sd) < low_mark: + for slot in slots_for_label(label, lib): + #print "found slot %d during refill" % slot + sd[slot] = True + + random_slot = None + while True: + _refill() + k = sd.keys() + random_slot = seed.choice(k) + #print random_slot + try: + return SessionInfo.open(lib, random_slot, pin) + except Exception, ex: # on first suspicion of failure - force the slot to be recreated + if random_slot in sd: + del sd[random_slot] + SessionInfo.close_slot(random_slot) + time.sleep(50/1000) # TODO - make retry delay configurable + logging.error(ex) + + return allocation(pools.setdefault(label, ObjectPool(_get, _del, _bump, maxSize=max_slots, slots=dict()))) - return session.keys[keyname] diff --git a/src/pyeleven/pool.py b/src/pyeleven/pool.py new file mode 100644 index 0000000..3f3d495 --- /dev/null +++ b/src/pyeleven/pool.py @@ -0,0 +1,51 @@ +# -*- coding:utf-8 -*- +try: + import Queue as Q +except ImportError: + import queue as Q + +from contextlib import contextmanager + + +class ObjectPool(object): + """A simple thread safe object pool""" + + def __init__(self, create, destroy, bump, *args, **kwargs): + super(ObjectPool, self).__init__() + self.create = create + self.destroy = destroy + self.bump = bump + self.args = args + self.kwargs = kwargs + self.maxSize = int(kwargs.get("maxSize", 1)) + self.queue = Q.PriorityQueue() + + def alloc(self): + if self.queue.qsize() < self.maxSize and self.queue.empty(): + n = self.maxSize - self.queue.qsize() + for i in range(0, n): # try to allocate enough objects to fill to maxSize + obj = self.create(*self.args, **self.kwargs) + #print "allocated %s" % obj + self.queue.put(obj) + return self.queue.get() + + def free(self, obj): + self.queue.put(obj) + + def invalidate(self, obj): + self.destroy(obj, *self.args, **self.kwargs) + + +@contextmanager +def allocation(pool): + obj = pool.alloc() + try: + yield obj + except Exception, e: + pool.invalidate(obj) + obj = None + raise e + finally: + if obj is not None: + pool.bump(obj) + pool.free(obj) diff --git a/src/pyeleven/test/__init__.py b/src/pyeleven/test/__init__.py index 1e72400..5dc310d 100644 --- a/src/pyeleven/test/__init__.py +++ b/src/pyeleven/test/__init__.py @@ -1,23 +1,30 @@ """ Testing the PKCS#11 shim layer """ -from flask import json -from .. import mechanism, intarray2bytes, find_key +from base64 import b64encode -__author__ = 'leifj' +from flask import json +from retrying import retry +from pyeleven.pk11 import slots_for_label, load_library +from ..utils import mechanism, intarray2bytes import pkg_resources import unittest import logging import os +import time +from shutil import copyfile import traceback import subprocess import tempfile -from PyKCS11 import PyKCS11Error -from PyKCS11.LowLevel import CKR_PIN_INCORRECT from .. import pk11 from unittest import TestCase -from .. import app +from .utils import ThreadPool +from threading import Thread +import random + +__author__ = 'leifj' + def _find_alts(alts): for a in alts: @@ -25,6 +32,7 @@ def _find_alts(alts): return a return None + P11_MODULE = _find_alts(['/usr/lib/libsofthsm.so', '/usr/lib/softhsm/libsofthsm.so']) P11_ENGINE = _find_alts(['/usr/lib/engines/engine_pkcs11.so']) P11_SPY = _find_alts(['/usr/lib/pkcs11/pkcs11-spy.so']) @@ -33,7 +41,6 @@ def _find_alts(alts): SOFTHSM = _find_alts(['/usr/bin/softhsm']) OPENSSL = _find_alts(['/usr/bin/openssl']) - if OPENSSL is None: raise unittest.SkipTest("OpenSSL not installed") @@ -53,7 +60,27 @@ def _find_alts(alts): softhsm_conf = None server_cert_pem = None server_cert_der = None -softhsm_db = None +softhsm_db_1 = None +softhsm_db_2 = None + + +def disable_tf(fn): + if os.path.exists(fn): + try: + os.rename(fn, "%s.bak" % fn) + except IOError, ex: + pass + else: + print "%s is gone!" % fn + + +def enable_tf(fn): + fn_old = "%s.bak" % fn + if os.path.exists(fn_old): + try: + os.rename(fn_old, fn) + except IOError, ex: + pass def _tf(): @@ -66,7 +93,7 @@ def _p(args): env = {} if softhsm_conf is not None: env['SOFTHSM_CONF'] = softhsm_conf - #print "env SOFTHSM_CONF=%s " % softhsm_conf +" ".join(args) + # print "env SOFTHSM_CONF=%s " % softhsm_conf +" ".join(args) proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) out, err = proc.communicate() if err is not None and len(err) > 0: @@ -77,17 +104,22 @@ def _p(args): if rv: raise RuntimeError("command exited with code != 0: %d" % rv) + @unittest.skipIf(P11_MODULE is None, "SoftHSM PKCS11 module not installed") def setup(): logging.debug("Creating test pkcs11 token using softhsm") try: global softhsm_conf - softhsm_db = _tf() + global softhsm_db_1 + global softhsm_db_2 + + softhsm_db_1 = _tf() + softhsm_db_2 = _tf() softhsm_conf = _tf() logging.debug("Generating softhsm.conf") with open(softhsm_conf, "w") as f: - f.write("#Generated by pyXMLSecurity test\n0:%s\n" % softhsm_db) + f.write("#Generated by pyeleven test\n0:%s\n1: %s\n" % (softhsm_db_1, softhsm_db_2)) logging.debug("Initializing the token") _p([SOFTHSM, '--slot', '0', @@ -169,6 +201,8 @@ def setup(): '-w', signer_cert_der, '--pin', 'secret1']) + copyfile(softhsm_db_1, softhsm_db_2) + except Exception, ex: traceback.print_exc() logging.warning("PKCS11 tests disabled: unable to initialize test token: %s" % ex) @@ -183,6 +217,7 @@ def teardown(self): class FlaskTestCase(TestCase): def setUp(self): + from .. import app os.environ['SOFTHSM_CONF'] = softhsm_conf app.config['TESTING'] = True app.config['PKCS11MODULE'] = P11_MODULE @@ -200,12 +235,37 @@ def test_info(self): def test_sign(self): rv = self.app.post("/0/test/sign", content_type='application/json', - data=json.dumps(dict(mech='RSAPKCS1', data="test".encode('base64')))) + data=json.dumps(dict(mech='RSAPKCS1', data=b64encode("test")))) + assert rv.data + d = json.loads(rv.data) + assert d is not None + assert 'slot' in d + assert 'signed' in d + + def test_1000_sign(self): + ts = time.time() + for i in range(0, 999): + rv = self.app.post("/test/test/sign", + content_type='application/json', + data=json.dumps(dict(mech='RSAPKCS1', data=b64encode("test")))) + assert rv.data + d = json.loads(rv.data) + assert d is not None + assert 'slot' in d + assert 'signed' in d + te = time.time() + print "1000 signatures (http): %2.3f sec (speed: %2.5f s/sig)" % (te - ts, (te - ts) / 1000) + + def test_label_sign(self): + rv = self.app.post("/test/test/sign", + content_type='application/json', + data=json.dumps(dict(mech='RSAPKCS1', data=b64encode("test")))) assert rv.data d = json.loads(rv.data) assert d is not None assert 'slot' in d assert 'signed' in d + assert d['signed'] def test_bad_sign_request(self): try: @@ -220,13 +280,32 @@ def test_slot_info(self): rv = self.app.get("/0") assert rv.data d = json.loads(rv.data) - print d + assert d + assert 'slots' in d + for nfo in d['slots']: + assert 'mechanisms' in nfo + assert 'slot' in nfo + assert 'token' in nfo + assert 'manufacturerID' in nfo['slot'] + assert 'SoftHSM' in nfo['slot']['manufacturerID'] + assert 'label' in nfo['token'] + assert 'test' in nfo['token']['label'] def test_token_info(self): rv = self.app.get("/") assert rv.data d = json.loads(rv.data) - print d + assert d + assert 'slots' in d + assert len(d['slots']) == 2 + assert 1 in d['slots'] + assert 0 in d['slots'] + assert 'labels' in d + assert 'test' in d['labels'] + test_slots = d['labels']['test'] + assert len(test_slots) == 2 + assert 1 in test_slots + assert 0 in test_slots class TestPKCS11(unittest.TestCase): @@ -235,21 +314,138 @@ def setUp(self): def test_open_session(self): os.environ['SOFTHSM_CONF'] = softhsm_conf + pk11.reset() with pk11.pkcs11(P11_MODULE, 0, "secret1") as session: assert session is not None + def test_multislot(self): + os.environ['SOFTHSM_CONF'] = softhsm_conf + pk11.reset() + lib = load_library(P11_MODULE) + slots = slots_for_label('test', lib) + assert 1 in slots + assert 0 in slots + assert len(slots) == 2 + def test_find_key(self): os.environ['SOFTHSM_CONF'] = softhsm_conf - with pk11.pkcs11(P11_MODULE, 0, "secret1") as session: - print session - assert session is not None - key, cert = find_key(session, 'test') + pk11.reset() + with pk11.pkcs11(P11_MODULE, 0, "secret1") as si: + assert si is not None + key, cert = si.find_key('test') + assert key is not None + assert cert is not None + + def test_find_key_spread(self): + os.environ['SOFTHSM_CONF'] = softhsm_conf + pk11.reset() + hits = {0: 0, 1: 0} + + @retry(stop_max_attempt_number=20) + def _try_sign(): + with pk11.pkcs11(P11_MODULE, 'test', "secret1") as si: + assert si is not None + key, cert = si.find_key('test') + assert key is not None + assert cert is not None + assert si.slot is not None + hits[si.slot] += 1 + if si.slot == random.choice([0, 1]): + raise ValueError("force a retry...") + + for i in range(0, 99): + _try_sign() + + assert hits[0] > 30 + assert hits[1] > 30 + + def test_find_key_by_label(self): + os.environ['SOFTHSM_CONF'] = softhsm_conf + pk11.reset() + with pk11.pkcs11(P11_MODULE, 'test', "secret1") as si: + assert si is not None + key, cert = si.find_key('test') assert key is not None assert cert is not None + def test_exception_reopen_session(self): + os.environ['SOFTHSM_CONF'] = softhsm_conf + pk11.reset() + for i in range(0, 10): + try: + with pk11.pkcs11(P11_MODULE, 'test', "secret1") as si: + assert si is not None + raise ValueError("oops...") + except ValueError: + pass + def test_sign(self): os.environ['SOFTHSM_CONF'] = softhsm_conf - with pk11.pkcs11(P11_MODULE, 0, "secret1") as si: - key, cert = find_key(si, 'test') + pk11.reset() + with pk11.pkcs11(P11_MODULE, 'test', "secret1") as si: + key, cert = si.find_key('test') signed = intarray2bytes(si.session.sign(key, 'test', mechanism('RSAPKCS1'))) - assert signed is not None \ No newline at end of file + assert signed is not None + + def test_1000_sign(self): + os.environ['SOFTHSM_CONF'] = softhsm_conf + pk11.reset() + ts = time.time() + for i in range(0, 999): + with pk11.pkcs11(P11_MODULE, 'test', "secret1") as si: + key, cert = si.find_key('test') + signed = intarray2bytes(si.session.sign(key, 'test', mechanism('RSAPKCS1'))) + assert signed is not None + te = time.time() + print "1000 signatures (p11): %2.3f sec (speed: %2.5f sec/s)" % (te - ts, (te - ts) / 1000) + + def test_stress_sign_sequential(self): + os.environ['SOFTHSM_CONF'] = softhsm_conf + pk11.reset() + + def _sign(msg): + with pk11.pkcs11(P11_MODULE, 0, "secret1") as si: + key, cert = si.find_key('test') + signed = intarray2bytes(si.session.sign(key, msg, mechanism('RSAPKCS1'))) + assert signed is not None + + for i in range(0, 999): + _sign("message %d" % i) + + def test_stress_sign_parallell_20(self): + os.environ['SOFTHSM_CONF'] = softhsm_conf + pk11.reset() + + def _sign(msg): + with pk11.pkcs11(P11_MODULE, 'test', "secret1") as si: + key, cert = si.find_key('test', find_cert=False) + signed = intarray2bytes(si.session.sign(key, msg, mechanism('RSAPKCS1'))) + assert signed is not None + + ts = time.time() + tp = ThreadPool(20) + for i in range(0, 999): + tp.add_task(_sign, "message %d" % i) + tp.wait_completion() + te = time.time() + print "1000 signatures (p11 parallell): %2.3f sec (speed: %2.5f sec/s)" % (te - ts, (te - ts) / 1000) + + def test_stress_sign_parallell_20_with_failovers(self): + os.environ['SOFTHSM_CONF'] = softhsm_conf + pk11.reset() + + @retry(stop_max_attempt_number=10) + def _sign(i): + msg = "message %d" % i + with pk11.pkcs11(P11_MODULE, 'test', "secret1") as si: + key, _ = si.find_key('test', find_cert=False) + signed = intarray2bytes(si.session.sign(key, msg, mechanism('RSAPKCS1'))) + assert signed is not None + + ts = time.time() + tp = ThreadPool(20) + for i in range(0, 999): # simulate 10 failures on each slot + tp.add_task(_sign, i) + tp.wait_completion() + te = time.time() + print "1000 signatures (p11 parallell): %2.3f sec (speed: %2.5f sec/s)" % (te - ts, (te - ts) / 1000) \ No newline at end of file diff --git a/src/pyeleven/test/utils.py b/src/pyeleven/test/utils.py new file mode 100644 index 0000000..5d05810 --- /dev/null +++ b/src/pyeleven/test/utils.py @@ -0,0 +1,35 @@ +from Queue import Queue +from threading import Thread + + +# from http://code.activestate.com/recipes/577187-python-thread-pool/ + +class Worker(Thread): + """Thread executing tasks from a given tasks queue""" + def __init__(self, tasks): + Thread.__init__(self) + self.tasks = tasks + self.daemon = True + self.start() + + def run(self): + while True: + func, args, kargs = self.tasks.get() + try: func(*args, **kargs) + except Exception, e: print e + self.tasks.task_done() + + +class ThreadPool: + """Pool of threads consuming tasks from a queue""" + def __init__(self, num_threads): + self.tasks = Queue(num_threads) + for _ in range(num_threads): Worker(self.tasks) + + def add_task(self, func, *args, **kargs): + """Add a task to the queue""" + self.tasks.put((func, args, kargs)) + + def wait_completion(self): + """Wait for completion of all the tasks in the queue""" + self.tasks.join() \ No newline at end of file diff --git a/src/pyeleven/utils.py b/src/pyeleven/utils.py new file mode 100644 index 0000000..98cf33d --- /dev/null +++ b/src/pyeleven/utils.py @@ -0,0 +1,24 @@ +import base64 +import PyKCS11 + + +def intarray2bytes(x): + return ''.join(chr(i) for i in x) + + +def mechanism(mech): + mn = "Mechanism%s" % mech + return getattr(PyKCS11, mn) + + +def cert_der2pem(der): + x = base64.standard_b64encode(der) + r = "-----BEGIN CERTIFICATE-----\n" + while len(x) > 64: + r += x[0:64] + r += "\n" + x = x[64:] + r += x + r += "\n" + r += "-----END CERTIFICATE-----" + return r