In [5]:
import numpy as np
import galois
import random
bits = 128

from collections import defaultdict
from collections import Counter
from collections import namedtuple
from operator import truediv

In [6]:
def keygen(bits):
    """Generates keys with `bits`-bits of security. Returns a pair: (secret key, public key)."""
    def invmod(x, m):
        gcd, s, t = galois.egcd(x, m)
        assert gcd == 1
        return s

    p = galois.random_prime(int(bits/2))
    q = galois.random_prime(int(bits/2))

    n = p*q
    g = n+1
    lamb = (p-1) * (q-1)
    mu = invmod(lamb, n)
    
    sk = (lamb, mu)
    pk = (n, g)
    return sk, pk

def encrypt(m, pk):
    """Encrypts the message `m` with public key `pk`."""
    n, g = pk
    n_sq = n**2
    r = random.randint(1, n)
    c = (pow(g, m, n_sq) * pow(r, n, n_sq)) % n_sq
    return c

def decrypt(c, sk, pk):
    """Decrypts the ciphertext `c` using secret key `sk` and public key `pk`."""
    lamb, mu = sk
    n, g = pk
    n_sq = n**2
    L_result = (pow(c, lamb, n_sq) - 1)//n
    return (L_result * mu) % n

def e_add_vec(ct, pk):
    """Add one encrypted integer to another"""
    n, g = pk
    sums = []
    sums.append(ct[0] * ct[1] % n**2)
    for i in range(2, len(ct)):
        sums.append(sums[-1] * ct[i] % n**2)

    return sums[-1]

In [7]:
class GenderPayGapSurveyParticipant:
    def submit_salary(self, salary, gender, age, server):
        """Submits an encrypted survey response to the server"""
        pk = server.get_public_key()
        gender_dict = {"Male": 1 , "Female": 0}
        age_bands = [16, 26, 36, 46, 56, 66, 76, 86, 96, 106, 116]
        temp_plc_holders = [0 for _ in range(len(age_bands))]
        
        # Generate temporary vectors where subvector position denotes which gender: pos 1 is male pay pands pos 0 is female pay bands
        temp_sal = [temp_plc_holders, temp_plc_holders]
        temp_band_counts = [temp_plc_holders, temp_plc_holders]

        # Create pay bands and band counts
        sal_bands = [salary if age>=age_bands[i] and age<age_bands[i+1] else 0 for i in range(len(age_bands))]
        band_counts = [1 if age>=age_bands[i] and age<age_bands[i+1] else 0 for i in range(len(age_bands))]

        # Now populate these vectors with the apprporiate salary bands and band counts, while leavint the other index an empty vector of zeros 
        # This allows for encoding of subvectors that will not leak which gender for the given pay bend 
        gender_dict_ind = gender_dict[gender]
        temp_sal[gender_dict_ind] = sal_bands
        temp_band_counts[gender_dict_ind] = band_counts
        
        enc_sal = [[encrypt(sal, pk) for sal in gen_sals] for gen_sals in temp_sal]
        enc_band_ct = [[encrypt(ct, pk) for ct in gen_cts] for gen_cts in temp_band_counts]

        server.submit_salary(enc_sal, enc_band_ct)
 
class GenderPayGapSurveyServer:
    def __init__(self):
        self.salaries = []
        self.band_cts = []
        self.sk, self.pk = keygen(32)
    
    def get_public_key(self):
        return self.pk
        
    def submit_salary(self, ct_salary_vector, ct_gender_vector):
        """Store an entry in the survey"""
        self.salaries.append(ct_salary_vector)
        self.band_cts.append(ct_gender_vector)
        
    
    def show_salaries(self):
        """Display the (encrypted) submitted salaries"""
        return self.salaries
    
    def compute_average_salaries(self):
        """Tally the results, decrypt, and return a 2-tuple: (average female salary, average male salary)"""
    
        full_dec = []
        for enc_vals in [self.salaries, self.band_cts]:
            arr_encs = np.array(enc_vals)
            arr_shape = list(arr_encs.shape)
            w_sums = []
            m_sums = []

            for i in range(arr_shape[2]):
                w_sums.append(e_add_vec(arr_encs[:, :, i][:,0], self.pk))
                m_sums.append(e_add_vec(arr_encs[:, :, i][:,1], self.pk))
            
            for sums in [w_sums, m_sums]:
                full_dec.append([decrypt(sum, self.sk, self.pk) for sum in sums])
       
        reorder = [0, 2, 1, 3]
        full_dec_gen_break = [full_dec[i] for i in reorder]
        return ("Female Salaries By Age Band:", [i / j if i != 0 else 0 for i, j in zip(full_dec_gen_break[0], full_dec_gen_break[1])], 
                "Male Salaries By Age Band:",  [i / j if i != 0 else 0 for i, j in zip(full_dec_gen_break[2], full_dec_gen_break[3])])
       

In [8]:
s = GenderPayGapSurveyServer()
GenderPayGapSurveyParticipant().submit_salary(10000, 'Male', 22, s)
GenderPayGapSurveyParticipant().submit_salary(30000, 'Female', 16 ,s)
GenderPayGapSurveyParticipant().submit_salary(15000, 'Male', 23, s)
GenderPayGapSurveyParticipant().submit_salary(20000, 'Female', 65, s)
GenderPayGapSurveyParticipant().submit_salary(20000, 'Female', 70, s)
s.compute_average_salaries()

('Female Salaries By Age Band:',
 [30000.0, 0, 0, 0, 20000.0, 20000.0, 0, 0, 0, 0, 0],
 'Male Salaries By Age Band:',
 [12500.0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])