In [None]:
# Imports
import pandas as pd
import numpy as np
import random
import math
import timeit
import itertools
import warnings
import pickle
import gc
import sys
import matplotlib.pyplot as plt
from os.path import join, isfile
from collections import Counter
import networkx as nx
import time
from scipy.special import gamma

from chronometer import ch # Easy chronometer that I implemented. First call: start, Second call: stop

warnings.filterwarnings('ignore')
np.set_printoptions(suppress=True, formatter={'float': lambda x: "{0:0.4f}".format(x)})

In [None]:
def load_files(data_path):
    beacon = pd.read_csv(join(data_path, "Beacon_164.txt"), index_col=0, delim_whitespace=True) # Needs to be sorted wrt rs100...
    maf = pd.read_csv(join(data_path, "MAF.txt"), index_col=0, delim_whitespace=True)   # Alelle frequency of only chromosome 1 
    af = pd.read_csv(join(data_path, "AF_CEU_all_chromosomes.txt"))   ## Alelle frequency of all chromosomes
    ld = pd.read_csv(join(data_path, "ld_new20_CEU_07_sortedM.txt"), sep = "," , index_col=0)
    
    af = af.sort_values(by=["markerId"]) # Needs to be sorted
    
    chromosome_index = 1 # Put in a JSON file
    indices = (maf["chromosome"] == "chr"+str(chromosome_index))
    
    maf = maf.loc[indices]  
    beacon = beacon.loc[indices.to_numpy()]

    return (beacon, af, ld)

In [None]:
# Constructing a graph using columns of the LD file. The graph is directed and is used in QI Attack.
# Since the column names are not proper, it is a good practice to change their names before using it.
# Thereby, we will access the weights with "weight" keyword.
# Otherwise we would have to access with "Var11" and "Var12" which is not intuitive.
def construct_graph(ld):
    ld.rename(columns={"Var11":"weight"}, inplace=True)
    G1 = nx.from_pandas_edgelist(ld, "Var4", "Var5", ["weight"], create_using=nx.DiGraph())
    ld.rename(columns={"weight":"Var11", "Var12":"weight"}, inplace=True)
    G2 = nx.from_pandas_edgelist(ld, "Var5", "Var4", ["weight"], create_using=nx.DiGraph())
    return nx.compose(G1, G2)

In [None]:
def filter_biallelic(curAF, markers, snps):
    ind = ~((snps == "AA") | (snps == "CC") | (snps == "GG") | (snps == "NN") | (snps == "TT"))

    lower = curAF["referenceAlleleFrequency"] < 0.5
    snps[lower] = curAF["referenceAllele"].loc[lower]
    upper = ~lower
    snps[upper] = curAF["otherAllele"].loc[upper]

    snps = snps[ind]
    markers = markers[ind]
    curAF = curAF[ind]
    
    return curAF, markers, snps

In [None]:
# This function simulates the query system in beacon networks. It returns 0 or 1, not exist or exist, respectively.
def query_beacon(marker, allele, beacon):
    try:
        vector = beacon.loc[marker].unique() # unique alleles 
        for i in range(len(vector)):
            if vector[i][0] == allele or vector[i][1] == allele:
                return 1
        return 0
    except KeyError:
        print("No such marker")
        return 0

In [None]:
def plot_power(power):
    plt.box(False)
    plt.xlabel("Number of Queries")
    plt.ylabel("Power")
    plt.xlim(0, len(power)-1)
    plt.ylim(0,1)
    plt.plot(power)

# Calculate power array using case delta and control delta results.
def calculate_power(cs, cn, test_size, query_count, t_alpha, do_power_plot=True):
    # Move all zeros to the end. ??
    for i in range(test_size):
        cs[i,:] = np.concatenate((cs[i,cs[i,:]!=0], cs[i,cs[i,:]==0]))
        cn[i,:] = np.concatenate((cn[i,cn[i,:]!=0], cn[i,cn[i,:]==0]))

    power = np.ones(query_count + 1)
    power[0] = 0

    for q in range(query_count):
        if len(cn[cn[:,q] != 0, q]) == 0: # We are checking because otherwise percentile function gives error.
            power[q+1] = 0
            continue
        # Count the elements above the threshold.
        power[q+1] = np.sum(cs[cs[:,q] != 0, q] < np.percentile(cn[cn[:,q] != 0, q], t_alpha)) / np.sum(cs[:,q] != 0)
    
    if do_power_plot:
        plot_power(power)
    
    return power

In [None]:
def calculate_delta(error, af_term, NN, br):
    log1 = np.log(np.square(1-af_term) / error)
    log2 = np.log((error / np.square(1-af_term)) * ((1-np.power(1-af_term, 2*NN)) / (1-error*np.power(1-af_term, 2*NN-2))))
    return np.sum(log1 + log2*br) 

In [None]:
# Graph function which calculates the sum of edges from the source "source" to possibly many targets "targets".
# Notes:
# To calculate, I use for loop which may be slow. However since the length of targets is not very high,
# it did not cause a problem. But it is better to improve it if possible.
def sum_of_edge_weights(graph, source, targets):
    sum_of_weights = 0
    for t in targets: # Can it be improved ??
        sum_of_weights += graph[source][t]["weight"]
    return sum_of_weights
    
# After calculating sum, it is divided by the length of targets.
def mean_of_edge_weights(graph, source, targets):
    return 0 if len(targets) == 0 else sum_of_edge_weights(graph, source, targets) / len(targets) 

# It finds all the edges from the "source" and returns target nodes which also exist in pass_nodes.
def filter_edges(graph, source, pass_nodes):
    edges = np.array(list(graph[source]))
    return edges[np.in1d(edges, pass_nodes, assume_unique=True)]

In [None]:
def calculate_delta_SB(DN, DN_1, error, br, n):
    log1 = np.log(DN / (DN_1 * error))
    log2 = np.log((error * DN_1 * (1 - DN)) / (DN*(1-error*DN_1)))
    return n*log1 + log2*np.sum(br)

def get_DN(a, b, N):
    return gamma(a + b) / (gamma(b) * (2*N + a + b)**a)

In [None]:
def attack(beacon, h_vals, error, test_group, query_count, testAF, LD_Graph, attack_type="OPTIMAL"):
    
    NN, test_size = beacon.shape[1], test_group.shape[1]    

    Delta = np.zeros((len(h_vals), test_size, query_count))  

    alreadyAsked = np.zeros((len(h_vals), test_size, query_count), dtype=bool)

    for person in range(test_size): # test_size
        print("[{}] - Person ID: {}".format(person, test_group.columns[person]))
        curAF, temp_markers, temp_snps = filter_biallelic(testAF, test_group.index.to_numpy(), np.array(test_group.iloc[:, person].values))
        
        for h in h_vals:
            markers = np.array(temp_markers) # copy
            snps = np.array(temp_snps) # copy
            beacon_response = np.zeros((query_count), dtype=int)

            min_af = np.minimum(curAF["referenceAlleleFrequency"].values, curAF["otherAlleleFrequency"].values)            

            sorted_ind = np.lexsort((markers, min_af))
            
            min_af = min_af[sorted_ind] # Sorting based on minimum AF
            snps = snps[sorted_ind] # Sorting based on minimum AF
            markers = markers[sorted_ind] # Sorting based on minimum AF
            
            min_af[min_af == 0] = 1e-9
            mask = np.where(min_af < (h / 100)) 
            markers = np.delete(markers, mask)
            snps = np.delete(snps, mask)
            min_af = np.delete(min_af, mask)
            
            innersum = 0
            
            markers = markers[:query_count]
            min_af = min_af[:query_count]

            for q in range(query_count):
                
                beacon_response[q] = query_beacon(markers[q], snps[q], beacon)
    
                if attack_type == "OPTIMAL":
                    dlt = calculate_delta(error, min_af[:q+1], NN, beacon_response[:q+1])
                    Delta[h_vals[h], person, q] = dlt

                elif attack_type == "QI":

                    if alreadyAsked[h_vals[h], person, q]:
                        continue
                    
                    if markers[q] in LD_Graph :

                        nodes = np.array(sorted(list(LD_Graph[markers[q]])))
                        indices = np.in1d(nodes, markers, assume_unique=True)
                        nodes = nodes[indices]
                        
                        indices = np.in1d(markers, nodes, assume_unique=True) # Intersection
                        MAFinner = min_af[indices]
                        
                        nodes = np.concatenate((markers[indices], [markers[q]])) # Append itself
                     
                        # Why we are appending if we don't use it.
                        
                        if len(MAFinner) > 0:
                            alreadyAsked[h_vals[h], person, indices[:query_count]] = True 
                            
                            if len(nodes) > 0: # used to be 1. But why? Definitely greater than zero
                                w_sum = np.zeros((len(nodes), 3))
                                
                                for j in range(len(nodes)):
                                    edges = filter_edges(LD_Graph, nodes[j], nodes)

                                    w_sum[j, 0] = mean_of_edge_weights(LD_Graph, nodes[j], edges)
                                    w_sum[j, 1] = len(edges)
                                    w_sum[j, 2] = j
                                    
                                w_sum = w_sum[np.argsort(w_sum[:, 1]),:] # Sort based on the second col.
                                w_sum = w_sum[:round(len(nodes) * 0.3),:]    # 0.3 should be an argument?
                                query_m = nodes[int(w_sum[np.argmax(w_sum[:,0]), 2])]

                                nodes = np.array(list(LD_Graph[query_m]))
                                indices = np.in1d(markers, nodes, assume_unique=True)
                                MAFinner = min_af[indices]
                                
                                edges = filter_edges(LD_Graph, query_m, markers) # nodes
                                mean = mean_of_edge_weights(LD_Graph, query_m, edges)
                                allele = snps[np.where(markers==query_m)[0][0]] # Find the index of query_m, then access to snps.
                                
                                innersum += calculate_delta(error, MAFinner, NN, query_beacon(query_m, allele, beacon))*mean
                    else:
                        pass
                    
                    mask = ~alreadyAsked[h_vals[h], person, :1+q]
                    Delta[h_vals[h], person, q] = innersum + calculate_delta(error, min_af[:1+q][mask], NN, beacon_response[:1+q][mask])

    return Delta    

In [None]:
def attack_SB(beacon, error, test_group, query_count, testAF, a_prime, b_prime):
    
    NN, test_size = beacon.shape[1], test_group.shape[1] 
        
    DN = get_DN(a_prime + 1, b_prime + 1, NN)
    DN_1 = get_DN(a_prime + 1, b_prime + 1, NN-1)
    
    Delta = np.zeros((1, test_size, query_count))  # Only for h = 0
    
    for person in range(test_size): # test_size
#         print("[{}] - Person ID: {}".format(person, test_group.columns[person]))
        curAF, markers, snps = filter_biallelic(testAF, test_group.index.to_numpy(), np.array(test_group.iloc[:, person].values))
        
        beacon_response = np.zeros((query_count), dtype=int)
        ind = np.random.permutation(len(markers))  # Select random markers in each iteration.
        
        for q in range(query_count):
            beacon_response[q] = query_beacon(markers[ind[q]], snps[ind[q]], beacon)       
            Delta[0, person, q] = calculate_delta_SB(DN, DN_1, error, beacon_response, q+1) 
    
    return Delta

In [None]:
def get_prefix_count(prefix, remaining_cor):
    total = np.sum(remaining_cor == np.array([prefix]).T, axis=0)
    return len(total[total == remaining_cor.shape[0]])
    
def calculate_prob(prefix, remaining_cor):
    # correlations => 5x84
    counts = [0, 0, 0]
    for i in range(len(counts)):
        counts[i] = get_prefix_count(prefix+[i], remaining_cor)

    total = sum(counts)
    if total == 0:        # We want 2 zero elements. 0 0 0 is not valid. 10 0 0 is valid. 8 2 0 is not valid.
        return False
    if total in counts:
        return True
    return False

def enumerate_alleles(remaining_GI, testAF, individual=False):
    
    m = remaining_GI.shape[0]
    table = remaining_GI.copy(deep=True)
    
    if individual:
        table = table.to_frame() # Do this when table is a series
    
    for i in range(m):
        marker = table.index[i]
        curAF = testAF[testAF["markerId"] == marker]
        if len(curAF) == 0:
            return 0 # Is it the best way?
        
        allele = curAF.iloc[0]["referenceAllele"]
        if curAF.iloc[0]["referenceAlleleFrequency"] > 0.5:
            allele = curAF.iloc[0]["otherAllele"]
        
        regexKeys = ["NN", "TT", "GG", "AA", "CC", ".."]
        regexVals = [-1, 0, 0, 0, 0, 1]
        
        regexVals[regexKeys.index(allele*2)] = 2
        table.iloc[i].replace(regexKeys, regexVals, inplace=True, regex=True)
    
    table = np.array(table.to_numpy())
    for i in range(m):
        table = table[:, table[i, :] != -1]
    
    return table

            
def attack_GI(beacon, h_vals, error, test_group, remaining_GI, query_count, testAF, order):
    
    NN, test_size = beacon.shape[1], test_group.shape[1]    
    Delta = np.zeros((len(h_vals), test_size, query_count)) # Init delta
    
    for person in range(test_size): # test_size
        print("[{}] - Person ID: {}".format(person, test_group.columns[person]))
        curAF, markers, snps = filter_biallelic(testAF, test_group.index.to_numpy(), np.array(test_group.iloc[:, person].values))
        
        CHR = test_group.iloc[:,person]
        
        CHR_ind = dict(zip(CHR.index, range(len(CHR.index))))
        
        min_af = np.minimum(curAF["referenceAlleleFrequency"].values, curAF["otherAlleleFrequency"].values)            
        
        sorted_ind = np.lexsort((markers, min_af))
        
        min_af = min_af[sorted_ind] # Sorting based on minimum AF
        snps = snps[sorted_ind] # Sorting based on minimum AF
        markers = markers[sorted_ind] # Sorting based on minimum AF
        
        min_af[min_af == 0] = 1e-9 # We don't do the same thing for original AF. May be a problem.
        
        for h in h_vals:
            
            beacon_response = np.zeros((query_count), dtype=int)
            maf = np.zeros((query_count))
            
            to_infer_idx = np.where(min_af < (h / 100))[0] # since it is a tuple. However, it does not affect the delete func. 
            
            q = 0
            for i in range(len(to_infer_idx)):
                if q == query_count:
                    break 
                    
                curMarker = markers[to_infer_idx[i]]

#                 CHR_idx = np.where(CHR.index == curMarker)[0][0] # => THIS STATEMENT SLOWS IT DOWN 
                CHR_idx = CHR_ind[curMarker]                       # => THIS IS A LOT FASTER (HASHING INDICES)
                
                if CHR_idx < order: 
                    continue
                
                cont = False
                for o in range(CHR_idx - order, CHR_idx + 1):
                    if CHR[o] == "NN":
                        cont = True
                        break
                if cont:
                    continue

                
                 # === SLOW PART BEGINS ===================================================================
                    
                correlations = enumerate_alleles(remaining_GI.iloc[CHR_idx - order:CHR_idx + 1], testAF, False)
                
#                 I am not sure whether we should check 4 elements or 5 elements ??
                prefix = enumerate_alleles(CHR[CHR_idx-order:CHR_idx], testAF, True)
                prefix = list(prefix.flatten())
                
                if not calculate_prob(prefix, correlations):
                    continue
                
                # === SLOW PART ENDS ======================================================================
                
                maf[q] = min_af[to_infer_idx[i]]
                beacon_response[q] = query_beacon(markers[to_infer_idx[i]],  snps[to_infer_idx[i]], beacon)
                dlt = calculate_delta(error, maf[:q+1], NN, beacon_response[:q+1])
                Delta[h_vals[h], person, q] = dlt 
                
                q += 1
        
    return Delta     
    

    
    

In [None]:
# Where the application begins after cold start.
def start_test(attack_type, test_beacon, t_alpha, error, query_count, NN, h_vals, test_groups, remaining_GI, af, LD_Graph, a_prime, b_prime, order=4, do_power_plot=True):
    
    delta_values = []
    delta = None
    
    for test_group in test_groups:
        ch(attack_type)
        if attack_type == "GI":
            delta = attack_GI(test_beacon, h_vals, error, test_group[1], remaining_GI, query_count, af, order)
        
        elif attack_type == "SB":
            delta = attack_SB(test_beacon, error, test_group[1], query_count, af, a_prime, b_prime)
            
        elif attack_type in ("QI", "OPTIMAL"):
            delta = attack(test_beacon, h_vals, error, test_group[1], query_count, af, LD_Graph, attack_type)
        
        else:
            break
        ch(attack_type)
        delta_values.append(delta)
    
    
    # ====================== PLOTTING
    for h in h_vals:
        power = calculate_power(delta_values[0][h_vals[h],:,:], delta_values[1][h_vals[h],:,:], test_size, query_count, t_alpha, do_power_plot=True)                
    plt.show()
    # ===============================
    
    # Note that calculate_power function may change the order of delta. If you want to return the original delta values,
    # return or copy above the plotting part.
    return delta_values


In [None]:
# Cold start
ch(reset=True)
data_path = "./data/" 

ch("Loading")
beacon, init_af, ld = load_files(data_path)
af = init_af[init_af["markerId"].isin(beacon.index)] # Get rid of unnecessary lines in the attack code.
ch("Loading")

ch("Constructing Graph")
LD_Graph = construct_graph(ld)
ch("Constructing Graph")

In [None]:
# Randomly construct the beacon and the test groups 

idx = np.random.permutation(len(beacon.keys()))

# Case (20)
case = beacon.iloc[:,idx[0:20]]

# Beacon (60)
test_beacon = beacon.iloc[:, idx[0:60]]

# Control (20)
control = beacon.iloc[:, idx[60:80]]

# Remaining for GI (84)
remaining_GI = beacon.iloc[:, idx[80:]]

test_groups = [("case", case), ("control", control)]

In [None]:
# Parameters:
h_vals = [0, 3, 5]
NN = 60
test_size = 20
query_count = 2000
error = 0.001
t_alpha = 5
order = 4
h_vals = dict(zip(h_vals, range(len(h_vals))))

a_prime = 0.0735 # 0.1300 # 
b_prime = 1.0096 # 1.1300 #

In [None]:
ch(reset=True)
deltas = start_test("SB", test_beacon, t_alpha, error, query_count, NN, h_vals, test_groups, remaining_GI, af, LD_Graph, a_prime, b_prime, order=order, do_power_plot=True)
ch(reset=True)