In [155]:
import torch

def get_parallel_db(db, remove_index):
    return torch.cat((db[0:remove_index],
                      db[remove_index+1:]))

def get_parallel_dbs(db):
    parallel_dbs = list()
    
    for i in range(len(db)):
        pdb = get_parallel_db(db, i)
        parallel_dbs.append(pdb)
    
    return parallel_dbs

def create_db_and_parallels(num_entries):
    db = torch.rand(num_entries) > 0.5
    pdbs = get_parallel_dbs(db)
    
    return db, pdbs

def query(db, threshold=5):
    return (db.sum() > threshold).float()

def sensitivity(query, n_entries=1000):
    db, pdbs = create_db_and_parallels(n_entries)
    
    full_db_result = query(db)
    
    sensitivity = 0
    for pdb in pdbs:
        pdb_result = query(pdb)

        db_distance = torch.abs(pdb_result - full_db_result)

        if(db_distance > sensitivity):
            sensitivity = db_distance
            
    return sensitivity

In [156]:
db, pdbs = create_db_and_parallels(100)
db

tensor([1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1,
        0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0,
        1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
        0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1,
        1, 1, 0, 1], dtype=torch.uint8)

In [157]:
true_result = torch.mean(db.float())
true_result

tensor(0.5400)

In [158]:
def query(db, noise=0.5):
    public_result = torch.mean(db.float())
    
    first_coin_flip = (torch.rand(len(db)) < noise).float()
    second_coin_flip = (torch.rand(len(db)) < noise).float()

    augumented_database = db.float() * first_coin_flip + (1 - first_coin_flip) * second_coin_flip

    private_result = torch.mean(augumented_database.float()) * 2 - 0.5
    
    return private_result, public_result

In [159]:
result = torch.rand(5)
result

tensor([0.5716, 0.3470, 0.9077, 0.5186, 0.9040])

In [160]:
result > 0.5

tensor([1, 0, 1, 1, 1], dtype=torch.uint8)

In [161]:
db, pdbs = create_db_and_parallels(10)

print("Without noise:" + str(torch.mean(db.float())))

private_result, public_result = query(db, noise=0.25)
print("With noise:" + str(private_result))

private_result, public_result = query(db, noise=1)
print("With noise:" + str(private_result))


Without noise:tensor(0.8000)
With noise:tensor(0.5000)
With noise:tensor(1.1000)
