In [1]:
from os import urandom
from enum import Enum
from pygost.gost3412 import GOST3412Kuznechik
from pygost.mgm import MGM
from pygost.mgm import nonce_prepare
from pygost import gost3410
from pygost import gost34112012256
from Crypto.Cipher import AES
from Crypto.Hash import HMAC, SHA256
from gostcrypto import gosthmac


class CipherSuiteEnum(Enum):
    DH_AES_128_GCM_SHA256 = 1
    ECDH_KUZNYECHIK_MGM_STREEBOG = 2


class CipherSuite:
    def encrypt(self, k, m):
        pass
    
    def decrypt(self, k, c, tag, nonce):
        pass
    
    def mac(self, data, key):
        pass
    
    def check_mac(self, mac, data, key):
        pass


class DH_AES_128_GCM_SHA256(CipherSuite):
    def get_enum(self):
        return CipherSuiteEnum.DH_AES_128_GCM_SHA256
    
    def mac(self, data, key):
        h = HMAC.new(key, digestmod=SHA256)
        h.update(data)
        return h.digest()
    
    def check_mac(self, mac, data, key):
        h = HMAC.new(key, digestmod=SHA256)
        h.update(data)
        return h.digest() == mac
    
    def encrypt(self, k, m):
        nonce = urandom(16)
        cipher = AES.new(k, AES.MODE_GCM, nonce)
        ct = cipher.encrypt(m)
        tag = cipher.digest()
        return ct, tag, nonce
        
    def decrypt(self, k, c, tag, nonce):
        cipher = AES.new(k, AES.MODE_GCM, nonce)
        return cipher.decrypt_and_verify(c, tag)


class ECDH_KUZNYECHIK_MGM_STREEBOG(CipherSuite):
    def get_enum(self):
        return CipherSuiteEnum.ECDH_KUZNYECHIK_MGM_STREEBOG
    
    def mac(self, data, key):
        hmac = gosthmac.new('HMAC_GOSTR3411_2012_256', key, data=data)
        return hmac.digest()
    
    def check_mac(self, mac, data, key):
        hmac = gosthmac.new('HMAC_GOSTR3411_2012_256', key, data=data)
        return hmac.digest() == mac
    
    def encrypt(self, k, m):
        mgm = MGM(GOST3412Kuznechik(k).encrypt, GOST3412Kuznechik.blocksize)
        nonce = nonce_prepare(urandom(16))
        seal = mgm.seal(nonce, m, b'')
        ct = seal[:-mgm.tag_size]
        tag = seal[-mgm.tag_size:]
        return ct, tag, nonce
    
    def decrypt(self, k, c, tag, nonce):
        mgm = MGM(GOST3412Kuznechik(k).encrypt, GOST3412Kuznechik.blocksize)
        nonce = nonce_prepare(nonce)
        return mgm.open(nonce, c + tag, b'')

In [2]:
# Tests

key = b'\x11' * 32
message = b'hello'

cipher_suite = DH_AES_128_GCM_SHA256()
ct, tag, nonce = cipher_suite.encrypt(key, message)
print(cipher_suite.decrypt(key, ct, tag, nonce))
mac = cipher_suite.mac(message, key)
print(cipher_suite.check_mac(mac, message, key))
print(cipher_suite.get_enum())

cipher_suite = ECDH_KUZNYECHIK_MGM_STREEBOG()
ct, tag, nonce = cipher_suite.encrypt(key, message)
print(cipher_suite.decrypt(key, ct, tag, nonce))
mac = cipher_suite.mac(message, key)
print(cipher_suite.check_mac(mac, message, key))
print(cipher_suite.get_enum())

b'hello'
True
CipherSuiteEnum.DH_AES_128_GCM_SHA256
b'hello'
True
CipherSuiteEnum.ECDH_KUZNYECHIK_MGM_STREEBOG


In [3]:
from collections import namedtuple
from pygost.gost3410 import CURVES


curve = CURVES["id-tc26-gost-3410-12-512-paramSetA"]

Certificate = namedtuple('Certificate', ['id', 'public_key', 'signature'])
KeyPair     = namedtuple('KeyPair',     ['public_key', 'private_key'])
Offer       = namedtuple('Offer', ['cipher_suite', 'group', 'generator'])
Mode        = namedtuple('Mode', ['cipher_suite'])

def generate_point():
    temp = gost3410.prv_unmarshal(urandom(64))
    return gost3410.public_key(curve, temp)

def generate_key_pair():
    prv = prv_unmarshal(urandom(64))
    return KeyPair(public_key(curve, prv), prv)

In [34]:
import pickle
from uuid import uuid4
from Crypto.Random import random
from Crypto.Protocol.KDF import HKDF
from Crypto.Util.number import getPrime, getRandomRange, getRandomNBitInteger, long_to_bytes
from pygost.gost3410 import verify
from gostcrypto import gosthmac


class Participant:
    def __init__(self, ca):
        self._ca = ca
        self._key_pair = generate_key_pair()
        self._id = uuid4()
        self._certificate = ca.issue_certificate(self)

    def get_id(self):
        return self._id
    
    # Returns (Q, P).
    def auth_get_nonce(self):
        self._auth_P = curve.x, curve.y
        self._auth_n = random.getrandbits(64)
        return curve.exp(self._auth_n, *self._auth_P), self._auth_P
    
    def auth_get_result(self, c):
        return self._auth_n + c * self._key_pair.private_key
    
    def get_public_key(self):
        return self._key_pair.public_key
    
    def get_certificate(self):
        return self._certificate
    
    def revoke_certificate(self, ca):
        ca.revoke(self)
    
    def sign(self, args):
        dgst = gost34112012512.new(args).digest()[::-1]
        signature = sign(curve, self._key_pair.private_key, dgst)
        return signature


class TlsParticipant(Participant):
    def __init__(self, ca):
        super().__init__(ca)
        
    def check_certificate(self, certificate, ca):
        data = str(certificate.id).encode() + pub_marshal(certificate.public_key)
        return ca.check_certificate(certificate) and \
            self.check_signature(certificate.signature, data, ca.get_public_key())
    
    def check_signature(self, signature, data, public_key):
        dgst = gost34112012512.new(data).digest()[::-1]
        return verify(curve, public_key, dgst, signature)


class Client(TlsParticipant):
    def __init__(self, ca, cipher_suite: CipherSuite):
        super().__init__(ca)
        self._cipher_suite = cipher_suite
    
    def establish_connection(self, server, connection_type):
        if self._cipher_suite.get_enum() == CipherSuiteEnum.DH_AES_128_GCM_SHA256:
            self.establish_connection_dh(server, connection_type)
        elif self._cipher_suite.get_enum() == CipherSuiteEnum.ECDH_KUZNYECHIK_MGM_STREEBOG:
            self.establish_connection_ecdh(server, connection_type)
    
    def establish_connection_ecdh(self, server, connection_type):
        p = curve.p
        g = (curve.x, curve.y)
        alpha = getRandomRange(2, p)
        
        client_pk = curve.exp(alpha, *g)
        client_nonce = getRandomNBitInteger(2048)
        offer = Offer(self._cipher_suite.get_enum(), p, g)
        
        server_hello = server.tls_get_hello(self, client_pk, client_nonce, offer)
        server_pk, server_nonce, mode, c1, c2, c3, c4 = server_hello
        
        psk = curve.exp(alpha, *server_pk)
        
        common_bytes = long_to_bytes(client_pk[0]) + long_to_bytes(client_pk[1]) + long_to_bytes(client_nonce) + str(offer).encode() + long_to_bytes(server_pk[0]) + long_to_bytes(server_pk[1]) + long_to_bytes(server_nonce) + str(mode).encode()
        k_ch, k_cm = HKDF(long_to_bytes(psk[0]) + long_to_bytes(psk[1]) + common_bytes,
                          32, b'', SHA256, 2)
        
        print(f'{k_ch = }\n{k_cm = }')
        
        m1 = self._cipher_suite.decrypt(k_ch, *c1)
#         print(m1)

        m2 = pickle.loads(self._cipher_suite.decrypt(k_ch, *c2))
#         print(m2)
        if not self.check_certificate(m2, self._ca):
            raise ArithmeticError('INVALID CERTIFICATE')

        m3 = self._cipher_suite.decrypt(k_ch, *c3)
#         print(m3)
        if not self.check_signature(m3, common_bytes + c1[0] + c2[0], server.get_public_key()):
            raise ArithmeticError('INVALID SIGNATURE')
            
        m4 = self._cipher_suite.decrypt(k_ch, *c4)
#         print(m4)
        if not self._cipher_suite.check_mac(m4, common_bytes + c1[0] + c2[0] + c3[0], k_cm):
            raise ArithmeticError('INVALID MAC')

        self._k_cs, self._k_sc = HKDF(long_to_bytes(psk[0]) + long_to_bytes(psk[1]) + common_bytes + c1[0] + c2[0] + c3[0] + c4[0],
                                      32, b'', SHA256, 2)
        
        print(f'{self._k_cs = }\n{self._k_sc = }')
        
        if m1 == b'yes':
            if not server.check_certificate(self._certificate, self._ca):
                raise ArithmeticError('INVALID CERTIFICATE')
    
    def establish_connection_dh(self, server, connection_type):
        p = getPrime(2048)
        g = getRandomRange(2, p)
        alpha = getRandomRange(2, p)
        
        client_pk = pow(g, alpha, p)
        client_nonce = getRandomNBitInteger(2048)
        offer = Offer(self._cipher_suite.get_enum(), p, g)
        
        server_hello = server.tls_get_hello(self, client_pk, client_nonce, offer)
        server_pk, server_nonce, mode, c1, c2, c3, c4 = server_hello
        
        psk = pow(server_pk, alpha, p)
        
        common_bytes = long_to_bytes(client_pk) + long_to_bytes(client_nonce) + str(offer).encode() + long_to_bytes(server_pk) + long_to_bytes(server_nonce) + str(mode).encode()
        
        k_ch, k_cm = HKDF(long_to_bytes(psk) + common_bytes,
                          32, b'', SHA256, 2)
        
        print(f'{k_ch = }\n{k_cm = }')
        
        m1 = self._cipher_suite.decrypt(k_ch, *c1)
#         print(m1)

        m2 = pickle.loads(self._cipher_suite.decrypt(k_ch, *c2))
#         print(m2)
        if not self.check_certificate(m2, self._ca):
            raise ArithmeticError('INVALID CERTIFICATE')

        m3 = self._cipher_suite.decrypt(k_ch, *c3)
#         print(m3)
        if not self.check_signature(m3, common_bytes + c1[0] + c2[0], server.get_public_key()):
            raise ArithmeticError('INVALID SIGNATURE')
            
        m4 = self._cipher_suite.decrypt(k_ch, *c4)
#         print(m4)
        if not self._cipher_suite.check_mac(m4, common_bytes + c1[0] + c2[0] + c3[0], k_cm):
            raise ArithmeticError('INVALID MAC')

        self._k_cs, self._k_sc = HKDF(long_to_bytes(psk) + common_bytes + c1[0] + c2[0] + c3[0] + c4[0],
                                      32, b'', SHA256, 2)
        
        print(f'{self._k_cs = }\n{self._k_sc = }')
        
        if m1 == b'yes':
            if not server.check_certificate(self._certificate, self._ca):
                raise ArithmeticError('INVALID CERTIFICATE')

    def check_mac(self, mac, data, key):
        return self._cipher_suite.check_mac(mac, data, key)


class Server(TlsParticipant):
    def __init__(self, ca, cipher_suites: list[CipherSuite]):
        super().__init__(ca)
        self._cipher_suites = cipher_suites
    
    def check_mac(self, mac, data, key):
        return self._cipher_suite.check_mac(mac, data, key)
    
    def tls_get_hello_dh(self, client, client_pk, client_nonce, offer):
        _, p, g = offer
        beta = getRandomRange(2, p)
        
        server_pk = pow(g, beta, p)
        server_nonce = getRandomNBitInteger(2048)
        mode = Mode(self._cipher_suite.get_enum())
        
        psk = pow(client_pk, beta, p)
        
        common_bytes = long_to_bytes(client_pk) + long_to_bytes(client_nonce) + str(offer).encode() + long_to_bytes(server_pk) + long_to_bytes(server_nonce) + str(mode).encode()
        
        k_sh, k_sm = HKDF(long_to_bytes(psk) + common_bytes,
                          32, b'', SHA256, 2)
        
        print(f'{k_sh = }\n{k_sm = }')
        
        cert_request = b'yes'
        c1 = self._cipher_suite.encrypt(k_sh, cert_request)
        c2 = self._cipher_suite.encrypt(k_sh, pickle.dumps(self._certificate))
        c3 = self._cipher_suite.encrypt(k_sh, self.sign(common_bytes + c1[0] + c2[0]))
        c4 = self._cipher_suite.encrypt(k_sh, self._cipher_suite.mac(common_bytes + c1[0] + c2[0] + c3[0], k_sm))
        
        self._k_cs, self._k_sc = HKDF(long_to_bytes(psk) + common_bytes + c1[0] + c2[0] + c3[0] + c4[0],
                                      32, b'', SHA256, 2)
        
        print(f'{self._k_cs = }\n{self._k_sc = }')
        
        return server_pk, server_nonce, mode, c1, c2, c3, c4

    def tls_get_hello_ecdh(self, client_pk, client_nonce, offer):
        _, p, g = offer
        beta = getRandomRange(2, p)
        
        server_pk = curve.exp(beta, *g)
        server_nonce = getRandomNBitInteger(2048)
        mode = Mode(self._cipher_suite.get_enum())
        
        psk = curve.exp(beta, *client_pk)
        
        common_bytes = long_to_bytes(client_pk[0]) + long_to_bytes(client_pk[1]) + long_to_bytes(client_nonce) + str(offer).encode() + long_to_bytes(server_pk[0]) + long_to_bytes(server_pk[1]) + long_to_bytes(server_nonce) + str(mode).encode()
        
        k_sh, k_sm = HKDF(long_to_bytes(psk[0]) + long_to_bytes(psk[1]) + common_bytes,
                          32, b'', SHA256, 2)
        
        print(f'{k_sh = }\n{k_sm = }')
        
        cert_request = b'no'
        c1 = self._cipher_suite.encrypt(k_sh, cert_request)
        c2 = self._cipher_suite.encrypt(k_sh, pickle.dumps(self._certificate))
        c3 = self._cipher_suite.encrypt(k_sh, self.sign(common_bytes + c1[0] + c2[0]))
        c4 = self._cipher_suite.encrypt(k_sh, self._cipher_suite.mac(common_bytes + c1[0] + c2[0] + c3[0], k_sm))
        
        self._k_cs, self._k_sc = HKDF(long_to_bytes(psk[0]) + long_to_bytes(psk[1]) + common_bytes + c1[0] + c2[0] + c3[0] + c4[0],
                                      32, b'', SHA256, 2)
        
        print(f'{self._k_cs = }\n{self._k_sc = }')
        
        return server_pk, server_nonce, mode, c1, c2, c3, c4
    
    def tls_get_hello(self, client, client_pk, client_nonce, offer):
        cipher_suite, group, generator = offer

        if cipher_suite == CipherSuiteEnum.DH_AES_128_GCM_SHA256:
            self._cipher_suite = DH_AES_128_GCM_SHA256()
            return self.tls_get_hello_dh(client, client_pk, client_nonce, offer)
        
        if cipher_suite == CipherSuiteEnum.ECDH_KUZNYECHIK_MGM_STREEBOG:
            self._cipher_suite = ECDH_KUZNYECHIK_MGM_STREEBOG()
            return self.tls_get_hello_ecdh(client_pk, client_nonce, offer)

In [35]:
from pygost import gost34112012512
from pygost.gost3410 import pub_marshal
from pygost.gost3410 import prv_unmarshal
from pygost.gost3410 import public_key
from pygost.gost3410 import sign


class CA:
    def __init__(self):
        self._certificates = set({})
        self._crl = set({})
        self._key_pair = generate_key_pair()

    def check_authenticity(self, participant):
        self._auth_Par_pub = participant.get_public_key()
        self._auth_Q, self._auth_P = participant.auth_get_nonce()
        self._auth_c = random.getrandbits(64)
        self._auth_t = participant.auth_get_result(self._auth_c)
        
        return curve.exp(self._auth_t, *self._auth_P) == \
            curve._add(*self._auth_Q, *curve.exp(self._auth_c, *self._auth_Par_pub))
        
    def issue_certificate(self, participant):
        par_id = participant.get_id()
        par_key = participant.get_public_key()
        
        if not self.check_authenticity(participant):
            print('INVALID PRIVATE KEY')
            return

        data_for_signing = str(par_id).encode() + pub_marshal(par_key)
        dgst = gost34112012512.new(data_for_signing).digest()[::-1]
        signature = sign(curve, self._key_pair.private_key, dgst)
        
        certificate = Certificate(par_id, par_key, signature)
        self._certificates.add(certificate)
        
        return certificate
    
    def revoke(self, participant):
        if not self.check_authenticity(participant):
            print('INVALID PRIVATE KEY')
            return
        
        par_cert = participant.get_certificate()
        if par_cert in self._certificates:
            self._certificates.remove(par_cert)
            self._crl.add(par_cert)
    
    def check_certificate(self, certificate):
        return (certificate not in self._crl) and (certificate in self._certificates)

    def get_public_key(self):
        return self._key_pair.public_key

In [36]:
# Tests

ca = CA()
par = Client(ca, DH_AES_128_GCM_SHA256())
par.check_certificate(par._certificate, ca)

True

In [37]:
ca = CA()
client = Client(ca, ECDH_KUZNYECHIK_MGM_STREEBOG())
server = Server(ca, [DH_AES_128_GCM_SHA256])

client.establish_connection(server, 1)

k_sh = b'\xdd\xba\xc3\xb8N\x1b\xda\xdc\x1e\xda5\x03\xa7\x92\x07\x95\x10\x1epT\x19\xed9\xad\xbd\xa0\xe7\xc8Pc\xd5\xc8'
k_sm = b')\xacb3\x0f1O\xc0\xbaxC\x84\xb4\x9fG\xa0\xd6\x904\x03\x93\xfb\xde@V\xa1]\xc2\xde\xee}\xc8'
self._k_cs = b'\xac\xd4\xbe\x01\x92\xc5\xf4\xb4>\xa6uM\xb1+7W\xf5=\xd8\xa9\x8e\r:\x847\xa8\xb8\x83\xbcN\xb9v'
self._k_sc = b'8y+\xa6\xdatt1\xc7P\xc9\xefby;\xb1Q\x7f\x7fi\rcW\xaatt\xaf\xd4\xb4y\x1f\xb8'
k_ch = b'\xdd\xba\xc3\xb8N\x1b\xda\xdc\x1e\xda5\x03\xa7\x92\x07\x95\x10\x1epT\x19\xed9\xad\xbd\xa0\xe7\xc8Pc\xd5\xc8'
k_cm = b')\xacb3\x0f1O\xc0\xbaxC\x84\xb4\x9fG\xa0\xd6\x904\x03\x93\xfb\xde@V\xa1]\xc2\xde\xee}\xc8'
self._k_cs = b'\xac\xd4\xbe\x01\x92\xc5\xf4\xb4>\xa6uM\xb1+7W\xf5=\xd8\xa9\x8e\r:\x847\xa8\xb8\x83\xbcN\xb9v'
self._k_sc = b'8y+\xa6\xdatt1\xc7P\xc9\xefby;\xb1Q\x7f\x7fi\rcW\xaatt\xaf\xd4\xb4y\x1f\xb8'
