In [1]:
import numpy as np
import utils
import torch
import pandas as pd

from tqdm import tqdm


%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
device

device(type='cpu')

In [4]:
dfx, dfy, cols = utils.get_dataset('adult_income', return_dataframe=True)

In [5]:
from sklearn.preprocessing import minmax_scale

In [6]:
minmax_separate = False

In [31]:
if minmax_separate:
    dfx_1 = dfx.loc[dfx.gender == 1]
    dfx_0 = dfx.loc[dfx.gender == 0]
    dfx_1[cols] = minmax_scale(dfx_1[cols])
    dfx_0[cols] = minmax_scale(dfx_0[cols])
else:
    dfx[cols] = minmax_scale(dfx[cols])
    dfx_1 = dfx.loc[dfx.gender == 1]
    dfx_0 = dfx.loc[dfx.gender == 0]
    dfy_1 = dfy.loc[dfx.gender == 1]
    dfy_0 = dfy.loc[dfx.gender == 0]

In [32]:
tensor_data_x0 = torch.tensor(dfx_0.values).to(device)
tensor_data_y0 = torch.tensor(dfy_0.values).to(device)
tensor_data_x1 = torch.tensor(dfx_1.values).to(device)
tensor_data_y1 = torch.tensor(dfy_1.values).to(device)

In [33]:
tensor_data_x0.shape

torch.Size([31114, 102])

In [34]:
import utils

In [35]:
INITIAL = 500

In [36]:
warm_start_samples_0 = utils.sample_from_tensor(tensor_data_x0.shape[0], INITIAL, device).long()
warm_start_samples_1 = utils.sample_from_tensor(tensor_data_x1.shape[0], INITIAL, device).long()

In [37]:
non_warm_start_samples_0 = utils.complement_idx(warm_start_samples_0, tensor_data_x0.shape[0])
non_warm_start_samples_1 = utils.complement_idx(warm_start_samples_1, tensor_data_x1.shape[0])

In [38]:
warm_start_x0, warm_start_y0 = tensor_data_x0[warm_start_samples_0], tensor_data_y0[warm_start_samples_0]
warm_start_x1, warm_start_y1 = tensor_data_x1[warm_start_samples_1], tensor_data_y1[warm_start_samples_1]

In [39]:
warm_start_y0.float().mean() - warm_start_y1.float().mean()

tensor(0.1840)

In [40]:
tensor_data_x0, tensor_data_y0 = tensor_data_x0[non_warm_start_samples_0], tensor_data_y0[non_warm_start_samples_0]
tensor_data_x1, tensor_data_y1 = tensor_data_x1[non_warm_start_samples_1], tensor_data_y1[non_warm_start_samples_1]

In [41]:
from blackbox_models import BlackBox

In [42]:
blackbox = BlackBox('Logistic', 102, 1)
blackbox.load_state_dict(torch.load("checkpoints/adult_income/blackbox/Logistic/best.pt", map_location='cpu'))
blackbox.eval()
blackbox.to(device);

In [43]:
k_nearby_points = 100
total_selection = 4000

In [44]:
tensor_data_y0.float().mean() - tensor_data_y1.float().mean()

tensor(0.1994)

In [45]:
warm_start_x0.shape

torch.Size([500, 102])

In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class naive_model(nn.Module):
    def __init__(self, feature_dim,device):
        super(naive_model, self).__init__()
        self.device = device
        self.net = nn.Linear(feature_dim, feature_dim)
        self.feature_dim = feature_dim
    def forward(self):
        return self.net(torch.ones(1,self.feature_dim).to(self.device))

In [47]:
nm = naive_model(102, 'cpu')

In [48]:
nm().shape

torch.Size([1, 102])

In [49]:
# k(x ∗ , x ∗ ) + σ 2noise − k(x ∗ , x) > [K + σ 2noise I] −1 k(x ∗ , x)

def compute_variance(x_new,X):
    A = x_new@X.t()
    K =  X@X.t()
    ##k(x_new,x_new)           ## torch.sum(x_new.X[i,:])
    return torch.sum(x_new * x_new)  - A@torch.inverse(K+0.1)@A.t() 

In [50]:
x_new =  nm().float()
X =  warm_start_x0.float()

A = x_new@X.t()
A.shape

torch.Size([1, 500])

First point -> either sample from the first 1000 random points
or sample from the input space

Q = 5000
Qi = 1000

4000 <- (x' <- random)

In [51]:
def compute_variance(x_new,X,device):
    A = x_new@(X.t())
    K =  X@(X.t())
    K_plus_I_inv = torch.inverse(K+0.1*torch.eye(K.shape[0]).to(device))
    return torch.sum(x_new * x_new) +0.1 - A@K_plus_I_inv@A.t() 

def compute_mean(x_new,X,y,device):
    #  k(x ∗ , x) T [K + σ 2noise I] −1 y
    K =  X@(X.t())
    A = x_new@(X.t())
    K_plus_I_inv = torch.inverse(K+0.1*torch.eye(K.shape[0]).float().to(device))
    m = A@K_plus_I_inv@y
    return m

In [52]:
BUDGET = 5000

In [53]:
def compute_disparity(x1, x2):
    return torch.abs(blackbox(x1.float()).mean() - blackbox(x2.float()).mean())

In [54]:
candidates0 = warm_start_x0.clone().float()
candidates1 = warm_start_x1.clone().float()
y0_new = warm_start_y0.clone().float()
y1_new = warm_start_y1.clone().float()

for epoch_outer in tqdm(range((BUDGET - INITIAL)//k_nearby_points)):
    # s = utils.sample_randomly_from_input_space('adult_income').to(device) 
    nm = naive_model(102,device).to(device)
    # optimizer0 = torch.optim.AdamW((s,), lr=1e-5)
    optimizer0 = torch.optim.Adam(nm.parameters(),
                               1e-2,
                               weight_decay=1e-5)
    best_loss = 10e5
    count = 0
    losses = []
    print(f'Disparity @ {epoch_outer}: {compute_disparity(candidates0, candidates1)}')
    
    for epoch in range(1,50):
        optimizer0.zero_grad()
        #x_new = nm().float()
        x_new = nm().float()
        mean_term = compute_mean(x_new,candidates0.float(),y0_new,device) + compute_mean(x_new,candidates1,y1_new,device) # likelihood0(model0(nm())).variance.sum(axis=0)
        abs_mean = torch.abs(mean_term)
        var_term = compute_variance(x_new,candidates0,device) + compute_variance(x_new,candidates1,device) # likelihood0(model0(nm())).variance.sum(axis=0)
        # ll_term = 0 #- lambda_reg * ((s - x0_samples)**2).mean()
        # obj0 =  var_term #+ ll_term
        assert var_term >0
        # loss = abs_mean/var_term
        loss = abs_mean/torch.sqrt(var_term)
        # if var_term < 0:
        #     flagged_X0 = X0
        #     flagged_X1 = X1
        #     flagged_x_new = x_new
        # print("mean, var, sum",abs_mean.item(),var_term.item(),loss.item())
        # print("at epoch", epoch, "we have loss",loss.item())
        #print(loss)
        loss.backward()
        optimizer0.step()
        if loss < best_loss:
            best_loss = loss
            count = 0
            losses.append(loss)
        else:
            count += 1
        if count  == 10:
            print(f'--> Triggering Early stop: current loss: {loss}, best loss: {best_loss}, count: {count}/10')
            break
            
            
            
    
    qx0, qy0, tensor_data_x0, tensor_data_y0 = utils.query_nearby(x_new, tensor_data_x0, tensor_data_y0, k_nearby_points)
    qx1, qy1, tensor_data_x1, tensor_data_y1 = utils.query_nearby(x_new, tensor_data_x1, tensor_data_y1, k_nearby_points)

    candidates0 = torch.cat((candidates0, qx0.float()), dim=0)
    candidates1 = torch.cat((candidates1, qx1.float()), dim=0)
    y0_new = torch.cat((y0_new, qy0.float()), dim=0)
    y1_new = torch.cat((y1_new, qy1.float()), dim=0)
    # if epoch_outer % 1 == 0
    # #new_vals = torch.concatenate(candidates)
    #     new_queries,_, _ = vae0(candidates.float().to(device))
    #     new_labels = blackbox(new_queries)
    #     neg_queried = torch.concatenate([neg_queried, new_queries])
    #     neg_labels = torch.concatenate([neg_labels, (0.5*(torch.sign(new_labels - 0.5) + 1.0)).long().detach().clone()])
    #     model0, likelihood0 = train(neg_queried.detach().clone(),neg_labels.flatten(), device)
    #     model0.eval()
    #     likelihood0.eval()

  2%|▉                                           | 1/45 [00:00<00:06,  7.16it/s]

Disparity @ 0: 0.1871395707130432
--> Triggering Early stop: current loss: tensor([[0.4369]], grad_fn=<DivBackward0>), best loss: tensor([[0.2298]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 1: 0.1484164595603943


  4%|█▉                                          | 2/45 [00:00<00:11,  3.88it/s]

--> Triggering Early stop: current loss: tensor([[0.0943]], grad_fn=<DivBackward0>), best loss: tensor([[0.0200]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 2: 0.16368448734283447


  7%|██▉                                         | 3/45 [00:01<00:16,  2.49it/s]

--> Triggering Early stop: current loss: tensor([[0.0767]], grad_fn=<DivBackward0>), best loss: tensor([[0.0213]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 3: 0.15614007413387299


  9%|███▉                                        | 4/45 [00:01<00:20,  1.99it/s]

--> Triggering Early stop: current loss: tensor([[0.0855]], grad_fn=<DivBackward0>), best loss: tensor([[0.0184]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 4: 0.14818575978279114


 11%|████▉                                       | 5/45 [00:02<00:18,  2.13it/s]

--> Triggering Early stop: current loss: tensor([[0.5171]], grad_fn=<DivBackward0>), best loss: tensor([[0.1158]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 5: 0.13352511823177338


 13%|█████▊                                      | 6/45 [00:02<00:18,  2.12it/s]

--> Triggering Early stop: current loss: tensor([[0.6284]], grad_fn=<DivBackward0>), best loss: tensor([[0.0005]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 6: 0.12678229808807373


 16%|██████▊                                     | 7/45 [00:04<00:34,  1.12it/s]

--> Triggering Early stop: current loss: tensor([[0.0386]], grad_fn=<DivBackward0>), best loss: tensor([[0.0008]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 7: 0.13818728923797607


 18%|███████▊                                    | 8/45 [00:05<00:30,  1.21it/s]

--> Triggering Early stop: current loss: tensor([[0.6563]], grad_fn=<DivBackward0>), best loss: tensor([[0.0574]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 8: 0.13089868426322937


 20%|████████▊                                   | 9/45 [00:05<00:30,  1.19it/s]

--> Triggering Early stop: current loss: tensor([[0.2803]], grad_fn=<DivBackward0>), best loss: tensor([[0.0459]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 9: 0.12036687135696411


 22%|█████████▌                                 | 10/45 [00:09<00:55,  1.60s/it]

--> Triggering Early stop: current loss: tensor([[0.0397]], grad_fn=<DivBackward0>), best loss: tensor([[0.0056]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 10: 0.1145174503326416


 24%|██████████▌                                | 11/45 [00:13<01:18,  2.30s/it]

--> Triggering Early stop: current loss: tensor([[0.0333]], grad_fn=<DivBackward0>), best loss: tensor([[0.0007]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 11: 0.1167323887348175


 27%|███████████▍                               | 12/45 [00:14<01:04,  1.95s/it]

--> Triggering Early stop: current loss: tensor([[0.3097]], grad_fn=<DivBackward0>), best loss: tensor([[0.2768]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 12: 0.10800707340240479


 29%|████████████▍                              | 13/45 [00:18<01:21,  2.56s/it]

--> Triggering Early stop: current loss: tensor([[0.0635]], grad_fn=<DivBackward0>), best loss: tensor([[0.0057]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 13: 0.10402262210845947


 31%|█████████████▍                             | 14/45 [00:19<01:11,  2.30s/it]

--> Triggering Early stop: current loss: tensor([[0.3127]], grad_fn=<DivBackward0>), best loss: tensor([[0.0322]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 14: 0.10836195945739746


 33%|██████████████▎                            | 15/45 [00:21<01:07,  2.23s/it]

--> Triggering Early stop: current loss: tensor([[0.2832]], grad_fn=<DivBackward0>), best loss: tensor([[0.0872]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 15: 0.10347697138786316


 36%|███████████████▎                           | 16/45 [00:24<01:03,  2.21s/it]

--> Triggering Early stop: current loss: tensor([[0.5346]], grad_fn=<DivBackward0>), best loss: tensor([[0.1963]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 16: 0.10532471537590027


 38%|████████████████▏                          | 17/45 [00:32<01:56,  4.15s/it]

--> Triggering Early stop: current loss: tensor([[0.0523]], grad_fn=<DivBackward0>), best loss: tensor([[0.0071]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 17: 0.12034168839454651


 40%|█████████████████▏                         | 18/45 [00:41<02:28,  5.48s/it]

--> Triggering Early stop: current loss: tensor([[0.0373]], grad_fn=<DivBackward0>), best loss: tensor([[0.0060]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 18: 0.112088143825531


 42%|██████████████████▏                        | 19/45 [00:54<03:22,  7.78s/it]

--> Triggering Early stop: current loss: tensor([[0.0436]], grad_fn=<DivBackward0>), best loss: tensor([[0.0009]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 19: 0.12531307339668274


 44%|███████████████████                        | 20/45 [00:57<02:41,  6.47s/it]

--> Triggering Early stop: current loss: tensor([[0.4994]], grad_fn=<DivBackward0>), best loss: tensor([[0.0240]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 20: 0.13514089584350586


 47%|████████████████████                       | 21/45 [01:10<03:22,  8.44s/it]

--> Triggering Early stop: current loss: tensor([[0.0384]], grad_fn=<DivBackward0>), best loss: tensor([[0.0065]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 21: 0.13302266597747803


 49%|█████████████████████                      | 22/45 [01:16<02:51,  7.46s/it]

--> Triggering Early stop: current loss: tensor([[0.4592]], grad_fn=<DivBackward0>), best loss: tensor([[0.1712]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 22: 0.13921180367469788


 51%|█████████████████████▉                     | 23/45 [01:22<02:34,  7.02s/it]

--> Triggering Early stop: current loss: tensor([[0.4596]], grad_fn=<DivBackward0>), best loss: tensor([[0.4160]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 23: 0.13531583547592163


 53%|██████████████████████▉                    | 24/45 [01:39<03:29,  9.97s/it]

--> Triggering Early stop: current loss: tensor([[0.0489]], grad_fn=<DivBackward0>), best loss: tensor([[0.0075]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 24: 0.13388942182064056


 56%|███████████████████████▉                   | 25/45 [01:46<03:05,  9.26s/it]

--> Triggering Early stop: current loss: tensor([[0.3210]], grad_fn=<DivBackward0>), best loss: tensor([[0.2722]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 25: 0.12908664345741272


 58%|████████████████████████▊                  | 26/45 [01:58<03:12, 10.16s/it]

--> Triggering Early stop: current loss: tensor([[0.4040]], grad_fn=<DivBackward0>), best loss: tensor([[0.1885]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 26: 0.1253991723060608


 60%|█████████████████████████▊                 | 27/45 [02:08<03:02, 10.15s/it]

--> Triggering Early stop: current loss: tensor([[0.2992]], grad_fn=<DivBackward0>), best loss: tensor([[0.1861]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 27: 0.12195079028606415


 62%|██████████████████████████▊                | 28/45 [02:19<02:52, 10.16s/it]

--> Triggering Early stop: current loss: tensor([[0.2368]], grad_fn=<DivBackward0>), best loss: tensor([[0.1558]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 28: 0.11963291466236115


 64%|███████████████████████████▋               | 29/45 [02:31<02:53, 10.82s/it]

--> Triggering Early stop: current loss: tensor([[0.3064]], grad_fn=<DivBackward0>), best loss: tensor([[0.0881]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 29: 0.11546732485294342


 67%|████████████████████████████▋              | 30/45 [03:06<04:33, 18.21s/it]

--> Triggering Early stop: current loss: tensor([[0.0407]], grad_fn=<DivBackward0>), best loss: tensor([[0.0075]], grad_fn=<DivBackward0>), count: 10/10
Disparity @ 30: 0.11256153881549835


 67%|████████████████████████████▋              | 30/45 [03:16<01:38,  6.53s/it]


KeyboardInterrupt: 

In [146]:

new_y0 =  blackbox(.float())
new_y1 = blackbox(q1.float())
new_y0.mean()-new_y1.mean()
# warm_start_y0.float().mean() - warm_start_y1.float().mean()

tensor(0.1831, device='cuda:0', grad_fn=<SubBackward0>)

In [117]:
X = flagged_X0
X =flagged_X1
x_new = flagged_x_new
A = x_new@(X.t())
K =  X@(X.t())
##k(x_new,x_new)           ## torch.sum(x_new.X[i,:])
torch.sum(x_new * x_new) +0.1 - A@torch.inverse(K+0.1*torch.eye(K.shape[0]).to(device))@A.t() 

tensor([[80.3355]], device='cuda:0', grad_fn=<SubBackward0>)

In [115]:
torch.sum(x_new * x_new) 
x_new@x_new.t()
K

tensor([[8.2852, 4.3340, 4.4040,  ..., 4.0772, 3.3389, 6.1568],
        [4.3340, 8.5060, 4.5015,  ..., 5.0772, 3.4083, 6.1812],
        [4.4040, 4.5015, 8.6500,  ..., 4.0732, 2.5113, 5.2155],
        ...,
        [4.0772, 5.0772, 4.0732,  ..., 8.0376, 3.0772, 5.0455],
        [3.3389, 3.4083, 2.5113,  ..., 3.0772, 8.4153, 3.1837],
        [6.1568, 6.1812, 5.2155,  ..., 5.0455, 3.1837, 8.0868]],
       device='cuda:0')

In [101]:
var_term

tensor([[3.9061e+16]], device='cuda:0', grad_fn=<AddBackward0>)

In [138]:
for x in model0.parameters():
    print(x.dtype)

torch.float32
torch.float32
torch.float32
torch.float32


In [53]:
model0(torch.zeros((1,102)).to(device))

RuntimeError: You must train on the training inputs!

In [20]:
lambda_reg = 1.0
candidates = []
neg_queried = warm_start_x0.clone()
neg_labels = warm_start_y0.clone().unsqueeze(1)

for epoch_outer in tqdm(range(1, 4001)):
    x0_random = torch.normal(0.,1.,size=(1,102), dtype=torch.float32, requires_grad=True)
    optimizer0 = torch.optim.AdamW((x0_random,), lr=1.0)
    best_loss = 10e5
    count = 0
    losses = []
    for epoch in range(1,100):
        optimizer0.zero_grad()
        x0_samples = utils.postprocess(
            vae0.sample(x0_random.to(device), 100, device, 
                        **{'tau': 1.0, 'tau_min': 0.1, 'anneal_rate': 3e-5, 'steps': 0, 'hard': False}).squeeze(1),
            'adult_income'
        )
        obj0 = likelihood0(model0(x0_samples.to(device))).variance.sum(axis=0).mean() - lambda_reg * ((x0_random.to(device) - x0_samples)**2).mean()
#         print(obj0.shape)
        loss = -obj0
        loss.backward()
        optimizer0.step()
        if loss < best_loss:
            best_loss = loss
            count = 0
            losses.append(loss)
        else:
            count += 1
        if count  == 5:
            break
    x0_query = dfx_0.iloc[np.argmin(np.linalg.norm(np.array(dfx_0) - x0_random.detach().clone().numpy()))]
    candidates.append(torch.Tensor([x0_query]).float().to(device))
    
    if epoch_outer % 100 == 0:
        new_vals = torch.concatenate(candidates)
        new_queries,_, _ = vae0(new_vals.to(device), **{'tau': 1.0, 'tau_min': 0.1, 'anneal_rate': 3e-5, 'steps': 0, 'hard': False})
        new_labels = blackbox(new_queries)
        neg_queried = torch.concatenate([neg_queried, new_queries])
        neg_labels = torch.concatenate([neg_labels, (0.5*(torch.sign(new_labels - 0.5) + 1.0)).long().detach().clone()])
        print(neg_labels.shape, neg_labels.dtype)
        print(neg_queried.shape, neg_queried.dtype)
        model0, likelihood0 = train(neg_queried.detach().clone(),neg_labels.flatten())
        model0.eval()
        likelihood0.eval()
        candidates = []

  0%|          | 0/4000 [00:00<?, ?it/s]

torch.Size([1100, 1]) torch.int64
torch.Size([1100, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1200, 1]) torch.int64
torch.Size([1200, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1300, 1]) torch.int64
torch.Size([1300, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1400, 1]) torch.int64
torch.Size([1400, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1500, 1]) torch.int64
torch.Size([1500, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1600, 1]) torch.int64
torch.Size([1600, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1700, 1]) torch.int64
torch.Size([1700, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1800, 1]) torch.int64
torch.Size([1800, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([1900, 1]) torch.int64
torch.Size([1900, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2000, 1]) torch.int64
torch.Size([2000, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2100, 1]) torch.int64
torch.Size([2100, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2200, 1]) torch.int64
torch.Size([2200, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2300, 1]) torch.int64
torch.Size([2300, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2400, 1]) torch.int64
torch.Size([2400, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2500, 1]) torch.int64
torch.Size([2500, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2600, 1]) torch.int64
torch.Size([2600, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2700, 1]) torch.int64
torch.Size([2700, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2800, 1]) torch.int64
torch.Size([2800, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([2900, 1]) torch.int64
torch.Size([2900, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3000, 1]) torch.int64
torch.Size([3000, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3100, 1]) torch.int64
torch.Size([3100, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3200, 1]) torch.int64
torch.Size([3200, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3300, 1]) torch.int64
torch.Size([3300, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3400, 1]) torch.int64
torch.Size([3400, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3500, 1]) torch.int64
torch.Size([3500, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3600, 1]) torch.int64
torch.Size([3600, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3700, 1]) torch.int64
torch.Size([3700, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3800, 1]) torch.int64
torch.Size([3800, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([3900, 1]) torch.int64
torch.Size([3900, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4000, 1]) torch.int64
torch.Size([4000, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4100, 1]) torch.int64
torch.Size([4100, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4200, 1]) torch.int64
torch.Size([4200, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4300, 1]) torch.int64
torch.Size([4300, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4400, 1]) torch.int64
torch.Size([4400, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4500, 1]) torch.int64
torch.Size([4500, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4600, 1]) torch.int64
torch.Size([4600, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4700, 1]) torch.int64
torch.Size([4700, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4800, 1]) torch.int64
torch.Size([4800, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([4900, 1]) torch.int64
torch.Size([4900, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([5000, 1]) torch.int64
torch.Size([5000, 102]) torch.float32


  0%|          | 0/50 [00:00<?, ?it/s]

In [21]:
candidates = []
pos_queried = warm_start_x1.clone()
pos_labels = warm_start_y1.clone().unsqueeze(1)
print(pos_labels.shape, pos_labels.dtype)
print(pos_queried.shape, pos_queried.dtype)

for epoch_outer in tqdm(range(1, 4001)):
    x1_random = torch.normal(0.,1.,size=(1,102), dtype=torch.float32, requires_grad=True)
    optimizer1 = torch.optim.AdamW((x1_random,), lr=1.0)
    best_loss = 10e5
    count = 0
    losses = []
    for epoch in range(1,100):
        optimizer1.zero_grad()
        x1_samples = utils.postprocess(
            vae1.sample(x1_random.to(device), 100, device, 
                        **{'tau': 1.0, 'tau_min': 0.1, 'anneal_rate': 3e-5, 'steps': 0, 'hard': False}).squeeze(1),
            'adult_income'
        )
        obj1 = likelihood1(model1(x1_samples.to(device))).variance.sum(axis=0).mean() - lambda_reg * ((x1_random.to(device) - x1_samples)**2).mean()
#         print(obj0.shape)
        loss = -obj1
        loss.backward()
        optimizer1.step()
        if loss < best_loss:
            best_loss = loss
            count = 0
            losses.append(loss)
        else:
            count += 1
        if count  == 5:
            break
    x1_query = dfx_1.iloc[np.argmin(np.linalg.norm(np.array(dfx_1) - x1_random.detach().clone().numpy()))]
    candidates.append(torch.Tensor([x1_query]).float().to(device))
    
    if epoch_outer % 100 == 0:
        new_vals = torch.concatenate(candidates)
        new_queries,_ , _ = vae1(new_vals.to(device), **{'tau': 1.0, 'tau_min': 0.1, 'anneal_rate': 3e-5, 'steps': 0, 'hard': False})
        new_labels = blackbox(new_queries)
        pos_queried = torch.concatenate([pos_queried, new_queries])
        pos_labels = torch.concatenate([pos_labels, (0.5*(torch.sign(new_labels - 0.5) + 1.0)).long().detach().clone()])
        model1, likelihood1 = train(pos_queried.detach().clone(),pos_labels.flatten())
        model1.eval()
        likelihood1.eval()
        candidates = []

torch.Size([1000, 1]) torch.int64
torch.Size([1000, 102]) torch.float32


  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [22]:
pos_labels.shape

torch.Size([5000, 1])

In [23]:
new_labels.shape

torch.Size([100, 1])

In [100]:
# parity in data
np.abs(dfy[dfx[dfx.gender == 0].index].mean() - dfy[dfx[dfx.gender == 1].index].mean())

0.19911019753072282

In [25]:
bb_input0 = torch.from_numpy(dfx_0.values).float().to(device)
bb_input1 = torch.from_numpy(dfx_1.values).float().to(device)

In [26]:
y0 = blackbox(bb_input0)
y1 = blackbox(bb_input1)

In [27]:
y0_ = torch.round(y0)
y1_ = torch.round(y1)

In [28]:
y0_.sum()

tensor(7659., device='cuda:0', grad_fn=<SumBackward0>)

In [29]:
torch.abs(y0_.mean() - y1_.mean())

tensor(0.1869, device='cuda:0', grad_fn=<AbsBackward0>)

In [35]:
dp_value = torch.abs(pos_labels.squeeze().float().mean() - neg_labels.squeeze().float().mean())
dp_value

tensor(0.0406, device='cuda:0')

In [36]:
pos_labels.float().mean()

tensor(0.0196, device='cuda:0')

In [37]:
neg_labels.float().mean()

tensor(0.0602, device='cuda:0')

In [39]:
torch.save(dp_value, 'results/adult_income/dp_value_vanilla_vae_reg.pt')