In [None]:
from IPython.display import display, HTML
import pandas as pd
import numpy as np
def find_violations(df, cols):
    """
    计算的思想是:
    1. 根据整个FD (X,Y), 用 groupby(FD)找出所有的pattern;
    2. 这些pattern就是 .groups 得到的key, 对应的row index就是 value
    3. 针对X, 也就是除去 .groups 的最后一个元素key[:-1], Counter 计算 X对应的所有pattern的frequency
    4. 对于每个 (X,Y) 的 pattern (k1,k2,k3), 如果发现 X 部分 (k1,k2)的frequency > 1, 就是violations
    
    E.g. : 如果对(X,Y)group后, 有4个pattern (1,2,2), (1,2,3), (1,4,5), (1,6,5);
    但是count这些pattern的X部分, 发现 (1,2)的frequency>1, 说明(1,2)出现了两次,
    说明4个pattern中, X 部分有相同的, 这些X相同的pattern就是violation
    
    Args:
        df: dataframe
        cols: 指定的columns, 是一个list, 如果 FD: (key1, key2) -> key3, 
        那么就要传入cols = ['key1', 'key2', 'key3'].
    Return:
        a list of tuple, 每个tuple 由XY的pattern和对应的row index组成, 比如
        [((1, 2, 4), [1, 4]), ((1, 2, 3), [0])], 表示有两个violated pattern, 
        pattern1 是 (1,2,4), 对应的row index是 [1,4]; row index从0 开始
    """
    violations = []
    cols_groups = df.groupby(cols)
    # groups 返回 dict (key = group的key, value = row index)
    groups_keys = cols_groups.groups.keys()
    from collections import Counter
    # t[:-1] 是 key1, key2, 没有key3, 也就是FD 的 X
    c = Counter([t[:-1] for t in groups_keys])
    from collections import namedtuple
    # 使用的时候, 要用 V(group, index) 初始化, 而不是 用 Violation名字, 这个名字只是内部标识
    V = namedtuple('Violation', ['pattern','rowIndex'])
    for group, index in cols_groups.groups.items():
        # 对于k1k2 ->k3 来说, 如果k1,k2,k3的所有group中
        # k1,k2 对应的group多过 1, 说明 k3 有不一样的
        if c.get(group[:-1]) > 1:
            violation = V(group, index)
            violations.append(violation)
    return violations

def group_violations(violations):
    """
    把所有 X part 一样的 violations 都放到一个组中, 返回 一个dict, key为 X 的 pattern, value 为
     a list of violations, 并且该 list 根据 violation的 len(rowIndex) 从大到小排序.
     也就是说, value[0] 是 patten 的 frequency最高的 violation, 也就是minimal repair需要的基准. 
    Args:
        violations: the output of find_violations
    Return:
        a dict, key is the X part patten, value is violation object
    """
    from collections import defaultdict
    d_violations = defaultdict(list)
    for v in violations:
        d_violations[v.pattern[:-1]].append(v)
    for key, value in d_violations.items():
        value.sort(key = lambda x: len(x.rowIndex), reverse = True)
    return d_violations

def make_private(dataframe, p):
    """
    传入一个 dataframe, 以及随机的概率参数 p (也就是 epsilon), 把一个dataframe进行randomization private.
    Args:
        dataframe, pandas的dataframe
        p, differential privacy 参数 [0,1) , 越大, 表示越可能把当前的 value 进行 randomize;
    Return:
        privated dataframe after randomized 
    """
    import numpy as np
    # deep copy df, otherwise it affects the original non-private df
    privated_df = dataframe.copy(deep = True)
    count_row, count_col = privated_df.shape
    for col in privated_df:
        # 生成一些列的随机 p 序列, 避免每次都要生成
        p_array = np.random.rand(count_row)
        for i in range(count_row):
            if p_array[i] > p:
                privated_df[col][i] =privated_df[col].sample(1)
    return privated_df

def make_private2(dataframe, p):
    """
    用sample来代替上面的方法, 速度更快
    """
    import numpy as np
    # deep copy df, otherwise it affects the original non-private df
    privated_df = dataframe.copy(deep = True)
    count_row, count_col = privated_df.shape
    for col in privated_df:
        sample_col = privated_df[col].sample(frac=p, replace = False)
        reset_sample = sample_col.reset_index(drop = True)
        privated_df[col].update(reset_sample)
    return privated_df

def repair_operation(d_violations):
    """
    Args:
        the sorted grouped violations dict which is returned by group_violations func.
    Return:
        print out the repair operations
    """
    for key, vList in d_violations.items():
        print ('\n', '*'*20, 'violations', '*'*20)
        for v in vList:
            print (v)
        print ('\n', '*'*20, 'repair operations', '*'*20)
        ground_truth = vList[0].rowIndex[0]
    # 如果这里 r.rowIndex 加上 list(), 就会去掉 dtype, Int64Index
        error_index = [v.rowIndex for v in vList[1:]]
    #error_index = [list(v.rowIndex) for v in vList[1:]]
        print ('rowIndex repair: ', ground_truth, '<==', error_index)
        ground_truth_value = vList[0].pattern[-1]
        error_value = [v.pattern[-1] for v in vList[1:]]
        print ('value update: ', ground_truth_value, '<==', error_value )
        
def correct_repair(df, cols, p):
    violations = find_violations(df, cols)
    count_violations = sum([len(v.rowIndex) for v in violations])
    after_private_num_violations = count_violations  - count_violations * 1/2 * p
    after_private_clean_num_violations = after_private_num_violations * (1 - p)
    applied_repair = count_violations
    need_repair = count_violations  - count_violations  * p * p
    from collections import namedtuple
    Acc = namedtuple('accuracy', ['correct_repair', 'precision', 'recall'])
    
    return Acc(int(after_private_clean_num_violations), 
               int(after_private_clean_num_violations)/int(applied_repair), 
               int(after_private_clean_num_violations)/int(need_repair))

In [9]:
data_test = {'key1':[1]*5, 'key2':[2,2,3,3,2], 'key3':[3,4,3,3,4]}
frame_test = pd.DataFrame(data_test)
print (frame_test)
cols = ['key1', 'key2', 'key3']
violations_test = find_violations(frame_test, cols)
print ('original dataset: \n', frame_test)
repair_operation(group_violations(violations_test))
df_private = make_private2(frame_test, p = 0.8)
print ('\n private dataset: \n',df_private)
repair_operation(group_violations(find_violations(df_private,cols)))

   key1  key2  key3
0     1     2     3
1     1     2     4
2     1     3     3
3     1     3     3
4     1     2     4
original dataset: 
    key1  key2  key3
0     1     2     3
1     1     2     4
2     1     3     3
3     1     3     3
4     1     2     4

 ******************** violations ********************
Violation(pattern=(1, 2, 4), rowIndex=[1, 4])
Violation(pattern=(1, 2, 3), rowIndex=[0])

 ******************** repair operations ********************
rowIndex repair:  1 <== [[0]]
value update:  4 <== [3]

 private dataset: 
    key1  key2  key3
0     1     3     3
1     1     2     3
2     1     2     4
3     1     2     4
4     1     2     4

 ******************** violations ********************
Violation(pattern=(1, 2, 4), rowIndex=[2, 3, 4])
Violation(pattern=(1, 2, 3), rowIndex=[1])

 ******************** repair operations ********************
rowIndex repair:  2 <== [[1]]
value update:  4 <== [3]


In [3]:
# test for find_violations
book500_path = '../Data/books_500k.csv'
df_original = pd.read_csv(book500_path).dropna()
fd1 = ['book_title','publisher','book_author']
display(df_original.head())
df1 = df_original[['book_title','publisher','book_author']]
df = df1[20000:]
display(df.head())

Unnamed: 0,user_id,user_age,book_rating,isbn,book_title,book_author,publication_yr,publisher,img_url,city,state,country
1,99136,29.0,0,671027077,Far Harbor,JoAnn Ross,2000.0,Pocket,http://images.amazon.com/images/P/0671027077.0...,waynesboro,georgia,usa
2,216012,34.0,7,373218257,Love By Design,Nora Roberts,2003.0,Silhouette Books,http://images.amazon.com/images/P/0373218257.0...,clevelad,ohio,usa
5,142555,32.0,2,1860463614,Very Long Engagement,Sebastien Japrisot,1997.0,Havill Pr,http://images.amazon.com/images/P/1860463614.0...,belfast,northern ireland,united kingdom
6,12154,41.0,6,739405047,Lip Service,Mj Rose,0.0,Lady Chaterlys Library,http://images.amazon.com/images/P/0739405047.0...,pittsburgh,pennsylvania,usa
9,109461,30.0,0,312204671,It's Like That: A Spiritual Memoir,Joseph Simmons,2000.0,St. Martin's Press,http://images.amazon.com/images/P/0312204671.0...,oklahoma city,oklahoma,usa


Unnamed: 0,book_title,publisher,book_author
30841,The Ugly Duckling,Bantam Books,Iris Johansen
30844,Candide (Dover Thrift Editions),Dover Publications,Voltaire
30845,While I Was Gone,Ballantine Books,Sue Miller
30846,The Rapture: Truth or Consequences,Bantam Books,Hal Lindsey
30847,The Little Prince,Harvest Books,Antoine de Saint-Exupéry


In [29]:
# 通过sort 让相似的 violations 在一起
violations = find_violations(df, fd1)
sorted_violations = sorted(violations, key = lambda x:x[0])
for i in range(10):
    print (sorted_violations[i])

Violation(pattern=('A 6th Bowl of Chicken Soup for the Soul (Chicken Soup for the Soul)', 'Health Communications', 'Canfield. Jack'), rowIndex=[337046, 341078])
Violation(pattern=('A 6th Bowl of Chicken Soup for the Soul (Chicken Soup for the Soul)', 'Health Communications', 'Jack Canfield'), rowIndex=[376969])
Violation(pattern=("A Maiden's Grave", 'Signet Book', 'Jeff Deaver'), rowIndex=[35222, 99799, 122339, 129176, 155690, 161100, 167338, 186386, 268781, 396473, 404306, 437859, 444179])
Violation(pattern=("A Maiden's Grave", 'Signet Book', 'Jeffery Deaver'), rowIndex=[408444])
Violation(pattern=('A River Runs Through It, and Other Stories', 'University of Chicago Press', 'Norman F. Maclean'), rowIndex=[71456, 94574, 118279, 120564, 132347, 170672, 174290, 202310, 225492, 261786, 280896, 332663, 334213, 382831, 388202, 419360, 467799, 483501])
Violation(pattern=('A River Runs Through It, and Other Stories', 'University of Chicago Press', 'Norman Maclean'), rowIndex=[402729])
Violati

In [28]:
group_violations(violations)

defaultdict(list,
            {('1988',
              "St. Martin's Press"): [Violation(pattern=('1988', "St. Martin's Press", 'Andrew McGahan'), rowIndex=Int64Index([263637], dtype='int64')), Violation(pattern=('1988', "St. Martin's Press", 'Richard Lamm'), rowIndex=Int64Index([489717], dtype='int64'))],
             ('A 6th Bowl of Chicken Soup for the Soul (Chicken Soup for the Soul)',
              'Health Communications'): [Violation(pattern=('A 6th Bowl of Chicken Soup for the Soul (Chicken Soup for the Soul)', 'Health Communications', 'Canfield. Jack'), rowIndex=Int64Index([25566, 337046, 341078, 369627], dtype='int64')),
              Violation(pattern=('A 6th Bowl of Chicken Soup for the Soul (Chicken Soup for the Soul)', 'Health Communications', 'Jack Canfield'), rowIndex=Int64Index([3921, 376969], dtype='int64'))],
             ('A Little Princess',
              'Scholastic Paperbacks (T)'): [Violation(pattern=('A Little Princess', 'Scholastic Paperbacks (T)', 'Diane Molles

## minimal cardinality repair

1. 找出violations, X 部分相同, Y 部分不同的, 也就是找出 similar 的violations, 放入violation group
2. 对一个 violation group中的所有 violation, 计算他的 对应的 row 的个数 len(row_index)  
3. 找到该 violation group 中, len(row_index) 最大的那个 violation $v_m$, 然后把该 group 中其他的 violations 的 Y 都改成 $v_m$的 Y
4. 记录下被更改了的 row index, 以及更改前后的 value

In [31]:
print ('original dataset: \n')
display(df)

original dataset: 



Unnamed: 0,book_title,publisher,book_author
30841,The Ugly Duckling,Bantam Books,Iris Johansen
30844,Candide (Dover Thrift Editions),Dover Publications,Voltaire
30845,While I Was Gone,Ballantine Books,Sue Miller
30846,The Rapture: Truth or Consequences,Bantam Books,Hal Lindsey
30847,The Little Prince,Harvest Books,Antoine de Saint-Exupéry
30850,Before I Say Good-Bye : A Novel,Simon &amp; Schuster,Mary Higgins Clark
30852,Ultimate Spider-man: Power and Responsibility,Marvel Entertainment Group,Bill Jemas
30853,Kate Greenaway's Cross-Stitch Designs,Sterling Pub Co Inc,Julie Hasler
30855,Miss Julia Meets Her Match,Viking Books,Ann B. Ross
30856,An den Ufern versinkt die Zeit.,Goldmann,Barbara Taylor Bradford


In [51]:
repair_operation(group_violations(find_violations(df1, fd1)))


 ******************** violations ********************
Violation(pattern=('While I Was Gone', 'Ballantine Books', 'Sue Miller'), rowIndex=[7232, 10560, 11007, 14182, 20371, 22345, 23214, 30845, 32336, 44991, 49523, 50191, 57992, 65228, 66234, 66777, 75296, 75775, 80956, 83889, 88300, 92328, 94751, 97945, 105458, 109424, 118607, 119124, 122051, 123156, 127107, 132828, 143481, 143488, 152025, 154440, 154614, 158354, 163684, 167838, 169729, 175252, 175533, 177462, 178887, 179986, 183103, 183430, 196397, 197687, 203652, 211852, 212381, 216669, 221160, 225792, 228707, 232218, 235500, 239608, 240645, 242922, 247744, 256096, 259204, 267784, 270373, 270478, 273061, 274186, 276013, 288691, 299901, 304705, 306978, 319184, 324937, 327432, 344087, 347113, 348949, 353362, 358534, 358583, 371248, 377192, 378260, 386922, 391744, 400284, 407605, 411860, 416381, 425284, 425462, 429809, 434784, 442453, 444088, 452478, 459593, 463108, 463340, 464857, 469227, 475089, 479573, 479587, 482969, 492376, 493265

In [53]:
df_private = make_private2(df1, p = 0.2)
print ('\n private dataset: \n')
display(df_private)


 private dataset: 



Unnamed: 0,book_title,publisher,book_author
1,Nelson's Quick Reference Bible Dictionary : Ne...,Warner Books,Mary Manin Morrissey
2,Nutcases - Tort (Nutcases),Free Press,Edgar Allan Poe
5,Songmaster,Tor Books,Judy Mercer
6,Drowning Ruth,Tor Books,Michael Dorris
9,16 Lighthouse Road,Fawcett Books,Mary Stewart
10,Ada Blackjack : A True Story of Survival in th...,Penguin Books,Fergus Henderson
14,"The Siege (Star Trek Deep Space Nine, No 2)",Harlequin,Anne McCaffrey
18,"Key of Valor (Roberts, Nora. Key Trilogy, 3.)",Star Trek,Caroline Slate
19,Acts of Vengeance: A Mystery,HarperTrophy,Richard E Grant
21,"Breath, Eyes, Memory",Pinnacle Books,Katherine Applegate


In [54]:
repair_operation(group_violations(find_violations(df_private,fd1)))


 ******************** violations ********************
Violation(pattern=('Heat', 'Warner Books', 'William Goldman'), rowIndex=[281074])
Violation(pattern=('Heat', 'Warner Books', 'Seamus Dunn'), rowIndex=[6489])

 ******************** repair operations ********************
rowIndex repair:  281074 <== [[6489]]
value update:  William Goldman <== ['Seamus Dunn']

 ******************** violations ********************
Violation(pattern=('Dune (Remembering Tomorrow)', 'Ivy Books', 'Pascale Clark'), rowIndex=[2982])
Violation(pattern=('Dune (Remembering Tomorrow)', 'Ivy Books', 'Barbara Michaels'), rowIndex=[63143])

 ******************** repair operations ********************
rowIndex repair:  2982 <== [[63143]]
value update:  Pascale Clark <== ['Barbara Michaels']

 ******************** violations ********************
Violation(pattern=('Blue Diary', 'Harlequin', 'Alcoholics Anonymous World Service'), rowIndex=[3812])
Violation(pattern=('Blue Diary', 'Harlequin', 'James Hamilton-Paterson'

$$num(CorrectRepairs) = num(ViolationsBeforeClean) - num(ViolationsAfterClean)$$

In [8]:
for p in np.arange(0,0.55,0.05):
    num_correct_repair, precision, recall = correct_repair(df1, fd1, p)
    f1 = 2* (precision*recall)/(precision + recall)
    print('p = ', p, 
          ', correct_repair = ', num_correct_repair,
         ', precision = ', precision,
         ', recall = ', recall,
         ', f1 = ', f1)

p =  0.0 , correct_repair =  2760 , precision =  1.0 , recall =  1.0 , f1 =  1.0
p =  0.05 , correct_repair =  2556 , precision =  0.9260869565217391 , recall =  0.928441699963676 , f1 =  0.9272628333031018
p =  0.1 , correct_repair =  2359 , precision =  0.8547101449275363 , recall =  0.8634699853587116 , f1 =  0.8590677348871085
p =  0.15 , correct_repair =  2170 , precision =  0.7862318840579711 , recall =  0.8045977011494253 , f1 =  0.7953087777166942
p =  0.2 , correct_repair =  1987 , precision =  0.7199275362318841 , recall =  0.7500943752359381 , f1 =  0.7347014235533371
p =  0.25 , correct_repair =  1811 , precision =  0.6561594202898551 , recall =  0.7000386548125241 , f1 =  0.6773891902001122
p =  0.3 , correct_repair =  1642 , precision =  0.5949275362318841 , recall =  0.6539227399442453 , f1 =  0.6230316827926389
p =  0.35 , correct_repair =  1480 , precision =  0.5362318840579711 , recall =  0.6113176373399422 , f1 =  0.5713182783246478
p =  0.4 , correct_repair =  1324 

In [9]:
for p in np.arange(0,1,0.1):
    num_correct_repair, precision, recall = correct_repair(df1, fd1, p)
    f1 = 2* (precision*recall)/(precision + recall)
    print(p, 
          ',',100*precision,
         ', ', 100*recall,
         ', ', 100*f1)

0.0 , 100.0 ,  100.0 ,  100.0
0.1 , 85.47101449275362 ,  86.34699853587115 ,  85.90677348871085
0.2 , 71.9927536231884 ,  75.00943752359382 ,  73.4701423553337
0.3 , 59.49275362318841 ,  65.39227399442453 ,  62.303168279263886
0.4 , 47.971014492753625 ,  57.118205349439165 ,  52.14651437573848
0.5 , 37.5 ,  50.0 ,  42.857142857142854
0.6 , 27.97101449275362 ,  43.71460928652321 ,  34.1140079540433
0.7 , 19.492753623188406 ,  38.237384506041224 ,  25.82193424526038
0.8 , 11.992753623188406 ,  33.33333333333333 ,  17.639221955768715
0.9 , 5.471014492753623 ,  28.816793893129773 ,  9.196102314250915


In [1]:
l = [1,2,3]
l

[1, 2, 3]