In [23]:
import dp_relational
import dp_relational.data.movies
import dp_relational.lib.qm
import dp_relational.lib.synth_data

import numpy as np
import time

# parameters

class ModelRunner:
    def __init__(self, *args, **kwargs) -> None:
        self.dataset_generator = None
        self.n_syn1 = None
        self.n_syn2 = None
        self.synth = None
        self.epsilon = None
        self.eps1 = None
        self.eps2 = None
        self.k = None
        self.dmax = None
        self.qm_generator = None
        self.cross_generation_strategy = None
        
        self.regenerate_dataset = False
        self.regenerate_syn_tables = False
        self.regenerate_qm = False
        self.regenerate_cross_answers = False
        
        self.rel_dataset = None
        self.df1_synth = None
        self.df2_synth = None
        self.qm = None
        self.relationship_syn = None
        
        self.T = None
        
        self.update(*args, **kwargs)
    
    def update(self, dataset_generator=None, n_syn1=None, n_syn2=None, synth=None,
                 epsilon=None, eps1=None, eps2=None, k=None,
                 dmax=None, qm_generator=None, cross_generation_strategy=None, T=None):
        dataset_generator = self.dataset_generator if dataset_generator is None else dataset_generator
        n_syn1 = self.n_syn1 if n_syn1 is None else n_syn1
        n_syn2 = self.n_syn2 if n_syn2 is None else n_syn2
        synth = self.synth if synth is None else synth
        epsilon = self.epsilon if epsilon is None else epsilon
        eps1 = self.eps1 if eps1 is None else eps1
        eps2 = self.eps2 if eps2 is None else eps2
        k = self.k if k is None else k
        dmax = self.dmax if dmax is None else dmax
        qm_generator = self.qm_generator if qm_generator is None else qm_generator
        cross_generation_strategy = self.cross_generation_strategy if cross_generation_strategy is None else cross_generation_strategy
        T = self.T if T is None else T
        
        if dataset_generator != self.dataset_generator or dmax != self.dmax:
            self.regenerate_dataset = True
        
        if n_syn1 != self.n_syn1 or n_syn2 != self.n_syn2 \
            or eps1 != self.eps1 or eps2 != self.eps2 or synth != self.synth:
            
            self.regenerate_syn_tables = True
        
        if qm_generator != self.qm_generator or k != self.k:
            self.regenerate_qm = True
        
        if epsilon != self.epsilon or cross_generation_strategy != self.cross_generation_strategy or T != self.T:
            self.regenerate_cross_answers = True
            
        # ensure that all downstream stages are forced to regenerate
        self.regenerate_syn_tables |= self.regenerate_dataset
        self.regenerate_qm |= self.regenerate_syn_tables
        self.regenerate_cross_answers |= self.regenerate_qm
        
        # actually copy in the values
        self.dataset_generator = dataset_generator
        self.n_syn1 = n_syn1
        self.n_syn2 = n_syn2
        self.synth = synth
        self.epsilon = epsilon
        self.eps1 = eps1
        self.eps2 = eps2
        self.k = k
        self.dmax = dmax
        self.qm_generator = qm_generator
        self.cross_generation_strategy = cross_generation_strategy
        self.T = T
    
    def dump_parameters(self):
        return {
            "dataset_generator": self.dataset_generator.__name__,
            "n_syn1": self.n_syn1,
            "n_syn2": self.n_syn2,
            "synth": self.synth,
            "epsilon": self.epsilon,
            "eps1": self.eps1,
            "eps2": self.eps2,
            "k": self.k,
            "dmax": self.dmax,
            "qm_generator": self.qm_generator.__name__,
            "cross_generation_strategy": self.cross_generation_strategy,
            "T": self.T
        }
    
    def run(self):
        self.times = {}
        
        class FuncTimer(object):
            def __init__(self, objin, name):
                self.objin = objin
                self.name = name
            def __enter__(self):
                self.time_start = time.perf_counter()
            def __exit__(self, exception_type, exception_value, traceback):
                time_end = time.perf_counter()
                self.objin[self.name] = time_end - self.time_start
        
        if self.regenerate_dataset:
            self.regenerate_dataset = False
            with FuncTimer(self.times, "dataset_generation"):
                self.rel_dataset = self.dataset_generator(self.dmax)
        
        if self.regenerate_syn_tables:
            self.regenerate_syn_tables = False
            with FuncTimer(self.times, "synth_table_generation"):
                self.df1_synth = dp_relational.lib.synth_data.compute_single_table_synth_data(
                    self.rel_dataset.table1.df, self.n_syn1, self.synth, epsilon=self.eps1)
                self.df2_synth = dp_relational.lib.synth_data.compute_single_table_synth_data(
                    self.rel_dataset.table2.df, self.n_syn2, self.synth, epsilon=self.eps2)
        
        if self.regenerate_qm:
            self.regenerate_qm = False
            with FuncTimer(self.times, "qm_init"):
                self.qm = self.qm_generator(self.rel_dataset, k=self.k, df1_synth=self.df1_synth, df2_synth=self.df2_synth)
        
        if self.regenerate_cross_answers:
            self.regenerate_cross_answers = False
            with FuncTimer(self.times, "cross_answers_gen"):
                self.relationship_syn = self.cross_generation_strategy(self.qm, self.epsilon - self.eps1 - self.eps2, T=self.T)
        
        ave_error, errors = dp_relational.lib.synth_data.evaluate_synthetic_rel_table(self.qm, self.relationship_syn)
        return (self.times, (ave_error, errors), self.relationship_syn)
        
def qm_generator_basic(rel_dataset, k, df1_synth, df2_synth):
    return dp_relational.lib.synth_data.QueryManagerBasic(rel_dataset, k=k, df1_synth=df1_synth, df2_synth=df2_synth)

def qm_generator_torch(rel_dataset, k, df1_synth, df2_synth):
    return dp_relational.lib.synth_data.QueryManagerTorch(rel_dataset, k=k, df1_synth=df1_synth, df2_synth=df2_synth)

def cross_generator_basic(qm, eps_rel, T):
    b_round = dp_relational.lib.synth_data.learn_relationship_vector_basic(qm, eps_rel, T=T, verbose=True)
    relationship_syn = dp_relational.lib.synth_data.make_synthetic_rel_table(qm, b_round)
    return relationship_syn

def cross_generator_torch(qm, eps_rel, T):
    b_round = dp_relational.lib.synth_data.learn_relationship_vector_torch(qm, eps_rel, T=T, T_mirror=50, verbose=True)
    relationship_syn = dp_relational.lib.synth_data.make_synthetic_rel_table_sparse(qm, b_round)
    return relationship_syn

runner = ModelRunner()
runner.update(dataset_generator=dp_relational.data.movies.dataset, n_syn1=776, n_syn2=1208,
              synth='mst', epsilon=3.0, eps1=1.0, eps2=1.0, k=2, dmax=10,
              qm_generator=qm_generator_torch, cross_generation_strategy=cross_generator_torch, T=7)

In [2]:
epsilons = [3, 5, 7, 9, 12, 16]
results = []
for epsilon in epsilons:
    runner.update(epsilon=epsilon)
    res = runner.run()
    print(res)
    results.append(res)



Fitting with 262144 dimensions
Getting cliques
Estimating marginals
Fitting with 294 dimensions
Getting cliques
Estimating marginals


100%|██████████| 7/7 [04:11<00:00, 35.95s/it]


({'dataset_generation': 0.6182566999999999, 'synth_table_generation': 336.1668837, 'qm_init': 0.19208209999999326, 'cross_answers_gen': 252.44764679999997}, (2.4551177574127343, array([0.12258706, 0.06480402, 0.09153019, ..., 0.00049628, 0.00019851,
       0.00408182])),      MovieID  UserID
0          2     623
1          4     825
2          7      68
3          7     507
4          7     562
..       ...     ...
397      764     402
398      764     933
399      765     564
400      765     927
401      771     525

[402 rows x 2 columns])


100%|██████████| 7/7 [03:50<00:00, 32.93s/it]


({'cross_answers_gen': 231.14690270000006}, (1.8541574562811791, array([0.00408978, 0.05291164, 0.05019882, ..., 0.00049628, 0.00019851,
       0.006588  ])),      MovieID  UserID
0          0     405
1          0     653
2          2     475
3          5    1154
4          7     374
..       ...     ...
396      768    1146
397      769    1174
398      775     547
399      775     551
400      775     930

[401 rows x 2 columns])


100%|██████████| 7/7 [01:37<00:00, 13.91s/it]


({'cross_answers_gen': 97.59466280000004}, (1.5194146444507544, array([0.0479602 , 0.01479797, 0.0119282 , ..., 0.00049628, 0.00019851,
       0.00159426])),      MovieID  UserID
0          1     164
1          1     949
2          6     515
3          6     993
4         15     223
..       ...     ...
397      768     334
398      769    1197
399      770     953
400      770    1040
401      774     916

[402 rows x 2 columns])


100%|██████████| 7/7 [01:28<00:00, 12.64s/it]


({'cross_answers_gen': 88.71277899999995}, (1.8551969034311566, array([0.08069825, 0.10170182, 0.00715779, ..., 0.00049628, 0.00019851,
       0.00160047])),      MovieID  UserID
0          1     361
1          1     368
2          2    1022
3          3     772
4          4      42
..       ...     ...
396      755     175
397      755     340
398      763     198
399      764    1037
400      775     930

[401 rows x 2 columns])


100%|██████████| 7/7 [01:29<00:00, 12.74s/it]


({'cross_answers_gen': 89.39093019999996}, (1.8465856088549075, array([0.01174129, 0.02474822, 0.02538523, ..., 0.00049628, 0.00228905,
       0.0008933 ])),      MovieID  UserID
0          1     746
1          2     674
2          8    1031
3          9     236
4         10     543
..       ...     ...
397      767     461
398      769     198
399      769     279
400      774    1022
401      775     631

[402 rows x 2 columns])


100%|██████████| 7/7 [01:28<00:00, 12.71s/it]


({'cross_answers_gen': 89.19384450000007}, (1.4135426070427277, array([0.00089776, 0.0019512 , 0.03274246, ..., 0.00049628, 0.00019851,
       0.00160047])),      MovieID  UserID
0          1     811
1          2     844
2          4     397
3          6     184
4          7      63
..       ...     ...
396      766     165
397      766     795
398      766    1156
399      771     993
400      773     773

[401 rows x 2 columns])


In [3]:
for epsilon in epsilons:
    runner.update(epsilon=epsilon)
    res = runner.run()
    print(res)
    results.append(res)

100%|██████████| 7/7 [03:20<00:00, 28.62s/it]


({'cross_answers_gen': 200.62221650000015}, (2.464500766089596, array([0.08776119, 0.01479797, 0.02685357, ..., 0.00049628, 0.00019851,
       0.0008933 ])),      MovieID  UserID
0          0     960
1          2    1075
2          6     509
3         15     986
4         15    1044
..       ...     ...
397      760     223
398      766     970
399      766    1144
400      767     535
401      771    1133

[402 rows x 2 columns])


100%|██████████| 7/7 [04:55<00:00, 42.19s/it]


({'cross_answers_gen': 296.14116420000005}, (1.1424988374916156, array([0.06149254, 0.00236016, 0.07264891, ..., 0.00049628, 0.00019851,
       0.0008933 ])),      MovieID  UserID
0          1     164
1          8    1161
2         11     750
3         12     352
4         16     843
..       ...     ...
397      761    1151
398      764     449
399      767     585
400      772     941
401      774      29

[402 rows x 2 columns])


100%|██████████| 7/7 [04:54<00:00, 42.13s/it]


({'cross_answers_gen': 295.60381370000005}, (1.4465091750909465, array([0.02154613, 0.01691379, 0.02212038, ..., 0.00049628, 0.00019851,
       0.00409423])),      MovieID  UserID
0          0      29
1          1     553
2          1     594
3          3     235
4          3     516
..       ...     ...
396      764     122
397      765     529
398      769     798
399      769     846
400      770     347

[401 rows x 2 columns])


100%|██████████| 7/7 [01:40<00:00, 14.41s/it]


({'cross_answers_gen': 101.39058900000009}, (1.8375890513428625, array([0.02666667, 0.04241596, 0.00944064, ..., 0.00049628, 0.00019851,
       0.00408182])),      MovieID  UserID
0          0     416
1          4    1001
2          8     366
3          9     521
4         13     277
..       ...     ...
397      756     997
398      757     570
399      759     782
400      762       8
401      774     652

[402 rows x 2 columns])


100%|██████████| 7/7 [01:40<00:00, 14.37s/it]


({'cross_answers_gen': 100.89542080000001}, (1.6313477187173222, array([0.0241791 , 0.02500302, 0.01543498, ..., 0.00049628, 0.00019851,
       0.0008933 ])),      MovieID  UserID
0          0     375
1          0     439
2          0     508
3          3     144
4          3     647
..       ...     ...
397      764     886
398      768     240
399      768     726
400      771     908
401      775     660

[402 rows x 2 columns])


100%|██████████| 7/7 [02:17<00:00, 19.70s/it]


({'cross_answers_gen': 138.19017630000008}, (1.8850962081422684, array([0.00069652, 0.02997815, 0.01939089, ..., 0.00049628, 0.00019851,
       0.00408182])),      MovieID  UserID
0          6    1014
1          7     801
2         10     241
3         14     558
4         14     933
..       ...     ...
397      765     258
398      769     451
399      769    1094
400      770     619
401      772     383

[402 rows x 2 columns])


In [10]:
for x in range(2):
    for epsilon in epsilons:
        runner.update(epsilon=epsilon)
        res = runner.run()
        print(res)
        results.append(res)


100%|██████████| 7/7 [04:12<00:00, 36.01s/it]


({'cross_answers_gen': 252.54946500000005}, (3.0014552351773816, array([0.15870324, 0.09031813, 0.02461415, ..., 0.00049628, 0.00229525,
       0.00160047])),      MovieID  UserID
0          2    1190
1          3     875
2          4     818
3         14     228
4         14     740
..       ...     ...
396      756     522
397      761     691
398      769     619
399      769     635
400      771      73

[401 rows x 2 columns])


100%|██████████| 7/7 [02:40<00:00, 22.97s/it]


({'cross_answers_gen': 161.49536269999953}, (1.353730245930535, array([0.03164179, 0.06977914, 0.03431626, ..., 0.00049628, 0.00019851,
       0.0008933 ])),      MovieID  UserID
0          1     141
1          1     281
2          2    1058
3          9     363
4         10      74
..       ...     ...
397      774     274
398      775     143
399      775     163
400      775     368
401      775     943

[402 rows x 2 columns])


100%|██████████| 7/7 [03:12<00:00, 27.44s/it]


({'cross_answers_gen': 192.69954919999873}, (1.6273941559115719, array([0.04578554, 0.02439509, 0.00531104, ..., 0.00049628, 0.00229525,
       0.0008933 ])),      MovieID  UserID
0          2      73
1          5     595
2          6     741
3          9     501
4         11     181
..       ...     ...
396      770     916
397      772     552
398      773     549
399      774     572
400      775     843

[401 rows x 2 columns])


100%|██████████| 7/7 [02:09<00:00, 18.57s/it]


({'cross_answers_gen': 130.6003643999993}, (2.0129448794992046, array([0.09383085, 0.05236621, 0.03284792, ..., 0.00049628, 0.00019851,
       0.00159426])),      MovieID  UserID
0          2    1133
1          4      82
2          4     347
3          4     505
4          7    1016
..       ...     ...
397      762    1163
398      765     613
399      766      11
400      769     891
401      775     602

[402 rows x 2 columns])


100%|██████████| 7/7 [01:33<00:00, 13.30s/it]


({'cross_answers_gen': 93.38140179999937}, (1.670506990824757, array([0.07890547, 0.07475427, 0.01045986, ..., 0.00199128, 0.00019851,
       0.00408182])),      MovieID  UserID
0          1     290
1          5     534
2          5     933
3          7    1087
4         10      38
..       ...     ...
397      770     728
398      772     352
399      772    1150
400      774     244
401      774     619

[402 rows x 2 columns])


100%|██████████| 7/7 [01:31<00:00, 13.01s/it]


({'cross_answers_gen': 91.33946569999898}, (1.6862149836327296, array([0.0315212 , 0.01442003, 0.02710791, ..., 0.00049628, 0.00019851,
       0.0008933 ])),      MovieID  UserID
0          4     302
1          4     784
2          6     760
3          6     784
4          8     177
..       ...     ...
396      761     651
397      770     846
398      773      28
399      775     347
400      775     549

[401 rows x 2 columns])


100%|██████████| 7/7 [01:32<00:00, 13.23s/it]


({'cross_answers_gen': 92.84864969999944}, (3.069781066128906, array([0.01835411, 0.04683898, 0.00466402, ..., 0.00049628, 0.00019851,
       0.0008933 ])),      MovieID  UserID
0          4     657
1          5     129
2          8     829
3         10     525
4         11    1136
..       ...     ...
396      773     340
397      773     473
398      774    1163
399      775     221
400      775     517

[401 rows x 2 columns])


100%|██████████| 7/7 [01:30<00:00, 12.93s/it]


({'cross_answers_gen': 90.7603321000006}, (1.6380730297597577, array([0.03412935, 0.04739108, 0.00050961, ..., 0.00049628, 0.00477661,
       0.0008933 ])),      MovieID  UserID
0          0     858
1          4     985
2          7     503
3         11     193
4         11     815
..       ...     ...
397      765     731
398      767    1124
399      771     798
400      772     971
401      773    1198

[402 rows x 2 columns])


100%|██████████| 7/7 [01:31<00:00, 13.01s/it]


({'cross_answers_gen': 91.303837200001}, (1.486885625748147, array([0.040798  , 0.02049269, 0.04521129, ..., 0.00199749, 0.00019851,
       0.0008933 ])),      MovieID  UserID
0          2     190
1          4     821
2          5     152
3          5     383
4          7     702
..       ...     ...
396      767    1191
397      770     134
398      773     284
399      774     949
400      775     456

[401 rows x 2 columns])


100%|██████████| 7/7 [01:31<00:00, 13.12s/it]


({'cross_answers_gen': 92.10404089999975}, (1.6512133873568995, array([0.00588529, 0.03296152, 0.0302487 , ..., 0.00049628, 0.00019851,
       0.0008933 ])),      MovieID  UserID
0          4    1143
1          6     379
2         11     500
3         12    1080
4         15     352
..       ...     ...
396      770     803
397      773     945
398      774     671
399      774    1056
400      775      43

[401 rows x 2 columns])


100%|██████████| 7/7 [01:34<00:00, 13.51s/it]


({'cross_answers_gen': 94.83119679999982}, (1.6616297247335605, array([0.0280597 , 0.03246571, 0.04675407, ..., 0.00049628, 0.00019851,
       0.00159426])),      MovieID  UserID
0          0     755
1          0    1206
2          1     840
3          1     866
4          2     224
..       ...     ...
397      764     573
398      766    1036
399      768     439
400      769     326
401      775      22

[402 rows x 2 columns])


100%|██████████| 7/7 [01:34<00:00, 13.45s/it]


({'cross_answers_gen': 94.40386650000073}, (1.6139474486770264, array([0.03910448, 0.00484772, 0.04031061, ..., 0.00049628, 0.00019851,
       0.00159426])),      MovieID  UserID
0          1     565
1          8     811
2         11    1075
3         13     379
4         15      57
..       ...     ...
397      768     122
398      768     286
399      768     619
400      773     384
401      774     963

[402 rows x 2 columns])


In [12]:
print([x[1][0] for x in results])
print(len(results))

[2.4551177574127343, 1.8541574562811791, 1.5194146444507544, 1.8551969034311566, 1.8465856088549075, 1.4135426070427277, 2.464500766089596, 1.1424988374916156, 1.4465091750909465, 1.8375890513428625, 1.6313477187173222, 1.8850962081422684, 2.9671570545192942, 1.327093304327638, 1.0812450538940765, 1.9154043220164225, 1.692299690775056, 1.8690452895424334, 2.1787484322711417, 1.8146224122471912, 1.7680172339296074, 1.6495564636152023, 1.3304450481186063, 1.5554325148451291, 3.0014552351773816, 1.353730245930535, 1.6273941559115719, 2.0129448794992046, 1.670506990824757, 1.6862149836327296, 3.069781066128906, 1.6380730297597577, 1.486885625748147, 1.6512133873568995, 1.6616297247335605, 1.6139474486770264]
36


In [15]:
import pickle
with open("experiments1.pickle", "wb") as fp:
    pickle.dump((results, epsilons), fp)

In [16]:
with open("experiments1.pickle", "rb") as fp:   # Unpickling
    b = pickle.load(fp)
    print(b)

([({'dataset_generation': 0.6182566999999999, 'synth_table_generation': 336.1668837, 'qm_init': 0.19208209999999326, 'cross_answers_gen': 252.44764679999997}, (2.4551177574127343, array([0.12258706, 0.06480402, 0.09153019, ..., 0.00049628, 0.00019851,
       0.00408182])),      MovieID  UserID
0          2     623
1          4     825
2          7      68
3          7     507
4          7     562
..       ...     ...
397      764     402
398      764     933
399      765     564
400      765     927
401      771     525

[402 rows x 2 columns]), ({'cross_answers_gen': 231.14690270000006}, (1.8541574562811791, array([0.00408978, 0.05291164, 0.05019882, ..., 0.00049628, 0.00019851,
       0.006588  ])),      MovieID  UserID
0          0     405
1          0     653
2          2     475
3          5    1154
4          7     374
..       ...     ...
396      768    1146
397      769    1174
398      775     547
399      775     551
400      775     930

[401 rows x 2 columns]), ({'cross_an

In [22]:
np.median(np.array([x[1][0] for x in results]).reshape((6, 6)), axis=0)
epsilons

[3, 5, 7, 9, 12, 16]